├── .coveragerc ├── .github └── workflows │ └── python-package.yml ├── .gitignore ├── EXTENDING.md ├── LICENSE ├── README.md ├── assets └── equivalence.png ├── dataset-config.zip ├── dataset_generator.py ├── evaluate.py ├── experiment_cfgs ├── gemma-1.1-7b-it-finetune.yaml └── gemma-1.1-7b-it-zero-shot.yaml ├── finetune.py ├── llm_planner.py ├── planetarium ├── ASSUMPTIONS.md ├── __init__.py ├── builder.py ├── domains │ ├── blocksworld.pddl │ ├── floor-tile.pddl │ ├── gripper.pddl │ ├── rover-single.pddl │ └── rover.pddl ├── downward.py ├── evaluate.py ├── graph.py ├── metric.py ├── oracle.py ├── oracles │ ├── __init__.py │ ├── blocksworld.py │ ├── floortile.py │ ├── gripper.py │ ├── oracle.py │ └── rover_single.py └── reduced_graph.py ├── poetry.lock ├── pyproject.toml └── tests ├── __init__.py ├── problem_fixtures.py ├── test_evaluate.py ├── test_graph.py ├── test_metric.py ├── test_oracle.py ├── test_pddl.py └── test_planner.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | planetarium/oracles/oracle.py -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: [push] 4 | 5 | jobs: 6 | test: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: checkout repository 10 | uses: actions/checkout@v3 11 | 12 | - name: set up python 13 | id: setup-python 14 | uses: actions/setup-python@v4 15 | with: 16 | python-version: "3.10" 17 | 18 | - name: Setup apptainer 19 | uses: eWaterCycle/setup-apptainer@v2.0.0 20 | 21 | - name: install poetry 22 | uses: snok/install-poetry@v1 23 | with: 24 | virtualenvs-create: true 25 | virtualenvs-in-project: true 26 | installer-parallel: true 27 | 28 | - name: load cached venv 29 | id: cached-poetry-dependencies 30 | uses: actions/cache@v3 31 | with: 32 | path: .venv 33 | key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} 34 | 35 | - name: load downward 36 | id: cached-downward 37 | uses: actions/cache@v3 38 | with: 39 | path: tmp/ 40 | key: downward-${{ runner.os }}-${{ hashFiles('**/VAL.zip') }}-${{ hashFiles('fast-downward.sif') }} 41 | 42 | - name: install dependencies 43 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 44 | run: poetry install --no-interaction --no-root 45 | 46 | - name: install downward 47 | if: steps.cached-downward.outputs.cache-hit != 'true' 48 | run: | 49 | mkdir tmp 50 | apptainer pull fast-downward.sif docker://aibasel/downward:latest 51 | mv fast-downward.sif tmp/ 52 | curl -o tmp/VAL.zip https://dev.azure.com/schlumberger/4e6bcb11-cd68-40fe-98a2-e3777bfec0a6/_apis/build/builds/77/artifacts?artifactName=linux64\&api-version=7.1\&%24format=zip 53 | unzip tmp/VAL.zip -d tmp/ 54 | tar -xzvf tmp/linux64/*.tar.gz -C tmp/ --strip-components=1 55 | 56 | - name: install project 57 | run: poetry install --no-interaction 58 | 59 | - name: lint 60 | run: 61 | source .venv/bin/activate 62 | poetry run black --check planetarium/. tests/. 63 | poetry ruff planetarium/. 64 | poetry mypy planetarium/. 65 | 66 | - name: test 67 | run: | 68 | source .venv/bin/activate 69 | export DOWNWARD=$(pwd)/tmp/fast-downward.sif 70 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$(pwd)/tmp/bin 71 | export PATH=$PATH:$(pwd)/tmp/bin 72 | poetry run pytest --cov-fail-under=90 --cov=planetarium --timeout=120 tests/. 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | venv 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 112 | .pdm.toml 113 | .pdm-python 114 | .pdm-build/ 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | test_pddl_generator.py 166 | -------------------------------------------------------------------------------- /EXTENDING.md: -------------------------------------------------------------------------------- 1 | # Extending Planetarium 2 | 3 | If you're looking to evaluate your own domain, here is a guide to help you get started. 4 | 5 | ### 1. Add a domain file 6 | Add a domain PDDL file to the `planetarium/domains/` directory, where the filename is the name of your domain. 7 | 8 | ### 2. Add an Oracle 9 | 10 | Every domain in Planetarium requires an [`Oracle`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/oracles/oracle.py#L8) object and file tied to it. 11 | There are three fundamental components to the Oracle: 12 | 13 | - [`.reduce()`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/oracles/oracle.py#L11) function, which takes a [`ProblemGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/graph.py#L430) or [`SceneGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/graph.py#L391) object and returns a [`ReducedProblemGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/reduced_graph.py#L80) or [`ReducedSceneGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/reduced_graph.py#L48) object, respectively. 14 | - [`.inflate()`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/oracles/oracle.py#L65) function, which takes a [`ReducedProblemGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/reduced_graph.py#L80) or [`ReducedSceneGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/reduced_graph.py#L48) object and returns a [`ProblemGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/graph.py#L430) or [`SceneGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/graph.py#L391) object, respectively. 15 | - [`.fully_specify()`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/oracles/oracle.py#L125) function, which takes a [`ProblemGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/graph.py#L430) and returns either a [`ProblemGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/graph.py#L430) or a [`ReducedProblemGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/reduced_graph.py#L80) with all possible predicates added to the goal scene without changing the original definition of the problem. 16 | We refer to these predicates as "trivial predicates" in our paper. 17 | The `fully_specify` function is used to ensure that the problem is fully specified before evaluation. 18 | 19 | To add your domain, you must create a python script under `planetarium/oracles/` that contains your implementation of an Oracle subclass. 20 | This script should contain a class that inherits from `Oracle` and implements the three functions described above. 21 | While we provide a generic `reduce` and `inflate` function in the base `Oracle` class, you will definitely want to override these functions if your domain has any 0-ary, 1-ary, or (3+)-ary predicates (more info below). 22 | 23 | #### 2.1 Remember to add your Oracle script to the `__init__.py` file in the `planetarium/oracles/` directory: 24 | ```python 25 | # planetarium/oracles/__init__.py 26 | __all__ = ["blocksworld", "gripper", "rover_single", "floortile", "YOUR_DOMAIN"] 27 | 28 | from . import oracle 29 | 30 | from . import blocksworld 31 | from . import gripper 32 | from . import rover_single 33 | from . import floortile 34 | from . import YOUR_DOMAIN 35 | 36 | ORACLES: dict[str, oracle.Oracle] = { 37 | "blocksworld": blocksworld.BlocksworldOracle, 38 | "gripper": gripper.GripperOracle, 39 | "rover-single": rover_single.RoverSingleOracle, 40 | "floor-tile": floortile.FloorTileOracle, 41 | "YOUR_DOMAIN": YOUR_DOMAIN.YourDomainOracle, 42 | } 43 | ``` 44 | 45 | #### 2.2 Working with non-binary predicates (Using `ReducedNode`s) 46 | Some predicates require special care for reducing and inflating. 47 | The key idea behind our reduced representation is to represent our PDDL problem in a domain-specific manner with as few graph nodes and edges as possible to reduce the search space for our equivalence check. 48 | The reduced representation also allows us to perform higher-level graph manipulations and searches more efficiently. 49 | 50 | **[`ReducedNode`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/reduced_graph.py#L9)**: 51 | 52 | A [`ReducedNode`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/reduced_graph.py#L9) is a domain-specific set of nodes that will be added to every `ReducedSceneGraph` or `ReducedProblemGraph` object on construction. They help hold metadata for specific types of predicates (non-binary predicates, 0-ary predicates, etc.) that are defined by the domain. 53 | 54 | Here is an example on how to reduce different -ary predicates using `ReducedNode`s: 55 | 56 | **0-ary predicates**: 57 | These predicates can be represented by using a `ReducedNode` with a `name` attribute that matches the predicate name. 58 | If the predicate is true in the scene, one way to handle this is by adding a self-edge on this `ReducedNode` to represent the predicate. 59 | 60 | **1-ary predicates**: 61 | These predicates can be represented by using a `ReducedNode` with a `name` attribute that matches the predicate name. 62 | If the predicate is true in the scene, we can add an edge from the `ReducedNode` to the node that represents the object in the scene. 63 | 64 | **3+-ary predicates**: 65 | There is no easy way to reduce these predicates, so the best way to keep track of these is to simply add a predicate node to the `ReducedSceneGraph` or `ReducedProblemGraph` object that represents the predicate, and add edges to the nodes that represent the objects in the scene. 66 | Make sure to set the `position=` argument when adding the edge to the reduced graph to ensure you can reverse this action in your `inflate` function. 67 | 68 | **To register your `ReducedNode`**: 69 | At the top of your oracle script, you can call the [`ReducedNode.register`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/reduced_graph.py#L12) method, like the following: 70 | 71 | ```python 72 | ReducedNode.register( 73 | { 74 | # ReducedNodeName: corresponding predicate name 75 | "ROOMS": "room", 76 | "BALLS": "ball", 77 | "GRIPPERS": "gripper", 78 | "ROBBY": "at-robby", 79 | "FREE": "free", 80 | }, 81 | "YOUR_DOMAIN", # name of your domain 82 | ) 83 | ``` 84 | 85 | This will let you use your `ReducedNode` like any other enum throughout your oracle script (e.g. `ReducedNode.ROOMS`). 86 | 87 | #### 2.3 Implementing `.plan()` (Optional) 88 | If you would like to evaluate whether or not a problem is _solvable_, you can implement the [`.plan()`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/oracles/oracle.py#L141) function in your `Oracle`. 89 | You will still be able to evaluate whether or not a problem is solvable without implementing this function, but it will rely on running the FastDownward planner to solve your problem, which may be *significantly* slower than using a domain-specific planner. 90 | (Note that you should try using the `lama-first` alias if possible, as this planner does not look for the optimal plan, just a satisficing plan.) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024 Brown University 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # planetarium🪐 2 | 3 |

4 | 5 | 6 | 7 | 8 |

9 | 10 | Planetarium🪐 is a [dataset](https://huggingface.co/datasets/BatsResearch/planetarium) and benchmark for assessing LLMs in translating natural language descriptions of planning problems into PDDL. We developed a robust method for comparing PDDL problem descriptions using graph isomorphism. 11 | 12 | ## Installation 13 | To install the `planetarium` package, you can use the following command: 14 | ```bash 15 | pip install git+https://github.com/BatsResearch/planetarium.git 16 | ``` 17 | 18 | For development or using our evaluate & finetune scripts, you can clone the repository and install all dependencies using the following commands: 19 | ```bash 20 | git clone https://github.com/BatsResearch/planetarium.git 21 | cd planetarium 22 | poetry install --with all 23 | ``` 24 | 25 | To use `planetarium.downward`, you will need to have the [Fast-Downward](https://www.fast-downward.org/) planner installed, and the [VAL](https://github.com/KCL-Planning/VAL) plan validator. The following commands is one way to install them with minimal overhead: 26 | ```bash 27 | # Fast-Downward via Apptainer 28 | apptainer pull fast-downward.sif docker://aibasel/downward:latest 29 | # VAL download link might not work, follow instructions to download binary at: https://github.com/KCL-Planning/VAL 30 | mkdir tmp 31 | curl -o tmp/VAL.zip https://dev.azure.com/schlumberger/4e6bcb11-cd68-40fe-98a2-e3777bfec0a6/_apis/build/builds/77/artifacts?artifactName=linux64\&api-version=7.1\&%24format=zip 32 | unzip tmp/VAL.zip -d tmp/ 33 | tar -xzvf tmp/linux64/*.tar.gz -C tmp/ --strip-components=1 34 | # clean up 35 | rm -rf tmp 36 | # Make sure to add fast-downward.sif and VAL to your PATH or make aliases. 37 | ``` 38 | 39 | ## Basic Usage 40 | To evaluate a PDDL problem description, we can use the `planetarium.evaluate` module: 41 | ```python 42 | import planetarium 43 | ... 44 | planetarium.evaluate(gt_pddl_str, pred_pddl_str) 45 | ``` 46 | The supported domains are `blocksworld` and `gripper` domains. 47 | 48 | ## Dataset 49 | The main page for the dataset can be found [here](https://huggingface.co/datasets/BatsResearch/planetarium). 50 | 51 | Here is an example of how to load the dataset: 52 | ```python 53 | from datasets import load_dataset 54 | 55 | dataset = load_dataset("BatsResearch/planetarium") 56 | ``` 57 | Here, `dataset["test"]` is the main test set used in the paper. You may evaluate on this set to reproduce our results. 58 | 59 | You can reproduce the dataset, the splits, and a report by running the following command: 60 | ```bash 61 | python dataset_generator.py -c dataset_config.yaml 62 | ``` 63 | 64 | By modifying the `dataset_config.yaml` file, you can change the dataset splits, the number of samples, and produce even more examples! 65 | 66 | ### Dataset Report 67 | Here is a summary of the types of PDDL problems in the dataset: 68 | 69 | Total number of problems: $132,037$. 70 | 71 | #### Abstractness Split 72 | | Init | Goal | blocksworld | gripper | 73 | |:---:|:---:|---:|---:| 74 | | abstract | abstract | $23,144$ | $10,632$ | 75 | | abstract | explicit | $23,086$ | $9,518$ | 76 | | explicit | abstract | $23,087$ | $10,313$ | 77 | | explicit | explicit | $23,033$ | $9,224$ | 78 | #### Size Splits (Number of Propositions in Ground Truth) 79 | | Num. of Propositions | blocksworld | gripper | 80 | |:---:|---:|---:| 81 | | $0$ - $20$ | $1,012$ | $379$ | 82 | | $20$ - $40$ | $10,765$ | $2,112$ | 83 | | $40$ - $60$ | $50,793$ | $9,412$ | 84 | | $60$ - $80$ | $26,316$ | $25,346$ | 85 | | $80$ - inf | $3,464$ | $2,438$ | 86 | 87 | ## How it Works 88 | Planetarium🪐 compares two PDDL problem descriptions by first transcribing them into a graph representation. 89 | Graphs help us to better detect and manipulate relationships between certain objects and propositions. 90 | Next, we build "fully specified" graph representations by adding "trivial" propositions (propositions that do not exist in the problem description but must exist in any state that satisfies such description). 91 | Finally, we use graph isomorphism to compare the fully specified graph representations of the two PDDL problem descriptions, either comparing the entire problem graph or the individual initial and goal scene graphs. 92 | This lets check correctness of the translation of the natural language description into PDDL, without ever needing to run a planner. 93 | 94 | Below is a flowchart providing an overview of the equivalence algorithm: 95 | 96 | ![Equivalence Algorithm Overview](assets/equivalence.png) 97 |

(Left) Two planning problems, in PDDL problem description, real-world scenario, and graph representations. (Center) Fully specified graph representation. (Right) Graph isomorphism.

98 | 99 | The key to this algorithm working is building a specially crafted "fully specify" function, which we build for each domain that we want to support. We provide implementations for the `blocksworld` and `gripper` domains in the `planetarium.oracle` module. 100 | 101 | ## Adding a new domain 102 | For adding a new domain, please take a look at [this guide](EXTENDING.md). -------------------------------------------------------------------------------- /assets/equivalence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BatsResearch/planetarium/f257ae9e68b813aeace00cc683d0c2ad5b57a157/assets/equivalence.png -------------------------------------------------------------------------------- /dataset-config.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BatsResearch/planetarium/f257ae9e68b813aeace00cc683d0c2ad5b57a157/dataset-config.zip -------------------------------------------------------------------------------- /experiment_cfgs/gemma-1.1-7b-it-finetune.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | database_path: dataset.db 3 | splits_path: splits.yaml 4 | splits: 5 | train: 6 | - [blocksworld, random, "0"] 7 | - [gripper, random, "0"] 8 | - [blocksworld, random, "1"] 9 | - [gripper, random, "1"] 10 | - [blocksworld, random, "2"] 11 | - [gripper, random, "2"] 12 | - [blocksworld, random, "3"] 13 | - [gripper, random, "3"] 14 | test: 15 | - [blocksworld, random, "4"] 16 | - [gripper, random, "4"] 17 | prompts: 18 | problem: "Provide me with the complete, valid problem PDDL file that \ 19 | describes the following planning problem directly without further \ 20 | explanations or texts." 21 | domain: "The domain for the planning problem is:" 22 | train: 23 | lora_config: 24 | r: 16 25 | lora_alpha: 32 26 | lora_dropout: 0.05 27 | bias: none 28 | task_type: CAUSAL_LM 29 | bnb_config: 30 | bits: 4 31 | double_quant: True 32 | quant_type: nf4 33 | training_args: 34 | output_dir: /oscar/scratch/mzuo6/gemma-logs/ 35 | evaluation_strategy: epoch 36 | learning_rate: .00002 37 | num_train_epochs: 1 38 | per_device_train_batch_size: 1 39 | per_device_eval_batch_size: 1 40 | weight_decay: 0.01 41 | report_to: wandb 42 | bf16: True 43 | bf16_full_eval: True 44 | save_path: gemma-1.1-7b-it-random/ 45 | model: 46 | type: hf 47 | tokenizer_name: google/gemma-1.1-7b-it 48 | model_name: google/gemma-1.1-7b-it 49 | response_template: "model\n" 50 | max_seq_length: 1500 -------------------------------------------------------------------------------- /experiment_cfgs/gemma-1.1-7b-it-zero-shot.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | database_path: dataset.db 3 | splits_path: splits.yaml 4 | splits: 5 | test: 6 | - [blocksworld, random, "4"] 7 | - [gripper, random, "4"] 8 | prompts: 9 | problem: "Provide me with the complete, valid problem PDDL file that \ 10 | describes the following planning problem directly without further \ 11 | explanations or texts." 12 | domain: "The domain for the planning problem is:" 13 | evaluate: 14 | splits: 15 | - test 16 | batch_size: 64 17 | model: 18 | type: hf 19 | tokenizer_name: google/gemma-1.1-7b-it 20 | model_name: google/gemma-1.1-7b-it 21 | response_template: "model\n" 22 | kwargs: 23 | attn_implementation: flash_attention_2 -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from functools import partial 3 | import os 4 | import sqlite3 5 | import yaml 6 | 7 | import dotenv 8 | 9 | dotenv.load_dotenv() 10 | 11 | import torch 12 | from torch import nn 13 | 14 | import bitsandbytes as bnb 15 | from datasets import Dataset 16 | from peft import LoraConfig, get_peft_model 17 | from transformers import ( 18 | AutoTokenizer, 19 | AutoModelForCausalLM, 20 | BitsAndBytesConfig, 21 | TrainingArguments, 22 | PreTrainedTokenizer, 23 | PreTrainedModel, 24 | ) 25 | from trl import SFTTrainer, DataCollatorForCompletionOnlyLM 26 | import tqdm as tqdm 27 | 28 | import llm_planner as llmp 29 | 30 | from accelerate import Accelerator 31 | 32 | 33 | HF_USER_TOKEN = os.getenv("HF_USER_TOKEN") 34 | 35 | 36 | def load_dataset(config: dict) -> dict[str, Dataset]: 37 | """Load the dataset from the configuration. 38 | 39 | Args: 40 | config (dict): The dataset configuration. 41 | 42 | Returns: 43 | dict[str, Dataset]: The loaded dataset. 44 | """ 45 | with open(config["splits_path"], "r") as f: 46 | split_ids_cfg = yaml.safe_load(f) 47 | 48 | splits: set[str] = config.get("splits", {}).keys() 49 | dataset = {split: defaultdict(list) for split in splits} 50 | 51 | # Connect to database 52 | conn = sqlite3.connect(config["database_path"]) 53 | c = conn.cursor() 54 | 55 | # load domains 56 | domains = {} 57 | c.execute("SELECT name, domain_pddl FROM domains") 58 | for domain_name, domain_pddl in c.fetchall(): 59 | domains[domain_name] = domain_pddl 60 | 61 | # load problems 62 | for split in splits: 63 | queries = [] 64 | split_keys: list[str] = config["splits"][split] 65 | for split_key in split_keys: 66 | split_ids = split_ids_cfg 67 | for key in split_key: 68 | split_ids = split_ids[key] 69 | 70 | c.execute( 71 | f"SELECT domain, problem_pddl, natural_language FROM problems WHERE id in ({', '.join(['?'] * len(split_ids))})", 72 | split_ids, 73 | ) 74 | queries.extend(c.fetchall()) 75 | 76 | for domain, problem_pddl, natural_language in queries: 77 | dataset[split]["domain"].append(domains[domain]) 78 | dataset[split]["problem"].append(problem_pddl) 79 | dataset[split]["natural_language"].append(natural_language) 80 | 81 | return {s: Dataset.from_dict(d, split=s) for s, d in dataset.items()} 82 | 83 | 84 | def find_all_linear_names( 85 | model: nn.Module, 86 | bits: int | None = None, 87 | ) -> list[str]: 88 | """Find names of all linear layers in the model. 89 | 90 | Args: 91 | model (nn.Module): The model to search for linear layers. 92 | 93 | Returns: 94 | list[str]: The names of all linear layers in the model (excluding LM Head) 95 | """ 96 | match bits: 97 | case 4: 98 | Linear = bnb.nn.Linear4bit 99 | case 8: 100 | Linear = bnb.nn.Linear8bitLt 101 | case _: 102 | Linear = torch.nn.Linear 103 | 104 | lora_module_names = set() 105 | for name, module in model.named_modules(): 106 | if isinstance(module, Linear): 107 | names = name.split(".") 108 | lora_module_names.add(names[-1]) 109 | 110 | if "lm_head" in lora_module_names: # needed for 16-bit 111 | lora_module_names.remove("lm_head") 112 | return list(lora_module_names) 113 | 114 | 115 | def strip(text: str, bos_token: str, eos_token: str) -> str: 116 | return text.removeprefix(bos_token) + eos_token 117 | 118 | 119 | def preprocess( 120 | tokenizer: PreTrainedTokenizer, 121 | examples, 122 | domain_prompt: str = "", 123 | problem_prompt: str = "", 124 | ) -> list[str]: 125 | """Preprocess the examples for training. 126 | 127 | Args: 128 | tokenizer (PreTrainedTokenizer): The tokenizer to use. 129 | examples: The examples to preprocess. 130 | domain_prompt (str, optional): How to prompt the domain. Defaults to "". 131 | problem_prompt (str, optional): How to prompt the problem. Defaults to "". 132 | 133 | Returns: 134 | list[str]: The preprocessed examples. 135 | """ 136 | inputs = [ 137 | strip( 138 | tokenizer.apply_chat_template( 139 | llmp.PlanningProblem(nl, d, p).apply_template( 140 | domain_prompt, 141 | problem_prompt, 142 | ), 143 | tokenize=False, 144 | add_generation_prompt=False, 145 | ), 146 | bos_token=tokenizer.bos_token, 147 | eos_token=tokenizer.eos_token, 148 | ) 149 | for nl, d, p in zip( 150 | examples["natural_language"], 151 | examples["domain"], 152 | examples["problem"], 153 | ) 154 | ] 155 | return inputs 156 | 157 | 158 | def load_model(config: dict) -> tuple[PreTrainedTokenizer, PreTrainedModel]: 159 | """Load the model and tokenizer from the configuration. 160 | 161 | Args: 162 | config (dict): The training config. 163 | 164 | Returns: 165 | tuple[PreTrainedTokenizer, PreTrainedModel]: The tokenizer and model. 166 | """ 167 | tokenizer = AutoTokenizer.from_pretrained( 168 | config["model"]["tokenizer_name"], 169 | token=HF_USER_TOKEN, 170 | ) 171 | if tokenizer.pad_token is None: 172 | tokenizer.pad_token = tokenizer.eos_token 173 | tokenizer.padding_side = "right" 174 | 175 | bnb_config_args: dict = config.get("bnb_config", {}) 176 | if bnb_config_args: 177 | bnb_config = BitsAndBytesConfig( 178 | load_in_4bit=bnb_config_args.get("bits", 16) == 4, 179 | load_in_8bit=bnb_config_args.get("bits", 16) == 8, 180 | bnb_4bit_use_double_quant=bnb_config_args.get("use_double_quant", False), 181 | bnb_4bit_quant_type=bnb_config_args.get("quant_type", "nf4"), 182 | bnb_4bit_compute_dtype=torch.bfloat16, 183 | ) 184 | else: 185 | bnb_config = None 186 | 187 | device_index = Accelerator().process_index 188 | device_map = {"": device_index} 189 | model = AutoModelForCausalLM.from_pretrained( 190 | config["model"]["model_name"], 191 | **config["model"].get("model_kwargs", {}), 192 | token=HF_USER_TOKEN, 193 | torch_dtype=torch.bfloat16, 194 | quantization_config=bnb_config, 195 | device_map=device_map, 196 | ) 197 | 198 | lora_config = LoraConfig( 199 | **config["lora_config"], 200 | target_modules=find_all_linear_names(model, bits=bnb_config_args.get("bits")), 201 | ) 202 | model = get_peft_model(model, lora_config) 203 | 204 | return tokenizer, model 205 | 206 | 207 | def extract_instruct_tokens(tokenizer: PreTrainedTokenizer) -> tuple[str, str]: 208 | """Extract the instruction tokens from the tokenizer. 209 | 210 | Args: 211 | tokenizer (PreTrainedTokenizer): The tokenizer to use. 212 | 213 | Returns: 214 | tuple[str, str]: The templates. 215 | """ 216 | placeholder = tokenizer.unk_token 217 | 218 | chat_str = tokenizer.apply_chat_template( 219 | [ 220 | {"role": "user", "content": placeholder}, 221 | {"role": "assistant", "content": placeholder}, 222 | ], 223 | tokenize=False, 224 | ) 225 | 226 | if not tokenizer.chat_template: 227 | templates = chat_str.split(f" {placeholder} ") 228 | else: 229 | templates = chat_str.split(placeholder) 230 | templates = [t.replace(" ", "").strip() for t in templates] 231 | 232 | return templates[:2] 233 | 234 | 235 | def main(config_path: str): 236 | """Train a model on a dataset using a given configuration. 237 | 238 | Args: 239 | config_path (str): The path to the configuration file. 240 | """ 241 | # Load configuration 242 | with open(config_path) as f: 243 | config = yaml.safe_load(f) 244 | 245 | # Load dataset 246 | dataset = load_dataset(config["dataset"]) 247 | 248 | train_config: dict = config["train"] 249 | 250 | # Load model 251 | tokenizer, model = load_model(train_config) 252 | 253 | # Create data collator 254 | instr_template, resp_template = extract_instruct_tokens(tokenizer) 255 | data_collator = DataCollatorForCompletionOnlyLM( 256 | response_template=resp_template, 257 | instruction_template=instr_template, 258 | tokenizer=tokenizer, 259 | ) 260 | 261 | # Build training arguments 262 | args_config = train_config.get("training_args", {}) 263 | training_args = TrainingArguments(**args_config) 264 | 265 | # Create trainer 266 | trainer = SFTTrainer( 267 | model, 268 | args=training_args, 269 | train_dataset=dataset["train"], 270 | eval_dataset=dataset["test"], 271 | data_collator=data_collator, 272 | max_seq_length=train_config["model"].get("max_seq_length", 512), 273 | formatting_func=partial( 274 | preprocess, 275 | tokenizer, 276 | problem_prompt=config["dataset"]["prompts"]["problem"], 277 | domain_prompt=config["dataset"]["prompts"]["domain"], 278 | ), 279 | ) 280 | trainer.train() 281 | 282 | trainer.save_model(train_config.get("save_path", "ckpt")) 283 | 284 | 285 | if __name__ == "__main__": 286 | import argparse 287 | 288 | parser = argparse.ArgumentParser("Fine-tune a model on PDDL dataset.") 289 | parser.add_argument( 290 | "-c", 291 | "--config", 292 | type=str, 293 | default="config.yaml", 294 | required=True, 295 | help="Path to the configuration file.", 296 | ) 297 | 298 | args = parser.parse_args() 299 | 300 | main(args.config) 301 | -------------------------------------------------------------------------------- /llm_planner.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from openai import OpenAI 4 | import torch 5 | from transformers import ( 6 | AutoTokenizer, 7 | AutoModelForCausalLM, 8 | PreTrainedTokenizer, 9 | PreTrainedModel, 10 | ) 11 | 12 | from vllm import LLM, RequestOutput, SamplingParams 13 | from vllm.lora.request import LoRARequest 14 | 15 | 16 | class PlanningProblem: 17 | def __init__( 18 | self, 19 | natural_language: str, 20 | domain: str, 21 | problem: str, 22 | ): 23 | """Initializes a new PlanningProblem. 24 | 25 | Args: 26 | natural_language (str): The natural language description of the 27 | problem to be solved. 28 | domain (str): A string representation of the domain of the problem. 29 | problem (str): A string representation of the ground truth PDDL. 30 | """ 31 | self.natural_language = natural_language 32 | self.domain = domain 33 | self.problem = problem 34 | 35 | def apply_template( 36 | self, 37 | domain_prompt: str = "", 38 | problem_prompt: str = "", 39 | include_answer: bool = True, 40 | ) -> list[dict[str, str]]: 41 | """Apply problem template to the problem. 42 | 43 | Args: 44 | domain_prompt (str, optional): How to prompt the domain. Defaults to "". 45 | problem_prompt (str, optional): How to prompt the problem. Defaults to "". 46 | include_answer (bool, optional): Whether to include the answer. Defaults to True. 47 | 48 | Returns: 49 | list[dict[str, str]]: Problem prompt. 50 | """ 51 | return [ 52 | { 53 | "role": "user", 54 | "content": f"{problem_prompt} {self.natural_language} " 55 | + f"{domain_prompt}\n{self.domain}\n", 56 | }, 57 | ] + ( 58 | [ 59 | { 60 | "role": "assistant", 61 | "content": " " + self.problem, 62 | }, 63 | ] 64 | if include_answer 65 | else [] 66 | ) 67 | 68 | 69 | class Planner(abc.ABC): 70 | @abc.abstractmethod 71 | def plan_chat( 72 | self, 73 | messages: list[list[dict[str, str]]], 74 | device=None, 75 | max_new_tokens: int = 8_000, 76 | **kwargs, 77 | ) -> list[str]: 78 | """Passes messages to a model for completion. 79 | 80 | Args: 81 | messages (list[list[dict[str, str]]]): A list of messages to be 82 | passed to the model. 83 | device (optional): The device to run the model on. Defaults to None. 84 | max_new_tokens (int): The maximum number of tokens to generate. 85 | Defaults to 8_000. 86 | 87 | Returns: 88 | list[str]: The message completion. 89 | """ 90 | pass 91 | 92 | 93 | class HFPlanner: 94 | """A class for planning using Huggingface transformers.""" 95 | 96 | def __init__( 97 | self, 98 | model_name: str | None = None, 99 | tokenizer_name: str | None = None, 100 | tokenizer: PreTrainedTokenizer | None = None, 101 | model: PreTrainedModel | None = None, 102 | **kwargs, 103 | ): 104 | """Initializes a new HFPlanner. 105 | 106 | Args: 107 | model_name (str): The name of the model to be used. 108 | tokenizer_name (str, optional): The name of the tokenizer to be used. 109 | Defaults to None, in which case the model_name is used. 110 | kwargs: Additional keyword arguments to be passed to the model. 111 | """ 112 | if model is not None and tokenizer is not None: 113 | self.model = model 114 | self.tokenizer = tokenizer 115 | else: 116 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name) 117 | self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) 118 | 119 | def plan(self, prompt: str, device=None, **kwargs) -> str: 120 | """Passes the prompt to the model for completion, without applying a 121 | chat template. 122 | 123 | Args: 124 | prompt (str): The prompt to be passed to the model. 125 | device (optional): The device to run the model on. Defaults to None. 126 | 127 | Returns: 128 | str: The message completion. 129 | """ 130 | encoded = self.tokenizer.encode(prompt, return_tensors="pt") 131 | 132 | if device is not None: 133 | encoded = encoded.to(device) 134 | 135 | generate_config = { 136 | "max_new_tokens": 4000, 137 | "temperature": 0.0, 138 | "do_sample": False, 139 | } 140 | generate_config.update(kwargs) 141 | generated_ids = self.model.generate(encoded, **generate_config) 142 | 143 | decoded = self.tokenizer.decode(generated_ids) 144 | 145 | return decoded[0] 146 | 147 | def plan_chat( 148 | self, 149 | messages: list[dict[str, str]], 150 | device=None, 151 | **kwargs, 152 | ) -> list[str]: 153 | """Passes messages to the model for completion, applying a chat template. 154 | 155 | Args: 156 | messages (list[dict[str, str]]): A list of messages to be passed to 157 | the model. 158 | device (optional): The device to run the model on. Defaults to None. 159 | kwargs: Additional keyword arguments to be passed to the model. 160 | 161 | Returns: 162 | str: The message completion. 163 | """ 164 | encoded = self.tokenizer.apply_chat_template( 165 | messages, 166 | return_tensors="pt", 167 | padding=True, 168 | add_generation_prompt=True, 169 | ) 170 | 171 | if device is not None: 172 | encoded = encoded.to(device) 173 | 174 | generate_config = { # default generate config 175 | "max_new_tokens": 4000, 176 | "do_sample": False, 177 | } 178 | generate_config.update(kwargs) 179 | with torch.no_grad(): 180 | generated_ids = self.model.generate(encoded, **generate_config) 181 | 182 | decoded = [] 183 | for g, e in zip(generated_ids, encoded): 184 | decoded.append( 185 | self.tokenizer.decode( 186 | g[len(e) :], 187 | skip_special_tokens=True, 188 | ) 189 | ) 190 | 191 | return decoded 192 | 193 | 194 | class VLLMPlanner(Planner): 195 | """A class for planning using VLLM models.""" 196 | 197 | def __init__(self, model_name: str, lora: str | None = None, **kwargs): 198 | """Initializes a new VLLMPlanner. 199 | 200 | Args: 201 | model_name (str): The name of the model to be used. 202 | kwargs: Additional keyword arguments to be passed to the model. 203 | """ 204 | self.lora = LoRARequest(lora, 1, lora) if lora else None 205 | self.model = LLM(model_name, enable_lora=bool(lora), **kwargs) 206 | self.tokenizer = self.model.get_tokenizer() 207 | 208 | def plan_chat( 209 | self, 210 | messages: list[list[dict[str, str]]], 211 | device=None, 212 | max_new_tokens: int = 8_000, 213 | **kwargs, 214 | ) -> list[str]: 215 | """Passes messages to the model for completion. 216 | 217 | Args: 218 | messages (list[dict[str, str]]): A list of messages to be passed to 219 | the model. 220 | 221 | Returns: 222 | list[str]: The message completions. 223 | """ 224 | encoded = self.tokenizer.apply_chat_template( 225 | messages, 226 | add_generation_prompt=True, 227 | tokenize=False, 228 | ) 229 | params = SamplingParams( 230 | max_tokens=max_new_tokens, 231 | temperature=kwargs.get("temperature", 0.0), 232 | top_p=kwargs.get("top_p", 1.0), 233 | top_k=kwargs.get("top_k", -1), 234 | min_p=kwargs.get("min_p", 0.0), 235 | ) 236 | 237 | outputs: list[RequestOutput] = self.model.generate( 238 | encoded, 239 | params, 240 | use_tqdm=False, 241 | lora_request=self.lora, 242 | ) 243 | 244 | return [output.outputs[0].text for output in outputs] 245 | 246 | 247 | class OpenAIPlanner: 248 | """A class for planning using OpenAI models.""" 249 | 250 | def __init__(self, model_name: str, **kwargs): 251 | """Initializes a new OpenAIPlanner. 252 | 253 | Args: 254 | model_name (str): The name of the model to be used. 255 | kwargs: Additional keyword arguments to be passed to the OpenAI 256 | client. 257 | """ 258 | self.client = OpenAI(**kwargs) 259 | self.model_name = model_name 260 | self.is_o1 = model_name.startswith("o1") 261 | 262 | def _plan_chat( 263 | self, 264 | messages: list[dict[str, str]], 265 | max_new_tokens: int | None = None, 266 | device=None, 267 | **kwargs, 268 | ) -> str: 269 | """Passes messages to the model for completion. 270 | 271 | Args: 272 | messages (list[dict[str, str]]): A list of messages to be passed to 273 | the model. 274 | device (optional): The device to run the model on (ignored for OpenAI). 275 | 276 | Returns: 277 | str: The message completion. 278 | """ 279 | 280 | if self.is_o1: 281 | return ( 282 | self.client.chat.completions.create( 283 | model=self.model_name, 284 | messages=messages, 285 | frequency_penalty=kwargs.get("frequency_penalty", None), 286 | max_completion_tokens=max_new_tokens, 287 | n=1, 288 | presence_penalty=kwargs.get("presence_penalty", None), 289 | temperature=kwargs.get("temperature", 0.0), 290 | top_p=kwargs.get("top_p", None), 291 | ) 292 | .choices[0] 293 | .message.content 294 | ) 295 | else: 296 | return ( 297 | self.client.chat.completions.create( 298 | model=self.model_name, 299 | messages=messages, 300 | frequency_penalty=kwargs.get("frequency_penalty", None), 301 | max_tokens=max_new_tokens, 302 | n=1, 303 | presence_penalty=kwargs.get("presence_penalty", None), 304 | temperature=kwargs.get("temperature", 0.0), 305 | top_p=kwargs.get("top_p", None), 306 | ) 307 | .choices[0] 308 | .message.content 309 | ) 310 | 311 | def plan_chat( 312 | self, 313 | messages: list[list[dict[str, str]]], 314 | max_new_tokens: int | None = None, 315 | device=None, 316 | **kwargs, 317 | ) -> list[str]: 318 | """Passes messages to the model for completion. 319 | 320 | Args: 321 | messages (list[list[dict[str, str]]]): A list of messages to be 322 | passed to the model. 323 | device (optional): The device to run the model on (ignored for OpenAI). 324 | 325 | Returns: 326 | list[str]: The message completions. 327 | """ 328 | return [ 329 | self._plan_chat( 330 | message, 331 | max_new_tokens=max_new_tokens, 332 | device=device, 333 | **kwargs, 334 | ) 335 | for message in messages 336 | ] 337 | -------------------------------------------------------------------------------- /planetarium/ASSUMPTIONS.md: -------------------------------------------------------------------------------- 1 | # Assumptions 2 | 3 | For each domain, the following assumptions are made regarding the types of problems evaluated. 4 | We also assume that problems are solvable via the domain's actions. 5 | 6 | Our equivalence check is based on the assumption that the problems are solvable already (i.e. our evaluator checks solvability before equivalence). 7 | 8 | ## Blocksworld 9 | 10 | Generally, since Blocks World has reversible actions, essentially all problems are evaluable. 11 | 12 | `:init` conditions: 13 | - No two blocks are on top of a single block 14 | - No block is initialized on top of two blocks 15 | - No loops of blocks are made 16 | - Arm can only hold one block 17 | 18 | ## Grippers 19 | 20 | Generally, since Grippers has reversible actions, essentially all problems are evaluable. 21 | 22 | `:init` conditions: 23 | - No double "typing" of objects (we don't using `:typing`, we use certain immutable predicates) 24 | - All balls have only one location 25 | - All grippers have only one location 26 | - All grippers hold up to 1 ball 27 | 28 | ## Rover Single 29 | Rover has the capability of being a much more complex domain, but for the purposes of this benchmark, we work only with a single rover and a single lander. 30 | 31 | `:init` conditions: 32 | - No double `at_*` predicates (rover and lander can only be in one location at a time) 33 | 34 | ## Floortile 35 | 36 | Generally, all valid problems with reachable goal states are evaluable. 37 | 38 | `:init` conditions: 39 | - No robot has two colors (`robot-has`) 40 | -------------------------------------------------------------------------------- /planetarium/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import resources 3 | 4 | __all__ = [ 5 | "builder", 6 | "downward", 7 | "graph", 8 | "metric", 9 | "oracle", 10 | "evaluate", 11 | "DOMAINS", 12 | ] 13 | 14 | from . import builder 15 | from . import downward 16 | from . import graph 17 | from . import metric 18 | from . import oracle 19 | from . import domains 20 | 21 | DOMAINS = dict() 22 | 23 | # load domains 24 | for domain in resources.files(domains).iterdir(): 25 | with domain.open() as f: 26 | DOMAINS[os.path.basename(domain).split(".")[0]] = f.read() 27 | 28 | from .evaluate import evaluate 29 | -------------------------------------------------------------------------------- /planetarium/builder.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from collections.abc import Iterable 4 | 5 | from planetarium.graph import ProblemGraph 6 | 7 | from pddl.core import And, Problem 8 | from pddl.logic.predicates import Predicate 9 | from pddl.logic.terms import Constant 10 | from pddl.parser.problem import LenientProblemParser 11 | 12 | 13 | def _constant_to_dict(constant: Constant) -> dict[str, typing.Any]: 14 | """ 15 | Convert a PDDL Constant object to a dictionary representation. 16 | 17 | Parameters: 18 | constant (Constant): The PDDL Constant object. 19 | 20 | Returns: 21 | dict: A dictionary containing the constant name and typing information. 22 | """ 23 | return { 24 | "name": constant.name, 25 | "typing": set(constant.type_tags), 26 | } 27 | 28 | 29 | def _predicate_to_dict(predicate: Predicate) -> dict[str, typing.Any]: 30 | """ 31 | Convert a PDDL Predicate object to a dictionary representation. 32 | 33 | Parameters: 34 | predicate (Predicate): The PDDL Predicate object. 35 | 36 | Returns: 37 | dict: A dictionary containing the predicate name and its parameter names. 38 | """ 39 | return { 40 | "typing": predicate.name, 41 | "parameters": [constant.name for constant in predicate.terms], 42 | } 43 | 44 | 45 | def _build_constants(constants: Iterable[Constant]) -> list[dict[str, typing.Any]]: 46 | """ 47 | Build a list of dictionaries representing PDDL constants. 48 | 49 | Parameters: 50 | constants (Iterable[Constant]): An iterable of PDDL Constant objects. 51 | 52 | Returns: 53 | list: A list of dictionaries containing constant information. 54 | """ 55 | return [_constant_to_dict(constant) for constant in constants] 56 | 57 | 58 | def _build_predicates( 59 | predicates: Iterable[Predicate], 60 | ) -> list[dict[str, typing.Any]]: 61 | """ 62 | Build a list of dictionaries representing PDDL predicates. 63 | 64 | Parameters: 65 | predicates (Iterable[Predicate]): An iterable of PDDL Predicate objects. 66 | 67 | Returns: 68 | list: A list of dictionaries containing predicate information. 69 | """ 70 | return [_predicate_to_dict(predicate) for predicate in predicates] 71 | 72 | 73 | def build(problem: str) -> ProblemGraph: 74 | """ 75 | Build scene graphs from a PDDL problem description. 76 | 77 | Parameters: 78 | problem (str): A string containing the PDDL problem description. 79 | 80 | Returns: 81 | tuple: Two SceneGraph instances representing the initial state and goal state. 82 | """ 83 | problem: Problem = LenientProblemParser()(problem) 84 | 85 | if isinstance(problem.goal, Predicate): 86 | goal = [problem.goal] 87 | elif isinstance(problem.goal, And): 88 | goal = problem.goal.operands 89 | else: 90 | raise ValueError(f"Unsupported goal type: {type(problem.goal)}") 91 | 92 | return ProblemGraph( 93 | _build_constants(problem.objects), 94 | _build_predicates(problem.init), 95 | _build_predicates(goal), 96 | domain=problem.domain_name, 97 | requirements=[req.name for req in problem.requirements], 98 | ) 99 | -------------------------------------------------------------------------------- /planetarium/domains/blocksworld.pddl: -------------------------------------------------------------------------------- 1 | ;; source: https://github.com/AI-Planning/pddl-generators/blob/main/blocksworld/domain.pddl 2 | ;; same as used in IPC 2023 3 | ;; 4 | (define (domain blocksworld) 5 | 6 | (:requirements :strips) 7 | 8 | (:predicates 9 | (clear ?x) 10 | (on-table ?x) 11 | (arm-empty) 12 | (holding ?x) 13 | (on ?x ?y) 14 | ) 15 | 16 | (:action pickup 17 | :parameters (?ob) 18 | :precondition (and (clear ?ob) (on-table ?ob) (arm-empty)) 19 | :effect (and (holding ?ob) (not (clear ?ob)) (not (on-table ?ob)) 20 | (not (arm-empty))) 21 | ) 22 | 23 | (:action putdown 24 | :parameters (?ob) 25 | :precondition (holding ?ob) 26 | :effect (and (clear ?ob) (arm-empty) (on-table ?ob) 27 | (not (holding ?ob))) 28 | ) 29 | 30 | (:action stack 31 | :parameters (?ob ?underob) 32 | :precondition (and (clear ?underob) (holding ?ob)) 33 | :effect (and (arm-empty) (clear ?ob) (on ?ob ?underob) 34 | (not (clear ?underob)) (not (holding ?ob))) 35 | ) 36 | 37 | (:action unstack 38 | :parameters (?ob ?underob) 39 | :precondition (and (on ?ob ?underob) (clear ?ob) (arm-empty)) 40 | :effect (and (holding ?ob) (clear ?underob) 41 | (not (on ?ob ?underob)) (not (clear ?ob)) (not (arm-empty))) 42 | ) 43 | ) -------------------------------------------------------------------------------- /planetarium/domains/floor-tile.pddl: -------------------------------------------------------------------------------- 1 | ;; Modified from: https://github.com/AI-Planning/pddl-generators/blob/main/floortile/domain.pddl 2 | 3 | (define (domain floor-tile) 4 | (:requirements :typing) 5 | (:types 6 | robot tile color - object 7 | ) 8 | 9 | (:predicates 10 | (robot-at ?r - robot ?x - tile) 11 | (up ?x - tile ?y - tile) 12 | (right ?x - tile ?y - tile) 13 | (painted ?x - tile ?c - color) 14 | (robot-has ?r - robot ?c - color) 15 | (available-color ?c - color) 16 | ) 17 | 18 | (:action change-color 19 | :parameters (?r - robot ?c - color ?c2 - color) 20 | :precondition (and (robot-has ?r ?c) (available-color ?c2)) 21 | :effect (and (not (robot-has ?r ?c)) (robot-has ?r ?c2) 22 | ) 23 | ) 24 | 25 | (:action paint-up 26 | :parameters (?r - robot ?y - tile ?x - tile ?c - color) 27 | :precondition (and (robot-has ?r ?c) (robot-at ?r ?x) (up ?y ?x)) 28 | :effect (painted ?y ?c) 29 | ) 30 | 31 | (:action paint-down 32 | :parameters (?r - robot ?y - tile ?x - tile ?c - color) 33 | :precondition (and (robot-has ?r ?c) (robot-at ?r ?x) (up ?x ?y)) 34 | :effect (and (painted ?y ?c) 35 | ) 36 | ) 37 | 38 | (:action paint-right 39 | :parameters (?r - robot ?y - tile ?x - tile ?c - color) 40 | :precondition (and (robot-has ?r ?c) (robot-at ?r ?x) (right ?y ?x)) 41 | :effect (and (painted ?y ?c) 42 | ) 43 | ) 44 | 45 | (:action paint-left 46 | :parameters (?r - robot ?y - tile ?x - tile ?c - color) 47 | :precondition (and (robot-has ?r ?c) (robot-at ?r ?x) (right ?x ?y)) 48 | :effect (and (painted ?y ?c) 49 | ) 50 | ) 51 | 52 | ; Robot movements 53 | (:action up 54 | :parameters (?r - robot ?x - tile ?y - tile) 55 | :precondition (and (robot-at ?r ?x) (up ?y ?x)) 56 | :effect (and (robot-at ?r ?y) (not (robot-at ?r ?x))) 57 | ) 58 | 59 | (:action down 60 | :parameters (?r - robot ?x - tile ?y - tile) 61 | :precondition (and (robot-at ?r ?x) (up ?x ?y)) 62 | :effect (and (robot-at ?r ?y) (not (robot-at ?r ?x))) 63 | ) 64 | 65 | (:action right 66 | :parameters (?r - robot ?x - tile ?y - tile) 67 | :precondition (and (robot-at ?r ?x) (right ?y ?x)) 68 | :effect (and (robot-at ?r ?y) (not (robot-at ?r ?x))) 69 | ) 70 | 71 | (:action left 72 | :parameters (?r - robot ?x - tile ?y - tile) 73 | :precondition (and (robot-at ?r ?x) (right ?x ?y)) 74 | :effect (and (robot-at ?r ?y) (not (robot-at ?r ?x))) 75 | ) 76 | 77 | ) -------------------------------------------------------------------------------- /planetarium/domains/gripper.pddl: -------------------------------------------------------------------------------- 1 | ;; source: https://github.com/AI-Planning/pddl-generators/blob/main/gripper/domain.pddl 2 | (define (domain gripper) 3 | (:requirements :strips) 4 | (:predicates 5 | (room ?r) 6 | (ball ?b) 7 | (gripper ?g) 8 | (at-robby ?r) 9 | (at ?b ?r) 10 | (free ?g) 11 | (carry ?o ?g) 12 | ) 13 | 14 | (:action move 15 | :parameters (?from ?to) 16 | :precondition (and (room ?from) (room ?to) (at-robby ?from)) 17 | :effect (and (at-robby ?to) 18 | (not (at-robby ?from))) 19 | ) 20 | 21 | (:action pick 22 | :parameters (?obj ?room ?gripper) 23 | :precondition (and (ball ?obj) (room ?room) (gripper ?gripper) 24 | (at ?obj ?room) (at-robby ?room) (free ?gripper)) 25 | :effect (and (carry ?obj ?gripper) 26 | (not (at ?obj ?room)) 27 | (not (free ?gripper))) 28 | ) 29 | 30 | (:action drop 31 | :parameters (?obj ?room ?gripper) 32 | :precondition (and (ball ?obj) (room ?room) (gripper ?gripper) 33 | (carry ?obj ?gripper) (at-robby ?room)) 34 | :effect (and (at ?obj ?room) 35 | (free ?gripper) 36 | (not (carry ?obj ?gripper))) 37 | ) 38 | ) -------------------------------------------------------------------------------- /planetarium/domains/rover-single.pddl: -------------------------------------------------------------------------------- 1 | (define (domain rover-single) 2 | (:requirements :strips :typing) 3 | (:types 4 | waypoint camera mode objective 5 | ) 6 | 7 | (:predicates 8 | (at_rover ?y - waypoint) 9 | (at_lander ?y - waypoint) 10 | (can_traverse ?x - waypoint ?y - waypoint) 11 | (have_rock_analysis ?w - waypoint) 12 | (have_soil_analysis ?w - waypoint) 13 | (supports ?c - camera ?m - mode) 14 | (available) 15 | (visible ?w - waypoint ?p - waypoint) 16 | (have_image ?o - objective ?m - mode) 17 | (communicated_soil_data ?w - waypoint) 18 | (communicated_rock_data ?w - waypoint) 19 | (communicated_image_data ?o - objective ?m - mode) 20 | (at_rock_sample ?w - waypoint) 21 | (at_soil_sample ?w - waypoint) 22 | (visible_from ?o - objective ?w - waypoint) 23 | (channel_free) 24 | ) 25 | 26 | (:action navigate 27 | :parameters (?y - waypoint ?z - waypoint) 28 | :precondition (and (can_traverse ?y ?z) (available) (at_rover ?y) 29 | (visible ?y ?z)) 30 | :effect (and (not (at_rover ?y)) (at_rover ?z)) 31 | ) 32 | 33 | (:action sample_soil 34 | :parameters (?p - waypoint) 35 | :precondition (and (at_rover ?p) (at_soil_sample ?p)) 36 | :effect (and (have_soil_analysis ?p)) 37 | ) 38 | 39 | (:action sample_rock 40 | :parameters (?p - waypoint) 41 | :precondition (and (at_rover ?p) (at_rock_sample ?p)) 42 | :effect (and (have_rock_analysis ?p)) 43 | ) 44 | 45 | (:action take_image 46 | :parameters (?p - waypoint ?o - objective ?i - camera ?m - mode) 47 | :precondition (and (supports ?i ?m) (visible_from ?o ?p) (at_rover ?p)) 48 | :effect (have_image ?o ?m) 49 | ) 50 | 51 | (:action communicate_soil_data 52 | :parameters (?p - waypoint ?x - waypoint ?y - waypoint) 53 | :precondition (and (at_rover ?x) 54 | (at_lander ?y)(have_soil_analysis ?p) 55 | (visible ?x ?y)(available)(channel_free)) 56 | :effect (and (not (available)) 57 | (not (channel_free))(channel_free) 58 | (communicated_soil_data ?p)(available)) 59 | ) 60 | 61 | (:action communicate_rock_data 62 | :parameters (?p - waypoint ?x - waypoint ?y - waypoint) 63 | :precondition (and (at_rover ?x) 64 | (at_lander ?y)(have_rock_analysis ?p) 65 | (visible ?x ?y)(available)(channel_free)) 66 | :effect (and (not (available)) 67 | (not (channel_free)) 68 | (channel_free)(communicated_rock_data ?p)(available)) 69 | ) 70 | 71 | (:action communicate_image_data 72 | :parameters (?o - objective ?m - mode ?x - waypoint ?y - waypoint) 73 | :precondition (and (at_rover ?x) 74 | (at_lander ?y)(have_image ?o ?m) 75 | (visible ?x ?y)(available)(channel_free)) 76 | :effect (and (not (available)) 77 | (not (channel_free))(channel_free) 78 | (communicated_image_data ?o ?m)(available)) 79 | ) 80 | ) -------------------------------------------------------------------------------- /planetarium/domains/rover.pddl: -------------------------------------------------------------------------------- 1 | ;; source: https://github.com/AI-Planning/pddl-generators/blob/main/rovers/domain.pddl 2 | ;; same as used in IPC 2023 3 | ;; 4 | (define (domain rover) 5 | (:requirements :strips :typing) 6 | (:types 7 | rover waypoint store camera mode lander objective 8 | ) 9 | 10 | (:predicates 11 | (at ?x - rover ?y - waypoint) 12 | (at_lander ?x - lander ?y - waypoint) 13 | (can_traverse ?r - rover ?x - waypoint ?y - waypoint) 14 | (equipped_for_soil_analysis ?r - rover) 15 | (equipped_for_rock_analysis ?r - rover) 16 | (equipped_for_imaging ?r - rover) 17 | (empty ?s - store) 18 | (have_rock_analysis ?r - rover ?w - waypoint) 19 | (have_soil_analysis ?r - rover ?w - waypoint) 20 | (full ?s - store) 21 | (calibrated ?c - camera ?r - rover) 22 | (supports ?c - camera ?m - mode) 23 | (available ?r - rover) 24 | (visible ?w - waypoint ?p - waypoint) 25 | (have_image ?r - rover ?o - objective ?m - mode) 26 | (communicated_soil_data ?w - waypoint) 27 | (communicated_rock_data ?w - waypoint) 28 | (communicated_image_data ?o - objective ?m - mode) 29 | (at_soil_sample ?w - waypoint) 30 | (at_rock_sample ?w - waypoint) 31 | (visible_from ?o - objective ?w - waypoint) 32 | (store_of ?s - store ?r - rover) 33 | (calibration_target ?i - camera ?o - objective) 34 | (on_board ?i - camera ?r - rover) 35 | (channel_free ?l - lander) 36 | ) 37 | 38 | (:action navigate 39 | :parameters (?x - rover ?y - waypoint ?z - waypoint) 40 | :precondition (and (can_traverse ?x ?y ?z) (available ?x) (at ?x ?y) 41 | (visible ?y ?z)) 42 | :effect (and (not (at ?x ?y)) (at ?x ?z)) 43 | ) 44 | 45 | (:action sample_soil 46 | :parameters (?x - rover ?s - store ?p - waypoint) 47 | :precondition (and (at ?x ?p) (at_soil_sample ?p) 48 | (equipped_for_soil_analysis ?x) (store_of ?s ?x) (empty ?s)) 49 | :effect (and (not (empty ?s)) (full ?s) (have_soil_analysis ?x ?p) 50 | (not (at_soil_sample ?p))) 51 | ) 52 | 53 | (:action sample_rock 54 | :parameters (?x - rover ?s - store ?p - waypoint) 55 | :precondition (and (at ?x ?p) (at_rock_sample ?p) 56 | (equipped_for_rock_analysis ?x) (store_of ?s ?x)(empty ?s)) 57 | :effect (and (not (empty ?s)) (full ?s) (have_rock_analysis ?x ?p) 58 | (not (at_rock_sample ?p))) 59 | ) 60 | 61 | (:action drop 62 | :parameters (?x - rover ?y - store) 63 | :precondition (and (store_of ?y ?x) (full ?y)) 64 | :effect (and (not (full ?y)) (empty ?y)) 65 | ) 66 | 67 | (:action calibrate 68 | :parameters (?r - rover ?i - camera ?t - objective ?w - waypoint) 69 | :precondition (and (equipped_for_imaging ?r) (calibration_target ?i ?t) 70 | (at ?r ?w) (visible_from ?t ?w)(on_board ?i ?r)) 71 | :effect (calibrated ?i ?r) 72 | ) 73 | 74 | (:action take_image 75 | :parameters (?r - rover ?p - waypoint ?o - objective ?i - camera ?m - mode) 76 | :precondition (and (calibrated ?i ?r) (on_board ?i ?r) (equipped_for_imaging ?r) 77 | (supports ?i ?m) (visible_from ?o ?p) (at ?r ?p)) 78 | :effect (and (have_image ?r ?o ?m) 79 | (not (calibrated ?i ?r))) 80 | ) 81 | 82 | (:action communicate_soil_data 83 | :parameters (?r - rover ?l - lander ?p - waypoint ?x - waypoint ?y - waypoint) 84 | :precondition (and (at ?r ?x) 85 | (at_lander ?l ?y)(have_soil_analysis ?r ?p) 86 | (visible ?x ?y)(available ?r)(channel_free ?l)) 87 | :effect (and (not (available ?r)) 88 | (not (channel_free ?l))(channel_free ?l) 89 | (communicated_soil_data ?p)(available ?r)) 90 | ) 91 | 92 | (:action communicate_rock_data 93 | :parameters (?r - rover ?l - lander ?p - waypoint ?x - waypoint ?y - waypoint) 94 | :precondition (and (at ?r ?x) 95 | (at_lander ?l ?y)(have_rock_analysis ?r ?p) 96 | (visible ?x ?y)(available ?r)(channel_free ?l)) 97 | :effect (and (not (available ?r)) 98 | (not (channel_free ?l)) 99 | (channel_free ?l)(communicated_rock_data ?p)(available ?r)) 100 | ) 101 | 102 | (:action communicate_image_data 103 | :parameters (?r - rover ?l - lander ?o - objective ?m - mode ?x - waypoint ?y - waypoint) 104 | :precondition (and (at ?r ?x) 105 | (at_lander ?l ?y)(have_image ?r ?o ?m) 106 | (visible ?x ?y)(available ?r)(channel_free ?l)) 107 | :effect (and (not (available ?r)) 108 | (not (channel_free ?l))(channel_free ?l) 109 | (communicated_image_data ?o ?m)(available ?r)) 110 | ) 111 | ) -------------------------------------------------------------------------------- /planetarium/downward.py: -------------------------------------------------------------------------------- 1 | # FastDownward python wrapper 2 | 3 | import glob 4 | import os 5 | import re 6 | import subprocess 7 | import tempfile 8 | 9 | 10 | def _get_best_plan(plan_filepath: str) -> tuple[str | None, float]: 11 | """Get the best plan from a FastDownward plan file. 12 | 13 | Args: 14 | plan_filepath (str): The path to the plan file. 15 | 16 | Returns: 17 | The best plan and its cost. 18 | """ 19 | 20 | best_cost = float("inf") 21 | best_plan = None 22 | 23 | for plan_fp in glob.glob(f"{plan_filepath}*"): 24 | with open(plan_fp, "r") as f: 25 | *pddl_plan, cost_str = f.readlines() 26 | match = re.search(r"cost = ([-\d\.]+)", cost_str) 27 | if match: 28 | cost = float(match.group(1)) 29 | 30 | if cost < best_cost: 31 | best_cost = cost 32 | best_plan = "\n".join([*pddl_plan, ";"]) 33 | return best_plan, best_cost 34 | 35 | 36 | def plan( 37 | domain: str, 38 | problem: str, 39 | downward: str = "downward", 40 | alias: str = "lama", 41 | **kwargs, 42 | ) -> tuple[str | None, float]: 43 | """Find plan using FastDownward. 44 | 45 | Args: 46 | domain (str): A string containing a PDDL domain definition. 47 | problem (str): A string containing a PDDL task/problem definition. 48 | downward (str, optional): Path to FastDownward. Defaults to "downward". 49 | alias (str, optional): The FastDownward alias to. Defaults to "lama". 50 | 51 | Returns: 52 | Returns the PDDL plan string, or `None` if the planner failed, and the 53 | plan cost. 54 | """ 55 | with tempfile.TemporaryDirectory() as tmpdir: 56 | domain_filepath = os.path.join(tmpdir, "domain.pddl") 57 | task_filepath = os.path.join(tmpdir, "task.pddl") 58 | 59 | plan_filepath = os.path.join(tmpdir, "plan.pddl") 60 | sas_filepath = os.path.join(tmpdir, "output.sas") 61 | 62 | # build temporary domain and task files 63 | with open(domain_filepath, "w") as f: 64 | f.write(domain) 65 | with open(task_filepath, "w") as f: 66 | f.write(problem) 67 | 68 | # update arguments 69 | kwargs["plan-file"] = plan_filepath 70 | kwargs["sas-file"] = sas_filepath 71 | kwargs["alias"] = alias 72 | 73 | # build FastDownward arguments 74 | downward_args = [] 75 | for k, v in kwargs.items(): 76 | downward_args.append(f"--{k}") 77 | if isinstance(v, list) or isinstance(v, tuple): 78 | downward_args.extend(str(v_i) for v_i in v) 79 | else: 80 | downward_args.append(str(v)) 81 | 82 | # call FastDownward 83 | subprocess.run( 84 | [downward, *downward_args, domain_filepath, task_filepath], 85 | stdout=subprocess.DEVNULL, 86 | stderr=subprocess.PIPE, 87 | check=False, 88 | ) 89 | 90 | # get results 91 | best_plan, best_cost = _get_best_plan(plan_filepath) 92 | 93 | return best_plan, best_cost 94 | 95 | 96 | def validate(domain: str, problem: str, plan: str, val: str = "validate"): 97 | """Validate a plan using VAL. 98 | 99 | Args: 100 | domain (str): A string containing a PDDL domain definition. 101 | problem (str): A string containing a PDDL task/problem definition. 102 | plan (str): A string containing a PDDL plan. 103 | val (str, optional): Path to VAL. Defaults to "validate". 104 | """ 105 | with tempfile.TemporaryDirectory() as tmpdir: 106 | domain_filepath = os.path.join(tmpdir, "domain.pddl") 107 | task_filepath = os.path.join(tmpdir, "task.pddl") 108 | plan_filepath = os.path.join(tmpdir, "plan.pddl") 109 | 110 | # build temporary domain, task, and plan files 111 | with open(domain_filepath, "w") as f: 112 | f.write(domain) 113 | with open(task_filepath, "w") as f: 114 | f.write(problem) 115 | with open(plan_filepath, "w") as f: 116 | f.write(plan) 117 | 118 | # call VAL 119 | res = subprocess.run( 120 | [val, domain_filepath, task_filepath, plan_filepath], 121 | stdout=subprocess.PIPE, 122 | stderr=subprocess.DEVNULL, 123 | check=False, 124 | ) 125 | 126 | return "Plan valid" in res.stdout.decode("utf-8") 127 | -------------------------------------------------------------------------------- /planetarium/evaluate.py: -------------------------------------------------------------------------------- 1 | import importlib.resources as resources 2 | import os 3 | 4 | from pddl.parser.problem import LenientProblemParser 5 | from pddl.formatter import problem_to_string 6 | 7 | from planetarium import builder, oracle, metric, downward, DOMAINS 8 | 9 | 10 | VALIDATE = os.getenv("VALIDATE", "Validate") 11 | DOWNWARD = os.getenv("DOWNWARD", "downward") 12 | 13 | 14 | def evaluate( 15 | source_pddl_str: str, 16 | target_pddl_str: str, 17 | domain_str: str | None = None, 18 | is_placeholder: bool = False, 19 | check_solveable: bool = True, 20 | val: str = VALIDATE, 21 | fast_downward: str = DOWNWARD, 22 | **downward_args, 23 | ) -> tuple[bool, bool, bool]: 24 | """Evaluate two PDDL problem descriptions for equivalence. 25 | 26 | Args: 27 | source_pddl_str (str): The ground truth problem PDDL string. 28 | target_pddl_str (str): The second problem PDDL string. 29 | domain_str (str): The domain PDDL string. 30 | is_placeholder (bool, optional): Whether or not to treat the ground truth 31 | as a "placeholder" description. Defaults to False. 32 | check_solveable (bool, optional): Whether or not to check if the problem 33 | is solveable. Defaults to True. If False, the function will return 34 | False for the solveable element. 35 | 36 | Returns: 37 | tuple: A tuple containing the following boolean elements: 38 | - parseable: Whether or not the target PDDL string is parseable. 39 | - solveable: Whether or not the target PDDL string is solveable. 40 | - equivalent: Whether or not the PDDL strings are equivalent. 41 | """ 42 | parseable = False 43 | solveable = False 44 | equivalent = False 45 | 46 | source_graph = builder.build(source_pddl_str) 47 | 48 | try: 49 | target_graph = builder.build(target_pddl_str) 50 | parseable = True 51 | except Exception as e: 52 | return parseable, solveable, equivalent 53 | 54 | clean_pddl_str = problem_to_string(LenientProblemParser()(target_pddl_str)) 55 | domain_str = domain_str or DOMAINS.get(target_graph.domain) 56 | 57 | if check_solveable and isinstance(domain_str, str): 58 | try: 59 | plan_str = oracle.plan_to_string(oracle.plan(target_graph)) 60 | except (oracle.DomainNotSupportedError, NotImplementedError): 61 | try: 62 | plan_str, _ = downward.plan( 63 | domain_str, 64 | clean_pddl_str, 65 | downward=fast_downward, 66 | **downward_args, 67 | ) 68 | except: 69 | return parseable, solveable, equivalent 70 | except: 71 | return parseable, solveable, equivalent 72 | 73 | try: 74 | if not ( 75 | solveable := downward.validate( 76 | domain_str, 77 | clean_pddl_str, 78 | plan_str, 79 | val=val, 80 | ) 81 | ): 82 | return parseable, solveable, equivalent 83 | except: 84 | return parseable, solveable, equivalent 85 | 86 | if source_graph == target_graph: 87 | equivalent = True 88 | elif not metric.equals(source_graph.init(), target_graph.init()): 89 | equivalent = False 90 | else: 91 | try: 92 | equivalent = metric.equals( 93 | oracle.fully_specify(source_graph, return_reduced=True), 94 | oracle.fully_specify(target_graph, return_reduced=True), 95 | is_placeholder=is_placeholder, 96 | ) 97 | except: 98 | pass 99 | 100 | return parseable, solveable, equivalent 101 | -------------------------------------------------------------------------------- /planetarium/graph.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterable 2 | 3 | import abc 4 | import enum 5 | from functools import cached_property 6 | 7 | import matplotlib.pyplot as plt 8 | import networkx as nx 9 | import rustworkx as rx 10 | 11 | from pddl.core import And, Problem, Domain 12 | from pddl.logic.predicates import Predicate 13 | from pddl.logic.terms import Constant 14 | from pddl.formatter import problem_to_string 15 | 16 | 17 | class Label(str, enum.Enum): 18 | CONSTANT = "constant" 19 | PREDICATE = "predicate" 20 | 21 | 22 | class Scene(str, enum.Enum): 23 | INIT = "init" 24 | GOAL = "goal" 25 | 26 | 27 | class PlanGraphNode: 28 | def __init__( 29 | self, 30 | node: str, 31 | name: str, 32 | label: Label, 33 | typing: set | str | None = None, 34 | scene: Scene | None = None, 35 | ): 36 | self.node = node 37 | self.name = name 38 | self.label = label 39 | self.typing = typing 40 | self.scene = scene 41 | 42 | def __eq__(self, other: "PlanGraphNode") -> bool: 43 | return ( 44 | isinstance(other, PlanGraphNode) 45 | and self.node == other.node 46 | and self.name == other.name 47 | and self.label == other.label 48 | and self.typing == other.typing 49 | and self.scene == other.scene 50 | ) 51 | 52 | def __hash__(self) -> int: 53 | return hash((self.name, self.label, (*sorted(self.typing),), self.scene)) 54 | 55 | def __repr__(self) -> str: 56 | return f"PlanGraphNode(node={self.node}, name={self.name}, label={self.label}, typing={self.typing}, scene={self.scene})" 57 | 58 | def __str__(self) -> str: 59 | return f"PlanGraphNode(node={self.node}, name={self.name}, label={self.label}, typing={self.typing}, scene={self.scene})" 60 | 61 | 62 | class PlanGraphEdge: 63 | def __init__( 64 | self, 65 | predicate: str, 66 | position: int | None = None, 67 | scene: Scene | None = None, 68 | ): 69 | self.predicate = predicate 70 | self.position = position 71 | self.scene = scene 72 | 73 | def __eq__(self, other: "PlanGraphEdge") -> bool: 74 | return ( 75 | isinstance(other, PlanGraphEdge) 76 | and self.predicate == other.predicate 77 | and self.position == other.position 78 | and self.scene == other.scene 79 | ) 80 | 81 | def __hash__(self) -> int: 82 | return hash((self.predicate, self.position, self.scene)) 83 | 84 | def __repr__(self) -> str: 85 | return f"PlanGraphEdge(predicate={self.predicate}, position={self.position}, scene={self.scene})" 86 | 87 | def __str__(self) -> str: 88 | return f"PlanGraphEdge(predicate={self.predicate}, position={self.position}, scene={self.scene})" 89 | 90 | 91 | class PlanGraph(metaclass=abc.ABCMeta): 92 | """ 93 | Subclass of rx.PyDiGraph representing a scene graph. 94 | 95 | Attributes: 96 | constants (property): A dictionary of constant nodes in the scene graph. 97 | domain (property): The domain of the scene graph. 98 | """ 99 | 100 | def __init__( 101 | self, 102 | constants: list[dict[str, Any]], 103 | domain: str | None = None, 104 | requirements: tuple[str] = (), 105 | ): 106 | """ 107 | Initialize the SceneGraph instance. 108 | 109 | Parameters: 110 | constants (list): List of dictionaries representing constants. 111 | domain (str, optional): The domain of the scene graph. 112 | Defaults to None. 113 | requirements (list, optional): List of requirements for the scene 114 | graph. 115 | """ 116 | super().__init__() 117 | 118 | self._constants: list[dict[str, Any]] = [] 119 | self._constant_nodes: list[PlanGraphNode] = [] 120 | self._predicates: list[dict[str, Any]] = [] 121 | self._predicate_nodes: list[PlanGraphNode] = [] 122 | self._node_lookup: dict[str, tuple[int, PlanGraphNode]] = {} 123 | self._nodes: list[PlanGraphNode] = [] 124 | self._edges: set[tuple[int, int, PlanGraphEdge]] = set() 125 | self._domain: str = domain 126 | self._requirements: tuple[str] = requirements 127 | self.graph = rx.PyDiGraph() 128 | 129 | for constant in constants: 130 | self.add_node( 131 | PlanGraphNode( 132 | constant["name"], 133 | name=constant["name"], 134 | label=Label.CONSTANT, 135 | typing=constant["typing"], 136 | ) 137 | ) 138 | 139 | @cached_property 140 | def nodes(self) -> list[PlanGraphNode]: 141 | return self._nodes 142 | 143 | @cached_property 144 | def edges(self) -> set[tuple[PlanGraphNode, PlanGraphNode, PlanGraphEdge]]: 145 | return self._edges 146 | 147 | def add_node(self, node: PlanGraphNode): 148 | if node in self.nodes: 149 | raise ValueError(f"Node {node} already exists in the graph.") 150 | index = self.graph.add_node(node) 151 | 152 | if node.label == Label.CONSTANT: 153 | self._constants.append({"name": node.name, "typing": node.typing}) 154 | self._constant_nodes.append(node) 155 | elif node.label == Label.PREDICATE: 156 | self._predicate_nodes.append(node) 157 | 158 | self._nodes.append(node) 159 | self._node_lookup[node.node] = (index, node) 160 | 161 | def has_edge( 162 | self, 163 | u: str | PlanGraphNode, 164 | v: str | PlanGraphNode, 165 | edge: PlanGraphEdge | None = None, 166 | ) -> bool: 167 | if isinstance(u, PlanGraphNode): 168 | u_index = self.nodes.index(u) 169 | else: 170 | u_index, _ = self._node_lookup[u] 171 | 172 | if isinstance(v, PlanGraphNode): 173 | v_index = self.nodes.index(v) 174 | else: 175 | v_index, _ = self._node_lookup[v] 176 | 177 | if edge: 178 | return (u_index, v_index, edge) in self.graph.edge_index_map().values() 179 | else: 180 | return self.graph.has_edge(u_index, v_index) 181 | 182 | def add_edge( 183 | self, u: str | PlanGraphNode, v: str | PlanGraphNode, edge: PlanGraphEdge 184 | ): 185 | if isinstance(u, PlanGraphNode): 186 | u_index = self.nodes.index(u) 187 | else: 188 | u_index, u = self._node_lookup[u] 189 | 190 | if isinstance(v, PlanGraphNode): 191 | v_index = self.nodes.index(v) 192 | else: 193 | v_index, v = self._node_lookup[v] 194 | 195 | self.graph.add_edge(u_index, v_index, edge) 196 | self._edges.add((u, v, edge)) 197 | 198 | def _add_predicate( 199 | self, 200 | predicate: dict[str, Any], 201 | scene: Scene | None = None, 202 | ): 203 | """ 204 | Add a predicate to the plan graph. 205 | 206 | Parameters: 207 | predicate (dict): A dictionary representing the predicate. 208 | scene (Scene, optional): The scene in which the predicate occurs. 209 | """ 210 | if scene: 211 | predicate.update({"scene": scene}) 212 | if predicate in self.predicates: 213 | return 214 | predicate_name = self._build_unique_predicate_name( 215 | predicate_name=predicate["typing"], 216 | argument_names=predicate["parameters"], 217 | ) 218 | self.add_node( 219 | PlanGraphNode( 220 | predicate_name, 221 | name=predicate_name, 222 | label=Label.PREDICATE, 223 | typing=predicate["typing"], 224 | scene=scene, 225 | ) 226 | ) 227 | 228 | for position, parameter_name in enumerate(predicate["parameters"]): 229 | if parameter_name not in [node.name for node in self.constant_nodes]: 230 | raise ValueError(f"Parameter {parameter_name} not found in constants.") 231 | self.add_edge( 232 | predicate_name, 233 | parameter_name, 234 | PlanGraphEdge( 235 | predicate=predicate["typing"], 236 | position=position, 237 | scene=scene, 238 | ), 239 | ) 240 | 241 | self._predicates.append(predicate) 242 | 243 | def in_degree(self, node: str | PlanGraphNode) -> int: 244 | if isinstance(node, PlanGraphNode): 245 | return self.graph.in_degree(self.nodes.index(node)) 246 | else: 247 | return self.graph.in_degree(self._node_lookup[node][0]) 248 | 249 | def out_degree(self, node: str | PlanGraphNode) -> int: 250 | if isinstance(node, PlanGraphNode): 251 | return self.graph.out_degree(self.nodes.index(node)) 252 | else: 253 | return self.graph.out_degree(self._node_lookup[node][0]) 254 | 255 | def predecessors(self, node: str | PlanGraphNode) -> list[PlanGraphNode]: 256 | if isinstance(node, PlanGraphNode): 257 | preds = self.graph.predecessors(self.nodes.index(node)) 258 | else: 259 | preds = self.graph.predecessors(self._node_lookup[node][0]) 260 | 261 | return preds 262 | 263 | def successors(self, node: str | PlanGraphNode) -> list[PlanGraphNode]: 264 | if isinstance(node, PlanGraphNode): 265 | succs = self.graph.successors(self.nodes.index(node)) 266 | else: 267 | succs = self.graph.successors(self._node_lookup[node][0]) 268 | 269 | return succs 270 | 271 | def in_edges( 272 | self, node: str | PlanGraphNode 273 | ) -> list[tuple[PlanGraphNode, PlanGraphEdge]]: 274 | if isinstance(node, PlanGraphNode): 275 | edges = self.graph.in_edges(self.nodes.index(node)) 276 | else: 277 | edges = self.graph.in_edges(self._node_lookup[node][0]) 278 | 279 | return [(self.nodes[u], edge) for u, _, edge in edges] 280 | 281 | def out_edges( 282 | self, node: str | PlanGraphNode 283 | ) -> list[tuple[PlanGraphNode, PlanGraphEdge]]: 284 | if isinstance(node, PlanGraphNode): 285 | edges = self.graph.out_edges(self.nodes.index(node)) 286 | else: 287 | edges = self.graph.out_edges(self._node_lookup[node][0]) 288 | 289 | return [(self.nodes[v], edge) for _, v, edge in edges] 290 | 291 | @staticmethod 292 | def _build_unique_predicate_name( 293 | predicate_name: str, argument_names: Iterable[str] 294 | ) -> str: 295 | """ 296 | Build a unique name for a predicate based on its name and argument names. 297 | 298 | Parameters: 299 | predicate_name (str): The name of the predicate. 300 | argument_names (Iterable[str]): Sequence of argument names 301 | for the predicate. 302 | 303 | Returns: 304 | str: The unique name for the predicate. 305 | """ 306 | return "-".join([predicate_name, *argument_names]) 307 | 308 | @cached_property 309 | def domain(self) -> str | None: 310 | """ 311 | Get the domain of the scene graph. 312 | 313 | Returns: 314 | str: The domain of the scene graph. 315 | """ 316 | return self._domain 317 | 318 | @cached_property 319 | def constant_nodes(self) -> list[PlanGraphNode]: 320 | """Get a list of constant nodes in the scene graph. 321 | 322 | Returns: 323 | list[PlanGraphNode]: A list of constant nodes. 324 | """ 325 | return self._constant_nodes 326 | 327 | @property 328 | def constants(self) -> list[dict[str, Any]]: 329 | return self._constants 330 | 331 | @property 332 | def predicate_nodes(self) -> list[PlanGraphNode]: 333 | """Get a list of predicate nodes in the scene graph. 334 | 335 | Returns: 336 | list[PlanGraphNode]: A list of predicate nodes. 337 | """ 338 | return self._predicate_nodes 339 | 340 | @property 341 | def predicates(self) -> list[dict[str, Any]]: 342 | return self._predicates 343 | 344 | def __eq__(self, other: "PlanGraph") -> bool: 345 | """ 346 | Check if two plan graphs are equal. 347 | 348 | Parameters: 349 | other (PlanGraph): The other plan graph to compare. 350 | 351 | Returns: 352 | bool: True if the plan graphs are equal, False otherwise. 353 | """ 354 | return ( 355 | isinstance(other, PlanGraph) 356 | and set(self.nodes) == set(other.nodes) 357 | and set(self.edges) == set(other.edges) 358 | and self.domain == other.domain 359 | and set(self._requirements) == set(other._requirements) 360 | ) 361 | 362 | def plot(self, fig: plt.Figure | None = None) -> plt.Figure: 363 | """Generate a plot of the graph, sorted by topological generation. 364 | 365 | Args: 366 | fig (plt.Figure | None, optional): The figure to plot on. Defaults 367 | to None. 368 | 369 | Returns: 370 | plt.Figure: The figure containing the plot. 371 | """ 372 | # rx has no plotting functionality 373 | nx_graph = nx.MultiDiGraph() 374 | nx_graph.add_edges_from( 375 | [(u.node, v.node, {"data": edge}) for u, v, edge in self.edges] 376 | ) 377 | 378 | for layer, nodes in enumerate(nx.topological_generations(nx_graph)): 379 | for node in nodes: 380 | nx_graph.nodes[node]["layer"] = layer 381 | 382 | pos = nx.multipartite_layout( 383 | nx_graph, 384 | align="horizontal", 385 | subset_key="layer", 386 | scale=-1, 387 | ) 388 | 389 | fig = fig or plt.figure() 390 | 391 | nx.draw(nx_graph, pos=pos, ax=fig.gca(), with_labels=True) 392 | 393 | return fig 394 | 395 | 396 | class SceneGraph(PlanGraph): 397 | """ 398 | Subclass of PlanGraph representing a scene graph. 399 | 400 | Attributes: 401 | constants (property): A dictionary of constant nodes in the scene graph. 402 | domain (property): The domain of the scene graph. 403 | """ 404 | 405 | def __init__( 406 | self, 407 | constants: list[dict[str, Any]], 408 | predicates: list[dict[str, Any]], 409 | domain: str | None = None, 410 | scene: Scene | None = None, 411 | requirements: tuple[str] = (), 412 | ): 413 | """ 414 | Initialize the SceneGraph instance. 415 | 416 | Parameters: 417 | constants (list): List of dictionaries representing constants. 418 | predicates (list): List of dictionaries representing predicates. 419 | domain (str, optional): The domain of the scene graph. 420 | Defaults to None. 421 | scene (str, optional): The scene of the scene graph. 422 | Defaults to None. 423 | requirements (list, optional): List of requirements for the scene 424 | graph. 425 | """ 426 | 427 | super().__init__(constants, domain=domain, requirements=requirements) 428 | 429 | self.scene = scene 430 | 431 | for predicate in predicates: 432 | self._add_predicate(predicate, scene=scene) 433 | 434 | 435 | class ProblemGraph(PlanGraph): 436 | """ 437 | Subclass of PlanGraph representing a scene graph. 438 | 439 | Attributes: 440 | constants (property): A dictionary of constant nodes in the scene graph. 441 | init_predicates (property): A dictionary of predicate nodes in the initial scene graph. 442 | goal_predicates (property): A dictionary of predicate nodes in the goal scene graph. 443 | """ 444 | 445 | def __init__( 446 | self, 447 | constants: list[dict[str, Any]], 448 | init_predicates: list[dict[str, Any]], 449 | goal_predicates: list[dict[str, Any]], 450 | domain: str | None = None, 451 | requirements: tuple[str] = (), 452 | ): 453 | """ 454 | Initialize the ProblemGraph instance. 455 | 456 | Parameters: 457 | constants (list): List of dictionaries representing constants. 458 | init_predicates (list): List of dictionaries representing predicates 459 | in the initial scene. 460 | goal_predicates (list): List of dictionaries representing predicates 461 | in the goal scene. 462 | domain (str, optional): The domain of the scene graph. 463 | Defaults to None. 464 | """ 465 | super().__init__(constants, domain=domain, requirements=requirements) 466 | 467 | self._init_predicates: list[dict[str, Any]] = [] 468 | self._init_predicate_nodes: list[PlanGraphNode] = [] 469 | self._goal_predicates: list[dict[str, Any]] = [] 470 | self._goal_predicate_nodes: list[PlanGraphNode] = [] 471 | 472 | for predicate in init_predicates: 473 | self._add_predicate(predicate, scene=Scene.INIT) 474 | 475 | for predicate in goal_predicates: 476 | self._add_predicate(predicate, scene=Scene.GOAL) 477 | 478 | def __eq__(self, other: "ProblemGraph") -> bool: 479 | return ( 480 | super().__eq__(other) 481 | and set(self.init_predicate_nodes) == set(other.init_predicate_nodes) 482 | and set(self.goal_predicate_nodes) == set(other.goal_predicate_nodes) 483 | ) 484 | 485 | def add_node(self, node: PlanGraphNode): 486 | super().add_node(node) 487 | if node.label == Label.PREDICATE: 488 | if node.scene == Scene.INIT: 489 | self._init_predicate_nodes.append(node) 490 | elif node.scene == Scene.GOAL: 491 | self._goal_predicate_nodes.append(node) 492 | 493 | def _add_predicate(self, predicate: dict[str, Any], scene: Scene | None = None): 494 | super()._add_predicate(predicate, scene) 495 | 496 | if scene == Scene.INIT: 497 | self._init_predicates.append(predicate) 498 | elif scene == Scene.GOAL: 499 | self._goal_predicates.append(predicate) 500 | 501 | @property 502 | def init_predicate_nodes(self) -> list[PlanGraphNode]: 503 | """Get a list of predicate nodes in the initial scene. 504 | 505 | Returns: 506 | list[PlanGraphNode]: A list of predicate nodes in the initial scene. 507 | """ 508 | return self._init_predicate_nodes 509 | 510 | @property 511 | def goal_predicate_nodes(self) -> list[PlanGraphNode]: 512 | """Get a list of predicate nodes in the goal scene. 513 | 514 | Returns: 515 | list[PlanGraphNode]: A list of predicate nodes in the goal scene. 516 | """ 517 | return self._goal_predicate_nodes 518 | 519 | @property 520 | def init_predicates(self) -> list[dict[str, Any]]: 521 | return self._init_predicates 522 | 523 | @property 524 | def goal_predicates(self) -> list[dict[str, Any]]: 525 | return self._goal_predicates 526 | 527 | def init(self) -> SceneGraph: 528 | """Return the initial scene graph. 529 | 530 | Returns: 531 | SceneGraph: The initial scene graph. 532 | """ 533 | return SceneGraph( 534 | constants=self.constants, 535 | predicates=self.init_predicates, 536 | domain=self.domain, 537 | scene=Scene.INIT, 538 | requirements=self._requirements, 539 | ) 540 | 541 | def goal(self) -> SceneGraph: 542 | """Return the goal scene graph. 543 | 544 | Returns: 545 | SceneGraph: The goal scene graph. 546 | """ 547 | return SceneGraph( 548 | constants=self.constants, 549 | predicates=self.goal_predicates, 550 | domain=self.domain, 551 | scene=Scene.GOAL, 552 | requirements=self._requirements, 553 | ) 554 | 555 | def decompose(self) -> tuple[SceneGraph, SceneGraph]: 556 | """Decompose the problem graph into initial and goal scene graphs. 557 | 558 | Returns: 559 | tuple[SceneGraph, SceneGraph]: A tuple containing the initial and goal scene graphs. 560 | """ 561 | 562 | init_scene = self.init() 563 | goal_scene = self.goal() 564 | 565 | return init_scene, goal_scene 566 | 567 | @staticmethod 568 | def join(init: SceneGraph, goal: SceneGraph) -> "ProblemGraph": 569 | """ 570 | Combine initial and goal scene graphs into a problem graph. 571 | 572 | Parameters: 573 | init (SceneGraph): The initial scene graph. 574 | goal (SceneGraph): The goal scene graph. 575 | 576 | Returns: 577 | ProblemGraph: The combined problem graph. 578 | """ 579 | return ProblemGraph( 580 | constants=init.constants, 581 | init_predicates=init.predicates, 582 | goal_predicates=goal.predicates, 583 | domain=init.domain, 584 | requirements=init._requirements, 585 | ) 586 | 587 | def to_pddl_str(self) -> str: 588 | """ 589 | Convert a ProblemGraph object to a PDDL problem description string. 590 | 591 | NOTE: REQUIREMENTS ARE NOT SUPPORTED YET. 592 | 593 | Parameters: 594 | graph (ProblemGraph): The ProblemGraph object to convert. 595 | 596 | Returns: 597 | str: A string containing the PDDL problem description. 598 | """ 599 | constant_objs = [ 600 | Constant(name=n.name) 601 | for n in self.nodes 602 | if n.label == Label.CONSTANT 603 | ] 604 | 605 | init_predicates = [] 606 | goal_predicates = [] 607 | for n in (n for n in self.nodes if n.label == Label.PREDICATE): 608 | args: list[PlanGraphNode] = [ 609 | v for v, _ in sorted(self.out_edges(n), key=lambda e: e[1].position) 610 | ] 611 | pddl_args = [ 612 | Constant(name=v.name) 613 | for v in args 614 | ] 615 | 616 | predicate = Predicate(n.typing, *pddl_args) 617 | if n.scene == Scene.INIT: 618 | init_predicates.append(predicate) 619 | elif n.scene == Scene.GOAL: 620 | goal_predicates.append(predicate) 621 | 622 | return problem_to_string( 623 | Problem( 624 | name="name", 625 | domain=Domain(name=self.domain, requirements=[]), 626 | objects=constant_objs, 627 | init=sorted(init_predicates), 628 | goal=And(*sorted(goal_predicates)), 629 | requirements=[], 630 | ) 631 | ) 632 | -------------------------------------------------------------------------------- /planetarium/metric.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import rustworkx as rx 3 | 4 | from planetarium import graph 5 | 6 | 7 | def _preserves_mapping( 8 | source: graph.PlanGraphNode, 9 | target: graph.PlanGraphNode, 10 | mapping: dict, 11 | ) -> bool: 12 | """ 13 | Check if a mapping is preserved between the nodes. 14 | 15 | Parameters: 16 | source (graph.PlanGraphNode): The source node. 17 | target (graph.PlanGraphNode): The target node. 18 | mapping (dict): The mapping between node names. 19 | 20 | Returns: 21 | bool: True if the mapping preserves names, False otherwise. 22 | """ 23 | return ( 24 | source.label == graph.Label.CONSTANT 25 | and target.label == graph.Label.CONSTANT 26 | and mapping[source.name] == target.name 27 | ) 28 | 29 | 30 | def _same_typing(source: graph.PlanGraphNode, target: graph.PlanGraphNode) -> bool: 31 | """ 32 | Check if the typing of two nodes is the same. 33 | 34 | Parameters: 35 | source (graph.PlanGraphNode): The source node. 36 | target (graph.PlanGraphNode): The target node. 37 | 38 | Returns: 39 | bool: True if typings are the same, False otherwise. 40 | """ 41 | return ( 42 | source.label == graph.Label.CONSTANT 43 | and target.label == graph.Label.CONSTANT 44 | and source.typing == target.typing 45 | ) 46 | 47 | 48 | def _node_matching( 49 | source: graph.PlanGraphNode, 50 | target: graph.PlanGraphNode, 51 | mapping: dict | None, 52 | ) -> bool: 53 | """ 54 | Check if two nodes match based on their labels, positions, and typings. 55 | 56 | Parameters: 57 | source (graph.PlanGraphNode): The source node. 58 | target (graph.PlanGraphNode): The target node. 59 | mapping (dict | None): The mapping between node names. 60 | 61 | Returns: 62 | bool: True if nodes match, False otherwise. 63 | """ 64 | match (source.label, target.label): 65 | case (graph.Label.CONSTANT, graph.Label.CONSTANT): 66 | return _same_typing(source, target) and ( 67 | _preserves_mapping(source, target, mapping) if mapping else True 68 | ) 69 | case (graph.Label.PREDICATE, graph.Label.PREDICATE): 70 | # type of predicate should be the same as well 71 | return source.typing == target.typing 72 | case _: 73 | return False 74 | 75 | 76 | def _edge_matching( 77 | source: graph.PlanGraphEdge, 78 | target: graph.PlanGraphEdge, 79 | attributes: dict[str, str | int | graph.Scene, graph.Label] = {}, 80 | ) -> bool: 81 | """ 82 | Check if two edges match based on their attributes. 83 | 84 | Parameters: 85 | source (graph.PlanGraphEdge): The source edge. 86 | target (graph.PlanGraphEdge): The target edge. 87 | attributes (dict): The attributes to match. 88 | 89 | Returns: 90 | bool: True if edges match, False otherwise. 91 | """ 92 | 93 | def _getattr(obj, attr): 94 | v = getattr(obj, attr, attributes.get(attr)) 95 | return v 96 | 97 | return all(_getattr(source, attr) == _getattr(target, attr) for attr in attributes) 98 | 99 | 100 | def isomorphic( 101 | source: graph.ProblemGraph | graph.SceneGraph, 102 | target: graph.ProblemGraph | graph.SceneGraph, 103 | mapping: dict | None = None, 104 | ) -> bool: 105 | """ 106 | Find all valid isomorphic mappings between nodes of two scene graphs. 107 | 108 | Parameters: 109 | source (ProblemGraph): The source problem graph. 110 | target (ProblemGraph): The target problem graph. 111 | mapping (dict | None): The initial mapping between node names. 112 | 113 | Returns: 114 | bool: True if there is a valid mapping, False otherwise. 115 | """ 116 | node_matching = functools.partial(_node_matching, mapping=mapping) 117 | edge_matching = functools.partial( 118 | _edge_matching, 119 | attributes={"position": -1, "predicate": "", "scene": None}, 120 | ) 121 | 122 | return rx.is_isomorphic( 123 | source.graph, 124 | target.graph, 125 | node_matcher=node_matching, 126 | edge_matcher=edge_matching, 127 | ) 128 | 129 | 130 | def equals( 131 | source: graph.ProblemGraph, 132 | target: graph.ProblemGraph, 133 | is_placeholder: bool = False, 134 | ) -> bool: 135 | """ 136 | Check if there is a valid mapping between problem graphs. 137 | 138 | Parameters: 139 | source (ProblemGraph | SceneGraph): The initial problem graph. 140 | target (ProblemGraph | SceneGraph): The goal problem graph. 141 | is_placeholder (bool): If False, the function will compare the initial 142 | and goal scene graphs together. If True, the function will compare 143 | the two initial scene graphs and the two goal scene graphs 144 | separately. 145 | 146 | Returns: 147 | bool: True if there is a valid mapping, False otherwise. 148 | """ 149 | if source == target: 150 | return True 151 | if not is_placeholder: 152 | return isomorphic(source, target) 153 | else: 154 | source_init, source_goal = source.decompose() 155 | target_init, target_goal = target.decompose() 156 | 157 | if source_init == target_init and source_goal == target_goal: 158 | return True 159 | 160 | valid_init = isomorphic(source_init, target_init) 161 | valid_goal = isomorphic(source_goal, target_goal) 162 | 163 | return valid_init and valid_goal 164 | -------------------------------------------------------------------------------- /planetarium/oracle.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import jinja2 as jinja 4 | from pddl.core import Action 5 | 6 | from planetarium import graph 7 | 8 | from .reduced_graph import ReducedSceneGraph, ReducedProblemGraph 9 | from .oracles import ORACLES 10 | 11 | 12 | plan_template = jinja.Template( 13 | """ 14 | {%- for action in actions -%} 15 | ({{ action.name }} {{ action.parameters | join(", ") }}) 16 | {% endfor %} 17 | """ 18 | ) 19 | 20 | 21 | class DomainNotSupportedError(Exception): 22 | pass 23 | 24 | 25 | def reduce( 26 | graph: graph.SceneGraph, 27 | domain: str | None = None, 28 | ) -> ReducedSceneGraph | ReducedProblemGraph: 29 | """Reduces a scene graph to a Directed Acyclic Graph. 30 | 31 | Args: 32 | graph (graph.SceneGraph): The scene graph to reduce. 33 | domain (str, optional): The domain of the scene graph. 34 | 35 | Returns: 36 | ReducedGraph: The reduced problem graph. 37 | """ 38 | domain = domain or graph.domain 39 | if oracle := ORACLES.get(domain): 40 | return oracle.reduce(graph) 41 | raise DomainNotSupportedError(f"Domain {domain} not supported.") 42 | 43 | 44 | def inflate( 45 | scene: ReducedSceneGraph | ReducedProblemGraph, 46 | domain: str | None = None, 47 | ) -> graph.SceneGraph: 48 | """Inflate a reduced scene graph to a SceneGraph. 49 | 50 | Args: 51 | scene (ReducedGraph): The reduced scene graph to respecify. 52 | domain (str | None, optional): The domain of the scene graph. Defaults 53 | to None. 54 | 55 | Returns: 56 | graph.SceneGraph: The respecified, inflated scene graph. 57 | """ 58 | domain = domain or scene._domain 59 | if oracle := ORACLES.get(domain): 60 | return oracle.inflate(scene) 61 | raise DomainNotSupportedError(f"Domain {domain} not supported.") 62 | 63 | 64 | def fully_specify( 65 | problem: graph.ProblemGraph, 66 | domain: str | None = None, 67 | return_reduced: bool = False, 68 | ) -> graph.ProblemGraph | ReducedProblemGraph: 69 | """Fully specifies a goal state. 70 | 71 | Args: 72 | problem (graph.ProblemGraph): The problem graph with the goal state to 73 | fully specify. 74 | domain (str | None, optional): The domain of the scene graph. Defaults 75 | to None. 76 | return_reduced (bool, optional): Whether to return the reduced scene 77 | graph. Defaults to False. 78 | 79 | Returns: 80 | graph.ProblemGraph: The fully specified problem graph. 81 | """ 82 | domain = domain or problem.domain 83 | 84 | if oracle := ORACLES.get(domain): 85 | return oracle.fully_specify(problem, return_reduced=return_reduced) 86 | raise DomainNotSupportedError(f"Domain {domain} not supported.") 87 | 88 | 89 | def plan(problem: graph.ProblemGraph, domain: str | None = None) -> list[Action]: 90 | """Plans a sequence of actions to solve a problem. 91 | 92 | Args: 93 | problem (graph.ProblemGraph): The problem to plan for. 94 | 95 | Returns: 96 | str: The sequence of actions to solve the problem. 97 | """ 98 | domain = domain or problem.domain 99 | if oracle := ORACLES.get(domain): 100 | return oracle.plan(problem) 101 | raise DomainNotSupportedError(f"Domain {domain} not supported.") 102 | 103 | 104 | def plan_to_string(actions: list[Action]) -> str: 105 | """Converts a list of actions to a string. 106 | 107 | Args: 108 | actions (list[Action]): The list of actions to convert. 109 | 110 | Returns: 111 | str: The string representation of the actions. 112 | """ 113 | return plan_template.render(actions=actions) 114 | -------------------------------------------------------------------------------- /planetarium/oracles/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["blocksworld", "gripper", "rover_single", "floortile"] 2 | 3 | from . import oracle 4 | 5 | from . import blocksworld 6 | from . import gripper 7 | from . import rover_single 8 | from . import floortile 9 | 10 | ORACLES: dict[str, oracle.Oracle] = { 11 | "blocksworld": blocksworld.BlocksworldOracle, 12 | "gripper": gripper.GripperOracle, 13 | "rover-single": rover_single.RoverSingleOracle, 14 | "floor-tile": floortile.FloorTileOracle, 15 | } 16 | -------------------------------------------------------------------------------- /planetarium/oracles/blocksworld.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from pddl.core import Action 4 | import rustworkx as rx 5 | 6 | from . import oracle 7 | from .. import graph 8 | from ..reduced_graph import ReducedSceneGraph, ReducedProblemGraph, ReducedNode 9 | 10 | 11 | # Add enum for blocksworld domain 12 | ReducedNode.register( 13 | { 14 | "TABLE": "table", 15 | "CLEAR": "clear", 16 | "ARM": "arm", 17 | }, 18 | "blocksworld", 19 | ) 20 | 21 | 22 | class BlocksworldOracle(oracle.Oracle): 23 | 24 | def reduce( 25 | scene: graph.SceneGraph | graph.ProblemGraph, 26 | ) -> ReducedSceneGraph | ReducedProblemGraph: 27 | """Reduces a blocksworld scene graph to a Directed Acyclic Graph. 28 | 29 | Args: 30 | problem (graph.SceneGraph | graph.ProblemGraph): The scene graph to 31 | reduce. 32 | 33 | Returns: 34 | ReducedGraph: The reduced problem graph. 35 | """ 36 | 37 | nodes = defaultdict(list) 38 | for node in scene.nodes: 39 | nodes[node.label].append(node) 40 | 41 | match scene: 42 | case graph.ProblemGraph( 43 | _constants=constants, 44 | _predicates=predicates, 45 | _domain=domain, 46 | _requirements=requirements, 47 | ): 48 | reduced = ReducedProblemGraph( 49 | constants=constants, 50 | domain=domain, 51 | requirements=requirements, 52 | ) 53 | case graph.SceneGraph( 54 | constants=constants, 55 | _predicates=predicates, 56 | scene=scene, 57 | _domain=domain, 58 | _requirements=requirements, 59 | ): 60 | reduced = ReducedSceneGraph( 61 | constants=constants, 62 | domain=domain, 63 | scene=scene, 64 | requirements=requirements, 65 | ) 66 | case _: 67 | raise ValueError("Scene must be a SceneGraph or ProblemGraph.") 68 | 69 | for predicate in predicates: 70 | params = predicate["parameters"] 71 | reduced_edge = graph.PlanGraphEdge( 72 | predicate=predicate["typing"], 73 | scene=predicate.get("scene"), 74 | ) 75 | match (predicate["typing"], len(params)): 76 | case ("arm-empty", 0): 77 | reduced.add_edge(ReducedNode.CLEAR, ReducedNode.ARM, reduced_edge) 78 | case ("on-table", 1): 79 | reduced.add_edge(params[0], ReducedNode.TABLE, reduced_edge) 80 | case ("clear", 1): 81 | reduced.add_edge(ReducedNode.CLEAR, params[0], reduced_edge) 82 | case ("on", 2): 83 | reduced.add_edge(params[0], params[1], reduced_edge) 84 | case ("holding", 1): 85 | reduced.add_edge(params[0], ReducedNode.ARM, reduced_edge) 86 | return reduced 87 | 88 | def inflate( 89 | scene: ReducedSceneGraph | ReducedProblemGraph, 90 | ) -> graph.SceneGraph: 91 | """Respecify a blocksworld scene graph. 92 | 93 | Args: 94 | scene (ReducedGraph): The reduced SceneGraph of a scene. 95 | 96 | Returns: 97 | graph.SceneGraph: The respecified scene graph. 98 | """ 99 | constants = [] 100 | predicates = [] 101 | 102 | for node in scene.nodes: 103 | if not isinstance(node.node, ReducedNode): 104 | constants.append({"name": node.node, "typing": node.typing}) 105 | 106 | for u, v, edge in scene.edges: 107 | match (u.node, v.node): 108 | case (ReducedNode.CLEAR, ReducedNode.ARM): 109 | predicates.append( 110 | { 111 | "typing": "arm-empty", 112 | "parameters": [], 113 | "scene": edge.scene, 114 | } 115 | ) 116 | case (ReducedNode.CLEAR, _): 117 | predicates.append( 118 | { 119 | "typing": "clear", 120 | "parameters": [v.node], 121 | "scene": edge.scene, 122 | } 123 | ) 124 | case (_, ReducedNode.TABLE): 125 | predicates.append( 126 | { 127 | "typing": "on-table", 128 | "parameters": [u.node], 129 | "scene": edge.scene, 130 | } 131 | ) 132 | case (_, ReducedNode.ARM): 133 | predicates.append( 134 | { 135 | "typing": "holding", 136 | "parameters": [u.node], 137 | "scene": edge.scene, 138 | } 139 | ) 140 | case (_, _): 141 | predicates.append( 142 | { 143 | "typing": "on", 144 | "parameters": [u.node, v.node], 145 | "scene": edge.scene, 146 | } 147 | ) 148 | 149 | if isinstance(scene, ReducedProblemGraph): 150 | return graph.ProblemGraph( 151 | constants, 152 | [pred for pred in predicates if pred["scene"] == graph.Scene.INIT], 153 | [pred for pred in predicates if pred["scene"] == graph.Scene.GOAL], 154 | domain="blocksworld", 155 | requirements=scene._requirements, 156 | ) 157 | else: 158 | return graph.SceneGraph( 159 | constants, 160 | predicates, 161 | domain="blocksworld", 162 | scene=scene.scene, 163 | requirements=scene._requirements, 164 | ) 165 | 166 | @staticmethod 167 | def _blocksworld_underspecified_blocks( 168 | scene: ReducedSceneGraph, 169 | ) -> tuple[set[str], set[str], bool]: 170 | """Finds blocks that are not fully specified. 171 | 172 | Args: 173 | scene (ReducedGraph): The reduced SceneGraph of a scene. 174 | 175 | Returns: 176 | tuple[set[str], set[str], bool]: The set of blocks that are not fully 177 | specified. 178 | - blocks that do not specify what is on top. 179 | - blocks that do not specify what is on the bottom. 180 | """ 181 | top_blocks = set() 182 | bottom_blocks = set() 183 | arm_behavior_defined = scene.in_degree(ReducedNode.ARM) > 0 184 | held_block = ( 185 | scene.predecessors(ReducedNode.ARM)[0] if arm_behavior_defined else None 186 | ) 187 | for node in scene.nodes: 188 | if node.label == graph.Label.CONSTANT: 189 | if not scene.in_edges(node) and node != held_block: 190 | top_blocks.add(node) 191 | if not scene.out_edges(node): 192 | bottom_blocks.add(node) 193 | return top_blocks, bottom_blocks, not arm_behavior_defined 194 | 195 | @staticmethod 196 | def _detached_blocks( 197 | nodesA: set[str], 198 | nodesB: set[str], 199 | scene: ReducedSceneGraph, 200 | ) -> tuple[set[str], set[str]]: 201 | """Finds nodes that are not connected to the rest of the scene graph. 202 | 203 | Args: 204 | nodesA (set[str]): The set of nodes to check. 205 | nodesB (set[str]): The set of nodes to check against. 206 | scene (ReducedGraph): The scene graph to check against. 207 | 208 | Returns: 209 | tuple[set[str], set[str]]: The set of nodes that are not connected to 210 | the rest of the scene graph. 211 | """ 212 | _nodesA = set(nodesA) 213 | _nodesB = set(nodesB) 214 | 215 | for a in nodesA: 216 | for b in nodesB: 217 | a_index = scene.nodes.index(a) 218 | b_index = scene.nodes.index(b) 219 | if ( 220 | not rx.has_path(scene.graph, a_index, b_index) 221 | and not rx.has_path(scene.graph, b_index, a_index) 222 | and a != b 223 | ): 224 | _nodesA.discard(a) 225 | _nodesB.discard(b) 226 | 227 | return _nodesA, _nodesB 228 | 229 | def fully_specify( 230 | problem: graph.ProblemGraph, 231 | return_reduced: bool = False, 232 | ) -> graph.ProblemGraph | ReducedProblemGraph: 233 | """Fully specifies a blocksworld scene graph. 234 | 235 | Adds any missing edges to fully specify the scene graph, without adding 236 | edges that change the problem represented by the graph. 237 | 238 | Args: 239 | problem (graph.ProblemGraph): The problem graph to fully specify. 240 | return_reduced (bool, optional): Whether to return a reduced problem graph. 241 | Defaults to False. 242 | 243 | Returns: 244 | ProblemGraph | ReducedProblemGraph: The fully specified problem graph. 245 | """ 246 | inflated_init, inflated_goal = problem.decompose() 247 | scene = BlocksworldOracle.reduce(inflated_goal) 248 | top_blocks, bottom_blocks, arm_empty = ( 249 | BlocksworldOracle._blocksworld_underspecified_blocks(scene) 250 | ) 251 | top_blocks_, bottom_blocks_ = BlocksworldOracle._detached_blocks( 252 | top_blocks, 253 | bottom_blocks, 254 | scene, 255 | ) 256 | 257 | for block in top_blocks_: 258 | scene.add_edge( 259 | ReducedNode.CLEAR, 260 | block, 261 | graph.PlanGraphEdge(predicate="clear", scene=scene.scene), 262 | ) 263 | for block in bottom_blocks_: 264 | scene.add_edge( 265 | block, 266 | ReducedNode.TABLE, 267 | graph.PlanGraphEdge(predicate="on-table", scene=scene.scene), 268 | ) 269 | 270 | # handle arm 271 | if arm_empty and not (top_blocks & bottom_blocks): 272 | scene.add_edge( 273 | ReducedNode.CLEAR, 274 | ReducedNode.ARM, 275 | graph.PlanGraphEdge(predicate="arm-empty", scene=scene.scene), 276 | ) 277 | 278 | if return_reduced: 279 | return ReducedProblemGraph.join( 280 | BlocksworldOracle.reduce(inflated_init), scene 281 | ) 282 | else: 283 | return graph.ProblemGraph.join( 284 | inflated_init, BlocksworldOracle.inflate(scene) 285 | ) 286 | 287 | def plan(problem: graph.ProblemGraph) -> list[Action]: 288 | problem = BlocksworldOracle.fully_specify(problem, return_reduced=True) 289 | init, goal = problem.decompose() 290 | actions = [] 291 | 292 | # Process init scene 293 | # check if arm is empty 294 | if ( 295 | not init.has_edge(ReducedNode.CLEAR, ReducedNode.ARM) 296 | and init.in_degree(ReducedNode.ARM) == 1 297 | ): 298 | obj = init.predecessors(ReducedNode.ARM)[0] 299 | actions.append(Action("putdown", [obj.name])) 300 | 301 | # unstack everything in init 302 | for idx in rx.topological_sort(init.graph): 303 | node = init.nodes[idx] 304 | if isinstance(node.node, ReducedNode): 305 | continue 306 | elif init.successors(node)[0].name in (ReducedNode.ARM, ReducedNode.TABLE): 307 | # if the block is on the table or in the arm, ignore it 308 | continue 309 | else: 310 | actions.append( 311 | Action("unstack", [node.name, init.successors(node)[0].name]) 312 | ) 313 | actions.append(Action("putdown", [node.name])) 314 | 315 | # Process goal scene 316 | # stack everything in goal 317 | for idx in reversed(rx.topological_sort(goal.graph)): 318 | node = goal.nodes[idx] 319 | if isinstance(node.node, ReducedNode): 320 | continue 321 | elif goal.out_degree(node.node) == 0: 322 | # isn't defined to be on anything (keep on table) 323 | continue 324 | elif goal.successors(node)[0].node in (ReducedNode.ARM, ReducedNode.TABLE): 325 | # if the block is on the table or in the arm, ignore it 326 | continue 327 | else: 328 | actions.append(Action("pickup", [node.name])) 329 | actions.append( 330 | Action("stack", [node.name, goal.successors(node)[0].name]) 331 | ) 332 | 333 | # Check if arm should be holding it 334 | if ( 335 | not goal.has_edge(ReducedNode.CLEAR, ReducedNode.ARM) 336 | and goal.in_degree(ReducedNode.ARM) == 1 337 | ): 338 | obj = goal.predecessors(ReducedNode.ARM)[0] 339 | actions.append(Action("pickup", [obj.name])) 340 | 341 | return actions 342 | -------------------------------------------------------------------------------- /planetarium/oracles/floortile.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import copy 3 | 4 | from pddl.core import Action 5 | import rustworkx as rx 6 | 7 | from . import oracle 8 | from .. import graph 9 | from ..reduced_graph import ReducedSceneGraph, ReducedProblemGraph, ReducedNode 10 | 11 | # Add the ReducedNode enum for the gripper domain 12 | ReducedNode.register( 13 | { 14 | "AVAILABLE": "available-color", 15 | }, 16 | "floor-tile", 17 | ) 18 | 19 | 20 | class FloorTileOracle(oracle.Oracle): 21 | 22 | def reduce( 23 | scene: graph.SceneGraph | graph.ProblemGraph, 24 | ) -> ReducedSceneGraph | ReducedProblemGraph: 25 | """Reduces a floortile scene graph to a reduced scene graph. 26 | 27 | Args: 28 | scene (graph.SceneGraph | graph.ProblemGraph): The scene graph to reduce. 29 | 30 | Returns: 31 | ReducedSceneGraph | ReducedProblemGraph: The reduced scene graph. 32 | """ 33 | match scene: 34 | case graph.ProblemGraph( 35 | _constants=constants, 36 | _predicates=predicates, 37 | _domain=domain, 38 | _requirements=requirements, 39 | ): 40 | reduced = ReducedProblemGraph( 41 | constants=constants, 42 | domain=domain, 43 | requirements=requirements, 44 | ) 45 | case graph.SceneGraph( 46 | constants=constants, 47 | _predicates=predicates, 48 | scene=scene, 49 | _domain=domain, 50 | _requirements=requirements, 51 | ): 52 | reduced = ReducedSceneGraph( 53 | constants=constants, 54 | domain=domain, 55 | scene=scene, 56 | requirements=requirements, 57 | ) 58 | case _: 59 | raise ValueError("Scene must be a SceneGraph or ProblemGraph.") 60 | 61 | for predicate in predicates: 62 | params = predicate["parameters"] 63 | reduced_edge = graph.PlanGraphEdge( 64 | predicate=predicate["typing"], 65 | scene=predicate.get("scene"), 66 | ) 67 | match (predicate["typing"], params): 68 | case (_, [x, y]): 69 | reduced.add_edge(x, y, reduced_edge) 70 | case ("clear", [x]): 71 | reduced.add_edge(ReducedNode.CLEAR, x, reduced_edge) 72 | case ("available-color", [x]): 73 | reduced.add_edge(ReducedNode.AVAILABLE, x, reduced_edge) 74 | 75 | return reduced 76 | 77 | def inflate( 78 | scene: ReducedSceneGraph | ReducedProblemGraph, 79 | ) -> graph.SceneGraph | graph.ProblemGraph: 80 | """Respecify a reduced floortile scene graph to a full scene graph. 81 | 82 | Args: 83 | scene (ReducedSceneGraph | ReducedProblemGraph): The reduced scene graph 84 | to inflate. 85 | 86 | Returns: 87 | graph.SceneGraph | graph.ProblemGraph: The inflated scene graph. 88 | """ 89 | constants = [] 90 | predicates = [] 91 | 92 | for node in scene.nodes: 93 | if ( 94 | not isinstance(node.node, ReducedNode) 95 | and node.label == graph.Label.CONSTANT 96 | ): 97 | # add constants 98 | constants.append({"name": node.node, "typing": node.typing}) 99 | 100 | for u, v, edge in scene.edges: 101 | match u.node: 102 | case pred_node if isinstance(pred_node, ReducedNode): 103 | predicates.append( 104 | { 105 | "typing": edge.predicate, 106 | "parameters": [v.node], 107 | "scene": edge.scene, 108 | } 109 | ) 110 | case _: 111 | predicates.append( 112 | { 113 | "typing": edge.predicate, 114 | "parameters": [u.node, v.node], 115 | "scene": edge.scene, 116 | } 117 | ) 118 | 119 | if isinstance(scene, ReducedProblemGraph): 120 | return graph.ProblemGraph( 121 | constants, 122 | [pred for pred in predicates if pred["scene"] == graph.Scene.INIT], 123 | [pred for pred in predicates if pred["scene"] == graph.Scene.GOAL], 124 | domain=scene.domain, 125 | requirements=scene._requirements, 126 | ) 127 | else: 128 | return graph.SceneGraph( 129 | constants, 130 | predicates, 131 | domain=scene.domain, 132 | scene=scene.scene, 133 | requirements=scene._requirements, 134 | ) 135 | 136 | @staticmethod 137 | def _apply_unchangeable_predicates( 138 | init: ReducedSceneGraph, 139 | goal: ReducedSceneGraph, 140 | ): 141 | unchangeable = { 142 | "up", 143 | "right", 144 | "painted", 145 | "available-color", 146 | } 147 | for u, v, edge in init.edges: 148 | if edge.predicate in unchangeable: 149 | edge = copy.deepcopy(edge) 150 | edge.scene = graph.Scene.GOAL 151 | if not goal.has_edge(u, v, edge): 152 | goal.add_edge(u, v, edge) 153 | 154 | @staticmethod 155 | def _fixed_color_predicates( 156 | init: ReducedSceneGraph, 157 | goal: ReducedSceneGraph, 158 | robots: list[ReducedNode], 159 | ): 160 | # if two or more colors are available, then the robot color cannot be fixed 161 | # (they can finish painting and switch to any of the available colors) 162 | available_colors = init.successors(ReducedNode.AVAILABLE) 163 | if len(available_colors) > 1: 164 | return 165 | 166 | # find the color each robot has 167 | has_color = defaultdict(list) 168 | for robot in robots: 169 | has_color[robot] = [ 170 | v for v, edge in init.out_edges(robot) if edge.predicate == "robot-has" 171 | ] 172 | 173 | # if no colors are available, then all robots must have their original color 174 | if len(available_colors) == 0: 175 | for robot, colors in has_color.items(): 176 | edge = graph.PlanGraphEdge( 177 | predicate="robot-has", 178 | scene=graph.Scene.GOAL, 179 | ) 180 | for color in colors: 181 | if not goal.has_edge(robot, color, edge): 182 | goal.add_edge(robot, color, edge) 183 | # if only one color is available: 184 | elif len(available_colors) == 1: 185 | # if there is only one robot and it either 186 | # a) has the available color 187 | # b) the available color needs to be painted 188 | # 189 | # then the robot color is fixed to the available color 190 | (available_color,) = available_colors 191 | if len(robots) == 1: 192 | (robot,) = robots 193 | robot_color = has_color[robot][0] 194 | if robot_color == available_color: 195 | edge = graph.PlanGraphEdge( 196 | predicate="robot-has", 197 | scene=graph.Scene.GOAL, 198 | ) 199 | if not goal.has_edge(robot, robot_color, edge): 200 | goal.add_edge(robot, robot_color, edge) 201 | else: 202 | painted_colors = set() 203 | for u, v, edge in goal.edges: 204 | if edge.predicate == "painted": 205 | init_edge = copy.deepcopy(edge) 206 | if not init.has_edge(u, v, init_edge): 207 | painted_colors.add(v) 208 | 209 | if available_color in painted_colors: 210 | edge = graph.PlanGraphEdge( 211 | predicate="robot-has", 212 | scene=graph.Scene.GOAL, 213 | ) 214 | if not goal.has_edge(robot, available_color, edge): 215 | goal.add_edge(robot, available_color, edge) 216 | else: 217 | # if there are more than one robot, every robot that already has that 218 | # color must keep it 219 | for robot, colors in has_color.items(): 220 | if available_color in colors: 221 | edge = graph.PlanGraphEdge( 222 | predicate="robot-has", 223 | scene=graph.Scene.GOAL, 224 | ) 225 | if not goal.has_edge(robot, available_color, edge): 226 | goal.add_edge(robot, available_color, edge) 227 | 228 | # for each node that needs to be painted the available_color, 229 | # if there is only one robot that can paint it, then that robot 230 | # must paint end with that color 231 | painted_nodes = [] 232 | node_reachable_by: list[list] = [] 233 | 234 | subgraph_nodes = [ 235 | i 236 | for i, n in enumerate(goal.nodes) 237 | if n.typing in ({"tile"}, {"robot"}) 238 | ] 239 | subgraph = init.graph.subgraph(subgraph_nodes).to_undirected() 240 | 241 | for u, v, edge in goal.edges: 242 | if edge.predicate == "painted": 243 | painted_nodes.append(u) 244 | reachable = [] 245 | # find all robots that can reach this node 246 | for robot in robots: 247 | robot_idx = subgraph.nodes().index(robot) 248 | u_idx = subgraph.nodes().index(u) 249 | 250 | if rx.has_path(subgraph, robot_idx, u_idx): 251 | reachable.append(robot) 252 | node_reachable_by.append(reachable) 253 | 254 | for r in node_reachable_by: 255 | # if there's only one robot that can paint it, then it must end up 256 | # painting it 257 | if len(r) == 1: 258 | robot = r[0] 259 | # assign the robot the only available color 260 | # (notice branch above ensures there's only one available color) 261 | edge = graph.PlanGraphEdge( 262 | predicate="robot-has", 263 | scene=graph.Scene.GOAL, 264 | ) 265 | if not goal.has_edge(robot, available_color, edge): 266 | goal.add_edge(robot, available_color, edge) 267 | 268 | def _fix_possible_positions( 269 | init: ReducedSceneGraph, 270 | goal: ReducedSceneGraph, 271 | robots: list[graph.PlanGraphNode], 272 | ): 273 | for robot in robots: 274 | init_pos = [ 275 | v for v, edge in init.out_edges(robot) if edge.predicate == "robot-at" 276 | ] 277 | goal_pos = [ 278 | v for v, edge in goal.out_edges(robot) if edge.predicate == "robot-at" 279 | ] 280 | 281 | if not init_pos or goal_pos: 282 | # if no initial position is specified or the goal position is already 283 | # specified, we can't determine the final robot position 284 | continue 285 | 286 | init_pos = init_pos[0] 287 | pos_neighbors = [ 288 | v 289 | for v, edge in (*init.out_edges(init_pos), *init.in_edges(init_pos)) 290 | if edge.predicate in {"up", "right"} 291 | ] 292 | 293 | if not pos_neighbors: 294 | # if the initial position has no neighbors, the robot must end there 295 | edge = graph.PlanGraphEdge( 296 | predicate="robot-at", 297 | scene=graph.Scene.GOAL, 298 | ) 299 | if not goal.has_edge(robot, init_pos, edge): 300 | goal.add_edge(robot, init_pos, edge) 301 | 302 | # TODO: if a robot is the only one that can paint a tile, it must end up painting it 303 | # if a robot ends up on a tile that is disconnected, then it must end up on that tile: 304 | 305 | def fully_specify( 306 | problem: graph.ProblemGraph, 307 | return_reduced: bool = False, 308 | ) -> graph.ProblemGraph | ReducedProblemGraph: 309 | """Fully specifies a floortile scene graph. 310 | 311 | Args: 312 | problem (graph.ProblemGraph): The problem graph to fully specify. 313 | return_reduced (bool, optional): Whether to return a reduced problem graph. 314 | Defaults to False. 315 | 316 | Returns: 317 | graph.ProblemGraph | ReducedProblemGraph: The fully specified problem graph. 318 | """ 319 | inflated_init, inflated_goal = problem.decompose() 320 | 321 | init: ReducedSceneGraph = FloorTileOracle.reduce(inflated_init) 322 | goal: ReducedSceneGraph = FloorTileOracle.reduce(inflated_goal) 323 | 324 | robots = [r for r in init.nodes if r.typing == {"robot"}] 325 | 326 | FloorTileOracle._apply_unchangeable_predicates(init, goal) 327 | FloorTileOracle._fixed_color_predicates(init, goal, robots) 328 | 329 | # if a robot that starts on a tile with no neighbors, it must also end 330 | # on that tile 331 | FloorTileOracle._fix_possible_positions(init, goal, robots) 332 | 333 | if return_reduced: 334 | return ReducedProblemGraph.join(init, goal) 335 | else: 336 | return graph.ProblemGraph.join(inflated_init, FloorTileOracle.inflate(goal)) 337 | -------------------------------------------------------------------------------- /planetarium/oracles/gripper.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import copy 3 | 4 | from pddl.core import Action 5 | 6 | from . import oracle 7 | from .. import graph 8 | from ..reduced_graph import ReducedSceneGraph, ReducedProblemGraph, ReducedNode 9 | 10 | 11 | # Add the ReducedNode enum for the gripper domain 12 | ReducedNode.register( 13 | { 14 | "ROOMS": "room", 15 | "BALLS": "ball", 16 | "GRIPPERS": "gripper", 17 | "ROBBY": "at-robby", 18 | "FREE": "free", 19 | }, 20 | "gripper", 21 | ) 22 | 23 | 24 | class GripperOracle(oracle.Oracle): 25 | 26 | def reduce( 27 | scene: graph.SceneGraph | graph.ProblemGraph, 28 | ) -> ReducedSceneGraph | ReducedProblemGraph: 29 | """Reduces a gripper scene graph to a Directed Acyclic Graph. 30 | 31 | Args: 32 | scene (graph.SceneGraph): The scene graph to reduce. 33 | 34 | Returns: 35 | ReducedGraph: The reduced problem graph. 36 | """ 37 | nodes = defaultdict(list) 38 | for node in scene.nodes: 39 | nodes[node.label].append(node) 40 | 41 | match scene: 42 | case graph.ProblemGraph( 43 | _constants=constants, 44 | _predicates=predicates, 45 | _domain=domain, 46 | _requirements=requirements, 47 | ): 48 | reduced = ReducedProblemGraph( 49 | constants=constants, 50 | domain=domain, 51 | requirements=requirements, 52 | ) 53 | case graph.SceneGraph( 54 | constants=constants, 55 | _predicates=predicates, 56 | scene=scene, 57 | _domain=domain, 58 | _requirements=requirements, 59 | ): 60 | reduced = ReducedSceneGraph( 61 | constants=constants, 62 | domain=domain, 63 | scene=scene, 64 | requirements=requirements, 65 | ) 66 | case _: 67 | raise ValueError("Scene must be a SceneGraph or ProblemGraph.") 68 | 69 | for predicate in predicates: 70 | params = predicate["parameters"] 71 | reduced_edge = graph.PlanGraphEdge( 72 | predicate=predicate["typing"], 73 | scene=predicate.get("scene"), 74 | ) 75 | match (predicate["typing"], len(params)): 76 | case ("at-robby", 1): 77 | reduced.add_edge(ReducedNode.ROBBY, params[0], reduced_edge) 78 | case ("free", 1): 79 | reduced.add_edge(ReducedNode.FREE, params[0], reduced_edge) 80 | case ("ball", 1): 81 | reduced.add_edge(ReducedNode.BALLS, params[0], reduced_edge) 82 | case ("gripper", 1): 83 | reduced.add_edge(ReducedNode.GRIPPERS, params[0], reduced_edge) 84 | case ("room", 1): 85 | reduced.add_edge(ReducedNode.ROOMS, params[0], reduced_edge) 86 | case ("at", 2): 87 | reduced.add_edge(params[0], params[1], reduced_edge) 88 | case ("carry", 2): 89 | reduced.add_edge(params[0], params[1], reduced_edge) 90 | 91 | return reduced 92 | 93 | def inflate( 94 | scene: ReducedSceneGraph | ReducedProblemGraph, 95 | ) -> graph.SceneGraph | graph.ProblemGraph: 96 | """Respecify a gripper scene graph. 97 | 98 | Args: 99 | scene (ReducedGraph): The reduced SceneGraph of a scene. 100 | 101 | Returns: 102 | graph.SceneGraph: The respecified scene graph. 103 | """ 104 | constants = [] 105 | predicates = [] 106 | 107 | for node in scene.nodes: 108 | if not isinstance(node.node, ReducedNode): 109 | constants.append({"name": node.node, "typing": node.typing}) 110 | 111 | for u, v, edge in scene.edges: 112 | match (u.node, v.node): 113 | case (ReducedNode.ROBBY, _): 114 | predicates.append( 115 | { 116 | "typing": "at-robby", 117 | "parameters": [v.node], 118 | "scene": edge.scene, 119 | } 120 | ) 121 | case (ReducedNode.FREE, _): 122 | predicates.append( 123 | { 124 | "typing": "free", 125 | "parameters": [v.node], 126 | "scene": edge.scene, 127 | } 128 | ) 129 | case (ReducedNode.BALLS, _): 130 | predicates.append( 131 | { 132 | "typing": "ball", 133 | "parameters": [v.node], 134 | "scene": edge.scene, 135 | } 136 | ) 137 | case (ReducedNode.GRIPPERS, _): 138 | predicates.append( 139 | { 140 | "typing": "gripper", 141 | "parameters": [v.node], 142 | "scene": edge.scene, 143 | } 144 | ) 145 | case (ReducedNode.ROOMS, _): 146 | predicates.append( 147 | { 148 | "typing": "room", 149 | "parameters": [v.node], 150 | "scene": edge.scene, 151 | } 152 | ) 153 | case (_, _): 154 | predicates.append( 155 | { 156 | "typing": edge.predicate, 157 | "parameters": [u.node, v.node], 158 | "scene": edge.scene, 159 | } 160 | ) 161 | 162 | if isinstance(scene, ReducedProblemGraph): 163 | return graph.ProblemGraph( 164 | constants, 165 | [pred for pred in predicates if pred["scene"] == graph.Scene.INIT], 166 | [pred for pred in predicates if pred["scene"] == graph.Scene.GOAL], 167 | domain="gripper", 168 | requirements=scene._requirements, 169 | ) 170 | else: 171 | return graph.SceneGraph( 172 | constants, 173 | predicates, 174 | domain="gripper", 175 | scene=scene.scene, 176 | requirements=scene._requirements, 177 | ) 178 | 179 | @staticmethod 180 | def _gripper_get_typed_objects( 181 | scene: ReducedSceneGraph, 182 | ) -> dict[ReducedNode, set[graph.PlanGraphNode]]: 183 | """Get the typed objects in a gripper scene graph. 184 | 185 | Args: 186 | scene (ReducedGraph): The reduced SceneGraph of a scene. 187 | 188 | Returns: 189 | dict[ReducedNode, set[graph.PlanGraphNode]]: The typed objects in the 190 | scene graph. 191 | """ 192 | rooms = set() 193 | balls = set() 194 | grippers = set() 195 | 196 | for node, _ in scene.out_edges(ReducedNode.ROOMS): 197 | rooms.add(node) 198 | for node, _ in scene.out_edges(ReducedNode.BALLS): 199 | balls.add(node) 200 | for node, _ in scene.out_edges(ReducedNode.GRIPPERS): 201 | grippers.add(node) 202 | 203 | return { 204 | ReducedNode.ROOMS: rooms, 205 | ReducedNode.BALLS: balls, 206 | ReducedNode.GRIPPERS: grippers, 207 | } 208 | 209 | @staticmethod 210 | def _gripper_underspecified_blocks( 211 | init: ReducedSceneGraph, 212 | goal: ReducedSceneGraph, 213 | ) -> tuple[set[str], set[str], bool]: 214 | """Finds blocks that are not fully specified. 215 | 216 | Args: 217 | init (ReducedGraph): The reduced SceneGraph of the initial scene. 218 | goal (ReducedGraph): The reduced SceneGraph of the goal scene. 219 | 220 | Returns: 221 | tuple[set[str], set[str]]: The set of blocks that are not fully 222 | specified. 223 | - balls that do not specify being carried or being in a room. 224 | - grippers that do not specify being free or carrying a ball. 225 | - whether robby is not in a room. 226 | """ 227 | 228 | typed = GripperOracle._gripper_get_typed_objects(init) 229 | 230 | underspecified_balls = set() 231 | underspecified_grippers = set() 232 | 233 | for ball in typed[ReducedNode.BALLS]: 234 | ball_edges = [ 235 | node 236 | for node, _ in goal.out_edges(ball) 237 | if not isinstance(node, ReducedNode) 238 | ] 239 | if not ball_edges: 240 | underspecified_balls.add(ball) 241 | for gripper in typed[ReducedNode.GRIPPERS]: 242 | gripper_edges = [ 243 | node 244 | for node, _ in goal.in_edges(gripper) 245 | if node == ReducedNode.FREE or not isinstance(node, ReducedNode) 246 | ] 247 | if not gripper_edges: 248 | underspecified_grippers.add(gripper) 249 | 250 | return ( 251 | underspecified_balls, 252 | underspecified_grippers, 253 | goal.out_degree(ReducedNode.ROBBY) == 0, 254 | ) 255 | 256 | def fully_specify( 257 | problem: graph.ProblemGraph, 258 | return_reduced: bool = False, 259 | ) -> graph.ProblemGraph | ReducedProblemGraph: 260 | """Fully specifies a gripper scene graph. 261 | 262 | Adds any missing edges to fully specify the scene graph, without adding 263 | edges that change the problem represented by the graph. 264 | 265 | Args: 266 | problem (graph.ProblemGraph): The problem graph to fully specify. 267 | return_reduced (bool, optional): Whether to return a reduced problem graph. 268 | Defaults to False. 269 | 270 | Returns: 271 | ProblemGraph | ReducedProblemGraph: The fully specified problem graph. 272 | """ 273 | inflated_init, inflated_goal = problem.decompose() 274 | 275 | init: ReducedSceneGraph = GripperOracle.reduce(inflated_init) 276 | goal: ReducedSceneGraph = GripperOracle.reduce(inflated_goal) 277 | 278 | scene = copy.deepcopy(goal) 279 | 280 | # bring "typing" predicates from init to goal 281 | typed_objects = GripperOracle._gripper_get_typed_objects(init) 282 | for typing, objects in typed_objects.items(): 283 | for obj in objects: 284 | edge = graph.PlanGraphEdge( 285 | predicate=typing.value, scene=graph.Scene.GOAL 286 | ) 287 | edge_ = graph.PlanGraphEdge(predicate=typing.value) 288 | if obj in scene.nodes and not ( 289 | scene.has_edge(typing, obj, edge) 290 | or scene.has_edge(typing, obj, edge_) 291 | ): 292 | scene.add_edge(typing, obj, edge) 293 | 294 | underspecified_balls, underspecified_grippers, _ = ( 295 | GripperOracle._gripper_underspecified_blocks( 296 | init, 297 | goal, 298 | ) 299 | ) 300 | 301 | if underspecified_grippers and not underspecified_balls: 302 | for gripper in underspecified_grippers: 303 | scene.add_edge( 304 | ReducedNode.FREE, 305 | gripper, 306 | graph.PlanGraphEdge(predicate="free", scene=scene.scene), 307 | ) 308 | 309 | if return_reduced: 310 | return ReducedProblemGraph.join(init, scene) 311 | else: 312 | return graph.ProblemGraph.join(inflated_init, GripperOracle.inflate(scene)) 313 | 314 | def plan(problem: graph.ProblemGraph) -> list[Action]: 315 | # TODO: this function is not "complete": it does not handle all cases 316 | # - multiple "types" per object 317 | # - robby not at a room (can be valid in a few cases) 318 | # - balls not in rooms 319 | # - objects without typing 320 | problem = GripperOracle.fully_specify(problem, return_reduced=True) 321 | 322 | init, goal = problem.decompose() 323 | actions = [] 324 | 325 | # Process init scene 326 | typed = GripperOracle._gripper_get_typed_objects(init) 327 | rooms = list(typed[ReducedNode.ROOMS]) 328 | grippers = list(typed[ReducedNode.GRIPPERS]) 329 | 330 | # get current room 331 | if init.out_degree(ReducedNode.ROBBY) < 1: 332 | return actions 333 | 334 | current_room = init.successors(ReducedNode.ROBBY)[0] 335 | # move to first room 336 | if current_room != rooms[0]: 337 | actions.append(Action("move", [current_room.name, rooms[0].name])) 338 | 339 | # ensure all grippers are free 340 | for gripper in grippers: 341 | if not init.has_edge(ReducedNode.FREE, gripper): 342 | # get in_edge 343 | ball = [ 344 | b 345 | for b in init.predecessors(gripper) 346 | if b in typed[ReducedNode.BALLS] 347 | ] 348 | if ball: 349 | actions.append( 350 | Action("drop", [ball[0].name, rooms[0].name, gripper.name]) 351 | ) 352 | 353 | # move all balls to first room 354 | for room in rooms: 355 | for obj in init.predecessors(room): 356 | if obj in typed[ReducedNode.BALLS]: 357 | actions.append(Action("move", [rooms[0].name, room.name])) 358 | actions.append( 359 | Action("pick", [obj.name, room.name, grippers[0].name]) 360 | ) 361 | actions.append(Action("move", [room.name, rooms[0].name])) 362 | actions.append( 363 | Action("drop", [obj.name, rooms[0].name, grippers[0].name]) 364 | ) 365 | 366 | # Process goal scene 367 | for room in rooms: 368 | for obj in goal.predecessors(room): 369 | if obj in typed[ReducedNode.BALLS]: 370 | actions.append( 371 | Action("pick", [obj.name, rooms[0].name, grippers[0].name]) 372 | ) 373 | actions.append(Action("move", [rooms[0].name, room.name])) 374 | actions.append( 375 | Action("drop", [obj.name, room.name, grippers[0].name]) 376 | ) 377 | actions.append(Action("move", [room.name, rooms[0].name])) 378 | 379 | # pick up balls in first room tied to grippers 380 | for gripper in grippers: 381 | for ball in typed[ReducedNode.BALLS]: 382 | if goal.has_edge(ball, gripper): 383 | actions.append( 384 | Action("pick", [ball.name, rooms[0].name, gripper.name]) 385 | ) 386 | 387 | # move to room with robby 388 | goal_room = next(iter(goal.successors(ReducedNode.ROBBY)), None) 389 | if goal_room: 390 | actions.append(Action("move", [rooms[0].name, goal_room.name])) 391 | 392 | return actions 393 | -------------------------------------------------------------------------------- /planetarium/oracles/oracle.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from pddl.core import Action 4 | 5 | from .. import graph, reduced_graph 6 | 7 | 8 | class Oracle(abc.ABC): 9 | 10 | @staticmethod 11 | def reduce( 12 | scene: graph.SceneGraph | graph.ProblemGraph, 13 | ) -> reduced_graph.ReducedSceneGraph | reduced_graph.ReducedProblemGraph: 14 | """Reduces a scene graph to a reduced scene graph. 15 | 16 | Args: 17 | scene (graph.SceneGraph | graph.ProblemGraph): The scene graph to reduce. 18 | 19 | Returns: 20 | ReducedSceneGraph | ReducedProblemGraph: The reduced scene graph. 21 | """ 22 | match scene: 23 | case graph.ProblemGraph( 24 | _constants=constants, 25 | _predicates=predicates, 26 | _domain=domain, 27 | _requirements=requirements, 28 | ): 29 | reduced = reduced_graph.ReducedProblemGraph( 30 | constants=constants, 31 | domain=domain, 32 | requirements=requirements, 33 | ) 34 | case graph.SceneGraph( 35 | constants=constants, 36 | _predicates=predicates, 37 | scene=scene, 38 | _domain=domain, 39 | _requirements=requirements, 40 | ): 41 | reduced = reduced_graph.ReducedSceneGraph( 42 | constants=constants, 43 | domain=domain, 44 | scene=scene, 45 | requirements=requirements, 46 | ) 47 | case _: 48 | raise ValueError("Scene must be a SceneGraph or ProblemGraph.") 49 | 50 | for predicate in predicates: 51 | params = predicate["parameters"] 52 | reduced_edge = graph.PlanGraphEdge( 53 | predicate=predicate["typing"], 54 | scene=predicate.get("scene"), 55 | ) 56 | match params: 57 | case [x]: 58 | reduced.add_node(x, reduced_edge) 59 | case [x, y]: 60 | reduced.add_edge(x, y, reduced_edge) 61 | case _: 62 | raise ValueError("Predicate parameters must be 1 or 2.") 63 | 64 | @staticmethod 65 | def inflate( 66 | scene: reduced_graph.ReducedSceneGraph | reduced_graph.ReducedProblemGraph, 67 | ) -> graph.SceneGraph | graph.ProblemGraph: 68 | """Inflates a reduced scene graph to a scene graph. 69 | 70 | Args: 71 | scene (ReducedSceneGraph | ReducedProblemGraph): The reduced scene graph to inflate. 72 | 73 | Returns: 74 | SceneGraph | ProblemGraph: The inflated scene graph. 75 | """ 76 | constants = [] 77 | predicates = [] 78 | 79 | for node in scene.nodes: 80 | if ( 81 | not isinstance(node.node, reduced_graph.ReducedNode) 82 | and node.label == graph.Label.CONSTANT 83 | ): 84 | # add constants 85 | constants.append({"name": node.node, "typing": node.typing}) 86 | 87 | for u, v, edge in scene.edges: 88 | match u.node: 89 | case pred_node if isinstance(pred_node, reduced_graph.ReducedNode): 90 | predicates.append( 91 | { 92 | "typing": edge.predicate, 93 | "parameters": [v.node], 94 | "scene": edge.scene, 95 | } 96 | ) 97 | case _: 98 | predicates.append( 99 | { 100 | "typing": edge.predicate, 101 | "parameters": [u.node, v.node], 102 | "scene": edge.scene, 103 | } 104 | ) 105 | 106 | if isinstance(scene, reduced_graph.ReducedProblemGraph): 107 | return graph.ProblemGraph( 108 | constants, 109 | [pred for pred in predicates if pred["scene"] == graph.Scene.INIT], 110 | [pred for pred in predicates if pred["scene"] == graph.Scene.GOAL], 111 | domain=scene.domain, 112 | requirements=scene._requirements, 113 | ) 114 | else: 115 | return graph.SceneGraph( 116 | constants, 117 | predicates, 118 | domain=scene.domain, 119 | scene=scene.scene, 120 | requirements=scene._requirements, 121 | ) 122 | 123 | @staticmethod 124 | @abc.abstractmethod 125 | def fully_specify( 126 | problem: graph.ProblemGraph, 127 | return_reduced: bool = False, 128 | ) -> graph.ProblemGraph | reduced_graph.ReducedProblemGraph: 129 | """Fully specifies a goal state. 130 | 131 | Args: 132 | problem (graph.ProblemGraph): The problem graph to fully specify. 133 | return_reduced (bool, optional): Whether to return a reduced problem graph. 134 | Defaults to False. 135 | 136 | Returns: 137 | ProblemGraph | ReducedProblemGraph: The fully specified problem graph. 138 | """ 139 | 140 | @staticmethod 141 | def plan(problem: graph.ProblemGraph) -> list[Action]: 142 | """Generates a plan for a problem graph. 143 | 144 | Args: 145 | problem (graph.ProblemGraph): The problem graph to plan. 146 | 147 | Returns: 148 | list[Action]: The plan for the problem graph. 149 | """ 150 | raise NotImplementedError("Planning not supported for this oracle.") 151 | -------------------------------------------------------------------------------- /planetarium/reduced_graph.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import copy 4 | import aenum 5 | 6 | from planetarium import graph 7 | 8 | 9 | class ReducedNode(str, aenum.Enum): 10 | 11 | @classmethod 12 | def register( 13 | cls, 14 | attrs: dict[str, str], 15 | domain: str | None = None, 16 | ) -> set["ReducedNode"]: 17 | """Extend the ReducedNode enum with new nodes. 18 | 19 | Args: 20 | attrs (dict[str, str]): The new nodes to add. 21 | domain (str, optional): The domain to extend. Defaults to None. 22 | 23 | Raises: 24 | AttributeError: If the domain is already extended. 25 | 26 | Returns: 27 | set[ReducedNode]: The set of new nodes. 28 | """ 29 | nodes = set() 30 | for name, value in attrs.items(): 31 | if name not in cls.__members__: 32 | aenum.extend_enum(cls, name, value) 33 | nodes.add(cls[name]) 34 | 35 | if isinstance(domain, str): 36 | if domain in ReducedNodes: 37 | raise AttributeError( 38 | f"ReducedNode already extended for domain {domain}." 39 | ) 40 | ReducedNodes[domain] = nodes 41 | 42 | return nodes 43 | 44 | 45 | ReducedNodes: dict[str, set[ReducedNode]] = {} 46 | 47 | 48 | class ReducedSceneGraph(graph.PlanGraph): 49 | def __init__( 50 | self, 51 | constants: list[dict[str, Any]], 52 | domain: str, 53 | scene: graph.Scene | None = None, 54 | requirements: tuple[str] = (), 55 | ): 56 | super().__init__(constants, domain=domain, requirements=requirements) 57 | self.scene = scene 58 | 59 | for e in ReducedNodes.get(domain, set()): 60 | predicate = e.value 61 | self.add_node( 62 | graph.PlanGraphNode( 63 | e, 64 | name=predicate, 65 | label=graph.Label.PREDICATE, 66 | typing=predicate, 67 | ), 68 | ) 69 | 70 | def _add_predicate( 71 | self, 72 | predicate: dict[str, Any], 73 | scene: graph.Scene | None = None, 74 | ): 75 | raise AttributeError( 76 | "ReducedSceneGraph does not support adding predicates directly." 77 | ) 78 | 79 | 80 | class ReducedProblemGraph(graph.PlanGraph): 81 | def __init__( 82 | self, 83 | constants: list[dict[str, Any]], 84 | domain: str, 85 | requirements: tuple[str] = (), 86 | ): 87 | super().__init__(constants, domain=domain, requirements=requirements) 88 | 89 | for e in ReducedNodes.get(domain, []): 90 | predicate = e.value 91 | self.add_node( 92 | graph.PlanGraphNode( 93 | e, 94 | name=predicate, 95 | label=graph.Label.PREDICATE, 96 | typing=predicate, 97 | ), 98 | ) 99 | 100 | def decompose(self) -> tuple[ReducedSceneGraph, ReducedSceneGraph]: 101 | init = ReducedSceneGraph( 102 | self.constants, 103 | self.domain, 104 | scene=graph.Scene.INIT, 105 | requirements=self._requirements, 106 | ) 107 | goal = ReducedSceneGraph( 108 | self.constants, 109 | self.domain, 110 | scene=graph.Scene.GOAL, 111 | requirements=self._requirements, 112 | ) 113 | 114 | for u, v, edge in self.edges: 115 | edge = copy.deepcopy(edge) 116 | if edge.scene == graph.Scene.INIT: 117 | init.add_edge(u, v, edge) 118 | elif edge.scene == graph.Scene.GOAL: 119 | goal.add_edge(u, v, edge) 120 | 121 | return init, goal 122 | 123 | @staticmethod 124 | def join(init: ReducedSceneGraph, goal: ReducedSceneGraph) -> "ReducedProblemGraph": 125 | problem = ReducedProblemGraph( 126 | init.constants, 127 | domain=init.domain, 128 | requirements=init._requirements, 129 | ) 130 | 131 | for node in (*init.nodes, *goal.nodes): 132 | if ( 133 | node not in problem.nodes 134 | and node.label == graph.Label.PREDICATE 135 | and not isinstance(node.node, ReducedNode) 136 | ): 137 | node = copy.deepcopy(node) 138 | problem.add_node(node) 139 | 140 | for u, v, edge in init.edges: 141 | edge = copy.deepcopy(edge) 142 | problem.add_edge(u, v, edge) 143 | edge.scene = graph.Scene.INIT 144 | for u, v, edge in goal.edges: 145 | edge = copy.deepcopy(edge) 146 | edge.scene = graph.Scene.GOAL 147 | problem.add_edge(u, v, edge) 148 | 149 | return problem 150 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "planetarium" 3 | version = "0.1.0" 4 | description = "Benchmark framework for evaluating LLMs performance in the context of PDDL code generation." 5 | authors = [ 6 | "Max Zuo ", 7 | "Francisco J. Piedrahita-Velez ", 8 | ] 9 | readme = "README.md" 10 | 11 | [tool.poetry.dependencies] 12 | python = "^3.10" 13 | networkx = "^3.2.1" 14 | pddl = {git = "https://github.com/maxzuo/pddl.git"} 15 | pyyaml = "^6.0.1" 16 | jinja2 = "^3.1.4" 17 | rustworkx = "^0.14.2" 18 | matplotlib = "^3.9.0" 19 | aenum = "^3.1.15" 20 | 21 | 22 | [tool.poetry.group.dev.dependencies] 23 | ruff = "^0.1.7" 24 | pytest = "^7.4.3" 25 | mypy = "^1.7.1" 26 | pytest-cov = "^4.1.0" 27 | pytest-timeout = "^2.2.0" 28 | pytest-subtests = "^0.12.1" 29 | black = {extras = ["jupyter"], version = "^24.4.2"} 30 | pytest-mock = "^3.14.0" 31 | 32 | [tool.poetry.group.all] 33 | optional = true 34 | 35 | [tool.poetry.group.all.dependencies] 36 | lark = "^1.1.9" 37 | vllm = "^0.5.0.post1" 38 | python-dotenv = "^1.0.1" 39 | datasets = "^2.20.0" 40 | peft = "^0.11.1" 41 | trl = "^0.9.4" 42 | bitsandbytes = "^0.43.1" 43 | openai = "^1.35.3" 44 | 45 | [build-system] 46 | requires = ["poetry-core"] 47 | build-backend = "poetry.core.masonry.api" 48 | 49 | [tool.mypy] 50 | allow_redefinition = true 51 | ignore_missing_imports = true 52 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BatsResearch/planetarium/f257ae9e68b813aeace00cc683d0c2ad5b57a157/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_evaluate.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | import pytest 3 | 4 | import planetarium 5 | 6 | from .problem_fixtures import ( 7 | blocksworld_underspecified, 8 | blocksworld_missing_clears, 9 | blocksworld_missing_ontables, 10 | blocksworld_fully_specified, 11 | blocksworld_invalid_1, 12 | rover_single_line_equiv, 13 | rover_single_line_equiva, 14 | rover_single_line_equiv_1, 15 | rover_single_line_equiv_1a, 16 | rover_single_line_equiv_1b, 17 | floortile_no_white1, 18 | floortile_no_white1a, 19 | floortile_disconnected_tile1, 20 | floortile_disconnected_tile1a, 21 | floortile_one_color_one_robot1, 22 | floortile_one_color_one_robot1a, 23 | floortile_no_available_colors, 24 | floortile_no_available_colors_a, 25 | ) 26 | 27 | 28 | @pytest.fixture 29 | def blocksworld_wrong_init(): 30 | """ 31 | Fixture providing a fully specified blocksworld problem with wrong init. 32 | """ 33 | return """ 34 | (define (problem staircase) 35 | (:domain blocksworld) 36 | (:objects 37 | b1 b2 b3 b4 b5 b6 38 | ) 39 | (:init 40 | (arm-empty) 41 | (on-table b1) 42 | (clear b1) 43 | (on-table b2) 44 | (clear b2) 45 | (on-table b3) 46 | (clear b3) 47 | (on-table b4) 48 | (clear b4) 49 | (on-table b5) 50 | (on b6 b5) 51 | (clear b6) 52 | ) 53 | (:goal 54 | (and 55 | (on-table b1) 56 | (clear b1) 57 | 58 | (on-table b2) 59 | (on b3 b2) 60 | (clear b3) 61 | 62 | (on-table b4) 63 | (on b5 b4) 64 | (on b6 b5) 65 | (clear b6) 66 | ) 67 | ) 68 | ) 69 | """ 70 | 71 | 72 | @pytest.fixture 73 | def blocksworld_fully_specified_wrong_domain(): 74 | """ 75 | Fixture providing a fully specified blocksworld problem with wrong domain. 76 | """ 77 | return """ 78 | (define (problem staircase) 79 | (:domain blocksworld-wrong) 80 | (:objects 81 | b1 b2 b3 b4 b5 b6 82 | ) 83 | (:init 84 | (arm-empty) 85 | (on-table b1) 86 | (clear b1) 87 | (on-table b2) 88 | (clear b2) 89 | (on-table b3) 90 | (clear b3) 91 | (on-table b4) 92 | (clear b4) 93 | (on-table b5) 94 | (clear b5) 95 | (on-table b6) 96 | (clear b6) 97 | ) 98 | (:goal 99 | (and 100 | (on-table b1) 101 | (clear b1) 102 | 103 | (on-table b2) 104 | (on b3 b2) 105 | (clear b3) 106 | 107 | (on-table b4) 108 | (on b5 b4) 109 | (on b6 b5) 110 | (clear b6) 111 | ) 112 | ) 113 | ) 114 | """ 115 | 116 | 117 | @pytest.fixture 118 | def blocksworld_unsolveable(): 119 | """ 120 | Fixture providing a fully specified blocksworld problem with wrong init. 121 | """ 122 | return """ 123 | (define (problem staircase) 124 | (:domain blocksworld) 125 | (:objects 126 | b1 b2 b3 b4 b5 b6 127 | ) 128 | (:init 129 | (arm-empty) 130 | (on-table b1) 131 | (clear b1) 132 | (on-table b2) 133 | (clear b2) 134 | (on-table b3) 135 | (clear b3) 136 | (on-table b4) 137 | (clear b4) 138 | (on-table b5) 139 | (on b6 b5) 140 | (clear b6) 141 | ) 142 | (:goal 143 | (and 144 | (on-table b1) 145 | (clear b1) 146 | 147 | (on-table b2) 148 | (on b3 b2) 149 | (clear b3) 150 | 151 | (on-table b4) 152 | (on b5 b4) 153 | (clear b5) 154 | (on b6 b5) 155 | (clear b6) 156 | ) 157 | ) 158 | ) 159 | """ 160 | 161 | 162 | @pytest.fixture 163 | def blocksworld_unparseable(): 164 | """ 165 | Fixture providing an unparseable blocksworld problem. 166 | """ 167 | return """ 168 | (define (problem staircase) 169 | (:domain blocksworld) 170 | (:objects 171 | b1 b2 b3 b4 b5 b6 172 | )) 173 | (:init 174 | (arm-empty) 175 | (on-table b1) 176 | (clear b1) 177 | (on-table b2) 178 | (clear b2) 179 | (on-table b3) 180 | (clear b3) 181 | (on-table b4) 182 | (clear b4) 183 | (on-table b5) 184 | (on b6 b5) 185 | (clear b6) 186 | ) 187 | (:goal 188 | (and 189 | (on-table b1) 190 | (clear b1) 191 | 192 | (on-table b2) 193 | (on b3 b2) 194 | (clear b3) 195 | 196 | (on-table b4) 197 | (on b5 b4) 198 | (on b6 b5) 199 | (clear b6) 200 | ) 201 | ) 202 | ) 203 | """ 204 | 205 | 206 | class TestEvaluate: 207 | """ 208 | Test suite for the evaluation of PDDL problem descriptions. 209 | """ 210 | 211 | def test_evaluate_equivalent( 212 | self, 213 | subtests, 214 | blocksworld_missing_clears, 215 | blocksworld_fully_specified, 216 | blocksworld_missing_ontables, 217 | blocksworld_underspecified, 218 | ): 219 | """ 220 | Test if the evaluation of PDDL problem descriptions is correct. 221 | """ 222 | descs = [ 223 | ("blocksworld_missing_clears", blocksworld_missing_clears), 224 | ("blocksworld_fully_specified", blocksworld_fully_specified), 225 | ("blocksworld_missing_ontables", blocksworld_missing_ontables), 226 | ] 227 | for (name1, desc1), (name2, desc2) in product(descs, descs): 228 | with subtests.test(f"{name1} equals {name2}"): 229 | assert all(planetarium.evaluate(desc1, desc2)) 230 | 231 | with subtests.test( 232 | "blocksworld_underspecified equals blocksworld_underspecified" 233 | ): 234 | assert all( 235 | planetarium.evaluate( 236 | blocksworld_underspecified, blocksworld_underspecified 237 | ) 238 | ) 239 | 240 | def test_evaluate_inequivalent( 241 | self, 242 | subtests, 243 | blocksworld_missing_clears, 244 | blocksworld_fully_specified, 245 | blocksworld_missing_ontables, 246 | blocksworld_underspecified, 247 | blocksworld_wrong_init, 248 | blocksworld_unparseable, 249 | blocksworld_unsolveable, 250 | ): 251 | """ 252 | Test if the evaluation of PDDL problem descriptions is correct. 253 | """ 254 | descs = [ 255 | ("blocksworld_missing_clears", blocksworld_missing_clears), 256 | ("blocksworld_fully_specified", blocksworld_fully_specified), 257 | ("blocksworld_missing_ontables", blocksworld_missing_ontables), 258 | ] 259 | for name, desc in descs: 260 | with subtests.test(f"{name} not equals blocksworld_underspecified"): 261 | assert planetarium.evaluate(desc, blocksworld_underspecified) == ( 262 | True, 263 | True, 264 | False, 265 | ) 266 | 267 | with subtests.test(f"{name} not equals blocksworld_wrong_init"): 268 | assert planetarium.evaluate(desc, blocksworld_wrong_init) == ( 269 | True, 270 | True, 271 | False, 272 | ) 273 | with subtests.test(f"{name} not equals blocksworld_unparseable"): 274 | assert planetarium.evaluate(desc, blocksworld_unparseable) == ( 275 | False, 276 | False, 277 | False, 278 | ) 279 | with subtests.test(f"{name} not equals blocksworld_unsolveable"): 280 | assert planetarium.evaluate(desc, blocksworld_unsolveable) == ( 281 | True, 282 | False, 283 | False, 284 | ) 285 | 286 | with subtests.test( 287 | "blocksworld_underspecified not equals blocksworld_wrong_init" 288 | ): 289 | assert planetarium.evaluate( 290 | blocksworld_underspecified, blocksworld_wrong_init 291 | ) == ( 292 | True, 293 | True, 294 | False, 295 | ) 296 | 297 | def test_rover_single_equivalent( 298 | self, 299 | subtests, 300 | rover_single_line_equiv, 301 | rover_single_line_equiva, 302 | rover_single_line_equiv_1, 303 | rover_single_line_equiv_1a, 304 | rover_single_line_equiv_1b, 305 | ): 306 | """ 307 | Test if the evaluation of PDDL problem descriptions is correct. 308 | """ 309 | 310 | with subtests.test("rover_single_line_equiv equals rover_single_line_equiva"): 311 | assert all( 312 | planetarium.evaluate( 313 | rover_single_line_equiv, 314 | rover_single_line_equiva, 315 | alias="lama-first", 316 | ) 317 | ) 318 | 319 | descs = { 320 | "rover_single_line_equiv_1": rover_single_line_equiv_1, 321 | "rover_single_line_equiv_1a": rover_single_line_equiv_1a, 322 | "rover_single_line_equiv_1b": rover_single_line_equiv_1b, 323 | } 324 | for (name1, desc1), (name2, desc2) in product(descs.items(), descs.items()): 325 | with subtests.test(f"{name1} equals {name2}"): 326 | assert all(planetarium.evaluate(desc1, desc2, alias="lama-first")) 327 | 328 | def test_floortile_equivalent( 329 | self, 330 | subtests, 331 | floortile_no_white1, 332 | floortile_no_white1a, 333 | floortile_disconnected_tile1, 334 | floortile_disconnected_tile1a, 335 | floortile_one_color_one_robot1, 336 | floortile_one_color_one_robot1a, 337 | floortile_no_available_colors, 338 | floortile_no_available_colors_a, 339 | ): 340 | """ 341 | Test if the evaluation of PDDL problem descriptions is correct. 342 | """ 343 | descs = { 344 | "floortile_no_white1": floortile_no_white1, 345 | "floortile_no_white1a": floortile_no_white1a, 346 | "floortile_disconnected_tile1": floortile_disconnected_tile1, 347 | "floortile_disconnected_tile1a": floortile_disconnected_tile1a, 348 | "floortile_one_color_one_robot1": floortile_one_color_one_robot1, 349 | "floortile_one_color_one_robot1a": floortile_one_color_one_robot1a, 350 | "floortile_no_available_colors": floortile_no_available_colors, 351 | "floortile_no_available_colors_a": floortile_no_available_colors_a, 352 | } 353 | equiv_pairs = [ 354 | ("floortile_no_white1", "floortile_no_white1a"), 355 | ("floortile_disconnected_tile1", "floortile_disconnected_tile1a"), 356 | ("floortile_one_color_one_robot1", "floortile_one_color_one_robot1a"), 357 | ("floortile_no_available_colors", "floortile_no_available_colors_a"), 358 | ] 359 | 360 | for n1, n2 in equiv_pairs: 361 | test_name = f"{n2} equals {n1}" 362 | with subtests.test(test_name): 363 | assert all( 364 | planetarium.evaluate(descs[n1], descs[n2], alias="lama-first") 365 | ) 366 | 367 | 368 | class TestUnsupportedDomain: 369 | """ 370 | Test suite for unsupported domain. 371 | """ 372 | 373 | def test_plan( 374 | self, blocksworld_fully_specified, blocksworld_fully_specified_wrong_domain 375 | ): 376 | """ 377 | Test if the oracle can plan for an unsupported domain. 378 | """ 379 | assert planetarium.evaluate( 380 | blocksworld_fully_specified, blocksworld_fully_specified_wrong_domain 381 | ) == ( 382 | True, 383 | False, 384 | False, 385 | ) 386 | -------------------------------------------------------------------------------- /tests/test_graph.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from .test_pddl import problem_string 4 | 5 | from planetarium import builder 6 | 7 | 8 | @pytest.fixture 9 | def sgraph(problem_string): 10 | """ 11 | Fixture providing an SGraph instance built from a PDDL problem string. 12 | """ 13 | return builder.build(problem_string).decompose()[0] 14 | 15 | 16 | class TestGraph: 17 | """ 18 | Test suite for the SGraph instance built from a PDDL problem. 19 | """ 20 | 21 | def test_constant_node_names(self, sgraph): 22 | """ 23 | Test if the names of constant nodes in the graph match the expected set. 24 | """ 25 | names = set(["p0", "p1", "f0", "f1", "f2", "f3"]) 26 | assert all([(node.name in names) for node in sgraph.constant_nodes]) 27 | 28 | def test_constant_node_size(self, sgraph): 29 | """ 30 | Test if the number of constant nodes in the graph matches the expected count. 31 | """ 32 | assert len(sgraph.constant_nodes) == 6 33 | 34 | def test_predicate_names(self, sgraph): 35 | """ 36 | Test if the names of predicate nodes in the graph match expected patterns. 37 | """ 38 | for predicate in sgraph.predicate_nodes: 39 | match predicate.node.split("-"): 40 | case ["above", _, _]: 41 | assert True 42 | case ["origin", _, _]: 43 | assert True 44 | case ["destin", _, _]: 45 | assert True 46 | case ["lift", "at", _]: 47 | assert True 48 | case _: 49 | assert False 50 | 51 | def test_plot(self, sgraph): 52 | """ 53 | Test if the graph can be plotted. 54 | """ 55 | sgraph.plot() 56 | assert True 57 | -------------------------------------------------------------------------------- /tests/test_metric.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from planetarium import builder, graph, metric, oracle 4 | 5 | # pylint: disable=unused-import 6 | from .test_pddl import ( 7 | problem_string, 8 | two_initial_problem_string, 9 | renamed_problem_string, 10 | wrong_problem_string, 11 | swap_problem_string, 12 | wrong_swap_problem_string, 13 | move_problem_string, 14 | wrong_move_problem_string, 15 | wrong_initial_problem_string, 16 | ) 17 | 18 | from .problem_fixtures import ( 19 | blocksworld_underspecified, 20 | blocksworld_missing_clears, 21 | blocksworld_missing_ontables, 22 | blocksworld_fully_specified, 23 | gripper_fully_specified, 24 | gripper_no_robby, 25 | rover_single_line_fully_specified_4, 26 | rover_single_line_fully_specified_4a, 27 | ) 28 | 29 | from pddl import requirements 30 | 31 | 32 | def problem_states( 33 | problem_init: graph.SceneGraph, problem_goals: set[graph.SceneGraph] 34 | ) -> set[graph.ProblemGraph]: 35 | return set([graph.ProblemGraph.join(problem_init, goal) for goal in problem_goals]) 36 | 37 | 38 | class TestConstantMatching: 39 | """ 40 | Test suite for constant matching functions in the metric module. 41 | """ 42 | 43 | @pytest.fixture 44 | def source(self): 45 | """Fixture for a valid source constant.""" 46 | return graph.PlanGraphNode( 47 | "o1", "o1", typing=["t1", "t2"], label=graph.Label.CONSTANT 48 | ) 49 | 50 | @pytest.fixture 51 | def target(self): 52 | """Fixture for a valid target constant.""" 53 | return graph.PlanGraphNode( 54 | "c1", "c1", typing=["t1", "t2"], label=graph.Label.CONSTANT 55 | ) 56 | 57 | @pytest.fixture 58 | def source_incorrect_label(self): 59 | """Fixture for a source constant with an incorrect label.""" 60 | return graph.PlanGraphNode( 61 | "o1", "o1", typing=["t1", "t2"], label=graph.Label.PREDICATE 62 | ) 63 | 64 | @pytest.fixture 65 | def target_incorrect_label(self): 66 | """Fixture for a target constant with an incorrect label.""" 67 | return graph.PlanGraphNode( 68 | "c1", "c1", typing=["t1", "t2"], label=graph.Label.PREDICATE 69 | ) 70 | 71 | @pytest.fixture 72 | def source_incorrect_typing(self): 73 | """Fixture for a source constant with incorrect typing.""" 74 | return graph.PlanGraphNode( 75 | "o1", "o1", typing=["ty1", "ty2"], label=graph.Label.CONSTANT 76 | ) 77 | 78 | @pytest.fixture 79 | def target_incorrect_typing(self): 80 | """Fixture for a target constant with incorrect typing.""" 81 | return graph.PlanGraphNode( 82 | "c1", "c1", typing=["ty1", "ty2"], label=graph.Label.CONSTANT 83 | ) 84 | 85 | @pytest.fixture 86 | def mapping(self): 87 | """Fixture for a valid mapping between source and target constants.""" 88 | return {"o1": "c1"} 89 | 90 | @pytest.fixture 91 | def mapping_incorrect(self): 92 | """Fixture for an incorrect mapping between source and target constants.""" 93 | return {"o1": "o1"} 94 | 95 | def test_correct_matching(self, source, target, mapping): 96 | """Test correct matching between source and target constants.""" 97 | assert metric._node_matching(source, target, None) 98 | assert metric._node_matching(source, target, mapping) 99 | 100 | assert metric._same_typing(source, target) 101 | assert metric._preserves_mapping(source, target, mapping) 102 | 103 | def test_incorrect_label( 104 | self, source, target, source_incorrect_label, target_incorrect_label, mapping 105 | ): 106 | """Test incorrect label matching between source and target constants.""" 107 | assert not metric._node_matching(source, target_incorrect_label, None) 108 | assert not metric._node_matching(source_incorrect_label, target, None) 109 | assert not metric._node_matching(source, target_incorrect_label, mapping) 110 | assert not metric._node_matching(source_incorrect_label, target, mapping) 111 | 112 | assert not metric._preserves_mapping(source, target_incorrect_label, mapping) 113 | assert not metric._preserves_mapping(source_incorrect_label, target, mapping) 114 | 115 | assert not metric._same_typing(source, target_incorrect_label) 116 | assert not metric._same_typing(source_incorrect_label, target) 117 | 118 | def test_incorrect_typing( 119 | self, source, target, source_incorrect_typing, target_incorrect_typing, mapping 120 | ): 121 | """Test incorrect typing between source and target constants.""" 122 | assert not metric._node_matching(source, target_incorrect_typing, None) 123 | assert not metric._node_matching(source_incorrect_typing, target, None) 124 | assert not metric._node_matching(source, target_incorrect_typing, mapping) 125 | assert not metric._node_matching(source_incorrect_typing, target, mapping) 126 | 127 | assert metric._preserves_mapping(source, target_incorrect_typing, mapping) 128 | assert metric._preserves_mapping(source_incorrect_typing, target, mapping) 129 | 130 | assert not metric._same_typing(source, target_incorrect_typing) 131 | assert not metric._same_typing(source_incorrect_typing, target) 132 | 133 | def test_incorrect_mapping(self, source, target, mapping_incorrect): 134 | """Test incorrect mapping between source and target constants.""" 135 | assert not metric._node_matching(source, target, mapping_incorrect) 136 | assert not metric._preserves_mapping(source, target, mapping_incorrect) 137 | 138 | 139 | class TestPredicateMatching: 140 | """ 141 | Test suite for predicate matching functions in the metric module. 142 | """ 143 | 144 | @pytest.fixture 145 | def source(self): 146 | """Fixture for a valid source predicate node.""" 147 | return graph.PlanGraphNode( 148 | "f-a1-a2", 149 | "f-a1-a2", 150 | typing="f", 151 | label=graph.Label.PREDICATE, 152 | ) 153 | 154 | @pytest.fixture 155 | def target(self): 156 | """Fixture for a valid target predicate node.""" 157 | return graph.PlanGraphNode( 158 | "f-a1-a2", 159 | "f-a1-a2", 160 | typing="f", 161 | label=graph.Label.PREDICATE, 162 | ) 163 | 164 | @pytest.fixture 165 | def source_incorrect_label(self): 166 | """Fixture for a source predicate node with an incorrect label.""" 167 | return graph.PlanGraphNode( 168 | "f-a1-a2", 169 | "f-a1-a2", 170 | typing="f", 171 | label=graph.Label.CONSTANT, 172 | ) 173 | 174 | @pytest.fixture 175 | def target_incorrect_label(self): 176 | """Fixture for a target predicate node with an incorrect label.""" 177 | return graph.PlanGraphNode( 178 | "f-a1-a2", 179 | "f-a1-a2", 180 | typing="f", 181 | label=graph.Label.CONSTANT, 182 | ) 183 | 184 | def test_correct_matching(self, source, target): 185 | """Test correct matching between source and target predicate nodes.""" 186 | assert metric._node_matching(source, target, None) 187 | 188 | def test_incorrect_label( 189 | self, 190 | source, 191 | target, 192 | source_incorrect_label, 193 | target_incorrect_label, 194 | ): 195 | """Test incorrect label matching between source and target predicate nodes.""" 196 | assert not metric._node_matching(source, target_incorrect_label, None) 197 | assert not metric._node_matching(source_incorrect_label, target, None) 198 | 199 | 200 | class TestMetrics: 201 | """ 202 | Test suite for metrics functions in the metric module. 203 | """ 204 | 205 | def test_map(self, problem_string, two_initial_problem_string): 206 | """Test the mapping function on graph pairs.""" 207 | problem_graph = builder.build(problem_string) 208 | problem_graph2 = builder.build(two_initial_problem_string) 209 | 210 | assert metric.isomorphic(problem_graph, problem_graph) 211 | assert not metric.isomorphic(problem_graph, problem_graph2) 212 | 213 | def test_validate(self, problem_string, two_initial_problem_string): 214 | """Test the validation function on graph pairs.""" 215 | problem_graph = builder.build(problem_string) 216 | problem_graph2 = builder.build(two_initial_problem_string) 217 | 218 | assert metric.equals(problem_graph, problem_graph, is_placeholder=True) 219 | assert not metric.equals( 220 | problem_graph, 221 | problem_graph2, 222 | is_placeholder=True, 223 | ) 224 | 225 | def test_swap(self, swap_problem_string, wrong_swap_problem_string): 226 | """ 227 | Test the distance function on graph pairs. 228 | """ 229 | swap_problem = builder.build(swap_problem_string) 230 | wrong_swap = builder.build(wrong_swap_problem_string) 231 | 232 | # Test validate 233 | assert metric.equals(swap_problem, swap_problem, is_placeholder=False) 234 | assert not metric.equals(swap_problem, wrong_swap, is_placeholder=False) 235 | assert metric.equals(swap_problem, wrong_swap, is_placeholder=True) 236 | 237 | def test_move(self, move_problem_string, wrong_move_problem_string): 238 | """ 239 | Test the distance function on graph pairs. 240 | """ 241 | move_problem = builder.build(move_problem_string) 242 | wrong_move = builder.build(wrong_move_problem_string) 243 | 244 | # Test validate 245 | assert metric.equals(move_problem, move_problem, is_placeholder=True) 246 | assert not metric.equals(move_problem, wrong_move, is_placeholder=True) 247 | 248 | def test_blocksworld_equivalence( 249 | self, 250 | subtests, 251 | blocksworld_fully_specified, 252 | blocksworld_missing_clears, 253 | blocksworld_missing_ontables, 254 | blocksworld_underspecified, 255 | ): 256 | """Test the equivalence of blocksworld problems.""" 257 | p1 = builder.build(blocksworld_fully_specified) 258 | p2 = builder.build(blocksworld_missing_clears) 259 | p3 = builder.build(blocksworld_missing_ontables) 260 | p4 = builder.build(blocksworld_underspecified) 261 | 262 | p1 = oracle.fully_specify(p1) 263 | p2 = oracle.fully_specify(p2) 264 | p3 = oracle.fully_specify(p3) 265 | p4 = oracle.fully_specify(p4) 266 | 267 | P = ( 268 | ("blocksworld_fully_specified", p1), 269 | ("blocksworld_missing_clears", p2), 270 | ("blocksworld_missing_ontables", p3), 271 | ("blocksworld_underspecified", p4), 272 | ) 273 | 274 | # equivalence to itself 275 | for name, p in P: 276 | with subtests.test(f"{name} equals {name}"): 277 | assert metric.equals(p, p, is_placeholder=True) 278 | assert metric.equals(p, p, is_placeholder=False) 279 | 280 | # check invalid equivalence 281 | 282 | for idx1, idx2 in ( 283 | (0, 3), 284 | (1, 3), 285 | (2, 3), 286 | ): 287 | (name1, p1), (name2, p2) = P[idx1], P[idx2] 288 | with subtests.test(f"{name1} not equals {name2}"): 289 | assert not metric.equals(p1, p2, is_placeholder=True) 290 | assert not metric.equals(p1, p2, is_placeholder=False) 291 | assert not metric.equals(p2, p1, is_placeholder=True) 292 | assert not metric.equals(p2, p1, is_placeholder=False) 293 | 294 | def test_rover_single_eqquivalence( 295 | self, 296 | subtests, 297 | rover_single_line_fully_specified_4, 298 | rover_single_line_fully_specified_4a, 299 | ): 300 | """Test the equivalence of rover single line problems.""" 301 | p1 = builder.build(rover_single_line_fully_specified_4) 302 | p2 = builder.build(rover_single_line_fully_specified_4a) 303 | 304 | p1 = oracle.fully_specify(p1) 305 | p2 = oracle.fully_specify(p2) 306 | 307 | # equivalence to itself 308 | assert metric.equals(p1, p1, is_placeholder=True) 309 | assert metric.equals(p2, p2, is_placeholder=False) 310 | 311 | # check invalid equivalence 312 | assert metric.equals(p1, p2, is_placeholder=True) 313 | assert metric.equals(p1, p2, is_placeholder=False) 314 | assert metric.equals(p2, p1, is_placeholder=True) 315 | assert metric.equals(p2, p1, is_placeholder=False) 316 | -------------------------------------------------------------------------------- /tests/test_oracle.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from planetarium import builder, graph, oracle 4 | 5 | from .problem_fixtures import ( 6 | blocksworld_fully_specified, 7 | blocksworld_missing_clears, 8 | blocksworld_missing_ontables, 9 | blocksworld_underspecified, 10 | blocksworld_underspecified_arm, 11 | blocksworld_holding, 12 | gripper_fully_specified, 13 | gripper_no_goal_types, 14 | gripper_fully_specified_not_strict, 15 | gripper_no_robby, 16 | gripper_underspecified_1, 17 | gripper_underspecified_2, 18 | gripper_underspecified_3, 19 | gripper_no_robby_init, 20 | rover_single_line_fully_specified, 21 | rover_single_line_fully_specified_1, 22 | rover_single_line_fully_specified_2, 23 | rover_single_line_fully_specified_3, 24 | rover_single_line_fully_specified_4, 25 | floortile_fully_specified, 26 | floortile_underspecified_directions, 27 | floortile_no_white1, 28 | floortile_no_white1a, 29 | floortile_no_white2, 30 | floortile_no_available_colors, 31 | floortile_disconnected_tile_no_white, 32 | floortile_disconnected_tile1, 33 | floortile_disconnected_tile1a, 34 | floortile_one_color_one_robot1, 35 | ) 36 | 37 | 38 | def reduce_and_inflate(scene: graph.SceneGraph) -> bool: 39 | """Respecify a scene and check if it is equal to the original. 40 | 41 | Args: 42 | scene (graph.SceneGraph): The scene to test 43 | 44 | Returns: 45 | bool: True if the respecified scene is equal to the original. 46 | """ 47 | reduced = oracle.reduce(scene, domain=scene.domain) 48 | respecified = oracle.inflate(reduced, domain=scene.domain) 49 | return scene == respecified 50 | 51 | 52 | class TestBlocksworldOracle: 53 | """ 54 | Test suite for the blocksworld oracle. 55 | """ 56 | 57 | def test_fully_specified(self, blocksworld_fully_specified): 58 | """ 59 | Test the fully specified blocksworld problem. 60 | """ 61 | problem = builder.build(blocksworld_fully_specified) 62 | full = oracle.fully_specify(problem) 63 | assert oracle.fully_specify(full) == full 64 | 65 | def test_missing_clears(self, blocksworld_missing_clears): 66 | """ 67 | Test the fully specified blocksworld problem with missing clears. 68 | """ 69 | problem = builder.build(blocksworld_missing_clears) 70 | full = oracle.fully_specify(problem) 71 | assert oracle.fully_specify(full) == full 72 | 73 | def test_missing_ontables(self, blocksworld_missing_ontables): 74 | """ 75 | Test the fully specified blocksworld problem with missing clears. 76 | """ 77 | problem = builder.build(blocksworld_missing_ontables) 78 | full = oracle.fully_specify(problem) 79 | assert oracle.fully_specify(full) == full 80 | 81 | def test_missing_ontables_and_clears(self, blocksworld_underspecified): 82 | """ 83 | Test the fully specified blocksworld problem with missing clears. 84 | """ 85 | problem = builder.build(blocksworld_underspecified) 86 | full = oracle.fully_specify(problem) 87 | assert oracle.fully_specify(full) == full 88 | 89 | def test_inflate( 90 | self, 91 | subtests, 92 | blocksworld_fully_specified, 93 | blocksworld_missing_clears, 94 | blocksworld_missing_ontables, 95 | blocksworld_underspecified, 96 | blocksworld_underspecified_arm, 97 | blocksworld_holding, 98 | ): 99 | """ 100 | Test the inflate function. 101 | """ 102 | 103 | for name, desc in { 104 | "blocksworld_fully_specified": blocksworld_fully_specified, 105 | "blocksworld_missing_clears": blocksworld_missing_clears, 106 | "blocksworld_missing_ontables": blocksworld_missing_ontables, 107 | "blocksworld_underspecified": blocksworld_underspecified, 108 | "blocksworld_underspecified_arm": blocksworld_underspecified_arm, 109 | "blocksworld_holding": blocksworld_holding, 110 | }.items(): 111 | problem = builder.build(desc) 112 | init, goal = problem.decompose() 113 | with subtests.test(name): 114 | assert reduce_and_inflate(init) 115 | assert reduce_and_inflate(goal) 116 | assert reduce_and_inflate(problem) 117 | 118 | assert problem == oracle.inflate( 119 | oracle.ReducedProblemGraph.join( 120 | oracle.reduce(init), 121 | oracle.reduce(goal), 122 | ) 123 | ) 124 | 125 | 126 | class TestGripperOracle: 127 | """ 128 | Test suite for the gripper oracle. 129 | """ 130 | 131 | def test_fully_specified( 132 | self, 133 | subtests, 134 | gripper_fully_specified, 135 | gripper_no_goal_types, 136 | gripper_fully_specified_not_strict, 137 | ): 138 | """ 139 | Test the fully specified gripper problem. 140 | """ 141 | descs = [ 142 | ("gripper_fully_specified", gripper_fully_specified), 143 | ("gripper_no_goal_types", gripper_no_goal_types), 144 | ("gripper_fully_specified_not_strict", gripper_fully_specified_not_strict), 145 | ] 146 | for name, desc in descs: 147 | with subtests.test(name): 148 | problem = builder.build(desc) 149 | full = oracle.fully_specify(problem) 150 | assert oracle.fully_specify(full) == full 151 | 152 | def test_inflate( 153 | self, 154 | subtests, 155 | gripper_fully_specified, 156 | gripper_no_robby, 157 | gripper_underspecified_1, 158 | gripper_underspecified_2, 159 | gripper_underspecified_3, 160 | gripper_no_robby_init, 161 | ): 162 | """ 163 | Test the inflate function. 164 | """ 165 | 166 | descs = [ 167 | ("gripper_fully_specified", gripper_fully_specified), 168 | ("gripper_no_robby", gripper_no_robby), 169 | ("gripper_underspecified_1", gripper_underspecified_1), 170 | ("gripper_underspecified_2", gripper_underspecified_2), 171 | ("gripper_underspecified_3", gripper_underspecified_3), 172 | ("gripper_no_robby_init", gripper_no_robby_init), 173 | ] 174 | 175 | for name, desc in descs: 176 | problem = builder.build(desc) 177 | init, goal = problem.decompose() 178 | with subtests.test(name): 179 | assert reduce_and_inflate(init) 180 | assert reduce_and_inflate(goal) 181 | assert reduce_and_inflate(problem) 182 | 183 | def test_underspecified( 184 | self, 185 | gripper_underspecified_1, 186 | gripper_underspecified_2, 187 | ): 188 | problem = builder.build(gripper_underspecified_1) 189 | full = oracle.fully_specify(problem) 190 | assert oracle.fully_specify(full) == full 191 | 192 | problem = builder.build(gripper_underspecified_2) 193 | full = oracle.fully_specify(problem) 194 | assert oracle.fully_specify(full) == full 195 | 196 | 197 | class TestRoverSingleOracle: 198 | """ 199 | Test suite for the rover oracle. 200 | """ 201 | 202 | def test_fully_specified( 203 | self, 204 | subtests, 205 | rover_single_line_fully_specified, 206 | rover_single_line_fully_specified_1, 207 | rover_single_line_fully_specified_2, 208 | rover_single_line_fully_specified_3, 209 | rover_single_line_fully_specified_4, 210 | ): 211 | """ 212 | Test the fully specified rover problem. 213 | """ 214 | descs = [ 215 | ("rover_single_line_fully_specified", rover_single_line_fully_specified), 216 | ( 217 | "rover_single_line_fully_specified_1", 218 | rover_single_line_fully_specified_1, 219 | ), 220 | ( 221 | "rover_single_line_fully_specified_2", 222 | rover_single_line_fully_specified_2, 223 | ), 224 | ( 225 | "rover_single_line_fully_specified_3", 226 | rover_single_line_fully_specified_3, 227 | ), 228 | ( 229 | "rover_single_line_fully_specified_4", 230 | rover_single_line_fully_specified_4, 231 | ), 232 | ] 233 | for name, desc in descs: 234 | with subtests.test(name): 235 | problem = builder.build(desc) 236 | full = oracle.fully_specify(problem) 237 | assert full == problem, "fully_specify(problem) == problem" 238 | assert oracle.fully_specify(full) == full, "fully_specify(fully_specify(problem)) == fully_specify(problem)" 239 | 240 | def test_inflate( 241 | self, 242 | subtests, 243 | rover_single_line_fully_specified, 244 | rover_single_line_fully_specified_1, 245 | rover_single_line_fully_specified_2, 246 | rover_single_line_fully_specified_3, 247 | rover_single_line_fully_specified_4, 248 | ): 249 | """ 250 | Test the inflate function. 251 | """ 252 | descs = [ 253 | ("rover_single_line_fully_specified", rover_single_line_fully_specified), 254 | ( 255 | "rover_single_line_fully_specified_1", 256 | rover_single_line_fully_specified_1, 257 | ), 258 | ( 259 | "rover_single_line_fully_specified_2", 260 | rover_single_line_fully_specified_2, 261 | ), 262 | ( 263 | "rover_single_line_fully_specified_3", 264 | rover_single_line_fully_specified_3, 265 | ), 266 | ( 267 | "rover_single_line_fully_specified_4", 268 | rover_single_line_fully_specified_4, 269 | ), 270 | ] 271 | for name, desc in descs: 272 | problem = builder.build(desc) 273 | init, goal = problem.decompose() 274 | with subtests.test(name): 275 | assert reduce_and_inflate(init) 276 | assert reduce_and_inflate(goal) 277 | assert reduce_and_inflate(problem) 278 | 279 | 280 | class TestFloorTileOracle: 281 | """ 282 | Test suite for the floor tile oracle. 283 | """ 284 | 285 | def test_fully_specified( 286 | self, 287 | subtests, 288 | floortile_fully_specified, 289 | floortile_no_white1, 290 | floortile_no_white2, 291 | floortile_no_available_colors, 292 | floortile_disconnected_tile_no_white, 293 | floortile_disconnected_tile1, 294 | floortile_one_color_one_robot1, 295 | ): 296 | """ 297 | Test the fully specified floor tile problem. 298 | """ 299 | descs = { 300 | "floortile_fully_specified": floortile_fully_specified, 301 | "floortile_no_white1": floortile_no_white1, 302 | "floortile_no_white2": floortile_no_white2, 303 | "floortile_no_available_colors": floortile_no_available_colors, 304 | "floortile_disconnected_tile_no_white": floortile_disconnected_tile_no_white, 305 | "floortile_disconnected_tile1": floortile_disconnected_tile1, 306 | "floortile_one_color_one_robot1": floortile_one_color_one_robot1, 307 | } 308 | for name, desc in descs.items(): 309 | problem = builder.build(desc) 310 | full = oracle.fully_specify(problem) 311 | with subtests.test(name): 312 | assert full == problem, name 313 | assert oracle.fully_specify(full) == full, name 314 | 315 | def test_under_specified( 316 | self, 317 | subtests, 318 | floortile_underspecified_directions, 319 | floortile_no_white1a, 320 | floortile_disconnected_tile1a, 321 | ): 322 | """ 323 | Test the under specified floor tile problem. 324 | """ 325 | descs = { 326 | "floortile_underspecified_directions": floortile_underspecified_directions, 327 | "floortile_no_white1a": floortile_no_white1a, 328 | "floortile_disconnected_tile1a": floortile_disconnected_tile1a, 329 | } 330 | for name, desc in descs.items(): 331 | problem = builder.build(desc) 332 | full = oracle.fully_specify(problem) 333 | with subtests.test(name): 334 | assert full != problem, name 335 | assert oracle.fully_specify(full) == full, name 336 | 337 | def test_infalte( 338 | self, 339 | subtests, 340 | floortile_fully_specified, 341 | floortile_no_white1, 342 | floortile_no_white2, 343 | floortile_no_available_colors, 344 | floortile_disconnected_tile_no_white, 345 | floortile_disconnected_tile1, 346 | floortile_one_color_one_robot1, 347 | floortile_underspecified_directions, 348 | floortile_no_white1a, 349 | floortile_disconnected_tile1a, 350 | ): 351 | """ 352 | Test the inflate function. 353 | """ 354 | descs = { 355 | "floortile_fully_specified": floortile_fully_specified, 356 | "floortile_no_white1": floortile_no_white1, 357 | "floortile_no_white2": floortile_no_white2, 358 | "floortile_no_available_colors": floortile_no_available_colors, 359 | "floortile_disconnected_tile_no_white": floortile_disconnected_tile_no_white, 360 | "floortile_disconnected_tile1": floortile_disconnected_tile1, 361 | "floortile_one_color_one_robot1": floortile_one_color_one_robot1, 362 | "floortile_underspecified_directions": floortile_underspecified_directions, 363 | "floortile_no_white1a": floortile_no_white1a, 364 | "floortile_disconnected_tile1a": floortile_disconnected_tile1a, 365 | } 366 | for name, desc in descs.items(): 367 | problem = builder.build(desc) 368 | init, goal = problem.decompose() 369 | with subtests.test(name): 370 | assert reduce_and_inflate(init) 371 | assert reduce_and_inflate(goal) 372 | assert reduce_and_inflate(problem) 373 | 374 | 375 | class TestUnsupportedDomain: 376 | def test_reduce_and_inflate(self, gripper_fully_specified): 377 | problem = builder.build(gripper_fully_specified) 378 | init, goal = problem.decompose() 379 | 380 | with pytest.raises(oracle.DomainNotSupportedError): 381 | oracle.reduce(init, domain="gripper-modified") 382 | with pytest.raises(oracle.DomainNotSupportedError): 383 | reduced = oracle.reduce(goal, domain="gripper") 384 | oracle.inflate(reduced, domain="gripper-modified") 385 | 386 | def test_fully_specify(self, gripper_fully_specified): 387 | problem = builder.build(gripper_fully_specified) 388 | with pytest.raises(oracle.DomainNotSupportedError): 389 | oracle.fully_specify(problem, domain="gripper-modified") 390 | 391 | def test_plan(self, gripper_fully_specified): 392 | problem = builder.build(gripper_fully_specified) 393 | with pytest.raises(oracle.DomainNotSupportedError): 394 | oracle.plan(problem, domain="gripper-modified") 395 | -------------------------------------------------------------------------------- /tests/test_pddl.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from planetarium import builder 4 | 5 | from pddl.parser.problem import LenientProblemParser 6 | 7 | 8 | @pytest.fixture 9 | def problem_string(): 10 | """ 11 | Fixture providing a sample PDDL problem definition as a string. 12 | """ 13 | return """ 14 | (define (problem mixed-f4-p2-u0-v0-g0-a0-n0-A0-B0-N0-F0-r0) 15 | (:domain miconic) 16 | (:objects p0 p1 - passenger 17 | f0 f1 f2 f3 - floor) 18 | 19 | (:init 20 | (above f0 f1) 21 | (above f0 f2) 22 | (above f0 f3) 23 | (above f1 f2) 24 | (above f1 f3) 25 | (above f2 f3) 26 | (origin p0 f3) 27 | (destin p0 f2) 28 | (origin p1 f1) 29 | (destin p1 f3) 30 | (lift-at f0)) 31 | 32 | (:goal (and 33 | (served p0) 34 | (served p1))) 35 | ) 36 | """ 37 | 38 | 39 | @pytest.fixture 40 | def two_initial_problem_string(): 41 | """ 42 | Fixture providing a sample PDDL problem definition as a string. 43 | """ 44 | return """ 45 | (define (problem mixed-f4-p2-u0-v0-g0-a0-n0-A0-B0-N0-F0-r0) 46 | (:domain miconic) 47 | (:objects p0 p1 - passenger 48 | f0 f1 f2 f3 - floor) 49 | 50 | (:init 51 | (above f0 f1) 52 | (above f0 f2) 53 | (above f0 f3) 54 | (above f1 f2) 55 | (above f1 f3) 56 | (above f2 f3) 57 | (origin p0 f3) 58 | (destin p0 f2) 59 | (origin p1 f1) 60 | (destin p1 f3) 61 | (lift-at f0)) 62 | 63 | (:goal (and 64 | (above f0 f1) 65 | (above f0 f2) 66 | (above f0 f3) 67 | (above f1 f2) 68 | (above f1 f3) 69 | (above f2 f3) 70 | (origin p0 f3) 71 | (destin p0 f2) 72 | (origin p1 f1) 73 | (destin p1 f3) 74 | (lift-at f0))) 75 | ) 76 | """ 77 | 78 | 79 | @pytest.fixture 80 | def renamed_problem_string(): 81 | """ 82 | Fixture providing a sample PDDL problem definition as a string. 83 | """ 84 | return """ 85 | (define (problem mixed-f4-p2-u0-v0-g0-a0-n0-A0-B0-N0-F0-r0) 86 | (:domain miconic) 87 | (:objects p1 p0 - passenger 88 | f1 f2 f3 f0 - floor) 89 | 90 | (:init 91 | (above f1 f2) 92 | (above f1 f3) 93 | (above f1 f0) 94 | (above f2 f3) 95 | (above f2 f0) 96 | (above f3 f0) 97 | (origin p1 f0) 98 | (destin p1 f3) 99 | (origin p0 f2) 100 | (destin p0 f0) 101 | (lift-at f1)) 102 | 103 | (:goal (and 104 | (served p1) 105 | (served p0))) 106 | ) 107 | """ 108 | 109 | 110 | @pytest.fixture 111 | def wrong_problem_string(): 112 | """ 113 | Fixture providing a sample PDDL problem definition as a string. 114 | """ 115 | return """ 116 | (define (problem mixed-f4-p2-u0-v0-g0-a0-n0-A0-B0-N0-F0-r0) 117 | (:domain miconic) 118 | (:objects p1 p0 - passenger 119 | f1 f2 f3 f0 - floor) 120 | 121 | (:init 122 | (above f1 f2) 123 | (above f1 f3) 124 | (above f1 f0) 125 | (above f2 f3) 126 | (above f2 f0) 127 | (above f3 f0) 128 | (origin p1 f0) 129 | (destin p1 f3) 130 | (origin p0 f2) 131 | (destin p0 f0) 132 | (lift-at f1)) 133 | 134 | (:goal (and 135 | (served f3) 136 | (served p0))) 137 | ) 138 | """ 139 | 140 | 141 | @pytest.fixture 142 | def wrong_initial_problem_string(): 143 | """ 144 | Fixture providing a sample PDDL problem definition as a string. 145 | """ 146 | return """ 147 | (define (problem mixed-f4-p2-u0-v0-g0-a0-n0-A0-B0-N0-F0-r0) 148 | (:domain miconic) 149 | (:objects p1 p0 - passenger 150 | f1 f2 f3 f0 - floor) 151 | 152 | (:init 153 | (above f1 f2) 154 | (above f1 f3) 155 | (above f1 f0) 156 | (above f2 f3) 157 | (above f2 f0) 158 | (above f3 f0) 159 | (destin p1 f3) 160 | (origin p0 f2) 161 | (destin p0 f0) 162 | (lift-at f1)) 163 | 164 | (:goal (and 165 | (served p1) 166 | (served p0))) 167 | ) 168 | """ 169 | 170 | 171 | @pytest.fixture 172 | def swap_problem_string(): 173 | """ 174 | Fixture providing a sample PDDL problem definition as a string. 175 | """ 176 | return """ 177 | (define (problem swap) 178 | (:domain swap) 179 | (:objects a0 a1 - object 180 | b0 b1 - room) 181 | 182 | (:init 183 | (in a0 b0) 184 | (in a1 b1)) 185 | 186 | (:goal (and 187 | (in a0 b1) 188 | (in a1 b0))) 189 | ) 190 | """ 191 | 192 | 193 | @pytest.fixture 194 | def wrong_swap_problem_string(): 195 | """ 196 | Fixture providing a sample PDDL problem definition as a string. 197 | """ 198 | return """ 199 | (define (problem swap) 200 | (:domain swap) 201 | (:objects a0 a1 - object 202 | b0 b1 - room) 203 | 204 | (:init 205 | (in a0 b0) 206 | (in a1 b1)) 207 | 208 | (:goal (and 209 | (in a0 b0) 210 | (in a1 b1))) 211 | ) 212 | """ 213 | 214 | 215 | @pytest.fixture 216 | def move_problem_string(): 217 | """ 218 | Fixture providing a sample PDDL problem definition as a string. 219 | """ 220 | return """ 221 | (define (problem move) 222 | (:domain move) 223 | (:objects a0 a1 - object 224 | b0 b1 - room) 225 | 226 | (:init 227 | (in a0 b0) 228 | (in a1 b0)) 229 | 230 | (:goal (and 231 | (in a0 b1) 232 | (in a1 b0))) 233 | ) 234 | """ 235 | 236 | 237 | @pytest.fixture 238 | def wrong_move_problem_string(): 239 | """ 240 | Fixture providing a sample PDDL problem definition as a string. 241 | """ 242 | return """ 243 | (define (problem move) 244 | (:domain move) 245 | (:objects a0 a1 - object 246 | b0 b1 - room) 247 | 248 | (:init 249 | (in a0 b0) 250 | (in a1 b1)) 251 | 252 | (:goal (and 253 | (in a0 b0) 254 | (in a1 b1))) 255 | ) 256 | """ 257 | 258 | 259 | @pytest.fixture 260 | def single_predicate_goal(): 261 | """ 262 | Fixture providing a sample PDDL problem definition as a string. 263 | """ 264 | return """ 265 | (define (problem move) 266 | (:domain move) 267 | (:objects a0 a1 - object 268 | b0 b1 - room) 269 | 270 | (:init 271 | (in a0 b0) 272 | (in a1 b1)) 273 | 274 | (:goal (in a0 b0)) 275 | ) 276 | """ 277 | 278 | 279 | @pytest.fixture 280 | def not_predicate_goal(): 281 | """ 282 | Fixture providing a sample PDDL problem definition as a string. 283 | """ 284 | return """ 285 | (define (problem move) 286 | (:domain move) 287 | (:objects a0 a1 - object 288 | b0 b1 - room) 289 | 290 | (:init 291 | (in a0 b0) 292 | (in a1 b1)) 293 | 294 | (:goal (not 295 | (in a0 b0))) 296 | ) 297 | """ 298 | 299 | 300 | @pytest.fixture 301 | def problem(problem_string): 302 | """ 303 | Fixture providing a parsed PDDL problem object. 304 | """ 305 | return LenientProblemParser()(problem_string) 306 | 307 | 308 | class TestConstantToDict: 309 | """ 310 | Test suite for the _constant_to_dict function. 311 | """ 312 | 313 | def test_constat_name(self, problem): 314 | """ 315 | Test the conversion of a PDDL Constant to a dictionary with the correct name. 316 | """ 317 | constant = list(problem.objects)[0] 318 | assert builder._constant_to_dict(constant)["name"] == str(constant.name) 319 | 320 | def test_constat_type(self, problem): 321 | """ 322 | Test the conversion of a PDDL Constant to a dictionary with the correct typing. 323 | """ 324 | constant = list(problem.objects)[0] 325 | result_dict = builder._constant_to_dict(constant) 326 | assert ( 327 | result_dict["typing"] == constant.type_tags 328 | and type(result_dict["typing"]) == set 329 | ) 330 | 331 | 332 | class TestPredicateToDict: 333 | """ 334 | Test suite for the _predicate_to_dict function. 335 | """ 336 | 337 | def test_predicate_name(self, problem): 338 | """ 339 | Test the conversion of a PDDL Predicate to a dictionary with the correct name. 340 | """ 341 | predicate = list(problem.init)[0] 342 | assert builder._predicate_to_dict(predicate)["typing"] == str(predicate.name) 343 | 344 | def test_predicate_parameters(self, problem): 345 | """ 346 | Test the conversion of a PDDL Predicate to a dictionary with the correct parameters. 347 | """ 348 | predicate = list(problem.init)[0] 349 | result_dict = builder._predicate_to_dict(predicate) 350 | assert ( 351 | result_dict["parameters"] == [term.name for term in predicate.terms] 352 | and type(result_dict["parameters"]) == list 353 | ) 354 | 355 | 356 | class TestBuildConstants: 357 | """ 358 | Test suite for the _build_constants function. 359 | """ 360 | 361 | def test_size(self, problem): 362 | """ 363 | Test the size of the list of constants built from a PDDL problem. 364 | """ 365 | assert len(builder._build_constants(problem.objects)) == len(problem.objects) 366 | 367 | 368 | class TestBuildPredicates: 369 | """ 370 | Test suite for the _build_predicates function. 371 | """ 372 | 373 | def test_initial_size(self, problem): 374 | """ 375 | Test the size of the list of initial predicates built from a PDDL problem. 376 | """ 377 | assert len(builder._build_predicates(problem.init)) == len(problem.init) 378 | 379 | def test_goal_size(self, problem): 380 | """ 381 | Test the size of the list of goal predicates built from a PDDL problem. 382 | """ 383 | assert len(builder._build_predicates(problem.goal.operands)) == len( 384 | problem.goal.operands 385 | ) 386 | 387 | 388 | class TestBuild: 389 | """ 390 | Test suite for the build function. 391 | """ 392 | 393 | def test_node_size(self, problem_string): 394 | """ 395 | Test the size of nodes in the scene graphs built from a PDDL problem. 396 | """ 397 | graph_1, graph_2 = builder.build(problem_string).decompose() 398 | assert len(graph_1.nodes) == 17 and len(graph_2.nodes) == 8 399 | 400 | def test_edge_size(self, problem_string): 401 | """ 402 | Test the size of edges in the scene graphs built from a PDDL problem. 403 | """ 404 | graph_1, graph_2 = builder.build(problem_string).decompose() 405 | assert len(graph_1.edges) == 21 and len(graph_2.edges) == 2 406 | 407 | def test_edge_size(self, problem_string): 408 | """ 409 | Test the size of edges in the scene graphs built from a PDDL problem. 410 | """ 411 | modified_problem_string = f"Here is an example of a problem string that is not a PDDL problem. ```pddl\n{problem_string}\n```" 412 | graph_1, graph_2 = builder.build(modified_problem_string).decompose() 413 | assert len(graph_1.edges) == 21 and len(graph_2.edges) == 2 414 | 415 | def test_single_predicate_goal(self, single_predicate_goal): 416 | """ 417 | Test the size of nodes in the scene graphs built from a PDDL problem. 418 | """ 419 | builder.build(single_predicate_goal).decompose() 420 | 421 | 422 | def test_to_pddl_str(self, single_predicate_goal): 423 | """ 424 | Test the size of nodes in the scene graphs built from a PDDL problem. 425 | """ 426 | builder.build(single_predicate_goal).to_pddl_str() 427 | 428 | def test_not_predicate_goal(self, not_predicate_goal): 429 | """ 430 | Test the size of nodes in the scene graphs built from a PDDL problem. 431 | """ 432 | with pytest.raises(ValueError): 433 | builder.build(not_predicate_goal).decompose() 434 | -------------------------------------------------------------------------------- /tests/test_planner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | 4 | VALIDATE = os.getenv("VALIDATE", "Validate") 5 | 6 | from planetarium import builder, downward, oracle 7 | 8 | from .problem_fixtures import ( 9 | blocksworld_fully_specified, 10 | blocksworld_holding, 11 | blocksworld_missing_clears, 12 | blocksworld_missing_ontables, 13 | blocksworld_underspecified, 14 | blocksworld_underspecified_arm, 15 | blocksworld_stack_to_holding, 16 | blocksworld_invalid_1, 17 | blocksworld_invalid_2, 18 | blocksworld_invalid_3, 19 | gripper_fully_specified, 20 | gripper_fully_specified_not_strict, 21 | gripper_inconsistent_typing, 22 | gripper_missing_typing, 23 | gripper_multiple_typing, 24 | gripper_no_goal_types, 25 | gripper_no_robby, 26 | gripper_robby_at_last, 27 | gripper_underspecified_1, 28 | gripper_underspecified_2, 29 | gripper_underspecified_3, 30 | gripper_invalid, 31 | ) 32 | 33 | DOMAINS = { 34 | "blocksworld": """;; source: https://github.com/AI-Planning/pddl-generators/blob/main/blocksworld/domain.pddl 35 | ;; same as used in IPC 2023 36 | ;; 37 | (define (domain blocksworld) 38 | 39 | (:requirements :strips) 40 | 41 | (:predicates (clear ?x) 42 | (on-table ?x) 43 | (arm-empty) 44 | (holding ?x) 45 | (on ?x ?y)) 46 | 47 | (:action pickup 48 | :parameters (?ob) 49 | :precondition (and (clear ?ob) (on-table ?ob) (arm-empty)) 50 | :effect (and (holding ?ob) (not (clear ?ob)) (not (on-table ?ob)) 51 | (not (arm-empty)))) 52 | 53 | (:action putdown 54 | :parameters (?ob) 55 | :precondition (holding ?ob) 56 | :effect (and (clear ?ob) (arm-empty) (on-table ?ob) 57 | (not (holding ?ob)))) 58 | 59 | (:action stack 60 | :parameters (?ob ?underob) 61 | :precondition (and (clear ?underob) (holding ?ob)) 62 | :effect (and (arm-empty) (clear ?ob) (on ?ob ?underob) 63 | (not (clear ?underob)) (not (holding ?ob)))) 64 | 65 | (:action unstack 66 | :parameters (?ob ?underob) 67 | :precondition (and (on ?ob ?underob) (clear ?ob) (arm-empty)) 68 | :effect (and (holding ?ob) (clear ?underob) 69 | (not (on ?ob ?underob)) (not (clear ?ob)) (not (arm-empty))))) 70 | """, 71 | "gripper": """;; source: https://github.com/AI-Planning/pddl-generators/blob/main/gripper/domain.pddl 72 | (define (domain gripper) 73 | (:requirements :strips) 74 | (:predicates (room ?r) 75 | (ball ?b) 76 | (gripper ?g) 77 | (at-robby ?r) 78 | (at ?b ?r) 79 | (free ?g) 80 | (carry ?o ?g)) 81 | 82 | (:action move 83 | :parameters (?from ?to) 84 | :precondition (and (room ?from) (room ?to) (at-robby ?from)) 85 | :effect (and (at-robby ?to) 86 | (not (at-robby ?from)))) 87 | 88 | (:action pick 89 | :parameters (?obj ?room ?gripper) 90 | :precondition (and (ball ?obj) (room ?room) (gripper ?gripper) 91 | (at ?obj ?room) (at-robby ?room) (free ?gripper)) 92 | :effect (and (carry ?obj ?gripper) 93 | (not (at ?obj ?room)) 94 | (not (free ?gripper)))) 95 | 96 | (:action drop 97 | :parameters (?obj ?room ?gripper) 98 | :precondition (and (ball ?obj) (room ?room) (gripper ?gripper) 99 | (carry ?obj ?gripper) (at-robby ?room)) 100 | :effect (and (at ?obj ?room) 101 | (free ?gripper) 102 | (not (carry ?obj ?gripper))))) 103 | """, 104 | } 105 | 106 | 107 | class TestBlocksworldOracle: 108 | """ 109 | Test suite for the blocksworld oracle. 110 | """ 111 | 112 | def test_plan( 113 | self, 114 | subtests, 115 | blocksworld_missing_clears, 116 | blocksworld_fully_specified, 117 | blocksworld_holding, 118 | blocksworld_missing_ontables, 119 | blocksworld_underspecified, 120 | blocksworld_underspecified_arm, 121 | blocksworld_stack_to_holding, 122 | ): 123 | """ 124 | Test if the oracle can plan for a fully specified blocksworld problem. 125 | """ 126 | for name, desc in { 127 | "blocksworld_fully_specified": blocksworld_fully_specified, 128 | "blocksworld_holding": blocksworld_holding, 129 | "blocksworld_missing_clears": blocksworld_missing_clears, 130 | "blocksworld_missing_ontables": blocksworld_missing_ontables, 131 | "blocksworld_underspecified": blocksworld_underspecified, 132 | "blocksworld_underspecified_arm": blocksworld_underspecified_arm, 133 | "blocksworld_stack_to_holding": blocksworld_stack_to_holding, 134 | }.items(): 135 | plan = oracle.plan(builder.build(desc)) 136 | with subtests.test(name): 137 | assert plan != [], name 138 | 139 | assert downward.validate( 140 | DOMAINS["blocksworld"], 141 | desc, 142 | oracle.plan_to_string(plan), 143 | VALIDATE, 144 | ) 145 | 146 | with subtests.test(name): 147 | assert not downward.validate( 148 | DOMAINS["gripper"], 149 | desc, 150 | oracle.plan_to_string(plan), 151 | VALIDATE, 152 | ) 153 | 154 | def test_invalid_plan( 155 | self, 156 | subtests, 157 | blocksworld_invalid_2, 158 | ): 159 | """ 160 | Test if the oracle can plan for an invalid blocksworld problem. 161 | """ 162 | domain = DOMAINS["blocksworld"] 163 | for name, desc in { 164 | "blocksworld_invalid_2": blocksworld_invalid_2, 165 | }.items(): 166 | with subtests.test(name): 167 | try: 168 | plan = oracle.plan(builder.build(desc)) 169 | except Exception as e: 170 | plan = [] 171 | assert plan == [], f"{name}: {plan}" 172 | 173 | plan_str = oracle.plan_to_string(plan) 174 | assert not downward.validate(domain, desc, plan_str, VALIDATE) 175 | 176 | 177 | class TestGripperOracle: 178 | """ 179 | Test suite for the gripper oracle. 180 | """ 181 | 182 | def test_plan( 183 | self, 184 | subtests, 185 | gripper_fully_specified, 186 | gripper_fully_specified_not_strict, 187 | gripper_no_goal_types, 188 | gripper_no_robby, 189 | gripper_robby_at_last, 190 | gripper_underspecified_1, 191 | gripper_underspecified_2, 192 | gripper_underspecified_3, 193 | ): 194 | """ 195 | Test if the oracle can plan for a fully specified gripper problem. 196 | """ 197 | domain = DOMAINS["gripper"] 198 | for name, desc in { 199 | "gripper_fully_specified": gripper_fully_specified, 200 | "gripper_fully_specified_not_strict": gripper_fully_specified_not_strict, 201 | "gripper_no_goal_types": gripper_no_goal_types, 202 | "gripper_no_robby": gripper_no_robby, 203 | "gripper_robby_at_last": gripper_robby_at_last, 204 | "gripper_underspecified_1": gripper_underspecified_1, 205 | "gripper_underspecified_2": gripper_underspecified_2, 206 | "gripper_underspecified_3": gripper_underspecified_3, 207 | }.items(): 208 | with subtests.test(name): 209 | plan = oracle.plan(builder.build(desc)) 210 | assert plan != [], name 211 | 212 | assert downward.validate( 213 | domain, 214 | desc, 215 | oracle.plan_to_string(plan), 216 | VALIDATE, 217 | ), name 218 | 219 | with subtests.test(name): 220 | assert not downward.validate( 221 | DOMAINS["blocksworld"], 222 | desc, 223 | oracle.plan_to_string(plan), 224 | VALIDATE, 225 | ) 226 | 227 | 228 | class TestUnsupportedDomain: 229 | """ 230 | Test suite for unsupported domain. 231 | """ 232 | 233 | def test_plan(self, mocker, blocksworld_fully_specified): 234 | """ 235 | Test if the oracle can plan for an unsupported domain. 236 | """ 237 | problem = builder.build(blocksworld_fully_specified) 238 | mocker.patch("planetarium.oracle.fully_specify", return_value=problem) 239 | with pytest.raises(oracle.DomainNotSupportedError): 240 | oracle.plan(problem, domain="unsupported_domain") 241 | --------------------------------------------------------------------------------