├── .coveragerc
├── .github
└── workflows
│ └── python-package.yml
├── .gitignore
├── EXTENDING.md
├── LICENSE
├── README.md
├── assets
└── equivalence.png
├── dataset-config.zip
├── dataset_generator.py
├── evaluate.py
├── experiment_cfgs
├── gemma-1.1-7b-it-finetune.yaml
└── gemma-1.1-7b-it-zero-shot.yaml
├── finetune.py
├── llm_planner.py
├── planetarium
├── ASSUMPTIONS.md
├── __init__.py
├── builder.py
├── domains
│ ├── blocksworld.pddl
│ ├── floor-tile.pddl
│ ├── gripper.pddl
│ ├── rover-single.pddl
│ └── rover.pddl
├── downward.py
├── evaluate.py
├── graph.py
├── metric.py
├── oracle.py
├── oracles
│ ├── __init__.py
│ ├── blocksworld.py
│ ├── floortile.py
│ ├── gripper.py
│ ├── oracle.py
│ └── rover_single.py
└── reduced_graph.py
├── poetry.lock
├── pyproject.toml
└── tests
├── __init__.py
├── problem_fixtures.py
├── test_evaluate.py
├── test_graph.py
├── test_metric.py
├── test_oracle.py
├── test_pddl.py
└── test_planner.py
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | omit =
3 | planetarium/oracles/oracle.py
--------------------------------------------------------------------------------
/.github/workflows/python-package.yml:
--------------------------------------------------------------------------------
1 | name: ci
2 |
3 | on: [push]
4 |
5 | jobs:
6 | test:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - name: checkout repository
10 | uses: actions/checkout@v3
11 |
12 | - name: set up python
13 | id: setup-python
14 | uses: actions/setup-python@v4
15 | with:
16 | python-version: "3.10"
17 |
18 | - name: Setup apptainer
19 | uses: eWaterCycle/setup-apptainer@v2.0.0
20 |
21 | - name: install poetry
22 | uses: snok/install-poetry@v1
23 | with:
24 | virtualenvs-create: true
25 | virtualenvs-in-project: true
26 | installer-parallel: true
27 |
28 | - name: load cached venv
29 | id: cached-poetry-dependencies
30 | uses: actions/cache@v3
31 | with:
32 | path: .venv
33 | key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}
34 |
35 | - name: load downward
36 | id: cached-downward
37 | uses: actions/cache@v3
38 | with:
39 | path: tmp/
40 | key: downward-${{ runner.os }}-${{ hashFiles('**/VAL.zip') }}-${{ hashFiles('fast-downward.sif') }}
41 |
42 | - name: install dependencies
43 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
44 | run: poetry install --no-interaction --no-root
45 |
46 | - name: install downward
47 | if: steps.cached-downward.outputs.cache-hit != 'true'
48 | run: |
49 | mkdir tmp
50 | apptainer pull fast-downward.sif docker://aibasel/downward:latest
51 | mv fast-downward.sif tmp/
52 | curl -o tmp/VAL.zip https://dev.azure.com/schlumberger/4e6bcb11-cd68-40fe-98a2-e3777bfec0a6/_apis/build/builds/77/artifacts?artifactName=linux64\&api-version=7.1\&%24format=zip
53 | unzip tmp/VAL.zip -d tmp/
54 | tar -xzvf tmp/linux64/*.tar.gz -C tmp/ --strip-components=1
55 |
56 | - name: install project
57 | run: poetry install --no-interaction
58 |
59 | - name: lint
60 | run:
61 | source .venv/bin/activate
62 | poetry run black --check planetarium/. tests/.
63 | poetry ruff planetarium/.
64 | poetry mypy planetarium/.
65 |
66 | - name: test
67 | run: |
68 | source .venv/bin/activate
69 | export DOWNWARD=$(pwd)/tmp/fast-downward.sif
70 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$(pwd)/tmp/bin
71 | export PATH=$PATH:$(pwd)/tmp/bin
72 | poetry run pytest --cov-fail-under=90 --cov=planetarium --timeout=120 tests/.
73 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | venv
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 | cover/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | .pybuilder/
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | # For a library or package, you might want to ignore these files since the code is
89 | # intended to run in multiple environments; otherwise, check them in:
90 | # .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # poetry
100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101 | # This is especially recommended for binary packages to ensure reproducibility, and is more
102 | # commonly ignored for libraries.
103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104 | #poetry.lock
105 |
106 | # pdm
107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108 | #pdm.lock
109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110 | # in version control.
111 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
112 | .pdm.toml
113 | .pdm-python
114 | .pdm-build/
115 |
116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117 | __pypackages__/
118 |
119 | # Celery stuff
120 | celerybeat-schedule
121 | celerybeat.pid
122 |
123 | # SageMath parsed files
124 | *.sage.py
125 |
126 | # Environments
127 | .env
128 | .venv
129 | env/
130 | venv/
131 | ENV/
132 | env.bak/
133 | venv.bak/
134 |
135 | # Spyder project settings
136 | .spyderproject
137 | .spyproject
138 |
139 | # Rope project settings
140 | .ropeproject
141 |
142 | # mkdocs documentation
143 | /site
144 |
145 | # mypy
146 | .mypy_cache/
147 | .dmypy.json
148 | dmypy.json
149 |
150 | # Pyre type checker
151 | .pyre/
152 |
153 | # pytype static type analyzer
154 | .pytype/
155 |
156 | # Cython debug symbols
157 | cython_debug/
158 |
159 | # PyCharm
160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162 | # and can be added to the global gitignore or merged into this file. For a more nuclear
163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164 | #.idea/
165 | test_pddl_generator.py
166 |
--------------------------------------------------------------------------------
/EXTENDING.md:
--------------------------------------------------------------------------------
1 | # Extending Planetarium
2 |
3 | If you're looking to evaluate your own domain, here is a guide to help you get started.
4 |
5 | ### 1. Add a domain file
6 | Add a domain PDDL file to the `planetarium/domains/` directory, where the filename is the name of your domain.
7 |
8 | ### 2. Add an Oracle
9 |
10 | Every domain in Planetarium requires an [`Oracle`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/oracles/oracle.py#L8) object and file tied to it.
11 | There are three fundamental components to the Oracle:
12 |
13 | - [`.reduce()`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/oracles/oracle.py#L11) function, which takes a [`ProblemGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/graph.py#L430) or [`SceneGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/graph.py#L391) object and returns a [`ReducedProblemGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/reduced_graph.py#L80) or [`ReducedSceneGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/reduced_graph.py#L48) object, respectively.
14 | - [`.inflate()`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/oracles/oracle.py#L65) function, which takes a [`ReducedProblemGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/reduced_graph.py#L80) or [`ReducedSceneGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/reduced_graph.py#L48) object and returns a [`ProblemGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/graph.py#L430) or [`SceneGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/graph.py#L391) object, respectively.
15 | - [`.fully_specify()`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/oracles/oracle.py#L125) function, which takes a [`ProblemGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/graph.py#L430) and returns either a [`ProblemGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/graph.py#L430) or a [`ReducedProblemGraph`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/reduced_graph.py#L80) with all possible predicates added to the goal scene without changing the original definition of the problem.
16 | We refer to these predicates as "trivial predicates" in our paper.
17 | The `fully_specify` function is used to ensure that the problem is fully specified before evaluation.
18 |
19 | To add your domain, you must create a python script under `planetarium/oracles/` that contains your implementation of an Oracle subclass.
20 | This script should contain a class that inherits from `Oracle` and implements the three functions described above.
21 | While we provide a generic `reduce` and `inflate` function in the base `Oracle` class, you will definitely want to override these functions if your domain has any 0-ary, 1-ary, or (3+)-ary predicates (more info below).
22 |
23 | #### 2.1 Remember to add your Oracle script to the `__init__.py` file in the `planetarium/oracles/` directory:
24 | ```python
25 | # planetarium/oracles/__init__.py
26 | __all__ = ["blocksworld", "gripper", "rover_single", "floortile", "YOUR_DOMAIN"]
27 |
28 | from . import oracle
29 |
30 | from . import blocksworld
31 | from . import gripper
32 | from . import rover_single
33 | from . import floortile
34 | from . import YOUR_DOMAIN
35 |
36 | ORACLES: dict[str, oracle.Oracle] = {
37 | "blocksworld": blocksworld.BlocksworldOracle,
38 | "gripper": gripper.GripperOracle,
39 | "rover-single": rover_single.RoverSingleOracle,
40 | "floor-tile": floortile.FloorTileOracle,
41 | "YOUR_DOMAIN": YOUR_DOMAIN.YourDomainOracle,
42 | }
43 | ```
44 |
45 | #### 2.2 Working with non-binary predicates (Using `ReducedNode`s)
46 | Some predicates require special care for reducing and inflating.
47 | The key idea behind our reduced representation is to represent our PDDL problem in a domain-specific manner with as few graph nodes and edges as possible to reduce the search space for our equivalence check.
48 | The reduced representation also allows us to perform higher-level graph manipulations and searches more efficiently.
49 |
50 | **[`ReducedNode`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/reduced_graph.py#L9)**:
51 |
52 | A [`ReducedNode`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/reduced_graph.py#L9) is a domain-specific set of nodes that will be added to every `ReducedSceneGraph` or `ReducedProblemGraph` object on construction. They help hold metadata for specific types of predicates (non-binary predicates, 0-ary predicates, etc.) that are defined by the domain.
53 |
54 | Here is an example on how to reduce different -ary predicates using `ReducedNode`s:
55 |
56 | **0-ary predicates**:
57 | These predicates can be represented by using a `ReducedNode` with a `name` attribute that matches the predicate name.
58 | If the predicate is true in the scene, one way to handle this is by adding a self-edge on this `ReducedNode` to represent the predicate.
59 |
60 | **1-ary predicates**:
61 | These predicates can be represented by using a `ReducedNode` with a `name` attribute that matches the predicate name.
62 | If the predicate is true in the scene, we can add an edge from the `ReducedNode` to the node that represents the object in the scene.
63 |
64 | **3+-ary predicates**:
65 | There is no easy way to reduce these predicates, so the best way to keep track of these is to simply add a predicate node to the `ReducedSceneGraph` or `ReducedProblemGraph` object that represents the predicate, and add edges to the nodes that represent the objects in the scene.
66 | Make sure to set the `position=` argument when adding the edge to the reduced graph to ensure you can reverse this action in your `inflate` function.
67 |
68 | **To register your `ReducedNode`**:
69 | At the top of your oracle script, you can call the [`ReducedNode.register`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/reduced_graph.py#L12) method, like the following:
70 |
71 | ```python
72 | ReducedNode.register(
73 | {
74 | # ReducedNodeName: corresponding predicate name
75 | "ROOMS": "room",
76 | "BALLS": "ball",
77 | "GRIPPERS": "gripper",
78 | "ROBBY": "at-robby",
79 | "FREE": "free",
80 | },
81 | "YOUR_DOMAIN", # name of your domain
82 | )
83 | ```
84 |
85 | This will let you use your `ReducedNode` like any other enum throughout your oracle script (e.g. `ReducedNode.ROOMS`).
86 |
87 | #### 2.3 Implementing `.plan()` (Optional)
88 | If you would like to evaluate whether or not a problem is _solvable_, you can implement the [`.plan()`](https://github.com/BatsResearch/planetarium/blob/4ca530982baec33ebe332f62b48a622f24b0dfb2/planetarium/oracles/oracle.py#L141) function in your `Oracle`.
89 | You will still be able to evaluate whether or not a problem is solvable without implementing this function, but it will rely on running the FastDownward planner to solve your problem, which may be *significantly* slower than using a domain-specific planner.
90 | (Note that you should try using the `lama-first` alias if possible, as this planner does not look for the optimal plan, just a satisficing plan.)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2024 Brown University
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # planetarium🪐
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 | Planetarium🪐 is a [dataset](https://huggingface.co/datasets/BatsResearch/planetarium) and benchmark for assessing LLMs in translating natural language descriptions of planning problems into PDDL. We developed a robust method for comparing PDDL problem descriptions using graph isomorphism.
11 |
12 | ## Installation
13 | To install the `planetarium` package, you can use the following command:
14 | ```bash
15 | pip install git+https://github.com/BatsResearch/planetarium.git
16 | ```
17 |
18 | For development or using our evaluate & finetune scripts, you can clone the repository and install all dependencies using the following commands:
19 | ```bash
20 | git clone https://github.com/BatsResearch/planetarium.git
21 | cd planetarium
22 | poetry install --with all
23 | ```
24 |
25 | To use `planetarium.downward`, you will need to have the [Fast-Downward](https://www.fast-downward.org/) planner installed, and the [VAL](https://github.com/KCL-Planning/VAL) plan validator. The following commands is one way to install them with minimal overhead:
26 | ```bash
27 | # Fast-Downward via Apptainer
28 | apptainer pull fast-downward.sif docker://aibasel/downward:latest
29 | # VAL download link might not work, follow instructions to download binary at: https://github.com/KCL-Planning/VAL
30 | mkdir tmp
31 | curl -o tmp/VAL.zip https://dev.azure.com/schlumberger/4e6bcb11-cd68-40fe-98a2-e3777bfec0a6/_apis/build/builds/77/artifacts?artifactName=linux64\&api-version=7.1\&%24format=zip
32 | unzip tmp/VAL.zip -d tmp/
33 | tar -xzvf tmp/linux64/*.tar.gz -C tmp/ --strip-components=1
34 | # clean up
35 | rm -rf tmp
36 | # Make sure to add fast-downward.sif and VAL to your PATH or make aliases.
37 | ```
38 |
39 | ## Basic Usage
40 | To evaluate a PDDL problem description, we can use the `planetarium.evaluate` module:
41 | ```python
42 | import planetarium
43 | ...
44 | planetarium.evaluate(gt_pddl_str, pred_pddl_str)
45 | ```
46 | The supported domains are `blocksworld` and `gripper` domains.
47 |
48 | ## Dataset
49 | The main page for the dataset can be found [here](https://huggingface.co/datasets/BatsResearch/planetarium).
50 |
51 | Here is an example of how to load the dataset:
52 | ```python
53 | from datasets import load_dataset
54 |
55 | dataset = load_dataset("BatsResearch/planetarium")
56 | ```
57 | Here, `dataset["test"]` is the main test set used in the paper. You may evaluate on this set to reproduce our results.
58 |
59 | You can reproduce the dataset, the splits, and a report by running the following command:
60 | ```bash
61 | python dataset_generator.py -c dataset_config.yaml
62 | ```
63 |
64 | By modifying the `dataset_config.yaml` file, you can change the dataset splits, the number of samples, and produce even more examples!
65 |
66 | ### Dataset Report
67 | Here is a summary of the types of PDDL problems in the dataset:
68 |
69 | Total number of problems: $132,037$.
70 |
71 | #### Abstractness Split
72 | | Init | Goal | blocksworld | gripper |
73 | |:---:|:---:|---:|---:|
74 | | abstract | abstract | $23,144$ | $10,632$ |
75 | | abstract | explicit | $23,086$ | $9,518$ |
76 | | explicit | abstract | $23,087$ | $10,313$ |
77 | | explicit | explicit | $23,033$ | $9,224$ |
78 | #### Size Splits (Number of Propositions in Ground Truth)
79 | | Num. of Propositions | blocksworld | gripper |
80 | |:---:|---:|---:|
81 | | $0$ - $20$ | $1,012$ | $379$ |
82 | | $20$ - $40$ | $10,765$ | $2,112$ |
83 | | $40$ - $60$ | $50,793$ | $9,412$ |
84 | | $60$ - $80$ | $26,316$ | $25,346$ |
85 | | $80$ - inf | $3,464$ | $2,438$ |
86 |
87 | ## How it Works
88 | Planetarium🪐 compares two PDDL problem descriptions by first transcribing them into a graph representation.
89 | Graphs help us to better detect and manipulate relationships between certain objects and propositions.
90 | Next, we build "fully specified" graph representations by adding "trivial" propositions (propositions that do not exist in the problem description but must exist in any state that satisfies such description).
91 | Finally, we use graph isomorphism to compare the fully specified graph representations of the two PDDL problem descriptions, either comparing the entire problem graph or the individual initial and goal scene graphs.
92 | This lets check correctness of the translation of the natural language description into PDDL, without ever needing to run a planner.
93 |
94 | Below is a flowchart providing an overview of the equivalence algorithm:
95 |
96 | 
97 | (Left) Two planning problems, in PDDL problem description, real-world scenario, and graph representations. (Center) Fully specified graph representation. (Right) Graph isomorphism.
98 |
99 | The key to this algorithm working is building a specially crafted "fully specify" function, which we build for each domain that we want to support. We provide implementations for the `blocksworld` and `gripper` domains in the `planetarium.oracle` module.
100 |
101 | ## Adding a new domain
102 | For adding a new domain, please take a look at [this guide](EXTENDING.md).
--------------------------------------------------------------------------------
/assets/equivalence.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BatsResearch/planetarium/f257ae9e68b813aeace00cc683d0c2ad5b57a157/assets/equivalence.png
--------------------------------------------------------------------------------
/dataset-config.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BatsResearch/planetarium/f257ae9e68b813aeace00cc683d0c2ad5b57a157/dataset-config.zip
--------------------------------------------------------------------------------
/experiment_cfgs/gemma-1.1-7b-it-finetune.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | database_path: dataset.db
3 | splits_path: splits.yaml
4 | splits:
5 | train:
6 | - [blocksworld, random, "0"]
7 | - [gripper, random, "0"]
8 | - [blocksworld, random, "1"]
9 | - [gripper, random, "1"]
10 | - [blocksworld, random, "2"]
11 | - [gripper, random, "2"]
12 | - [blocksworld, random, "3"]
13 | - [gripper, random, "3"]
14 | test:
15 | - [blocksworld, random, "4"]
16 | - [gripper, random, "4"]
17 | prompts:
18 | problem: "Provide me with the complete, valid problem PDDL file that \
19 | describes the following planning problem directly without further \
20 | explanations or texts."
21 | domain: "The domain for the planning problem is:"
22 | train:
23 | lora_config:
24 | r: 16
25 | lora_alpha: 32
26 | lora_dropout: 0.05
27 | bias: none
28 | task_type: CAUSAL_LM
29 | bnb_config:
30 | bits: 4
31 | double_quant: True
32 | quant_type: nf4
33 | training_args:
34 | output_dir: /oscar/scratch/mzuo6/gemma-logs/
35 | evaluation_strategy: epoch
36 | learning_rate: .00002
37 | num_train_epochs: 1
38 | per_device_train_batch_size: 1
39 | per_device_eval_batch_size: 1
40 | weight_decay: 0.01
41 | report_to: wandb
42 | bf16: True
43 | bf16_full_eval: True
44 | save_path: gemma-1.1-7b-it-random/
45 | model:
46 | type: hf
47 | tokenizer_name: google/gemma-1.1-7b-it
48 | model_name: google/gemma-1.1-7b-it
49 | response_template: "model\n"
50 | max_seq_length: 1500
--------------------------------------------------------------------------------
/experiment_cfgs/gemma-1.1-7b-it-zero-shot.yaml:
--------------------------------------------------------------------------------
1 | dataset:
2 | database_path: dataset.db
3 | splits_path: splits.yaml
4 | splits:
5 | test:
6 | - [blocksworld, random, "4"]
7 | - [gripper, random, "4"]
8 | prompts:
9 | problem: "Provide me with the complete, valid problem PDDL file that \
10 | describes the following planning problem directly without further \
11 | explanations or texts."
12 | domain: "The domain for the planning problem is:"
13 | evaluate:
14 | splits:
15 | - test
16 | batch_size: 64
17 | model:
18 | type: hf
19 | tokenizer_name: google/gemma-1.1-7b-it
20 | model_name: google/gemma-1.1-7b-it
21 | response_template: "model\n"
22 | kwargs:
23 | attn_implementation: flash_attention_2
--------------------------------------------------------------------------------
/finetune.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from functools import partial
3 | import os
4 | import sqlite3
5 | import yaml
6 |
7 | import dotenv
8 |
9 | dotenv.load_dotenv()
10 |
11 | import torch
12 | from torch import nn
13 |
14 | import bitsandbytes as bnb
15 | from datasets import Dataset
16 | from peft import LoraConfig, get_peft_model
17 | from transformers import (
18 | AutoTokenizer,
19 | AutoModelForCausalLM,
20 | BitsAndBytesConfig,
21 | TrainingArguments,
22 | PreTrainedTokenizer,
23 | PreTrainedModel,
24 | )
25 | from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
26 | import tqdm as tqdm
27 |
28 | import llm_planner as llmp
29 |
30 | from accelerate import Accelerator
31 |
32 |
33 | HF_USER_TOKEN = os.getenv("HF_USER_TOKEN")
34 |
35 |
36 | def load_dataset(config: dict) -> dict[str, Dataset]:
37 | """Load the dataset from the configuration.
38 |
39 | Args:
40 | config (dict): The dataset configuration.
41 |
42 | Returns:
43 | dict[str, Dataset]: The loaded dataset.
44 | """
45 | with open(config["splits_path"], "r") as f:
46 | split_ids_cfg = yaml.safe_load(f)
47 |
48 | splits: set[str] = config.get("splits", {}).keys()
49 | dataset = {split: defaultdict(list) for split in splits}
50 |
51 | # Connect to database
52 | conn = sqlite3.connect(config["database_path"])
53 | c = conn.cursor()
54 |
55 | # load domains
56 | domains = {}
57 | c.execute("SELECT name, domain_pddl FROM domains")
58 | for domain_name, domain_pddl in c.fetchall():
59 | domains[domain_name] = domain_pddl
60 |
61 | # load problems
62 | for split in splits:
63 | queries = []
64 | split_keys: list[str] = config["splits"][split]
65 | for split_key in split_keys:
66 | split_ids = split_ids_cfg
67 | for key in split_key:
68 | split_ids = split_ids[key]
69 |
70 | c.execute(
71 | f"SELECT domain, problem_pddl, natural_language FROM problems WHERE id in ({', '.join(['?'] * len(split_ids))})",
72 | split_ids,
73 | )
74 | queries.extend(c.fetchall())
75 |
76 | for domain, problem_pddl, natural_language in queries:
77 | dataset[split]["domain"].append(domains[domain])
78 | dataset[split]["problem"].append(problem_pddl)
79 | dataset[split]["natural_language"].append(natural_language)
80 |
81 | return {s: Dataset.from_dict(d, split=s) for s, d in dataset.items()}
82 |
83 |
84 | def find_all_linear_names(
85 | model: nn.Module,
86 | bits: int | None = None,
87 | ) -> list[str]:
88 | """Find names of all linear layers in the model.
89 |
90 | Args:
91 | model (nn.Module): The model to search for linear layers.
92 |
93 | Returns:
94 | list[str]: The names of all linear layers in the model (excluding LM Head)
95 | """
96 | match bits:
97 | case 4:
98 | Linear = bnb.nn.Linear4bit
99 | case 8:
100 | Linear = bnb.nn.Linear8bitLt
101 | case _:
102 | Linear = torch.nn.Linear
103 |
104 | lora_module_names = set()
105 | for name, module in model.named_modules():
106 | if isinstance(module, Linear):
107 | names = name.split(".")
108 | lora_module_names.add(names[-1])
109 |
110 | if "lm_head" in lora_module_names: # needed for 16-bit
111 | lora_module_names.remove("lm_head")
112 | return list(lora_module_names)
113 |
114 |
115 | def strip(text: str, bos_token: str, eos_token: str) -> str:
116 | return text.removeprefix(bos_token) + eos_token
117 |
118 |
119 | def preprocess(
120 | tokenizer: PreTrainedTokenizer,
121 | examples,
122 | domain_prompt: str = "",
123 | problem_prompt: str = "",
124 | ) -> list[str]:
125 | """Preprocess the examples for training.
126 |
127 | Args:
128 | tokenizer (PreTrainedTokenizer): The tokenizer to use.
129 | examples: The examples to preprocess.
130 | domain_prompt (str, optional): How to prompt the domain. Defaults to "".
131 | problem_prompt (str, optional): How to prompt the problem. Defaults to "".
132 |
133 | Returns:
134 | list[str]: The preprocessed examples.
135 | """
136 | inputs = [
137 | strip(
138 | tokenizer.apply_chat_template(
139 | llmp.PlanningProblem(nl, d, p).apply_template(
140 | domain_prompt,
141 | problem_prompt,
142 | ),
143 | tokenize=False,
144 | add_generation_prompt=False,
145 | ),
146 | bos_token=tokenizer.bos_token,
147 | eos_token=tokenizer.eos_token,
148 | )
149 | for nl, d, p in zip(
150 | examples["natural_language"],
151 | examples["domain"],
152 | examples["problem"],
153 | )
154 | ]
155 | return inputs
156 |
157 |
158 | def load_model(config: dict) -> tuple[PreTrainedTokenizer, PreTrainedModel]:
159 | """Load the model and tokenizer from the configuration.
160 |
161 | Args:
162 | config (dict): The training config.
163 |
164 | Returns:
165 | tuple[PreTrainedTokenizer, PreTrainedModel]: The tokenizer and model.
166 | """
167 | tokenizer = AutoTokenizer.from_pretrained(
168 | config["model"]["tokenizer_name"],
169 | token=HF_USER_TOKEN,
170 | )
171 | if tokenizer.pad_token is None:
172 | tokenizer.pad_token = tokenizer.eos_token
173 | tokenizer.padding_side = "right"
174 |
175 | bnb_config_args: dict = config.get("bnb_config", {})
176 | if bnb_config_args:
177 | bnb_config = BitsAndBytesConfig(
178 | load_in_4bit=bnb_config_args.get("bits", 16) == 4,
179 | load_in_8bit=bnb_config_args.get("bits", 16) == 8,
180 | bnb_4bit_use_double_quant=bnb_config_args.get("use_double_quant", False),
181 | bnb_4bit_quant_type=bnb_config_args.get("quant_type", "nf4"),
182 | bnb_4bit_compute_dtype=torch.bfloat16,
183 | )
184 | else:
185 | bnb_config = None
186 |
187 | device_index = Accelerator().process_index
188 | device_map = {"": device_index}
189 | model = AutoModelForCausalLM.from_pretrained(
190 | config["model"]["model_name"],
191 | **config["model"].get("model_kwargs", {}),
192 | token=HF_USER_TOKEN,
193 | torch_dtype=torch.bfloat16,
194 | quantization_config=bnb_config,
195 | device_map=device_map,
196 | )
197 |
198 | lora_config = LoraConfig(
199 | **config["lora_config"],
200 | target_modules=find_all_linear_names(model, bits=bnb_config_args.get("bits")),
201 | )
202 | model = get_peft_model(model, lora_config)
203 |
204 | return tokenizer, model
205 |
206 |
207 | def extract_instruct_tokens(tokenizer: PreTrainedTokenizer) -> tuple[str, str]:
208 | """Extract the instruction tokens from the tokenizer.
209 |
210 | Args:
211 | tokenizer (PreTrainedTokenizer): The tokenizer to use.
212 |
213 | Returns:
214 | tuple[str, str]: The templates.
215 | """
216 | placeholder = tokenizer.unk_token
217 |
218 | chat_str = tokenizer.apply_chat_template(
219 | [
220 | {"role": "user", "content": placeholder},
221 | {"role": "assistant", "content": placeholder},
222 | ],
223 | tokenize=False,
224 | )
225 |
226 | if not tokenizer.chat_template:
227 | templates = chat_str.split(f" {placeholder} ")
228 | else:
229 | templates = chat_str.split(placeholder)
230 | templates = [t.replace(" ", "").strip() for t in templates]
231 |
232 | return templates[:2]
233 |
234 |
235 | def main(config_path: str):
236 | """Train a model on a dataset using a given configuration.
237 |
238 | Args:
239 | config_path (str): The path to the configuration file.
240 | """
241 | # Load configuration
242 | with open(config_path) as f:
243 | config = yaml.safe_load(f)
244 |
245 | # Load dataset
246 | dataset = load_dataset(config["dataset"])
247 |
248 | train_config: dict = config["train"]
249 |
250 | # Load model
251 | tokenizer, model = load_model(train_config)
252 |
253 | # Create data collator
254 | instr_template, resp_template = extract_instruct_tokens(tokenizer)
255 | data_collator = DataCollatorForCompletionOnlyLM(
256 | response_template=resp_template,
257 | instruction_template=instr_template,
258 | tokenizer=tokenizer,
259 | )
260 |
261 | # Build training arguments
262 | args_config = train_config.get("training_args", {})
263 | training_args = TrainingArguments(**args_config)
264 |
265 | # Create trainer
266 | trainer = SFTTrainer(
267 | model,
268 | args=training_args,
269 | train_dataset=dataset["train"],
270 | eval_dataset=dataset["test"],
271 | data_collator=data_collator,
272 | max_seq_length=train_config["model"].get("max_seq_length", 512),
273 | formatting_func=partial(
274 | preprocess,
275 | tokenizer,
276 | problem_prompt=config["dataset"]["prompts"]["problem"],
277 | domain_prompt=config["dataset"]["prompts"]["domain"],
278 | ),
279 | )
280 | trainer.train()
281 |
282 | trainer.save_model(train_config.get("save_path", "ckpt"))
283 |
284 |
285 | if __name__ == "__main__":
286 | import argparse
287 |
288 | parser = argparse.ArgumentParser("Fine-tune a model on PDDL dataset.")
289 | parser.add_argument(
290 | "-c",
291 | "--config",
292 | type=str,
293 | default="config.yaml",
294 | required=True,
295 | help="Path to the configuration file.",
296 | )
297 |
298 | args = parser.parse_args()
299 |
300 | main(args.config)
301 |
--------------------------------------------------------------------------------
/llm_planner.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | from openai import OpenAI
4 | import torch
5 | from transformers import (
6 | AutoTokenizer,
7 | AutoModelForCausalLM,
8 | PreTrainedTokenizer,
9 | PreTrainedModel,
10 | )
11 |
12 | from vllm import LLM, RequestOutput, SamplingParams
13 | from vllm.lora.request import LoRARequest
14 |
15 |
16 | class PlanningProblem:
17 | def __init__(
18 | self,
19 | natural_language: str,
20 | domain: str,
21 | problem: str,
22 | ):
23 | """Initializes a new PlanningProblem.
24 |
25 | Args:
26 | natural_language (str): The natural language description of the
27 | problem to be solved.
28 | domain (str): A string representation of the domain of the problem.
29 | problem (str): A string representation of the ground truth PDDL.
30 | """
31 | self.natural_language = natural_language
32 | self.domain = domain
33 | self.problem = problem
34 |
35 | def apply_template(
36 | self,
37 | domain_prompt: str = "",
38 | problem_prompt: str = "",
39 | include_answer: bool = True,
40 | ) -> list[dict[str, str]]:
41 | """Apply problem template to the problem.
42 |
43 | Args:
44 | domain_prompt (str, optional): How to prompt the domain. Defaults to "".
45 | problem_prompt (str, optional): How to prompt the problem. Defaults to "".
46 | include_answer (bool, optional): Whether to include the answer. Defaults to True.
47 |
48 | Returns:
49 | list[dict[str, str]]: Problem prompt.
50 | """
51 | return [
52 | {
53 | "role": "user",
54 | "content": f"{problem_prompt} {self.natural_language} "
55 | + f"{domain_prompt}\n{self.domain}\n",
56 | },
57 | ] + (
58 | [
59 | {
60 | "role": "assistant",
61 | "content": " " + self.problem,
62 | },
63 | ]
64 | if include_answer
65 | else []
66 | )
67 |
68 |
69 | class Planner(abc.ABC):
70 | @abc.abstractmethod
71 | def plan_chat(
72 | self,
73 | messages: list[list[dict[str, str]]],
74 | device=None,
75 | max_new_tokens: int = 8_000,
76 | **kwargs,
77 | ) -> list[str]:
78 | """Passes messages to a model for completion.
79 |
80 | Args:
81 | messages (list[list[dict[str, str]]]): A list of messages to be
82 | passed to the model.
83 | device (optional): The device to run the model on. Defaults to None.
84 | max_new_tokens (int): The maximum number of tokens to generate.
85 | Defaults to 8_000.
86 |
87 | Returns:
88 | list[str]: The message completion.
89 | """
90 | pass
91 |
92 |
93 | class HFPlanner:
94 | """A class for planning using Huggingface transformers."""
95 |
96 | def __init__(
97 | self,
98 | model_name: str | None = None,
99 | tokenizer_name: str | None = None,
100 | tokenizer: PreTrainedTokenizer | None = None,
101 | model: PreTrainedModel | None = None,
102 | **kwargs,
103 | ):
104 | """Initializes a new HFPlanner.
105 |
106 | Args:
107 | model_name (str): The name of the model to be used.
108 | tokenizer_name (str, optional): The name of the tokenizer to be used.
109 | Defaults to None, in which case the model_name is used.
110 | kwargs: Additional keyword arguments to be passed to the model.
111 | """
112 | if model is not None and tokenizer is not None:
113 | self.model = model
114 | self.tokenizer = tokenizer
115 | else:
116 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name)
117 | self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
118 |
119 | def plan(self, prompt: str, device=None, **kwargs) -> str:
120 | """Passes the prompt to the model for completion, without applying a
121 | chat template.
122 |
123 | Args:
124 | prompt (str): The prompt to be passed to the model.
125 | device (optional): The device to run the model on. Defaults to None.
126 |
127 | Returns:
128 | str: The message completion.
129 | """
130 | encoded = self.tokenizer.encode(prompt, return_tensors="pt")
131 |
132 | if device is not None:
133 | encoded = encoded.to(device)
134 |
135 | generate_config = {
136 | "max_new_tokens": 4000,
137 | "temperature": 0.0,
138 | "do_sample": False,
139 | }
140 | generate_config.update(kwargs)
141 | generated_ids = self.model.generate(encoded, **generate_config)
142 |
143 | decoded = self.tokenizer.decode(generated_ids)
144 |
145 | return decoded[0]
146 |
147 | def plan_chat(
148 | self,
149 | messages: list[dict[str, str]],
150 | device=None,
151 | **kwargs,
152 | ) -> list[str]:
153 | """Passes messages to the model for completion, applying a chat template.
154 |
155 | Args:
156 | messages (list[dict[str, str]]): A list of messages to be passed to
157 | the model.
158 | device (optional): The device to run the model on. Defaults to None.
159 | kwargs: Additional keyword arguments to be passed to the model.
160 |
161 | Returns:
162 | str: The message completion.
163 | """
164 | encoded = self.tokenizer.apply_chat_template(
165 | messages,
166 | return_tensors="pt",
167 | padding=True,
168 | add_generation_prompt=True,
169 | )
170 |
171 | if device is not None:
172 | encoded = encoded.to(device)
173 |
174 | generate_config = { # default generate config
175 | "max_new_tokens": 4000,
176 | "do_sample": False,
177 | }
178 | generate_config.update(kwargs)
179 | with torch.no_grad():
180 | generated_ids = self.model.generate(encoded, **generate_config)
181 |
182 | decoded = []
183 | for g, e in zip(generated_ids, encoded):
184 | decoded.append(
185 | self.tokenizer.decode(
186 | g[len(e) :],
187 | skip_special_tokens=True,
188 | )
189 | )
190 |
191 | return decoded
192 |
193 |
194 | class VLLMPlanner(Planner):
195 | """A class for planning using VLLM models."""
196 |
197 | def __init__(self, model_name: str, lora: str | None = None, **kwargs):
198 | """Initializes a new VLLMPlanner.
199 |
200 | Args:
201 | model_name (str): The name of the model to be used.
202 | kwargs: Additional keyword arguments to be passed to the model.
203 | """
204 | self.lora = LoRARequest(lora, 1, lora) if lora else None
205 | self.model = LLM(model_name, enable_lora=bool(lora), **kwargs)
206 | self.tokenizer = self.model.get_tokenizer()
207 |
208 | def plan_chat(
209 | self,
210 | messages: list[list[dict[str, str]]],
211 | device=None,
212 | max_new_tokens: int = 8_000,
213 | **kwargs,
214 | ) -> list[str]:
215 | """Passes messages to the model for completion.
216 |
217 | Args:
218 | messages (list[dict[str, str]]): A list of messages to be passed to
219 | the model.
220 |
221 | Returns:
222 | list[str]: The message completions.
223 | """
224 | encoded = self.tokenizer.apply_chat_template(
225 | messages,
226 | add_generation_prompt=True,
227 | tokenize=False,
228 | )
229 | params = SamplingParams(
230 | max_tokens=max_new_tokens,
231 | temperature=kwargs.get("temperature", 0.0),
232 | top_p=kwargs.get("top_p", 1.0),
233 | top_k=kwargs.get("top_k", -1),
234 | min_p=kwargs.get("min_p", 0.0),
235 | )
236 |
237 | outputs: list[RequestOutput] = self.model.generate(
238 | encoded,
239 | params,
240 | use_tqdm=False,
241 | lora_request=self.lora,
242 | )
243 |
244 | return [output.outputs[0].text for output in outputs]
245 |
246 |
247 | class OpenAIPlanner:
248 | """A class for planning using OpenAI models."""
249 |
250 | def __init__(self, model_name: str, **kwargs):
251 | """Initializes a new OpenAIPlanner.
252 |
253 | Args:
254 | model_name (str): The name of the model to be used.
255 | kwargs: Additional keyword arguments to be passed to the OpenAI
256 | client.
257 | """
258 | self.client = OpenAI(**kwargs)
259 | self.model_name = model_name
260 | self.is_o1 = model_name.startswith("o1")
261 |
262 | def _plan_chat(
263 | self,
264 | messages: list[dict[str, str]],
265 | max_new_tokens: int | None = None,
266 | device=None,
267 | **kwargs,
268 | ) -> str:
269 | """Passes messages to the model for completion.
270 |
271 | Args:
272 | messages (list[dict[str, str]]): A list of messages to be passed to
273 | the model.
274 | device (optional): The device to run the model on (ignored for OpenAI).
275 |
276 | Returns:
277 | str: The message completion.
278 | """
279 |
280 | if self.is_o1:
281 | return (
282 | self.client.chat.completions.create(
283 | model=self.model_name,
284 | messages=messages,
285 | frequency_penalty=kwargs.get("frequency_penalty", None),
286 | max_completion_tokens=max_new_tokens,
287 | n=1,
288 | presence_penalty=kwargs.get("presence_penalty", None),
289 | temperature=kwargs.get("temperature", 0.0),
290 | top_p=kwargs.get("top_p", None),
291 | )
292 | .choices[0]
293 | .message.content
294 | )
295 | else:
296 | return (
297 | self.client.chat.completions.create(
298 | model=self.model_name,
299 | messages=messages,
300 | frequency_penalty=kwargs.get("frequency_penalty", None),
301 | max_tokens=max_new_tokens,
302 | n=1,
303 | presence_penalty=kwargs.get("presence_penalty", None),
304 | temperature=kwargs.get("temperature", 0.0),
305 | top_p=kwargs.get("top_p", None),
306 | )
307 | .choices[0]
308 | .message.content
309 | )
310 |
311 | def plan_chat(
312 | self,
313 | messages: list[list[dict[str, str]]],
314 | max_new_tokens: int | None = None,
315 | device=None,
316 | **kwargs,
317 | ) -> list[str]:
318 | """Passes messages to the model for completion.
319 |
320 | Args:
321 | messages (list[list[dict[str, str]]]): A list of messages to be
322 | passed to the model.
323 | device (optional): The device to run the model on (ignored for OpenAI).
324 |
325 | Returns:
326 | list[str]: The message completions.
327 | """
328 | return [
329 | self._plan_chat(
330 | message,
331 | max_new_tokens=max_new_tokens,
332 | device=device,
333 | **kwargs,
334 | )
335 | for message in messages
336 | ]
337 |
--------------------------------------------------------------------------------
/planetarium/ASSUMPTIONS.md:
--------------------------------------------------------------------------------
1 | # Assumptions
2 |
3 | For each domain, the following assumptions are made regarding the types of problems evaluated.
4 | We also assume that problems are solvable via the domain's actions.
5 |
6 | Our equivalence check is based on the assumption that the problems are solvable already (i.e. our evaluator checks solvability before equivalence).
7 |
8 | ## Blocksworld
9 |
10 | Generally, since Blocks World has reversible actions, essentially all problems are evaluable.
11 |
12 | `:init` conditions:
13 | - No two blocks are on top of a single block
14 | - No block is initialized on top of two blocks
15 | - No loops of blocks are made
16 | - Arm can only hold one block
17 |
18 | ## Grippers
19 |
20 | Generally, since Grippers has reversible actions, essentially all problems are evaluable.
21 |
22 | `:init` conditions:
23 | - No double "typing" of objects (we don't using `:typing`, we use certain immutable predicates)
24 | - All balls have only one location
25 | - All grippers have only one location
26 | - All grippers hold up to 1 ball
27 |
28 | ## Rover Single
29 | Rover has the capability of being a much more complex domain, but for the purposes of this benchmark, we work only with a single rover and a single lander.
30 |
31 | `:init` conditions:
32 | - No double `at_*` predicates (rover and lander can only be in one location at a time)
33 |
34 | ## Floortile
35 |
36 | Generally, all valid problems with reachable goal states are evaluable.
37 |
38 | `:init` conditions:
39 | - No robot has two colors (`robot-has`)
40 |
--------------------------------------------------------------------------------
/planetarium/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib import resources
3 |
4 | __all__ = [
5 | "builder",
6 | "downward",
7 | "graph",
8 | "metric",
9 | "oracle",
10 | "evaluate",
11 | "DOMAINS",
12 | ]
13 |
14 | from . import builder
15 | from . import downward
16 | from . import graph
17 | from . import metric
18 | from . import oracle
19 | from . import domains
20 |
21 | DOMAINS = dict()
22 |
23 | # load domains
24 | for domain in resources.files(domains).iterdir():
25 | with domain.open() as f:
26 | DOMAINS[os.path.basename(domain).split(".")[0]] = f.read()
27 |
28 | from .evaluate import evaluate
29 |
--------------------------------------------------------------------------------
/planetarium/builder.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | from collections.abc import Iterable
4 |
5 | from planetarium.graph import ProblemGraph
6 |
7 | from pddl.core import And, Problem
8 | from pddl.logic.predicates import Predicate
9 | from pddl.logic.terms import Constant
10 | from pddl.parser.problem import LenientProblemParser
11 |
12 |
13 | def _constant_to_dict(constant: Constant) -> dict[str, typing.Any]:
14 | """
15 | Convert a PDDL Constant object to a dictionary representation.
16 |
17 | Parameters:
18 | constant (Constant): The PDDL Constant object.
19 |
20 | Returns:
21 | dict: A dictionary containing the constant name and typing information.
22 | """
23 | return {
24 | "name": constant.name,
25 | "typing": set(constant.type_tags),
26 | }
27 |
28 |
29 | def _predicate_to_dict(predicate: Predicate) -> dict[str, typing.Any]:
30 | """
31 | Convert a PDDL Predicate object to a dictionary representation.
32 |
33 | Parameters:
34 | predicate (Predicate): The PDDL Predicate object.
35 |
36 | Returns:
37 | dict: A dictionary containing the predicate name and its parameter names.
38 | """
39 | return {
40 | "typing": predicate.name,
41 | "parameters": [constant.name for constant in predicate.terms],
42 | }
43 |
44 |
45 | def _build_constants(constants: Iterable[Constant]) -> list[dict[str, typing.Any]]:
46 | """
47 | Build a list of dictionaries representing PDDL constants.
48 |
49 | Parameters:
50 | constants (Iterable[Constant]): An iterable of PDDL Constant objects.
51 |
52 | Returns:
53 | list: A list of dictionaries containing constant information.
54 | """
55 | return [_constant_to_dict(constant) for constant in constants]
56 |
57 |
58 | def _build_predicates(
59 | predicates: Iterable[Predicate],
60 | ) -> list[dict[str, typing.Any]]:
61 | """
62 | Build a list of dictionaries representing PDDL predicates.
63 |
64 | Parameters:
65 | predicates (Iterable[Predicate]): An iterable of PDDL Predicate objects.
66 |
67 | Returns:
68 | list: A list of dictionaries containing predicate information.
69 | """
70 | return [_predicate_to_dict(predicate) for predicate in predicates]
71 |
72 |
73 | def build(problem: str) -> ProblemGraph:
74 | """
75 | Build scene graphs from a PDDL problem description.
76 |
77 | Parameters:
78 | problem (str): A string containing the PDDL problem description.
79 |
80 | Returns:
81 | tuple: Two SceneGraph instances representing the initial state and goal state.
82 | """
83 | problem: Problem = LenientProblemParser()(problem)
84 |
85 | if isinstance(problem.goal, Predicate):
86 | goal = [problem.goal]
87 | elif isinstance(problem.goal, And):
88 | goal = problem.goal.operands
89 | else:
90 | raise ValueError(f"Unsupported goal type: {type(problem.goal)}")
91 |
92 | return ProblemGraph(
93 | _build_constants(problem.objects),
94 | _build_predicates(problem.init),
95 | _build_predicates(goal),
96 | domain=problem.domain_name,
97 | requirements=[req.name for req in problem.requirements],
98 | )
99 |
--------------------------------------------------------------------------------
/planetarium/domains/blocksworld.pddl:
--------------------------------------------------------------------------------
1 | ;; source: https://github.com/AI-Planning/pddl-generators/blob/main/blocksworld/domain.pddl
2 | ;; same as used in IPC 2023
3 | ;;
4 | (define (domain blocksworld)
5 |
6 | (:requirements :strips)
7 |
8 | (:predicates
9 | (clear ?x)
10 | (on-table ?x)
11 | (arm-empty)
12 | (holding ?x)
13 | (on ?x ?y)
14 | )
15 |
16 | (:action pickup
17 | :parameters (?ob)
18 | :precondition (and (clear ?ob) (on-table ?ob) (arm-empty))
19 | :effect (and (holding ?ob) (not (clear ?ob)) (not (on-table ?ob))
20 | (not (arm-empty)))
21 | )
22 |
23 | (:action putdown
24 | :parameters (?ob)
25 | :precondition (holding ?ob)
26 | :effect (and (clear ?ob) (arm-empty) (on-table ?ob)
27 | (not (holding ?ob)))
28 | )
29 |
30 | (:action stack
31 | :parameters (?ob ?underob)
32 | :precondition (and (clear ?underob) (holding ?ob))
33 | :effect (and (arm-empty) (clear ?ob) (on ?ob ?underob)
34 | (not (clear ?underob)) (not (holding ?ob)))
35 | )
36 |
37 | (:action unstack
38 | :parameters (?ob ?underob)
39 | :precondition (and (on ?ob ?underob) (clear ?ob) (arm-empty))
40 | :effect (and (holding ?ob) (clear ?underob)
41 | (not (on ?ob ?underob)) (not (clear ?ob)) (not (arm-empty)))
42 | )
43 | )
--------------------------------------------------------------------------------
/planetarium/domains/floor-tile.pddl:
--------------------------------------------------------------------------------
1 | ;; Modified from: https://github.com/AI-Planning/pddl-generators/blob/main/floortile/domain.pddl
2 |
3 | (define (domain floor-tile)
4 | (:requirements :typing)
5 | (:types
6 | robot tile color - object
7 | )
8 |
9 | (:predicates
10 | (robot-at ?r - robot ?x - tile)
11 | (up ?x - tile ?y - tile)
12 | (right ?x - tile ?y - tile)
13 | (painted ?x - tile ?c - color)
14 | (robot-has ?r - robot ?c - color)
15 | (available-color ?c - color)
16 | )
17 |
18 | (:action change-color
19 | :parameters (?r - robot ?c - color ?c2 - color)
20 | :precondition (and (robot-has ?r ?c) (available-color ?c2))
21 | :effect (and (not (robot-has ?r ?c)) (robot-has ?r ?c2)
22 | )
23 | )
24 |
25 | (:action paint-up
26 | :parameters (?r - robot ?y - tile ?x - tile ?c - color)
27 | :precondition (and (robot-has ?r ?c) (robot-at ?r ?x) (up ?y ?x))
28 | :effect (painted ?y ?c)
29 | )
30 |
31 | (:action paint-down
32 | :parameters (?r - robot ?y - tile ?x - tile ?c - color)
33 | :precondition (and (robot-has ?r ?c) (robot-at ?r ?x) (up ?x ?y))
34 | :effect (and (painted ?y ?c)
35 | )
36 | )
37 |
38 | (:action paint-right
39 | :parameters (?r - robot ?y - tile ?x - tile ?c - color)
40 | :precondition (and (robot-has ?r ?c) (robot-at ?r ?x) (right ?y ?x))
41 | :effect (and (painted ?y ?c)
42 | )
43 | )
44 |
45 | (:action paint-left
46 | :parameters (?r - robot ?y - tile ?x - tile ?c - color)
47 | :precondition (and (robot-has ?r ?c) (robot-at ?r ?x) (right ?x ?y))
48 | :effect (and (painted ?y ?c)
49 | )
50 | )
51 |
52 | ; Robot movements
53 | (:action up
54 | :parameters (?r - robot ?x - tile ?y - tile)
55 | :precondition (and (robot-at ?r ?x) (up ?y ?x))
56 | :effect (and (robot-at ?r ?y) (not (robot-at ?r ?x)))
57 | )
58 |
59 | (:action down
60 | :parameters (?r - robot ?x - tile ?y - tile)
61 | :precondition (and (robot-at ?r ?x) (up ?x ?y))
62 | :effect (and (robot-at ?r ?y) (not (robot-at ?r ?x)))
63 | )
64 |
65 | (:action right
66 | :parameters (?r - robot ?x - tile ?y - tile)
67 | :precondition (and (robot-at ?r ?x) (right ?y ?x))
68 | :effect (and (robot-at ?r ?y) (not (robot-at ?r ?x)))
69 | )
70 |
71 | (:action left
72 | :parameters (?r - robot ?x - tile ?y - tile)
73 | :precondition (and (robot-at ?r ?x) (right ?x ?y))
74 | :effect (and (robot-at ?r ?y) (not (robot-at ?r ?x)))
75 | )
76 |
77 | )
--------------------------------------------------------------------------------
/planetarium/domains/gripper.pddl:
--------------------------------------------------------------------------------
1 | ;; source: https://github.com/AI-Planning/pddl-generators/blob/main/gripper/domain.pddl
2 | (define (domain gripper)
3 | (:requirements :strips)
4 | (:predicates
5 | (room ?r)
6 | (ball ?b)
7 | (gripper ?g)
8 | (at-robby ?r)
9 | (at ?b ?r)
10 | (free ?g)
11 | (carry ?o ?g)
12 | )
13 |
14 | (:action move
15 | :parameters (?from ?to)
16 | :precondition (and (room ?from) (room ?to) (at-robby ?from))
17 | :effect (and (at-robby ?to)
18 | (not (at-robby ?from)))
19 | )
20 |
21 | (:action pick
22 | :parameters (?obj ?room ?gripper)
23 | :precondition (and (ball ?obj) (room ?room) (gripper ?gripper)
24 | (at ?obj ?room) (at-robby ?room) (free ?gripper))
25 | :effect (and (carry ?obj ?gripper)
26 | (not (at ?obj ?room))
27 | (not (free ?gripper)))
28 | )
29 |
30 | (:action drop
31 | :parameters (?obj ?room ?gripper)
32 | :precondition (and (ball ?obj) (room ?room) (gripper ?gripper)
33 | (carry ?obj ?gripper) (at-robby ?room))
34 | :effect (and (at ?obj ?room)
35 | (free ?gripper)
36 | (not (carry ?obj ?gripper)))
37 | )
38 | )
--------------------------------------------------------------------------------
/planetarium/domains/rover-single.pddl:
--------------------------------------------------------------------------------
1 | (define (domain rover-single)
2 | (:requirements :strips :typing)
3 | (:types
4 | waypoint camera mode objective
5 | )
6 |
7 | (:predicates
8 | (at_rover ?y - waypoint)
9 | (at_lander ?y - waypoint)
10 | (can_traverse ?x - waypoint ?y - waypoint)
11 | (have_rock_analysis ?w - waypoint)
12 | (have_soil_analysis ?w - waypoint)
13 | (supports ?c - camera ?m - mode)
14 | (available)
15 | (visible ?w - waypoint ?p - waypoint)
16 | (have_image ?o - objective ?m - mode)
17 | (communicated_soil_data ?w - waypoint)
18 | (communicated_rock_data ?w - waypoint)
19 | (communicated_image_data ?o - objective ?m - mode)
20 | (at_rock_sample ?w - waypoint)
21 | (at_soil_sample ?w - waypoint)
22 | (visible_from ?o - objective ?w - waypoint)
23 | (channel_free)
24 | )
25 |
26 | (:action navigate
27 | :parameters (?y - waypoint ?z - waypoint)
28 | :precondition (and (can_traverse ?y ?z) (available) (at_rover ?y)
29 | (visible ?y ?z))
30 | :effect (and (not (at_rover ?y)) (at_rover ?z))
31 | )
32 |
33 | (:action sample_soil
34 | :parameters (?p - waypoint)
35 | :precondition (and (at_rover ?p) (at_soil_sample ?p))
36 | :effect (and (have_soil_analysis ?p))
37 | )
38 |
39 | (:action sample_rock
40 | :parameters (?p - waypoint)
41 | :precondition (and (at_rover ?p) (at_rock_sample ?p))
42 | :effect (and (have_rock_analysis ?p))
43 | )
44 |
45 | (:action take_image
46 | :parameters (?p - waypoint ?o - objective ?i - camera ?m - mode)
47 | :precondition (and (supports ?i ?m) (visible_from ?o ?p) (at_rover ?p))
48 | :effect (have_image ?o ?m)
49 | )
50 |
51 | (:action communicate_soil_data
52 | :parameters (?p - waypoint ?x - waypoint ?y - waypoint)
53 | :precondition (and (at_rover ?x)
54 | (at_lander ?y)(have_soil_analysis ?p)
55 | (visible ?x ?y)(available)(channel_free))
56 | :effect (and (not (available))
57 | (not (channel_free))(channel_free)
58 | (communicated_soil_data ?p)(available))
59 | )
60 |
61 | (:action communicate_rock_data
62 | :parameters (?p - waypoint ?x - waypoint ?y - waypoint)
63 | :precondition (and (at_rover ?x)
64 | (at_lander ?y)(have_rock_analysis ?p)
65 | (visible ?x ?y)(available)(channel_free))
66 | :effect (and (not (available))
67 | (not (channel_free))
68 | (channel_free)(communicated_rock_data ?p)(available))
69 | )
70 |
71 | (:action communicate_image_data
72 | :parameters (?o - objective ?m - mode ?x - waypoint ?y - waypoint)
73 | :precondition (and (at_rover ?x)
74 | (at_lander ?y)(have_image ?o ?m)
75 | (visible ?x ?y)(available)(channel_free))
76 | :effect (and (not (available))
77 | (not (channel_free))(channel_free)
78 | (communicated_image_data ?o ?m)(available))
79 | )
80 | )
--------------------------------------------------------------------------------
/planetarium/domains/rover.pddl:
--------------------------------------------------------------------------------
1 | ;; source: https://github.com/AI-Planning/pddl-generators/blob/main/rovers/domain.pddl
2 | ;; same as used in IPC 2023
3 | ;;
4 | (define (domain rover)
5 | (:requirements :strips :typing)
6 | (:types
7 | rover waypoint store camera mode lander objective
8 | )
9 |
10 | (:predicates
11 | (at ?x - rover ?y - waypoint)
12 | (at_lander ?x - lander ?y - waypoint)
13 | (can_traverse ?r - rover ?x - waypoint ?y - waypoint)
14 | (equipped_for_soil_analysis ?r - rover)
15 | (equipped_for_rock_analysis ?r - rover)
16 | (equipped_for_imaging ?r - rover)
17 | (empty ?s - store)
18 | (have_rock_analysis ?r - rover ?w - waypoint)
19 | (have_soil_analysis ?r - rover ?w - waypoint)
20 | (full ?s - store)
21 | (calibrated ?c - camera ?r - rover)
22 | (supports ?c - camera ?m - mode)
23 | (available ?r - rover)
24 | (visible ?w - waypoint ?p - waypoint)
25 | (have_image ?r - rover ?o - objective ?m - mode)
26 | (communicated_soil_data ?w - waypoint)
27 | (communicated_rock_data ?w - waypoint)
28 | (communicated_image_data ?o - objective ?m - mode)
29 | (at_soil_sample ?w - waypoint)
30 | (at_rock_sample ?w - waypoint)
31 | (visible_from ?o - objective ?w - waypoint)
32 | (store_of ?s - store ?r - rover)
33 | (calibration_target ?i - camera ?o - objective)
34 | (on_board ?i - camera ?r - rover)
35 | (channel_free ?l - lander)
36 | )
37 |
38 | (:action navigate
39 | :parameters (?x - rover ?y - waypoint ?z - waypoint)
40 | :precondition (and (can_traverse ?x ?y ?z) (available ?x) (at ?x ?y)
41 | (visible ?y ?z))
42 | :effect (and (not (at ?x ?y)) (at ?x ?z))
43 | )
44 |
45 | (:action sample_soil
46 | :parameters (?x - rover ?s - store ?p - waypoint)
47 | :precondition (and (at ?x ?p) (at_soil_sample ?p)
48 | (equipped_for_soil_analysis ?x) (store_of ?s ?x) (empty ?s))
49 | :effect (and (not (empty ?s)) (full ?s) (have_soil_analysis ?x ?p)
50 | (not (at_soil_sample ?p)))
51 | )
52 |
53 | (:action sample_rock
54 | :parameters (?x - rover ?s - store ?p - waypoint)
55 | :precondition (and (at ?x ?p) (at_rock_sample ?p)
56 | (equipped_for_rock_analysis ?x) (store_of ?s ?x)(empty ?s))
57 | :effect (and (not (empty ?s)) (full ?s) (have_rock_analysis ?x ?p)
58 | (not (at_rock_sample ?p)))
59 | )
60 |
61 | (:action drop
62 | :parameters (?x - rover ?y - store)
63 | :precondition (and (store_of ?y ?x) (full ?y))
64 | :effect (and (not (full ?y)) (empty ?y))
65 | )
66 |
67 | (:action calibrate
68 | :parameters (?r - rover ?i - camera ?t - objective ?w - waypoint)
69 | :precondition (and (equipped_for_imaging ?r) (calibration_target ?i ?t)
70 | (at ?r ?w) (visible_from ?t ?w)(on_board ?i ?r))
71 | :effect (calibrated ?i ?r)
72 | )
73 |
74 | (:action take_image
75 | :parameters (?r - rover ?p - waypoint ?o - objective ?i - camera ?m - mode)
76 | :precondition (and (calibrated ?i ?r) (on_board ?i ?r) (equipped_for_imaging ?r)
77 | (supports ?i ?m) (visible_from ?o ?p) (at ?r ?p))
78 | :effect (and (have_image ?r ?o ?m)
79 | (not (calibrated ?i ?r)))
80 | )
81 |
82 | (:action communicate_soil_data
83 | :parameters (?r - rover ?l - lander ?p - waypoint ?x - waypoint ?y - waypoint)
84 | :precondition (and (at ?r ?x)
85 | (at_lander ?l ?y)(have_soil_analysis ?r ?p)
86 | (visible ?x ?y)(available ?r)(channel_free ?l))
87 | :effect (and (not (available ?r))
88 | (not (channel_free ?l))(channel_free ?l)
89 | (communicated_soil_data ?p)(available ?r))
90 | )
91 |
92 | (:action communicate_rock_data
93 | :parameters (?r - rover ?l - lander ?p - waypoint ?x - waypoint ?y - waypoint)
94 | :precondition (and (at ?r ?x)
95 | (at_lander ?l ?y)(have_rock_analysis ?r ?p)
96 | (visible ?x ?y)(available ?r)(channel_free ?l))
97 | :effect (and (not (available ?r))
98 | (not (channel_free ?l))
99 | (channel_free ?l)(communicated_rock_data ?p)(available ?r))
100 | )
101 |
102 | (:action communicate_image_data
103 | :parameters (?r - rover ?l - lander ?o - objective ?m - mode ?x - waypoint ?y - waypoint)
104 | :precondition (and (at ?r ?x)
105 | (at_lander ?l ?y)(have_image ?r ?o ?m)
106 | (visible ?x ?y)(available ?r)(channel_free ?l))
107 | :effect (and (not (available ?r))
108 | (not (channel_free ?l))(channel_free ?l)
109 | (communicated_image_data ?o ?m)(available ?r))
110 | )
111 | )
--------------------------------------------------------------------------------
/planetarium/downward.py:
--------------------------------------------------------------------------------
1 | # FastDownward python wrapper
2 |
3 | import glob
4 | import os
5 | import re
6 | import subprocess
7 | import tempfile
8 |
9 |
10 | def _get_best_plan(plan_filepath: str) -> tuple[str | None, float]:
11 | """Get the best plan from a FastDownward plan file.
12 |
13 | Args:
14 | plan_filepath (str): The path to the plan file.
15 |
16 | Returns:
17 | The best plan and its cost.
18 | """
19 |
20 | best_cost = float("inf")
21 | best_plan = None
22 |
23 | for plan_fp in glob.glob(f"{plan_filepath}*"):
24 | with open(plan_fp, "r") as f:
25 | *pddl_plan, cost_str = f.readlines()
26 | match = re.search(r"cost = ([-\d\.]+)", cost_str)
27 | if match:
28 | cost = float(match.group(1))
29 |
30 | if cost < best_cost:
31 | best_cost = cost
32 | best_plan = "\n".join([*pddl_plan, ";"])
33 | return best_plan, best_cost
34 |
35 |
36 | def plan(
37 | domain: str,
38 | problem: str,
39 | downward: str = "downward",
40 | alias: str = "lama",
41 | **kwargs,
42 | ) -> tuple[str | None, float]:
43 | """Find plan using FastDownward.
44 |
45 | Args:
46 | domain (str): A string containing a PDDL domain definition.
47 | problem (str): A string containing a PDDL task/problem definition.
48 | downward (str, optional): Path to FastDownward. Defaults to "downward".
49 | alias (str, optional): The FastDownward alias to. Defaults to "lama".
50 |
51 | Returns:
52 | Returns the PDDL plan string, or `None` if the planner failed, and the
53 | plan cost.
54 | """
55 | with tempfile.TemporaryDirectory() as tmpdir:
56 | domain_filepath = os.path.join(tmpdir, "domain.pddl")
57 | task_filepath = os.path.join(tmpdir, "task.pddl")
58 |
59 | plan_filepath = os.path.join(tmpdir, "plan.pddl")
60 | sas_filepath = os.path.join(tmpdir, "output.sas")
61 |
62 | # build temporary domain and task files
63 | with open(domain_filepath, "w") as f:
64 | f.write(domain)
65 | with open(task_filepath, "w") as f:
66 | f.write(problem)
67 |
68 | # update arguments
69 | kwargs["plan-file"] = plan_filepath
70 | kwargs["sas-file"] = sas_filepath
71 | kwargs["alias"] = alias
72 |
73 | # build FastDownward arguments
74 | downward_args = []
75 | for k, v in kwargs.items():
76 | downward_args.append(f"--{k}")
77 | if isinstance(v, list) or isinstance(v, tuple):
78 | downward_args.extend(str(v_i) for v_i in v)
79 | else:
80 | downward_args.append(str(v))
81 |
82 | # call FastDownward
83 | subprocess.run(
84 | [downward, *downward_args, domain_filepath, task_filepath],
85 | stdout=subprocess.DEVNULL,
86 | stderr=subprocess.PIPE,
87 | check=False,
88 | )
89 |
90 | # get results
91 | best_plan, best_cost = _get_best_plan(plan_filepath)
92 |
93 | return best_plan, best_cost
94 |
95 |
96 | def validate(domain: str, problem: str, plan: str, val: str = "validate"):
97 | """Validate a plan using VAL.
98 |
99 | Args:
100 | domain (str): A string containing a PDDL domain definition.
101 | problem (str): A string containing a PDDL task/problem definition.
102 | plan (str): A string containing a PDDL plan.
103 | val (str, optional): Path to VAL. Defaults to "validate".
104 | """
105 | with tempfile.TemporaryDirectory() as tmpdir:
106 | domain_filepath = os.path.join(tmpdir, "domain.pddl")
107 | task_filepath = os.path.join(tmpdir, "task.pddl")
108 | plan_filepath = os.path.join(tmpdir, "plan.pddl")
109 |
110 | # build temporary domain, task, and plan files
111 | with open(domain_filepath, "w") as f:
112 | f.write(domain)
113 | with open(task_filepath, "w") as f:
114 | f.write(problem)
115 | with open(plan_filepath, "w") as f:
116 | f.write(plan)
117 |
118 | # call VAL
119 | res = subprocess.run(
120 | [val, domain_filepath, task_filepath, plan_filepath],
121 | stdout=subprocess.PIPE,
122 | stderr=subprocess.DEVNULL,
123 | check=False,
124 | )
125 |
126 | return "Plan valid" in res.stdout.decode("utf-8")
127 |
--------------------------------------------------------------------------------
/planetarium/evaluate.py:
--------------------------------------------------------------------------------
1 | import importlib.resources as resources
2 | import os
3 |
4 | from pddl.parser.problem import LenientProblemParser
5 | from pddl.formatter import problem_to_string
6 |
7 | from planetarium import builder, oracle, metric, downward, DOMAINS
8 |
9 |
10 | VALIDATE = os.getenv("VALIDATE", "Validate")
11 | DOWNWARD = os.getenv("DOWNWARD", "downward")
12 |
13 |
14 | def evaluate(
15 | source_pddl_str: str,
16 | target_pddl_str: str,
17 | domain_str: str | None = None,
18 | is_placeholder: bool = False,
19 | check_solveable: bool = True,
20 | val: str = VALIDATE,
21 | fast_downward: str = DOWNWARD,
22 | **downward_args,
23 | ) -> tuple[bool, bool, bool]:
24 | """Evaluate two PDDL problem descriptions for equivalence.
25 |
26 | Args:
27 | source_pddl_str (str): The ground truth problem PDDL string.
28 | target_pddl_str (str): The second problem PDDL string.
29 | domain_str (str): The domain PDDL string.
30 | is_placeholder (bool, optional): Whether or not to treat the ground truth
31 | as a "placeholder" description. Defaults to False.
32 | check_solveable (bool, optional): Whether or not to check if the problem
33 | is solveable. Defaults to True. If False, the function will return
34 | False for the solveable element.
35 |
36 | Returns:
37 | tuple: A tuple containing the following boolean elements:
38 | - parseable: Whether or not the target PDDL string is parseable.
39 | - solveable: Whether or not the target PDDL string is solveable.
40 | - equivalent: Whether or not the PDDL strings are equivalent.
41 | """
42 | parseable = False
43 | solveable = False
44 | equivalent = False
45 |
46 | source_graph = builder.build(source_pddl_str)
47 |
48 | try:
49 | target_graph = builder.build(target_pddl_str)
50 | parseable = True
51 | except Exception as e:
52 | return parseable, solveable, equivalent
53 |
54 | clean_pddl_str = problem_to_string(LenientProblemParser()(target_pddl_str))
55 | domain_str = domain_str or DOMAINS.get(target_graph.domain)
56 |
57 | if check_solveable and isinstance(domain_str, str):
58 | try:
59 | plan_str = oracle.plan_to_string(oracle.plan(target_graph))
60 | except (oracle.DomainNotSupportedError, NotImplementedError):
61 | try:
62 | plan_str, _ = downward.plan(
63 | domain_str,
64 | clean_pddl_str,
65 | downward=fast_downward,
66 | **downward_args,
67 | )
68 | except:
69 | return parseable, solveable, equivalent
70 | except:
71 | return parseable, solveable, equivalent
72 |
73 | try:
74 | if not (
75 | solveable := downward.validate(
76 | domain_str,
77 | clean_pddl_str,
78 | plan_str,
79 | val=val,
80 | )
81 | ):
82 | return parseable, solveable, equivalent
83 | except:
84 | return parseable, solveable, equivalent
85 |
86 | if source_graph == target_graph:
87 | equivalent = True
88 | elif not metric.equals(source_graph.init(), target_graph.init()):
89 | equivalent = False
90 | else:
91 | try:
92 | equivalent = metric.equals(
93 | oracle.fully_specify(source_graph, return_reduced=True),
94 | oracle.fully_specify(target_graph, return_reduced=True),
95 | is_placeholder=is_placeholder,
96 | )
97 | except:
98 | pass
99 |
100 | return parseable, solveable, equivalent
101 |
--------------------------------------------------------------------------------
/planetarium/graph.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Iterable
2 |
3 | import abc
4 | import enum
5 | from functools import cached_property
6 |
7 | import matplotlib.pyplot as plt
8 | import networkx as nx
9 | import rustworkx as rx
10 |
11 | from pddl.core import And, Problem, Domain
12 | from pddl.logic.predicates import Predicate
13 | from pddl.logic.terms import Constant
14 | from pddl.formatter import problem_to_string
15 |
16 |
17 | class Label(str, enum.Enum):
18 | CONSTANT = "constant"
19 | PREDICATE = "predicate"
20 |
21 |
22 | class Scene(str, enum.Enum):
23 | INIT = "init"
24 | GOAL = "goal"
25 |
26 |
27 | class PlanGraphNode:
28 | def __init__(
29 | self,
30 | node: str,
31 | name: str,
32 | label: Label,
33 | typing: set | str | None = None,
34 | scene: Scene | None = None,
35 | ):
36 | self.node = node
37 | self.name = name
38 | self.label = label
39 | self.typing = typing
40 | self.scene = scene
41 |
42 | def __eq__(self, other: "PlanGraphNode") -> bool:
43 | return (
44 | isinstance(other, PlanGraphNode)
45 | and self.node == other.node
46 | and self.name == other.name
47 | and self.label == other.label
48 | and self.typing == other.typing
49 | and self.scene == other.scene
50 | )
51 |
52 | def __hash__(self) -> int:
53 | return hash((self.name, self.label, (*sorted(self.typing),), self.scene))
54 |
55 | def __repr__(self) -> str:
56 | return f"PlanGraphNode(node={self.node}, name={self.name}, label={self.label}, typing={self.typing}, scene={self.scene})"
57 |
58 | def __str__(self) -> str:
59 | return f"PlanGraphNode(node={self.node}, name={self.name}, label={self.label}, typing={self.typing}, scene={self.scene})"
60 |
61 |
62 | class PlanGraphEdge:
63 | def __init__(
64 | self,
65 | predicate: str,
66 | position: int | None = None,
67 | scene: Scene | None = None,
68 | ):
69 | self.predicate = predicate
70 | self.position = position
71 | self.scene = scene
72 |
73 | def __eq__(self, other: "PlanGraphEdge") -> bool:
74 | return (
75 | isinstance(other, PlanGraphEdge)
76 | and self.predicate == other.predicate
77 | and self.position == other.position
78 | and self.scene == other.scene
79 | )
80 |
81 | def __hash__(self) -> int:
82 | return hash((self.predicate, self.position, self.scene))
83 |
84 | def __repr__(self) -> str:
85 | return f"PlanGraphEdge(predicate={self.predicate}, position={self.position}, scene={self.scene})"
86 |
87 | def __str__(self) -> str:
88 | return f"PlanGraphEdge(predicate={self.predicate}, position={self.position}, scene={self.scene})"
89 |
90 |
91 | class PlanGraph(metaclass=abc.ABCMeta):
92 | """
93 | Subclass of rx.PyDiGraph representing a scene graph.
94 |
95 | Attributes:
96 | constants (property): A dictionary of constant nodes in the scene graph.
97 | domain (property): The domain of the scene graph.
98 | """
99 |
100 | def __init__(
101 | self,
102 | constants: list[dict[str, Any]],
103 | domain: str | None = None,
104 | requirements: tuple[str] = (),
105 | ):
106 | """
107 | Initialize the SceneGraph instance.
108 |
109 | Parameters:
110 | constants (list): List of dictionaries representing constants.
111 | domain (str, optional): The domain of the scene graph.
112 | Defaults to None.
113 | requirements (list, optional): List of requirements for the scene
114 | graph.
115 | """
116 | super().__init__()
117 |
118 | self._constants: list[dict[str, Any]] = []
119 | self._constant_nodes: list[PlanGraphNode] = []
120 | self._predicates: list[dict[str, Any]] = []
121 | self._predicate_nodes: list[PlanGraphNode] = []
122 | self._node_lookup: dict[str, tuple[int, PlanGraphNode]] = {}
123 | self._nodes: list[PlanGraphNode] = []
124 | self._edges: set[tuple[int, int, PlanGraphEdge]] = set()
125 | self._domain: str = domain
126 | self._requirements: tuple[str] = requirements
127 | self.graph = rx.PyDiGraph()
128 |
129 | for constant in constants:
130 | self.add_node(
131 | PlanGraphNode(
132 | constant["name"],
133 | name=constant["name"],
134 | label=Label.CONSTANT,
135 | typing=constant["typing"],
136 | )
137 | )
138 |
139 | @cached_property
140 | def nodes(self) -> list[PlanGraphNode]:
141 | return self._nodes
142 |
143 | @cached_property
144 | def edges(self) -> set[tuple[PlanGraphNode, PlanGraphNode, PlanGraphEdge]]:
145 | return self._edges
146 |
147 | def add_node(self, node: PlanGraphNode):
148 | if node in self.nodes:
149 | raise ValueError(f"Node {node} already exists in the graph.")
150 | index = self.graph.add_node(node)
151 |
152 | if node.label == Label.CONSTANT:
153 | self._constants.append({"name": node.name, "typing": node.typing})
154 | self._constant_nodes.append(node)
155 | elif node.label == Label.PREDICATE:
156 | self._predicate_nodes.append(node)
157 |
158 | self._nodes.append(node)
159 | self._node_lookup[node.node] = (index, node)
160 |
161 | def has_edge(
162 | self,
163 | u: str | PlanGraphNode,
164 | v: str | PlanGraphNode,
165 | edge: PlanGraphEdge | None = None,
166 | ) -> bool:
167 | if isinstance(u, PlanGraphNode):
168 | u_index = self.nodes.index(u)
169 | else:
170 | u_index, _ = self._node_lookup[u]
171 |
172 | if isinstance(v, PlanGraphNode):
173 | v_index = self.nodes.index(v)
174 | else:
175 | v_index, _ = self._node_lookup[v]
176 |
177 | if edge:
178 | return (u_index, v_index, edge) in self.graph.edge_index_map().values()
179 | else:
180 | return self.graph.has_edge(u_index, v_index)
181 |
182 | def add_edge(
183 | self, u: str | PlanGraphNode, v: str | PlanGraphNode, edge: PlanGraphEdge
184 | ):
185 | if isinstance(u, PlanGraphNode):
186 | u_index = self.nodes.index(u)
187 | else:
188 | u_index, u = self._node_lookup[u]
189 |
190 | if isinstance(v, PlanGraphNode):
191 | v_index = self.nodes.index(v)
192 | else:
193 | v_index, v = self._node_lookup[v]
194 |
195 | self.graph.add_edge(u_index, v_index, edge)
196 | self._edges.add((u, v, edge))
197 |
198 | def _add_predicate(
199 | self,
200 | predicate: dict[str, Any],
201 | scene: Scene | None = None,
202 | ):
203 | """
204 | Add a predicate to the plan graph.
205 |
206 | Parameters:
207 | predicate (dict): A dictionary representing the predicate.
208 | scene (Scene, optional): The scene in which the predicate occurs.
209 | """
210 | if scene:
211 | predicate.update({"scene": scene})
212 | if predicate in self.predicates:
213 | return
214 | predicate_name = self._build_unique_predicate_name(
215 | predicate_name=predicate["typing"],
216 | argument_names=predicate["parameters"],
217 | )
218 | self.add_node(
219 | PlanGraphNode(
220 | predicate_name,
221 | name=predicate_name,
222 | label=Label.PREDICATE,
223 | typing=predicate["typing"],
224 | scene=scene,
225 | )
226 | )
227 |
228 | for position, parameter_name in enumerate(predicate["parameters"]):
229 | if parameter_name not in [node.name for node in self.constant_nodes]:
230 | raise ValueError(f"Parameter {parameter_name} not found in constants.")
231 | self.add_edge(
232 | predicate_name,
233 | parameter_name,
234 | PlanGraphEdge(
235 | predicate=predicate["typing"],
236 | position=position,
237 | scene=scene,
238 | ),
239 | )
240 |
241 | self._predicates.append(predicate)
242 |
243 | def in_degree(self, node: str | PlanGraphNode) -> int:
244 | if isinstance(node, PlanGraphNode):
245 | return self.graph.in_degree(self.nodes.index(node))
246 | else:
247 | return self.graph.in_degree(self._node_lookup[node][0])
248 |
249 | def out_degree(self, node: str | PlanGraphNode) -> int:
250 | if isinstance(node, PlanGraphNode):
251 | return self.graph.out_degree(self.nodes.index(node))
252 | else:
253 | return self.graph.out_degree(self._node_lookup[node][0])
254 |
255 | def predecessors(self, node: str | PlanGraphNode) -> list[PlanGraphNode]:
256 | if isinstance(node, PlanGraphNode):
257 | preds = self.graph.predecessors(self.nodes.index(node))
258 | else:
259 | preds = self.graph.predecessors(self._node_lookup[node][0])
260 |
261 | return preds
262 |
263 | def successors(self, node: str | PlanGraphNode) -> list[PlanGraphNode]:
264 | if isinstance(node, PlanGraphNode):
265 | succs = self.graph.successors(self.nodes.index(node))
266 | else:
267 | succs = self.graph.successors(self._node_lookup[node][0])
268 |
269 | return succs
270 |
271 | def in_edges(
272 | self, node: str | PlanGraphNode
273 | ) -> list[tuple[PlanGraphNode, PlanGraphEdge]]:
274 | if isinstance(node, PlanGraphNode):
275 | edges = self.graph.in_edges(self.nodes.index(node))
276 | else:
277 | edges = self.graph.in_edges(self._node_lookup[node][0])
278 |
279 | return [(self.nodes[u], edge) for u, _, edge in edges]
280 |
281 | def out_edges(
282 | self, node: str | PlanGraphNode
283 | ) -> list[tuple[PlanGraphNode, PlanGraphEdge]]:
284 | if isinstance(node, PlanGraphNode):
285 | edges = self.graph.out_edges(self.nodes.index(node))
286 | else:
287 | edges = self.graph.out_edges(self._node_lookup[node][0])
288 |
289 | return [(self.nodes[v], edge) for _, v, edge in edges]
290 |
291 | @staticmethod
292 | def _build_unique_predicate_name(
293 | predicate_name: str, argument_names: Iterable[str]
294 | ) -> str:
295 | """
296 | Build a unique name for a predicate based on its name and argument names.
297 |
298 | Parameters:
299 | predicate_name (str): The name of the predicate.
300 | argument_names (Iterable[str]): Sequence of argument names
301 | for the predicate.
302 |
303 | Returns:
304 | str: The unique name for the predicate.
305 | """
306 | return "-".join([predicate_name, *argument_names])
307 |
308 | @cached_property
309 | def domain(self) -> str | None:
310 | """
311 | Get the domain of the scene graph.
312 |
313 | Returns:
314 | str: The domain of the scene graph.
315 | """
316 | return self._domain
317 |
318 | @cached_property
319 | def constant_nodes(self) -> list[PlanGraphNode]:
320 | """Get a list of constant nodes in the scene graph.
321 |
322 | Returns:
323 | list[PlanGraphNode]: A list of constant nodes.
324 | """
325 | return self._constant_nodes
326 |
327 | @property
328 | def constants(self) -> list[dict[str, Any]]:
329 | return self._constants
330 |
331 | @property
332 | def predicate_nodes(self) -> list[PlanGraphNode]:
333 | """Get a list of predicate nodes in the scene graph.
334 |
335 | Returns:
336 | list[PlanGraphNode]: A list of predicate nodes.
337 | """
338 | return self._predicate_nodes
339 |
340 | @property
341 | def predicates(self) -> list[dict[str, Any]]:
342 | return self._predicates
343 |
344 | def __eq__(self, other: "PlanGraph") -> bool:
345 | """
346 | Check if two plan graphs are equal.
347 |
348 | Parameters:
349 | other (PlanGraph): The other plan graph to compare.
350 |
351 | Returns:
352 | bool: True if the plan graphs are equal, False otherwise.
353 | """
354 | return (
355 | isinstance(other, PlanGraph)
356 | and set(self.nodes) == set(other.nodes)
357 | and set(self.edges) == set(other.edges)
358 | and self.domain == other.domain
359 | and set(self._requirements) == set(other._requirements)
360 | )
361 |
362 | def plot(self, fig: plt.Figure | None = None) -> plt.Figure:
363 | """Generate a plot of the graph, sorted by topological generation.
364 |
365 | Args:
366 | fig (plt.Figure | None, optional): The figure to plot on. Defaults
367 | to None.
368 |
369 | Returns:
370 | plt.Figure: The figure containing the plot.
371 | """
372 | # rx has no plotting functionality
373 | nx_graph = nx.MultiDiGraph()
374 | nx_graph.add_edges_from(
375 | [(u.node, v.node, {"data": edge}) for u, v, edge in self.edges]
376 | )
377 |
378 | for layer, nodes in enumerate(nx.topological_generations(nx_graph)):
379 | for node in nodes:
380 | nx_graph.nodes[node]["layer"] = layer
381 |
382 | pos = nx.multipartite_layout(
383 | nx_graph,
384 | align="horizontal",
385 | subset_key="layer",
386 | scale=-1,
387 | )
388 |
389 | fig = fig or plt.figure()
390 |
391 | nx.draw(nx_graph, pos=pos, ax=fig.gca(), with_labels=True)
392 |
393 | return fig
394 |
395 |
396 | class SceneGraph(PlanGraph):
397 | """
398 | Subclass of PlanGraph representing a scene graph.
399 |
400 | Attributes:
401 | constants (property): A dictionary of constant nodes in the scene graph.
402 | domain (property): The domain of the scene graph.
403 | """
404 |
405 | def __init__(
406 | self,
407 | constants: list[dict[str, Any]],
408 | predicates: list[dict[str, Any]],
409 | domain: str | None = None,
410 | scene: Scene | None = None,
411 | requirements: tuple[str] = (),
412 | ):
413 | """
414 | Initialize the SceneGraph instance.
415 |
416 | Parameters:
417 | constants (list): List of dictionaries representing constants.
418 | predicates (list): List of dictionaries representing predicates.
419 | domain (str, optional): The domain of the scene graph.
420 | Defaults to None.
421 | scene (str, optional): The scene of the scene graph.
422 | Defaults to None.
423 | requirements (list, optional): List of requirements for the scene
424 | graph.
425 | """
426 |
427 | super().__init__(constants, domain=domain, requirements=requirements)
428 |
429 | self.scene = scene
430 |
431 | for predicate in predicates:
432 | self._add_predicate(predicate, scene=scene)
433 |
434 |
435 | class ProblemGraph(PlanGraph):
436 | """
437 | Subclass of PlanGraph representing a scene graph.
438 |
439 | Attributes:
440 | constants (property): A dictionary of constant nodes in the scene graph.
441 | init_predicates (property): A dictionary of predicate nodes in the initial scene graph.
442 | goal_predicates (property): A dictionary of predicate nodes in the goal scene graph.
443 | """
444 |
445 | def __init__(
446 | self,
447 | constants: list[dict[str, Any]],
448 | init_predicates: list[dict[str, Any]],
449 | goal_predicates: list[dict[str, Any]],
450 | domain: str | None = None,
451 | requirements: tuple[str] = (),
452 | ):
453 | """
454 | Initialize the ProblemGraph instance.
455 |
456 | Parameters:
457 | constants (list): List of dictionaries representing constants.
458 | init_predicates (list): List of dictionaries representing predicates
459 | in the initial scene.
460 | goal_predicates (list): List of dictionaries representing predicates
461 | in the goal scene.
462 | domain (str, optional): The domain of the scene graph.
463 | Defaults to None.
464 | """
465 | super().__init__(constants, domain=domain, requirements=requirements)
466 |
467 | self._init_predicates: list[dict[str, Any]] = []
468 | self._init_predicate_nodes: list[PlanGraphNode] = []
469 | self._goal_predicates: list[dict[str, Any]] = []
470 | self._goal_predicate_nodes: list[PlanGraphNode] = []
471 |
472 | for predicate in init_predicates:
473 | self._add_predicate(predicate, scene=Scene.INIT)
474 |
475 | for predicate in goal_predicates:
476 | self._add_predicate(predicate, scene=Scene.GOAL)
477 |
478 | def __eq__(self, other: "ProblemGraph") -> bool:
479 | return (
480 | super().__eq__(other)
481 | and set(self.init_predicate_nodes) == set(other.init_predicate_nodes)
482 | and set(self.goal_predicate_nodes) == set(other.goal_predicate_nodes)
483 | )
484 |
485 | def add_node(self, node: PlanGraphNode):
486 | super().add_node(node)
487 | if node.label == Label.PREDICATE:
488 | if node.scene == Scene.INIT:
489 | self._init_predicate_nodes.append(node)
490 | elif node.scene == Scene.GOAL:
491 | self._goal_predicate_nodes.append(node)
492 |
493 | def _add_predicate(self, predicate: dict[str, Any], scene: Scene | None = None):
494 | super()._add_predicate(predicate, scene)
495 |
496 | if scene == Scene.INIT:
497 | self._init_predicates.append(predicate)
498 | elif scene == Scene.GOAL:
499 | self._goal_predicates.append(predicate)
500 |
501 | @property
502 | def init_predicate_nodes(self) -> list[PlanGraphNode]:
503 | """Get a list of predicate nodes in the initial scene.
504 |
505 | Returns:
506 | list[PlanGraphNode]: A list of predicate nodes in the initial scene.
507 | """
508 | return self._init_predicate_nodes
509 |
510 | @property
511 | def goal_predicate_nodes(self) -> list[PlanGraphNode]:
512 | """Get a list of predicate nodes in the goal scene.
513 |
514 | Returns:
515 | list[PlanGraphNode]: A list of predicate nodes in the goal scene.
516 | """
517 | return self._goal_predicate_nodes
518 |
519 | @property
520 | def init_predicates(self) -> list[dict[str, Any]]:
521 | return self._init_predicates
522 |
523 | @property
524 | def goal_predicates(self) -> list[dict[str, Any]]:
525 | return self._goal_predicates
526 |
527 | def init(self) -> SceneGraph:
528 | """Return the initial scene graph.
529 |
530 | Returns:
531 | SceneGraph: The initial scene graph.
532 | """
533 | return SceneGraph(
534 | constants=self.constants,
535 | predicates=self.init_predicates,
536 | domain=self.domain,
537 | scene=Scene.INIT,
538 | requirements=self._requirements,
539 | )
540 |
541 | def goal(self) -> SceneGraph:
542 | """Return the goal scene graph.
543 |
544 | Returns:
545 | SceneGraph: The goal scene graph.
546 | """
547 | return SceneGraph(
548 | constants=self.constants,
549 | predicates=self.goal_predicates,
550 | domain=self.domain,
551 | scene=Scene.GOAL,
552 | requirements=self._requirements,
553 | )
554 |
555 | def decompose(self) -> tuple[SceneGraph, SceneGraph]:
556 | """Decompose the problem graph into initial and goal scene graphs.
557 |
558 | Returns:
559 | tuple[SceneGraph, SceneGraph]: A tuple containing the initial and goal scene graphs.
560 | """
561 |
562 | init_scene = self.init()
563 | goal_scene = self.goal()
564 |
565 | return init_scene, goal_scene
566 |
567 | @staticmethod
568 | def join(init: SceneGraph, goal: SceneGraph) -> "ProblemGraph":
569 | """
570 | Combine initial and goal scene graphs into a problem graph.
571 |
572 | Parameters:
573 | init (SceneGraph): The initial scene graph.
574 | goal (SceneGraph): The goal scene graph.
575 |
576 | Returns:
577 | ProblemGraph: The combined problem graph.
578 | """
579 | return ProblemGraph(
580 | constants=init.constants,
581 | init_predicates=init.predicates,
582 | goal_predicates=goal.predicates,
583 | domain=init.domain,
584 | requirements=init._requirements,
585 | )
586 |
587 | def to_pddl_str(self) -> str:
588 | """
589 | Convert a ProblemGraph object to a PDDL problem description string.
590 |
591 | NOTE: REQUIREMENTS ARE NOT SUPPORTED YET.
592 |
593 | Parameters:
594 | graph (ProblemGraph): The ProblemGraph object to convert.
595 |
596 | Returns:
597 | str: A string containing the PDDL problem description.
598 | """
599 | constant_objs = [
600 | Constant(name=n.name)
601 | for n in self.nodes
602 | if n.label == Label.CONSTANT
603 | ]
604 |
605 | init_predicates = []
606 | goal_predicates = []
607 | for n in (n for n in self.nodes if n.label == Label.PREDICATE):
608 | args: list[PlanGraphNode] = [
609 | v for v, _ in sorted(self.out_edges(n), key=lambda e: e[1].position)
610 | ]
611 | pddl_args = [
612 | Constant(name=v.name)
613 | for v in args
614 | ]
615 |
616 | predicate = Predicate(n.typing, *pddl_args)
617 | if n.scene == Scene.INIT:
618 | init_predicates.append(predicate)
619 | elif n.scene == Scene.GOAL:
620 | goal_predicates.append(predicate)
621 |
622 | return problem_to_string(
623 | Problem(
624 | name="name",
625 | domain=Domain(name=self.domain, requirements=[]),
626 | objects=constant_objs,
627 | init=sorted(init_predicates),
628 | goal=And(*sorted(goal_predicates)),
629 | requirements=[],
630 | )
631 | )
632 |
--------------------------------------------------------------------------------
/planetarium/metric.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import rustworkx as rx
3 |
4 | from planetarium import graph
5 |
6 |
7 | def _preserves_mapping(
8 | source: graph.PlanGraphNode,
9 | target: graph.PlanGraphNode,
10 | mapping: dict,
11 | ) -> bool:
12 | """
13 | Check if a mapping is preserved between the nodes.
14 |
15 | Parameters:
16 | source (graph.PlanGraphNode): The source node.
17 | target (graph.PlanGraphNode): The target node.
18 | mapping (dict): The mapping between node names.
19 |
20 | Returns:
21 | bool: True if the mapping preserves names, False otherwise.
22 | """
23 | return (
24 | source.label == graph.Label.CONSTANT
25 | and target.label == graph.Label.CONSTANT
26 | and mapping[source.name] == target.name
27 | )
28 |
29 |
30 | def _same_typing(source: graph.PlanGraphNode, target: graph.PlanGraphNode) -> bool:
31 | """
32 | Check if the typing of two nodes is the same.
33 |
34 | Parameters:
35 | source (graph.PlanGraphNode): The source node.
36 | target (graph.PlanGraphNode): The target node.
37 |
38 | Returns:
39 | bool: True if typings are the same, False otherwise.
40 | """
41 | return (
42 | source.label == graph.Label.CONSTANT
43 | and target.label == graph.Label.CONSTANT
44 | and source.typing == target.typing
45 | )
46 |
47 |
48 | def _node_matching(
49 | source: graph.PlanGraphNode,
50 | target: graph.PlanGraphNode,
51 | mapping: dict | None,
52 | ) -> bool:
53 | """
54 | Check if two nodes match based on their labels, positions, and typings.
55 |
56 | Parameters:
57 | source (graph.PlanGraphNode): The source node.
58 | target (graph.PlanGraphNode): The target node.
59 | mapping (dict | None): The mapping between node names.
60 |
61 | Returns:
62 | bool: True if nodes match, False otherwise.
63 | """
64 | match (source.label, target.label):
65 | case (graph.Label.CONSTANT, graph.Label.CONSTANT):
66 | return _same_typing(source, target) and (
67 | _preserves_mapping(source, target, mapping) if mapping else True
68 | )
69 | case (graph.Label.PREDICATE, graph.Label.PREDICATE):
70 | # type of predicate should be the same as well
71 | return source.typing == target.typing
72 | case _:
73 | return False
74 |
75 |
76 | def _edge_matching(
77 | source: graph.PlanGraphEdge,
78 | target: graph.PlanGraphEdge,
79 | attributes: dict[str, str | int | graph.Scene, graph.Label] = {},
80 | ) -> bool:
81 | """
82 | Check if two edges match based on their attributes.
83 |
84 | Parameters:
85 | source (graph.PlanGraphEdge): The source edge.
86 | target (graph.PlanGraphEdge): The target edge.
87 | attributes (dict): The attributes to match.
88 |
89 | Returns:
90 | bool: True if edges match, False otherwise.
91 | """
92 |
93 | def _getattr(obj, attr):
94 | v = getattr(obj, attr, attributes.get(attr))
95 | return v
96 |
97 | return all(_getattr(source, attr) == _getattr(target, attr) for attr in attributes)
98 |
99 |
100 | def isomorphic(
101 | source: graph.ProblemGraph | graph.SceneGraph,
102 | target: graph.ProblemGraph | graph.SceneGraph,
103 | mapping: dict | None = None,
104 | ) -> bool:
105 | """
106 | Find all valid isomorphic mappings between nodes of two scene graphs.
107 |
108 | Parameters:
109 | source (ProblemGraph): The source problem graph.
110 | target (ProblemGraph): The target problem graph.
111 | mapping (dict | None): The initial mapping between node names.
112 |
113 | Returns:
114 | bool: True if there is a valid mapping, False otherwise.
115 | """
116 | node_matching = functools.partial(_node_matching, mapping=mapping)
117 | edge_matching = functools.partial(
118 | _edge_matching,
119 | attributes={"position": -1, "predicate": "", "scene": None},
120 | )
121 |
122 | return rx.is_isomorphic(
123 | source.graph,
124 | target.graph,
125 | node_matcher=node_matching,
126 | edge_matcher=edge_matching,
127 | )
128 |
129 |
130 | def equals(
131 | source: graph.ProblemGraph,
132 | target: graph.ProblemGraph,
133 | is_placeholder: bool = False,
134 | ) -> bool:
135 | """
136 | Check if there is a valid mapping between problem graphs.
137 |
138 | Parameters:
139 | source (ProblemGraph | SceneGraph): The initial problem graph.
140 | target (ProblemGraph | SceneGraph): The goal problem graph.
141 | is_placeholder (bool): If False, the function will compare the initial
142 | and goal scene graphs together. If True, the function will compare
143 | the two initial scene graphs and the two goal scene graphs
144 | separately.
145 |
146 | Returns:
147 | bool: True if there is a valid mapping, False otherwise.
148 | """
149 | if source == target:
150 | return True
151 | if not is_placeholder:
152 | return isomorphic(source, target)
153 | else:
154 | source_init, source_goal = source.decompose()
155 | target_init, target_goal = target.decompose()
156 |
157 | if source_init == target_init and source_goal == target_goal:
158 | return True
159 |
160 | valid_init = isomorphic(source_init, target_init)
161 | valid_goal = isomorphic(source_goal, target_goal)
162 |
163 | return valid_init and valid_goal
164 |
--------------------------------------------------------------------------------
/planetarium/oracle.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import jinja2 as jinja
4 | from pddl.core import Action
5 |
6 | from planetarium import graph
7 |
8 | from .reduced_graph import ReducedSceneGraph, ReducedProblemGraph
9 | from .oracles import ORACLES
10 |
11 |
12 | plan_template = jinja.Template(
13 | """
14 | {%- for action in actions -%}
15 | ({{ action.name }} {{ action.parameters | join(", ") }})
16 | {% endfor %}
17 | """
18 | )
19 |
20 |
21 | class DomainNotSupportedError(Exception):
22 | pass
23 |
24 |
25 | def reduce(
26 | graph: graph.SceneGraph,
27 | domain: str | None = None,
28 | ) -> ReducedSceneGraph | ReducedProblemGraph:
29 | """Reduces a scene graph to a Directed Acyclic Graph.
30 |
31 | Args:
32 | graph (graph.SceneGraph): The scene graph to reduce.
33 | domain (str, optional): The domain of the scene graph.
34 |
35 | Returns:
36 | ReducedGraph: The reduced problem graph.
37 | """
38 | domain = domain or graph.domain
39 | if oracle := ORACLES.get(domain):
40 | return oracle.reduce(graph)
41 | raise DomainNotSupportedError(f"Domain {domain} not supported.")
42 |
43 |
44 | def inflate(
45 | scene: ReducedSceneGraph | ReducedProblemGraph,
46 | domain: str | None = None,
47 | ) -> graph.SceneGraph:
48 | """Inflate a reduced scene graph to a SceneGraph.
49 |
50 | Args:
51 | scene (ReducedGraph): The reduced scene graph to respecify.
52 | domain (str | None, optional): The domain of the scene graph. Defaults
53 | to None.
54 |
55 | Returns:
56 | graph.SceneGraph: The respecified, inflated scene graph.
57 | """
58 | domain = domain or scene._domain
59 | if oracle := ORACLES.get(domain):
60 | return oracle.inflate(scene)
61 | raise DomainNotSupportedError(f"Domain {domain} not supported.")
62 |
63 |
64 | def fully_specify(
65 | problem: graph.ProblemGraph,
66 | domain: str | None = None,
67 | return_reduced: bool = False,
68 | ) -> graph.ProblemGraph | ReducedProblemGraph:
69 | """Fully specifies a goal state.
70 |
71 | Args:
72 | problem (graph.ProblemGraph): The problem graph with the goal state to
73 | fully specify.
74 | domain (str | None, optional): The domain of the scene graph. Defaults
75 | to None.
76 | return_reduced (bool, optional): Whether to return the reduced scene
77 | graph. Defaults to False.
78 |
79 | Returns:
80 | graph.ProblemGraph: The fully specified problem graph.
81 | """
82 | domain = domain or problem.domain
83 |
84 | if oracle := ORACLES.get(domain):
85 | return oracle.fully_specify(problem, return_reduced=return_reduced)
86 | raise DomainNotSupportedError(f"Domain {domain} not supported.")
87 |
88 |
89 | def plan(problem: graph.ProblemGraph, domain: str | None = None) -> list[Action]:
90 | """Plans a sequence of actions to solve a problem.
91 |
92 | Args:
93 | problem (graph.ProblemGraph): The problem to plan for.
94 |
95 | Returns:
96 | str: The sequence of actions to solve the problem.
97 | """
98 | domain = domain or problem.domain
99 | if oracle := ORACLES.get(domain):
100 | return oracle.plan(problem)
101 | raise DomainNotSupportedError(f"Domain {domain} not supported.")
102 |
103 |
104 | def plan_to_string(actions: list[Action]) -> str:
105 | """Converts a list of actions to a string.
106 |
107 | Args:
108 | actions (list[Action]): The list of actions to convert.
109 |
110 | Returns:
111 | str: The string representation of the actions.
112 | """
113 | return plan_template.render(actions=actions)
114 |
--------------------------------------------------------------------------------
/planetarium/oracles/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ["blocksworld", "gripper", "rover_single", "floortile"]
2 |
3 | from . import oracle
4 |
5 | from . import blocksworld
6 | from . import gripper
7 | from . import rover_single
8 | from . import floortile
9 |
10 | ORACLES: dict[str, oracle.Oracle] = {
11 | "blocksworld": blocksworld.BlocksworldOracle,
12 | "gripper": gripper.GripperOracle,
13 | "rover-single": rover_single.RoverSingleOracle,
14 | "floor-tile": floortile.FloorTileOracle,
15 | }
16 |
--------------------------------------------------------------------------------
/planetarium/oracles/blocksworld.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 |
3 | from pddl.core import Action
4 | import rustworkx as rx
5 |
6 | from . import oracle
7 | from .. import graph
8 | from ..reduced_graph import ReducedSceneGraph, ReducedProblemGraph, ReducedNode
9 |
10 |
11 | # Add enum for blocksworld domain
12 | ReducedNode.register(
13 | {
14 | "TABLE": "table",
15 | "CLEAR": "clear",
16 | "ARM": "arm",
17 | },
18 | "blocksworld",
19 | )
20 |
21 |
22 | class BlocksworldOracle(oracle.Oracle):
23 |
24 | def reduce(
25 | scene: graph.SceneGraph | graph.ProblemGraph,
26 | ) -> ReducedSceneGraph | ReducedProblemGraph:
27 | """Reduces a blocksworld scene graph to a Directed Acyclic Graph.
28 |
29 | Args:
30 | problem (graph.SceneGraph | graph.ProblemGraph): The scene graph to
31 | reduce.
32 |
33 | Returns:
34 | ReducedGraph: The reduced problem graph.
35 | """
36 |
37 | nodes = defaultdict(list)
38 | for node in scene.nodes:
39 | nodes[node.label].append(node)
40 |
41 | match scene:
42 | case graph.ProblemGraph(
43 | _constants=constants,
44 | _predicates=predicates,
45 | _domain=domain,
46 | _requirements=requirements,
47 | ):
48 | reduced = ReducedProblemGraph(
49 | constants=constants,
50 | domain=domain,
51 | requirements=requirements,
52 | )
53 | case graph.SceneGraph(
54 | constants=constants,
55 | _predicates=predicates,
56 | scene=scene,
57 | _domain=domain,
58 | _requirements=requirements,
59 | ):
60 | reduced = ReducedSceneGraph(
61 | constants=constants,
62 | domain=domain,
63 | scene=scene,
64 | requirements=requirements,
65 | )
66 | case _:
67 | raise ValueError("Scene must be a SceneGraph or ProblemGraph.")
68 |
69 | for predicate in predicates:
70 | params = predicate["parameters"]
71 | reduced_edge = graph.PlanGraphEdge(
72 | predicate=predicate["typing"],
73 | scene=predicate.get("scene"),
74 | )
75 | match (predicate["typing"], len(params)):
76 | case ("arm-empty", 0):
77 | reduced.add_edge(ReducedNode.CLEAR, ReducedNode.ARM, reduced_edge)
78 | case ("on-table", 1):
79 | reduced.add_edge(params[0], ReducedNode.TABLE, reduced_edge)
80 | case ("clear", 1):
81 | reduced.add_edge(ReducedNode.CLEAR, params[0], reduced_edge)
82 | case ("on", 2):
83 | reduced.add_edge(params[0], params[1], reduced_edge)
84 | case ("holding", 1):
85 | reduced.add_edge(params[0], ReducedNode.ARM, reduced_edge)
86 | return reduced
87 |
88 | def inflate(
89 | scene: ReducedSceneGraph | ReducedProblemGraph,
90 | ) -> graph.SceneGraph:
91 | """Respecify a blocksworld scene graph.
92 |
93 | Args:
94 | scene (ReducedGraph): The reduced SceneGraph of a scene.
95 |
96 | Returns:
97 | graph.SceneGraph: The respecified scene graph.
98 | """
99 | constants = []
100 | predicates = []
101 |
102 | for node in scene.nodes:
103 | if not isinstance(node.node, ReducedNode):
104 | constants.append({"name": node.node, "typing": node.typing})
105 |
106 | for u, v, edge in scene.edges:
107 | match (u.node, v.node):
108 | case (ReducedNode.CLEAR, ReducedNode.ARM):
109 | predicates.append(
110 | {
111 | "typing": "arm-empty",
112 | "parameters": [],
113 | "scene": edge.scene,
114 | }
115 | )
116 | case (ReducedNode.CLEAR, _):
117 | predicates.append(
118 | {
119 | "typing": "clear",
120 | "parameters": [v.node],
121 | "scene": edge.scene,
122 | }
123 | )
124 | case (_, ReducedNode.TABLE):
125 | predicates.append(
126 | {
127 | "typing": "on-table",
128 | "parameters": [u.node],
129 | "scene": edge.scene,
130 | }
131 | )
132 | case (_, ReducedNode.ARM):
133 | predicates.append(
134 | {
135 | "typing": "holding",
136 | "parameters": [u.node],
137 | "scene": edge.scene,
138 | }
139 | )
140 | case (_, _):
141 | predicates.append(
142 | {
143 | "typing": "on",
144 | "parameters": [u.node, v.node],
145 | "scene": edge.scene,
146 | }
147 | )
148 |
149 | if isinstance(scene, ReducedProblemGraph):
150 | return graph.ProblemGraph(
151 | constants,
152 | [pred for pred in predicates if pred["scene"] == graph.Scene.INIT],
153 | [pred for pred in predicates if pred["scene"] == graph.Scene.GOAL],
154 | domain="blocksworld",
155 | requirements=scene._requirements,
156 | )
157 | else:
158 | return graph.SceneGraph(
159 | constants,
160 | predicates,
161 | domain="blocksworld",
162 | scene=scene.scene,
163 | requirements=scene._requirements,
164 | )
165 |
166 | @staticmethod
167 | def _blocksworld_underspecified_blocks(
168 | scene: ReducedSceneGraph,
169 | ) -> tuple[set[str], set[str], bool]:
170 | """Finds blocks that are not fully specified.
171 |
172 | Args:
173 | scene (ReducedGraph): The reduced SceneGraph of a scene.
174 |
175 | Returns:
176 | tuple[set[str], set[str], bool]: The set of blocks that are not fully
177 | specified.
178 | - blocks that do not specify what is on top.
179 | - blocks that do not specify what is on the bottom.
180 | """
181 | top_blocks = set()
182 | bottom_blocks = set()
183 | arm_behavior_defined = scene.in_degree(ReducedNode.ARM) > 0
184 | held_block = (
185 | scene.predecessors(ReducedNode.ARM)[0] if arm_behavior_defined else None
186 | )
187 | for node in scene.nodes:
188 | if node.label == graph.Label.CONSTANT:
189 | if not scene.in_edges(node) and node != held_block:
190 | top_blocks.add(node)
191 | if not scene.out_edges(node):
192 | bottom_blocks.add(node)
193 | return top_blocks, bottom_blocks, not arm_behavior_defined
194 |
195 | @staticmethod
196 | def _detached_blocks(
197 | nodesA: set[str],
198 | nodesB: set[str],
199 | scene: ReducedSceneGraph,
200 | ) -> tuple[set[str], set[str]]:
201 | """Finds nodes that are not connected to the rest of the scene graph.
202 |
203 | Args:
204 | nodesA (set[str]): The set of nodes to check.
205 | nodesB (set[str]): The set of nodes to check against.
206 | scene (ReducedGraph): The scene graph to check against.
207 |
208 | Returns:
209 | tuple[set[str], set[str]]: The set of nodes that are not connected to
210 | the rest of the scene graph.
211 | """
212 | _nodesA = set(nodesA)
213 | _nodesB = set(nodesB)
214 |
215 | for a in nodesA:
216 | for b in nodesB:
217 | a_index = scene.nodes.index(a)
218 | b_index = scene.nodes.index(b)
219 | if (
220 | not rx.has_path(scene.graph, a_index, b_index)
221 | and not rx.has_path(scene.graph, b_index, a_index)
222 | and a != b
223 | ):
224 | _nodesA.discard(a)
225 | _nodesB.discard(b)
226 |
227 | return _nodesA, _nodesB
228 |
229 | def fully_specify(
230 | problem: graph.ProblemGraph,
231 | return_reduced: bool = False,
232 | ) -> graph.ProblemGraph | ReducedProblemGraph:
233 | """Fully specifies a blocksworld scene graph.
234 |
235 | Adds any missing edges to fully specify the scene graph, without adding
236 | edges that change the problem represented by the graph.
237 |
238 | Args:
239 | problem (graph.ProblemGraph): The problem graph to fully specify.
240 | return_reduced (bool, optional): Whether to return a reduced problem graph.
241 | Defaults to False.
242 |
243 | Returns:
244 | ProblemGraph | ReducedProblemGraph: The fully specified problem graph.
245 | """
246 | inflated_init, inflated_goal = problem.decompose()
247 | scene = BlocksworldOracle.reduce(inflated_goal)
248 | top_blocks, bottom_blocks, arm_empty = (
249 | BlocksworldOracle._blocksworld_underspecified_blocks(scene)
250 | )
251 | top_blocks_, bottom_blocks_ = BlocksworldOracle._detached_blocks(
252 | top_blocks,
253 | bottom_blocks,
254 | scene,
255 | )
256 |
257 | for block in top_blocks_:
258 | scene.add_edge(
259 | ReducedNode.CLEAR,
260 | block,
261 | graph.PlanGraphEdge(predicate="clear", scene=scene.scene),
262 | )
263 | for block in bottom_blocks_:
264 | scene.add_edge(
265 | block,
266 | ReducedNode.TABLE,
267 | graph.PlanGraphEdge(predicate="on-table", scene=scene.scene),
268 | )
269 |
270 | # handle arm
271 | if arm_empty and not (top_blocks & bottom_blocks):
272 | scene.add_edge(
273 | ReducedNode.CLEAR,
274 | ReducedNode.ARM,
275 | graph.PlanGraphEdge(predicate="arm-empty", scene=scene.scene),
276 | )
277 |
278 | if return_reduced:
279 | return ReducedProblemGraph.join(
280 | BlocksworldOracle.reduce(inflated_init), scene
281 | )
282 | else:
283 | return graph.ProblemGraph.join(
284 | inflated_init, BlocksworldOracle.inflate(scene)
285 | )
286 |
287 | def plan(problem: graph.ProblemGraph) -> list[Action]:
288 | problem = BlocksworldOracle.fully_specify(problem, return_reduced=True)
289 | init, goal = problem.decompose()
290 | actions = []
291 |
292 | # Process init scene
293 | # check if arm is empty
294 | if (
295 | not init.has_edge(ReducedNode.CLEAR, ReducedNode.ARM)
296 | and init.in_degree(ReducedNode.ARM) == 1
297 | ):
298 | obj = init.predecessors(ReducedNode.ARM)[0]
299 | actions.append(Action("putdown", [obj.name]))
300 |
301 | # unstack everything in init
302 | for idx in rx.topological_sort(init.graph):
303 | node = init.nodes[idx]
304 | if isinstance(node.node, ReducedNode):
305 | continue
306 | elif init.successors(node)[0].name in (ReducedNode.ARM, ReducedNode.TABLE):
307 | # if the block is on the table or in the arm, ignore it
308 | continue
309 | else:
310 | actions.append(
311 | Action("unstack", [node.name, init.successors(node)[0].name])
312 | )
313 | actions.append(Action("putdown", [node.name]))
314 |
315 | # Process goal scene
316 | # stack everything in goal
317 | for idx in reversed(rx.topological_sort(goal.graph)):
318 | node = goal.nodes[idx]
319 | if isinstance(node.node, ReducedNode):
320 | continue
321 | elif goal.out_degree(node.node) == 0:
322 | # isn't defined to be on anything (keep on table)
323 | continue
324 | elif goal.successors(node)[0].node in (ReducedNode.ARM, ReducedNode.TABLE):
325 | # if the block is on the table or in the arm, ignore it
326 | continue
327 | else:
328 | actions.append(Action("pickup", [node.name]))
329 | actions.append(
330 | Action("stack", [node.name, goal.successors(node)[0].name])
331 | )
332 |
333 | # Check if arm should be holding it
334 | if (
335 | not goal.has_edge(ReducedNode.CLEAR, ReducedNode.ARM)
336 | and goal.in_degree(ReducedNode.ARM) == 1
337 | ):
338 | obj = goal.predecessors(ReducedNode.ARM)[0]
339 | actions.append(Action("pickup", [obj.name]))
340 |
341 | return actions
342 |
--------------------------------------------------------------------------------
/planetarium/oracles/floortile.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import copy
3 |
4 | from pddl.core import Action
5 | import rustworkx as rx
6 |
7 | from . import oracle
8 | from .. import graph
9 | from ..reduced_graph import ReducedSceneGraph, ReducedProblemGraph, ReducedNode
10 |
11 | # Add the ReducedNode enum for the gripper domain
12 | ReducedNode.register(
13 | {
14 | "AVAILABLE": "available-color",
15 | },
16 | "floor-tile",
17 | )
18 |
19 |
20 | class FloorTileOracle(oracle.Oracle):
21 |
22 | def reduce(
23 | scene: graph.SceneGraph | graph.ProblemGraph,
24 | ) -> ReducedSceneGraph | ReducedProblemGraph:
25 | """Reduces a floortile scene graph to a reduced scene graph.
26 |
27 | Args:
28 | scene (graph.SceneGraph | graph.ProblemGraph): The scene graph to reduce.
29 |
30 | Returns:
31 | ReducedSceneGraph | ReducedProblemGraph: The reduced scene graph.
32 | """
33 | match scene:
34 | case graph.ProblemGraph(
35 | _constants=constants,
36 | _predicates=predicates,
37 | _domain=domain,
38 | _requirements=requirements,
39 | ):
40 | reduced = ReducedProblemGraph(
41 | constants=constants,
42 | domain=domain,
43 | requirements=requirements,
44 | )
45 | case graph.SceneGraph(
46 | constants=constants,
47 | _predicates=predicates,
48 | scene=scene,
49 | _domain=domain,
50 | _requirements=requirements,
51 | ):
52 | reduced = ReducedSceneGraph(
53 | constants=constants,
54 | domain=domain,
55 | scene=scene,
56 | requirements=requirements,
57 | )
58 | case _:
59 | raise ValueError("Scene must be a SceneGraph or ProblemGraph.")
60 |
61 | for predicate in predicates:
62 | params = predicate["parameters"]
63 | reduced_edge = graph.PlanGraphEdge(
64 | predicate=predicate["typing"],
65 | scene=predicate.get("scene"),
66 | )
67 | match (predicate["typing"], params):
68 | case (_, [x, y]):
69 | reduced.add_edge(x, y, reduced_edge)
70 | case ("clear", [x]):
71 | reduced.add_edge(ReducedNode.CLEAR, x, reduced_edge)
72 | case ("available-color", [x]):
73 | reduced.add_edge(ReducedNode.AVAILABLE, x, reduced_edge)
74 |
75 | return reduced
76 |
77 | def inflate(
78 | scene: ReducedSceneGraph | ReducedProblemGraph,
79 | ) -> graph.SceneGraph | graph.ProblemGraph:
80 | """Respecify a reduced floortile scene graph to a full scene graph.
81 |
82 | Args:
83 | scene (ReducedSceneGraph | ReducedProblemGraph): The reduced scene graph
84 | to inflate.
85 |
86 | Returns:
87 | graph.SceneGraph | graph.ProblemGraph: The inflated scene graph.
88 | """
89 | constants = []
90 | predicates = []
91 |
92 | for node in scene.nodes:
93 | if (
94 | not isinstance(node.node, ReducedNode)
95 | and node.label == graph.Label.CONSTANT
96 | ):
97 | # add constants
98 | constants.append({"name": node.node, "typing": node.typing})
99 |
100 | for u, v, edge in scene.edges:
101 | match u.node:
102 | case pred_node if isinstance(pred_node, ReducedNode):
103 | predicates.append(
104 | {
105 | "typing": edge.predicate,
106 | "parameters": [v.node],
107 | "scene": edge.scene,
108 | }
109 | )
110 | case _:
111 | predicates.append(
112 | {
113 | "typing": edge.predicate,
114 | "parameters": [u.node, v.node],
115 | "scene": edge.scene,
116 | }
117 | )
118 |
119 | if isinstance(scene, ReducedProblemGraph):
120 | return graph.ProblemGraph(
121 | constants,
122 | [pred for pred in predicates if pred["scene"] == graph.Scene.INIT],
123 | [pred for pred in predicates if pred["scene"] == graph.Scene.GOAL],
124 | domain=scene.domain,
125 | requirements=scene._requirements,
126 | )
127 | else:
128 | return graph.SceneGraph(
129 | constants,
130 | predicates,
131 | domain=scene.domain,
132 | scene=scene.scene,
133 | requirements=scene._requirements,
134 | )
135 |
136 | @staticmethod
137 | def _apply_unchangeable_predicates(
138 | init: ReducedSceneGraph,
139 | goal: ReducedSceneGraph,
140 | ):
141 | unchangeable = {
142 | "up",
143 | "right",
144 | "painted",
145 | "available-color",
146 | }
147 | for u, v, edge in init.edges:
148 | if edge.predicate in unchangeable:
149 | edge = copy.deepcopy(edge)
150 | edge.scene = graph.Scene.GOAL
151 | if not goal.has_edge(u, v, edge):
152 | goal.add_edge(u, v, edge)
153 |
154 | @staticmethod
155 | def _fixed_color_predicates(
156 | init: ReducedSceneGraph,
157 | goal: ReducedSceneGraph,
158 | robots: list[ReducedNode],
159 | ):
160 | # if two or more colors are available, then the robot color cannot be fixed
161 | # (they can finish painting and switch to any of the available colors)
162 | available_colors = init.successors(ReducedNode.AVAILABLE)
163 | if len(available_colors) > 1:
164 | return
165 |
166 | # find the color each robot has
167 | has_color = defaultdict(list)
168 | for robot in robots:
169 | has_color[robot] = [
170 | v for v, edge in init.out_edges(robot) if edge.predicate == "robot-has"
171 | ]
172 |
173 | # if no colors are available, then all robots must have their original color
174 | if len(available_colors) == 0:
175 | for robot, colors in has_color.items():
176 | edge = graph.PlanGraphEdge(
177 | predicate="robot-has",
178 | scene=graph.Scene.GOAL,
179 | )
180 | for color in colors:
181 | if not goal.has_edge(robot, color, edge):
182 | goal.add_edge(robot, color, edge)
183 | # if only one color is available:
184 | elif len(available_colors) == 1:
185 | # if there is only one robot and it either
186 | # a) has the available color
187 | # b) the available color needs to be painted
188 | #
189 | # then the robot color is fixed to the available color
190 | (available_color,) = available_colors
191 | if len(robots) == 1:
192 | (robot,) = robots
193 | robot_color = has_color[robot][0]
194 | if robot_color == available_color:
195 | edge = graph.PlanGraphEdge(
196 | predicate="robot-has",
197 | scene=graph.Scene.GOAL,
198 | )
199 | if not goal.has_edge(robot, robot_color, edge):
200 | goal.add_edge(robot, robot_color, edge)
201 | else:
202 | painted_colors = set()
203 | for u, v, edge in goal.edges:
204 | if edge.predicate == "painted":
205 | init_edge = copy.deepcopy(edge)
206 | if not init.has_edge(u, v, init_edge):
207 | painted_colors.add(v)
208 |
209 | if available_color in painted_colors:
210 | edge = graph.PlanGraphEdge(
211 | predicate="robot-has",
212 | scene=graph.Scene.GOAL,
213 | )
214 | if not goal.has_edge(robot, available_color, edge):
215 | goal.add_edge(robot, available_color, edge)
216 | else:
217 | # if there are more than one robot, every robot that already has that
218 | # color must keep it
219 | for robot, colors in has_color.items():
220 | if available_color in colors:
221 | edge = graph.PlanGraphEdge(
222 | predicate="robot-has",
223 | scene=graph.Scene.GOAL,
224 | )
225 | if not goal.has_edge(robot, available_color, edge):
226 | goal.add_edge(robot, available_color, edge)
227 |
228 | # for each node that needs to be painted the available_color,
229 | # if there is only one robot that can paint it, then that robot
230 | # must paint end with that color
231 | painted_nodes = []
232 | node_reachable_by: list[list] = []
233 |
234 | subgraph_nodes = [
235 | i
236 | for i, n in enumerate(goal.nodes)
237 | if n.typing in ({"tile"}, {"robot"})
238 | ]
239 | subgraph = init.graph.subgraph(subgraph_nodes).to_undirected()
240 |
241 | for u, v, edge in goal.edges:
242 | if edge.predicate == "painted":
243 | painted_nodes.append(u)
244 | reachable = []
245 | # find all robots that can reach this node
246 | for robot in robots:
247 | robot_idx = subgraph.nodes().index(robot)
248 | u_idx = subgraph.nodes().index(u)
249 |
250 | if rx.has_path(subgraph, robot_idx, u_idx):
251 | reachable.append(robot)
252 | node_reachable_by.append(reachable)
253 |
254 | for r in node_reachable_by:
255 | # if there's only one robot that can paint it, then it must end up
256 | # painting it
257 | if len(r) == 1:
258 | robot = r[0]
259 | # assign the robot the only available color
260 | # (notice branch above ensures there's only one available color)
261 | edge = graph.PlanGraphEdge(
262 | predicate="robot-has",
263 | scene=graph.Scene.GOAL,
264 | )
265 | if not goal.has_edge(robot, available_color, edge):
266 | goal.add_edge(robot, available_color, edge)
267 |
268 | def _fix_possible_positions(
269 | init: ReducedSceneGraph,
270 | goal: ReducedSceneGraph,
271 | robots: list[graph.PlanGraphNode],
272 | ):
273 | for robot in robots:
274 | init_pos = [
275 | v for v, edge in init.out_edges(robot) if edge.predicate == "robot-at"
276 | ]
277 | goal_pos = [
278 | v for v, edge in goal.out_edges(robot) if edge.predicate == "robot-at"
279 | ]
280 |
281 | if not init_pos or goal_pos:
282 | # if no initial position is specified or the goal position is already
283 | # specified, we can't determine the final robot position
284 | continue
285 |
286 | init_pos = init_pos[0]
287 | pos_neighbors = [
288 | v
289 | for v, edge in (*init.out_edges(init_pos), *init.in_edges(init_pos))
290 | if edge.predicate in {"up", "right"}
291 | ]
292 |
293 | if not pos_neighbors:
294 | # if the initial position has no neighbors, the robot must end there
295 | edge = graph.PlanGraphEdge(
296 | predicate="robot-at",
297 | scene=graph.Scene.GOAL,
298 | )
299 | if not goal.has_edge(robot, init_pos, edge):
300 | goal.add_edge(robot, init_pos, edge)
301 |
302 | # TODO: if a robot is the only one that can paint a tile, it must end up painting it
303 | # if a robot ends up on a tile that is disconnected, then it must end up on that tile:
304 |
305 | def fully_specify(
306 | problem: graph.ProblemGraph,
307 | return_reduced: bool = False,
308 | ) -> graph.ProblemGraph | ReducedProblemGraph:
309 | """Fully specifies a floortile scene graph.
310 |
311 | Args:
312 | problem (graph.ProblemGraph): The problem graph to fully specify.
313 | return_reduced (bool, optional): Whether to return a reduced problem graph.
314 | Defaults to False.
315 |
316 | Returns:
317 | graph.ProblemGraph | ReducedProblemGraph: The fully specified problem graph.
318 | """
319 | inflated_init, inflated_goal = problem.decompose()
320 |
321 | init: ReducedSceneGraph = FloorTileOracle.reduce(inflated_init)
322 | goal: ReducedSceneGraph = FloorTileOracle.reduce(inflated_goal)
323 |
324 | robots = [r for r in init.nodes if r.typing == {"robot"}]
325 |
326 | FloorTileOracle._apply_unchangeable_predicates(init, goal)
327 | FloorTileOracle._fixed_color_predicates(init, goal, robots)
328 |
329 | # if a robot that starts on a tile with no neighbors, it must also end
330 | # on that tile
331 | FloorTileOracle._fix_possible_positions(init, goal, robots)
332 |
333 | if return_reduced:
334 | return ReducedProblemGraph.join(init, goal)
335 | else:
336 | return graph.ProblemGraph.join(inflated_init, FloorTileOracle.inflate(goal))
337 |
--------------------------------------------------------------------------------
/planetarium/oracles/gripper.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import copy
3 |
4 | from pddl.core import Action
5 |
6 | from . import oracle
7 | from .. import graph
8 | from ..reduced_graph import ReducedSceneGraph, ReducedProblemGraph, ReducedNode
9 |
10 |
11 | # Add the ReducedNode enum for the gripper domain
12 | ReducedNode.register(
13 | {
14 | "ROOMS": "room",
15 | "BALLS": "ball",
16 | "GRIPPERS": "gripper",
17 | "ROBBY": "at-robby",
18 | "FREE": "free",
19 | },
20 | "gripper",
21 | )
22 |
23 |
24 | class GripperOracle(oracle.Oracle):
25 |
26 | def reduce(
27 | scene: graph.SceneGraph | graph.ProblemGraph,
28 | ) -> ReducedSceneGraph | ReducedProblemGraph:
29 | """Reduces a gripper scene graph to a Directed Acyclic Graph.
30 |
31 | Args:
32 | scene (graph.SceneGraph): The scene graph to reduce.
33 |
34 | Returns:
35 | ReducedGraph: The reduced problem graph.
36 | """
37 | nodes = defaultdict(list)
38 | for node in scene.nodes:
39 | nodes[node.label].append(node)
40 |
41 | match scene:
42 | case graph.ProblemGraph(
43 | _constants=constants,
44 | _predicates=predicates,
45 | _domain=domain,
46 | _requirements=requirements,
47 | ):
48 | reduced = ReducedProblemGraph(
49 | constants=constants,
50 | domain=domain,
51 | requirements=requirements,
52 | )
53 | case graph.SceneGraph(
54 | constants=constants,
55 | _predicates=predicates,
56 | scene=scene,
57 | _domain=domain,
58 | _requirements=requirements,
59 | ):
60 | reduced = ReducedSceneGraph(
61 | constants=constants,
62 | domain=domain,
63 | scene=scene,
64 | requirements=requirements,
65 | )
66 | case _:
67 | raise ValueError("Scene must be a SceneGraph or ProblemGraph.")
68 |
69 | for predicate in predicates:
70 | params = predicate["parameters"]
71 | reduced_edge = graph.PlanGraphEdge(
72 | predicate=predicate["typing"],
73 | scene=predicate.get("scene"),
74 | )
75 | match (predicate["typing"], len(params)):
76 | case ("at-robby", 1):
77 | reduced.add_edge(ReducedNode.ROBBY, params[0], reduced_edge)
78 | case ("free", 1):
79 | reduced.add_edge(ReducedNode.FREE, params[0], reduced_edge)
80 | case ("ball", 1):
81 | reduced.add_edge(ReducedNode.BALLS, params[0], reduced_edge)
82 | case ("gripper", 1):
83 | reduced.add_edge(ReducedNode.GRIPPERS, params[0], reduced_edge)
84 | case ("room", 1):
85 | reduced.add_edge(ReducedNode.ROOMS, params[0], reduced_edge)
86 | case ("at", 2):
87 | reduced.add_edge(params[0], params[1], reduced_edge)
88 | case ("carry", 2):
89 | reduced.add_edge(params[0], params[1], reduced_edge)
90 |
91 | return reduced
92 |
93 | def inflate(
94 | scene: ReducedSceneGraph | ReducedProblemGraph,
95 | ) -> graph.SceneGraph | graph.ProblemGraph:
96 | """Respecify a gripper scene graph.
97 |
98 | Args:
99 | scene (ReducedGraph): The reduced SceneGraph of a scene.
100 |
101 | Returns:
102 | graph.SceneGraph: The respecified scene graph.
103 | """
104 | constants = []
105 | predicates = []
106 |
107 | for node in scene.nodes:
108 | if not isinstance(node.node, ReducedNode):
109 | constants.append({"name": node.node, "typing": node.typing})
110 |
111 | for u, v, edge in scene.edges:
112 | match (u.node, v.node):
113 | case (ReducedNode.ROBBY, _):
114 | predicates.append(
115 | {
116 | "typing": "at-robby",
117 | "parameters": [v.node],
118 | "scene": edge.scene,
119 | }
120 | )
121 | case (ReducedNode.FREE, _):
122 | predicates.append(
123 | {
124 | "typing": "free",
125 | "parameters": [v.node],
126 | "scene": edge.scene,
127 | }
128 | )
129 | case (ReducedNode.BALLS, _):
130 | predicates.append(
131 | {
132 | "typing": "ball",
133 | "parameters": [v.node],
134 | "scene": edge.scene,
135 | }
136 | )
137 | case (ReducedNode.GRIPPERS, _):
138 | predicates.append(
139 | {
140 | "typing": "gripper",
141 | "parameters": [v.node],
142 | "scene": edge.scene,
143 | }
144 | )
145 | case (ReducedNode.ROOMS, _):
146 | predicates.append(
147 | {
148 | "typing": "room",
149 | "parameters": [v.node],
150 | "scene": edge.scene,
151 | }
152 | )
153 | case (_, _):
154 | predicates.append(
155 | {
156 | "typing": edge.predicate,
157 | "parameters": [u.node, v.node],
158 | "scene": edge.scene,
159 | }
160 | )
161 |
162 | if isinstance(scene, ReducedProblemGraph):
163 | return graph.ProblemGraph(
164 | constants,
165 | [pred for pred in predicates if pred["scene"] == graph.Scene.INIT],
166 | [pred for pred in predicates if pred["scene"] == graph.Scene.GOAL],
167 | domain="gripper",
168 | requirements=scene._requirements,
169 | )
170 | else:
171 | return graph.SceneGraph(
172 | constants,
173 | predicates,
174 | domain="gripper",
175 | scene=scene.scene,
176 | requirements=scene._requirements,
177 | )
178 |
179 | @staticmethod
180 | def _gripper_get_typed_objects(
181 | scene: ReducedSceneGraph,
182 | ) -> dict[ReducedNode, set[graph.PlanGraphNode]]:
183 | """Get the typed objects in a gripper scene graph.
184 |
185 | Args:
186 | scene (ReducedGraph): The reduced SceneGraph of a scene.
187 |
188 | Returns:
189 | dict[ReducedNode, set[graph.PlanGraphNode]]: The typed objects in the
190 | scene graph.
191 | """
192 | rooms = set()
193 | balls = set()
194 | grippers = set()
195 |
196 | for node, _ in scene.out_edges(ReducedNode.ROOMS):
197 | rooms.add(node)
198 | for node, _ in scene.out_edges(ReducedNode.BALLS):
199 | balls.add(node)
200 | for node, _ in scene.out_edges(ReducedNode.GRIPPERS):
201 | grippers.add(node)
202 |
203 | return {
204 | ReducedNode.ROOMS: rooms,
205 | ReducedNode.BALLS: balls,
206 | ReducedNode.GRIPPERS: grippers,
207 | }
208 |
209 | @staticmethod
210 | def _gripper_underspecified_blocks(
211 | init: ReducedSceneGraph,
212 | goal: ReducedSceneGraph,
213 | ) -> tuple[set[str], set[str], bool]:
214 | """Finds blocks that are not fully specified.
215 |
216 | Args:
217 | init (ReducedGraph): The reduced SceneGraph of the initial scene.
218 | goal (ReducedGraph): The reduced SceneGraph of the goal scene.
219 |
220 | Returns:
221 | tuple[set[str], set[str]]: The set of blocks that are not fully
222 | specified.
223 | - balls that do not specify being carried or being in a room.
224 | - grippers that do not specify being free or carrying a ball.
225 | - whether robby is not in a room.
226 | """
227 |
228 | typed = GripperOracle._gripper_get_typed_objects(init)
229 |
230 | underspecified_balls = set()
231 | underspecified_grippers = set()
232 |
233 | for ball in typed[ReducedNode.BALLS]:
234 | ball_edges = [
235 | node
236 | for node, _ in goal.out_edges(ball)
237 | if not isinstance(node, ReducedNode)
238 | ]
239 | if not ball_edges:
240 | underspecified_balls.add(ball)
241 | for gripper in typed[ReducedNode.GRIPPERS]:
242 | gripper_edges = [
243 | node
244 | for node, _ in goal.in_edges(gripper)
245 | if node == ReducedNode.FREE or not isinstance(node, ReducedNode)
246 | ]
247 | if not gripper_edges:
248 | underspecified_grippers.add(gripper)
249 |
250 | return (
251 | underspecified_balls,
252 | underspecified_grippers,
253 | goal.out_degree(ReducedNode.ROBBY) == 0,
254 | )
255 |
256 | def fully_specify(
257 | problem: graph.ProblemGraph,
258 | return_reduced: bool = False,
259 | ) -> graph.ProblemGraph | ReducedProblemGraph:
260 | """Fully specifies a gripper scene graph.
261 |
262 | Adds any missing edges to fully specify the scene graph, without adding
263 | edges that change the problem represented by the graph.
264 |
265 | Args:
266 | problem (graph.ProblemGraph): The problem graph to fully specify.
267 | return_reduced (bool, optional): Whether to return a reduced problem graph.
268 | Defaults to False.
269 |
270 | Returns:
271 | ProblemGraph | ReducedProblemGraph: The fully specified problem graph.
272 | """
273 | inflated_init, inflated_goal = problem.decompose()
274 |
275 | init: ReducedSceneGraph = GripperOracle.reduce(inflated_init)
276 | goal: ReducedSceneGraph = GripperOracle.reduce(inflated_goal)
277 |
278 | scene = copy.deepcopy(goal)
279 |
280 | # bring "typing" predicates from init to goal
281 | typed_objects = GripperOracle._gripper_get_typed_objects(init)
282 | for typing, objects in typed_objects.items():
283 | for obj in objects:
284 | edge = graph.PlanGraphEdge(
285 | predicate=typing.value, scene=graph.Scene.GOAL
286 | )
287 | edge_ = graph.PlanGraphEdge(predicate=typing.value)
288 | if obj in scene.nodes and not (
289 | scene.has_edge(typing, obj, edge)
290 | or scene.has_edge(typing, obj, edge_)
291 | ):
292 | scene.add_edge(typing, obj, edge)
293 |
294 | underspecified_balls, underspecified_grippers, _ = (
295 | GripperOracle._gripper_underspecified_blocks(
296 | init,
297 | goal,
298 | )
299 | )
300 |
301 | if underspecified_grippers and not underspecified_balls:
302 | for gripper in underspecified_grippers:
303 | scene.add_edge(
304 | ReducedNode.FREE,
305 | gripper,
306 | graph.PlanGraphEdge(predicate="free", scene=scene.scene),
307 | )
308 |
309 | if return_reduced:
310 | return ReducedProblemGraph.join(init, scene)
311 | else:
312 | return graph.ProblemGraph.join(inflated_init, GripperOracle.inflate(scene))
313 |
314 | def plan(problem: graph.ProblemGraph) -> list[Action]:
315 | # TODO: this function is not "complete": it does not handle all cases
316 | # - multiple "types" per object
317 | # - robby not at a room (can be valid in a few cases)
318 | # - balls not in rooms
319 | # - objects without typing
320 | problem = GripperOracle.fully_specify(problem, return_reduced=True)
321 |
322 | init, goal = problem.decompose()
323 | actions = []
324 |
325 | # Process init scene
326 | typed = GripperOracle._gripper_get_typed_objects(init)
327 | rooms = list(typed[ReducedNode.ROOMS])
328 | grippers = list(typed[ReducedNode.GRIPPERS])
329 |
330 | # get current room
331 | if init.out_degree(ReducedNode.ROBBY) < 1:
332 | return actions
333 |
334 | current_room = init.successors(ReducedNode.ROBBY)[0]
335 | # move to first room
336 | if current_room != rooms[0]:
337 | actions.append(Action("move", [current_room.name, rooms[0].name]))
338 |
339 | # ensure all grippers are free
340 | for gripper in grippers:
341 | if not init.has_edge(ReducedNode.FREE, gripper):
342 | # get in_edge
343 | ball = [
344 | b
345 | for b in init.predecessors(gripper)
346 | if b in typed[ReducedNode.BALLS]
347 | ]
348 | if ball:
349 | actions.append(
350 | Action("drop", [ball[0].name, rooms[0].name, gripper.name])
351 | )
352 |
353 | # move all balls to first room
354 | for room in rooms:
355 | for obj in init.predecessors(room):
356 | if obj in typed[ReducedNode.BALLS]:
357 | actions.append(Action("move", [rooms[0].name, room.name]))
358 | actions.append(
359 | Action("pick", [obj.name, room.name, grippers[0].name])
360 | )
361 | actions.append(Action("move", [room.name, rooms[0].name]))
362 | actions.append(
363 | Action("drop", [obj.name, rooms[0].name, grippers[0].name])
364 | )
365 |
366 | # Process goal scene
367 | for room in rooms:
368 | for obj in goal.predecessors(room):
369 | if obj in typed[ReducedNode.BALLS]:
370 | actions.append(
371 | Action("pick", [obj.name, rooms[0].name, grippers[0].name])
372 | )
373 | actions.append(Action("move", [rooms[0].name, room.name]))
374 | actions.append(
375 | Action("drop", [obj.name, room.name, grippers[0].name])
376 | )
377 | actions.append(Action("move", [room.name, rooms[0].name]))
378 |
379 | # pick up balls in first room tied to grippers
380 | for gripper in grippers:
381 | for ball in typed[ReducedNode.BALLS]:
382 | if goal.has_edge(ball, gripper):
383 | actions.append(
384 | Action("pick", [ball.name, rooms[0].name, gripper.name])
385 | )
386 |
387 | # move to room with robby
388 | goal_room = next(iter(goal.successors(ReducedNode.ROBBY)), None)
389 | if goal_room:
390 | actions.append(Action("move", [rooms[0].name, goal_room.name]))
391 |
392 | return actions
393 |
--------------------------------------------------------------------------------
/planetarium/oracles/oracle.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | from pddl.core import Action
4 |
5 | from .. import graph, reduced_graph
6 |
7 |
8 | class Oracle(abc.ABC):
9 |
10 | @staticmethod
11 | def reduce(
12 | scene: graph.SceneGraph | graph.ProblemGraph,
13 | ) -> reduced_graph.ReducedSceneGraph | reduced_graph.ReducedProblemGraph:
14 | """Reduces a scene graph to a reduced scene graph.
15 |
16 | Args:
17 | scene (graph.SceneGraph | graph.ProblemGraph): The scene graph to reduce.
18 |
19 | Returns:
20 | ReducedSceneGraph | ReducedProblemGraph: The reduced scene graph.
21 | """
22 | match scene:
23 | case graph.ProblemGraph(
24 | _constants=constants,
25 | _predicates=predicates,
26 | _domain=domain,
27 | _requirements=requirements,
28 | ):
29 | reduced = reduced_graph.ReducedProblemGraph(
30 | constants=constants,
31 | domain=domain,
32 | requirements=requirements,
33 | )
34 | case graph.SceneGraph(
35 | constants=constants,
36 | _predicates=predicates,
37 | scene=scene,
38 | _domain=domain,
39 | _requirements=requirements,
40 | ):
41 | reduced = reduced_graph.ReducedSceneGraph(
42 | constants=constants,
43 | domain=domain,
44 | scene=scene,
45 | requirements=requirements,
46 | )
47 | case _:
48 | raise ValueError("Scene must be a SceneGraph or ProblemGraph.")
49 |
50 | for predicate in predicates:
51 | params = predicate["parameters"]
52 | reduced_edge = graph.PlanGraphEdge(
53 | predicate=predicate["typing"],
54 | scene=predicate.get("scene"),
55 | )
56 | match params:
57 | case [x]:
58 | reduced.add_node(x, reduced_edge)
59 | case [x, y]:
60 | reduced.add_edge(x, y, reduced_edge)
61 | case _:
62 | raise ValueError("Predicate parameters must be 1 or 2.")
63 |
64 | @staticmethod
65 | def inflate(
66 | scene: reduced_graph.ReducedSceneGraph | reduced_graph.ReducedProblemGraph,
67 | ) -> graph.SceneGraph | graph.ProblemGraph:
68 | """Inflates a reduced scene graph to a scene graph.
69 |
70 | Args:
71 | scene (ReducedSceneGraph | ReducedProblemGraph): The reduced scene graph to inflate.
72 |
73 | Returns:
74 | SceneGraph | ProblemGraph: The inflated scene graph.
75 | """
76 | constants = []
77 | predicates = []
78 |
79 | for node in scene.nodes:
80 | if (
81 | not isinstance(node.node, reduced_graph.ReducedNode)
82 | and node.label == graph.Label.CONSTANT
83 | ):
84 | # add constants
85 | constants.append({"name": node.node, "typing": node.typing})
86 |
87 | for u, v, edge in scene.edges:
88 | match u.node:
89 | case pred_node if isinstance(pred_node, reduced_graph.ReducedNode):
90 | predicates.append(
91 | {
92 | "typing": edge.predicate,
93 | "parameters": [v.node],
94 | "scene": edge.scene,
95 | }
96 | )
97 | case _:
98 | predicates.append(
99 | {
100 | "typing": edge.predicate,
101 | "parameters": [u.node, v.node],
102 | "scene": edge.scene,
103 | }
104 | )
105 |
106 | if isinstance(scene, reduced_graph.ReducedProblemGraph):
107 | return graph.ProblemGraph(
108 | constants,
109 | [pred for pred in predicates if pred["scene"] == graph.Scene.INIT],
110 | [pred for pred in predicates if pred["scene"] == graph.Scene.GOAL],
111 | domain=scene.domain,
112 | requirements=scene._requirements,
113 | )
114 | else:
115 | return graph.SceneGraph(
116 | constants,
117 | predicates,
118 | domain=scene.domain,
119 | scene=scene.scene,
120 | requirements=scene._requirements,
121 | )
122 |
123 | @staticmethod
124 | @abc.abstractmethod
125 | def fully_specify(
126 | problem: graph.ProblemGraph,
127 | return_reduced: bool = False,
128 | ) -> graph.ProblemGraph | reduced_graph.ReducedProblemGraph:
129 | """Fully specifies a goal state.
130 |
131 | Args:
132 | problem (graph.ProblemGraph): The problem graph to fully specify.
133 | return_reduced (bool, optional): Whether to return a reduced problem graph.
134 | Defaults to False.
135 |
136 | Returns:
137 | ProblemGraph | ReducedProblemGraph: The fully specified problem graph.
138 | """
139 |
140 | @staticmethod
141 | def plan(problem: graph.ProblemGraph) -> list[Action]:
142 | """Generates a plan for a problem graph.
143 |
144 | Args:
145 | problem (graph.ProblemGraph): The problem graph to plan.
146 |
147 | Returns:
148 | list[Action]: The plan for the problem graph.
149 | """
150 | raise NotImplementedError("Planning not supported for this oracle.")
151 |
--------------------------------------------------------------------------------
/planetarium/reduced_graph.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import copy
4 | import aenum
5 |
6 | from planetarium import graph
7 |
8 |
9 | class ReducedNode(str, aenum.Enum):
10 |
11 | @classmethod
12 | def register(
13 | cls,
14 | attrs: dict[str, str],
15 | domain: str | None = None,
16 | ) -> set["ReducedNode"]:
17 | """Extend the ReducedNode enum with new nodes.
18 |
19 | Args:
20 | attrs (dict[str, str]): The new nodes to add.
21 | domain (str, optional): The domain to extend. Defaults to None.
22 |
23 | Raises:
24 | AttributeError: If the domain is already extended.
25 |
26 | Returns:
27 | set[ReducedNode]: The set of new nodes.
28 | """
29 | nodes = set()
30 | for name, value in attrs.items():
31 | if name not in cls.__members__:
32 | aenum.extend_enum(cls, name, value)
33 | nodes.add(cls[name])
34 |
35 | if isinstance(domain, str):
36 | if domain in ReducedNodes:
37 | raise AttributeError(
38 | f"ReducedNode already extended for domain {domain}."
39 | )
40 | ReducedNodes[domain] = nodes
41 |
42 | return nodes
43 |
44 |
45 | ReducedNodes: dict[str, set[ReducedNode]] = {}
46 |
47 |
48 | class ReducedSceneGraph(graph.PlanGraph):
49 | def __init__(
50 | self,
51 | constants: list[dict[str, Any]],
52 | domain: str,
53 | scene: graph.Scene | None = None,
54 | requirements: tuple[str] = (),
55 | ):
56 | super().__init__(constants, domain=domain, requirements=requirements)
57 | self.scene = scene
58 |
59 | for e in ReducedNodes.get(domain, set()):
60 | predicate = e.value
61 | self.add_node(
62 | graph.PlanGraphNode(
63 | e,
64 | name=predicate,
65 | label=graph.Label.PREDICATE,
66 | typing=predicate,
67 | ),
68 | )
69 |
70 | def _add_predicate(
71 | self,
72 | predicate: dict[str, Any],
73 | scene: graph.Scene | None = None,
74 | ):
75 | raise AttributeError(
76 | "ReducedSceneGraph does not support adding predicates directly."
77 | )
78 |
79 |
80 | class ReducedProblemGraph(graph.PlanGraph):
81 | def __init__(
82 | self,
83 | constants: list[dict[str, Any]],
84 | domain: str,
85 | requirements: tuple[str] = (),
86 | ):
87 | super().__init__(constants, domain=domain, requirements=requirements)
88 |
89 | for e in ReducedNodes.get(domain, []):
90 | predicate = e.value
91 | self.add_node(
92 | graph.PlanGraphNode(
93 | e,
94 | name=predicate,
95 | label=graph.Label.PREDICATE,
96 | typing=predicate,
97 | ),
98 | )
99 |
100 | def decompose(self) -> tuple[ReducedSceneGraph, ReducedSceneGraph]:
101 | init = ReducedSceneGraph(
102 | self.constants,
103 | self.domain,
104 | scene=graph.Scene.INIT,
105 | requirements=self._requirements,
106 | )
107 | goal = ReducedSceneGraph(
108 | self.constants,
109 | self.domain,
110 | scene=graph.Scene.GOAL,
111 | requirements=self._requirements,
112 | )
113 |
114 | for u, v, edge in self.edges:
115 | edge = copy.deepcopy(edge)
116 | if edge.scene == graph.Scene.INIT:
117 | init.add_edge(u, v, edge)
118 | elif edge.scene == graph.Scene.GOAL:
119 | goal.add_edge(u, v, edge)
120 |
121 | return init, goal
122 |
123 | @staticmethod
124 | def join(init: ReducedSceneGraph, goal: ReducedSceneGraph) -> "ReducedProblemGraph":
125 | problem = ReducedProblemGraph(
126 | init.constants,
127 | domain=init.domain,
128 | requirements=init._requirements,
129 | )
130 |
131 | for node in (*init.nodes, *goal.nodes):
132 | if (
133 | node not in problem.nodes
134 | and node.label == graph.Label.PREDICATE
135 | and not isinstance(node.node, ReducedNode)
136 | ):
137 | node = copy.deepcopy(node)
138 | problem.add_node(node)
139 |
140 | for u, v, edge in init.edges:
141 | edge = copy.deepcopy(edge)
142 | problem.add_edge(u, v, edge)
143 | edge.scene = graph.Scene.INIT
144 | for u, v, edge in goal.edges:
145 | edge = copy.deepcopy(edge)
146 | edge.scene = graph.Scene.GOAL
147 | problem.add_edge(u, v, edge)
148 |
149 | return problem
150 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "planetarium"
3 | version = "0.1.0"
4 | description = "Benchmark framework for evaluating LLMs performance in the context of PDDL code generation."
5 | authors = [
6 | "Max Zuo ",
7 | "Francisco J. Piedrahita-Velez ",
8 | ]
9 | readme = "README.md"
10 |
11 | [tool.poetry.dependencies]
12 | python = "^3.10"
13 | networkx = "^3.2.1"
14 | pddl = {git = "https://github.com/maxzuo/pddl.git"}
15 | pyyaml = "^6.0.1"
16 | jinja2 = "^3.1.4"
17 | rustworkx = "^0.14.2"
18 | matplotlib = "^3.9.0"
19 | aenum = "^3.1.15"
20 |
21 |
22 | [tool.poetry.group.dev.dependencies]
23 | ruff = "^0.1.7"
24 | pytest = "^7.4.3"
25 | mypy = "^1.7.1"
26 | pytest-cov = "^4.1.0"
27 | pytest-timeout = "^2.2.0"
28 | pytest-subtests = "^0.12.1"
29 | black = {extras = ["jupyter"], version = "^24.4.2"}
30 | pytest-mock = "^3.14.0"
31 |
32 | [tool.poetry.group.all]
33 | optional = true
34 |
35 | [tool.poetry.group.all.dependencies]
36 | lark = "^1.1.9"
37 | vllm = "^0.5.0.post1"
38 | python-dotenv = "^1.0.1"
39 | datasets = "^2.20.0"
40 | peft = "^0.11.1"
41 | trl = "^0.9.4"
42 | bitsandbytes = "^0.43.1"
43 | openai = "^1.35.3"
44 |
45 | [build-system]
46 | requires = ["poetry-core"]
47 | build-backend = "poetry.core.masonry.api"
48 |
49 | [tool.mypy]
50 | allow_redefinition = true
51 | ignore_missing_imports = true
52 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BatsResearch/planetarium/f257ae9e68b813aeace00cc683d0c2ad5b57a157/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_evaluate.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 | import pytest
3 |
4 | import planetarium
5 |
6 | from .problem_fixtures import (
7 | blocksworld_underspecified,
8 | blocksworld_missing_clears,
9 | blocksworld_missing_ontables,
10 | blocksworld_fully_specified,
11 | blocksworld_invalid_1,
12 | rover_single_line_equiv,
13 | rover_single_line_equiva,
14 | rover_single_line_equiv_1,
15 | rover_single_line_equiv_1a,
16 | rover_single_line_equiv_1b,
17 | floortile_no_white1,
18 | floortile_no_white1a,
19 | floortile_disconnected_tile1,
20 | floortile_disconnected_tile1a,
21 | floortile_one_color_one_robot1,
22 | floortile_one_color_one_robot1a,
23 | floortile_no_available_colors,
24 | floortile_no_available_colors_a,
25 | )
26 |
27 |
28 | @pytest.fixture
29 | def blocksworld_wrong_init():
30 | """
31 | Fixture providing a fully specified blocksworld problem with wrong init.
32 | """
33 | return """
34 | (define (problem staircase)
35 | (:domain blocksworld)
36 | (:objects
37 | b1 b2 b3 b4 b5 b6
38 | )
39 | (:init
40 | (arm-empty)
41 | (on-table b1)
42 | (clear b1)
43 | (on-table b2)
44 | (clear b2)
45 | (on-table b3)
46 | (clear b3)
47 | (on-table b4)
48 | (clear b4)
49 | (on-table b5)
50 | (on b6 b5)
51 | (clear b6)
52 | )
53 | (:goal
54 | (and
55 | (on-table b1)
56 | (clear b1)
57 |
58 | (on-table b2)
59 | (on b3 b2)
60 | (clear b3)
61 |
62 | (on-table b4)
63 | (on b5 b4)
64 | (on b6 b5)
65 | (clear b6)
66 | )
67 | )
68 | )
69 | """
70 |
71 |
72 | @pytest.fixture
73 | def blocksworld_fully_specified_wrong_domain():
74 | """
75 | Fixture providing a fully specified blocksworld problem with wrong domain.
76 | """
77 | return """
78 | (define (problem staircase)
79 | (:domain blocksworld-wrong)
80 | (:objects
81 | b1 b2 b3 b4 b5 b6
82 | )
83 | (:init
84 | (arm-empty)
85 | (on-table b1)
86 | (clear b1)
87 | (on-table b2)
88 | (clear b2)
89 | (on-table b3)
90 | (clear b3)
91 | (on-table b4)
92 | (clear b4)
93 | (on-table b5)
94 | (clear b5)
95 | (on-table b6)
96 | (clear b6)
97 | )
98 | (:goal
99 | (and
100 | (on-table b1)
101 | (clear b1)
102 |
103 | (on-table b2)
104 | (on b3 b2)
105 | (clear b3)
106 |
107 | (on-table b4)
108 | (on b5 b4)
109 | (on b6 b5)
110 | (clear b6)
111 | )
112 | )
113 | )
114 | """
115 |
116 |
117 | @pytest.fixture
118 | def blocksworld_unsolveable():
119 | """
120 | Fixture providing a fully specified blocksworld problem with wrong init.
121 | """
122 | return """
123 | (define (problem staircase)
124 | (:domain blocksworld)
125 | (:objects
126 | b1 b2 b3 b4 b5 b6
127 | )
128 | (:init
129 | (arm-empty)
130 | (on-table b1)
131 | (clear b1)
132 | (on-table b2)
133 | (clear b2)
134 | (on-table b3)
135 | (clear b3)
136 | (on-table b4)
137 | (clear b4)
138 | (on-table b5)
139 | (on b6 b5)
140 | (clear b6)
141 | )
142 | (:goal
143 | (and
144 | (on-table b1)
145 | (clear b1)
146 |
147 | (on-table b2)
148 | (on b3 b2)
149 | (clear b3)
150 |
151 | (on-table b4)
152 | (on b5 b4)
153 | (clear b5)
154 | (on b6 b5)
155 | (clear b6)
156 | )
157 | )
158 | )
159 | """
160 |
161 |
162 | @pytest.fixture
163 | def blocksworld_unparseable():
164 | """
165 | Fixture providing an unparseable blocksworld problem.
166 | """
167 | return """
168 | (define (problem staircase)
169 | (:domain blocksworld)
170 | (:objects
171 | b1 b2 b3 b4 b5 b6
172 | ))
173 | (:init
174 | (arm-empty)
175 | (on-table b1)
176 | (clear b1)
177 | (on-table b2)
178 | (clear b2)
179 | (on-table b3)
180 | (clear b3)
181 | (on-table b4)
182 | (clear b4)
183 | (on-table b5)
184 | (on b6 b5)
185 | (clear b6)
186 | )
187 | (:goal
188 | (and
189 | (on-table b1)
190 | (clear b1)
191 |
192 | (on-table b2)
193 | (on b3 b2)
194 | (clear b3)
195 |
196 | (on-table b4)
197 | (on b5 b4)
198 | (on b6 b5)
199 | (clear b6)
200 | )
201 | )
202 | )
203 | """
204 |
205 |
206 | class TestEvaluate:
207 | """
208 | Test suite for the evaluation of PDDL problem descriptions.
209 | """
210 |
211 | def test_evaluate_equivalent(
212 | self,
213 | subtests,
214 | blocksworld_missing_clears,
215 | blocksworld_fully_specified,
216 | blocksworld_missing_ontables,
217 | blocksworld_underspecified,
218 | ):
219 | """
220 | Test if the evaluation of PDDL problem descriptions is correct.
221 | """
222 | descs = [
223 | ("blocksworld_missing_clears", blocksworld_missing_clears),
224 | ("blocksworld_fully_specified", blocksworld_fully_specified),
225 | ("blocksworld_missing_ontables", blocksworld_missing_ontables),
226 | ]
227 | for (name1, desc1), (name2, desc2) in product(descs, descs):
228 | with subtests.test(f"{name1} equals {name2}"):
229 | assert all(planetarium.evaluate(desc1, desc2))
230 |
231 | with subtests.test(
232 | "blocksworld_underspecified equals blocksworld_underspecified"
233 | ):
234 | assert all(
235 | planetarium.evaluate(
236 | blocksworld_underspecified, blocksworld_underspecified
237 | )
238 | )
239 |
240 | def test_evaluate_inequivalent(
241 | self,
242 | subtests,
243 | blocksworld_missing_clears,
244 | blocksworld_fully_specified,
245 | blocksworld_missing_ontables,
246 | blocksworld_underspecified,
247 | blocksworld_wrong_init,
248 | blocksworld_unparseable,
249 | blocksworld_unsolveable,
250 | ):
251 | """
252 | Test if the evaluation of PDDL problem descriptions is correct.
253 | """
254 | descs = [
255 | ("blocksworld_missing_clears", blocksworld_missing_clears),
256 | ("blocksworld_fully_specified", blocksworld_fully_specified),
257 | ("blocksworld_missing_ontables", blocksworld_missing_ontables),
258 | ]
259 | for name, desc in descs:
260 | with subtests.test(f"{name} not equals blocksworld_underspecified"):
261 | assert planetarium.evaluate(desc, blocksworld_underspecified) == (
262 | True,
263 | True,
264 | False,
265 | )
266 |
267 | with subtests.test(f"{name} not equals blocksworld_wrong_init"):
268 | assert planetarium.evaluate(desc, blocksworld_wrong_init) == (
269 | True,
270 | True,
271 | False,
272 | )
273 | with subtests.test(f"{name} not equals blocksworld_unparseable"):
274 | assert planetarium.evaluate(desc, blocksworld_unparseable) == (
275 | False,
276 | False,
277 | False,
278 | )
279 | with subtests.test(f"{name} not equals blocksworld_unsolveable"):
280 | assert planetarium.evaluate(desc, blocksworld_unsolveable) == (
281 | True,
282 | False,
283 | False,
284 | )
285 |
286 | with subtests.test(
287 | "blocksworld_underspecified not equals blocksworld_wrong_init"
288 | ):
289 | assert planetarium.evaluate(
290 | blocksworld_underspecified, blocksworld_wrong_init
291 | ) == (
292 | True,
293 | True,
294 | False,
295 | )
296 |
297 | def test_rover_single_equivalent(
298 | self,
299 | subtests,
300 | rover_single_line_equiv,
301 | rover_single_line_equiva,
302 | rover_single_line_equiv_1,
303 | rover_single_line_equiv_1a,
304 | rover_single_line_equiv_1b,
305 | ):
306 | """
307 | Test if the evaluation of PDDL problem descriptions is correct.
308 | """
309 |
310 | with subtests.test("rover_single_line_equiv equals rover_single_line_equiva"):
311 | assert all(
312 | planetarium.evaluate(
313 | rover_single_line_equiv,
314 | rover_single_line_equiva,
315 | alias="lama-first",
316 | )
317 | )
318 |
319 | descs = {
320 | "rover_single_line_equiv_1": rover_single_line_equiv_1,
321 | "rover_single_line_equiv_1a": rover_single_line_equiv_1a,
322 | "rover_single_line_equiv_1b": rover_single_line_equiv_1b,
323 | }
324 | for (name1, desc1), (name2, desc2) in product(descs.items(), descs.items()):
325 | with subtests.test(f"{name1} equals {name2}"):
326 | assert all(planetarium.evaluate(desc1, desc2, alias="lama-first"))
327 |
328 | def test_floortile_equivalent(
329 | self,
330 | subtests,
331 | floortile_no_white1,
332 | floortile_no_white1a,
333 | floortile_disconnected_tile1,
334 | floortile_disconnected_tile1a,
335 | floortile_one_color_one_robot1,
336 | floortile_one_color_one_robot1a,
337 | floortile_no_available_colors,
338 | floortile_no_available_colors_a,
339 | ):
340 | """
341 | Test if the evaluation of PDDL problem descriptions is correct.
342 | """
343 | descs = {
344 | "floortile_no_white1": floortile_no_white1,
345 | "floortile_no_white1a": floortile_no_white1a,
346 | "floortile_disconnected_tile1": floortile_disconnected_tile1,
347 | "floortile_disconnected_tile1a": floortile_disconnected_tile1a,
348 | "floortile_one_color_one_robot1": floortile_one_color_one_robot1,
349 | "floortile_one_color_one_robot1a": floortile_one_color_one_robot1a,
350 | "floortile_no_available_colors": floortile_no_available_colors,
351 | "floortile_no_available_colors_a": floortile_no_available_colors_a,
352 | }
353 | equiv_pairs = [
354 | ("floortile_no_white1", "floortile_no_white1a"),
355 | ("floortile_disconnected_tile1", "floortile_disconnected_tile1a"),
356 | ("floortile_one_color_one_robot1", "floortile_one_color_one_robot1a"),
357 | ("floortile_no_available_colors", "floortile_no_available_colors_a"),
358 | ]
359 |
360 | for n1, n2 in equiv_pairs:
361 | test_name = f"{n2} equals {n1}"
362 | with subtests.test(test_name):
363 | assert all(
364 | planetarium.evaluate(descs[n1], descs[n2], alias="lama-first")
365 | )
366 |
367 |
368 | class TestUnsupportedDomain:
369 | """
370 | Test suite for unsupported domain.
371 | """
372 |
373 | def test_plan(
374 | self, blocksworld_fully_specified, blocksworld_fully_specified_wrong_domain
375 | ):
376 | """
377 | Test if the oracle can plan for an unsupported domain.
378 | """
379 | assert planetarium.evaluate(
380 | blocksworld_fully_specified, blocksworld_fully_specified_wrong_domain
381 | ) == (
382 | True,
383 | False,
384 | False,
385 | )
386 |
--------------------------------------------------------------------------------
/tests/test_graph.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from .test_pddl import problem_string
4 |
5 | from planetarium import builder
6 |
7 |
8 | @pytest.fixture
9 | def sgraph(problem_string):
10 | """
11 | Fixture providing an SGraph instance built from a PDDL problem string.
12 | """
13 | return builder.build(problem_string).decompose()[0]
14 |
15 |
16 | class TestGraph:
17 | """
18 | Test suite for the SGraph instance built from a PDDL problem.
19 | """
20 |
21 | def test_constant_node_names(self, sgraph):
22 | """
23 | Test if the names of constant nodes in the graph match the expected set.
24 | """
25 | names = set(["p0", "p1", "f0", "f1", "f2", "f3"])
26 | assert all([(node.name in names) for node in sgraph.constant_nodes])
27 |
28 | def test_constant_node_size(self, sgraph):
29 | """
30 | Test if the number of constant nodes in the graph matches the expected count.
31 | """
32 | assert len(sgraph.constant_nodes) == 6
33 |
34 | def test_predicate_names(self, sgraph):
35 | """
36 | Test if the names of predicate nodes in the graph match expected patterns.
37 | """
38 | for predicate in sgraph.predicate_nodes:
39 | match predicate.node.split("-"):
40 | case ["above", _, _]:
41 | assert True
42 | case ["origin", _, _]:
43 | assert True
44 | case ["destin", _, _]:
45 | assert True
46 | case ["lift", "at", _]:
47 | assert True
48 | case _:
49 | assert False
50 |
51 | def test_plot(self, sgraph):
52 | """
53 | Test if the graph can be plotted.
54 | """
55 | sgraph.plot()
56 | assert True
57 |
--------------------------------------------------------------------------------
/tests/test_metric.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from planetarium import builder, graph, metric, oracle
4 |
5 | # pylint: disable=unused-import
6 | from .test_pddl import (
7 | problem_string,
8 | two_initial_problem_string,
9 | renamed_problem_string,
10 | wrong_problem_string,
11 | swap_problem_string,
12 | wrong_swap_problem_string,
13 | move_problem_string,
14 | wrong_move_problem_string,
15 | wrong_initial_problem_string,
16 | )
17 |
18 | from .problem_fixtures import (
19 | blocksworld_underspecified,
20 | blocksworld_missing_clears,
21 | blocksworld_missing_ontables,
22 | blocksworld_fully_specified,
23 | gripper_fully_specified,
24 | gripper_no_robby,
25 | rover_single_line_fully_specified_4,
26 | rover_single_line_fully_specified_4a,
27 | )
28 |
29 | from pddl import requirements
30 |
31 |
32 | def problem_states(
33 | problem_init: graph.SceneGraph, problem_goals: set[graph.SceneGraph]
34 | ) -> set[graph.ProblemGraph]:
35 | return set([graph.ProblemGraph.join(problem_init, goal) for goal in problem_goals])
36 |
37 |
38 | class TestConstantMatching:
39 | """
40 | Test suite for constant matching functions in the metric module.
41 | """
42 |
43 | @pytest.fixture
44 | def source(self):
45 | """Fixture for a valid source constant."""
46 | return graph.PlanGraphNode(
47 | "o1", "o1", typing=["t1", "t2"], label=graph.Label.CONSTANT
48 | )
49 |
50 | @pytest.fixture
51 | def target(self):
52 | """Fixture for a valid target constant."""
53 | return graph.PlanGraphNode(
54 | "c1", "c1", typing=["t1", "t2"], label=graph.Label.CONSTANT
55 | )
56 |
57 | @pytest.fixture
58 | def source_incorrect_label(self):
59 | """Fixture for a source constant with an incorrect label."""
60 | return graph.PlanGraphNode(
61 | "o1", "o1", typing=["t1", "t2"], label=graph.Label.PREDICATE
62 | )
63 |
64 | @pytest.fixture
65 | def target_incorrect_label(self):
66 | """Fixture for a target constant with an incorrect label."""
67 | return graph.PlanGraphNode(
68 | "c1", "c1", typing=["t1", "t2"], label=graph.Label.PREDICATE
69 | )
70 |
71 | @pytest.fixture
72 | def source_incorrect_typing(self):
73 | """Fixture for a source constant with incorrect typing."""
74 | return graph.PlanGraphNode(
75 | "o1", "o1", typing=["ty1", "ty2"], label=graph.Label.CONSTANT
76 | )
77 |
78 | @pytest.fixture
79 | def target_incorrect_typing(self):
80 | """Fixture for a target constant with incorrect typing."""
81 | return graph.PlanGraphNode(
82 | "c1", "c1", typing=["ty1", "ty2"], label=graph.Label.CONSTANT
83 | )
84 |
85 | @pytest.fixture
86 | def mapping(self):
87 | """Fixture for a valid mapping between source and target constants."""
88 | return {"o1": "c1"}
89 |
90 | @pytest.fixture
91 | def mapping_incorrect(self):
92 | """Fixture for an incorrect mapping between source and target constants."""
93 | return {"o1": "o1"}
94 |
95 | def test_correct_matching(self, source, target, mapping):
96 | """Test correct matching between source and target constants."""
97 | assert metric._node_matching(source, target, None)
98 | assert metric._node_matching(source, target, mapping)
99 |
100 | assert metric._same_typing(source, target)
101 | assert metric._preserves_mapping(source, target, mapping)
102 |
103 | def test_incorrect_label(
104 | self, source, target, source_incorrect_label, target_incorrect_label, mapping
105 | ):
106 | """Test incorrect label matching between source and target constants."""
107 | assert not metric._node_matching(source, target_incorrect_label, None)
108 | assert not metric._node_matching(source_incorrect_label, target, None)
109 | assert not metric._node_matching(source, target_incorrect_label, mapping)
110 | assert not metric._node_matching(source_incorrect_label, target, mapping)
111 |
112 | assert not metric._preserves_mapping(source, target_incorrect_label, mapping)
113 | assert not metric._preserves_mapping(source_incorrect_label, target, mapping)
114 |
115 | assert not metric._same_typing(source, target_incorrect_label)
116 | assert not metric._same_typing(source_incorrect_label, target)
117 |
118 | def test_incorrect_typing(
119 | self, source, target, source_incorrect_typing, target_incorrect_typing, mapping
120 | ):
121 | """Test incorrect typing between source and target constants."""
122 | assert not metric._node_matching(source, target_incorrect_typing, None)
123 | assert not metric._node_matching(source_incorrect_typing, target, None)
124 | assert not metric._node_matching(source, target_incorrect_typing, mapping)
125 | assert not metric._node_matching(source_incorrect_typing, target, mapping)
126 |
127 | assert metric._preserves_mapping(source, target_incorrect_typing, mapping)
128 | assert metric._preserves_mapping(source_incorrect_typing, target, mapping)
129 |
130 | assert not metric._same_typing(source, target_incorrect_typing)
131 | assert not metric._same_typing(source_incorrect_typing, target)
132 |
133 | def test_incorrect_mapping(self, source, target, mapping_incorrect):
134 | """Test incorrect mapping between source and target constants."""
135 | assert not metric._node_matching(source, target, mapping_incorrect)
136 | assert not metric._preserves_mapping(source, target, mapping_incorrect)
137 |
138 |
139 | class TestPredicateMatching:
140 | """
141 | Test suite for predicate matching functions in the metric module.
142 | """
143 |
144 | @pytest.fixture
145 | def source(self):
146 | """Fixture for a valid source predicate node."""
147 | return graph.PlanGraphNode(
148 | "f-a1-a2",
149 | "f-a1-a2",
150 | typing="f",
151 | label=graph.Label.PREDICATE,
152 | )
153 |
154 | @pytest.fixture
155 | def target(self):
156 | """Fixture for a valid target predicate node."""
157 | return graph.PlanGraphNode(
158 | "f-a1-a2",
159 | "f-a1-a2",
160 | typing="f",
161 | label=graph.Label.PREDICATE,
162 | )
163 |
164 | @pytest.fixture
165 | def source_incorrect_label(self):
166 | """Fixture for a source predicate node with an incorrect label."""
167 | return graph.PlanGraphNode(
168 | "f-a1-a2",
169 | "f-a1-a2",
170 | typing="f",
171 | label=graph.Label.CONSTANT,
172 | )
173 |
174 | @pytest.fixture
175 | def target_incorrect_label(self):
176 | """Fixture for a target predicate node with an incorrect label."""
177 | return graph.PlanGraphNode(
178 | "f-a1-a2",
179 | "f-a1-a2",
180 | typing="f",
181 | label=graph.Label.CONSTANT,
182 | )
183 |
184 | def test_correct_matching(self, source, target):
185 | """Test correct matching between source and target predicate nodes."""
186 | assert metric._node_matching(source, target, None)
187 |
188 | def test_incorrect_label(
189 | self,
190 | source,
191 | target,
192 | source_incorrect_label,
193 | target_incorrect_label,
194 | ):
195 | """Test incorrect label matching between source and target predicate nodes."""
196 | assert not metric._node_matching(source, target_incorrect_label, None)
197 | assert not metric._node_matching(source_incorrect_label, target, None)
198 |
199 |
200 | class TestMetrics:
201 | """
202 | Test suite for metrics functions in the metric module.
203 | """
204 |
205 | def test_map(self, problem_string, two_initial_problem_string):
206 | """Test the mapping function on graph pairs."""
207 | problem_graph = builder.build(problem_string)
208 | problem_graph2 = builder.build(two_initial_problem_string)
209 |
210 | assert metric.isomorphic(problem_graph, problem_graph)
211 | assert not metric.isomorphic(problem_graph, problem_graph2)
212 |
213 | def test_validate(self, problem_string, two_initial_problem_string):
214 | """Test the validation function on graph pairs."""
215 | problem_graph = builder.build(problem_string)
216 | problem_graph2 = builder.build(two_initial_problem_string)
217 |
218 | assert metric.equals(problem_graph, problem_graph, is_placeholder=True)
219 | assert not metric.equals(
220 | problem_graph,
221 | problem_graph2,
222 | is_placeholder=True,
223 | )
224 |
225 | def test_swap(self, swap_problem_string, wrong_swap_problem_string):
226 | """
227 | Test the distance function on graph pairs.
228 | """
229 | swap_problem = builder.build(swap_problem_string)
230 | wrong_swap = builder.build(wrong_swap_problem_string)
231 |
232 | # Test validate
233 | assert metric.equals(swap_problem, swap_problem, is_placeholder=False)
234 | assert not metric.equals(swap_problem, wrong_swap, is_placeholder=False)
235 | assert metric.equals(swap_problem, wrong_swap, is_placeholder=True)
236 |
237 | def test_move(self, move_problem_string, wrong_move_problem_string):
238 | """
239 | Test the distance function on graph pairs.
240 | """
241 | move_problem = builder.build(move_problem_string)
242 | wrong_move = builder.build(wrong_move_problem_string)
243 |
244 | # Test validate
245 | assert metric.equals(move_problem, move_problem, is_placeholder=True)
246 | assert not metric.equals(move_problem, wrong_move, is_placeholder=True)
247 |
248 | def test_blocksworld_equivalence(
249 | self,
250 | subtests,
251 | blocksworld_fully_specified,
252 | blocksworld_missing_clears,
253 | blocksworld_missing_ontables,
254 | blocksworld_underspecified,
255 | ):
256 | """Test the equivalence of blocksworld problems."""
257 | p1 = builder.build(blocksworld_fully_specified)
258 | p2 = builder.build(blocksworld_missing_clears)
259 | p3 = builder.build(blocksworld_missing_ontables)
260 | p4 = builder.build(blocksworld_underspecified)
261 |
262 | p1 = oracle.fully_specify(p1)
263 | p2 = oracle.fully_specify(p2)
264 | p3 = oracle.fully_specify(p3)
265 | p4 = oracle.fully_specify(p4)
266 |
267 | P = (
268 | ("blocksworld_fully_specified", p1),
269 | ("blocksworld_missing_clears", p2),
270 | ("blocksworld_missing_ontables", p3),
271 | ("blocksworld_underspecified", p4),
272 | )
273 |
274 | # equivalence to itself
275 | for name, p in P:
276 | with subtests.test(f"{name} equals {name}"):
277 | assert metric.equals(p, p, is_placeholder=True)
278 | assert metric.equals(p, p, is_placeholder=False)
279 |
280 | # check invalid equivalence
281 |
282 | for idx1, idx2 in (
283 | (0, 3),
284 | (1, 3),
285 | (2, 3),
286 | ):
287 | (name1, p1), (name2, p2) = P[idx1], P[idx2]
288 | with subtests.test(f"{name1} not equals {name2}"):
289 | assert not metric.equals(p1, p2, is_placeholder=True)
290 | assert not metric.equals(p1, p2, is_placeholder=False)
291 | assert not metric.equals(p2, p1, is_placeholder=True)
292 | assert not metric.equals(p2, p1, is_placeholder=False)
293 |
294 | def test_rover_single_eqquivalence(
295 | self,
296 | subtests,
297 | rover_single_line_fully_specified_4,
298 | rover_single_line_fully_specified_4a,
299 | ):
300 | """Test the equivalence of rover single line problems."""
301 | p1 = builder.build(rover_single_line_fully_specified_4)
302 | p2 = builder.build(rover_single_line_fully_specified_4a)
303 |
304 | p1 = oracle.fully_specify(p1)
305 | p2 = oracle.fully_specify(p2)
306 |
307 | # equivalence to itself
308 | assert metric.equals(p1, p1, is_placeholder=True)
309 | assert metric.equals(p2, p2, is_placeholder=False)
310 |
311 | # check invalid equivalence
312 | assert metric.equals(p1, p2, is_placeholder=True)
313 | assert metric.equals(p1, p2, is_placeholder=False)
314 | assert metric.equals(p2, p1, is_placeholder=True)
315 | assert metric.equals(p2, p1, is_placeholder=False)
316 |
--------------------------------------------------------------------------------
/tests/test_oracle.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from planetarium import builder, graph, oracle
4 |
5 | from .problem_fixtures import (
6 | blocksworld_fully_specified,
7 | blocksworld_missing_clears,
8 | blocksworld_missing_ontables,
9 | blocksworld_underspecified,
10 | blocksworld_underspecified_arm,
11 | blocksworld_holding,
12 | gripper_fully_specified,
13 | gripper_no_goal_types,
14 | gripper_fully_specified_not_strict,
15 | gripper_no_robby,
16 | gripper_underspecified_1,
17 | gripper_underspecified_2,
18 | gripper_underspecified_3,
19 | gripper_no_robby_init,
20 | rover_single_line_fully_specified,
21 | rover_single_line_fully_specified_1,
22 | rover_single_line_fully_specified_2,
23 | rover_single_line_fully_specified_3,
24 | rover_single_line_fully_specified_4,
25 | floortile_fully_specified,
26 | floortile_underspecified_directions,
27 | floortile_no_white1,
28 | floortile_no_white1a,
29 | floortile_no_white2,
30 | floortile_no_available_colors,
31 | floortile_disconnected_tile_no_white,
32 | floortile_disconnected_tile1,
33 | floortile_disconnected_tile1a,
34 | floortile_one_color_one_robot1,
35 | )
36 |
37 |
38 | def reduce_and_inflate(scene: graph.SceneGraph) -> bool:
39 | """Respecify a scene and check if it is equal to the original.
40 |
41 | Args:
42 | scene (graph.SceneGraph): The scene to test
43 |
44 | Returns:
45 | bool: True if the respecified scene is equal to the original.
46 | """
47 | reduced = oracle.reduce(scene, domain=scene.domain)
48 | respecified = oracle.inflate(reduced, domain=scene.domain)
49 | return scene == respecified
50 |
51 |
52 | class TestBlocksworldOracle:
53 | """
54 | Test suite for the blocksworld oracle.
55 | """
56 |
57 | def test_fully_specified(self, blocksworld_fully_specified):
58 | """
59 | Test the fully specified blocksworld problem.
60 | """
61 | problem = builder.build(blocksworld_fully_specified)
62 | full = oracle.fully_specify(problem)
63 | assert oracle.fully_specify(full) == full
64 |
65 | def test_missing_clears(self, blocksworld_missing_clears):
66 | """
67 | Test the fully specified blocksworld problem with missing clears.
68 | """
69 | problem = builder.build(blocksworld_missing_clears)
70 | full = oracle.fully_specify(problem)
71 | assert oracle.fully_specify(full) == full
72 |
73 | def test_missing_ontables(self, blocksworld_missing_ontables):
74 | """
75 | Test the fully specified blocksworld problem with missing clears.
76 | """
77 | problem = builder.build(blocksworld_missing_ontables)
78 | full = oracle.fully_specify(problem)
79 | assert oracle.fully_specify(full) == full
80 |
81 | def test_missing_ontables_and_clears(self, blocksworld_underspecified):
82 | """
83 | Test the fully specified blocksworld problem with missing clears.
84 | """
85 | problem = builder.build(blocksworld_underspecified)
86 | full = oracle.fully_specify(problem)
87 | assert oracle.fully_specify(full) == full
88 |
89 | def test_inflate(
90 | self,
91 | subtests,
92 | blocksworld_fully_specified,
93 | blocksworld_missing_clears,
94 | blocksworld_missing_ontables,
95 | blocksworld_underspecified,
96 | blocksworld_underspecified_arm,
97 | blocksworld_holding,
98 | ):
99 | """
100 | Test the inflate function.
101 | """
102 |
103 | for name, desc in {
104 | "blocksworld_fully_specified": blocksworld_fully_specified,
105 | "blocksworld_missing_clears": blocksworld_missing_clears,
106 | "blocksworld_missing_ontables": blocksworld_missing_ontables,
107 | "blocksworld_underspecified": blocksworld_underspecified,
108 | "blocksworld_underspecified_arm": blocksworld_underspecified_arm,
109 | "blocksworld_holding": blocksworld_holding,
110 | }.items():
111 | problem = builder.build(desc)
112 | init, goal = problem.decompose()
113 | with subtests.test(name):
114 | assert reduce_and_inflate(init)
115 | assert reduce_and_inflate(goal)
116 | assert reduce_and_inflate(problem)
117 |
118 | assert problem == oracle.inflate(
119 | oracle.ReducedProblemGraph.join(
120 | oracle.reduce(init),
121 | oracle.reduce(goal),
122 | )
123 | )
124 |
125 |
126 | class TestGripperOracle:
127 | """
128 | Test suite for the gripper oracle.
129 | """
130 |
131 | def test_fully_specified(
132 | self,
133 | subtests,
134 | gripper_fully_specified,
135 | gripper_no_goal_types,
136 | gripper_fully_specified_not_strict,
137 | ):
138 | """
139 | Test the fully specified gripper problem.
140 | """
141 | descs = [
142 | ("gripper_fully_specified", gripper_fully_specified),
143 | ("gripper_no_goal_types", gripper_no_goal_types),
144 | ("gripper_fully_specified_not_strict", gripper_fully_specified_not_strict),
145 | ]
146 | for name, desc in descs:
147 | with subtests.test(name):
148 | problem = builder.build(desc)
149 | full = oracle.fully_specify(problem)
150 | assert oracle.fully_specify(full) == full
151 |
152 | def test_inflate(
153 | self,
154 | subtests,
155 | gripper_fully_specified,
156 | gripper_no_robby,
157 | gripper_underspecified_1,
158 | gripper_underspecified_2,
159 | gripper_underspecified_3,
160 | gripper_no_robby_init,
161 | ):
162 | """
163 | Test the inflate function.
164 | """
165 |
166 | descs = [
167 | ("gripper_fully_specified", gripper_fully_specified),
168 | ("gripper_no_robby", gripper_no_robby),
169 | ("gripper_underspecified_1", gripper_underspecified_1),
170 | ("gripper_underspecified_2", gripper_underspecified_2),
171 | ("gripper_underspecified_3", gripper_underspecified_3),
172 | ("gripper_no_robby_init", gripper_no_robby_init),
173 | ]
174 |
175 | for name, desc in descs:
176 | problem = builder.build(desc)
177 | init, goal = problem.decompose()
178 | with subtests.test(name):
179 | assert reduce_and_inflate(init)
180 | assert reduce_and_inflate(goal)
181 | assert reduce_and_inflate(problem)
182 |
183 | def test_underspecified(
184 | self,
185 | gripper_underspecified_1,
186 | gripper_underspecified_2,
187 | ):
188 | problem = builder.build(gripper_underspecified_1)
189 | full = oracle.fully_specify(problem)
190 | assert oracle.fully_specify(full) == full
191 |
192 | problem = builder.build(gripper_underspecified_2)
193 | full = oracle.fully_specify(problem)
194 | assert oracle.fully_specify(full) == full
195 |
196 |
197 | class TestRoverSingleOracle:
198 | """
199 | Test suite for the rover oracle.
200 | """
201 |
202 | def test_fully_specified(
203 | self,
204 | subtests,
205 | rover_single_line_fully_specified,
206 | rover_single_line_fully_specified_1,
207 | rover_single_line_fully_specified_2,
208 | rover_single_line_fully_specified_3,
209 | rover_single_line_fully_specified_4,
210 | ):
211 | """
212 | Test the fully specified rover problem.
213 | """
214 | descs = [
215 | ("rover_single_line_fully_specified", rover_single_line_fully_specified),
216 | (
217 | "rover_single_line_fully_specified_1",
218 | rover_single_line_fully_specified_1,
219 | ),
220 | (
221 | "rover_single_line_fully_specified_2",
222 | rover_single_line_fully_specified_2,
223 | ),
224 | (
225 | "rover_single_line_fully_specified_3",
226 | rover_single_line_fully_specified_3,
227 | ),
228 | (
229 | "rover_single_line_fully_specified_4",
230 | rover_single_line_fully_specified_4,
231 | ),
232 | ]
233 | for name, desc in descs:
234 | with subtests.test(name):
235 | problem = builder.build(desc)
236 | full = oracle.fully_specify(problem)
237 | assert full == problem, "fully_specify(problem) == problem"
238 | assert oracle.fully_specify(full) == full, "fully_specify(fully_specify(problem)) == fully_specify(problem)"
239 |
240 | def test_inflate(
241 | self,
242 | subtests,
243 | rover_single_line_fully_specified,
244 | rover_single_line_fully_specified_1,
245 | rover_single_line_fully_specified_2,
246 | rover_single_line_fully_specified_3,
247 | rover_single_line_fully_specified_4,
248 | ):
249 | """
250 | Test the inflate function.
251 | """
252 | descs = [
253 | ("rover_single_line_fully_specified", rover_single_line_fully_specified),
254 | (
255 | "rover_single_line_fully_specified_1",
256 | rover_single_line_fully_specified_1,
257 | ),
258 | (
259 | "rover_single_line_fully_specified_2",
260 | rover_single_line_fully_specified_2,
261 | ),
262 | (
263 | "rover_single_line_fully_specified_3",
264 | rover_single_line_fully_specified_3,
265 | ),
266 | (
267 | "rover_single_line_fully_specified_4",
268 | rover_single_line_fully_specified_4,
269 | ),
270 | ]
271 | for name, desc in descs:
272 | problem = builder.build(desc)
273 | init, goal = problem.decompose()
274 | with subtests.test(name):
275 | assert reduce_and_inflate(init)
276 | assert reduce_and_inflate(goal)
277 | assert reduce_and_inflate(problem)
278 |
279 |
280 | class TestFloorTileOracle:
281 | """
282 | Test suite for the floor tile oracle.
283 | """
284 |
285 | def test_fully_specified(
286 | self,
287 | subtests,
288 | floortile_fully_specified,
289 | floortile_no_white1,
290 | floortile_no_white2,
291 | floortile_no_available_colors,
292 | floortile_disconnected_tile_no_white,
293 | floortile_disconnected_tile1,
294 | floortile_one_color_one_robot1,
295 | ):
296 | """
297 | Test the fully specified floor tile problem.
298 | """
299 | descs = {
300 | "floortile_fully_specified": floortile_fully_specified,
301 | "floortile_no_white1": floortile_no_white1,
302 | "floortile_no_white2": floortile_no_white2,
303 | "floortile_no_available_colors": floortile_no_available_colors,
304 | "floortile_disconnected_tile_no_white": floortile_disconnected_tile_no_white,
305 | "floortile_disconnected_tile1": floortile_disconnected_tile1,
306 | "floortile_one_color_one_robot1": floortile_one_color_one_robot1,
307 | }
308 | for name, desc in descs.items():
309 | problem = builder.build(desc)
310 | full = oracle.fully_specify(problem)
311 | with subtests.test(name):
312 | assert full == problem, name
313 | assert oracle.fully_specify(full) == full, name
314 |
315 | def test_under_specified(
316 | self,
317 | subtests,
318 | floortile_underspecified_directions,
319 | floortile_no_white1a,
320 | floortile_disconnected_tile1a,
321 | ):
322 | """
323 | Test the under specified floor tile problem.
324 | """
325 | descs = {
326 | "floortile_underspecified_directions": floortile_underspecified_directions,
327 | "floortile_no_white1a": floortile_no_white1a,
328 | "floortile_disconnected_tile1a": floortile_disconnected_tile1a,
329 | }
330 | for name, desc in descs.items():
331 | problem = builder.build(desc)
332 | full = oracle.fully_specify(problem)
333 | with subtests.test(name):
334 | assert full != problem, name
335 | assert oracle.fully_specify(full) == full, name
336 |
337 | def test_infalte(
338 | self,
339 | subtests,
340 | floortile_fully_specified,
341 | floortile_no_white1,
342 | floortile_no_white2,
343 | floortile_no_available_colors,
344 | floortile_disconnected_tile_no_white,
345 | floortile_disconnected_tile1,
346 | floortile_one_color_one_robot1,
347 | floortile_underspecified_directions,
348 | floortile_no_white1a,
349 | floortile_disconnected_tile1a,
350 | ):
351 | """
352 | Test the inflate function.
353 | """
354 | descs = {
355 | "floortile_fully_specified": floortile_fully_specified,
356 | "floortile_no_white1": floortile_no_white1,
357 | "floortile_no_white2": floortile_no_white2,
358 | "floortile_no_available_colors": floortile_no_available_colors,
359 | "floortile_disconnected_tile_no_white": floortile_disconnected_tile_no_white,
360 | "floortile_disconnected_tile1": floortile_disconnected_tile1,
361 | "floortile_one_color_one_robot1": floortile_one_color_one_robot1,
362 | "floortile_underspecified_directions": floortile_underspecified_directions,
363 | "floortile_no_white1a": floortile_no_white1a,
364 | "floortile_disconnected_tile1a": floortile_disconnected_tile1a,
365 | }
366 | for name, desc in descs.items():
367 | problem = builder.build(desc)
368 | init, goal = problem.decompose()
369 | with subtests.test(name):
370 | assert reduce_and_inflate(init)
371 | assert reduce_and_inflate(goal)
372 | assert reduce_and_inflate(problem)
373 |
374 |
375 | class TestUnsupportedDomain:
376 | def test_reduce_and_inflate(self, gripper_fully_specified):
377 | problem = builder.build(gripper_fully_specified)
378 | init, goal = problem.decompose()
379 |
380 | with pytest.raises(oracle.DomainNotSupportedError):
381 | oracle.reduce(init, domain="gripper-modified")
382 | with pytest.raises(oracle.DomainNotSupportedError):
383 | reduced = oracle.reduce(goal, domain="gripper")
384 | oracle.inflate(reduced, domain="gripper-modified")
385 |
386 | def test_fully_specify(self, gripper_fully_specified):
387 | problem = builder.build(gripper_fully_specified)
388 | with pytest.raises(oracle.DomainNotSupportedError):
389 | oracle.fully_specify(problem, domain="gripper-modified")
390 |
391 | def test_plan(self, gripper_fully_specified):
392 | problem = builder.build(gripper_fully_specified)
393 | with pytest.raises(oracle.DomainNotSupportedError):
394 | oracle.plan(problem, domain="gripper-modified")
395 |
--------------------------------------------------------------------------------
/tests/test_pddl.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from planetarium import builder
4 |
5 | from pddl.parser.problem import LenientProblemParser
6 |
7 |
8 | @pytest.fixture
9 | def problem_string():
10 | """
11 | Fixture providing a sample PDDL problem definition as a string.
12 | """
13 | return """
14 | (define (problem mixed-f4-p2-u0-v0-g0-a0-n0-A0-B0-N0-F0-r0)
15 | (:domain miconic)
16 | (:objects p0 p1 - passenger
17 | f0 f1 f2 f3 - floor)
18 |
19 | (:init
20 | (above f0 f1)
21 | (above f0 f2)
22 | (above f0 f3)
23 | (above f1 f2)
24 | (above f1 f3)
25 | (above f2 f3)
26 | (origin p0 f3)
27 | (destin p0 f2)
28 | (origin p1 f1)
29 | (destin p1 f3)
30 | (lift-at f0))
31 |
32 | (:goal (and
33 | (served p0)
34 | (served p1)))
35 | )
36 | """
37 |
38 |
39 | @pytest.fixture
40 | def two_initial_problem_string():
41 | """
42 | Fixture providing a sample PDDL problem definition as a string.
43 | """
44 | return """
45 | (define (problem mixed-f4-p2-u0-v0-g0-a0-n0-A0-B0-N0-F0-r0)
46 | (:domain miconic)
47 | (:objects p0 p1 - passenger
48 | f0 f1 f2 f3 - floor)
49 |
50 | (:init
51 | (above f0 f1)
52 | (above f0 f2)
53 | (above f0 f3)
54 | (above f1 f2)
55 | (above f1 f3)
56 | (above f2 f3)
57 | (origin p0 f3)
58 | (destin p0 f2)
59 | (origin p1 f1)
60 | (destin p1 f3)
61 | (lift-at f0))
62 |
63 | (:goal (and
64 | (above f0 f1)
65 | (above f0 f2)
66 | (above f0 f3)
67 | (above f1 f2)
68 | (above f1 f3)
69 | (above f2 f3)
70 | (origin p0 f3)
71 | (destin p0 f2)
72 | (origin p1 f1)
73 | (destin p1 f3)
74 | (lift-at f0)))
75 | )
76 | """
77 |
78 |
79 | @pytest.fixture
80 | def renamed_problem_string():
81 | """
82 | Fixture providing a sample PDDL problem definition as a string.
83 | """
84 | return """
85 | (define (problem mixed-f4-p2-u0-v0-g0-a0-n0-A0-B0-N0-F0-r0)
86 | (:domain miconic)
87 | (:objects p1 p0 - passenger
88 | f1 f2 f3 f0 - floor)
89 |
90 | (:init
91 | (above f1 f2)
92 | (above f1 f3)
93 | (above f1 f0)
94 | (above f2 f3)
95 | (above f2 f0)
96 | (above f3 f0)
97 | (origin p1 f0)
98 | (destin p1 f3)
99 | (origin p0 f2)
100 | (destin p0 f0)
101 | (lift-at f1))
102 |
103 | (:goal (and
104 | (served p1)
105 | (served p0)))
106 | )
107 | """
108 |
109 |
110 | @pytest.fixture
111 | def wrong_problem_string():
112 | """
113 | Fixture providing a sample PDDL problem definition as a string.
114 | """
115 | return """
116 | (define (problem mixed-f4-p2-u0-v0-g0-a0-n0-A0-B0-N0-F0-r0)
117 | (:domain miconic)
118 | (:objects p1 p0 - passenger
119 | f1 f2 f3 f0 - floor)
120 |
121 | (:init
122 | (above f1 f2)
123 | (above f1 f3)
124 | (above f1 f0)
125 | (above f2 f3)
126 | (above f2 f0)
127 | (above f3 f0)
128 | (origin p1 f0)
129 | (destin p1 f3)
130 | (origin p0 f2)
131 | (destin p0 f0)
132 | (lift-at f1))
133 |
134 | (:goal (and
135 | (served f3)
136 | (served p0)))
137 | )
138 | """
139 |
140 |
141 | @pytest.fixture
142 | def wrong_initial_problem_string():
143 | """
144 | Fixture providing a sample PDDL problem definition as a string.
145 | """
146 | return """
147 | (define (problem mixed-f4-p2-u0-v0-g0-a0-n0-A0-B0-N0-F0-r0)
148 | (:domain miconic)
149 | (:objects p1 p0 - passenger
150 | f1 f2 f3 f0 - floor)
151 |
152 | (:init
153 | (above f1 f2)
154 | (above f1 f3)
155 | (above f1 f0)
156 | (above f2 f3)
157 | (above f2 f0)
158 | (above f3 f0)
159 | (destin p1 f3)
160 | (origin p0 f2)
161 | (destin p0 f0)
162 | (lift-at f1))
163 |
164 | (:goal (and
165 | (served p1)
166 | (served p0)))
167 | )
168 | """
169 |
170 |
171 | @pytest.fixture
172 | def swap_problem_string():
173 | """
174 | Fixture providing a sample PDDL problem definition as a string.
175 | """
176 | return """
177 | (define (problem swap)
178 | (:domain swap)
179 | (:objects a0 a1 - object
180 | b0 b1 - room)
181 |
182 | (:init
183 | (in a0 b0)
184 | (in a1 b1))
185 |
186 | (:goal (and
187 | (in a0 b1)
188 | (in a1 b0)))
189 | )
190 | """
191 |
192 |
193 | @pytest.fixture
194 | def wrong_swap_problem_string():
195 | """
196 | Fixture providing a sample PDDL problem definition as a string.
197 | """
198 | return """
199 | (define (problem swap)
200 | (:domain swap)
201 | (:objects a0 a1 - object
202 | b0 b1 - room)
203 |
204 | (:init
205 | (in a0 b0)
206 | (in a1 b1))
207 |
208 | (:goal (and
209 | (in a0 b0)
210 | (in a1 b1)))
211 | )
212 | """
213 |
214 |
215 | @pytest.fixture
216 | def move_problem_string():
217 | """
218 | Fixture providing a sample PDDL problem definition as a string.
219 | """
220 | return """
221 | (define (problem move)
222 | (:domain move)
223 | (:objects a0 a1 - object
224 | b0 b1 - room)
225 |
226 | (:init
227 | (in a0 b0)
228 | (in a1 b0))
229 |
230 | (:goal (and
231 | (in a0 b1)
232 | (in a1 b0)))
233 | )
234 | """
235 |
236 |
237 | @pytest.fixture
238 | def wrong_move_problem_string():
239 | """
240 | Fixture providing a sample PDDL problem definition as a string.
241 | """
242 | return """
243 | (define (problem move)
244 | (:domain move)
245 | (:objects a0 a1 - object
246 | b0 b1 - room)
247 |
248 | (:init
249 | (in a0 b0)
250 | (in a1 b1))
251 |
252 | (:goal (and
253 | (in a0 b0)
254 | (in a1 b1)))
255 | )
256 | """
257 |
258 |
259 | @pytest.fixture
260 | def single_predicate_goal():
261 | """
262 | Fixture providing a sample PDDL problem definition as a string.
263 | """
264 | return """
265 | (define (problem move)
266 | (:domain move)
267 | (:objects a0 a1 - object
268 | b0 b1 - room)
269 |
270 | (:init
271 | (in a0 b0)
272 | (in a1 b1))
273 |
274 | (:goal (in a0 b0))
275 | )
276 | """
277 |
278 |
279 | @pytest.fixture
280 | def not_predicate_goal():
281 | """
282 | Fixture providing a sample PDDL problem definition as a string.
283 | """
284 | return """
285 | (define (problem move)
286 | (:domain move)
287 | (:objects a0 a1 - object
288 | b0 b1 - room)
289 |
290 | (:init
291 | (in a0 b0)
292 | (in a1 b1))
293 |
294 | (:goal (not
295 | (in a0 b0)))
296 | )
297 | """
298 |
299 |
300 | @pytest.fixture
301 | def problem(problem_string):
302 | """
303 | Fixture providing a parsed PDDL problem object.
304 | """
305 | return LenientProblemParser()(problem_string)
306 |
307 |
308 | class TestConstantToDict:
309 | """
310 | Test suite for the _constant_to_dict function.
311 | """
312 |
313 | def test_constat_name(self, problem):
314 | """
315 | Test the conversion of a PDDL Constant to a dictionary with the correct name.
316 | """
317 | constant = list(problem.objects)[0]
318 | assert builder._constant_to_dict(constant)["name"] == str(constant.name)
319 |
320 | def test_constat_type(self, problem):
321 | """
322 | Test the conversion of a PDDL Constant to a dictionary with the correct typing.
323 | """
324 | constant = list(problem.objects)[0]
325 | result_dict = builder._constant_to_dict(constant)
326 | assert (
327 | result_dict["typing"] == constant.type_tags
328 | and type(result_dict["typing"]) == set
329 | )
330 |
331 |
332 | class TestPredicateToDict:
333 | """
334 | Test suite for the _predicate_to_dict function.
335 | """
336 |
337 | def test_predicate_name(self, problem):
338 | """
339 | Test the conversion of a PDDL Predicate to a dictionary with the correct name.
340 | """
341 | predicate = list(problem.init)[0]
342 | assert builder._predicate_to_dict(predicate)["typing"] == str(predicate.name)
343 |
344 | def test_predicate_parameters(self, problem):
345 | """
346 | Test the conversion of a PDDL Predicate to a dictionary with the correct parameters.
347 | """
348 | predicate = list(problem.init)[0]
349 | result_dict = builder._predicate_to_dict(predicate)
350 | assert (
351 | result_dict["parameters"] == [term.name for term in predicate.terms]
352 | and type(result_dict["parameters"]) == list
353 | )
354 |
355 |
356 | class TestBuildConstants:
357 | """
358 | Test suite for the _build_constants function.
359 | """
360 |
361 | def test_size(self, problem):
362 | """
363 | Test the size of the list of constants built from a PDDL problem.
364 | """
365 | assert len(builder._build_constants(problem.objects)) == len(problem.objects)
366 |
367 |
368 | class TestBuildPredicates:
369 | """
370 | Test suite for the _build_predicates function.
371 | """
372 |
373 | def test_initial_size(self, problem):
374 | """
375 | Test the size of the list of initial predicates built from a PDDL problem.
376 | """
377 | assert len(builder._build_predicates(problem.init)) == len(problem.init)
378 |
379 | def test_goal_size(self, problem):
380 | """
381 | Test the size of the list of goal predicates built from a PDDL problem.
382 | """
383 | assert len(builder._build_predicates(problem.goal.operands)) == len(
384 | problem.goal.operands
385 | )
386 |
387 |
388 | class TestBuild:
389 | """
390 | Test suite for the build function.
391 | """
392 |
393 | def test_node_size(self, problem_string):
394 | """
395 | Test the size of nodes in the scene graphs built from a PDDL problem.
396 | """
397 | graph_1, graph_2 = builder.build(problem_string).decompose()
398 | assert len(graph_1.nodes) == 17 and len(graph_2.nodes) == 8
399 |
400 | def test_edge_size(self, problem_string):
401 | """
402 | Test the size of edges in the scene graphs built from a PDDL problem.
403 | """
404 | graph_1, graph_2 = builder.build(problem_string).decompose()
405 | assert len(graph_1.edges) == 21 and len(graph_2.edges) == 2
406 |
407 | def test_edge_size(self, problem_string):
408 | """
409 | Test the size of edges in the scene graphs built from a PDDL problem.
410 | """
411 | modified_problem_string = f"Here is an example of a problem string that is not a PDDL problem. ```pddl\n{problem_string}\n```"
412 | graph_1, graph_2 = builder.build(modified_problem_string).decompose()
413 | assert len(graph_1.edges) == 21 and len(graph_2.edges) == 2
414 |
415 | def test_single_predicate_goal(self, single_predicate_goal):
416 | """
417 | Test the size of nodes in the scene graphs built from a PDDL problem.
418 | """
419 | builder.build(single_predicate_goal).decompose()
420 |
421 |
422 | def test_to_pddl_str(self, single_predicate_goal):
423 | """
424 | Test the size of nodes in the scene graphs built from a PDDL problem.
425 | """
426 | builder.build(single_predicate_goal).to_pddl_str()
427 |
428 | def test_not_predicate_goal(self, not_predicate_goal):
429 | """
430 | Test the size of nodes in the scene graphs built from a PDDL problem.
431 | """
432 | with pytest.raises(ValueError):
433 | builder.build(not_predicate_goal).decompose()
434 |
--------------------------------------------------------------------------------
/tests/test_planner.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 |
4 | VALIDATE = os.getenv("VALIDATE", "Validate")
5 |
6 | from planetarium import builder, downward, oracle
7 |
8 | from .problem_fixtures import (
9 | blocksworld_fully_specified,
10 | blocksworld_holding,
11 | blocksworld_missing_clears,
12 | blocksworld_missing_ontables,
13 | blocksworld_underspecified,
14 | blocksworld_underspecified_arm,
15 | blocksworld_stack_to_holding,
16 | blocksworld_invalid_1,
17 | blocksworld_invalid_2,
18 | blocksworld_invalid_3,
19 | gripper_fully_specified,
20 | gripper_fully_specified_not_strict,
21 | gripper_inconsistent_typing,
22 | gripper_missing_typing,
23 | gripper_multiple_typing,
24 | gripper_no_goal_types,
25 | gripper_no_robby,
26 | gripper_robby_at_last,
27 | gripper_underspecified_1,
28 | gripper_underspecified_2,
29 | gripper_underspecified_3,
30 | gripper_invalid,
31 | )
32 |
33 | DOMAINS = {
34 | "blocksworld": """;; source: https://github.com/AI-Planning/pddl-generators/blob/main/blocksworld/domain.pddl
35 | ;; same as used in IPC 2023
36 | ;;
37 | (define (domain blocksworld)
38 |
39 | (:requirements :strips)
40 |
41 | (:predicates (clear ?x)
42 | (on-table ?x)
43 | (arm-empty)
44 | (holding ?x)
45 | (on ?x ?y))
46 |
47 | (:action pickup
48 | :parameters (?ob)
49 | :precondition (and (clear ?ob) (on-table ?ob) (arm-empty))
50 | :effect (and (holding ?ob) (not (clear ?ob)) (not (on-table ?ob))
51 | (not (arm-empty))))
52 |
53 | (:action putdown
54 | :parameters (?ob)
55 | :precondition (holding ?ob)
56 | :effect (and (clear ?ob) (arm-empty) (on-table ?ob)
57 | (not (holding ?ob))))
58 |
59 | (:action stack
60 | :parameters (?ob ?underob)
61 | :precondition (and (clear ?underob) (holding ?ob))
62 | :effect (and (arm-empty) (clear ?ob) (on ?ob ?underob)
63 | (not (clear ?underob)) (not (holding ?ob))))
64 |
65 | (:action unstack
66 | :parameters (?ob ?underob)
67 | :precondition (and (on ?ob ?underob) (clear ?ob) (arm-empty))
68 | :effect (and (holding ?ob) (clear ?underob)
69 | (not (on ?ob ?underob)) (not (clear ?ob)) (not (arm-empty)))))
70 | """,
71 | "gripper": """;; source: https://github.com/AI-Planning/pddl-generators/blob/main/gripper/domain.pddl
72 | (define (domain gripper)
73 | (:requirements :strips)
74 | (:predicates (room ?r)
75 | (ball ?b)
76 | (gripper ?g)
77 | (at-robby ?r)
78 | (at ?b ?r)
79 | (free ?g)
80 | (carry ?o ?g))
81 |
82 | (:action move
83 | :parameters (?from ?to)
84 | :precondition (and (room ?from) (room ?to) (at-robby ?from))
85 | :effect (and (at-robby ?to)
86 | (not (at-robby ?from))))
87 |
88 | (:action pick
89 | :parameters (?obj ?room ?gripper)
90 | :precondition (and (ball ?obj) (room ?room) (gripper ?gripper)
91 | (at ?obj ?room) (at-robby ?room) (free ?gripper))
92 | :effect (and (carry ?obj ?gripper)
93 | (not (at ?obj ?room))
94 | (not (free ?gripper))))
95 |
96 | (:action drop
97 | :parameters (?obj ?room ?gripper)
98 | :precondition (and (ball ?obj) (room ?room) (gripper ?gripper)
99 | (carry ?obj ?gripper) (at-robby ?room))
100 | :effect (and (at ?obj ?room)
101 | (free ?gripper)
102 | (not (carry ?obj ?gripper)))))
103 | """,
104 | }
105 |
106 |
107 | class TestBlocksworldOracle:
108 | """
109 | Test suite for the blocksworld oracle.
110 | """
111 |
112 | def test_plan(
113 | self,
114 | subtests,
115 | blocksworld_missing_clears,
116 | blocksworld_fully_specified,
117 | blocksworld_holding,
118 | blocksworld_missing_ontables,
119 | blocksworld_underspecified,
120 | blocksworld_underspecified_arm,
121 | blocksworld_stack_to_holding,
122 | ):
123 | """
124 | Test if the oracle can plan for a fully specified blocksworld problem.
125 | """
126 | for name, desc in {
127 | "blocksworld_fully_specified": blocksworld_fully_specified,
128 | "blocksworld_holding": blocksworld_holding,
129 | "blocksworld_missing_clears": blocksworld_missing_clears,
130 | "blocksworld_missing_ontables": blocksworld_missing_ontables,
131 | "blocksworld_underspecified": blocksworld_underspecified,
132 | "blocksworld_underspecified_arm": blocksworld_underspecified_arm,
133 | "blocksworld_stack_to_holding": blocksworld_stack_to_holding,
134 | }.items():
135 | plan = oracle.plan(builder.build(desc))
136 | with subtests.test(name):
137 | assert plan != [], name
138 |
139 | assert downward.validate(
140 | DOMAINS["blocksworld"],
141 | desc,
142 | oracle.plan_to_string(plan),
143 | VALIDATE,
144 | )
145 |
146 | with subtests.test(name):
147 | assert not downward.validate(
148 | DOMAINS["gripper"],
149 | desc,
150 | oracle.plan_to_string(plan),
151 | VALIDATE,
152 | )
153 |
154 | def test_invalid_plan(
155 | self,
156 | subtests,
157 | blocksworld_invalid_2,
158 | ):
159 | """
160 | Test if the oracle can plan for an invalid blocksworld problem.
161 | """
162 | domain = DOMAINS["blocksworld"]
163 | for name, desc in {
164 | "blocksworld_invalid_2": blocksworld_invalid_2,
165 | }.items():
166 | with subtests.test(name):
167 | try:
168 | plan = oracle.plan(builder.build(desc))
169 | except Exception as e:
170 | plan = []
171 | assert plan == [], f"{name}: {plan}"
172 |
173 | plan_str = oracle.plan_to_string(plan)
174 | assert not downward.validate(domain, desc, plan_str, VALIDATE)
175 |
176 |
177 | class TestGripperOracle:
178 | """
179 | Test suite for the gripper oracle.
180 | """
181 |
182 | def test_plan(
183 | self,
184 | subtests,
185 | gripper_fully_specified,
186 | gripper_fully_specified_not_strict,
187 | gripper_no_goal_types,
188 | gripper_no_robby,
189 | gripper_robby_at_last,
190 | gripper_underspecified_1,
191 | gripper_underspecified_2,
192 | gripper_underspecified_3,
193 | ):
194 | """
195 | Test if the oracle can plan for a fully specified gripper problem.
196 | """
197 | domain = DOMAINS["gripper"]
198 | for name, desc in {
199 | "gripper_fully_specified": gripper_fully_specified,
200 | "gripper_fully_specified_not_strict": gripper_fully_specified_not_strict,
201 | "gripper_no_goal_types": gripper_no_goal_types,
202 | "gripper_no_robby": gripper_no_robby,
203 | "gripper_robby_at_last": gripper_robby_at_last,
204 | "gripper_underspecified_1": gripper_underspecified_1,
205 | "gripper_underspecified_2": gripper_underspecified_2,
206 | "gripper_underspecified_3": gripper_underspecified_3,
207 | }.items():
208 | with subtests.test(name):
209 | plan = oracle.plan(builder.build(desc))
210 | assert plan != [], name
211 |
212 | assert downward.validate(
213 | domain,
214 | desc,
215 | oracle.plan_to_string(plan),
216 | VALIDATE,
217 | ), name
218 |
219 | with subtests.test(name):
220 | assert not downward.validate(
221 | DOMAINS["blocksworld"],
222 | desc,
223 | oracle.plan_to_string(plan),
224 | VALIDATE,
225 | )
226 |
227 |
228 | class TestUnsupportedDomain:
229 | """
230 | Test suite for unsupported domain.
231 | """
232 |
233 | def test_plan(self, mocker, blocksworld_fully_specified):
234 | """
235 | Test if the oracle can plan for an unsupported domain.
236 | """
237 | problem = builder.build(blocksworld_fully_specified)
238 | mocker.patch("planetarium.oracle.fully_specify", return_value=problem)
239 | with pytest.raises(oracle.DomainNotSupportedError):
240 | oracle.plan(problem, domain="unsupported_domain")
241 |
--------------------------------------------------------------------------------