├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── docs └── images │ └── conver.jpg ├── models ├── 11_vs_11_kaggle │ └── put_weight_here ├── academy_3_vs_1_with_keeper │ └── put_weight_here ├── academy_run_pass_and_shoot_with_keeper │ └── put_weight_here └── academy_run_to_score_with_keeper │ └── put_weight_here ├── requirements.txt ├── results └── results_saved_here ├── scripts └── football │ ├── evaluate.sh │ └── replay2video.py ├── setup.py └── tmarl ├── __init__.py ├── algorithms ├── __init__.py └── r_mappo_distributed │ ├── __init__.py │ ├── mappo_algorithm.py │ └── mappo_module.py ├── configs ├── __init__.py └── config.py ├── drivers ├── __init__.py └── shared_distributed │ ├── base_driver.py │ └── football_driver.py ├── envs ├── __init__.py ├── env_wrappers.py └── football │ ├── __init__.py │ ├── env │ ├── __init__.py │ ├── config.py │ ├── football_env.py │ ├── football_env_core.py │ ├── scenario_builder.py │ └── script_helpers.py │ ├── football.py │ └── scenarios │ ├── 11_vs_11_kaggle.py │ ├── 11_vs_11_lazy.py │ ├── __init__.py │ ├── academy_3_vs_1_with_keeper.py │ ├── academy_corner.py │ ├── academy_counterattack_easy.py │ ├── academy_counterattack_hard.py │ ├── academy_empty_goal.py │ ├── academy_empty_goal_close.py │ ├── academy_pass_and_shoot_with_keeper.py │ ├── academy_run_pass_and_shoot_with_keeper.py │ ├── academy_run_to_score.py │ └── academy_run_to_score_with_keeper.py ├── loggers ├── TSee │ ├── README.md │ └── __init__.py ├── __init__.py └── utils.py ├── networks ├── __init__.py ├── policy_network.py └── utils │ ├── act.py │ ├── distributions.py │ ├── mlp.py │ ├── popart.py │ ├── rnn.py │ └── util.py ├── replay_buffers ├── __init__.py └── normal │ ├── __init__.py │ └── shared_buffer.py ├── runners ├── __init__.py ├── base_evaluator.py ├── base_runner.py └── football │ └── football_evaluator.py ├── utils ├── __init__.py ├── gpu_mem_track.py ├── modelsize_estimate.py ├── multi_discrete.py ├── segment_tree.py ├── util.py └── valuenorm.py └── wrappers ├── TWrapper ├── README.md └── __init__.py └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | # results 131 | *.dump 132 | *.avi 133 | # model 134 | *.pt 135 | .idea/ -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM reg.real-ai.cn/launchpad/launchpad-tensorflow:latest 2 | MAINTAINER Sen Na 3 | 4 | WORKDIR /tmarl 5 | 6 | COPY . . 7 | 8 | ARG pip_registry='https://mirrors.aliyun.com/pypi/simple/' 9 | RUN pip install -e . 10 | 11 | ARG pip_dependencies='\ 12 | torch \ 13 | wandb \ 14 | setproctitle \ 15 | gym \ 16 | seaborn \ 17 | tensorboardX \ 18 | icecream' 19 | 20 | RUN pip install -i $pip_registry $pip_dependencies 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TiKick 2 | 3 |
4 | 5 |
6 | 7 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 8 | 9 | [Update]: check out our newest GRF agent here: [TiZero: Mastering Multi-Agent Football with Curriculum Learning and Self-Play](https://github.com/OpenRL-Lab/TiZero) 10 | 11 | ### 1.Introduction 12 | 13 | Learning-based agent for Google Research Football 14 | 15 | Code accompanying the paper 16 | "TiKick: Towards Playing Multi-agent Football Full Games from Single-agent Demonstrations". [[arxiv](https://arxiv.org/abs/2110.04507)][[videos](https://sites.google.com/view/tikick)]. The implementation in this repositorory is heavily based on https://github.com/marlbenchmark/on-policy. 17 | 18 | Update: 19 | - [22.8.11]: 11 vs 11 model is released! Model can be found on [Google Drive](https://drive.google.com/drive/folders/1pUW_7db9Of9zCDZZWoImVgg0_lX5xCt1?usp=sharing). 20 | 21 | ### 2.Installation 22 | ``` 23 | pip install -r requirements.txt 24 | pip install . 25 | ``` 26 | 27 | ### 3.Evaluation with Trained Model 28 | 29 | (a) First, you should download the trained model from Baidu Yun or Google Drive: 30 | 31 | * pre-trained models can be found at: 32 | * Baidu Yun: [Click to download](https://pan.baidu.com/s/11bKsKxs_spXzlpRGCUNlOA) Password:vz3a 33 | * Google Drive: [Click to download](https://drive.google.com/drive/folders/1pUW_7db9Of9zCDZZWoImVgg0_lX5xCt1?usp=sharing) 34 | 35 | (b) Then, you should put the `actor.pt` under `./models/{scenario_name}/`. 36 | 37 | (c) Finally, you can go to the `./scripts/football` folder and execute the evaluation script as below: 38 | 39 | ``` 40 | cd scripts/football 41 | ./evaluate.sh 42 | ``` 43 | 44 | Then the replay file will be saved into `./results/{scenario_name}/replay/`. 45 | 46 | * Hyper-parameters in the evaluation script: 47 | * --replay_save_dir : the replay file will be saved in this directory 48 | * --model_dir : pre-trained model should be placed under this directory 49 | * --n_eval_rollout_threads : number of parallel envs for evaluating rollout 50 | * --eval_num : number of total evaluation times 51 | 52 | ### 4.Render with the Replay File 53 | 54 | Once you obtain a replay file, you can convert it to a `.avi` file and watch the game. 55 | This can be easily done via: 56 | 57 | ``` 58 | cd scripts/football 59 | python3 replay2video.py --replay_file ../../results/academy_3_vs_1_with_keeper/replay/your_path.dump 60 | ``` 61 | 62 | The video file will finally be saved to `./results/{scenario_name}/video/` 63 | 64 | ### 5.Cite 65 | 66 | Please cite our paper if you use our codes or our weights in your own work: 67 | 68 | ``` 69 | @misc{huang2021tikick, 70 | title={TiKick: Towards Playing Multi-agent Football Full Games from Single-agent Demonstrations}, 71 | author={Shiyu Huang and Wenze Chen and Longfei Zhang and Ziyang Li and Fengming Zhu and Deheng Ye and Ting Chen and Jun Zhu}, 72 | year={2021}, 73 | eprint={2110.04507}, 74 | archivePrefix={arXiv}, 75 | primaryClass={cs.AI} 76 | } 77 | ``` 78 | -------------------------------------------------------------------------------- /docs/images/conver.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TARTRL/TiKick/3521627be1b4c65215157f748fa0df6a35f61c23/docs/images/conver.jpg -------------------------------------------------------------------------------- /models/11_vs_11_kaggle/put_weight_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TARTRL/TiKick/3521627be1b4c65215157f748fa0df6a35f61c23/models/11_vs_11_kaggle/put_weight_here -------------------------------------------------------------------------------- /models/academy_3_vs_1_with_keeper/put_weight_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TARTRL/TiKick/3521627be1b4c65215157f748fa0df6a35f61c23/models/academy_3_vs_1_with_keeper/put_weight_here -------------------------------------------------------------------------------- /models/academy_run_pass_and_shoot_with_keeper/put_weight_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TARTRL/TiKick/3521627be1b4c65215157f748fa0df6a35f61c23/models/academy_run_pass_and_shoot_with_keeper/put_weight_here -------------------------------------------------------------------------------- /models/academy_run_to_score_with_keeper/put_weight_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TARTRL/TiKick/3521627be1b4c65215157f748fa0df6a35f61c23/models/academy_run_to_score_with_keeper/put_weight_here -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | cffi 3 | setproctitle 4 | icecream 5 | tensorboardX 6 | gym 7 | pysc2 8 | torch 9 | gfootball 10 | tabulate 11 | lz4 12 | kaggle_environments -------------------------------------------------------------------------------- /results/results_saved_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TARTRL/TiKick/3521627be1b4c65215157f748fa0df6a35f61c23/results/results_saved_here -------------------------------------------------------------------------------- /scripts/football/evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | env="gfootball" 4 | 5 | scenario="11_vs_11_kaggle" 6 | num_agents=10 7 | 8 | #scenario="academy_3_vs_1_with_keeper" 9 | #num_agents=3 10 | 11 | # scenario="academy_run_pass_and_shoot_with_keeper" 12 | # num_agents=2 13 | 14 | # scenario="academy_run_to_score_with_keeper" 15 | # num_agents=1 16 | 17 | algo="rmappo" 18 | 19 | CUDA_VISIBLE_DEVICES=0 python3 ../../tmarl/runners/football/football_evaluator.py --env_name ${env} \ 20 | --algorithm_name ${algo} --scenario_name ${scenario} --num_agents ${num_agents} \ 21 | --n_eval_rollout_threads 2 --eval_num 1 --use_eval \ 22 | --replay_save_dir "../../results/$scenario/replay/" \ 23 | --model_dir "../../models/$scenario" 24 | -------------------------------------------------------------------------------- /scripts/football/replay2video.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """Script allowing to replay a given trace file. 17 | Example usage: 18 | python replay.py --trace_file=/tmp/dumps/shutdown_20190521-165136974075.dump 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | from tmarl.envs.football.env import script_helpers 26 | 27 | from absl import app 28 | from absl import flags 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | flags.DEFINE_string('replay_file', None, 'replay file path') 33 | flags.DEFINE_string('video_save_dir', '../../results/videos', 'video save dir') 34 | flags.DEFINE_integer('fps', 10, 'How many frames per second to render') 35 | flags.mark_flag_as_required('replay_file') 36 | 37 | 38 | def main(_): 39 | script_helpers.ScriptHelpers().replay(FLAGS.replay_file, FLAGS.fps,directory=FLAGS.video_save_dir) 40 | 41 | 42 | if __name__ == '__main__': 43 | app.run(main) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2021 The TARTRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | import os 20 | from setuptools import setup, find_packages 21 | import setuptools 22 | 23 | def get_version() -> str: 24 | # https://packaging.python.org/guides/single-sourcing-package-version/ 25 | init = open(os.path.join("tmarl", "__init__.py"), "r").read().split() 26 | return init[init.index("__version__") + 2][1:-1] 27 | 28 | setup( 29 | name="tmarl", # Replace with your own username 30 | version=get_version(), 31 | description="marl algorithms", 32 | long_description=open("README.md", encoding="utf8").read(), 33 | long_description_content_type="text/markdown", 34 | author="tmarl", 35 | author_email="tmarl_contact@tartrl.cn", 36 | packages=setuptools.find_packages(), 37 | classifiers=[ 38 | "Development Status :: 3 - Alpha", 39 | "Intended Audience :: Science/Research", 40 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 41 | "Topic :: Software Development :: Libraries :: Python Modules", 42 | "Programming Language :: Python :: 3", 43 | "License :: OSI Approved :: Apache License", 44 | "Operating System :: OS Independent", 45 | ], 46 | keywords="multi-agent reinforcement learning algorithms pytorch", 47 | python_requires='>=3.6', 48 | ) 49 | -------------------------------------------------------------------------------- /tmarl/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.3" -------------------------------------------------------------------------------- /tmarl/algorithms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TARTRL/TiKick/3521627be1b4c65215157f748fa0df6a35f61c23/tmarl/algorithms/__init__.py -------------------------------------------------------------------------------- /tmarl/algorithms/r_mappo_distributed/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2021 The TARTRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /tmarl/algorithms/r_mappo_distributed/mappo_algorithm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from tmarl.utils.valuenorm import ValueNorm 4 | 5 | # implement the loss of the MAPPO here 6 | class MAPPOAlgorithm(): 7 | def __init__(self, 8 | args, 9 | init_module, 10 | device=torch.device("cpu")): 11 | 12 | self.device = device 13 | self.tpdv = dict(dtype=torch.float32, device=device) 14 | 15 | self.algo_module = init_module 16 | 17 | 18 | self.clip_param = args.clip_param 19 | self.ppo_epoch = args.ppo_epoch 20 | self.num_mini_batch = args.num_mini_batch 21 | self.data_chunk_length = args.data_chunk_length 22 | self.policy_value_loss_coef = args.policy_value_loss_coef 23 | self.value_loss_coef = args.value_loss_coef 24 | self.entropy_coef = args.entropy_coef 25 | self.max_grad_norm = args.max_grad_norm 26 | self.huber_delta = args.huber_delta 27 | 28 | self._use_recurrent_policy = args.use_recurrent_policy 29 | self._use_naive_recurrent = args.use_naive_recurrent_policy 30 | self._use_max_grad_norm = args.use_max_grad_norm 31 | self._use_clipped_value_loss = args.use_clipped_value_loss 32 | self._use_huber_loss = args.use_huber_loss 33 | self._use_popart = args.use_popart 34 | self._use_valuenorm = args.use_valuenorm 35 | self._use_value_active_masks = args.use_value_active_masks 36 | self._use_policy_active_masks = args.use_policy_active_masks 37 | self._use_policy_vhead = args.use_policy_vhead 38 | 39 | assert (self._use_popart and self._use_valuenorm) == False, ("self._use_popart and self._use_valuenorm can not be set True simultaneously") 40 | 41 | if self._use_popart: 42 | self.value_normalizer = self.algo_module.critic.v_out 43 | if self._use_policy_vhead: 44 | self.policy_value_normalizer = self.algo_module.actor.v_out 45 | elif self._use_valuenorm: 46 | self.value_normalizer = ValueNorm(1, device = self.device) 47 | if self._use_policy_vhead: 48 | self.policy_value_normalizer = ValueNorm(1, device = self.device) 49 | else: 50 | self.value_normalizer = None 51 | if self._use_policy_vhead: 52 | self.policy_value_normalizer = None 53 | 54 | def prep_rollout(self): 55 | self.algo_module.actor.eval() 56 | 57 | -------------------------------------------------------------------------------- /tmarl/algorithms/r_mappo_distributed/mappo_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from tmarl.networks.policy_network import PolicyNetwork 4 | 5 | class MAPPOModule: 6 | def __init__(self, args, obs_space, share_obs_space, act_space, device=torch.device("cpu")): 7 | 8 | self.device = device 9 | self.lr = args.lr 10 | self.critic_lr = args.critic_lr 11 | self.opti_eps = args.opti_eps 12 | self.weight_decay = args.weight_decay 13 | 14 | self.obs_space = obs_space 15 | self.share_obs_space = share_obs_space 16 | self.act_space = act_space 17 | 18 | self.actor = PolicyNetwork(args, self.obs_space, self.act_space, self.device) 19 | 20 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.lr, eps=self.opti_eps, weight_decay=self.weight_decay) 21 | 22 | def get_actions(self, share_obs, obs, rnn_states_actor, rnn_states_critic, masks, available_actions=None, deterministic=False): 23 | actions, action_log_probs, rnn_states_actor = self.actor(obs, rnn_states_actor, masks, available_actions, deterministic) 24 | 25 | return None, actions, action_log_probs, rnn_states_actor, None -------------------------------------------------------------------------------- /tmarl/configs/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2021 The TARTRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /tmarl/configs/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2021 The TARTRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | import argparse 20 | 21 | def get_config(): 22 | 23 | parser = argparse.ArgumentParser( 24 | description='TiKick', formatter_class=argparse.RawDescriptionHelpFormatter) 25 | 26 | # prepare parameters 27 | parser.add_argument("--algorithm_name", type=str, 28 | default='rmappo', choices=["rmappo"]) 29 | 30 | parser.add_argument("--experiment_name", type=str, default="check", 31 | help="an identifier to distinguish different experiment.") 32 | parser.add_argument("--seed", type=int, default=1, 33 | help="Random seed for numpy/torch") 34 | parser.add_argument("--disable_cuda", action='store_true', default=False, 35 | help="by default False, will use GPU to train; or else will use CPU;") 36 | parser.add_argument("--cuda_deterministic", 37 | action='store_false', default=True, 38 | help="by default, make sure random seed effective. if set, bypass such function.") 39 | 40 | parser.add_argument("--n_rollout_threads", type=int, default=2, 41 | help="Number of parallel envs for training rollout") 42 | parser.add_argument("--n_eval_rollout_threads", type=int, default=1, 43 | help="Number of parallel envs for evaluating rollout") 44 | parser.add_argument("--n_render_rollout_threads", type=int, default=1, 45 | help="Number of parallel envs for rendering rollout") 46 | parser.add_argument("--eval_num", type=int, default=1, 47 | help='Number of environment steps to evaluate (default: 1)') 48 | 49 | # env parameters 50 | parser.add_argument("--env_name", type=str, default='StarCraft2', 51 | help="specify the name of environment") 52 | parser.add_argument("--use_obs_instead_of_state", action='store_true', 53 | default=False, help="Whether to use global state or concatenated obs") 54 | 55 | # replay buffer parameters 56 | parser.add_argument("--episode_length", type=int, 57 | default=200, help="Max length for any episode") 58 | 59 | # network parameters 60 | parser.add_argument("--separate_policy", action='store_true', 61 | default=False, help='Whether agent seperate the policy') 62 | parser.add_argument("--use_centralized_V", action='store_false', 63 | default=True, help="Whether to use centralized V function") 64 | parser.add_argument("--use_conv1d", action='store_true', 65 | default=False, help="Whether to use conv1d") 66 | parser.add_argument("--stacked_frames", type=int, default=1, 67 | help="Dimension of hidden layers for actor/critic networks") 68 | parser.add_argument("--use_stacked_frames", action='store_true', 69 | default=False, help="Whether to use stacked_frames") 70 | parser.add_argument("--hidden_size", type=int, default=256, 71 | help="Dimension of hidden layers for actor/critic networks") # TODO @zoeyuchao. The same comment might in need of change. 72 | parser.add_argument("--layer_N", type=int, default=3, 73 | help="Number of layers for actor/critic networks") 74 | parser.add_argument("--activation_id", type=int, 75 | default=1, help="choose 0 to use tanh, 1 to use relu, 2 to use leaky relu, 3 to use elu") 76 | parser.add_argument("--use_popart", action='store_true', default=False, 77 | help="by default False, use PopArt to normalize rewards.") 78 | parser.add_argument("--use_valuenorm", action='store_false', default=True, 79 | help="by default True, use running mean and std to normalize rewards.") 80 | parser.add_argument("--use_feature_normalization", action='store_false', 81 | default=True, help="Whether to apply layernorm to the inputs") 82 | parser.add_argument("--use_orthogonal", action='store_false', default=True, 83 | help="Whether to use Orthogonal initialization for weights and 0 initialization for biases") 84 | parser.add_argument("--gain", type=float, default=0.01, 85 | help="The gain # of last action layer") 86 | parser.add_argument("--cnn_layers_params", type=str, default=None, 87 | help="The parameters of cnn layer") 88 | parser.add_argument("--use_maxpool2d", action='store_true', 89 | default=False, help="Whether to apply layernorm to the inputs") 90 | 91 | # recurrent parameters 92 | parser.add_argument("--use_naive_recurrent_policy", action='store_true', 93 | default=False, help='Whether to use a naive recurrent policy') 94 | parser.add_argument("--use_recurrent_policy", action='store_false', 95 | default=True, help='use a recurrent policy') 96 | parser.add_argument("--recurrent_N", type=int, default=1, 97 | help="The number of recurrent layers.") 98 | parser.add_argument("--data_chunk_length", type=int, default=25, 99 | help="Time length of chunks used to train a recurrent_policy") 100 | parser.add_argument("--use_influence_policy", action='store_true', 101 | default=False, help='use a recurrent policy') 102 | parser.add_argument("--influence_layer_N", type=int, default=1, 103 | help="Number of layers for actor/critic networks") 104 | 105 | 106 | # optimizer parameters 107 | parser.add_argument("--lr", type=float, default=5e-4, 108 | help='learning rate (default: 5e-4)') 109 | parser.add_argument("--tau", type=float, default=0.995, 110 | help='soft update polyak (default: 0.995)') 111 | parser.add_argument("--critic_lr", type=float, default=5e-4, 112 | help='critic learning rate (default: 5e-4)') 113 | parser.add_argument("--opti_eps", type=float, default=1e-5, 114 | help='RMSprop optimizer epsilon (default: 1e-5)') 115 | parser.add_argument("--weight_decay", type=float, default=0) 116 | 117 | # ppo parameters 118 | parser.add_argument("--ppo_epoch", type=int, default=15, 119 | help='number of ppo epochs (default: 15)') 120 | parser.add_argument("--use_policy_vhead", 121 | action='store_true', default=False, 122 | help="by default, do not use policy vhead. if set, use policy vhead.") 123 | parser.add_argument("--use_clipped_value_loss", 124 | action='store_false', default=True, 125 | help="by default, clip loss value. If set, do not clip loss value.") 126 | parser.add_argument("--clip_param", type=float, default=0.2, 127 | help='ppo clip parameter (default: 0.2)') 128 | parser.add_argument("--num_mini_batch", type=int, default=1, 129 | help='number of batches for ppo (default: 1)') 130 | parser.add_argument("--policy_value_loss_coef", type=float, 131 | default=1, help='policy value loss coefficient (default: 0.5)') 132 | parser.add_argument("--entropy_coef", type=float, default=0.01, 133 | help='entropy term coefficient (default: 0.01)') 134 | parser.add_argument("--value_loss_coef", type=float, 135 | default=1, help='value loss coefficient (default: 0.5)') 136 | parser.add_argument("--use_max_grad_norm", 137 | action='store_false', default=True, 138 | help="by default, use max norm of gradients. If set, do not use.") 139 | parser.add_argument("--max_grad_norm", type=float, default=10.0, 140 | help='max norm of gradients (default: 0.5)') 141 | parser.add_argument("--use_gae", action='store_false', 142 | default=True, help='use generalized advantage estimation') 143 | parser.add_argument("--gamma", type=float, default=0.99, 144 | help='discount factor for rewards (default: 0.99)') 145 | parser.add_argument("--gae_lambda", type=float, default=0.95, 146 | help='gae lambda parameter (default: 0.95)') 147 | parser.add_argument("--use_proper_time_limits", action='store_true', 148 | default=False, help='compute returns taking into account time limits') 149 | parser.add_argument("--use_huber_loss", action='store_false', default=True, 150 | help="by default, use huber loss. If set, do not use huber loss.") 151 | parser.add_argument("--use_value_active_masks", 152 | action='store_false', default=True, 153 | help="by default True, whether to mask useless data in value loss.") 154 | parser.add_argument("--use_policy_active_masks", 155 | action='store_false', default=True, 156 | help="by default True, whether to mask useless data in policy loss.") 157 | parser.add_argument("--huber_delta", type=float, 158 | default=10.0, help=" coefficience of huber loss.") 159 | 160 | # save parameters 161 | parser.add_argument("--save_interval", type=int, default=1, 162 | help="time duration between contiunous twice models saving.") 163 | 164 | # log parameters 165 | parser.add_argument("--log_interval", type=int, default=5, 166 | help="time duration between contiunous twice log printing.") 167 | 168 | # eval parameters 169 | parser.add_argument("--use_eval", action='store_true', default=False, 170 | help="by default, do not start evaluation. If set`, start evaluation alongside with training.") 171 | parser.add_argument("--eval_interval", type=int, default=25, 172 | help="time duration between contiunous twice evaluation progress.") 173 | parser.add_argument("--eval_episodes", type=int, default=64, 174 | help="number of episodes of a single evaluation.") 175 | 176 | # pretrained parameters 177 | parser.add_argument("--model_dir", type=str, default=None, 178 | help="by default None. set the path to pretrained model.") 179 | 180 | parser.add_argument("--replay_save_dir", type=str, default=None, 181 | help="replay file save dir") 182 | 183 | # replay buffer parameters 184 | 185 | 186 | 187 | return parser 188 | -------------------------------------------------------------------------------- /tmarl/drivers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TARTRL/TiKick/3521627be1b4c65215157f748fa0df6a35f61c23/tmarl/drivers/__init__.py -------------------------------------------------------------------------------- /tmarl/drivers/shared_distributed/base_driver.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def _t2n(x): 5 | return x.detach().cpu().numpy() 6 | 7 | class Driver(object): 8 | def __init__(self, config, client=None): 9 | 10 | self.all_args = config['all_args'] 11 | self.envs = config['envs'] 12 | self.eval_envs = config['eval_envs'] 13 | self.device = config['device'] 14 | self.num_agents = config['num_agents'] 15 | if 'signal' in config: 16 | self.actor_id = config['signal'].actor_id 17 | self.weight_ids = config['signal'].weight_ids 18 | else: 19 | self.actor_id = 0 20 | self.weight_ids = [0] 21 | 22 | # parameters 23 | self.env_name = self.all_args.env_name 24 | self.algorithm_name = self.all_args.algorithm_name 25 | self.experiment_name = self.all_args.experiment_name 26 | self.use_centralized_V = self.all_args.use_centralized_V 27 | self.use_obs_instead_of_state = self.all_args.use_obs_instead_of_state 28 | self.num_env_steps = self.all_args.num_env_steps if hasattr(self.all_args,'num_env_steps') else self.all_args.eval_num 29 | 30 | self.episode_length = self.all_args.episode_length 31 | self.n_rollout_threads = self.all_args.n_rollout_threads 32 | self.learner_n_rollout_threads = self.all_args.n_rollout_threads 33 | 34 | self.n_eval_rollout_threads = self.all_args.n_eval_rollout_threads 35 | self.hidden_size = self.all_args.hidden_size 36 | self.recurrent_N = self.all_args.recurrent_N 37 | 38 | # interval 39 | self.save_interval = self.all_args.save_interval 40 | self.use_eval = self.all_args.use_eval 41 | self.eval_interval = self.all_args.eval_interval 42 | self.log_interval = self.all_args.log_interval 43 | 44 | # dir 45 | self.model_dir = self.all_args.model_dir 46 | 47 | 48 | 49 | if self.algorithm_name == "rmappo": 50 | from tmarl.algorithms.r_mappo_distributed.mappo_algorithm import MAPPOAlgorithm as TrainAlgo 51 | from tmarl.algorithms.r_mappo_distributed.mappo_module import MAPPOModule as AlgoModule 52 | else: 53 | raise NotImplementedError 54 | 55 | if self.envs: 56 | share_observation_space = self.envs.share_observation_space[0] \ 57 | if self.use_centralized_V else self.envs.observation_space[0] 58 | # policy network 59 | self.algo_module = AlgoModule(self.all_args, 60 | self.envs.observation_space[0], 61 | share_observation_space, 62 | self.envs.action_space[0], 63 | device=self.device) 64 | 65 | else: 66 | share_observation_space = self.eval_envs.share_observation_space[0] \ 67 | if self.use_centralized_V else self.eval_envs.observation_space[0] 68 | # policy network 69 | self.algo_module = AlgoModule(self.all_args, 70 | self.eval_envs.observation_space[0], 71 | share_observation_space, 72 | self.eval_envs.action_space[0], 73 | device=self.device) 74 | 75 | if self.model_dir is not None: 76 | self.restore() 77 | 78 | # algorithm 79 | self.trainer = TrainAlgo(self.all_args, self.algo_module, device=self.device) 80 | 81 | 82 | # buffer 83 | from tmarl.replay_buffers.normal.shared_buffer import SharedReplayBuffer 84 | 85 | self.buffer = SharedReplayBuffer(self.all_args, 86 | self.num_agents, 87 | self.envs.observation_space[0] if self.envs else self.eval_envs.observation_space[0], 88 | share_observation_space, 89 | self.envs.action_space[0] if self.envs else self.eval_envs.action_space[0]) 90 | 91 | def run(self): 92 | raise NotImplementedError 93 | 94 | def warmup(self): 95 | raise NotImplementedError 96 | 97 | def collect(self, step): 98 | raise NotImplementedError 99 | 100 | def insert(self, data): 101 | raise NotImplementedError 102 | 103 | def restore(self): 104 | policy_actor_state_dict = torch.load(str(self.model_dir) + '/actor.pt', map_location=self.device) 105 | self.algo_module.actor.load_state_dict(policy_actor_state_dict) 106 | -------------------------------------------------------------------------------- /tmarl/drivers/shared_distributed/football_driver.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | 4 | from tmarl.drivers.shared_distributed.base_driver import Driver 5 | 6 | 7 | def _t2n(x): 8 | return x.detach().cpu().numpy() 9 | 10 | 11 | class FootballDriver(Driver): 12 | def __init__(self, config): 13 | super(FootballDriver, self).__init__(config) 14 | 15 | def run(self): 16 | self.trainer.prep_rollout() 17 | episodes = int(self.num_env_steps) 18 | total_num_steps = 0 19 | for episode in range(episodes): 20 | print('Episode {}:'.format(episode)) 21 | 22 | self.eval(total_num_steps) 23 | 24 | def eval(self, total_num_steps): 25 | 26 | eval_episode_rewards = [] 27 | eval_obs, eval_share_obs, eval_available_actions = self.eval_envs.reset() 28 | 29 | agent_num = eval_obs.shape[1] 30 | used_buffer = self.buffer 31 | rnn_shape = [self.n_eval_rollout_threads, agent_num, *used_buffer.rnn_states_critic.shape[3:]] 32 | eval_rnn_states = np.zeros(rnn_shape, dtype=np.float32) 33 | eval_rnn_states_critic = np.zeros(rnn_shape, dtype=np.float32) 34 | eval_masks = np.ones((self.n_eval_rollout_threads, agent_num, 1), dtype=np.float32) 35 | 36 | finished = None 37 | 38 | for eval_step in tqdm(range(3001)): 39 | self.trainer.prep_rollout() 40 | _, eval_action, eval_action_log_prob, eval_rnn_states, _ = \ 41 | self.trainer.algo_module.get_actions(np.concatenate(eval_share_obs), 42 | np.concatenate(eval_obs), 43 | np.concatenate(eval_rnn_states), 44 | None, 45 | np.concatenate(eval_masks), 46 | np.concatenate(eval_available_actions), 47 | deterministic=True) 48 | 49 | eval_actions = np.array( 50 | np.split(_t2n(eval_action), self.n_eval_rollout_threads)) 51 | eval_rnn_states = np.array( 52 | np.split(_t2n(eval_rnn_states), self.n_eval_rollout_threads)) 53 | 54 | 55 | if self.eval_envs.action_space[0].__class__.__name__ == 'Discrete': 56 | eval_actions_env = np.squeeze( 57 | np.eye(self.eval_envs.action_space[0].n)[eval_actions], 2) 58 | else: 59 | raise NotImplementedError 60 | 61 | # Obser reward and next obs 62 | eval_obs, eval_share_obs, eval_rewards, eval_dones, eval_infos, eval_available_actions = \ 63 | self.eval_envs.step(eval_actions_env) 64 | eval_rewards = eval_rewards.reshape([-1, agent_num]) # [roll_out, num_agents] 65 | 66 | if finished is None: 67 | eval_r = eval_rewards[:,:self.num_agents] 68 | eval_episode_rewards.append(eval_r) 69 | finished = eval_dones.copy() 70 | else: 71 | eval_r = (eval_rewards * ~finished)[:,:self.num_agents] 72 | eval_episode_rewards.append(eval_r) 73 | finished = eval_dones.copy() | finished 74 | 75 | eval_masks = np.ones( 76 | (self.n_eval_rollout_threads, agent_num, 1), dtype=np.float32) 77 | eval_masks[eval_dones == True] = np.zeros( 78 | ((eval_dones == True).sum(), 1), dtype=np.float32) 79 | eval_rnn_states[eval_dones == True] = np.zeros( 80 | ((eval_dones == True).sum(), self.recurrent_N, self.hidden_size), dtype=np.float32) 81 | 82 | 83 | if finished.all() == True: 84 | break 85 | 86 | eval_episode_rewards = np.array(eval_episode_rewards) # [step,rollout,num_agents] 87 | 88 | ally_goal = np.sum((eval_episode_rewards == 1), axis=0) 89 | enemy_goal = np.sum((eval_episode_rewards == -1), axis=0) 90 | net_goal = np.sum(eval_episode_rewards, axis=0) 91 | winning_rate = np.mean(net_goal, axis=-1) 92 | eval_env_infos = {} 93 | eval_env_infos['eval_average_winning_rate'] = winning_rate>0 94 | eval_env_infos['eval_average_losing_rate'] = winning_rate<0 95 | eval_env_infos['eval_average_draw_rate'] = winning_rate==0 96 | eval_env_infos['eval_average_ally_score'] = ally_goal 97 | eval_env_infos['eval_average_enemy_score'] = enemy_goal 98 | eval_env_infos['eval_average_net_score'] = net_goal 99 | print("\tSuccess Rate: " + str(np.mean(winning_rate>0)) ) 100 | -------------------------------------------------------------------------------- /tmarl/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TARTRL/TiKick/3521627be1b4c65215157f748fa0df6a35f61c23/tmarl/envs/__init__.py -------------------------------------------------------------------------------- /tmarl/envs/env_wrappers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from OpenAI Baselines code to work with multi-agent envs 3 | """ 4 | import numpy as np 5 | 6 | from multiprocessing import Process, Pipe 7 | from abc import ABC, abstractmethod 8 | from tmarl.utils.util import tile_images 9 | 10 | 11 | class CloudpickleWrapper(object): 12 | """ 13 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 14 | """ 15 | 16 | def __init__(self, x): 17 | self.x = x 18 | 19 | def __getstate__(self): 20 | import cloudpickle 21 | return cloudpickle.dumps(self.x) 22 | 23 | def __setstate__(self, ob): 24 | import pickle 25 | self.x = pickle.loads(ob) 26 | 27 | class ShareVecEnv(ABC): 28 | """ 29 | An abstract asynchronous, vectorized environment. 30 | Used to batch data from multiple copies of an environment, so that 31 | each observation becomes an batch of observations, and expected action is a batch of actions to 32 | be applied per-environment. 33 | """ 34 | closed = False 35 | viewer = None 36 | 37 | metadata = { 38 | 'render.modes': ['human', 'rgb_array'] 39 | } 40 | 41 | def __init__(self, num_envs, observation_space, share_observation_space, action_space): 42 | self.num_envs = num_envs 43 | self.observation_space = observation_space 44 | self.share_observation_space = share_observation_space 45 | self.action_space = action_space 46 | 47 | @abstractmethod 48 | def reset(self): 49 | """ 50 | Reset all the environments and return an array of 51 | observations, or a dict of observation arrays. 52 | 53 | If step_async is still doing work, that work will 54 | be cancelled and step_wait() should not be called 55 | until step_async() is invoked again. 56 | """ 57 | pass 58 | 59 | @abstractmethod 60 | def step_async(self, actions): 61 | """ 62 | Tell all the environments to start taking a step 63 | with the given actions. 64 | Call step_wait() to get the results of the step. 65 | 66 | You should not call this if a step_async run is 67 | already pending. 68 | """ 69 | pass 70 | 71 | @abstractmethod 72 | def step_wait(self): 73 | """ 74 | Wait for the step taken with step_async(). 75 | 76 | Returns (obs, rews, dones, infos): 77 | - obs: an array of observations, or a dict of 78 | arrays of observations. 79 | - rews: an array of rewards 80 | - dones: an array of "episode done" booleans 81 | - infos: a sequence of info objects 82 | """ 83 | pass 84 | 85 | def close_extras(self): 86 | """ 87 | Clean up the extra resources, beyond what's in this base class. 88 | Only runs when not self.closed. 89 | """ 90 | pass 91 | 92 | def close(self): 93 | if self.closed: 94 | return 95 | if self.viewer is not None: 96 | self.viewer.close() 97 | self.close_extras() 98 | self.closed = True 99 | 100 | def step(self, actions): 101 | """ 102 | Step the environments synchronously. 103 | 104 | This is available for backwards compatibility. 105 | """ 106 | self.step_async(actions) 107 | return self.step_wait() 108 | 109 | def render(self, mode='human'): 110 | imgs = self.get_images() 111 | bigimg = tile_images(imgs) 112 | if mode == 'human': 113 | self.get_viewer().imshow(bigimg) 114 | return self.get_viewer().isopen 115 | elif mode == 'rgb_array': 116 | return bigimg 117 | else: 118 | raise NotImplementedError 119 | 120 | def get_images(self): 121 | """ 122 | Return RGB images from each environment 123 | """ 124 | raise NotImplementedError 125 | 126 | @property 127 | def unwrapped(self): 128 | if isinstance(self, VecEnvWrapper): 129 | return self.venv.unwrapped 130 | else: 131 | return self 132 | 133 | def get_viewer(self): 134 | if self.viewer is None: 135 | from gym.envs.classic_control import rendering 136 | self.viewer = rendering.SimpleImageViewer() 137 | return self.viewer 138 | 139 | def worker(remote, parent_remote, env_fn_wrapper): 140 | parent_remote.close() 141 | env = env_fn_wrapper.x() 142 | while True: 143 | cmd, data = remote.recv() 144 | if cmd == 'step': 145 | ob, reward, done, info = env.step(data) 146 | 147 | if 'bool' in done.__class__.__name__: 148 | if done: 149 | ob = env.reset() 150 | else: 151 | if np.all(done): 152 | ob = env.reset() 153 | 154 | remote.send((ob, reward, done, info)) 155 | elif cmd == 'reset': 156 | ob = env.reset() 157 | remote.send((ob)) 158 | elif cmd == 'render': 159 | if data == "rgb_array": 160 | fr = env.render(mode=data) 161 | remote.send(fr) 162 | elif data == "human": 163 | env.render(mode=data) 164 | elif cmd == 'reset_task': 165 | ob = env.reset_task() 166 | remote.send(ob) 167 | elif cmd == 'close': 168 | env.close() 169 | remote.close() 170 | break 171 | elif cmd == 'get_spaces': 172 | remote.send((env.observation_space, env.share_observation_space, env.action_space)) 173 | elif cmd == 'get_max_step': 174 | remote.send((env.max_steps)) 175 | elif cmd == 'get_action': # for behavior cloning 176 | action = env.get_action() 177 | remote.send((action)) 178 | else: 179 | raise NotImplementedError 180 | 181 | class SubprocVecEnv(ShareVecEnv): 182 | def __init__(self, env_fns, spaces=None): 183 | """ 184 | envs: list of gym environments to run in subprocesses 185 | """ 186 | self.waiting = False 187 | self.closed = False 188 | nenvs = len(env_fns) 189 | self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) 190 | self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) 191 | for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] 192 | for p in self.ps: 193 | p.daemon = True # if the main process crashes, we should not cause things to hang 194 | p.start() 195 | for remote in self.work_remotes: 196 | remote.close() 197 | 198 | self.remotes[0].send(('get_spaces', None)) 199 | observation_space, share_observation_space, action_space = self.remotes[0].recv() 200 | ShareVecEnv.__init__(self, len(env_fns), observation_space, 201 | share_observation_space, action_space) 202 | 203 | def step_async(self, actions): 204 | for remote, action in zip(self.remotes, actions): 205 | remote.send(('step', action)) 206 | self.waiting = True 207 | 208 | def step_wait(self): 209 | results = [remote.recv() for remote in self.remotes] 210 | self.waiting = False 211 | obs, rews, dones, infos = zip(*results) 212 | return np.stack(obs), np.stack(rews), np.stack(dones), infos 213 | 214 | def reset(self): 215 | for remote in self.remotes: 216 | remote.send(('reset', None)) 217 | obs = [remote.recv() for remote in self.remotes] 218 | return np.stack(obs) 219 | 220 | def get_max_step(self): 221 | for remote in self.remotes: 222 | remote.send(('get_max_step', None)) 223 | return np.stack([remote.recv() for remote in self.remotes]) 224 | 225 | def reset_task(self): 226 | for remote in self.remotes: 227 | remote.send(('reset_task', None)) 228 | return np.stack([remote.recv() for remote in self.remotes]) 229 | 230 | def close(self): 231 | if self.closed: 232 | return 233 | if self.waiting: 234 | for remote in self.remotes: 235 | remote.recv() 236 | for remote in self.remotes: 237 | remote.send(('close', None)) 238 | for p in self.ps: 239 | p.join() 240 | self.closed = True 241 | 242 | def render(self, mode="rgb_array"): 243 | for remote in self.remotes: 244 | remote.send(('render', mode)) 245 | if mode == "rgb_array": 246 | frame = [remote.recv() for remote in self.remotes] 247 | return np.stack(frame) 248 | 249 | def shareworker(remote, parent_remote, env_fn_wrapper): 250 | parent_remote.close() 251 | env = env_fn_wrapper.x() 252 | while True: 253 | cmd, data = remote.recv() 254 | if cmd == 'step': 255 | ob, s_ob, reward, done, info, available_actions = env.step(data) 256 | if 'bool' in done.__class__.__name__: 257 | if done: 258 | ob, s_ob, available_actions = env.reset() 259 | else: 260 | if np.all(done): 261 | ob, s_ob, available_actions = env.reset() 262 | 263 | remote.send((ob, s_ob, reward, done, info, available_actions)) 264 | elif cmd == 'reset': 265 | ob, s_ob, available_actions = env.reset() 266 | remote.send((ob, s_ob, available_actions)) 267 | elif cmd == 'reset_task': 268 | ob = env.reset_task() 269 | remote.send(ob) 270 | elif cmd == 'render': 271 | if data == "rgb_array": 272 | fr = env.render(mode=data) 273 | remote.send(fr) 274 | elif data == "human": 275 | env.render(mode=data) 276 | elif cmd == 'close': 277 | env.close() 278 | remote.close() 279 | break 280 | elif cmd == 'get_spaces': 281 | remote.send( 282 | (env.observation_space, env.share_observation_space, env.action_space)) 283 | elif cmd == 'render_vulnerability': 284 | fr = env.render_vulnerability(data) 285 | remote.send((fr)) 286 | elif cmd == 'get_action': # for behavior cloning 287 | action = env.get_action() 288 | remote.send((action)) 289 | else: 290 | raise NotImplementedError 291 | 292 | class ShareSubprocVecEnv(ShareVecEnv): 293 | def __init__(self, env_fns, spaces=None): 294 | """ 295 | envs: list of gym environments to run in subprocesses 296 | """ 297 | self.waiting = False 298 | self.closed = False 299 | nenvs = len(env_fns) 300 | self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) 301 | self.ps = [Process(target=shareworker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) 302 | for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] 303 | for p in self.ps: 304 | p.daemon = True # if the main process crashes, we should not cause things to hang 305 | p.start() 306 | for remote in self.work_remotes: 307 | remote.close() 308 | self.remotes[0].send(('get_spaces', None)) 309 | observation_space, share_observation_space, action_space = self.remotes[0].recv( 310 | ) 311 | ShareVecEnv.__init__(self, len(env_fns), observation_space, 312 | share_observation_space, action_space) 313 | 314 | def step_async(self, actions): 315 | for remote, action in zip(self.remotes, actions): 316 | remote.send(('step', action)) 317 | self.waiting = True 318 | 319 | def step_wait(self): 320 | results = [remote.recv() for remote in self.remotes] 321 | self.waiting = False 322 | obs, share_obs, rews, dones, infos, available_actions = zip(*results) 323 | return np.stack(obs), np.stack(share_obs), np.stack(rews), np.stack(dones), infos, np.stack(available_actions) 324 | 325 | def reset(self): 326 | for remote in self.remotes: 327 | remote.send(('reset', None)) 328 | results = [remote.recv() for remote in self.remotes] 329 | obs, share_obs, available_actions = zip(*results) 330 | return np.stack(obs), np.stack(share_obs), np.stack(available_actions) 331 | 332 | def reset_task(self): 333 | for remote in self.remotes: 334 | remote.send(('reset_task', None)) 335 | return np.stack([remote.recv() for remote in self.remotes]) 336 | 337 | def close(self): 338 | if self.closed: 339 | return 340 | if self.waiting: 341 | for remote in self.remotes: 342 | remote.recv() 343 | for remote in self.remotes: 344 | remote.send(('close', None)) 345 | for p in self.ps: 346 | p.join() 347 | self.closed = True 348 | 349 | def get_action(self): # for behavior clonging 350 | for remote in self.remotes: 351 | remote.send(('get_action', None)) 352 | results = [remote.recv() for remote in self.remotes] 353 | return np.concatenate(results) 354 | 355 | 356 | # single env 357 | class DummyVecEnv(ShareVecEnv): 358 | def __init__(self, env_fns): 359 | self.envs = [fn() for fn in env_fns] 360 | env = self.envs[0] 361 | ShareVecEnv.__init__(self, len( 362 | env_fns), env.observation_space, env.share_observation_space, env.action_space) 363 | self.actions = None 364 | 365 | def step_async(self, actions): 366 | self.actions = actions 367 | 368 | def step_wait(self): 369 | results = [env.step(a) for (a, env) in zip(self.actions, self.envs)] 370 | obs, rews, dones, infos = map(np.array, zip(*results)) 371 | for (i, done) in enumerate(dones): 372 | if 'bool' in done.__class__.__name__: 373 | if done: 374 | obs[i] = self.envs[i].reset() 375 | else: 376 | if np.all(done): 377 | obs[i] = self.envs[i].reset() 378 | 379 | self.actions = None 380 | return obs, rews, dones, infos 381 | 382 | def reset(self): 383 | obs = [env.reset() for env in self.envs] 384 | return np.array(obs) 385 | 386 | def get_max_step(self): 387 | return [env.max_steps for env in self.envs] 388 | 389 | def close(self): 390 | for env in self.envs: 391 | env.close() 392 | 393 | def render(self, mode="human", playeridx=None): 394 | if mode == "rgb_array": 395 | if playeridx == None: 396 | return np.array([env.render(mode=mode) for env in self.envs]) 397 | else: 398 | return np.array([env.render(mode=mode,playeridx=playeridx) for env in self.envs]) 399 | elif mode == "human": 400 | for env in self.envs: 401 | if playeridx == None: 402 | env.render(mode=mode) 403 | else: 404 | env.render(mode=mode, playeridx=playeridx) 405 | else: 406 | raise NotImplementedError 407 | 408 | class ShareDummyVecEnv(ShareVecEnv): 409 | def __init__(self, env_fns): 410 | self.envs = [fn() for fn in env_fns] 411 | env = self.envs[0] 412 | ShareVecEnv.__init__(self, len( 413 | env_fns), env.observation_space, env.share_observation_space, env.action_space) 414 | self.actions = None 415 | 416 | def step_async(self, actions): 417 | self.actions = actions 418 | 419 | def step_wait(self): 420 | results = [env.step(a) for (a, env) in zip(self.actions, self.envs)] 421 | obs, share_obs, rews, dones, infos, available_actions = map( 422 | np.array, zip(*results)) 423 | 424 | for (i, done) in enumerate(dones): 425 | if 'bool' in done.__class__.__name__: 426 | if done: 427 | obs[i], share_obs[i], available_actions[i] = self.envs[i].reset() 428 | else: 429 | if np.all(done): 430 | obs[i], share_obs[i], available_actions[i] = self.envs[i].reset() 431 | self.actions = None 432 | 433 | return obs, share_obs, rews, dones, infos, available_actions 434 | 435 | def reset(self): 436 | results = [env.reset() for env in self.envs] 437 | obs, share_obs, available_actions = map(np.array, zip(*results)) 438 | return obs, share_obs, available_actions 439 | 440 | def close(self): 441 | for env in self.envs: 442 | env.close() 443 | 444 | def render(self, mode="human"): 445 | if mode == "rgb_array": 446 | return np.array([env.render(mode=mode) for env in self.envs]) 447 | elif mode == "human": 448 | for env in self.envs: 449 | env.render(mode=mode) 450 | else: 451 | raise NotImplementedError 452 | 453 | def save_replay(self): 454 | for env in self.envs: 455 | env.save_replay() 456 | 457 | def get_action(self): # for behavior cloning 458 | results = [env.reset() for env in self.envs] 459 | return results 460 | -------------------------------------------------------------------------------- /tmarl/envs/football/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2021 The TARTRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /tmarl/envs/football/env/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """GFootball Environment.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tmarl.envs.football.env import config 22 | from gfootball.env import football_env 23 | from gfootball.env import observation_preprocessing 24 | from gfootball.env import wrappers 25 | 26 | 27 | def _process_reward_wrappers(env, rewards): 28 | assert 'scoring' in rewards.split(',') 29 | if 'checkpoints' in rewards.split(','): 30 | env = wrappers.CheckpointRewardWrapper(env) 31 | return env 32 | 33 | 34 | def _process_representation_wrappers(env, representation, channel_dimensions): 35 | """Wraps with necessary representation wrappers. 36 | 37 | Args: 38 | env: A GFootball gym environment. 39 | representation: See create_environment.representation comment. 40 | channel_dimensions: (width, height) tuple that represents the dimensions of 41 | SMM or pixels representation. 42 | Returns: 43 | Google Research Football environment. 44 | """ 45 | if representation.startswith('pixels'): 46 | env = wrappers.PixelsStateWrapper(env, 'gray' in representation, 47 | channel_dimensions) 48 | elif representation == 'simple115': 49 | env = wrappers.Simple115StateWrapper(env) 50 | elif representation == 'simple115v2': 51 | env = wrappers.Simple115StateWrapper(env, True) 52 | elif representation == 'extracted': 53 | env = wrappers.SMMWrapper(env, channel_dimensions) 54 | elif representation == 'raw': 55 | pass 56 | else: 57 | raise ValueError('Unsupported representation: {}'.format(representation)) 58 | return env 59 | 60 | 61 | def _apply_output_wrappers(env, rewards, representation, channel_dimensions, 62 | apply_single_agent_wrappers, stacked): 63 | """Wraps with necessary wrappers modifying the output of the environment. 64 | 65 | Args: 66 | env: A GFootball gym environment. 67 | rewards: What rewards to apply. 68 | representation: See create_environment.representation comment. 69 | channel_dimensions: (width, height) tuple that represents the dimensions of 70 | SMM or pixels representation. 71 | apply_single_agent_wrappers: Whether to reduce output to single agent case. 72 | stacked: Should observations be stacked. 73 | Returns: 74 | Google Research Football environment. 75 | """ 76 | env = _process_reward_wrappers(env, rewards) 77 | env = _process_representation_wrappers(env, representation, 78 | channel_dimensions) 79 | if apply_single_agent_wrappers: 80 | if representation != 'raw': 81 | env = wrappers.SingleAgentObservationWrapper(env) 82 | env = wrappers.SingleAgentRewardWrapper(env) 83 | if stacked: 84 | env = wrappers.FrameStack(env, 4) 85 | env = wrappers.GetStateWrapper(env) 86 | return env 87 | 88 | 89 | def create_environment(env_name='', 90 | stacked=False, 91 | representation='extracted', 92 | rewards='scoring', 93 | write_goal_dumps=False, 94 | write_full_episode_dumps=False, 95 | render=False, 96 | write_video=False, 97 | dump_frequency=1, 98 | logdir='', 99 | extra_players=None, 100 | number_of_left_players_agent_controls=1, 101 | number_of_right_players_agent_controls=0, 102 | channel_dimensions=( 103 | observation_preprocessing.SMM_WIDTH, 104 | observation_preprocessing.SMM_HEIGHT), 105 | other_config_options={}): 106 | """Creates a Google Research Football environment. 107 | 108 | Args: 109 | env_name: a name of a scenario to run, e.g. "11_vs_11_stochastic". 110 | The list of scenarios can be found in directory "scenarios". 111 | stacked: If True, stack 4 observations, otherwise, only the last 112 | observation is returned by the environment. 113 | Stacking is only possible when representation is one of the following: 114 | "pixels", "pixels_gray" or "extracted". 115 | In that case, the stacking is done along the last (i.e. channel) 116 | dimension. 117 | representation: String to define the representation used to build 118 | the observation. It can be one of the following: 119 | 'pixels': the observation is the rendered view of the football field 120 | downsampled to 'channel_dimensions'. The observation size is: 121 | 'channel_dimensions'x3 (or 'channel_dimensions'x12 when "stacked" is 122 | True). 123 | 'pixels_gray': the observation is the rendered view of the football field 124 | in gray scale and downsampled to 'channel_dimensions'. The observation 125 | size is 'channel_dimensions'x1 (or 'channel_dimensions'x4 when stacked 126 | is True). 127 | 'extracted': also referred to as super minimap. The observation is 128 | composed of 4 planes of size 'channel_dimensions'. 129 | Its size is then 'channel_dimensions'x4 (or 'channel_dimensions'x16 when 130 | stacked is True). 131 | The first plane P holds the position of players on the left 132 | team, P[y,x] is 255 if there is a player at position (x,y), otherwise, 133 | its value is 0. 134 | The second plane holds in the same way the position of players 135 | on the right team. 136 | The third plane holds the position of the ball. 137 | The last plane holds the active player. 138 | 'simple115'/'simple115v2': the observation is a vector of size 115. 139 | It holds: 140 | - the ball_position and the ball_direction as (x,y,z) 141 | - one hot encoding of who controls the ball. 142 | [1, 0, 0]: nobody, [0, 1, 0]: left team, [0, 0, 1]: right team. 143 | - one hot encoding of size 11 to indicate who is the active player 144 | in the left team. 145 | - 11 (x,y) positions for each player of the left team. 146 | - 11 (x,y) motion vectors for each player of the left team. 147 | - 11 (x,y) positions for each player of the right team. 148 | - 11 (x,y) motion vectors for each player of the right team. 149 | - one hot encoding of the game mode. Vector of size 7 with the 150 | following meaning: 151 | {NormalMode, KickOffMode, GoalKickMode, FreeKickMode, 152 | CornerMode, ThrowInMode, PenaltyMode}. 153 | Can only be used when the scenario is a flavor of normal game 154 | (i.e. 11 versus 11 players). 155 | rewards: Comma separated list of rewards to be added. 156 | Currently supported rewards are 'scoring' and 'checkpoints'. 157 | write_goal_dumps: whether to dump traces up to 200 frames before goals. 158 | write_full_episode_dumps: whether to dump traces for every episode. 159 | render: whether to render game frames. 160 | Must be enable when rendering videos or when using pixels 161 | representation. 162 | write_video: whether to dump videos when a trace is dumped. 163 | dump_frequency: how often to write dumps/videos (in terms of # of episodes) 164 | Sub-sample the episodes for which we dump videos to save some disk space. 165 | logdir: directory holding the logs. 166 | extra_players: A list of extra players to use in the environment. 167 | Each player is defined by a string like: 168 | "$player_name:left_players=?,right_players=?,$param1=?,$param2=?...." 169 | number_of_left_players_agent_controls: Number of left players an agent 170 | controls. 171 | number_of_right_players_agent_controls: Number of right players an agent 172 | controls. 173 | channel_dimensions: (width, height) tuple that represents the dimensions of 174 | SMM or pixels representation. 175 | other_config_options: dict that allows directly setting other options in 176 | the Config 177 | Returns: 178 | Google Research Football environment. 179 | """ 180 | assert env_name 181 | 182 | scenario_config = config.Config({'level': env_name}).ScenarioConfig() 183 | players = [('agent:left_players=%d,right_players=%d' % ( 184 | number_of_left_players_agent_controls, 185 | number_of_right_players_agent_controls))] 186 | 187 | # Enable MultiAgentToSingleAgent wrapper? 188 | multiagent_to_singleagent = False 189 | if scenario_config.control_all_players: 190 | if (number_of_left_players_agent_controls in [0, 1] and 191 | number_of_right_players_agent_controls in [0, 1]): 192 | multiagent_to_singleagent = True 193 | players = [('agent:left_players=%d,right_players=%d' % 194 | (scenario_config.controllable_left_players 195 | if number_of_left_players_agent_controls else 0, 196 | scenario_config.controllable_right_players 197 | if number_of_right_players_agent_controls else 0))] 198 | 199 | if extra_players is not None: 200 | players.extend(extra_players) 201 | config_values = { 202 | 'dump_full_episodes': write_full_episode_dumps, 203 | 'dump_scores': write_goal_dumps, 204 | 'players': players, 205 | 'level': env_name, 206 | 'tracesdir': logdir, 207 | 'write_video': write_video, 208 | } 209 | config_values.update(other_config_options) 210 | c = config.Config(config_values) 211 | 212 | env = football_env.FootballEnv(c) 213 | if multiagent_to_singleagent: 214 | env = wrappers.MultiAgentToSingleAgent( 215 | env, number_of_left_players_agent_controls, 216 | number_of_right_players_agent_controls) 217 | if dump_frequency > 1: 218 | env = wrappers.PeriodicDumpWriter(env, dump_frequency, render) 219 | elif render: 220 | env.render() 221 | env = _apply_output_wrappers( 222 | env, rewards, representation, channel_dimensions, 223 | (number_of_left_players_agent_controls + 224 | number_of_right_players_agent_controls == 1), stacked) 225 | return env 226 | 227 | 228 | def create_remote_environment( 229 | username, 230 | token, 231 | model_name='', 232 | track='', 233 | stacked=False, 234 | representation='raw', 235 | rewards='scoring', 236 | channel_dimensions=( 237 | observation_preprocessing.SMM_WIDTH, 238 | observation_preprocessing.SMM_HEIGHT), 239 | include_rendering=False): 240 | """Creates a remote Google Research Football environment. 241 | 242 | Args: 243 | username: User name. 244 | token: User token. 245 | model_name: A model identifier to be displayed on the leaderboard. 246 | track: which competition track to connect to. 247 | stacked: If True, stack 4 observations, otherwise, only the last 248 | observation is returned by the environment. 249 | Stacking is only possible when representation is one of the following: 250 | "pixels", "pixels_gray" or "extracted". 251 | In that case, the stacking is done along the last (i.e. channel) 252 | dimension. 253 | representation: See create_environment.representation comment. 254 | rewards: Comma separated list of rewards to be added. 255 | Currently supported rewards are 'scoring' and 'checkpoints'. 256 | channel_dimensions: (width, height) tuple that represents the dimensions of 257 | SMM or pixels representation. 258 | include_rendering: Whether to return frame as part of the output. 259 | Returns: 260 | Google Research Football environment. 261 | """ 262 | from gfootball.env import remote_football_env 263 | env = remote_football_env.RemoteFootballEnv( 264 | username, token, model_name=model_name, track=track, 265 | include_rendering=include_rendering) 266 | env = _apply_output_wrappers( 267 | env, rewards, representation, channel_dimensions, 268 | env._config.number_of_players_agent_controls() == 1, stacked) 269 | return env 270 | -------------------------------------------------------------------------------- /tmarl/envs/football/env/config.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """Config loader.""" 17 | 18 | from __future__ import print_function 19 | 20 | import copy 21 | 22 | from absl import flags 23 | 24 | import gfootball_engine as libgame 25 | 26 | FLAGS = flags.FLAGS 27 | 28 | def parse_player_definition(definition): 29 | """Parses player definition. 30 | 31 | An example of player definition is: "agent:players=4" or "replay:path=...". 32 | 33 | Args: 34 | definition: a string defining a player 35 | 36 | Returns: 37 | A tuple (name, dict). 38 | """ 39 | name = definition 40 | d = {'left_players': 0, 41 | 'right_players': 0} 42 | if ':' in definition: 43 | (name, params) = definition.split(':') 44 | for param in params.split(','): 45 | (key, value) = param.split('=') 46 | d[key] = value 47 | if d['left_players'] == 0 and d['right_players'] == 0: 48 | d['left_players'] = 1 49 | return name, d 50 | 51 | 52 | def count_players(definition): 53 | """Returns a number of players given a definition.""" 54 | _, player_definition = parse_player_definition(definition) 55 | return (int(player_definition['left_players']) + 56 | int(player_definition['right_players'])) 57 | 58 | 59 | def count_left_players(definition): 60 | """Returns a number of left players given a definition.""" 61 | return int(parse_player_definition(definition)[1]['left_players']) 62 | 63 | 64 | def count_right_players(definition): 65 | """Returns a number of players given a definition.""" 66 | return int(parse_player_definition(definition)[1]['right_players']) 67 | 68 | 69 | def get_agent_number_of_players(players): 70 | """Returns a total number of players controlled by an agent.""" 71 | return sum([count_players(player) for player in players 72 | if player.startswith('agent')]) 73 | 74 | 75 | class Config(object): 76 | 77 | def __init__(self, values=None): 78 | self._values = { 79 | 'action_set': 'default', 80 | 'custom_display_stats': None, 81 | 'display_game_stats': True, 82 | 'dump_full_episodes': False, 83 | 'dump_scores': False, 84 | 'players': ['agent:left_players=1'], 85 | 'level': '11_vs_11_stochastic', 86 | 'physics_steps_per_frame': 10, 87 | 'render_resolution_x': 1280, 88 | 'real_time': False, 89 | 'tracesdir': '/tmp/dumps', 90 | 'video_format': 'avi', 91 | 'video_quality_level': 0, # 0 - low, 1 - medium, 2 - high 92 | 'write_video': False 93 | } 94 | self._values['render_resolution_y'] = int( 95 | 0.5625 * self._values['render_resolution_x']) 96 | if values: 97 | self._values.update(values) 98 | self.NewScenario() 99 | 100 | def number_of_left_players(self): 101 | return sum([count_left_players(player) 102 | for player in self._values['players']]) 103 | 104 | def number_of_right_players(self): 105 | return sum([count_right_players(player) 106 | for player in self._values['players']]) 107 | 108 | def number_of_players_agent_controls(self): 109 | return get_agent_number_of_players(self._values['players']) 110 | 111 | def __eq__(self, other): 112 | assert isinstance(other, self.__class__) 113 | return self._values == other._values and self._scenario_values == other._scenario_values 114 | 115 | def __ne__(self, other): 116 | return not self.__eq__(other) 117 | 118 | def __getitem__(self, key): 119 | if key in self._scenario_values: 120 | return self._scenario_values[key] 121 | return self._values[key] 122 | 123 | def __setitem__(self, key, value): 124 | self._values[key] = value 125 | 126 | def __contains__(self, key): 127 | return key in self._scenario_values or key in self._values 128 | 129 | def get_dictionary(self): 130 | cfg = copy.deepcopy(self._values) 131 | cfg.update(self._scenario_values) 132 | return cfg 133 | 134 | def set_scenario_value(self, key, value): 135 | """Override value of specific config key for a single episode.""" 136 | self._scenario_values[key] = value 137 | 138 | def serialize(self): 139 | return self._values 140 | 141 | def update(self, config): 142 | self._values.update(config) 143 | 144 | def ScenarioConfig(self): 145 | return self._scenario_cfg 146 | 147 | def NewScenario(self, inc = 1): 148 | if 'episode_number' not in self._values: 149 | self._values['episode_number'] = 0 150 | self._values['episode_number'] += inc 151 | self._scenario_values = {} 152 | from tmarl.envs.football.env import scenario_builder 153 | self._scenario_cfg = scenario_builder.Scenario(self).ScenarioConfig() 154 | -------------------------------------------------------------------------------- /tmarl/envs/football/env/football_env.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """Allows different types of players to play against each other.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import copy 23 | import importlib 24 | from absl import logging 25 | 26 | from tmarl.envs.football.env import config as cfg 27 | from gfootball.env import constants 28 | from gfootball.env import football_action_set 29 | from tmarl.envs.football.env import football_env_core 30 | from gfootball.env import observation_rotation 31 | import gym 32 | import numpy as np 33 | 34 | 35 | class FootballEnv(gym.Env): 36 | """Allows multiple players to play in the same environment.""" 37 | 38 | def __init__(self, config): 39 | self._config = config 40 | player_config = {'index': 0} 41 | # There can be at most one agent at a time. We need to remember its 42 | # team and the index on the team to generate observations appropriately. 43 | self._agent = None 44 | self._agent_index = -1 45 | self._agent_left_position = -1 46 | self._agent_right_position = -1 47 | self._players = self._construct_players(config['players'], player_config) 48 | self._env = football_env_core.FootballEnvCore(self._config) 49 | self._num_actions = len(football_action_set.get_action_set(self._config)) 50 | self._cached_observation = None 51 | 52 | @property 53 | def action_space(self): 54 | if self._config.number_of_players_agent_controls() > 1: 55 | return gym.spaces.MultiDiscrete( 56 | [self._num_actions] * self._config.number_of_players_agent_controls()) 57 | return gym.spaces.Discrete(self._num_actions) 58 | 59 | def _construct_players(self, definitions, config): 60 | result = [] 61 | left_position = 0 62 | right_position = 0 63 | for definition in definitions: 64 | (name, d) = cfg.parse_player_definition(definition) 65 | config_name = 'player_{}'.format(name) 66 | if config_name in config: 67 | config[config_name] += 1 68 | else: 69 | config[config_name] = 0 70 | try: 71 | player_factory = importlib.import_module( 72 | 'gfootball.env.players.{}'.format(name)) 73 | except ImportError as e: 74 | logging.error('Failed loading player "%s"', name) 75 | logging.error(e) 76 | exit(1) 77 | player_config = copy.deepcopy(config) 78 | player_config.update(d) 79 | player = player_factory.Player(player_config, self._config) 80 | if name == 'agent': 81 | assert not self._agent, 'Only one \'agent\' player allowed' 82 | self._agent = player 83 | self._agent_index = len(result) 84 | self._agent_left_position = left_position 85 | self._agent_right_position = right_position 86 | result.append(player) 87 | left_position += player.num_controlled_left_players() 88 | right_position += player.num_controlled_right_players() 89 | config['index'] += 1 90 | return result 91 | 92 | def _convert_observations(self, original, player, 93 | left_player_position, right_player_position): 94 | """Converts generic observations returned by the environment to 95 | the player specific observations. 96 | 97 | Args: 98 | original: original observations from the environment. 99 | player: player for which to generate observations. 100 | left_player_position: index into observation corresponding to the left 101 | player. 102 | right_player_position: index into observation corresponding to the right 103 | player. 104 | """ 105 | observations = [] 106 | for is_left in [True, False]: 107 | adopted = original if is_left or player.can_play_right( 108 | ) else observation_rotation.flip_observation(original, self._config) 109 | prefix = 'left' if is_left or not player.can_play_right() else 'right' 110 | position = left_player_position if is_left else right_player_position 111 | for x in range(player.num_controlled_left_players() if is_left 112 | else player.num_controlled_right_players()): 113 | o = {} 114 | for v in constants.EXPOSED_OBSERVATIONS: 115 | o[v] = copy.deepcopy(adopted[v]) 116 | assert (len(adopted[prefix + '_agent_controlled_player']) == len( 117 | adopted[prefix + '_agent_sticky_actions'])) 118 | o['designated'] = adopted[prefix + '_team_designated_player'] 119 | if position + x >= len(adopted[prefix + '_agent_controlled_player']): 120 | o['active'] = -1 121 | o['sticky_actions'] = [] 122 | else: 123 | o['active'] = ( 124 | adopted[prefix + '_agent_controlled_player'][position + x]) 125 | o['sticky_actions'] = np.array(copy.deepcopy( 126 | adopted[prefix + '_agent_sticky_actions'][position + x])) 127 | # There is no frame for players on the right ATM. 128 | if is_left and 'frame' in original: 129 | o['frame'] = original['frame'] 130 | observations.append(o) 131 | return observations 132 | 133 | def _action_to_list(self, a): 134 | if isinstance(a, np.ndarray): 135 | return a.tolist() 136 | if not isinstance(a, list): 137 | return [a] 138 | return a 139 | 140 | def _get_actions(self): 141 | obs = self._env.observation() 142 | left_actions = [] 143 | right_actions = [] 144 | left_player_position = 0 145 | right_player_position = 0 146 | for player in self._players: 147 | adopted_obs = self._convert_observations(obs, player, 148 | left_player_position, 149 | right_player_position) 150 | left_player_position += player.num_controlled_left_players() 151 | right_player_position += player.num_controlled_right_players() 152 | a = self._action_to_list(player.take_action(adopted_obs)) 153 | assert len(adopted_obs) == len( 154 | a), 'Player provided {} actions instead of {}.'.format( 155 | len(a), len(adopted_obs)) 156 | if not player.can_play_right(): 157 | for x in range(player.num_controlled_right_players()): 158 | index = x + player.num_controlled_left_players() 159 | a[index] = observation_rotation.flip_single_action( 160 | a[index], self._config) 161 | left_actions.extend(a[:player.num_controlled_left_players()]) 162 | right_actions.extend(a[player.num_controlled_left_players():]) 163 | actions = left_actions + right_actions 164 | return actions 165 | 166 | def step(self, action): 167 | action = self._action_to_list(action) 168 | if self._agent: 169 | self._agent.set_action(action) 170 | else: 171 | assert len( 172 | action 173 | ) == 0, 'step() received {} actions, but no agent is playing.'.format( 174 | len(action)) 175 | 176 | _, reward, done, info = self._env.step(self._get_actions()) 177 | score_reward = reward 178 | if self._agent: 179 | reward = ([reward] * self._agent.num_controlled_left_players() + 180 | [-reward] * self._agent.num_controlled_right_players()) 181 | self._cached_observation = None 182 | info['score_reward'] = score_reward 183 | return (self.observation(), np.array(reward, dtype=np.float32), done, info) 184 | 185 | def reset(self): 186 | self._env.reset() 187 | for player in self._players: 188 | player.reset() 189 | self._cached_observation = None 190 | return self.observation() 191 | 192 | def observation(self): 193 | if not self._cached_observation: 194 | self._cached_observation = self._env.observation() 195 | if self._agent: 196 | self._cached_observation = self._convert_observations( 197 | self._cached_observation, self._agent, 198 | self._agent_left_position, self._agent_right_position) 199 | return self._cached_observation 200 | 201 | def write_dump(self, name): 202 | return self._env.write_dump(name) 203 | 204 | def close(self): 205 | self._env.close() 206 | 207 | def get_state(self, to_pickle={}): 208 | return self._env.get_state(to_pickle) 209 | 210 | def set_state(self, state): 211 | self._cached_observation = None 212 | return self._env.set_state(state) 213 | 214 | def tracker_setup(self, start, end): 215 | self._env.tracker_setup(start, end) 216 | 217 | def render(self, mode='human'): 218 | self._cached_observation = None 219 | return self._env.render(mode=mode) 220 | 221 | def disable_render(self): 222 | self._cached_observation = None 223 | return self._env.disable_render() 224 | -------------------------------------------------------------------------------- /tmarl/envs/football/env/scenario_builder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """Class responsible for generating scenarios.""" 17 | 18 | import importlib 19 | import os 20 | import pkgutil 21 | import random 22 | import sys 23 | 24 | from absl import flags 25 | from absl import logging 26 | 27 | import gfootball_engine as libgame 28 | 29 | Player = libgame.FormationEntry 30 | Role = libgame.e_PlayerRole 31 | Team = libgame.e_Team 32 | 33 | FLAGS = flags.FLAGS 34 | 35 | 36 | def all_scenarios(): 37 | path = os.path.abspath(__file__) 38 | path = os.path.join(os.path.dirname(os.path.dirname(path)), 'scenarios') 39 | scenarios = [] 40 | for m in pkgutil.iter_modules([path]): 41 | # There was API change in pkgutil between Python 3.5 and 3.6... 42 | if m.__class__ == tuple: 43 | scenarios.append(m[1]) 44 | else: 45 | scenarios.append(m.name) 46 | return scenarios 47 | 48 | 49 | class Scenario(object): 50 | 51 | def __init__(self, config): 52 | # Game config controls C++ engine and is derived from the main config. 53 | self._scenario_cfg = libgame.ScenarioConfig.make() 54 | self._config = config 55 | self._active_team = Team.e_Left 56 | scenario = None 57 | try: 58 | scenario = importlib.import_module('tmarl.envs.football.scenarios.{}'.format(config['level'])) 59 | except ImportError as e: 60 | logging.error('Loading scenario "%s" failed' % config['level']) 61 | logging.error(e) 62 | sys.exit(1) 63 | scenario.build_scenario(self) 64 | self.SetTeam(libgame.e_Team.e_Left) 65 | self._FakePlayersForEmptyTeam(self._scenario_cfg.left_team) 66 | self.SetTeam(libgame.e_Team.e_Right) 67 | self._FakePlayersForEmptyTeam(self._scenario_cfg.right_team) 68 | self._BuildScenarioConfig() 69 | 70 | def _FakePlayersForEmptyTeam(self, team): 71 | if len(team) == 0: 72 | self.AddPlayer(-1.000000, 0.420000, libgame.e_PlayerRole.e_PlayerRole_GK, True) 73 | 74 | def _BuildScenarioConfig(self): 75 | """Builds scenario config from gfootball.environment config.""" 76 | self._scenario_cfg.real_time = self._config['real_time'] 77 | self._scenario_cfg.left_agents = self._config.number_of_left_players() 78 | self._scenario_cfg.right_agents = self._config.number_of_right_players() 79 | # This is needed to record 'game_engine_random_seed' in the dump. 80 | if 'game_engine_random_seed' not in self._config._values: 81 | self._config.set_scenario_value('game_engine_random_seed', 82 | random.randint(0, 2000000000)) 83 | if not self._scenario_cfg.deterministic: 84 | self._scenario_cfg.game_engine_random_seed = ( 85 | self._config['game_engine_random_seed']) 86 | if 'reverse_team_processing' not in self._config: 87 | self._config['reverse_team_processing'] = ( 88 | bool(self._config['game_engine_random_seed'] % 2)) 89 | if 'reverse_team_processing' in self._config: 90 | self._scenario_cfg.reverse_team_processing = ( 91 | self._config['reverse_team_processing']) 92 | 93 | def config(self): 94 | return self._scenario_cfg 95 | 96 | def SetTeam(self, team): 97 | self._active_team = team 98 | 99 | def AddPlayer(self, x, y, role, lazy=False, controllable=True): 100 | """Build player for the current scenario. 101 | 102 | Args: 103 | x: x coordinate of the player in the range [-1, 1]. 104 | y: y coordinate of the player in the range [-0.42, 0.42]. 105 | role: Player's role in the game (goal keeper etc.). 106 | lazy: Computer doesn't perform any automatic actions for lazy player. 107 | controllable: Whether player can be controlled. 108 | """ 109 | player = Player(x, y, role, lazy, controllable) 110 | if self._active_team == Team.e_Left: 111 | self._scenario_cfg.left_team.append(player) 112 | else: 113 | self._scenario_cfg.right_team.append(player) 114 | 115 | def SetBallPosition(self, ball_x, ball_y): 116 | self._scenario_cfg.ball_position[0] = ball_x 117 | self._scenario_cfg.ball_position[1] = ball_y 118 | 119 | def EpisodeNumber(self): 120 | return self._config['episode_number'] 121 | 122 | def ScenarioConfig(self): 123 | return self._scenario_cfg 124 | -------------------------------------------------------------------------------- /tmarl/envs/football/env/script_helpers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """Set of functions used by command line scripts.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tmarl.envs.football.env import config 23 | from gfootball.env import football_action_set 24 | from tmarl.envs.football.env import football_env 25 | from gfootball.env import observation_processor 26 | 27 | import copy 28 | import six.moves.cPickle 29 | import os 30 | import tempfile 31 | 32 | 33 | class ScriptHelpers(object): 34 | """Set of methods used by command line scripts.""" 35 | 36 | def __init__(self): 37 | pass 38 | 39 | def __modify_trace(self, replay, fps): 40 | """Adopt replay to the new framerate and add additional steps at the end.""" 41 | trace = [] 42 | min_fps = replay[0]['debug']['config']['physics_steps_per_frame'] 43 | assert fps % min_fps == 0, ( 44 | 'Trace has to be rendered in framerate being multiple of {}'.format( 45 | min_fps)) 46 | assert fps <= 100, ('Framerate of up to 100 is supported') 47 | empty_steps = int(fps / min_fps) - 1 48 | for f in replay: 49 | trace.append(f) 50 | idle_step = copy.deepcopy(f) 51 | idle_step['debug']['action'] = [football_action_set.action_idle 52 | ] * len(f['debug']['action']) 53 | for _ in range(empty_steps): 54 | trace.append(idle_step) 55 | # Add some empty steps at the end, so that we can record videos. 56 | for _ in range(10): 57 | trace.append(idle_step) 58 | return trace 59 | 60 | def __build_players(self, dump_file, spec): 61 | players = [] 62 | for player in spec: 63 | players.extend(['replay:path={},left_players=1'.format( 64 | dump_file)] * config.count_left_players(player)) 65 | players.extend(['replay:path={},right_players=1'.format( 66 | dump_file)] * config.count_right_players(player)) 67 | return players 68 | 69 | def load_dump(self, dump_file): 70 | dump = [] 71 | with open(dump_file, 'rb') as in_fd: 72 | while True: 73 | try: 74 | step = six.moves.cPickle.load(in_fd) 75 | except EOFError: 76 | return dump 77 | dump.append(step) 78 | 79 | def dump_to_txt(self, dump_file, output, include_debug): 80 | with open(output, 'w') as out_fd: 81 | dump = self.load_dump(dump_file) 82 | if not include_debug: 83 | for s in dump: 84 | if 'debug' in s: 85 | del s['debug'] 86 | with open(output, 'w') as f: 87 | f.write(str(dump)) 88 | 89 | def dump_to_video(self, dump_file): 90 | dump = self.load_dump(dump_file) 91 | cfg = config.Config(dump[0]['debug']['config']) 92 | cfg['dump_full_episodes'] = True 93 | cfg['write_video'] = True 94 | cfg['display_game_stats'] = True 95 | processor = observation_processor.ObservationProcessor(cfg) 96 | processor.write_dump('episode_done') 97 | for frame in dump: 98 | processor.update(frame) 99 | 100 | def replay(self, dump, fps=10, config_update={}, directory=None, render=True): 101 | replay = self.load_dump(dump) 102 | trace = self.__modify_trace(replay, fps) 103 | fd, temp_path = tempfile.mkstemp(suffix='.dump') 104 | with open(temp_path, 'wb') as f: 105 | for step in trace: 106 | six.moves.cPickle.dump(step, f) 107 | assert replay[0]['debug']['frame_cnt'] == 0, ( 108 | 'Trace does not start from the beginning of the episode, can not replay') 109 | cfg = config.Config(replay[0]['debug']['config']) 110 | cfg['players'] = self.__build_players(temp_path, cfg['players']) 111 | config_update['physics_steps_per_frame'] = int(100 / fps) 112 | config_update['real_time'] = False 113 | if directory: 114 | config_update['tracesdir'] = directory 115 | config_update['write_video'] = True 116 | # my edition 117 | # config_update['display_game_stats'] = False 118 | # config_update['video_quality_level'] = 2 119 | cfg.update(config_update) 120 | env = football_env.FootballEnv(cfg) 121 | if render: 122 | env.render() 123 | env.reset() 124 | done = False 125 | try: 126 | while not done: 127 | _, _, done, _ = env.step([]) 128 | except KeyboardInterrupt: 129 | env.write_dump('shutdown') 130 | exit(1) 131 | os.close(fd) 132 | -------------------------------------------------------------------------------- /tmarl/envs/football/football.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import gym 4 | from ray.rllib.env.multi_agent_env import MultiAgentEnv 5 | 6 | import tmarl.envs.football.env as football_env 7 | 8 | class RllibGFootball(MultiAgentEnv): 9 | """An example of a wrapper for GFootball to make it compatible with rllib.""" 10 | 11 | def __init__(self, all_args, rank, log_dir=None, isEval=False): 12 | 13 | self.num_agents = all_args.num_agents 14 | self.num_rollout = all_args.n_rollout_threads 15 | self.isEval = isEval 16 | self.rank = rank 17 | # create env 18 | # need_render = (rank == 0) and isEval 19 | need_render = (rank == 0) 20 | # and (not isEval or self.use_behavior_cloning) 21 | self.env = football_env.create_environment( 22 | env_name=all_args.scenario_name, stacked=False, 23 | logdir=log_dir, 24 | representation=all_args.representation, 25 | rewards='scoring' if isEval else all_args.rewards, 26 | write_goal_dumps=False, 27 | write_full_episode_dumps=need_render, 28 | render=need_render, 29 | dump_frequency=1 if need_render else 0, 30 | number_of_left_players_agent_controls=self.num_agents, 31 | number_of_right_players_agent_controls=0, 32 | other_config_options={'action_set':'full'}) 33 | # state 34 | self.last_loffside = np.zeros(11) 35 | self.last_roffside = np.zeros(11) 36 | # dimension 37 | self.action_size = 33 38 | 39 | if all_args.scenario_name == "11_vs_11_kaggle": 40 | self.avail_size = 20 41 | else: 42 | self.avail_size = 19 43 | 44 | if all_args.representation == 'raw': 45 | obs_space_dim = 268 46 | obs_space_low = np.zeros(obs_space_dim) - 1e6 47 | obs_space_high = np.zeros(obs_space_dim) + 1e6 48 | obs_space_type = 'float64' 49 | else: 50 | raise NotImplementedError 51 | 52 | self.action_space = [gym.spaces.Discrete( 53 | self.action_size) for _ in range(self.num_agents)] 54 | self.observation_space = [gym.spaces.Box( 55 | low=obs_space_low, 56 | high=obs_space_high, 57 | dtype=obs_space_type) for _ in range(self.num_agents)] 58 | self.share_observation_space = [gym.spaces.Box( 59 | low=obs_space_low, 60 | high=obs_space_high, 61 | dtype=obs_space_type) for _ in range(self.num_agents)] 62 | 63 | def reset(self): 64 | 65 | # available actions 66 | avail_actions = np.ones([self.num_agents, self.action_size]) 67 | avail_actions[:, self.avail_size:] = 0 68 | # state 69 | self.last_loffside = np.zeros(11) 70 | self.last_roffside = np.zeros(11) 71 | # obs 72 | raw_obs = self.env.reset() 73 | raw_obs = self._notFullGame(raw_obs) 74 | obs = self.raw2vec(raw_obs) 75 | share_obs = obs.copy() 76 | 77 | return obs, share_obs, avail_actions 78 | 79 | def step(self, actions): 80 | # step 81 | actions = np.argmax(actions, axis=-1) 82 | raw_o, r, d, info = self.env.step(actions.astype('int32')) 83 | raw_o = self._notFullGame(raw_o) 84 | obs = self.raw2vec(raw_o) 85 | share_obs = obs.copy() 86 | # available actions 87 | avail_actions = np.ones([self.num_agents, self.action_size]) 88 | avail_actions[:, self.avail_size:] = 0 89 | # translate to specific form 90 | rewards = [] 91 | infos, dones = [], [] 92 | for i in range(self.num_agents): 93 | infos.append(info) 94 | dones.append(d) 95 | reward = r[i] if self.num_agents > 1 else r 96 | reward = -0.01 if d and reward < 1 and not self.isEval else reward 97 | rewards.append(reward) 98 | rewards = np.expand_dims(np.array(rewards), axis=1) 99 | 100 | return obs, share_obs, rewards, dones, infos, avail_actions 101 | 102 | def seed(self, seed=None): 103 | if seed is None: 104 | np.random.seed(1) 105 | else: 106 | np.random.seed(seed) 107 | 108 | def close(self): 109 | self.env.close() 110 | 111 | def raw2vec(self, raw_obs): 112 | obs = [] 113 | ally = np.array(raw_obs[0]['left_team']) 114 | ally_d = np.array(raw_obs[0]['left_team_direction']) 115 | enemy = np.array(raw_obs[0]['right_team']) 116 | enemy_d = np.array(raw_obs[0]['right_team_direction']) 117 | lo, ro = self.get_offside(raw_obs[0]) 118 | for a in range(self.num_agents): 119 | # prepocess 120 | me = ally[int(raw_obs[a]['active'])] 121 | ball = raw_obs[a]['ball'][:2] 122 | ball_dist = np.linalg.norm(me - ball) 123 | enemy_dist = np.linalg.norm(me - enemy, axis=-1) 124 | to_enemy = enemy - me 125 | to_ally = ally - me 126 | to_ball = ball - me 127 | 128 | o = [] 129 | # shape = 0 130 | o.extend(ally.flatten()) 131 | o.extend(ally_d.flatten()) 132 | o.extend(enemy.flatten()) 133 | o.extend(enemy_d.flatten()) 134 | # shape = 88 135 | o.extend(raw_obs[a]['ball']) 136 | o.extend(raw_obs[a]['ball_direction']) 137 | # shape = 94 138 | if raw_obs[a]['ball_owned_team'] == -1: 139 | o.extend([1, 0, 0]) 140 | if raw_obs[a]['ball_owned_team'] == 0: 141 | o.extend([0, 1, 0]) 142 | if raw_obs[a]['ball_owned_team'] == 1: 143 | o.extend([0, 0, 1]) 144 | # shape = 97 145 | active = [0] * 11 146 | active[raw_obs[a]['active']] = 1 147 | o.extend(active) 148 | # shape = 108 149 | game_mode = [0] * 7 150 | game_mode[raw_obs[a]['game_mode']] = 1 151 | o.extend(game_mode) 152 | # shape = 115 153 | o.extend(raw_obs[a]['sticky_actions'][:10]) 154 | # shape = 125) 155 | ball_dist = 1 if ball_dist > 1 else ball_dist 156 | o.extend([ball_dist]) 157 | # shape = 126) 158 | o.extend(raw_obs[a]['left_team_tired_factor']) 159 | # shape = 137) 160 | o.extend(raw_obs[a]['left_team_yellow_card']) 161 | # shape = 148) 162 | o.extend(raw_obs[a]['left_team_active']) # red cards 163 | # shape = 159) 164 | o.extend(lo) # ! 165 | # shape = 170) 166 | o.extend(ro) # ! 167 | # shape = 181) 168 | o.extend(enemy_dist) 169 | # shape = 192) 170 | to_ally[:, 0] /= 2 171 | o.extend(to_ally.flatten()) 172 | # shape = 214) 173 | to_enemy[:, 0] /= 2 174 | o.extend(to_enemy.flatten()) 175 | # shape = 236) 176 | to_ball[0] /= 2 177 | o.extend(to_ball.flatten()) 178 | # shape = 238) 179 | 180 | steps_left = raw_obs[a]['steps_left'] 181 | o.extend([1.0 * steps_left / 3001]) # steps left till end 182 | if steps_left > 1500: 183 | steps_left -= 1501 # steps left till halfend 184 | steps_left = 1.0 * min(steps_left, 300.0) # clip 185 | steps_left /= 300.0 186 | o.extend([steps_left]) 187 | 188 | score_ratio = 1.0 * \ 189 | (raw_obs[a]['score'][0] - raw_obs[a]['score'][1]) 190 | score_ratio /= 5.0 191 | score_ratio = min(score_ratio, 1.0) 192 | score_ratio = max(-1.0, score_ratio) 193 | o.extend([score_ratio]) 194 | # shape = 241 195 | o.extend([0.0] * 27) 196 | # shape = 268 197 | 198 | obs.append(o) 199 | 200 | return np.array(obs) 201 | 202 | def get_offside(self, obs): 203 | ball = np.array(obs['ball'][:2]) 204 | ally = np.array(obs['left_team']) 205 | enemy = np.array(obs['right_team']) 206 | 207 | if obs['game_mode'] != 0: 208 | self.last_loffside = np.zeros(11, np.float32) 209 | self.last_roffside = np.zeros(11, np.float32) 210 | return np.zeros(11, np.float32), np.zeros(11, np.float32) 211 | 212 | need_recalc = False 213 | effective_ownball_team = -1 214 | effective_ownball_player = -1 215 | 216 | if obs['ball_owned_team'] > -1: 217 | effective_ownball_team = obs['ball_owned_team'] 218 | effective_ownball_player = obs['ball_owned_player'] 219 | need_recalc = True 220 | else: 221 | ally_dist = np.linalg.norm(ball - ally, axis=-1) 222 | enemy_dist = np.linalg.norm(ball - enemy, axis=-1) 223 | if np.min(ally_dist) < np.min(enemy_dist): 224 | if np.min(ally_dist) < 0.017: 225 | need_recalc = True 226 | effective_ownball_team = 0 227 | effective_ownball_player = np.argmin(ally_dist) 228 | elif np.min(enemy_dist) < np.min(ally_dist): 229 | if np.min(enemy_dist) < 0.017: 230 | need_recalc = True 231 | effective_ownball_team = 1 232 | effective_ownball_player = np.argmin(enemy_dist) 233 | 234 | if not need_recalc: 235 | return self.last_loffside, self.last_roffside 236 | 237 | left_offside = np.zeros(11, np.float32) 238 | right_offside = np.zeros(11, np.float32) 239 | 240 | if effective_ownball_team == 0: 241 | right_xs = [obs['right_team'][k][0] for k in range(1, 11)] 242 | right_xs = np.array(right_xs) 243 | right_xs.sort() 244 | 245 | for k in range(1, 11): 246 | if obs['left_team'][k][0] > right_xs[-1] and k != effective_ownball_player \ 247 | and obs['left_team'][k][0] > 0.0: 248 | left_offside[k] = 1.0 249 | else: 250 | left_xs = [obs['left_team'][k][0] for k in range(1, 11)] 251 | left_xs = np.array(left_xs) 252 | left_xs.sort() 253 | 254 | for k in range(1, 11): 255 | if obs['right_team'][k][0] < left_xs[0] and k != effective_ownball_player \ 256 | and obs['right_team'][k][0] < 0.0: 257 | right_offside[k] = 1.0 258 | 259 | self.last_loffside = left_offside 260 | self.last_roffside = right_offside 261 | 262 | return left_offside, right_offside 263 | 264 | 265 | def _notFullGame(self, raw_obs): 266 | # use this function when there are less than 11 players in the scenario 267 | left_ok = len(raw_obs[0]['left_team']) == 11 268 | right_ok = len(raw_obs[0]['right_team']) == 11 269 | if left_ok and right_ok: 270 | return raw_obs 271 | # set player's coordinate at (-1,0), set player's velocity as (0,0) 272 | for obs in raw_obs: 273 | obs['left_team'] = np.array(obs['left_team']) 274 | obs['right_team'] = np.array(obs['right_team']) 275 | obs['left_team_direction'] = np.array(obs['left_team_direction']) 276 | obs['right_team_direction'] = np.array(obs['right_team_direction']) 277 | while len(obs['left_team']) < 11: 278 | obs['left_team'] = np.concatenate([obs['left_team'], np.array([[-1,0]])], axis=0) 279 | obs['left_team_direction'] = np.concatenate([obs['left_team_direction'], np.zeros([1,2])], axis=0) 280 | obs['left_team_tired_factor'] = np.concatenate([obs['left_team_tired_factor'], np.zeros(1)], axis=0) 281 | obs['left_team_yellow_card'] = np.concatenate([obs['left_team_yellow_card'], np.zeros(1)], axis=0) 282 | obs['left_team_active'] = np.concatenate([obs['left_team_active'], np.ones(1)], axis=0) 283 | while len(obs['right_team']) < 11: 284 | obs['right_team'] = np.concatenate([obs['right_team'], np.array([[-1,0]])], axis=0) 285 | obs['right_team_direction'] = np.concatenate([obs['right_team_direction'], np.zeros([1,2])], axis=0) 286 | return raw_obs -------------------------------------------------------------------------------- /tmarl/envs/football/scenarios/11_vs_11_kaggle.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | 17 | 18 | 19 | 20 | from . import * 21 | 22 | 23 | def build_scenario(builder): 24 | builder.config().game_duration = 3000 25 | builder.config().second_half = 1500 26 | builder.config().right_team_difficulty = 1.0 27 | builder.config().left_team_difficulty = 1.0 28 | builder.config().deterministic = False 29 | if builder.EpisodeNumber() % 2 == 0: 30 | first_team = Team.e_Left 31 | second_team = Team.e_Right 32 | else: 33 | first_team = Team.e_Right 34 | second_team = Team.e_Left 35 | builder.SetTeam(first_team) 36 | builder.AddPlayer(-1.000000, 0.000000, e_PlayerRole_GK, controllable=False) 37 | builder.AddPlayer(0.000000, 0.020000, e_PlayerRole_RM) 38 | builder.AddPlayer(0.000000, -0.020000, e_PlayerRole_CF) 39 | builder.AddPlayer(-0.422000, -0.19576, e_PlayerRole_LB) 40 | builder.AddPlayer(-0.500000, -0.06356, e_PlayerRole_CB) 41 | builder.AddPlayer(-0.500000, 0.063559, e_PlayerRole_CB) 42 | builder.AddPlayer(-0.422000, 0.195760, e_PlayerRole_RB) 43 | builder.AddPlayer(-0.184212, -0.10568, e_PlayerRole_CM) 44 | builder.AddPlayer(-0.267574, 0.000000, e_PlayerRole_CM) 45 | builder.AddPlayer(-0.184212, 0.105680, e_PlayerRole_CM) 46 | builder.AddPlayer(-0.010000, -0.21610, e_PlayerRole_LM) 47 | builder.SetTeam(second_team) 48 | builder.AddPlayer(-1.000000, 0.000000, e_PlayerRole_GK, controllable=False) 49 | builder.AddPlayer(-0.050000, 0.000000, e_PlayerRole_RM) 50 | builder.AddPlayer(-0.010000, 0.216102, e_PlayerRole_CF) 51 | builder.AddPlayer(-0.422000, -0.19576, e_PlayerRole_LB) 52 | builder.AddPlayer(-0.500000, -0.06356, e_PlayerRole_CB) 53 | builder.AddPlayer(-0.500000, 0.063559, e_PlayerRole_CB) 54 | builder.AddPlayer(-0.422000, 0.195760, e_PlayerRole_RB) 55 | builder.AddPlayer(-0.184212, -0.10568, e_PlayerRole_CM) 56 | builder.AddPlayer(-0.267574, 0.000000, e_PlayerRole_CM) 57 | builder.AddPlayer(-0.184212, 0.105680, e_PlayerRole_CM) 58 | builder.AddPlayer(-0.010000, -0.21610, e_PlayerRole_LM) 59 | -------------------------------------------------------------------------------- /tmarl/envs/football/scenarios/11_vs_11_lazy.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | 17 | 18 | 19 | 20 | from . import * 21 | 22 | 23 | def build_scenario(builder): 24 | builder.config().game_duration = 3000 25 | builder.config().second_half = 1500 26 | builder.config().right_team_difficulty = 1.0 27 | builder.config().left_team_difficulty = 1.0 28 | builder.config().deterministic = False 29 | if builder.EpisodeNumber() % 2 == 0: 30 | first_team = Team.e_Left 31 | second_team = Team.e_Right 32 | else: 33 | first_team = Team.e_Right 34 | second_team = Team.e_Left 35 | builder.SetTeam(first_team) 36 | builder.AddPlayer(-1.000000, 0.000000, e_PlayerRole_GK, controllable=False) 37 | builder.AddPlayer(0.000000, 0.020000, e_PlayerRole_RM) 38 | builder.AddPlayer(0.000000, -0.020000, e_PlayerRole_CF) 39 | builder.AddPlayer(-0.422000, -0.19576, e_PlayerRole_LB) 40 | builder.AddPlayer(-0.500000, -0.06356, e_PlayerRole_CB) 41 | builder.AddPlayer(-0.500000, 0.063559, e_PlayerRole_CB) 42 | builder.AddPlayer(-0.422000, 0.195760, e_PlayerRole_RB) 43 | builder.AddPlayer(-0.184212, -0.10568, e_PlayerRole_CM) 44 | builder.AddPlayer(-0.267574, 0.000000, e_PlayerRole_CM) 45 | builder.AddPlayer(-0.184212, 0.105680, e_PlayerRole_CM) 46 | builder.AddPlayer(-0.010000, -0.21610, e_PlayerRole_LM) 47 | builder.SetTeam(second_team) 48 | builder.AddPlayer(-1.000000, 0.000000, e_PlayerRole_GK, controllable=False) 49 | builder.AddPlayer(-0.050000, 0.000000, e_PlayerRole_RM, lazy=True) 50 | builder.AddPlayer(-0.010000, 0.216102, e_PlayerRole_CF, lazy=True) 51 | builder.AddPlayer(-0.422000, -0.19576, e_PlayerRole_LB, lazy=True) 52 | builder.AddPlayer(-0.500000, -0.06356, e_PlayerRole_CB, lazy=True) 53 | builder.AddPlayer(-0.500000, 0.063559, e_PlayerRole_CB, lazy=True) 54 | builder.AddPlayer(-0.422000, 0.195760, e_PlayerRole_RB, lazy=True) 55 | builder.AddPlayer(-0.184212, -0.10568, e_PlayerRole_CM, lazy=True) 56 | builder.AddPlayer(-0.267574, 0.000000, e_PlayerRole_CM, lazy=True) 57 | builder.AddPlayer(-0.184212, 0.105680, e_PlayerRole_CM, lazy=True) 58 | builder.AddPlayer(-0.010000, -0.21610, e_PlayerRole_LM, lazy=True) 59 | -------------------------------------------------------------------------------- /tmarl/envs/football/scenarios/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | 17 | 18 | import gfootball_engine as libgame 19 | e_PlayerRole_GK = libgame.e_PlayerRole.e_PlayerRole_GK 20 | e_PlayerRole_CB = libgame.e_PlayerRole.e_PlayerRole_CB 21 | e_PlayerRole_LB = libgame.e_PlayerRole.e_PlayerRole_LB 22 | e_PlayerRole_RB = libgame.e_PlayerRole.e_PlayerRole_RB 23 | e_PlayerRole_DM = libgame.e_PlayerRole.e_PlayerRole_DM 24 | e_PlayerRole_CM = libgame.e_PlayerRole.e_PlayerRole_CM 25 | e_PlayerRole_LM = libgame.e_PlayerRole.e_PlayerRole_LM 26 | e_PlayerRole_RM = libgame.e_PlayerRole.e_PlayerRole_RM 27 | e_PlayerRole_AM = libgame.e_PlayerRole.e_PlayerRole_AM 28 | e_PlayerRole_CF = libgame.e_PlayerRole.e_PlayerRole_CF 29 | Team = libgame.e_Team 30 | -------------------------------------------------------------------------------- /tmarl/envs/football/scenarios/academy_3_vs_1_with_keeper.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | 17 | 18 | 19 | 20 | from . import * 21 | 22 | 23 | def build_scenario(builder): 24 | builder.config().game_duration = 400 25 | builder.config().deterministic = False 26 | builder.config().offsides = False 27 | builder.config().end_episode_on_score = True 28 | builder.config().end_episode_on_out_of_play = True 29 | builder.config().end_episode_on_possession_change = True 30 | builder.SetBallPosition(0.62, 0.0) 31 | 32 | builder.SetTeam(Team.e_Left) 33 | builder.AddPlayer(-1.0, 0.0, e_PlayerRole_GK, controllable=False) 34 | builder.AddPlayer(0.6, 0.0, e_PlayerRole_CM) 35 | builder.AddPlayer(0.7, 0.2, e_PlayerRole_CM) 36 | builder.AddPlayer(0.7, -0.2, e_PlayerRole_CM) 37 | 38 | builder.SetTeam(Team.e_Right) 39 | builder.AddPlayer(-1.0, 0.0, e_PlayerRole_GK) 40 | builder.AddPlayer(-0.75, 0.0, e_PlayerRole_CB) 41 | -------------------------------------------------------------------------------- /tmarl/envs/football/scenarios/academy_corner.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | 17 | 18 | 19 | 20 | from . import * 21 | 22 | 23 | def build_scenario(builder): 24 | builder.config().game_duration = 400 25 | builder.config().deterministic = False 26 | builder.config().offsides = False 27 | builder.config().end_episode_on_score = True 28 | builder.config().end_episode_on_out_of_play = True 29 | builder.config().end_episode_on_possession_change = False 30 | 31 | builder.SetBallPosition(0.99, 0.41) 32 | 33 | builder.SetTeam(Team.e_Left) 34 | builder.AddPlayer(-1.0, 0.0, e_PlayerRole_GK, controllable=False) 35 | builder.AddPlayer(1.0, 0.42, e_PlayerRole_LB) 36 | builder.AddPlayer(0.7, 0.15, e_PlayerRole_CB) 37 | builder.AddPlayer(0.7, 0.05, e_PlayerRole_CB) 38 | builder.AddPlayer(0.7, -0.05, e_PlayerRole_RB) 39 | builder.AddPlayer(0.0, 0.0, e_PlayerRole_CM) 40 | builder.AddPlayer(0.6, 0.35, e_PlayerRole_CM) 41 | builder.AddPlayer(0.8, 0.07, e_PlayerRole_CM) 42 | builder.AddPlayer(0.8, -0.03, e_PlayerRole_LM) 43 | builder.AddPlayer(0.8, -0.13, e_PlayerRole_RM) 44 | builder.AddPlayer(0.7, -0.3, e_PlayerRole_CF) 45 | 46 | builder.SetTeam(Team.e_Right) 47 | builder.AddPlayer(-1.0, 0.0, e_PlayerRole_GK) 48 | builder.AddPlayer(-0.75, -0.18, e_PlayerRole_LB) 49 | builder.AddPlayer(-0.75, -0.08, e_PlayerRole_CB) 50 | builder.AddPlayer(-0.75, 0.02, e_PlayerRole_CB) 51 | builder.AddPlayer(-1.0, -0.1, e_PlayerRole_RB) 52 | builder.AddPlayer(-0.8, -0.25, e_PlayerRole_CM) 53 | builder.AddPlayer(-0.88, -0.07, e_PlayerRole_CM) 54 | builder.AddPlayer(-0.88, 0.03, e_PlayerRole_CM) 55 | builder.AddPlayer(-0.88, 0.13, e_PlayerRole_LM) 56 | builder.AddPlayer(-0.75, 0.25, e_PlayerRole_RM) 57 | builder.AddPlayer(-0.2, 0.0, e_PlayerRole_CF) 58 | -------------------------------------------------------------------------------- /tmarl/envs/football/scenarios/academy_counterattack_easy.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | 17 | 18 | 19 | 20 | from . import * 21 | 22 | 23 | def build_scenario(builder): 24 | builder.config().game_duration = 400 25 | builder.config().deterministic = False 26 | builder.config().offsides = False 27 | builder.config().end_episode_on_score = True 28 | builder.config().end_episode_on_out_of_play = True 29 | builder.config().end_episode_on_possession_change = True 30 | 31 | builder.SetBallPosition(0.26, -0.11) 32 | 33 | builder.SetTeam(Team.e_Left) 34 | builder.AddPlayer(-1.0, 0.0, e_PlayerRole_GK, controllable=False) 35 | builder.AddPlayer(-0.672, -0.19576, e_PlayerRole_LB) 36 | builder.AddPlayer(-0.75, -0.06356, e_PlayerRole_CB) 37 | builder.AddPlayer(-0.75, 0.063559, e_PlayerRole_CB) 38 | builder.AddPlayer(-0.672, 0.19576, e_PlayerRole_RB) 39 | builder.AddPlayer(-0.434, -0.10568, e_PlayerRole_CM) 40 | builder.AddPlayer(-0.434, 0.10568, e_PlayerRole_CM) 41 | builder.AddPlayer(0.5, -0.3161, e_PlayerRole_CM) 42 | builder.AddPlayer(0.25, -0.1, e_PlayerRole_LM) 43 | builder.AddPlayer(0.25, 0.1, e_PlayerRole_RM) 44 | builder.AddPlayer(0.35, 0.316102, e_PlayerRole_CF) 45 | 46 | builder.SetTeam(Team.e_Right) 47 | builder.AddPlayer(-1.0, 0.0, e_PlayerRole_GK) 48 | builder.AddPlayer(0.128, -0.19576, e_PlayerRole_LB) 49 | builder.AddPlayer(0.4, -0.06356, e_PlayerRole_CB) 50 | builder.AddPlayer(-0.4, 0.063559, e_PlayerRole_CB) 51 | builder.AddPlayer(0.128, -0.19576, e_PlayerRole_RB) 52 | builder.AddPlayer(0.365, -0.10568, e_PlayerRole_CM) 53 | builder.AddPlayer(0.282, 0.0, e_PlayerRole_CM) 54 | builder.AddPlayer(0.365, 0.10568, e_PlayerRole_CM) 55 | builder.AddPlayer(0.54, -0.3161, e_PlayerRole_LM) 56 | builder.AddPlayer(0.51, 0.0, e_PlayerRole_RM) 57 | builder.AddPlayer(0.54, 0.316102, e_PlayerRole_CF) 58 | -------------------------------------------------------------------------------- /tmarl/envs/football/scenarios/academy_counterattack_hard.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | 17 | 18 | 19 | 20 | from . import * 21 | 22 | 23 | def build_scenario(builder): 24 | builder.config().game_duration = 400 25 | builder.config().deterministic = False 26 | builder.config().offsides = False 27 | builder.config().end_episode_on_score = True 28 | builder.config().end_episode_on_out_of_play = True 29 | builder.config().end_episode_on_possession_change = True 30 | 31 | builder.SetBallPosition(0.26, -0.11) 32 | 33 | builder.SetTeam(Team.e_Left) 34 | builder.AddPlayer(-1.0, 0.0, e_PlayerRole_GK, controllable=False) 35 | builder.AddPlayer(-0.672, -0.19576, e_PlayerRole_LB) 36 | builder.AddPlayer(-0.75, -0.06356, e_PlayerRole_CB) 37 | builder.AddPlayer(-0.75, 0.063559, e_PlayerRole_CB) 38 | builder.AddPlayer(-0.672, 0.19576, e_PlayerRole_RB) 39 | builder.AddPlayer(-0.434, -0.10568, e_PlayerRole_CM) 40 | builder.AddPlayer(-0.434, 0.10568, e_PlayerRole_CM) 41 | builder.AddPlayer(0.5, -0.3161, e_PlayerRole_CM) 42 | builder.AddPlayer(0.25, -0.1, e_PlayerRole_LM) 43 | builder.AddPlayer(0.25, 0.1, e_PlayerRole_RM) 44 | builder.AddPlayer(0.35, 0.316102, e_PlayerRole_CF) 45 | 46 | builder.SetTeam(Team.e_Right) 47 | builder.AddPlayer(-1.0, 0.0, e_PlayerRole_GK) 48 | builder.AddPlayer(0.128, -0.19576, e_PlayerRole_LB) 49 | builder.AddPlayer(-0.4, -0.06356, e_PlayerRole_CB) 50 | builder.AddPlayer(-0.4, 0.063559, e_PlayerRole_CB) 51 | builder.AddPlayer(0.128, -0.19576, e_PlayerRole_RB) 52 | builder.AddPlayer(0.365, -0.10568, e_PlayerRole_CM) 53 | builder.AddPlayer(0.282, 0.0, e_PlayerRole_CM) 54 | builder.AddPlayer(0.365, 0.10568, e_PlayerRole_CM) 55 | builder.AddPlayer(0.54, -0.3161, e_PlayerRole_LM) 56 | builder.AddPlayer(0.51, 0.0, e_PlayerRole_RM) 57 | builder.AddPlayer(0.54, 0.316102, e_PlayerRole_CF) 58 | -------------------------------------------------------------------------------- /tmarl/envs/football/scenarios/academy_empty_goal.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | 17 | 18 | 19 | 20 | from . import * 21 | 22 | 23 | def build_scenario(builder): 24 | builder.config().game_duration = 400 25 | builder.config().deterministic = False 26 | builder.config().offsides = False 27 | builder.config().end_episode_on_score = True 28 | builder.config().end_episode_on_out_of_play = True 29 | builder.config().end_episode_on_possession_change = True 30 | builder.SetBallPosition(0.02, 0.0) 31 | 32 | builder.SetTeam(Team.e_Left) 33 | builder.AddPlayer(-1.0, 0.0, e_PlayerRole_GK, controllable=False) 34 | builder.AddPlayer(0.0, 0.0, e_PlayerRole_CB) 35 | 36 | builder.SetTeam(Team.e_Right) 37 | builder.AddPlayer(1.0, 0.0, e_PlayerRole_GK) 38 | -------------------------------------------------------------------------------- /tmarl/envs/football/scenarios/academy_empty_goal_close.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | 17 | 18 | 19 | 20 | from . import * 21 | 22 | 23 | def build_scenario(builder): 24 | builder.config().game_duration = 400 25 | builder.config().deterministic = False 26 | builder.config().offsides = False 27 | builder.config().end_episode_on_score = True 28 | builder.config().end_episode_on_out_of_play = True 29 | builder.config().end_episode_on_possession_change = True 30 | builder.SetBallPosition(0.77, 0.0) 31 | 32 | builder.SetTeam(Team.e_Left) 33 | builder.AddPlayer(-1.0, 0.0, e_PlayerRole_GK, controllable=False) 34 | builder.AddPlayer(0.75, 0.0, e_PlayerRole_CB) 35 | 36 | builder.SetTeam(Team.e_Right) 37 | builder.AddPlayer(1.0, 0.0, e_PlayerRole_GK) 38 | -------------------------------------------------------------------------------- /tmarl/envs/football/scenarios/academy_pass_and_shoot_with_keeper.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | 17 | 18 | 19 | 20 | from . import * 21 | 22 | 23 | def build_scenario(builder): 24 | builder.config().game_duration = 400 25 | builder.config().deterministic = False 26 | builder.config().offsides = False 27 | builder.config().end_episode_on_score = True 28 | builder.config().end_episode_on_out_of_play = True 29 | builder.config().end_episode_on_possession_change = True 30 | builder.SetBallPosition(0.7, -0.28) 31 | 32 | builder.SetTeam(Team.e_Left) 33 | builder.AddPlayer(-1.0, 0.0, e_PlayerRole_GK, controllable=False) 34 | builder.AddPlayer(0.7, 0.0, e_PlayerRole_CB) 35 | builder.AddPlayer(0.7, -0.3, e_PlayerRole_CB) 36 | 37 | builder.SetTeam(Team.e_Right) 38 | builder.AddPlayer(-1.0, 0.0, e_PlayerRole_GK) 39 | builder.AddPlayer(-0.75, 0.3, e_PlayerRole_CB) 40 | -------------------------------------------------------------------------------- /tmarl/envs/football/scenarios/academy_run_pass_and_shoot_with_keeper.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | 17 | 18 | 19 | 20 | from . import * 21 | 22 | 23 | def build_scenario(builder): 24 | builder.config().game_duration = 400 25 | builder.config().deterministic = False 26 | builder.config().offsides = False 27 | builder.config().end_episode_on_score = True 28 | builder.config().end_episode_on_out_of_play = True 29 | builder.config().end_episode_on_possession_change = True 30 | builder.SetBallPosition(0.7, -0.28) 31 | 32 | builder.SetTeam(Team.e_Left) 33 | builder.AddPlayer(-1.0, 0.0, e_PlayerRole_GK, controllable=False) 34 | builder.AddPlayer(0.7, 0.0, e_PlayerRole_CB) 35 | builder.AddPlayer(0.7, -0.3, e_PlayerRole_CB) 36 | 37 | builder.SetTeam(Team.e_Right) 38 | builder.AddPlayer(-1.0, 0.0, e_PlayerRole_GK) 39 | builder.AddPlayer(-0.75, 0.1, e_PlayerRole_CB) 40 | -------------------------------------------------------------------------------- /tmarl/envs/football/scenarios/academy_run_to_score.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | 17 | 18 | 19 | 20 | from . import * 21 | 22 | 23 | def build_scenario(builder): 24 | builder.config().game_duration = 400 25 | builder.config().deterministic = False 26 | builder.config().offsides = False 27 | builder.config().end_episode_on_score = True 28 | builder.config().end_episode_on_out_of_play = True 29 | builder.config().end_episode_on_possession_change = True 30 | builder.SetBallPosition(0.02, 0.0) 31 | 32 | builder.SetTeam(Team.e_Left) 33 | builder.AddPlayer(-1.0, 0.0, e_PlayerRole_GK, controllable=False) 34 | builder.AddPlayer(0.0, 0.0, e_PlayerRole_CB) 35 | 36 | builder.SetTeam(Team.e_Right) 37 | builder.AddPlayer(1.0, 0.0, e_PlayerRole_GK) 38 | builder.AddPlayer(0.12, 0.2, e_PlayerRole_LB) 39 | builder.AddPlayer(0.12, 0.1, e_PlayerRole_CB) 40 | builder.AddPlayer(0.12, 0.0, e_PlayerRole_CM) 41 | builder.AddPlayer(0.12, -0.1, e_PlayerRole_CB) 42 | builder.AddPlayer(0.12, -0.2, e_PlayerRole_RB) 43 | -------------------------------------------------------------------------------- /tmarl/envs/football/scenarios/academy_run_to_score_with_keeper.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 Google LLC 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | 17 | 18 | 19 | 20 | from . import * 21 | 22 | 23 | def build_scenario(builder): 24 | builder.config().game_duration = 400 25 | builder.config().deterministic = False 26 | builder.config().offsides = False 27 | builder.config().end_episode_on_score = True 28 | builder.config().end_episode_on_out_of_play = True 29 | builder.config().end_episode_on_possession_change = True 30 | builder.SetBallPosition(0.02, 0.0) 31 | 32 | builder.SetTeam(Team.e_Left) 33 | builder.AddPlayer(-1.0, 0.0, e_PlayerRole_GK, controllable=False) 34 | builder.AddPlayer(0.0, 0.0, e_PlayerRole_CB) 35 | 36 | builder.SetTeam(Team.e_Right) 37 | builder.AddPlayer(-1.0, 0.0, e_PlayerRole_GK) 38 | builder.AddPlayer(0.12, 0.2, e_PlayerRole_LB) 39 | builder.AddPlayer(0.12, 0.1, e_PlayerRole_CB) 40 | builder.AddPlayer(0.12, 0.0, e_PlayerRole_CM) 41 | builder.AddPlayer(0.12, -0.1, e_PlayerRole_CB) 42 | builder.AddPlayer(0.12, -0.2, e_PlayerRole_RB) 43 | -------------------------------------------------------------------------------- /tmarl/loggers/TSee/README.md: -------------------------------------------------------------------------------- 1 | # TSee 2 | 3 | ## 参数管理、日志存储和可视化 4 | 5 | 6 | -------------------------------------------------------------------------------- /tmarl/loggers/TSee/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TARTRL/TiKick/3521627be1b4c65215157f748fa0df6a35f61c23/tmarl/loggers/TSee/__init__.py -------------------------------------------------------------------------------- /tmarl/loggers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TARTRL/TiKick/3521627be1b4c65215157f748fa0df6a35f61c23/tmarl/loggers/__init__.py -------------------------------------------------------------------------------- /tmarl/loggers/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2021 The TARTRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | import time 20 | 21 | def timer(function): 22 | """ 23 | 装饰器函数timer 24 | :param function:想要计时的函数 25 | :return: 26 | """ 27 | 28 | def wrapper(*args, **kwargs): 29 | time_start = time.time() 30 | res = function(*args, **kwargs) 31 | cost_time = time.time() - time_start 32 | print("{} running time: {}s".format(function.__name__, cost_time)) 33 | return res 34 | 35 | return wrapper -------------------------------------------------------------------------------- /tmarl/networks/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2021 The TARTRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /tmarl/networks/policy_network.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | from tmarl.networks.utils.util import init, check 6 | from tmarl.networks.utils.mlp import MLPBase, MLPLayer 7 | from tmarl.networks.utils.rnn import RNNLayer 8 | from tmarl.networks.utils.act import ACTLayer 9 | from tmarl.networks.utils.popart import PopArt 10 | from tmarl.utils.util import get_shape_from_obs_space 11 | 12 | # networks are defined here 13 | 14 | class PolicyNetwork(nn.Module): 15 | def __init__(self, args, obs_space, action_space, device=torch.device("cpu")): 16 | super(PolicyNetwork, self).__init__() 17 | self.hidden_size = args.hidden_size 18 | 19 | self._gain = args.gain 20 | self._use_orthogonal = args.use_orthogonal 21 | self._activation_id = args.activation_id 22 | self._use_policy_active_masks = args.use_policy_active_masks 23 | self._use_naive_recurrent_policy = args.use_naive_recurrent_policy 24 | self._use_recurrent_policy = args.use_recurrent_policy 25 | self._use_influence_policy = args.use_influence_policy 26 | self._influence_layer_N = args.influence_layer_N 27 | self._use_policy_vhead = args.use_policy_vhead 28 | self._recurrent_N = args.recurrent_N 29 | self.tpdv = dict(dtype=torch.float32, device=device) 30 | 31 | obs_shape = get_shape_from_obs_space(obs_space) 32 | 33 | self._mixed_obs = False 34 | self.base = MLPBase(args, obs_shape, use_attn_internal=False, use_cat_self=True) 35 | 36 | input_size = self.base.output_size 37 | 38 | if self._use_naive_recurrent_policy or self._use_recurrent_policy: 39 | self.rnn = RNNLayer(input_size, self.hidden_size, self._recurrent_N, self._use_orthogonal) 40 | input_size = self.hidden_size 41 | 42 | if self._use_influence_policy: 43 | self.mlp = MLPLayer(obs_shape[0], self.hidden_size, 44 | self._influence_layer_N, self._use_orthogonal, self._activation_id) 45 | input_size += self.hidden_size 46 | 47 | self.act = ACTLayer(action_space, input_size, self._use_orthogonal, self._gain) 48 | 49 | if self._use_policy_vhead: 50 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][self._use_orthogonal] 51 | def init_(m): 52 | return init(m, init_method, lambda x: nn.init.constant_(x, 0)) 53 | if self._use_popart: 54 | self.v_out = init_(PopArt(input_size, 1, device=device)) 55 | else: 56 | self.v_out = init_(nn.Linear(input_size, 1)) 57 | 58 | self.to(device) 59 | 60 | def forward(self, obs, rnn_states, masks, available_actions=None, deterministic=False): 61 | if self._mixed_obs: 62 | for key in obs.keys(): 63 | obs[key] = check(obs[key]).to(**self.tpdv) 64 | else: 65 | obs = check(obs).to(**self.tpdv) 66 | rnn_states = check(rnn_states).to(**self.tpdv) 67 | masks = check(masks).to(**self.tpdv) 68 | if available_actions is not None: 69 | available_actions = check(available_actions).to(**self.tpdv) 70 | actor_features = self.base(obs) 71 | 72 | if self._use_naive_recurrent_policy or self._use_recurrent_policy: 73 | actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) 74 | if self._use_influence_policy: 75 | mlp_obs = self.mlp(obs) 76 | actor_features = torch.cat([actor_features, mlp_obs], dim=1) 77 | actions, action_log_probs = self.act(actor_features, available_actions, deterministic) 78 | 79 | return actions, action_log_probs, rnn_states 80 | 81 | def evaluate_actions(self, obs, rnn_states, action, masks, available_actions=None, active_masks=None): 82 | if self._mixed_obs: 83 | for key in obs.keys(): 84 | obs[key] = check(obs[key]).to(**self.tpdv) 85 | else: 86 | obs = check(obs).to(**self.tpdv) 87 | 88 | rnn_states = check(rnn_states).to(**self.tpdv) 89 | action = check(action).to(**self.tpdv) 90 | masks = check(masks).to(**self.tpdv) 91 | 92 | if available_actions is not None: 93 | available_actions = check(available_actions).to(**self.tpdv) 94 | 95 | if active_masks is not None: 96 | active_masks = check(active_masks).to(**self.tpdv) 97 | 98 | actor_features = self.base(obs) 99 | 100 | if self._use_naive_recurrent_policy or self._use_recurrent_policy: 101 | actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) 102 | 103 | if self._use_influence_policy: 104 | mlp_obs = self.mlp(obs) 105 | actor_features = torch.cat([actor_features, mlp_obs], dim=1) 106 | 107 | action_log_probs, dist_entropy = self.act.evaluate_actions(actor_features, action, available_actions, active_masks = active_masks if self._use_policy_active_masks else None) 108 | 109 | values = self.v_out(actor_features) if self._use_policy_vhead else None 110 | 111 | return action_log_probs, dist_entropy, values 112 | 113 | def get_policy_values(self, obs, rnn_states, masks): 114 | if self._mixed_obs: 115 | for key in obs.keys(): 116 | obs[key] = check(obs[key]).to(**self.tpdv) 117 | else: 118 | obs = check(obs).to(**self.tpdv) 119 | rnn_states = check(rnn_states).to(**self.tpdv) 120 | masks = check(masks).to(**self.tpdv) 121 | 122 | actor_features = self.base(obs) 123 | if self._use_naive_recurrent_policy or self._use_recurrent_policy: 124 | actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) 125 | 126 | if self._use_influence_policy: 127 | mlp_obs = self.mlp(obs) 128 | actor_features = torch.cat([actor_features, mlp_obs], dim=1) 129 | 130 | values = self.v_out(actor_features) 131 | 132 | return values -------------------------------------------------------------------------------- /tmarl/networks/utils/act.py: -------------------------------------------------------------------------------- 1 | 2 | from .distributions import Bernoulli, Categorical, DiagGaussian 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class ACTLayer(nn.Module): 8 | def __init__(self, action_space, inputs_dim, use_orthogonal, gain): 9 | super(ACTLayer, self).__init__() 10 | self.multidiscrete_action = False 11 | self.continuous_action = False 12 | self.mixed_action = False 13 | 14 | if action_space.__class__.__name__ == "Discrete": 15 | action_dim = action_space.n 16 | self.action_out = Categorical(inputs_dim, action_dim, use_orthogonal, gain) 17 | elif action_space.__class__.__name__ == "Box": 18 | self.continuous_action = True 19 | action_dim = action_space.shape[0] 20 | self.action_out = DiagGaussian(inputs_dim, action_dim, use_orthogonal, gain) 21 | elif action_space.__class__.__name__ == "MultiBinary": 22 | action_dim = action_space.shape[0] 23 | self.action_out = Bernoulli(inputs_dim, action_dim, use_orthogonal, gain) 24 | elif action_space.__class__.__name__ == "MultiDiscrete": 25 | self.multidiscrete_action = True 26 | action_dims = action_space.high - action_space.low + 1 27 | self.action_outs = [] 28 | for action_dim in action_dims: 29 | self.action_outs.append(Categorical(inputs_dim, action_dim, use_orthogonal, gain)) 30 | self.action_outs = nn.ModuleList(self.action_outs) 31 | else: # discrete + continous 32 | self.mixed_action = True 33 | continous_dim = action_space[0].shape[0] 34 | discrete_dim = action_space[1].n 35 | self.action_outs = nn.ModuleList([DiagGaussian(inputs_dim, continous_dim, use_orthogonal, gain), Categorical( 36 | inputs_dim, discrete_dim, use_orthogonal, gain)]) 37 | 38 | def forward(self, x, available_actions=None, deterministic=False): 39 | if self.mixed_action : 40 | actions = [] 41 | action_log_probs = [] 42 | for action_out in self.action_outs: 43 | action_logit = action_out(x) 44 | action = action_logit.mode() if deterministic else action_logit.sample() 45 | action_log_prob = action_logit.log_probs(action) 46 | actions.append(action.float()) 47 | action_log_probs.append(action_log_prob) 48 | 49 | actions = torch.cat(actions, -1) 50 | action_log_probs = torch.sum(torch.cat(action_log_probs, -1), -1, keepdim=True) 51 | 52 | elif self.multidiscrete_action: 53 | actions = [] 54 | action_log_probs = [] 55 | for action_out in self.action_outs: 56 | action_logit = action_out(x) 57 | action = action_logit.mode() if deterministic else action_logit.sample() 58 | action_log_prob = action_logit.log_probs(action) 59 | actions.append(action) 60 | action_log_probs.append(action_log_prob) 61 | 62 | actions = torch.cat(actions, -1) 63 | action_log_probs = torch.cat(action_log_probs, -1) 64 | 65 | elif self.continuous_action: 66 | action_logits = self.action_out(x) 67 | actions = action_logits.mode() if deterministic else action_logits.sample() 68 | action_log_probs = action_logits.log_probs(actions) 69 | 70 | else: 71 | action_logits = self.action_out(x, available_actions) 72 | actions = action_logits.mode() if deterministic else action_logits.sample() 73 | action_log_probs = action_logits.log_probs(actions) 74 | 75 | return actions, action_log_probs 76 | 77 | def get_probs(self, x, available_actions=None): 78 | if self.mixed_action or self.multidiscrete_action: 79 | action_probs = [] 80 | for action_out in self.action_outs: 81 | action_logit = action_out(x) 82 | action_prob = action_logit.probs 83 | action_probs.append(action_prob) 84 | action_probs = torch.cat(action_probs, -1) 85 | elif self.continuous_action: 86 | action_logits = self.action_out(x) 87 | action_probs = action_logits.probs 88 | else: 89 | action_logits = self.action_out(x, available_actions) 90 | action_probs = action_logits.probs 91 | 92 | return action_probs 93 | 94 | def get_log_1mp(self, x, action, available_actions=None, active_masks=None): 95 | action_logits = self.action_out(x, available_actions) 96 | action_prob = torch.gather(action_logits.probs, 1, action.long()) 97 | action_prob = torch.clamp(action_prob, 0, 1-1e-6) 98 | action_log_1mp = torch.log(1 - action_prob) 99 | return action_log_1mp 100 | 101 | def evaluate_actions(self, x, action, available_actions=None, active_masks=None): 102 | if self.mixed_action: 103 | a, b = action.split((2, 1), -1) 104 | b = b.long() 105 | action = [a, b] 106 | action_log_probs = [] 107 | dist_entropy = [] 108 | for action_out, act in zip(self.action_outs, action): 109 | action_logit = action_out(x) 110 | action_log_probs.append(action_logit.log_probs(act)) 111 | if active_masks is not None: 112 | if len(action_logit.entropy().shape) == len(active_masks.shape): 113 | dist_entropy.append((action_logit.entropy() * active_masks).sum()/active_masks.sum()) 114 | else: 115 | dist_entropy.append((action_logit.entropy() * active_masks.squeeze(-1)).sum()/active_masks.sum()) 116 | else: 117 | dist_entropy.append(action_logit.entropy().mean()) 118 | 119 | action_log_probs = torch.sum(torch.cat(action_log_probs, -1), -1, keepdim=True) 120 | dist_entropy = dist_entropy[0] * 0.0025 + dist_entropy[1] * 0.01 121 | 122 | elif self.multidiscrete_action: 123 | action = torch.transpose(action, 0, 1) 124 | action_log_probs = [] 125 | dist_entropy = [] 126 | for action_out, act in zip(self.action_outs, action): 127 | action_logit = action_out(x) 128 | action_log_probs.append(action_logit.log_probs(act)) 129 | if active_masks is not None: 130 | dist_entropy.append((action_logit.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum()) 131 | else: 132 | dist_entropy.append(action_logit.entropy().mean()) 133 | 134 | action_log_probs = torch.cat(action_log_probs, -1) # ! could be wrong 135 | dist_entropy = torch.tensor(dist_entropy).mean() 136 | 137 | elif self.continuous_action: 138 | action_logits = self.action_out(x) 139 | action_log_probs = action_logits.log_probs(action) 140 | if active_masks is not None: 141 | dist_entropy = (action_logits.entropy()*active_masks).sum()/active_masks.sum() 142 | else: 143 | dist_entropy = action_logits.entropy().mean() 144 | else: 145 | action_logits = self.action_out(x, available_actions) 146 | action_log_probs = action_logits.log_probs(action) 147 | if active_masks is not None: 148 | dist_entropy = (action_logits.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum() 149 | else: 150 | dist_entropy = action_logits.entropy().mean() 151 | 152 | return action_log_probs, dist_entropy -------------------------------------------------------------------------------- /tmarl/networks/utils/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .util import init 5 | 6 | """ 7 | Modify standard PyTorch distributions so they are compatible with this code. 8 | """ 9 | 10 | # 11 | # Standardize distribution interfaces 12 | # 13 | 14 | # Categorical 15 | class FixedCategorical(torch.distributions.Categorical): 16 | def sample(self): 17 | return super().sample().unsqueeze(-1) 18 | 19 | def log_probs(self, actions): 20 | return ( 21 | super() 22 | .log_prob(actions.squeeze(-1)) 23 | .view(actions.size(0), -1) 24 | .sum(-1) 25 | .unsqueeze(-1) 26 | ) 27 | 28 | def mode(self): 29 | return self.probs.argmax(dim=-1, keepdim=True) 30 | 31 | 32 | # Normal 33 | class FixedNormal(torch.distributions.Normal): 34 | def log_probs(self, actions): 35 | return super().log_prob(actions).sum(-1, keepdim=True) 36 | 37 | def entrop(self): 38 | return super.entropy().sum(-1) 39 | 40 | def mode(self): 41 | return self.mean 42 | 43 | 44 | # Bernoulli 45 | class FixedBernoulli(torch.distributions.Bernoulli): 46 | def log_probs(self, actions): 47 | return super.log_prob(actions).view(actions.size(0), -1).sum(-1).unsqueeze(-1) 48 | 49 | def entropy(self): 50 | return super().entropy().sum(-1) 51 | 52 | def mode(self): 53 | return torch.gt(self.probs, 0.5).float() 54 | 55 | 56 | class Categorical(nn.Module): 57 | def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01): 58 | super(Categorical, self).__init__() 59 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 60 | def init_(m): 61 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain) 62 | 63 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 64 | 65 | def forward(self, x, available_actions=None): 66 | x = self.linear(x) 67 | if available_actions is not None: 68 | x[available_actions == 0] = -1e10 69 | return FixedCategorical(logits=x) 70 | 71 | 72 | class DiagGaussian(nn.Module): 73 | def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01): 74 | super(DiagGaussian, self).__init__() 75 | 76 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 77 | def init_(m): 78 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain) 79 | 80 | self.fc_mean = init_(nn.Linear(num_inputs, num_outputs)) 81 | self.logstd = AddBias(torch.zeros(num_outputs)) 82 | 83 | def forward(self, x): 84 | action_mean = self.fc_mean(x) 85 | 86 | # An ugly hack for my KFAC implementation. 87 | zeros = torch.zeros(action_mean.size()) 88 | if x.is_cuda: 89 | zeros = zeros.cuda() 90 | 91 | action_logstd = self.logstd(zeros) 92 | return FixedNormal(action_mean, action_logstd.exp()) 93 | 94 | 95 | class Bernoulli(nn.Module): 96 | def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01): 97 | super(Bernoulli, self).__init__() 98 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 99 | def init_(m): 100 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain) 101 | 102 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 103 | 104 | def forward(self, x): 105 | x = self.linear(x) 106 | return FixedBernoulli(logits=x) 107 | 108 | class AddBias(nn.Module): 109 | def __init__(self, bias): 110 | super(AddBias, self).__init__() 111 | self._bias = nn.Parameter(bias.unsqueeze(1)) 112 | 113 | def forward(self, x): 114 | if x.dim() == 2: 115 | bias = self._bias.t().view(1, -1) 116 | else: 117 | bias = self._bias.t().view(1, -1, 1, 1) 118 | 119 | return x + bias 120 | -------------------------------------------------------------------------------- /tmarl/networks/utils/mlp.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | 4 | 5 | from .util import init, get_clones 6 | 7 | class MLPLayer(nn.Module): 8 | def __init__(self, input_dim, hidden_size, layer_N, use_orthogonal, activation_id): 9 | super(MLPLayer, self).__init__() 10 | self._layer_N = layer_N 11 | 12 | active_func = [nn.Tanh(), nn.ReLU(), nn.LeakyReLU(), nn.ELU()][activation_id] 13 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 14 | gain = nn.init.calculate_gain(['tanh', 'relu', 'leaky_relu', 'leaky_relu'][activation_id]) 15 | 16 | def init_(m): 17 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain) 18 | 19 | self.fc1 = nn.Sequential( 20 | init_(nn.Linear(input_dim, hidden_size)), active_func, nn.LayerNorm(hidden_size)) 21 | self.fc_h = nn.Sequential(init_( 22 | nn.Linear(hidden_size, hidden_size)), active_func, nn.LayerNorm(hidden_size)) 23 | self.fc2 = get_clones(self.fc_h, self._layer_N) 24 | 25 | def forward(self, x): 26 | x = self.fc1(x) 27 | for i in range(self._layer_N): 28 | x = self.fc2[i](x) 29 | return x 30 | 31 | 32 | class MLPBase(nn.Module): 33 | def __init__(self, args, obs_shape, use_attn_internal=False, use_cat_self=True): 34 | super(MLPBase, self).__init__() 35 | 36 | self._use_feature_normalization = args.use_feature_normalization 37 | self._use_orthogonal = args.use_orthogonal 38 | self._activation_id = args.activation_id 39 | self._use_conv1d = args.use_conv1d 40 | self._stacked_frames = args.stacked_frames 41 | self._layer_N = args.layer_N 42 | self.hidden_size = args.hidden_size 43 | 44 | obs_dim = obs_shape[0] 45 | inputs_dim = obs_dim 46 | 47 | if self._use_feature_normalization: 48 | self.feature_norm = nn.LayerNorm(obs_dim) 49 | 50 | self.mlp = MLPLayer(inputs_dim, self.hidden_size, 51 | self._layer_N, self._use_orthogonal, self._activation_id) 52 | 53 | def forward(self, x): 54 | if self._use_feature_normalization: 55 | x = self.feature_norm(x) 56 | 57 | x = self.mlp(x) 58 | 59 | return x 60 | 61 | @property 62 | def output_size(self): 63 | return self.hidden_size -------------------------------------------------------------------------------- /tmarl/networks/utils/popart.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class PopArt(torch.nn.Module): 8 | 9 | def __init__(self, input_shape, output_shape, norm_axes=1, beta=0.99999, epsilon=1e-5, device=torch.device("cpu")): 10 | 11 | super(PopArt, self).__init__() 12 | 13 | self.beta = beta 14 | self.epsilon = epsilon 15 | self.norm_axes = norm_axes 16 | self.tpdv = dict(dtype=torch.float32, device=device) 17 | 18 | self.input_shape = input_shape 19 | self.output_shape = output_shape 20 | 21 | self.weight = nn.Parameter(torch.Tensor(output_shape, input_shape)).to(**self.tpdv) 22 | self.bias = nn.Parameter(torch.Tensor(output_shape)).to(**self.tpdv) 23 | 24 | self.stddev = nn.Parameter(torch.ones(output_shape), requires_grad=False).to(**self.tpdv) 25 | self.mean = nn.Parameter(torch.zeros(output_shape), requires_grad=False).to(**self.tpdv) 26 | self.mean_sq = nn.Parameter(torch.zeros(output_shape), requires_grad=False).to(**self.tpdv) 27 | self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False).to(**self.tpdv) 28 | 29 | self.reset_parameters() 30 | 31 | def reset_parameters(self): 32 | torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 33 | if self.bias is not None: 34 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) 35 | bound = 1 / math.sqrt(fan_in) 36 | torch.nn.init.uniform_(self.bias, -bound, bound) 37 | self.mean.zero_() 38 | self.mean_sq.zero_() 39 | self.debiasing_term.zero_() 40 | 41 | def forward(self, input_vector): 42 | if type(input_vector) == np.ndarray: 43 | input_vector = torch.from_numpy(input_vector) 44 | input_vector = input_vector.to(**self.tpdv) 45 | 46 | return F.linear(input_vector, self.weight, self.bias) 47 | 48 | @torch.no_grad() 49 | def update(self, input_vector): 50 | if type(input_vector) == np.ndarray: 51 | input_vector = torch.from_numpy(input_vector) 52 | input_vector = input_vector.to(**self.tpdv) 53 | 54 | old_mean, old_stddev = self.mean, self.stddev 55 | 56 | batch_mean = input_vector.mean(dim=tuple(range(self.norm_axes))) 57 | batch_sq_mean = (input_vector ** 2).mean(dim=tuple(range(self.norm_axes))) 58 | 59 | self.mean.mul_(self.beta).add_(batch_mean * (1.0 - self.beta)) 60 | self.mean_sq.mul_(self.beta).add_(batch_sq_mean * (1.0 - self.beta)) 61 | self.debiasing_term.mul_(self.beta).add_(1.0 * (1.0 - self.beta)) 62 | 63 | self.stddev = (self.mean_sq - self.mean ** 2).sqrt().clamp(min=1e-4) 64 | 65 | self.weight = self.weight * old_stddev / self.stddev 66 | self.bias = (old_stddev * self.bias + old_mean - self.mean) / self.stddev 67 | 68 | def debiased_mean_var(self): 69 | debiased_mean = self.mean / self.debiasing_term.clamp(min=self.epsilon) 70 | debiased_mean_sq = self.mean_sq / self.debiasing_term.clamp(min=self.epsilon) 71 | debiased_var = (debiased_mean_sq - debiased_mean ** 2).clamp(min=1e-2) 72 | return debiased_mean, debiased_var 73 | 74 | def normalize(self, input_vector): 75 | if type(input_vector) == np.ndarray: 76 | input_vector = torch.from_numpy(input_vector) 77 | input_vector = input_vector.to(**self.tpdv) 78 | 79 | mean, var = self.debiased_mean_var() 80 | out = (input_vector - mean[(None,) * self.norm_axes]) / torch.sqrt(var)[(None,) * self.norm_axes] 81 | 82 | return out 83 | 84 | def denormalize(self, input_vector): 85 | if type(input_vector) == np.ndarray: 86 | input_vector = torch.from_numpy(input_vector) 87 | input_vector = input_vector.to(**self.tpdv) 88 | 89 | mean, var = self.debiased_mean_var() 90 | out = input_vector * torch.sqrt(var)[(None,) * self.norm_axes] + mean[(None,) * self.norm_axes] 91 | 92 | out = out.cpu().numpy() 93 | 94 | return out 95 | -------------------------------------------------------------------------------- /tmarl/networks/utils/rnn.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class RNNLayer(nn.Module): 7 | def __init__(self, inputs_dim, outputs_dim, recurrent_N, use_orthogonal): 8 | super(RNNLayer, self).__init__() 9 | self._recurrent_N = recurrent_N 10 | self._use_orthogonal = use_orthogonal 11 | 12 | self.rnn = nn.GRU(inputs_dim, outputs_dim, num_layers=self._recurrent_N) 13 | for name, param in self.rnn.named_parameters(): 14 | if 'bias' in name: 15 | nn.init.constant_(param, 0) 16 | elif 'weight' in name: 17 | if self._use_orthogonal: 18 | nn.init.orthogonal_(param) 19 | else: 20 | nn.init.xavier_uniform_(param) 21 | self.norm = nn.LayerNorm(outputs_dim) 22 | 23 | def forward(self, x, hxs, masks): 24 | if x.size(0) == hxs.size(0): 25 | x, hxs = self.rnn(x.unsqueeze(0), (hxs * masks.repeat(1, self._recurrent_N).unsqueeze(-1)).transpose(0, 1).contiguous()) 26 | 27 | x = x.squeeze(0) 28 | hxs = hxs.transpose(0, 1) 29 | else: 30 | # x is a (T, N, -1) tensor that has been flatten to (T * N, -1) 31 | N = hxs.size(0) 32 | T = int(x.size(0) / N) 33 | 34 | # unflatten 35 | x = x.view(T, N, x.size(1)) 36 | 37 | # Same deal with masks 38 | masks = masks.view(T, N) 39 | 40 | # Let's figure out which steps in the sequence have a zero for any agent 41 | # We will always assume t=0 has a zero in it as that makes the logic cleaner 42 | has_zeros = ((masks[1:] == 0.0) 43 | .any(dim=-1) 44 | .nonzero() 45 | .squeeze() 46 | .cpu()) 47 | 48 | # +1 to correct the masks[1:] 49 | if has_zeros.dim() == 0: 50 | # Deal with scalar 51 | has_zeros = [has_zeros.item() + 1] 52 | else: 53 | has_zeros = (has_zeros + 1).numpy().tolist() 54 | 55 | # add t=0 and t=T to the list 56 | has_zeros = [0] + has_zeros + [T] 57 | 58 | hxs = hxs.transpose(0, 1) 59 | 60 | outputs = [] 61 | for i in range(len(has_zeros) - 1): 62 | # We can now process steps that don't have any zeros in masks together! 63 | # This is much faster 64 | start_idx = has_zeros[i] 65 | end_idx = has_zeros[i + 1] 66 | temp = (hxs * masks[start_idx].view(1, -1, 1).repeat(self._recurrent_N, 1, 1)).contiguous() 67 | rnn_scores, hxs = self.rnn(x[start_idx:end_idx], temp) 68 | outputs.append(rnn_scores) 69 | 70 | # assert len(outputs) == T 71 | # x is a (T, N, -1) tensor 72 | x = torch.cat(outputs, dim=0) 73 | 74 | # flatten 75 | x = x.reshape(T * N, -1) 76 | hxs = hxs.transpose(0, 1) 77 | 78 | x = self.norm(x) 79 | return x, hxs 80 | -------------------------------------------------------------------------------- /tmarl/networks/utils/util.py: -------------------------------------------------------------------------------- 1 | 2 | import copy 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | def init(module, weight_init, bias_init, gain=1): 9 | weight_init(module.weight.data, gain=gain) 10 | bias_init(module.bias.data) 11 | return module 12 | 13 | def get_clones(module, N): 14 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 15 | 16 | def check(input): 17 | output = torch.from_numpy(input) if type(input) == np.ndarray else input 18 | return output 19 | -------------------------------------------------------------------------------- /tmarl/replay_buffers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TARTRL/TiKick/3521627be1b4c65215157f748fa0df6a35f61c23/tmarl/replay_buffers/__init__.py -------------------------------------------------------------------------------- /tmarl/replay_buffers/normal/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2021 The TARTRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /tmarl/runners/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TARTRL/TiKick/3521627be1b4c65215157f748fa0df6a35f61c23/tmarl/runners/__init__.py -------------------------------------------------------------------------------- /tmarl/runners/base_evaluator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2021 The TARTRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | 20 | import random 21 | 22 | import numpy as np 23 | import torch 24 | 25 | from tmarl.configs.config import get_config 26 | from tmarl.runners.base_runner import Runner 27 | 28 | def set_seed(seed): 29 | random.seed(seed) 30 | np.random.seed(seed) 31 | 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed_all(seed) 34 | 35 | class Evaluator(Runner): 36 | def __init__(self, argv,program_type=None, client=None): 37 | super().__init__(argv) 38 | 39 | parser = get_config() 40 | all_args = self.extra_args_func(argv, parser) 41 | 42 | all_args.cuda = not all_args.disable_cuda 43 | 44 | self.algorithm_name = all_args.algorithm_name 45 | 46 | # cuda 47 | if not all_args.disable_cuda and torch.cuda.is_available(): 48 | device = torch.device("cuda:0") 49 | 50 | if all_args.cuda_deterministic: 51 | torch.backends.cudnn.benchmark = False 52 | torch.backends.cudnn.deterministic = True 53 | else: 54 | print("choose to use cpu...") 55 | device = torch.device("cpu") 56 | 57 | 58 | # run dir 59 | run_dir = self.setup_run_dir(all_args) 60 | 61 | # env init 62 | Env_Class, SubprocVecEnv, DummyVecEnv = self.get_env() 63 | eval_envs = self.env_init( 64 | all_args, Env_Class, SubprocVecEnv, DummyVecEnv) 65 | num_agents = all_args.num_agents 66 | 67 | config = { 68 | "all_args": all_args, 69 | "envs": None, 70 | "eval_envs": eval_envs, 71 | "num_agents": num_agents, 72 | "device": device, 73 | "run_dir": run_dir, 74 | } 75 | self.all_args, self.envs, self.eval_envs, self.config \ 76 | = all_args, None, eval_envs, config 77 | self.driver = self.init_driver() 78 | 79 | def run(self): 80 | # run experiments 81 | self.driver.run() 82 | self.stop() 83 | 84 | def stop(self): 85 | pass 86 | 87 | def extra_args_func(self, argv, parser): 88 | raise NotImplementedError 89 | 90 | def get_env(self): 91 | raise NotImplementedError 92 | 93 | def init_driver(self): 94 | raise NotImplementedError 95 | 96 | def make_eval_env(self, all_args, Env_Class, SubprocVecEnv, DummyVecEnv): 97 | def get_env_fn(rank): 98 | def init_env(): 99 | env = Env_Class(all_args) 100 | env.seed(all_args.seed * 50000 + rank * 10000) 101 | return env 102 | 103 | return init_env 104 | 105 | if all_args.n_eval_rollout_threads == 1: 106 | return DummyVecEnv([get_env_fn(0)]) 107 | else: 108 | return SubprocVecEnv([get_env_fn(i) for i in range(all_args.n_eval_rollout_threads)]) 109 | 110 | def env_init(self, all_args, Env_Class, SubprocVecEnv, DummyVecEnv): 111 | eval_envs = self.make_eval_env( 112 | all_args, Env_Class, SubprocVecEnv, DummyVecEnv) if all_args.use_eval else None 113 | return eval_envs 114 | 115 | def setup_run_dir(self, all_args): 116 | return None 117 | -------------------------------------------------------------------------------- /tmarl/runners/base_runner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2021 The TARTRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | 20 | import os 21 | import random 22 | import socket 23 | import setproctitle 24 | 25 | import numpy as np 26 | from pathlib import Path 27 | import torch 28 | 29 | from tmarl.configs.config import get_config 30 | 31 | 32 | def set_seed(seed): 33 | random.seed(seed) 34 | np.random.seed(seed) 35 | 36 | torch.manual_seed(seed) 37 | torch.cuda.manual_seed_all(seed) 38 | 39 | 40 | class Runner: 41 | def __init__(self, argv): 42 | self.argv = argv 43 | 44 | def run(self): 45 | # main run 46 | raise NotImplementedError -------------------------------------------------------------------------------- /tmarl/runners/football/football_evaluator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2021 The TARTRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | import sys 19 | import os 20 | 21 | from pathlib import Path 22 | 23 | from tmarl.runners.base_evaluator import Evaluator 24 | from tmarl.envs.football.football import RllibGFootball 25 | from tmarl.envs.env_wrappers import ShareSubprocVecEnv, ShareDummyVecEnv 26 | 27 | 28 | class FootballEvaluator(Evaluator): 29 | def __init__(self, argv): 30 | super(FootballEvaluator, self).__init__(argv) 31 | 32 | def setup_run_dir(self, all_args): 33 | dump_dir = Path(all_args.replay_save_dir) 34 | if not dump_dir.exists(): 35 | os.makedirs(str(dump_dir)) 36 | self.dump_dir = dump_dir 37 | 38 | return super(FootballEvaluator, self).setup_run_dir(all_args) 39 | 40 | def make_eval_env(self, all_args, Env_Class, SubprocVecEnv, DummyVecEnv): 41 | 42 | def get_env_fn(rank): 43 | def init_env(): 44 | env = Env_Class(all_args, rank, log_dir=str(self.dump_dir), isEval=True) 45 | env.seed(all_args.seed * 50000 + rank * 10000) 46 | return env 47 | return init_env 48 | 49 | if all_args.n_eval_rollout_threads == 1: 50 | return DummyVecEnv([get_env_fn(0)]) 51 | else: 52 | return SubprocVecEnv([get_env_fn(i) for i in range(all_args.n_eval_rollout_threads)]) 53 | 54 | def extra_args_func(self, args, parser): 55 | parser.add_argument('--scenario_name', type=str, 56 | default='simple_spread', help="Which scenario to run on") 57 | parser.add_argument('--num_agents', type=int, 58 | default=0, help="number of players") 59 | 60 | # football config 61 | parser.add_argument('--representation', type=str, 62 | default='raw', help="format of the observation in gfootball env") 63 | parser.add_argument('--rewards', type=str, 64 | default='scoring', help="format of the reward in gfootball env") 65 | parser.add_argument("--render_only", action='store_true', default=False, 66 | help="if ture, render without training") 67 | 68 | all_args = parser.parse_known_args(args)[0] 69 | return all_args 70 | 71 | def get_env(self): 72 | return RllibGFootball, ShareSubprocVecEnv, ShareDummyVecEnv 73 | 74 | def init_driver(self): 75 | if not self.all_args.separate_policy: 76 | from tmarl.drivers.shared_distributed.football_driver import FootballDriver as Driver 77 | else: 78 | raise NotImplementedError 79 | driver = Driver(self.config) 80 | return driver 81 | 82 | 83 | def main(argv): 84 | evaluator = FootballEvaluator(argv) 85 | evaluator.run() 86 | 87 | 88 | if __name__ == "__main__": 89 | main(sys.argv[1:]) 90 | -------------------------------------------------------------------------------- /tmarl/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TARTRL/TiKick/3521627be1b4c65215157f748fa0df6a35f61c23/tmarl/utils/__init__.py -------------------------------------------------------------------------------- /tmarl/utils/gpu_mem_track.py: -------------------------------------------------------------------------------- 1 | # code from https://github.com/Oldpan/Pytorch-Memory-Utils 2 | 3 | import gc 4 | import datetime 5 | import inspect 6 | 7 | import torch 8 | import numpy as np 9 | 10 | dtype_memory_size_dict = { 11 | torch.float64: 64/8, 12 | torch.double: 64/8, 13 | torch.float32: 32/8, 14 | torch.float: 32/8, 15 | torch.float16: 16/8, 16 | torch.half: 16/8, 17 | torch.int64: 64/8, 18 | torch.long: 64/8, 19 | torch.int32: 32/8, 20 | torch.int: 32/8, 21 | torch.int16: 16/8, 22 | torch.short: 16/6, 23 | torch.uint8: 8/8, 24 | torch.int8: 8/8, 25 | } 26 | 27 | # compatibility of torch1.0 28 | if getattr(torch, "bfloat16", None) is not None: 29 | dtype_memory_size_dict[torch.bfloat16] = 16/8 30 | if getattr(torch, "bool", None) is not None: 31 | dtype_memory_size_dict[torch.bool] = 8/8 # pytorch use 1 byte for a bool, see https://github.com/pytorch/pytorch/issues/41571 32 | 33 | def get_mem_space(x): 34 | try: 35 | ret = dtype_memory_size_dict[x] 36 | except KeyError: 37 | print(f"dtype {x} is not supported!") 38 | return ret 39 | 40 | class MemTracker(object): 41 | """ 42 | Class used to track pytorch memory usage 43 | Arguments: 44 | detail(bool, default True): whether the function shows the detail gpu memory usage 45 | path(str): where to save log file 46 | verbose(bool, default False): whether show the trivial exception 47 | device(int): GPU number, default is 0 48 | """ 49 | def __init__(self, detail=True, path='', verbose=False, device=0): 50 | self.print_detail = detail 51 | self.last_tensor_sizes = set() 52 | self.gpu_profile_fn = path + f'{datetime.datetime.now():%d-%b-%y-%H:%M:%S}-gpu_mem_track.txt' 53 | self.verbose = verbose 54 | self.begin = True 55 | self.device = device 56 | 57 | def get_tensors(self): 58 | for obj in gc.get_objects(): 59 | try: 60 | if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): 61 | tensor = obj 62 | else: 63 | continue 64 | if tensor.is_cuda: 65 | yield tensor 66 | except Exception as e: 67 | if self.verbose: 68 | print('A trivial exception occured: {}'.format(e)) 69 | 70 | def get_tensor_usage(self): 71 | sizes = [np.prod(np.array(tensor.size())) * get_mem_space(tensor.dtype) for tensor in self.get_tensors()] 72 | return np.sum(sizes) / 1024**2 73 | 74 | def get_allocate_usage(self): 75 | return torch.cuda.memory_allocated() / 1024**2 76 | 77 | def clear_cache(self): 78 | gc.collect() 79 | torch.cuda.empty_cache() 80 | 81 | def print_all_gpu_tensor(self, file=None): 82 | for x in self.get_tensors(): 83 | print(x.size(), x.dtype, np.prod(np.array(x.size()))*get_mem_space(x.dtype)/1024**2, file=file) 84 | 85 | def track(self): 86 | """ 87 | Track the GPU memory usage 88 | """ 89 | frameinfo = inspect.stack()[1] 90 | where_str = frameinfo.filename + ' line ' + str(frameinfo.lineno) + ': ' + frameinfo.function 91 | 92 | with open(self.gpu_profile_fn, 'a+') as f: 93 | 94 | if self.begin: 95 | f.write(f"GPU Memory Track | {datetime.datetime.now():%d-%b-%y-%H:%M:%S} |" 96 | f" Total Tensor Used Memory:{self.get_tensor_usage():<7.1f}Mb" 97 | f" Total Allocated Memory:{self.get_allocate_usage():<7.1f}Mb\n\n") 98 | self.begin = False 99 | 100 | if self.print_detail is True: 101 | ts_list = [(tensor.size(), tensor.dtype) for tensor in self.get_tensors()] 102 | new_tensor_sizes = {(type(x), 103 | tuple(x.size()), 104 | ts_list.count((x.size(), x.dtype)), 105 | np.prod(np.array(x.size()))*get_mem_space(x.dtype)/1024**2, 106 | x.dtype) for x in self.get_tensors()} 107 | for t, s, n, m, data_type in new_tensor_sizes - self.last_tensor_sizes: 108 | f.write(f'+ | {str(n)} * Size:{str(s):<20} | Memory: {str(m*n)[:6]} M | {str(t):<20} | {data_type}\n') 109 | for t, s, n, m, data_type in self.last_tensor_sizes - new_tensor_sizes: 110 | f.write(f'- | {str(n)} * Size:{str(s):<20} | Memory: {str(m*n)[:6]} M | {str(t):<20} | {data_type}\n') 111 | 112 | self.last_tensor_sizes = new_tensor_sizes 113 | 114 | f.write(f"\nAt {where_str:<50}" 115 | f" Total Tensor Used Memory:{self.get_tensor_usage():<7.1f}Mb" 116 | f" Total Allocated Memory:{self.get_allocate_usage():<7.1f}Mb\n\n") 117 | -------------------------------------------------------------------------------- /tmarl/utils/modelsize_estimate.py: -------------------------------------------------------------------------------- 1 | # code from https://github.com/Oldpan/Pytorch-Memory-Utils 2 | 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | 7 | def modelsize(model, input, type_size=4): 8 | para = sum([np.prod(list(p.size())) for p in model.parameters()]) 9 | # print('Model {} : Number of params: {}'.format(model._get_name(), para)) 10 | print('Model {} : params: {:4f}M'.format(model._get_name(), para * type_size / 1000 / 1000)) 11 | 12 | input_ = input.clone() 13 | input_.requires_grad_(requires_grad=False) 14 | 15 | mods = list(model.modules()) 16 | out_sizes = [] 17 | 18 | for i in range(1, len(mods)): 19 | m = mods[i] 20 | if isinstance(m, nn.ReLU): 21 | if m.inplace: 22 | continue 23 | out = m(input_) 24 | out_sizes.append(np.array(out.size())) 25 | input_ = out 26 | 27 | total_nums = 0 28 | for i in range(len(out_sizes)): 29 | s = out_sizes[i] 30 | nums = np.prod(np.array(s)) 31 | total_nums += nums 32 | 33 | # print('Model {} : Number of intermedite variables without backward: {}'.format(model._get_name(), total_nums)) 34 | # print('Model {} : Number of intermedite variables with backward: {}'.format(model._get_name(), total_nums*2)) 35 | print('Model {} : intermedite variables: {:3f} M (without backward)' 36 | .format(model._get_name(), total_nums * type_size / 1000 / 1000)) 37 | print('Model {} : intermedite variables: {:3f} M (with backward)' 38 | .format(model._get_name(), total_nums * type_size*2 / 1000 / 1000)) 39 | 40 | -------------------------------------------------------------------------------- /tmarl/utils/multi_discrete.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | # An old version of OpenAI Gym's multi_discrete.py. (Was getting affected by Gym updates) 5 | # (https://github.com/openai/gym/blob/1fb81d4e3fb780ccf77fec731287ba07da35eb84/gym/spaces/multi_discrete.py) 6 | class MultiDiscrete(gym.Space): 7 | """ 8 | - The multi-discrete action space consists of a series of discrete action spaces with different parameters 9 | - It can be adapted to both a Discrete action space or a continuous (Box) action space 10 | - It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space 11 | - It is parametrized by passing an array of arrays containing [min, max] for each discrete action space where the discrete action space can take any integers from `min` to `max` (both inclusive) 12 | Note: A value of 0 always need to represent the NOOP action. 13 | e.g. Nintendo Game Controller 14 | - Can be conceptualized as 3 discrete action spaces: 15 | 1) Arrow Keys: Discrete 5 - NOOP[0], UP[1], RIGHT[2], DOWN[3], LEFT[4] - params: min: 0, max: 4 16 | 2) Button A: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1 17 | 3) Button B: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1 18 | - Can be initialized as 19 | MultiDiscrete([ [0,4], [0,1], [0,1] ]) 20 | """ 21 | 22 | def __init__(self, array_of_param_array): 23 | self.low = np.array([x[0] for x in array_of_param_array]) 24 | self.high = np.array([x[1] for x in array_of_param_array]) 25 | self.num_discrete_space = self.low.shape[0] 26 | self.n = np.sum(self.high) + 2 27 | 28 | def sample(self): 29 | """ Returns a array with one sample from each discrete action space """ 30 | # For each row: round(random .* (max - min) + min, 0) 31 | random_array = np.random.rand(self.num_discrete_space) 32 | return [int(x) for x in np.floor(np.multiply((self.high - self.low + 1.), random_array) + self.low)] 33 | 34 | def contains(self, x): 35 | return len(x) == self.num_discrete_space and (np.array(x) >= self.low).all() and (np.array(x) <= self.high).all() 36 | 37 | @property 38 | def shape(self): 39 | return self.num_discrete_space 40 | 41 | def __repr__(self): 42 | return "MultiDiscrete" + str(self.num_discrete_space) 43 | 44 | def __eq__(self, other): 45 | return np.array_equal(self.low, other.low) and np.array_equal(self.high, other.high) 46 | -------------------------------------------------------------------------------- /tmarl/utils/segment_tree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def unique(sorted_array): 5 | """ 6 | More efficient implementation of np.unique for sorted arrays 7 | :param sorted_array: (np.ndarray) 8 | :return:(np.ndarray) sorted_array without duplicate elements 9 | """ 10 | if len(sorted_array) == 1: 11 | return sorted_array 12 | left = sorted_array[:-1] 13 | right = sorted_array[1:] 14 | uniques = np.append(right != left, True) 15 | return sorted_array[uniques] 16 | 17 | 18 | class SegmentTree(object): 19 | def __init__(self, capacity, operation, neutral_element): 20 | """ 21 | Build a Segment Tree data structure. 22 | https://en.wikipedia.org/wiki/Segment_tree 23 | Can be used as regular array that supports Index arrays, but with two 24 | important differences: 25 | a) setting item's value is slightly slower. 26 | It is O(lg capacity) instead of O(1). 27 | b) user has access to an efficient ( O(log segment size) ) 28 | `reduce` operation which reduces `operation` over 29 | a contiguous subsequence of items in the array. 30 | :param capacity: (int) Total size of the array - must be a power of two. 31 | :param operation: (lambda (Any, Any): Any) operation for combining elements (eg. sum, max) must form a 32 | mathematical group together with the set of possible values for array elements (i.e. be associative) 33 | :param neutral_element: (Any) neutral element for the operation above. eg. float('-inf') for max and 0 for sum. 34 | """ 35 | assert capacity > 0 and capacity & ( 36 | capacity - 1) == 0, "capacity must be positive and a power of 2." 37 | self._capacity = capacity 38 | self._value = [neutral_element for _ in range(2 * capacity)] 39 | self._operation = operation 40 | self.neutral_element = neutral_element 41 | 42 | def _reduce_helper(self, start, end, node, node_start, node_end): 43 | if start == node_start and end == node_end: 44 | return self._value[node] 45 | mid = (node_start + node_end) // 2 46 | if end <= mid: 47 | return self._reduce_helper(start, end, 2 * node, node_start, mid) 48 | else: 49 | if mid + 1 <= start: 50 | return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end) 51 | else: 52 | return self._operation( 53 | self._reduce_helper(start, mid, 2 * node, node_start, mid), 54 | self._reduce_helper( 55 | mid + 1, end, 2 * node + 1, mid + 1, node_end) 56 | ) 57 | 58 | def reduce(self, start=0, end=None): 59 | """ 60 | Returns result of applying `self.operation` 61 | to a contiguous subsequence of the array. 62 | self.operation(arr[start], operation(arr[start+1], operation(... arr[end]))) 63 | :param start: (int) beginning of the subsequence 64 | :param end: (int) end of the subsequences 65 | :return: (Any) result of reducing self.operation over the specified range of array elements. 66 | """ 67 | if end is None: 68 | end = self._capacity 69 | if end < 0: 70 | end += self._capacity 71 | end -= 1 72 | return self._reduce_helper(start, end, 1, 0, self._capacity - 1) 73 | 74 | def __setitem__(self, idx, val): 75 | # indexes of the leaf 76 | idxs = idx + self._capacity 77 | self._value[idxs] = val 78 | if isinstance(idxs, int): 79 | idxs = np.array([idxs]) 80 | # go up one level in the tree and remove duplicate indexes 81 | idxs = unique(idxs // 2) 82 | while len(idxs) > 1 or idxs[0] > 0: 83 | # as long as there are non-zero indexes, update the corresponding values 84 | self._value[idxs] = self._operation( 85 | self._value[2 * idxs], 86 | self._value[2 * idxs + 1] 87 | ) 88 | # go up one level in the tree and remove duplicate indexes 89 | idxs = unique(idxs // 2) 90 | 91 | def __getitem__(self, idx): 92 | assert np.max(idx) < self._capacity 93 | assert 0 <= np.min(idx) 94 | return self._value[self._capacity + idx] 95 | 96 | 97 | class SumSegmentTree(SegmentTree): 98 | def __init__(self, capacity): 99 | super(SumSegmentTree, self).__init__( 100 | capacity=capacity, 101 | operation=np.add, 102 | neutral_element=0.0 103 | ) 104 | self._value = np.array(self._value) 105 | 106 | def sum(self, start=0, end=None): 107 | """ 108 | Returns arr[start] + ... + arr[end] 109 | :param start: (int) start position of the reduction (must be >= 0) 110 | :param end: (int) end position of the reduction (must be < len(arr), can be None for len(arr) - 1) 111 | :return: (Any) reduction of SumSegmentTree 112 | """ 113 | return super(SumSegmentTree, self).reduce(start, end) 114 | 115 | def find_prefixsum_idx(self, prefixsum): 116 | """ 117 | Find the highest index `i` in the array such that 118 | sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum for each entry in prefixsum 119 | if array values are probabilities, this function 120 | allows to sample indexes according to the discrete 121 | probability efficiently. 122 | :param prefixsum: (np.ndarray) float upper bounds on the sum of array prefix 123 | :return: (np.ndarray) highest indexes satisfying the prefixsum constraint 124 | """ 125 | if isinstance(prefixsum, float): 126 | prefixsum = np.array([prefixsum]) 127 | assert 0 <= np.min(prefixsum) 128 | assert np.max(prefixsum) <= self.sum() + 1e-5 129 | assert isinstance(prefixsum[0], float) 130 | 131 | idx = np.ones(len(prefixsum), dtype=int) 132 | cont = np.ones(len(prefixsum), dtype=bool) 133 | 134 | while np.any(cont): # while not all nodes are leafs 135 | idx[cont] = 2 * idx[cont] 136 | prefixsum_new = np.where( 137 | self._value[idx] <= prefixsum, prefixsum - self._value[idx], prefixsum) 138 | # prepare update of prefixsum for all right children 139 | idx = np.where(np.logical_or( 140 | self._value[idx] > prefixsum, np.logical_not(cont)), idx, idx + 1) 141 | # Select child node for non-leaf nodes 142 | prefixsum = prefixsum_new 143 | # update prefixsum 144 | cont = idx < self._capacity 145 | # collect leafs 146 | return idx - self._capacity 147 | 148 | 149 | class MinSegmentTree(SegmentTree): 150 | def __init__(self, capacity): 151 | super(MinSegmentTree, self).__init__( 152 | capacity=capacity, 153 | operation=np.minimum, 154 | neutral_element=float('inf') 155 | ) 156 | self._value = np.array(self._value) 157 | 158 | def min(self, start=0, end=None): 159 | """ 160 | Returns min(arr[start], ..., arr[end]) 161 | :param start: (int) start position of the reduction (must be >= 0) 162 | :param end: (int) end position of the reduction (must be < len(arr), can be None for len(arr) - 1) 163 | :return: (Any) reduction of MinSegmentTree 164 | """ 165 | return super(MinSegmentTree, self).reduce(start, end) 166 | -------------------------------------------------------------------------------- /tmarl/utils/util.py: -------------------------------------------------------------------------------- 1 | 2 | import copy 3 | import numpy as np 4 | import math 5 | import gym 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.distributed as dist 10 | from torch.autograd import Variable 11 | from gym.spaces import Box, Discrete, Tuple 12 | 13 | 14 | def check(input): 15 | if type(input) == np.ndarray: 16 | return torch.from_numpy(input) 17 | 18 | 19 | def get_gard_norm(it): 20 | sum_grad = 0 21 | for x in it: 22 | if x.grad is None: 23 | continue 24 | sum_grad += x.grad.norm() ** 2 25 | return math.sqrt(sum_grad) 26 | 27 | 28 | def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr): 29 | """Decreases the learning rate linearly""" 30 | lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs))) 31 | for param_group in optimizer.param_groups: 32 | param_group['lr'] = lr 33 | 34 | 35 | def huber_loss(e, d): 36 | a = (abs(e) <= d).float() 37 | b = (e > d).float() 38 | return a*e**2/2 + b*d*(abs(e)-d/2) 39 | 40 | 41 | def mse_loss(e): 42 | return e**2/2 43 | 44 | 45 | def get_shape_from_obs_space(obs_space): 46 | if obs_space.__class__.__name__ == 'Box': 47 | obs_shape = obs_space.shape 48 | elif obs_space.__class__.__name__ == 'list': 49 | obs_shape = obs_space 50 | elif obs_space.__class__.__name__ == 'Dict': 51 | obs_shape = obs_space.spaces 52 | else: 53 | raise NotImplementedError 54 | return obs_shape 55 | 56 | 57 | def get_shape_from_act_space(act_space): 58 | if act_space.__class__.__name__ == 'Discrete': 59 | act_shape = 1 60 | elif act_space.__class__.__name__ == "MultiDiscrete": 61 | act_shape = act_space.shape 62 | elif act_space.__class__.__name__ == "Box": 63 | act_shape = act_space.shape[0] 64 | elif act_space.__class__.__name__ == "MultiBinary": 65 | act_shape = act_space.shape[0] 66 | else: # agar 67 | act_shape = act_space[0].shape[0] + 1 68 | return act_shape 69 | 70 | 71 | def tile_images(img_nhwc): 72 | """ 73 | Tile N images into one big PxQ image 74 | (P,Q) are chosen to be as close as possible, and if N 75 | is square, then P=Q. 76 | input: img_nhwc, list or array of images, ndim=4 once turned into array 77 | n = batch index, h = height, w = width, c = channel 78 | returns: 79 | bigim_HWc, ndarray with ndim=3 80 | """ 81 | img_nhwc = np.asarray(img_nhwc) 82 | N, h, w, c = img_nhwc.shape 83 | H = int(np.ceil(np.sqrt(N))) 84 | W = int(np.ceil(float(N)/H)) 85 | img_nhwc = np.array( 86 | list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)]) 87 | img_HWhwc = img_nhwc.reshape(H, W, h, w, c) 88 | img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4) 89 | img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c) 90 | return img_Hh_Ww_c 91 | 92 | 93 | def to_torch(input): 94 | return torch.from_numpy(input) if type(input) == np.ndarray else input 95 | 96 | 97 | def to_numpy(x): 98 | return x.detach().cpu().numpy() 99 | 100 | 101 | class FixedCategorical(torch.distributions.Categorical): 102 | def sample(self): 103 | return super().sample() 104 | 105 | def log_probs(self, actions): 106 | return ( 107 | super() 108 | .log_prob(actions.squeeze(-1)) 109 | .view(actions.size(0), -1) 110 | .sum(-1) 111 | .unsqueeze(-1) 112 | ) 113 | 114 | def mode(self): 115 | return self.probs.argmax(dim=-1, keepdim=True) 116 | 117 | 118 | class MultiDiscrete(gym.Space): 119 | """ 120 | - The multi-discrete action space consists of a series of discrete action spaces with different parameters 121 | - It can be adapted to both a Discrete action space or a continuous (Box) action space 122 | - It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space 123 | - It is parametrized by passing an array of arrays containing [min, max] for each discrete action space 124 | where the discrete action space can take any integers from `min` to `max` (both inclusive) 125 | Note: A value of 0 always need to represent the NOOP action. 126 | e.g. Nintendo Game Controller 127 | - Can be conceptualized as 3 discrete action spaces: 128 | 1) Arrow Keys: Discrete 5 - NOOP[0], UP[1], RIGHT[2], DOWN[3], LEFT[4] - params: min: 0, max: 4 129 | 2) Button A: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1 130 | 3) Button B: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1 131 | - Can be initialized as 132 | MultiDiscrete([ [0,4], [0,1], [0,1] ]) 133 | """ 134 | 135 | def __init__(self, array_of_param_array): 136 | self.low = np.array([x[0] for x in array_of_param_array]) 137 | self.high = np.array([x[1] for x in array_of_param_array]) 138 | self.num_discrete_space = self.low.shape[0] 139 | self.n = np.sum(self.high) + 2 140 | 141 | def sample(self): 142 | """ Returns a array with one sample from each discrete action space """ 143 | # For each row: round(random .* (max - min) + min, 0) 144 | random_array = np.random.rand(self.num_discrete_space) 145 | return [int(x) for x in np.floor(np.multiply((self.high - self.low + 1.), random_array) + self.low)] 146 | 147 | def contains(self, x): 148 | return len(x) == self.num_discrete_space and (np.array(x) >= self.low).all() and (np.array(x) <= self.high).all() 149 | 150 | @property 151 | def shape(self): 152 | return self.num_discrete_space 153 | 154 | def __repr__(self): 155 | return "MultiDiscrete" + str(self.num_discrete_space) 156 | 157 | def __eq__(self, other): 158 | return np.array_equal(self.low, other.low) and np.array_equal(self.high, other.high) 159 | 160 | 161 | class DecayThenFlatSchedule(): 162 | def __init__(self, 163 | start, 164 | finish, 165 | time_length, 166 | decay="exp"): 167 | 168 | self.start = start 169 | self.finish = finish 170 | self.time_length = time_length 171 | self.delta = (self.start - self.finish) / self.time_length 172 | self.decay = decay 173 | 174 | if self.decay in ["exp"]: 175 | self.exp_scaling = (-1) * self.time_length / \ 176 | np.log(self.finish) if self.finish > 0 else 1 177 | 178 | def eval(self, T): 179 | if self.decay in ["linear"]: 180 | return max(self.finish, self.start - self.delta * T) 181 | elif self.decay in ["exp"]: 182 | return min(self.start, max(self.finish, np.exp(- T / self.exp_scaling))) 183 | pass 184 | 185 | 186 | def huber_loss(e, d): 187 | a = (abs(e) <= d).float() 188 | b = (e > d).float() 189 | return a*e**2/2 + b*d*(abs(e)-d/2) 190 | 191 | 192 | def mse_loss(e): 193 | return e**2 194 | 195 | 196 | def init(module, weight_init, bias_init, gain=1): 197 | weight_init(module.weight.data, gain=gain) 198 | bias_init(module.bias.data) 199 | return module 200 | 201 | 202 | def get_clones(module, N): 203 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 204 | 205 | # https://github.com/ikostrikov/pytorch-ddpg-naf/blob/master/ddpg.py#L11 206 | 207 | 208 | def soft_update(target, source, tau): 209 | """ 210 | Perform DDPG soft update (move target params toward source based on weight 211 | factor tau) 212 | Inputs: 213 | target (torch.nn.Module): Net to copy parameters to 214 | source (torch.nn.Module): Net whose parameters to copy 215 | tau (float, 0 < x < 1): Weight factor for update 216 | """ 217 | for target_param, param in zip(target.parameters(), source.parameters()): 218 | target_param.data.copy_( 219 | target_param.data * (1.0 - tau) + param.data * tau) 220 | 221 | # https://github.com/ikostrikov/pytorch-ddpg-naf/blob/master/ddpg.py#L15 222 | 223 | 224 | def hard_update(target, source): 225 | """ 226 | Copy network parameters from source to target 227 | Inputs: 228 | target (torch.nn.Module): Net to copy parameters to 229 | source (torch.nn.Module): Net whose parameters to copy 230 | """ 231 | for target_param, param in zip(target.parameters(), source.parameters()): 232 | target_param.data.copy_(param.data) 233 | 234 | # https://github.com/seba-1511/dist_tuto.pth/blob/gh-pages/train_dist.py 235 | 236 | 237 | def average_gradients(model): 238 | """ Gradient averaging. """ 239 | size = float(dist.get_world_size()) 240 | for param in model.parameters(): 241 | dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM, group=0) 242 | param.grad.data /= size 243 | 244 | 245 | def onehot_from_logits(logits, avail_logits=None, eps=0.0): 246 | """ 247 | Given batch of logits, return one-hot sample using epsilon greedy strategy 248 | (based on given epsilon) 249 | """ 250 | # get best (according to current policy) actions in one-hot form 251 | logits = to_torch(logits) 252 | 253 | dim = len(logits.shape) - 1 254 | if avail_logits is not None: 255 | avail_logits = to_torch(avail_logits) 256 | logits[avail_logits == 0] = -1e10 257 | argmax_acs = (logits == logits.max(dim, keepdim=True)[0]).float() 258 | if eps == 0.0: 259 | return argmax_acs 260 | # get random actions in one-hot form 261 | rand_acs = Variable(torch.eye(logits.shape[1])[[np.random.choice( 262 | range(logits.shape[1]), size=logits.shape[0])]], requires_grad=False) 263 | # chooses between best and random actions using epsilon greedy 264 | return torch.stack([argmax_acs[i] if r > eps else rand_acs[i] for i, r in 265 | enumerate(torch.rand(logits.shape[0]))]) 266 | 267 | # modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb 268 | 269 | 270 | def sample_gumbel(shape, eps=1e-20, tens_type=torch.FloatTensor): 271 | """Sample from Gumbel(0, 1)""" 272 | U = Variable(tens_type(*shape).uniform_(), requires_grad=False) 273 | return -torch.log(-torch.log(U + eps) + eps) 274 | 275 | # modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb 276 | 277 | 278 | def gumbel_softmax_sample(logits, avail_logits, temperature, device=torch.device('cpu')): 279 | """ Draw a sample from the Gumbel-Softmax distribution""" 280 | if str(device) == 'cpu': 281 | y = logits + sample_gumbel(logits.shape, tens_type=type(logits.data)) 282 | else: 283 | y = (logits.cpu() + sample_gumbel(logits.shape, 284 | tens_type=type(logits.data))).cuda() 285 | 286 | dim = len(logits.shape) - 1 287 | if avail_logits is not None: 288 | avail_logits = to_torch(avail_logits).to(device) 289 | y[avail_logits == 0] = -1e10 290 | return F.softmax(y / temperature, dim=dim) 291 | 292 | # modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb 293 | 294 | 295 | def gumbel_softmax(logits, avail_logits=None, temperature=1.0, hard=False, device=torch.device('cpu')): 296 | """Sample from the Gumbel-Softmax distribution and optionally discretize. 297 | Args: 298 | logits: [batch_size, n_class] unnormalized log-probs 299 | temperature: non-negative scalar 300 | hard: if True, take argmax, but differentiate w.r.t. soft sample y 301 | Returns: 302 | [batch_size, n_class] sample from the Gumbel-Softmax distribution. 303 | If hard=True, then the returned sample will be one-hot, otherwise it will 304 | be a probabilitiy distribution that sums to 1 across classes 305 | """ 306 | y = gumbel_softmax_sample(logits, avail_logits, temperature, device) 307 | if hard: 308 | y_hard = onehot_from_logits(y) 309 | y = (y_hard - y).detach() + y 310 | return y 311 | 312 | 313 | def gaussian_noise(shape, std): 314 | return torch.empty(shape).normal_(mean=0, std=std) 315 | 316 | 317 | def get_obs_shape(obs_space): 318 | if obs_space.__class__.__name__ == "Box": 319 | obs_shape = obs_space.shape 320 | elif obs_space.__class__.__name__ == "list": 321 | obs_shape = obs_space 322 | else: 323 | raise NotImplementedError 324 | 325 | return obs_shape 326 | 327 | 328 | def get_dim_from_space(space): 329 | if isinstance(space, Box): 330 | dim = space.shape[0] 331 | elif isinstance(space, Discrete): 332 | dim = space.n 333 | elif isinstance(space, Tuple): 334 | dim = sum([get_dim_from_space(sp) for sp in space]) 335 | elif "MultiDiscrete" in space.__class__.__name__: 336 | return (space.high - space.low) + 1 337 | elif isinstance(space, list): 338 | dim = space[0] 339 | else: 340 | raise Exception("Unrecognized space: ", type(space)) 341 | return dim 342 | 343 | 344 | def get_state_dim(observation_dict, action_dict): 345 | combined_obs_dim = sum([get_dim_from_space(space) 346 | for space in observation_dict.values()]) 347 | combined_act_dim = 0 348 | for space in action_dict.values(): 349 | dim = get_dim_from_space(space) 350 | if isinstance(dim, np.ndarray): 351 | combined_act_dim += int(sum(dim)) 352 | else: 353 | combined_act_dim += dim 354 | return combined_obs_dim, combined_act_dim, combined_obs_dim+combined_act_dim 355 | 356 | 357 | def get_cent_act_dim(action_space): 358 | cent_act_dim = 0 359 | for space in action_space: 360 | dim = get_dim_from_space(space) 361 | if isinstance(dim, np.ndarray): 362 | cent_act_dim += int(sum(dim)) 363 | else: 364 | cent_act_dim += dim 365 | return cent_act_dim 366 | 367 | 368 | def is_discrete(space): 369 | if isinstance(space, Discrete) or "MultiDiscrete" in space.__class__.__name__: 370 | return True 371 | else: 372 | return False 373 | 374 | 375 | def is_multidiscrete(space): 376 | if "MultiDiscrete" in space.__class__.__name__: 377 | return True 378 | else: 379 | return False 380 | 381 | 382 | def make_onehot(int_action, action_dim, seq_len=None): 383 | if type(int_action) == torch.Tensor: 384 | int_action = int_action.cpu().numpy() 385 | if not seq_len: 386 | return np.eye(action_dim)[int_action] 387 | if seq_len: 388 | onehot_actions = [] 389 | for i in range(seq_len): 390 | onehot_action = np.eye(action_dim)[int_action[i]] 391 | onehot_actions.append(onehot_action) 392 | return np.stack(onehot_actions) 393 | 394 | 395 | def avail_choose(x, avail_x=None): 396 | x = to_torch(x) 397 | if avail_x is not None: 398 | avail_x = to_torch(avail_x) 399 | x[avail_x == 0] = -1e10 400 | return x # FixedCategorical(logits=x) 401 | 402 | 403 | def tile_images(img_nhwc): 404 | """ 405 | Tile N images into one big PxQ image 406 | (P,Q) are chosen to be as close as possible, and if N 407 | is square, then P=Q. 408 | input: img_nhwc, list or array of images, ndim=4 once turned into array 409 | n = batch index, h = height, w = width, c = channel 410 | returns: 411 | bigim_HWc, ndarray with ndim=3 412 | """ 413 | img_nhwc = np.asarray(img_nhwc) 414 | N, h, w, c = img_nhwc.shape 415 | H = int(np.ceil(np.sqrt(N))) 416 | W = int(np.ceil(float(N)/H)) 417 | img_nhwc = np.array( 418 | list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)]) 419 | img_HWhwc = img_nhwc.reshape(H, W, h, w, c) 420 | img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4) 421 | img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c) 422 | return img_Hh_Ww_c 423 | -------------------------------------------------------------------------------- /tmarl/utils/valuenorm.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class ValueNorm(nn.Module): 9 | """ Normalize a vector of observations - across the first norm_axes dimensions""" 10 | 11 | def __init__(self, input_shape, norm_axes=1, beta=0.99999, per_element_update=False, epsilon=1e-5, device=torch.device("cpu")): 12 | super(ValueNorm, self).__init__() 13 | 14 | self.input_shape = input_shape 15 | self.norm_axes = norm_axes 16 | self.epsilon = epsilon 17 | self.beta = beta 18 | self.per_element_update = per_element_update 19 | self.tpdv = dict(dtype=torch.float32, device=device) 20 | 21 | self.running_mean = nn.Parameter(torch.zeros(input_shape), requires_grad=False).to(**self.tpdv) 22 | self.running_mean_sq = nn.Parameter(torch.zeros(input_shape), requires_grad=False).to(**self.tpdv) 23 | self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False).to(**self.tpdv) 24 | 25 | self.reset_parameters() 26 | 27 | def reset_parameters(self): 28 | self.running_mean.zero_() 29 | self.running_mean_sq.zero_() 30 | self.debiasing_term.zero_() 31 | 32 | def running_mean_var(self): 33 | debiased_mean = self.running_mean / self.debiasing_term.clamp(min=self.epsilon) 34 | debiased_mean_sq = self.running_mean_sq / self.debiasing_term.clamp(min=self.epsilon) 35 | debiased_var = (debiased_mean_sq - debiased_mean ** 2).clamp(min=1e-2) 36 | return debiased_mean, debiased_var 37 | 38 | @torch.no_grad() 39 | def update(self, input_vector): 40 | if type(input_vector) == np.ndarray: 41 | input_vector = torch.from_numpy(input_vector) 42 | input_vector = input_vector.to(**self.tpdv) 43 | 44 | batch_mean = input_vector.mean(dim=tuple(range(self.norm_axes))) 45 | batch_sq_mean = (input_vector ** 2).mean(dim=tuple(range(self.norm_axes))) 46 | 47 | if self.per_element_update: 48 | batch_size = np.prod(input_vector.size()[:self.norm_axes]) 49 | weight = self.beta ** batch_size 50 | else: 51 | weight = self.beta 52 | 53 | self.running_mean.mul_(weight).add_(batch_mean * (1.0 - weight)) 54 | self.running_mean_sq.mul_(weight).add_(batch_sq_mean * (1.0 - weight)) 55 | self.debiasing_term.mul_(weight).add_(1.0 * (1.0 - weight)) 56 | 57 | def normalize(self, input_vector): 58 | # Make sure input is float32 59 | if type(input_vector) == np.ndarray: 60 | input_vector = torch.from_numpy(input_vector) 61 | input_vector = input_vector.to(**self.tpdv) 62 | 63 | mean, var = self.running_mean_var() 64 | out = (input_vector - mean[(None,) * self.norm_axes]) / torch.sqrt(var)[(None,) * self.norm_axes] 65 | 66 | return out 67 | 68 | def denormalize(self, input_vector): 69 | """ Transform normalized data back into original distribution """ 70 | if type(input_vector) == np.ndarray: 71 | input_vector = torch.from_numpy(input_vector) 72 | input_vector = input_vector.to(**self.tpdv) 73 | 74 | mean, var = self.running_mean_var() 75 | out = input_vector * torch.sqrt(var)[(None,) * self.norm_axes] + mean[(None,) * self.norm_axes] 76 | 77 | out = out.cpu().numpy() 78 | 79 | return out 80 | -------------------------------------------------------------------------------- /tmarl/wrappers/TWrapper/README.md: -------------------------------------------------------------------------------- 1 | # TWrapper 2 | 3 | ## 分布式多智能体wrapper 4 | 5 | -------------------------------------------------------------------------------- /tmarl/wrappers/TWrapper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TARTRL/TiKick/3521627be1b4c65215157f748fa0df6a35f61c23/tmarl/wrappers/TWrapper/__init__.py -------------------------------------------------------------------------------- /tmarl/wrappers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TARTRL/TiKick/3521627be1b4c65215157f748fa0df6a35f61c23/tmarl/wrappers/__init__.py --------------------------------------------------------------------------------