├── .gitattributes ├── .github ├── unittest │ └── install_dependencies.sh └── workflows │ ├── pre_commit.yml │ ├── tests-linux.yml │ ├── tests-mac.yml │ └── tests-windows.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CITATION.cff ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── codecov.yml ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── _static │ └── js │ │ └── version_alert.js │ ├── _templates │ ├── autosummary │ │ ├── class.rst │ │ ├── class_no_inherit.rst │ │ ├── class_no_undoc.rst │ │ ├── class_private.rst │ │ └── class_private_no_undoc.rst │ └── breadcrumbs.html │ ├── conf.py │ ├── index.rst │ ├── modules │ ├── root.rst │ └── simulator.rst │ └── usage │ ├── citing.rst │ ├── installation.rst │ ├── notebooks.rst │ └── running.rst ├── mpe_comparison ├── README.md ├── mpe_performance_comparison.py └── vmas_vs_mpe_graphs │ ├── VMAS_vs_MPE_100_steps__intel(r)_xeon(r)_gold_6248r_cpu_@_3.00ghz.pdf │ ├── VMAS_vs_MPE_100_steps__intel(r)_xeon(r)_gold_6248r_cpu_@_3.00ghz.tex │ ├── VMAS_vs_MPE_100_steps_nvidia_geforce_rtx_2080_ti.pdf │ ├── VMAS_vs_MPE_100_steps_nvidia_geforce_rtx_2080_ti.tex │ └── pickled │ ├── VMAS_vs_MPE_100_steps__intel(r)_xeon(r)_gold_6248r_cpu_@_3.00ghz_range_1_100_num_10.pkl │ ├── VMAS_vs_MPE_100_steps__intel(r)_xeon(r)_gold_6248r_cpu_@_3.00ghz_range_1_30000_num_75.pkl │ ├── VMAS_vs_MPE_100_steps_nvidia_geforce_rtx_2080_ti_range_1_100_num_10.pkl │ └── VMAS_vs_MPE_100_steps_nvidia_geforce_rtx_2080_ti_range_1_30000_num_100.pkl ├── notebooks ├── Simulation_and_training_in_VMAS_and_BenchMARL.ipynb ├── VMAS_RLlib.ipynb └── VMAS_Use_vmas_environment.ipynb ├── requirements.txt ├── setup.cfg ├── setup.py ├── tests ├── test_lidar.py ├── test_scenarios │ ├── __init__.py │ ├── test_balance.py │ ├── test_discovery.py │ ├── test_dispersion.py │ ├── test_dropout.py │ ├── test_flocking.py │ ├── test_football.py │ ├── test_give_way.py │ ├── test_navigation.py │ ├── test_passage.py │ ├── test_reverse_transport.py │ ├── test_transport.py │ ├── test_waterfall.py │ └── test_wheel.py ├── test_vmas.py └── test_wrappers │ ├── __init__.py │ ├── test_gym_wrapper.py │ ├── test_gymnasium_vec_wrapper.py │ └── test_gymnasium_wrapper.py └── vmas ├── __init__.py ├── examples ├── __init__.py ├── rllib.py ├── run_heuristic.py └── use_vmas_env.py ├── interactive_rendering.py ├── make_env.py ├── scenarios ├── __init__.py ├── balance.py ├── ball_passage.py ├── ball_trajectory.py ├── buzz_wire.py ├── debug │ ├── __init__.py │ ├── asym_joint.py │ ├── circle_trajectory.py │ ├── diff_drive.py │ ├── drone.py │ ├── goal.py │ ├── het_mass.py │ ├── kinematic_bicycle.py │ ├── line_trajectory.py │ ├── pollock.py │ ├── vel_control.py │ └── waterfall.py ├── discovery.py ├── dispersion.py ├── dropout.py ├── flocking.py ├── football.py ├── give_way.py ├── joint_passage.py ├── joint_passage_size.py ├── mpe │ ├── __init__.py │ ├── simple.py │ ├── simple_adversary.py │ ├── simple_crypto.py │ ├── simple_push.py │ ├── simple_reference.py │ ├── simple_speaker_listener.py │ ├── simple_spread.py │ ├── simple_tag.py │ └── simple_world_comm.py ├── multi_give_way.py ├── navigation.py ├── passage.py ├── reverse_transport.py ├── road_traffic.py ├── sampling.py ├── transport.py ├── wheel.py └── wind_flocking.py ├── scenarios_data └── road_traffic │ └── road_traffic_cpm_lab.xml └── simulator ├── __init__.py ├── controllers ├── __init__.py └── velocity_controller.py ├── core.py ├── dynamics ├── __init__.py ├── common.py ├── diff_drive.py ├── drone.py ├── forward.py ├── holonomic.py ├── holonomic_with_rot.py ├── kinematic_bicycle.py ├── roatation.py └── static.py ├── environment ├── __init__.py ├── environment.py ├── gym │ ├── __init__.py │ ├── base.py │ ├── gym.py │ ├── gymnasium.py │ └── gymnasium_vec.py └── rllib.py ├── heuristic_policy.py ├── joints.py ├── physics.py ├── rendering.py ├── scenario.py ├── secrcode.ttf ├── sensors.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | notebooks/* linguist-documentation 2 | mpe_comparison/* linguist-documentation 3 | -------------------------------------------------------------------------------- /.github/unittest/install_dependencies.sh: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2024. 3 | # ProrokLab (https://www.proroklab.org/) 4 | # All rights reserved. 5 | # 6 | 7 | 8 | python -m pip install --upgrade pip 9 | 10 | pip install -e ".[gymnasium]" 11 | 12 | python -m pip install flake8 pytest pytest-cov tqdm matplotlib==3.8 13 | python -m pip install cvxpylayers # Navigation heuristic 14 | -------------------------------------------------------------------------------- /.github/workflows/pre_commit.yml: -------------------------------------------------------------------------------- 1 | 2 | name: pre-commit 3 | 4 | 5 | on: 6 | push: 7 | branches: [ $default-branch , "main" , "dev" ] 8 | pull_request: 9 | branches: [ $default-branch , "main" ] 10 | 11 | permissions: 12 | contents: read 13 | 14 | 15 | jobs: 16 | pre-commit: 17 | runs-on: ubuntu-latest 18 | strategy: 19 | matrix: 20 | python-version: [ "3.11" ] 21 | steps: 22 | - uses: actions/checkout@v3 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v3 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - uses: pre-commit/action@v3.0.1 28 | -------------------------------------------------------------------------------- /.github/workflows/tests-linux.yml: -------------------------------------------------------------------------------- 1 | 2 | name: pytest-linux 3 | 4 | on: 5 | push: 6 | branches: [ $default-branch , "main" , "dev" ] 7 | pull_request: 8 | branches: [ $default-branch , "main" ] 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | build: 15 | 16 | runs-on: ${{ matrix.os }} 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | os: [ ubuntu-latest ] 21 | python-version: ["3.11"] 22 | steps: 23 | - uses: actions/checkout@v3 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v3 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | - name: Install dependencies 29 | run: | 30 | sudo apt-get update 31 | sudo apt-get install python3-opengl xvfb 32 | bash .github/unittest/install_dependencies.sh 33 | - name: Test with pytest 34 | run: | 35 | xvfb-run -s "-screen 0 1024x768x24" pytest tests/ --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html 36 | - name: Upload coverage to Codecov 37 | uses: codecov/codecov-action@v3 38 | env: 39 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 40 | with: 41 | fail_ci_if_error: false 42 | -------------------------------------------------------------------------------- /.github/workflows/tests-mac.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | name: pytest-mac 4 | 5 | on: 6 | push: 7 | branches: [ $default-branch , "main" , "dev" ] 8 | pull_request: 9 | branches: [ $default-branch , "main" ] 10 | 11 | permissions: 12 | contents: read 13 | 14 | jobs: 15 | build: 16 | 17 | runs-on: ${{ matrix.os }} 18 | strategy: 19 | fail-fast: false 20 | matrix: 21 | os: [ macos-latest ] 22 | python-version: ["3.11"] 23 | 24 | steps: 25 | - uses: actions/checkout@v3 26 | - name: Set up Python ${{ matrix.python-version }} 27 | uses: actions/setup-python@v3 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | - name: Install dependencies 31 | run: | 32 | bash .github/unittest/install_dependencies.sh 33 | - name: Test with pytest 34 | run: | 35 | pytest tests/ --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html 36 | -------------------------------------------------------------------------------- /.github/workflows/tests-windows.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | name: pytest-windows 4 | 5 | on: 6 | push: 7 | branches: [ $default-branch , "main" , "dev" ] 8 | pull_request: 9 | branches: [ $default-branch , "main" ] 10 | 11 | permissions: 12 | contents: read 13 | 14 | jobs: 15 | build: 16 | 17 | runs-on: ${{ matrix.os }} 18 | strategy: 19 | fail-fast: false 20 | matrix: 21 | os: [ windows-latest ] 22 | python-version: ["3.11"] 23 | 24 | steps: 25 | - uses: actions/checkout@v3 26 | - name: Set up Python ${{ matrix.python-version }} 27 | uses: actions/setup-python@v3 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | - name: Install dependencies 31 | run: | 32 | bash .github/unittest/install_dependencies.sh 33 | - name: Test with pytest 34 | run: | 35 | pytest tests/ --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Docs 2 | docs/output/ 3 | docs/source/generated/ 4 | docs/build/ 5 | 6 | ### JetBrains template 7 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 8 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 9 | 10 | # User-specific stuff 11 | .idea/**/workspace.xml 12 | .idea/**/tasks.xml 13 | .idea/**/usage.statistics.xml 14 | .idea/**/dictionaries 15 | .idea/**/shelf 16 | 17 | # Generated files 18 | .idea/**/contentModel.xml 19 | 20 | # Sensitive or high-churn files 21 | .idea/**/dataSources/ 22 | .idea/**/dataSources.ids 23 | .idea/**/dataSources.local.xml 24 | .idea/**/sqlDataSources.xml 25 | .idea/**/dynamic.xml 26 | .idea/**/uiDesigner.xml 27 | .idea/**/dbnavigator.xml 28 | 29 | # Gradle 30 | .idea/**/gradle.xml 31 | .idea/**/libraries 32 | 33 | # Gradle and Maven with auto-import 34 | # When using Gradle or Maven with auto-import, you should exclude module files, 35 | # since they will be recreated, and may cause churn. Uncomment if using 36 | # auto-import. 37 | # .idea/artifacts 38 | # .idea/compiler.xml 39 | # .idea/jarRepositories.xml 40 | # .idea/modules.xml 41 | # .idea/*.iml 42 | # .idea/modules 43 | # *.iml 44 | # *.ipr 45 | 46 | # CMake 47 | cmake-build-*/ 48 | 49 | # Mongo Explorer plugin 50 | .idea/**/mongoSettings.xml 51 | 52 | # File-based project format 53 | *.iws 54 | 55 | .idea 56 | 57 | wandb 58 | 59 | junit 60 | 61 | # IntelliJ 62 | out/ 63 | 64 | # mpeltonen/sbt-idea plugin 65 | .idea_modules/ 66 | 67 | # JIRA plugin 68 | atlassian-ide-plugin.xml 69 | 70 | # Cursive Clojure plugin 71 | .idea/replstate.xml 72 | 73 | # Crashlytics plugin (for Android Studio and IntelliJ) 74 | com_crashlytics_export_strings.xml 75 | crashlytics.properties 76 | crashlytics-build.properties 77 | fabric.properties 78 | 79 | # Editor-based Rest Client 80 | .idea/httpRequests 81 | 82 | # Android studio 3.1+ serialized cache file 83 | .idea/caches/build_file_checksums.ser 84 | 85 | ### macOS template 86 | # General 87 | .DS_Store 88 | .AppleDouble 89 | .LSOverride 90 | 91 | # Icon must end with two \r 92 | Icon 93 | 94 | # Thumbnails 95 | ._* 96 | 97 | # Files that might appear in the root of a volume 98 | .DocumentRevisions-V100 99 | .fseventsd 100 | .Spotlight-V100 101 | .TemporaryItems 102 | .Trashes 103 | .VolumeIcon.icns 104 | .com.apple.timemachine.donotpresent 105 | 106 | # Directories potentially created on remote AFP share 107 | .AppleDB 108 | .AppleDesktop 109 | Network Trash Folder 110 | Temporary Items 111 | .apdisk 112 | 113 | ### Python template 114 | # Byte-compiled / optimized / DLL files 115 | __pycache__/ 116 | *.py[cod] 117 | *$py.class 118 | 119 | # C extensions 120 | *.so 121 | 122 | # Distribution / packaging 123 | .Python 124 | build/ 125 | develop-eggs/ 126 | dist/ 127 | downloads/ 128 | eggs/ 129 | .eggs/ 130 | lib/ 131 | lib64/ 132 | parts/ 133 | sdist/ 134 | var/ 135 | wheels/ 136 | share/python-wheels/ 137 | *.egg-info/ 138 | .installed.cfg 139 | *.egg 140 | MANIFEST 141 | 142 | # PyInstaller 143 | # Usually these files are written by a python script from a template 144 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 145 | *.manifest 146 | *.spec 147 | 148 | # Installer logs 149 | pip-log.txt 150 | pip-delete-this-directory.txt 151 | 152 | # Unit test / coverage reports 153 | htmlcov/ 154 | .tox/ 155 | .nox/ 156 | .coverage 157 | .coverage.* 158 | .cache 159 | nosetests.xml 160 | coverage.xml 161 | *.cover 162 | *.py,cover 163 | .hypothesis/ 164 | .pytest_cache/ 165 | cover/ 166 | 167 | # Translations 168 | *.mo 169 | *.pot 170 | 171 | # Django stuff: 172 | *.log 173 | local_settings.py 174 | db.sqlite3 175 | db.sqlite3-journal 176 | 177 | # Flask stuff: 178 | instance/ 179 | .webassets-cache 180 | 181 | # Scrapy stuff: 182 | .scrapy 183 | 184 | # Sphinx documentation 185 | docs/_build/ 186 | 187 | # PyBuilder 188 | .pybuilder/ 189 | target/ 190 | 191 | # Jupyter Notebook 192 | .ipynb_checkpoints 193 | 194 | # IPython 195 | profile_default/ 196 | ipython_config.py 197 | 198 | # pyenv 199 | # For a library or package, you might want to ignore these files since the code is 200 | # intended to run in multiple environments; otherwise, check them in: 201 | # .python-version 202 | 203 | # pipenv 204 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 205 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 206 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 207 | # install all needed dependencies. 208 | #Pipfile.lock 209 | 210 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 211 | __pypackages__/ 212 | 213 | # Celery stuff 214 | celerybeat-schedule 215 | celerybeat.pid 216 | 217 | # SageMath parsed files 218 | *.sage.py 219 | 220 | # Environments 221 | .env 222 | .venv 223 | env/ 224 | venv/ 225 | ENV/ 226 | env.bak/ 227 | venv.bak/ 228 | 229 | # Spyder project settings 230 | .spyderproject 231 | .spyproject 232 | 233 | # Rope project settings 234 | .ropeproject 235 | 236 | # mkdocs documentation 237 | /site 238 | 239 | # mypy 240 | .mypy_cache/ 241 | .dmypy.json 242 | dmypy.json 243 | 244 | # Pyre type checker 245 | .pyre/ 246 | 247 | # pytype static type analyzer 248 | .pytype/ 249 | 250 | # Cython debug symbols 251 | cython_debug/ 252 | 253 | !/.idea/multiagentparticlesimulator.iml 254 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "mpe_comparison/mpe"] 2 | path = mpe_comparison/mpe 3 | url = https://github.com/matteobettini/multiagent-particle-envs.git 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.0.1 4 | hooks: 5 | - id: check-toml 6 | - id: check-yaml 7 | exclude: packaging/.* 8 | - id: mixed-line-ending 9 | args: [--fix=lf] 10 | - id: end-of-file-fixer 11 | 12 | - repo: https://github.com/omnilib/ufmt 13 | rev: v2.0.0b2 14 | hooks: 15 | - id: ufmt 16 | additional_dependencies: 17 | - black == 22.3.0 18 | - usort == 1.0.3 19 | - libcst == 1.4.0 20 | 21 | - repo: https://github.com/pycqa/flake8 22 | rev: 4.0.1 23 | hooks: 24 | - id: flake8 25 | args: [--config=setup.cfg] 26 | additional_dependencies: 27 | - flake8-bugbear==22.10.27 28 | - flake8-comprehensions==3.10.1 29 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the OS, Python version and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.10" 13 | 14 | # Build documentation in the "docs/" directory with Sphinx 15 | sphinx: 16 | fail_on_warning: true 17 | configuration: docs/source/conf.py 18 | 19 | # Optionally build your docs in additional formats such as PDF and ePub 20 | formats: 21 | - epub 22 | 23 | # Optional but recommended, declare the Python requirements required 24 | # to build your documentation 25 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 26 | python: 27 | install: 28 | - requirements: docs/requirements.txt 29 | # Install our python package before building the docs 30 | - method: pip 31 | path: . 32 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: If you use this software, please cite it as below. 3 | title: VMAS 4 | authors: 5 | - family-names: Bettini 6 | given-names: Matteo 7 | preferred-citation: 8 | type: conference-paper 9 | title: "VMAS: A Vectorized Multi-Agent Simulator for Collective Robot Learning" 10 | authors: 11 | - family-names: Bettini 12 | given-names: Matteo 13 | - family-names: Kortvelesy 14 | given-names: Ryan 15 | - family-names: Blumenkamp 16 | given-names: Jan 17 | - family-names: Prorok 18 | given-names: Amanda 19 | collection-title: Proceedings of the 16th International Symposium on Distributed Autonomous Robotic Systems 20 | publisher: Springer 21 | year: 2023 22 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to VMAS 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Installing the library 6 | 7 | To contribute, it is suggested to install the library (or your fork of it) from source: 8 | 9 | ```bash 10 | git clone https://github.com/proroklab/VectorizedMultiAgentSimulator.git 11 | cd VectorizedMultiAgentSimulator 12 | python setup.py develop 13 | ``` 14 | 15 | ## Formatting your code 16 | 17 | Before your PR is ready, you'll probably want your code to be checked. This can be done easily by installing 18 | ``` 19 | pip install pre-commit 20 | ``` 21 | and running 22 | ``` 23 | pre-commit run --all-files 24 | ``` 25 | from within the vmas cloned directory. 26 | 27 | You can also install [pre-commit hooks](https://pre-commit.com/) (using `pre-commit install` 28 | ). You can disable the check by appending `-n` to your commit command: `git commit -m -n` 29 | 30 | ## Pull Requests 31 | We actively welcome your pull requests. 32 | 33 | 1. Fork the repo and create your branch from `main`. 34 | 2. If you've added code that should be tested, add tests. 35 | 3. If you've changed APIs, update the documentation. 36 | 4. Ensure the test suite and the documentation pass. 37 | 5. Make sure your code lints. 38 | 39 | When submitting a PR, we encourage you to link it to the related issue (if any) and add some tags to it. 40 | 41 | 42 | ## License 43 | By contributing to vmas, you agree that your contributions will be licensed 44 | under the license of the project 45 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include vmas/simulator/secrcode.ttf 2 | recursive-include vmas/scenarios_data * 3 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | patch: off 4 | project: off 5 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/matteobettini/benchmarl_sphinx_theme.git 2 | numpy 3 | torch 4 | pyglet<=1.5.27 5 | gym 6 | six 7 | -------------------------------------------------------------------------------- /docs/source/_static/js/version_alert.js: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2024. 3 | * ProrokLab (https://www.proroklab.org/) 4 | * All rights reserved. 5 | */ 6 | 7 | 8 | function warnOnLatestVersion() { 9 | if (!window.READTHEDOCS_DATA || window.READTHEDOCS_DATA.version !== "latest") { 10 | return; // not on ReadTheDocs and not latest. 11 | } 12 | 13 | var note = document.createElement('div'); 14 | note.setAttribute('class', 'admonition note'); 15 | note.innerHTML = "

Note

" + 16 | "

" + 17 | "This documentation is for an unreleased development version. " + 18 | "Click here to access the documentation of the current stable release." + 19 | "

"; 20 | 21 | var parent = document.querySelector('#vmas'); 22 | if (parent) 23 | parent.insertBefore(note, parent.querySelector('h1')); 24 | } 25 | 26 | document.addEventListener('DOMContentLoaded', warnOnLatestVersion); 27 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline }} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :show-inheritance: 7 | :members: 8 | :undoc-members: 9 | :inherited-members: 10 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/class_no_inherit.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline }} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :show-inheritance: 7 | :members: 8 | :undoc-members: 9 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/class_no_undoc.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline }} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :show-inheritance: 7 | :members: 8 | :inherited-members: 9 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/class_private.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline }} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :show-inheritance: 7 | :members: 8 | :undoc-members: 9 | :private-members: 10 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/class_private_no_undoc.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline }} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :show-inheritance: 7 | :members: 8 | :private-members: 9 | -------------------------------------------------------------------------------- /docs/source/_templates/breadcrumbs.html: -------------------------------------------------------------------------------- 1 | {%- extends "sphinx_rtd_theme/breadcrumbs.html" %} 2 | 3 | {% block breadcrumbs_aside %} 4 | {% endblock %} 9 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | # Configuration file for the Sphinx documentation builder. 6 | import os.path as osp 7 | import sys 8 | 9 | import benchmarl_sphinx_theme 10 | 11 | import vmas 12 | 13 | # -- Project information 14 | 15 | project = "VMAS" 16 | copyright = "ProrokLab" 17 | author = "Matteo Bettini" 18 | version = vmas.__version__ 19 | 20 | 21 | # -- General configuration 22 | sys.path.append(osp.join(osp.dirname(benchmarl_sphinx_theme.__file__), "extension")) 23 | 24 | extensions = [ 25 | "sphinx.ext.duration", 26 | "sphinx.ext.doctest", 27 | "sphinx.ext.autodoc", 28 | "sphinx.ext.autosummary", 29 | "sphinx.ext.napoleon", 30 | "sphinx.ext.intersphinx", 31 | "sphinx.ext.viewcode", 32 | "patch", 33 | ] 34 | 35 | add_module_names = False 36 | autodoc_member_order = "bysource" 37 | toc_object_entries = False 38 | 39 | intersphinx_mapping = { 40 | "python": ("https://docs.python.org/3/", None), 41 | "sphinx": ("https://www.sphinx-doc.org/en/master/", None), 42 | "torch": ("https://pytorch.org/docs/stable/", None), 43 | "torchrl": ("https://pytorch.org/rl/stable/", None), 44 | "tensordict": ("https://pytorch.org/tensordict/stable", None), 45 | "benchmarl": ("https://benchmarl.readthedocs.io/en/latest/", None), 46 | } 47 | intersphinx_disabled_domains = ["std"] 48 | 49 | templates_path = ["_templates"] 50 | html_static_path = [ 51 | osp.join(osp.dirname(benchmarl_sphinx_theme.__file__), "static"), 52 | "_static", 53 | ] 54 | 55 | 56 | html_theme = "sphinx_rtd_theme" 57 | # html_logo = ( 58 | # "https://raw.githubusercontent.com/matteobettini/benchmarl_sphinx_theme/master/benchmarl" 59 | # "_sphinx_theme/static/img/benchmarl_logo.png" 60 | # ) 61 | html_theme_options = {"logo_only": False, "navigation_depth": 2} 62 | # html_favicon = ('') 63 | html_css_files = [ 64 | "css/mytheme.css", 65 | ] 66 | 67 | # -- Options for EPUB output 68 | epub_show_urls = "footnote" 69 | 70 | 71 | def setup(app): 72 | def rst_jinja_render(app, _, source): 73 | rst_context = {"vmas": vmas} 74 | source[0] = app.builder.templates.render_string(source[0], rst_context) 75 | 76 | app.connect("source-read", rst_jinja_render) 77 | app.add_js_file("js/version_alert.js") 78 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | 2 | VMAS 3 | ==== 4 | 5 | .. discord_button:: 6 | https://discord.gg/dg8txxDW5t 7 | 8 | .. figure:: https://raw.githubusercontent.com/matteobettini/vmas-media/master/media/vmas_scenarios_more.gif 9 | :align: center 10 | 11 | 12 | :github:`null` `GitHub `__ 13 | 14 | Docs are currently being written, please bear with us. 15 | 16 | Forthcoming doc sections: 17 | 18 | 19 | - Concepts (explanation of all the lib features) (divide in basic, advanced and medium) 20 | - Input and output spaces 21 | - Resetting and vectorization (which properties are vectorized) 22 | - Implementing your scenario 23 | - Customizing entities 24 | - Controllers 25 | - Dynamics models 26 | - Communication actions 27 | - Extra actions 28 | - Customizing the world 29 | - Sensors 30 | - Joints 31 | - Differentiability 32 | - Rendering 33 | - Plot function under rendering 34 | - Interactive rendering 35 | 36 | - Scenarios description and renderings 37 | 38 | - Full package reference with docstrings for all public functions 39 | 40 | .. toctree:: 41 | :maxdepth: 1 42 | :caption: Using 43 | 44 | usage/notebooks 45 | usage/installation 46 | usage/running 47 | usage/citing 48 | 49 | .. toctree:: 50 | :maxdepth: 1 51 | :caption: Concepts 52 | 53 | 54 | .. toctree:: 55 | :maxdepth: 1 56 | :caption: Package Reference 57 | 58 | modules/root 59 | modules/simulator 60 | -------------------------------------------------------------------------------- /docs/source/modules/root.rst: -------------------------------------------------------------------------------- 1 | vmas 2 | ==== 3 | 4 | 5 | .. automodule:: vmas 6 | :members: 7 | -------------------------------------------------------------------------------- /docs/source/modules/simulator.rst: -------------------------------------------------------------------------------- 1 | vmas.simulator 2 | ============== 3 | 4 | .. currentmodule:: vmas.simulator 5 | 6 | .. contents:: Contents 7 | :local: 8 | 9 | Scenario 10 | -------- 11 | 12 | .. currentmodule:: vmas.simulator.scenario 13 | 14 | .. autosummary:: 15 | :nosignatures: 16 | :toctree: ../generated 17 | :template: autosummary/class_no_undoc.rst 18 | 19 | BaseScenario 20 | -------------------------------------------------------------------------------- /docs/source/usage/citing.rst: -------------------------------------------------------------------------------- 1 | Citing 2 | ====== 3 | 4 | If you use VMAS in your research please use the following BibTeX entry: 5 | 6 | 7 | .. code-block:: bibtex 8 | 9 | @inproceedings{bettini2022vmas, 10 | title = {VMAS: A Vectorized Multi-Agent Simulator for Collective Robot Learning}, 11 | author = {Bettini, Matteo and Kortvelesy, Ryan and Blumenkamp, Jan and Prorok, Amanda}, 12 | year = {2022}, 13 | booktitle = {Proceedings of the 16th International Symposium on Distributed Autonomous Robotic Systems}, 14 | publisher = {Springer}, 15 | series = {DARS '22} 16 | } 17 | -------------------------------------------------------------------------------- /docs/source/usage/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | 5 | Install from PyPi 6 | ----------------- 7 | 8 | You can install `VMAS `__ from PyPi. 9 | 10 | .. code-block:: console 11 | 12 | pip install vmas 13 | 14 | Install from source 15 | ------------------- 16 | 17 | If you want to install the current main version (more up to date than latest release), you can do: 18 | 19 | .. code-block:: console 20 | 21 | git clone https://github.com/proroklab/VectorizedMultiAgentSimulator.git 22 | cd VectorizedMultiAgentSimulator 23 | pip install -e . 24 | 25 | 26 | Install optional requirements 27 | ----------------------------- 28 | 29 | By default, vmas has only the core requirements. 30 | Here are some optional packages you may want to install. 31 | 32 | Wrappers 33 | ^^^^^^^^ 34 | 35 | If you want to use VMAS environment wrappers, you may want to install VMAS 36 | with the following options: 37 | 38 | .. code-block:: console 39 | 40 | # install gymnasium for gymnasium wrapper 41 | pip install vmas[gymnasium] 42 | 43 | # install rllib for rllib wrapper 44 | pip install vmas[rllib] 45 | 46 | 47 | Training 48 | ^^^^^^^^ 49 | 50 | You may want to install one of the following training libraries 51 | 52 | .. code-block:: console 53 | 54 | pip install benchmarl 55 | pip install torchrl 56 | pip install "ray[rllib]"==2.1.0 # We support versions "ray[rllib]<=2.2,>=1.13" 57 | 58 | Utils 59 | ^^^^^ 60 | 61 | You may want to install the following additional tools 62 | 63 | .. code-block:: console 64 | 65 | # install rendering dependencies 66 | pip install vmas[render] 67 | # install testing dependencies 68 | pip install vmas[test] 69 | -------------------------------------------------------------------------------- /docs/source/usage/notebooks.rst: -------------------------------------------------------------------------------- 1 | Notebooks 2 | ========= 3 | 4 | In the following you can find a list of :colab:`null` Google Colab notebooks to help you learn how to use VMAS: 5 | 6 | - :colab:`null` `Using a VMAS environment `_. Here is a simple notebook that you can run to create, step and render any scenario in VMAS. It reproduces the ``use_vmas_env.py`` script in the ``examples`` folder 7 | - :colab:`null` `Creating a VMAS scenario and training it in BenchMARL `_. We will create a scenario where multiple robots with different embodiments need to navigate to their goals while avoiding each other (as well as obstacles) and train it using MAPPO and MLP/GNN policies. 8 | - :colab:`null` `Training VMAS in BenchMARL (suggested) `_. In this notebook, we show how to use VMAS in BenchMARL, TorchRL's MARL training library 9 | - :colab:`null` `Training VMAS in TorchRL `_. In this notebook, `available in the TorchRL docs `__, we show how to use any VMAS scenario in TorchRL. It will guide you through the full pipeline needed to train agents using MAPPO/IPPO. 10 | - :colab:`null` `Training competitive VMAS MPE in TorchRL `_. In this notebook, `available in the TorchRL docs `__, we show how to solve a Competitive Multi-Agent Reinforcement Learning (MARL) problem using MADDPG/IDDPG. 11 | - :colab:`null` `Training VMAS in RLlib `_. In this notebook, we show how to use any VMAS scenario in RLlib. It reproduces the ``rllib.py`` script in the ``examples`` folder. 12 | -------------------------------------------------------------------------------- /docs/source/usage/running.rst: -------------------------------------------------------------------------------- 1 | Running 2 | ======= 3 | 4 | To use the simulator, simply create an environment by passing the name of the scenario 5 | you want (from the ``scenarios`` folder) to the :class:`vmas.make_env` function. 6 | The function arguments are explained in the documentation. The function returns an environment 7 | object which you can step and reset. 8 | 9 | .. code-block:: python 10 | 11 | import vmas 12 | 13 | # Create the environment 14 | env = vmas.make_env( 15 | scenario="waterfall", # can be scenario name or BaseScenario class 16 | num_envs=32, 17 | device="cpu", # Or "cuda" for GPU 18 | continuous_actions=True, 19 | max_steps=None, # Defines the horizon. None is infinite horizon. 20 | seed=None, # Seed of the environment 21 | n_agents=3 # Additional arguments you want to pass to the scenario 22 | ) 23 | # Reset it 24 | obs = env.reset() 25 | 26 | # Step it with deterministic actions (all agents take their maximum range action) 27 | for _ in range(10): 28 | obs, rews, dones, info = env.step(env.get_random_actions()) 29 | 30 | Here is a python example on how you can execute vmas environments. 31 | 32 | .. python_example_button:: 33 | https://github.com/proroklab/VectorizedMultiAgentSimulator/blob/main/vmas/examples/use_vmas_env.py 34 | 35 | The `Concepts` documentation contains a series of sections that 36 | can help you get familiar with further VMAS functionalities. 37 | -------------------------------------------------------------------------------- /mpe_comparison/README.md: -------------------------------------------------------------------------------- 1 | # VMAS vs MPE 2 | 3 | ## Installing MPE 4 | ``` 5 | git submodule update --init --recursive 6 | ``` 7 | Then `cd` to `mpe_comparison/mpe` and 8 | ``` 9 | pip install -e . 10 | ``` 11 | ## Running the comparison 12 | ``` 13 | python mpe_performance_comparison.py 14 | ``` 15 | -------------------------------------------------------------------------------- /mpe_comparison/vmas_vs_mpe_graphs/VMAS_vs_MPE_100_steps__intel(r)_xeon(r)_gold_6248r_cpu_@_3.00ghz.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proroklab/VectorizedMultiAgentSimulator/acd9b7aca3ca5718f58f03a993c5ef3e920c6af9/mpe_comparison/vmas_vs_mpe_graphs/VMAS_vs_MPE_100_steps__intel(r)_xeon(r)_gold_6248r_cpu_@_3.00ghz.pdf -------------------------------------------------------------------------------- /mpe_comparison/vmas_vs_mpe_graphs/VMAS_vs_MPE_100_steps__intel(r)_xeon(r)_gold_6248r_cpu_@_3.00ghz.tex: -------------------------------------------------------------------------------- 1 | % This file was created with tikzplotlib v0.10.1. 2 | \begin{tikzpicture} 3 | 4 | \definecolor{darkgray176}{RGB}{176,176,176} 5 | \definecolor{darkorange25512714}{RGB}{255,127,14} 6 | \definecolor{lightgray204}{RGB}{204,204,204} 7 | \definecolor{steelblue31119180}{RGB}{31,119,180} 8 | 9 | \begin{axis}[ 10 | legend cell align={left}, 11 | legend style={ 12 | fill opacity=0.8, 13 | draw opacity=1, 14 | text opacity=1, 15 | at={(0.03,0.97)}, 16 | anchor=north west, 17 | draw=lightgray204 18 | }, 19 | tick align=outside, 20 | tick pos=left, 21 | title={Execution time of 'simple\_spread' for 100 steps on Intel(R) Xeon(R) Gold 6248R CPU @ 3.00GHz}, 22 | x grid style={darkgray176}, 23 | xlabel={Number of parallel environments}, 24 | xmin=-1498.95, xmax=31499.95, 25 | xtick style={color=black}, 26 | y grid style={darkgray176}, 27 | ylabel={Seconds}, 28 | ymin=-51.8446296572685, ymax=1089.55411685705, 29 | ytick style={color=black} 30 | ] 31 | \addplot [semithick, steelblue31119180] 32 | table {% 33 | 1 0.0371315479278564 34 | 811.783813476562 27.769702911377 35 | 1217.17565917969 42.5418014526367 36 | 1622.56762695312 56.5375900268555 37 | 2433.35131835938 82.8487319946289 38 | 2838.7431640625 96.2251586914062 39 | 3244.13525390625 111.793746948242 40 | 4054.9189453125 139.141738891602 41 | 4460.31103515625 151.861343383789 42 | 4865.70263671875 165.839370727539 43 | 5271.0947265625 178.706115722656 44 | 5676.486328125 196.890808105469 45 | 6081.87841796875 208.800445556641 46 | 6487.2705078125 221.729568481445 47 | 6892.662109375 236.750442504883 48 | 7298.05419921875 250.950592041016 49 | 7703.44580078125 258.885955810547 50 | 8108.837890625 281.903930664062 51 | 8514.2294921875 298.704986572266 52 | 8919.6220703125 311.454803466797 53 | 9325.013671875 325.562866210938 54 | 9730.4052734375 339.120513916016 55 | 10135.796875 356.37646484375 56 | 10541.189453125 372.168365478516 57 | 10946.5810546875 406.632476806641 58 | 11351.97265625 419.812591552734 59 | 11757.365234375 441.653930664062 60 | 12162.7568359375 445.096435546875 61 | 12568.1484375 444.572143554688 62 | 13378.9326171875 457.838684082031 63 | 14189.7158203125 492.410217285156 64 | 14595.1083984375 515.853942871094 65 | 15000.5 521.164916992188 66 | 15405.8916015625 537.571655273438 67 | 15811.2841796875 556.857421875 68 | 16216.67578125 561.490966796875 69 | 16622.068359375 571.675537109375 70 | 17027.458984375 583.693237304688 71 | 17432.8515625 600.141357421875 72 | 17838.244140625 612.76025390625 73 | 18649.02734375 652.895141601562 74 | 19054.41796875 675.385620117188 75 | 19459.810546875 695.513671875 76 | 19865.203125 697.362670898438 77 | 20270.59375 734.109313964844 78 | 20675.986328125 731.173156738281 79 | 21081.37890625 731.542419433594 80 | 21486.76953125 751.052856445312 81 | 21892.162109375 757.402770996094 82 | 22297.5546875 767.898010253906 83 | 22702.9453125 783.299621582031 84 | 23108.337890625 809 85 | 23513.73046875 816.264587402344 86 | 23919.12109375 826.696960449219 87 | 24324.513671875 846.113037109375 88 | 24729.90625 853.262573242188 89 | 25135.296875 866.059143066406 90 | 25540.689453125 880.532836914062 91 | 26351.47265625 913.076965332031 92 | 26756.865234375 925.855346679688 93 | 27162.255859375 935.866882324219 94 | 27567.6484375 950.4296875 95 | 27973.041015625 981.765686035156 96 | 28378.431640625 986.556518554688 97 | 28783.82421875 997.012451171875 98 | 29189.216796875 1026.40380859375 99 | 29594.607421875 1036.822265625 100 | 30000 1037.67236328125 101 | }; 102 | \addlegendentry{MPE} 103 | \addplot [semithick, darkorange25512714] 104 | table {% 105 | 1 0.177299976348877 106 | 2433.35131835938 2.66015696525574 107 | 2838.7431640625 4.14012432098389 108 | 3244.13525390625 5.1555871963501 109 | 3649.52709960938 6.77866554260254 110 | 4865.70263671875 10.3733282089233 111 | 10135.796875 33.9152946472168 112 | 10541.189453125 36.8092346191406 113 | 12162.7568359375 46.9539108276367 114 | 12568.1484375 46.0832252502441 115 | 12973.541015625 49.1964225769043 116 | 14189.7158203125 55.5887069702148 117 | 14595.1083984375 62.3101005554199 118 | 15000.5 62.238208770752 119 | 15405.8916015625 63.8216400146484 120 | 15811.2841796875 70.4993515014648 121 | 16216.67578125 70.7533569335938 122 | 16622.068359375 66.4213714599609 123 | 17838.244140625 74.271614074707 124 | 18243.634765625 78.513671875 125 | 19054.41796875 84.3815383911133 126 | 19459.810546875 90.2557373046875 127 | 19865.203125 94.3944473266602 128 | 20270.59375 96.620964050293 129 | 20675.986328125 96.2870635986328 130 | 21081.37890625 101.634246826172 131 | 21486.76953125 103.821815490723 132 | 21892.162109375 106.631042480469 133 | 22297.5546875 111.989837646484 134 | 22702.9453125 114.092712402344 135 | 23108.337890625 122.612533569336 136 | 23919.12109375 130.556610107422 137 | 24729.90625 134.852554321289 138 | 25135.296875 140.009552001953 139 | 25946.08203125 150.927124023438 140 | 26351.47265625 155.179916381836 141 | 26756.865234375 157.005630493164 142 | 27162.255859375 160.784942626953 143 | 27973.041015625 170.05143737793 144 | 28378.431640625 175.843399047852 145 | 29189.216796875 183.919906616211 146 | 29594.607421875 186.492630004883 147 | 30000 193.123809814453 148 | }; 149 | \addlegendentry{VMAS} 150 | \end{axis} 151 | 152 | \draw ({$(current bounding box.south west)!0.5!(current bounding box.south east)$}|-{$(current bounding box.south west)!0.98!(current bounding box.north west)$}) node[ 153 | scale=0.8, 154 | anchor=north, 155 | text=black, 156 | rotate=0.0 157 | ]{VMAS vs MPE}; 158 | \end{tikzpicture} 159 | -------------------------------------------------------------------------------- /mpe_comparison/vmas_vs_mpe_graphs/VMAS_vs_MPE_100_steps_nvidia_geforce_rtx_2080_ti.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proroklab/VectorizedMultiAgentSimulator/acd9b7aca3ca5718f58f03a993c5ef3e920c6af9/mpe_comparison/vmas_vs_mpe_graphs/VMAS_vs_MPE_100_steps_nvidia_geforce_rtx_2080_ti.pdf -------------------------------------------------------------------------------- /mpe_comparison/vmas_vs_mpe_graphs/VMAS_vs_MPE_100_steps_nvidia_geforce_rtx_2080_ti.tex: -------------------------------------------------------------------------------- 1 | % This file was created with tikzplotlib v0.10.1. 2 | \begin{tikzpicture} 3 | 4 | \definecolor{darkgray176}{RGB}{176,176,176} 5 | \definecolor{darkorange25512714}{RGB}{255,127,14} 6 | \definecolor{lightgray204}{RGB}{204,204,204} 7 | \definecolor{steelblue31119180}{RGB}{31,119,180} 8 | 9 | \begin{axis}[ 10 | legend cell align={left}, 11 | legend style={ 12 | fill opacity=0.8, 13 | draw opacity=1, 14 | text opacity=1, 15 | at={(0.03,0.97)}, 16 | anchor=north west, 17 | draw=lightgray204 18 | }, 19 | tick align=outside, 20 | tick pos=left, 21 | title={Execution time of 'simple\_spread' for 100 steps on NVIDIA GeForce RTX 2080 Ti}, 22 | x grid style={darkgray176}, 23 | xlabel={Number of parallel environments}, 24 | xmin=-1498.95, xmax=31499.95, 25 | xtick style={color=black}, 26 | y grid style={darkgray176}, 27 | ylabel={Seconds}, 28 | ymin=-49.1076600551605, ymax=1032.37547326088, 29 | ytick style={color=black} 30 | ] 31 | \addplot [semithick, steelblue31119180] 32 | table {% 33 | 1 0.0506641864776611 34 | 304.020202636719 12.4890584945679 35 | 607.040405273438 24.3274688720703 36 | 1819.12121582031 73.275390625 37 | 2122.14135742188 86.4136810302734 38 | 2425.16162109375 98.2450256347656 39 | 3031.20190429688 123.702217102051 40 | 3334.22216796875 134.722274780273 41 | 3637.24243164062 146.787094116211 42 | 3940.2626953125 160.27978515625 43 | 4243.28271484375 172.154876708984 44 | 4546.30322265625 185.467636108398 45 | 4849.3232421875 197.332870483398 46 | 5152.34326171875 208.0390625 47 | 5455.36376953125 219.575256347656 48 | 5758.3837890625 232.126647949219 49 | 6061.40380859375 245.520904541016 50 | 6364.42431640625 256.132202148438 51 | 6667.4443359375 294.987060546875 52 | 6970.46484375 407.182586669922 53 | 7273.48486328125 433.452087402344 54 | 7576.5048828125 411.388610839844 55 | 7879.525390625 299.174713134766 56 | 8182.54541015625 382.546813964844 57 | 8485.5654296875 480.603424072266 58 | 8788.5859375 466.927795410156 59 | 9091.6064453125 298.089416503906 60 | 9394.6259765625 309.800048828125 61 | 9697.646484375 318.166839599609 62 | 10000.6669921875 328.176116943359 63 | 10606.70703125 349.943786621094 64 | 10909.7275390625 357.344604492188 65 | 11212.7470703125 385.715270996094 66 | 11515.767578125 661.783813476562 67 | 11818.7880859375 571.394287109375 68 | 12121.8076171875 396.060241699219 69 | 12424.828125 408.32470703125 70 | 12727.8486328125 418.030609130859 71 | 13030.869140625 428.212921142578 72 | 13636.9091796875 446.617492675781 73 | 13939.9296875 459.548797607422 74 | 14242.94921875 466.942230224609 75 | 14848.990234375 487.268859863281 76 | 15152.009765625 496.42529296875 77 | 15455.0302734375 508.369140625 78 | 15758.05078125 520.786437988281 79 | 16364.0908203125 541.401733398438 80 | 16667.111328125 550.4892578125 81 | 16970.130859375 560.962219238281 82 | 17273.15234375 570.659973144531 83 | 17576.171875 579.718994140625 84 | 17879.19140625 587.605041503906 85 | 18182.212890625 597.820129394531 86 | 18788.251953125 615.6982421875 87 | 19091.2734375 628.707885742188 88 | 19394.29296875 636.641906738281 89 | 19697.3125 645.548767089844 90 | 20000.333984375 659.99365234375 91 | 20303.353515625 667.142395019531 92 | 20606.373046875 678.9521484375 93 | 20909.39453125 687.535705566406 94 | 21212.4140625 700.959594726562 95 | 21515.43359375 708.58935546875 96 | 21818.455078125 715.333129882812 97 | 22121.474609375 723.936096191406 98 | 22424.494140625 734.006469726562 99 | 22727.515625 746.122497558594 100 | 23333.5546875 762.113464355469 101 | 23636.576171875 771.131774902344 102 | 23939.595703125 782.139831542969 103 | 24242.615234375 796.912414550781 104 | 24545.63671875 803.501953125 105 | 24848.65625 818.838317871094 106 | 25151.677734375 826.553833007812 107 | 25454.697265625 831.122192382812 108 | 25757.716796875 841.722229003906 109 | 26060.73828125 853.490783691406 110 | 26969.798828125 882.305541992188 111 | 27575.837890625 906.564453125 112 | 27878.859375 916.021423339844 113 | 28181.87890625 924.682800292969 114 | 29090.939453125 956.894470214844 115 | 29393.958984375 965.8486328125 116 | 29696.98046875 972.555786132812 117 | 30000 983.217163085938 118 | }; 119 | \addlegendentry{MPE} 120 | \addplot [semithick, darkorange25512714] 121 | table {% 122 | 1 0.490082979202271 123 | 7879.525390625 1.09798455238342 124 | 8485.5654296875 1.35930705070496 125 | 9394.6259765625 1.35736846923828 126 | 16667.111328125 3.38233280181885 127 | 18788.251953125 4.18041896820068 128 | 19394.29296875 4.42005729675293 129 | 24545.63671875 6.93449306488037 130 | 30000 10.1494998931885 131 | }; 132 | \addlegendentry{VMAS} 133 | \end{axis} 134 | 135 | \draw ({$(current bounding box.south west)!0.5!(current bounding box.south east)$}|-{$(current bounding box.south west)!0.98!(current bounding box.north west)$}) node[ 136 | scale=0.8, 137 | anchor=north, 138 | text=black, 139 | rotate=0.0 140 | ]{VMAS vs MPE}; 141 | \end{tikzpicture} 142 | -------------------------------------------------------------------------------- /mpe_comparison/vmas_vs_mpe_graphs/pickled/VMAS_vs_MPE_100_steps__intel(r)_xeon(r)_gold_6248r_cpu_@_3.00ghz_range_1_100_num_10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proroklab/VectorizedMultiAgentSimulator/acd9b7aca3ca5718f58f03a993c5ef3e920c6af9/mpe_comparison/vmas_vs_mpe_graphs/pickled/VMAS_vs_MPE_100_steps__intel(r)_xeon(r)_gold_6248r_cpu_@_3.00ghz_range_1_100_num_10.pkl -------------------------------------------------------------------------------- /mpe_comparison/vmas_vs_mpe_graphs/pickled/VMAS_vs_MPE_100_steps__intel(r)_xeon(r)_gold_6248r_cpu_@_3.00ghz_range_1_30000_num_75.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proroklab/VectorizedMultiAgentSimulator/acd9b7aca3ca5718f58f03a993c5ef3e920c6af9/mpe_comparison/vmas_vs_mpe_graphs/pickled/VMAS_vs_MPE_100_steps__intel(r)_xeon(r)_gold_6248r_cpu_@_3.00ghz_range_1_30000_num_75.pkl -------------------------------------------------------------------------------- /mpe_comparison/vmas_vs_mpe_graphs/pickled/VMAS_vs_MPE_100_steps_nvidia_geforce_rtx_2080_ti_range_1_100_num_10.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proroklab/VectorizedMultiAgentSimulator/acd9b7aca3ca5718f58f03a993c5ef3e920c6af9/mpe_comparison/vmas_vs_mpe_graphs/pickled/VMAS_vs_MPE_100_steps_nvidia_geforce_rtx_2080_ti_range_1_100_num_10.pkl -------------------------------------------------------------------------------- /mpe_comparison/vmas_vs_mpe_graphs/pickled/VMAS_vs_MPE_100_steps_nvidia_geforce_rtx_2080_ti_range_1_30000_num_100.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proroklab/VectorizedMultiAgentSimulator/acd9b7aca3ca5718f58f03a993c5ef3e920c6af9/mpe_comparison/vmas_vs_mpe_graphs/pickled/VMAS_vs_MPE_100_steps_nvidia_geforce_rtx_2080_ti_range_1_30000_num_100.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | pyglet<=1.5.27 4 | gym 5 | six 6 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file=README.md 3 | license_files=LICENSE 4 | 5 | [pep8] 6 | max-line-length = 120 7 | 8 | [flake8] 9 | # note: we ignore all 501s (line too long) anyway as they're taken care of by black 10 | max-line-length = 79 11 | ignore = E203, E402, W503, W504, E501 12 | per-file-ignores = 13 | __init__.py: F401, F403, F405 14 | test_*.py: F841, E731, E266 15 | exclude = venv 16 | extend-select = B901, C401, C408, C409 17 | 18 | [pydocstyle] 19 | ;select = D417 # Missing argument descriptions in the docstring 20 | ;inherit = false 21 | match = .*\.py 22 | ;match_dir = ^(?!(.circlecli|test)).* 23 | convention = google 24 | add-ignore = D100, D104, D105, D107, D102 25 | ignore-decorators = 26 | test_* 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | import pathlib 5 | 6 | from setuptools import find_packages, setup 7 | 8 | 9 | def get_version(): 10 | """Gets the vmas version.""" 11 | path = CWD / "vmas" / "__init__.py" 12 | content = path.read_text() 13 | 14 | for line in content.splitlines(): 15 | if line.startswith("__version__"): 16 | return line.strip().split()[-1].strip().strip('"') 17 | raise RuntimeError("bad version data in __init__.py") 18 | 19 | 20 | CWD = pathlib.Path(__file__).absolute().parent 21 | 22 | 23 | setup( 24 | name="vmas", 25 | version=get_version(), 26 | description="Vectorized Multi-Agent Simulator", 27 | url="https://github.com/proroklab/VectorizedMultiAgentSimulator", 28 | license="GPLv3", 29 | author="Matteo Bettini", 30 | author_email="mb2389@cl.cam.ac.uk", 31 | packages=find_packages(), 32 | install_requires=["numpy", "torch", "pyglet<=1.5.27", "gym", "six"], 33 | extras_require={ 34 | "gymnasium": ["gymnasium", "shimmy"], 35 | "rllib": ["ray[rllib]<=2.2"], 36 | "render": ["opencv-python", "moviepy", "matplotlib", "opencv-python"], 37 | "test": ["pytest", "pytest-instafail", "pyyaml", "tqdm"], 38 | }, 39 | include_package_data=True, 40 | ) 41 | -------------------------------------------------------------------------------- /tests/test_lidar.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import torch 6 | 7 | from vmas import make_env 8 | 9 | 10 | def test_vectorized_lidar(n_envs=12, n_steps=15): 11 | def get_obs(env): 12 | rollout_obs = [] 13 | for _ in range(n_steps): 14 | obs, _, _, _ = env.step(env.get_random_actions()) 15 | obs = torch.stack(obs, dim=-1) 16 | rollout_obs.append(obs) 17 | return torch.stack(rollout_obs, dim=-1) 18 | 19 | env_vec_lidar = make_env( 20 | scenario="pollock", num_envs=n_envs, seed=0, lidar=True, vectorized_lidar=True 21 | ) 22 | obs_vec_lidar = get_obs(env_vec_lidar) 23 | env_non_vec_lidar = make_env( 24 | scenario="pollock", num_envs=n_envs, seed=0, lidar=True, vectorized_lidar=False 25 | ) 26 | obs_non_vec_lidar = get_obs(env_non_vec_lidar) 27 | 28 | assert torch.allclose(obs_vec_lidar, obs_non_vec_lidar) 29 | -------------------------------------------------------------------------------- /tests/test_scenarios/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022. Matteo Bettini 2 | # All rights reserved. 3 | -------------------------------------------------------------------------------- /tests/test_scenarios/test_balance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import pytest 6 | import torch 7 | 8 | from vmas import make_env 9 | from vmas.scenarios import balance 10 | 11 | 12 | class TestBalance: 13 | def setup_env( 14 | self, 15 | n_envs, 16 | **kwargs, 17 | ) -> None: 18 | self.n_agents = kwargs.get("n_agents", 4) 19 | 20 | self.continuous_actions = True 21 | self.env = make_env( 22 | scenario="balance", 23 | num_envs=n_envs, 24 | device="cpu", 25 | continuous_actions=self.continuous_actions, 26 | # Environment specific variables 27 | **kwargs, 28 | ) 29 | self.env.seed(0) 30 | 31 | @pytest.mark.parametrize("n_agents", [2, 5]) 32 | def test_heuristic(self, n_agents, n_steps=50, n_envs=4): 33 | self.setup_env( 34 | n_agents=n_agents, random_package_pos_on_line=False, n_envs=n_envs 35 | ) 36 | policy = balance.HeuristicPolicy(self.continuous_actions) 37 | 38 | obs = self.env.reset() 39 | 40 | prev_package_dist_to_goal = obs[0][:, 8:10] 41 | 42 | for _ in range(n_steps): 43 | actions = [] 44 | for i in range(n_agents): 45 | obs_agent = obs[i] 46 | package_dist_to_goal = obs_agent[:, 8:10] 47 | 48 | action_agent = policy.compute_action( 49 | obs_agent, self.env.agents[i].u_range 50 | ) 51 | 52 | actions.append(action_agent) 53 | 54 | obs, new_rews, dones, _ = self.env.step(actions) 55 | 56 | assert ( 57 | torch.linalg.vector_norm(package_dist_to_goal, dim=-1) 58 | <= torch.linalg.vector_norm(prev_package_dist_to_goal, dim=-1) 59 | ).all() 60 | prev_package_dist_to_goal = package_dist_to_goal 61 | -------------------------------------------------------------------------------- /tests/test_scenarios/test_discovery.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import pytest 6 | 7 | from vmas import make_env 8 | from vmas.scenarios import discovery 9 | 10 | 11 | class TestDiscovery: 12 | def setup_env( 13 | self, 14 | n_envs, 15 | **kwargs, 16 | ) -> None: 17 | self.env = make_env( 18 | scenario="discovery", 19 | num_envs=n_envs, 20 | device="cpu", 21 | # Environment specific variables 22 | **kwargs, 23 | ) 24 | self.env.seed(0) 25 | 26 | @pytest.mark.parametrize("n_agents", [1, 4]) 27 | @pytest.mark.parametrize("agent_lidar", [True, False]) 28 | def test_heuristic(self, n_agents, agent_lidar, n_steps=50, n_envs=4): 29 | self.setup_env(n_agents=n_agents, n_envs=n_envs, use_agent_lidar=agent_lidar) 30 | policy = discovery.HeuristicPolicy(True) 31 | 32 | obs = self.env.reset() 33 | 34 | for _ in range(n_steps): 35 | actions = [] 36 | for i in range(n_agents): 37 | obs_agent = obs[i] 38 | 39 | action_agent = policy.compute_action( 40 | obs_agent, self.env.agents[i].u_range 41 | ) 42 | 43 | actions.append(action_agent) 44 | 45 | obs, new_rews, dones, _ = self.env.step(actions) 46 | -------------------------------------------------------------------------------- /tests/test_scenarios/test_dispersion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import pytest 6 | import torch 7 | 8 | from vmas import make_env 9 | 10 | 11 | class TestDispersion: 12 | def setup_env( 13 | self, n_agents: int, share_reward: bool, penalise_by_time: bool, n_envs 14 | ) -> None: 15 | self.n_agents = n_agents 16 | self.share_reward = share_reward 17 | self.penalise_by_time = penalise_by_time 18 | 19 | self.continuous_actions = True 20 | self.env = make_env( 21 | scenario="dispersion", 22 | num_envs=n_envs, 23 | device="cpu", 24 | continuous_actions=self.continuous_actions, 25 | # Environment specific variables 26 | n_agents=self.n_agents, 27 | share_reward=self.share_reward, 28 | penalise_by_time=self.penalise_by_time, 29 | ) 30 | self.env.seed(0) 31 | 32 | @pytest.mark.parametrize("n_agents", [1, 5, 10]) 33 | def test_heuristic(self, n_agents, n_envs=4): 34 | self.setup_env( 35 | n_agents=n_agents, 36 | share_reward=False, 37 | penalise_by_time=False, 38 | n_envs=n_envs, 39 | ) 40 | all_done = torch.full((n_envs,), False) 41 | obs = self.env.reset() 42 | total_rew = torch.zeros(self.env.num_envs, n_agents) 43 | while not all_done.all(): 44 | actions = [] 45 | idx = 0 46 | for i in range(n_agents): 47 | obs_agent = obs[i] 48 | obs_idx = 4 + idx 49 | action_agent = torch.clamp( 50 | obs_agent[:, obs_idx : obs_idx + 2], 51 | min=-self.env.agents[i].u_range, 52 | max=self.env.agents[i].u_range, 53 | ) 54 | idx += 3 55 | actions.append(action_agent) 56 | 57 | obs, rews, dones, _ = self.env.step(actions) 58 | for i in range(n_agents): 59 | total_rew[:, i] += rews[i] 60 | if dones.any(): 61 | # Done envs should have exactly sum of rewards equal to num_agents 62 | assert torch.equal( 63 | total_rew[dones].sum(-1).to(torch.long), 64 | torch.full((dones.sum(),), n_agents), 65 | ) 66 | total_rew[dones] = 0 67 | all_done += dones 68 | for env_index, done in enumerate(dones): 69 | if done: 70 | self.env.reset_at(env_index) 71 | 72 | @pytest.mark.parametrize("n_agents", [1, 5, 10, 20]) 73 | def test_heuristic_share_reward(self, n_agents, n_envs=4): 74 | self.setup_env( 75 | n_agents=n_agents, 76 | share_reward=True, 77 | penalise_by_time=False, 78 | n_envs=n_envs, 79 | ) 80 | all_done = torch.full((n_envs,), False) 81 | obs = self.env.reset() 82 | total_rew = torch.zeros(self.env.num_envs, n_agents) 83 | while not all_done.all(): 84 | actions = [] 85 | idx = 0 86 | for i in range(n_agents): 87 | obs_agent = obs[i] 88 | obs_idx = 4 + idx 89 | action_agent = torch.clamp( 90 | obs_agent[:, obs_idx : obs_idx + 2], 91 | min=-self.env.agents[i].u_range, 92 | max=self.env.agents[i].u_range, 93 | ) 94 | idx += 3 95 | actions.append(action_agent) 96 | 97 | obs, rews, dones, _ = self.env.step(actions) 98 | for i in range(n_agents): 99 | total_rew[:, i] += rews[i] 100 | if dones.any(): 101 | # Done envs should have exactly sum of rewards equal to num_agents 102 | assert torch.equal( 103 | total_rew[dones], 104 | torch.full((dones.sum(), n_agents), n_agents).to(torch.float), 105 | ) 106 | total_rew[dones] = 0 107 | all_done += dones 108 | for env_index, done in enumerate(dones): 109 | if done: 110 | self.env.reset_at(env_index) 111 | -------------------------------------------------------------------------------- /tests/test_scenarios/test_dropout.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import pytest 6 | import torch 7 | 8 | from vmas import make_env 9 | from vmas.scenarios.dropout import DEFAULT_ENERGY_COEFF 10 | 11 | 12 | class TestDropout: 13 | def setup_env( 14 | self, 15 | n_agents: int, 16 | num_envs: int, 17 | energy_coeff: float = DEFAULT_ENERGY_COEFF, 18 | ) -> None: 19 | self.n_agents = n_agents 20 | self.energy_coeff = energy_coeff 21 | 22 | self.continuous_actions = True 23 | self.n_envs = num_envs 24 | self.env = make_env( 25 | scenario="dropout", 26 | num_envs=num_envs, 27 | device="cpu", 28 | continuous_actions=self.continuous_actions, 29 | # Environment specific variables 30 | n_agents=self.n_agents, 31 | energy_coeff=self.energy_coeff, 32 | ) 33 | self.env.seed(0) 34 | 35 | # Test that one agent can always reach the goal no matter the conditions 36 | @pytest.mark.parametrize("n_agents", [1, 5]) 37 | def test_heuristic(self, n_agents, n_envs=4): 38 | self.setup_env(n_agents=n_agents, num_envs=n_envs) 39 | 40 | obs = self.env.reset() 41 | total_rew = torch.zeros(self.env.num_envs) 42 | 43 | current_min = float("inf") 44 | best_i = None 45 | for i in range(n_agents): 46 | obs_agent = obs[i] 47 | if torch.linalg.vector_norm(obs_agent[:, -3:-1], dim=1)[0] < current_min: 48 | current_min = torch.linalg.vector_norm(obs_agent[:, -3:-1], dim=1)[0] 49 | best_i = i 50 | 51 | done = False 52 | while not done: 53 | obs_agent = obs[best_i] 54 | action_agent = torch.clamp( 55 | obs_agent[:, -3:-1], 56 | min=-self.env.agents[best_i].u_range, 57 | max=self.env.agents[best_i].u_range, 58 | ) 59 | 60 | actions = [] 61 | other_agents_action = torch.zeros(self.env.num_envs, self.env.world.dim_p) 62 | for j in range(self.n_agents): 63 | if best_i != j: 64 | actions.append(other_agents_action) 65 | else: 66 | actions.append(action_agent) 67 | 68 | obs, new_rews, dones, _ = self.env.step(actions) 69 | for j in range(self.n_agents): 70 | assert torch.equal(new_rews[0], new_rews[j]) 71 | total_rew += new_rews[0] 72 | assert (total_rew[dones] > 0).all() 73 | done = dones.any() 74 | 75 | # Test that one agent can always reach the goal no matter the conditions 76 | @pytest.mark.parametrize("n_agents", [1, 5]) 77 | def test_one_random_agent_can_do_it(self, n_agents, n_steps=50, n_envs=4): 78 | self.setup_env(n_agents=n_agents, num_envs=n_envs) 79 | for i in range(self.n_agents): 80 | obs = self.env.reset() 81 | total_rew = torch.zeros(self.env.num_envs) 82 | for _ in range(n_steps): 83 | obs_agent = obs[i] 84 | action_agent = torch.clamp( 85 | obs_agent[:, -3:-1], 86 | min=-self.env.agents[i].u_range, 87 | max=self.env.agents[i].u_range, 88 | ) 89 | 90 | actions = [] 91 | other_agents_action = torch.zeros( 92 | self.env.num_envs, self.env.world.dim_p 93 | ) 94 | for j in range(self.n_agents): 95 | if i != j: 96 | actions.append(other_agents_action) 97 | else: 98 | actions.append(action_agent) 99 | 100 | obs, new_rews, dones, _ = self.env.step(actions) 101 | for j in range(self.n_agents): 102 | assert torch.equal(new_rews[0], new_rews[j]) 103 | total_rew += new_rews[0] 104 | assert (total_rew[dones] > 0).all() 105 | for env_index, done in enumerate(dones): 106 | if done: 107 | self.env.reset_at(env_index) 108 | 109 | total_rew[dones] = 0 110 | 111 | @pytest.mark.parametrize("n_agents", [5, 10]) 112 | def test_all_agents_cannot_do_it(self, n_agents): 113 | # Test that all agents together cannot reach the goal no matter the conditions (to be sure we do 5+ agents) 114 | assert self.all_agents(DEFAULT_ENERGY_COEFF, n_agents) < 0 115 | # Test that all agents together can reach the goal with no energy penalty 116 | assert self.all_agents(0, n_agents) > 0 117 | 118 | def all_agents(self, energy_coeff: float, n_agents: int, n_steps=100, n_envs=4): 119 | rewards = [] 120 | self.setup_env(n_agents=n_agents, energy_coeff=energy_coeff, num_envs=n_envs) 121 | obs = self.env.reset() 122 | total_rew = torch.zeros(self.env.num_envs) 123 | for _ in range(n_steps): 124 | actions = [] 125 | for i in range(self.n_agents): 126 | obs_i = obs[i] 127 | action_i = torch.clamp( 128 | obs_i[:, -3:-1], 129 | min=-self.env.agents[i].u_range, 130 | max=self.env.agents[i].u_range, 131 | ) 132 | actions.append(action_i) 133 | 134 | obs, new_rews, dones, _ = self.env.step(actions) 135 | for j in range(self.n_agents): 136 | assert torch.equal(new_rews[0], new_rews[j]) 137 | total_rew += new_rews[0] 138 | for env_index, done in enumerate(dones): 139 | if done: 140 | self.env.reset_at(env_index) 141 | if dones.any(): 142 | rewards.append(total_rew[dones].clone()) 143 | total_rew[dones] = 0 144 | return sum([rew.mean().item() for rew in rewards]) / len(rewards) 145 | -------------------------------------------------------------------------------- /tests/test_scenarios/test_flocking.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import pytest 6 | 7 | from vmas import make_env 8 | from vmas.scenarios import flocking 9 | 10 | 11 | class TestFlocking: 12 | def setup_env( 13 | self, 14 | n_envs, 15 | **kwargs, 16 | ) -> None: 17 | self.env = make_env( 18 | scenario="flocking", 19 | num_envs=n_envs, 20 | device="cpu", 21 | # Environment specific variables 22 | **kwargs, 23 | ) 24 | self.env.seed(0) 25 | 26 | @pytest.mark.parametrize("n_agents", [1, 5]) 27 | def test_heuristic(self, n_agents, n_steps=50, n_envs=4): 28 | self.setup_env(n_agents=n_agents, n_envs=n_envs) 29 | policy = flocking.HeuristicPolicy(True) 30 | 31 | obs = self.env.reset() 32 | 33 | for _ in range(n_steps): 34 | actions = [] 35 | for i in range(n_agents): 36 | obs_agent = obs[i] 37 | 38 | action_agent = policy.compute_action( 39 | obs_agent, self.env.agents[i].u_range 40 | ) 41 | 42 | actions.append(action_agent) 43 | 44 | obs, new_rews, dones, _ = self.env.step(actions) 45 | -------------------------------------------------------------------------------- /tests/test_scenarios/test_football.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | import sys 5 | 6 | import pytest 7 | import torch 8 | from tqdm import tqdm 9 | 10 | from vmas import make_env 11 | 12 | 13 | class TestFootball: 14 | def setup_env(self, n_envs, **kwargs) -> None: 15 | self.continuous_actions = True 16 | 17 | self.env = make_env( 18 | scenario="football", 19 | num_envs=n_envs, 20 | device="cpu", 21 | continuous_actions=True, 22 | # Environment specific variables 23 | **kwargs, 24 | ) 25 | self.env.seed(0) 26 | 27 | @pytest.mark.skipif( 28 | sys.platform.startswith("win32"), reason="Test does not work on windows" 29 | ) 30 | def test_ai_vs_random(self, n_envs=4, n_agents=3, scoring_reward=1): 31 | self.setup_env( 32 | n_red_agents=n_agents, 33 | n_blue_agents=n_agents, 34 | ai_red_agents=True, 35 | ai_blue_agents=False, 36 | dense_reward=False, 37 | n_envs=n_envs, 38 | scoring_reward=scoring_reward, 39 | ) 40 | all_done = torch.full((n_envs,), False) 41 | obs = self.env.reset() 42 | total_rew = torch.zeros(self.env.num_envs, n_agents) 43 | with tqdm(total=n_envs) as pbar: 44 | while not all_done.all(): 45 | pbar.update(all_done.sum().item() - pbar.n) 46 | actions = [] 47 | for _ in range(n_agents): 48 | actions.append(torch.rand(n_envs, 2)) 49 | 50 | obs, rews, dones, _ = self.env.step(actions) 51 | for i in range(n_agents): 52 | total_rew[:, i] += rews[i] 53 | if dones.any(): 54 | # Done envs should have exactly sum of rewards equal to num_agents 55 | actual_rew = -scoring_reward * n_agents 56 | assert torch.equal( 57 | total_rew[dones].sum(-1).to(torch.long), 58 | torch.full((dones.sum(),), actual_rew, dtype=torch.long), 59 | ) 60 | total_rew[dones] = 0 61 | all_done += dones 62 | for env_index, done in enumerate(dones): 63 | if done: 64 | self.env.reset_at(env_index) 65 | -------------------------------------------------------------------------------- /tests/test_scenarios/test_give_way.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import torch 6 | 7 | from vmas import make_env 8 | 9 | 10 | class TestGiveWay: 11 | def setup_env(self, n_envs, **kwargs) -> None: 12 | self.continuous_actions = True 13 | 14 | self.env = make_env( 15 | scenario="give_way", 16 | num_envs=n_envs, 17 | device="cpu", 18 | continuous_actions=self.continuous_actions, 19 | # Environment specific variables 20 | **kwargs, 21 | ) 22 | self.env.seed(0) 23 | 24 | def test_heuristic(self, n_envs=4): 25 | self.setup_env(mirror_passage=False, n_envs=n_envs) 26 | all_done = torch.full((n_envs,), False) 27 | obs = self.env.reset() 28 | u_range = self.env.agents[0].u_range 29 | total_rew = torch.zeros((n_envs,)) 30 | while not (total_rew > 17).all(): 31 | obs_agent = obs[0] 32 | if (obs[1][:, :1] < 0).all(): 33 | action_1 = torch.tensor([u_range / 2, -u_range]).repeat(n_envs, 1) 34 | else: 35 | action_1 = torch.tensor([u_range / 2, u_range]).repeat(n_envs, 1) 36 | action_2 = torch.tensor([-u_range / 3, 0]).repeat(n_envs, 1) 37 | obs, rews, dones, _ = self.env.step([action_1, action_2]) 38 | for rew in rews: 39 | total_rew += rew 40 | if dones.any(): 41 | # Done envs should have exactly sum of rewards equal to num_agents 42 | all_done += dones 43 | for env_index, done in enumerate(dones): 44 | if done: 45 | self.env.reset_at(env_index) 46 | -------------------------------------------------------------------------------- /tests/test_scenarios/test_navigation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | import pytest 5 | import torch 6 | 7 | from vmas import make_env 8 | from vmas.scenarios.navigation import HeuristicPolicy 9 | 10 | 11 | class TestNavigation: 12 | def setUp(self, n_envs, n_agents) -> None: 13 | self.continuous_actions = True 14 | 15 | self.env = make_env( 16 | scenario="navigation", 17 | num_envs=n_envs, 18 | device="cpu", 19 | continuous_actions=self.continuous_actions, 20 | # Environment specific variables 21 | n_agents=n_agents, 22 | ) 23 | self.env.seed(0) 24 | 25 | @pytest.mark.parametrize("n_agents", [1]) 26 | def test_heuristic( 27 | self, 28 | n_agents, 29 | n_envs=5, 30 | ): 31 | self.setUp(n_envs=n_envs, n_agents=n_agents) 32 | 33 | policy = HeuristicPolicy( 34 | continuous_action=self.continuous_actions, clf_epsilon=0.4, clf_slack=100.0 35 | ) 36 | 37 | obs = self.env.reset() 38 | all_done = torch.zeros(n_envs, dtype=torch.bool) 39 | 40 | while not all_done.all(): 41 | actions = [] 42 | for i in range(n_agents): 43 | obs_agent = obs[i] 44 | 45 | action_agent = policy.compute_action( 46 | obs_agent, self.env.agents[i].action.u_range_tensor 47 | ) 48 | 49 | actions.append(action_agent) 50 | 51 | obs, new_rews, dones, _ = self.env.step(actions) 52 | if dones.any(): 53 | all_done += dones 54 | 55 | for env_index, done in enumerate(dones): 56 | if done: 57 | self.env.reset_at(env_index) 58 | -------------------------------------------------------------------------------- /tests/test_scenarios/test_passage.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import torch 6 | 7 | from vmas import make_env 8 | 9 | 10 | class TestPassage: 11 | def setup_env( 12 | self, 13 | n_envs, 14 | **kwargs, 15 | ) -> None: 16 | self.n_passages = kwargs.get("n_passages", 4) 17 | 18 | self.continuous_actions = True 19 | self.env = make_env( 20 | scenario="passage", 21 | num_envs=n_envs, 22 | device="cpu", 23 | continuous_actions=self.continuous_actions, 24 | # Environment specific variables 25 | **kwargs, 26 | ) 27 | self.env.seed(0) 28 | 29 | def test_heuristic(self, n_envs=4): 30 | self.setup_env(n_passages=1, shared_reward=True, n_envs=4) 31 | 32 | obs = self.env.reset() 33 | agent_switched = torch.full((5, n_envs), False) 34 | all_done = torch.full((n_envs,), False) 35 | while not all_done.all(): 36 | actions = [] 37 | 38 | for i in range(5): 39 | obs_agent = obs[i] 40 | dist_to_passage = obs_agent[:, 6:8] 41 | dist_to_goal = obs_agent[:, 4:6] 42 | dist_to_passage_is_close = ( 43 | torch.linalg.vector_norm(dist_to_passage, dim=1) <= 0.025 44 | ) 45 | 46 | action_agent = torch.clamp( 47 | 2 * dist_to_passage, 48 | min=-self.env.agents[i].u_range, 49 | max=self.env.agents[i].u_range, 50 | ) 51 | agent_switched[i] += dist_to_passage_is_close 52 | action_agent[agent_switched[i]] = torch.clamp( 53 | 2 * dist_to_goal, 54 | min=-self.env.agents[i].u_range, 55 | max=self.env.agents[i].u_range, 56 | )[agent_switched[i]] 57 | 58 | actions.append(action_agent) 59 | 60 | obs, new_rews, dones, _ = self.env.step(actions) 61 | 62 | if dones.any(): 63 | all_done += dones 64 | for env_index, done in enumerate(dones): 65 | if done: 66 | agent_switched[:, env_index] = False 67 | self.env.reset_at(env_index) 68 | -------------------------------------------------------------------------------- /tests/test_scenarios/test_reverse_transport.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import pytest 6 | import torch 7 | 8 | from vmas import make_env 9 | 10 | 11 | class TestReverseTransport: 12 | def setup_env(self, n_envs, **kwargs) -> None: 13 | self.n_agents = kwargs.get("n_agents", 4) 14 | self.package_width = kwargs.get("package_width", 0.6) 15 | self.package_length = kwargs.get("package_length", 0.6) 16 | self.package_mass = kwargs.get("package_mass", 50) 17 | 18 | self.continuous_actions = True 19 | 20 | self.env = make_env( 21 | scenario="reverse_transport", 22 | num_envs=n_envs, 23 | device="cpu", 24 | continuous_actions=self.continuous_actions, 25 | # Environment specific variables 26 | **kwargs, 27 | ) 28 | self.env.seed(0) 29 | 30 | @pytest.mark.parametrize("n_agents", [5]) 31 | def test_heuristic(self, n_agents, n_envs=4): 32 | self.setup_env(n_agents=n_agents, n_envs=n_envs) 33 | obs = self.env.reset() 34 | all_done = torch.full((n_envs,), False) 35 | 36 | while not all_done.all(): 37 | actions = [] 38 | for i in range(n_agents): 39 | obs_agent = obs[i] 40 | action_agent = torch.clamp( 41 | -obs_agent[:, -2:], 42 | min=-self.env.agents[i].u_range, 43 | max=self.env.agents[i].u_range, 44 | ) 45 | actions.append(action_agent) 46 | obs, new_rews, dones, _ = self.env.step(actions) 47 | 48 | if dones.any(): 49 | all_done += dones 50 | 51 | for env_index, done in enumerate(dones): 52 | if done: 53 | self.env.reset_at(env_index) 54 | -------------------------------------------------------------------------------- /tests/test_scenarios/test_transport.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | import pytest 5 | import torch 6 | 7 | from vmas import make_env 8 | from vmas.scenarios import transport 9 | 10 | 11 | class TestTransport: 12 | def setup_env(self, n_envs, **kwargs) -> None: 13 | self.n_agents = kwargs.get("n_agents", 4) 14 | self.n_packages = kwargs.get("n_packages", 1) 15 | self.package_width = kwargs.get("package_width", 0.15) 16 | self.package_length = kwargs.get("package_length", 0.15) 17 | self.package_mass = kwargs.get("package_mass", 50) 18 | 19 | self.continuous_actions = True 20 | 21 | self.env = make_env( 22 | scenario="transport", 23 | num_envs=n_envs, 24 | device="cpu", 25 | continuous_actions=self.continuous_actions, 26 | # Environment specific variables 27 | **kwargs 28 | ) 29 | self.env.seed(0) 30 | 31 | def test_not_passing_through_packages(self, n_agents=1, n_envs=4): 32 | self.setup_env(n_agents=n_agents, n_envs=n_envs) 33 | 34 | for _ in range(10): 35 | obs = self.env.reset() 36 | for _ in range(100): 37 | obs_agent = obs[0] 38 | assert ( 39 | torch.linalg.vector_norm(obs_agent[:, 6:8], dim=1) 40 | > self.env.agents[0].shape.radius 41 | ).all() 42 | action_agent = torch.clamp( 43 | obs_agent[:, 6:8], 44 | min=-self.env.agents[0].u_range, 45 | max=self.env.agents[0].u_range, 46 | ) 47 | action_agent /= torch.linalg.vector_norm(action_agent, dim=1).unsqueeze( 48 | -1 49 | ) 50 | action_agent *= self.env.agents[0].u_range 51 | 52 | obs, rews, dones, _ = self.env.step([action_agent]) 53 | 54 | @pytest.mark.parametrize("n_agents", [6]) 55 | def test_heuristic(self, n_agents, n_envs=4): 56 | self.setup_env(n_agents=n_agents, n_envs=n_envs) 57 | policy = transport.HeuristicPolicy(self.continuous_actions) 58 | 59 | obs = self.env.reset() 60 | all_done = torch.zeros(n_envs, dtype=torch.bool) 61 | 62 | while not all_done.all(): 63 | actions = [] 64 | for i in range(n_agents): 65 | obs_agent = obs[i] 66 | 67 | action_agent = policy.compute_action( 68 | obs_agent, self.env.agents[i].u_range 69 | ) 70 | 71 | actions.append(action_agent) 72 | 73 | obs, new_rews, dones, _ = self.env.step(actions) 74 | 75 | if dones.any(): 76 | all_done += dones 77 | 78 | for env_index, done in enumerate(dones): 79 | if done: 80 | self.env.reset_at(env_index) 81 | -------------------------------------------------------------------------------- /tests/test_scenarios/test_waterfall.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import torch 6 | 7 | from vmas import make_env 8 | 9 | 10 | class TestWaterfall: 11 | def setUp(self, n_envs, n_agents) -> None: 12 | self.continuous_actions = True 13 | 14 | self.env = make_env( 15 | scenario="waterfall", 16 | num_envs=n_envs, 17 | device="cpu", 18 | continuous_actions=self.continuous_actions, 19 | # Environment specific variables 20 | n_agents=n_agents, 21 | ) 22 | self.env.seed(0) 23 | 24 | def test_heuristic(self, n_agents=5, n_envs=4, n_steps=50): 25 | self.setUp(n_envs=n_envs, n_agents=n_agents) 26 | obs = self.env.reset() 27 | for _ in range(n_steps): 28 | actions = [] 29 | for i in range(n_agents): 30 | obs_agent = obs[i] 31 | action_agent = torch.clamp( 32 | obs_agent[:, -2:], 33 | min=-self.env.agents[i].u_range, 34 | max=self.env.agents[i].u_range, 35 | ) 36 | actions.append(action_agent) 37 | obs, new_rews, _, _ = self.env.step(actions) 38 | -------------------------------------------------------------------------------- /tests/test_scenarios/test_wheel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import pytest 6 | 7 | from vmas import make_env 8 | from vmas.scenarios import wheel 9 | 10 | 11 | class TestWheel: 12 | def setup_env( 13 | self, 14 | n_envs, 15 | n_agents, 16 | **kwargs, 17 | ) -> None: 18 | self.desired_velocity = kwargs.get("desired_velocity", 0.1) 19 | 20 | self.continuous_actions = True 21 | self.n_envs = 15 22 | self.env = make_env( 23 | scenario="wheel", 24 | num_envs=n_envs, 25 | device="cpu", 26 | continuous_actions=self.continuous_actions, 27 | # Environment specific variables 28 | n_agents=n_agents, 29 | **kwargs, 30 | ) 31 | self.env.seed(0) 32 | 33 | @pytest.mark.parametrize("n_agents", [2, 10]) 34 | def test_heuristic(self, n_agents, n_steps=50, n_envs=4): 35 | line_length = 2 36 | self.setup_env(n_agents=n_agents, line_length=line_length, n_envs=n_envs) 37 | policy = wheel.HeuristicPolicy(self.continuous_actions) 38 | 39 | obs = self.env.reset() 40 | 41 | for _ in range(n_steps): 42 | actions = [] 43 | 44 | for i in range(n_agents): 45 | obs_agent = obs[i] 46 | 47 | action_agent = policy.compute_action( 48 | obs_agent, self.env.agents[i].u_range 49 | ) 50 | actions.append(action_agent) 51 | 52 | obs, new_rews, dones, _ = self.env.step(actions) 53 | -------------------------------------------------------------------------------- /tests/test_wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | -------------------------------------------------------------------------------- /tests/test_wrappers/test_gym_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import gym 6 | import numpy as np 7 | import pytest 8 | from torch import Tensor 9 | 10 | from vmas import make_env 11 | from vmas.simulator.environment import Environment 12 | 13 | 14 | TEST_SCENARIOS = [ 15 | "balance", 16 | "discovery", 17 | "give_way", 18 | "joint_passage", 19 | "navigation", 20 | "passage", 21 | "transport", 22 | "waterfall", 23 | "simple_world_comm", 24 | ] 25 | 26 | 27 | def _check_obs_type(obss, obs_shapes, dict_space, return_numpy): 28 | if dict_space: 29 | assert isinstance( 30 | obss, dict 31 | ), f"Expected dictionary of observations, got {type(obss)}" 32 | for k, obs in obss.items(): 33 | obs_shape = obs_shapes[k] 34 | assert ( 35 | obs.shape == obs_shape 36 | ), f"Expected shape {obs_shape}, got {obs.shape}" 37 | if return_numpy: 38 | assert isinstance( 39 | obs, np.ndarray 40 | ), f"Expected numpy array, got {type(obs)}" 41 | else: 42 | assert isinstance( 43 | obs, Tensor 44 | ), f"Expected torch tensor, got {type(obs)}" 45 | else: 46 | assert isinstance( 47 | obss, list 48 | ), f"Expected list of observations, got {type(obss)}" 49 | for obs, shape in zip(obss, obs_shapes): 50 | assert obs.shape == shape, f"Expected shape {shape}, got {obs.shape}" 51 | if return_numpy: 52 | assert isinstance( 53 | obs, np.ndarray 54 | ), f"Expected numpy array, got {type(obs)}" 55 | else: 56 | assert isinstance( 57 | obs, Tensor 58 | ), f"Expected torch tensor, got {type(obs)}" 59 | 60 | 61 | @pytest.mark.parametrize("scenario", TEST_SCENARIOS) 62 | @pytest.mark.parametrize("return_numpy", [True, False]) 63 | @pytest.mark.parametrize("continuous_actions", [True, False]) 64 | @pytest.mark.parametrize("dict_space", [True, False]) 65 | def test_gym_wrapper( 66 | scenario, return_numpy, continuous_actions, dict_space, max_steps=10 67 | ): 68 | env = make_env( 69 | scenario=scenario, 70 | num_envs=1, 71 | device="cpu", 72 | continuous_actions=continuous_actions, 73 | dict_spaces=dict_space, 74 | wrapper="gym", 75 | wrapper_kwargs={"return_numpy": return_numpy}, 76 | max_steps=max_steps, 77 | ) 78 | 79 | assert ( 80 | len(env.observation_space) == env.unwrapped.n_agents 81 | ), "Expected one observation per agent" 82 | assert ( 83 | len(env.action_space) == env.unwrapped.n_agents 84 | ), "Expected one action per agent" 85 | if dict_space: 86 | assert isinstance( 87 | env.observation_space, gym.spaces.Dict 88 | ), "Expected Dict observation space" 89 | assert isinstance( 90 | env.action_space, gym.spaces.Dict 91 | ), "Expected Dict action space" 92 | obs_shapes = { 93 | k: obs_space.shape for k, obs_space in env.observation_space.spaces.items() 94 | } 95 | else: 96 | assert isinstance( 97 | env.observation_space, gym.spaces.Tuple 98 | ), "Expected Tuple observation space" 99 | assert isinstance( 100 | env.action_space, gym.spaces.Tuple 101 | ), "Expected Tuple action space" 102 | obs_shapes = [obs_space.shape for obs_space in env.observation_space.spaces] 103 | 104 | assert isinstance( 105 | env.unwrapped, Environment 106 | ), "The unwrapped attribute of the Gym wrapper should be a VMAS Environment" 107 | 108 | obss = env.reset() 109 | _check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy) 110 | 111 | for _ in range(max_steps): 112 | actions = [ 113 | env.unwrapped.get_random_action(agent).numpy() 114 | for agent in env.unwrapped.agents 115 | ] 116 | obss, rews, done, info = env.step(actions) 117 | _check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy) 118 | 119 | assert len(rews) == env.unwrapped.n_agents, "Expected one reward per agent" 120 | if not dict_space: 121 | assert isinstance( 122 | rews, list 123 | ), f"Expected list of rewards but got {type(rews)}" 124 | 125 | rew_values = rews 126 | else: 127 | assert isinstance( 128 | rews, dict 129 | ), f"Expected dictionary of rewards but got {type(rews)}" 130 | rew_values = list(rews.values()) 131 | assert all( 132 | isinstance(rew, float) for rew in rew_values 133 | ), f"Expected float rewards but got {type(rew_values[0])}" 134 | 135 | assert isinstance(done, bool), f"Expected bool for done but got {type(done)}" 136 | 137 | assert isinstance( 138 | info, dict 139 | ), f"Expected info to be a dictionary but got {type(info)}" 140 | 141 | assert done, "Expected done to be True after 100 steps" 142 | -------------------------------------------------------------------------------- /tests/test_wrappers/test_gymnasium_vec_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import gymnasium as gym 6 | import numpy as np 7 | import pytest 8 | import torch 9 | from vmas import make_env 10 | from vmas.simulator.environment import Environment 11 | 12 | from test_wrappers.test_gym_wrapper import _check_obs_type, TEST_SCENARIOS 13 | 14 | 15 | @pytest.mark.parametrize("scenario", TEST_SCENARIOS) 16 | @pytest.mark.parametrize("return_numpy", [True, False]) 17 | @pytest.mark.parametrize("continuous_actions", [True, False]) 18 | @pytest.mark.parametrize("dict_space", [True, False]) 19 | @pytest.mark.parametrize("num_envs", [1, 10]) 20 | def test_gymnasium_wrapper( 21 | scenario, return_numpy, continuous_actions, dict_space, num_envs, max_steps=10 22 | ): 23 | env = make_env( 24 | scenario=scenario, 25 | num_envs=num_envs, 26 | device="cpu", 27 | continuous_actions=continuous_actions, 28 | dict_spaces=dict_space, 29 | wrapper="gymnasium_vec", 30 | terminated_truncated=True, 31 | wrapper_kwargs={"return_numpy": return_numpy}, 32 | max_steps=max_steps, 33 | ) 34 | 35 | assert isinstance( 36 | env.unwrapped, Environment 37 | ), "The unwrapped attribute of the Gym wrapper should be a VMAS Environment" 38 | 39 | assert ( 40 | len(env.observation_space) == env.unwrapped.n_agents 41 | ), "Expected one observation per agent" 42 | assert ( 43 | len(env.action_space) == env.unwrapped.n_agents 44 | ), "Expected one action per agent" 45 | if dict_space: 46 | assert isinstance( 47 | env.observation_space, gym.spaces.Dict 48 | ), "Expected Dict observation space" 49 | assert isinstance( 50 | env.action_space, gym.spaces.Dict 51 | ), "Expected Dict action space" 52 | obs_shapes = { 53 | k: obs_space.shape for k, obs_space in env.observation_space.spaces.items() 54 | } 55 | else: 56 | assert isinstance( 57 | env.observation_space, gym.spaces.Tuple 58 | ), "Expected Tuple observation space" 59 | assert isinstance( 60 | env.action_space, gym.spaces.Tuple 61 | ), "Expected Tuple action space" 62 | obs_shapes = [obs_space.shape for obs_space in env.observation_space.spaces] 63 | 64 | obss, info = env.reset() 65 | _check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy) 66 | assert isinstance( 67 | info, dict 68 | ), f"Expected info to be a dictionary but got {type(info)}" 69 | 70 | for _ in range(max_steps): 71 | actions = [ 72 | env.unwrapped.get_random_action(agent).numpy() 73 | for agent in env.unwrapped.agents 74 | ] 75 | obss, rews, terminated, truncated, info = env.step(actions) 76 | _check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy) 77 | 78 | assert len(rews) == env.unwrapped.n_agents, "Expected one reward per agent" 79 | if not dict_space: 80 | assert isinstance( 81 | rews, list 82 | ), f"Expected list of rewards but got {type(rews)}" 83 | 84 | rew_values = rews 85 | else: 86 | assert isinstance( 87 | rews, dict 88 | ), f"Expected dictionary of rewards but got {type(rews)}" 89 | rew_values = list(rews.values()) 90 | if return_numpy: 91 | assert all( 92 | isinstance(rew, np.ndarray) for rew in rew_values 93 | ), f"Expected np.array rewards but got {type(rew_values[0])}" 94 | else: 95 | assert all( 96 | isinstance(rew, torch.Tensor) for rew in rew_values 97 | ), f"Expected torch tensor rewards but got {type(rew_values[0])}" 98 | 99 | if return_numpy: 100 | assert isinstance( 101 | terminated, np.ndarray 102 | ), f"Expected np.array for terminated but got {type(terminated)}" 103 | assert isinstance( 104 | truncated, np.ndarray 105 | ), f"Expected np.array for truncated but got {type(truncated)}" 106 | else: 107 | assert isinstance( 108 | terminated, torch.Tensor 109 | ), f"Expected torch tensor for terminated but got {type(terminated)}" 110 | assert isinstance( 111 | truncated, torch.Tensor 112 | ), f"Expected torch tensor for truncated but got {type(truncated)}" 113 | 114 | assert isinstance( 115 | info, dict 116 | ), f"Expected info to be a dictionary but got {type(info)}" 117 | 118 | assert all(truncated), "Expected done to be True after 100 steps" 119 | -------------------------------------------------------------------------------- /tests/test_wrappers/test_gymnasium_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import gymnasium as gym 6 | import pytest 7 | from vmas import make_env 8 | from vmas.simulator.environment import Environment 9 | 10 | from test_wrappers.test_gym_wrapper import _check_obs_type, TEST_SCENARIOS 11 | 12 | 13 | @pytest.mark.parametrize("scenario", TEST_SCENARIOS) 14 | @pytest.mark.parametrize("return_numpy", [True, False]) 15 | @pytest.mark.parametrize("continuous_actions", [True, False]) 16 | @pytest.mark.parametrize("dict_space", [True, False]) 17 | def test_gymnasium_wrapper( 18 | scenario, return_numpy, continuous_actions, dict_space, max_steps=10 19 | ): 20 | env = make_env( 21 | scenario=scenario, 22 | num_envs=1, 23 | device="cpu", 24 | continuous_actions=continuous_actions, 25 | dict_spaces=dict_space, 26 | wrapper="gymnasium", 27 | terminated_truncated=True, 28 | wrapper_kwargs={"return_numpy": return_numpy}, 29 | max_steps=max_steps, 30 | ) 31 | 32 | assert ( 33 | len(env.observation_space) == env.unwrapped.n_agents 34 | ), "Expected one observation per agent" 35 | assert ( 36 | len(env.action_space) == env.unwrapped.n_agents 37 | ), "Expected one action per agent" 38 | if dict_space: 39 | assert isinstance( 40 | env.observation_space, gym.spaces.Dict 41 | ), "Expected Dict observation space" 42 | assert isinstance( 43 | env.action_space, gym.spaces.Dict 44 | ), "Expected Dict action space" 45 | obs_shapes = { 46 | k: obs_space.shape for k, obs_space in env.observation_space.spaces.items() 47 | } 48 | else: 49 | assert isinstance( 50 | env.observation_space, gym.spaces.Tuple 51 | ), "Expected Tuple observation space" 52 | assert isinstance( 53 | env.action_space, gym.spaces.Tuple 54 | ), "Expected Tuple action space" 55 | obs_shapes = [obs_space.shape for obs_space in env.observation_space.spaces] 56 | 57 | assert isinstance( 58 | env.unwrapped, Environment 59 | ), "The unwrapped attribute of the Gym wrapper should be a VMAS Environment" 60 | 61 | obss, info = env.reset() 62 | _check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy) 63 | assert isinstance( 64 | info, dict 65 | ), f"Expected info to be a dictionary but got {type(info)}" 66 | 67 | for _ in range(max_steps): 68 | actions = [ 69 | env.unwrapped.get_random_action(agent).numpy() 70 | for agent in env.unwrapped.agents 71 | ] 72 | obss, rews, terminated, truncated, info = env.step(actions) 73 | _check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy) 74 | 75 | assert len(rews) == env.unwrapped.n_agents, "Expected one reward per agent" 76 | if not dict_space: 77 | assert isinstance( 78 | rews, list 79 | ), f"Expected list of rewards but got {type(rews)}" 80 | 81 | rew_values = rews 82 | else: 83 | assert isinstance( 84 | rews, dict 85 | ), f"Expected dictionary of rewards but got {type(rews)}" 86 | rew_values = list(rews.values()) 87 | assert all( 88 | isinstance(rew, float) for rew in rew_values 89 | ), f"Expected float rewards but got {type(rew_values[0])}" 90 | 91 | assert isinstance( 92 | terminated, bool 93 | ), f"Expected bool for terminated but got {type(terminated)}" 94 | assert isinstance( 95 | truncated, bool 96 | ), f"Expected bool for truncated but got {type(truncated)}" 97 | 98 | assert isinstance( 99 | info, dict 100 | ), f"Expected info to be a dictionary but got {type(info)}" 101 | 102 | assert truncated, "Expected done to be True after 100 steps" 103 | -------------------------------------------------------------------------------- /vmas/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2025. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | from vmas.interactive_rendering import render_interactively 6 | from vmas.make_env import make_env 7 | from vmas.simulator.environment import Wrapper 8 | 9 | from vmas.simulator.utils import _init_pyglet_device 10 | 11 | _init_pyglet_device() 12 | 13 | __all__ = [ 14 | "make_env", 15 | "render_interactively", 16 | "scenarios", 17 | "debug_scenarios", 18 | "mpe_scenarios", 19 | ] 20 | 21 | __version__ = "1.5.0" 22 | 23 | scenarios = sorted( 24 | [ 25 | "dropout", 26 | "dispersion", 27 | "transport", 28 | "reverse_transport", 29 | "give_way", 30 | "wheel", 31 | "balance", 32 | "football", 33 | "discovery", 34 | "flocking", 35 | "passage", 36 | "joint_passage_size", 37 | "joint_passage", 38 | "ball_passage", 39 | "ball_trajectory", 40 | "buzz_wire", 41 | "multi_give_way", 42 | "navigation", 43 | "sampling", 44 | "wind_flocking", 45 | "road_traffic", 46 | ] 47 | ) 48 | """List of the vmas scenarios (excluding MPE and debug)""" 49 | 50 | debug_scenarios = sorted( 51 | [ 52 | "asym_joint", 53 | "circle_trajectory", 54 | "goal", 55 | "het_mass", 56 | "line_trajectory", 57 | "vel_control", 58 | "waterfall", 59 | "diff_drive", 60 | "kinematic_bicycle", 61 | "pollock", 62 | "drone", 63 | ] 64 | ) 65 | """List of the vmas debug scenarios """ 66 | 67 | 68 | mpe_scenarios = sorted( 69 | [ 70 | "simple", 71 | "simple_adversary", 72 | "simple_crypto", 73 | "simple_push", 74 | "simple_reference", 75 | "simple_speaker_listener", 76 | "simple_spread", 77 | "simple_tag", 78 | "simple_world_comm", 79 | ] 80 | ) 81 | """List of the vmas MPE scenarios """ 82 | -------------------------------------------------------------------------------- /vmas/examples/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | -------------------------------------------------------------------------------- /vmas/examples/run_heuristic.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | import time 5 | from typing import Type 6 | 7 | import torch 8 | 9 | from vmas import make_env 10 | from vmas.simulator.heuristic_policy import BaseHeuristicPolicy, RandomPolicy 11 | from vmas.simulator.utils import save_video 12 | 13 | 14 | def run_heuristic( 15 | scenario_name: str, 16 | heuristic: Type[BaseHeuristicPolicy] = RandomPolicy, 17 | n_steps: int = 200, 18 | n_envs: int = 32, 19 | env_kwargs: dict = None, 20 | render: bool = False, 21 | save_render: bool = False, 22 | device: str = "cpu", 23 | ): 24 | assert not (save_render and not render), "To save the video you have to render it" 25 | if env_kwargs is None: 26 | env_kwargs = {} 27 | 28 | # Scenario specific variables 29 | policy = heuristic(continuous_action=True) 30 | 31 | env = make_env( 32 | scenario=scenario_name, 33 | num_envs=n_envs, 34 | device=device, 35 | continuous_actions=True, 36 | wrapper=None, 37 | # Environment specific variables 38 | **env_kwargs, 39 | ) 40 | 41 | frame_list = [] # For creating a gif 42 | init_time = time.time() 43 | step = 0 44 | obs = env.reset() 45 | total_reward = 0 46 | for _ in range(n_steps): 47 | step += 1 48 | actions = [None] * len(obs) 49 | for i in range(len(obs)): 50 | actions[i] = policy.compute_action(obs[i], u_range=env.agents[i].u_range) 51 | obs, rews, dones, info = env.step(actions) 52 | rewards = torch.stack(rews, dim=1) 53 | global_reward = rewards.mean(dim=1) 54 | mean_global_reward = global_reward.mean(dim=0) 55 | total_reward += mean_global_reward 56 | if render: 57 | frame_list.append( 58 | env.render( 59 | mode="rgb_array", 60 | agent_index_focus=None, 61 | visualize_when_rgb=True, 62 | ) 63 | ) 64 | 65 | total_time = time.time() - init_time 66 | if render and save_render: 67 | save_video(scenario_name, frame_list, 1 / env.scenario.world.dt) 68 | 69 | print( 70 | f"It took: {total_time}s for {n_steps} steps of {n_envs} parallel environments on device {device}\n" 71 | f"The average total reward was {total_reward}" 72 | ) 73 | 74 | 75 | if __name__ == "__main__": 76 | from vmas.scenarios.transport import HeuristicPolicy as TransportHeuristic 77 | 78 | run_heuristic( 79 | scenario_name="transport", 80 | heuristic=TransportHeuristic, 81 | n_envs=300, 82 | n_steps=200, 83 | render=True, 84 | save_render=False, 85 | ) 86 | -------------------------------------------------------------------------------- /vmas/examples/use_vmas_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | import random 5 | import time 6 | 7 | import torch 8 | 9 | from vmas import make_env 10 | from vmas.simulator.core import Agent 11 | from vmas.simulator.utils import save_video 12 | 13 | 14 | def _get_deterministic_action(agent: Agent, continuous: bool, env): 15 | if continuous: 16 | action = -agent.action.u_range_tensor.expand(env.batch_dim, agent.action_size) 17 | else: 18 | action = ( 19 | torch.tensor([1], device=env.device, dtype=torch.long) 20 | .unsqueeze(-1) 21 | .expand(env.batch_dim, 1) 22 | ) 23 | return action.clone() 24 | 25 | 26 | def use_vmas_env( 27 | render: bool = False, 28 | save_render: bool = False, 29 | num_envs: int = 32, 30 | n_steps: int = 100, 31 | random_action: bool = False, 32 | device: str = "cpu", 33 | scenario_name: str = "waterfall", 34 | continuous_actions: bool = True, 35 | visualize_render: bool = True, 36 | dict_spaces: bool = True, 37 | **kwargs, 38 | ): 39 | """Example function to use a vmas environment 40 | 41 | Args: 42 | continuous_actions (bool): Whether the agents have continuous or discrete actions 43 | scenario_name (str): Name of scenario 44 | device (str): Torch device to use 45 | render (bool): Whether to render the scenario 46 | save_render (bool): Whether to save render of the scenario 47 | num_envs (int): Number of vectorized environments 48 | n_steps (int): Number of steps before returning done 49 | random_action (bool): Use random actions or have all agents perform the down action 50 | visualize_render (bool, optional): Whether to visualize the render. Defaults to ``True``. 51 | dict_spaces (bool, optional): Weather to return obs, rewards, and infos as dictionaries with agent names. 52 | By default, they are lists of len # of agents 53 | kwargs (dict, optional): Keyword arguments to pass to the scenario 54 | 55 | Returns: 56 | 57 | """ 58 | assert not (save_render and not render), "To save the video you have to render it" 59 | 60 | env = make_env( 61 | scenario=scenario_name, 62 | num_envs=num_envs, 63 | device=device, 64 | continuous_actions=continuous_actions, 65 | dict_spaces=dict_spaces, 66 | wrapper=None, 67 | seed=None, 68 | # Environment specific variables 69 | **kwargs, 70 | ) 71 | 72 | frame_list = [] # For creating a gif 73 | init_time = time.time() 74 | step = 0 75 | 76 | for _ in range(n_steps): 77 | step += 1 78 | print(f"Step {step}") 79 | 80 | # VMAS actions can be either a list of tensors (one per agent) 81 | # or a dict of tensors (one entry per agent with its name as key) 82 | # Both action inputs can be used independently of what type of space its chosen 83 | dict_actions = random.choice([True, False]) 84 | 85 | actions = {} if dict_actions else [] 86 | for agent in env.agents: 87 | if not random_action: 88 | action = _get_deterministic_action(agent, continuous_actions, env) 89 | else: 90 | action = env.get_random_action(agent) 91 | if dict_actions: 92 | actions.update({agent.name: action}) 93 | else: 94 | actions.append(action) 95 | 96 | obs, rews, dones, info = env.step(actions) 97 | 98 | if render: 99 | frame = env.render( 100 | mode="rgb_array", 101 | agent_index_focus=None, # Can give the camera an agent index to focus on 102 | visualize_when_rgb=visualize_render, 103 | ) 104 | if save_render: 105 | frame_list.append(frame) 106 | 107 | total_time = time.time() - init_time 108 | print( 109 | f"It took: {total_time}s for {n_steps} steps of {num_envs} parallel environments on device {device} " 110 | f"for {scenario_name} scenario." 111 | ) 112 | 113 | if render and save_render: 114 | save_video(scenario_name, frame_list, fps=1 / env.scenario.world.dt) 115 | 116 | 117 | if __name__ == "__main__": 118 | use_vmas_env( 119 | scenario_name="waterfall", 120 | render=True, 121 | save_render=False, 122 | random_action=False, 123 | continuous_actions=False, 124 | # Environment specific 125 | n_agents=4, 126 | ) 127 | -------------------------------------------------------------------------------- /vmas/make_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | from typing import Optional, Union 6 | 7 | from vmas import scenarios 8 | from vmas.simulator.environment import Environment, Wrapper 9 | from vmas.simulator.scenario import BaseScenario 10 | from vmas.simulator.utils import DEVICE_TYPING 11 | 12 | 13 | def make_env( 14 | scenario: Union[str, BaseScenario], 15 | num_envs: int, 16 | device: DEVICE_TYPING = "cpu", 17 | continuous_actions: bool = True, 18 | wrapper: Optional[Union[Wrapper, str]] = None, 19 | max_steps: Optional[int] = None, 20 | seed: Optional[int] = None, 21 | dict_spaces: bool = False, 22 | multidiscrete_actions: bool = False, 23 | clamp_actions: bool = False, 24 | grad_enabled: bool = False, 25 | terminated_truncated: bool = False, 26 | wrapper_kwargs: Optional[dict] = None, 27 | **kwargs, 28 | ): 29 | """Create a vmas environment. 30 | 31 | Args: 32 | scenario (Union[str, BaseScenario]): Scenario to load. 33 | Can be the name of a file in `vmas.scenarios` folder or a :class:`~vmas.simulator.scenario.BaseScenario` class, 34 | num_envs (int): Number of vectorized simulation environments. VMAS performs vectorized simulations using PyTorch. 35 | This argument indicates the number of vectorized environments that should be simulated in a batch. It will also 36 | determine the batch size of the environment. 37 | device (Union[str, int, torch.device], optional): Device for simulation. All the tensors created by VMAS 38 | will be placed on this device. Default is ``"cpu"``, 39 | continuous_actions (bool, optional): Whether to use continuous actions. If ``False``, actions 40 | will be discrete. The number of actions and their size will depend on the chosen scenario. Default is ``True``, 41 | wrapper (Union[Wrapper, str], optional): Wrapper class to use. For example, it can be 42 | ``"rllib"``, ``"gym"``, ``"gymnasium"``, ``"gymnasium_vec"``. Default is ``None``. 43 | max_steps (int, optional): Horizon of the task. Defaults to ``None`` (infinite horizon). Each VMAS scenario can 44 | be terminating or not. If ``max_steps`` is specified, 45 | the scenario is also terminated whenever this horizon is reached, 46 | seed (int, optional): Seed for the environment. Defaults to ``None``, 47 | dict_spaces (bool, optional): Weather to use dictionaries spaces with format ``{"agent_name": tensor, ...}`` 48 | for obs, rewards, and info instead of tuples. Defaults to ``False``: obs, rewards, info are tuples with length number of agents, 49 | multidiscrete_actions (bool, optional): Whether to use multidiscrete action spaces when ``continuous_actions=False``. 50 | Default is ``False``: the action space will be ``Discrete``, and it will be the cartesian product of the 51 | discrete action spaces available to an agent, 52 | clamp_actions (bool, optional): Weather to clamp input actions to their range instead of throwing 53 | an error when ``continuous_actions==True`` and actions are out of bounds, 54 | grad_enabled (bool, optional): If ``True`` the simulator will not call ``detach()`` on input actions and gradients can 55 | be taken from the simulator output. Default is ``False``. 56 | terminated_truncated (bool, optional): Weather to use terminated and truncated flags in the output of the step method (or single done). 57 | Default is ``False``. 58 | wrapper_kwargs (dict, optional): Keyword arguments to pass to the wrapper class. Default is ``{}``. 59 | **kwargs (dict, optional): Keyword arguments to pass to the :class:`~vmas.simulator.scenario.BaseScenario` class. 60 | 61 | Examples: 62 | >>> from vmas import make_env 63 | >>> env = make_env( 64 | ... "waterfall", 65 | ... num_envs=3, 66 | ... num_agents=2, 67 | ... ) 68 | >>> print(env.reset()) 69 | 70 | 71 | """ 72 | 73 | # load scenario from script 74 | if isinstance(scenario, str): 75 | if not scenario.endswith(".py"): 76 | scenario += ".py" 77 | scenario = scenarios.load(scenario).Scenario() 78 | 79 | env = Environment( 80 | scenario, 81 | num_envs=num_envs, 82 | device=device, 83 | continuous_actions=continuous_actions, 84 | max_steps=max_steps, 85 | seed=seed, 86 | dict_spaces=dict_spaces, 87 | multidiscrete_actions=multidiscrete_actions, 88 | clamp_actions=clamp_actions, 89 | grad_enabled=grad_enabled, 90 | terminated_truncated=terminated_truncated, 91 | **kwargs, 92 | ) 93 | 94 | if wrapper is not None and isinstance(wrapper, str): 95 | wrapper = Wrapper[wrapper.upper()] 96 | 97 | if wrapper_kwargs is None: 98 | wrapper_kwargs = {} 99 | 100 | return wrapper.get_env(env, **wrapper_kwargs) if wrapper is not None else env 101 | -------------------------------------------------------------------------------- /vmas/scenarios/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | import importlib 5 | import os 6 | import os.path as osp 7 | from pathlib import Path 8 | 9 | 10 | def load(name: str): 11 | pathname = None 12 | for dirpath, _, filenames in os.walk(osp.dirname(__file__)): 13 | if pathname is None: 14 | for filename in filenames: 15 | if name == filename or Path(name) == Path(dirpath) / Path(filename): 16 | pathname = os.path.join(dirpath, filename) 17 | break 18 | assert pathname is not None, f"{name} scenario not found." 19 | 20 | spec = importlib.util.spec_from_file_location("", pathname) 21 | module = importlib.util.module_from_spec(spec) 22 | spec.loader.exec_module(module) 23 | return module 24 | -------------------------------------------------------------------------------- /vmas/scenarios/debug/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | -------------------------------------------------------------------------------- /vmas/scenarios/debug/diff_drive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | import typing 5 | from typing import List 6 | 7 | import torch 8 | 9 | from vmas import render_interactively 10 | from vmas.simulator.core import Agent, World 11 | from vmas.simulator.dynamics.diff_drive import DiffDrive 12 | from vmas.simulator.dynamics.holonomic_with_rot import HolonomicWithRotation 13 | from vmas.simulator.scenario import BaseScenario 14 | from vmas.simulator.utils import ScenarioUtils 15 | 16 | if typing.TYPE_CHECKING: 17 | from vmas.simulator.rendering import Geom 18 | 19 | 20 | class Scenario(BaseScenario): 21 | def make_world(self, batch_dim: int, device: torch.device, **kwargs): 22 | """ 23 | Differential drive example scenario 24 | Run this file to try it out 25 | 26 | The first agent has differential drive dynamics. 27 | You can control its forward input with the LEFT and RIGHT arrows. 28 | You can control its rotation with UP and DOWN. 29 | 30 | The second agent has standard vmas holonomic dynamics. 31 | You can control it with WASD 32 | You can control its rotation with Q and E. 33 | 34 | """ 35 | # T 36 | self.plot_grid = True 37 | self.n_agents = kwargs.pop("n_agents", 2) 38 | ScenarioUtils.check_kwargs_consumed(kwargs) 39 | 40 | # Make world 41 | world = World(batch_dim, device, substeps=10) 42 | 43 | for i in range(self.n_agents): 44 | if i == 0: 45 | agent = Agent( 46 | name=f"diff_drive_{i}", 47 | collide=True, 48 | render_action=True, 49 | u_range=[1, 1], 50 | u_multiplier=[1, 1], 51 | dynamics=DiffDrive(world, integration="rk4"), 52 | ) 53 | else: 54 | agent = Agent( 55 | name=f"holo_rot_{i}", 56 | collide=True, 57 | render_action=True, 58 | u_range=[1, 1, 1], 59 | u_multiplier=[1, 1, 0.001], 60 | dynamics=HolonomicWithRotation(), 61 | ) 62 | 63 | world.add_agent(agent) 64 | 65 | return world 66 | 67 | def reset_world_at(self, env_index: int = None): 68 | ScenarioUtils.spawn_entities_randomly( 69 | self.world.agents, 70 | self.world, 71 | env_index, 72 | min_dist_between_entities=0.1, 73 | x_bounds=(-1, 1), 74 | y_bounds=(-1, 1), 75 | ) 76 | 77 | def reward(self, agent: Agent): 78 | return torch.zeros(self.world.batch_dim) 79 | 80 | def observation(self, agent: Agent): 81 | observations = [ 82 | agent.state.pos, 83 | agent.state.vel, 84 | ] 85 | return torch.cat( 86 | observations, 87 | dim=-1, 88 | ) 89 | 90 | def extra_render(self, env_index: int = 0) -> "List[Geom]": 91 | 92 | geoms: List[Geom] = [] 93 | 94 | # Agent rotation 95 | for agent in self.world.agents: 96 | geoms.append( 97 | ScenarioUtils.plot_entity_rotation(agent, env_index, length=0.1) 98 | ) 99 | 100 | return geoms 101 | 102 | 103 | if __name__ == "__main__": 104 | render_interactively(__file__, control_two_agents=True) 105 | -------------------------------------------------------------------------------- /vmas/scenarios/debug/drone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import typing 6 | from typing import List 7 | 8 | import torch 9 | 10 | from vmas import render_interactively 11 | from vmas.simulator.core import Agent, World 12 | from vmas.simulator.dynamics.drone import Drone 13 | from vmas.simulator.scenario import BaseScenario 14 | from vmas.simulator.utils import ScenarioUtils 15 | 16 | if typing.TYPE_CHECKING: 17 | from vmas.simulator.rendering import Geom 18 | 19 | 20 | class Scenario(BaseScenario): 21 | def make_world(self, batch_dim: int, device: torch.device, **kwargs): 22 | """ 23 | Drone example scenario 24 | Run this file to try it out. 25 | 26 | You can control the three input torques using left/right arrows, up/down arrows, and m/n. 27 | """ 28 | self.plot_grid = True 29 | self.n_agents = kwargs.pop("n_agents", 2) 30 | ScenarioUtils.check_kwargs_consumed(kwargs) 31 | 32 | # Make world 33 | world = World(batch_dim, device, substeps=10) 34 | 35 | for i in range(self.n_agents): 36 | agent = Agent( 37 | name=f"drone_{i}", 38 | collide=True, 39 | render_action=True, 40 | u_range=[ 41 | 0.00001, 42 | 0.00001, 43 | 0.00001, 44 | ], # torque_x, torque_y, torque_z 45 | u_multiplier=[1, 1, 1], 46 | action_size=3, # We feed only the torque actions to interactively control the drone in the debug scenario 47 | # In non-debug cases, remove this line and the `process_action` function in this file 48 | dynamics=Drone(world, integration="rk4"), 49 | ) 50 | world.add_agent(agent) 51 | 52 | return world 53 | 54 | def reset_world_at(self, env_index: int = None): 55 | ScenarioUtils.spawn_entities_randomly( 56 | self.world.agents, 57 | self.world, 58 | env_index, 59 | min_dist_between_entities=0.1, 60 | x_bounds=(-1, 1), 61 | y_bounds=(-1, 1), 62 | ) 63 | 64 | def reward(self, agent: Agent): 65 | return torch.zeros(self.world.batch_dim, device=self.world.device) 66 | 67 | def process_action(self, agent: Agent): 68 | torque = agent.action.u 69 | thrust = torch.full( 70 | (self.world.batch_dim, 1), 71 | agent.mass * agent.dynamics.g, 72 | device=self.world.device, 73 | ) # Add a fixed thrust to make sure the agent is not falling 74 | agent.action.u = torch.cat([thrust, torque], dim=-1) 75 | 76 | def observation(self, agent: Agent): 77 | observations = [ 78 | agent.state.pos, 79 | agent.state.vel, 80 | ] 81 | return torch.cat( 82 | observations, 83 | dim=-1, 84 | ) 85 | 86 | def done(self): 87 | return torch.any( 88 | torch.stack( 89 | [agent.dynamics.needs_reset() for agent in self.world.agents], 90 | dim=-1, 91 | ), 92 | dim=-1, 93 | ) 94 | 95 | def extra_render(self, env_index: int = 0) -> "List[Geom]": 96 | 97 | geoms: List[Geom] = [] 98 | 99 | # Agent rotation 100 | for agent in self.world.agents: 101 | geoms.append( 102 | ScenarioUtils.plot_entity_rotation(agent, env_index, length=0.1) 103 | ) 104 | 105 | return geoms 106 | 107 | 108 | if __name__ == "__main__": 109 | render_interactively(__file__, control_two_agents=True) 110 | -------------------------------------------------------------------------------- /vmas/scenarios/debug/het_mass.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | import math 5 | from typing import Dict 6 | 7 | import numpy as np 8 | import torch 9 | from torch import Tensor 10 | 11 | from vmas import render_interactively 12 | from vmas.simulator.core import Agent, World 13 | from vmas.simulator.scenario import BaseScenario 14 | from vmas.simulator.utils import Color, ScenarioUtils, Y 15 | 16 | 17 | class Scenario(BaseScenario): 18 | def make_world(self, batch_dim: int, device: torch.device, **kwargs): 19 | self.green_mass = kwargs.pop("green_mass", 4) 20 | self.blue_mass = kwargs.pop("blue_mass", 2) 21 | self.mass_noise = kwargs.pop("mass_noise", 1) 22 | ScenarioUtils.check_kwargs_consumed(kwargs) 23 | self.plot_grid = True 24 | 25 | # Make world 26 | world = World(batch_dim, device) 27 | # Add agents 28 | self.green_agent = Agent( 29 | name="agent 0", 30 | collide=False, 31 | color=Color.GREEN, 32 | render_action=True, 33 | mass=self.green_mass, 34 | f_range=1, 35 | ) 36 | world.add_agent(self.green_agent) 37 | self.blue_agent = Agent( 38 | name="agent 1", collide=False, render_action=True, f_range=1 39 | ) 40 | world.add_agent(self.blue_agent) 41 | 42 | self.max_speed = torch.zeros(batch_dim, device=device) 43 | self.energy_expenditure = self.max_speed.clone() 44 | 45 | return world 46 | 47 | def reset_world_at(self, env_index: int = None): 48 | # Temp 49 | self.blue_agent.mass = self.blue_mass + np.random.uniform( 50 | -self.mass_noise, self.mass_noise 51 | ) 52 | self.green_agent.mass = self.green_mass + np.random.uniform( 53 | -self.mass_noise, self.mass_noise 54 | ) 55 | 56 | for agent in self.world.agents: 57 | agent.set_pos( 58 | torch.zeros( 59 | ( 60 | (1, self.world.dim_p) 61 | if env_index is not None 62 | else (self.world.batch_dim, self.world.dim_p) 63 | ), 64 | device=self.world.device, 65 | dtype=torch.float32, 66 | ).uniform_(-1, 1), 67 | batch_index=env_index, 68 | ) 69 | 70 | def process_action(self, agent: Agent): 71 | agent.action.u[:, Y] = 0 72 | 73 | def reward(self, agent: Agent): 74 | is_first = agent == self.world.agents[0] 75 | 76 | if is_first: 77 | self.max_speed = torch.stack( 78 | [ 79 | torch.linalg.vector_norm(a.state.vel, dim=1) 80 | for a in self.world.agents 81 | ], 82 | dim=1, 83 | ).max(dim=1)[0] 84 | 85 | self.energy_expenditure = ( 86 | -torch.stack( 87 | [ 88 | torch.linalg.vector_norm(a.action.u, dim=-1) 89 | / math.sqrt(self.world.dim_p * (a.f_range**2)) 90 | for a in self.world.agents 91 | ], 92 | dim=1, 93 | ).sum(-1) 94 | * 0.17 95 | ) 96 | 97 | # print(self.max_speed) 98 | # print(self.energy_expenditure) 99 | # self.energy_rew_1 = (self.world.agents[0].action.u[:, X] - 0).abs() 100 | # self.energy_rew_1 += (self.world.agents[0].action.u[:, Y] - 0).abs() 101 | # 102 | # self.energy_rew_2 = (self.world.agents[1].action.u[:, X] - 0).abs() 103 | # self.energy_rew_2 += (self.world.agents[1].action.u[:, Y] - 0).abs() 104 | 105 | return self.max_speed + self.energy_expenditure 106 | 107 | def observation(self, agent: Agent): 108 | return torch.cat( 109 | [agent.state.pos, agent.state.vel], 110 | dim=-1, 111 | ) 112 | 113 | def info(self, agent: Agent) -> Dict[str, Tensor]: 114 | return { 115 | "max_speed": self.max_speed, 116 | "energy_expenditure": self.energy_expenditure, 117 | } 118 | 119 | 120 | if __name__ == "__main__": 121 | render_interactively(__file__, control_two_agents=True) 122 | -------------------------------------------------------------------------------- /vmas/scenarios/debug/kinematic_bicycle.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import typing 6 | from typing import List 7 | 8 | import torch 9 | 10 | from vmas import render_interactively 11 | from vmas.simulator.core import Agent, Box, World 12 | from vmas.simulator.dynamics.holonomic_with_rot import HolonomicWithRotation 13 | from vmas.simulator.dynamics.kinematic_bicycle import KinematicBicycle 14 | from vmas.simulator.scenario import BaseScenario 15 | from vmas.simulator.utils import ScenarioUtils 16 | 17 | if typing.TYPE_CHECKING: 18 | from vmas.simulator.rendering import Geom 19 | 20 | 21 | class Scenario(BaseScenario): 22 | def make_world(self, batch_dim: int, device: torch.device, **kwargs): 23 | """ 24 | Kinematic bicycle model example scenario 25 | """ 26 | self.n_agents = kwargs.pop("n_agents", 2) 27 | width = kwargs.pop("width", 0.1) # Agent width 28 | l_f = kwargs.pop( 29 | "l_f", 0.1 30 | ) # Distance between the front axle and the center of gravity 31 | l_r = kwargs.pop( 32 | "l_r", 0.1 33 | ) # Distance between the rear axle and the center of gravity 34 | max_steering_angle = kwargs.pop( 35 | "max_steering_angle", torch.deg2rad(torch.tensor(30.0)) 36 | ) 37 | max_speed = kwargs.pop("max_speed", 1.0) 38 | ScenarioUtils.check_kwargs_consumed(kwargs) 39 | 40 | # Make world 41 | world = World(batch_dim, device, substeps=10, collision_force=500) 42 | 43 | for i in range(self.n_agents): 44 | if i == 0: 45 | # Use the kinematic bicycle model for the first agent 46 | agent = Agent( 47 | name=f"bicycle_{i}", 48 | shape=Box(length=l_f + l_r, width=width), 49 | collide=True, 50 | render_action=True, 51 | u_range=[max_speed, max_steering_angle], 52 | u_multiplier=[1, 1], 53 | max_speed=max_speed, 54 | dynamics=KinematicBicycle( 55 | world, 56 | width=width, 57 | l_f=l_f, 58 | l_r=l_r, 59 | max_steering_angle=max_steering_angle, 60 | integration="euler", # one of "euler", "rk4" 61 | ), 62 | ) 63 | else: 64 | agent = Agent( 65 | name=f"holo_rot_{i}", 66 | shape=Box(length=l_f + l_r, width=width), 67 | collide=True, 68 | render_action=True, 69 | u_range=[1, 1, 1], 70 | u_multiplier=[1, 1, 0.001], 71 | dynamics=HolonomicWithRotation(), 72 | ) 73 | 74 | world.add_agent(agent) 75 | 76 | return world 77 | 78 | def reset_world_at(self, env_index: int = None): 79 | ScenarioUtils.spawn_entities_randomly( 80 | self.world.agents, 81 | self.world, 82 | env_index, 83 | min_dist_between_entities=0.1, 84 | x_bounds=(-1, 1), 85 | y_bounds=(-1, 1), 86 | ) 87 | 88 | def reward(self, agent: Agent): 89 | return torch.zeros(self.world.batch_dim) 90 | 91 | def observation(self, agent: Agent): 92 | observations = [ 93 | agent.state.pos, 94 | agent.state.vel, 95 | ] 96 | return torch.cat( 97 | observations, 98 | dim=-1, 99 | ) 100 | 101 | def extra_render(self, env_index: int = 0) -> "List[Geom]": 102 | 103 | geoms: List[Geom] = [] 104 | 105 | # Agent rotation 106 | for agent in self.world.agents: 107 | geoms.append( 108 | ScenarioUtils.plot_entity_rotation(agent, env_index, length=0.1) 109 | ) 110 | 111 | return geoms 112 | 113 | 114 | # ... and the code to run the simulation. 115 | if __name__ == "__main__": 116 | render_interactively( 117 | __file__, 118 | control_two_agents=True, 119 | width=0.1, 120 | l_f=0.1, 121 | l_r=0.1, 122 | display_info=True, 123 | ) 124 | -------------------------------------------------------------------------------- /vmas/scenarios/debug/line_trajectory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | from typing import Dict 5 | 6 | import torch 7 | from torch import Tensor 8 | 9 | from vmas import render_interactively 10 | from vmas.simulator.controllers.velocity_controller import VelocityController 11 | from vmas.simulator.core import Agent, Sphere, World 12 | from vmas.simulator.scenario import BaseScenario 13 | from vmas.simulator.utils import Color, ScenarioUtils, X, Y 14 | 15 | 16 | class Scenario(BaseScenario): 17 | def make_world(self, batch_dim: int, device: torch.device, **kwargs): 18 | self.obs_noise = kwargs.pop("obs_noise", 0) 19 | ScenarioUtils.check_kwargs_consumed(kwargs) 20 | 21 | self.agent_radius = 0.03 22 | self.line_length = 3 23 | 24 | # Make world 25 | world = World(batch_dim, device, drag=0.1) 26 | # Add agents 27 | self.agent = Agent( 28 | name="agent_0", 29 | shape=Sphere(self.agent_radius), 30 | mass=2, 31 | f_range=0.5, 32 | u_range=1, 33 | render_action=True, 34 | ) 35 | self.agent.controller = VelocityController( 36 | self.agent, world, [4, 1.25, 0.001], "standard" 37 | ) 38 | world.add_agent(self.agent) 39 | 40 | self.tangent = torch.zeros((world.batch_dim, world.dim_p), device=world.device) 41 | self.tangent[:, Y] = 1 42 | 43 | self.pos_rew = torch.zeros(batch_dim, device=device) 44 | self.dot_product = self.pos_rew.clone() 45 | self.steady_rew = self.pos_rew.clone() 46 | 47 | return world 48 | 49 | def process_action(self, agent: Agent): 50 | self.vel_action = agent.action.u.clone() 51 | agent.controller.process_force() 52 | 53 | def reset_world_at(self, env_index: int = None): 54 | self.agent.controller.reset(env_index) 55 | self.agent.set_pos( 56 | torch.cat( 57 | [ 58 | torch.zeros( 59 | (1, 1) if env_index is not None else (self.world.batch_dim, 1), 60 | device=self.world.device, 61 | dtype=torch.float32, 62 | ).uniform_( 63 | -1, 64 | 1, 65 | ), 66 | torch.zeros( 67 | (1, 1) if env_index is not None else (self.world.batch_dim, 1), 68 | device=self.world.device, 69 | dtype=torch.float32, 70 | ).uniform_( 71 | -1, 72 | 0, 73 | ), 74 | ], 75 | dim=1, 76 | ), 77 | batch_index=env_index, 78 | ) 79 | 80 | def reward(self, agent: Agent): 81 | closest_point = agent.state.pos.clone() 82 | closest_point[:, X] = 0 83 | self.pos_rew = ( 84 | -(torch.linalg.vector_norm(agent.state.pos - closest_point, dim=1) ** 0.5) 85 | * 1 86 | ) 87 | self.dot_product = torch.einsum("bs,bs->b", self.tangent, agent.state.vel) * 0.5 88 | 89 | normalized_vel = agent.state.vel / torch.linalg.vector_norm( 90 | agent.state.vel, dim=1 91 | ).unsqueeze(-1) 92 | normalized_vel = torch.nan_to_num(normalized_vel) 93 | 94 | normalized_vel_action = self.vel_action / torch.linalg.vector_norm( 95 | self.vel_action, dim=1 96 | ).unsqueeze(-1) 97 | normalized_vel_action = torch.nan_to_num(normalized_vel_action) 98 | 99 | self.steady_rew = ( 100 | torch.einsum("bs,bs->b", normalized_vel, normalized_vel_action) * 0.2 101 | ) 102 | 103 | return self.pos_rew + self.dot_product + self.steady_rew 104 | 105 | def observation(self, agent: Agent): 106 | observations = [agent.state.pos, agent.state.vel, agent.state.pos] 107 | for i, obs in enumerate(observations): 108 | noise = torch.zeros(*obs.shape, device=self.world.device,).uniform_( 109 | -self.obs_noise, 110 | self.obs_noise, 111 | ) 112 | observations[i] = obs + noise 113 | return torch.cat( 114 | observations, 115 | dim=-1, 116 | ) 117 | 118 | def done(self): 119 | return self.world.agents[0].state.pos[:, Y] > self.line_length - 1 120 | 121 | def info(self, agent: Agent) -> Dict[str, Tensor]: 122 | return { 123 | "pos_rew": self.pos_rew, 124 | "dot_product": self.dot_product, 125 | "steady_rew": self.steady_rew, 126 | } 127 | 128 | def extra_render(self, env_index: int = 0): 129 | from vmas.simulator import rendering 130 | 131 | geoms = [] 132 | 133 | # Trajectory goal circle 134 | color = Color.BLACK.value 135 | line = rendering.Line( 136 | (0, -1), 137 | (0, -1 + self.line_length), 138 | width=1, 139 | ) 140 | line.set_color(*color) 141 | geoms.append(line) 142 | 143 | return geoms 144 | 145 | 146 | if __name__ == "__main__": 147 | render_interactively(__file__) 148 | -------------------------------------------------------------------------------- /vmas/scenarios/debug/pollock.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import torch 6 | 7 | from vmas import render_interactively 8 | from vmas.simulator.core import Agent, Box, Landmark, Line, Sphere, World 9 | from vmas.simulator.scenario import BaseScenario 10 | 11 | from vmas.simulator.sensors import Lidar 12 | from vmas.simulator.utils import Color, ScenarioUtils 13 | 14 | 15 | class Scenario(BaseScenario): 16 | def make_world(self, batch_dim: int, device: torch.device, **kwargs): 17 | self.n_agents = kwargs.pop("n_agents", 15) 18 | self.n_lines = kwargs.pop("n_lines", 15) 19 | self.n_boxes = kwargs.pop("n_boxes", 15) 20 | self.lidar = kwargs.pop("lidar", False) 21 | self.vectorized_lidar = kwargs.pop("vectorized_lidar", True) 22 | ScenarioUtils.check_kwargs_consumed(kwargs) 23 | 24 | self.agent_radius = 0.05 25 | self.line_length = 0.3 26 | self.box_length = 0.2 27 | self.box_width = 0.1 28 | 29 | self.world_semidim = 1 30 | self.min_dist_between_entities = 0.1 31 | 32 | # Make world 33 | world = World( 34 | batch_dim, 35 | device, 36 | dt=0.1, 37 | drag=0.25, 38 | substeps=5, 39 | collision_force=500, 40 | x_semidim=self.world_semidim, 41 | y_semidim=self.world_semidim, 42 | ) 43 | # Add agents 44 | for i in range(self.n_agents): 45 | agent = Agent( 46 | name=f"agent_{i}", 47 | shape=Sphere(radius=self.agent_radius), 48 | u_multiplier=0.7, 49 | rotatable=True, 50 | sensors=[Lidar(world, n_rays=16, max_range=0.5)] if self.lidar else [], 51 | ) 52 | world.add_agent(agent) 53 | 54 | # Add lines 55 | for i in range(self.n_lines): 56 | landmark = Landmark( 57 | name=f"line {i}", 58 | collide=True, 59 | movable=True, 60 | rotatable=True, 61 | shape=Line(length=self.line_length), 62 | color=Color.BLACK, 63 | ) 64 | world.add_landmark(landmark) 65 | for i in range(self.n_boxes): 66 | landmark = Landmark( 67 | name=f"box {i}", 68 | collide=True, 69 | movable=True, 70 | rotatable=True, 71 | shape=Box(length=self.box_length, width=self.box_width), 72 | color=Color.RED, 73 | ) 74 | world.add_landmark(landmark) 75 | 76 | return world 77 | 78 | def reset_world_at(self, env_index: int = None): 79 | # Some things may be spawn on top of each other 80 | ScenarioUtils.spawn_entities_randomly( 81 | self.world.agents + self.world.landmarks, 82 | self.world, 83 | env_index, 84 | self.min_dist_between_entities, 85 | (-self.world_semidim, self.world_semidim), 86 | (-self.world_semidim, self.world_semidim), 87 | ) 88 | 89 | def reward(self, agent: Agent): 90 | return torch.zeros(self.world.batch_dim, device=self.world.device) 91 | 92 | def observation(self, agent: Agent): 93 | return ( 94 | torch.zeros(self.world.batch_dim, 1, device=self.world.device) 95 | if not self.lidar 96 | else agent.sensors[0].measure(vectorized=self.vectorized_lidar) 97 | ) 98 | 99 | 100 | if __name__ == "__main__": 101 | render_interactively( 102 | __file__, 103 | control_two_agents=True, 104 | ) 105 | -------------------------------------------------------------------------------- /vmas/scenarios/debug/vel_control.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | from typing import Dict 5 | 6 | import torch 7 | from torch import Tensor 8 | 9 | from vmas import render_interactively 10 | from vmas.simulator.controllers.velocity_controller import VelocityController 11 | from vmas.simulator.core import Agent, Landmark, World 12 | from vmas.simulator.scenario import BaseScenario 13 | from vmas.simulator.utils import Color, ScenarioUtils, TorchUtils, X 14 | 15 | 16 | class Scenario(BaseScenario): 17 | def make_world(self, batch_dim: int, device: torch.device, **kwargs): 18 | self.green_mass = kwargs.pop("green_mass", 1) 19 | ScenarioUtils.check_kwargs_consumed(kwargs) 20 | self.plot_grid = True 21 | 22 | self.agent_radius = 0.16 23 | 24 | controller_params = [2, 6, 0.002] 25 | 26 | linear_friction = 0.1 27 | v_range = 1 28 | a_range = 1 29 | f_range = linear_friction + a_range 30 | 31 | # u_range now represents velocities since we are preprocessing actions using a velocity controller 32 | u_range = v_range 33 | 34 | # Make world 35 | world = World( 36 | batch_dim, 37 | device, 38 | linear_friction=linear_friction, 39 | drag=0, 40 | dt=0.05, 41 | substeps=4, 42 | ) 43 | 44 | null_action = torch.zeros(world.batch_dim, world.dim_p, device=world.device) 45 | self.input_queue = [null_action.clone() for _ in range(2)] 46 | # control delayed by n dts 47 | 48 | # Add agents 49 | agent = Agent( 50 | name="agent 0", 51 | collide=False, 52 | color=Color.GREEN, 53 | render_action=True, 54 | mass=self.green_mass, 55 | f_range=f_range, 56 | u_range=u_range, 57 | ) 58 | agent.controller = VelocityController( 59 | agent, world, controller_params, "standard" 60 | ) 61 | world.add_agent(agent) 62 | agent = Agent( 63 | name="agent 1", 64 | collide=False, 65 | render_action=True, 66 | # f_range=30, 67 | u_range=u_range, 68 | ) 69 | agent.controller = VelocityController( 70 | agent, world, controller_params, "standard" 71 | ) 72 | world.add_agent(agent) 73 | agent = Agent( 74 | name="agent 2", 75 | collide=False, 76 | render_action=True, 77 | f_range=30, 78 | u_range=u_range, 79 | ) 80 | agent.controller = VelocityController( 81 | agent, world, controller_params, "standard" 82 | ) 83 | world.add_agent(agent) 84 | 85 | self.landmark = Landmark("landmark 0", collide=False, movable=True) 86 | world.add_landmark(self.landmark) 87 | 88 | self.energy_expenditure = torch.zeros(batch_dim, device=device) 89 | 90 | return world 91 | 92 | def reset_world_at(self, env_index: int = None): 93 | for agent in self.world.agents: 94 | agent.controller.reset(env_index) 95 | agent.set_pos( 96 | torch.cat( 97 | [ 98 | torch.zeros( 99 | ( 100 | (1, 1) 101 | if env_index is not None 102 | else (self.world.batch_dim, 1) 103 | ), 104 | device=self.world.device, 105 | dtype=torch.float32, 106 | ).uniform_( 107 | -1, 108 | -1, 109 | ), 110 | torch.zeros( 111 | ( 112 | (1, 1) 113 | if env_index is not None 114 | else (self.world.batch_dim, 1) 115 | ), 116 | device=self.world.device, 117 | dtype=torch.float32, 118 | ).uniform_( 119 | 0, 120 | 0, 121 | ), 122 | ], 123 | dim=1, 124 | ), 125 | batch_index=env_index, 126 | ) 127 | 128 | def process_action(self, agent: Agent): 129 | # Clamp square to circle 130 | agent.action.u = TorchUtils.clamp_with_norm(agent.action.u, agent.u_range) 131 | 132 | # Zero small input 133 | action_norm = torch.linalg.vector_norm(agent.action.u, dim=1) 134 | agent.action.u[action_norm < 0.08] = 0 135 | 136 | # agent.action.u[:, Y] = 0 137 | if agent == self.world.agents[1]: 138 | max_a = 1 139 | 140 | agent.vel_goal = agent.action.u[:, X] 141 | requested_a = (agent.vel_goal - agent.state.vel[:, X]) / self.world.dt 142 | achievable_a = torch.clamp(requested_a, -max_a, max_a) 143 | agent.action.u[:, X] = (achievable_a * self.world.dt) + agent.state.vel[ 144 | :, X 145 | ] 146 | 147 | agent.controller.process_force() 148 | 149 | def reward(self, agent: Agent): 150 | is_first = agent == self.world.agents[0] 151 | 152 | if is_first: 153 | self.energy_expenditure = ( 154 | -torch.stack( 155 | [ 156 | torch.linalg.vector_norm(a.action.u, dim=-1) 157 | for a in self.world.agents 158 | ], 159 | dim=1, 160 | ).sum(-1) 161 | * 3 162 | ) 163 | 164 | return self.energy_expenditure 165 | 166 | def observation(self, agent: Agent): 167 | return torch.cat( 168 | [agent.state.pos, agent.state.vel], 169 | dim=-1, 170 | ) 171 | 172 | def info(self, agent: Agent) -> Dict[str, Tensor]: 173 | return { 174 | "energy_expenditure": self.energy_expenditure, 175 | } 176 | 177 | 178 | if __name__ == "__main__": 179 | render_interactively(__file__, control_two_agents=True) 180 | -------------------------------------------------------------------------------- /vmas/scenarios/debug/waterfall.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import torch 6 | 7 | from vmas import render_interactively 8 | from vmas.simulator.core import Agent, Box, Landmark, Line, Sphere, World 9 | from vmas.simulator.joints import Joint 10 | from vmas.simulator.scenario import BaseScenario 11 | from vmas.simulator.utils import Color, ScenarioUtils 12 | 13 | 14 | class Scenario(BaseScenario): 15 | def make_world(self, batch_dim: int, device: torch.device, **kwargs): 16 | self.n_agents = kwargs.pop("n_agents", 5) 17 | self.with_joints = kwargs.pop("joints", True) 18 | ScenarioUtils.check_kwargs_consumed(kwargs) 19 | 20 | self.agent_dist = 0.1 21 | self.agent_radius = 0.04 22 | 23 | # Make world 24 | world = World( 25 | batch_dim, 26 | device, 27 | dt=0.1, 28 | drag=0.25, 29 | substeps=5, 30 | collision_force=500, 31 | ) 32 | # Add agents 33 | for i in range(self.n_agents): 34 | agent = Agent( 35 | name=f"agent_{i}", 36 | shape=Sphere(radius=self.agent_radius), 37 | u_multiplier=0.7, 38 | rotatable=True, 39 | ) 40 | world.add_agent(agent) 41 | if self.with_joints: 42 | # Add joints 43 | for i in range(self.n_agents - 1): 44 | joint = Joint( 45 | world.agents[i], 46 | world.agents[i + 1], 47 | anchor_a=(1, 0), 48 | anchor_b=(-1, 0), 49 | dist=self.agent_dist, 50 | rotate_a=True, 51 | rotate_b=True, 52 | collidable=True, 53 | width=0, 54 | mass=1, 55 | ) 56 | world.add_joint(joint) 57 | landmark = Landmark( 58 | name="joined landmark", 59 | collide=True, 60 | movable=True, 61 | rotatable=True, 62 | shape=Box(length=self.agent_radius * 2, width=0.3), 63 | color=Color.GREEN, 64 | ) 65 | world.add_landmark(landmark) 66 | joint = Joint( 67 | world.agents[-1], 68 | landmark, 69 | anchor_a=(1, 0), 70 | anchor_b=(-1, 0), 71 | dist=self.agent_dist, 72 | rotate_a=False, 73 | rotate_b=False, 74 | collidable=True, 75 | width=0, 76 | mass=1, 77 | ) 78 | world.add_joint(joint) 79 | 80 | # Add landmarks 81 | for i in range(5): 82 | landmark = Landmark( 83 | name=f"landmark {i}", 84 | collide=True, 85 | movable=True, 86 | rotatable=True, 87 | shape=Box(length=0.3, width=0.1), 88 | color=Color.RED, 89 | # collision_filter=lambda e: False 90 | # if isinstance(e.shape, Box) and e.name != "joined landmark" 91 | # else True, 92 | ) 93 | world.add_landmark(landmark) 94 | floor = Landmark( 95 | name="floor", 96 | collide=True, 97 | movable=False, 98 | shape=Line(length=2), 99 | color=Color.BLACK, 100 | ) 101 | world.add_landmark(floor) 102 | 103 | return world 104 | 105 | def reset_world_at(self, env_index: int = None): 106 | for i, agent in enumerate( 107 | self.world.agents + [self.world.landmarks[self.n_agents - 1]] 108 | ): 109 | agent.set_pos( 110 | torch.tensor( 111 | [ 112 | -0.2 + (self.agent_dist + 2 * self.agent_radius) * i, 113 | 1.0, 114 | ], 115 | dtype=torch.float32, 116 | device=self.world.device, 117 | ), 118 | batch_index=env_index, 119 | ) 120 | for i, landmark in enumerate( 121 | self.world.landmarks[(self.n_agents + 1) if self.with_joints else 0 : -1] 122 | ): 123 | landmark.set_pos( 124 | torch.tensor( 125 | [0.2 if i % 2 else -0.2, 0.6 - 0.3 * i], 126 | dtype=torch.float32, 127 | device=self.world.device, 128 | ), 129 | batch_index=env_index, 130 | ) 131 | landmark.set_rot( 132 | torch.tensor( 133 | [torch.pi / 4 if i % 2 else -torch.pi / 4], 134 | dtype=torch.float32, 135 | device=self.world.device, 136 | ), 137 | batch_index=env_index, 138 | ) 139 | floor = self.world.landmarks[-1] 140 | floor.set_pos( 141 | torch.tensor( 142 | [0, -1], 143 | dtype=torch.float32, 144 | device=self.world.device, 145 | ), 146 | batch_index=env_index, 147 | ) 148 | 149 | def reward(self, agent: Agent): 150 | dist2 = torch.linalg.vector_norm( 151 | agent.state.pos - self.world.landmarks[-1].state.pos, dim=1 152 | ) 153 | return -dist2 154 | 155 | def observation(self, agent: Agent): 156 | # get positions of all entities in this agent's reference frame 157 | return torch.cat( 158 | [agent.state.pos, agent.state.vel] 159 | + [ 160 | landmark.state.pos - agent.state.pos 161 | for landmark in self.world.landmarks 162 | ], 163 | dim=-1, 164 | ) 165 | 166 | 167 | if __name__ == "__main__": 168 | render_interactively( 169 | __file__, 170 | control_two_agents=True, 171 | n_agents=5, 172 | joints=True, 173 | ) 174 | -------------------------------------------------------------------------------- /vmas/scenarios/mpe/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | -------------------------------------------------------------------------------- /vmas/scenarios/mpe/simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import torch 6 | 7 | from vmas import render_interactively 8 | from vmas.simulator.core import Agent, Landmark, World 9 | from vmas.simulator.scenario import BaseScenario 10 | from vmas.simulator.utils import Color, ScenarioUtils 11 | 12 | 13 | class Scenario(BaseScenario): 14 | def make_world(self, batch_dim: int, device: torch.device, **kwargs): 15 | ScenarioUtils.check_kwargs_consumed(kwargs) 16 | # Make world 17 | world = World(batch_dim, device) 18 | # Add agents 19 | for i in range(1): 20 | agent = Agent(name=f"agent_{i}", collide=False, color=Color.GRAY) 21 | world.add_agent(agent) 22 | # Add landmarks 23 | for i in range(1): 24 | landmark = Landmark( 25 | name=f"landmark {i}", 26 | collide=False, 27 | color=Color.RED, 28 | ) 29 | world.add_landmark(landmark) 30 | 31 | return world 32 | 33 | def reset_world_at(self, env_index: int = None): 34 | for agent in self.world.agents: 35 | agent.set_pos( 36 | torch.zeros( 37 | ( 38 | (1, self.world.dim_p) 39 | if env_index is not None 40 | else (self.world.batch_dim, self.world.dim_p) 41 | ), 42 | device=self.world.device, 43 | dtype=torch.float32, 44 | ).uniform_( 45 | -1.0, 46 | 1.0, 47 | ), 48 | batch_index=env_index, 49 | ) 50 | for landmark in self.world.landmarks: 51 | landmark.set_pos( 52 | torch.zeros( 53 | ( 54 | (1, self.world.dim_p) 55 | if env_index is not None 56 | else (self.world.batch_dim, self.world.dim_p) 57 | ), 58 | device=self.world.device, 59 | dtype=torch.float32, 60 | ).uniform_( 61 | -1.0, 62 | 1.0, 63 | ), 64 | batch_index=env_index, 65 | ) 66 | 67 | def reward(self, agent: Agent): 68 | dist2 = torch.sum( 69 | torch.square(agent.state.pos - self.world.landmarks[0].state.pos), 70 | dim=-1, 71 | ) 72 | return -dist2 73 | 74 | def observation(self, agent: Agent): 75 | # get positions of all entities in this agent's reference frame 76 | entity_pos = [] 77 | for entity in self.world.landmarks: 78 | entity_pos.append(entity.state.pos - agent.state.pos) 79 | return torch.cat([agent.state.vel, *entity_pos], dim=-1) 80 | 81 | 82 | if __name__ == "__main__": 83 | render_interactively(__file__) 84 | -------------------------------------------------------------------------------- /vmas/scenarios/mpe/simple_reference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import torch 6 | 7 | from vmas.simulator.core import Agent, Landmark, World 8 | from vmas.simulator.scenario import BaseScenario 9 | from vmas.simulator.utils import ScenarioUtils 10 | 11 | 12 | class Scenario(BaseScenario): 13 | def make_world(self, batch_dim: int, device: torch.device, **kwargs): 14 | ScenarioUtils.check_kwargs_consumed(kwargs) 15 | world = World(batch_dim=batch_dim, device=device, dim_c=10) 16 | 17 | n_agents = 2 18 | n_landmarks = 3 19 | 20 | # Add agents 21 | for i in range(n_agents): 22 | agent = Agent(name=f"agent_{i}", collide=False, silent=False) 23 | world.add_agent(agent) 24 | # Add landmarks 25 | for i in range(n_landmarks): 26 | landmark = Landmark( 27 | name=f"landmark {i}", 28 | collide=False, 29 | ) 30 | world.add_landmark(landmark) 31 | 32 | return world 33 | 34 | def reset_world_at(self, env_index: int = None): 35 | if env_index is None: 36 | # assign goals to agents 37 | for agent in self.world.agents: 38 | agent.goal_a = None 39 | agent.goal_b = None 40 | # want other agent to go to the goal landmark 41 | self.world.agents[0].goal_a = self.world.agents[1] 42 | self.world.agents[0].goal_b = self.world.landmarks[ 43 | torch.randint(0, len(self.world.landmarks), (1,)).item() 44 | ] 45 | self.world.agents[1].goal_a = self.world.agents[0] 46 | self.world.agents[1].goal_b = self.world.landmarks[ 47 | torch.randint(0, len(self.world.landmarks), (1,)).item() 48 | ] 49 | # random properties for agents 50 | for agent in self.world.agents: 51 | agent.color = torch.tensor( 52 | [0.25, 0.25, 0.25], 53 | device=self.world.device, 54 | dtype=torch.float32, 55 | ) 56 | # random properties for landmarks 57 | self.world.landmarks[0].color = torch.tensor( 58 | [0.75, 0.25, 0.25], 59 | device=self.world.device, 60 | dtype=torch.float32, 61 | ) 62 | self.world.landmarks[1].color = torch.tensor( 63 | [0.25, 0.75, 0.25], 64 | device=self.world.device, 65 | dtype=torch.float32, 66 | ) 67 | self.world.landmarks[2].color = torch.tensor( 68 | [0.25, 0.25, 0.75], 69 | device=self.world.device, 70 | dtype=torch.float32, 71 | ) 72 | # special colors for goals 73 | self.world.agents[0].goal_a.color = self.world.agents[0].goal_b.color 74 | self.world.agents[1].goal_a.color = self.world.agents[1].goal_b.color 75 | 76 | # set random initial states 77 | for agent in self.world.agents: 78 | agent.set_pos( 79 | torch.zeros( 80 | ( 81 | (1, self.world.dim_p) 82 | if env_index is not None 83 | else (self.world.batch_dim, self.world.dim_p) 84 | ), 85 | device=self.world.device, 86 | dtype=torch.float32, 87 | ).uniform_( 88 | -1.0, 89 | 1.0, 90 | ), 91 | batch_index=env_index, 92 | ) 93 | for landmark in self.world.landmarks: 94 | landmark.set_pos( 95 | torch.zeros( 96 | ( 97 | (1, self.world.dim_p) 98 | if env_index is not None 99 | else (self.world.batch_dim, self.world.dim_p) 100 | ), 101 | device=self.world.device, 102 | dtype=torch.float32, 103 | ).uniform_( 104 | -1.0, 105 | 1.0, 106 | ), 107 | batch_index=env_index, 108 | ) 109 | 110 | def reward(self, agent: Agent): 111 | is_first = agent == self.world.agents[0] 112 | if is_first: 113 | self.rew = torch.zeros(self.world.batch_dim, device=self.world.device) 114 | for a in self.world.agents: 115 | if a.goal_a is None or a.goal_b is None: 116 | return torch.zeros( 117 | self.world.batch_dim, 118 | device=self.world.device, 119 | dtype=torch.float32, 120 | ) 121 | self.rew += -torch.sqrt( 122 | torch.sum( 123 | torch.square(a.goal_a.state.pos - a.goal_b.state.pos), 124 | dim=-1, 125 | ) 126 | ) 127 | return self.rew 128 | 129 | def observation(self, agent: Agent): 130 | # goal color 131 | goal_color = agent.goal_b.color 132 | 133 | # get positions of all entities in this agent's reference frame 134 | entity_pos = [] 135 | for entity in self.world.landmarks: 136 | entity_pos.append(entity.state.pos - agent.state.pos) 137 | 138 | # communication of all other agents 139 | comm = [] 140 | for other in self.world.agents: 141 | if other is agent: 142 | continue 143 | comm.append(other.state.c) 144 | return torch.cat( 145 | [ 146 | agent.state.vel, 147 | *entity_pos, 148 | goal_color.repeat(self.world.batch_dim, 1), 149 | *comm, 150 | ], 151 | dim=-1, 152 | ) 153 | -------------------------------------------------------------------------------- /vmas/scenarios/mpe/simple_speaker_listener.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import torch 6 | 7 | from vmas.simulator.core import Agent, Landmark, Sphere, World 8 | from vmas.simulator.scenario import BaseScenario 9 | from vmas.simulator.utils import ScenarioUtils 10 | 11 | 12 | class Scenario(BaseScenario): 13 | def make_world(self, batch_dim: int, device: torch.device, **kwargs): 14 | ScenarioUtils.check_kwargs_consumed(kwargs) 15 | world = World(batch_dim=batch_dim, device=device, dim_c=3) 16 | # set any world properties first 17 | num_agents = 2 18 | num_landmarks = 3 19 | 20 | # Add agents 21 | for i in range(num_agents): 22 | speaker = True if i == 0 else False 23 | name = "speaker_0" if speaker else "listener_0" 24 | agent = Agent( 25 | name=name, 26 | collide=False, 27 | movable=False if speaker else True, 28 | silent=False if speaker else True, 29 | shape=Sphere(radius=0.075), 30 | ) 31 | world.add_agent(agent) 32 | # Add landmarks 33 | for i in range(num_landmarks): 34 | landmark = Landmark( 35 | name=f"landmark {i}", collide=False, shape=Sphere(radius=0.04) 36 | ) 37 | world.add_landmark(landmark) 38 | 39 | return world 40 | 41 | def reset_world_at(self, env_index: int = None): 42 | if env_index is None: 43 | # assign goals to agents 44 | for agent in self.world.agents: 45 | agent.goal_a = None 46 | agent.goal_b = None 47 | # want listener to go to the goal landmark 48 | self.world.agents[0].goal_a = self.world.agents[1] 49 | self.world.agents[0].goal_b = self.world.landmarks[ 50 | torch.randint(0, len(self.world.landmarks), (1,)).item() 51 | ] 52 | # random properties for agents 53 | for agent in self.world.agents: 54 | agent.color = torch.tensor( 55 | [0.25, 0.25, 0.25], 56 | device=self.world.device, 57 | dtype=torch.float32, 58 | ) 59 | # random properties for landmarks 60 | self.world.landmarks[0].color = torch.tensor( 61 | [0.65, 0.15, 0.15], 62 | device=self.world.device, 63 | dtype=torch.float32, 64 | ) 65 | self.world.landmarks[1].color = torch.tensor( 66 | [0.15, 0.65, 0.15], 67 | device=self.world.device, 68 | dtype=torch.float32, 69 | ) 70 | self.world.landmarks[2].color = torch.tensor( 71 | [0.15, 0.15, 0.65], 72 | device=self.world.device, 73 | dtype=torch.float32, 74 | ) 75 | # special colors for goals 76 | self.world.agents[0].goal_a.color = self.world.agents[ 77 | 0 78 | ].goal_b.color + torch.tensor( 79 | [0.45, 0.45, 0.45], 80 | device=self.world.device, 81 | dtype=torch.float32, 82 | ) 83 | 84 | # set random initial states 85 | for agent in self.world.agents: 86 | agent.set_pos( 87 | torch.zeros( 88 | ( 89 | (1, self.world.dim_p) 90 | if env_index is not None 91 | else (self.world.batch_dim, self.world.dim_p) 92 | ), 93 | device=self.world.device, 94 | dtype=torch.float32, 95 | ).uniform_( 96 | -1.0, 97 | 1.0, 98 | ), 99 | batch_index=env_index, 100 | ) 101 | for landmark in self.world.landmarks: 102 | landmark.set_pos( 103 | torch.zeros( 104 | ( 105 | (1, self.world.dim_p) 106 | if env_index is not None 107 | else (self.world.batch_dim, self.world.dim_p) 108 | ), 109 | device=self.world.device, 110 | dtype=torch.float32, 111 | ).uniform_( 112 | -1.0, 113 | 1.0, 114 | ), 115 | batch_index=env_index, 116 | ) 117 | 118 | def reward(self, agent: Agent): 119 | # squared distance from listener to landmark 120 | is_first = agent == self.world.agents[0] 121 | if is_first: 122 | self.rew = torch.zeros(self.world.batch_dim, device=self.world.device) 123 | for _ in self.world.agents: 124 | a = self.world.agents[0] 125 | self.rew += -torch.sqrt( 126 | torch.sum( 127 | torch.square(a.goal_a.state.pos - a.goal_b.state.pos), 128 | dim=-1, 129 | ) 130 | ) 131 | return self.rew 132 | 133 | def observation(self, agent): 134 | # goal color 135 | goal_color = torch.zeros(3, device=self.world.device, dtype=torch.float32) 136 | if agent.goal_b is not None: 137 | goal_color = agent.goal_b.color 138 | 139 | # get positions of all entities in this agent's reference frame 140 | entity_pos = [] 141 | for entity in self.world.landmarks: 142 | entity_pos.append(entity.state.pos - agent.state.pos) 143 | 144 | # communication of all other agents 145 | comm = [] 146 | for other in self.world.agents: 147 | if other is agent or (other.state.c is None): 148 | continue 149 | comm.append(other.state.c) 150 | 151 | # speaker 152 | if not agent.movable: 153 | return goal_color.repeat(self.world.batch_dim, 1) 154 | # listener 155 | if agent.silent: 156 | return torch.cat([agent.state.vel, *entity_pos, *comm], dim=-1) 157 | -------------------------------------------------------------------------------- /vmas/scenarios/mpe/simple_spread.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import torch 6 | 7 | from vmas import render_interactively 8 | from vmas.simulator.core import Agent, Landmark, Sphere, World 9 | from vmas.simulator.scenario import BaseScenario 10 | from vmas.simulator.utils import Color, ScenarioUtils 11 | 12 | 13 | class Scenario(BaseScenario): 14 | def make_world(self, batch_dim: int, device: torch.device, **kwargs): 15 | num_agents = kwargs.pop("n_agents", 3) 16 | obs_agents = kwargs.pop("obs_agents", True) 17 | ScenarioUtils.check_kwargs_consumed(kwargs) 18 | 19 | self.obs_agents = obs_agents 20 | 21 | world = World(batch_dim=batch_dim, device=device) 22 | # set any world properties first 23 | num_landmarks = num_agents 24 | # Add agents 25 | for i in range(num_agents): 26 | agent = Agent( 27 | name=f"agent_{i}", 28 | collide=True, 29 | shape=Sphere(radius=0.15), 30 | color=Color.BLUE, 31 | ) 32 | world.add_agent(agent) 33 | # Add landmarks 34 | for i in range(num_landmarks): 35 | landmark = Landmark( 36 | name=f"landmark {i}", 37 | collide=False, 38 | color=Color.BLACK, 39 | ) 40 | world.add_landmark(landmark) 41 | 42 | return world 43 | 44 | def reset_world_at(self, env_index: int = None): 45 | for agent in self.world.agents: 46 | agent.set_pos( 47 | torch.zeros( 48 | ( 49 | (1, self.world.dim_p) 50 | if env_index is not None 51 | else (self.world.batch_dim, self.world.dim_p) 52 | ), 53 | device=self.world.device, 54 | dtype=torch.float32, 55 | ).uniform_( 56 | -1.0, 57 | 1.0, 58 | ), 59 | batch_index=env_index, 60 | ) 61 | 62 | for landmark in self.world.landmarks: 63 | landmark.set_pos( 64 | torch.zeros( 65 | ( 66 | (1, self.world.dim_p) 67 | if env_index is not None 68 | else (self.world.batch_dim, self.world.dim_p) 69 | ), 70 | device=self.world.device, 71 | dtype=torch.float32, 72 | ).uniform_( 73 | -1.0, 74 | 1.0, 75 | ), 76 | batch_index=env_index, 77 | ) 78 | 79 | def reward(self, agent: Agent): 80 | is_first = agent == self.world.agents[0] 81 | if is_first: 82 | # Agents are rewarded based on minimum agent distance to each landmark, penalized for collisions 83 | self.rew = torch.zeros( 84 | self.world.batch_dim, 85 | device=self.world.device, 86 | dtype=torch.float32, 87 | ) 88 | for single_agent in self.world.agents: 89 | for landmark in self.world.landmarks: 90 | closest = torch.min( 91 | torch.stack( 92 | [ 93 | torch.linalg.vector_norm( 94 | a.state.pos - landmark.state.pos, dim=1 95 | ) 96 | for a in self.world.agents 97 | ], 98 | dim=-1, 99 | ), 100 | dim=-1, 101 | )[0] 102 | self.rew -= closest 103 | 104 | if single_agent.collide: 105 | for a in self.world.agents: 106 | if a != single_agent: 107 | self.rew[self.world.is_overlapping(a, single_agent)] -= 1 108 | 109 | return self.rew 110 | 111 | def observation(self, agent: Agent): 112 | # get positions of all landmarks in this agent's reference frame 113 | landmark_pos = [] 114 | for landmark in self.world.landmarks: # world.entities: 115 | landmark_pos.append(landmark.state.pos - agent.state.pos) 116 | # distance to all other agents 117 | other_pos = [] 118 | for other in self.world.agents: 119 | if other != agent: 120 | other_pos.append(other.state.pos - agent.state.pos) 121 | return torch.cat( 122 | [ 123 | agent.state.pos, 124 | agent.state.vel, 125 | *landmark_pos, 126 | *(other_pos if self.obs_agents else []), 127 | ], 128 | dim=-1, 129 | ) 130 | 131 | 132 | if __name__ == "__main__": 133 | render_interactively(__file__, control_two_agents=True) 134 | -------------------------------------------------------------------------------- /vmas/scenarios/wheel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | import torch 6 | 7 | from vmas import render_interactively 8 | from vmas.simulator.core import Agent, Landmark, Line, Sphere, World 9 | from vmas.simulator.heuristic_policy import BaseHeuristicPolicy 10 | from vmas.simulator.scenario import BaseScenario 11 | from vmas.simulator.utils import Color, ScenarioUtils, TorchUtils 12 | 13 | 14 | class Scenario(BaseScenario): 15 | def make_world(self, batch_dim: int, device: torch.device, **kwargs): 16 | n_agents = kwargs.pop("n_agents", 4) 17 | self.line_length = kwargs.pop("line_length", 2) 18 | line_mass = kwargs.pop("line_mass", 30) 19 | self.desired_velocity = kwargs.pop("desired_velocity", 0.05) 20 | ScenarioUtils.check_kwargs_consumed(kwargs) 21 | 22 | # Make world 23 | world = World(batch_dim, device) 24 | # Add agents 25 | for i in range(n_agents): 26 | # Constraint: all agents have same action range and multiplier 27 | agent = Agent(name=f"agent_{i}", u_multiplier=0.6, shape=Sphere(0.03)) 28 | world.add_agent(agent) 29 | # Add landmarks 30 | self.line = Landmark( 31 | name="line", 32 | collide=True, 33 | rotatable=True, 34 | shape=Line(length=self.line_length), 35 | mass=line_mass, 36 | color=Color.BLACK, 37 | ) 38 | world.add_landmark(self.line) 39 | center = Landmark( 40 | name="center", 41 | shape=Sphere(radius=0.02), 42 | collide=False, 43 | color=Color.BLACK, 44 | ) 45 | world.add_landmark(center) 46 | 47 | return world 48 | 49 | def reset_world_at(self, env_index: int = None): 50 | for agent in self.world.agents: 51 | # Random pos between -1 and 1 52 | agent.set_pos( 53 | torch.zeros( 54 | ( 55 | (1, self.world.dim_p) 56 | if env_index is not None 57 | else (self.world.batch_dim, self.world.dim_p) 58 | ), 59 | device=self.world.device, 60 | dtype=torch.float32, 61 | ).uniform_( 62 | -1.0, 63 | 1.0, 64 | ), 65 | batch_index=env_index, 66 | ) 67 | 68 | self.line.set_rot( 69 | torch.zeros( 70 | (1, 1) if env_index is not None else (self.world.batch_dim, 1), 71 | device=self.world.device, 72 | dtype=torch.float32, 73 | ).uniform_( 74 | -torch.pi / 2, 75 | torch.pi / 2, 76 | ), 77 | batch_index=env_index, 78 | ) 79 | 80 | def reward(self, agent: Agent): 81 | is_first = agent == self.world.agents[0] 82 | 83 | if is_first: 84 | self.rew = (self.line.state.ang_vel.abs() - self.desired_velocity).abs() 85 | 86 | return -self.rew 87 | 88 | def observation(self, agent: Agent): 89 | line_end_1 = torch.cat( 90 | [ 91 | (self.line_length / 2) * torch.cos(self.line.state.rot), 92 | (self.line_length / 2) * torch.sin(self.line.state.rot), 93 | ], 94 | dim=1, 95 | ) 96 | line_end_2 = -line_end_1 97 | 98 | return torch.cat( 99 | [ 100 | agent.state.pos, 101 | agent.state.vel, 102 | self.line.state.pos - agent.state.pos, 103 | line_end_1 - agent.state.pos, 104 | line_end_2 - agent.state.pos, 105 | self.line.state.rot % torch.pi, 106 | self.line.state.ang_vel.abs(), 107 | (self.line.state.ang_vel.abs() - self.desired_velocity).abs(), 108 | ], 109 | dim=-1, 110 | ) 111 | 112 | 113 | class HeuristicPolicy(BaseHeuristicPolicy): 114 | def compute_action(self, observation: torch.Tensor, u_range: float) -> torch.Tensor: 115 | assert self.continuous_actions is True, "Heuristic for continuous actions only" 116 | 117 | index_line_extrema = 6 118 | 119 | pos_agent = observation[:, :2] 120 | pos_end2_agent = observation[:, index_line_extrema + 2 : index_line_extrema + 4] 121 | 122 | pos_end2 = pos_end2_agent + pos_agent 123 | 124 | pos_end2_shifted = TorchUtils.rotate_vector( 125 | pos_end2, 126 | torch.tensor(torch.pi / 4, device=observation.device).expand( 127 | pos_end2.shape[0] 128 | ), 129 | ) 130 | 131 | pos_end2_shifted_agent = pos_end2_shifted - pos_agent 132 | 133 | action_agent = torch.clamp( 134 | pos_end2_shifted_agent, 135 | min=-u_range, 136 | max=u_range, 137 | ) 138 | 139 | return action_agent 140 | 141 | 142 | if __name__ == "__main__": 143 | render_interactively( 144 | __file__, 145 | control_two_agents=True, 146 | desired_velocity=0.05, 147 | n_agents=4, 148 | line_length=2, 149 | line_mass=30, 150 | ) 151 | -------------------------------------------------------------------------------- /vmas/simulator/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | -------------------------------------------------------------------------------- /vmas/simulator/controllers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | -------------------------------------------------------------------------------- /vmas/simulator/controllers/velocity_controller.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | import math 5 | import warnings 6 | from typing import Optional 7 | 8 | import torch 9 | 10 | import vmas.simulator.core 11 | import vmas.simulator.utils 12 | from vmas.simulator.utils import TorchUtils 13 | 14 | 15 | class VelocityController: 16 | """ 17 | Implements PID controller for velocity targets found in agent.action.u. 18 | Two forms of the PID controller are implemented: standard, and parallel. The controller takes 3 params, which 19 | are interpreted differently based on the form. 20 | > Standard form: ctrl_params=[gain, intg_ts, derv_ts] 21 | intg_ts: rise time for integrator (err will be tolerated for this interval) 22 | derv_ts: seek time for derivative (err is predicted over this interval) 23 | These are specified in 1/dt scale (0.5 means 0.5/0.1==5sec) 24 | > Parallel form: ctrl_params=[kP, kI, kD] 25 | kI and kD have no simple physical meaning, but are related to standard form params. 26 | intg_ts = kP/kI and kD/kP = derv_ts 27 | """ 28 | 29 | def __init__( 30 | self, 31 | agent: vmas.simulator.core.Agent, 32 | world: vmas.simulator.core.World, 33 | ctrl_params=(1, 0, 0), 34 | pid_form="standard", 35 | ): 36 | self.agent = agent 37 | self.world = world 38 | self.dt = world.dt 39 | # controller parameters: standard=[kP, intgTs ,dervTs], parallel=[kP, kI, kD] 40 | # in parallel form, kI = kP/intgTs and kD = kP*dervTs 41 | self.ctrl_gain = ctrl_params[0] # kP 42 | if pid_form == "standard": 43 | self.integralTs = ctrl_params[1] 44 | self.derivativeTs = ctrl_params[2] 45 | elif pid_form == "parallel": 46 | if ctrl_params[1] == 0: 47 | self.integralTs = 0.0 48 | else: 49 | self.integralTs = self.ctrl_gain / ctrl_params[1] 50 | self.derivativeTs = ctrl_params[2] / self.ctrl_gain 51 | else: 52 | raise Exception("PID form is either standard or parallel.") 53 | 54 | # in either form: 55 | if self.integralTs == 0: 56 | self.use_integrator = False 57 | else: 58 | self.use_integrator = True 59 | # set windup limit to 50% of agent's max force 60 | fmax = min( 61 | self.agent.max_f, 62 | self.agent.f_range, 63 | key=lambda x: x if x is not None else math.inf, 64 | ) 65 | 66 | if fmax is not None: 67 | self.integrator_windup_cutoff = ( 68 | 0.5 * fmax * self.integralTs / (self.dt * self.ctrl_gain) 69 | ) 70 | else: 71 | self.integrator_windup_cutoff = None 72 | warnings.warn("Force limits not specified. Integrator can wind up!") 73 | 74 | self.reset() 75 | 76 | def reset(self, index: Optional[int] = None): 77 | if index is None: 78 | self.accum_errs = torch.zeros( 79 | (self.world.batch_dim, self.world.dim_p), 80 | device=self.world.device, 81 | ) 82 | self.prev_err = torch.zeros( 83 | (self.world.batch_dim, self.world.dim_p), 84 | device=self.world.device, 85 | ) 86 | else: 87 | self.accum_errs = TorchUtils.where_from_index(index, 0.0, self.accum_errs) 88 | self.prev_err = TorchUtils.where_from_index(index, 0.0, self.prev_err) 89 | 90 | def integralError(self, err): 91 | if not self.use_integrator: 92 | return 0 93 | # fixed-length history (not recommended): 94 | # if len( self.accum_errs ) > self.integrator_hist-1: 95 | # self.accum_errs.pop(0); 96 | # self.accum_errs.append( err ); 97 | # return (1.0/self.integralTs) * torch.stack( self.accum_errs, dim=1 ).sum(dim=1) * self.dt; 98 | 99 | self.accum_errs += self.dt * err 100 | if self.integrator_windup_cutoff is not None: 101 | self.accum_errs = self.accum_errs.clamp( 102 | -self.integrator_windup_cutoff, self.integrator_windup_cutoff 103 | ) 104 | 105 | return (1.0 / self.integralTs) * self.accum_errs 106 | 107 | def rateError(self, err): 108 | e = self.derivativeTs * (err - self.prev_err) / self.dt 109 | self.prev_err = err 110 | return e 111 | 112 | def process_force(self): 113 | self.accum_errs = self.accum_errs.to(self.world.device) 114 | self.prev_err = self.prev_err.to(self.world.device) 115 | 116 | des_vel = self.agent.action.u 117 | cur_vel = self.agent.state.vel 118 | 119 | # apply control 120 | err = des_vel - cur_vel 121 | u = self.ctrl_gain * (err + self.integralError(err) + self.rateError(err)) 122 | u *= self.agent.mass 123 | 124 | self.agent.action.u = u 125 | -------------------------------------------------------------------------------- /vmas/simulator/dynamics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | -------------------------------------------------------------------------------- /vmas/simulator/dynamics/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | import abc 5 | from abc import ABC 6 | from typing import Union 7 | 8 | from torch import Tensor 9 | 10 | 11 | class Dynamics(ABC): 12 | def __init__( 13 | self, 14 | ): 15 | self._agent = None 16 | 17 | def reset(self, index: Union[Tensor, int] = None): 18 | return 19 | 20 | def zero_grad(self): 21 | return 22 | 23 | @property 24 | def agent(self): 25 | if self._agent is None: 26 | raise ValueError( 27 | "You need to add the dynamics to an agent during construction before accessing its properties" 28 | ) 29 | return self._agent 30 | 31 | @agent.setter 32 | def agent(self, value): 33 | if self._agent is not None: 34 | raise ValueError("Agent in dynamics has already been set") 35 | self._agent = value 36 | 37 | def check_and_process_action(self): 38 | action = self.agent.action.u 39 | if action.shape[1] < self.needed_action_size: 40 | raise ValueError( 41 | f"Agent action size {action.shape[1]} is less than the required dynamics action size {self.needed_action_size}" 42 | ) 43 | self.process_action() 44 | 45 | @property 46 | @abc.abstractmethod 47 | def needed_action_size(self) -> int: 48 | raise NotImplementedError 49 | 50 | @abc.abstractmethod 51 | def process_action(self): 52 | raise NotImplementedError 53 | -------------------------------------------------------------------------------- /vmas/simulator/dynamics/diff_drive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | 6 | import torch 7 | 8 | import vmas.simulator.core 9 | import vmas.simulator.utils 10 | from vmas.simulator.dynamics.common import Dynamics 11 | 12 | 13 | class DiffDrive(Dynamics): 14 | def __init__( 15 | self, 16 | world: vmas.simulator.core.World, 17 | integration: str = "rk4", # one of "euler", "rk4" 18 | ): 19 | super().__init__() 20 | assert integration == "rk4" or integration == "euler" 21 | 22 | self.dt = world.dt 23 | self.integration = integration 24 | self.world = world 25 | 26 | def f(self, state, u_command, ang_vel_command): 27 | theta = state[:, 2] 28 | dx = u_command * torch.cos(theta) 29 | dy = u_command * torch.sin(theta) 30 | dtheta = ang_vel_command 31 | return torch.stack((dx, dy, dtheta), dim=-1) # [batch_size,3] 32 | 33 | def euler(self, state, u_command, ang_vel_command): 34 | return self.dt * self.f(state, u_command, ang_vel_command) 35 | 36 | def runge_kutta(self, state, u_command, ang_vel_command): 37 | k1 = self.f(state, u_command, ang_vel_command) 38 | k2 = self.f(state + self.dt * k1 / 2, u_command, ang_vel_command) 39 | k3 = self.f(state + self.dt * k2 / 2, u_command, ang_vel_command) 40 | k4 = self.f(state + self.dt * k3, u_command, ang_vel_command) 41 | return (self.dt / 6) * (k1 + 2 * k2 + 2 * k3 + k4) 42 | 43 | @property 44 | def needed_action_size(self) -> int: 45 | return 2 46 | 47 | def process_action(self): 48 | u_command = self.agent.action.u[:, 0] # Forward velocity 49 | ang_vel_command = self.agent.action.u[:, 1] # Angular velocity 50 | 51 | # Current state of the agent 52 | state = torch.cat((self.agent.state.pos, self.agent.state.rot), dim=1) 53 | 54 | v_cur_x = self.agent.state.vel[:, 0] # Current velocity in x-direction 55 | v_cur_y = self.agent.state.vel[:, 1] # Current velocity in y-direction 56 | v_cur_angular = self.agent.state.ang_vel[:, 0] # Current angular velocity 57 | 58 | # Select the integration method to calculate the change in state 59 | if self.integration == "euler": 60 | delta_state = self.euler(state, u_command, ang_vel_command) 61 | else: 62 | delta_state = self.runge_kutta(state, u_command, ang_vel_command) 63 | 64 | # Calculate the accelerations required to achieve the change in state 65 | acceleration_x = (delta_state[:, 0] - v_cur_x * self.dt) / self.dt**2 66 | acceleration_y = (delta_state[:, 1] - v_cur_y * self.dt) / self.dt**2 67 | acceleration_angular = ( 68 | delta_state[:, 2] - v_cur_angular * self.dt 69 | ) / self.dt**2 70 | 71 | # Calculate the forces required for the linear accelerations 72 | force_x = self.agent.mass * acceleration_x 73 | force_y = self.agent.mass * acceleration_y 74 | 75 | # Calculate the torque required for the angular acceleration 76 | torque = self.agent.moment_of_inertia * acceleration_angular 77 | 78 | # Update the physical force and torque required for the user inputs 79 | self.agent.state.force[:, vmas.simulator.utils.X] = force_x 80 | self.agent.state.force[:, vmas.simulator.utils.Y] = force_y 81 | self.agent.state.torque = torque.unsqueeze(-1) 82 | -------------------------------------------------------------------------------- /vmas/simulator/dynamics/drone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | from typing import Union 6 | 7 | import torch 8 | from torch import Tensor 9 | 10 | import vmas.simulator.core 11 | import vmas.simulator.utils 12 | from vmas.simulator.dynamics.common import Dynamics 13 | from vmas.simulator.utils import TorchUtils 14 | 15 | 16 | class Drone(Dynamics): 17 | def __init__( 18 | self, 19 | world: vmas.simulator.core.World, 20 | I_xx: float = 8.1e-3, 21 | I_yy: float = 8.1e-3, 22 | I_zz: float = 14.2e-3, 23 | integration: str = "rk4", 24 | ): 25 | super().__init__() 26 | 27 | assert integration in ( 28 | "rk4", 29 | "euler", 30 | ) 31 | 32 | self.integration = integration 33 | self.I_xx = I_xx 34 | self.I_yy = I_yy 35 | self.I_zz = I_zz 36 | self.world = world 37 | self.g = 9.81 38 | self.dt = world.dt 39 | self.reset() 40 | 41 | def reset(self, index: Union[Tensor, int] = None): 42 | if index is None: 43 | # Drone state: phi(roll), theta (pitch), psi (yaw), 44 | # p (roll_rate), q (pitch_rate), r (yaw_rate), 45 | # x_dot (vel_x), y_dot (vel_y), z_dot (vel_z), 46 | # x (pos_x), y (pos_y), z (pos_z) 47 | self.drone_state = torch.zeros( 48 | self.world.batch_dim, 49 | 12, 50 | device=self.world.device, 51 | ) 52 | else: 53 | self.drone_state = TorchUtils.where_from_index(index, 0.0, self.drone_state) 54 | 55 | def zero_grad(self): 56 | self.drone_state = self.drone_state.detach() 57 | 58 | def f(self, state, thrust_command, torque_command): 59 | phi = state[:, 0] 60 | theta = state[:, 1] 61 | psi = state[:, 2] 62 | p = state[:, 3] 63 | q = state[:, 4] 64 | r = state[:, 5] 65 | x_dot = state[:, 6] 66 | y_dot = state[:, 7] 67 | z_dot = state[:, 8] 68 | 69 | c_phi = torch.cos(phi) 70 | s_phi = torch.sin(phi) 71 | c_theta = torch.cos(theta) 72 | s_theta = torch.sin(theta) 73 | c_psi = torch.cos(psi) 74 | s_psi = torch.sin(psi) 75 | 76 | # Postion Dynamics 77 | x_ddot = ( 78 | (c_phi * s_theta * c_psi + s_phi * s_psi) * thrust_command / self.agent.mass 79 | ) 80 | y_ddot = ( 81 | (c_phi * s_theta * s_psi - s_phi * c_psi) * thrust_command / self.agent.mass 82 | ) 83 | z_ddot = (c_phi * c_theta) * thrust_command / self.agent.mass - self.g 84 | # Angular velocity dynamics 85 | p_dot = (torque_command[:, 0] - (self.I_yy - self.I_zz) * q * r) / self.I_xx 86 | q_dot = (torque_command[:, 1] - (self.I_zz - self.I_xx) * p * r) / self.I_yy 87 | r_dot = (torque_command[:, 2] - (self.I_xx - self.I_yy) * p * q) / self.I_zz 88 | 89 | return torch.stack( 90 | [ 91 | p, 92 | q, 93 | r, 94 | p_dot, 95 | q_dot, 96 | r_dot, 97 | x_ddot, 98 | y_ddot, 99 | z_ddot, 100 | x_dot, 101 | y_dot, 102 | z_dot, 103 | ], 104 | dim=-1, 105 | ) 106 | 107 | def needs_reset(self) -> Tensor: 108 | # Constraint roll and pitch within +-30 degrees 109 | return torch.any(self.drone_state[:, :2].abs() > 30 * (torch.pi / 180), dim=-1) 110 | 111 | def euler(self, state, thrust, torque): 112 | return self.dt * self.f(state, thrust, torque) 113 | 114 | def runge_kutta(self, state, thrust, torque): 115 | k1 = self.f(state, thrust, torque) 116 | k2 = self.f(state + self.dt * k1 / 2, thrust, torque) 117 | k3 = self.f(state + self.dt * k2 / 2, thrust, torque) 118 | k4 = self.f(state + self.dt * k3, thrust, torque) 119 | return (self.dt / 6) * (k1 + 2 * k2 + 2 * k3 + k4) 120 | 121 | @property 122 | def needed_action_size(self) -> int: 123 | return 4 124 | 125 | def process_action(self): 126 | u = self.agent.action.u 127 | thrust = u[:, 0] # Thrust, sum of all propeller thrusts 128 | torque = u[:, 1:4] # Torque in x, y, z direction 129 | 130 | thrust += self.agent.mass * self.g # Ensure the drone is not falling 131 | 132 | self.drone_state[:, 9] = self.agent.state.pos[:, 0] # x 133 | self.drone_state[:, 10] = self.agent.state.pos[:, 1] # y 134 | self.drone_state[:, 2] = self.agent.state.rot[:, 0] # psi (yaw) 135 | 136 | if self.integration == "euler": 137 | delta_state = self.euler(self.drone_state, thrust, torque) 138 | else: 139 | delta_state = self.runge_kutta(self.drone_state, thrust, torque) 140 | 141 | # Calculate the change in state 142 | self.drone_state = self.drone_state + delta_state 143 | 144 | v_cur_x = self.agent.state.vel[:, 0] # Current velocity in x-direction 145 | v_cur_y = self.agent.state.vel[:, 1] # Current velocity in y-direction 146 | v_cur_angular = self.agent.state.ang_vel[:, 0] # Current angular velocity 147 | 148 | # Calculate the accelerations required to achieve the change in state 149 | acceleration_x = (delta_state[:, 6] - v_cur_x * self.dt) / self.dt**2 150 | acceleration_y = (delta_state[:, 7] - v_cur_y * self.dt) / self.dt**2 151 | acceleration_angular = ( 152 | delta_state[:, 5] - v_cur_angular * self.dt 153 | ) / self.dt**2 154 | 155 | # Calculate the forces required for the linear accelerations 156 | force_x = self.agent.mass * acceleration_x 157 | force_y = self.agent.mass * acceleration_y 158 | 159 | # Calculate the torque required for the angular acceleration 160 | torque_yaw = self.agent.moment_of_inertia * acceleration_angular 161 | 162 | # Update the physical force and torque required for the user inputs 163 | self.agent.state.force[:, vmas.simulator.utils.X] = force_x 164 | self.agent.state.force[:, vmas.simulator.utils.Y] = force_y 165 | self.agent.state.torque = torque_yaw.unsqueeze(-1) 166 | -------------------------------------------------------------------------------- /vmas/simulator/dynamics/forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | import torch 5 | 6 | from vmas.simulator.dynamics.common import Dynamics 7 | from vmas.simulator.utils import TorchUtils, X 8 | 9 | 10 | class Forward(Dynamics): 11 | @property 12 | def needed_action_size(self) -> int: 13 | return 1 14 | 15 | def process_action(self): 16 | force = torch.zeros( 17 | self.agent.batch_dim, 2, device=self.agent.device, dtype=torch.float 18 | ) 19 | force[:, X] = self.agent.action.u[:, 0] 20 | self.agent.state.force = TorchUtils.rotate_vector(force, self.agent.state.rot) 21 | -------------------------------------------------------------------------------- /vmas/simulator/dynamics/holonomic.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | from vmas.simulator.dynamics.common import Dynamics 6 | 7 | 8 | class Holonomic(Dynamics): 9 | @property 10 | def needed_action_size(self) -> int: 11 | return 2 12 | 13 | def process_action(self): 14 | self.agent.state.force = self.agent.action.u[:, : self.needed_action_size] 15 | -------------------------------------------------------------------------------- /vmas/simulator/dynamics/holonomic_with_rot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | from vmas.simulator.dynamics.common import Dynamics 6 | 7 | 8 | class HolonomicWithRotation(Dynamics): 9 | @property 10 | def needed_action_size(self) -> int: 11 | return 3 12 | 13 | def process_action(self): 14 | self.agent.state.force = self.agent.action.u[:, :2] 15 | self.agent.state.torque = self.agent.action.u[:, 2].unsqueeze(-1) 16 | -------------------------------------------------------------------------------- /vmas/simulator/dynamics/kinematic_bicycle.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2025. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | 6 | import torch 7 | 8 | import vmas.simulator.core 9 | import vmas.simulator.utils 10 | from vmas.simulator.dynamics.common import Dynamics 11 | 12 | 13 | class KinematicBicycle(Dynamics): 14 | # For the implementation of the kinematic bicycle model, see the equation (2) of the paper Polack, Philip, et al. "The kinematic bicycle model: A consistent model for planning feasible trajectories for autonomous vehicles?." 2017 IEEE intelligent vehicles symposium (IV). IEEE, 2017. 15 | def __init__( 16 | self, 17 | world: vmas.simulator.core.World, 18 | width: float, 19 | l_f: float, 20 | l_r: float, 21 | max_steering_angle: float, 22 | integration: str = "rk4", # one of "euler", "rk4" 23 | ): 24 | super().__init__() 25 | assert integration in ( 26 | "rk4", 27 | "euler", 28 | ), "Integration method must be 'euler' or 'rk4'." 29 | self.width = width 30 | self.l_f = l_f # Distance between the front axle and the center of gravity 31 | self.l_r = l_r # Distance between the rear axle and the center of gravity 32 | self.max_steering_angle = max_steering_angle 33 | self.dt = world.dt 34 | self.integration = integration 35 | self.world = world 36 | 37 | def f(self, state, steering_command, v_command): 38 | theta = state[:, 2] # Yaw angle 39 | beta = torch.atan2( 40 | torch.tan(steering_command) * self.l_r / (self.l_f + self.l_r), 41 | torch.tensor(1, device=self.world.device), 42 | ) # [-pi, pi] slip angle 43 | dx = v_command * torch.cos(theta + beta) 44 | dy = v_command * torch.sin(theta + beta) 45 | dtheta = ( 46 | v_command 47 | / (self.l_f + self.l_r) 48 | * torch.cos(beta) 49 | * torch.tan(steering_command) 50 | ) 51 | return torch.stack((dx, dy, dtheta), dim=1) # [batch_size,3] 52 | 53 | def euler(self, state, steering_command, v_command): 54 | # Calculate the change in state using Euler's method 55 | # For Euler's method, see https://math.libretexts.org/Bookshelves/Calculus/Book%3A_Active_Calculus_(Boelkins_et_al.)/07%3A_Differential_Equations/7.03%3A_Euler's_Method (the full link may not be recognized properly, please copy and paste in your browser) 56 | return self.dt * self.f(state, steering_command, v_command) 57 | 58 | def runge_kutta(self, state, steering_command, v_command): 59 | # Calculate the change in state using fourth-order Runge-Kutta method 60 | # For Runge-Kutta method, see https://math.libretexts.org/Courses/Monroe_Community_College/MTH_225_Differential_Equations/3%3A_Numerical_Methods/3.3%3A_The_Runge-Kutta_Method 61 | k1 = self.f(state, steering_command, v_command) 62 | k2 = self.f(state + self.dt * k1 / 2, steering_command, v_command) 63 | k3 = self.f(state + self.dt * k2 / 2, steering_command, v_command) 64 | k4 = self.f(state + self.dt * k3, steering_command, v_command) 65 | return (self.dt / 6) * (k1 + 2 * k2 + 2 * k3 + k4) 66 | 67 | @property 68 | def needed_action_size(self) -> int: 69 | return 2 70 | 71 | def process_action(self): 72 | # Extracts the velocity and steering angle from the agent's actions and convert them to physical force and torque 73 | v_command = self.agent.action.u[:, 0] 74 | steering_command = self.agent.action.u[:, 1] 75 | # Ensure steering angle is within bounds 76 | steering_command = torch.clamp( 77 | steering_command, -self.max_steering_angle, self.max_steering_angle 78 | ) 79 | 80 | # Current state of the agent 81 | state = torch.cat((self.agent.state.pos, self.agent.state.rot), dim=1) 82 | 83 | v_cur_x = self.agent.state.vel[:, 0] # Current velocity in x-direction 84 | v_cur_y = self.agent.state.vel[:, 1] # Current velocity in y-direction 85 | v_cur_angular = self.agent.state.ang_vel[:, 0] # Current angular velocity 86 | 87 | # Select the integration method to calculate the change in state 88 | if self.integration == "euler": 89 | delta_state = self.euler(state, steering_command, v_command) 90 | else: 91 | delta_state = self.runge_kutta(state, steering_command, v_command) 92 | 93 | # Calculate the accelerations required to achieve the change in state. 94 | acceleration_x = (delta_state[:, 0] - v_cur_x * self.dt) / self.dt**2 95 | acceleration_y = (delta_state[:, 1] - v_cur_y * self.dt) / self.dt**2 96 | acceleration_angular = ( 97 | delta_state[:, 2] - v_cur_angular * self.dt 98 | ) / self.dt**2 99 | 100 | # Calculate the forces required for the linear accelerations 101 | force_x = self.agent.mass * acceleration_x 102 | force_y = self.agent.mass * acceleration_y 103 | 104 | # Calculate the torque required for the angular acceleration 105 | torque = self.agent.moment_of_inertia * acceleration_angular 106 | 107 | # Update the physical force and torque required for the user inputs 108 | self.agent.state.force[:, vmas.simulator.utils.X] = force_x 109 | self.agent.state.force[:, vmas.simulator.utils.Y] = force_y 110 | self.agent.state.torque = torque.unsqueeze(-1) 111 | -------------------------------------------------------------------------------- /vmas/simulator/dynamics/roatation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | from vmas.simulator.dynamics.common import Dynamics 6 | 7 | 8 | class Rotation(Dynamics): 9 | @property 10 | def needed_action_size(self) -> int: 11 | return 1 12 | 13 | def process_action(self): 14 | self.agent.state.torque = self.agent.action.u[:, 0].unsqueeze(-1) 15 | -------------------------------------------------------------------------------- /vmas/simulator/dynamics/static.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | from vmas.simulator.dynamics.common import Dynamics 6 | 7 | 8 | class Static(Dynamics): 9 | @property 10 | def needed_action_size(self) -> int: 11 | return 0 12 | 13 | def process_action(self): 14 | pass 15 | -------------------------------------------------------------------------------- /vmas/simulator/environment/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | from enum import Enum 5 | 6 | from vmas.simulator.environment.environment import Environment 7 | 8 | 9 | class Wrapper(Enum): 10 | RLLIB = 0 11 | GYM = 1 12 | GYMNASIUM = 2 13 | GYMNASIUM_VEC = 3 14 | 15 | def get_env(self, env: Environment, **kwargs): 16 | if self is self.RLLIB: 17 | from vmas.simulator.environment.rllib import VectorEnvWrapper 18 | 19 | return VectorEnvWrapper(env, **kwargs) 20 | elif self is self.GYM: 21 | from vmas.simulator.environment.gym import GymWrapper 22 | 23 | return GymWrapper(env, **kwargs) 24 | elif self is self.GYMNASIUM: 25 | from vmas.simulator.environment.gym.gymnasium import GymnasiumWrapper 26 | 27 | return GymnasiumWrapper(env, **kwargs) 28 | elif self is self.GYMNASIUM_VEC: 29 | from vmas.simulator.environment.gym.gymnasium_vec import ( 30 | GymnasiumVectorizedWrapper, 31 | ) 32 | 33 | return GymnasiumVectorizedWrapper(env, **kwargs) 34 | -------------------------------------------------------------------------------- /vmas/simulator/environment/gym/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | from .gym import GymWrapper 6 | -------------------------------------------------------------------------------- /vmas/simulator/environment/gym/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | from abc import ABC, abstractmethod 6 | from collections import namedtuple 7 | from typing import List, Optional 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from vmas.simulator.environment import Environment 13 | 14 | from vmas.simulator.utils import extract_nested_with_index, TorchUtils 15 | 16 | 17 | EnvData = namedtuple( 18 | "EnvData", ["obs", "rews", "terminated", "truncated", "done", "info"] 19 | ) 20 | 21 | 22 | class BaseGymWrapper(ABC): 23 | def __init__(self, env: Environment, return_numpy: bool, vectorized: bool): 24 | self._env = env 25 | self.return_numpy = return_numpy 26 | self.dict_spaces = env.dict_spaces 27 | self.vectorized = vectorized 28 | 29 | @property 30 | def env(self): 31 | return self._env 32 | 33 | def _maybe_to_numpy(self, tensor): 34 | return TorchUtils.to_numpy(tensor) if self.return_numpy else tensor 35 | 36 | def _convert_output(self, data, item: bool = False): 37 | if not self.vectorized: 38 | data = extract_nested_with_index(data, index=0) 39 | if item: 40 | return data.item() 41 | return self._maybe_to_numpy(data) 42 | 43 | def _compress_infos(self, infos): 44 | if isinstance(infos, dict): 45 | return infos 46 | elif isinstance(infos, list): 47 | return {self._env.agents[i].name: info for i, info in enumerate(infos)} 48 | else: 49 | raise ValueError( 50 | f"Expected list or dictionary for infos but got {type(infos)}" 51 | ) 52 | 53 | def _convert_env_data( 54 | self, obs=None, rews=None, info=None, terminated=None, truncated=None, done=None 55 | ): 56 | if self.dict_spaces: 57 | for agent in obs.keys(): 58 | if obs is not None: 59 | obs[agent] = self._convert_output(obs[agent]) 60 | if info is not None: 61 | info[agent] = self._convert_output(info[agent]) 62 | if rews is not None: 63 | rews[agent] = self._convert_output(rews[agent], item=True) 64 | else: 65 | for i in range(self._env.n_agents): 66 | if obs is not None: 67 | obs[i] = self._convert_output(obs[i]) 68 | if info is not None: 69 | info[i] = self._convert_output(info[i]) 70 | if rews is not None: 71 | rews[i] = self._convert_output(rews[i], item=True) 72 | terminated = ( 73 | self._convert_output(terminated, item=True) 74 | if terminated is not None 75 | else None 76 | ) 77 | truncated = ( 78 | self._convert_output(truncated, item=True) 79 | if truncated is not None 80 | else None 81 | ) 82 | done = self._convert_output(done, item=True) if done is not None else None 83 | info = self._compress_infos(info) if info is not None else None 84 | return EnvData( 85 | obs=obs, 86 | rews=rews, 87 | terminated=terminated, 88 | truncated=truncated, 89 | done=done, 90 | info=info, 91 | ) 92 | 93 | def _action_list_to_tensor(self, list_in: List) -> List: 94 | assert ( 95 | len(list_in) == self._env.n_agents 96 | ), f"Expecting actions for {self._env.n_agents} agents, got {len(list_in)} actions" 97 | 98 | dtype = torch.float32 if self._env.continuous_actions else torch.long 99 | 100 | return [ 101 | torch.tensor(act, device=self._env.device, dtype=dtype).reshape( 102 | self._env.num_envs, self._env.get_agent_action_size(agent) 103 | ) 104 | if not isinstance(act, torch.Tensor) 105 | else act.to(dtype=dtype, device=self._env.device).reshape( 106 | self._env.num_envs, self._env.get_agent_action_size(agent) 107 | ) 108 | for agent, act in zip(self._env.agents, list_in) 109 | ] 110 | 111 | @abstractmethod 112 | def step(self, action): 113 | raise NotImplementedError 114 | 115 | @abstractmethod 116 | def reset( 117 | self, 118 | *, 119 | seed: Optional[int] = None, 120 | options: Optional[dict] = None, 121 | ): 122 | raise NotImplementedError 123 | 124 | @abstractmethod 125 | def render( 126 | self, 127 | agent_index_focus: Optional[int] = None, 128 | visualize_when_rgb: bool = False, 129 | **kwargs, 130 | ) -> Optional[np.ndarray]: 131 | raise NotImplementedError 132 | -------------------------------------------------------------------------------- /vmas/simulator/environment/gym/gym.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | from typing import Optional 5 | 6 | import gym 7 | import numpy as np 8 | 9 | from vmas.simulator.environment.environment import Environment 10 | from vmas.simulator.environment.gym.base import BaseGymWrapper 11 | 12 | 13 | class GymWrapper(gym.Env, BaseGymWrapper): 14 | metadata = Environment.metadata 15 | 16 | def __init__( 17 | self, 18 | env: Environment, 19 | return_numpy: bool = True, 20 | ): 21 | super().__init__(env, return_numpy=return_numpy, vectorized=False) 22 | assert ( 23 | env.num_envs == 1 24 | ), f"GymEnv wrapper is not vectorised, got env.num_envs: {env.num_envs}" 25 | 26 | assert ( 27 | not self._env.terminated_truncated 28 | ), "GymWrapper is not compatible with termination and truncation flags. Please set `terminated_truncated=False` in the VMAS environment." 29 | self.observation_space = self._env.observation_space 30 | self.action_space = self._env.action_space 31 | 32 | @property 33 | def unwrapped(self) -> Environment: 34 | return self._env 35 | 36 | def step(self, action): 37 | action = self._action_list_to_tensor(action) 38 | obs, rews, done, info = self._env.step(action) 39 | env_data = self._convert_env_data( 40 | obs=obs, 41 | rews=rews, 42 | info=info, 43 | done=done, 44 | ) 45 | return env_data.obs, env_data.rews, env_data.done, env_data.info 46 | 47 | def reset( 48 | self, 49 | *, 50 | seed: Optional[int] = None, 51 | return_info: bool = False, 52 | options: Optional[dict] = None, 53 | ): 54 | if seed is not None: 55 | self._env.seed(seed) 56 | obs = self._env.reset_at(index=0) 57 | env_data = self._convert_env_data(obs=obs) 58 | return env_data.obs 59 | 60 | def render( 61 | self, 62 | mode="human", 63 | agent_index_focus: Optional[int] = None, 64 | visualize_when_rgb: bool = False, 65 | **kwargs, 66 | ) -> Optional[np.ndarray]: 67 | return self._env.render( 68 | mode=mode, 69 | env_index=0, 70 | agent_index_focus=agent_index_focus, 71 | visualize_when_rgb=visualize_when_rgb, 72 | **kwargs, 73 | ) 74 | -------------------------------------------------------------------------------- /vmas/simulator/environment/gym/gymnasium.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | import importlib 5 | from typing import Optional 6 | 7 | import numpy as np 8 | 9 | from vmas.simulator.environment.environment import Environment 10 | from vmas.simulator.environment.gym.base import BaseGymWrapper 11 | 12 | 13 | if ( 14 | importlib.util.find_spec("gymnasium") is not None 15 | and importlib.util.find_spec("shimmy") is not None 16 | ): 17 | import gymnasium as gym 18 | from shimmy.openai_gym_compatibility import _convert_space 19 | else: 20 | raise ImportError( 21 | "Gymnasium or shimmy is not installed. Please install it with `pip install gymnasium shimmy`." 22 | ) 23 | 24 | 25 | class GymnasiumWrapper(gym.Env, BaseGymWrapper): 26 | metadata = Environment.metadata 27 | 28 | def __init__( 29 | self, 30 | env: Environment, 31 | return_numpy: bool = True, 32 | render_mode: str = "human", 33 | ): 34 | super().__init__(env, return_numpy=return_numpy, vectorized=False) 35 | assert ( 36 | env.num_envs == 1 37 | ), "GymnasiumEnv wrapper only supports singleton VMAS environment! For vectorized environments, use vectorized wrapper with `wrapper=gymnasium_vec`." 38 | 39 | assert ( 40 | self._env.terminated_truncated 41 | ), "GymnasiumWrapper is only compatible with termination and truncation flags. Please set `terminated_truncated=True` in the VMAS environment." 42 | self.observation_space = _convert_space(self._env.observation_space) 43 | self.action_space = _convert_space(self._env.action_space) 44 | self.render_mode = render_mode 45 | 46 | @property 47 | def unwrapped(self) -> Environment: 48 | return self._env 49 | 50 | def step(self, action): 51 | action = self._action_list_to_tensor(action) 52 | obs, rews, terminated, truncated, info = self._env.step(action) 53 | env_data = self._convert_env_data( 54 | obs=obs, rews=rews, info=info, terminated=terminated, truncated=truncated 55 | ) 56 | return ( 57 | env_data.obs, 58 | env_data.rews, 59 | env_data.terminated, 60 | env_data.truncated, 61 | env_data.info, 62 | ) 63 | 64 | def reset( 65 | self, 66 | *, 67 | seed: Optional[int] = None, 68 | options: Optional[dict] = None, 69 | ): 70 | if seed is not None: 71 | self._env.seed(seed) 72 | obs, info = self._env.reset_at(index=0, return_info=True) 73 | env_data = self._convert_env_data(obs=obs, info=info) 74 | return env_data.obs, env_data.info 75 | 76 | def render( 77 | self, 78 | agent_index_focus: Optional[int] = None, 79 | visualize_when_rgb: bool = False, 80 | **kwargs, 81 | ) -> Optional[np.ndarray]: 82 | return self._env.render( 83 | mode=self.render_mode, 84 | env_index=0, 85 | agent_index_focus=agent_index_focus, 86 | visualize_when_rgb=visualize_when_rgb, 87 | **kwargs, 88 | ) 89 | -------------------------------------------------------------------------------- /vmas/simulator/environment/gym/gymnasium_vec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | import importlib 5 | 6 | import warnings 7 | from typing import Optional 8 | 9 | import numpy as np 10 | 11 | from vmas.simulator.environment.environment import Environment 12 | from vmas.simulator.environment.gym.base import BaseGymWrapper 13 | 14 | 15 | if ( 16 | importlib.util.find_spec("gymnasium") is not None 17 | and importlib.util.find_spec("shimmy") is not None 18 | ): 19 | import gymnasium as gym 20 | from gymnasium.vector.utils import batch_space 21 | from shimmy.openai_gym_compatibility import _convert_space 22 | else: 23 | raise ImportError( 24 | "Gymnasium or shimmy is not installed. Please install it with `pip install gymnasium shimmy`." 25 | ) 26 | 27 | 28 | class GymnasiumVectorizedWrapper(gym.Env, BaseGymWrapper): 29 | metadata = Environment.metadata 30 | 31 | def __init__( 32 | self, 33 | env: Environment, 34 | return_numpy: bool = True, 35 | render_mode: str = "human", 36 | ): 37 | super().__init__(env, return_numpy=return_numpy, vectorized=True) 38 | self._num_envs = self._env.num_envs 39 | assert ( 40 | self._env.terminated_truncated 41 | ), "GymnasiumWrapper is only compatible with termination and truncation flags. Please set `terminated_truncated=True` in the VMAS environment." 42 | self.single_observation_space = _convert_space(self._env.observation_space) 43 | self.single_action_space = _convert_space(self._env.action_space) 44 | self.observation_space = batch_space( 45 | self.single_observation_space, n=self._num_envs 46 | ) 47 | self.action_space = batch_space(self.single_action_space, n=self._num_envs) 48 | self.render_mode = render_mode 49 | warnings.warn( 50 | "The Gymnasium Vector wrapper currently does not have auto-resets or support partial resets." 51 | "We warn you that by using this class, individual environments will not be reset when they are done and you" 52 | "will only have access to global resets. We strongly suggest using the VMAS API unless your scenario does not implement" 53 | "the `done` function and thus all sub-environments are done at the same time." 54 | ) 55 | 56 | @property 57 | def unwrapped(self) -> Environment: 58 | return self._env 59 | 60 | def step(self, action): 61 | action = self._action_list_to_tensor(action) 62 | obs, rews, terminated, truncated, info = self._env.step(action) 63 | env_data = self._convert_env_data( 64 | obs=obs, rews=rews, info=info, terminated=terminated, truncated=truncated 65 | ) 66 | return ( 67 | env_data.obs, 68 | env_data.rews, 69 | env_data.terminated, 70 | env_data.truncated, 71 | env_data.info, 72 | ) 73 | 74 | def reset( 75 | self, 76 | *, 77 | seed: Optional[int] = None, 78 | options: Optional[dict] = None, 79 | ): 80 | if seed is not None: 81 | self._env.seed(seed) 82 | obs, info = self._env.reset(return_info=True) 83 | env_data = self._convert_env_data(obs=obs, info=info) 84 | return env_data.obs, env_data.info 85 | 86 | def render( 87 | self, 88 | agent_index_focus: Optional[int] = None, 89 | visualize_when_rgb: bool = False, 90 | **kwargs, 91 | ) -> Optional[np.ndarray]: 92 | return self._env.render( 93 | mode=self.render_mode, 94 | agent_index_focus=agent_index_focus, 95 | visualize_when_rgb=visualize_when_rgb, 96 | **kwargs, 97 | ) 98 | -------------------------------------------------------------------------------- /vmas/simulator/heuristic_policy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | from abc import ABC, abstractmethod 5 | 6 | import torch 7 | 8 | 9 | class BaseHeuristicPolicy(ABC): 10 | def __init__(self, continuous_action: bool): 11 | self.continuous_actions = continuous_action 12 | 13 | @abstractmethod 14 | def compute_action(self, observation: torch.Tensor, u_range: float) -> torch.Tensor: 15 | raise NotImplementedError 16 | 17 | 18 | class RandomPolicy(BaseHeuristicPolicy): 19 | def compute_action(self, observation: torch.Tensor, u_range: float) -> torch.Tensor: 20 | n_envs = observation.shape[0] 21 | return torch.clamp(torch.randn(n_envs, 2), -u_range, u_range) 22 | -------------------------------------------------------------------------------- /vmas/simulator/secrcode.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proroklab/VectorizedMultiAgentSimulator/acd9b7aca3ca5718f58f03a993c5ef3e920c6af9/vmas/simulator/secrcode.ttf -------------------------------------------------------------------------------- /vmas/simulator/sensors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024. 2 | # ProrokLab (https://www.proroklab.org/) 3 | # All rights reserved. 4 | 5 | from __future__ import annotations 6 | 7 | import typing 8 | from abc import ABC, abstractmethod 9 | from typing import Callable, List, Tuple, Union 10 | 11 | import torch 12 | 13 | import vmas.simulator.core 14 | from vmas.simulator.utils import Color 15 | 16 | if typing.TYPE_CHECKING: 17 | from vmas.simulator.rendering import Geom 18 | 19 | 20 | class Sensor(ABC): 21 | def __init__(self, world: vmas.simulator.core.World): 22 | super().__init__() 23 | self._world = world 24 | self._agent: Union[vmas.simulator.core.Agent, None] = None 25 | 26 | @property 27 | def agent(self) -> Union[vmas.simulator.core.Agent, None]: 28 | return self._agent 29 | 30 | @agent.setter 31 | def agent(self, agent: vmas.simulator.core.Agent): 32 | self._agent = agent 33 | 34 | @abstractmethod 35 | def measure(self): 36 | raise NotImplementedError 37 | 38 | @abstractmethod 39 | def render(self, env_index: int = 0) -> "List[Geom]": 40 | raise NotImplementedError 41 | 42 | def to(self, device: torch.device): 43 | raise NotImplementedError 44 | 45 | 46 | class Lidar(Sensor): 47 | def __init__( 48 | self, 49 | world: vmas.simulator.core.World, 50 | angle_start: float = 0.0, 51 | angle_end: float = 2 * torch.pi, 52 | n_rays: int = 8, 53 | max_range: float = 1.0, 54 | entity_filter: Callable[[vmas.simulator.core.Entity], bool] = lambda _: True, 55 | render_color: Union[Color, Tuple[float, float, float]] = Color.GRAY, 56 | alpha: float = 1.0, 57 | render: bool = True, 58 | ): 59 | super().__init__(world) 60 | if (angle_start - angle_end) % (torch.pi * 2) < 1e-5: 61 | angles = torch.linspace( 62 | angle_start, angle_end, n_rays + 1, device=self._world.device 63 | )[:n_rays] 64 | else: 65 | angles = torch.linspace( 66 | angle_start, angle_end, n_rays, device=self._world.device 67 | ) 68 | 69 | self._angles = angles.repeat(self._world.batch_dim, 1) 70 | self._max_range = max_range 71 | self._last_measurement = None 72 | self._render = render 73 | self._entity_filter = entity_filter 74 | self._render_color = render_color 75 | self._alpha = alpha 76 | 77 | def to(self, device: torch.device): 78 | self._angles = self._angles.to(device) 79 | 80 | @property 81 | def entity_filter(self): 82 | return self._entity_filter 83 | 84 | @entity_filter.setter 85 | def entity_filter( 86 | self, entity_filter: Callable[[vmas.simulator.core.Entity], bool] 87 | ): 88 | self._entity_filter = entity_filter 89 | 90 | @property 91 | def render_color(self): 92 | if isinstance(self._render_color, Color): 93 | return self._render_color.value 94 | return self._render_color 95 | 96 | @property 97 | def alpha(self): 98 | return self._alpha 99 | 100 | def measure(self, vectorized: bool = True): 101 | if not vectorized: 102 | dists = [] 103 | for angle in self._angles.unbind(1): 104 | dists.append( 105 | self._world.cast_ray( 106 | self.agent, 107 | angle + self.agent.state.rot.squeeze(-1), 108 | max_range=self._max_range, 109 | entity_filter=self.entity_filter, 110 | ) 111 | ) 112 | measurement = torch.stack(dists, dim=1) 113 | 114 | else: 115 | measurement = self._world.cast_rays( 116 | self.agent, 117 | self._angles + self.agent.state.rot, 118 | max_range=self._max_range, 119 | entity_filter=self.entity_filter, 120 | ) 121 | self._last_measurement = measurement 122 | return measurement 123 | 124 | def set_render(self, render: bool): 125 | self._render = render 126 | 127 | def render(self, env_index: int = 0) -> "List[Geom]": 128 | if not self._render: 129 | return [] 130 | from vmas.simulator import rendering 131 | 132 | geoms: List[rendering.Geom] = [] 133 | if self._last_measurement is not None: 134 | for angle, dist in zip( 135 | self._angles.unbind(1), self._last_measurement.unbind(1) 136 | ): 137 | angle = angle[env_index] + self.agent.state.rot.squeeze(-1)[env_index] 138 | ray = rendering.Line( 139 | (0, 0), 140 | (dist[env_index], 0), 141 | width=0.05, 142 | ) 143 | xform = rendering.Transform() 144 | xform.set_translation(*self.agent.state.pos[env_index]) 145 | xform.set_rotation(angle) 146 | ray.add_attr(xform) 147 | ray.set_color(r=0, g=0, b=0, alpha=self.alpha) 148 | 149 | ray_circ = rendering.make_circle(0.01) 150 | ray_circ.set_color(*self.render_color, alpha=self.alpha) 151 | xform = rendering.Transform() 152 | rot = torch.stack([torch.cos(angle), torch.sin(angle)], dim=-1) 153 | pos_circ = ( 154 | self.agent.state.pos[env_index] + rot * dist.unsqueeze(1)[env_index] 155 | ) 156 | xform.set_translation(*pos_circ) 157 | ray_circ.add_attr(xform) 158 | 159 | geoms.append(ray) 160 | geoms.append(ray_circ) 161 | return geoms 162 | --------------------------------------------------------------------------------