├── .gitignore ├── .pre-commit-config.yaml ├── ACKNOWLEDGEMENTS.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── LICENSE_DATA ├── README.md ├── assets └── space_figure.png ├── requirements.txt └── space ├── __init__.py ├── agents ├── __init__.py ├── basenav_agent.py ├── cswm_agent.py ├── dmnav_agent.py ├── egonav_agent.py ├── mct_agent.py └── qa_agent.py ├── configs ├── __init__.py ├── claude.py ├── gpt.py ├── llama.py ├── mistral.py ├── phi.py └── yi.py ├── envs ├── __init__.py ├── cswm.py ├── mct.py ├── nav_dm.py └── nav_ego.py ├── evaluate_cswm.py ├── evaluate_dmnav.py ├── evaluate_egonav.py ├── evaluate_mct.py ├── evaluate_qas.py ├── registry.py └── utils ├── claude_api.py ├── common.py ├── habitat.py ├── model.py ├── openai_api.py ├── visualizations.py └── vllm_api.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | 176 | # Local ignores 177 | data 178 | experiments 179 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.6.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: check-merge-conflict 7 | - id: check-yaml 8 | - id: end-of-file-fixer 9 | - repo: https://github.com/astral-sh/ruff-pre-commit 10 | rev: v0.5.5 11 | hooks: 12 | - id: ruff 13 | args: [ --fix, --show-fixes] 14 | - id: ruff-format 15 | -------------------------------------------------------------------------------- /ACKNOWLEDGEMENTS.md: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | Portions of this code base may utilize the following copyrighted 3 | material, the use of which is hereby acknowledged. 4 | 5 | _____________________ 6 | Habitat-Sim 7 | 8 | ``` 9 | MIT License 10 | 11 | Copyright (c) Meta Platforms, Inc. and its affiliates. 12 | 13 | Permission is hereby granted, free of charge, to any person obtaining a copy 14 | of this software and associated documentation files (the "Software"), to deal 15 | in the Software without restriction, including without limitation the rights 16 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 17 | copies of the Software, and to permit persons to whom the Software is 18 | furnished to do so, subject to the following conditions: 19 | 20 | The above copyright notice and this permission notice shall be included in all 21 | copies or substantial portions of the Software. 22 | 23 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 24 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 25 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 26 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 27 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 28 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 29 | SOFTWARE. 30 | ``` 31 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2025 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | 41 | ------------------------------------------------------------------------------- 42 | SOFTWARE DISTRIBUTED WITH SPACE BENCHMARK: 43 | 44 | The SPACE benchmark software includes subcomponents with separate copyright notices 45 | and license terms - please see the file ACKNOWLEDGEMENTS.md. 46 | ------------------------------------------------------------------------------- 47 | -------------------------------------------------------------------------------- /LICENSE_DATA: -------------------------------------------------------------------------------- 1 | Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 2 | International Public License 3 | 4 | By exercising the Licensed Rights (defined below), You accept and agree 5 | to be bound by the terms and conditions of this Creative Commons 6 | Attribution-NonCommercial-NoDerivatives 4.0 International Public 7 | License ("Public License"). To the extent this Public License may be 8 | interpreted as a contract, You are granted the Licensed Rights in 9 | consideration of Your acceptance of these terms and conditions, and the 10 | Licensor grants You such rights in consideration of benefits the 11 | Licensor receives from making the Licensed Material available under 12 | these terms and conditions. 13 | 14 | 15 | Section 1 -- Definitions. 16 | 17 | a. Adapted Material means material subject to Copyright and Similar 18 | Rights that is derived from or based upon the Licensed Material 19 | and in which the Licensed Material is translated, altered, 20 | arranged, transformed, or otherwise modified in a manner requiring 21 | permission under the Copyright and Similar Rights held by the 22 | Licensor. For purposes of this Public License, where the Licensed 23 | Material is a musical work, performance, or sound recording, 24 | Adapted Material is always produced where the Licensed Material is 25 | synched in timed relation with a moving image. 26 | 27 | b. Copyright and Similar Rights means copyright and/or similar rights 28 | closely related to copyright including, without limitation, 29 | performance, broadcast, sound recording, and Sui Generis Database 30 | Rights, without regard to how the rights are labeled or 31 | categorized. For purposes of this Public License, the rights 32 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 33 | Rights. 34 | 35 | c. Effective Technological Measures means those measures that, in the 36 | absence of proper authority, may not be circumvented under laws 37 | fulfilling obligations under Article 11 of the WIPO Copyright 38 | Treaty adopted on December 20, 1996, and/or similar international 39 | agreements. 40 | 41 | d. Exceptions and Limitations means fair use, fair dealing, and/or 42 | any other exception or limitation to Copyright and Similar Rights 43 | that applies to Your use of the Licensed Material. 44 | 45 | e. Licensed Material means the artistic or literary work, database, 46 | or other material to which the Licensor applied this Public 47 | License. 48 | 49 | f. Licensed Rights means the rights granted to You subject to the 50 | terms and conditions of this Public License, which are limited to 51 | all Copyright and Similar Rights that apply to Your use of the 52 | Licensed Material and that the Licensor has authority to license. 53 | 54 | g. Licensor means the individual(s) or entity(ies) granting rights 55 | under this Public License. 56 | 57 | h. NonCommercial means not primarily intended for or directed towards 58 | commercial advantage or monetary compensation. For purposes of 59 | this Public License, the exchange of the Licensed Material for 60 | other material subject to Copyright and Similar Rights by digital 61 | file-sharing or similar means is NonCommercial provided there is 62 | no payment of monetary compensation in connection with the 63 | exchange. 64 | 65 | i. Share means to provide material to the public by any means or 66 | process that requires permission under the Licensed Rights, such 67 | as reproduction, public display, public performance, distribution, 68 | dissemination, communication, or importation, and to make material 69 | available to the public including in ways that members of the 70 | public may access the material from a place and at a time 71 | individually chosen by them. 72 | 73 | j. Sui Generis Database Rights means rights other than copyright 74 | resulting from Directive 96/9/EC of the European Parliament and of 75 | the Council of 11 March 1996 on the legal protection of databases, 76 | as amended and/or succeeded, as well as other essentially 77 | equivalent rights anywhere in the world. 78 | 79 | k. You means the individual or entity exercising the Licensed Rights 80 | under this Public License. Your has a corresponding meaning. 81 | 82 | 83 | Section 2 -- Scope. 84 | 85 | a. License grant. 86 | 87 | 1. Subject to the terms and conditions of this Public License, 88 | the Licensor hereby grants You a worldwide, royalty-free, 89 | non-sublicensable, non-exclusive, irrevocable license to 90 | exercise the Licensed Rights in the Licensed Material to: 91 | 92 | a. reproduce and Share the Licensed Material, in whole or 93 | in part, for NonCommercial purposes only; and 94 | 95 | b. produce and reproduce, but not Share, Adapted Material 96 | for NonCommercial purposes only. 97 | 98 | 2. Exceptions and Limitations. For the avoidance of doubt, where 99 | Exceptions and Limitations apply to Your use, this Public 100 | License does not apply, and You do not need to comply with 101 | its terms and conditions. 102 | 103 | 3. Term. The term of this Public License is specified in Section 104 | 6(a). 105 | 106 | 4. Media and formats; technical modifications allowed. The 107 | Licensor authorizes You to exercise the Licensed Rights in 108 | all media and formats whether now known or hereafter created, 109 | and to make technical modifications necessary to do so. The 110 | Licensor waives and/or agrees not to assert any right or 111 | authority to forbid You from making technical modifications 112 | necessary to exercise the Licensed Rights, including 113 | technical modifications necessary to circumvent Effective 114 | Technological Measures. For purposes of this Public License, 115 | simply making modifications authorized by this Section 2(a) 116 | (4) never produces Adapted Material. 117 | 118 | 5. Downstream recipients. 119 | 120 | a. Offer from the Licensor -- Licensed Material. Every 121 | recipient of the Licensed Material automatically 122 | receives an offer from the Licensor to exercise the 123 | Licensed Rights under the terms and conditions of this 124 | Public License. 125 | 126 | b. No downstream restrictions. You may not offer or impose 127 | any additional or different terms or conditions on, or 128 | apply any Effective Technological Measures to, the 129 | Licensed Material if doing so restricts exercise of the 130 | Licensed Rights by any recipient of the Licensed 131 | Material. 132 | 133 | 6. No endorsement. Nothing in this Public License constitutes or 134 | may be construed as permission to assert or imply that You 135 | are, or that Your use of the Licensed Material is, connected 136 | with, or sponsored, endorsed, or granted official status by, 137 | the Licensor or others designated to receive attribution as 138 | provided in Section 3(a)(1)(A)(i). 139 | 140 | b. Other rights. 141 | 142 | 1. Moral rights, such as the right of integrity, are not 143 | licensed under this Public License, nor are publicity, 144 | privacy, and/or other similar personality rights; however, to 145 | the extent possible, the Licensor waives and/or agrees not to 146 | assert any such rights held by the Licensor to the limited 147 | extent necessary to allow You to exercise the Licensed 148 | Rights, but not otherwise. 149 | 150 | 2. Patent and trademark rights are not licensed under this 151 | Public License. 152 | 153 | 3. To the extent possible, the Licensor waives any right to 154 | collect royalties from You for the exercise of the Licensed 155 | Rights, whether directly or through a collecting society 156 | under any voluntary or waivable statutory or compulsory 157 | licensing scheme. In all other cases the Licensor expressly 158 | reserves any right to collect such royalties, including when 159 | the Licensed Material is used other than for NonCommercial 160 | purposes. 161 | 162 | 163 | Section 3 -- License Conditions. 164 | 165 | Your exercise of the Licensed Rights is expressly made subject to the 166 | following conditions. 167 | 168 | a. Attribution. 169 | 170 | 1. If You Share the Licensed Material, You must: 171 | 172 | a. retain the following if it is supplied by the Licensor 173 | with the Licensed Material: 174 | 175 | i. identification of the creator(s) of the Licensed 176 | Material and any others designated to receive 177 | attribution, in any reasonable manner requested by 178 | the Licensor (including by pseudonym if 179 | designated); 180 | 181 | ii. a copyright notice; 182 | 183 | iii. a notice that refers to this Public License; 184 | 185 | iv. a notice that refers to the disclaimer of 186 | warranties; 187 | 188 | v. a URI or hyperlink to the Licensed Material to the 189 | extent reasonably practicable; 190 | 191 | b. indicate if You modified the Licensed Material and 192 | retain an indication of any previous modifications; and 193 | 194 | c. indicate the Licensed Material is licensed under this 195 | Public License, and include the text of, or the URI or 196 | hyperlink to, this Public License. 197 | 198 | For the avoidance of doubt, You do not have permission under 199 | this Public License to Share Adapted Material. 200 | 201 | 2. You may satisfy the conditions in Section 3(a)(1) in any 202 | reasonable manner based on the medium, means, and context in 203 | which You Share the Licensed Material. For example, it may be 204 | reasonable to satisfy the conditions by providing a URI or 205 | hyperlink to a resource that includes the required 206 | information. 207 | 208 | 3. If requested by the Licensor, You must remove any of the 209 | information required by Section 3(a)(1)(A) to the extent 210 | reasonably practicable. 211 | 212 | 213 | Section 4 -- Sui Generis Database Rights. 214 | 215 | Where the Licensed Rights include Sui Generis Database Rights that 216 | apply to Your use of the Licensed Material: 217 | 218 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 219 | to extract, reuse, reproduce, and Share all or a substantial 220 | portion of the contents of the database for NonCommercial purposes 221 | only and provided You do not Share Adapted Material; 222 | 223 | b. if You include all or a substantial portion of the database 224 | contents in a database in which You have Sui Generis Database 225 | Rights, then the database in which You have Sui Generis Database 226 | Rights (but not its individual contents) is Adapted Material; and 227 | 228 | c. You must comply with the conditions in Section 3(a) if You Share 229 | all or a substantial portion of the contents of the database. 230 | 231 | For the avoidance of doubt, this Section 4 supplements and does not 232 | replace Your obligations under this Public License where the Licensed 233 | Rights include other Copyright and Similar Rights. 234 | 235 | 236 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 237 | 238 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 239 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 240 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 241 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 242 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 243 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 244 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 245 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 246 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 247 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 248 | 249 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 250 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 251 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 252 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 253 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 254 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 255 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 256 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 257 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 258 | 259 | c. The disclaimer of warranties and limitation of liability provided 260 | above shall be interpreted in a manner that, to the extent 261 | possible, most closely approximates an absolute disclaimer and 262 | waiver of all liability. 263 | 264 | 265 | Section 6 -- Term and Termination. 266 | 267 | a. This Public License applies for the term of the Copyright and 268 | Similar Rights licensed here. However, if You fail to comply with 269 | this Public License, then Your rights under this Public License 270 | terminate automatically. 271 | 272 | b. Where Your right to use the Licensed Material has terminated under 273 | Section 6(a), it reinstates: 274 | 275 | 1. automatically as of the date the violation is cured, provided 276 | it is cured within 30 days of Your discovery of the 277 | violation; or 278 | 279 | 2. upon express reinstatement by the Licensor. 280 | 281 | For the avoidance of doubt, this Section 6(b) does not affect any 282 | right the Licensor may have to seek remedies for Your violations 283 | of this Public License. 284 | 285 | c. For the avoidance of doubt, the Licensor may also offer the 286 | Licensed Material under separate terms or conditions or stop 287 | distributing the Licensed Material at any time; however, doing so 288 | will not terminate this Public License. 289 | 290 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 291 | License. 292 | 293 | 294 | Section 7 -- Other Terms and Conditions. 295 | 296 | a. The Licensor shall not be bound by any additional or different 297 | terms or conditions communicated by You unless expressly agreed. 298 | 299 | b. Any arrangements, understandings, or agreements regarding the 300 | Licensed Material not stated herein are separate from and 301 | independent of the terms and conditions of this Public License. 302 | 303 | 304 | Section 8 -- Interpretation. 305 | 306 | a. For the avoidance of doubt, this Public License does not, and 307 | shall not be interpreted to, reduce, limit, restrict, or impose 308 | conditions on any use of the Licensed Material that could lawfully 309 | be made without permission under this Public License. 310 | 311 | b. To the extent possible, if any provision of this Public License is 312 | deemed unenforceable, it shall be automatically reformed to the 313 | minimum extent necessary to make it enforceable. If the provision 314 | cannot be reformed, it shall be severed from this Public License 315 | without affecting the enforceability of the remaining terms and 316 | conditions. 317 | 318 | c. No term or condition of this Public License will be waived and no 319 | failure to comply consented to unless expressly agreed to by the 320 | Licensor. 321 | 322 | d. Nothing in this Public License constitutes or may be interpreted 323 | as a limitation upon, or waiver of, any privileges and immunities 324 | that apply to the Licensor or You, including from the legal 325 | processes of any jurisdiction or authority. 326 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SPACE benchmark 2 | 3 | This software project and dataset accompanies the research paper: **[Does Spatial Cognition Emerge in Frontier Models?](https://arxiv.org/pdf/2410.06468)**, *Santhosh Kumar Ramakrishnan, Erik Wijmans, Philipp Krähenbühl, Vladlen Koltun*. 4 | Published in ICLR 2025 5 | 6 | ![](assets/space_figure.png) 7 | 8 | We present SPACE, a benchmark that systematically evaluates spatial cognition in frontier models. Our benchmark builds on decades of research in cognitive science. It evaluates large-scale mapping abilities that are brought to bear when an organism traverses physical environments, smaller-scale reasoning about object shapes and layouts, and cognitive infrastructure such as spatial attention and memory. For many tasks, we instantiate parallel presentations via text and images, allowing us to benchmark both large language models and large multimodal models. Results suggest that contemporary frontier models fall short of the spatial intelligence of animals, performing near chance level on a number of classic tests of animal cognition. 9 | 10 | ## Installation instructions 11 | 12 | 1. Install mamba following instructions from [Miniforge](https://github.com/conda-forge/miniforge). Create a mamba environment. 13 | ``` 14 | mamba create -n space-benchmark python=3.9 cmake=3.14.0 -y 15 | mamba activate space-benchmark 16 | ``` 17 | 2. Clone this repo and install requirements. 18 | ``` 19 | pip install -r requirements.txt 20 | ``` 21 | 3. Install habitat for large-scale cognition navigation experiments. 22 | ``` 23 | mamba install habitat-sim=0.3.0 headless -c conda-forge -c aihabitat 24 | ``` 25 | 4. Generate a security token from huggingface and set the environment variables. 26 | ``` 27 | export HF_TOKEN= 28 | export HF_HUB_ENABLE_HF_TRANSFER=1 29 | ``` 30 | 5. Set OpenAI and Anthropic API keys for evaluating GPT and Claude models. 31 | ``` 32 | export OPENAI_API_KEY= 33 | export ANTHROPIC_API_KEY= 34 | ``` 35 | 36 | ## Downloading SPACE dataset 37 | The SPACE dataset is available [here](https://ml-site.cdn-apple.com/datasets/space/space.tar.gz). Download it to `/data/SPACE_data_release`. 38 | ``` 39 | mkdir /data 40 | cd /data 41 | 42 | wget https://ml-site.cdn-apple.com/datasets/space/space.tar.gz 43 | tar -xvzf space.tar.gz 44 | rm space.tar.gz 45 | ``` 46 | 47 | ## Evaluating models 48 | 49 | ### SPACE QA tasks 50 | **Command:** 51 | ``` 52 | python -m space.evaluate_qas \ 53 | --model_name _qa \ 54 | --data_path data/SPACE_data_release//qas.json \ 55 | --save_dir experiments/ 56 | ``` 57 | **QA task names:** `CBTT_text`, `CBTT_vision`, `DirectionEstimationBEVImage`, `DirectionEstimationBEVText`, `DirectionEstimationEgo`, `DistanceEstimationBEVImage`, `DistanceEstimationBEVText`, `DistanceEstimationEgo`, `JLO_text`, `JLO_vision`, `MPFB_text`, `MPFB_vision`, `MRT_text`, `MRT_vision`, `MapSketchingBEVImage`, `MapSketchingBEVText`, `MapSketchingEgo`, `PTT_text`, `PTT_vision`, `SAdd_text`, `SAdd_vision`, `SAtt_text`, `SAtt_vision`, `WLT_vision`
58 | **Models supported for multimodal presentation:** `claude35sonnet`, `gpt4o`, `gpt4v`, `phi35vision`, `pixtral12b`
59 | **Models supported for text-only presentation:** `claude35sonnet`, `gpt4o`, `gpt4v`, `llama3_8b`, `llama3_70b`, `mixtral8x7b`, `mixtral8x22b`, `mistral123b`, `yi15_9b`, `yi15_34b`
60 | 61 | ### SPACE navigation tasks 62 | This evaluation requires habitat-sim to be installed.
63 | 64 | **Command for egocentric navigation:** 65 | ``` 66 | python -m space.evaluate_egonav \ 67 | --model_name _egonav \ 68 | --envs_dir data/SPACE_data_release/3D_scenes/ \ 69 | --save_dir experiments/RouteRetracingEgo \ 70 | --walkthrough_key shortestpath \ 71 | --max_steps 250 72 | ``` 73 | 74 | **Command for discrete map image navigation:** 75 | ``` 76 | python -m space.evaluate_dmnav \ 77 | --model_name _dminav \ 78 | --envs_dir data/SPACE_data_release/2D_scenes \ 79 | --walkthroughs_dir data/SPACE_data_release/BEV_image_walkthroughs/ \ 80 | --obs_type image \ 81 | --save_dir experiments/RouteRetracingDiscreteMapImage \ 82 | --walkthrough_key shortestpath 83 | ``` 84 | 85 | **Command for discrete map text navigation:** 86 | ``` 87 | python -m space.evaluate_dmnav \ 88 | --model_name _dmtnav \ 89 | --envs_dir data/SPACE_data_release/2D_scenes \ 90 | --walkthroughs_dir data/SPACE_data_release/BEV_text_walkthroughs/ \ 91 | --obs_type text \ 92 | --save_dir experiments/RouteRetracingDiscreteMapText \ 93 | --walkthrough_key shortestpath 94 | ``` 95 | 96 | **Notes:** 97 | * These commands evaluate on route retracing. Set `--walkthrough_key walkthrough` for evaluating on shortcut discovery.
98 | * Models supported for multimodal presentation: `claude35sonnet`, `gpt4o`, `gpt4v`
99 | * Models supported for text-only presentation: `claude35sonnet`, `gpt4o`, `gpt4v`, `llama3_8b`, `llama3_70b`, `mixtral8x7b`, `mixtral8x22b`, `mistral123b`, `yi15_9b`, `yi15_34b` 100 | 101 | ### SPACE CSWM task 102 | **Command for multimodal presentation:** 103 | ``` 104 | python -m space.evaluate_cswm \ 105 | --model_name _cswm_vision \ 106 | --envs_dir data/SPACE_data_release/CSWM_vision/ \ 107 | --save_dir experiments/CSWM_vision \ 108 | --game_mode vision 109 | ``` 110 | **Command for text-only presentation:** 111 | ``` 112 | python -m space.evaluate_cswm \ 113 | --model_name _cswm_text \ 114 | --envs_dir data/SPACE_data_release/CSWM_text/ \ 115 | --save_dir experiments/CSWM_text \ 116 | --game_mode text 117 | ``` 118 | **Models supported for multimodal presentation:** `claude35sonnet`, `gpt4o`, `gpt4v`
119 | **Models supported for text-only presentation:** `claude35sonnet`, `gpt4o`, `gpt4v`, `llama3_8b`, `llama3_70b`, `mixtral8x7b`, `mixtral8x22b`, `mistral123b`, `yi15_9b`, `yi15_34b` 120 | 121 | ### SPACE maze completion task 122 | **Command for multimodal presentation:** 123 | ``` 124 | python -m space.evaluate_mct \ 125 | --model_name _mct_vision \ 126 | --envs_dir data/SPACE_data_release/MCT_vision/envs \ 127 | --save_dir experiments/MCT_vision 128 | ``` 129 | **Command for text-only presentation:** 130 | ``` 131 | python -m space.evaluate_mct \ 132 | --model_name _mct_text \ 133 | --envs_dir data/SPACE_data_release/MCT_text/envs \ 134 | --save_dir experiments/MCT_text 135 | ``` 136 | **Models supported for multimodal presentation:** `claude35sonnet`, `gpt4o`, `gpt4v`
137 | **Models supported for text-only presentation:** `claude35sonnet`, `gpt4o`, `gpt4v`, `llama3_8b`, `llama3_70b`, `mixtral8x7b`, `mixtral8x22b`, `mistral123b`, `yi15_9b`, `yi15_34b` 138 | 139 | 140 | ## Citation 141 | ``` 142 | @inproceedings{ramakrishnan2025space, 143 | title={Does Spatial Cognition Emerge in Frontier Models?}, 144 | author={Ramakrishnan, Santhosh Kumar and Wijmans, Erik and Kraehenbuehl, Philipp and Koltun, Vladlen}, 145 | booktitle={International Conference on Learning Representations}, 146 | year={2025}, 147 | url={https://openreview.net/forum?id=WK6K1FMEQ1} 148 | } 149 | ``` 150 | 151 | ## License 152 | This project's code is released under the Apple Sample Code License (see [LICENSE](LICENSE)). This project's data is released under the CC-BY-NC-ND license (see [LICENSE_DATA](LICENSE_DATA)). 153 | 154 | ## Acknowledgements 155 | Our codebase is built using opensource contributions, please see [Acknowledgements](ACKNOWLEDGEMENTS.md) for more details. 156 | 157 | Please check the paper for a complete list of references and datasets used in this work. 158 | -------------------------------------------------------------------------------- /assets/space_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-space-benchmark/564e43932adc84543800dd56b99cee37efaeabd8/assets/space_figure.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | anthropic 3 | opencv-python 4 | fire 5 | func-timeout 6 | huggingface_hub[hf_transfer] 7 | imageio 8 | mdutils 9 | matplotlib 10 | mistral_common 11 | networkx 12 | numpy 13 | openai 14 | Pillow 15 | scipy 16 | tqdm 17 | transformers 18 | vllm 19 | -------------------------------------------------------------------------------- /space/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | from space.agents import get_agent as get_agent 4 | from space.configs import get_config as get_config 5 | from space.envs import get_env as get_env 6 | -------------------------------------------------------------------------------- /space/agents/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | from typing import Any 4 | from space.registry import AGENTS_REGISTRY 5 | import space.agents.cswm_agent # noqa 6 | import space.agents.dmnav_agent # noqa 7 | import space.agents.egonav_agent # noqa 8 | import space.agents.mct_agent # noqa 9 | import space.agents.qa_agent # noqa 10 | 11 | 12 | def get_agent(agent_name: str, agent_cfg: dict[str, Any]): 13 | assert agent_name in AGENTS_REGISTRY 14 | agent_cls = AGENTS_REGISTRY[agent_name] 15 | agent = agent_cls(**agent_cfg) 16 | return agent 17 | -------------------------------------------------------------------------------- /space/agents/basenav_agent.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | import os 4 | import numpy as np 5 | import time 6 | 7 | from abc import ABC 8 | from typing import Any, Optional, Union 9 | from mdutils.mdutils import MdUtils 10 | 11 | from space.utils.openai_api import ( 12 | Dialog, 13 | setup_llm_client, 14 | ) 15 | 16 | from space.utils.claude_api import ( 17 | Dialog as DialogClaude, 18 | setup_llm_client as setup_llm_client_claude, 19 | ) 20 | from space.utils.model import get_model_response 21 | from space.utils.common import count_images_in_query 22 | 23 | 24 | class Base_Navigation_Agent(ABC): 25 | def __init__( 26 | self, 27 | model_name: str, 28 | host_port: str, 29 | save_dir: str, 30 | image_detail: str = "low", 31 | max_new_tokens: int = 2048, 32 | max_history_length: int = -1, 33 | completion_cost_per_mil: float = 0.0, 34 | prompt_cost_per_mil: float = 0.0, 35 | supports_system_prompt: bool = True, 36 | max_context_tokens: int = 8192, 37 | context_truncation_factor: float = 0.9, 38 | subsampling_factor: int = 1, 39 | max_images_per_query: int = -1, 40 | **kwargs, 41 | ): 42 | self.model_name = model_name 43 | self.host_port = host_port 44 | self.save_dir = save_dir 45 | self.image_detail = image_detail 46 | self.max_new_tokens = max_new_tokens 47 | self.max_history_length = max_history_length 48 | self.completion_cost_per_mil = completion_cost_per_mil 49 | self.prompt_cost_per_mil = prompt_cost_per_mil 50 | self.supports_system_prompt = supports_system_prompt 51 | self.max_context_tokens = max_context_tokens 52 | self.context_truncation_factor = context_truncation_factor 53 | self.subsampling_factor = subsampling_factor 54 | self.max_images_per_query = max_images_per_query 55 | self.writer = None 56 | self.walkthrough_key = None 57 | self.dialog = None 58 | self.completion_tokens = None 59 | self.prompt_tokens = None 60 | 61 | if self.model_name.startswith("claude"): 62 | self.client = setup_llm_client_claude(self.model_name) 63 | else: 64 | self.client = setup_llm_client(self.model_name, self.host_port) 65 | self.task_prompt = None 66 | 67 | def reset(self, walkthrough_key: str): 68 | assert walkthrough_key in ["shortestpath", "walkthrough"] 69 | self.walkthrough_key = walkthrough_key 70 | self.set_task_prompt(walkthrough_key) 71 | os.makedirs(self.save_dir, exist_ok=True) 72 | self.writer = MdUtils( 73 | file_name=os.path.join(self.save_dir, "transcript"), 74 | title="SPACE navigation", 75 | ) 76 | if self.model_name.startswith("claude"): 77 | self.dialog = DialogClaude(self.writer) 78 | else: 79 | self.dialog = Dialog(self.writer) 80 | self.completion_tokens = 0 81 | self.prompt_tokens = 0 82 | self.necessary_context_len = 0 83 | # Setup system prompt 84 | if self.supports_system_prompt: 85 | self.dialog.add_system_message(content=self.task_prompt) 86 | self.necessary_context_len += 1 87 | else: 88 | self.dialog.add_user_message(content=self.task_prompt) 89 | self.dialog.add_assistant_message( 90 | content="Okay, I understand. Let's begin!" 91 | ) 92 | self.necessary_context_len += 2 93 | 94 | def perform_model_response_loop(self, max_retries: int = 10, sleep_sec: int = 60): 95 | n_retries = 0 96 | response = None 97 | while n_retries <= max_retries: 98 | response = get_model_response( 99 | self.client, 100 | self.dialog.dialog, 101 | self.model_name, 102 | max_tokens=self.max_new_tokens, 103 | max_retries=1, 104 | ) 105 | excptn_type = response.get("exception_type", None) 106 | excptn_code = response.get("exception_code", None) 107 | excptn_message = response.get("exception_message", None) 108 | # Check for out-of-context error 109 | ooc_error = ( 110 | excptn_type == "BadRequestError" 111 | and excptn_code == 400 112 | and "maximum context length" in excptn_message 113 | ) 114 | if ooc_error: 115 | # Out-of-context error => restrict dialog and retry 116 | self.restrict_dialog_history() 117 | elif excptn_message is None and response["text"] is not None: 118 | # No errors => finish 119 | break 120 | else: 121 | # Some other error (e.g., rate limits) => retry after sleep_sec 122 | time.sleep(sleep_sec) 123 | n_retries += 1 124 | if n_retries >= max_retries and response is None: 125 | print( 126 | f"Failed after {n_retries} retries due to following error: {response['excptn_message']}" 127 | ) 128 | 129 | return response 130 | 131 | def handle_model_response(self, content: Union[list[Any], str]): 132 | n_images = count_images_in_query(content) 133 | if self.max_images_per_query >= 1 and n_images > self.max_images_per_query: 134 | c_subset = [] 135 | n_images_subset = 0 136 | for c in content: 137 | if isinstance(c, dict) and c.get("type", None) == "image": 138 | n_images_subset += 1 139 | c_subset.append(c) 140 | if n_images_subset == self.max_images_per_query: 141 | c_subset.append( 142 | "More information will be provided in the next image. Do not say anything. Please wait before responding." 143 | ) 144 | self.dialog.add_user_message(content=c_subset) 145 | self.dialog.add_assistant_message( 146 | content="I understand. I'll continue to wait for your next message before providing any analysis or response." 147 | ) 148 | # Reset 149 | n_images_subset = 0 150 | c_subset = [] 151 | c_subset.append( 152 | "This marks the end of my message to you. Please respond now." 153 | ) 154 | self.dialog.add_user_message(content=c_subset) 155 | else: 156 | self.dialog.add_user_message(content=content) 157 | 158 | response = self.perform_model_response_loop() 159 | ########################################################## 160 | prompt_tokens = response["prompt_tokens"] 161 | completion_tokens = response["completion_tokens"] 162 | ########################################################## 163 | response_txt = self.postprocess_response(response["text"]) 164 | self.dialog.add_assistant_message(content=response_txt) 165 | self.dialog.write_dialog() 166 | ############################################################################################ 167 | # Smart context handling 168 | if response is not None: 169 | request_tokens = prompt_tokens + completion_tokens 170 | if ( 171 | request_tokens + self.max_new_tokens 172 | >= self.context_truncation_factor * self.max_context_tokens 173 | ): 174 | # Restrict dialog history 175 | self.restrict_dialog_history() 176 | ############################################################################################ 177 | # Truncate history if needed 178 | if self.max_history_length > 0: 179 | history_len = len(self.dialog.history) - self.necessary_context_len 180 | if history_len > self.max_history_length: 181 | n_steps = (history_len - self.max_history_length) // 2 182 | for _ in range(n_steps): 183 | self.restrict_dialog_history() 184 | ############################################################################################ 185 | # Track token usage 186 | self.completion_tokens += completion_tokens 187 | self.prompt_tokens += prompt_tokens 188 | total_cost = ( 189 | self.completion_tokens * self.completion_cost_per_mil / 1.0e6 190 | + self.prompt_tokens * self.prompt_cost_per_mil / 1.0e6 191 | ) 192 | self.dialog.log_token_usage( 193 | self.prompt_tokens, self.completion_tokens, total_cost 194 | ) 195 | # Log time taken 196 | self.dialog.log_response_time(response["response_time"]) 197 | self.writer.create_md_file() 198 | return response_txt 199 | 200 | def postprocess_response(self, text: str): 201 | if self.model_name.startswith("mistralai"): 202 | text = text.split("[/INST]")[-1].strip() 203 | return text 204 | 205 | def initialize_with_walkthrough( 206 | self, walkthrough_obs: Union[list[np.ndarray], list[str]] 207 | ): 208 | if self.subsampling_factor > 1: 209 | walkthrough_obs = ( 210 | [walkthrough_obs[0]] 211 | + walkthrough_obs[1 : -1 : self.subsampling_factor] 212 | + [walkthrough_obs[-1]] 213 | ) 214 | prompt = self.get_walkthrough_prompt(walkthrough_obs) 215 | _ = self.handle_model_response(prompt) 216 | self.necessary_context_len += 2 217 | 218 | def initialize_with_goal( 219 | self, goal_desc: str, goal_img: Optional[np.ndarray] = None 220 | ): 221 | message = self.get_goal_prompt(goal_desc, goal_img) 222 | self.dialog.add_user_message(content=message) 223 | self.necessary_context_len += 1 224 | 225 | def set_task_prompt(self, walkthrough_key: str): 226 | raise NotImplementedError 227 | 228 | def get_goal_prompt(self, goal_desc: str, goal_img: Optional[np.ndarray] = None): 229 | raise NotImplementedError 230 | 231 | def get_action(self, obs: Union[np.ndarray, str]): 232 | raise NotImplementedError 233 | 234 | def get_walkthrough_prompt( 235 | self, walkthrough_obs: Union[list[np.ndarray], list[str]] 236 | ): 237 | raise NotImplementedError 238 | 239 | def restrict_dialog_history(self): 240 | if len(self.dialog.history) <= 1: 241 | return 242 | 243 | task_context = self.dialog.history[: self.necessary_context_len] 244 | dialog_history = self.dialog.history[self.necessary_context_len :] 245 | self.dialog.history = task_context + dialog_history[2:] 246 | 247 | def get_eval_cost(self): 248 | total_cost = ( 249 | self.completion_tokens * self.completion_cost_per_mil / 1.0e6 250 | + self.prompt_tokens * self.prompt_cost_per_mil / 1.0e6 251 | ) 252 | return { 253 | "total_cost": total_cost, 254 | "completion_tokens": self.completion_tokens, 255 | "prompt_tokens": self.prompt_tokens, 256 | } 257 | -------------------------------------------------------------------------------- /space/agents/cswm_agent.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | import os 4 | import re 5 | import time 6 | import numpy as np 7 | 8 | from typing import Any, Union 9 | from mdutils.mdutils import MdUtils 10 | from space.registry import register_agent 11 | 12 | from space.utils.common import get_image_as_message 13 | from space.utils.openai_api import ( 14 | Dialog, 15 | setup_llm_client, 16 | ) 17 | 18 | from space.utils.claude_api import ( 19 | Dialog as DialogClaude, 20 | setup_llm_client as setup_llm_client_claude, 21 | ) 22 | from space.utils.model import get_model_response 23 | from space.utils.common import count_images_in_query 24 | 25 | 26 | VISION_TASK_PROMPT = """You are playing the Cambridge Spatial Working Memory game. You will be shown a screen with blue boxes. A treasure is hidden in one of the blue boxes. You must identify the box containing the treasure, which is shown as an yellow square. Once you find a treasure, it will be collected and placed in the "Treasures collected" section shown below the image. A new treasure will be hidden in one of the other boxes where the treasure did not appear before. You must again find the new treasure. This process is repeated till you find all treasures placed in each of the blue boxes once. Note: The treasure will never appear in a box where it had already been placed. 27 | 28 | Each turn, there are randomly selected numbers associated with each box. These numbers are meant to aid you with communication, i.e., specify what box you want to open in that turn. However, these numbers will change after every turn. So do NOT associate boxes with numbers over the long term. The number identity of a box can change any time. Therefore, you must remember the boxes based on their spatial positions and not the numbers. 29 | 30 | RESPONSE FORMAT: 31 | Think step-by-step about where the treasure might be based on your past actions. After that, indicate the box you want to open in the following json format: 32 | ``` 33 | { 34 | "action": 35 | } 36 | ```""" 37 | 38 | TEXT_TASK_PROMPT = """You are playing the Cambridge Spatial Working Memory game. You will be shown an array with integers. 0 represents empty locations. Locations numbered 1 - 9 represent boxes. A treasure is hidden in one of the boxes. You must identify the box containing the treasure. Once you find a treasure, the location will be momentarily shown as a "T" indicating that the treasure was found. The treasure is then collected and a new treasure will be hidden in one of the other boxes where the treasure did not appear before. You must then find the new treasure. This process is repeated till you find all treasures placed in each of the boxes once. Note: The treasure will never appear in a box where it had already been placed. 39 | 40 | While the boxes are represented using integers from 1 - 9, the true identity of the box is its location (row, column) in the array. The box location is always fixed (i.e., the boxes will not move and the number of boxes will not change). However, each turn, the integer id associated with the box will change randomly. These integer ids are meant to aid you with communication, i.e., specify what box you want to open in that turn. However, these numbers will change after every turn. So do NOT associate boxes with numbers over the long term. The number identity of a box can change any time. Therefore, you must remember the boxes based on their spatial positions and not the numbers. 41 | 42 | RESPONSE FORMAT: 43 | Think step-by-step about where the treasure might be based on your past actions. After that, indicate the box you want to open in the following json format: 44 | ``` 45 | { 46 | "action": 47 | } 48 | ```""" 49 | 50 | 51 | @register_agent 52 | class CSWM_Agent(object): 53 | def __init__( 54 | self, 55 | model_name: str, 56 | host_port: str, 57 | save_dir: str, 58 | task_mode: str, 59 | image_detail: str = "low", 60 | max_new_tokens: int = 2048, 61 | max_history_length: int = -1, 62 | completion_cost_per_mil: float = 0.0, 63 | prompt_cost_per_mil: float = 0.0, 64 | supports_system_prompt: bool = True, 65 | max_context_tokens: int = 8192, 66 | context_truncation_factor: float = 0.9, 67 | max_images_per_query: int = -1, 68 | **kwargs, 69 | ): 70 | assert task_mode in ["vision", "text"] 71 | self.model_name = model_name 72 | self.host_port = host_port 73 | self.save_dir = save_dir 74 | self.task_mode = task_mode 75 | self.image_detail = image_detail 76 | self.max_new_tokens = max_new_tokens 77 | self.max_history_length = max_history_length 78 | self.completion_cost_per_mil = completion_cost_per_mil 79 | self.prompt_cost_per_mil = prompt_cost_per_mil 80 | self.supports_system_prompt = supports_system_prompt 81 | self.max_context_tokens = max_context_tokens 82 | self.context_truncation_factor = context_truncation_factor 83 | self.max_images_per_query = max_images_per_query 84 | self.writer = None 85 | self.dialog = None 86 | self.completion_tokens = None 87 | self.prompt_tokens = None 88 | 89 | if self.model_name.startswith("claude"): 90 | self.client = setup_llm_client_claude(self.model_name) 91 | else: 92 | self.client = setup_llm_client(self.model_name, self.host_port) 93 | 94 | if self.task_mode == "vision": 95 | self.task_prompt = VISION_TASK_PROMPT 96 | else: 97 | self.task_prompt = TEXT_TASK_PROMPT 98 | 99 | def reset(self): 100 | os.makedirs(self.save_dir, exist_ok=True) 101 | self.writer = MdUtils( 102 | file_name=os.path.join(self.save_dir, "transcript"), 103 | title="SPACE CSWM", 104 | ) 105 | if self.model_name.startswith("claude"): 106 | self.dialog = DialogClaude(self.writer) 107 | else: 108 | self.dialog = Dialog(self.writer) 109 | self.completion_tokens = 0 110 | self.prompt_tokens = 0 111 | # Setup system prompt 112 | if self.supports_system_prompt: 113 | self.dialog.add_system_message(content=self.task_prompt) 114 | else: 115 | self.dialog.add_user_message(content=self.task_prompt) 116 | self.dialog.add_assistant_message( 117 | content="Okay, I understand. Let's begin!" 118 | ) 119 | 120 | def get_action(self, obs: Union[np.ndarray, str]): 121 | content = self.get_user_message(obs) 122 | response_txt = self.handle_model_response(content) 123 | pred_action = self.parse_answer_from_response(response_txt) 124 | return pred_action 125 | 126 | def get_user_message(self, obs: Union[np.ndarray, str]): 127 | if self.task_mode == "vision": 128 | assert isinstance(obs, np.ndarray) 129 | content = [ 130 | "Here is the current state of the game. You must find the next treasure. Note that the numbers of the boxes have changed, but the box locations are fixed. Decide which box you want to open next, and then use the number associated with the box as the action.", 131 | get_image_as_message( 132 | image=obs, 133 | model_name=self.model_name, 134 | image_detail=self.image_detail, 135 | ), 136 | ] 137 | else: 138 | assert isinstance(obs, str) 139 | content = obs 140 | return content 141 | 142 | def postprocess_response(self, text: str): 143 | if self.model_name.startswith("mistralai"): 144 | text = text.split("[/INST]")[-1].strip() 145 | return text 146 | 147 | def perform_model_response_loop(self, max_retries: int = 10, sleep_sec: int = 60): 148 | n_retries = 0 149 | response = None 150 | while n_retries <= max_retries: 151 | response = get_model_response( 152 | self.client, 153 | self.dialog.dialog, 154 | self.model_name, 155 | max_tokens=self.max_new_tokens, 156 | max_retries=1, 157 | ) 158 | excptn_type = response.get("exception_type", None) 159 | excptn_code = response.get("exception_code", None) 160 | excptn_message = response.get("exception_message", None) 161 | # Check for out-of-context error 162 | ooc_error = ( 163 | excptn_type == "BadRequestError" 164 | and excptn_code == 400 165 | and "maximum context length" in excptn_message 166 | ) 167 | if ooc_error: 168 | # Out-of-context error => restrict dialog and retry 169 | self.restrict_dialog_history() 170 | elif excptn_message is None and response["text"] is not None: 171 | # No errors => finish 172 | break 173 | else: 174 | # Some other error (e.g., rate limits) => retry after sleep_sec 175 | time.sleep(sleep_sec) 176 | n_retries += 1 177 | if n_retries >= max_retries and response is None: 178 | print( 179 | f"Failed after {n_retries} retries due to following error: {response['excptn_message']}" 180 | ) 181 | 182 | return response 183 | 184 | def handle_model_response(self, content: Union[list[Any], str]): 185 | n_images = count_images_in_query(content) 186 | if self.max_images_per_query >= 1 and n_images > self.max_images_per_query: 187 | c_subset = [] 188 | n_images_subset = 0 189 | for c in content: 190 | if isinstance(c, dict) and c.get("type", None) == "image": 191 | n_images_subset += 1 192 | c_subset.append(c) 193 | if n_images_subset == self.max_images_per_query: 194 | c_subset.append( 195 | "More information will be provided in the next image. Do not say anything. Please wait before responding." 196 | ) 197 | self.dialog.add_user_message(content=c_subset) 198 | self.dialog.add_assistant_message( 199 | content="I understand. I'll continue to wait for your next message before providing any analysis or response." 200 | ) 201 | # Reset 202 | n_images_subset = 0 203 | c_subset = [] 204 | c_subset.append( 205 | "This marks the end of my message to you. Please respond now." 206 | ) 207 | self.dialog.add_user_message(content=c_subset) 208 | else: 209 | self.dialog.add_user_message(content=content) 210 | 211 | response = self.perform_model_response_loop() 212 | ########################################################## 213 | prompt_tokens = response["prompt_tokens"] 214 | completion_tokens = response["completion_tokens"] 215 | ########################################################## 216 | response_txt = self.postprocess_response(response["text"]) 217 | self.dialog.add_assistant_message(content=response_txt) 218 | self.dialog.write_dialog() 219 | ############################################################################################ 220 | # Smart context handling 221 | if response is not None: 222 | request_tokens = prompt_tokens + completion_tokens 223 | if ( 224 | request_tokens + self.max_new_tokens 225 | >= self.context_truncation_factor * self.max_context_tokens 226 | ): 227 | # Restrict dialog history 228 | self.restrict_dialog_history() 229 | ############################################################################################ 230 | # Truncate history if needed 231 | if self.max_history_length > 0: 232 | history_len = len(self.dialog.history) 233 | if self.supports_system_prompt: 234 | history_len -= 1 235 | else: 236 | history_len -= 2 237 | if history_len > self.max_history_length: 238 | n_steps = (history_len - self.max_history_length) // 2 239 | for _ in range(n_steps): 240 | self.restrict_dialog_history() 241 | ############################################################################################ 242 | # Track token usage 243 | self.completion_tokens += completion_tokens 244 | self.prompt_tokens += prompt_tokens 245 | total_cost = ( 246 | self.completion_tokens * self.completion_cost_per_mil / 1.0e6 247 | + self.prompt_tokens * self.prompt_cost_per_mil / 1.0e6 248 | ) 249 | self.dialog.log_token_usage( 250 | self.prompt_tokens, self.completion_tokens, total_cost 251 | ) 252 | # Log time taken 253 | self.dialog.log_response_time(response["response_time"]) 254 | return response_txt 255 | 256 | def parse_answer_from_response(self, response_txt: str): 257 | try: 258 | re_out = re.search(r'{\s*"action":\s*(.*)\s*}', response_txt) 259 | pred_action = int(re_out.groups()[-1]) 260 | except Exception as e: 261 | print( 262 | f"Unable to parse predictions from '{response_txt}'. Got exception {e}" 263 | ) 264 | pred_action = 0 if self.task_mode == "vision" else 1 265 | return pred_action 266 | 267 | def restrict_dialog_history(self): 268 | if len(self.dialog.history) <= 1: 269 | return 270 | if self.supports_system_prompt: 271 | task_context = self.dialog.history[:1] 272 | dialog_history = self.dialog.history[1:] 273 | else: 274 | task_context = self.dialog.history[:2] 275 | dialog_history = self.dialog.history[2:] 276 | self.dialog.history = task_context + dialog_history[2:] 277 | 278 | def get_eval_cost(self): 279 | total_cost = ( 280 | self.completion_tokens * self.completion_cost_per_mil / 1.0e6 281 | + self.prompt_tokens * self.prompt_cost_per_mil / 1.0e6 282 | ) 283 | return { 284 | "total_cost": total_cost, 285 | "completion_tokens": self.completion_tokens, 286 | "prompt_tokens": self.prompt_tokens, 287 | } 288 | -------------------------------------------------------------------------------- /space/agents/egonav_agent.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | import re 4 | 5 | from typing import Any 6 | import numpy as np 7 | 8 | from space.utils.common import get_image_as_message 9 | 10 | from space.registry import register_agent 11 | from space.agents.basenav_agent import Base_Navigation_Agent 12 | 13 | 14 | TASK_PROMPT_ROUTE_FOLLOWING = """You are a sentient living creature capable of navigating in environments, building internal spatial representations of environments, and finding goals in them. You will be shown a video of the shortest route from the initial position to the goal. You must look at the video and understand the environment structure and the route taken. Then, you will be placed in the environment at the same initial position. You must navigate from the initial position to the goal using the same route shown in the video, as quickly as possible. Below, you will find sections highlighting more details about the task. You can refer to these for more information. 15 | 16 | OBSERVATIONS: 17 | The images are recorded from a perspective viewpoint (i.e., egocentric or first-person). This means that you are likely to see objects from different angles, resulting in a skewed appearance of the underlying 3D objects. It is important for you to look past this skew in the appearance and percive the true shape of the object in 3D. 18 | 19 | GOAL: 20 | You will be provided an object goal using a text description and an image of the object. You must find the goal object in the environment by repeating the path shown in the video walkthrough. Once you find it, move close to the location of the goal and re-orient yourself to face the object. 21 | 22 | ACTIONS: 23 | You have four actions available. 24 | 25 | move_forward: move forward by 0.25m along the current heading direction. It does not change the heading angle. 26 | turn_left: decrease your heading angle by 30 degrees. It does not change the (x, y) position. 27 | turn_right: increase your heading angle by 30 degrees. It does not change the (x, y) position. 28 | stop: ends the current task. Issue this action only if you think you have reached the goal. If you haven't reached the goal, this action will result in a navigation failure that cannot be recovered from. 29 | 30 | STUCK IN PLACE BEHAVIOR: 31 | Avoid getting stuck in one place, i.e., do not alternate between left and right turns without going anywhere. You must try and move around consistently without being stuck in one place. 32 | 33 | STOP CRITERIA: 34 | Before executing stop, you must ensure that you've "reached" the goal correctly. To reach a goal, you have to move close enough to the wall where you see the goal, and see the object clearly in your observation in front of you. 35 | 36 | RESPONSE FORMAT: 37 | Respond in the following format: 38 | 39 | Reasoning: 40 | Intent: 41 | Then provide the final action to take in a json formatted string. 42 | ``` 43 | { 44 | "action": 45 | } 46 | ``` 47 | """ 48 | 49 | TASK_PROMPT_NOVEL_SHORTCUTS = """You are a sentient living creature capable of navigating in environments, building internal spatial representations of environments, and finding goals in them. You will be shown a video of some route from the initial position to the goal. You must look at the video and understand the environment structure, and remember the locations of the start and the goal. The video may show a long-winded route from the start to the goal with unnecessary detours. Based on the environment structure, you must identify a faster route to the goal. Then, you will be placed in the environment at the same initial position. You must navigate to the goal using your identified shortest route as quickly as possible. Below, you will find sections highlighting more details about the task. You can refer to these for more information. 50 | 51 | OBSERVATIONS: 52 | The images are recorded from a perspective viewpoint (i.e., egocentric or first-person). This means that you are likely to see objects from different angles, resulting in a skewed appearance of the underlying 3D objects. It is important for you to look past this skew in the appearance and percive the true shape of the object in 3D. 53 | 54 | GOAL: 55 | You will be provided an object goal using a text description and an image of the object. You must find the goal object in the environment by identifying the shortest route based on your experience from the video. Once you find the goal, move close to its location and re-orient yourself to face the object. 56 | 57 | ACTIONS: 58 | You have four actions available. 59 | 60 | move_forward: move forward by 0.25m along the current heading direction. It does not change the heading angle. 61 | turn_left: decrease your heading angle by 30 degrees. It does not change the (x, y) position. 62 | turn_right: increase your heading angle by 30 degrees. It does not change the (x, y) position. 63 | stop: ends the current task. Issue this action only if you think you have reached the goal. If you haven't reached the goal, this action will result in a navigation failure that cannot be recovered from. 64 | 65 | STUCK IN PLACE BEHAVIOR: 66 | Avoid getting stuck in one place, i.e., do not alternate between left and right turns without going anywhere. You must try and move around consistently without being stuck in one place. 67 | 68 | STOP CRITERIA: 69 | Before executing stop, you must ensure that you've "reached" the goal correctly. To reach a goal, you have to move the robot close enough to the wall where you see the goal, and see the object clearly in your observation in front of you. 70 | 71 | RESPONSE FORMAT: 72 | Respond in the following format: 73 | 74 | Reasoning: 75 | Intent: 76 | Then provide the final action to take in a json formatted string. 77 | ``` 78 | { 79 | "action": 80 | } 81 | ``` 82 | """ 83 | 84 | 85 | @register_agent 86 | class EgoNav_Agent(Base_Navigation_Agent): 87 | def set_task_prompt(self, walkthrough_key: str): 88 | assert walkthrough_key in ["shortestpath", "walkthrough"] 89 | if walkthrough_key == "shortestpath": 90 | # Route following experiment 91 | self.task_prompt = TASK_PROMPT_ROUTE_FOLLOWING 92 | else: 93 | # Novel shortcuts experiment 94 | self.task_prompt = TASK_PROMPT_NOVEL_SHORTCUTS 95 | 96 | def get_goal_prompt(self, goal_desc: str, goal_img: np.ndarray): 97 | message = [ 98 | f"Now, you must navigate to the goal. Here is the goal description and the image: {goal_desc}", 99 | get_image_as_message( 100 | image=goal_img, 101 | model_name=self.model_name, 102 | image_detail=self.image_detail, 103 | ), 104 | ] 105 | return message 106 | 107 | def get_action(self, obs: np.ndarray): 108 | obs_encoded = get_image_as_message( 109 | image=obs, model_name=self.model_name, image_detail=self.image_detail 110 | ) 111 | new_message = self.get_observation_prompt(obs_encoded) + [ 112 | "What action do you select next? The available actions are move_forward, turn_left, turn_right and stop. Recall that each turn action is only 30 degrees and each forward step is only 0.25m, so you may have to execute several actions to notice substantial changes in your viewpoints. Be patient and persist with your actions over a longer time horizon.", 113 | ] 114 | response_txt = self.handle_model_response(new_message) 115 | ############################################################################################ 116 | # Clean-up context history 117 | # Remove explanatory text content of last user message "What actions do you select next? ...." 118 | assert self.dialog.history[-2]["role"] == "user" 119 | self.dialog.history[-2] = { 120 | "role": "user", 121 | "content": self.get_clean_observation_prompt(obs_encoded), 122 | } 123 | ############################################################################################ 124 | action = self.convert_response_to_action(response_txt) 125 | if action is None: 126 | # Turn left if no action was parseable from the model 127 | action = "turn_left" 128 | 129 | return action 130 | 131 | def get_observation_prompt(self, image_encoded: dict[str, Any]) -> list[str]: 132 | return [ 133 | "Here is the current observation. If you are stuck very close to the same wall for several steps, it means that you are colliding and need to turn around and search elsewhere.", 134 | image_encoded, 135 | ] 136 | 137 | def get_clean_observation_prompt(self, image_encoded: str): 138 | return ["Here is the current observation.", image_encoded] 139 | 140 | def get_walkthrough_prompt(self, walkthrough_frames: np.ndarray) -> list[str]: 141 | if self.walkthrough_key == "shortestpath": 142 | output = [ 143 | "Here are the sequence of frames from the walkthrough video demonstrating the route you need to take. Analyze the walkthrough to understand the movements and the maze structure. Take a note of all the details needed to help you repeat this route when navigating next. Think step by step." 144 | ] 145 | elif self.walkthrough_key == "walkthrough": 146 | output = [ 147 | "Here are the sequence of frames from the walkthrough video demonstrating a suboptimal route from the start to some goal location. Analyze the walkthrough to understand the movements and the environment structure. Keep track of the start and goal locations, and the current location in the environment as you watch the walkthrough. Then plan a shortcut route that takes you to the goal while avoiding unnecessary detours. Think step by step." 148 | ] 149 | else: 150 | raise ValueError( 151 | f"Unable to process walkthrough_key = {self.walkthrough_key}" 152 | ) 153 | for frame in walkthrough_frames: 154 | frame_encoded = get_image_as_message( 155 | image=frame, model_name=self.model_name, image_detail=self.image_detail 156 | ) 157 | output.append(frame_encoded) 158 | return output 159 | 160 | def convert_response_to_action(self, response_txt: str) -> str: 161 | try: 162 | re_out = re.search(r'{\s+"action":\s*"(.*)"\s+}', response_txt) 163 | pred_action = re_out.groups()[0] 164 | except Exception as e: 165 | print(f"Unable to parse predictions. Got exception {e}") 166 | pred_action = None 167 | return pred_action 168 | -------------------------------------------------------------------------------- /space/agents/mct_agent.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | import os 4 | import re 5 | import time 6 | from abc import ABC 7 | from typing import Any, Union 8 | 9 | import numpy as np 10 | from mdutils.mdutils import MdUtils 11 | 12 | from space.registry import register_agent 13 | from space.utils.common import get_image_as_message 14 | from space.utils.openai_api import ( 15 | Dialog, 16 | setup_llm_client, 17 | ) 18 | 19 | from space.utils.claude_api import ( 20 | Dialog as DialogClaude, 21 | setup_llm_client as setup_llm_client_claude, 22 | ) 23 | from space.utils.model import get_model_response 24 | from space.utils.common import count_images_in_query 25 | 26 | 27 | VISION_TASK_PROMPT = """You are a sentient living creature capable navigating in mazes, planning, and spatial reasoning. You are playing a Pacman-style maze game. You start at some random position in the maze. You must escape the maze as quickly as possible to reach the goal. You are given the game screen that shows the following: 28 | * maze structure - blue is obstacle space, black is navigable space. You can only move on black spaces. You cannot move through blue spaces. 29 | * your current position - yellow square 30 | * goal position - red circle 31 | 32 | Below the screen, a status message might appear indicating that you collided into a wall after your previous action. 33 | 34 | Actions available: You can take five possible actions. 35 | * left - move left from your current position by one step 36 | * right - move right from your current position by one step 37 | * up - move up from your current position by one step 38 | * down - move down from your current position by one step 39 | * stop - issue this action only after you have reached the goal position. If you execute it prematurely, you will fail. If you do not execute it after reaching the goal, you will again fail. 40 | 41 | Response format: Respond in the following format. 42 | 43 | 44 | 45 | ``` 46 | { 47 | "action": "" 48 | } 49 | ``` 50 | """ 51 | 52 | TEXT_TASK_PROMPT = """You are a sentient living creature capable navigating in mazes, planning, and spatial reasoning. You are playing a text-based maze game. You start at some random position in the maze. You must escape the maze as quickly as possible to reach the goal. You are given a 2D array representing the maze, which contains the following: 53 | * maze structure - 0 is obstacle space, 1 is navigable space. You can only move on 1s (i.e., navigable spaces). You cannot move through 0s (i.e., obstacles). 54 | * your current position - marked as A 55 | * goal position - marked as G 56 | 57 | Goal and current positions are always navigable spaces. 58 | 59 | Actions available: You can take five possible actions. 60 | * left - move left from your current position by one step 61 | * right - move right from your current position by one step 62 | * up - move up from your current position by one step 63 | * down - move down from your current position by one step 64 | * stop - issue this action only after you have reached the goal position. If you execute it prematurely, you will fail. If you do not execute it after reaching the goal, you will again fail. 65 | 66 | Response format: Respond in the following format. 67 | 68 | 69 | 70 | ``` 71 | { 72 | "action": "" 73 | } 74 | ``` 75 | """ 76 | 77 | 78 | @register_agent 79 | class MCT_Agent(ABC): 80 | def __init__( 81 | self, 82 | model_name: str, 83 | host_port: str, 84 | save_dir: str, 85 | description_type: str, 86 | image_detail: str = "low", 87 | max_new_tokens: int = 2048, 88 | max_history_length: int = 20, 89 | completion_cost_per_mil: float = 0.0, 90 | prompt_cost_per_mil: float = 0.0, 91 | supports_system_prompt: bool = True, 92 | max_context_tokens: int = 8192, 93 | context_truncation_factor: float = 0.9, 94 | max_images_per_query: int = -1, 95 | **kwargs, 96 | ): 97 | assert description_type in ["image", "text"] 98 | self.model_name = model_name 99 | self.host_port = host_port 100 | self.save_dir = save_dir 101 | self.description_type = description_type 102 | self.image_detail = image_detail 103 | self.max_new_tokens = max_new_tokens 104 | self.max_history_length = max_history_length 105 | self.completion_cost_per_mil = completion_cost_per_mil 106 | self.prompt_cost_per_mil = prompt_cost_per_mil 107 | self.supports_system_prompt = supports_system_prompt 108 | self.max_context_tokens = max_context_tokens 109 | self.context_truncation_factor = context_truncation_factor 110 | self.max_images_per_query = max_images_per_query 111 | self.writer = None 112 | self.dialog = None 113 | self.completion_tokens = None 114 | self.prompt_tokens = None 115 | 116 | if self.model_name.startswith("claude"): 117 | self.client = setup_llm_client_claude(self.model_name) 118 | else: 119 | self.client = setup_llm_client(self.model_name, self.host_port) 120 | 121 | if self.description_type == "image": 122 | self.task_prompt = VISION_TASK_PROMPT 123 | else: 124 | self.task_prompt = TEXT_TASK_PROMPT 125 | 126 | def reset(self): 127 | os.makedirs(self.save_dir, exist_ok=True) 128 | self.writer = MdUtils( 129 | file_name=os.path.join(self.save_dir, "transcript"), 130 | title="SPACE MCT", 131 | ) 132 | if self.model_name.startswith("claude"): 133 | self.dialog = DialogClaude(self.writer) 134 | else: 135 | self.dialog = Dialog(self.writer) 136 | self.completion_tokens = 0 137 | self.prompt_tokens = 0 138 | # Setup system prompt 139 | if self.supports_system_prompt: 140 | self.dialog.add_system_message(content=self.task_prompt) 141 | else: 142 | self.dialog.add_user_message(content=self.task_prompt) 143 | self.dialog.add_assistant_message( 144 | content="Okay, I understand. Let's begin!" 145 | ) 146 | 147 | def get_action(self, obs: Union[str, np.ndarray]): 148 | content = self.get_user_message(obs) 149 | response_txt = self.handle_model_response(content) 150 | pred_action = self.parse_answer_from_response(response_txt) 151 | return pred_action 152 | 153 | def get_user_message(self, obs: Union[np.ndarray, str]): 154 | if self.description_type == "image": 155 | assert isinstance(obs, np.ndarray) 156 | content = [ 157 | "Here is the current state of the maze.", 158 | get_image_as_message( 159 | image=obs, 160 | model_name=self.model_name, 161 | image_detail=self.image_detail, 162 | ), 163 | ] 164 | content.append( 165 | "Think step-by-step about how to reach the goal. What action do you take next?" 166 | ) 167 | else: 168 | assert isinstance(obs, str) 169 | content = obs 170 | content += "Think step-by-step about how to reach the goal. What action do you take next?" 171 | return content 172 | 173 | def postprocess_response(self, text: str): 174 | if self.model_name.startswith("mistralai"): 175 | text = text.split("[/INST]")[-1].strip() 176 | return text 177 | 178 | def perform_model_response_loop(self, max_retries: int = 10, sleep_sec: int = 60): 179 | n_retries = 0 180 | response = None 181 | while n_retries <= max_retries: 182 | response = get_model_response( 183 | self.client, 184 | self.dialog.dialog, 185 | self.model_name, 186 | max_tokens=self.max_new_tokens, 187 | max_retries=1, 188 | ) 189 | excptn_type = response.get("exception_type", None) 190 | excptn_code = response.get("exception_code", None) 191 | excptn_message = response.get("exception_message", None) 192 | # Check for out-of-context error 193 | ooc_error = ( 194 | excptn_type == "BadRequestError" 195 | and excptn_code == 400 196 | and "maximum context length" in excptn_message 197 | ) 198 | if ooc_error: 199 | # Out-of-context error => restrict dialog and retry 200 | self.restrict_dialog_history() 201 | elif excptn_message is None and response["text"] is not None: 202 | # No errors => finish 203 | break 204 | else: 205 | # Some other error (e.g., rate limits) => retry after sleep_sec 206 | time.sleep(sleep_sec) 207 | n_retries += 1 208 | if n_retries >= max_retries and response is None: 209 | print( 210 | f"Failed after {n_retries} retries due to following error: {response['excptn_message']}" 211 | ) 212 | 213 | return response 214 | 215 | def handle_model_response(self, content: Union[list[Any], str]): 216 | n_images = count_images_in_query(content) 217 | if self.max_images_per_query >= 1 and n_images > self.max_images_per_query: 218 | c_subset = [] 219 | n_images_subset = 0 220 | for c in content: 221 | if isinstance(c, dict) and c.get("type", None) == "image": 222 | n_images_subset += 1 223 | c_subset.append(c) 224 | if n_images_subset == self.max_images_per_query: 225 | c_subset.append( 226 | "More information will be provided in the next image. Do not say anything. Please wait before responding." 227 | ) 228 | self.dialog.add_user_message(content=c_subset) 229 | self.dialog.add_assistant_message( 230 | content="I understand. I'll continue to wait for your next message before providing any analysis or response." 231 | ) 232 | # Reset 233 | n_images_subset = 0 234 | c_subset = [] 235 | c_subset.append( 236 | "This marks the end of my message to you. Please respond now." 237 | ) 238 | self.dialog.add_user_message(content=c_subset) 239 | else: 240 | self.dialog.add_user_message(content=content) 241 | response = self.perform_model_response_loop() 242 | ########################################################## 243 | prompt_tokens = response["prompt_tokens"] 244 | completion_tokens = response["completion_tokens"] 245 | ########################################################## 246 | response_txt = self.postprocess_response(response["text"]) 247 | self.dialog.add_assistant_message(content=response_txt) 248 | self.dialog.write_dialog() 249 | ############################################################################################ 250 | # Smart context handling 251 | if response is not None: 252 | request_tokens = prompt_tokens + completion_tokens 253 | if ( 254 | request_tokens + self.max_new_tokens 255 | >= self.context_truncation_factor * self.max_context_tokens 256 | ): 257 | # Restrict dialog history 258 | self.restrict_dialog_history() 259 | ############################################################################################ 260 | # Truncate history if needed 261 | if self.max_history_length > 0: 262 | history_len = len(self.dialog.history) 263 | if self.supports_system_prompt: 264 | history_len -= 1 265 | else: 266 | history_len -= 2 267 | if history_len > self.max_history_length: 268 | n_steps = (history_len - self.max_history_length) // 2 269 | for _ in range(n_steps): 270 | self.restrict_dialog_history() 271 | ############################################################################################ 272 | # Track token usage 273 | self.completion_tokens += completion_tokens 274 | self.prompt_tokens += prompt_tokens 275 | total_cost = ( 276 | self.completion_tokens * self.completion_cost_per_mil / 1.0e6 277 | + self.prompt_tokens * self.prompt_cost_per_mil / 1.0e6 278 | ) 279 | self.dialog.log_token_usage( 280 | self.prompt_tokens, self.completion_tokens, total_cost 281 | ) 282 | # Log time taken 283 | self.dialog.log_response_time(response["response_time"]) 284 | return response_txt 285 | 286 | def parse_answer_from_response(self, response_txt: str): 287 | try: 288 | re_out = re.search(r'{\s*"action":\s*"(.*)"\s*}', response_txt) 289 | pred_action = re_out.groups()[0] 290 | if pred_action not in ["up", "left", "right", "down", "stop"]: 291 | print( 292 | f".... Invalid action predicted: `{pred_action}`. Replacing it with action `up`." 293 | ) 294 | pred_action = "up" 295 | except Exception as e: 296 | print( 297 | f"Unable to parse predictions from '{response_txt}'. Got exception {e}" 298 | ) 299 | pred_action = "up" 300 | return pred_action 301 | 302 | def restrict_dialog_history(self): 303 | if len(self.dialog.history) <= 1: 304 | return 305 | if self.supports_system_prompt: 306 | task_context = self.dialog.history[:1] 307 | dialog_history = self.dialog.history[1:] 308 | else: 309 | task_context = self.dialog.history[:2] 310 | dialog_history = self.dialog.history[2:] 311 | self.dialog.history = task_context + dialog_history[2:] 312 | 313 | def get_eval_cost(self): 314 | total_cost = ( 315 | self.completion_tokens * self.completion_cost_per_mil / 1.0e6 316 | + self.prompt_tokens * self.prompt_cost_per_mil / 1.0e6 317 | ) 318 | return { 319 | "total_cost": total_cost, 320 | "completion_tokens": self.completion_tokens, 321 | "prompt_tokens": self.prompt_tokens, 322 | } 323 | -------------------------------------------------------------------------------- /space/agents/qa_agent.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | import json 4 | import os 5 | import re 6 | from typing import Any, Union 7 | 8 | from mdutils.mdutils import MdUtils 9 | from space.registry import register_agent 10 | 11 | from space.utils.openai_api import ( 12 | Dialog, 13 | setup_llm_client, 14 | ) 15 | 16 | from space.utils.claude_api import ( 17 | Dialog as DialogClaude, 18 | setup_llm_client as setup_llm_client_claude, 19 | ) 20 | from space.utils.model import get_model_response 21 | from space.utils.common import count_images_in_query 22 | 23 | 24 | def is_int(text: str): 25 | try: 26 | int(text) 27 | except Exception: 28 | return False 29 | else: 30 | return True 31 | 32 | 33 | @register_agent 34 | class QA_Agent(object): 35 | def __init__( 36 | self, 37 | model_name: str, 38 | host_port: str, 39 | save_dir: str, 40 | image_detail: str = "low", 41 | max_new_tokens: int = 2048, 42 | completion_cost_per_mil: float = 0.0, 43 | prompt_cost_per_mil: float = 0.0, 44 | subsampling_factor: int = 1, 45 | max_images_per_query: int = -1, 46 | **kwargs, 47 | ): 48 | self.model_name = model_name 49 | self.host_port = host_port 50 | self.save_dir = save_dir 51 | self.image_detail = image_detail 52 | self.max_new_tokens = max_new_tokens 53 | self.completion_cost_per_mil = completion_cost_per_mil 54 | self.prompt_cost_per_mil = prompt_cost_per_mil 55 | self.subsampling_factor = subsampling_factor 56 | self.max_images_per_query = max_images_per_query 57 | self.writer = None 58 | self.dialog = None 59 | self.completion_tokens = None 60 | self.prompt_tokens = None 61 | if self.model_name.startswith("claude"): 62 | self.client = setup_llm_client_claude(self.model_name) 63 | else: 64 | self.client = setup_llm_client(self.model_name, self.host_port) 65 | 66 | def reset(self): 67 | os.makedirs(self.save_dir, exist_ok=True) 68 | self.writer = MdUtils( 69 | file_name=os.path.join(self.save_dir, "transcript"), 70 | title="Question-answering task", 71 | ) 72 | if self.model_name.startswith("claude"): 73 | self.dialog = DialogClaude(self.writer) 74 | else: 75 | self.dialog = Dialog(self.writer) 76 | self.completion_tokens = 0 77 | self.prompt_tokens = 0 78 | 79 | def update_save_dir(self, save_dir: str): 80 | self.save_dir = save_dir 81 | 82 | def get_prediction(self, question_content: Union[list[Any], str], answer: Any): 83 | question_content = self.preprocess_question(question_content) 84 | n_images = count_images_in_query(question_content) 85 | if self.max_images_per_query >= 1 and n_images > self.max_images_per_query: 86 | qc_subset = [] 87 | n_images_subset = 0 88 | for q in question_content: 89 | if isinstance(q, dict) and q.get("type", None) == "image": 90 | n_images_subset += 1 91 | qc_subset.append(q) 92 | if n_images_subset == self.max_images_per_query: 93 | qc_subset.append( 94 | "More information will be provided in the next image. Do not say anything. Please wait before responding." 95 | ) 96 | self.dialog.add_user_message(content=qc_subset) 97 | self.dialog.add_assistant_message( 98 | content="I understand. I'll continue to wait for your next message before providing any analysis or response." 99 | ) 100 | # Reset 101 | n_images_subset = 0 102 | qc_subset = [] 103 | qc_subset.append( 104 | "This marks the end of my message to you. Please respond now." 105 | ) 106 | self.dialog.add_user_message(content=qc_subset) 107 | else: 108 | self.dialog.add_user_message(content=question_content) 109 | 110 | response = get_model_response( 111 | self.client, 112 | self.dialog.dialog, 113 | self.model_name, 114 | max_tokens=self.max_new_tokens, 115 | ) 116 | response_txt = self.postprocess_response(response["text"]) 117 | self.dialog.add_assistant_message(content=response_txt) 118 | pred = self.parse_answer_from_response(response_txt) 119 | # Add GT answer and prediction for reference 120 | self.dialog.log_writer.write( 121 | f"\n\nGround-truth answer: {answer}, prediction: {pred}" 122 | ) 123 | 124 | ############################################################################################ 125 | # Log token usage 126 | self.completion_tokens += response["completion_tokens"] 127 | self.prompt_tokens += response["prompt_tokens"] 128 | total_cost = ( 129 | self.completion_tokens * self.completion_cost_per_mil / 1.0e6 130 | + self.prompt_tokens * self.prompt_cost_per_mil / 1.0e6 131 | ) 132 | self.dialog.log_token_usage( 133 | self.prompt_tokens, self.completion_tokens, total_cost 134 | ) 135 | # Log time taken 136 | self.dialog.log_response_time(response["response_time"]) 137 | ############################################################################################ 138 | 139 | ############################################################################################ 140 | # Remove prior messages from history 141 | self.dialog.delete_last_message() 142 | self.dialog.delete_last_message() 143 | ############################################################################################ 144 | return pred 145 | 146 | def preprocess_question(self, question_content: Union[list[Any], str]): 147 | assert isinstance(question_content, str) or isinstance(question_content, list) 148 | if self.model_name.startswith("mistralai/Pixtral"): 149 | if isinstance(question_content, str): 150 | question_content = [{"type": "text", "text": question_content}] 151 | else: 152 | question_content_p = [] 153 | for q in question_content: 154 | if isinstance(q, str): 155 | question_content_p.append({"type": "text", "text": q}) 156 | elif isinstance(q, dict): 157 | question_content_p.append(q) 158 | else: 159 | raise ValueError(f"Unable to preprocess question content: {q}") 160 | question_content = question_content_p 161 | return question_content 162 | 163 | def postprocess_response(self, response_txt: str): 164 | if self.model_name.startswith("mistralai"): 165 | response_txt = response_txt.split("[/INST]")[-1].strip() 166 | return response_txt 167 | 168 | def parse_answer_from_response(self, text: str): 169 | outputs = re.findall(r"({.*?})", text, re.DOTALL) 170 | if len(outputs) == 0: 171 | print(f"Unable to parse answer from text:\n{text}") 172 | return None 173 | output = outputs[-1].strip() 174 | try: 175 | answer_dict = json.loads(output) 176 | except ValueError: 177 | print(f"Unable to decode json from text:\n{text}") 178 | answer = None 179 | else: 180 | if "answer" in answer_dict: 181 | answer = {"answer": answer_dict["answer"]} 182 | elif len(answer_dict) == 1: 183 | answer = {"answer": list(answer_dict.values())[0]} 184 | else: 185 | print( 186 | f"Ambiguous response in text (dict without key `answer`):\n{text}" 187 | ) 188 | answer = None 189 | if answer is not None and answer["answer"] is None: 190 | answer = None 191 | if answer is not None and is_int(answer["answer"]): 192 | answer["answer"] = int(answer["answer"]) 193 | 194 | if answer is not None: 195 | answer = answer["answer"] 196 | return answer 197 | 198 | def get_eval_cost(self): 199 | total_cost = ( 200 | self.completion_tokens * self.completion_cost_per_mil / 1.0e6 201 | + self.prompt_tokens * self.prompt_cost_per_mil / 1.0e6 202 | ) 203 | return { 204 | "total_cost": total_cost, 205 | "completion_tokens": self.completion_tokens, 206 | "prompt_tokens": self.prompt_tokens, 207 | } 208 | -------------------------------------------------------------------------------- /space/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | from space.registry import CONFIGS_REGISTRY 4 | from dataclasses import asdict 5 | 6 | import space.configs.claude # noqa 7 | import space.configs.gpt # noqa 8 | import space.configs.llama # noqa 9 | import space.configs.mistral # noqa 10 | import space.configs.phi # noqa 11 | import space.configs.yi # noqa 12 | 13 | 14 | def get_config(name: str): 15 | assert name in CONFIGS_REGISTRY 16 | cfg = CONFIGS_REGISTRY[name]() 17 | return asdict(cfg) 18 | -------------------------------------------------------------------------------- /space/configs/claude.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | from typing import Optional 4 | from dataclasses import dataclass 5 | 6 | from space.registry import register_config 7 | 8 | 9 | @dataclass 10 | class Claude35Sonnet_Base: 11 | model_name: str = "claude-3-5-sonnet-20240620" 12 | max_context_tokens: int = 200000 13 | completion_cost_per_mil: float = 15.0 14 | prompt_cost_per_mil: float = 3.0 15 | supports_system_prompt: bool = False 16 | max_images_per_query: int = 10 17 | use_vllm: bool = False 18 | 19 | 20 | @register_config("claude35sonnet_qa") 21 | @dataclass 22 | class Claude35Sonnet_QA(Claude35Sonnet_Base): 23 | agent_name: str = "QA_Agent" 24 | max_new_tokens: int = 4096 25 | host_port: Optional[str] = None 26 | save_dir: Optional[str] = None 27 | image_detail: str = "low" 28 | subsampling_factor: int = 1 29 | 30 | 31 | @register_config("claude35sonnet_egoqa") 32 | @dataclass 33 | class Claude35Sonnet_EgoQA(Claude35Sonnet_Base): 34 | agent_name: str = "QA_Agent" 35 | max_new_tokens: int = 4096 36 | host_port: Optional[str] = None 37 | save_dir: Optional[str] = None 38 | image_detail: str = "low" 39 | subsampling_factor: int = 2 40 | 41 | 42 | @register_config("claude35sonnet_egonav") 43 | @dataclass 44 | class Claude35Sonnet_EgoNav(Claude35Sonnet_Base): 45 | agent_name: str = "EgoNav_Agent" 46 | max_new_tokens: int = 2048 47 | max_history_length: int = 10 48 | host_port: Optional[str] = None 49 | save_dir: Optional[str] = None 50 | image_detail: str = "low" 51 | context_truncation_factor: float = 0.9 52 | subsampling_factor: int = 2 53 | 54 | 55 | @register_config("claude35sonnet_dminav") 56 | @dataclass 57 | class Claude35Sonnet_DiscreteMapImageNav(Claude35Sonnet_Base): 58 | agent_name: str = "DiscreteMapImage_Nav_Agent" 59 | max_new_tokens: int = 2048 60 | max_history_length: int = 50 61 | host_port: Optional[str] = None 62 | save_dir: Optional[str] = None 63 | image_detail: str = "low" 64 | context_truncation_factor: float = 0.9 65 | subsampling_factor: int = 1 66 | 67 | 68 | @register_config("claude35sonnet_dmtnav") 69 | @dataclass 70 | class Claude35Sonnet_DiscreteMapTextNav(Claude35Sonnet_Base): 71 | agent_name: str = "DiscreteMapText_Nav_Agent" 72 | max_new_tokens: int = 2048 73 | max_history_length: int = 50 74 | host_port: Optional[str] = None 75 | save_dir: Optional[str] = None 76 | image_detail: str = "low" 77 | context_truncation_factor: float = 0.9 78 | subsampling_factor: int = 1 79 | 80 | 81 | @register_config("claude35sonnet_cswm_vision") 82 | @dataclass 83 | class Claude35Sonnet_CSWM_Vision(Claude35Sonnet_Base): 84 | agent_name: str = "CSWM_Agent" 85 | task_mode: str = "vision" 86 | max_new_tokens: int = 4096 87 | max_history_length: int = -1 88 | host_port: Optional[str] = None 89 | save_dir: Optional[str] = None 90 | image_detail: str = "low" 91 | context_truncation_factor: float = 0.9 92 | 93 | 94 | @register_config("claude35sonnet_cswm_text") 95 | @dataclass 96 | class Claude35Sonnet_CSWM_Text(Claude35Sonnet_Base): 97 | agent_name: str = "CSWM_Agent" 98 | task_mode: str = "text" 99 | max_new_tokens: int = 4096 100 | max_history_length: int = -1 101 | host_port: Optional[str] = None 102 | save_dir: Optional[str] = None 103 | image_detail: str = "low" 104 | context_truncation_factor: float = 0.9 105 | 106 | 107 | @register_config("claude35sonnet_mct_vision") 108 | @dataclass 109 | class Claude35Sonnet_MCT_Vision(Claude35Sonnet_Base): 110 | agent_name: str = "MCT_Agent" 111 | description_type: str = "image" 112 | max_new_tokens: int = 4096 113 | max_history_length: int = 20 114 | host_port: Optional[str] = None 115 | save_dir: Optional[str] = None 116 | image_detail: str = "high" 117 | context_truncation_factor: float = 0.9 118 | 119 | 120 | @register_config("claude35sonnet_mct_text") 121 | @dataclass 122 | class Claude35Sonnet_MCT_Text(Claude35Sonnet_Base): 123 | agent_name: str = "MCT_Agent" 124 | description_type: str = "text" 125 | max_new_tokens: int = 4096 126 | max_history_length: int = 20 127 | host_port: Optional[str] = None 128 | save_dir: Optional[str] = None 129 | context_truncation_factor: float = 0.9 130 | -------------------------------------------------------------------------------- /space/configs/gpt.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | from typing import Optional 4 | 5 | from dataclasses import dataclass 6 | from space.registry import register_config 7 | 8 | 9 | @dataclass 10 | class GPT4V_Base: 11 | model_name: str = "gpt-4-turbo-2024-04-09" 12 | max_context_tokens: int = 128000 13 | completion_cost_per_mil: float = 30.0 14 | prompt_cost_per_mil: float = 10.0 15 | supports_system_prompt: bool = True 16 | use_vllm: bool = False 17 | 18 | 19 | @dataclass 20 | class GPT4O_Base: 21 | model_name: str = "gpt-4o-2024-05-13" 22 | max_context_tokens: int = 128000 23 | completion_cost_per_mil: float = 15.0 24 | prompt_cost_per_mil: float = 5.0 25 | supports_system_prompt: bool = True 26 | use_vllm: bool = False 27 | 28 | 29 | @register_config("gpt4v_qa") 30 | @dataclass 31 | class GPT4V_QA(GPT4V_Base): 32 | agent_name: str = "QA_Agent" 33 | max_new_tokens: int = 4096 34 | host_port: Optional[str] = None 35 | save_dir: Optional[str] = None 36 | image_detail: str = "low" 37 | subsampling_factor: int = 1 38 | 39 | 40 | @register_config("gpt4o_qa") 41 | @dataclass 42 | class GPT4O_QA(GPT4O_Base): 43 | agent_name: str = "QA_Agent" 44 | max_new_tokens: int = 4096 45 | host_port: Optional[str] = None 46 | save_dir: Optional[str] = None 47 | image_detail: str = "low" 48 | subsampling_factor: int = 1 49 | 50 | 51 | @register_config("gpt4v_egonav") 52 | @dataclass 53 | class GPT4V_EgoNav(GPT4V_Base): 54 | agent_name: str = "EgoNav_Agent" 55 | max_new_tokens: int = 2048 56 | max_history_length: int = 20 57 | host_port: Optional[str] = None 58 | save_dir: Optional[str] = None 59 | image_detail: str = "low" 60 | context_truncation_factor: float = 0.9 61 | subsampling_factor: int = 1 62 | 63 | 64 | @register_config("gpt4o_egonav") 65 | @dataclass 66 | class GPT4O_EgoNav(GPT4V_Base): 67 | agent_name: str = "EgoNav_Agent" 68 | max_new_tokens: int = 2048 69 | max_history_length: int = 20 70 | host_port: Optional[str] = None 71 | save_dir: Optional[str] = None 72 | image_detail: str = "low" 73 | context_truncation_factor: float = 0.9 74 | subsampling_factor: int = 1 75 | 76 | 77 | @register_config("gpt4v_dminav") 78 | @dataclass 79 | class GPT4V_DiscreteMapImageNav(GPT4V_Base): 80 | agent_name: str = "DiscreteMapImage_Nav_Agent" 81 | max_new_tokens: int = 2048 82 | max_history_length: int = 50 83 | host_port: Optional[str] = None 84 | save_dir: Optional[str] = None 85 | image_detail: str = "low" 86 | context_truncation_factor: float = 0.9 87 | subsampling_factor: int = 1 88 | 89 | 90 | @register_config("gpt4o_dminav") 91 | @dataclass 92 | class GPT4O_DiscreteMapImageNav(GPT4O_Base): 93 | agent_name: str = "DiscreteMapImage_Nav_Agent" 94 | max_new_tokens: int = 2048 95 | max_history_length: int = 50 96 | host_port: Optional[str] = None 97 | save_dir: Optional[str] = None 98 | image_detail: str = "low" 99 | context_truncation_factor: float = 0.9 100 | subsampling_factor: int = 1 101 | 102 | 103 | @register_config("gpt4v_dmtnav") 104 | @dataclass 105 | class GPT4V_DiscreteMapTextNav(GPT4V_Base): 106 | agent_name: str = "DiscreteMapText_Nav_Agent" 107 | max_new_tokens: int = 2048 108 | max_history_length: int = 50 109 | host_port: Optional[str] = None 110 | save_dir: Optional[str] = None 111 | image_detail: str = "low" 112 | context_truncation_factor: float = 0.9 113 | subsampling_factor: int = 1 114 | 115 | 116 | @register_config("gpt4o_dmtnav") 117 | @dataclass 118 | class GPT4O_DiscreteMapTextNav(GPT4O_Base): 119 | agent_name: str = "DiscreteMapText_Nav_Agent" 120 | max_new_tokens: int = 2048 121 | max_history_length: int = 50 122 | host_port: Optional[str] = None 123 | save_dir: Optional[str] = None 124 | image_detail: str = "low" 125 | context_truncation_factor: float = 0.9 126 | subsampling_factor: int = 1 127 | 128 | 129 | @register_config("gpt4v_cswm_vision") 130 | @dataclass 131 | class GPT4V_CSWM_Vision(GPT4V_Base): 132 | agent_name: str = "CSWM_Agent" 133 | task_mode: str = "vision" 134 | max_new_tokens: int = 4096 135 | max_history_length: int = -1 136 | host_port: Optional[str] = None 137 | save_dir: Optional[str] = None 138 | image_detail: str = "low" 139 | context_truncation_factor: float = 0.9 140 | 141 | 142 | @register_config("gpt4v_cswm_text") 143 | @dataclass 144 | class GPT4V_CSWM_Text(GPT4V_Base): 145 | agent_name: str = "CSWM_Agent" 146 | task_mode: str = "text" 147 | max_new_tokens: int = 4096 148 | max_history_length: int = -1 149 | host_port: Optional[str] = None 150 | save_dir: Optional[str] = None 151 | image_detail: str = "low" 152 | context_truncation_factor: float = 0.9 153 | 154 | 155 | @register_config("gpt4o_cswm_vision") 156 | @dataclass 157 | class GPT4O_CSWM_Vision(GPT4O_Base): 158 | agent_name: str = "CSWM_Agent" 159 | task_mode: str = "vision" 160 | max_new_tokens: int = 4096 161 | max_history_length: int = -1 162 | host_port: Optional[str] = None 163 | save_dir: Optional[str] = None 164 | image_detail: str = "low" 165 | context_truncation_factor: float = 0.9 166 | 167 | 168 | @register_config("gpt4o_cswm_text") 169 | @dataclass 170 | class GPT4O_CSWM_Text(GPT4O_Base): 171 | agent_name: str = "CSWM_Agent" 172 | task_mode: str = "text" 173 | max_new_tokens: int = 4096 174 | max_history_length: int = -1 175 | host_port: Optional[str] = None 176 | save_dir: Optional[str] = None 177 | image_detail: str = "low" 178 | context_truncation_factor: float = 0.9 179 | 180 | 181 | @register_config("gpt4v_mct_vision") 182 | @dataclass 183 | class GPT4V_MCT_Vision(GPT4V_Base): 184 | agent_name: str = "MCT_Agent" 185 | description_type: str = "image" 186 | max_new_tokens: int = 4096 187 | max_history_length: int = 20 188 | host_port: Optional[str] = None 189 | save_dir: Optional[str] = None 190 | image_detail: str = "high" 191 | context_truncation_factor: float = 0.9 192 | 193 | 194 | @register_config("gpt4v_mct_text") 195 | @dataclass 196 | class GPT4V_MCT_Text(GPT4V_Base): 197 | agent_name: str = "MCT_Agent" 198 | description_type: str = "text" 199 | max_new_tokens: int = 4096 200 | max_history_length: int = 20 201 | host_port: Optional[str] = None 202 | save_dir: Optional[str] = None 203 | context_truncation_factor: float = 0.9 204 | 205 | 206 | @register_config("gpt4o_mct_vision") 207 | @dataclass 208 | class GPT4O_MCT_Vision(GPT4O_Base): 209 | agent_name: str = "MCT_Agent" 210 | description_type: str = "image" 211 | max_new_tokens: int = 4096 212 | max_history_length: int = 20 213 | host_port: Optional[str] = None 214 | save_dir: Optional[str] = None 215 | image_detail: str = "high" 216 | context_truncation_factor: float = 0.9 217 | 218 | 219 | @register_config("gpt4o_mct_text") 220 | @dataclass 221 | class GPT4O_MCT_Text(GPT4O_Base): 222 | agent_name: str = "MCT_Agent" 223 | description_type: str = "text" 224 | max_new_tokens: int = 4096 225 | max_history_length: int = 20 226 | host_port: Optional[str] = None 227 | save_dir: Optional[str] = None 228 | context_truncation_factor: float = 0.9 229 | -------------------------------------------------------------------------------- /space/configs/llama.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | from typing import Any, Optional 4 | import copy 5 | from dataclasses import dataclass, field 6 | from space.registry import register_config 7 | 8 | 9 | VLLM_CFG = { 10 | "dtype": "auto", 11 | "trust_remote_code": True, 12 | "enable_prefix_caching": True, 13 | "tensor_parallel_size": None, 14 | } 15 | 16 | 17 | @dataclass 18 | class Llama3_8b_Base: 19 | model_name: str = "meta-llama/Meta-Llama-3-8B-Instruct" 20 | max_context_tokens: int = 8192 21 | completion_cost_per_mil: float = 0.0 22 | prompt_cost_per_mil: float = 0.0 23 | supports_system_prompt: bool = True 24 | use_vllm: bool = True 25 | vllm_cfg: dict[str, Any] = field(default_factory=lambda: copy.deepcopy(VLLM_CFG)) 26 | 27 | 28 | @dataclass 29 | class Llama3_70b_Base: 30 | model_name: str = "meta-llama/Meta-Llama-3-70B-Instruct" 31 | max_context_tokens: int = 8192 32 | completion_cost_per_mil: float = 0.0 33 | prompt_cost_per_mil: float = 0.0 34 | supports_system_prompt: bool = True 35 | use_vllm: bool = True 36 | vllm_cfg: dict[str, Any] = field(default_factory=lambda: copy.deepcopy(VLLM_CFG)) 37 | 38 | 39 | @register_config("llama3_8b_qa") 40 | @dataclass 41 | class Llama3_8b_QA(Llama3_8b_Base): 42 | agent_name: str = "QA_Agent" 43 | max_new_tokens: int = 2048 44 | host_port: str = "8001" 45 | save_dir: Optional[str] = None 46 | subsampling_factor: int = 1 47 | 48 | 49 | @register_config("llama3_70b_qa") 50 | @dataclass 51 | class Llama3_70b_QA(Llama3_70b_Base): 52 | agent_name: str = "QA_Agent" 53 | max_new_tokens: int = 2048 54 | host_port: str = "8001" 55 | save_dir: Optional[str] = None 56 | subsampling_factor: int = 1 57 | 58 | 59 | @register_config("llama3_8b_dmtnav") 60 | @dataclass 61 | class Llama3_8b_DiscreteMapTextNav(Llama3_8b_Base): 62 | agent_name: str = "DiscreteMapText_Nav_Agent" 63 | max_new_tokens: int = 1024 64 | max_history_length: int = 50 65 | host_port: str = "8001" 66 | save_dir: Optional[str] = None 67 | context_truncation_factor: float = 0.9 68 | subsampling_factor: int = 1 69 | 70 | 71 | @register_config("llama3_70b_dmtnav") 72 | @dataclass 73 | class Llama3_70b_DiscreteMapTextNav(Llama3_70b_Base): 74 | agent_name: str = "DiscreteMapText_Nav_Agent" 75 | max_new_tokens: int = 1024 76 | max_history_length: int = 50 77 | host_port: str = "8001" 78 | save_dir: Optional[str] = None 79 | context_truncation_factor: float = 0.9 80 | subsampling_factor: int = 1 81 | 82 | 83 | @register_config("llama3_8b_cswm_text") 84 | @dataclass 85 | class Llama3_8b_CSWM_Text(Llama3_8b_Base): 86 | agent_name: str = "CSWM_Agent" 87 | task_mode: str = "text" 88 | max_new_tokens: int = 2048 89 | max_history_length: int = -1 90 | host_port: str = "8001" 91 | save_dir: Optional[str] = None 92 | context_truncation_factor: float = 0.9 93 | 94 | 95 | @register_config("llama3_70b_cswm_text") 96 | @dataclass 97 | class Llama3_70b_CSWM_Text(Llama3_70b_Base): 98 | agent_name: str = "CSWM_Agent" 99 | task_mode: str = "text" 100 | max_new_tokens: int = 2048 101 | max_history_length: int = -1 102 | host_port: str = "8001" 103 | save_dir: Optional[str] = None 104 | context_truncation_factor: float = 0.9 105 | 106 | 107 | @register_config("llama3_8b_mct_text") 108 | @dataclass 109 | class Llama3_8b_MCT_Text(Llama3_8b_Base): 110 | agent_name: str = "MCT_Agent" 111 | description_type: str = "text" 112 | max_new_tokens: int = 2048 113 | max_history_length: int = 20 114 | host_port: str = "8001" 115 | save_dir: Optional[str] = None 116 | context_truncation_factor: float = 0.8 117 | 118 | 119 | @register_config("llama3_70b_mct_text") 120 | @dataclass 121 | class Llama3_70b_MCT_Text(Llama3_70b_Base): 122 | agent_name: str = "MCT_Agent" 123 | description_type: str = "text" 124 | max_new_tokens: int = 2048 125 | max_history_length: int = 20 126 | host_port: str = "8001" 127 | save_dir: Optional[str] = None 128 | context_truncation_factor: float = 0.8 129 | -------------------------------------------------------------------------------- /space/configs/mistral.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | from typing import Any, Optional 4 | import copy 5 | from dataclasses import dataclass, field 6 | from space.registry import register_config 7 | 8 | 9 | VLLM_CFG = { 10 | "dtype": "auto", 11 | "trust_remote_code": True, 12 | "enable_prefix_caching": True, 13 | "config_format": "mistral", 14 | "load_format": "mistral", 15 | "tokenizer_mode": "mistral", 16 | "tensor_parallel_size": None, 17 | } 18 | 19 | PIXTRAL_VLLM_CFG = { 20 | "dtype": "auto", 21 | "trust_remote_code": True, 22 | "max_model_len": 32768, 23 | "enable_prefix_caching": True, 24 | "config_format": "mistral", 25 | "load_format": "mistral", 26 | "tokenizer_mode": "mistral", 27 | "limit_mm_per_prompt": "image=20", 28 | "tensor_parallel_size": None, 29 | } 30 | 31 | 32 | @dataclass 33 | class Mixtral8x7b_Base: 34 | model_name: str = "mistralai/Mixtral-8x7B-Instruct-v0.1" 35 | max_context_tokens: int = 32768 36 | completion_cost_per_mil: float = 0.0 37 | prompt_cost_per_mil: float = 0.0 38 | supports_system_prompt: bool = False 39 | use_vllm: bool = True 40 | vllm_cfg: dict[str, Any] = field(default_factory=lambda: copy.deepcopy(VLLM_CFG)) 41 | 42 | 43 | @dataclass 44 | class Mixtral8x22b_Base: 45 | model_name: str = "mistralai/Mixtral-8x22B-Instruct-v0.1" 46 | max_context_tokens: int = 32768 47 | completion_cost_per_mil: float = 0.0 48 | prompt_cost_per_mil: float = 0.0 49 | supports_system_prompt: bool = False 50 | use_vllm: bool = True 51 | vllm_cfg: dict[str, Any] = field(default_factory=lambda: copy.deepcopy(VLLM_CFG)) 52 | 53 | 54 | @dataclass 55 | class MistralLarge2407_Base: 56 | model_name: str = "mistralai/Mistral-Large-Instruct-2407" 57 | max_context_tokens: int = 128000 58 | completion_cost_per_mil: float = 0.0 59 | prompt_cost_per_mil: float = 0.0 60 | supports_system_prompt: bool = False 61 | use_vllm: bool = True 62 | vllm_cfg: dict[str, Any] = field(default_factory=lambda: copy.deepcopy(VLLM_CFG)) 63 | 64 | 65 | @dataclass 66 | class Pixtral12b_Base: 67 | model_name: str = "mistralai/Pixtral-12B-2409" 68 | max_context_tokens: int = 32768 69 | completion_cost_per_mil: float = 0.0 70 | prompt_cost_per_mil: float = 0.0 71 | supports_system_prompt: bool = False 72 | use_vllm: bool = True 73 | vllm_cfg: dict[str, Any] = field( 74 | default_factory=lambda: copy.deepcopy(PIXTRAL_VLLM_CFG) 75 | ) 76 | 77 | 78 | @register_config("mixtral8x7b_qa") 79 | @dataclass 80 | class Mixtral8x7b_QA(Mixtral8x7b_Base): 81 | agent_name: str = "QA_Agent" 82 | max_new_tokens: int = 4096 83 | host_port: str = "8001" 84 | save_dir: Optional[str] = None 85 | subsampling_factor: int = 1 86 | 87 | 88 | @register_config("mixtral8x22b_qa") 89 | @dataclass 90 | class Mixtral8x22b_QA(Mixtral8x22b_Base): 91 | agent_name: str = "QA_Agent" 92 | max_new_tokens: int = 4096 93 | host_port: str = "8001" 94 | save_dir: Optional[str] = None 95 | subsampling_factor: int = 1 96 | 97 | 98 | @register_config("mistral123b_qa") 99 | @dataclass 100 | class MistralLarge2407_QA(MistralLarge2407_Base): 101 | agent_name: str = "QA_Agent" 102 | max_new_tokens: int = 4096 103 | host_port: str = "8001" 104 | save_dir: Optional[str] = None 105 | subsampling_factor: int = 1 106 | 107 | 108 | @register_config("pixtral12b_qa") 109 | @dataclass 110 | class Pixtral12b_QA(Pixtral12b_Base): 111 | agent_name: str = "QA_Agent" 112 | max_new_tokens: int = 4096 113 | host_port: str = "8001" 114 | save_dir: Optional[str] = None 115 | subsampling_factor: int = 1 116 | 117 | 118 | @register_config("mixtral8x7b_dmtnav") 119 | @dataclass 120 | class Mixtral8x7b_DiscreteMapTextNav(Mixtral8x7b_Base): 121 | agent_name: str = "DiscreteMapText_Nav_Agent" 122 | max_new_tokens: int = 2048 123 | max_history_length: int = 50 124 | host_port: str = "8001" 125 | save_dir: Optional[str] = None 126 | context_truncation_factor: float = 0.9 127 | subsampling_factor: int = 1 128 | 129 | 130 | @register_config("mixtral8x22b_dmtnav") 131 | @dataclass 132 | class Mixtral8x22b_DiscreteMapTextNav(Mixtral8x22b_Base): 133 | agent_name: str = "DiscreteMapText_Nav_Agent" 134 | max_new_tokens: int = 2048 135 | max_history_length: int = 50 136 | host_port: str = "8001" 137 | save_dir: Optional[str] = None 138 | context_truncation_factor: float = 0.9 139 | subsampling_factor: int = 1 140 | 141 | 142 | @register_config("mistral123b_dmtnav") 143 | @dataclass 144 | class MistralLarge2407_DiscreteMapTextNav(MistralLarge2407_Base): 145 | agent_name: str = "DiscreteMapText_Nav_Agent" 146 | max_new_tokens: int = 4096 147 | max_history_length: int = 50 148 | host_port: str = "8001" 149 | save_dir: Optional[str] = None 150 | context_truncation_factor: float = 0.9 151 | subsampling_factor: int = 1 152 | 153 | 154 | @register_config("mixtral8x7b_cswm_text") 155 | @dataclass 156 | class Mixtral8x7b_CSWM_Text(Mixtral8x7b_Base): 157 | agent_name: str = "CSWM_Agent" 158 | task_mode: str = "text" 159 | max_new_tokens: int = 2048 160 | max_history_length: int = -1 161 | host_port: str = "8001" 162 | save_dir: Optional[str] = None 163 | context_truncation_factor: float = 0.9 164 | 165 | 166 | @register_config("mixtral8x22b_cswm_text") 167 | @dataclass 168 | class Mixtral8x22b_CSWM_Text(Mixtral8x22b_Base): 169 | agent_name: str = "CSWM_Agent" 170 | task_mode: str = "text" 171 | max_new_tokens: int = 2048 172 | max_history_length: int = -1 173 | host_port: str = "8001" 174 | save_dir: Optional[str] = None 175 | context_truncation_factor: float = 0.9 176 | 177 | 178 | @register_config("mistral123b_cswm_text") 179 | @dataclass 180 | class MistralLarge2407_CSWM_Text(MistralLarge2407_Base): 181 | agent_name: str = "CSWM_Agent" 182 | task_mode: str = "text" 183 | max_new_tokens: int = 4096 184 | max_history_length: int = -1 185 | host_port: str = "8001" 186 | save_dir: Optional[str] = None 187 | context_truncation_factor: float = 0.9 188 | 189 | 190 | @register_config("mixtral8x7b_mct_text") 191 | @dataclass 192 | class Mixtral8x7b_MCT_Text(Mixtral8x7b_Base): 193 | agent_name: str = "MCT_Agent" 194 | description_type: str = "text" 195 | max_new_tokens: int = 2048 196 | max_history_length: int = 20 197 | host_port: str = "8001" 198 | save_dir: Optional[str] = None 199 | context_truncation_factor: float = 0.9 200 | 201 | 202 | @register_config("mixtral8x22b_mct_text") 203 | @dataclass 204 | class Mixtral8x22b_MCT_Text(Mixtral8x22b_Base): 205 | agent_name: str = "MCT_Agent" 206 | description_type: str = "text" 207 | max_new_tokens: int = 2048 208 | max_history_length: int = 20 209 | host_port: str = "8001" 210 | save_dir: Optional[str] = None 211 | context_truncation_factor: float = 0.9 212 | 213 | 214 | @register_config("mistral123b_mct_text") 215 | @dataclass 216 | class MistralLarge2407_MCT_Text(MistralLarge2407_Base): 217 | agent_name: str = "MCT_Agent" 218 | description_type: str = "text" 219 | max_new_tokens: int = 4096 220 | max_history_length: int = 20 221 | host_port: str = "8001" 222 | save_dir: Optional[str] = None 223 | context_truncation_factor: float = 0.9 224 | -------------------------------------------------------------------------------- /space/configs/phi.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | from typing import Any, Optional 4 | import copy 5 | from dataclasses import dataclass, field 6 | from space.registry import register_config 7 | 8 | 9 | VLLM_CFG = { 10 | "dtype": "auto", 11 | "trust_remote_code": True, 12 | "limit_mm_per_prompt": "image=20", 13 | "tensor_parallel_size": None, 14 | } 15 | 16 | 17 | @dataclass 18 | class Phi35Vision_Base: 19 | model_name: str = "microsoft/Phi-3.5-vision-instruct" 20 | max_context_tokens: int = 128000 21 | completion_cost_per_mil: float = 0.0 22 | prompt_cost_per_mil: float = 0.0 23 | supports_system_prompt: bool = True 24 | use_vllm: bool = True 25 | vllm_cfg: dict[str, Any] = field(default_factory=lambda: copy.deepcopy(VLLM_CFG)) 26 | 27 | 28 | @register_config("phi35vision_qa") 29 | @dataclass 30 | class Phi35Vision_QA(Phi35Vision_Base): 31 | agent_name: str = "QA_Agent" 32 | max_new_tokens: int = 4096 33 | host_port: str = "8001" 34 | save_dir: Optional[str] = None 35 | subsampling_factor: int = 1 36 | -------------------------------------------------------------------------------- /space/configs/yi.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | from typing import Any, Optional 4 | import copy 5 | from dataclasses import dataclass, field 6 | from space.registry import register_config 7 | 8 | 9 | VLLM_CFG = { 10 | "dtype": "auto", 11 | "trust_remote_code": True, 12 | "enable_prefix_caching": True, 13 | "tensor_parallel_size": None, 14 | } 15 | 16 | 17 | @dataclass 18 | class Yi15_9b_Base: 19 | model_name: str = "01-ai/Yi-1.5-9B-Chat-16K" 20 | max_context_tokens: int = 16000 21 | completion_cost_per_mil: float = 0.0 22 | prompt_cost_per_mil: float = 0.0 23 | supports_system_prompt: bool = True 24 | use_vllm: bool = True 25 | vllm_cfg: dict[str, Any] = field(default_factory=lambda: copy.deepcopy(VLLM_CFG)) 26 | 27 | 28 | @dataclass 29 | class Yi15_34b_Base: 30 | model_name: str = "01-ai/Yi-1.5-34B-Chat-16K" 31 | max_context_tokens: int = 16000 32 | completion_cost_per_mil: float = 0.0 33 | prompt_cost_per_mil: float = 0.0 34 | supports_system_prompt: bool = True 35 | use_vllm: bool = True 36 | vllm_cfg: dict[str, Any] = field(default_factory=lambda: copy.deepcopy(VLLM_CFG)) 37 | 38 | 39 | @register_config("yi15_9b_qa") 40 | @dataclass 41 | class Yi15_9b_QA(Yi15_9b_Base): 42 | agent_name: str = "QA_Agent" 43 | max_new_tokens: int = 2048 44 | host_port: str = "8001" 45 | save_dir: Optional[str] = None 46 | subsampling_factor: int = 1 47 | 48 | 49 | @register_config("yi15_34b_qa") 50 | @dataclass 51 | class Yi15_34b_QA(Yi15_34b_Base): 52 | agent_name: str = "QA_Agent" 53 | max_new_tokens: int = 2048 54 | host_port: str = "8001" 55 | save_dir: Optional[str] = None 56 | subsampling_factor: int = 1 57 | 58 | 59 | @register_config("yi15_9b_dmtnav") 60 | @dataclass 61 | class Yi15_9b_DiscreteMapTextNav(Yi15_9b_Base): 62 | agent_name: str = "DiscreteMapText_Nav_Agent" 63 | max_new_tokens: int = 1024 64 | max_history_length: int = 50 65 | host_port: str = "8001" 66 | save_dir: Optional[str] = None 67 | context_truncation_factor: float = 0.9 68 | subsampling_factor: int = 1 69 | 70 | 71 | @register_config("yi15_34b_dmtnav") 72 | @dataclass 73 | class Yi15_34b_DiscreteMapTextNav(Yi15_34b_Base): 74 | agent_name: str = "DiscreteMapText_Nav_Agent" 75 | max_new_tokens: int = 1024 76 | max_history_length: int = 50 77 | host_port: str = "8001" 78 | save_dir: Optional[str] = None 79 | context_truncation_factor: float = 0.9 80 | subsampling_factor: int = 1 81 | 82 | 83 | @register_config("yi15_9b_cswm_text") 84 | @dataclass 85 | class Yi15_9b_CSWM_Text(Yi15_9b_Base): 86 | agent_name: str = "CSWM_Agent" 87 | task_mode: str = "text" 88 | max_new_tokens: int = 2048 89 | max_history_length: int = 50 90 | host_port: str = "8001" 91 | save_dir: Optional[str] = None 92 | context_truncation_factor: float = 0.9 93 | 94 | 95 | @register_config("yi15_34b_cswm_text") 96 | @dataclass 97 | class Yi15_34b_CSWM_Text(Yi15_34b_Base): 98 | agent_name: str = "CSWM_Agent" 99 | task_mode: str = "text" 100 | max_new_tokens: int = 2048 101 | max_history_length: int = 50 102 | host_port: str = "8001" 103 | save_dir: Optional[str] = None 104 | context_truncation_factor: float = 0.9 105 | 106 | 107 | @register_config("yi15_9b_mct_text") 108 | @dataclass 109 | class Yi15_9b_MCT_Text(Yi15_9b_Base): 110 | agent_name: str = "MCT_Agent" 111 | description_type: str = "text" 112 | max_new_tokens: int = 2048 113 | max_history_length: int = 20 114 | host_port: str = "8001" 115 | save_dir: Optional[str] = None 116 | context_truncation_factor: float = 0.9 117 | 118 | 119 | @register_config("yi15_34b_mct_text") 120 | @dataclass 121 | class Yi15_34b_MCT_Text(Yi15_34b_Base): 122 | agent_name: str = "MCT_Agent" 123 | description_type: str = "text" 124 | max_new_tokens: int = 2048 125 | max_history_length: int = 20 126 | host_port: str = "8001" 127 | save_dir: Optional[str] = None 128 | context_truncation_factor: float = 0.9 129 | -------------------------------------------------------------------------------- /space/envs/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | from space.registry import ENVS_REGISTRY 4 | 5 | import space.envs.cswm # noqa 6 | import space.envs.mct # noqa 7 | import space.envs.nav_dm # noqa 8 | import space.envs.nav_ego # noqa 9 | 10 | 11 | def get_env(name: str, *args, **kwargs): 12 | assert name in ENVS_REGISTRY 13 | env = ENVS_REGISTRY[name](*args, **kwargs) 14 | return env 15 | -------------------------------------------------------------------------------- /space/envs/cswm.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | import json 4 | import os.path as osp 5 | import random 6 | from dataclasses import dataclass 7 | 8 | from abc import ABC 9 | import cv2 10 | import imageio 11 | import numpy as np 12 | 13 | from space.registry import register_env 14 | 15 | N_BOXES_TO_MAX_STEPS = {3: 8, 4: 12, 5: 18, 6: 25, 7: 36} 16 | 17 | 18 | @dataclass 19 | class Rect: 20 | x: int 21 | y: int 22 | w: int 23 | h: int 24 | 25 | def get_extents(self): 26 | return (self.x, self.y, self.x + self.w - 1, self.y + self.h - 1) 27 | 28 | 29 | class Base_CSWM_Env(ABC): 30 | r"""Cambridge Spatial Working Memory test: 31 | N treasures are hidden in N boxes one at a time. Search for the currently hidden treasure by 32 | selecting a box and opening it. If the treasure is hidden in this box, it is collected and the 33 | next treasure is hidden in a new box (where the treasure was never hidden before). The goal is 34 | to collect all the treasures. 35 | 36 | Reference: https://cambridgecognition.com/spatial-working-memory-swm/ 37 | """ 38 | 39 | def __init__(self, load_dir: str): 40 | # Check for text vs. visual version of game 41 | self.game_mode = ( 42 | "vision" if osp.isfile(osp.join(load_dir, "board.png")) else "text" 43 | ) 44 | with open(osp.join(load_dir, "states.json"), "r") as fp: 45 | states = json.load(fp) 46 | if self.game_mode == "vision": 47 | self.image = imageio.imread(osp.join(load_dir, "board.png")) 48 | self.rects = [Rect(x, y, w, h) for x, y, w, h in states["rects"]] 49 | self.treasure_boxes = [ 50 | Rect(x, y, w, h) for x, y, w, h in states["treasure_boxes"] 51 | ] 52 | else: 53 | self.board_array = np.array(states["board_array"]) 54 | self.rects = states["rects"] 55 | self.treasures = states["treasures"] 56 | self.max_steps = N_BOXES_TO_MAX_STEPS[len(self.rects)] 57 | self.curr_treasure_idx = None 58 | self.selected_idx = None 59 | self.n_collected = None 60 | self.n_steps_taken = None 61 | self.finished = None 62 | self.t2r_mapping = None 63 | self.r2t_mapping = None 64 | 65 | def reset(self): 66 | self.n_collected = 0 67 | self.n_steps_taken = 0 68 | self.curr_treasure_idx = self.treasures[self.n_collected] 69 | self.selected_idx = None 70 | self.finished = False 71 | self.t2r_mapping, self.r2t_mapping = self.sample_random_mapping() 72 | return self.render(None) 73 | 74 | def sample_random_mapping(self): 75 | n = len(self.rects) 76 | mapping = {i: j for i, j in enumerate(np.random.permutation(n).tolist())} 77 | inv_mapping = {j: i for i, j in mapping.items()} 78 | return mapping, inv_mapping 79 | 80 | def step(self, action: int): 81 | assert not self.finished 82 | assert self.is_valid_action(action) 83 | action = self.apply_r2t_mapping(action) 84 | self.selected_idx = action 85 | done = False 86 | treasure_collected = None 87 | if action == self.curr_treasure_idx: 88 | treasure_collected = self.curr_treasure_idx 89 | self.n_collected += 1 90 | if self.n_collected < len(self.rects): 91 | self.curr_treasure_idx = self.treasures[self.n_collected] 92 | else: 93 | done = True 94 | self.t2r_mapping, self.r2t_mapping = self.sample_random_mapping() 95 | obs = self.render(treasure_collected) 96 | self.n_steps_taken += 1 97 | if self.n_steps_taken >= self.max_steps: 98 | done = True 99 | if done: 100 | self.finished = True 101 | return obs, 0.0, done, {} 102 | 103 | def render(self, treasure_collected: int): 104 | raise NotImplementedError 105 | 106 | def apply_r2t_mapping(self, act: int): 107 | raise NotImplementedError 108 | 109 | def is_valid_action(self, act: int): 110 | raise NotImplementedError 111 | 112 | def sample_random_action(self): 113 | raise NotImplementedError 114 | 115 | 116 | @register_env 117 | class Vision_CSWM_Env(Base_CSWM_Env): 118 | def render(self, treasure_collected: int): 119 | font = cv2.FONT_HERSHEY_TRIPLEX 120 | fontScale = np.ceil(self.image.shape[0] / 600.0).item() 121 | thickness = int(np.ceil(self.image.shape[0] / 600.0).item()) 122 | image = np.copy(self.image) 123 | for i, rect in enumerate(self.rects): 124 | if i == treasure_collected: 125 | sx = rect.x + 5 126 | sy = rect.y + 5 127 | ex = rect.x + rect.w - 6 128 | ey = rect.y + rect.h - 6 129 | image = cv2.rectangle(image, (sx, sy), (ex, ey), (255, 191, 0), -1) 130 | text = f"{self.t2r_mapping[i]}" 131 | textsize, _ = cv2.getTextSize(text, font, fontScale, thickness) 132 | text_x = rect.x + (rect.w - textsize[0]) // 2 133 | text_y = rect.y + rect.h - 1 - (rect.h - textsize[1]) // 2 134 | cv2.putText( 135 | image, text, (text_x, text_y), font, fontScale, (0, 0, 0), thickness 136 | ) 137 | # Fill treasure boxes 138 | for i in range(self.n_collected): 139 | r = self.treasure_boxes[i] 140 | image = cv2.rectangle( 141 | image, (r.x, r.y), (r.x + r.w - 1, r.y + r.h - 1), (255, 191, 0), -1 142 | ) 143 | return image 144 | 145 | def apply_r2t_mapping(self, act): 146 | return self.r2t_mapping[act] 147 | 148 | def is_valid_action(self, act): 149 | return act >= 0 and act < len(self.rects) 150 | 151 | def sample_random_action(self): 152 | return random.randint(0, len(self.rects) - 1) 153 | 154 | 155 | @register_env 156 | class Text_CSWM_Env(Base_CSWM_Env): 157 | def render(self, treasure_collected: int): 158 | board_array = np.copy(self.board_array) 159 | for i, (x, y) in enumerate(self.rects): 160 | if i == treasure_collected: 161 | board_array[y, x] = "T" 162 | else: 163 | board_array[y, x] = f"{self.t2r_mapping[i] + 1}" 164 | obs = "Here is the current view of the board. You must find the next treasure. Note that the numbers of the boxes have changed, but the box locations are fixed. Decide which box location you want to open next. Then provide the number associated with the box as the action.\n\n" 165 | obs += self.convert_array_to_str(board_array) + "\n\n" 166 | obs += ( 167 | f"Number of treasures collected: {self.n_collected} / {len(self.treasures)}" 168 | ) 169 | return obs 170 | 171 | def convert_array_to_str(self, array: np.ndarray): 172 | array_str = [] 173 | for r in array: 174 | array_str.append(",".join(r.tolist())) 175 | array_str = "\n".join(array_str) 176 | return array_str 177 | 178 | def apply_r2t_mapping(self, act): 179 | return self.r2t_mapping[act - 1] 180 | 181 | def is_valid_action(self, act): 182 | return act > 0 and act <= len(self.rects) 183 | 184 | def sample_random_action(self): 185 | return random.randint(1, len(self.rects)) 186 | -------------------------------------------------------------------------------- /space/envs/mct.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | import json 4 | import os 5 | 6 | import cv2 7 | import networkx as nx 8 | import numpy as np 9 | 10 | from space.registry import register_env 11 | from space.envs.nav_dm import evaluate_path_efficiency as evaluate_path_efficiency 12 | 13 | 14 | @register_env 15 | class MCT_Env: 16 | valid_actions: list[str] = ["up", "down", "left", "right"] 17 | valid_description_types: list[str] = ["image", "text"] 18 | 19 | def __init__( 20 | self, 21 | env_dir: str, 22 | description_type: str, 23 | num_pixels_per_cell: int = 100, 24 | ): 25 | assert description_type in self.valid_description_types 26 | self.description_type = description_type 27 | self.num_pixels_per_cell = num_pixels_per_cell 28 | 29 | with open(os.path.join(env_dir, "info.json")) as fp: 30 | info = json.load(fp) 31 | self.start = tuple(info["start"]) # (r, c) 32 | self.goal = tuple(info["goal"]) # (r, c) 33 | self.maze = np.load(os.path.join(env_dir, "maze.npz"))["maze"] 34 | 35 | # Create graph for shortest-path planning 36 | self.create_graph_from_maze() 37 | self.shortest_path = self.get_shortest_path_from_nodes(self.start, self.goal) 38 | 39 | # Navigation state maintenance 40 | self.current = None 41 | self.steps_taken = None 42 | self.actions_taken = None 43 | self.collided = None 44 | self.path_taken = None 45 | 46 | def reset(self): 47 | self.current = self.start 48 | self.steps_taken = 0 49 | self.actions_taken = [] 50 | self.collided = False 51 | self.path_taken = [self.current] 52 | obs = self.get_observation() 53 | return obs, {"has_collided": self.collided} 54 | 55 | def step(self, action: str): 56 | self.collided = self._update_step(action) 57 | self.actions_taken.append(action) 58 | obs = self.get_observation() 59 | return obs, {"has_collided": self.collided} 60 | 61 | def _update_step(self, action: str): 62 | assert action in self.valid_actions 63 | r, c = self.current 64 | if action == "left": 65 | next_pos = (r, c - 1) 66 | elif action == "right": 67 | next_pos = (r, c + 1) 68 | elif action == "up": 69 | next_pos = (r - 1, c) 70 | else: 71 | next_pos = (r + 1, c) 72 | if next_pos in self.nodes: 73 | self.current = next_pos 74 | has_collided = False 75 | else: 76 | has_collided = True 77 | self.steps_taken += 1 78 | return has_collided 79 | 80 | def create_graph_from_maze(self): 81 | H, W = self.maze.shape[:2] 82 | self.graph = nx.Graph() 83 | self.nodes = set() 84 | ## Add nodes 85 | for r in range(H): 86 | for c in range(W): 87 | if self.maze[r, c] == 1: 88 | self.graph.add_node((r, c)) 89 | self.nodes.add((r, c)) 90 | ## Add edges 91 | for r in range(H): 92 | for c in range(W): 93 | if self.maze[r, c] == 1: 94 | # Check neighbors (only forward looking) 95 | nbs = [(r + 1, c), (r, c + 1)] 96 | for r_, c_ in nbs: 97 | if not (r_ >= 0 and r_ < H and c_ >= 0 and c_ < W): 98 | continue 99 | if self.maze[r_, c_] == 1: 100 | self.graph.add_edge((r, c), (r_, c_)) 101 | 102 | def get_shortest_path_from_nodes( 103 | self, start_node: tuple[int, int], goal_node: tuple[int, int] 104 | ): 105 | assert start_node in self.nodes 106 | assert goal_node in self.nodes 107 | return nx.shortest_path(self.graph, start_node, goal_node) 108 | 109 | def get_observation(self): 110 | if self.description_type == "image": 111 | return self.render_visual_observation(self.maze, self.current, self.goal) 112 | elif self.description_type == "text": 113 | return self.render_textual_observation(self.maze, self.current, self.goal) 114 | else: 115 | raise ValueError( 116 | f"get_observation() is not defined for description type: {self.description_type}" 117 | ) 118 | 119 | def render_visual_observation( 120 | self, maze: np.ndarray, current_loc: tuple[int, int], goal_loc: tuple[int, int] 121 | ): 122 | N = self.num_pixels_per_cell 123 | H, W = maze.shape 124 | image = np.zeros((N * H, N * W, 3), dtype=np.uint8) 125 | ############################################################################################ 126 | # Apply Pacman-style coloring 127 | # --------------------------- 128 | # Blue - walls 129 | # Black - free spaces 130 | # Yellow - current position 131 | # Red - goal position 132 | ############################################################################################ 133 | # Color walls 134 | for r in range(H): 135 | for c in range(W): 136 | if maze[r, c] == 0: 137 | image[N * r : N * (r + 1), N * c : N * (c + 1), :] = np.array( 138 | [0, 0, 255] 139 | ) 140 | # Color current position 141 | r, c = current_loc 142 | start_x = int(c * N + N * 0.2) 143 | start_y = int(r * N + N * 0.2) 144 | end_x = int(c * N + N * 0.8) 145 | end_y = int(r * N + N * 0.8) 146 | image = cv2.rectangle( 147 | image, (start_x, start_y), (end_x, end_y), (255, 255, 0), -1 148 | ) 149 | # Color goal position (if visible) 150 | r, c = goal_loc 151 | if r >= 0 and r < H and c >= 0 and c < W: 152 | center_x = c * N + N // 2 153 | center_y = r * N + N // 2 154 | image = cv2.circle( 155 | image, (center_x, center_y), int(N * 0.25), (255, 0, 0), -1 156 | ) 157 | ############################################################################################ 158 | # Add status message below image: Did the agent collide after the previous action? 159 | ############################################################################################ 160 | status_size = int(0.1 * N * W) 161 | status_image = np.full((status_size, N * W, 3), 128, dtype=np.uint8) 162 | if self.collided: 163 | text = "The previous action caused you to collide into a wall." 164 | font = cv2.FONT_HERSHEY_TRIPLEX 165 | fontScale = np.ceil(status_size / 200.0).item() 166 | thickness = int(np.ceil(status_size / 200.0).item()) 167 | textsize, _ = cv2.getTextSize(text, font, fontScale, thickness) 168 | textX = (status_image.shape[1] - textsize[0]) // 2 169 | textY = (status_image.shape[0] + textsize[1]) // 2 170 | status_image = cv2.putText( 171 | status_image, 172 | text, 173 | (textX, textY), 174 | font, 175 | fontScale, 176 | (255, 255, 255), 177 | thickness, 178 | ) 179 | image = np.concatenate([image, status_image], axis=0) 180 | 181 | return image 182 | 183 | def render_textual_observation( 184 | self, 185 | maze: np.ndarray, 186 | current_loc: tuple[int, int], 187 | goal_loc: tuple[int, int], 188 | add_positional_information: bool = True, 189 | ): 190 | desc = "" 191 | if self.collided: 192 | desc += "Collision alert: The previous action caused you to collide into a wall.\n\n" 193 | desc += "Here is the current view of the maze.\n\n" 194 | # Array-like description of maze 195 | maze_str = np.array([[str(int(col)) for col in row] for row in maze]) 196 | maze_str[current_loc[0], current_loc[1]] = "A" 197 | maze_str[goal_loc[0], goal_loc[1]] = "G" 198 | for row in maze_str: 199 | desc += ",".join(row.tolist()) + "\n" 200 | desc += "\n\n" 201 | desc += "0 represents obstacles. 1 represents free spaces. G is the goal. A is your current position in the maze.\n" 202 | if add_positional_information: 203 | desc += ( 204 | "Your current location in the maze is row, column = ({}, {}).\n".format( 205 | *current_loc 206 | ) 207 | ) 208 | desc += "The goal location is row, column = ({}, {}).\n".format(*goal_loc) 209 | 210 | return desc 211 | -------------------------------------------------------------------------------- /space/envs/nav_ego.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | from space.registry import register_env 4 | from dataclasses import field 5 | 6 | import os 7 | import cv2 8 | from typing import Any 9 | from pathlib import Path 10 | 11 | try: 12 | import habitat_sim 13 | from space.utils.habitat import load_sim, compute_quaternion_from_heading 14 | except ImportError as e: 15 | print(f"WARNING: Failed to import habitat_sim. Error: {str(e)}") 16 | import numpy as np 17 | import json 18 | 19 | 20 | GOAL_MAPPING = { 21 | "CogSci_Env_1_00000": "Painting of a french horn", 22 | "CogSci_Env_1_00001": "Painting of an aeroplane", 23 | "CogSci_Env_1_00002": "Painting of a power drill", 24 | "CogSci_Env_2_00000": "Painting of a turtle", 25 | "CogSci_Env_2_00001": "Painting of a stove", 26 | "CogSci_Env_2_00002": "Painting of a cradle", 27 | "CogSci_Env_3_00000": "Painting of a dog", 28 | "CogSci_Env_3_00001": "Painting of a camel", 29 | "CogSci_Env_3_00002": "Painting of a dog", 30 | "CogSci_Env_4_00000": "Painting of a fish", 31 | "CogSci_Env_4_00001": "Painting of a meerkat", 32 | "CogSci_Env_4_00002": "Painting of a guinea pig", 33 | "CogSci_Env_5_00000": "Painting of a horse cart", 34 | "CogSci_Env_5_00001": "Painting of a volcano", 35 | "CogSci_Env_5_00002": "Painting of an aeroplane", 36 | "CogSci_Env_6_00000": "Painting of a slot machine", 37 | "CogSci_Env_6_00001": "Painting of a boat", 38 | "CogSci_Env_6_00002": "Painting of a padlock", 39 | "CogSci_Env_7_00000": "Painting of a bike", 40 | "CogSci_Env_7_00001": "Painting of a zebra", 41 | "CogSci_Env_7_00002": "Painting of a turtle", 42 | "CogSci_Env_8_00000": "Painting of an ambulance", 43 | "CogSci_Env_8_00001": "Painting of a soccer ball", 44 | "CogSci_Env_8_00002": "Painting of a hammer", 45 | "CogSci_Env_9_00000": "Painting of a hatchet", 46 | "CogSci_Env_9_00001": "Painting of a bird", 47 | "CogSci_Env_9_00002": "Painting of a typewriter", 48 | "CogSci_Env_10_00000": "Painting of a soccer ball", 49 | "CogSci_Env_10_00001": "Painting of a couch", 50 | "CogSci_Env_10_00002": "Painting of a fish", 51 | } 52 | 53 | 54 | @register_env 55 | class NavEgoEnv: 56 | def __init__( 57 | self, 58 | env_dir: str, 59 | habitat_kwargs: dict[str, Any] = field(default_factory=dict), 60 | image_downscaling: float = 4.0, 61 | ): 62 | """ 63 | Arguments: 64 | env_dir: Path to directory with environment information 65 | habitat_kwargs: Keyword args for loading habitat environment 66 | image_downscaling: Factor to downscale image after rendering 67 | """ 68 | 69 | self._is_sim_initialized = False 70 | self.scene_name = Path(env_dir).name 71 | scene_path = os.path.join(env_dir, "scene/scene.glb") 72 | scene_dataset_config_path = os.path.join( 73 | env_dir, "scene/scene_dataset_config.json" 74 | ) 75 | with open(os.path.join(env_dir, "walkthrough_info.json")) as fp: 76 | info = json.load(fp) 77 | self.info = info 78 | self.image_downscaling = image_downscaling 79 | self.sim = load_sim(scene_path, scene_dataset_config_path, **habitat_kwargs) 80 | 81 | def get_task_info(self): 82 | start_position = np.array(self.info["walkthrough_info"]["positions"][0]) 83 | start_heading = self.info["walkthrough_info"]["headings"][0] 84 | goal_position = np.array(self.info["walkthrough_info"]["positions"][-1]) 85 | goal_desc = GOAL_MAPPING[self.scene_name] 86 | return { 87 | "start_position": start_position, 88 | "start_heading": start_heading, 89 | "goal_position": goal_position, 90 | "goal_desc": goal_desc, 91 | } 92 | 93 | def initialize_sim(self, position: np.ndarray, heading_deg: float): 94 | """ 95 | Arguments: 96 | position: (x, y, z) array in meters 97 | heading_deg: heading angle in degrees 98 | """ 99 | rotation = compute_quaternion_from_heading(heading_deg) 100 | state = habitat_sim.AgentState(position, rotation) 101 | self.sim.reset() 102 | self.sim.initialize_agent(0, state) 103 | self._is_sim_initialized = True 104 | 105 | def reset(self): 106 | position = np.array(self.info["walkthrough_info"]["positions"][0]) 107 | heading_deg = self.info["walkthrough_info"]["headings"][0] 108 | self.initialize_sim(position, heading_deg) 109 | obs = self.get_observation() 110 | return obs 111 | 112 | def get_observation(self): 113 | """ 114 | Get RGB observation from current agent state 115 | """ 116 | assert self._is_sim_initialized, "Simulator is not initialized." 117 | obs = self.sim.get_sensor_observations(0) 118 | return cv2.resize( 119 | obs["rgb"][..., :3], 120 | None, 121 | fx=1.0 / self.image_downscaling, 122 | fy=1.0 / self.image_downscaling, 123 | ) 124 | 125 | def step(self, act: str): 126 | assert self._is_sim_initialized, "Simulator is not initialized." 127 | _ = self.sim.step(act) 128 | obs = self.get_observation() 129 | return obs 130 | 131 | def close(self): 132 | if self.sim is not None: 133 | self.sim.close() 134 | 135 | def get_sim_state(self): 136 | assert self._is_sim_initialized 137 | state = self.sim.agents[0].state 138 | position = np.array(state.position) 139 | rotation = state.rotation 140 | return position, rotation 141 | -------------------------------------------------------------------------------- /space/evaluate_cswm.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | import glob 4 | import json 5 | import logging 6 | import multiprocessing as mp 7 | import os 8 | from collections import defaultdict 9 | from copy import deepcopy 10 | from typing import Any 11 | 12 | import fire 13 | import imageio 14 | import numpy as np 15 | import tqdm 16 | 17 | from space import get_config, get_agent, get_env 18 | from space.utils.common import get_datetimestr 19 | from space.utils.vllm_api import start_vllm_server 20 | 21 | logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR) 22 | 23 | 24 | def play_game( 25 | agent_name: str, 26 | agent_cfg: dict[str, Any], 27 | game_dir: str, 28 | game_name: str, 29 | game_mode: str, 30 | ): 31 | save_dir = agent_cfg["save_dir"] 32 | os.makedirs(save_dir, exist_ok=True) 33 | 34 | # Setup environment 35 | env_name = "Vision_CSWM_Env" if game_mode == "vision" else "Text_CSWM_Env" 36 | env = get_env(env_name, game_dir) 37 | 38 | # Setup agent 39 | agent = get_agent(agent_name, agent_cfg) 40 | agent.reset() 41 | 42 | obs = env.reset() 43 | n_boxes = len(env.rects) 44 | # Setup logging 45 | os.makedirs(save_dir, exist_ok=True) 46 | # Begin evaluation 47 | n_steps_taken = 0 48 | actions = [] 49 | actions_true = [] 50 | observations = [obs] 51 | for _ in range(env.max_steps): 52 | act = agent.get_action(obs) 53 | if not env.is_valid_action(act): 54 | act = env.sample_random_action() 55 | act_true = env.apply_r2t_mapping(act) 56 | obs, _, done, _ = env.step(act) 57 | n_steps_taken += 1 58 | observations.append(obs) 59 | actions.append(act) 60 | actions_true.append(act_true) 61 | if done: 62 | break 63 | if game_mode == "vision": 64 | with imageio.get_writer( 65 | os.path.join(save_dir, "video.mp4"), fps=2 66 | ) as video_writer: 67 | for obs in observations: 68 | video_writer.append_data(obs) 69 | elif game_mode == "text": 70 | with open(os.path.join(save_dir, "game_log.json"), "w") as fp: 71 | json.dump(observations, fp) 72 | else: 73 | raise ValueError(f"Undefined game mode: {game_mode}") 74 | 75 | n_treasures_found = env.n_collected 76 | success = float(env.n_collected == len(env.treasures)) * 100.0 77 | metrics = { 78 | "success": success, 79 | "n_steps_taken": n_steps_taken, 80 | "n_treasures_found": n_treasures_found, 81 | } 82 | 83 | with open(os.path.join(save_dir, "info.json"), "w") as fp: 84 | json.dump( 85 | { 86 | "actions": actions, 87 | "actions_true": actions_true, 88 | "metrics": metrics, 89 | "n_boxes": n_boxes, 90 | }, 91 | fp, 92 | ) 93 | 94 | # Calculate experiment cost 95 | eval_cost = agent.get_eval_cost() 96 | 97 | return actions, actions_true, metrics, eval_cost, game_name, n_boxes 98 | 99 | 100 | def _mp_helper(inputs: dict[str, Any]): 101 | return play_game(**inputs) 102 | 103 | 104 | def main( 105 | model_name: str, 106 | envs_dir: str, 107 | save_dir: str, 108 | n_workers: int = 8, 109 | game_mode: str = "vision", 110 | ): 111 | # Sanity checks 112 | assert game_mode in ["vision", "text"] 113 | 114 | agent_cfg = get_config(model_name) 115 | agent_name = agent_cfg["agent_name"] 116 | del agent_cfg["agent_name"] 117 | 118 | if agent_cfg["use_vllm"]: 119 | start_vllm_server( 120 | agent_cfg["model_name"], agent_cfg["host_port"], agent_cfg["vllm_cfg"] 121 | ) 122 | 123 | save_dir = os.path.join(save_dir, model_name, get_datetimestr()) 124 | os.makedirs(save_dir, exist_ok=True) 125 | 126 | # Load game paths 127 | game_dirs = sorted(glob.glob(os.path.join(envs_dir, "*"))) 128 | 129 | mp_inputs = [] 130 | for game_dir in game_dirs: 131 | agent_cfg_m = deepcopy(agent_cfg) 132 | game_name = os.path.basename(game_dir) 133 | agent_cfg_m["save_dir"] = os.path.join(save_dir, game_name) 134 | mp_inputs.append( 135 | { 136 | "agent_name": agent_name, 137 | "agent_cfg": agent_cfg_m, 138 | "game_dir": game_dir, 139 | "game_name": game_name, 140 | "game_mode": game_mode, 141 | } 142 | ) 143 | 144 | all_outputs = [] 145 | pbar = tqdm.tqdm(total=len(mp_inputs), desc="Evaluating on Space CSWM task") 146 | with mp.Pool(n_workers, maxtasksperchild=1) as pool: 147 | for ( 148 | actions, 149 | actions_true, 150 | metrics, 151 | eval_cost, 152 | game_name, 153 | n_boxes, 154 | ) in pool.imap_unordered(_mp_helper, mp_inputs): 155 | all_outputs.append( 156 | { 157 | "actions": actions, 158 | "actions_true": actions_true, 159 | "metrics": metrics, 160 | "eval_cost": eval_cost, 161 | "game_name": game_name, 162 | "n_boxes": n_boxes, 163 | } 164 | ) 165 | pbar.update() 166 | 167 | all_outputs = sorted(all_outputs, key=lambda x: x["game_name"]) 168 | all_metrics = defaultdict(list) 169 | all_actions = [] 170 | all_actions_true = [] 171 | all_n_boxes = [] 172 | total_experiment_cost = defaultdict(int) 173 | for output in all_outputs: 174 | for k, v in output["metrics"].items(): 175 | all_metrics[k].append(v) 176 | all_actions.append(output["actions"]) 177 | all_actions_true.append(output["actions_true"]) 178 | all_n_boxes.append(output["n_boxes"]) 179 | for k, v in output["eval_cost"].items(): 180 | total_experiment_cost[k] += v 181 | 182 | mean_metrics = {k: np.mean(v).item() for k, v in all_metrics.items()} 183 | for k, v in mean_metrics.items(): 184 | print(f"{k:20s}: {v:6.3f}") 185 | for k, v in total_experiment_cost.items(): 186 | print(f"{k:20s}: {v:6.3f}") 187 | 188 | with open(os.path.join(save_dir, "results.json"), "w") as fp: 189 | json.dump( 190 | { 191 | "all_metrics": all_metrics, 192 | "mean_metrics": mean_metrics, 193 | "all_actions": all_actions, 194 | "all_actions_true": all_actions_true, 195 | "all_n_boxes": all_n_boxes, 196 | "total_experiment_cost": total_experiment_cost, 197 | }, 198 | fp, 199 | ) 200 | 201 | 202 | if __name__ == "__main__": 203 | fire.Fire(main) 204 | -------------------------------------------------------------------------------- /space/evaluate_dmnav.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | import glob 4 | import json 5 | import multiprocessing as mp 6 | import os 7 | from copy import deepcopy 8 | from typing import Any 9 | 10 | import fire 11 | import imageio 12 | import numpy as np 13 | import tqdm 14 | 15 | from space import get_config, get_agent, get_env 16 | from space.utils.common import get_datetimestr 17 | from space.envs.nav_dm import evaluate_path_efficiency 18 | from space.utils.vllm_api import start_vllm_server 19 | 20 | 21 | LOCAL_CONTEXT = 5 22 | 23 | 24 | def evaluate_on_env( 25 | agent_name: str, 26 | agent_cfg: dict[str, Any], 27 | env_dir: str, 28 | obs_type: str, 29 | walkthrough_dir: str, 30 | env_name: str, 31 | walkthrough_key: str, 32 | max_steps: int = 100, 33 | ): 34 | save_dir = agent_cfg["save_dir"] 35 | os.makedirs(save_dir, exist_ok=True) 36 | 37 | # Setup environment 38 | env = get_env( 39 | "NavDiscreteMapEnv", os.path.join(env_dir, "info.json"), LOCAL_CONTEXT, obs_type 40 | ) 41 | 42 | # Load walkthrough from path 43 | if obs_type == "image": 44 | walkthrough_obs = [] 45 | with imageio.get_reader( 46 | os.path.join(walkthrough_dir, f"{walkthrough_key}_obs.mp4") 47 | ) as reader: 48 | for f in reader: 49 | walkthrough_obs.append(f) 50 | elif obs_type == "text": 51 | with open( 52 | os.path.join(walkthrough_dir, f"{walkthrough_key}_obs.json"), "r" 53 | ) as fp: 54 | walkthrough_obs = json.load(fp) 55 | else: 56 | raise ValueError(f"Observation type {obs_type} is not defined!") 57 | 58 | # Get ground-truth shortest path 59 | with open(os.path.join(walkthrough_dir, "shortestpath_info.json"), "r") as fp: 60 | gt_path = json.load(fp)["positions"] 61 | 62 | # Setup agent 63 | agent = get_agent(agent_name, agent_cfg) 64 | agent.reset(walkthrough_key) 65 | 66 | # Provide walkthrough to agent 67 | agent.initialize_with_walkthrough(walkthrough_obs) 68 | 69 | # Provide goal to agent 70 | goal_desc = f"landmark {env.goal_name}" 71 | agent.initialize_with_goal(goal_desc) 72 | 73 | # Start navigation 74 | # Initialize environment at walkthrough start state 75 | obs = env.reset() 76 | stop_issued = False 77 | pred_path = [env.current_location] 78 | actions_taken = [] 79 | if obs_type == "image": 80 | writer = imageio.get_writer(os.path.join(save_dir, "observations.mp4")) 81 | else: 82 | writer = open(os.path.join(save_dir, "observations.txt"), "w") 83 | 84 | def _log_observation(obs): 85 | if obs_type == "image": 86 | writer.append_data(obs) 87 | else: 88 | writer.write(obs) 89 | writer.write("\n\n") 90 | writer.write("-" * 25) 91 | writer.write("\n\n") 92 | 93 | _log_observation(obs) 94 | for _ in range(max_steps): 95 | act = agent.get_action(obs) 96 | if act == "stop": 97 | stop_issued = True 98 | actions_taken.append(act) 99 | break 100 | obs = env.step(act) 101 | pred_path.append(env.current_location) 102 | actions_taken.append(act) 103 | _log_observation(obs) 104 | 105 | writer.close() 106 | 107 | # Calculate metrics 108 | ## evaluate_path_efficiency assumes (r, c) inputs for positions 109 | grid = np.array([[0 if col == "0" else 1 for col in row] for row in env.textmap]) 110 | metrics = evaluate_path_efficiency( 111 | [(r, c) for c, r in gt_path], 112 | [(r, c) for c, r in pred_path], 113 | grid, 114 | stop_issued, 115 | dist_thresh=1.5, 116 | ) 117 | 118 | # Calculate experiment cost 119 | eval_cost = agent.get_eval_cost() 120 | 121 | # Save information 122 | with open(os.path.join(save_dir, "results.json"), "w") as fp: 123 | json.dump( 124 | { 125 | "metrics": metrics, 126 | "pred_path": pred_path, 127 | "actions_taken": actions_taken, 128 | "eval_cost": eval_cost, 129 | }, 130 | fp, 131 | ) 132 | 133 | return metrics, eval_cost, env_name 134 | 135 | 136 | def _mp_helper(inputs: dict[str, Any]): 137 | return evaluate_on_env(**inputs) 138 | 139 | 140 | def main( 141 | model_name: str, 142 | envs_dir: str, 143 | walkthroughs_dir: str, 144 | obs_type: str, 145 | save_dir: str, 146 | walkthrough_key: str, 147 | max_steps: int = 100, 148 | n_workers: int = 8, 149 | ): 150 | # Sanity checks 151 | assert obs_type in ["text", "image"] 152 | assert walkthrough_key in ["shortestpath", "walkthrough"] 153 | 154 | agent_cfg = get_config(model_name) 155 | agent_name = agent_cfg["agent_name"] 156 | del agent_cfg["agent_name"] 157 | 158 | if agent_cfg["use_vllm"]: 159 | start_vllm_server( 160 | agent_cfg["model_name"], agent_cfg["host_port"], agent_cfg["vllm_cfg"] 161 | ) 162 | 163 | save_dir = os.path.join(save_dir, model_name, get_datetimestr()) 164 | os.makedirs(save_dir, exist_ok=True) 165 | 166 | # Load maze paths 167 | env_dirs = sorted(glob.glob(os.path.join(envs_dir, "*"))) 168 | 169 | mp_inputs = [] 170 | for env_dir in env_dirs: 171 | agent_cfg_m = deepcopy(agent_cfg) 172 | env_name = os.path.basename(env_dir) 173 | if "save_dir" in agent_cfg: 174 | agent_cfg_m["save_dir"] = os.path.join(save_dir, env_name) 175 | mp_inputs.append( 176 | { 177 | "agent_name": agent_name, 178 | "agent_cfg": agent_cfg_m, 179 | "env_dir": env_dir, 180 | "obs_type": obs_type, 181 | "walkthrough_dir": os.path.join(walkthroughs_dir, env_name), 182 | "env_name": env_name, 183 | "walkthrough_key": walkthrough_key, 184 | "max_steps": max_steps, 185 | } 186 | ) 187 | 188 | all_outputs = [] 189 | pbar = tqdm.tqdm(total=len(mp_inputs), desc="Evaluating on SPACE navigation") 190 | with mp.Pool(n_workers, maxtasksperchild=1) as pool: 191 | for metrics, eval_cost, env_name in pool.imap_unordered(_mp_helper, mp_inputs): 192 | all_outputs.append( 193 | {"metrics": metrics, "eval_cost": eval_cost, "env_name": env_name} 194 | ) 195 | pbar.update() 196 | pbar.close() 197 | 198 | all_outputs = sorted(all_outputs, key=lambda x: x["env_name"]) 199 | all_metrics = [] 200 | total_experiment_cost = {} 201 | for outputs in all_outputs: 202 | all_metrics.append(outputs["metrics"]) 203 | for k, v in eval_cost.items(): 204 | if k not in total_experiment_cost: 205 | total_experiment_cost[k] = 0.0 206 | total_experiment_cost[k] += v 207 | 208 | mean_metrics = { 209 | k: np.mean([m[k] for m in all_metrics]).item() for k in all_metrics[0].keys() 210 | } 211 | for k, v in mean_metrics.items(): 212 | print(f"{k:20s}: {v:6.3f}") 213 | for k, v in total_experiment_cost.items(): 214 | print(f"{k:20s}: {v:6.3f}") 215 | 216 | with open(os.path.join(save_dir, "results.json"), "w") as fp: 217 | json.dump( 218 | { 219 | "all_metrics": all_metrics, 220 | "total_experiment_cost": total_experiment_cost, 221 | }, 222 | fp, 223 | ) 224 | 225 | 226 | if __name__ == "__main__": 227 | fire.Fire(main) 228 | -------------------------------------------------------------------------------- /space/evaluate_egonav.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | import glob 4 | import json 5 | import multiprocessing as mp 6 | import os 7 | from copy import deepcopy 8 | from typing import Any 9 | 10 | import fire 11 | import imageio 12 | import numpy as np 13 | import tqdm 14 | 15 | os.environ["MAGNUM_LOG"] = "quiet" 16 | os.environ["HABITAT_SIM_LOG"] = "quiet" 17 | 18 | 19 | from space import get_config, get_agent, get_env 20 | from space.utils.common import get_datetimestr 21 | from space.utils.habitat import ( 22 | compute_heading_from_quaternion, 23 | SPL, 24 | DistanceToGoal, 25 | Success, 26 | ) 27 | from space.utils.visualizations import add_goal_to_obs 28 | from space.utils.vllm_api import start_vllm_server 29 | 30 | IMAGE_DOWNSCALING = 4 31 | HABITAT_CONFIG = { 32 | "resolution": [512 * IMAGE_DOWNSCALING, 512 * IMAGE_DOWNSCALING], 33 | "forward_amount": 0.25, 34 | "turn_amount": 30, 35 | } 36 | 37 | 38 | def evaluate_on_env( 39 | agent_name: str, 40 | agent_cfg: dict[str, Any], 41 | env_dir: str, 42 | env_name: str, 43 | walkthrough_key: str, 44 | max_steps: int = 250, 45 | ): 46 | save_dir = agent_cfg["save_dir"] 47 | os.makedirs(save_dir, exist_ok=True) 48 | env = get_env( 49 | "NavEgoEnv", 50 | env_dir, 51 | habitat_kwargs=HABITAT_CONFIG, 52 | image_downscaling=IMAGE_DOWNSCALING, 53 | ) 54 | 55 | # Load walkthrough from path 56 | walkthrough_video_frames = [] 57 | with imageio.get_reader(os.path.join(env_dir, f"{walkthrough_key}.mp4")) as reader: 58 | for f in reader: 59 | walkthrough_video_frames.append(f) 60 | 61 | # Setup metrics 62 | task_info = env.get_task_info() 63 | d2g_metric = DistanceToGoal(env.sim, task_info["goal_position"]) 64 | success_metric = Success(env.sim, task_info["goal_position"]) 65 | spl_metric = SPL(env.sim, task_info["start_position"], task_info["goal_position"]) 66 | 67 | # Setup agent 68 | agent = get_agent(agent_name, agent_cfg) 69 | agent.reset(walkthrough_key) 70 | 71 | # Provide walkthrough to agent 72 | agent.initialize_with_walkthrough(walkthrough_video_frames) 73 | 74 | # Provide goal to agent 75 | goal_desc = task_info["goal_desc"] 76 | goal_image = walkthrough_video_frames[-1] 77 | agent.initialize_with_goal(goal_desc, goal_image) 78 | 79 | # Start navigation 80 | # Initialize environment at walkthrough start state 81 | obs = env.reset() 82 | stop_issued = False 83 | pos, rot = env.get_sim_state() 84 | trajectory_positions = [pos] 85 | trajectory_rotations = [rot] 86 | actions_taken = [] 87 | video_writer = imageio.get_writer(os.path.join(save_dir, "video.mp4")) 88 | vis_img = add_goal_to_obs(obs, goal_image) 89 | video_writer.append_data(vis_img) 90 | for _ in range(max_steps): 91 | act = agent.get_action(obs) 92 | if act not in ["move_forward", "turn_left", "turn_right", "stop"]: 93 | print( 94 | f"Obtained invalid action `{act}` from system. Replacing it with `turn_left`." 95 | ) 96 | act = "turn_left" 97 | if act == "stop": 98 | stop_issued = True 99 | actions_taken.append(act) 100 | break 101 | obs = env.step(act) 102 | pos, rot = env.get_sim_state() 103 | trajectory_positions.append(pos) 104 | trajectory_rotations.append(rot) 105 | actions_taken.append(act) 106 | vis_img = add_goal_to_obs(obs, goal_image) 107 | video_writer.append_data(vis_img) 108 | video_writer.close() 109 | 110 | d2g = d2g_metric(trajectory_positions) 111 | success = success_metric(stop_issued, trajectory_positions) 112 | spl = spl_metric(stop_issued, trajectory_positions) 113 | metrics = { 114 | "distance_to_goal": d2g, 115 | "success": success, 116 | "spl": spl, 117 | } 118 | trajectory = { 119 | "positions": [t.tolist() for t in trajectory_positions], 120 | "headings": [compute_heading_from_quaternion(r) for r in trajectory_rotations], 121 | "actions": actions_taken, 122 | } 123 | 124 | env.close() 125 | 126 | # Calculate experiment cost 127 | eval_cost = agent.get_eval_cost() 128 | 129 | # Save information 130 | with open(os.path.join(save_dir, "results.json"), "w") as fp: 131 | json.dump( 132 | { 133 | "metrics": metrics, 134 | "trajectory": trajectory, 135 | "eval_cost": eval_cost, 136 | }, 137 | fp, 138 | ) 139 | 140 | return metrics, eval_cost, env_name 141 | 142 | 143 | def _mp_helper(inputs: dict[str, Any]): 144 | return evaluate_on_env(**inputs) 145 | 146 | 147 | def main( 148 | model_name: str, 149 | envs_dir: str, 150 | save_dir: str, 151 | walkthrough_key: str, 152 | max_steps: int = 250, 153 | n_workers: int = 8, 154 | ): 155 | agent_cfg = get_config(model_name) 156 | agent_name = agent_cfg["agent_name"] 157 | del agent_cfg["agent_name"] 158 | 159 | if agent_cfg["use_vllm"]: 160 | start_vllm_server( 161 | agent_cfg["model_name"], agent_cfg["host_port"], agent_cfg["vllm_cfg"] 162 | ) 163 | 164 | save_dir = os.path.join(save_dir, model_name, get_datetimestr()) 165 | os.makedirs(save_dir, exist_ok=True) 166 | 167 | # Load maze paths 168 | env_dirs = sorted(glob.glob(os.path.join(envs_dir, "*"))) 169 | 170 | mp_inputs = [] 171 | for env_dir in env_dirs: 172 | agent_cfg_m = deepcopy(agent_cfg) 173 | env_name = os.path.basename(env_dir) 174 | if "save_dir" in agent_cfg: 175 | agent_cfg_m["save_dir"] = os.path.join(save_dir, env_name) 176 | mp_inputs.append( 177 | { 178 | "agent_name": agent_name, 179 | "agent_cfg": agent_cfg_m, 180 | "env_dir": env_dir, 181 | "env_name": env_name, 182 | "walkthrough_key": walkthrough_key, 183 | "max_steps": max_steps, 184 | } 185 | ) 186 | 187 | all_outputs = [] 188 | pbar = tqdm.tqdm(total=len(mp_inputs), desc="Evaluating on SPACE navigation") 189 | with mp.Pool(n_workers, maxtasksperchild=1) as pool: 190 | for metrics, eval_cost, env_name in pool.imap_unordered(_mp_helper, mp_inputs): 191 | all_outputs.append( 192 | {"metrics": metrics, "eval_cost": eval_cost, "env_name": env_name} 193 | ) 194 | pbar.update() 195 | pbar.close() 196 | 197 | all_outputs = sorted(all_outputs, key=lambda x: x["env_name"]) 198 | all_metrics = [] 199 | total_experiment_cost = {} 200 | for outputs in all_outputs: 201 | all_metrics.append(outputs["metrics"]) 202 | for k, v in eval_cost.items(): 203 | if k not in total_experiment_cost: 204 | total_experiment_cost[k] = 0.0 205 | total_experiment_cost[k] += v 206 | 207 | mean_metrics = { 208 | k: np.mean([m[k] for m in all_metrics]).item() for k in all_metrics[0].keys() 209 | } 210 | for k, v in mean_metrics.items(): 211 | print(f"{k:20s}: {v:6.3f}") 212 | for k, v in total_experiment_cost.items(): 213 | print(f"{k:20s}: {v:6.3f}") 214 | 215 | with open(os.path.join(save_dir, "results.json"), "w") as fp: 216 | json.dump( 217 | { 218 | "all_metrics": all_metrics, 219 | "total_experiment_cost": total_experiment_cost, 220 | }, 221 | fp, 222 | ) 223 | 224 | 225 | if __name__ == "__main__": 226 | fire.Fire(main) 227 | -------------------------------------------------------------------------------- /space/evaluate_mct.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | import glob 4 | import json 5 | import multiprocessing as mp 6 | import os 7 | from copy import deepcopy 8 | from typing import Any 9 | 10 | import fire 11 | import numpy as np 12 | import tqdm 13 | 14 | from space import get_config, get_agent, get_env 15 | from space.utils.common import get_datetimestr 16 | from space.envs.mct import evaluate_path_efficiency 17 | from space.utils.vllm_api import start_vllm_server 18 | 19 | 20 | def evaluate_on_maze( 21 | agent_name: str, 22 | agent_cfg: dict[str, Any], 23 | maze_dir: str, 24 | maze_name: str, 25 | max_steps: int = 250, 26 | ): 27 | save_dir = agent_cfg["save_dir"] 28 | os.makedirs(save_dir, exist_ok=True) 29 | 30 | # Setup environment 31 | env = get_env("MCT_Env", maze_dir, agent_cfg["description_type"]) 32 | 33 | # Setup agent 34 | agent = get_agent(agent_name, agent_cfg) 35 | agent.reset() 36 | 37 | # Get ground-truth shortest path 38 | gt_path = env.shortest_path 39 | obs, info = env.reset() 40 | pred_path = [env.current] 41 | actions = [] 42 | issued_stop = False 43 | for _ in range(max_steps): 44 | act = agent.get_action(obs) 45 | if act not in ["up", "down", "left", "right", "stop"]: 46 | print( 47 | f"Obtained invalid action `{act}` from system. Replacing it with `up`." 48 | ) 49 | act = "up" 50 | if act == "stop": 51 | actions.append(act) 52 | issued_stop = True 53 | break 54 | obs, info = env.step(act) 55 | pred_path.append(env.current) 56 | actions.append(act) 57 | 58 | # Calculate metrics 59 | metrics = evaluate_path_efficiency( 60 | gt_path, pred_path, env.maze, issued_stop=issued_stop 61 | ) 62 | # Calculate experiment cost 63 | eval_cost = agent.get_eval_cost() 64 | 65 | return metrics, pred_path, actions, eval_cost, maze_name 66 | 67 | 68 | def _mp_helper(inputs: dict[str, Any]): 69 | return evaluate_on_maze(**inputs) 70 | 71 | 72 | def main( 73 | model_name: str, 74 | envs_dir: str, 75 | save_dir: str, 76 | n_workers: int = 8, 77 | ): 78 | agent_cfg = get_config(model_name) 79 | agent_name = agent_cfg["agent_name"] 80 | del agent_cfg["agent_name"] 81 | 82 | if agent_cfg["use_vllm"]: 83 | start_vllm_server( 84 | agent_cfg["model_name"], agent_cfg["host_port"], agent_cfg["vllm_cfg"] 85 | ) 86 | 87 | save_dir = os.path.join(save_dir, model_name, get_datetimestr()) 88 | os.makedirs(save_dir, exist_ok=True) 89 | 90 | # Load maze paths 91 | maze_dirs = sorted(glob.glob(os.path.join(envs_dir, "*"))) 92 | 93 | mp_inputs = [] 94 | for maze_dir in maze_dirs: 95 | agent_cfg_m = deepcopy(agent_cfg) 96 | maze_name = os.path.basename(maze_dir) 97 | if "save_dir" in agent_cfg: 98 | agent_cfg_m["save_dir"] = os.path.join(save_dir, maze_name) 99 | mp_inputs.append( 100 | { 101 | "agent_name": agent_name, 102 | "agent_cfg": agent_cfg_m, 103 | "maze_dir": maze_dir, 104 | "maze_name": maze_name, 105 | } 106 | ) 107 | 108 | all_outputs = [] 109 | pbar = tqdm.tqdm(total=len(mp_inputs), desc="Evaluating on SPACE MCT") 110 | with mp.Pool(n_workers, maxtasksperchild=1) as pool: 111 | for ( 112 | metrics, 113 | path_taken, 114 | actions, 115 | eval_cost, 116 | maze_name, 117 | ) in pool.imap_unordered(_mp_helper, mp_inputs): 118 | all_outputs.append( 119 | { 120 | "metrics": metrics, 121 | "path_taken": path_taken, 122 | "actions": actions, 123 | "eval_cost": eval_cost, 124 | "maze_name": maze_name, 125 | } 126 | ) 127 | # Save info for episode 128 | log_dir_i = os.path.join(save_dir, maze_name) 129 | with open(os.path.join(log_dir_i, "metrics.json"), "w") as fp: 130 | json.dump( 131 | { 132 | "metrics": metrics, 133 | "actions": actions, 134 | "path_taken": path_taken, 135 | }, 136 | fp, 137 | ) 138 | pbar.update() 139 | 140 | all_outputs = sorted(all_outputs, key=lambda x: x["maze_name"]) 141 | all_metrics = [] 142 | all_paths_taken = [] 143 | all_actions = [] 144 | total_experiment_cost = {} 145 | for outputs in all_outputs: 146 | all_metrics.append(outputs["metrics"]) 147 | all_paths_taken.append(outputs["path_taken"]) 148 | all_actions.append(outputs["actions"]) 149 | for k, v in eval_cost.items(): 150 | if k not in total_experiment_cost: 151 | total_experiment_cost[k] = 0.0 152 | total_experiment_cost[k] += v 153 | 154 | mean_metrics = { 155 | k: np.mean([m[k] for m in all_metrics]).item() for k in all_metrics[0].keys() 156 | } 157 | for k, v in mean_metrics.items(): 158 | print(f"{k:20s}: {v:6.3f}") 159 | for k, v in total_experiment_cost.items(): 160 | print(f"{k:20s}: {v:6.3f}") 161 | 162 | with open(os.path.join(save_dir, "results.json"), "w") as fp: 163 | json.dump( 164 | { 165 | "all_metrics": all_metrics, 166 | "all_paths_taken": all_paths_taken, 167 | "all_actions": all_actions, 168 | "total_experiment_cost": total_experiment_cost, 169 | }, 170 | fp, 171 | ) 172 | 173 | 174 | if __name__ == "__main__": 175 | fire.Fire(main) 176 | -------------------------------------------------------------------------------- /space/evaluate_qas.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | import json 4 | import multiprocessing as mp 5 | import os 6 | from collections import defaultdict 7 | from copy import deepcopy 8 | from typing import Any 9 | 10 | import fire 11 | import numpy as np 12 | import tqdm 13 | 14 | from space import get_config, get_agent 15 | from space.utils.common import get_datetimestr 16 | from space.utils.common import get_image_as_message, get_video_as_messages 17 | from space.utils.vllm_api import start_vllm_server 18 | 19 | 20 | def evaluate_on_qa( 21 | agent_name: str, 22 | agent_cfg: dict[str, Any], 23 | qa: dict[str, Any], 24 | ): 25 | # Setup agent 26 | agent = get_agent(agent_name, agent_cfg) 27 | agent.reset() 28 | question = qa["question"] 29 | answer = qa["answer"] 30 | if isinstance(question, list): 31 | question_content = [] 32 | for q in question: 33 | if q.startswith("IMAGE:"): 34 | image_path = q[len("IMAGE:") :] 35 | message = get_image_as_message( 36 | image_path=image_path, 37 | model_name=agent.model_name, 38 | image_detail=agent.image_detail, 39 | ) 40 | question_content.append(message) 41 | elif q.startswith("VIDEO:"): 42 | video_path = q[len("VIDEO:") :] 43 | video_messages = get_video_as_messages( 44 | video_path, 45 | model_name=agent.model_name, 46 | subsampling_factor=agent.subsampling_factor, 47 | image_detail=agent.image_detail, 48 | ) 49 | question_content.extend(video_messages) 50 | else: 51 | question_content.append(q) 52 | elif isinstance(question, str): 53 | question_content = question 54 | else: 55 | raise ValueError( 56 | f"Unable to parse question_content with type: {type(question)}" 57 | ) 58 | 59 | P = agent.get_prediction(question_content, answer) 60 | metrics = {"accuracy": float(P == qa["answer"]) * 100.0} 61 | 62 | # Calculate experiment cost 63 | eval_cost = agent.get_eval_cost() 64 | 65 | return metrics, P, eval_cost 66 | 67 | 68 | def _mp_helper(inputs: dict[str, Any]): 69 | return evaluate_on_qa(**inputs) 70 | 71 | 72 | def main( 73 | model_name: str, 74 | data_path: str, 75 | save_dir: str, 76 | n_workers: int = 8, 77 | ): 78 | agent_cfg = get_config(model_name) 79 | agent_name = agent_cfg["agent_name"] 80 | del agent_cfg["agent_name"] 81 | 82 | if agent_cfg["use_vllm"]: 83 | start_vllm_server( 84 | agent_cfg["model_name"], agent_cfg["host_port"], agent_cfg["vllm_cfg"] 85 | ) 86 | 87 | save_dir = os.path.join(save_dir, model_name, get_datetimestr()) 88 | os.makedirs(save_dir, exist_ok=True) 89 | 90 | with open(data_path, "r") as fp: 91 | dataset = json.load(fp) 92 | 93 | mp_inputs = [] 94 | for i, qa in enumerate(dataset): 95 | agent_cfg_i = deepcopy(agent_cfg) 96 | if "save_dir" in agent_cfg_i: 97 | agent_cfg_i["save_dir"] = os.path.join(save_dir, f"qa_{i:05d}") 98 | mp_inputs.append( 99 | { 100 | "agent_name": agent_name, 101 | "agent_cfg": agent_cfg_i, 102 | "qa": qa, 103 | } 104 | ) 105 | 106 | all_metrics = [] 107 | all_predictions = [] 108 | total_experiment_cost = {} 109 | 110 | pbar = tqdm.tqdm(total=len(mp_inputs), desc="Evaluating on QAs") 111 | with mp.Pool(n_workers, maxtasksperchild=1) as pool: 112 | for metrics, P, eval_cost in pool.imap(_mp_helper, mp_inputs): 113 | all_metrics.append(metrics) 114 | all_predictions.append(P) 115 | for k, v in eval_cost.items(): 116 | if k not in total_experiment_cost: 117 | total_experiment_cost[k] = 0.0 118 | total_experiment_cost[k] += v 119 | pbar.update() 120 | 121 | metrics = defaultdict(list) 122 | for m in all_metrics: 123 | for k, v in m.items(): 124 | metrics[k].append(v) 125 | mean_metrics = {} 126 | for k, v_list in metrics.items(): 127 | v_mean = np.mean(v_list).item() 128 | mean_metrics[k] = v_mean 129 | print(f"{k:<20s} | {v_mean:>7.3f}") 130 | 131 | with open(os.path.join(save_dir, "results.json"), "w") as fp: 132 | json.dump( 133 | { 134 | "all_metrics": all_metrics, 135 | "all_predictions": all_predictions, 136 | "total_experiment_cost": total_experiment_cost, 137 | "mean_metrics": mean_metrics, 138 | }, 139 | fp, 140 | ) 141 | 142 | 143 | if __name__ == "__main__": 144 | fire.Fire(main) 145 | -------------------------------------------------------------------------------- /space/registry.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | AGENTS_REGISTRY = {} 4 | CONFIGS_REGISTRY = {} 5 | ENVS_REGISTRY = {} 6 | 7 | 8 | def register_agent(cls): 9 | global AGENTS_REGISTRY 10 | AGENTS_REGISTRY[cls.__name__] = cls 11 | return cls 12 | 13 | 14 | def register_config(name: str): 15 | def decorator(func): 16 | global CONFIGS_REGISTRY 17 | assert name not in CONFIGS_REGISTRY 18 | CONFIGS_REGISTRY[name] = func 19 | return func 20 | 21 | return decorator 22 | 23 | 24 | def register_env(cls: str): 25 | global ENVS_REGISTRY 26 | ENVS_REGISTRY[cls.__name__] = cls 27 | return cls 28 | -------------------------------------------------------------------------------- /space/utils/claude_api.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | import copy 4 | import json 5 | import os 6 | import time 7 | from typing import Any 8 | 9 | import anthropic 10 | from anthropic import Anthropic 11 | 12 | from space.utils.common import convert_content_to_str 13 | 14 | 15 | def setup_llm_client(model_name: str): 16 | assert model_name.startswith("claude-") 17 | assert "ANTHROPIC_API_KEY" in os.environ 18 | api_key = os.environ["ANTHROPIC_API_KEY"] 19 | client = Anthropic(api_key=api_key) 20 | return client 21 | 22 | 23 | class Dialog: 24 | def __init__(self, log_writer=None): 25 | self.history = [] 26 | self.log_writer = log_writer 27 | self.images_save_dir = None 28 | self.log_dir = None 29 | if log_writer is not None: 30 | self.log_dir = os.path.dirname(self.log_writer.file_name) 31 | self.log_count = 0 32 | self.images_save_dir = os.path.dirname(self.log_writer.file_name) 33 | os.makedirs(os.path.join(self.images_save_dir, "images"), exist_ok=True) 34 | 35 | def add_system_message(self, **kwargs): 36 | self.history.append({"role": "system", **kwargs}) 37 | self.log_to_file( 38 | "system", convert_content_to_str(kwargs["content"], self.images_save_dir) 39 | ) 40 | 41 | def add_user_message(self, **kwargs): 42 | self.history.append({"role": "user", **kwargs}) 43 | self.log_to_file( 44 | "user", convert_content_to_str(kwargs["content"], self.images_save_dir) 45 | ) 46 | 47 | def add_assistant_message(self, **kwargs): 48 | self.history.append({"role": "assistant", **kwargs}) 49 | self.log_to_file( 50 | "assistant", 51 | convert_content_to_str(kwargs["content"], self.images_save_dir), 52 | ) 53 | 54 | def add_inner_thoughts(self, **kwargs): 55 | if kwargs["content"] is not None: 56 | self.log_to_file( 57 | "assistant (inner thoughts)", 58 | convert_content_to_str(kwargs["content"], self.images_save_dir), 59 | bold_italics_code="bi", 60 | ) 61 | 62 | def log_to_file(self, role, content_str, bold_italics_code="b"): 63 | if self.log_writer is not None: 64 | self.log_writer.new_paragraph( 65 | f"{role.upper()}", bold_italics_code=bold_italics_code 66 | ) 67 | self.log_writer.new_paragraph("\n" + content_str + "\n") 68 | self.log_writer.create_md_file() 69 | 70 | def log_token_usage( 71 | self, prompt_tokens: int, completion_tokens: int, total_cost: int = None 72 | ): 73 | if self.log_writer is not None: 74 | self.log_writer.write("\n") 75 | self.log_writer.write( 76 | "====> Token usage: prompt tokens: {}, completion_tokens: {}".format( 77 | prompt_tokens, completion_tokens 78 | ) 79 | ) 80 | if total_cost is not None: 81 | self.log_writer.write("\n") 82 | self.log_writer.write(f"====> Total cost: ${total_cost:.3f}") 83 | self.log_writer.write("\n\n") 84 | self.log_writer.create_md_file() 85 | 86 | def log_response_time(self, time_taken: float): 87 | if self.log_writer is not None: 88 | self.log_writer.write("\n") 89 | self.log_writer.write(f"====> Response time (sec): {time_taken:.3f}") 90 | self.log_writer.write("\n\n") 91 | self.log_writer.create_md_file() 92 | 93 | @property 94 | def dialog(self): 95 | return copy.deepcopy(self.history) 96 | 97 | def delete_last_message(self): 98 | del self.history[-1] 99 | 100 | def clear_history(self): 101 | self.history = [] 102 | 103 | def clone(self): 104 | dialog_clone = Dialog(self.log_writer) 105 | dialog_clone.history = copy.deepcopy(self.history) 106 | return dialog_clone 107 | 108 | def write_dialog(self): 109 | if self.log_writer is not None: 110 | save_path = os.path.join(self.log_dir, f"dialog_{self.log_count:05d}.json") 111 | with open(save_path, "w") as fp: 112 | json.dump(self.history, fp) 113 | save_path = os.path.join(self.log_dir, f"dialog_{self.log_count:05d}.txt") 114 | with open(save_path, "w") as fp: 115 | for h in self.history: 116 | content_str = convert_content_to_str( 117 | h["content"], save_dir=None, ignore_images=True 118 | ) 119 | role_str = h["role"].upper() 120 | fp.write(f"{role_str}: {content_str}\n\n") 121 | self.log_count += 1 122 | else: 123 | print( 124 | "WARNING: Dialog object does not have a log_writer, so write_dialog() failed..." 125 | ) 126 | 127 | 128 | # Function to get response from model 129 | def get_model_response( 130 | client: Any, 131 | dialog: list[Any], 132 | model_name: str = "claude-3-5-sonnet-20240620", 133 | temperature: float = 0.5, 134 | max_tokens: int = 1000, 135 | max_retries: int = 10, 136 | sleep_secs: int = 60, 137 | verbose: bool = True, 138 | **kwargs, 139 | ): 140 | num_retries = 0 141 | start_time = time.time() 142 | excptn_info = {} 143 | # Process dialog 144 | dialog_proc = [] 145 | for d in dialog: 146 | assert len(d.keys()) == 2 147 | dc = d["content"] 148 | if isinstance(dc, str): 149 | dialog_proc.append({"role": d["role"], "content": dc}) 150 | elif isinstance(dc, list): 151 | dc_new = [] 152 | for dc_ in dc: 153 | if isinstance(dc_, dict): 154 | dc_new.append(dc_) 155 | elif isinstance(dc_, str): 156 | dc_new.append({"type": "text", "text": dc_}) 157 | else: 158 | raise ValueError("Cannot process non-dict and non-str content") 159 | dialog_proc.append({"role": d["role"], "content": dc_new}) 160 | else: 161 | raise ValueError("Cannot process non-str and non-list content") 162 | 163 | while num_retries < max_retries: 164 | try: 165 | response = client.messages.create( 166 | model=model_name, 167 | messages=dialog_proc, 168 | temperature=temperature, 169 | max_tokens=max_tokens, 170 | **kwargs, 171 | ) 172 | response_txt = response.content[0].text 173 | token_counts = { 174 | "completion_tokens": response.usage.output_tokens, 175 | "prompt_tokens": response.usage.input_tokens, 176 | } 177 | break 178 | except anthropic.APIConnectionError: 179 | excptn_info = { 180 | "exception_message": "The server could not be reached.", 181 | "exception_type": "APIConnectionError", 182 | "exception_code": "N/A", 183 | } 184 | response = None 185 | break 186 | except anthropic.RateLimitError as excptn: 187 | excptn_info = { 188 | "exception_message": "Received a rate limiting error.", 189 | "exception_type": "RateLimitError", 190 | "exception_code": "429", 191 | } 192 | response = None 193 | num_retries += 1 194 | if num_retries >= max_retries: 195 | if verbose: 196 | print(excptn) 197 | time.sleep(sleep_secs) 198 | except anthropic.BadRequestError as excptn: 199 | excptn_info = { 200 | "exception_message": excptn.message, 201 | "exception_type": "BadRequestError", 202 | "exception_code": str(excptn.status_code), 203 | } 204 | except anthropic.APIStatusError as excptn: 205 | excptn_info = { 206 | "exception_message": "Received an API Status Error.", 207 | "exception_type": "APIStatusError", 208 | "exception_code": str(excptn.status_code), 209 | } 210 | response = None 211 | break 212 | 213 | if response is None or num_retries >= max_retries: 214 | if verbose: 215 | print(f"===> Failed after {max_retries} retries") 216 | print(excptn_info) 217 | response_txt = None 218 | token_counts = {} 219 | 220 | time_taken = time.time() - start_time 221 | output = { 222 | "text": response_txt, 223 | **token_counts, 224 | "response_time": time_taken, 225 | **excptn_info, 226 | } 227 | return output 228 | -------------------------------------------------------------------------------- /space/utils/common.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | from typing import Any, Optional, Union 4 | 5 | import os 6 | import cv2 7 | import base64 8 | import time 9 | import random 10 | import string 11 | import imageio 12 | import numpy as np 13 | from PIL import Image 14 | from datetime import datetime 15 | from mdutils.tools import Html 16 | from io import BytesIO 17 | 18 | 19 | def get_datetimestr() -> str: 20 | time_now = datetime.now() 21 | time_now_str = time_now.strftime("%Y%m%d_%H%M%S") 22 | return time_now_str 23 | 24 | 25 | def get_pid() -> int: 26 | return os.getpid() 27 | 28 | 29 | def get_random_string(str_len: int = 10) -> str: 30 | return "".join(random.choices(string.ascii_uppercase + string.digits, k=str_len)) 31 | 32 | 33 | def encode_image(image: np.ndarray): 34 | buffered = BytesIO() 35 | image = Image.fromarray(image) 36 | image.save(buffered, format="JPEG") 37 | output = base64.b64encode(buffered.getvalue()).decode("utf-8") 38 | return output 39 | 40 | 41 | def decode_image(base64_string): 42 | # Step 1: Decode the base64 string 43 | image_data = base64.b64decode(base64_string) 44 | # Step 2: Convert to a bytes-like object 45 | image_bytes = BytesIO(image_data) 46 | # Step 3: Open the image using PIL 47 | image = Image.open(image_bytes) 48 | # Step 4: Convert the image to a numpy array 49 | image_array = np.array(image) 50 | return image_array 51 | 52 | 53 | def get_image_as_message( 54 | image: Optional[np.ndarray] = None, 55 | image_path: Optional[str] = None, 56 | model_name: Optional[str] = None, 57 | image_detail: str = "low", 58 | ): 59 | # Sanity checks 60 | assert image is None or image_path is None 61 | assert not (image is not None and image_path is not None) 62 | mode = ( 63 | "claude" 64 | if model_name is not None and model_name.startswith("claude") 65 | else "openai" 66 | ) 67 | 68 | if image is None: 69 | image = imageio.imread(image_path, pilmode="RGB") 70 | image_encoded = encode_image(image) 71 | img_format = "jpeg" 72 | if mode == "openai" and model_name.startswith("gpt-"): 73 | message = { 74 | "type": "image_url", 75 | "image_url": { 76 | "url": f"data:image/{img_format};base64,{image_encoded}", 77 | "detail": image_detail, 78 | }, 79 | } 80 | elif mode == "openai": 81 | message = { 82 | "type": "image_url", 83 | "image_url": {"url": f"data:image/{img_format};base64,{image_encoded}"}, 84 | } 85 | else: 86 | message = { 87 | "type": "image", 88 | "source": { 89 | "type": "base64", 90 | "media_type": f"image/{img_format}", 91 | "data": image_encoded, 92 | }, 93 | } 94 | 95 | return message 96 | 97 | 98 | def get_video_as_messages( 99 | video_path: str, 100 | model_name: Optional[str] = None, 101 | subsampling_factor: int = 1, 102 | image_width: int = None, 103 | image_detail: str = "low", 104 | ): 105 | messages = [] 106 | with imageio.get_reader(video_path) as reader: 107 | for i, f in enumerate(reader): 108 | if i % subsampling_factor != 0: 109 | continue 110 | if image_width is not None: 111 | # Resize to fixed width 112 | h, w = f.shape[:2] 113 | image_height = int(float(h) / w * image_width) 114 | f = cv2.resize(f, (image_width, image_height)) 115 | message = get_image_as_message( 116 | image=f, model_name=model_name, image_detail=image_detail 117 | ) 118 | messages.append(message) 119 | 120 | return messages 121 | 122 | 123 | def convert_content_to_str( 124 | content: Any, save_dir: str, ignore_images: bool = False 125 | ) -> str: 126 | content_str = "" 127 | if isinstance(content, str): 128 | content_str = content + "\n" 129 | elif isinstance(content, list): 130 | for c in content: 131 | if isinstance(c, str): 132 | content_str += c + "\n" 133 | elif isinstance(c, dict) and c["type"] == "text": 134 | content_str += c["text"] + "\n" 135 | elif isinstance(c, dict) and c["type"] == "image": 136 | if not ignore_images: 137 | # Get file path 138 | time.sleep(1) 139 | time_now = datetime.now() 140 | time_now_str = time_now.strftime("%Y%m%d_%H%M%S") 141 | if save_dir is not None: 142 | img_save_path = f"{save_dir}/images/image_{time_now_str}.jpg" 143 | # Decode base64 string to image 144 | img_encoded = c["source"]["data"] 145 | img_decoded = decode_image(img_encoded) 146 | if img_decoded.shape[2] == 4: 147 | mask = img_decoded[..., 3] == 0 148 | img_decoded = img_decoded[..., :3] 149 | img_decoded[mask, :] = np.array([255, 255, 255]) 150 | imageio.imwrite(img_save_path, img_decoded) 151 | content_str += ( 152 | "\n\n" 153 | + Html.image( 154 | path=f"images/image_{time_now_str}.jpg", size="x300" 155 | ) 156 | + "\n\n" 157 | ) 158 | elif isinstance(c, dict) and c["type"] == "image_url": 159 | if not ignore_images: 160 | # Get file path 161 | time.sleep(1) 162 | time_now = datetime.now() 163 | time_now_str = time_now.strftime("%Y%m%d_%H%M%S") 164 | if save_dir is not None: 165 | img_save_path = f"{save_dir}/images/image_{time_now_str}.jpg" 166 | # Decode base64 string to image 167 | img_encoded = c["image_url"]["url"].split(";base64,")[1] 168 | img_decoded = decode_image(img_encoded) 169 | if img_decoded.shape[2] == 4: 170 | mask = img_decoded[..., 3] == 0 171 | img_decoded = img_decoded[..., :3] 172 | img_decoded[mask, :] = np.array([255, 255, 255]) 173 | imageio.imwrite(img_save_path, img_decoded) 174 | content_str += ( 175 | "\n\n" 176 | + Html.image( 177 | path=f"images/image_{time_now_str}.jpg", size="x300" 178 | ) 179 | + "\n\n" 180 | ) 181 | else: 182 | content_str += "\n" 183 | else: 184 | content_str += "\n" + str(c) + "\n" 185 | else: 186 | content_str = str(content) 187 | return content_str 188 | 189 | 190 | def count_images_in_query(content: Union[str, list[Any]]): 191 | n_images = 0 192 | assert isinstance(content, str) or isinstance(content, list) 193 | if isinstance(content, list): 194 | for c in content: 195 | if isinstance(c, dict) and c["type"] in ["image", "image_url"]: 196 | n_images += 1 197 | return n_images 198 | -------------------------------------------------------------------------------- /space/utils/habitat.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | import habitat_sim 4 | import numpy as np 5 | import quaternion as qt 6 | from habitat_sim.gfx import LightInfo 7 | 8 | EPSILON = 1e-8 9 | 10 | 11 | def get_rgb_cfg( 12 | resolution: list[int], agent_height: float 13 | ) -> habitat_sim.CameraSensorSpec: 14 | rgb_cfg = habitat_sim.CameraSensorSpec() 15 | rgb_cfg.uuid = "rgb" 16 | rgb_cfg.sensor_type = habitat_sim.SensorType.COLOR 17 | rgb_cfg.resolution = resolution 18 | rgb_cfg.position = np.array([0.0, max(agent_height - 0.1, 0.0), 0.0]) 19 | return rgb_cfg 20 | 21 | 22 | def get_depth_cfg( 23 | resolution: list[int], agent_height: float 24 | ) -> habitat_sim.CameraSensorSpec: 25 | depth_cfg = habitat_sim.CameraSensorSpec() 26 | depth_cfg.uuid = "depth" 27 | depth_cfg.sensor_type = habitat_sim.SensorType.DEPTH 28 | depth_cfg.resolution = resolution 29 | depth_cfg.position = np.array([0.0, max(agent_height - 0.1, 0.0), 0.0]) 30 | return depth_cfg 31 | 32 | 33 | def make_habitat_configuration( 34 | scene_id: str, 35 | scene_dataset_config_file: str, 36 | resolution: list[int], 37 | agent_height: float = 0.5, 38 | agent_radius: float = 0.1, 39 | forward_amount: float = 0.25, 40 | turn_amount: float = 10.0, 41 | enable_physics: bool = False, 42 | ): 43 | backend_cfg = habitat_sim.SimulatorConfiguration() 44 | backend_cfg.scene_id = scene_id 45 | backend_cfg.scene_dataset_config_file = scene_dataset_config_file 46 | backend_cfg.enable_physics = enable_physics 47 | 48 | agent_cfg = habitat_sim.agent.AgentConfiguration() 49 | agent_cfg.sensor_specifications = [ 50 | get_rgb_cfg(resolution, agent_height), 51 | get_depth_cfg(resolution, agent_height), 52 | ] 53 | 54 | agent_cfg.height = agent_height 55 | agent_cfg.radius = agent_radius 56 | agent_cfg.action_space["move_forward"].actuation.amount = forward_amount 57 | agent_cfg.action_space["turn_left"].actuation.amount = turn_amount 58 | agent_cfg.action_space["turn_right"].actuation.amount = turn_amount 59 | 60 | sim_cfg = habitat_sim.Configuration(backend_cfg, [agent_cfg]) 61 | 62 | return sim_cfg 63 | 64 | 65 | def load_sim( 66 | scene_id: str, 67 | scene_dataset_config_file, 68 | lights: list[LightInfo] = None, 69 | agent_height: float = 0.5, 70 | agent_radius: float = 0.1, 71 | **kwargs, 72 | ) -> habitat_sim.Simulator: 73 | sim_cfg = make_habitat_configuration( 74 | scene_id, 75 | scene_dataset_config_file, 76 | agent_height=agent_height, 77 | agent_radius=agent_radius, 78 | **kwargs, 79 | ) 80 | if lights is not None: 81 | sim_cfg.sim_cfg.scene_light_setup = "custom_scene_lighting" 82 | sim_cfg.sim_cfg.override_scene_light_defaults = True 83 | sim = habitat_sim.Simulator(sim_cfg) 84 | 85 | # Reload navmesh with appropriate height, radius 86 | navmesh_settings = habitat_sim.NavMeshSettings() 87 | navmesh_settings.set_defaults() 88 | navmesh_settings.agent_radius = agent_radius 89 | navmesh_settings.agent_height = agent_height 90 | sim.recompute_navmesh(sim.pathfinder, navmesh_settings) 91 | 92 | # Add lights if available 93 | if lights is not None: 94 | sim.set_light_setup(lights, "custom_scene_lighting") 95 | sim.reconfigure(sim_cfg) 96 | 97 | return sim 98 | 99 | 100 | def calculate_geodesic_distance( 101 | sim: habitat_sim.Simulator, start_position: np.ndarray, goal_position: np.ndarray 102 | ) -> tuple[float, bool]: 103 | """Calculate geodesic distance b/w two points 104 | 105 | Args: 106 | sim: habitat simulator instance 107 | start_position: start of path 108 | goal_position: end of path 109 | """ 110 | path = habitat_sim.ShortestPath() 111 | path.requested_start = start_position 112 | path.requested_end = goal_position 113 | found_path = sim.pathfinder.find_path(path) 114 | distance = path.geodesic_distance 115 | return distance, found_path 116 | 117 | 118 | def calculate_shortest_path( 119 | sim: habitat_sim.Simulator, start_position: np.ndarray, goal_position: np.ndarray 120 | ) -> tuple[list[np.ndarray], bool]: 121 | """Calculate shortest path b/w two points 122 | 123 | Args: 124 | sim: habitat simulator instance 125 | start_position: start of path 126 | goal_position: end of path 127 | """ 128 | path = habitat_sim.ShortestPath() 129 | path.requested_start = start_position 130 | path.requested_end = goal_position 131 | found_path = sim.pathfinder.find_path(path) 132 | return path.points, found_path 133 | 134 | 135 | def quaternion_from_two_vectors(v0: np.ndarray, v1: np.ndarray) -> qt.quaternion: 136 | r"""Computes the quaternion representation of v1 using v0 as the origin.""" 137 | v0 = v0 / np.linalg.norm(v0) 138 | v1 = v1 / np.linalg.norm(v1) 139 | c = v0.dot(v1) 140 | # Epsilon prevents issues at poles. 141 | if c < (-1 + EPSILON): 142 | c = max(c, -1) 143 | m = np.stack([v0, v1], 0) 144 | _, _, vh = np.linalg.svd(m, full_matrices=True) 145 | axis = vh.T[:, 2] 146 | w2 = (1 + c) * 0.5 147 | w = np.sqrt(w2) 148 | axis = axis * np.sqrt(1 - w2) 149 | return qt.quaternion(w, *axis) 150 | 151 | axis = np.cross(v0, v1) 152 | s = np.sqrt((1 + c) * 2) 153 | return qt.quaternion(s * 0.5, *(axis / s)) 154 | 155 | 156 | def quaternion_rotate_vector(quat: qt.quaternion, v: np.ndarray) -> np.ndarray: 157 | r"""Rotates a vector by a quaternion 158 | Args: 159 | quaternion: The quaternion to rotate by 160 | v: The vector to rotate 161 | Returns: 162 | np.ndarray: The rotated vector 163 | """ 164 | vq = qt.quaternion(0, 0, 0, 0) 165 | vq.imag = v 166 | return (quat * vq * quat.inverse()).imag 167 | 168 | 169 | def compute_heading_from_quaternion(r) -> float: 170 | """ 171 | r - rotation quaternion 172 | 173 | Computes clockwise rotation about Y. 174 | """ 175 | # quaternion - np.quaternion unit quaternion 176 | # Real world rotation 177 | direction_vector = np.array([0, 0, -1]) # Forward vector 178 | heading_vector = quaternion_rotate_vector(r.inverse(), direction_vector) 179 | 180 | phi = -np.arctan2(heading_vector[0], -heading_vector[2]).item() 181 | return phi 182 | 183 | 184 | def compute_quaternion_from_heading(h_deg: float) -> qt.quaternion: 185 | """Calculates quaternion corresponding to heading. 186 | 187 | Args: 188 | h_deg: Clockwise rotation about Y in degrees. 189 | """ 190 | h = np.deg2rad(h_deg) 191 | fwd_dir = np.array([0.0, 0.0, -1.0]) 192 | head_dir = np.array([np.sin(h), 0.0, -np.cos(h)]) 193 | quat = quaternion_from_two_vectors(fwd_dir, head_dir) 194 | return quat 195 | 196 | 197 | def quaternion_to_list(q: qt.quaternion): 198 | return q.imag.tolist() + [q.real] 199 | 200 | 201 | class DistanceToGoal: 202 | def __init__(self, sim: habitat_sim.Simulator, goal_position: np.ndarray): 203 | """ 204 | Class to compute distance to goal metric 205 | 206 | Arguments: 207 | sim: habitat simulator instance 208 | goal_position: (x, y, z) location of goal (in meters) 209 | """ 210 | self.sim = sim 211 | self.goal_position = goal_position 212 | 213 | def __call__(self, trajectory_positions: list[np.ndarray]) -> float: 214 | last_position = trajectory_positions[-1] 215 | distance, found_path = calculate_geodesic_distance( 216 | self.sim, last_position, self.goal_position 217 | ) 218 | assert found_path, "Could not find a path in DistanceToGoal" 219 | return distance 220 | 221 | 222 | class Success: 223 | def __init__( 224 | self, 225 | sim: habitat_sim.Simulator, 226 | goal_position: np.ndarray, 227 | dist_thresh: float = 1.0, 228 | ): 229 | """ 230 | Class to compute success metric 231 | 232 | Arguments: 233 | sim: habitat simulator instance 234 | goal_position: (x, y, z) location of goal viewpoint (in meters) 235 | dist_thresh: geodesic distance threshold to determine success (in meters) 236 | """ 237 | self.sim = sim 238 | self.goal_position = goal_position 239 | self.dist_thresh = dist_thresh 240 | self._d2g = DistanceToGoal(sim, goal_position) 241 | 242 | def __call__( 243 | self, 244 | last_action_was_stop: bool, 245 | trajectory_positions: list[np.ndarray], 246 | ) -> float: 247 | # If last action called was not STOP, then success is 0 by definition 248 | if not last_action_was_stop: 249 | success = 0.0 250 | else: 251 | d2g = self._d2g(trajectory_positions) 252 | if d2g <= self.dist_thresh: 253 | success = 1.0 254 | else: 255 | success = 0.0 256 | return success 257 | 258 | 259 | class SPL: 260 | def __init__( 261 | self, 262 | sim: habitat_sim.Simulator, 263 | start_position: np.ndarray, 264 | goal_position: np.ndarray, 265 | dist_thresh: float = 1.0, 266 | ): 267 | """ 268 | Class to compute Success weighted by Path Length 269 | 270 | Reference: https://arxiv.org/pdf/1807.06757.pdf 271 | 272 | Arguments: 273 | sim: habitat simulator instance 274 | start_position: (x, y, z) location of start (in meters) 275 | goal_position: (x, y, z) location of goal viewpoint (in meters) 276 | dist_thresh: geodesic distance threshold to determine success (in meters) 277 | """ 278 | self.sim = sim 279 | self._success = Success(sim, goal_position, dist_thresh) 280 | # Calculate shortest path length 281 | distance, found_path = calculate_geodesic_distance( 282 | sim, start_position, goal_position 283 | ) 284 | assert found_path, "Could not find a path in SPL.__init__()" 285 | self._shortest_path_length = distance 286 | 287 | def __call__( 288 | self, last_action_was_stop: bool, trajectory_positions: list[np.ndarray] 289 | ): 290 | success = self._success(last_action_was_stop, trajectory_positions) 291 | if success == 0.0: 292 | spl = 0.0 293 | else: 294 | current_path_length = self.calculate_path_length(trajectory_positions) 295 | spl = max(self._shortest_path_length, EPSILON) / max( 296 | current_path_length, self._shortest_path_length, EPSILON 297 | ) 298 | return spl 299 | 300 | def calculate_path_length(self, trajectory_positions: list[np.ndarray]): 301 | path_length = 0.0 302 | for p1, p2 in zip(trajectory_positions[:-1], trajectory_positions[1:]): 303 | distance, found_path = calculate_geodesic_distance(self.sim, p1, p2) 304 | assert found_path, "Could not find a path in SPL.calculate_path_length()" 305 | path_length += distance 306 | return path_length 307 | -------------------------------------------------------------------------------- /space/utils/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from func_timeout import FunctionTimedOut, func_timeout 3 | from space.utils.claude_api import get_model_response as get_model_response_claude 4 | from space.utils.openai_api import get_model_response as get_model_response_openai 5 | 6 | 7 | def get_model_response( 8 | client: Any, 9 | dialog: list[Any], 10 | model_name: str, 11 | *args, 12 | max_retries: int = 10, 13 | max_response_wait_secs_per_retry: float = 300.0, 14 | **kwargs, 15 | ): 16 | max_response_wait_secs = max_retries * max_response_wait_secs_per_retry 17 | try: 18 | if model_name.startswith("claude"): 19 | response = func_timeout( 20 | max_response_wait_secs, 21 | get_model_response_claude, 22 | args=(client, dialog, model_name, *args), 23 | kwargs=kwargs, 24 | ) 25 | else: 26 | response = func_timeout( 27 | max_response_wait_secs, 28 | get_model_response_openai, 29 | args=(client, dialog, model_name, *args), 30 | kwargs=kwargs, 31 | ) 32 | except FunctionTimedOut: 33 | print(f"get_model_response() timed out after {max_response_wait_secs} secs!") 34 | response = { 35 | "text": "timed out", 36 | "completion_tokens": 0, 37 | "prompt_tokens": 0, 38 | "response_time": max_response_wait_secs, 39 | } 40 | return response 41 | -------------------------------------------------------------------------------- /space/utils/openai_api.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | import copy 4 | import json 5 | import os 6 | import time 7 | from typing import Any 8 | 9 | from openai import OpenAI 10 | from space.utils.common import convert_content_to_str 11 | 12 | 13 | def setup_llm_client(model_name, host_port: str = "8000"): 14 | if model_name.startswith("gpt-"): 15 | assert "OPENAI_API_KEY" in os.environ 16 | api_key = os.environ["OPENAI_API_KEY"] 17 | base_url = None 18 | else: 19 | api_key = "EMPTY" 20 | base_url = f"http://localhost:{host_port}/v1" 21 | client = OpenAI(api_key=api_key, base_url=base_url) 22 | return client 23 | 24 | 25 | class Dialog: 26 | def __init__(self, log_writer=None): 27 | self.history = [] 28 | self.log_writer = log_writer 29 | self.images_save_dir = None 30 | self.log_dir = None 31 | if log_writer is not None: 32 | self.log_dir = os.path.dirname(self.log_writer.file_name) 33 | self.log_count = 0 34 | self.images_save_dir = os.path.dirname(self.log_writer.file_name) 35 | os.makedirs(os.path.join(self.images_save_dir, "images"), exist_ok=True) 36 | 37 | def add_system_message(self, **kwargs): 38 | self.history.append({"role": "system", **kwargs}) 39 | self.log_to_file( 40 | "system", convert_content_to_str(kwargs["content"], self.images_save_dir) 41 | ) 42 | 43 | def add_user_message(self, **kwargs): 44 | self.history.append({"role": "user", **kwargs}) 45 | self.log_to_file( 46 | "user", convert_content_to_str(kwargs["content"], self.images_save_dir) 47 | ) 48 | 49 | def add_assistant_message(self, **kwargs): 50 | self.history.append({"role": "assistant", **kwargs}) 51 | self.log_to_file( 52 | "assistant", 53 | convert_content_to_str(kwargs["content"], self.images_save_dir), 54 | ) 55 | 56 | def add_inner_thoughts(self, **kwargs): 57 | if kwargs["content"] is not None: 58 | self.log_to_file( 59 | "assistant (inner thoughts)", 60 | convert_content_to_str(kwargs["content"], self.images_save_dir), 61 | bold_italics_code="bi", 62 | ) 63 | 64 | def log_to_file(self, role, content_str, bold_italics_code="b"): 65 | if self.log_writer is not None: 66 | self.log_writer.new_paragraph( 67 | f"{role.upper()}", bold_italics_code=bold_italics_code 68 | ) 69 | self.log_writer.new_paragraph("\n" + content_str + "\n") 70 | self.log_writer.create_md_file() 71 | 72 | def log_token_usage( 73 | self, prompt_tokens: int, completion_tokens: int, total_cost: int = None 74 | ): 75 | if self.log_writer is not None: 76 | self.log_writer.write("\n") 77 | self.log_writer.write( 78 | "====> Token usage: prompt tokens: {}, completion_tokens: {}".format( 79 | prompt_tokens, completion_tokens 80 | ) 81 | ) 82 | if total_cost is not None: 83 | self.log_writer.write("\n") 84 | self.log_writer.write(f"====> Total cost: ${total_cost:.3f}") 85 | self.log_writer.write("\n\n") 86 | self.log_writer.create_md_file() 87 | 88 | def log_response_time(self, time_taken: float): 89 | if self.log_writer is not None: 90 | self.log_writer.write("\n") 91 | self.log_writer.write(f"====> Response time (sec): {time_taken:.3f}") 92 | self.log_writer.write("\n\n") 93 | self.log_writer.create_md_file() 94 | 95 | @property 96 | def dialog(self): 97 | return copy.deepcopy(self.history) 98 | 99 | def delete_last_message(self): 100 | del self.history[-1] 101 | 102 | def clear_history(self): 103 | self.history = [] 104 | 105 | def clone(self): 106 | dialog_clone = Dialog(self.log_writer) 107 | dialog_clone.history = copy.deepcopy(self.history) 108 | return dialog_clone 109 | 110 | def write_dialog(self): 111 | if self.log_writer is not None: 112 | save_path = os.path.join(self.log_dir, f"dialog_{self.log_count:05d}.json") 113 | with open(save_path, "w") as fp: 114 | json.dump(self.history, fp) 115 | save_path = os.path.join(self.log_dir, f"dialog_{self.log_count:05d}.txt") 116 | with open(save_path, "w") as fp: 117 | for h in self.history: 118 | content_str = convert_content_to_str( 119 | h["content"], save_dir=None, ignore_images=True 120 | ) 121 | role_str = h["role"].upper() 122 | fp.write(f"{role_str}: {content_str}\n\n") 123 | self.log_count += 1 124 | else: 125 | print( 126 | "WARNING: Dialog object does not have a log_writer, so write_dialog() failed..." 127 | ) 128 | 129 | 130 | # Function to get response from model 131 | def get_model_response( 132 | client: Any, 133 | dialog: list[Any], 134 | model_name: str = "gpt-4-vision-preview", 135 | temperature: float = 0.5, 136 | max_tokens: int = 1000, 137 | frequency_penalty: float = 0.0, 138 | max_retries: int = 10, 139 | sleep_secs: int = 60, 140 | verbose: bool = True, 141 | **kwargs, 142 | ): 143 | num_retries = 0 144 | start_time = time.time() 145 | excptn_info = {} 146 | while num_retries < max_retries: 147 | try: 148 | response = client.chat.completions.create( 149 | model=model_name, 150 | messages=dialog, 151 | temperature=temperature, 152 | max_tokens=max_tokens, 153 | frequency_penalty=frequency_penalty, 154 | **kwargs, 155 | ) 156 | response_txt = response.choices[0].message.content 157 | token_counts = { 158 | "completion_tokens": response.usage.completion_tokens, 159 | "prompt_tokens": response.usage.prompt_tokens, 160 | } 161 | break 162 | except Exception as excptn: 163 | num_retries += 1 164 | if num_retries >= max_retries: 165 | if verbose: 166 | print(excptn) 167 | excptn_info = { 168 | "exception_message": excptn.message, 169 | "exception_type": excptn.type, 170 | "exception_code": excptn.code, 171 | } 172 | time.sleep(sleep_secs) 173 | 174 | if num_retries >= max_retries: 175 | if verbose: 176 | print(f"===> Failed after {max_retries} retries") 177 | response_txt = None 178 | token_counts = {} 179 | 180 | time_taken = time.time() - start_time 181 | output = { 182 | "text": response_txt, 183 | **token_counts, 184 | "response_time": time_taken, 185 | **excptn_info, 186 | } 187 | return output 188 | -------------------------------------------------------------------------------- /space/utils/visualizations.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | import cv2 4 | import numpy as np 5 | 6 | 7 | def add_goal_to_obs(obs_img: np.ndarray, goal_img: np.ndarray) -> np.ndarray: 8 | H, W, _ = obs_img.shape 9 | 10 | goal_img = np.copy(goal_img) 11 | goal_img = cv2.resize(goal_img, (W // 6, H // 6)) 12 | 13 | font = cv2.FONT_HERSHEY_SIMPLEX 14 | 15 | # org 16 | org = (5, 25) 17 | 18 | # fontScale 19 | fontScale = 0.7 20 | 21 | color = (0, 0, 0) 22 | 23 | # Line thickness of 2 px 24 | thickness = 1 25 | 26 | # Using cv2.putText() method 27 | goal_img = cv2.putText( 28 | goal_img, "Goal", org, font, fontScale, color, thickness, cv2.LINE_AA 29 | ) 30 | 31 | # Draw border 32 | goal_img = cv2.rectangle( 33 | goal_img, (0, 0), (goal_img.shape[1] - 1, goal_img.shape[0] - 1), color, 2 34 | ) 35 | 36 | obs_img = np.copy(obs_img) 37 | obs_img[5 : 5 + goal_img.shape[0], W - 5 - goal_img.shape[1] : W - 5] = goal_img 38 | 39 | return obs_img 40 | 41 | 42 | def add_text_to_image( 43 | img: np.ndarray, 44 | text: str, 45 | origin: list[int] = None, 46 | add_background: bool = False, 47 | font_scale: int = 1, 48 | color=(255, 255, 255), 49 | thickness=2, 50 | ): 51 | img = np.copy(img) 52 | 53 | font = cv2.FONT_HERSHEY_SIMPLEX 54 | # Using cv2.putText() method 55 | textsize = cv2.getTextSize(text, font, font_scale, thickness)[0] 56 | if origin is None: 57 | org = ((img.shape[1] - textsize[0]) // 2, textsize[1] + 20) 58 | else: 59 | org = tuple(origin) 60 | # Add black background if needed 61 | if add_background: 62 | start_x = org[0] - 5 63 | end_x = start_x + textsize[0] + 10 64 | start_y = org[1] - textsize[1] - 5 65 | end_y = start_y + textsize[1] + 10 66 | img_crop = img[start_y:end_y, start_x:end_x] 67 | img_blend = cv2.addWeighted(img_crop, 0.3, np.zeros_like(img_crop), 0.7, 1.0) 68 | img[start_y:end_y, start_x:end_x] = img_blend 69 | 70 | img = cv2.putText(img, text, org, font, font_scale, color, thickness, cv2.LINE_AA) 71 | 72 | return img 73 | -------------------------------------------------------------------------------- /space/utils/vllm_api.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 3 | from typing import Any 4 | 5 | import subprocess as sp 6 | import time 7 | import torch 8 | import openai 9 | from openai import OpenAI 10 | 11 | 12 | def is_server_ready(model_name: str, host_port: str, vllm_cfg: dict[str, Any]): 13 | try: 14 | client = OpenAI(base_url=f"http://localhost:{host_port}/v1", api_key="EMPTY") 15 | models_available = [i.id for i in client.models.list().data] 16 | if model_name in models_available: 17 | return True 18 | else: 19 | return False 20 | except openai.APIConnectionError: 21 | return False 22 | 23 | 24 | def start_vllm_server(model_name: str, host_port: str, vllm_cfg: dict[str, Any]): 25 | # Check if server is already available 26 | if is_server_ready(model_name, host_port, vllm_cfg): 27 | return 28 | 29 | # Host server 30 | if "tensor_parallel_size" in vllm_cfg and vllm_cfg["tensor_parallel_size"] is None: 31 | vllm_cfg["tensor_parallel_size"] = torch.cuda.device_count() 32 | command = ["vllm", "serve", model_name, "--port", host_port] 33 | for k, v in vllm_cfg.items(): 34 | if k in ["enable_prefix_caching", "trust_remote_code"] and v: 35 | command.append(f"--{k}") 36 | else: 37 | command.extend([f"--{k}", f"{v}"]) 38 | sp.Popen(command, stderr=sp.DEVNULL, stdout=sp.DEVNULL) 39 | 40 | # Wait till vllm server is ready 41 | while not is_server_ready(model_name, host_port, vllm_cfg): 42 | print( 43 | f"Model {model_name} not yet available on vLLM. Waiting for server to be up..." 44 | ) 45 | time.sleep(5) 46 | --------------------------------------------------------------------------------