├── .coveragerc ├── .github ├── CODEOWNERS └── workflows │ └── predicators.yml ├── .gitignore ├── .isort.cfg ├── .predicators_pylintrc ├── .style.yapf ├── LICENSE.md ├── README.md ├── behavior.md ├── mypy.ini ├── predicators ├── __init__.py ├── approaches │ ├── __init__.py │ ├── base_approach.py │ ├── bilevel_planning_approach.py │ ├── gnn_action_policy_approach.py │ ├── gnn_approach.py │ ├── gnn_metacontroller_approach.py │ ├── gnn_option_policy_approach.py │ ├── grammar_search_invention_approach.py │ ├── initialized_pg3_approach.py │ ├── interactive_learning_approach.py │ ├── llm_base_renaming_approach.py │ ├── llm_bilevel_planning_approach.py │ ├── llm_open_loop_approach.py │ ├── llm_option_renaming_approach.py │ ├── llm_predicate_renaming_approach.py │ ├── llm_syntax_renaming_approach.py │ ├── nsrt_learning_approach.py │ ├── nsrt_metacontroller_approach.py │ ├── nsrt_rl_approach.py │ ├── online_nsrt_learning_approach.py │ ├── online_pg3_approach.py │ ├── oracle_approach.py │ ├── pg3_approach.py │ ├── pg4_approach.py │ ├── random_actions_approach.py │ ├── random_options_approach.py │ └── refinement_estimation_approach.py ├── args.py ├── behavior_utils │ ├── __init__.py │ ├── behavior_utils.py │ ├── motion_planner_fns.py │ ├── option_fns.py │ ├── option_model_fns.py │ ├── task_to_broken_inst_ids.json │ └── task_to_preselected_scenes.json ├── datasets │ ├── __init__.py │ ├── demo_only.py │ ├── demo_replay.py │ └── ground_atom_data.py ├── envs │ ├── __init__.py │ ├── assets │ │ ├── pddl │ │ │ ├── blocks │ │ │ │ ├── domain.pddl │ │ │ │ ├── task1.pddl │ │ │ │ ├── task10.pddl │ │ │ │ ├── task11.pddl │ │ │ │ ├── task12.pddl │ │ │ │ ├── task13.pddl │ │ │ │ ├── task14.pddl │ │ │ │ ├── task15.pddl │ │ │ │ ├── task16.pddl │ │ │ │ ├── task17.pddl │ │ │ │ ├── task18.pddl │ │ │ │ ├── task19.pddl │ │ │ │ ├── task2.pddl │ │ │ │ ├── task20.pddl │ │ │ │ ├── task21.pddl │ │ │ │ ├── task22.pddl │ │ │ │ ├── task23.pddl │ │ │ │ ├── task24.pddl │ │ │ │ ├── task25.pddl │ │ │ │ ├── task26.pddl │ │ │ │ ├── task27.pddl │ │ │ │ ├── task28.pddl │ │ │ │ ├── task29.pddl │ │ │ │ ├── task3.pddl │ │ │ │ ├── task30.pddl │ │ │ │ ├── task31.pddl │ │ │ │ ├── task32.pddl │ │ │ │ ├── task33.pddl │ │ │ │ ├── task34.pddl │ │ │ │ ├── task35.pddl │ │ │ │ ├── task4.pddl │ │ │ │ ├── task5.pddl │ │ │ │ ├── task6.pddl │ │ │ │ ├── task7.pddl │ │ │ │ ├── task8.pddl │ │ │ │ └── task9.pddl │ │ │ ├── delivery │ │ │ │ └── domain.pddl │ │ │ ├── ferry │ │ │ │ └── domain.pddl │ │ │ ├── forest │ │ │ │ └── domain.pddl │ │ │ ├── gripper │ │ │ │ ├── domain.pddl │ │ │ │ └── prefixed_domain.pddl │ │ │ ├── miconic │ │ │ │ └── domain.pddl │ │ │ └── spannerlearning │ │ │ │ └── domain.pddl │ │ └── urdf │ │ │ ├── fetch_description │ │ │ ├── meshes │ │ │ │ ├── base_link.dae │ │ │ │ ├── base_link_collision.STL │ │ │ │ ├── base_link_uv.png │ │ │ │ ├── bellows_link.STL │ │ │ │ ├── bellows_link_collision.STL │ │ │ │ ├── elbow_flex_link.dae │ │ │ │ ├── elbow_flex_link_collision.STL │ │ │ │ ├── elbow_flex_uv.png │ │ │ │ ├── estop_link.STL │ │ │ │ ├── forearm_roll_link.dae │ │ │ │ ├── forearm_roll_link_collision.STL │ │ │ │ ├── forearm_roll_uv.png │ │ │ │ ├── gripper_link.STL │ │ │ │ ├── gripper_link.dae │ │ │ │ ├── gripper_uv.png │ │ │ │ ├── head_pan_link.dae │ │ │ │ ├── head_pan_link_collision.STL │ │ │ │ ├── head_pan_uv.png │ │ │ │ ├── head_tilt_link.dae │ │ │ │ ├── head_tilt_link_collision.STL │ │ │ │ ├── head_tilt_uv.png │ │ │ │ ├── l_gripper_finger_link.STL │ │ │ │ ├── l_wheel_link.STL │ │ │ │ ├── l_wheel_link_collision.STL │ │ │ │ ├── laser_link.STL │ │ │ │ ├── r_gripper_finger_link.STL │ │ │ │ ├── r_wheel_link.STL │ │ │ │ ├── r_wheel_link_collision.STL │ │ │ │ ├── shoulder_lift_link.dae │ │ │ │ ├── shoulder_lift_link_collision.STL │ │ │ │ ├── shoulder_lift_uv.png │ │ │ │ ├── shoulder_pan_link.dae │ │ │ │ ├── shoulder_pan_link_collision.STL │ │ │ │ ├── shoulder_pan_uv.png │ │ │ │ ├── torso_fixed_link.STL │ │ │ │ ├── torso_fixed_link.dae │ │ │ │ ├── torso_fixed_uv.png │ │ │ │ ├── torso_lift_link.dae │ │ │ │ ├── torso_lift_link_collision.STL │ │ │ │ ├── torso_lift_uv.png │ │ │ │ ├── upperarm_roll_link.dae │ │ │ │ ├── upperarm_roll_link_collision.STL │ │ │ │ ├── upperarm_roll_uv.png │ │ │ │ ├── wrist_flex_link.dae │ │ │ │ ├── wrist_flex_link_collision.STL │ │ │ │ ├── wrist_flex_uv.png │ │ │ │ ├── wrist_roll_link.dae │ │ │ │ ├── wrist_roll_link_collision.STL │ │ │ │ └── wrist_roll_uv.png │ │ │ └── robots │ │ │ │ └── fetch.urdf │ │ │ ├── franka_description │ │ │ ├── CMakeLists.txt │ │ │ ├── mainpage.dox │ │ │ ├── meshes │ │ │ │ ├── collision │ │ │ │ │ ├── finger.stl │ │ │ │ │ ├── hand.stl │ │ │ │ │ ├── link0.stl │ │ │ │ │ ├── link1.stl │ │ │ │ │ ├── link2.stl │ │ │ │ │ ├── link3.stl │ │ │ │ │ ├── link4.stl │ │ │ │ │ ├── link5.stl │ │ │ │ │ ├── link6.stl │ │ │ │ │ └── link7.stl │ │ │ │ └── visual │ │ │ │ │ ├── finger.dae │ │ │ │ │ ├── hand.dae │ │ │ │ │ ├── link0.dae │ │ │ │ │ ├── link1.dae │ │ │ │ │ ├── link2.dae │ │ │ │ │ ├── link3.dae │ │ │ │ │ ├── link4.dae │ │ │ │ │ ├── link5.dae │ │ │ │ │ ├── link6.dae │ │ │ │ │ └── link7.dae │ │ │ ├── package.xml │ │ │ ├── robots │ │ │ │ ├── hand.dae │ │ │ │ ├── hand.urdf │ │ │ │ ├── hand.urdf.xacro │ │ │ │ ├── hand.xacro │ │ │ │ ├── panda_arm.backup.dae │ │ │ │ ├── panda_arm.dae │ │ │ │ ├── panda_arm.urdf │ │ │ │ ├── panda_arm.urdf.xacro │ │ │ │ ├── panda_arm.xacro │ │ │ │ ├── panda_arm_hand.backup.dae │ │ │ │ ├── panda_arm_hand.dae │ │ │ │ ├── panda_arm_hand.urdf │ │ │ │ └── panda_arm_hand.urdf.xacro │ │ │ └── rosdoc.yaml │ │ │ ├── plane.obj │ │ │ ├── plane.urdf │ │ │ ├── table.obj │ │ │ ├── table.png │ │ │ └── table.urdf │ ├── base_env.py │ ├── behavior.py │ ├── blocks.py │ ├── cluttered_table.py │ ├── coffee.py │ ├── cover.py │ ├── doors.py │ ├── narrow_passage.py │ ├── painting.py │ ├── pddl_env.py │ ├── pddl_procedural_generation.py │ ├── playroom.py │ ├── pybullet_blocks.py │ ├── pybullet_cover.py │ ├── pybullet_env.py │ ├── repeated_nextto.py │ ├── repeated_nextto_painting.py │ ├── satellites.py │ ├── screws.py │ ├── stick_button.py │ ├── tools.py │ └── touch_point.py ├── explorers │ ├── __init__.py │ ├── base_explorer.py │ ├── bilevel_planning_explorer.py │ ├── exploit_bilevel_planning_explorer.py │ ├── glib_explorer.py │ ├── greedy_lookahead_explorer.py │ ├── no_explore_explorer.py │ ├── random_actions_explorer.py │ └── random_options_explorer.py ├── gnn │ ├── __init__.py │ ├── gnn.py │ └── gnn_utils.py ├── ground_truth_nsrts.py ├── llm_interface.py ├── main.py ├── ml_models.py ├── nsrt_learning │ ├── __init__.py │ ├── nsrt_learning_main.py │ ├── option_learning.py │ ├── sampler_learning.py │ ├── segmentation.py │ └── strips_learning │ │ ├── __init__.py │ │ ├── base_strips_learner.py │ │ ├── clustering_learner.py │ │ ├── gen_to_spec_learner.py │ │ ├── oracle_learner.py │ │ └── pnad_search_learner.py ├── option_model.py ├── planning.py ├── predicate_search_score_functions.py ├── pybullet_helpers │ ├── __init__.py │ ├── controllers.py │ ├── geometry.py │ ├── ikfast │ │ ├── __init__.py │ │ ├── load.py │ │ └── utils.py │ ├── inverse_kinematics.py │ ├── joint.py │ ├── link.py │ ├── motion_planning.py │ └── robots │ │ ├── __init__.py │ │ ├── fetch.py │ │ ├── panda.py │ │ └── single_arm.py ├── refinement_estimators │ ├── __init__.py │ ├── base_refinement_estimator.py │ └── oracle_refinement_estimator.py ├── settings.py ├── structs.py ├── teacher.py ├── third_party │ ├── fast_downward_translator │ │ ├── README.md │ │ ├── axiom_rules.py │ │ ├── build_model.py │ │ ├── constraints.py │ │ ├── fact_groups.py │ │ ├── graph.py │ │ ├── greedy_join.py │ │ ├── instantiate.py │ │ ├── invariant_finder.py │ │ ├── invariants.py │ │ ├── normalize.py │ │ ├── pddl │ │ │ ├── __init__.py │ │ │ ├── actions.py │ │ │ ├── axioms.py │ │ │ ├── conditions.py │ │ │ ├── effects.py │ │ │ ├── f_expression.py │ │ │ ├── functions.py │ │ │ ├── pddl_types.py │ │ │ ├── predicates.py │ │ │ └── tasks.py │ │ ├── pddl_parser │ │ │ ├── __init__.py │ │ │ ├── lisp_parser.py │ │ │ ├── parsing_functions.py │ │ │ └── pddl_file.py │ │ ├── pddl_to_prolog.py │ │ ├── sas_tasks.py │ │ ├── sccs.py │ │ ├── simplify.py │ │ ├── split_rules.py │ │ ├── timers.py │ │ ├── tools.py │ │ ├── translate.py │ │ └── variable_order.py │ └── ikfast │ │ ├── __init__.py │ │ ├── compile.py │ │ ├── ikfast.h │ │ └── panda_arm │ │ ├── __init__.py │ │ ├── compile.py │ │ ├── ikfast.h │ │ ├── ikfast_panda_arm.cpp │ │ └── setup.py └── utils.py ├── run_autoformat.sh ├── scripts ├── __init__.py ├── analyze_results_directory.py ├── cluster_utils.py ├── configs │ ├── backchaining_predicate_invention.yaml │ ├── behavior_20_evaluation.yaml │ ├── behavior_all_tasks_oracle.yaml │ ├── behavior_example.yaml │ ├── behavior_pick_place_data_collection.yaml │ ├── behavior_pick_place_evaluation.yaml │ ├── behavior_pick_place_learning.yaml │ ├── example_basic.yaml │ ├── example_multiple_combinations.yaml │ ├── full_pipeline.yaml │ ├── interactive_learning.yaml │ ├── llm_pddl.yaml │ ├── llm_pddl_ablations.yaml │ ├── nightly.yaml │ ├── pg3_offline.yaml │ ├── pg3_online.yaml │ └── pg4.yaml ├── eval_trajectory_to_lisdf.py ├── evaluate_interactive_approach_classifiers.py ├── evaluate_interactive_approach_entropy.py ├── find_unused_functions.py ├── grammar_search_analysis.py ├── launch_slack_bot.py ├── lisdf_plan_to_reset.py ├── lisdf_pybullet_visualizer.py ├── local │ ├── launch.py │ └── run_behavior_tests.py ├── openstack │ ├── README.md │ ├── __init__.py │ ├── download.py │ ├── kill_all.py │ └── launch.py ├── plotting │ ├── create_bar_plots.py │ ├── create_classification_plots.py │ ├── create_ignore_effects_lineplots.py │ ├── create_interactive_predicate_learning_plots.py │ ├── create_latex_tables.py │ ├── create_num_demos_plots.py │ ├── create_option_learning_plots.py │ ├── create_per_task_histograms.py │ ├── create_per_task_nodes_stripplot.py │ └── create_pnadsearch_lineplot.py ├── realsense_helpers.py ├── run_blocks_perception.py ├── run_blocks_real.sh ├── run_checks.sh ├── skeleton_score_analysis.py └── supercloud │ ├── __init__.py │ ├── download.py │ ├── kill_all.py │ ├── launch.py │ ├── run_ignore_effects_experiments.sh │ ├── run_loft_experiments.sh │ ├── run_option_learning_experiments.sh │ ├── run_predicators_evalonly_experiments.sh │ ├── run_predicators_main_experiments.sh │ ├── run_predicators_num_demos_experiments.sh │ └── submit_supercloud_job.py ├── setup.py ├── supercloud.md └── tests ├── __init__.py ├── approaches ├── test_base_approach.py ├── test_gnn_action_policy_approach.py ├── test_gnn_metacontroller_approach.py ├── test_gnn_option_policy_approach.py ├── test_grammar_search_invention_approach.py ├── test_initialized_pg3_approach.py ├── test_interactive_approach.py ├── test_llm_base_renaming_approach.py ├── test_llm_bilevel_planning_approach.py ├── test_llm_open_loop_approach.py ├── test_llm_option_renaming_approach.py ├── test_llm_predicate_renaming_approach.py ├── test_llm_syntax_renaming_approach.py ├── test_nsrt_learning_approach.py ├── test_nsrt_rl_approach.py ├── test_online_nsrt_learning_approach.py ├── test_online_pg3_approach.py ├── test_oracle_approach.py ├── test_pg3_approach.py ├── test_pg4_approach.py ├── test_random_actions_approach.py ├── test_random_options_approach.py └── test_refinement_estimation_approach.py ├── behavior_utils └── test_behavior_utils.py ├── conftest.py ├── datasets └── test_datasets.py ├── envs ├── test_base_env.py ├── test_blocks.py ├── test_cluttered_table.py ├── test_coffee.py ├── test_cover.py ├── test_doors_env.py ├── test_narrow_passage.py ├── test_painting.py ├── test_pddl_env.py ├── test_pddl_procedural_generation.py ├── test_playroom.py ├── test_pybullet_blocks.py ├── test_pybullet_cover.py ├── test_repeated_nextto.py ├── test_repeated_nextto_painting.py ├── test_satellites.py ├── test_screws.py ├── test_stick_button.py ├── test_tools.py └── test_touch_point.py ├── explorers ├── test_base_explorer.py ├── test_exploit_bilevel_planning_explorer.py ├── test_glib_explorer.py ├── test_greedy_lookahead_explorer.py ├── test_no_explore_explorer.py ├── test_online_learning.py ├── test_random_actions_explorer.py └── test_random_options_explorer.py ├── nsrt_learning ├── strips_learning │ ├── test_backchaining_based_learners.py │ ├── test_base_strips_learner.py │ ├── test_clustering_learner.py │ └── test_oracle_learner.py ├── test_nsrt_learning_main.py ├── test_option_learning.py ├── test_sampler_learning.py └── test_segmentation.py ├── pybullet_helpers ├── conftest.py ├── ikfast │ ├── test_load.py │ └── test_utils.py ├── robots │ └── test_panda.py ├── test_geometry.py ├── test_joint.py ├── test_link.py └── test_pybullet_robots.py ├── refinement_estimators ├── test_base_refinement_estimator.py └── test_oracle_refinement_estimator.py ├── test_args.py ├── test_llm_interface.py ├── test_main.py ├── test_ml_models.py ├── test_option_model.py ├── test_planning.py ├── test_predicate_search_score_functions.py ├── test_settings.py ├── test_structs.py ├── test_teacher.py └── test_utils.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | predicators/envs/behavior.py 4 | predicators/behavior_utils/** 5 | predicators/third_party/** 6 | 7 | [report] 8 | # Regexes for lines to exclude from consideration 9 | exclude_lines = 10 | # Have to re-enable the standard pragma 11 | # per https://coverage.readthedocs.io/en/latest/config.html#syntax 12 | pragma: no cover 13 | 14 | # Don't complain about abstract methods, they aren't run 15 | @abstractmethod 16 | @abc.abstractmethod 17 | 18 | # Don't complain about TYPE_CHECKING imports. 19 | if TYPE_CHECKING: 20 | 21 | # Don't complain about longrun tests. 22 | @longrun 23 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Default owners for all files. 2 | * @NishanthJKumar @wmcclinton 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | .DS_Store 4 | *.egg-info 5 | *.pkl 6 | *~ 7 | *.python-version 8 | *.csv 9 | *.zip 10 | *.so 11 | build 12 | dist 13 | pylint_recursive.py 14 | .coverage 15 | results 16 | eval_trajectories 17 | videos 18 | logs 19 | saved_approaches 20 | saved_datasets 21 | scripts/results 22 | tmp_behavior_states 23 | llm_cache 24 | machines.txt 25 | *_vision_data 26 | tests/_fake_trajs 27 | tests/_fake_results 28 | *.mp4 29 | video_frames/ 30 | 31 | # Jetbrains IDEs 32 | .idea/ 33 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | multi_line_output = 2 3 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = pep8 3 | split_before_logical_operator = true 4 | column_limit = 79 5 | spaces_before_comment = 2 -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Rohan Chitnis and Tom Silver 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | strict_equality = True 3 | disallow_untyped_calls = True 4 | warn_unreachable = True 5 | exclude = (predicators/envs/assets|predicators/third_party) 6 | 7 | [mypy-predicators.*] 8 | disallow_untyped_defs = True 9 | 10 | [mypy-scripts.*] 11 | disallow_untyped_defs = True 12 | 13 | [mypy-predicators.tests.*] 14 | ignore_missing_imports = True 15 | 16 | [mypy-predicators.third_party.*] 17 | ignore_missing_imports = True 18 | 19 | [mypy-setuptools.*] 20 | ignore_missing_imports = True 21 | 22 | [mypy-gym.spaces.*] 23 | ignore_missing_imports = True 24 | 25 | [mypy-imageio.*] 26 | ignore_missing_imports = True 27 | 28 | [mypy-matplotlib.*] 29 | ignore_missing_imports = True 30 | 31 | [mypy-scipy.*] 32 | ignore_missing_imports = True 33 | 34 | [mypy-bddl.*] 35 | ignore_missing_imports = True 36 | 37 | [mypy-igibson.*] 38 | ignore_missing_imports = True 39 | 40 | [mypy-tabulate.*] 41 | ignore_missing_imports = True 42 | 43 | [mypy-pandas.*] 44 | ignore_missing_imports = True 45 | 46 | [mypy-dill.*] 47 | ignore_missing_imports = True 48 | 49 | [mypy-pyperplan.*] 50 | ignore_missing_imports = True 51 | 52 | [mypy-pybullet.*] 53 | ignore_missing_imports = True 54 | 55 | [mypy-pathos.*] 56 | ignore_missing_imports = True 57 | 58 | [mypy-requests.*] 59 | ignore_missing_imports = True 60 | 61 | [mypy-sklearn.*] 62 | ignore_missing_imports = True 63 | 64 | [mypy-graphlib.*] 65 | ignore_missing_imports = True 66 | 67 | [mypy-pybullet_utils.*] 68 | # utils provided by pybullet itself don't have typing stubs 69 | ignore_missing_imports = True 70 | 71 | [mypy-seaborn.*] 72 | ignore_missing_imports = True 73 | 74 | [mypy-open3d.*] 75 | ignore_missing_imports = True 76 | 77 | [mypy-smepy.*] 78 | ignore_missing_imports = True 79 | 80 | [mypy-pyrealsense2.*] 81 | ignore_missing_imports = True 82 | 83 | [mypy-panda_robot_client.*] 84 | ignore_missing_imports = True 85 | 86 | [mypy-gym.*] 87 | ignore_missing_imports = True 88 | 89 | [mypy-gym_sokoban.*] 90 | ignore_missing_imports = True 91 | 92 | [mypy-tqdm.*] 93 | ignore_missing_imports = True 94 | -------------------------------------------------------------------------------- /predicators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/__init__.py -------------------------------------------------------------------------------- /predicators/approaches/__init__.py: -------------------------------------------------------------------------------- 1 | """Handle creation of approaches.""" 2 | 3 | import importlib 4 | import pkgutil 5 | from typing import TYPE_CHECKING, List, Set 6 | 7 | from gym.spaces import Box 8 | 9 | from predicators import utils 10 | from predicators.approaches.base_approach import ApproachFailure, \ 11 | ApproachTimeout, BaseApproach 12 | from predicators.structs import ParameterizedOption, Predicate, Task, Type 13 | 14 | __all__ = ["BaseApproach", "ApproachTimeout", "ApproachFailure"] 15 | 16 | if not TYPE_CHECKING: 17 | # Load all modules so that utils.get_all_subclasses() works. 18 | for _, module_name, _ in pkgutil.walk_packages(__path__): 19 | if "__init__" not in module_name: 20 | # Important! We use an absolute import here to avoid issues 21 | # with isinstance checking when using relative imports. 22 | importlib.import_module(f"{__name__}.{module_name}") 23 | 24 | 25 | def create_approach(name: str, initial_predicates: Set[Predicate], 26 | initial_options: Set[ParameterizedOption], 27 | types: Set[Type], action_space: Box, 28 | train_tasks: List[Task]) -> BaseApproach: 29 | """Create an approach given its name.""" 30 | for cls in utils.get_all_subclasses(BaseApproach): 31 | if not cls.__abstractmethods__ and cls.get_name() == name: 32 | approach = cls(initial_predicates, initial_options, types, 33 | action_space, train_tasks) 34 | break 35 | else: 36 | raise NotImplementedError(f"Unknown approach: {name}") 37 | return approach 38 | -------------------------------------------------------------------------------- /predicators/approaches/llm_option_renaming_approach.py: -------------------------------------------------------------------------------- 1 | """Open-loop large language model (LLM) meta-controller approach with prompt 2 | modification where option names are replaced with random strings. 3 | 4 | Example command line: 5 | export OPENAI_API_KEY= 6 | python predicators/main.py --approach llm_option_renaming --seed 0 \ 7 | --strips_learner oracle \ 8 | --env pddl_blocks_procedural_tasks \ 9 | --num_train_tasks 3 \ 10 | --num_test_tasks 1 \ 11 | --debug 12 | """ 13 | import string 14 | from typing import Dict, List 15 | 16 | from predicators import utils 17 | from predicators.approaches.llm_base_renaming_approach import \ 18 | LLMBaseRenamingApproach 19 | 20 | 21 | class LLMOptionRenamingApproach(LLMBaseRenamingApproach): 22 | """LLMOptionRenamingApproach definition.""" 23 | 24 | @classmethod 25 | def get_name(cls) -> str: 26 | return "llm_option_renaming" 27 | 28 | @property 29 | def _renaming_prefixes(self) -> List[str]: 30 | # Options start with either a new line or a white space. 31 | return [" ", "\n"] 32 | 33 | @property 34 | def _renaming_suffixes(self) -> List[str]: 35 | # Option names end with a left parenthesis. 36 | return ["("] 37 | 38 | def _create_replacements(self) -> Dict[str, str]: 39 | return { 40 | o.name: utils.generate_random_string(len(o.name), 41 | list(string.ascii_lowercase), 42 | self._rng) 43 | for o in self._initial_options 44 | } 45 | -------------------------------------------------------------------------------- /predicators/approaches/llm_predicate_renaming_approach.py: -------------------------------------------------------------------------------- 1 | """Open-loop large language model (LLM) meta-controller approach with prompt 2 | modification where where predicate names are replaced with random strings. 3 | 4 | Example command line: 5 | export OPENAI_API_KEY= 6 | python predicators/main.py --approach llm_predicate_renaming --seed 0 \ 7 | --strips_learner oracle \ 8 | --env pddl_blocks_procedural_tasks \ 9 | --num_train_tasks 3 \ 10 | --num_test_tasks 1 \ 11 | --debug 12 | """ 13 | import string 14 | from typing import Dict, List 15 | 16 | from predicators import utils 17 | from predicators.approaches.llm_base_renaming_approach import \ 18 | LLMBaseRenamingApproach 19 | 20 | 21 | class LLMPredicateRenamingApproach(LLMBaseRenamingApproach): 22 | """LLMPredicateRenamingApproach definition.""" 23 | 24 | @classmethod 25 | def get_name(cls) -> str: 26 | return "llm_predicate_renaming" 27 | 28 | @property 29 | def _renaming_prefixes(self) -> List[str]: 30 | # Predicates start with either a new line or a white space. 31 | return [" ", "\n"] 32 | 33 | @property 34 | def _renaming_suffixes(self) -> List[str]: 35 | # Predicate names end with a left parenthesis. 36 | return ["("] 37 | 38 | def _create_replacements(self) -> Dict[str, str]: 39 | return { 40 | p.name: utils.generate_random_string(len(p.name), 41 | list(string.ascii_lowercase), 42 | self._rng) 43 | for p in self._get_current_predicates() 44 | } 45 | -------------------------------------------------------------------------------- /predicators/approaches/llm_syntax_renaming_approach.py: -------------------------------------------------------------------------------- 1 | """Open-loop large language model (LLM) meta-controller approach with prompt 2 | modification where certain PDDL syntax is replaced with random characters. 3 | 4 | Example command line: 5 | export OPENAI_API_KEY= 6 | python predicators/main.py --approach llm_syntax_renaming --seed 0 \ 7 | --strips_learner oracle \ 8 | --env pddl_blocks_procedural_tasks \ 9 | --num_train_tasks 3 \ 10 | --num_test_tasks 1 \ 11 | --debug 12 | """ 13 | from typing import Dict, List 14 | 15 | from predicators.approaches.llm_base_renaming_approach import \ 16 | LLMBaseRenamingApproach 17 | 18 | ORIGINAL_CHARS = ['(', ')', ':'] 19 | REPLACEMENT_CHARS = ['^', '$', '#', '!', '*'] 20 | 21 | 22 | class LLMSyntaxRenamingApproach(LLMBaseRenamingApproach): 23 | """LLMSyntaxRenamingApproach definition.""" 24 | 25 | @classmethod 26 | def get_name(cls) -> str: 27 | return "llm_syntax_renaming" 28 | 29 | @property 30 | def _renaming_prefixes(self) -> List[str]: 31 | # Since we're replacing single characters, we don't need to worry about 32 | # the possibility that one string is a substring of another. 33 | return [""] 34 | 35 | @property 36 | def _renaming_suffixes(self) -> List[str]: 37 | # Since we're replacing single characters, we don't need to worry about 38 | # the possibility that one string is a substring of another. 39 | return [""] 40 | 41 | def _create_replacements(self) -> Dict[str, str]: 42 | # Without replacement because if multiple original characters mapped 43 | # to the same replacement character, the inverse substitution would be 44 | # not well defined and the parsing of the option plan would fail. 45 | replacement_chars = self._rng.choice(REPLACEMENT_CHARS, 46 | size=len(ORIGINAL_CHARS), 47 | replace=False) 48 | return dict(zip(ORIGINAL_CHARS, replacement_chars)) 49 | -------------------------------------------------------------------------------- /predicators/approaches/online_pg3_approach.py: -------------------------------------------------------------------------------- 1 | """Online learning of generalized policies via PG3. 2 | 3 | Example command line: 4 | python predicators/main.py --approach online_pg3 --seed 0 \ 5 | --env pddl_easy_delivery_procedural_tasks \ 6 | --explorer random_options \ 7 | --max_initial_demos 1 \ 8 | --num_train_tasks 1000 \ 9 | --num_test_tasks 10 \ 10 | --max_num_steps_interaction_request 10 \ 11 | --min_data_for_nsrt 10 12 | """ 13 | from __future__ import annotations 14 | 15 | from typing import List, Sequence, Set 16 | 17 | from predicators.approaches.online_nsrt_learning_approach import \ 18 | OnlineNSRTLearningApproach 19 | from predicators.approaches.pg3_approach import PG3Approach 20 | from predicators.structs import Box, Dataset, InteractionResult, \ 21 | ParameterizedOption, Predicate, Task, Type 22 | 23 | 24 | class OnlinePG3Approach(PG3Approach, OnlineNSRTLearningApproach): 25 | """OnlinePG3Approach implementation.""" 26 | 27 | def __init__(self, initial_predicates: Set[Predicate], 28 | initial_options: Set[ParameterizedOption], types: Set[Type], 29 | action_space: Box, train_tasks: List[Task]) -> None: 30 | # Initializes the generalized policy. 31 | PG3Approach.__init__(self, initial_predicates, initial_options, types, 32 | action_space, train_tasks) 33 | # Initializes the cumulative dataset. 34 | OnlineNSRTLearningApproach.__init__(self, initial_predicates, 35 | initial_options, types, 36 | action_space, train_tasks) 37 | 38 | @classmethod 39 | def get_name(cls) -> str: 40 | return "online_pg3" 41 | 42 | def learn_from_offline_dataset(self, dataset: Dataset) -> None: 43 | # Update the dataset with the offline data. 44 | self._dataset = Dataset(dataset.trajectories) 45 | # Learn NSRTs and generalized policy. 46 | return PG3Approach.learn_from_offline_dataset(self, dataset) 47 | 48 | def learn_from_interaction_results( 49 | self, results: Sequence[InteractionResult]) -> None: 50 | # This does three things: adds data to self._dataset, re-learns NSRTs, 51 | # and advances the online learning cycle counter. 52 | old_nsrts = self._nsrts 53 | OnlineNSRTLearningApproach.learn_from_interaction_results( 54 | self, results) 55 | save_cycle = self._online_learning_cycle - 1 56 | # Then, relearn the generalized policy, but only if the NSRTs have 57 | # changed, because LDL learning is only a function of the NSRTs. 58 | if old_nsrts != self._nsrts: 59 | self._learn_ldl(online_learning_cycle=save_cycle) 60 | -------------------------------------------------------------------------------- /predicators/approaches/oracle_approach.py: -------------------------------------------------------------------------------- 1 | """A bilevel planning approach that uses hand-specified NSRTs. 2 | 3 | The approach is aware of the initial predicates and options. Predicates 4 | that are not in the initial predicates are excluded from the ground 5 | truth NSRTs. If an NSRT's option is not included, that NSRT will not be 6 | generated at all. 7 | """ 8 | 9 | from typing import List, Set 10 | 11 | from gym.spaces import Box 12 | 13 | from predicators import utils 14 | from predicators.approaches.bilevel_planning_approach import \ 15 | BilevelPlanningApproach 16 | from predicators.envs import get_or_create_env 17 | from predicators.ground_truth_nsrts import get_gt_nsrts 18 | from predicators.settings import CFG 19 | from predicators.structs import NSRT, ParameterizedOption, Predicate, Task, \ 20 | Type 21 | 22 | 23 | class OracleApproach(BilevelPlanningApproach): 24 | """A bilevel planning approach that uses hand-specified NSRTs.""" 25 | 26 | def __init__(self, 27 | initial_predicates: Set[Predicate], 28 | initial_options: Set[ParameterizedOption], 29 | types: Set[Type], 30 | action_space: Box, 31 | train_tasks: List[Task], 32 | task_planning_heuristic: str = "default", 33 | max_skeletons_optimized: int = -1) -> None: 34 | super().__init__(initial_predicates, initial_options, types, 35 | action_space, train_tasks, task_planning_heuristic, 36 | max_skeletons_optimized) 37 | self._nsrts = get_gt_nsrts(CFG.env, self._initial_predicates, 38 | self._initial_options) 39 | 40 | @classmethod 41 | def get_name(cls) -> str: 42 | return "oracle" 43 | 44 | @property 45 | def is_learning_based(self) -> bool: 46 | return False 47 | 48 | def _get_current_predicates(self) -> Set[Predicate]: 49 | # If the env is BEHAVIOR, the predicates might change from one 50 | # task to another, so we need to recompute them. 51 | if CFG.env == "behavior": # pragma: no cover 52 | env = get_or_create_env("behavior") 53 | self._initial_predicates, _ = \ 54 | utils.parse_config_excluded_predicates(env) 55 | return self._initial_predicates 56 | 57 | def _get_current_nsrts(self) -> Set[NSRT]: 58 | if CFG.env == "behavior": # pragma: no cover 59 | # If the env is BEHAVIOR the types and therefore 60 | # initial_predicates and initial_options could 61 | # have changed. 62 | env = get_or_create_env("behavior") 63 | preds = self._get_current_predicates() 64 | self._initial_options = env.options 65 | self._nsrts = get_gt_nsrts(env.get_name(), preds, 66 | self._initial_options) 67 | return self._nsrts 68 | -------------------------------------------------------------------------------- /predicators/approaches/pg4_approach.py: -------------------------------------------------------------------------------- 1 | """Policy-guided planning for generalized policy generation for planning 2 | guidance (PG4). 3 | 4 | PG4 requires known STRIPS operators. The command below uses oracle operators, 5 | but it is also possible to use this approach with operators learned from 6 | demonstrations. 7 | 8 | Example command line: 9 | python predicators/main.py --approach pg4 --seed 0 \ 10 | --env cover \ 11 | --strips_learner oracle --num_train_tasks 50 12 | """ 13 | from __future__ import annotations 14 | 15 | from typing import Any, Callable, List, Set, Tuple 16 | 17 | from predicators import utils 18 | from predicators.approaches.bilevel_planning_approach import \ 19 | BilevelPlanningApproach 20 | from predicators.approaches.pg3_approach import PG3Approach 21 | from predicators.settings import CFG 22 | from predicators.structs import NSRT, AbstractPolicy, Action, Metrics, \ 23 | Predicate, State, Task, _Option 24 | 25 | 26 | class PG4Approach(PG3Approach): 27 | """Policy-guided planning for generalized policy generation for planning 28 | guidance (PG4).""" 29 | 30 | @classmethod 31 | def get_name(cls) -> str: 32 | return "pg4" 33 | 34 | def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: 35 | # This is a rare case where protected access seems like the best thing 36 | # to do, because this approach subclasses from BilevelPlanningApproach, 37 | # but it's not the direct child, so we can't use super(). 38 | return BilevelPlanningApproach._solve(self, task, timeout) # pylint: disable=protected-access 39 | 40 | def _run_sesame_plan( 41 | self, task: Task, nsrts: Set[NSRT], preds: Set[Predicate], 42 | timeout: float, seed: int, 43 | **kwargs: Any) -> Tuple[List[_Option], Metrics, List[State]]: 44 | """Generates a plan choosing the best skeletons generated from policy- 45 | based skeletons and primitive successors.""" 46 | abstract_policy: AbstractPolicy = lambda a, o, g: utils.query_ldl( 47 | self._current_ldl, a, o, g) 48 | max_policy_guided_rollout = CFG.pg3_max_policy_guided_rollout 49 | return super()._run_sesame_plan( 50 | task, 51 | nsrts, 52 | preds, 53 | timeout, 54 | seed, 55 | abstract_policy=abstract_policy, 56 | max_policy_guided_rollout=max_policy_guided_rollout, 57 | **kwargs) 58 | -------------------------------------------------------------------------------- /predicators/approaches/random_actions_approach.py: -------------------------------------------------------------------------------- 1 | """An approach that just takes random low-level actions.""" 2 | 3 | from typing import Callable 4 | 5 | from predicators.approaches import BaseApproach 6 | from predicators.structs import Action, State, Task 7 | 8 | 9 | class RandomActionsApproach(BaseApproach): 10 | """Samples random low-level actions.""" 11 | 12 | @classmethod 13 | def get_name(cls) -> str: 14 | return "random_actions" 15 | 16 | @property 17 | def is_learning_based(self) -> bool: 18 | return False 19 | 20 | def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: 21 | 22 | def _policy(_: State) -> Action: 23 | return Action(self._action_space.sample()) 24 | 25 | return _policy 26 | -------------------------------------------------------------------------------- /predicators/approaches/random_options_approach.py: -------------------------------------------------------------------------------- 1 | """An approach that just executes random options.""" 2 | 3 | from typing import Callable 4 | 5 | from predicators import utils 6 | from predicators.approaches import ApproachFailure, BaseApproach 7 | from predicators.structs import Action, State, Task 8 | 9 | 10 | class RandomOptionsApproach(BaseApproach): 11 | """Samples random options (and random parameters for those options).""" 12 | 13 | @classmethod 14 | def get_name(cls) -> str: 15 | return "random_options" 16 | 17 | @property 18 | def is_learning_based(self) -> bool: 19 | return False 20 | 21 | def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: 22 | 23 | def fallback_policy(state: State) -> Action: 24 | del state # unused 25 | raise ApproachFailure("Random option sampling failed!") 26 | 27 | return utils.create_random_option_policy(self._initial_options, 28 | self._rng, fallback_policy) 29 | -------------------------------------------------------------------------------- /predicators/approaches/refinement_estimation_approach.py: -------------------------------------------------------------------------------- 1 | """A bilevel planning approach that uses a refinement cost estimator. 2 | 3 | Generates N proposed skeletons and then ranks them based on a given 4 | refinement cost estimation function (e.g. a heuristic, learned model), 5 | attempting to refine them in this order. 6 | """ 7 | 8 | from typing import Any, List, Set, Tuple 9 | 10 | from gym.spaces import Box 11 | 12 | from predicators.approaches.oracle_approach import OracleApproach 13 | from predicators.refinement_estimators import create_refinement_estimator 14 | from predicators.settings import CFG 15 | from predicators.structs import NSRT, Metrics, ParameterizedOption, \ 16 | Predicate, State, Task, Type, _Option 17 | 18 | 19 | class RefinementEstimationApproach(OracleApproach): 20 | """A bilevel planning approach that uses a refinement cost estimator.""" 21 | 22 | def __init__(self, 23 | initial_predicates: Set[Predicate], 24 | initial_options: Set[ParameterizedOption], 25 | types: Set[Type], 26 | action_space: Box, 27 | train_tasks: List[Task], 28 | task_planning_heuristic: str = "default", 29 | max_skeletons_optimized: int = -1) -> None: 30 | super().__init__(initial_predicates, initial_options, types, 31 | action_space, train_tasks, task_planning_heuristic, 32 | max_skeletons_optimized) 33 | assert (CFG.refinement_estimation_num_skeletons_generated <= 34 | CFG.sesame_max_skeletons_optimized), \ 35 | "refinement_estimation_num_skeletons_generated should not be" \ 36 | "greater than sesame_max_skeletons_optimized" 37 | self._refinement_estimator = create_refinement_estimator( 38 | CFG.refinement_estimator) 39 | 40 | @classmethod 41 | def get_name(cls) -> str: 42 | return "refinement_estimation" 43 | 44 | def _run_sesame_plan( 45 | self, task: Task, nsrts: Set[NSRT], preds: Set[Predicate], 46 | timeout: float, seed: int, 47 | **kwargs: Any) -> Tuple[List[_Option], Metrics, List[State]]: 48 | """Generates a plan choosing the best skeletons based on a given 49 | refinement cost estimator.""" 50 | result = super()._run_sesame_plan( 51 | task, 52 | nsrts, 53 | preds, 54 | timeout, 55 | seed, 56 | refinement_estimator=self._refinement_estimator, 57 | **kwargs) 58 | return result 59 | -------------------------------------------------------------------------------- /predicators/args.py: -------------------------------------------------------------------------------- 1 | """Contains settings that vary per run. 2 | 3 | All global, immutable settings should be in settings.py. 4 | """ 5 | 6 | import argparse 7 | import logging 8 | 9 | 10 | def create_arg_parser(env_required: bool = True, 11 | approach_required: bool = True, 12 | seed_required: bool = True) -> argparse.ArgumentParser: 13 | """Defines command line argument parser.""" 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--env", required=env_required, type=str) 16 | parser.add_argument("--approach", required=approach_required, type=str) 17 | parser.add_argument("--excluded_predicates", default="", type=str) 18 | parser.add_argument("--included_options", default="", type=str) 19 | parser.add_argument("--seed", required=seed_required, type=int) 20 | parser.add_argument("--option_learner", type=str, default="no_learning") 21 | parser.add_argument("--explorer", type=str, default="no_explore") 22 | # NOTE: this timeout affects both data generation and evaluation. 23 | # If you want to change only the data generation timeout, 24 | # modify offline_data_planning_timeout. 25 | parser.add_argument("--timeout", default=10, type=float) 26 | parser.add_argument("--make_test_videos", action="store_true") 27 | parser.add_argument("--make_failure_videos", action="store_true") 28 | parser.add_argument("--make_interaction_videos", action="store_true") 29 | parser.add_argument("--make_demo_videos", action="store_true") 30 | parser.add_argument("--load_approach", action="store_true") 31 | parser.add_argument("--load_data", action="store_true") 32 | parser.add_argument("--load_atoms", action="store_true") 33 | parser.add_argument("--skip_until_cycle", default=-1, type=int) 34 | parser.add_argument("--experiment_id", default="", type=str) 35 | parser.add_argument("--load_experiment_id", default="", type=str) 36 | parser.add_argument("--log_file", default="", type=str) 37 | parser.add_argument("--use_gui", action="store_true") 38 | parser.add_argument('--debug', 39 | action="store_const", 40 | dest="loglevel", 41 | const=logging.DEBUG, 42 | default=logging.INFO) 43 | parser.add_argument("--crash_on_failure", action="store_true") 44 | return parser 45 | -------------------------------------------------------------------------------- /predicators/behavior_utils/__init__.py: -------------------------------------------------------------------------------- 1 | """All BEHAVIOR-related code that doesn't directly pertain to the BEHAVIOR 2 | environment interface.""" 3 | -------------------------------------------------------------------------------- /predicators/behavior_utils/task_to_broken_inst_ids.json: -------------------------------------------------------------------------------- 1 | { 2 | "collecting_aluminum_cans": { 3 | "Ihlen_1_int": { 4 | "train": [], 5 | "test": [19] 6 | }, 7 | 8 | "Pomaria_2_int": { 9 | "train": [], 10 | "test": [10, 12, 14, 15, 19] 11 | } 12 | }, 13 | 14 | "sorting_books": { 15 | "Pomaria_1_int": { 16 | "train": [0, 2, 3, 6, 7, 8], 17 | "test": [12, 13, 14, 17] 18 | }, 19 | 20 | "Ihlen_0_int": { 21 | "train": [], 22 | "test": [13, 14, 15, 16, 17, 18] 23 | } 24 | }, 25 | 26 | "throwing_away_leftovers": { 27 | "Ihlen_1_int": { 28 | "train": [1, 2, 3, 4, 5, 9], 29 | "test": [10, 13, 14, 15, 16, 17, 18, 19, 25, 26] 30 | } 31 | } 32 | } -------------------------------------------------------------------------------- /predicators/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """Create offline datasets for training, given a set of training tasks for an 2 | environment.""" 3 | 4 | from typing import List, Set 5 | 6 | from predicators import utils 7 | from predicators.datasets.demo_only import create_demo_data 8 | from predicators.datasets.demo_replay import create_demo_replay_data 9 | from predicators.datasets.ground_atom_data import create_ground_atom_data 10 | from predicators.envs import BaseEnv 11 | from predicators.settings import CFG 12 | from predicators.structs import Dataset, ParameterizedOption, Task 13 | 14 | 15 | def create_dataset(env: BaseEnv, train_tasks: List[Task], 16 | known_options: Set[ParameterizedOption]) -> Dataset: 17 | """Create offline datasets for training, given a set of training tasks for 18 | an environment. 19 | 20 | Some or all of this data may be loaded from disk. 21 | """ 22 | if CFG.offline_data_method == "demo": 23 | return create_demo_data(env, train_tasks, known_options) 24 | if CFG.offline_data_method == "demo+replay": 25 | return create_demo_replay_data(env, train_tasks, known_options) 26 | if CFG.offline_data_method == "demo+ground_atoms": 27 | base_dataset = create_demo_data(env, train_tasks, known_options) 28 | _, excluded_preds = utils.parse_config_excluded_predicates(env) 29 | n = int(CFG.teacher_dataset_num_examples) 30 | assert n >= 1, "Must have at least 1 example of each predicate" 31 | return create_ground_atom_data(env, base_dataset, excluded_preds, n) 32 | if CFG.offline_data_method == "empty": 33 | return Dataset([]) 34 | raise NotImplementedError("Unrecognized dataset method.") 35 | -------------------------------------------------------------------------------- /predicators/envs/__init__.py: -------------------------------------------------------------------------------- 1 | """Handle creation of environments.""" 2 | 3 | import importlib 4 | import logging 5 | import pkgutil 6 | from typing import TYPE_CHECKING 7 | 8 | from predicators import utils 9 | from predicators.envs.base_env import BaseEnv 10 | 11 | __all__ = ["BaseEnv"] 12 | _MOST_RECENT_ENV_INSTANCE = {} 13 | 14 | if not TYPE_CHECKING: 15 | # Load all modules so that utils.get_all_subclasses() works. 16 | for _, module_name, _ in pkgutil.walk_packages(__path__): 17 | if "__init__" not in module_name: 18 | # Important! We use an absolute import here to avoid issues 19 | # with isinstance checking when using relative imports. 20 | importlib.import_module(f"{__name__}.{module_name}") 21 | 22 | 23 | def create_new_env(name: str, 24 | do_cache: bool = True, 25 | use_gui: bool = True) -> BaseEnv: 26 | """Create a new instance of an environment from its name. 27 | 28 | If do_cache is True, then cache this env instance so that it can 29 | later be loaded using get_or_create_env(). 30 | """ 31 | for cls in utils.get_all_subclasses(BaseEnv): 32 | if not cls.__abstractmethods__ and cls.get_name() == name: 33 | env = cls(use_gui) 34 | break 35 | else: 36 | raise NotImplementedError(f"Unknown env: {name}") 37 | if do_cache: 38 | _MOST_RECENT_ENV_INSTANCE[name] = env 39 | return env 40 | 41 | 42 | def get_or_create_env(name: str) -> BaseEnv: 43 | """Get the most recent cached env instance. If one does not exist in the 44 | cache, create it using create_new_env(). 45 | 46 | If you use this function, you should NOT be doing anything that 47 | relies on the environment's internal state (i.e., you should not 48 | call reset() or step()). 49 | 50 | Also note that the GUI is always turned off for environments that are 51 | newly created by this function. If you want to use the GUI, you should 52 | create the environment explicitly through create_new_env(). 53 | """ 54 | if name not in _MOST_RECENT_ENV_INSTANCE: 55 | logging.warning( 56 | "WARNING: you called get_or_create_env, but I couldn't " 57 | f"find {name} in the cache. Making a new instance.") 58 | create_new_env(name, do_cache=True, use_gui=False) 59 | return _MOST_RECENT_ENV_INSTANCE[name] 60 | -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/domain.pddl: -------------------------------------------------------------------------------- 1 | ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; 2 | ;;; 4 Op-blocks world 3 | ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; 4 | 5 | (define (domain BLOCKS) 6 | (:requirements :strips :typing) 7 | (:types block) 8 | (:predicates (on ?x - block ?y - block) 9 | (ontable ?x - block) 10 | (clear ?x - block) 11 | (handempty) 12 | (holding ?x - block) 13 | ) 14 | 15 | (:action pick-up 16 | :parameters (?x - block) 17 | :precondition (and (clear ?x) (ontable ?x) (handempty)) 18 | :effect 19 | (and (not (ontable ?x)) 20 | (not (clear ?x)) 21 | (not (handempty)) 22 | (holding ?x))) 23 | 24 | (:action put-down 25 | :parameters (?x - block) 26 | :precondition (holding ?x) 27 | :effect 28 | (and (not (holding ?x)) 29 | (clear ?x) 30 | (handempty) 31 | (ontable ?x))) 32 | (:action stack 33 | :parameters (?x - block ?y - block) 34 | :precondition (and (holding ?x) (clear ?y)) 35 | :effect 36 | (and (not (holding ?x)) 37 | (not (clear ?y)) 38 | (clear ?x) 39 | (handempty) 40 | (on ?x ?y))) 41 | (:action unstack 42 | :parameters (?x - block ?y - block) 43 | :precondition (and (on ?x ?y) (clear ?x) (handempty)) 44 | :effect 45 | (and (holding ?x) 46 | (clear ?y) 47 | (not (clear ?x)) 48 | (not (handempty)) 49 | (not (on ?x ?y))))) 50 | -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task1.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-4-0) 2 | (:domain BLOCKS) 3 | (:objects D B A C - block) 4 | (:INIT (CLEAR C) (CLEAR A) (CLEAR B) (CLEAR D) (ONTABLE C) (ONTABLE A) 5 | (ONTABLE B) (ONTABLE D) (HANDEMPTY)) 6 | (:goal (AND (ON D C) (ON C B) (ON B A))) 7 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task10.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-7-0) 2 | (:domain BLOCKS) 3 | (:objects C F A B G D E - block) 4 | (:INIT (CLEAR E) (ONTABLE D) (ON E G) (ON G B) (ON B A) (ON A F) (ON F C) 5 | (ON C D) (HANDEMPTY)) 6 | (:goal (AND (ON A G) (ON G D) (ON D B) (ON B C) (ON C F) (ON F E))) 7 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task11.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-7-1) 2 | (:domain BLOCKS) 3 | (:objects E B D F G C A - block) 4 | (:INIT (CLEAR A) (CLEAR C) (ONTABLE G) (ONTABLE F) (ON A G) (ON C D) (ON D B) 5 | (ON B E) (ON E F) (HANDEMPTY)) 6 | (:goal (AND (ON A E) (ON E B) (ON B F) (ON F G) (ON G C) (ON C D))) 7 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task12.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-7-2) 2 | (:domain BLOCKS) 3 | (:objects E G C D F A B - block) 4 | (:INIT (CLEAR B) (CLEAR A) (ONTABLE F) (ONTABLE D) (ON B C) (ON C G) (ON G E) 5 | (ON E F) (ON A D) (HANDEMPTY)) 6 | (:goal (AND (ON E B) (ON B F) (ON F D) (ON D A) (ON A C) (ON C G))) 7 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task13.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-8-0) 2 | (:domain BLOCKS) 3 | (:objects H G F E C B D A - block) 4 | (:INIT (CLEAR A) (CLEAR D) (CLEAR B) (CLEAR C) (ONTABLE E) (ONTABLE F) 5 | (ONTABLE B) (ONTABLE C) (ON A G) (ON G E) (ON D H) (ON H F) (HANDEMPTY)) 6 | (:goal (AND (ON D F) (ON F E) (ON E H) (ON H C) (ON C A) (ON A G) (ON G B))) 7 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task14.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-8-1) 2 | (:domain BLOCKS) 3 | (:objects B A G C F D H E - block) 4 | (:INIT (CLEAR E) (CLEAR H) (CLEAR D) (CLEAR F) (ONTABLE C) (ONTABLE G) 5 | (ONTABLE D) (ONTABLE F) (ON E C) (ON H A) (ON A B) (ON B G) (HANDEMPTY)) 6 | (:goal (AND (ON C D) (ON D B) (ON B G) (ON G F) (ON F H) (ON H A) (ON A E))) 7 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task15.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-8-2) 2 | (:domain BLOCKS) 3 | (:objects F B G C H E A D - block) 4 | (:INIT (CLEAR D) (CLEAR A) (CLEAR E) (CLEAR H) (CLEAR C) (ONTABLE G) 5 | (ONTABLE A) (ONTABLE E) (ONTABLE H) (ONTABLE C) (ON D B) (ON B F) (ON F G) 6 | (HANDEMPTY)) 7 | (:goal (AND (ON C B) (ON B E) (ON E G) (ON G F) (ON F A) (ON A D) (ON D H))) 8 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task16.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-9-0) 2 | (:domain BLOCKS) 3 | (:objects H D I A E G B F C - block) 4 | (:INIT (CLEAR C) (CLEAR F) (ONTABLE C) (ONTABLE B) (ON F G) (ON G E) (ON E A) 5 | (ON A I) (ON I D) (ON D H) (ON H B) (HANDEMPTY)) 6 | (:goal (AND (ON G D) (ON D B) (ON B C) (ON C A) (ON A I) (ON I F) (ON F E) 7 | (ON E H))) 8 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task17.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-9-1) 2 | (:domain BLOCKS) 3 | (:objects H G I C D B E A F - block) 4 | (:INIT (CLEAR F) (ONTABLE A) (ON F E) (ON E B) (ON B D) (ON D C) (ON C I) 5 | (ON I G) (ON G H) (ON H A) (HANDEMPTY)) 6 | (:goal (AND (ON D I) (ON I A) (ON A B) (ON B H) (ON H G) (ON G F) (ON F E) 7 | (ON E C))) 8 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task18.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-9-2) 2 | (:domain BLOCKS) 3 | (:objects B I C E D A G F H - block) 4 | (:INIT (CLEAR H) (CLEAR F) (ONTABLE G) (ONTABLE F) (ON H A) (ON A D) (ON D E) 5 | (ON E C) (ON C I) (ON I B) (ON B G) (HANDEMPTY)) 6 | (:goal (AND (ON F G) (ON G H) (ON H D) (ON D I) (ON I E) (ON E B) (ON B C) 7 | (ON C A))) 8 | ) 9 | -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task19.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-10-0) 2 | (:domain BLOCKS) 3 | (:objects D A H G B J E I F C - block) 4 | (:INIT (CLEAR C) (CLEAR F) (ONTABLE I) (ONTABLE F) (ON C E) (ON E J) (ON J B) 5 | (ON B G) (ON G H) (ON H A) (ON A D) (ON D I) (HANDEMPTY)) 6 | (:goal (AND (ON D C) (ON C F) (ON F J) (ON J E) (ON E H) (ON H B) (ON B A) 7 | (ON A G) (ON G I))) 8 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task2.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-4-1) 2 | (:domain BLOCKS) 3 | (:objects A C D B - block) 4 | (:INIT (CLEAR B) (ONTABLE D) (ON B C) (ON C A) (ON A D) (HANDEMPTY)) 5 | (:goal (AND (ON D C) (ON C A) (ON A B))) 6 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task20.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-10-1) 2 | (:domain BLOCKS) 3 | (:objects D A J I E G H B F C - block) 4 | (:INIT (CLEAR C) (CLEAR F) (ONTABLE B) (ONTABLE H) (ON C G) (ON G E) (ON E I) 5 | (ON I J) (ON J A) (ON A B) (ON F D) (ON D H) (HANDEMPTY)) 6 | (:goal (AND (ON C B) (ON B D) (ON D F) (ON F I) (ON I A) (ON A E) (ON E H) 7 | (ON H G) (ON G J))) 8 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task21.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-10-2) 2 | (:domain BLOCKS) 3 | (:objects B G E D F H I A C J - block) 4 | (:INIT (CLEAR J) (CLEAR C) (ONTABLE A) (ONTABLE C) (ON J I) (ON I H) (ON H F) 5 | (ON F D) (ON D E) (ON E G) (ON G B) (ON B A) (HANDEMPTY)) 6 | (:goal (AND (ON B E) (ON E I) (ON I G) (ON G H) (ON H C) (ON C A) (ON A F) 7 | (ON F J) (ON J D))) 8 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task22.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-11-0) 2 | (:domain BLOCKS) 3 | (:objects F A K H G E D I C J B - block) 4 | (:INIT (CLEAR B) (CLEAR J) (CLEAR C) (ONTABLE I) (ONTABLE D) (ONTABLE E) 5 | (ON B G) (ON G H) (ON H K) (ON K A) (ON A F) (ON F I) (ON J D) (ON C E) 6 | (HANDEMPTY)) 7 | (:goal (AND (ON A J) (ON J D) (ON D B) (ON B H) (ON H K) (ON K I) (ON I F) 8 | (ON F E) (ON E G) (ON G C))) 9 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task23.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-11-1) 2 | (:domain BLOCKS) 3 | (:objects B C E A H K I G D F J - block) 4 | (:INIT (CLEAR J) (CLEAR F) (CLEAR D) (CLEAR G) (ONTABLE I) (ONTABLE K) 5 | (ONTABLE H) (ONTABLE A) (ON J I) (ON F E) (ON E K) (ON D C) (ON C H) (ON G B) 6 | (ON B A) (HANDEMPTY)) 7 | (:goal (AND (ON B D) (ON D J) (ON J K) (ON K H) (ON H A) (ON A C) (ON C F) 8 | (ON F G) (ON G I) (ON I E))) 9 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task24.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-11-2) 2 | (:domain BLOCKS) 3 | (:objects E J D C F K H G A I B - block) 4 | (:INIT (CLEAR B) (CLEAR I) (ONTABLE A) (ONTABLE G) (ON B H) (ON H K) (ON K F) 5 | (ON F C) (ON C D) (ON D J) (ON J A) (ON I E) (ON E G) (HANDEMPTY)) 6 | (:goal (AND (ON I G) (ON G C) (ON C D) (ON D E) (ON E J) (ON J B) (ON B H) 7 | (ON H A) (ON A F) (ON F K))) 8 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task25.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-12-0) 2 | (:domain BLOCKS) 3 | (:objects I D B E K G A F C J L H - block) 4 | (:INIT (CLEAR H) (CLEAR L) (CLEAR J) (ONTABLE C) (ONTABLE F) (ONTABLE J) 5 | (ON H A) (ON A G) (ON G K) (ON K E) (ON E B) (ON B D) (ON D I) (ON I C) 6 | (ON L F) (HANDEMPTY)) 7 | (:goal (AND (ON I C) (ON C B) (ON B L) (ON L D) (ON D J) (ON J E) (ON E K) 8 | (ON K F) (ON F A) (ON A H) (ON H G))) 9 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task26.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-12-1) 2 | (:domain BLOCKS) 3 | (:objects E L A B F I H G D J K C - block) 4 | (:INIT (CLEAR C) (CLEAR K) (ONTABLE J) (ONTABLE D) (ON C G) (ON G H) (ON H I) 5 | (ON I F) (ON F B) (ON B A) (ON A L) (ON L E) (ON E J) (ON K D) (HANDEMPTY)) 6 | (:goal (AND (ON J C) (ON C E) (ON E K) (ON K H) (ON H A) (ON A F) (ON F L) 7 | (ON L G) (ON G B) (ON B I) (ON I D))) 8 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task27.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-13-0) 2 | (:domain BLOCKS) 3 | (:objects L H E A J C D F G K M I B - block) 4 | (:INIT (CLEAR B) (CLEAR I) (CLEAR M) (ONTABLE K) (ONTABLE G) (ONTABLE M) 5 | (ON B F) (ON F D) (ON D C) (ON C J) (ON J A) (ON A E) (ON E H) (ON H L) 6 | (ON L K) (ON I G) (HANDEMPTY)) 7 | (:goal (AND (ON G I) (ON I C) (ON C D) (ON D F) (ON F A) (ON A M) (ON M H) 8 | (ON H E) (ON E L) (ON L J) (ON J B) (ON B K))) 9 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task28.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-13-1) 2 | (:domain BLOCKS) 3 | (:objects I M G H L A C D E K F B J - block) 4 | (:INIT (CLEAR J) (CLEAR B) (ONTABLE F) (ONTABLE K) (ON J E) (ON E D) (ON D C) 5 | (ON C A) (ON A L) (ON L H) (ON H G) (ON G M) (ON M I) (ON I F) (ON B K) 6 | (HANDEMPTY)) 7 | (:goal (AND (ON D A) (ON A E) (ON E L) (ON L M) (ON M C) (ON C J) (ON J F) 8 | (ON F K) (ON K G) (ON G H) (ON H I) (ON I B))) 9 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task29.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-14-0) 2 | (:domain BLOCKS) 3 | (:objects I D B L C K M H J N E F G A - block) 4 | (:INIT (CLEAR A) (CLEAR G) (CLEAR F) (ONTABLE E) (ONTABLE N) (ONTABLE F) 5 | (ON A J) (ON J H) (ON H M) (ON M K) (ON K C) (ON C L) (ON L B) (ON B E) 6 | (ON G D) (ON D I) (ON I N) (HANDEMPTY)) 7 | (:goal (AND (ON E L) (ON L F) (ON F B) (ON B J) (ON J I) (ON I N) (ON N C) 8 | (ON C K) (ON K G) (ON G D) (ON D M) (ON M A) (ON A H))) 9 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task3.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-4-2) 2 | (:domain BLOCKS) 3 | (:objects B D C A - block) 4 | (:INIT (CLEAR A) (CLEAR C) (CLEAR D) (ONTABLE A) (ONTABLE B) (ONTABLE D) 5 | (ON C B) (HANDEMPTY)) 6 | (:goal (AND (ON A B) (ON B C) (ON C D))) 7 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task30.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-14-1) 2 | (:domain BLOCKS) 3 | (:objects K A F L D B M E J N H I C G - block) 4 | (:INIT (CLEAR G) (CLEAR C) (CLEAR I) (CLEAR H) (CLEAR N) (ONTABLE J) 5 | (ONTABLE E) (ONTABLE M) (ONTABLE B) (ONTABLE N) (ON G J) (ON C E) (ON I D) 6 | (ON D L) (ON L M) (ON H F) (ON F A) (ON A K) (ON K B) (HANDEMPTY)) 7 | (:goal (AND (ON J D) (ON D B) (ON B H) (ON H M) (ON M K) (ON K F) (ON F G) 8 | (ON G A) (ON A I) (ON I E) (ON E L) (ON L N) (ON N C))) 9 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task31.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-15-0) 2 | (:domain BLOCKS) 3 | (:objects A C L D J H K O N G I F B M E - block) 4 | (:INIT (CLEAR E) (CLEAR M) (CLEAR B) (CLEAR F) (CLEAR I) (ONTABLE G) 5 | (ONTABLE N) (ONTABLE O) (ONTABLE K) (ONTABLE H) (ON E J) (ON J D) (ON D L) 6 | (ON L C) (ON C G) (ON M N) (ON B A) (ON A O) (ON F K) (ON I H) (HANDEMPTY)) 7 | (:goal (AND (ON G O) (ON O H) (ON H K) (ON K M) (ON M F) (ON F E) (ON E A) 8 | (ON A B) (ON B L) (ON L J) (ON J D) (ON D N) (ON N I) (ON I C))) 9 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task32.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-15-1) 2 | (:domain BLOCKS) 3 | (:objects J B K A D H E N C F L M I O G - block) 4 | (:INIT (CLEAR G) (CLEAR O) (ONTABLE I) (ONTABLE M) (ON G L) (ON L F) (ON F C) 5 | (ON C N) (ON N E) (ON E H) (ON H D) (ON D A) (ON A K) (ON K B) (ON B J) 6 | (ON J I) (ON O M) (HANDEMPTY)) 7 | (:goal (AND (ON D G) (ON G F) (ON F K) (ON K J) (ON J E) (ON E M) (ON M A) 8 | (ON A B) (ON B C) (ON C N) (ON N O) (ON O I) (ON I L) (ON L H))) 9 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task33.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-16-1) 2 | (:domain BLOCKS) 3 | (:objects K C D B I N P J M L G E A O H F - block) 4 | (:INIT (CLEAR F) (CLEAR H) (CLEAR O) (ONTABLE A) (ONTABLE E) (ONTABLE G) 5 | (ON F L) (ON L M) (ON M J) (ON J P) (ON P N) (ON N I) (ON I B) (ON B D) 6 | (ON D C) (ON C K) (ON K A) (ON H E) (ON O G) (HANDEMPTY)) 7 | (:goal (AND (ON D B) (ON B P) (ON P F) (ON F G) (ON G K) (ON K I) (ON I L) 8 | (ON L J) (ON J H) (ON H A) (ON A N) (ON N E) (ON E M) (ON M C) 9 | (ON C O))) 10 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task34.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-16-2) 2 | (:domain BLOCKS) 3 | (:objects K I G N P A D M C B H F O J L E - block) 4 | (:INIT (CLEAR E) (CLEAR L) (ONTABLE J) (ONTABLE O) (ON E F) (ON F H) (ON H B) 5 | (ON B C) (ON C M) (ON M D) (ON D A) (ON A P) (ON P N) (ON N G) (ON G I) 6 | (ON I K) (ON K J) (ON L O) (HANDEMPTY)) 7 | (:goal (AND (ON I D) (ON D H) (ON H F) (ON F B) (ON B K) (ON K J) (ON J G) 8 | (ON G E) (ON E C) (ON C L) (ON L M) (ON M N) (ON N A) (ON A P) 9 | (ON P O))) 10 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task35.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-17-0) 2 | (:domain BLOCKS) 3 | (:objects C D E F B I J A N O K M P H G L Q - block) 4 | (:INIT (CLEAR Q) (CLEAR L) (CLEAR G) (CLEAR H) (CLEAR P) (ONTABLE M) 5 | (ONTABLE K) (ONTABLE O) (ONTABLE N) (ONTABLE P) (ON Q A) (ON A J) (ON J I) 6 | (ON I B) (ON B M) (ON L F) (ON F E) (ON E K) (ON G D) (ON D C) (ON C O) 7 | (ON H N) (HANDEMPTY)) 8 | (:goal (AND (ON Q N) (ON N L) (ON L O) (ON O J) (ON J H) (ON H C) (ON C E) 9 | (ON E M) (ON M P) (ON P A) (ON A G) (ON G B) (ON B I) (ON I K) 10 | (ON K F) (ON F D))) 11 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task4.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-5-0) 2 | (:domain BLOCKS) 3 | (:objects B E A C D - block) 4 | (:INIT (CLEAR D) (CLEAR C) (ONTABLE D) (ONTABLE A) (ON C E) (ON E B) (ON B A) 5 | (HANDEMPTY)) 6 | (:goal (AND (ON A E) (ON E B) (ON B D) (ON D C))) 7 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task5.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-5-1) 2 | (:domain BLOCKS) 3 | (:objects A D C E B - block) 4 | (:INIT (CLEAR B) (CLEAR E) (CLEAR C) (ONTABLE D) (ONTABLE E) (ONTABLE C) 5 | (ON B A) (ON A D) (HANDEMPTY)) 6 | (:goal (AND (ON D C) (ON C B) (ON B A) (ON A E))) 7 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task6.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-5-2) 2 | (:domain BLOCKS) 3 | (:objects A C E B D - block) 4 | (:INIT (CLEAR D) (ONTABLE B) (ON D E) (ON E C) (ON C A) (ON A B) (HANDEMPTY)) 5 | (:goal (AND (ON D C) (ON C B) (ON B E) (ON E A))) 6 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task7.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-6-0) 2 | (:domain BLOCKS) 3 | (:objects E A B C F D - block) 4 | (:INIT (CLEAR D) (CLEAR F) (ONTABLE C) (ONTABLE B) (ON D A) (ON A C) (ON F E) 5 | (ON E B) (HANDEMPTY)) 6 | (:goal (AND (ON C B) (ON B A) (ON A E) (ON E F) (ON F D))) 7 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task8.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-6-1) 2 | (:domain BLOCKS) 3 | (:objects F D C E B A - block) 4 | (:INIT (CLEAR A) (CLEAR B) (CLEAR E) (CLEAR C) (CLEAR D) (ONTABLE F) 5 | (ONTABLE B) (ONTABLE E) (ONTABLE C) (ONTABLE D) (ON A F) (HANDEMPTY)) 6 | (:goal (AND (ON E F) (ON F C) (ON C B) (ON B A) (ON A D))) 7 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/blocks/task9.pddl: -------------------------------------------------------------------------------- 1 | (define (problem BLOCKS-6-2) 2 | (:domain BLOCKS) 3 | (:objects E F B D C A - block) 4 | (:INIT (CLEAR A) (ONTABLE C) (ON A D) (ON D B) (ON B F) (ON F E) (ON E C) 5 | (HANDEMPTY)) 6 | (:goal (AND (ON E F) (ON F A) (ON A B) (ON B C) (ON C D))) 7 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/delivery/domain.pddl: -------------------------------------------------------------------------------- 1 | (define (domain delivery) 2 | (:requirements :strips :typing) 3 | (:types loc paper) 4 | (:predicates 5 | (at ?loc - loc) 6 | (isHomeBase ?loc - loc) 7 | (satisfied ?loc - loc) 8 | (wantsPaper ?loc - loc) 9 | (safe ?loc - loc) 10 | (unpacked ?paper - paper) 11 | (carrying ?paper - paper) 12 | ) 13 | 14 | (:action pick-up 15 | :parameters (?paper - paper ?loc - loc) 16 | :precondition (and 17 | (at ?loc) 18 | (isHomeBase ?loc) 19 | (unpacked ?paper) 20 | ) 21 | :effect (and 22 | (not (unpacked ?paper)) 23 | (carrying ?paper) 24 | ) 25 | ) 26 | 27 | (:action move 28 | :parameters (?from - loc ?to - loc) 29 | :precondition (and 30 | (at ?from) 31 | (safe ?from) 32 | ) 33 | :effect (and 34 | (not (at ?from)) 35 | (at ?to) 36 | ) 37 | ) 38 | 39 | (:action deliver 40 | :parameters (?paper - paper ?loc - loc) 41 | :precondition (and 42 | (at ?loc) 43 | (carrying ?paper) 44 | ) 45 | :effect (and 46 | (not (carrying ?paper)) 47 | (not (wantsPaper ?loc)) 48 | (satisfied ?loc) 49 | ) 50 | ) 51 | 52 | ) 53 | -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/ferry/domain.pddl: -------------------------------------------------------------------------------- 1 | (define (domain ferry) 2 | (:predicates (not-eq ?x ?y) 3 | (car ?c) 4 | (location ?l) 5 | (at-ferry ?l) 6 | (at ?c ?l) 7 | (empty-ferry) 8 | (on ?c)) 9 | 10 | (:action sail 11 | :parameters (?from ?to) 12 | :precondition (and (not-eq ?from ?to) 13 | (location ?from) (location ?to) (at-ferry ?from)) 14 | :effect (and (at-ferry ?to) 15 | (not (at-ferry ?from)))) 16 | 17 | 18 | (:action board 19 | :parameters (?car ?loc) 20 | :precondition (and (car ?car) (location ?loc) 21 | (at ?car ?loc) (at-ferry ?loc) (empty-ferry)) 22 | :effect (and (on ?car) 23 | (not (at ?car ?loc)) 24 | (not (empty-ferry)))) 25 | 26 | (:action debark 27 | :parameters (?car ?loc) 28 | :precondition (and (car ?car) (location ?loc) 29 | (on ?car) (at-ferry ?loc)) 30 | :effect (and (at ?car ?loc) 31 | (empty-ferry) 32 | (not (on ?car))))) 33 | -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/forest/domain.pddl: -------------------------------------------------------------------------------- 1 | (define (domain forest) 2 | (:requirements :strips :typing) 3 | (:types loc) 4 | 5 | (:predicates 6 | (at ?loc - loc) 7 | (isNotWater ?loc - loc) 8 | (isHill ?loc - loc) 9 | (isNotHill ?loc - loc) 10 | (adjacent ?loc1 - loc ?loc2 - loc) 11 | (onTrail ?from - loc ?to - loc) 12 | ) 13 | 14 | (:action walk 15 | :parameters (?from - loc ?to - loc) 16 | :precondition (and 17 | (isNotHill ?to) 18 | (at ?from) 19 | (adjacent ?from ?to) 20 | (isNotWater ?from)) 21 | :effect (and (at ?to) (not (at ?from))) 22 | ) 23 | 24 | (:action climb 25 | :parameters (?from - loc ?to - loc) 26 | :precondition (and 27 | (isHill ?to) 28 | (at ?from) 29 | (adjacent ?from ?to) 30 | (isNotWater ?from)) 31 | :effect (and (at ?to) (not (at ?from))) 32 | ) 33 | 34 | 35 | ) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/gripper/domain.pddl: -------------------------------------------------------------------------------- 1 | 2 | (define (domain gripper) 3 | (:predicates (room ?r) 4 | (ball ?b) 5 | (gripper ?g) 6 | (at-robby ?r) 7 | (at ?b ?r) 8 | (free ?g) 9 | (carry ?o ?g)) 10 | 11 | (:action move 12 | :parameters (?from ?to) 13 | :precondition (and (room ?from) (room ?to) (at-robby ?from)) 14 | :effect (and (at-robby ?to) 15 | (not (at-robby ?from)))) 16 | 17 | 18 | 19 | (:action pick 20 | :parameters (?obj ?room ?gripper) 21 | :precondition (and (ball ?obj) (room ?room) (gripper ?gripper) 22 | (at ?obj ?room) (at-robby ?room) (free ?gripper)) 23 | :effect (and (carry ?obj ?gripper) 24 | (not (at ?obj ?room)) 25 | (not (free ?gripper)))) 26 | 27 | 28 | (:action drop 29 | :parameters (?obj ?room ?gripper) 30 | :precondition (and (ball ?obj) (room ?room) (gripper ?gripper) 31 | (carry ?obj ?gripper) (at-robby ?room)) 32 | :effect (and (at ?obj ?room) 33 | (free ?gripper) 34 | (not (carry ?obj ?gripper))))) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/gripper/prefixed_domain.pddl: -------------------------------------------------------------------------------- 1 | (define (domain pregripper) 2 | (:predicates (preroom ?r) 3 | (preball ?b) 4 | (pregripper ?g) 5 | (preat-robby ?r) 6 | (preat ?b ?r) 7 | (prefree ?g) 8 | (precarry ?o ?g)) 9 | 10 | (:action move 11 | :parameters (?from ?to) 12 | :precondition (and (preroom ?from) (preroom ?to) (preat-robby ?from)) 13 | :effect (and (preat-robby ?to) 14 | (not (preat-robby ?from)))) 15 | 16 | 17 | 18 | (:action pick 19 | :parameters (?obj ?preroom ?pregripper) 20 | :precondition (and (preball ?obj) (preroom ?preroom) (pregripper ?pregripper) 21 | (preat ?obj ?preroom) (preat-robby ?preroom) (prefree ?pregripper)) 22 | :effect (and (precarry ?obj ?pregripper) 23 | (not (preat ?obj ?preroom)) 24 | (not (prefree ?pregripper)))) 25 | 26 | 27 | (:action drop 28 | :parameters (?obj ?preroom ?pregripper) 29 | :precondition (and (preball ?obj) (preroom ?preroom) (pregripper ?pregripper) 30 | (precarry ?obj ?pregripper) (preat-robby ?preroom)) 31 | :effect (and (preat ?obj ?preroom) 32 | (prefree ?pregripper) 33 | (not (precarry ?obj ?pregripper))))) -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/miconic/domain.pddl: -------------------------------------------------------------------------------- 1 | (define (domain miconic) 2 | (:requirements :strips) 3 | (:types passenger - object 4 | floor - object 5 | ) 6 | 7 | (:predicates 8 | (origin ?person - passenger ?floor - floor) 9 | (destin ?person - passenger ?floor - floor) 10 | (above ?floor1 - floor ?floor2 - floor) 11 | (boarded ?person - passenger) 12 | (served ?person - passenger) 13 | (lift-at ?floor - floor) 14 | ) 15 | 16 | (:action board 17 | :parameters (?f - floor ?p - passenger) 18 | :precondition (and (lift-at ?f) 19 | (origin ?p ?f) 20 | ) 21 | :effect (and (boarded ?p) 22 | ) 23 | ) 24 | 25 | (:action depart 26 | :parameters (?f - floor ?p - passenger) 27 | :precondition (and (lift-at ?f) 28 | (destin ?p ?f) 29 | (boarded ?p) 30 | ) 31 | :effect (and (not (boarded ?p)) 32 | (served ?p) 33 | ) 34 | ) 35 | 36 | (:action up 37 | :parameters (?f1 - floor ?f2 - floor) 38 | :precondition (and (lift-at ?f1) 39 | (above ?f1 ?f2) 40 | ) 41 | :effect (and (lift-at ?f2) 42 | (not (lift-at ?f1)) 43 | ) 44 | ) 45 | 46 | (:action down 47 | :parameters (?f1 - floor ?f2 - floor) 48 | :precondition (and (lift-at ?f1) 49 | (above ?f2 ?f1) 50 | ) 51 | :effect (and (lift-at ?f2) 52 | (not (lift-at ?f1)) 53 | ) 54 | ) 55 | ) 56 | -------------------------------------------------------------------------------- /predicators/envs/assets/pddl/spannerlearning/domain.pddl: -------------------------------------------------------------------------------- 1 | (define (domain spanner) 2 | (:requirements :typing :strips) 3 | (:types 4 | location locatable - object 5 | man nut spanner - locatable 6 | ) 7 | 8 | (:predicates 9 | (at ?m - locatable ?l - location) 10 | (carrying ?m - man ?s - spanner) 11 | (useable ?s - spanner) 12 | (link ?l1 - location ?l2 - location) 13 | (tightened ?n - nut) 14 | (loose ?n - nut)) 15 | 16 | (:action walk 17 | :parameters (?start - location ?end - location ?m - man) 18 | :precondition (and (at ?m ?start) 19 | (link ?start ?end)) 20 | :effect (and (not (at ?m ?start)) (at ?m ?end))) 21 | 22 | (:action pickup_spanner 23 | :parameters (?l - location ?s - spanner ?m - man) 24 | :precondition (and (at ?m ?l) 25 | (at ?s ?l)) 26 | :effect (and (not (at ?s ?l)) 27 | (carrying ?m ?s))) 28 | 29 | (:action tighten_nut 30 | :parameters (?l - location ?s - spanner ?m - man ?n - nut) 31 | :precondition (and (at ?m ?l) 32 | (at ?n ?l) 33 | (carrying ?m ?s) 34 | (useable ?s) 35 | (loose ?n)) 36 | :effect (and (not (loose ?n))(not (useable ?s)) (tightened ?n))) 37 | ) 38 | -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/base_link_collision.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/base_link_collision.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/base_link_uv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/base_link_uv.png -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/bellows_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/bellows_link.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/bellows_link_collision.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/bellows_link_collision.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/elbow_flex_link_collision.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/elbow_flex_link_collision.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/elbow_flex_uv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/elbow_flex_uv.png -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/estop_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/estop_link.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/forearm_roll_link_collision.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/forearm_roll_link_collision.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/forearm_roll_uv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/forearm_roll_uv.png -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/gripper_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/gripper_link.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/gripper_uv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/gripper_uv.png -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/head_pan_link_collision.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/head_pan_link_collision.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/head_pan_uv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/head_pan_uv.png -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/head_tilt_link_collision.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/head_tilt_link_collision.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/head_tilt_uv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/head_tilt_uv.png -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/l_gripper_finger_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/l_gripper_finger_link.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/l_wheel_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/l_wheel_link.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/l_wheel_link_collision.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/l_wheel_link_collision.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/laser_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/laser_link.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/r_gripper_finger_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/r_gripper_finger_link.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/r_wheel_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/r_wheel_link.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/r_wheel_link_collision.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/r_wheel_link_collision.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/shoulder_lift_link_collision.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/shoulder_lift_link_collision.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/shoulder_lift_uv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/shoulder_lift_uv.png -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/shoulder_pan_link_collision.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/shoulder_pan_link_collision.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/shoulder_pan_uv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/shoulder_pan_uv.png -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/torso_fixed_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/torso_fixed_link.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/torso_fixed_uv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/torso_fixed_uv.png -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/torso_lift_link_collision.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/torso_lift_link_collision.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/torso_lift_uv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/torso_lift_uv.png -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/upperarm_roll_link_collision.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/upperarm_roll_link_collision.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/upperarm_roll_uv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/upperarm_roll_uv.png -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/wrist_flex_link_collision.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/wrist_flex_link_collision.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/wrist_flex_uv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/wrist_flex_uv.png -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/wrist_roll_link_collision.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/wrist_roll_link_collision.STL -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/fetch_description/meshes/wrist_roll_uv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/fetch_description/meshes/wrist_roll_uv.png -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/franka_description/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.3) 2 | project(franka_description) 3 | 4 | find_package(catkin REQUIRED) 5 | catkin_package(CATKIN_DEPENDS xacro) 6 | 7 | install(DIRECTORY meshes 8 | DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} 9 | ) 10 | install(DIRECTORY robots 11 | DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} 12 | ) 13 | -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/franka_description/mainpage.dox: -------------------------------------------------------------------------------- 1 | /** 2 | * @mainpage 3 | * @htmlinclude "manifest.html" 4 | * 5 | * Overview page for Franka Emika research robots: https://frankaemika.github.io 6 | */ 7 | -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/franka_description/meshes/collision/finger.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/franka_description/meshes/collision/finger.stl -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/franka_description/meshes/collision/hand.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/franka_description/meshes/collision/hand.stl -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/franka_description/meshes/collision/link0.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/franka_description/meshes/collision/link0.stl -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/franka_description/meshes/collision/link1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/franka_description/meshes/collision/link1.stl -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/franka_description/meshes/collision/link2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/franka_description/meshes/collision/link2.stl -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/franka_description/meshes/collision/link3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/franka_description/meshes/collision/link3.stl -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/franka_description/meshes/collision/link4.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/franka_description/meshes/collision/link4.stl -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/franka_description/meshes/collision/link5.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/franka_description/meshes/collision/link5.stl -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/franka_description/meshes/collision/link6.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/franka_description/meshes/collision/link6.stl -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/franka_description/meshes/collision/link7.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/franka_description/meshes/collision/link7.stl -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/franka_description/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | franka_description 4 | 0.7.0 5 | franka_description contains URDF files and meshes of Franka Emika robots 6 | Franka Emika GmbH 7 | Apache 2.0 8 | 9 | http://wiki.ros.org/franka_description 10 | https://github.com/frankaemika/franka_ros 11 | https://github.com/frankaemika/franka_ros/issues 12 | Franka Emika GmbH 13 | 14 | catkin 15 | 16 | xacro 17 | 18 | -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/franka_description/robots/hand.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/franka_description/robots/hand.urdf.xacro: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/franka_description/robots/hand.xacro: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/franka_description/robots/panda_arm.urdf.xacro: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/franka_description/robots/panda_arm_hand.urdf.xacro: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/franka_description/rosdoc.yaml: -------------------------------------------------------------------------------- 1 | - builder: doxygen 2 | javadoc_autobrief: YES 3 | -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/plane.obj: -------------------------------------------------------------------------------- 1 | # Blender v2.66 (sub 1) OBJ File: '' 2 | # www.blender.org 3 | mtllib plane.mtl 4 | o Plane 5 | v 15.000000 -15.000000 0.000000 6 | v 15.000000 15.000000 0.000000 7 | v -15.000000 15.000000 0.000000 8 | v -15.000000 -15.000000 0.000000 9 | 10 | vt 15.000000 0.000000 11 | vt 15.000000 15.000000 12 | vt 0.000000 15.000000 13 | vt 0.000000 0.000000 14 | 15 | usemtl Material 16 | s off 17 | f 1/1 2/2 3/3 18 | f 1/1 3/3 4/4 19 | -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/plane.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/table.obj: -------------------------------------------------------------------------------- 1 | # table.obj 2 | # 3 | 4 | o table 5 | mtllib table.mtl 6 | 7 | v -0.500000 -0.500000 0.500000 8 | v 0.500000 -0.500000 0.500000 9 | v -0.500000 0.500000 0.500000 10 | v 0.500000 0.500000 0.500000 11 | v -0.500000 0.500000 -0.500000 12 | v 0.500000 0.500000 -0.500000 13 | v -0.500000 -0.500000 -0.500000 14 | v 0.500000 -0.500000 -0.500000 15 | 16 | vt 0.000000 0.000000 17 | vt 1.000000 0.000000 18 | vt 0.000000 1.000000 19 | vt 1.000000 1.000000 20 | 21 | vn 0.000000 0.000000 1.000000 22 | vn 0.000000 1.000000 0.000000 23 | vn 0.000000 0.000000 -1.000000 24 | vn 0.000000 -1.000000 0.000000 25 | vn 1.000000 0.000000 0.000000 26 | vn -1.000000 0.000000 0.000000 27 | 28 | g table 29 | usemtl table 30 | s 1 31 | f 1/1/1 2/2/1 3/3/1 32 | f 3/3/1 2/2/1 4/4/1 33 | s 2 34 | f 3/1/2 4/2/2 5/3/2 35 | f 5/3/2 4/2/2 6/4/2 36 | s 3 37 | f 5/4/3 6/3/3 7/2/3 38 | f 7/2/3 6/3/3 8/1/3 39 | s 4 40 | f 7/1/4 8/2/4 1/3/4 41 | f 1/3/4 8/2/4 2/4/4 42 | s 5 43 | f 2/1/5 8/2/5 4/3/5 44 | f 4/3/5 8/2/5 6/4/5 45 | s 6 46 | f 7/1/6 1/2/6 5/3/6 47 | f 5/3/6 1/2/6 3/4/6 48 | 49 | -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/envs/assets/urdf/table.png -------------------------------------------------------------------------------- /predicators/envs/assets/urdf/table.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /predicators/explorers/base_explorer.py: -------------------------------------------------------------------------------- 1 | """Base class for an explorer.""" 2 | 3 | import abc 4 | from typing import List, Set 5 | 6 | import numpy as np 7 | from gym.spaces import Box 8 | 9 | from predicators.settings import CFG 10 | from predicators.structs import ExplorationStrategy, ParameterizedOption, \ 11 | Predicate, Task, Type 12 | 13 | 14 | class BaseExplorer(abc.ABC): 15 | """Creates a policy and termination function for exploring in a task. 16 | 17 | The explorer is created at the beginning of every interaction cycle 18 | with the latest predicates and options. 19 | """ 20 | 21 | def __init__(self, predicates: Set[Predicate], 22 | options: Set[ParameterizedOption], types: Set[Type], 23 | action_space: Box, train_tasks: List[Task]) -> None: 24 | self._predicates = predicates 25 | self._options = options 26 | self._types = types 27 | self._action_space = action_space 28 | self._train_tasks = train_tasks 29 | self._set_seed(CFG.seed) 30 | 31 | @classmethod 32 | @abc.abstractmethod 33 | def get_name(cls) -> str: 34 | """Get the unique name of this explorer.""" 35 | raise NotImplementedError("Override me!") 36 | 37 | @abc.abstractmethod 38 | def get_exploration_strategy( 39 | self, 40 | train_task_idx: int, 41 | timeout: int, 42 | ) -> ExplorationStrategy: 43 | """Given a train task idx, create an ExplorationStrategy, which is a 44 | tuple of a policy and a termination function.""" 45 | raise NotImplementedError("Override me!") 46 | 47 | def _set_seed(self, seed: int) -> None: 48 | """Reset seed and rng.""" 49 | self._seed = seed 50 | self._rng = np.random.default_rng(self._seed) 51 | -------------------------------------------------------------------------------- /predicators/explorers/bilevel_planning_explorer.py: -------------------------------------------------------------------------------- 1 | """An explorer that uses bilevel planning with NSRTs.""" 2 | 3 | from typing import List, Set 4 | 5 | from gym.spaces import Box 6 | 7 | from predicators import utils 8 | from predicators.explorers.base_explorer import BaseExplorer 9 | from predicators.option_model import _OptionModelBase 10 | from predicators.planning import sesame_plan 11 | from predicators.settings import CFG 12 | from predicators.structs import NSRT, ExplorationStrategy, \ 13 | ParameterizedOption, Predicate, Task, Type 14 | 15 | 16 | class BilevelPlanningExplorer(BaseExplorer): 17 | """BilevelPlanningExplorer implementation. 18 | 19 | This explorer is abstract: subclasses decide how to use the _solve 20 | method implemented in this class, which calls sesame_plan(). 21 | """ 22 | 23 | def __init__(self, predicates: Set[Predicate], 24 | options: Set[ParameterizedOption], types: Set[Type], 25 | action_space: Box, train_tasks: List[Task], nsrts: Set[NSRT], 26 | option_model: _OptionModelBase) -> None: 27 | 28 | super().__init__(predicates, options, types, action_space, train_tasks) 29 | self._nsrts = nsrts 30 | self._option_model = option_model 31 | self._num_calls = 0 32 | 33 | def _solve(self, task: Task, timeout: int) -> ExplorationStrategy: 34 | 35 | # Ensure random over successive calls. 36 | self._num_calls += 1 37 | seed = self._seed + self._num_calls 38 | # Note: subclasses are responsible for catching PlanningFailure and 39 | # PlanningTimeout and handling them accordingly. 40 | plan, _, _ = sesame_plan( 41 | task, 42 | self._option_model, 43 | self._nsrts, 44 | self._predicates, 45 | self._types, 46 | timeout, 47 | seed, 48 | CFG.sesame_task_planning_heuristic, 49 | CFG.sesame_max_skeletons_optimized, 50 | max_horizon=CFG.horizon, 51 | allow_noops=CFG.sesame_allow_noops, 52 | use_visited_state_set=CFG.sesame_use_visited_state_set) 53 | policy = utils.option_plan_to_policy(plan) 54 | termination_function = task.goal_holds 55 | 56 | return policy, termination_function 57 | -------------------------------------------------------------------------------- /predicators/explorers/exploit_bilevel_planning_explorer.py: -------------------------------------------------------------------------------- 1 | """An explorer that explores by solving tasks with bilevel planning.""" 2 | 3 | from typing import List, Set 4 | 5 | from gym.spaces import Box 6 | 7 | from predicators.explorers.bilevel_planning_explorer import \ 8 | BilevelPlanningExplorer 9 | from predicators.explorers.random_options_explorer import RandomOptionsExplorer 10 | from predicators.option_model import _OptionModelBase 11 | from predicators.planning import PlanningFailure, PlanningTimeout 12 | from predicators.structs import NSRT, ExplorationStrategy, \ 13 | ParameterizedOption, Predicate, Task, Type 14 | 15 | 16 | class ExploitBilevelPlanningExplorer(BilevelPlanningExplorer): 17 | """ExploitBilevelPlanningExplorer implementation.""" 18 | 19 | def __init__(self, predicates: Set[Predicate], 20 | options: Set[ParameterizedOption], types: Set[Type], 21 | action_space: Box, train_tasks: List[Task], nsrts: Set[NSRT], 22 | option_model: _OptionModelBase) -> None: 23 | super().__init__(predicates, options, types, action_space, train_tasks, 24 | nsrts, option_model) 25 | # Falls back to random options. 26 | self._fallback_explorer = RandomOptionsExplorer( 27 | predicates, options, types, action_space, train_tasks) 28 | 29 | @classmethod 30 | def get_name(cls) -> str: 31 | return "exploit_planning" 32 | 33 | def get_exploration_strategy(self, train_task_idx: int, 34 | timeout: int) -> ExplorationStrategy: 35 | task = self._train_tasks[train_task_idx] 36 | try: 37 | return self._solve(task, timeout) 38 | except (PlanningFailure, PlanningTimeout): 39 | return self._fallback_explorer.get_exploration_strategy( 40 | train_task_idx, timeout) 41 | -------------------------------------------------------------------------------- /predicators/explorers/no_explore_explorer.py: -------------------------------------------------------------------------------- 1 | """An explorer that always terminates immediately without taking an action.""" 2 | 3 | from predicators.explorers import BaseExplorer 4 | from predicators.structs import Action, ExplorationStrategy, State 5 | 6 | 7 | class NoExploreExplorer(BaseExplorer): 8 | """Terminates immediately during exploration.""" 9 | 10 | @classmethod 11 | def get_name(cls) -> str: 12 | return "no_explore" 13 | 14 | def get_exploration_strategy(self, train_task_idx: int, 15 | timeout: int) -> ExplorationStrategy: 16 | 17 | def policy(_: State) -> Action: 18 | raise RuntimeError("The policy for no-explore shouldn't be used.") 19 | 20 | # Terminate immediately. 21 | termination_function = lambda _: True 22 | 23 | return policy, termination_function 24 | -------------------------------------------------------------------------------- /predicators/explorers/random_actions_explorer.py: -------------------------------------------------------------------------------- 1 | """An explorer that just takes random low-level actions.""" 2 | 3 | from predicators.explorers import BaseExplorer 4 | from predicators.structs import Action, ExplorationStrategy 5 | 6 | 7 | class RandomActionsExplorer(BaseExplorer): 8 | """Samples random low-level actions.""" 9 | 10 | @classmethod 11 | def get_name(cls) -> str: 12 | return "random_actions" 13 | 14 | def get_exploration_strategy(self, train_task_idx: int, 15 | timeout: int) -> ExplorationStrategy: 16 | # Take random actions. 17 | policy = lambda _: Action(self._action_space.sample()) 18 | # Never terminate (until the interaction budget is exceeded). 19 | termination_function = lambda _: False 20 | return policy, termination_function 21 | -------------------------------------------------------------------------------- /predicators/explorers/random_options_explorer.py: -------------------------------------------------------------------------------- 1 | """An explorer that takes random options.""" 2 | 3 | from predicators import utils 4 | from predicators.explorers import BaseExplorer 5 | from predicators.structs import Action, ExplorationStrategy, State 6 | 7 | 8 | class RandomOptionsExplorer(BaseExplorer): 9 | """Samples random options.""" 10 | 11 | @classmethod 12 | def get_name(cls) -> str: 13 | return "random_options" 14 | 15 | def get_exploration_strategy(self, train_task_idx: int, 16 | timeout: int) -> ExplorationStrategy: 17 | # Take random options, and raise an exception if no applicable option 18 | # can be found. 19 | 20 | # Note that this fallback policy is different from the one in 21 | # RandomOptionsApproach because explorers should raise 22 | # RequestActPolicyFailure instead of ApproachFailure. 23 | def fallback_policy(state: State) -> Action: 24 | del state # unused 25 | raise utils.RequestActPolicyFailure( 26 | "Random option sampling failed!") 27 | 28 | policy = utils.create_random_option_policy(self._options, self._rng, 29 | fallback_policy) 30 | # Never terminate (until the interaction budget is exceeded). 31 | termination_function = lambda _: False 32 | return policy, termination_function 33 | -------------------------------------------------------------------------------- /predicators/gnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/gnn/__init__.py -------------------------------------------------------------------------------- /predicators/nsrt_learning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/nsrt_learning/__init__.py -------------------------------------------------------------------------------- /predicators/nsrt_learning/strips_learning/__init__.py: -------------------------------------------------------------------------------- 1 | """This directory contains algorithms for STRIPS operator learning.""" 2 | 3 | import importlib 4 | import pkgutil 5 | from typing import TYPE_CHECKING, List, Set 6 | 7 | from predicators import utils 8 | from predicators.nsrt_learning.strips_learning.base_strips_learner import \ 9 | BaseSTRIPSLearner 10 | from predicators.settings import CFG 11 | from predicators.structs import PNAD, LowLevelTrajectory, Predicate, Segment, \ 12 | Task 13 | 14 | __all__ = ["BaseSTRIPSLearner"] 15 | 16 | if not TYPE_CHECKING: 17 | # Load all modules so that utils.get_all_subclasses() works. 18 | for _, module_name, _ in pkgutil.walk_packages(__path__): 19 | if "__init__" not in module_name: 20 | # Important! We use an absolute import here to avoid issues 21 | # with isinstance checking when using relative imports. 22 | importlib.import_module(f"{__name__}.{module_name}") 23 | 24 | 25 | def learn_strips_operators( 26 | trajectories: List[LowLevelTrajectory], 27 | train_tasks: List[Task], 28 | predicates: Set[Predicate], 29 | segmented_trajs: List[List[Segment]], 30 | verify_harmlessness: bool, 31 | verbose: bool = True, 32 | ) -> List[PNAD]: 33 | """Learn strips operators on the given data segments. 34 | 35 | Return a list of PNADs with op (STRIPSOperator), datastore, and 36 | option_spec fields filled in (but not sampler). 37 | """ 38 | for cls in utils.get_all_subclasses(BaseSTRIPSLearner): 39 | if not cls.__abstractmethods__ and \ 40 | cls.get_name() == CFG.strips_learner: 41 | learner = cls(trajectories, train_tasks, predicates, 42 | segmented_trajs, verify_harmlessness, verbose) 43 | break 44 | else: 45 | raise ValueError(f"Unrecognized STRIPS learner: {CFG.strips_learner}") 46 | return learner.learn() 47 | -------------------------------------------------------------------------------- /predicators/nsrt_learning/strips_learning/oracle_learner.py: -------------------------------------------------------------------------------- 1 | """Oracle for STRIPS learning.""" 2 | 3 | import logging 4 | from typing import List 5 | 6 | from predicators.envs import get_or_create_env 7 | from predicators.ground_truth_nsrts import get_gt_nsrts 8 | from predicators.nsrt_learning.strips_learning import BaseSTRIPSLearner 9 | from predicators.settings import CFG 10 | from predicators.structs import PNAD, Datastore, DummyOption 11 | 12 | 13 | class OracleSTRIPSLearner(BaseSTRIPSLearner): 14 | """Base class for an oracle STRIPS learner.""" 15 | 16 | def _learn(self) -> List[PNAD]: 17 | env = get_or_create_env(CFG.env) 18 | gt_nsrts = get_gt_nsrts(env.get_name(), env.predicates, env.options) 19 | pnads: List[PNAD] = [] 20 | for nsrt in gt_nsrts: 21 | datastore: Datastore = [] 22 | # If options are unknown, use a dummy option spec. 23 | if CFG.option_learner == "no_learning": 24 | option_spec = (nsrt.option, list(nsrt.option_vars)) 25 | else: 26 | option_spec = (DummyOption.parent, []) 27 | pnads.append(PNAD(nsrt.op, datastore, option_spec)) 28 | self._recompute_datastores_from_segments(pnads) 29 | # Filter out any pnad that has an empty datastore. This can occur when 30 | # using non-standard settings with environments that cause certain 31 | # operators to be unnecessary. For example, in painting, when using 32 | # `--painting_goal_receptacles box`, the operator for picking from 33 | # the side becomes unnecessary (and no demo data will cover it). 34 | nontrivial_pnads = [] 35 | for pnad in pnads: 36 | if not pnad.datastore: 37 | logging.warning(f"Discarding PNAD with no data: {pnad}") 38 | continue 39 | nontrivial_pnads.append(pnad) 40 | return nontrivial_pnads 41 | 42 | @classmethod 43 | def get_name(cls) -> str: 44 | return "oracle" 45 | -------------------------------------------------------------------------------- /predicators/pybullet_helpers/__init__.py: -------------------------------------------------------------------------------- 1 | """This implementation is heavily based on the pybullet-planning repository by 2 | Caelan Garrett (https://github.com/caelan/pybullet-planning/). 3 | 4 | In addition, the structure is loosely based off the pb_robot repository 5 | by Rachel Holladay (https://github.com/rachelholladay/pb_robot). 6 | """ 7 | -------------------------------------------------------------------------------- /predicators/pybullet_helpers/ikfast/__init__.py: -------------------------------------------------------------------------------- 1 | """The `ikfast` module contains all the functionality to compile, install, and 2 | load the IKFast module and also run IK using it.""" 3 | 4 | from typing import List, NamedTuple 5 | 6 | 7 | class IKFastInfo(NamedTuple): 8 | """IKFast information for a given robot.""" 9 | module_dir: str 10 | module_name: str 11 | base_link: str 12 | ee_link: str 13 | free_joints: List[str] 14 | -------------------------------------------------------------------------------- /predicators/pybullet_helpers/link.py: -------------------------------------------------------------------------------- 1 | """PyBullet helper class for link utilities.""" 2 | from typing import NamedTuple 3 | 4 | import pybullet as p 5 | 6 | from predicators.pybullet_helpers.geometry import Pose, Pose3D, Quaternion, \ 7 | get_pose, multiply_poses 8 | 9 | BASE_LINK: int = -1 10 | 11 | 12 | class LinkState(NamedTuple): 13 | """Link state to match the output of the PyBullet getLinkState API. 14 | 15 | We use a NamedTuple as it supports retrieving by integer indexing 16 | and most closely follows the PyBullet API. 17 | """ 18 | linkWorldPosition: Pose3D 19 | linkWorldOrientation: Quaternion 20 | localInertialFramePosition: Pose3D 21 | localInertialFrameOrientation: Quaternion 22 | worldLinkFramePosition: Pose3D 23 | worldLinkFrameOrientation: Quaternion 24 | 25 | @property 26 | def com_pose(self) -> Pose: 27 | """Center of mass (COM) pose of link.""" 28 | return Pose(self.linkWorldPosition, self.linkWorldOrientation) 29 | 30 | @property 31 | def pose(self) -> Pose: 32 | """Pose of link in world frame.""" 33 | return Pose(self.worldLinkFramePosition, 34 | self.worldLinkFrameOrientation) 35 | 36 | 37 | def get_link_state( 38 | body: int, 39 | link: int, 40 | physics_client_id: int, 41 | ) -> LinkState: 42 | """Get the state of a link in a given body. 43 | 44 | Note: it is unclear what the computeForwardKinematics flag does as we 45 | could not reproduce any difference in the resulting Cartesian world 46 | position or orientation of the link after setting joint positions 47 | with both the flag set to False or True. 48 | 49 | The default PyBullet flag is computeForwardKinematics=False, so we 50 | will stick to that. 51 | """ 52 | link_state = p.getLinkState(body, link, physicsClientId=physics_client_id) 53 | return LinkState(*link_state) 54 | 55 | 56 | def get_link_pose(body: int, link: int, physics_client_id: int) -> Pose: 57 | """Get the pose for a link in a given body.""" 58 | if link == BASE_LINK: 59 | return get_pose(body, physics_client_id) 60 | 61 | link_state = get_link_state(body, link, physics_client_id) 62 | return link_state.pose 63 | 64 | 65 | def get_relative_link_pose(body: int, link1: int, link2: int, 66 | physics_client_id: int) -> Pose: 67 | """Get the pose of link1 relative to link2 on the same body.""" 68 | world_from_link1 = get_link_pose(body, link1, physics_client_id) 69 | world_from_link2 = get_link_pose(body, link2, physics_client_id) 70 | link2_from_link1 = multiply_poses(world_from_link2.invert(), 71 | world_from_link1) 72 | return link2_from_link1 73 | -------------------------------------------------------------------------------- /predicators/pybullet_helpers/robots/__init__.py: -------------------------------------------------------------------------------- 1 | """Handles the creation of robots.""" 2 | from typing import Dict, Type 3 | 4 | from predicators.pybullet_helpers.geometry import Pose, Pose3D 5 | from predicators.pybullet_helpers.robots.fetch import FetchPyBulletRobot 6 | from predicators.pybullet_helpers.robots.panda import PandaPyBulletRobot 7 | from predicators.pybullet_helpers.robots.single_arm import \ 8 | SingleArmPyBulletRobot 9 | from predicators.settings import CFG 10 | 11 | # Note: these are static base poses which suffice for the current environments. 12 | _ROBOT_TO_BASE_POSE: Dict[str, Pose] = { 13 | "fetch": Pose(position=(0.75, 0.7441, 0.0)), 14 | "panda": Pose(position=(0.8, 0.7441, 0.195)), 15 | } 16 | 17 | _ROBOT_TO_CLS: Dict[str, Type[SingleArmPyBulletRobot]] = { 18 | "fetch": FetchPyBulletRobot, 19 | "panda": PandaPyBulletRobot, 20 | } 21 | 22 | 23 | def create_single_arm_pybullet_robot( 24 | robot_name: str, 25 | physics_client_id: int, 26 | ee_home_pose: Pose3D = (1.35, 0.6, 0.7), 27 | ) -> SingleArmPyBulletRobot: 28 | """Create a single-arm PyBullet robot.""" 29 | robot_to_ee_orn = CFG.pybullet_robot_ee_orns[CFG.env] 30 | if robot_name in _ROBOT_TO_CLS and robot_name in robot_to_ee_orn: 31 | assert robot_name in _ROBOT_TO_BASE_POSE, \ 32 | f"Base pose not specified for robot {robot_name}." 33 | base_pose = _ROBOT_TO_BASE_POSE[robot_name] 34 | ee_orientation = robot_to_ee_orn[robot_name] 35 | cls = _ROBOT_TO_CLS[robot_name] 36 | return cls(ee_home_pose, 37 | ee_orientation, 38 | physics_client_id, 39 | base_pose=base_pose) 40 | raise NotImplementedError(f"Unrecognized robot name: {robot_name}.") 41 | -------------------------------------------------------------------------------- /predicators/pybullet_helpers/robots/fetch.py: -------------------------------------------------------------------------------- 1 | """Fetch Robotics Mobile Manipulator (Fetch).""" 2 | 3 | from predicators import utils 4 | from predicators.pybullet_helpers.robots.single_arm import \ 5 | SingleArmPyBulletRobot 6 | 7 | 8 | class FetchPyBulletRobot(SingleArmPyBulletRobot): 9 | """A Fetch robot with a fixed base and only one arm in use.""" 10 | 11 | @classmethod 12 | def get_name(cls) -> str: 13 | return "fetch" 14 | 15 | @classmethod 16 | def urdf_path(cls) -> str: 17 | return utils.get_env_asset_path( 18 | "urdf/fetch_description/robots/fetch.urdf") 19 | 20 | @property 21 | def end_effector_name(self) -> str: 22 | return "gripper_axis" 23 | 24 | @property 25 | def tool_link_name(self) -> str: 26 | return "gripper_link" 27 | 28 | @property 29 | def left_finger_joint_name(self) -> str: 30 | return "l_gripper_finger_joint" 31 | 32 | @property 33 | def right_finger_joint_name(self) -> str: 34 | return "r_gripper_finger_joint" 35 | 36 | @property 37 | def open_fingers(self) -> float: 38 | return 0.04 39 | 40 | @property 41 | def closed_fingers(self) -> float: 42 | return 0.01 43 | -------------------------------------------------------------------------------- /predicators/pybullet_helpers/robots/panda.py: -------------------------------------------------------------------------------- 1 | """Franka Emika Panda robot.""" 2 | from typing import Optional 3 | 4 | from predicators import utils 5 | from predicators.pybullet_helpers.ikfast import IKFastInfo 6 | from predicators.pybullet_helpers.robots.single_arm import \ 7 | SingleArmPyBulletRobot 8 | 9 | 10 | class PandaPyBulletRobot(SingleArmPyBulletRobot): 11 | """Franka Emika Panda which we assume is fixed on some base.""" 12 | 13 | @classmethod 14 | def get_name(cls) -> str: 15 | return "panda" 16 | 17 | @classmethod 18 | def urdf_path(cls) -> str: 19 | return utils.get_env_asset_path( 20 | "urdf/franka_description/robots/panda_arm_hand.urdf") 21 | 22 | @property 23 | def end_effector_name(self) -> str: 24 | """The tool joint is offset from the final arm joint such that it 25 | represents the point in the center of the two fingertips of the gripper 26 | (fingertips, NOT the entire fingers). 27 | 28 | This differs from the "panda_hand" joint which represents the 29 | center of the gripper itself including parts of the gripper 30 | body. 31 | """ 32 | return "tool_joint" 33 | 34 | @property 35 | def tool_link_name(self) -> str: 36 | return "tool_link" 37 | 38 | @property 39 | def left_finger_joint_name(self) -> str: 40 | return "panda_finger_joint1" 41 | 42 | @property 43 | def right_finger_joint_name(self) -> str: 44 | return "panda_finger_joint2" 45 | 46 | @property 47 | def open_fingers(self) -> float: 48 | return 0.04 49 | 50 | @property 51 | def closed_fingers(self) -> float: 52 | return 0.03 53 | 54 | @classmethod 55 | def ikfast_info(cls) -> Optional[IKFastInfo]: 56 | return IKFastInfo( 57 | module_dir="panda_arm", 58 | module_name="ikfast_panda_arm", 59 | base_link="panda_link0", 60 | ee_link="panda_link8", 61 | free_joints=["panda_joint7"], 62 | ) 63 | -------------------------------------------------------------------------------- /predicators/refinement_estimators/__init__.py: -------------------------------------------------------------------------------- 1 | """Handle creation of refinement cost estimators.""" 2 | 3 | import importlib 4 | import pkgutil 5 | from typing import TYPE_CHECKING 6 | 7 | from predicators import utils 8 | from predicators.refinement_estimators.base_refinement_estimator import \ 9 | BaseRefinementEstimator 10 | 11 | __all__ = ["BaseRefinementEstimator", "create_refinement_estimator"] 12 | 13 | if not TYPE_CHECKING: 14 | # Load all modules so that utils.get_all_subclasses() works. 15 | for _, module_name, _ in pkgutil.walk_packages(__path__): 16 | if "__init__" not in module_name: 17 | # Important! We use an absolute import here to avoid issues 18 | # with isinstance checking when using relative imports. 19 | importlib.import_module(f"{__name__}.{module_name}") 20 | 21 | 22 | def create_refinement_estimator(name: str) -> BaseRefinementEstimator: 23 | """Create an approach given its name.""" 24 | for cls in utils.get_all_subclasses(BaseRefinementEstimator): 25 | if not cls.__abstractmethods__ and cls.get_name() == name: 26 | estimator = cls() 27 | break 28 | else: 29 | raise NotImplementedError(f"Unknown refinement cost estimator: {name}") 30 | return estimator 31 | -------------------------------------------------------------------------------- /predicators/refinement_estimators/base_refinement_estimator.py: -------------------------------------------------------------------------------- 1 | """Base class for a refinement cost estimator.""" 2 | 3 | import abc 4 | from typing import List, Set 5 | 6 | from predicators.structs import GroundAtom, _GroundNSRT 7 | 8 | 9 | class BaseRefinementEstimator(abc.ABC): 10 | """Base refinement cost estimator.""" 11 | 12 | @classmethod 13 | @abc.abstractmethod 14 | def get_name(cls) -> str: 15 | """Get the unique name of this refinement cost estimator, for future 16 | use as the argument to `--refinement_estimator`.""" 17 | raise NotImplementedError("Override me!") 18 | 19 | @abc.abstractmethod 20 | def get_cost(self, skeleton: List[_GroundNSRT], 21 | atoms_sequence: List[Set[GroundAtom]]) -> float: 22 | """Return an estimated cost for a proposed high-level skeleton.""" 23 | raise NotImplementedError("Override me!") 24 | -------------------------------------------------------------------------------- /predicators/refinement_estimators/oracle_refinement_estimator.py: -------------------------------------------------------------------------------- 1 | """A hand-written refinement cost estimator.""" 2 | 3 | from typing import List, Set 4 | 5 | from predicators.refinement_estimators import BaseRefinementEstimator 6 | from predicators.settings import CFG 7 | from predicators.structs import GroundAtom, _GroundNSRT 8 | 9 | 10 | class OracleRefinementEstimator(BaseRefinementEstimator): 11 | """A refinement cost estimator that returns a hand-designed cost 12 | estimation.""" 13 | 14 | @classmethod 15 | def get_name(cls) -> str: 16 | return "oracle" 17 | 18 | def get_cost(self, skeleton: List[_GroundNSRT], 19 | atoms_sequence: List[Set[GroundAtom]]) -> float: 20 | env_name = CFG.env 21 | if env_name == "narrow_passage": 22 | return narrow_passage_oracle_estimator(skeleton, atoms_sequence) 23 | 24 | # Given environment doesn't have an implemented oracle estimator 25 | raise NotImplementedError( 26 | f"No oracle refinement cost estimator for env {env_name}") 27 | 28 | 29 | def narrow_passage_oracle_estimator( 30 | skeleton: List[_GroundNSRT], 31 | atoms_sequence: List[Set[GroundAtom]], 32 | ) -> float: 33 | """Oracle refinement estimation function for narrow_passage env.""" 34 | del atoms_sequence # unused 35 | 36 | # Hard-coded estimated num_samples needed to refine different operators 37 | move_and_open_door = 1 38 | move_through_door = 1 39 | move_through_passage = 3 40 | 41 | # Sum metric of difficulty over skeleton 42 | cost = 0 43 | door_open = False 44 | for ground_nsrt in skeleton: 45 | if ground_nsrt.name == "MoveAndOpenDoor": 46 | cost += move_and_open_door 47 | door_open = True 48 | elif ground_nsrt.name == "MoveToTarget": 49 | cost += move_through_door if door_open else move_through_passage 50 | return cost 51 | -------------------------------------------------------------------------------- /predicators/third_party/fast_downward_translator/README.md: -------------------------------------------------------------------------------- 1 | Fast Downward translator, with light modifications made by Rohan. 2 | 3 | Code copied from the `src/translate/` directory in the [official repository at this commit from May 9, 2022](https://github.com/aibasel/downward/tree/3e3759d091196515fa68c44a729153100747c4bf). All credits go to the original authors. 4 | 5 | To use, call the function `main()` in translate.py, which takes in domain and problem file strings and returns a `SASTask` object (sas_tasks.py) representing a ground planning problem. 6 | 7 | Modifications: 8 | * Changed input and output to not require file I/O. 9 | * Removed [options.py](https://github.com/aibasel/downward/blob/3e3759d091196515fa68c44a729153100747c4bf/src/translate/options.py) and associated command-line arguments, replacing with the default values from the linked file. 10 | * Removed tests. 11 | * Changed imports to be absolute. 12 | * Ran our code autoformatter. -------------------------------------------------------------------------------- /predicators/third_party/fast_downward_translator/graph.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | 3 | 4 | class Graph: 5 | 6 | def __init__(self, nodes): 7 | self.nodes = nodes 8 | self.neighbours = {u: set() for u in nodes} 9 | 10 | def connect(self, u, v): 11 | self.neighbours[u].add(v) 12 | self.neighbours[v].add(u) 13 | 14 | def connected_components(self): 15 | remaining_nodes = set(self.nodes) 16 | result = [] 17 | 18 | def dfs(node): 19 | result[-1].append(node) 20 | remaining_nodes.remove(node) 21 | for neighbour in self.neighbours[node]: 22 | if neighbour in remaining_nodes: 23 | dfs(neighbour) 24 | 25 | while remaining_nodes: 26 | node = next(iter(remaining_nodes)) 27 | result.append([]) 28 | dfs(node) 29 | result[-1].sort() 30 | return sorted(result) 31 | 32 | 33 | def transitive_closure(pairs): 34 | # Warshall's algorithm. 35 | result = set(pairs) 36 | nodes = {u for (u, v) in pairs} | {v for (u, v) in pairs} 37 | for k in nodes: 38 | for i in nodes: 39 | for j in nodes: 40 | if (i, j) not in result and (i, k) in result and (k, 41 | j) in result: 42 | result.add((i, j)) 43 | return sorted(result) 44 | -------------------------------------------------------------------------------- /predicators/third_party/fast_downward_translator/pddl/__init__.py: -------------------------------------------------------------------------------- 1 | from .actions import Action, PropositionalAction 2 | from .axioms import Axiom, PropositionalAxiom 3 | from .conditions import Atom, Conjunction, Disjunction, ExistentialCondition, \ 4 | Falsity, Literal, NegatedAtom, Truth, UniversalCondition 5 | from .effects import ConditionalEffect, ConjunctiveEffect, CostEffect, \ 6 | Effect, SimpleEffect, UniversalEffect 7 | from .f_expression import Assign, Increase, NumericConstant, \ 8 | PrimitiveNumericExpression 9 | from .functions import Function 10 | from .pddl_types import Type, TypedObject 11 | from .predicates import Predicate 12 | from .tasks import Requirements, Task 13 | -------------------------------------------------------------------------------- /predicators/third_party/fast_downward_translator/pddl/functions.py: -------------------------------------------------------------------------------- 1 | class Function: 2 | 3 | def __init__(self, name, arguments, type_name): 4 | self.name = name 5 | self.arguments = arguments 6 | if type_name != "number": 7 | raise SystemExit("Error: object fluents not supported\n" + 8 | "(function %s has type %s)" % (name, type_name)) 9 | self.type_name = type_name 10 | 11 | def __str__(self): 12 | result = "%s(%s)" % (self.name, ", ".join(map(str, self.arguments))) 13 | if self.type_name: 14 | result += ": %s" % self.type_name 15 | return result 16 | -------------------------------------------------------------------------------- /predicators/third_party/fast_downward_translator/pddl/pddl_types.py: -------------------------------------------------------------------------------- 1 | # Renamed from types.py to avoid clash with stdlib module. 2 | # In the future, use explicitly relative imports or absolute 3 | # imports as a better solution. 4 | 5 | import itertools 6 | 7 | 8 | def _get_type_predicate_name(type_name): 9 | # PDDL allows mixing types and predicates, but some PDDL files 10 | # have name collisions between types and predicates. We want to 11 | # support both the case where such name collisions occur and the 12 | # case where types are used as predicates. 13 | # 14 | # We internally give types predicate names that cannot be confused 15 | # with non-type predicates. When the input uses a PDDL type as a 16 | # predicate, we automatically map it to this internal name. 17 | return "type@%s" % type_name 18 | 19 | 20 | class Type: 21 | 22 | def __init__(self, name, basetype_name=None): 23 | self.name = name 24 | self.basetype_name = basetype_name 25 | 26 | def __str__(self): 27 | return self.name 28 | 29 | def __repr__(self): 30 | return "Type(%s, %s)" % (self.name, self.basetype_name) 31 | 32 | def get_predicate_name(self): 33 | return _get_type_predicate_name(self.name) 34 | 35 | 36 | class TypedObject: 37 | 38 | def __init__(self, name, type_name): 39 | self.name = name 40 | self.type_name = type_name 41 | 42 | def __hash__(self): 43 | return hash((self.name, self.type_name)) 44 | 45 | def __eq__(self, other): 46 | return self.name == other.name and self.type_name == other.type_name 47 | 48 | def __ne__(self, other): 49 | return not self == other 50 | 51 | def __str__(self): 52 | return "%s: %s" % (self.name, self.type_name) 53 | 54 | def __repr__(self): 55 | return "" % (self.name, self.type_name) 56 | 57 | def uniquify_name(self, type_map, renamings): 58 | if self.name not in type_map: 59 | type_map[self.name] = self.type_name 60 | return self 61 | for counter in itertools.count(1): 62 | new_name = self.name + str(counter) 63 | if new_name not in type_map: 64 | renamings[self.name] = new_name 65 | type_map[new_name] = self.type_name 66 | return TypedObject(new_name, self.type_name) 67 | 68 | def get_atom(self): 69 | # TODO: Resolve cyclic import differently. 70 | from . import conditions 71 | predicate_name = _get_type_predicate_name(self.type_name) 72 | return conditions.Atom(predicate_name, [self.name]) 73 | -------------------------------------------------------------------------------- /predicators/third_party/fast_downward_translator/pddl/predicates.py: -------------------------------------------------------------------------------- 1 | class Predicate: 2 | 3 | def __init__(self, name, arguments): 4 | self.name = name 5 | self.arguments = arguments 6 | 7 | def __str__(self): 8 | return "%s(%s)" % (self.name, ", ".join(map(str, self.arguments))) 9 | 10 | def get_arity(self): 11 | return len(self.arguments) 12 | -------------------------------------------------------------------------------- /predicators/third_party/fast_downward_translator/pddl/tasks.py: -------------------------------------------------------------------------------- 1 | from . import axioms, predicates 2 | 3 | 4 | class Task: 5 | 6 | def __init__(self, domain_name, task_name, requirements, types, objects, 7 | predicates, functions, init, goal, actions, axioms, 8 | use_metric): 9 | self.domain_name = domain_name 10 | self.task_name = task_name 11 | self.requirements = requirements 12 | self.types = types 13 | self.objects = objects 14 | self.predicates = predicates 15 | self.functions = functions 16 | self.init = init 17 | self.goal = goal 18 | self.actions = actions 19 | self.axioms = axioms 20 | self.axiom_counter = 0 21 | self.use_min_cost_metric = use_metric 22 | 23 | def add_axiom(self, parameters, condition): 24 | name = "new-axiom@%d" % self.axiom_counter 25 | self.axiom_counter += 1 26 | axiom = axioms.Axiom(name, parameters, len(parameters), condition) 27 | self.predicates.append(predicates.Predicate(name, parameters)) 28 | self.axioms.append(axiom) 29 | return axiom 30 | 31 | def dump(self): 32 | print("Problem %s: %s [%s]" % 33 | (self.domain_name, self.task_name, self.requirements)) 34 | print("Types:") 35 | for type in self.types: 36 | print(" %s" % type) 37 | print("Objects:") 38 | for obj in self.objects: 39 | print(" %s" % obj) 40 | print("Predicates:") 41 | for pred in self.predicates: 42 | print(" %s" % pred) 43 | print("Functions:") 44 | for func in self.functions: 45 | print(" %s" % func) 46 | print("Init:") 47 | for fact in self.init: 48 | print(" %s" % fact) 49 | print("Goal:") 50 | self.goal.dump() 51 | print("Actions:") 52 | for action in self.actions: 53 | action.dump() 54 | if self.axioms: 55 | print("Axioms:") 56 | for axiom in self.axioms: 57 | axiom.dump() 58 | 59 | 60 | class Requirements: 61 | 62 | def __init__(self, requirements): 63 | self.requirements = requirements 64 | for req in requirements: 65 | assert req in (":strips", ":adl", ":typing", ":negation", 66 | ":equality", ":negative-preconditions", 67 | ":disjunctive-preconditions", 68 | ":existential-preconditions", 69 | ":universal-preconditions", 70 | ":quantified-preconditions", ":conditional-effects", 71 | ":derived-predicates", ":action-costs"), req 72 | 73 | def __str__(self): 74 | return ", ".join(self.requirements) 75 | -------------------------------------------------------------------------------- /predicators/third_party/fast_downward_translator/pddl_parser/__init__.py: -------------------------------------------------------------------------------- 1 | from .pddl_file import open 2 | -------------------------------------------------------------------------------- /predicators/third_party/fast_downward_translator/pddl_parser/lisp_parser.py: -------------------------------------------------------------------------------- 1 | __all__ = ["ParseError", "parse_nested_list"] 2 | 3 | 4 | class ParseError(Exception): 5 | 6 | def __init__(self, value): 7 | self.value = value 8 | 9 | def __str__(self): 10 | return self.value 11 | 12 | 13 | # Basic functions for parsing PDDL (Lisp) files. 14 | def parse_nested_list(input_file): 15 | tokens = tokenize(input_file) 16 | next_token = next(tokens) 17 | if next_token != "(": 18 | raise ParseError("Expected '(', got %s." % next_token) 19 | result = list(parse_list_aux(tokens)) 20 | for tok in tokens: # Check that generator is exhausted. 21 | raise ParseError("Unexpected token: %s." % tok) 22 | return result 23 | 24 | 25 | def tokenize(input): 26 | for line in input.split("\n"): 27 | line = line.split(";", 1)[0] # Strip comments. 28 | try: 29 | line.encode("ascii") 30 | except UnicodeEncodeError: 31 | raise ParseError("Non-ASCII character outside comment: %s" % 32 | line[0:-1]) 33 | line = line.replace("(", " ( ").replace(")", " ) ").replace("?", " ?") 34 | for token in line.split(): 35 | yield token.lower() 36 | 37 | 38 | def parse_list_aux(tokenstream): 39 | # Leading "(" has already been swallowed. 40 | while True: 41 | try: 42 | token = next(tokenstream) 43 | except StopIteration: 44 | raise ParseError("Missing ')'") 45 | if token == ")": 46 | return 47 | elif token == "(": 48 | yield list(parse_list_aux(tokenstream)) 49 | else: 50 | yield token 51 | -------------------------------------------------------------------------------- /predicators/third_party/fast_downward_translator/pddl_parser/pddl_file.py: -------------------------------------------------------------------------------- 1 | from predicators.third_party.fast_downward_translator.pddl_parser import \ 2 | lisp_parser, parsing_functions 3 | 4 | file_open = open 5 | 6 | 7 | def open(domain_string, task_string): 8 | domain_pddl = lisp_parser.parse_nested_list(domain_string) 9 | task_pddl = lisp_parser.parse_nested_list(task_string) 10 | return parsing_functions.parse_task(domain_pddl, task_pddl) 11 | -------------------------------------------------------------------------------- /predicators/third_party/fast_downward_translator/timers.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import os 3 | import sys 4 | import time 5 | 6 | 7 | class Timer: 8 | 9 | def __init__(self): 10 | self.start_time = time.time() 11 | self.start_clock = self._clock() 12 | 13 | def _clock(self): 14 | times = os.times() 15 | return times[0] + times[1] 16 | 17 | def __str__(self): 18 | return "[%.3fs CPU, %.3fs wall-clock]" % ( 19 | self._clock() - self.start_clock, time.time() - self.start_time) 20 | 21 | 22 | @contextlib.contextmanager 23 | def timing(text, block=False): 24 | timer = Timer() 25 | if block: 26 | print("%s..." % text) 27 | else: 28 | print("%s..." % text, end=' ') 29 | sys.stdout.flush() 30 | yield 31 | if block: 32 | print("%s: %s" % (text, timer)) 33 | else: 34 | print(timer) 35 | sys.stdout.flush() 36 | -------------------------------------------------------------------------------- /predicators/third_party/fast_downward_translator/tools.py: -------------------------------------------------------------------------------- 1 | def cartesian_product(sequences): 2 | # TODO: Rename this. It's not good that we have two functions 3 | # called "product" and "cartesian_product", of which "product" 4 | # computes cartesian products, while "cartesian_product" does not. 5 | 6 | # This isn't actually a proper cartesian product because we 7 | # concatenate lists, rather than forming sequences of atomic elements. 8 | # We could probably also use something like 9 | # map(itertools.chain, product(*sequences)) 10 | # but that does not produce the same results 11 | if not sequences: 12 | yield [] 13 | else: 14 | temp = list(cartesian_product(sequences[1:])) 15 | for item in sequences[0]: 16 | for sequence in temp: 17 | yield item + sequence 18 | 19 | 20 | def get_peak_memory_in_kb(): 21 | try: 22 | # This will only work on Linux systems. 23 | with open("/proc/self/status") as status_file: 24 | for line in status_file: 25 | parts = line.split() 26 | if parts[0] == "VmPeak:": 27 | return int(parts[1]) 28 | except OSError: 29 | pass 30 | raise Warning("warning: could not determine peak memory") 31 | -------------------------------------------------------------------------------- /predicators/third_party/ikfast/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/third_party/ikfast/__init__.py -------------------------------------------------------------------------------- /predicators/third_party/ikfast/compile.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | import importlib 3 | import os 4 | import shutil 5 | from distutils.core import setup 6 | from distutils.dir_util import copy_tree 7 | from distutils.extension import Extension 8 | 9 | # Build C++ extension by running: 'python setup.py build' 10 | # see: https://docs.python.org/3/extending/building.html 11 | 12 | # http://openrave.org/docs/0.8.2/openravepy/ikfast/ 13 | # https://github.com/rdiankov/openrave/blob/master/python/ikfast.py#L92 14 | # http://wiki.ros.org/Industrial/Tutorials/Create_a_Fast_IK_Solution 15 | 16 | # Yijiang 17 | # https://github.com/yijiangh/ikfast_pybind 18 | # https://github.com/yijiangh/conrob_pybullet/tree/master/utils/ikfast 19 | # https://github.com/yijiangh/choreo/blob/bc777069b8eb7283c74af26e5461532aec3d9e8a/framefab_robot/abb/framefab_irb6600/framefab_irb6600_support/doc/ikfast_tutorial.rst 20 | 21 | 22 | def compile_ikfast(module_name, cpp_filename, remove_build=False): 23 | ikfast_module = Extension(module_name, sources=[cpp_filename]) 24 | setup(name=module_name, 25 | version='1.0', 26 | description="ikfast module {}".format(module_name), 27 | ext_modules=[ikfast_module]) 28 | 29 | build_lib_path = None 30 | for root, dirnames, filenames in os.walk(os.getcwd()): 31 | if fnmatch.fnmatch(root, os.path.join(os.getcwd(), "*build", "lib*")): 32 | build_lib_path = root 33 | break 34 | assert build_lib_path 35 | 36 | copy_tree(build_lib_path, os.getcwd()) 37 | if remove_build: 38 | # TODO: error when compiling multiple arms for python2 39 | # error: unable to open output file 'build/temp.macosx-10.15-x86_64-2.7/movo_right_arm_ik.o': 'No such file or directory' 40 | shutil.rmtree(os.path.join(os.getcwd(), 'build')) 41 | 42 | try: 43 | importlib.import_module(module_name) 44 | print('\nikfast module {} imported successful'.format(module_name)) 45 | except ImportError as e: 46 | print('\nikfast module {} imported failed'.format(module_name)) 47 | raise e 48 | return True 49 | -------------------------------------------------------------------------------- /predicators/third_party/ikfast/panda_arm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/predicators/third_party/ikfast/panda_arm/__init__.py -------------------------------------------------------------------------------- /predicators/third_party/ikfast/panda_arm/compile.py: -------------------------------------------------------------------------------- 1 | ../compile.py -------------------------------------------------------------------------------- /predicators/third_party/ikfast/panda_arm/ikfast.h: -------------------------------------------------------------------------------- 1 | ../ikfast.h -------------------------------------------------------------------------------- /predicators/third_party/ikfast/panda_arm/setup.py: -------------------------------------------------------------------------------- 1 | """Note: this subdirectory is modified from the pybullet-planning repository 2 | by Caelan Garrett (https://github.com/caelan/pybullet-planning/).""" 3 | 4 | import sys 5 | 6 | from compile import compile_ikfast 7 | 8 | 9 | def main(): 10 | # lib name template: 'ikfast_' 11 | sys.argv[:] = sys.argv[:1] + ['build'] 12 | robot_name = 'panda_arm' 13 | compile_ikfast(module_name='ikfast_{}'.format(robot_name), 14 | cpp_filename='ikfast_{}.cpp'.format(robot_name)) 15 | 16 | 17 | if __name__ == '__main__': 18 | main() 19 | -------------------------------------------------------------------------------- /run_autoformat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | yapf -i -r --style .style.yapf --exclude '**/third_party' predicators 3 | yapf -i -r --style .style.yapf scripts 4 | yapf -i -r --style .style.yapf tests 5 | yapf -i -r --style .style.yapf setup.py 6 | docformatter -i -r . --exclude venv predicators/third_party 7 | isort . 8 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/configs/backchaining_predicate_invention.yaml: -------------------------------------------------------------------------------- 1 | # Predicate invention with backchaining operator learning 2 | --- 3 | APPROACHES: 4 | invent_allexclude_backchaining: 5 | NAME: "grammar_search_invention" 6 | FLAGS: 7 | strips_learner: "backchaining" 8 | excluded_predicates: "all" 9 | ENVS: 10 | cover: 11 | NAME: "cover" 12 | pybullet_blocks: 13 | NAME: "pybullet_blocks" 14 | painting: 15 | NAME: "painting" 16 | tools: 17 | NAME: "tools" 18 | FLAGS: 19 | num_train_tasks: 200 # requires more data (default is 50) 20 | pybullet_cover: 21 | NAME: "pybullet_cover" 22 | ARGS: [] 23 | FLAGS: {} 24 | START_SEED: 456 25 | NUM_SEEDS: 10 26 | ... 27 | -------------------------------------------------------------------------------- /scripts/configs/behavior_example.yaml: -------------------------------------------------------------------------------- 1 | # An example configuration file. 2 | --- 3 | APPROACHES: 4 | my-oracle: # used in constructing the experiment ID 5 | NAME: "oracle" 6 | FLAGS: 7 | offline_data_planning_timeout: 500.0 8 | timeout: 500.0 9 | plan_only_eval: True 10 | sesame_task_planner: fdopt 11 | backchaining: 12 | NAME: "nsrt_learning" 13 | FLAGS: 14 | offline_data_planning_timeout: 500.0 15 | timeout: 500.0 16 | sampler_learner: neural 17 | strips_learner: backchaining 18 | plan_only_eval: True 19 | sesame_task_planner: fdopt 20 | ENVS: 21 | opening-packages-Pomaria_2_int: 22 | NAME: "behavior" 23 | FLAGS: 24 | behavior_scene_name: Pomaria_2_int 25 | behavior_task_list: "[opening_packages]" 26 | behavior_option_model_eval: True 27 | ARGS: {} 28 | FLAGS: # general flags 29 | num_train_tasks: 10 30 | num_test_tasks: 10 31 | START_SEED: 456 32 | NUM_SEEDS: 3 33 | USE_GPU: False 34 | -------------------------------------------------------------------------------- /scripts/configs/behavior_pick_place_data_collection.yaml: -------------------------------------------------------------------------------- 1 | # A configuration file for 4 representative pick and place BEHAVIOR tasks. 2 | # python scripts/supercloud/launch.py \ 3 | # --user wmcclinton \ 4 | # --config behavior_pick_place_data_collection.yaml 5 | --- 6 | APPROACHES: 7 | pnad-search: 8 | NAME: "nsrt_learning" 9 | FLAGS: 10 | offline_data_planning_timeout: 1000.0 11 | timeout: 1000.0 12 | sampler_learner: neural 13 | strips_learner: pnad_search 14 | plan_only_eval: True 15 | sesame_task_planner: fdopt 16 | create_training_dataset: True 17 | max_demo_attempts: 10 18 | ENVS: 19 | collecting-aluminum-cans-Ihlen_1_int: 20 | NAME: "behavior" 21 | FLAGS: 22 | behavior_train_scene_name: Ihlen_1_int 23 | behavior_test_scene_name: Ihlen_1_int 24 | behavior_task_list: "[collecting_aluminum_cans]" 25 | behavior_option_model_eval: True 26 | opening-presents-Pomaria_2_int: 27 | NAME: "behavior" 28 | FLAGS: 29 | behavior_train_scene_name: Pomaria_2_int 30 | behavior_test_scene_name: Pomaria_2_int 31 | behavior_task_list: "[opening_presents]" 32 | behavior_option_model_eval: True 33 | locking-every-window-Merom_1_int: 34 | NAME: "behavior" 35 | FLAGS: 36 | behavior_train_scene_name: Merom_1_int 37 | behavior_test_scene_name: Merom_1_int 38 | behavior_task_list: "[locking_every_window]" 39 | behavior_option_model_eval: True 40 | sorting-books-Pomaria_1_int: 41 | NAME: "behavior" 42 | FLAGS: 43 | behavior_train_scene_name: Pomaria_1_int 44 | behavior_test_scene_name: Pomaria_1_int 45 | behavior_task_list: "[sorting_books]" 46 | behavior_option_model_eval: True 47 | ARGS: {} 48 | FLAGS: # general flags 49 | num_train_tasks: 10 50 | num_test_tasks: 0 51 | START_SEED: 456 52 | NUM_SEEDS: 1 53 | USE_GPU: False 54 | -------------------------------------------------------------------------------- /scripts/configs/example_basic.yaml: -------------------------------------------------------------------------------- 1 | # An example configuration file. 2 | --- 3 | APPROACHES: 4 | my-oracle: # used in constructing the experiment ID 5 | NAME: "oracle" 6 | nsrt-learning: 7 | NAME: "nsrt_learning" 8 | FLAGS: 9 | disable_harmlessness_check: True # just an example 10 | ENVS: 11 | cover-default-settings: # used in constructing the experiment ID 12 | NAME: "cover" 13 | cover-single-block: 14 | NAME: "cover" 15 | FLAGS: 16 | cover_num_blocks: 1 17 | cover_block_widths: [0.1] 18 | ARGS: 19 | - "debug" 20 | FLAGS: # general flags 21 | num_train_tasks: 20 22 | num_test_tasks: 10 23 | START_SEED: 456 24 | NUM_SEEDS: 2 25 | -------------------------------------------------------------------------------- /scripts/configs/example_multiple_combinations.yaml: -------------------------------------------------------------------------------- 1 | # An example configuration file showing how to define multiple approach/env 2 | # combinations, i.e., not just one cross product. In this example, the oracle 3 | # approach is run in cover, and both the oracle approach and nsrt_learning are 4 | # run in blocks. 5 | --- 6 | APPROACHES: 7 | oracle: 8 | NAME: "oracle" 9 | ENVS: 10 | cover: 11 | NAME: "cover" 12 | ARGS: [] 13 | FLAGS: {} 14 | START_SEED: 456 15 | NUM_SEEDS: 3 16 | ... 17 | --- 18 | APPROACHES: 19 | oracle: 20 | NAME: "oracle" 21 | nsrt-learning: 22 | NAME: "nsrt_learning" 23 | ENVS: 24 | blocks: 25 | NAME: "blocks" 26 | ARGS: [] 27 | FLAGS: {} 28 | START_SEED: 456 29 | NUM_SEEDS: 3 30 | -------------------------------------------------------------------------------- /scripts/configs/full_pipeline.yaml: -------------------------------------------------------------------------------- 1 | # Learning operators, samplers, policies, and predicates. 2 | --- 3 | APPROACHES: 4 | learn_all: 5 | NAME: "grammar_search_invention" 6 | FLAGS: 7 | min_perc_data_for_nsrt: 1 8 | segmenter: "contacts" 9 | neural_gaus_regressor_max_itr: 50000 10 | option_learner: direct_bc 11 | sesame_max_samples_per_step: 100 # mainly to improve demo collection 12 | ENVS: 13 | cover: 14 | NAME: "cover_multistep_options" 15 | ARGS: [] 16 | FLAGS: 17 | num_train_tasks: 1000 18 | excluded_predicates: "all" 19 | timeout: 300 20 | START_SEED: 456 21 | NUM_SEEDS: 10 22 | ... 23 | -------------------------------------------------------------------------------- /scripts/configs/llm_pddl_ablations.yaml: -------------------------------------------------------------------------------- 1 | # Varun's experiments on predicate, operator, and syntax renaming. 2 | --- 3 | APPROACHES: 4 | llm_predicate_renaming: 5 | NAME: "llm_predicate_renaming" 6 | FLAGS: 7 | llm_num_completions: 1 8 | llm_option_renaming: 9 | NAME: "llm_option_renaming" 10 | FLAGS: 11 | llm_num_completions: 1 12 | llm_syntax_renaming: 13 | NAME: "llm_syntax_renaming" 14 | FLAGS: 15 | llm_num_completions: 1 16 | ENVS: 17 | easy_delivery: 18 | NAME: "pddl_delivery_procedural_tasks" 19 | FLAGS: 20 | pddl_delivery_procedural_train_min_num_locs: 3 21 | pddl_delivery_procedural_train_max_num_locs: 5 22 | pddl_delivery_procedural_train_min_want_locs: 1 23 | pddl_delivery_procedural_train_max_want_locs: 2 24 | pddl_delivery_procedural_test_min_num_locs: 4 25 | pddl_delivery_procedural_test_max_num_locs: 6 26 | pddl_delivery_procedural_test_min_want_locs: 2 27 | pddl_delivery_procedural_test_max_want_locs: 3 28 | pddl_delivery_procedural_test_max_extra_newspapers: 1 29 | ARGS: 30 | - "debug" 31 | FLAGS: 32 | num_train_tasks: 5 33 | num_test_tasks: 10 34 | strips_learner: "oracle" 35 | timeout: 100 36 | llm_model_name: "text-davinci-002" 37 | llm_use_cache_only: False # change to true to rerun 38 | START_SEED: 456 39 | NUM_SEEDS: 5 40 | -------------------------------------------------------------------------------- /scripts/configs/nightly.yaml: -------------------------------------------------------------------------------- 1 | # NOTE: all these experiments should get >45 / 50 in the NUM_SOLVED column of 2 | # the table printed out by `python scripts/analyze_results_directory.py`, 3 | # except stick_button_option_learning, which should get ~40-45 / 50. 4 | 5 | # oracle, NSRT learning, and predicate invention experiments 6 | --- 7 | APPROACHES: 8 | oracle: 9 | NAME: "oracle" 10 | nsrt_learning: 11 | NAME: "nsrt_learning" 12 | invent_allexclude: 13 | NAME: "grammar_search_invention" 14 | FLAGS: 15 | excluded_predicates: "all" 16 | ENVS: 17 | cover: 18 | NAME: "cover" 19 | pybullet_blocks: 20 | NAME: "pybullet_blocks" 21 | painting: 22 | NAME: "painting" 23 | tools: 24 | NAME: "tools" 25 | FLAGS: 26 | num_train_tasks: 200 # requires more data (default is 50) 27 | pybullet_cover: 28 | NAME: "pybullet_cover" 29 | ARGS: [] 30 | FLAGS: {} 31 | START_SEED: 456 32 | NUM_SEEDS: 5 33 | ... 34 | 35 | # other environments with oracle and NSRT learning 36 | --- 37 | APPROACHES: 38 | oracle: 39 | NAME: "oracle" 40 | nsrt_learning: 41 | NAME: "nsrt_learning" 42 | ENVS: 43 | playroom: 44 | NAME: "playroom" 45 | stick_button: 46 | NAME: "stick_button" 47 | FLAGS: 48 | timeout: 300 # requires longer timeout (default is 10) 49 | num_train_tasks: 500 # requires more data (default is 50) 50 | min_perc_data_for_nsrt: 1 # requires filtering (default is 0) 51 | ARGS: [] 52 | FLAGS: {} 53 | START_SEED: 456 54 | NUM_SEEDS: 5 55 | ... 56 | 57 | # option learning 58 | --- 59 | APPROACHES: 60 | option_learning: 61 | NAME: "nsrt_learning" 62 | FLAGS: 63 | option_learner: "direct_bc" 64 | min_perc_data_for_nsrt: 1 65 | segmenter: "contacts" 66 | neural_gaus_regressor_max_itr: 50000 67 | ENVS: 68 | cover: 69 | NAME: "cover_multistep_options" 70 | doors: 71 | NAME: "doors" 72 | FLAGS: 73 | included_options: "MoveToDoor,MoveThroughDoor" 74 | stick_button: 75 | NAME: "stick_button" 76 | coffee: 77 | NAME: "coffee" 78 | ARGS: [] 79 | FLAGS: 80 | num_train_tasks: 1000 # more train tasks (default is 50) 81 | timeout: 300 # more time (default is 10) 82 | START_SEED: 456 83 | NUM_SEEDS: 5 84 | ... 85 | -------------------------------------------------------------------------------- /scripts/configs/pg3_offline.yaml: -------------------------------------------------------------------------------- 1 | # Offline PG3 Configuration File 2 | --- 3 | APPROACHES: 4 | pg3: 5 | NAME: "pg3" 6 | pg3-given-ops: 7 | NAME: "pg3" 8 | FLAGS: 9 | strips_learner: oracle 10 | plan-only: 11 | NAME: "oracle" 12 | ENVS: 13 | delivery: 14 | NAME: "pddl_easy_delivery_procedural_tasks" 15 | spanner: 16 | NAME: "pddl_spanner_procedural_tasks" 17 | forest: 18 | NAME: "pddl_forest_procedural_tasks" 19 | ARGS: [] 20 | FLAGS: 21 | num_train_tasks: 50 22 | START_SEED: 456 23 | NUM_SEEDS: 10 24 | -------------------------------------------------------------------------------- /scripts/configs/pg3_online.yaml: -------------------------------------------------------------------------------- 1 | # Online PG3 Configuration File 2 | --- 3 | APPROACHES: 4 | pg3: 5 | NAME: "online_pg3" 6 | pg3-random-options: 7 | NAME: "online_pg3" 8 | FLAGS: 9 | explorer: random_options 10 | pg3-glib: 11 | NAME: "online_pg3" 12 | FLAGS: 13 | explorer: glib 14 | pg3-given-ops: 15 | NAME: "online_pg3" 16 | FLAGS: 17 | strips_learner: oracle 18 | plan-only: 19 | NAME: "oracle" 20 | ENVS: 21 | delivery: 22 | NAME: "pddl_easy_delivery_procedural_tasks" 23 | spanner: 24 | NAME: "pddl_spanner_procedural_tasks" 25 | forest: 26 | NAME: "pddl_forest_procedural_tasks" 27 | ARGS: [] 28 | FLAGS: 29 | num_train_tasks: 50 30 | num_test_tasks: 10 31 | max_initial_demos: 0 32 | max_num_steps_interaction_request: 10 33 | START_SEED: 456 34 | NUM_SEEDS: 10 35 | -------------------------------------------------------------------------------- /scripts/configs/pg4.yaml: -------------------------------------------------------------------------------- 1 | # PG4 experiments. 2 | --- 3 | APPROACHES: 4 | pg3: 5 | NAME: "pg3" 6 | pg4: 7 | NAME: "pg4" 8 | plan-only: 9 | NAME: "oracle" 10 | ENVS: 11 | cover-easy: # downward refinable. PG3 should succeed 12 | NAME: "cover" 13 | FLAGS: 14 | cover_initial_holding_prob: 0.0 15 | cover: # PG3 should sometimes fail, but PG4 should succeed 16 | NAME: "cover" 17 | painting-no-holding-shelf-only: 18 | NAME: "painting" 19 | FLAGS: 20 | painting_initial_holding_prob: 0.0 21 | painting_goal_receptacles: "shelf" 22 | painting-no-lid-no-holding-box-only: 23 | NAME: "painting" 24 | FLAGS: 25 | painting_lid_open_prob: 1.0 26 | painting_initial_holding_prob: 0.0 27 | painting_goal_receptacles: "box" 28 | painting-lid-no-holding-box-only: 29 | NAME: "painting" 30 | FLAGS: 31 | painting_initial_holding_prob: 0.0 32 | painting_goal_receptacles: "box" 33 | painting-lid-no-holding-boxandshelf: 34 | NAME: "painting" 35 | FLAGS: 36 | painting_initial_holding_prob: 0.0 37 | screws: 38 | NAME: "screws" 39 | repeated-nextto: 40 | NAME: "repeated_nextto" 41 | cluttered-table: 42 | NAME: "cluttered_table" 43 | coffee-easy: 44 | NAME: "coffee" 45 | FLAGS: 46 | coffee_jug_init_rot_amt: 0 47 | coffee-hard: 48 | NAME: "coffee" 49 | ARGS: [] 50 | FLAGS: 51 | strips_learner: oracle 52 | sampler_learner: oracle 53 | num_train_tasks: 50 54 | START_SEED: 456 55 | NUM_SEEDS: 10 56 | -------------------------------------------------------------------------------- /scripts/find_unused_functions.py: -------------------------------------------------------------------------------- 1 | """Report all unused top-level functions in a particular file. 2 | 3 | Note: this can report false positives because we only check for the string: 4 | "(", while sometimes functions are used in other ways. 5 | """ 6 | 7 | import re 8 | import subprocess 9 | 10 | FILENAME = "predicators/utils.py" 11 | DIRS_TO_CHECK = [ 12 | "predicators/", 13 | "scripts/", 14 | # "tests/", 15 | ] 16 | 17 | with open(FILENAME, "r", encoding="utf-8") as f: 18 | lines = f.readlines() 19 | for line in lines: 20 | if line.startswith("def "): # top-level only! 21 | match = re.match(r"def (.+?)\(.*", line) 22 | assert match is not None, f"Malformed line: {line}" 23 | func_name = match.groups()[0] 24 | dirs = " ".join(DIRS_TO_CHECK) 25 | results = subprocess.getoutput(f"git grep '{func_name}(' {dirs}") 26 | assert results # at least must match the definition line 27 | num_hits = len(results.split("\n")) 28 | assert num_hits > 0 29 | if num_hits == 1: 30 | print(func_name) 31 | -------------------------------------------------------------------------------- /scripts/local/launch.py: -------------------------------------------------------------------------------- 1 | """Launch experiments defined in config files locally. 2 | 3 | Run experiments sequentially, not in parallel. 4 | 5 | python scripts/local/launch.py --config example_basic.yaml 6 | 7 | The default branch can be overridden with the --branch flag. 8 | """ 9 | 10 | import argparse 11 | import os 12 | import subprocess 13 | 14 | from scripts.cluster_utils import DEFAULT_BRANCH, config_to_cmd_flags, \ 15 | config_to_logfile, generate_run_configs, get_cmds_to_prep_repo 16 | 17 | 18 | def _main() -> None: 19 | # Set up argparse. 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--config", required=True, type=str) 22 | parser.add_argument("--branch", type=str, default=DEFAULT_BRANCH) 23 | args = parser.parse_args() 24 | # Prepare the repo. 25 | for cmd in get_cmds_to_prep_repo(args.branch, False): 26 | subprocess.run(cmd, shell=True, check=False) 27 | # Create the run commands. 28 | cmds = [] 29 | for cfg in generate_run_configs(args.config): 30 | cmd_flags = config_to_cmd_flags(cfg) 31 | logfile = os.path.join("logs", config_to_logfile(cfg)) 32 | cmd_flags = config_to_cmd_flags(cfg) 33 | cmd = f"python predicators/main.py {cmd_flags} > {logfile}" 34 | cmds.append(cmd) 35 | # Run the commands in order. 36 | num_cmds = len(cmds) 37 | for i, cmd in enumerate(cmds): 38 | print(f"********* RUNNING COMMAND {i+1} of {num_cmds} *********") 39 | subprocess.run(cmd, shell=True, check=False) 40 | 41 | 42 | if __name__ == "__main__": 43 | _main() 44 | -------------------------------------------------------------------------------- /scripts/local/run_behavior_tests.py: -------------------------------------------------------------------------------- 1 | """Easily launch experiments on a variety of BEHAVIOR environments.""" 2 | 3 | import json 4 | import os 5 | import shutil 6 | 7 | NUM_TEST = 1 8 | SEED = 0 9 | TIMEOUT = 1000 10 | OPEN_PICK_PLACE_TASKS = [ 11 | 'collecting_aluminum_cans', 'throwing_away_leftovers', 12 | 'packing_bags_or_suitcase', 'packing_boxes_for_household_move_or_trip', 13 | 'opening_presents', 'organizing_file_cabinet', 'locking_every_window', 14 | 'packing_car_for_trip', 're-shelving_library_books', 'storing_food', 15 | 'organizing_boxes_in_garage', 'putting_leftovers_away', 16 | 'unpacking_suitcase', 'putting_away_toys', 'boxing_books_up_for_storage', 17 | 'sorting_books', 'clearing_the_table_after_dinner', 'opening_packages', 18 | 'picking_up_take-out_food', 'collect_misplaced_items', 19 | 'locking_every_door', 'putting_dishes_away_after_cleaning', 20 | 'picking_up_trash', 'packing_food_for_work' 21 | ] 22 | 23 | 24 | def _run_behavior_pickplaceopen_tests() -> None: 25 | path_to_file = "predicators/behavior_utils/task_to_preselected_scenes.json" 26 | f = open(path_to_file, 'rb') 27 | data = json.load(f) 28 | f.close() 29 | 30 | tasks_to_test = OPEN_PICK_PLACE_TASKS 31 | 32 | # Create commands to run 33 | cmds = [] 34 | for task, scenes in data.items(): 35 | if task in tasks_to_test: 36 | for scene in scenes: 37 | logfolder = os.path.join( 38 | "logs", f"{task}_{scene}_{SEED}" 39 | f"_{NUM_TEST}_{TIMEOUT}/") 40 | if os.path.exists(logfolder): 41 | shutil.rmtree(logfolder) 42 | os.makedirs(logfolder) 43 | 44 | cmds.append("python predicators/main.py " 45 | "--env behavior " 46 | "--approach oracle " 47 | "--behavior_mode headless " 48 | "--option_model_name oracle_behavior " 49 | "--num_train_tasks 1 " 50 | f"--num_test_tasks {NUM_TEST} " 51 | f"--behavior_scene_name {scene} " 52 | f"--behavior_task_list \"[{task}]\" " 53 | f"--seed {SEED} " 54 | f"--offline_data_planning_timeout {TIMEOUT} " 55 | f"--timeout {TIMEOUT} " 56 | "--behavior_option_model_eval True " 57 | "--plan_only_eval True " 58 | f"--results_dir {logfolder}") 59 | 60 | # Run the commands in order. 61 | num_cmds = len(cmds) 62 | for i, cmd in enumerate(cmds): 63 | print(f"********* RUNNING COMMAND {i+1} of {num_cmds} *********") 64 | _ = os.popen(cmd).read() 65 | 66 | 67 | if __name__ == "__main__": 68 | _run_behavior_pickplaceopen_tests() 69 | -------------------------------------------------------------------------------- /scripts/openstack/README.md: -------------------------------------------------------------------------------- 1 | # Instructions for Using Openstack 2 | 3 | 1. Request an account through the LIS group. See [here](https://tig.csail.mit.edu/shared-computing/open-stack/quick-start/). 4 | 2. Follow the instructions for creating and uploading a private key. 5 | 3. Launch instances using the `predicators` image with your private key. Make sure you launch enough instances so that each (environment, approach, seed) can have one instance. 6 | 4. Create a file (e.g. `machines.txt`) that lists your instance IP addresses, one per line. 7 | 5. Create an experiment yaml file (see `scripts/openstack/configs/example_basic.yaml` for an example). 8 | 6. Run `python scripts/openstack/launch.py --config --machines --sshkey ` to launch your experiments. 9 | 7. Wait for your experiments to complete. 10 | 8. Run `python scripts/openstack/download.py --dir --machines --sshkey ` to download the results. 11 | -------------------------------------------------------------------------------- /scripts/openstack/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/scripts/openstack/__init__.py -------------------------------------------------------------------------------- /scripts/openstack/download.py: -------------------------------------------------------------------------------- 1 | """Download results from openstack experiments. 2 | 3 | Requires a file that contains a list of IP addresses for instances that are: 4 | - Turned on 5 | - Accessible via ssh for the user of this file 6 | - Configured with a predicators image 7 | - Sufficient in number to run all of the experiments in the config file 8 | 9 | The dir flag should point to a directory where the results, logs, and saved_* 10 | subdirectories will be downloaded. 11 | 12 | Usage example: 13 | python scripts/openstack/download.py --dir "$PWD" --machines machines.txt \ 14 | --sshkey ~/.ssh/cloud.key 15 | """ 16 | 17 | import argparse 18 | import os 19 | 20 | from scripts.cluster_utils import SAVE_DIRS 21 | 22 | 23 | def _main() -> None: 24 | # Set up argparse. 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--dir", required=True, type=str) 27 | parser.add_argument("--machines", required=True, type=str) 28 | parser.add_argument("--sshkey", required=True, type=str) 29 | args = parser.parse_args() 30 | openstack_dir = os.path.dirname(os.path.realpath(__file__)) 31 | # Load the machine IPs. 32 | machine_file = os.path.join(openstack_dir, args.machines) 33 | with open(machine_file, "r", encoding="utf-8") as f: 34 | machines = f.read().splitlines() 35 | # Make sure that the ssh key exists. 36 | assert os.path.exists(args.sshkey) 37 | # Create the download directory if it doesn't exist. 38 | os.makedirs(args.dir, exist_ok=True) 39 | # Loop over machines. 40 | for machine in machines: 41 | _download_from_machine(machine, args.dir, args.sshkey) 42 | 43 | 44 | def _download_from_machine(machine: str, download_dir: str, 45 | ssh_key: str) -> None: 46 | print(f"Downloading from machine {machine}") 47 | for save_dir in SAVE_DIRS: 48 | local_save_dir = os.path.join(download_dir, save_dir) 49 | os.makedirs(local_save_dir, exist_ok=True) 50 | cmd = f"scp -r -i {ssh_key} -o StrictHostKeyChecking=no " + \ 51 | f"ubuntu@{machine}:~/predicators/{save_dir}/* {local_save_dir}" 52 | retcode = os.system(cmd) 53 | if retcode != 0: 54 | print(f"WARNING: command failed: {cmd}") 55 | 56 | 57 | if __name__ == "__main__": 58 | _main() 59 | -------------------------------------------------------------------------------- /scripts/openstack/kill_all.py: -------------------------------------------------------------------------------- 1 | """Script for killing all active openstack predicator experiments. 2 | 3 | Analogous to scancel on supercloud. 4 | 5 | WARNING: any other python3.8 processes running on the machine will also be 6 | killed (but there typically shouldn't be any). 7 | 8 | See launch.py for information about the format of machines.txt. 9 | 10 | Usage example: 11 | python scripts/openstack/kill_all.py --machines machines.txt \ 12 | --sshkey ~/.ssh/cloud.key 13 | """ 14 | 15 | import argparse 16 | import os 17 | 18 | from scripts.cluster_utils import run_cmds_on_machine 19 | 20 | 21 | def _main() -> None: 22 | # Set up argparse. 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--machines", required=True, type=str) 25 | parser.add_argument("--sshkey", required=True, type=str) 26 | args = parser.parse_args() 27 | openstack_dir = os.path.dirname(os.path.realpath(__file__)) 28 | # Load the machine IPs. 29 | machine_file = os.path.join(openstack_dir, args.machines) 30 | with open(machine_file, "r", encoding="utf-8") as f: 31 | machines = f.read().splitlines() 32 | # Make sure that the ssh key exists. 33 | assert os.path.exists(args.sshkey) 34 | # Loop through each machine and kill the python3.8 process. 35 | kill_cmd = "pkill -9 python3.8" 36 | for machine in machines: 37 | print(f"Killing machine {machine}") 38 | # Allow return code 1, meaning that no process was found to kill. 39 | run_cmds_on_machine([kill_cmd], 40 | "ubuntu", 41 | machine, 42 | ssh_key=args.sshkey, 43 | allowed_return_codes=(0, 1)) 44 | 45 | 46 | if __name__ == "__main__": 47 | _main() 48 | -------------------------------------------------------------------------------- /scripts/openstack/launch.py: -------------------------------------------------------------------------------- 1 | """Launch script for openstack experiments. 2 | 3 | Requires a file that contains a list of IP addresses for instances that are: 4 | - Turned on 5 | - Accessible via ssh for the user of this file 6 | - Configured with a predicators image 7 | - Sufficient in number to run all of the experiments in the config file 8 | 9 | Usage example: 10 | python scripts/openstack/launch.py --config example_basic.yaml \ 11 | --machines machines.txt --sshkey ~/.ssh/cloud.key 12 | 13 | The default branch can be overridden with the --branch flag. 14 | """ 15 | 16 | import argparse 17 | import os 18 | 19 | from scripts.cluster_utils import DEFAULT_BRANCH, SingleSeedRunConfig, \ 20 | config_to_cmd_flags, config_to_logfile, generate_run_configs, \ 21 | get_cmds_to_prep_repo, run_cmds_on_machine 22 | 23 | 24 | def _main() -> None: 25 | # Set up argparse. 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--config", required=True, type=str) 28 | parser.add_argument("--machines", required=True, type=str) 29 | parser.add_argument("--sshkey", required=True, type=str) 30 | parser.add_argument("--branch", type=str, default=DEFAULT_BRANCH) 31 | args = parser.parse_args() 32 | openstack_dir = os.path.dirname(os.path.realpath(__file__)) 33 | # Load the machine IPs. 34 | machine_file = os.path.join(openstack_dir, args.machines) 35 | with open(machine_file, "r", encoding="utf-8") as f: 36 | machines = f.read().splitlines() 37 | # Make sure that the ssh key exists. 38 | assert os.path.exists(args.sshkey) 39 | # Generate all of the run configs and make sure that we have enough 40 | # machines to run them all. 41 | run_configs = list(generate_run_configs(args.config)) 42 | num_machines = len(machines) 43 | assert num_machines >= len(run_configs) 44 | # Launch the runs. 45 | for machine, cfg in zip(machines, run_configs): 46 | assert isinstance(cfg, SingleSeedRunConfig) 47 | logfile = os.path.join("logs", config_to_logfile(cfg)) 48 | cmd_flags = config_to_cmd_flags(cfg) 49 | cmd = f"python3.8 predicators/main.py {cmd_flags}" 50 | _launch_experiment(cmd, machine, logfile, args.sshkey, args.branch) 51 | 52 | 53 | def _launch_experiment(cmd: str, machine: str, logfile: str, ssh_key: str, 54 | branch: str) -> None: 55 | print(f"Launching on machine {machine}: {cmd}") 56 | # Enter the repo. 57 | server_cmds = ["cd ~/predicators"] 58 | # Prepare the repo. 59 | server_cmds.extend(get_cmds_to_prep_repo(branch, False)) 60 | # Run the main command. 61 | server_cmds.append(f"{cmd} &> {logfile} &") 62 | run_cmds_on_machine(server_cmds, "ubuntu", machine, ssh_key=ssh_key) 63 | 64 | 65 | if __name__ == "__main__": 66 | _main() 67 | -------------------------------------------------------------------------------- /scripts/plotting/create_per_task_histograms.py: -------------------------------------------------------------------------------- 1 | """Create histograms for per-task metrics that start with "PER_TASK_". 2 | 3 | Assumes that files in the results/ directory can be grouped by 4 | experiment ID alone. 5 | """ 6 | 7 | import glob 8 | import os 9 | import re 10 | from collections import defaultdict 11 | 12 | import dill as pkl 13 | import matplotlib.pyplot as plt 14 | 15 | from predicators.settings import CFG 16 | 17 | DPI = 500 18 | 19 | 20 | def _main() -> None: 21 | outdir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 22 | "results") 23 | os.makedirs(outdir, exist_ok=True) 24 | 25 | experiment_ids = set() 26 | solve_data = defaultdict(list) 27 | exec_data = defaultdict(list) 28 | for filepath in sorted(glob.glob(f"{CFG.results_dir}/*")): 29 | with open(filepath, "rb") as f: 30 | outdata = pkl.load(f) 31 | config = outdata["config"].__dict__.copy() 32 | run_data_defaultdict = outdata["results"] 33 | assert not set(config.keys()) & set(run_data_defaultdict.keys()) 34 | run_data_defaultdict.update(config) 35 | _, _, _, _, _, experiment_id, _ = filepath[8:-4].split("__") 36 | experiment_ids.add(experiment_id) 37 | run_data = dict( 38 | run_data_defaultdict) # want to crash if key not found! 39 | run_data.update({"experiment_id": experiment_id}) 40 | for key in run_data: 41 | if not key.startswith("PER_TASK_"): 42 | continue 43 | match = re.match(r"PER_TASK_task\d+_(solve|exec)_time", key) 44 | assert match is not None 45 | solve_or_exec = match.groups()[0] 46 | if solve_or_exec == "solve": 47 | solve_data[experiment_id].append(run_data[key]) 48 | else: 49 | exec_data[experiment_id].append(run_data[key]) 50 | if not solve_data and not exec_data: 51 | raise ValueError(f"No per-task data found in {CFG.results_dir}/") 52 | print("Found the following experiment IDs:") 53 | for experiment_id in experiment_ids: 54 | print(experiment_id) 55 | _, (ax1, ax2) = plt.subplots(1, 2) 56 | ax1.hist(solve_data[experiment_id]) 57 | ax2.hist(exec_data[experiment_id]) 58 | ax1.set_title("Per-task solve time histogram") 59 | ax2.set_title("Per-task execution time histogram") 60 | outfile = os.path.join(outdir, f"{experiment_id}__per_task.png") 61 | plt.savefig(outfile, dpi=DPI) 62 | print(f"\tFound {len(solve_data[experiment_id])} task solve times and " 63 | f"{len(exec_data[experiment_id])} task execution times") 64 | print(f"\tWrote out to {outfile}") 65 | 66 | 67 | if __name__ == "__main__": 68 | _main() 69 | -------------------------------------------------------------------------------- /scripts/run_blocks_real.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | # Constants. 6 | DATA_DIR="blocks_vision_data" 7 | TASK_NUM="0" 8 | SEED="0" 9 | IMG_SUFFIX="6blocks.png" 10 | VIZ_PLANNING="True" 11 | TASK_DIR="${DATA_DIR}/tasks" 12 | 13 | # Set up file paths. 14 | TASK_FILE="${TASK_DIR}/blocks-vision-task${TASK_NUM}.json" 15 | EVAL_TRAJ_FILE="eval_trajectories/pybullet_blocks__oracle__${SEED}________task1.traj" 16 | LISDF_PLAN_FILE="/tmp/pybullet_blocks__oracle__${SEED}________task1.json" 17 | FINAL_PLAN_FILE="/tmp/file_plan.json" 18 | 19 | # Start the pipeline. 20 | mkdir -p $TASK_DIR 21 | 22 | echo "Capturing images." 23 | python scripts/realsense_helpers.py \ 24 | --rgb $DATA_DIR/color-$TASK_NUM-$IMG_SUFFIX \ 25 | --depth $DATA_DIR/depth-$TASK_NUM-$IMG_SUFFIX 26 | 27 | echo "Running perception." 28 | python scripts/run_blocks_perception.py \ 29 | --rgb $DATA_DIR/color-$TASK_NUM-$IMG_SUFFIX \ 30 | --depth $DATA_DIR/depth-$TASK_NUM-$IMG_SUFFIX \ 31 | --goal $DATA_DIR/goal-$TASK_NUM.json \ 32 | --extrinsics $DATA_DIR/extrinsics.json \ 33 | --intrinsics $DATA_DIR/intrinsics.json \ 34 | --output $TASK_FILE # --debug_viz 35 | 36 | echo "Running planning with oracle models." 37 | python predicators/main.py --env pybullet_blocks --approach oracle \ 38 | --seed $SEED --num_test_tasks 1 \ 39 | --test_task_json_dir $TASK_DIR \ 40 | --pybullet_robot panda \ 41 | --option_model_use_gui $VIZ_PLANNING \ 42 | --option_model_name oracle --option_model_terminate_on_repeat False \ 43 | --blocks_block_size 0.0505 \ 44 | --sesame_check_static_object_changes True \ 45 | --crash_on_failure \ 46 | --timeout 100 \ 47 | --blocks_num_blocks_test [15] # just needs to be an upper bound 48 | 49 | echo "Converting plan to LISDF." 50 | python scripts/eval_trajectory_to_lisdf.py \ 51 | --input $EVAL_TRAJ_FILE \ 52 | --output $LISDF_PLAN_FILE 53 | 54 | echo "Planning to reset the robot." 55 | python scripts/lisdf_plan_to_reset.py \ 56 | --lisdf $LISDF_PLAN_FILE \ 57 | --output $FINAL_PLAN_FILE 58 | 59 | echo "Visualizing LISDF plan." 60 | python scripts/lisdf_pybullet_visualizer.py --lisdf $FINAL_PLAN_FILE 61 | 62 | echo "To execute the LISDF plan on the real robot, run this command:" 63 | echo "panda-client execute_lisdf_plan ${FINAL_PLAN_FILE}" 64 | -------------------------------------------------------------------------------- /scripts/run_checks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Running autoformatting." 4 | yapf -i -r --style .style.yapf --exclude '**/third_party' predicators 5 | yapf -i -r --style .style.yapf scripts 6 | yapf -i -r --style .style.yapf tests 7 | docformatter -i -r . --exclude venv predicators/third_party 8 | isort . 9 | echo "Autoformatting complete." 10 | 11 | echo "Running type checking." 12 | mypy . --config-file mypy.ini 13 | if [ $? -eq 0 ]; then 14 | echo "Type checking passed." 15 | else 16 | echo "Type checking failed! Terminating check script early." 17 | exit 18 | fi 19 | 20 | echo "Running linting." 21 | pytest . --pylint -m pylint --pylint-rcfile=.predicators_pylintrc 22 | if [ $? -eq 0 ]; then 23 | echo "Linting passed." 24 | else 25 | echo "Linting failed! Terminating check script early." 26 | exit 27 | fi 28 | 29 | echo "Running unit tests." 30 | pytest -s tests/ --cov-config=.coveragerc --cov=predicators/ --cov=tests/ --cov-fail-under=100 --cov-report=term-missing:skip-covered --durations=10 31 | if [ $? -eq 0 ]; then 32 | echo "Unit tests passed." 33 | else 34 | echo "Unit tests failed!" 35 | exit 36 | fi 37 | 38 | echo "All checks passed!" 39 | -------------------------------------------------------------------------------- /scripts/supercloud/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/scripts/supercloud/__init__.py -------------------------------------------------------------------------------- /scripts/supercloud/download.py: -------------------------------------------------------------------------------- 1 | """Download results from supercloud experiments. 2 | 3 | Usage example: 4 | 5 | python scripts/supercloud/download.py --dir "$PWD" --user tslvr 6 | 7 | By default, we assume that the predicators directory on supercloud is located 8 | at ~/predicators. Otherwise, use the --supercloud_dir flag. Example: 9 | 10 | python scripts/supercloud/download.py --dir "$PWD" --user njk \ 11 | --supercloud_dir "~/GitHub/research/predicators" 12 | """ 13 | 14 | import argparse 15 | import os 16 | 17 | from scripts.cluster_utils import SAVE_DIRS, SUPERCLOUD_IP 18 | 19 | 20 | def _main() -> None: 21 | # Set up argparse. 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--dir", required=True, type=str) 24 | parser.add_argument("--user", required=True, type=str) 25 | parser.add_argument("--supercloud_dir", default="~/predicators", type=str) 26 | parser.add_argument("--download_behavior_states", action="store_true") 27 | args = parser.parse_args() 28 | # Create the download directory if it doesn't exist. 29 | os.makedirs(args.dir, exist_ok=True) 30 | # Download the results. 31 | print(f"Downloading results from supercloud for user {args.user}") 32 | host = f"{args.user}@{SUPERCLOUD_IP}" 33 | save_dirs = SAVE_DIRS 34 | if args.download_behavior_states: 35 | save_dirs = save_dirs + ["tmp_behavior_states"] 36 | for save_dir in save_dirs: 37 | local_save_dir = os.path.join(args.dir, save_dir) 38 | os.makedirs(local_save_dir, exist_ok=True) 39 | cmd = "rsync -avzhe ssh " + \ 40 | f"{host}:{args.supercloud_dir}/{save_dir}/* {local_save_dir}" 41 | retcode = os.system(cmd) 42 | if retcode != 0: 43 | print(f"WARNING: command failed: {cmd}") 44 | 45 | 46 | if __name__ == "__main__": 47 | _main() 48 | -------------------------------------------------------------------------------- /scripts/supercloud/kill_all.py: -------------------------------------------------------------------------------- 1 | """Script to kill all experiments running on supercloud for a given user. 2 | 3 | Runs scancel -u on supercloud. 4 | 5 | Usage example: 6 | 7 | python scripts/supercloud/kill_all.py --user tslvr 8 | """ 9 | 10 | import argparse 11 | 12 | from scripts.cluster_utils import SUPERCLOUD_IP, run_cmds_on_machine 13 | 14 | 15 | def _main() -> None: 16 | # Set up argparse. 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--user", required=True, type=str) 19 | args = parser.parse_args() 20 | print(f"Killing all jobs on supercloud for user {args.user}") 21 | kill_cmd = f"scancel -u {args.user}" 22 | run_cmds_on_machine([kill_cmd], args.user, SUPERCLOUD_IP) 23 | 24 | 25 | if __name__ == "__main__": 26 | _main() 27 | -------------------------------------------------------------------------------- /scripts/supercloud/run_predicators_evalonly_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | FILE="scripts/supercloud/submit_supercloud_job.py" 4 | NUM_TRAIN_TASKS="200" 5 | ALL_ENVS=( 6 | "cover" 7 | "pybullet_blocks" 8 | "painting" 9 | "tools" 10 | ) 11 | 12 | # We want this to crash if backup_results already exists, because overwriting it would be bad. 13 | mkdir backup_results && mv results/* backup_results 14 | 15 | if [ $? -ne 0 ]; then 16 | echo "backup_results/ already exists, exiting" 17 | exit 1 18 | fi 19 | 20 | for ENV in ${ALL_ENVS[@]}; do 21 | # Downrefeval ablation. 22 | echo python $FILE --experiment_id ${ENV}_main_${NUM_TRAIN_TASKS}demo --env $ENV --approach grammar_search_invention --excluded_predicates all --num_train_tasks $NUM_TRAIN_TASKS --sesame_max_skeletons_optimized 1 --load_approach --load_data 23 | 24 | # GNN model-free baseline. 25 | echo python $FILE --experiment_id ${ENV}_gnn_shooting_${NUM_TRAIN_TASKS}demo --env $ENV --approach gnn_option_policy --excluded_predicates all --num_train_tasks $NUM_TRAIN_TASKS --gnn_option_policy_solve_with_shooting False --load_approach --load_data 26 | done 27 | 28 | # Commands to run after all jobs are finished: 29 | # for i in results/*main* ; do mv "$i" "${i/main/downrefeval}" ; done 30 | # for i in results/*gnn* ; do mv "$i" "${i/gnn_shooting/gnn_modelfree}" ; done 31 | # mv backup_results/* results/ 32 | # rm -r backup_results 33 | -------------------------------------------------------------------------------- /scripts/supercloud/run_predicators_num_demos_experiments.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | FILE="scripts/supercloud/submit_supercloud_job.py" 4 | ALL_NUM_TRAIN_TASKS=( 5 | "25" 6 | "50" 7 | "75" 8 | "100" 9 | "125" 10 | "150" 11 | "175" 12 | ) 13 | ALL_ENVS=( 14 | "cover" 15 | "pybullet_blocks" 16 | "painting" 17 | "tools" 18 | ) 19 | 20 | for ENV in ${ALL_ENVS[@]}; do 21 | for NUM_TRAIN_TASKS in ${ALL_NUM_TRAIN_TASKS[@]}; do 22 | # Main approach. 23 | python $FILE --experiment_id ${ENV}_main_${NUM_TRAIN_TASKS}demo --env $ENV --approach grammar_search_invention --excluded_predicates all --num_train_tasks $NUM_TRAIN_TASKS 24 | # GNN option policy approach. 25 | python $FILE --experiment_id ${ENV}_gnn_shooting_${NUM_TRAIN_TASKS}demo --env $ENV --approach gnn_option_policy --excluded_predicates all --num_train_tasks $NUM_TRAIN_TASKS 26 | done 27 | done 28 | -------------------------------------------------------------------------------- /scripts/supercloud/submit_supercloud_job.py: -------------------------------------------------------------------------------- 1 | """Script for submitting jobs on supercloud.""" 2 | 3 | import os 4 | import subprocess 5 | import sys 6 | 7 | from predicators import utils 8 | from predicators.settings import CFG 9 | 10 | START_SEED = 456 11 | NUM_SEEDS = 10 12 | 13 | 14 | def _run() -> None: 15 | args = utils.parse_args(seed_required=False) 16 | utils.update_config(args) 17 | assert CFG.seed is None, "Do not pass in a seed to this script!" 18 | job_name = CFG.experiment_id 19 | log_dir = CFG.log_dir 20 | logfile_prefix = utils.get_config_path_str() 21 | args_and_flags_str = " ".join(sys.argv[1:]) 22 | return submit_supercloud_job(job_name, log_dir, logfile_prefix, 23 | args_and_flags_str, START_SEED, NUM_SEEDS) 24 | 25 | 26 | def submit_supercloud_job(job_name: str, 27 | log_dir: str, 28 | logfile_prefix: str, 29 | args_and_flags_str: str, 30 | start_seed: int, 31 | num_seeds: int, 32 | use_gpu: bool = False) -> None: 33 | """Launch the supercloud job.""" 34 | os.makedirs(log_dir, exist_ok=True) 35 | logfile_pattern = os.path.join(log_dir, f"{logfile_prefix}__%j.log") 36 | assert logfile_pattern.count("None") == 1 37 | logfile_pattern = logfile_pattern.replace("None", "%a") 38 | mystr = (f"#!/bin/bash\npython predicators/main.py {args_and_flags_str} " 39 | f"--seed $SLURM_ARRAY_TASK_ID") 40 | temp_run_file = "temp_run_file.sh" 41 | assert not os.path.exists(temp_run_file) 42 | with open(temp_run_file, "w", encoding="utf-8") as f: 43 | f.write(mystr) 44 | cmd = "sbatch --time=99:00:00 " 45 | if use_gpu: 46 | cmd += "--partition=xeon-g6-volta --gres=gpu:volta:1 " 47 | else: 48 | cmd += "--partition=xeon-p8 " 49 | cmd += ("--nodes=1 --exclusive " 50 | f"--job-name={job_name} " 51 | f"--array={start_seed}-{start_seed+num_seeds-1} " 52 | f"-o {logfile_pattern} {temp_run_file}") 53 | print(f"Running command: {cmd}") 54 | output = subprocess.getoutput(cmd) 55 | if "command not found" in output: 56 | os.remove(temp_run_file) 57 | raise Exception("Are you logged into supercloud?") 58 | os.remove(temp_run_file) 59 | 60 | 61 | if __name__ == "__main__": 62 | _run() 63 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup script.""" 2 | from setuptools import find_packages, setup 3 | 4 | # NOTE: Windows users will have to install windows-curses 5 | # (https://pypi.org/project/windows-curses/) 6 | setup( 7 | name="predicators", 8 | version="0.1.0", 9 | packages=find_packages(include=["predicators", "predicators.*"]), 10 | install_requires=[ 11 | "mypy", 12 | "numpy>=1.22.3", 13 | "pytest", 14 | "gym==0.26.2", 15 | "matplotlib", 16 | "imageio", 17 | "imageio-ffmpeg", 18 | "pandas", 19 | "torch==2.0.1", 20 | "scipy", 21 | "tabulate", 22 | "dill", 23 | "pyperplan", 24 | "pathos", 25 | "requests", 26 | "slack_bolt", 27 | "pybullet>=3.2.0", 28 | "scikit-learn", 29 | "graphlib-backport", 30 | "openai", 31 | "pyyaml", 32 | "pylint==2.14.5", 33 | "types-PyYAML", 34 | "lisdf", 35 | "seaborn", 36 | "smepy@git+https://github.com/sebdumancic/structure_mapping.git", 37 | "pg3@git+https://github.com/tomsilver/pg3.git", 38 | "gym_sokoban@git+https://github.com/Learning-and-Intelligent-Systems/gym-sokoban.git" # pylint: disable=line-too-long 39 | ], 40 | include_package_data=True, 41 | extras_require={ 42 | "develop": [ 43 | "pytest-cov==2.12.1", 44 | "pytest-pylint==0.18.0", 45 | "yapf==0.32.0", 46 | "docformatter==1.4", 47 | "isort==5.10.1", 48 | ] 49 | }) 50 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Learning-and-Intelligent-Systems/predicators_behavior/365bc9e38e8a38395980deb8ee04febc54540d7b/tests/__init__.py -------------------------------------------------------------------------------- /tests/approaches/test_gnn_action_policy_approach.py: -------------------------------------------------------------------------------- 1 | """Test cases for the GNN action policy approach.""" 2 | 3 | import pytest 4 | 5 | from predicators import utils 6 | from predicators.approaches import create_approach 7 | from predicators.datasets import create_dataset 8 | from predicators.envs import create_new_env 9 | from predicators.settings import CFG 10 | 11 | 12 | def test_gnn_action_policy_approach(): 13 | """Tests for GNNActionPolicyApproach class.""" 14 | utils.reset_config({ 15 | "env": "cover", 16 | # Include replay data for coverage. It will be ignored. 17 | "offline_data_method": "demo+replay", 18 | "num_train_tasks": 3, 19 | "num_test_tasks": 3, 20 | "gnn_num_epochs": 20, 21 | "gnn_do_normalization": True, 22 | "horizon": 10 23 | }) 24 | env = create_new_env("cover") 25 | train_tasks = env.get_train_tasks() 26 | approach = create_approach("gnn_action_policy", env.predicates, 27 | env.options, env.types, env.action_space, 28 | train_tasks) 29 | dataset = create_dataset(env, train_tasks, env.options) 30 | assert approach.is_learning_based 31 | task = env.get_test_tasks()[0] 32 | with pytest.raises(AssertionError): # haven't learned yet! 33 | approach.solve(task, timeout=CFG.timeout) 34 | approach.learn_from_offline_dataset(dataset) 35 | policy = approach.solve(task, timeout=CFG.timeout) 36 | act = policy(task.init) 37 | assert env.action_space.contains(act.arr) 38 | # Test predictions by executing policy. 39 | utils.run_policy_with_simulator(policy, 40 | env.simulate, 41 | task.init, 42 | task.goal_holds, 43 | max_num_steps=CFG.horizon) 44 | # Test loading. 45 | approach2 = create_approach("gnn_action_policy", env.predicates, 46 | env.options, env.types, env.action_space, 47 | train_tasks) 48 | approach2.load(online_learning_cycle=None) 49 | -------------------------------------------------------------------------------- /tests/approaches/test_llm_option_renaming_approach.py: -------------------------------------------------------------------------------- 1 | """Test cases for the option-renaming open-loop LLM approach.""" 2 | 3 | from predicators import utils 4 | from predicators.approaches.llm_option_renaming_approach import \ 5 | LLMOptionRenamingApproach 6 | from predicators.envs import create_new_env 7 | 8 | 9 | def test_llm_option_renaming_approach(): 10 | """Tests for LLMOptionRenamingApproach().""" 11 | env_name = "pddl_easy_delivery_procedural_tasks" 12 | utils.reset_config({ 13 | "env": env_name, 14 | "approach": "llm_option_renaming", 15 | "num_train_tasks": 1, 16 | "num_test_tasks": 1, 17 | "strips_learner": "oracle", 18 | }) 19 | env = create_new_env(env_name) 20 | train_tasks = env.get_train_tasks() 21 | approach = LLMOptionRenamingApproach(env.predicates, env.options, 22 | env.types, env.action_space, 23 | train_tasks) 24 | assert approach.get_name() == "llm_option_renaming" 25 | assert approach._renaming_prefixes == [" ", "\n"] # pylint: disable=protected-access 26 | assert approach._renaming_suffixes == ["("] # pylint: disable=protected-access 27 | subs = approach._orig_to_replace # pylint: disable=protected-access 28 | assert set(subs) == {o.name for o in env.options} 29 | assert all(len(k) == len(v) for k, v in subs.items()) 30 | -------------------------------------------------------------------------------- /tests/approaches/test_llm_predicate_renaming_approach.py: -------------------------------------------------------------------------------- 1 | """Test cases for the predicate-renaming open-loop LLM approach.""" 2 | 3 | from predicators import utils 4 | from predicators.approaches.llm_predicate_renaming_approach import \ 5 | LLMPredicateRenamingApproach 6 | from predicators.envs import create_new_env 7 | 8 | 9 | def test_llm_predicate_renaming_approach(): 10 | """Tests for LLMPredicateRenamingApproach().""" 11 | env_name = "pddl_easy_delivery_procedural_tasks" 12 | utils.reset_config({ 13 | "env": env_name, 14 | "approach": "llm_predicate_renaming", 15 | "num_train_tasks": 1, 16 | "num_test_tasks": 1, 17 | "strips_learner": "oracle", 18 | }) 19 | env = create_new_env(env_name) 20 | train_tasks = env.get_train_tasks() 21 | approach = LLMPredicateRenamingApproach(env.predicates, env.options, 22 | env.types, env.action_space, 23 | train_tasks) 24 | assert approach.get_name() == "llm_predicate_renaming" 25 | assert approach._renaming_prefixes == [" ", "\n"] # pylint: disable=protected-access 26 | assert approach._renaming_suffixes == ["("] # pylint: disable=protected-access 27 | subs = approach._orig_to_replace # pylint: disable=protected-access 28 | assert set(subs) == {p.name for p in env.predicates} 29 | assert all(len(k) == len(v) for k, v in subs.items()) 30 | -------------------------------------------------------------------------------- /tests/approaches/test_llm_syntax_renaming_approach.py: -------------------------------------------------------------------------------- 1 | """Test cases for the syntax-renaming open-loop LLM approach.""" 2 | 3 | from predicators import utils 4 | from predicators.approaches.llm_syntax_renaming_approach import \ 5 | ORIGINAL_CHARS, REPLACEMENT_CHARS, LLMSyntaxRenamingApproach 6 | from predicators.envs import create_new_env 7 | 8 | 9 | def test_llm_syntax_renaming_approach(): 10 | """Tests for LLMSyntaxRenamingApproach().""" 11 | env_name = "pddl_easy_delivery_procedural_tasks" 12 | utils.reset_config({ 13 | "env": env_name, 14 | "approach": "llm_syntax_renaming", 15 | "num_train_tasks": 1, 16 | "num_test_tasks": 1, 17 | "strips_learner": "oracle", 18 | }) 19 | env = create_new_env(env_name) 20 | train_tasks = env.get_train_tasks() 21 | approach = LLMSyntaxRenamingApproach(env.predicates, env.options, 22 | env.types, env.action_space, 23 | train_tasks) 24 | assert approach.get_name() == "llm_syntax_renaming" 25 | assert approach._renaming_prefixes == [""] # pylint: disable=protected-access 26 | assert approach._renaming_suffixes == [""] # pylint: disable=protected-access 27 | subs = approach._orig_to_replace # pylint: disable=protected-access 28 | assert set(subs) == set(ORIGINAL_CHARS) 29 | assert all(v in REPLACEMENT_CHARS for v in subs.values()) 30 | assert len(set(subs.values())) == len(subs) 31 | assert all(len(k) == len(v) for k, v in subs.items()) 32 | -------------------------------------------------------------------------------- /tests/approaches/test_online_pg3_approach.py: -------------------------------------------------------------------------------- 1 | """Test cases for the online PG3 approach.""" 2 | 3 | import pytest 4 | 5 | from predicators import utils 6 | from predicators.approaches import ApproachFailure, ApproachTimeout 7 | from predicators.approaches.online_pg3_approach import OnlinePG3Approach 8 | from predicators.datasets import create_dataset 9 | from predicators.envs.cover import CoverEnv 10 | from predicators.main import _generate_interaction_results 11 | from predicators.settings import CFG 12 | from predicators.structs import Dataset 13 | from predicators.teacher import Teacher 14 | 15 | 16 | def test_online_pg3_approach(): 17 | """Test for OnlinePG3Approach class, entire pipeline.""" 18 | utils.reset_config({ 19 | "env": "cover", 20 | "approach": "online_pg3", 21 | "timeout": 10, 22 | "sampler_mlp_classifier_max_itr": 10, 23 | "neural_gaus_regressor_max_itr": 10, 24 | "num_online_learning_cycles": 1, 25 | "online_nsrt_learning_requests_per_cycle": 1, 26 | "num_train_tasks": 3, 27 | "num_test_tasks": 3, 28 | "explorer": "random_options", 29 | "pg3_heuristic": "demo_plan_comparison", # faster for tests 30 | "pg3_search_method": "gbfs", 31 | "pg3_gbfs_max_expansions": 1 32 | }) 33 | env = CoverEnv() 34 | train_tasks = env.get_train_tasks() 35 | approach = OnlinePG3Approach(env.predicates, env.options, env.types, 36 | env.action_space, train_tasks) 37 | dataset = create_dataset(env, train_tasks, env.options) 38 | assert approach.is_learning_based 39 | # Learning with an empty dataset should not crash. 40 | approach.learn_from_offline_dataset(Dataset([])) 41 | # Learning with the actual dataset. 42 | approach.learn_from_offline_dataset(dataset) 43 | approach.load(online_learning_cycle=None) 44 | interaction_requests = approach.get_interaction_requests() 45 | teacher = Teacher(train_tasks) 46 | interaction_results, _ = _generate_interaction_results( 47 | env, teacher, interaction_requests) 48 | approach.learn_from_interaction_results(interaction_results) 49 | approach.load(online_learning_cycle=0) 50 | with pytest.raises(FileNotFoundError): 51 | approach.load(online_learning_cycle=1) 52 | for task in env.get_test_tasks(): 53 | try: 54 | approach.solve(task, timeout=CFG.timeout) 55 | except (ApproachTimeout, ApproachFailure): # pragma: no cover 56 | pass 57 | # We won't check the policy here because we don't want unit tests to 58 | # have to train very good models, since that would be slow. 59 | -------------------------------------------------------------------------------- /tests/approaches/test_random_actions_approach.py: -------------------------------------------------------------------------------- 1 | """Test cases for the random actions approach class.""" 2 | 3 | from predicators import utils 4 | from predicators.approaches.random_actions_approach import \ 5 | RandomActionsApproach 6 | from predicators.envs.cover import CoverEnv 7 | 8 | 9 | def test_random_actions_approach(): 10 | """Tests for RandomActionsApproach class.""" 11 | utils.reset_config({ 12 | "env": "cover", 13 | "approach": "random_actions", 14 | }) 15 | env = CoverEnv() 16 | train_tasks = env.get_train_tasks() 17 | task = train_tasks[0] 18 | approach = RandomActionsApproach(env.predicates, env.options, env.types, 19 | env.action_space, train_tasks) 20 | assert not approach.is_learning_based 21 | policy = approach.solve(task, 500) 22 | actions = [] 23 | for _ in range(10): 24 | act = policy(task.init) 25 | actions.append(act) 26 | assert env.action_space.contains(act.arr) 27 | # Test reproducibility 28 | assert str(actions) == "[Action(_arr=array([0.6823519], dtype=float32)), Action(_arr=array([0.05382102], dtype=float32)), Action(_arr=array([0.22035988], dtype=float32)), Action(_arr=array([0.18437181], dtype=float32)), Action(_arr=array([0.1759059], dtype=float32)), Action(_arr=array([0.8120945], dtype=float32)), Action(_arr=array([0.92334497], dtype=float32)), Action(_arr=array([0.2765744], dtype=float32)), Action(_arr=array([0.81975454], dtype=float32)), Action(_arr=array([0.8898927], dtype=float32))]" # pylint: disable=line-too-long 29 | -------------------------------------------------------------------------------- /tests/approaches/test_refinement_estimation_approach.py: -------------------------------------------------------------------------------- 1 | """Test cases for the refinement cost estimation--based approach class.""" 2 | 3 | from predicators import utils 4 | from predicators.approaches.refinement_estimation_approach import \ 5 | RefinementEstimationApproach 6 | from predicators.envs.narrow_passage import NarrowPassageEnv 7 | from predicators.settings import CFG 8 | 9 | 10 | def _policy_solves_task(policy, task, simulator): 11 | """Helper method used in this file, copied from test_oracle_approach.py.""" 12 | traj = utils.run_policy_with_simulator(policy, 13 | simulator, 14 | task.init, 15 | task.goal_holds, 16 | max_num_steps=CFG.horizon) 17 | return task.goal_holds(traj.states[-1]) 18 | 19 | 20 | def test_refinement_estimation_approach(): 21 | """Tests for RefinementEstimationApproach class.""" 22 | args = { 23 | "env": "narrow_passage", 24 | "refinement_estimator": "oracle", 25 | } 26 | # Default to 2 train and test tasks, but allow them to be specified in 27 | # the extra args too. 28 | if "num_train_tasks" not in args: 29 | args["num_train_tasks"] = 2 30 | if "num_test_tasks" not in args: 31 | args["num_test_tasks"] = 2 32 | utils.reset_config(args) 33 | env = NarrowPassageEnv(use_gui=False) 34 | train_tasks = env.get_train_tasks() 35 | test_tasks = env.get_test_tasks() 36 | approach = RefinementEstimationApproach(env.predicates, env.options, 37 | env.types, env.action_space, 38 | train_tasks) 39 | assert approach.get_name() == "refinement_estimation" 40 | assert not approach.is_learning_based 41 | for task in train_tasks: 42 | policy = approach.solve(task, timeout=500) 43 | assert _policy_solves_task(policy, task, env.simulate) 44 | for task in test_tasks: 45 | policy = approach.solve(task, timeout=500) 46 | assert _policy_solves_task(policy, task, env.simulate) 47 | -------------------------------------------------------------------------------- /tests/behavior_utils/test_behavior_utils.py: -------------------------------------------------------------------------------- 1 | """Tests for behavior_utils.""" 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from predicators.behavior_utils import behavior_utils as utils 7 | 8 | 9 | def test_aabb_volume(): 10 | """Tests for get_aabb_volume().""" 11 | lo = np.array([1.0, 1.5, -1.0]) 12 | hi = np.array([2.0, 2.5, 0.0]) 13 | # Test zero volume calculation 14 | assert utils.get_aabb_volume(lo, lo) == 0.0 15 | # Test ordinary calculation 16 | assert utils.get_aabb_volume(lo, hi) == 1.0 17 | with pytest.raises(AssertionError): 18 | # Test assertion error when lower bound is 19 | # greater than upper bound 20 | lo1 = np.array([10.0, 12.5, 10.0]) 21 | hi1 = np.array([-10.0, -12.5, -10.0]) 22 | assert utils.get_aabb_volume(lo1, hi1) 23 | 24 | 25 | def test_aabb_closest_point(): 26 | """Tests for get_closest_point_on_aabb().""" 27 | # Test ordinary usage 28 | xyz = [1.5, 3.0, -2.5] 29 | lo = np.array([1.0, 1.5, -1.0]) 30 | hi = np.array([2.0, 2.5, 0.0]) 31 | assert utils.get_closest_point_on_aabb(xyz, lo, hi) == [1.5, 2.5, -1.0] 32 | with pytest.raises(AssertionError): 33 | # Test error where lower bound is greater than upper bound. 34 | lo1 = np.array([10.0, 12.5, 10.0]) 35 | hi1 = np.array([-10.0, -12.5, -10.0]) 36 | utils.get_closest_point_on_aabb(xyz, lo1, hi1) 37 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Shared configurations for pytest. 2 | 3 | See https://docs.pytest.org/en/6.2.x/fixture.html. 4 | """ 5 | 6 | 7 | def pytest_addoption(parser): 8 | """Enable a command line flag for running tests decorated with @longrun.""" 9 | parser.addoption('--longrun', 10 | action='store_true', 11 | dest="longrun", 12 | default=False, 13 | help="enable tests decorated with @longrun") 14 | -------------------------------------------------------------------------------- /tests/envs/test_base_env.py: -------------------------------------------------------------------------------- 1 | """Test cases for the base environment class.""" 2 | 3 | import pytest 4 | from test_oracle_approach import ENV_NAME_AND_CLS # type: ignore 5 | 6 | from predicators import utils 7 | from predicators.envs import BaseEnv, create_new_env, get_or_create_env 8 | 9 | 10 | def test_env_creation(): 11 | """Tests for create_new_env() and get_or_create_env().""" 12 | utils.reset_config({"num_train_tasks": 5, "num_test_tasks": 5}) 13 | for name, _ in ENV_NAME_AND_CLS: 14 | env = create_new_env(name, do_cache=True, use_gui=False) 15 | assert isinstance(env, BaseEnv) 16 | other_env = get_or_create_env(name) 17 | assert env is other_env 18 | train_tasks = env.get_train_tasks() 19 | for idx, train_task in enumerate(train_tasks): 20 | task = env.get_task("train", idx) 21 | assert train_task.init.allclose(task.init) 22 | assert train_task.goal == task.goal 23 | test_tasks = env.get_test_tasks() 24 | for idx, test_task in enumerate(test_tasks): 25 | task = env.get_task("test", idx) 26 | assert test_task.init.allclose(task.init) 27 | assert test_task.goal == task.goal 28 | with pytest.raises(ValueError): 29 | env.get_task("not a real task category", 0) 30 | with pytest.raises(NotImplementedError): 31 | create_new_env("Not a real env") 32 | -------------------------------------------------------------------------------- /tests/explorers/test_exploit_bilevel_planning_explorer.py: -------------------------------------------------------------------------------- 1 | """Test cases for the exploit bilevel planning explorer class.""" 2 | 3 | import pytest 4 | 5 | from predicators import utils 6 | from predicators.envs.cover import CoverEnv 7 | from predicators.explorers import BaseExplorer, create_explorer 8 | from predicators.ground_truth_nsrts import get_gt_nsrts 9 | from predicators.option_model import _OracleOptionModel 10 | 11 | 12 | def test_exploit_bilevel_planning_explorer(): 13 | """Tests for ExploitBilevelPlanningExplorer class.""" 14 | utils.reset_config({ 15 | "env": "cover", 16 | "explorer": "exploit_planning", 17 | }) 18 | env = CoverEnv() 19 | nsrts = get_gt_nsrts(env.get_name(), env.predicates, env.options) 20 | option_model = _OracleOptionModel(env) 21 | train_tasks = env.get_train_tasks() 22 | explorer = create_explorer("exploit_planning", env.predicates, env.options, 23 | env.types, env.action_space, train_tasks, nsrts, 24 | option_model) 25 | task_idx = 0 26 | task = train_tasks[task_idx] 27 | policy, termination_function = explorer.get_exploration_strategy( 28 | task_idx, 500) 29 | traj, _ = utils.run_policy( 30 | policy, 31 | env, 32 | "train", 33 | task_idx, 34 | termination_function, 35 | max_num_steps=1000, 36 | ) 37 | final_state = traj.states[-1] 38 | assert termination_function(final_state) 39 | assert task.goal_holds(final_state) 40 | 41 | # Test timeout. Should fall back. 42 | 43 | class _DummyExplorer(BaseExplorer): 44 | 45 | @classmethod 46 | def get_name(cls): 47 | return "dummy" 48 | 49 | def get_exploration_strategy(self, train_task_idx, timeout): 50 | raise NotImplementedError("Dummy explorer called") 51 | 52 | dummy_explorer = _DummyExplorer(env.predicates, env.options, env.types, 53 | env.action_space, train_tasks) 54 | assert dummy_explorer.get_name() == "dummy" 55 | 56 | explorer._fallback_explorer = dummy_explorer # pylint: disable=protected-access 57 | 58 | with pytest.raises(NotImplementedError) as e: 59 | explorer.get_exploration_strategy(task_idx, -1) 60 | assert "Dummy explorer called" in str(e) 61 | -------------------------------------------------------------------------------- /tests/explorers/test_no_explore_explorer.py: -------------------------------------------------------------------------------- 1 | """Test cases for the no explore explorer class.""" 2 | 3 | import pytest 4 | 5 | from predicators import utils 6 | from predicators.envs.cover import CoverEnv 7 | from predicators.explorers import create_explorer 8 | 9 | 10 | def test_no_explore_explorer(): 11 | """Tests for NoExploreExplorer class.""" 12 | utils.reset_config({ 13 | "env": "cover", 14 | "explorer": "no_explore", 15 | }) 16 | env = CoverEnv() 17 | train_tasks = env.get_train_tasks() 18 | task_idx = 0 19 | task = train_tasks[task_idx] 20 | explorer = create_explorer("no_explore", env.predicates, env.options, 21 | env.types, env.action_space, train_tasks) 22 | policy, termination_function = explorer.get_exploration_strategy( 23 | task_idx, 500) 24 | assert termination_function(task.init) 25 | with pytest.raises(RuntimeError) as e: 26 | policy(task.init) 27 | assert "The policy for no-explore shouldn't be used." in str(e) 28 | -------------------------------------------------------------------------------- /tests/explorers/test_random_actions_explorer.py: -------------------------------------------------------------------------------- 1 | """Test cases for the random actions explorer class.""" 2 | 3 | from predicators import utils 4 | from predicators.envs.cover import CoverEnv 5 | from predicators.explorers import create_explorer 6 | 7 | 8 | def test_random_actions_explorer(): 9 | """Tests for RandomActionsExplorer class.""" 10 | utils.reset_config({ 11 | "env": "cover", 12 | "explorer": "random_actions", 13 | }) 14 | env = CoverEnv() 15 | train_tasks = env.get_train_tasks() 16 | task_idx = 0 17 | task = train_tasks[task_idx] 18 | explorer = create_explorer("random_actions", env.predicates, env.options, 19 | env.types, env.action_space, train_tasks) 20 | policy, termination_function = explorer.get_exploration_strategy( 21 | task_idx, 500) 22 | assert not termination_function(task.init) 23 | for _ in range(10): 24 | act = policy(task.init) 25 | assert env.action_space.contains(act.arr) 26 | -------------------------------------------------------------------------------- /tests/explorers/test_random_options_explorer.py: -------------------------------------------------------------------------------- 1 | """Test cases for the random options explorer class.""" 2 | 3 | import pytest 4 | 5 | from predicators import utils 6 | from predicators.envs.cover import CoverEnv 7 | from predicators.explorers import create_explorer 8 | from predicators.structs import ParameterizedOption 9 | 10 | 11 | def test_random_options_explorer(): 12 | """Tests for RandomOptionsExplorer class.""" 13 | utils.reset_config({ 14 | "env": "cover", 15 | "explorer": "random_options", 16 | }) 17 | env = CoverEnv() 18 | train_tasks = env.get_train_tasks() 19 | task = train_tasks[0] 20 | explorer = create_explorer("random_options", env.predicates, env.options, 21 | env.types, env.action_space, train_tasks) 22 | policy, termination_function = explorer.get_exploration_strategy(task, 500) 23 | assert not termination_function(task.init) 24 | for _ in range(10): 25 | act = policy(task.init) 26 | assert env.action_space.contains(act.arr) 27 | # Test case where no applicable option can be found. 28 | opt = sorted(env.options)[0] 29 | dummy_opt = ParameterizedOption(opt.name, opt.types, opt.params_space, 30 | opt.policy, lambda _1, _2, _3, _4: False, 31 | opt.terminal) 32 | explorer = create_explorer("random_options", env.predicates, {dummy_opt}, 33 | env.types, env.action_space, train_tasks) 34 | policy, _ = explorer.get_exploration_strategy(task, 500) 35 | with pytest.raises(utils.RequestActPolicyFailure) as e: 36 | policy(task.init) 37 | assert "Random option sampling failed!" in str(e) 38 | -------------------------------------------------------------------------------- /tests/nsrt_learning/strips_learning/test_base_strips_learner.py: -------------------------------------------------------------------------------- 1 | """Tests for methods in the BaseSTRIPSLearner class.""" 2 | 3 | import pytest 4 | 5 | from predicators.nsrt_learning.strips_learning.base_strips_learner import \ 6 | BaseSTRIPSLearner 7 | from predicators.structs import PNAD, LowLevelTrajectory, Predicate, Segment, \ 8 | State, STRIPSOperator, Task, Type 9 | from predicators.utils import SingletonParameterizedOption 10 | 11 | 12 | class _MockBaseSTRIPSLearner(BaseSTRIPSLearner): 13 | """Mock class that exposes private methods for testing.""" 14 | 15 | def recompute_datastores_from_segments(self, pnads): 16 | """Exposed for testing.""" 17 | return self._recompute_datastores_from_segments(pnads) 18 | 19 | def _learn(self): 20 | raise Exception("Can't use this") 21 | 22 | @classmethod 23 | def get_name(cls) -> str: 24 | return "dummy_mock_base_strips_learner" 25 | 26 | 27 | def test_recompute_datastores_from_segments(): 28 | """Tests for recompute_datastores_from_segments().""" 29 | obj_type = Type("obj_type", ["feat"]) 30 | Pred = Predicate("Pred", [obj_type], lambda s, o: s[o[0]][0] > 0.5) 31 | opt_name_to_opt = { 32 | "Act": SingletonParameterizedOption("Act", lambda s, m, o, p: None) 33 | } 34 | obj = obj_type("obj") 35 | var = obj_type("?obj") 36 | state = State({obj: [1.0]}) 37 | act = opt_name_to_opt["Act"].ground([], []) 38 | op1 = STRIPSOperator("Op1", [var], set(), {Pred([var])}, set(), set()) 39 | pnad1 = PNAD(op1, [], (act.parent, [])) 40 | op2 = STRIPSOperator("Op2", [], set(), set(), set(), set()) 41 | pnad2 = PNAD(op2, [], (act.parent, [])) 42 | traj = LowLevelTrajectory([state, state], [act], True, 0) 43 | task = Task(state, set()) 44 | segment = Segment(traj, {Pred([obj])}, {Pred([obj])}, act) 45 | learner = _MockBaseSTRIPSLearner([traj], [task], {Pred}, [[segment]], 46 | verify_harmlessness=True) 47 | with pytest.raises(Exception) as e: 48 | learner.learn() 49 | assert "Can't use this" in str(e) 50 | learner.recompute_datastores_from_segments([pnad1, pnad2]) 51 | assert len(pnad1.datastore) == 0 52 | assert len(pnad2.datastore) == 1 53 | -------------------------------------------------------------------------------- /tests/pybullet_helpers/conftest.py: -------------------------------------------------------------------------------- 1 | """Common fixtures for the pybullet_helpers tests.""" 2 | 3 | import pybullet as p 4 | import pytest 5 | 6 | 7 | @pytest.fixture(scope="function", name="physics_client_id") 8 | def _connect_to_pybullet(): 9 | """Direct connect to PyBullet physics server, and disconnect when we're 10 | done. 11 | 12 | This fixture automatically disconnects the physics server, so we 13 | don't forget to do it ourselves. 14 | """ 15 | physics_client_id = p.connect(p.DIRECT) 16 | yield physics_client_id 17 | p.disconnect(physics_client_id) 18 | -------------------------------------------------------------------------------- /tests/pybullet_helpers/test_geometry.py: -------------------------------------------------------------------------------- 1 | """Tests for geometry PyBullet helper utilities.""" 2 | 3 | import numpy as np 4 | import pybullet as p 5 | 6 | from predicators.pybullet_helpers.geometry import Pose, get_pose, \ 7 | matrix_from_quat 8 | 9 | 10 | def test_pose(): 11 | """Tests for Pose().""" 12 | position = (5.0, 0.5, 1.0) 13 | orientation = (0.0, 0.0, 0.0, 1.0) 14 | pose = Pose(position, orientation) 15 | rpy = pose.rpy 16 | reconstructed_pose = Pose.from_rpy(position, rpy) 17 | assert pose.allclose(reconstructed_pose) 18 | unit_pose = Pose.identity() 19 | assert not pose.allclose(unit_pose) 20 | multiplied_pose = pose.multiply(unit_pose, unit_pose, unit_pose) 21 | assert pose.allclose(multiplied_pose) 22 | inverted_pose = pose.invert() 23 | assert not pose.allclose(inverted_pose) 24 | assert pose.allclose(inverted_pose.invert()) 25 | 26 | 27 | def test_matrix_from_quat(): 28 | """Tests for matrix_from_quat().""" 29 | mat = matrix_from_quat((0.0, 0.0, 0.0, 1.0)) 30 | assert np.allclose(mat, np.eye(3)) 31 | mat = matrix_from_quat((0.0, 0.0, 0.0, -1.0)) 32 | assert np.allclose(mat, np.eye(3)) 33 | mat = matrix_from_quat((1.0, 0.0, 0.0, 1.0)) 34 | expected_mat = np.array([ 35 | [1.0, 0.0, 0.0], 36 | [0.0, 0.0, -1.0], 37 | [0.0, 1.0, 0.0], 38 | ]) 39 | assert np.allclose(mat, expected_mat) 40 | 41 | 42 | def test_get_pose(physics_client_id): 43 | """Tests for get_pose().""" 44 | collision_id = p.createCollisionShape(p.GEOM_BOX, 45 | halfExtents=[1, 1, 1], 46 | physicsClientId=physics_client_id) 47 | mass = 0 48 | position = (1.0, 0.0, 3.0) 49 | orientation = (0.0, 1.0, 0.0, 0.0) 50 | expected_pose = Pose(position, orientation) 51 | body = p.createMultiBody(mass, 52 | collision_id, 53 | basePosition=position, 54 | baseOrientation=orientation, 55 | physicsClientId=physics_client_id) 56 | pose = get_pose(body, physics_client_id) 57 | assert pose.allclose(expected_pose) 58 | -------------------------------------------------------------------------------- /tests/pybullet_helpers/test_joint.py: -------------------------------------------------------------------------------- 1 | """Tests for joint PyBullet helper utilities.""" 2 | 3 | import pybullet as p 4 | 5 | from predicators.pybullet_helpers.joint import JointInfo 6 | 7 | 8 | def test_joint_info(): 9 | """Tests for JointInfo().""" 10 | 11 | fixed_joint_info = JointInfo(jointIndex=0, 12 | jointName="fake-fixed-joint", 13 | jointType=p.JOINT_FIXED, 14 | qIndex=0, 15 | uIndex=0, 16 | flags=0, 17 | jointDamping=0.1, 18 | jointFriction=0.1, 19 | jointLowerLimit=0.0, 20 | jointUpperLimit=1.0, 21 | jointMaxForce=1.0, 22 | jointMaxVelocity=1.0, 23 | linkName="fake-link", 24 | jointAxis=(0.0, 0.0, 0.0), 25 | parentFramePos=(0.0, 0.0, 0.0), 26 | parentFrameOrn=(0.0, 0.0, 0.0, 1.0), 27 | parentIndex=-1) 28 | 29 | assert fixed_joint_info.is_fixed 30 | assert not fixed_joint_info.is_circular 31 | assert not fixed_joint_info.is_movable 32 | assert not fixed_joint_info.violates_limit(0.5) 33 | assert fixed_joint_info.violates_limit(1.1) 34 | 35 | circular_joint_info = JointInfo(jointIndex=0, 36 | jointName="fake-circular-joint", 37 | jointType=p.JOINT_REVOLUTE, 38 | qIndex=0, 39 | uIndex=0, 40 | flags=0, 41 | jointDamping=0.1, 42 | jointFriction=0.1, 43 | jointLowerLimit=1.0, 44 | jointUpperLimit=0.0, 45 | jointMaxForce=1.0, 46 | jointMaxVelocity=1.0, 47 | linkName="fake-link", 48 | jointAxis=(0.0, 0.0, 0.0), 49 | parentFramePos=(0.0, 0.0, 0.0), 50 | parentFrameOrn=(0.0, 0.0, 0.0, 1.0), 51 | parentIndex=-1) 52 | 53 | assert not circular_joint_info.is_fixed 54 | assert circular_joint_info.is_circular 55 | assert circular_joint_info.is_movable 56 | assert not circular_joint_info.violates_limit(9999.0) 57 | assert not circular_joint_info.violates_limit(0.0) 58 | -------------------------------------------------------------------------------- /tests/pybullet_helpers/test_link.py: -------------------------------------------------------------------------------- 1 | """Tests for link PyBullet helper utilities.""" 2 | from unittest.mock import call, patch 3 | 4 | import pytest 5 | 6 | import predicators.pybullet_helpers.link 7 | from predicators.pybullet_helpers.geometry import Pose, multiply_poses 8 | from predicators.pybullet_helpers.link import LinkState, get_relative_link_pose 9 | 10 | _MODULE_PATH = predicators.pybullet_helpers.link.__name__ 11 | 12 | 13 | def test_link_state(): 14 | """Tests for LinkState().""" 15 | link_state = LinkState( 16 | linkWorldPosition=(0.0, 1.0, 0.0), 17 | linkWorldOrientation=(0.0, 1.0, 0.0, 1.0), 18 | localInertialFramePosition=(0.0, 0.0, 0.0), 19 | localInertialFrameOrientation=(0.0, 0.0, 0.0, 1.0), 20 | worldLinkFramePosition=(0.0, 0.0, 0.0), 21 | worldLinkFrameOrientation=(0.0, 0.0, 0.0, 1.0), 22 | ) 23 | com_pose = link_state.com_pose 24 | assert com_pose.position == (0.0, 1.0, 0.0) 25 | assert com_pose.orientation == (0.0, 1.0, 0.0, 1.0) 26 | pose = link_state.pose 27 | assert pose.position == (0.0, 0.0, 0.0) 28 | assert pose.orientation == (0.0, 0.0, 0.0, 1.0) 29 | 30 | 31 | @pytest.mark.parametrize("body, link1, link2, physics_client_id", 32 | [(0, 1, 2, 0), (1, 2, 8, 2)]) 33 | def test_get_relative_link_pose(body, link1, link2, physics_client_id): 34 | """Tests for get_relative_link_pose().""" 35 | world_from_link1 = Pose(position=(0.5, 0.5, 0.5), 36 | orientation=(0.0, 0.0, 0.0, 1.0)) 37 | world_from_link2 = Pose(position=(1.0, 1.0, 1.0), 38 | orientation=(0.0, 1.0, 0.0, 1.0)) 39 | link2_from_link1 = multiply_poses(world_from_link2.invert(), 40 | world_from_link1) 41 | 42 | with patch(f"{_MODULE_PATH}.get_link_pose") as mock_get_link_pose: 43 | mock_get_link_pose.side_effect = [world_from_link1, world_from_link2] 44 | relative_link_pose = get_relative_link_pose(body, link1, link2, 45 | physics_client_id) 46 | assert relative_link_pose == link2_from_link1 47 | 48 | assert mock_get_link_pose.call_count == 2 49 | mock_get_link_pose.assert_has_calls([ 50 | call(body, link1, physics_client_id), 51 | call(body, link2, physics_client_id) 52 | ]) 53 | -------------------------------------------------------------------------------- /tests/refinement_estimators/test_base_refinement_estimator.py: -------------------------------------------------------------------------------- 1 | """Test cases for the base refinement estimator class.""" 2 | 3 | import pytest 4 | 5 | from predicators.refinement_estimators import BaseRefinementEstimator, \ 6 | create_refinement_estimator 7 | 8 | ESTIMATOR_NAMES = ["oracle"] 9 | 10 | 11 | def test_refinement_estimator_creation(): 12 | """Tests for create_refinement_estimator().""" 13 | for est_name in ESTIMATOR_NAMES: 14 | estimator = create_refinement_estimator(est_name) 15 | assert isinstance(estimator, BaseRefinementEstimator) 16 | assert estimator.get_name() == est_name 17 | with pytest.raises(NotImplementedError): 18 | create_refinement_estimator("non-existent refinement estimator") 19 | -------------------------------------------------------------------------------- /tests/refinement_estimators/test_oracle_refinement_estimator.py: -------------------------------------------------------------------------------- 1 | """Test cases for the oracle refinement cost estimator.""" 2 | 3 | import pytest 4 | 5 | from predicators import utils 6 | from predicators.envs.narrow_passage import NarrowPassageEnv 7 | from predicators.ground_truth_nsrts import get_gt_nsrts 8 | from predicators.refinement_estimators.oracle_refinement_estimator import \ 9 | OracleRefinementEstimator 10 | from predicators.settings import CFG 11 | 12 | 13 | def test_oracle_refinement_estimator(): 14 | """Test general properties of oracle refinement cost estimator.""" 15 | utils.reset_config({"env": "non-existent env"}) 16 | estimator = OracleRefinementEstimator() 17 | assert estimator.get_name() == "oracle" 18 | with pytest.raises(NotImplementedError): 19 | estimator.get_cost([], []) 20 | 21 | 22 | def test_narrow_passage_oracle_refinement_estimator(): 23 | """Test oracle refinement cost estimator for narrow_passage env.""" 24 | utils.reset_config({"env": "narrow_passage"}) 25 | estimator = OracleRefinementEstimator() 26 | 27 | # Get env objects and NSRTs 28 | env = NarrowPassageEnv() 29 | door_type, _, robot_type, target_type, _ = sorted(env.types) 30 | sample_state = env.get_train_tasks()[0].init 31 | door, = sample_state.get_objects(door_type) 32 | robot, = sample_state.get_objects(robot_type) 33 | target, = sample_state.get_objects(target_type) 34 | gt_nsrts = get_gt_nsrts(CFG.env, env.predicates, env.options) 35 | move_and_open_door_nsrt, move_to_target_nsrt = sorted(gt_nsrts) 36 | 37 | # Ground NSRTs using objects 38 | ground_move_and_open_door = move_and_open_door_nsrt.ground([robot, door]) 39 | ground_move_to_target = move_to_target_nsrt.ground([robot, target]) 40 | 41 | # Test direct MoveToTarget skeleton 42 | move_direct_skeleton = [ground_move_to_target] 43 | move_direct_cost = estimator.get_cost(move_direct_skeleton, []) 44 | assert move_direct_cost == 3 45 | 46 | # Test open door then move skeleton 47 | move_through_door_skeleton = [ 48 | ground_move_and_open_door, 49 | ground_move_to_target, 50 | ] 51 | move_through_door_cost = estimator.get_cost(move_through_door_skeleton, []) 52 | assert move_through_door_cost == 1 + 1 53 | 54 | # Test open door multiple times then move skeleton 55 | move_door_multiple_skeleton = [ 56 | ground_move_and_open_door, 57 | ground_move_and_open_door, 58 | ground_move_and_open_door, 59 | ground_move_to_target, 60 | ] 61 | move_door_multiple_cost = estimator.get_cost(move_door_multiple_skeleton, 62 | []) 63 | assert move_door_multiple_cost == 4 64 | 65 | # Make sure that sorting the costs makes sense 66 | assert sorted([ 67 | move_direct_cost, move_through_door_cost, move_door_multiple_cost 68 | ]) == [move_through_door_cost, move_direct_cost, move_door_multiple_cost] 69 | -------------------------------------------------------------------------------- /tests/test_args.py: -------------------------------------------------------------------------------- 1 | """Tests for args.py.""" 2 | 3 | import sys 4 | 5 | from predicators import utils 6 | 7 | 8 | def test_args(): 9 | """Tests for args.py.""" 10 | sys.argv = [ 11 | "dummy", "--env", "my_env", "--approach", "my_approach", "--seed", 12 | "123" 13 | ] 14 | args = utils.parse_args() 15 | assert args["env"] == "my_env" 16 | assert args["approach"] == "my_approach" 17 | assert args["seed"] == 123 18 | -------------------------------------------------------------------------------- /tests/test_settings.py: -------------------------------------------------------------------------------- 1 | """Test cases for some parts of the settings.py file.""" 2 | 3 | from predicators import utils 4 | from predicators.settings import get_allowed_query_type_names 5 | 6 | 7 | def test_get_allowed_query_type_names(): 8 | """Test the get_allowed_query_type_names method.""" 9 | utils.reset_config() 10 | assert get_allowed_query_type_names() == set() 11 | utils.reset_config({ 12 | "option_learner": "direct_bc", 13 | }) 14 | assert get_allowed_query_type_names() == {"PathToStateQuery"} 15 | utils.reset_config({ 16 | "option_learner": "no_learning", 17 | "approach": "interactive_learning" 18 | }) 19 | assert get_allowed_query_type_names() == {"GroundAtomsHoldQuery"} 20 | utils.reset_config({ 21 | "option_learner": "no_learning", 22 | "approach": "unittest" 23 | }) 24 | assert get_allowed_query_type_names() == { 25 | "GroundAtomsHoldQuery", "DemonstrationQuery", "PathToStateQuery", 26 | "_MockQuery" 27 | } 28 | --------------------------------------------------------------------------------