├── .flake8 ├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── code_checks.sh ├── demonstrations ├── sac_ant_traj_len_10_seed_354516857.pkl ├── sac_ant_traj_len_10_seed_68863226.pkl ├── sac_ant_traj_len_10_seed_827609420.pkl ├── sac_ant_traj_len_1_seed_592603083.pkl ├── sac_ant_traj_len_1_seed_60586324.pkl ├── sac_ant_traj_len_1_seed_702401661.pkl ├── sac_ant_traj_len_50_seed_135895263.pkl ├── sac_ant_traj_len_50_seed_238906182.pkl ├── sac_ant_traj_len_50_seed_276687624.pkl ├── sac_cheetah_bw_traj_len_10_seed_200763839.pkl ├── sac_cheetah_bw_traj_len_10_seed_293462553.pkl ├── sac_cheetah_bw_traj_len_10_seed_779465690.pkl ├── sac_cheetah_bw_traj_len_1_seed_531475214.pkl ├── sac_cheetah_bw_traj_len_1_seed_903840226.pkl ├── sac_cheetah_bw_traj_len_1_seed_905588575.pkl ├── sac_cheetah_bw_traj_len_50_seed_222108450.pkl ├── sac_cheetah_bw_traj_len_50_seed_472880492.pkl ├── sac_cheetah_bw_traj_len_50_seed_890295234.pkl ├── sac_cheetah_fw_traj_len_10_seed_102743744.pkl ├── sac_cheetah_fw_traj_len_10_seed_569580504.pkl ├── sac_cheetah_fw_traj_len_10_seed_874088138.pkl ├── sac_cheetah_fw_traj_len_1_seed_22750069.pkl ├── sac_cheetah_fw_traj_len_1_seed_31823986.pkl ├── sac_cheetah_fw_traj_len_1_seed_686078893.pkl ├── sac_cheetah_fw_traj_len_50_seed_559763593.pkl ├── sac_cheetah_fw_traj_len_50_seed_592117357.pkl ├── sac_cheetah_fw_traj_len_50_seed_96575016.pkl ├── sac_hopper_traj_len_10_seed_647574168.pkl ├── sac_hopper_traj_len_10_seed_700423463.pkl ├── sac_hopper_traj_len_10_seed_750917369.pkl ├── sac_hopper_traj_len_1_seed_125775086.pkl ├── sac_hopper_traj_len_1_seed_582846018.pkl ├── sac_hopper_traj_len_1_seed_809943847.pkl ├── sac_hopper_traj_len_50_seed_264928546.pkl ├── sac_hopper_traj_len_50_seed_371855109.pkl ├── sac_hopper_traj_len_50_seed_802647326.pkl ├── sac_pendulum_traj_len_10_seed_177330974.pkl ├── sac_pendulum_traj_len_10_seed_294267497.pkl ├── sac_pendulum_traj_len_10_seed_610830928.pkl ├── sac_pendulum_traj_len_1_seed_556679729.pkl ├── sac_pendulum_traj_len_1_seed_583612201.pkl ├── sac_pendulum_traj_len_1_seed_728082477.pkl ├── sac_pendulum_traj_len_50_seed_196767060.pkl ├── sac_pendulum_traj_len_50_seed_523507457.pkl ├── sac_pendulum_traj_len_50_seed_960406656.pkl ├── sac_reach_true_traj_len_10_seed_183338627.pkl ├── sac_reach_true_traj_len_10_seed_48907083.pkl ├── sac_reach_true_traj_len_10_seed_880581313.pkl ├── sac_reach_true_traj_len_1_seed_326943823.pkl ├── sac_reach_true_traj_len_1_seed_497025194.pkl ├── sac_reach_true_traj_len_1_seed_571661907.pkl ├── sac_reach_true_traj_len_49_traj_n_1000_seed_85509514.pkl ├── sac_reach_true_traj_len_50_seed_44258750.pkl ├── sac_reach_true_traj_len_50_seed_696002925.pkl └── sac_reach_true_traj_len_50_seed_742173558.pkl ├── docker ├── Dockerfile └── environment.yml ├── policies ├── sac_ant_2e6.mp4 ├── sac_ant_2e6.zip ├── sac_cheetah_bw_2e6.mp4 ├── sac_cheetah_bw_2e6.zip ├── sac_cheetah_fw_2e6.mp4 ├── sac_cheetah_fw_2e6.zip ├── sac_hopper_2e6.mp4 ├── sac_hopper_2e6.zip ├── sac_pendulum_6e4.mp4 ├── sac_pendulum_6e4.zip ├── sac_reach_task_2e6.mp4 ├── sac_reach_task_2e6.zip ├── sac_reach_true_2e6.mp4 └── sac_reach_true_2e6.zip ├── scripts ├── create_demonstrations.py ├── evaluate_policy.py ├── evaluate_result_policies.py ├── mujoco_evaluate_inferred_reward.py ├── plot_learning_curves.py ├── run_gail.py ├── toy_experiments.py ├── train_inverse_dynamics.py ├── train_sac.py └── train_vae.py ├── setup.py ├── skills ├── balancing.mp4 ├── balancing_rollouts.pkl ├── jumping.mp4 └── jumping_rollouts.pkl └── src └── deep_rlsp ├── __init__.py ├── ablation_AverageFeatures.py ├── ablation_Waypoints.py ├── envs ├── __init__.py ├── gridworlds │ ├── __init__.py │ ├── apples.py │ ├── apples_spec.py │ ├── basic_room.py │ ├── batteries.py │ ├── batteries_spec.py │ ├── env.py │ ├── gym_envs.py │ ├── one_hot_action_space_wrapper.py │ ├── room.py │ ├── room_spec.py │ ├── tests │ │ ├── apples_test.py │ │ ├── batteries_test.py │ │ ├── common.py │ │ ├── env_test.py │ │ ├── room_test.py │ │ ├── test_observation_spaces.py │ │ └── train_test.py │ ├── train.py │ └── train_spec.py ├── mujoco │ ├── __init__.py │ ├── ant.py │ ├── assets │ │ ├── ant.xml │ │ ├── ant_footsensor.xml │ │ ├── ant_plot.xml │ │ ├── half_cheetah.xml │ │ └── half_cheetah_plot.xml │ └── half_cheetah.py ├── reward_wrapper.py └── robotics │ ├── __init__.py │ ├── assets │ ├── LICENSE.md │ ├── fetch │ │ ├── reach.xml │ │ ├── robot.xml │ │ └── shared.xml │ ├── stls │ │ ├── .get │ │ └── fetch │ │ │ ├── base_link_collision.stl │ │ │ ├── bellows_link_collision.stl │ │ │ ├── elbow_flex_link_collision.stl │ │ │ ├── estop_link.stl │ │ │ ├── forearm_roll_link_collision.stl │ │ │ ├── gripper_link.stl │ │ │ ├── head_pan_link_collision.stl │ │ │ ├── head_tilt_link_collision.stl │ │ │ ├── l_wheel_link_collision.stl │ │ │ ├── laser_link.stl │ │ │ ├── r_wheel_link_collision.stl │ │ │ ├── shoulder_lift_link_collision.stl │ │ │ ├── shoulder_pan_link_collision.stl │ │ │ ├── torso_fixed_link.stl │ │ │ ├── torso_lift_link_collision.stl │ │ │ ├── upperarm_roll_link_collision.stl │ │ │ ├── wrist_flex_link_collision.stl │ │ │ └── wrist_roll_link_collision.stl │ └── textures │ │ ├── block.png │ │ └── block_hidden.png │ └── reach_side_effects.py ├── latent_rlsp.py ├── learn_dynamics_model.py ├── model ├── __init__.py ├── base.py ├── dynamics_mdn.py ├── dynamics_mlp.py ├── exact_dynamics_mujoco.py ├── experience_replay.py ├── gridworlds_feature_space.py ├── inverse_model_env_wrapper.py ├── inverse_policy_mdn.py ├── latent_space.py ├── mujoco_debug_models.py ├── rssm.py ├── state_vae.py └── tabular.py ├── policy_discriminator.py ├── relative_reachability.py ├── rlsp.py ├── run.py ├── run_mujoco.py ├── sampling.py ├── solvers ├── __init__.py ├── ppo.py └── value_iter.py ├── tests └── test_toy_experiments.py └── util ├── __init__.py ├── dist.py ├── helper.py ├── linalg.py ├── mujoco.py ├── parameter_checks.py ├── probs.py ├── results.py ├── timer.py ├── train.py └── video.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | # See https://github.com/PyCQA/pycodestyle/issues/373 4 | extend-ignore = E203, 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *.swp 3 | gail_logs/ 4 | output/ 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | 111 | .envrc 112 | results/ 113 | tf_ckpt 114 | mjkey.txt 115 | 116 | # log files 117 | *.out 118 | *.log 119 | 120 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | dist: xenial 3 | cache: pip 4 | python: 5 | - 3.7 6 | before_install: 7 | - python$PY -m pip install --upgrade pip setuptools wheel 8 | install: 9 | # hotfix: ray 0.8.0 causes the jobs to not be properly executed on travis 10 | # (but it works on other machines) 11 | - pip install ray==0.7.6 12 | - pip install black mypy flake8 13 | script: 14 | - bash code_checks.sh 15 | - python setup.py test 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Center for Human-Compatible AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /code_checks.sh: -------------------------------------------------------------------------------- 1 | black --check src 2 | black --check scripts 3 | flake8 src 4 | #flake8 scripts 5 | mypy --ignore-missing-imports src 6 | -------------------------------------------------------------------------------- /demonstrations/sac_ant_traj_len_10_seed_354516857.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_ant_traj_len_10_seed_354516857.pkl -------------------------------------------------------------------------------- /demonstrations/sac_ant_traj_len_10_seed_68863226.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_ant_traj_len_10_seed_68863226.pkl -------------------------------------------------------------------------------- /demonstrations/sac_ant_traj_len_10_seed_827609420.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_ant_traj_len_10_seed_827609420.pkl -------------------------------------------------------------------------------- /demonstrations/sac_ant_traj_len_1_seed_592603083.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_ant_traj_len_1_seed_592603083.pkl -------------------------------------------------------------------------------- /demonstrations/sac_ant_traj_len_1_seed_60586324.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_ant_traj_len_1_seed_60586324.pkl -------------------------------------------------------------------------------- /demonstrations/sac_ant_traj_len_1_seed_702401661.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_ant_traj_len_1_seed_702401661.pkl -------------------------------------------------------------------------------- /demonstrations/sac_ant_traj_len_50_seed_135895263.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_ant_traj_len_50_seed_135895263.pkl -------------------------------------------------------------------------------- /demonstrations/sac_ant_traj_len_50_seed_238906182.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_ant_traj_len_50_seed_238906182.pkl -------------------------------------------------------------------------------- /demonstrations/sac_ant_traj_len_50_seed_276687624.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_ant_traj_len_50_seed_276687624.pkl -------------------------------------------------------------------------------- /demonstrations/sac_cheetah_bw_traj_len_10_seed_200763839.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_bw_traj_len_10_seed_200763839.pkl -------------------------------------------------------------------------------- /demonstrations/sac_cheetah_bw_traj_len_10_seed_293462553.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_bw_traj_len_10_seed_293462553.pkl -------------------------------------------------------------------------------- /demonstrations/sac_cheetah_bw_traj_len_10_seed_779465690.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_bw_traj_len_10_seed_779465690.pkl -------------------------------------------------------------------------------- /demonstrations/sac_cheetah_bw_traj_len_1_seed_531475214.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_bw_traj_len_1_seed_531475214.pkl -------------------------------------------------------------------------------- /demonstrations/sac_cheetah_bw_traj_len_1_seed_903840226.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_bw_traj_len_1_seed_903840226.pkl -------------------------------------------------------------------------------- /demonstrations/sac_cheetah_bw_traj_len_1_seed_905588575.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_bw_traj_len_1_seed_905588575.pkl -------------------------------------------------------------------------------- /demonstrations/sac_cheetah_bw_traj_len_50_seed_222108450.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_bw_traj_len_50_seed_222108450.pkl -------------------------------------------------------------------------------- /demonstrations/sac_cheetah_bw_traj_len_50_seed_472880492.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_bw_traj_len_50_seed_472880492.pkl -------------------------------------------------------------------------------- /demonstrations/sac_cheetah_bw_traj_len_50_seed_890295234.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_bw_traj_len_50_seed_890295234.pkl -------------------------------------------------------------------------------- /demonstrations/sac_cheetah_fw_traj_len_10_seed_102743744.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_fw_traj_len_10_seed_102743744.pkl -------------------------------------------------------------------------------- /demonstrations/sac_cheetah_fw_traj_len_10_seed_569580504.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_fw_traj_len_10_seed_569580504.pkl -------------------------------------------------------------------------------- /demonstrations/sac_cheetah_fw_traj_len_10_seed_874088138.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_fw_traj_len_10_seed_874088138.pkl -------------------------------------------------------------------------------- /demonstrations/sac_cheetah_fw_traj_len_1_seed_22750069.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_fw_traj_len_1_seed_22750069.pkl -------------------------------------------------------------------------------- /demonstrations/sac_cheetah_fw_traj_len_1_seed_31823986.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_fw_traj_len_1_seed_31823986.pkl -------------------------------------------------------------------------------- /demonstrations/sac_cheetah_fw_traj_len_1_seed_686078893.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_fw_traj_len_1_seed_686078893.pkl -------------------------------------------------------------------------------- /demonstrations/sac_cheetah_fw_traj_len_50_seed_559763593.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_fw_traj_len_50_seed_559763593.pkl -------------------------------------------------------------------------------- /demonstrations/sac_cheetah_fw_traj_len_50_seed_592117357.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_fw_traj_len_50_seed_592117357.pkl -------------------------------------------------------------------------------- /demonstrations/sac_cheetah_fw_traj_len_50_seed_96575016.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_fw_traj_len_50_seed_96575016.pkl -------------------------------------------------------------------------------- /demonstrations/sac_hopper_traj_len_10_seed_647574168.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_hopper_traj_len_10_seed_647574168.pkl -------------------------------------------------------------------------------- /demonstrations/sac_hopper_traj_len_10_seed_700423463.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_hopper_traj_len_10_seed_700423463.pkl -------------------------------------------------------------------------------- /demonstrations/sac_hopper_traj_len_10_seed_750917369.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_hopper_traj_len_10_seed_750917369.pkl -------------------------------------------------------------------------------- /demonstrations/sac_hopper_traj_len_1_seed_125775086.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_hopper_traj_len_1_seed_125775086.pkl -------------------------------------------------------------------------------- /demonstrations/sac_hopper_traj_len_1_seed_582846018.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_hopper_traj_len_1_seed_582846018.pkl -------------------------------------------------------------------------------- /demonstrations/sac_hopper_traj_len_1_seed_809943847.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_hopper_traj_len_1_seed_809943847.pkl -------------------------------------------------------------------------------- /demonstrations/sac_hopper_traj_len_50_seed_264928546.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_hopper_traj_len_50_seed_264928546.pkl -------------------------------------------------------------------------------- /demonstrations/sac_hopper_traj_len_50_seed_371855109.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_hopper_traj_len_50_seed_371855109.pkl -------------------------------------------------------------------------------- /demonstrations/sac_hopper_traj_len_50_seed_802647326.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_hopper_traj_len_50_seed_802647326.pkl -------------------------------------------------------------------------------- /demonstrations/sac_pendulum_traj_len_10_seed_177330974.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_pendulum_traj_len_10_seed_177330974.pkl -------------------------------------------------------------------------------- /demonstrations/sac_pendulum_traj_len_10_seed_294267497.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_pendulum_traj_len_10_seed_294267497.pkl -------------------------------------------------------------------------------- /demonstrations/sac_pendulum_traj_len_10_seed_610830928.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_pendulum_traj_len_10_seed_610830928.pkl -------------------------------------------------------------------------------- /demonstrations/sac_pendulum_traj_len_1_seed_556679729.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_pendulum_traj_len_1_seed_556679729.pkl -------------------------------------------------------------------------------- /demonstrations/sac_pendulum_traj_len_1_seed_583612201.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_pendulum_traj_len_1_seed_583612201.pkl -------------------------------------------------------------------------------- /demonstrations/sac_pendulum_traj_len_1_seed_728082477.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_pendulum_traj_len_1_seed_728082477.pkl -------------------------------------------------------------------------------- /demonstrations/sac_pendulum_traj_len_50_seed_196767060.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_pendulum_traj_len_50_seed_196767060.pkl -------------------------------------------------------------------------------- /demonstrations/sac_pendulum_traj_len_50_seed_523507457.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_pendulum_traj_len_50_seed_523507457.pkl -------------------------------------------------------------------------------- /demonstrations/sac_pendulum_traj_len_50_seed_960406656.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_pendulum_traj_len_50_seed_960406656.pkl -------------------------------------------------------------------------------- /demonstrations/sac_reach_true_traj_len_10_seed_183338627.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_reach_true_traj_len_10_seed_183338627.pkl -------------------------------------------------------------------------------- /demonstrations/sac_reach_true_traj_len_10_seed_48907083.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_reach_true_traj_len_10_seed_48907083.pkl -------------------------------------------------------------------------------- /demonstrations/sac_reach_true_traj_len_10_seed_880581313.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_reach_true_traj_len_10_seed_880581313.pkl -------------------------------------------------------------------------------- /demonstrations/sac_reach_true_traj_len_1_seed_326943823.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_reach_true_traj_len_1_seed_326943823.pkl -------------------------------------------------------------------------------- /demonstrations/sac_reach_true_traj_len_1_seed_497025194.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_reach_true_traj_len_1_seed_497025194.pkl -------------------------------------------------------------------------------- /demonstrations/sac_reach_true_traj_len_1_seed_571661907.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_reach_true_traj_len_1_seed_571661907.pkl -------------------------------------------------------------------------------- /demonstrations/sac_reach_true_traj_len_49_traj_n_1000_seed_85509514.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_reach_true_traj_len_49_traj_n_1000_seed_85509514.pkl -------------------------------------------------------------------------------- /demonstrations/sac_reach_true_traj_len_50_seed_44258750.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_reach_true_traj_len_50_seed_44258750.pkl -------------------------------------------------------------------------------- /demonstrations/sac_reach_true_traj_len_50_seed_696002925.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_reach_true_traj_len_50_seed_696002925.pkl -------------------------------------------------------------------------------- /demonstrations/sac_reach_true_traj_len_50_seed_742173558.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_reach_true_traj_len_50_seed_742173558.pkl -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | #FROM tensorflow/tensorflow:1.13.2-gpu 2 | FROM tensorflow/tensorflow:1.13.2 3 | 4 | # Install Anaconda 5 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 6 | ENV PATH /opt/conda/bin:$PATH 7 | 8 | RUN apt-get update --fix-missing && apt-get install -y wget bzip2 ca-certificates \ 9 | libglib2.0-0 libxext6 libsm6 libxrender1 \ 10 | git mercurial subversion 11 | 12 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-py37_4.8.2-Linux-x86_64.sh -O ~/anaconda.sh && \ 13 | /bin/bash ~/anaconda.sh -b -p /opt/conda && \ 14 | rm ~/anaconda.sh && \ 15 | ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ 16 | echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ 17 | echo "conda activate base" >> ~/.bashrc 18 | 19 | # Install Mujoco 20 | RUN apt-get update -q \ 21 | && DEBIAN_FRONTEND=noninteractive apt-get install -y \ 22 | curl \ 23 | git \ 24 | libgl1-mesa-dev \ 25 | libgl1-mesa-glx \ 26 | libglew-dev \ 27 | libosmesa6-dev \ 28 | software-properties-common \ 29 | net-tools \ 30 | unzip \ 31 | vim \ 32 | virtualenv \ 33 | wget \ 34 | xpra \ 35 | xserver-xorg-dev \ 36 | patchelf \ 37 | && apt-get clean \ 38 | && rm -rf /var/lib/apt/lists/* 39 | 40 | RUN mkdir -p /root/.mujoco \ 41 | && wget https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip \ 42 | && unzip mujoco.zip -d /root/.mujoco \ 43 | && mv /root/.mujoco/mujoco200_linux /root/.mujoco/mujoco200 \ 44 | && rm mujoco.zip 45 | COPY ./mjkey.txt /root/.mujoco/ 46 | ENV LD_LIBRARY_PATH /root/.mujoco/mujoco200/bin:${LD_LIBRARY_PATH} 47 | ENV LD_LIBRARY_PATH /usr/local/nvidia/lib64:${LD_LIBRARY_PATH} 48 | 49 | RUN conda install -c menpo osmesa 50 | 51 | # Set up environment 52 | RUN mkdir /deep-rlsp 53 | COPY environment.yml /deep-rlsp/ 54 | RUN conda env create -f /deep-rlsp/environment.yml 55 | #RUN conda run -n deep-rlsp pip uninstall tensorflow 56 | #RUN conda run -n deep-rlsp pip install tensorflow-gpu==1.13.2 57 | ENV PYTHONPATH "${PYTHONPATH}:/deep-rlsp/src" 58 | -------------------------------------------------------------------------------- /docker/environment.yml: -------------------------------------------------------------------------------- 1 | name: deep-rlsp 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - attrs=19.3.0=py_0 7 | - backcall=0.1.0=py37_0 8 | - bleach=3.1.0=py_0 9 | - ca-certificates=2019.11.27=0 10 | - certifi=2019.11.28=py37_0 11 | - decorator=4.4.1=py_0 12 | - defusedxml=0.6.0=py_0 13 | - gmp=6.1.2=h6c8ec71_1 14 | - importlib_metadata=1.3.0=py37_0 15 | - ipykernel=5.1.3=py37h39e3cac_1 16 | - ipython=7.11.1=py37h39e3cac_0 17 | - ipython_genutils=0.2.0=py37_0 18 | - jedi=0.15.2=py37_0 19 | - jinja2=2.10.3=py_0 20 | - jsonschema=3.2.0=py37_0 21 | - jupyter_client=5.3.4=py37_0 22 | - jupyter_core=4.6.1=py37_0 23 | - ld_impl_linux-64=2.33.1=h53a641e_7 24 | - libedit=3.1.20181209=hc058e9b_0 25 | - libffi=3.2.1=hd88cf55_4 26 | - libgcc-ng=9.1.0=hdf63c60_0 27 | - libsodium=1.0.16=h1bed415_0 28 | - libstdcxx-ng=9.1.0=hdf63c60_0 29 | - markupsafe=1.1.1=py37h7b6447c_0 30 | - mistune=0.8.4=py37h7b6447c_0 31 | - nb_conda=2.2.1=py37_0 32 | - nb_conda_kernels=2.2.2=py37_0 33 | - nbconvert=5.6.1=py37_0 34 | - nbformat=4.4.0=py37_0 35 | - ncurses=6.1=he6710b0_1 36 | - notebook=6.0.2=py37_0 37 | - openssl=1.1.1d=h7b6447c_3 38 | - pandoc=2.2.3.2=0 39 | - pandocfilters=1.4.2=py37_1 40 | - parso=0.5.2=py_0 41 | - pip=19.3.1=py37_0 42 | - prometheus_client=0.7.1=py_0 43 | - prompt_toolkit=3.0.2=py_0 44 | - pygments=2.5.2=py_0 45 | - python=3.7.6=h0371630_2 46 | - python-dateutil=2.8.1=py_0 47 | - pyzmq=18.1.0=py37he6710b0_0 48 | - readline=7.0=h7b6447c_5 49 | - send2trash=1.5.0=py37_0 50 | - setuptools=44.0.0=py37_0 51 | - sqlite=3.30.1=h7b6447c_0 52 | - terminado=0.8.3=py37_0 53 | - testpath=0.4.4=py_0 54 | - tk=8.6.8=hbc83047_0 55 | - tornado=6.0.3=py37h7b6447c_0 56 | - traitlets=4.3.3=py37_0 57 | - webencodings=0.5.1=py37_1 58 | - wheel=0.33.6=py37_0 59 | - xz=5.2.4=h14c3975_4 60 | - zeromq=4.3.1=he6710b0_3 61 | - zlib=1.2.11=h7b6447c_3 62 | - pip: 63 | - absl-py==0.9.0 64 | - appdirs==1.4.3 65 | - astor==0.8.1 66 | - atari-py==0.2.6 67 | - black==19.10b0 68 | - cffi==1.13.2 69 | - chardet==3.0.4 70 | - click==7.0 71 | - cloudpickle==1.2.2 72 | - colorama==0.4.3 73 | - cycler==0.10.0 74 | - cython==0.29.14 75 | - docopt==0.6.2 76 | - entrypoints==0.3 77 | - fasteners==0.15 78 | - filelock==3.0.12 79 | - flake8==3.7.9 80 | - funcsigs==1.0.2 81 | - future==0.18.2 82 | - gast==0.2.2 83 | - gitdb2==2.0.6 84 | - gitpython==3.0.5 85 | - glfw==1.10.1 86 | - google-pasta==0.2.0 87 | - grpcio==1.26.0 88 | - gym==0.15.4 89 | - h5py==2.10.0 90 | - idna==2.8 91 | - imageio==2.6.1 92 | - imageio-ffmpeg==0.4.1 93 | - importlib-metadata==1.4.0 94 | - ipdb==0.12.3 95 | - ipython-genutils==0.2.0 96 | - joblib==0.14.1 97 | - jsonpickle==0.9.6 98 | - keras-applications==1.0.8 99 | - keras-preprocessing==1.1.0 100 | - kiwisolver==1.1.0 101 | - markdown==3.1.1 102 | - matplotlib==3.1.2 103 | - mccabe==0.6.1 104 | - mock==3.0.5 105 | - monotonic==1.5 106 | - more-itertools==8.1.0 107 | - munch==2.5.0 108 | - mujoco-py==2.0.2.9 109 | - mypy==0.770 110 | - mypy-extensions==0.4.3 111 | - numpy==1.20.1 112 | - opencv-python==4.1.2.30 113 | - opt-einsum==3.2.1 114 | - packaging==20.0 115 | - pandas==0.25.3 116 | - pathspec==0.7.0 117 | - pexpect==4.8.0 118 | - pickleshare==0.7.5 119 | - pillow==7.0.0 120 | - pluggy==0.13.1 121 | - protobuf==3.11.2 122 | - ptyprocess==0.6.0 123 | - py==1.8.1 124 | - py-cpuinfo==5.0.0 125 | - pycodestyle==2.5.0 126 | - pycparser==2.19 127 | - pyflakes==2.1.1 128 | - pyglet==1.3.1 129 | - pyopengl==3.1.5 130 | - pyparsing==2.4.6 131 | - pyrsistent==0.15.7 132 | - pytest==5.3.4 133 | - pytz==2019.3 134 | - pyyaml==5.3 135 | - ray==0.8.0 136 | - redis==3.3.11 137 | - regex==2020.2.20 138 | - requests==2.22.0 139 | - sacred==0.7.4 140 | - scipy==1.4.1 141 | - seaborn==0.9.0 142 | - six==1.14.0 143 | - smmap2==2.0.5 144 | - stable-baselines==2.9.0 145 | - tensorboard==1.13.1 146 | - tensorflow==1.13.2 147 | - tensorflow-estimator==1.13.0 148 | - tensorflow-probability==0.6.0 149 | - termcolor==1.1.0 150 | - toml==0.10.0 151 | - typed-ast==1.4.1 152 | - typing-extensions==3.7.4.2 153 | - urllib3==1.25.8 154 | - wcwidth==0.1.8 155 | - werkzeug==0.16.0 156 | - wrapt==1.11.2 157 | - zipp==2.0.0 158 | prefix: /home/david/anaconda3/envs/deep-rlsp 159 | -------------------------------------------------------------------------------- /policies/sac_ant_2e6.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_ant_2e6.mp4 -------------------------------------------------------------------------------- /policies/sac_ant_2e6.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_ant_2e6.zip -------------------------------------------------------------------------------- /policies/sac_cheetah_bw_2e6.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_cheetah_bw_2e6.mp4 -------------------------------------------------------------------------------- /policies/sac_cheetah_bw_2e6.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_cheetah_bw_2e6.zip -------------------------------------------------------------------------------- /policies/sac_cheetah_fw_2e6.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_cheetah_fw_2e6.mp4 -------------------------------------------------------------------------------- /policies/sac_cheetah_fw_2e6.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_cheetah_fw_2e6.zip -------------------------------------------------------------------------------- /policies/sac_hopper_2e6.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_hopper_2e6.mp4 -------------------------------------------------------------------------------- /policies/sac_hopper_2e6.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_hopper_2e6.zip -------------------------------------------------------------------------------- /policies/sac_pendulum_6e4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_pendulum_6e4.mp4 -------------------------------------------------------------------------------- /policies/sac_pendulum_6e4.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_pendulum_6e4.zip -------------------------------------------------------------------------------- /policies/sac_reach_task_2e6.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_reach_task_2e6.mp4 -------------------------------------------------------------------------------- /policies/sac_reach_task_2e6.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_reach_task_2e6.zip -------------------------------------------------------------------------------- /policies/sac_reach_true_2e6.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_reach_true_2e6.mp4 -------------------------------------------------------------------------------- /policies/sac_reach_true_2e6.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_reach_true_2e6.zip -------------------------------------------------------------------------------- /scripts/create_demonstrations.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import numpy as np 3 | import tensorflow as tf 4 | import gym 5 | import pickle 6 | import sys 7 | 8 | from stable_baselines import SAC 9 | 10 | from imitation.data import types 11 | from imitation.data.rollout import unwrap_traj 12 | import deep_rlsp 13 | 14 | 15 | def convert_trajs(filename, traj_len): 16 | with open(filename, "rb") as f: 17 | data = pickle.load(f) 18 | 19 | assert traj_len < len(data["observations"][0]) 20 | obs = np.array(data["observations"][0][: traj_len + 1]) 21 | acts = np.array(data["actions"][0][:traj_len]) 22 | rews = np.array([0 for _ in range(traj_len)]) 23 | infos = [{} for _ in range(traj_len)] 24 | traj = types.TrajectoryWithRew(obs=obs, acts=acts, infos=infos, rews=rews) 25 | return [traj] 26 | 27 | 28 | def rollout_policy(filename, traj_len, seed, env_name, n_trajs=1): 29 | model = SAC.load(filename) 30 | env = gym.make(env_name) 31 | env.seed(seed) 32 | 33 | trajs = [] 34 | for _ in range(int(n_trajs)): 35 | obs_list, acts_list, rews_list = [], [], [] 36 | obs = env.reset() 37 | obs_list.append(obs) 38 | for _ in range(traj_len): 39 | act = model.predict(obs, deterministic=True)[0] 40 | obs, r, done, _ = env.step(act) 41 | # assert not done 42 | acts_list.append(act) 43 | obs_list.append(obs) 44 | rews_list.append(r) 45 | 46 | infos = [{} for _ in range(traj_len)] 47 | traj = types.TrajectoryWithRew( 48 | obs=np.array(obs_list), 49 | acts=np.array(acts_list), 50 | infos=infos, 51 | rews=np.array(rews_list), 52 | ) 53 | trajs.append(traj) 54 | 55 | return trajs 56 | 57 | 58 | def recode_and_save_trajectories(traj_or_policy_file, save_loc, traj_len, seed, args): 59 | if "skills" in traj_or_policy_file: 60 | trajs = convert_trajs(traj_or_policy_file, traj_len, *args) 61 | else: 62 | trajs = rollout_policy(traj_or_policy_file, traj_len, seed, *args) 63 | 64 | # assert len(trajs) == 1 65 | for traj in trajs: 66 | assert traj.obs.shape[0] == traj_len + 1 67 | assert traj.acts.shape[0] == traj_len 68 | trajs = [dataclasses.replace(traj, infos=None) for traj in trajs] 69 | types.save(save_loc, trajs) 70 | 71 | 72 | if __name__ == "__main__": 73 | _, traj_or_policy_file, save_loc, traj_len, seed = sys.argv[:5] 74 | if seed == "generate_seed": 75 | seed = np.random.randint(0, 1e9) 76 | else: 77 | seed = int(seed) 78 | save_loc = save_loc.format(traj_len, seed) 79 | np.random.seed(seed) 80 | tf.random.set_random_seed(seed) 81 | traj_len = int(traj_len) 82 | recode_and_save_trajectories( 83 | traj_or_policy_file, save_loc, traj_len, seed, sys.argv[5:] 84 | ) 85 | -------------------------------------------------------------------------------- /scripts/evaluate_policy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import numpy as np 5 | import gym 6 | 7 | from stable_baselines import PPO2, SAC 8 | 9 | from deep_rlsp.util.video import render_mujoco_from_obs, save_video 10 | from deep_rlsp.model.mujoco_debug_models import MujocoDebugFeatures 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("policy_file", type=str) 16 | parser.add_argument("policy_type", type=str) 17 | parser.add_argument("envname", type=str) 18 | parser.add_argument("--render", action="store_true") 19 | parser.add_argument("--out_video", type=str, default=None) 20 | parser.add_argument( 21 | "--num_rollouts", type=int, default=1, help="Number of expert rollouts" 22 | ) 23 | return parser.parse_args() 24 | 25 | 26 | def main(): 27 | args = parse_args() 28 | env = gym.make(args.envname) 29 | 30 | print("loading and building expert policy") 31 | 32 | if args.policy_type == "ppo": 33 | model = PPO2.load(args.policy_file) 34 | 35 | def get_action(obs): 36 | return model.predict(obs)[0] 37 | 38 | elif args.policy_type == "sac": 39 | model = SAC.load(args.policy_file) 40 | 41 | def get_action(obs): 42 | return model.predict(obs, deterministic=True)[0] 43 | 44 | elif args.policy_type == "gail": 45 | from imitation.policies import serialize 46 | from stable_baselines.common.vec_env import DummyVecEnv 47 | 48 | venv = DummyVecEnv([lambda: env]) 49 | loading_context = serialize.load_policy("ppo2", args.policy_file, venv) 50 | model = loading_context.__enter__() 51 | 52 | def get_action(obs): 53 | return model.step(np.reshape(obs, (1, -1)))[0] 54 | 55 | else: 56 | raise NotImplementedError() 57 | 58 | # # env.unwrapped.viewer.cam.trackbodyid = 0 59 | # env.unwrapped.viewer.cam.fixedcamid = 0 60 | 61 | returns = [] 62 | observations = [] 63 | render_params = {} 64 | # render_params = {"width": 4000, "height":1000, "camera_id": -1} 65 | timesteps = 1000 66 | 67 | for i in range(args.num_rollouts): 68 | print("iter", i) 69 | if args.out_video is not None: 70 | rgbs = [] 71 | obs = env.reset() 72 | done = False 73 | totalr = 0.0 74 | steps = 0 75 | for _ in range(timesteps): 76 | # print(obs) 77 | action = get_action(obs) 78 | action = env.action_space.sample() 79 | observations.append(obs) 80 | 81 | last_done = done 82 | obs, r, done, info = env.step(action) 83 | 84 | if not last_done: 85 | totalr += r 86 | steps += 1 87 | if args.render: 88 | env.render(mode="human", **render_params) 89 | if args.out_video is not None: 90 | rgb = env.render(mode="rgb_array", **render_params) 91 | rgbs.append(rgb) 92 | if steps % 100 == 0: 93 | print("%i/%i" % (steps, env.spec.max_episode_steps)) 94 | if steps >= env.spec.max_episode_steps: 95 | break 96 | print("return", totalr) 97 | returns.append(totalr) 98 | 99 | if args.out_video is not None: 100 | save_video(rgbs, args.out_video, fps=20.0) 101 | 102 | print("returns", returns) 103 | print("mean return", np.mean(returns)) 104 | print("std of return", np.std(returns)) 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /scripts/evaluate_result_policies.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import numpy as np 5 | import gym 6 | 7 | from stable_baselines import PPO2, SAC 8 | 9 | from deep_rlsp.util.video import render_mujoco_from_obs, save_video 10 | from deep_rlsp.model import StateVAE 11 | from deep_rlsp.util.results import FileExperimentResults 12 | 13 | 14 | def evaluate_policy(policy_file, policy_type, envname, num_rollouts): 15 | if policy_type == "ppo": 16 | model = PPO2.load(policy_file) 17 | 18 | def get_action(obs): 19 | return model.predict(obs)[0] 20 | 21 | elif policy_type == "sac": 22 | model = SAC.load(policy_file) 23 | 24 | def get_action(obs): 25 | return model.predict(obs, deterministic=True)[0] 26 | 27 | else: 28 | raise NotImplementedError() 29 | 30 | env = gym.make(envname) 31 | 32 | returns = [] 33 | for i in range(num_rollouts): 34 | # print("iter", i, end=" ") 35 | obs = env.reset() 36 | done = False 37 | totalr = 0.0 38 | while not done: 39 | action = get_action(obs) 40 | obs, r, done, _ = env.step(action) 41 | totalr += r 42 | returns.append(totalr) 43 | 44 | return np.mean(returns), np.std(returns) 45 | 46 | 47 | def parse_args(): 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("results_folder", type=str) 50 | parser.add_argument("envname", type=str) 51 | parser.add_argument("--policy_type", type=str, default="sac") 52 | parser.add_argument( 53 | "--num_rollouts", type=int, default=1, help="Number of expert rollouts" 54 | ) 55 | return parser.parse_args() 56 | 57 | 58 | def main(): 59 | args = parse_args() 60 | 61 | base, root_dirs, _ = next(os.walk(args.results_folder)) 62 | root_dirs = [os.path.join(base, dir) for dir in root_dirs] 63 | for root in root_dirs: 64 | print(root) 65 | _, dirs, files = next(os.walk(root)) 66 | if "policy.zip" in files: 67 | policy_path = os.path.join(root, "policy.zip") 68 | elif any([f.startswith("rlsp_policy") for f in files]): 69 | policy_files = [f for f in files if f.startswith("rlsp_policy")] 70 | policy_numbers = [int(f.split(".")[0].split("_")[2]) for f in policy_files] 71 | # take second to last policy, in case a run crashed while writing 72 | # the last policy 73 | policy_file = f"rlsp_policy_{max(policy_numbers)-1}.zip" 74 | policy_path = policy_path = os.path.join(root, policy_file) 75 | else: 76 | policy_path = None 77 | 78 | if policy_path is not None: 79 | mean, std = evaluate_policy( 80 | policy_path, args.policy_type, args.envname, args.num_rollouts 81 | ) 82 | print(policy_path) 83 | print(f"mean {mean} std {std}") 84 | 85 | 86 | if __name__ == "__main__": 87 | main() 88 | -------------------------------------------------------------------------------- /scripts/mujoco_evaluate_inferred_reward.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate an inferred reward function by using it to train a policy in the original env. 3 | """ 4 | 5 | import argparse 6 | import datetime 7 | 8 | import cv2 9 | import gym 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | import tensorflow as tf 13 | 14 | from sacred import Experiment 15 | from sacred.observers import FileStorageObserver, RunObserver 16 | 17 | from stable_baselines import SAC 18 | from stable_baselines.sac.policies import MlpPolicy as MlpPolicySac 19 | 20 | from deep_rlsp.util.results import Artifact, FileExperimentResults 21 | from deep_rlsp.model import StateVAE 22 | from deep_rlsp.envs.reward_wrapper import LatentSpaceRewardWrapper 23 | from deep_rlsp.util.video import render_mujoco_from_obs 24 | from deep_rlsp.util.helper import get_trajectory, evaluate_policy 25 | from deep_rlsp.model.mujoco_debug_models import MujocoDebugFeatures, PendulumDynamics 26 | from deep_rlsp.solvers import get_sac 27 | 28 | # changes the run _id and thereby the path that the FileStorageObserver 29 | # writes the results 30 | # cf. https://github.com/IDSIA/sacred/issues/174 31 | class SetID(RunObserver): 32 | priority = 50 # very high priority to set id 33 | 34 | def started_event( 35 | self, ex_info, command, host_info, start_time, config, meta_info, _id 36 | ): 37 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 38 | label = config["experiment_folder"].strip("/").split("/")[-1] 39 | custom_id = "{}_{}".format(timestamp, label) 40 | return custom_id # started_event returns the _run._id 41 | 42 | 43 | ex = Experiment("mujoco-eval") 44 | ex.observers = [SetID(), FileStorageObserver.create("results/mujoco/eval")] 45 | 46 | 47 | def print_rollout(env, policy, latent_space, decode=False): 48 | state = env.reset() 49 | done = False 50 | while not done: 51 | a, _ = policy.predict(state, deterministic=False) 52 | state, reward, done, info = env.step(a) 53 | if decode: 54 | obs = latent_space.decoder(state) 55 | else: 56 | obs = state 57 | print("action", a) 58 | print("obs", obs) 59 | print("reward", reward) 60 | 61 | 62 | @ex.config 63 | def config(): 64 | experiment_folder = None # noqa:F841 65 | iteration = -1 # noqa:F841 66 | 67 | 68 | @ex.automain 69 | def main(_run, experiment_folder, iteration, seed): 70 | ex = FileExperimentResults(experiment_folder) 71 | env_id = ex.config["env_id"] 72 | env = gym.make(env_id) 73 | 74 | if env_id == "InvertedPendulum-v2": 75 | iterations = int(6e4) 76 | else: 77 | iterations = int(2e6) 78 | 79 | label = experiment_folder.strip("/").split("/")[-1] 80 | 81 | if ex.config["debug_handcoded_features"]: 82 | latent_space = MujocoDebugFeatures(env) 83 | else: 84 | graph_latent = tf.Graph() 85 | latent_model_path = ex.info["latent_model_checkpoint"] 86 | with graph_latent.as_default(): 87 | latent_space = StateVAE.restore(latent_model_path) 88 | 89 | r_inferred = ex.info["inferred_rewards"][iteration] 90 | r_inferred /= np.linalg.norm(r_inferred) 91 | 92 | print("r_inferred") 93 | print(r_inferred) 94 | if env_id.startswith("Fetch"): 95 | env_has_task_reward = True 96 | inferred_weight = 0.1 97 | else: 98 | env_has_task_reward = False 99 | inferred_weight = None 100 | 101 | env_inferred = LatentSpaceRewardWrapper( 102 | env, 103 | latent_space, 104 | r_inferred, 105 | inferred_weight=inferred_weight, 106 | use_task_reward=env_has_task_reward, 107 | ) 108 | 109 | policy_inferred = get_sac(env_inferred) 110 | policy_inferred.learn(total_timesteps=iterations, log_interval=10) 111 | with Artifact(f"policy.zip", None, _run) as f: 112 | policy_inferred.save(f) 113 | 114 | print_rollout(env_inferred, policy_inferred, latent_space) 115 | 116 | N = 10 117 | true_reward_obtained = evaluate_policy(env, policy_inferred, N) 118 | print("Inferred reward policy: true return", true_reward_obtained) 119 | if env_has_task_reward: 120 | env.use_penalty = False 121 | task_reward_obtained = evaluate_policy(env, policy_inferred, N) 122 | print("Inferred reward policy: task return", task_reward_obtained) 123 | env.use_penalty = True 124 | with Artifact(f"video.mp4", None, _run) as f: 125 | inferred_reward_obtained = evaluate_policy( 126 | env_inferred, policy_inferred, N, video_out=f 127 | ) 128 | print("Inferred reward policy: inferred return", inferred_reward_obtained) 129 | 130 | good_policy_path = ex.config["good_policy_path"] 131 | if good_policy_path is not None: 132 | true_reward_policy = SAC.load(good_policy_path) 133 | good_policy_true_reward_obtained = evaluate_policy(env, true_reward_policy, N) 134 | print("True reward policy: true return", good_policy_true_reward_obtained) 135 | if env_has_task_reward: 136 | env.use_penalty = False 137 | good_policy_task_reward_obtained = evaluate_policy( 138 | env, true_reward_policy, N 139 | ) 140 | print("True reward policy: task return", good_policy_task_reward_obtained) 141 | env.use_penalty = True 142 | good_policy_inferred_reward_obtained = evaluate_policy( 143 | env_inferred, true_reward_policy, N 144 | ) 145 | print( 146 | "True reward policy: inferred return", good_policy_inferred_reward_obtained 147 | ) 148 | 149 | random_policy = SAC(MlpPolicySac, env_inferred, verbose=1) 150 | random_policy_true_reward_obtained = evaluate_policy(env, random_policy, N) 151 | print("Random policy: true return", random_policy_true_reward_obtained) 152 | if env_has_task_reward: 153 | env.use_penalty = False 154 | random_policy_task_reward_obtained = evaluate_policy(env, random_policy, N) 155 | print("Random reward policy: task return", random_policy_task_reward_obtained) 156 | env.use_penalty = True 157 | random_policy_inferred_reward_obtained = evaluate_policy( 158 | env_inferred, random_policy, N 159 | ) 160 | print("Random policy: inferred return", random_policy_inferred_reward_obtained) 161 | print() 162 | -------------------------------------------------------------------------------- /scripts/plot_learning_curves.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import matplotlib.pyplot as plt 4 | import matplotlib 5 | import seaborn as sns 6 | import numpy as np 7 | 8 | 9 | def moving_average(a, n=3): 10 | ret = np.cumsum(a, dtype=float) 11 | ret[n:] = ret[n:] - ret[:-n] 12 | return ret[n - 1 :] / n 13 | 14 | 15 | def main(): 16 | n_samples = 1 17 | skill = "balancing" 18 | loss_files = [ 19 | f"discriminator_cheetah_{skill}_{n_samples}_gail.pkl", 20 | f"discriminator_cheetah_{skill}_{n_samples}_rlsp.pkl", 21 | f"discriminator_cheetah_{skill}_{n_samples}_average_features.pkl", 22 | f"discriminator_cheetah_{skill}_{n_samples}_waypoints.pkl", 23 | ] 24 | labels = ["GAIL", "Deep RLSP", "AverageFeatures", "Waypoints"] 25 | outfile = f"{skill}_{n_samples}.pdf" 26 | moving_average_n = 10 27 | 28 | sns.set_context("paper", font_scale=3.6, rc={"lines.linewidth": 3}) 29 | sns.set_style("white") 30 | matplotlib.rc( 31 | "font", 32 | **{ 33 | "family": "serif", 34 | "serif": ["Computer Modern"], 35 | "sans-serif": ["Latin Modern"], 36 | }, 37 | ) 38 | matplotlib.rc("text", usetex=True) 39 | 40 | markers_every = 100 41 | markersize = 10 42 | 43 | colors = [ 44 | "#377eb8", 45 | "#ff7f00", 46 | "#4daf4a", 47 | "#e41a1c", 48 | "#dede00", 49 | "#999999", 50 | "#f781bf", 51 | "#a65628", 52 | "#984ea3", 53 | ] 54 | 55 | markers = ["o", "^", "s", "v", "d", "+", "x", "."] 56 | 57 | L = len(loss_files) 58 | colors = colors[:L] 59 | markers = markers[:L] 60 | 61 | # plt.figure(figsize=(20, 6)) 62 | for loss_file, label, color, marker in zip(loss_files, labels, colors, markers): 63 | print(loss_file) 64 | with open(loss_file, "rb") as f: 65 | losses = pickle.load(f) 66 | print(losses.shape) 67 | 68 | if moving_average_n > 1: 69 | losses = moving_average(losses, moving_average_n) 70 | 71 | plt.plot( 72 | range(len(losses)), 73 | losses, 74 | label=label, 75 | color=color, 76 | marker=marker, 77 | markevery=markers_every, 78 | markersize=markersize, 79 | ) 80 | plt.xlabel("iterations") 81 | plt.ylabel("cross-entropy loss") 82 | 83 | # plt.legend() 84 | plt.tight_layout() 85 | plt.savefig(outfile) 86 | # plt.show() 87 | 88 | 89 | if __name__ == "__main__": 90 | main() 91 | -------------------------------------------------------------------------------- /scripts/run_gail.py: -------------------------------------------------------------------------------- 1 | import deep_rlsp 2 | from imitation.scripts.train_adversarial import train_ex 3 | from sacred.observers import FileStorageObserver 4 | import os.path as osp 5 | 6 | 7 | def main_console(): 8 | observer = FileStorageObserver(osp.join("output", "sacred", "train")) 9 | train_ex.observers.append(observer) 10 | train_ex.run_commandline() 11 | 12 | 13 | if __name__ == "__main__": # pragma: no cover 14 | main_console() 15 | -------------------------------------------------------------------------------- /scripts/train_inverse_dynamics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | 4 | import gym 5 | import numpy as np 6 | 7 | from deep_rlsp.model import ExperienceReplay, InverseDynamicsMLP, InverseDynamicsMDN 8 | from deep_rlsp.util.video import render_mujoco_from_obs, save_video 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("env_id", type=str) 14 | parser.add_argument("--n_rollouts", default=1000, type=int) 15 | parser.add_argument("--learning_rate", default=1e-5, type=float) 16 | parser.add_argument("--n_epochs", default=100, type=int) 17 | parser.add_argument("--batch_size", default=500, type=int) 18 | parser.add_argument("--n_layers", default=5, type=int) 19 | parser.add_argument("--layer_size", default=1024, type=int) 20 | parser.add_argument("--gridworlds", action="store_true") 21 | return parser.parse_args() 22 | 23 | 24 | def main(): 25 | args = parse_args() 26 | 27 | env = gym.make(args.env_id) 28 | experience_replay = ExperienceReplay(None) 29 | experience_replay.add_random_rollouts( 30 | env, env.spec.max_episode_steps, args.n_rollouts 31 | ) 32 | 33 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 34 | label = "mlp_{}_{}".format(args.env_id, timestamp) 35 | tensorboard_log = "tf_logs/tf_logs_" + label 36 | checkpoint_folder = "tf_ckpt/tf_ckpt_" + label 37 | 38 | if args.gridworlds: 39 | inverse_dynamics = InverseDynamicsMDN( 40 | env, 41 | experience_replay, 42 | hidden_layer_size=args.layer_size, 43 | n_hidden_layers=args.n_layers, 44 | learning_rate=args.learning_rate, 45 | tensorboard_log=tensorboard_log, 46 | checkpoint_folder=checkpoint_folder, 47 | gauss_stdev=0.1, 48 | n_out_states=3, 49 | ) 50 | else: 51 | inverse_dynamics = InverseDynamicsMLP( 52 | env, 53 | experience_replay, 54 | hidden_layer_size=args.layer_size, 55 | n_hidden_layers=args.n_layers, 56 | learning_rate=args.learning_rate, 57 | tensorboard_log=tensorboard_log, 58 | checkpoint_folder=checkpoint_folder, 59 | ) 60 | 61 | inverse_dynamics.learn( 62 | n_epochs=args.n_epochs, 63 | batch_size=args.batch_size, 64 | print_evaluation=False, 65 | verbose=True, 66 | ) 67 | 68 | # Evaluation 69 | 70 | if args.gridworlds: 71 | obs = env.reset() 72 | for _ in range(env.time_horizon): 73 | s = env.obs_to_s(obs) 74 | 75 | print("agent_pos", s.agent_pos) 76 | print("vase_states", s.vase_states) 77 | print(np.reshape(obs, env.obs_shape)[:, :, 0]) 78 | 79 | action = env.action_space.sample() 80 | obs, reward, done, info = env.step(action) 81 | obs_bwd = inverse_dynamics.step(obs, action) 82 | 83 | s = env.obs_to_s(obs_bwd) 84 | print("bwd: agent_pos", s.agent_pos) 85 | print("bwd: vase_states", s.vase_states) 86 | print(np.reshape(obs_bwd, env.obs_shape)[:, :, 0]) 87 | 88 | print() 89 | print("action", action) 90 | print() 91 | else: 92 | obs = env.reset() 93 | for _ in range(env.spec.max_episode_steps): 94 | print("obs", obs) 95 | action = env.action_space.sample() 96 | obs, reward, done, info = env.step(action) 97 | obs_bwd = inverse_dynamics.step(obs, action) 98 | with np.printoptions(suppress=True): 99 | print("action", action) 100 | print("\tbwd", obs_bwd) 101 | 102 | obs = env.reset() 103 | rgbs = [] 104 | for _ in range(env.spec.max_episode_steps): 105 | obs = np.clip(obs, -30, 30) 106 | rgb = render_mujoco_from_obs(env, obs) 107 | rgbs.append(rgb) 108 | action = env.action_space.sample() 109 | obs = inverse_dynamics.step(obs, action, sample=True) 110 | save_video(rgbs, f"train_{args.env_id}.avi", 20) 111 | 112 | 113 | if __name__ == "__main__": 114 | main() 115 | -------------------------------------------------------------------------------- /scripts/train_sac.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | from deep_rlsp.solvers import get_sac 5 | 6 | # for envs 7 | import deep_rlsp 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("env_id", type=str) 13 | parser.add_argument("policy_out", type=str) 14 | parser.add_argument("--timesteps", default=int(2e6), type=int) 15 | return parser.parse_args() 16 | 17 | 18 | def main(): 19 | args = parse_args() 20 | env = gym.make(args.env_id) 21 | 22 | solver = get_sac(env, learning_starts=1000) 23 | solver.learn(total_timesteps=args.timesteps) 24 | solver.save(args.policy_out) 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /scripts/train_vae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | 4 | import gym 5 | import numpy as np 6 | 7 | from deep_rlsp.model import StateVAE, ExperienceReplay 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("env_id", type=str) 13 | parser.add_argument("--n_rollouts", default=100, type=int) 14 | parser.add_argument("--state_size", default=30, type=int) 15 | parser.add_argument("--learning_rate", default=1e-5, type=float) 16 | parser.add_argument("--n_epochs", default=100, type=int) 17 | parser.add_argument("--batch_size", default=500, type=int) 18 | parser.add_argument("--n_layers", default=3, type=int) 19 | parser.add_argument("--layer_size", default=512, type=int) 20 | parser.add_argument("--prior_stdev", default=1, type=float) 21 | parser.add_argument("--divergence_factor", default=0.001, type=float) 22 | return parser.parse_args() 23 | 24 | 25 | def main(): 26 | args = parse_args() 27 | 28 | env = gym.make(args.env_id) 29 | experience_replay = ExperienceReplay(None) 30 | experience_replay.add_random_rollouts( 31 | env, env.spec.max_episode_steps, args.n_rollouts 32 | ) 33 | 34 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 35 | label = "vae_{}_{}".format(args.env_id, timestamp) 36 | tensorboard_log = "tf_logs/tf_logs_" + label 37 | checkpoint_folder = "tf_ckpt/tf_ckpt_" + label 38 | 39 | vae = StateVAE( 40 | env.observation_space.shape[0], 41 | args.state_size, 42 | n_layers=args.n_layers, 43 | layer_size=args.layer_size, 44 | learning_rate=args.learning_rate, 45 | prior_stdev=args.prior_stdev, 46 | divergence_factor=args.divergence_factor, 47 | tensorboard_log=tensorboard_log, 48 | checkpoint_folder=checkpoint_folder, 49 | ) 50 | # vae.checkpoint_folder = None 51 | 52 | vae.learn(experience_replay, args.n_epochs, args.batch_size, verbose=True) 53 | 54 | with np.printoptions(suppress=True): 55 | for _ in range(20): 56 | x = experience_replay.sample(1)[0][0] 57 | print("x", x) 58 | z = vae.encoder(x) 59 | print("z", z) 60 | x = vae.decoder(z) 61 | print("x2", x) 62 | 63 | 64 | if __name__ == "__main__": 65 | main() 66 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="deep-rlsp", 5 | version=0.1, 6 | description="Learning What To Do by Simulating the Past", 7 | author="David Lindner, Rohin Shah, et al", 8 | python_requires=">=3.7.0", 9 | url="https://github.com/HumanCompatibleAI/deep-rlsp", 10 | packages=find_packages("src"), 11 | package_dir={"": "src"}, 12 | install_requires=[ 13 | "numpy>=1.13", 14 | "scipy>=0.19", 15 | "sacred==0.8.2", 16 | "ray", 17 | "stable_baselines", 18 | "tensorflow==1.13.2", 19 | "tensorflow-probability==0.6.0", 20 | "seaborn", 21 | "gym", 22 | ], 23 | test_suite="nose.collector", 24 | tests_require=["nose", "nose-cover3"], 25 | include_package_data=True, 26 | license="MIT", 27 | classifiers=[ 28 | # Trove classifiers 29 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 30 | "License :: OSI Approved :: MIT License", 31 | "Programming Language :: Python", 32 | "Programming Language :: Python :: 3", 33 | "Programming Language :: Python :: 3.7", 34 | ], 35 | ) 36 | -------------------------------------------------------------------------------- /skills/balancing.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/skills/balancing.mp4 -------------------------------------------------------------------------------- /skills/balancing_rollouts.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/skills/balancing_rollouts.pkl -------------------------------------------------------------------------------- /skills/jumping.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/skills/jumping.mp4 -------------------------------------------------------------------------------- /skills/jumping_rollouts.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/skills/jumping_rollouts.pkl -------------------------------------------------------------------------------- /src/deep_rlsp/ablation_AverageFeatures.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import gym 4 | import numpy as np 5 | 6 | from sacred import Experiment 7 | from sacred.observers import FileStorageObserver, RunObserver 8 | 9 | from deep_rlsp.envs.reward_wrapper import LatentSpaceRewardWrapper 10 | from deep_rlsp.model import StateVAE 11 | from deep_rlsp.util.results import Artifact, FileExperimentResults 12 | from deep_rlsp.solvers import get_sac 13 | from deep_rlsp.util.helper import evaluate_policy 14 | 15 | 16 | # changes the run _id and thereby the path that the FileStorageObserver 17 | # writes the results 18 | # cf. https://github.com/IDSIA/sacred/issues/174 19 | class SetID(RunObserver): 20 | priority = 50 # very high priority to set id 21 | 22 | def started_event( 23 | self, ex_info, command, host_info, start_time, config, meta_info, _id 24 | ): 25 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 26 | if config["result_folder"] is not None: 27 | result_folder = config["result_folder"].strip("/").split("/")[-1] 28 | custom_id = f"{timestamp}_ablation_average_features_{result_folder}" 29 | else: 30 | custom_id = f"{timestamp}_ablation_average_features" 31 | return custom_id # started_event returns the _run._id 32 | 33 | 34 | ex = Experiment("mujoco-ablation-average-features") 35 | ex.observers = [ 36 | SetID(), 37 | FileStorageObserver.create("results/mujoco/ablation_average_features"), 38 | ] 39 | 40 | 41 | @ex.config 42 | def config(): 43 | result_folder = None # noqa:F841 44 | 45 | 46 | @ex.automain 47 | def main(_run, result_folder, seed): 48 | ex = FileExperimentResults(result_folder) 49 | env_id = ex.config["env_id"] 50 | latent_model_checkpoint = ex.info["latent_model_checkpoint"] 51 | current_states = ex.info["current_states"] 52 | 53 | if env_id == "InvertedPendulum-v2": 54 | iterations = int(6e4) 55 | else: 56 | iterations = int(2e6) 57 | 58 | env = gym.make(env_id) 59 | 60 | latent_space = StateVAE.restore(latent_model_checkpoint) 61 | 62 | r_vec = sum([latent_space.encoder(obs) for obs in current_states]) 63 | r_vec /= np.linalg.norm(r_vec) 64 | env_inferred = LatentSpaceRewardWrapper(env, latent_space, r_vec) 65 | 66 | solver = get_sac(env_inferred) 67 | solver.learn(iterations) 68 | 69 | with Artifact(f"policy.zip", None, _run) as f: 70 | solver.save(f) 71 | 72 | N = 10 73 | true_reward_obtained = evaluate_policy(env, solver, N) 74 | inferred_reward_obtained = evaluate_policy(env_inferred, solver, N) 75 | print("Policy: true return", true_reward_obtained) 76 | print("Policy: inferred return", inferred_reward_obtained) 77 | -------------------------------------------------------------------------------- /src/deep_rlsp/ablation_Waypoints.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import gym 4 | import numpy as np 5 | 6 | from sacred import Experiment 7 | from sacred.observers import FileStorageObserver, RunObserver 8 | 9 | from deep_rlsp.model import StateVAE 10 | from deep_rlsp.util.results import Artifact, FileExperimentResults 11 | from deep_rlsp.solvers import get_sac 12 | from deep_rlsp.util.helper import evaluate_policy 13 | 14 | 15 | # changes the run _id and thereby the path that the FileStorageObserver 16 | # writes the results 17 | # cf. https://github.com/IDSIA/sacred/issues/174 18 | class SetID(RunObserver): 19 | priority = 50 # very high priority to set id 20 | 21 | def started_event( 22 | self, ex_info, command, host_info, start_time, config, meta_info, _id 23 | ): 24 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") 25 | if config["result_folder"] is not None: 26 | result_folder = config["result_folder"].strip("/").split("/")[-1] 27 | custom_id = f"{timestamp}_ablation_waypoints_{result_folder}" 28 | else: 29 | custom_id = f"{timestamp}_ablation_waypoints" 30 | return custom_id # started_event returns the _run._id 31 | 32 | 33 | ex = Experiment("mujoco-ablation-waypoints") 34 | ex.observers = [ 35 | SetID(), 36 | FileStorageObserver.create("results/mujoco/ablation_waypoints"), 37 | ] 38 | 39 | 40 | class LatentSpaceTargetStateRewardWrapper(gym.Wrapper): 41 | def __init__(self, env, latent_space, target_states): 42 | self.env = env 43 | self.latent_space = latent_space 44 | self.target_states = [ts / np.linalg.norm(ts) for ts in target_states] 45 | self.state = None 46 | self.timestep = 0 47 | super().__init__(env) 48 | 49 | def reset(self): 50 | obs = super().reset() 51 | self.state = self.latent_space.encoder(obs) 52 | self.timestep = 0 53 | return obs 54 | 55 | def step(self, action: int): 56 | action = np.clip(action, self.action_space.low, self.action_space.high) 57 | obs, true_reward, done, info = self.env.step(action) 58 | self.state = self.latent_space.encoder(obs) 59 | 60 | waypoints = [np.dot(ts, self.state) for ts in self.target_states] 61 | reward = max(waypoints) 62 | 63 | # remove termination criteria from mujoco environments 64 | self.timestep += 1 65 | done = self.timestep > self.env.spec.max_episode_steps 66 | 67 | return obs, reward, done, info 68 | 69 | 70 | @ex.config 71 | def config(): 72 | result_folder = None # noqa:F841 73 | 74 | 75 | @ex.automain 76 | def main(_run, result_folder, seed): 77 | # result_folder = "results/mujoco/20200706_153755_HalfCheetah-FW-v2_optimal_50" 78 | ex = FileExperimentResults(result_folder) 79 | env_id = ex.config["env_id"] 80 | latent_model_checkpoint = ex.info["latent_model_checkpoint"] 81 | current_states = ex.info["current_states"] 82 | 83 | if env_id == "InvertedPendulum-v2": 84 | iterations = int(6e4) 85 | else: 86 | iterations = int(2e6) 87 | 88 | env = gym.make(env_id) 89 | 90 | latent_space = StateVAE.restore(latent_model_checkpoint) 91 | 92 | target_states = [latent_space.encoder(obs) for obs in current_states] 93 | env_inferred = LatentSpaceTargetStateRewardWrapper(env, latent_space, target_states) 94 | 95 | solver = get_sac(env_inferred) 96 | solver.learn(iterations) 97 | 98 | with Artifact(f"policy.zip", None, _run) as f: 99 | solver.save(f) 100 | 101 | N = 10 102 | true_reward_obtained = evaluate_policy(env, solver, N) 103 | inferred_reward_obtained = evaluate_policy(env_inferred, solver, N) 104 | print("Policy: true return", true_reward_obtained) 105 | print("Policy: inferred return", inferred_reward_obtained) 106 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/__init__.py -------------------------------------------------------------------------------- /src/deep_rlsp/envs/gridworlds/__init__.py: -------------------------------------------------------------------------------- 1 | from .room import RoomEnv 2 | from .apples import ApplesEnv 3 | from .train import TrainEnv 4 | from .batteries import BatteriesEnv 5 | 6 | from .room_spec import ROOM_PROBLEMS 7 | from .apples_spec import APPLES_PROBLEMS 8 | from .train_spec import TRAIN_PROBLEMS 9 | from .batteries_spec import BATTERIES_PROBLEMS 10 | 11 | TOY_PROBLEMS = { 12 | "room": ROOM_PROBLEMS, 13 | "apples": APPLES_PROBLEMS, 14 | "train": TRAIN_PROBLEMS, 15 | "batteries": BATTERIES_PROBLEMS, 16 | } 17 | 18 | TOY_ENV_CLASSES = { 19 | "room": RoomEnv, 20 | "apples": ApplesEnv, 21 | "train": TrainEnv, 22 | "batteries": BatteriesEnv, 23 | } 24 | 25 | __all__ = [ 26 | "RoomEnv", 27 | "ApplesEnv", 28 | "TrainEnv", 29 | "BatteriesEnv", 30 | ] 31 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/gridworlds/apples_spec.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from deep_rlsp.envs.gridworlds.apples import ApplesState 3 | from deep_rlsp.envs.gridworlds.env import Direction 4 | 5 | 6 | class ApplesSpec(object): 7 | def __init__( 8 | self, 9 | height, 10 | width, 11 | init_state, 12 | apple_regen_probability, 13 | bucket_capacity, 14 | include_location_features, 15 | ): 16 | """See ApplesEnv.__init__ in apples.py for details.""" 17 | self.height = height 18 | self.width = width 19 | self.init_state = init_state 20 | self.apple_regen_probability = apple_regen_probability 21 | self.bucket_capacity = bucket_capacity 22 | self.include_location_features = include_location_features 23 | 24 | 25 | # In the diagrams below, T is a tree, B is a bucket, C is a carpet, A is the 26 | # agent. Each tuple is of the form (spec, current state, task R, true R). 27 | 28 | APPLES_PROBLEMS = { 29 | # ----- 30 | # |T T| 31 | # | | 32 | # | B | 33 | # | | 34 | # |A T| 35 | # ----- 36 | # After 11 actions (riuiruuildi), it looks like this: 37 | # ----- 38 | # |T T| 39 | # | A | 40 | # | B | 41 | # | | 42 | # | T| 43 | # ----- 44 | # Where the agent has picked the right trees once and put the fruit in the 45 | # basket. 46 | "default": ( 47 | ApplesSpec( 48 | 5, 49 | 3, 50 | ApplesState( 51 | agent_pos=(0, 0, 2), 52 | tree_states={(0, 0): True, (2, 0): True, (2, 4): True}, 53 | bucket_states={(1, 2): 0}, 54 | carrying_apple=False, 55 | ), 56 | apple_regen_probability=0.1, 57 | bucket_capacity=10, 58 | include_location_features=True, 59 | ), 60 | ApplesState( 61 | agent_pos=(Direction.get_number_from_direction(Direction.SOUTH), 1, 1), 62 | tree_states={(0, 0): True, (2, 0): False, (2, 4): True}, 63 | bucket_states={(1, 2): 2}, 64 | carrying_apple=False, 65 | ), 66 | np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 67 | np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 68 | ) 69 | } 70 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/gridworlds/basic_room.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from itertools import product 3 | 4 | from deep_rlsp.envs.gridworlds.env import Env, Direction, get_grid_representation 5 | 6 | 7 | class BasicRoomEnv(Env): 8 | """ 9 | Basic empty room with stochastic transitions. Used for debugging. 10 | """ 11 | 12 | def __init__(self, prob, use_pixels_as_observations=True): 13 | self.height = 3 14 | self.width = 3 15 | self.init_state = (1, 1) 16 | self.prob = prob 17 | self.nS = self.height * self.width 18 | self.nA = 5 19 | 20 | super().__init__(1, use_pixels_as_observations=use_pixels_as_observations) 21 | 22 | self.num_features = 2 23 | self.default_action = Direction.get_number_from_direction(Direction.STAY) 24 | self.num_features = len(self.s_to_f(self.init_state)) 25 | 26 | self.reset() 27 | 28 | states = self.enumerate_states() 29 | self.make_transition_matrices(states, range(self.nA), self.nS, self.nA) 30 | self.make_f_matrix(self.nS, self.num_features) 31 | 32 | def enumerate_states(self): 33 | return product(range(self.width), range(self.height)) 34 | 35 | def get_num_from_state(self, state): 36 | return np.ravel_multi_index(state, (self.width, self.height)) 37 | 38 | def get_state_from_num(self, num): 39 | return np.unravel_index(num, (self.width, self.height)) 40 | 41 | def s_to_f(self, s): 42 | return s 43 | 44 | def _obs_to_f(self, obs): 45 | return np.unravel_index(obs[0].argmax(), obs[0].shape) 46 | 47 | def _s_to_obs(self, s): 48 | layers = [[s]] 49 | obs = get_grid_representation(self.width, self.height, layers) 50 | return np.array(obs, dtype=np.float32) 51 | 52 | # render_width = 64 53 | # render_height = 64 54 | # x, y = s 55 | # obs = np.zeros((3, render_height, render_width), dtype=np.float32) 56 | # obs[ 57 | # :, 58 | # y * render_height : (y + 1) * render_height, 59 | # x * render_width : (x + 1) * render_width, 60 | # ] = 1 61 | # return obs 62 | 63 | def get_next_states(self, state, action): 64 | # next_states = [] 65 | # for a in range(self.nA): 66 | # next_s = self.get_next_state(state, a) 67 | # p = 1 - self.prob if a == action else self.prob / (self.nA - 1) 68 | # next_states.append((p, next_s, 0)) 69 | 70 | next_s = self.get_next_state(state, action) 71 | next_states = [(self.prob, next_s, 0), (1 - self.prob, state, 0)] 72 | return next_states 73 | 74 | def get_next_state(self, state, action): 75 | """Returns the next state given a state and an action.""" 76 | action = int(action) 77 | 78 | if action == Direction.get_number_from_direction(Direction.STAY): 79 | pass 80 | elif action < len(Direction.ALL_DIRECTIONS): 81 | move_x, move_y = Direction.move_in_direction_number(state, action) 82 | # New position is legal 83 | if 0 <= move_x < self.width and 0 <= move_y < self.height: 84 | state = move_x, move_y 85 | else: 86 | # Move only changes orientation, which we already handled 87 | pass 88 | else: 89 | raise ValueError("Invalid action {}".format(action)) 90 | 91 | return state 92 | 93 | def s_to_ansi(self, state): 94 | return str(self.s_to_obs(state)) 95 | 96 | 97 | if __name__ == "__main__": 98 | from gym.utils.play import play 99 | 100 | env = BasicRoomEnv(1) 101 | play(env, fps=5) 102 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/gridworlds/batteries_spec.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from deep_rlsp.envs.gridworlds.batteries import BatteriesState 3 | 4 | 5 | class BatteriesSpec(object): 6 | def __init__(self, height, width, init_state, feature_locations, train_transition): 7 | """See BatteriesEnv.__init__ in batteries.py for details.""" 8 | self.height = height 9 | self.width = width 10 | self.init_state = init_state 11 | self.feature_locations = feature_locations 12 | self.train_transition = train_transition 13 | 14 | 15 | def get_problem(version): 16 | # In the diagram below, G is a goal location, B is a battery, A is the 17 | # agent, and T is the train. 18 | # Each tuple is of the form (spec, current state, task R, true R). 19 | # ------- 20 | # |B G | 21 | # | TT | 22 | # | TTG| 23 | # | | 24 | # |A B| 25 | # ------- 26 | spec = BatteriesSpec( 27 | 5, 28 | 5, 29 | BatteriesState( 30 | (0, 4), # agent pos 31 | (2, 1), # train pos 32 | 8, # train battery (/life) 33 | {(0, 0): True, (4, 4): True}, # batteries present in environment 34 | False, # agent carrying battery 35 | ), 36 | [(2, 0), (4, 2)], # goals 37 | # train transitions 38 | {(2, 1): (3, 1), (3, 1): (3, 2), (3, 2): (2, 2), (2, 2): (2, 1)}, 39 | ) 40 | final_state = BatteriesState( 41 | (2, 0), # agent pos 42 | (3, 2), # train pos 43 | 8, # train battery (/ life) 44 | {(0, 0): False, (4, 4): True}, # batteries present in environment 45 | False, # agent carrying battery 46 | ) 47 | train_weight = -1 if version == "easy" else 0 48 | task_reward = np.array([0, train_weight, 0, 0, 0, 0, 0, 1]) 49 | true_reward = np.array([0, -1, 0, 0, 0, 0, 0, 1]) 50 | return (spec, final_state, task_reward, true_reward) 51 | 52 | 53 | BATTERIES_PROBLEMS = {"default": get_problem("default"), "easy": get_problem("easy")} 54 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/gridworlds/gym_envs.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | from deep_rlsp.envs.gridworlds.one_hot_action_space_wrapper import ( 4 | OneHotActionSpaceWrapper, 5 | ) 6 | from deep_rlsp.envs.gridworlds import TOY_PROBLEMS, TOY_ENV_CLASSES 7 | 8 | 9 | def get_gym_gridworld_id_time_limit(env_name, env_spec): 10 | id = env_name + "_" + env_spec 11 | id = "".join([s.capitalize() for s in id.split("_")]) 12 | id += "-v0" 13 | spec, _, _, _ = TOY_PROBLEMS[env_name][env_spec] 14 | env = TOY_ENV_CLASSES[env_name](spec) 15 | return id, env.time_horizon 16 | 17 | 18 | def make_gym_gridworld(env_name, env_spec): 19 | spec, _, _, _ = TOY_PROBLEMS[env_name][env_spec] 20 | env = TOY_ENV_CLASSES[env_name](spec) 21 | env = OneHotActionSpaceWrapper(env) 22 | return env 23 | 24 | 25 | def get_gym_gridworld(env_name, env_spec): 26 | id, time_horizon = get_gym_gridworld_id_time_limit(env_name, env_spec) 27 | env = gym.make(id) 28 | return env 29 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/gridworlds/one_hot_action_space_wrapper.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | 5 | class OneHotActionSpaceWrapper(gym.ActionWrapper): 6 | def __init__(self, env): 7 | assert isinstance(env.action_space, gym.spaces.Discrete) 8 | super().__init__(env) 9 | self.n_actions = env.action_space.n 10 | self.action_space = gym.spaces.Box(-np.inf, np.inf, (self.n_actions,)) 11 | 12 | def step(self, action, **kwargs): 13 | return self.env.step(self.action(action), **kwargs) 14 | 15 | def action(self, one_hot_action): 16 | assert one_hot_action.shape == (self.n_actions,) 17 | action = np.argmax(one_hot_action) 18 | return action 19 | 20 | def reverse_action(self, action): 21 | return np.arange(self.n_actions) == action 22 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/gridworlds/room_spec.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from deep_rlsp.envs.gridworlds.room import RoomState 4 | 5 | 6 | class RoomSpec(object): 7 | def __init__( 8 | self, 9 | height, 10 | width, 11 | init_state, 12 | carpet_locations, 13 | feature_locations, 14 | use_pixels_as_observations, 15 | ): 16 | """See RoomEnv.__init__ in room.py for details.""" 17 | self.height = height 18 | self.width = width 19 | self.init_state = init_state 20 | self.carpet_locations = carpet_locations 21 | self.feature_locations = feature_locations 22 | self.use_pixels_as_observations = use_pixels_as_observations 23 | 24 | 25 | # In the diagrams below, G is a goal location, V is a vase, C is a carpet, A is 26 | # the agent. Each tuple is of the form (spec, current state, task R, true R). 27 | 28 | ROOM_PROBLEMS = { 29 | # ------- 30 | # | G | 31 | # |GCVC | 32 | # | A | 33 | # ------- 34 | "default": ( 35 | RoomSpec( 36 | 3, 37 | 5, 38 | RoomState((2, 2), {(2, 1): True}), 39 | [(1, 1), (3, 1)], 40 | [(0, 1), (2, 0)], 41 | False, 42 | ), 43 | RoomState((2, 0), {(2, 1): True}), 44 | np.array([0, 0, 1, 0]), 45 | np.array([-1, 0, 1, 0]), 46 | ), 47 | "default_pixel": ( 48 | RoomSpec( 49 | 3, 50 | 5, 51 | RoomState((2, 2), {(2, 1): True}), 52 | [(1, 1), (3, 1)], 53 | [(0, 1), (2, 0)], 54 | True, 55 | ), 56 | RoomState((2, 0), {(2, 1): True}), 57 | np.array([0, 0, 1, 0]), 58 | np.array([-1, 0, 1, 0]), 59 | ), 60 | # ------- 61 | # | G | 62 | # |GCVCA| 63 | # | | 64 | # ------- 65 | "alt": ( 66 | RoomSpec( 67 | 3, 68 | 5, 69 | RoomState((4, 1), {(2, 1): True}), 70 | [(1, 1), (3, 1)], 71 | [(0, 1), (2, 0)], 72 | False, 73 | ), 74 | RoomState((2, 0), {(2, 1): True}), 75 | np.array([0, 0, 1, 0]), 76 | np.array([-1, 0, 1, 0]), 77 | ), 78 | # ------- 79 | # |G VG| 80 | # | | 81 | # |A C | 82 | # ------- 83 | "bad": ( 84 | RoomSpec( 85 | 3, 5, RoomState((0, 2), {(3, 0): True}), [(3, 2)], [(0, 0), (4, 0)], False 86 | ), 87 | RoomState((0, 0), {(3, 0): True}), 88 | np.array([0, 0, 0, 1]), 89 | np.array([-1, 0, 0, 1]), 90 | ), 91 | # ------- 92 | "big": ( 93 | RoomSpec( 94 | 10, 95 | 10, 96 | RoomState( 97 | (0, 2), 98 | { 99 | (0, 5): True, 100 | # (0, 9): True, 101 | # (1, 2): True, 102 | # (2, 4): True, 103 | # (2, 5): True, 104 | # (2, 6): True, 105 | # (3, 1): True, 106 | # (3, 3): True, 107 | # (3, 8): True, 108 | # (4, 2): True, 109 | # (4, 4): True, 110 | # (4, 5): True, 111 | # (4, 6): True, 112 | # (4, 9): True, 113 | # (5, 3): True, 114 | # (5, 5): True, 115 | # (5, 7): True, 116 | # (5, 8): True, 117 | # (6, 1): True, 118 | # (6, 2): True, 119 | # (6, 4): True, 120 | # (6, 7): True, 121 | # (6, 9): True, 122 | # (7, 2): True, 123 | # (7, 5): True, 124 | # (7, 8): True, 125 | # (8, 6): True, 126 | # (9, 0): True, 127 | # (9, 2): True, 128 | # (9, 6): True, 129 | }, 130 | ), 131 | [(1, 1), (1, 3), (2, 0), (6, 6), (7, 3)], 132 | [(0, 4), (0, 7), (5, 1), (8, 7), (9, 3)], 133 | False, 134 | ), 135 | RoomState( 136 | (0, 7), 137 | { 138 | (0, 5): True, 139 | # (0, 9): True, 140 | # (1, 2): True, 141 | # (2, 4): True, 142 | # (2, 5): True, 143 | # (2, 6): True, 144 | # (3, 1): True, 145 | # (3, 3): True, 146 | # (3, 8): True, 147 | # (4, 2): True, 148 | # (4, 4): True, 149 | # (4, 5): True, 150 | # (4, 6): True, 151 | # (4, 9): True, 152 | # (5, 3): True, 153 | # (5, 5): True, 154 | # (5, 7): True, 155 | # (5, 8): True, 156 | # (6, 1): True, 157 | # (6, 2): True, 158 | # (6, 4): True, 159 | # (6, 7): True, 160 | # (6, 9): True, 161 | # (7, 2): True, 162 | # (7, 5): True, 163 | # (7, 8): True, 164 | # (8, 6): True, 165 | # (9, 0): True, 166 | # (9, 2): True, 167 | # (9, 6): True, 168 | }, 169 | ), 170 | np.array([0, 0, 0, 0, 0, 0, 1]), 171 | np.array([-1, 0, 0, 0, 0, 0, 1]), 172 | ), 173 | } 174 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/gridworlds/tests/apples_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from deep_rlsp.envs.gridworlds.apples import ApplesState, ApplesEnv 6 | from deep_rlsp.envs.gridworlds.tests.common import BaseTests, get_directions 7 | 8 | 9 | class TestApplesSpec(object): 10 | def __init__(self): 11 | """Test spec for the Apples environment. 12 | 13 | T is a tree, B is a bucket, C is a carpet, A is the agent. 14 | ----- 15 | |T T| 16 | | | 17 | |AB | 18 | ----- 19 | """ 20 | self.height = 3 21 | self.width = 5 22 | self.init_state = ApplesState( 23 | agent_pos=(0, 0, 2), 24 | tree_states={(0, 0): True, (2, 0): True}, 25 | bucket_states={(1, 2): 0}, 26 | carrying_apple=False, 27 | ) 28 | # Use a power of 2, to avoid rounding issues 29 | self.apple_regen_probability = 1.0 / 4 30 | self.bucket_capacity = 10 31 | self.include_location_features = True 32 | 33 | 34 | class TestApplesEnv(BaseTests.TestEnv): 35 | def setUp(self): 36 | self.env = ApplesEnv(TestApplesSpec()) 37 | 38 | u, d, l, r, s = get_directions() 39 | i = 5 # interact action 40 | 41 | def make_state(agent_pos, tree1, tree2, bucket, carrying_apple): 42 | tree_states = {(0, 0): tree1, (2, 0): tree2} 43 | bucket_state = {(1, 2): bucket} 44 | return ApplesState(agent_pos, tree_states, bucket_state, carrying_apple) 45 | 46 | self.trajectories = [ 47 | [ 48 | (u, (make_state((u, 0, 1), True, True, 0, False), 1.0)), 49 | (i, (make_state((u, 0, 1), False, True, 0, True), 1.0)), 50 | (r, (make_state((r, 1, 1), False, True, 0, True), 3.0 / 4)), 51 | (d, (make_state((d, 1, 1), False, True, 0, True), 3.0 / 4)), 52 | (i, (make_state((d, 1, 1), False, True, 1, False), 3.0 / 4)), 53 | (u, (make_state((u, 1, 0), False, True, 1, False), 3.0 / 4)), 54 | (r, (make_state((r, 1, 0), False, True, 1, False), 3.0 / 4)), 55 | (i, (make_state((r, 1, 0), False, False, 1, True), 3.0 / 4)), 56 | (d, (make_state((d, 1, 1), False, False, 1, True), 9.0 / 16)), 57 | (i, (make_state((d, 1, 1), True, False, 2, False), 3.0 / 16)), 58 | (s, (make_state((d, 1, 1), True, True, 2, False), 1.0 / 4)), 59 | ] 60 | ] 61 | 62 | 63 | class TestApplesModel(BaseTests.TestTabularTransitionModel): 64 | def setUp(self): 65 | self.env = ApplesEnv(TestApplesSpec()) 66 | self.model_tests = [] 67 | 68 | _, _, _, _, stay = get_directions() 69 | policy_stay = np.zeros((self.env.nS, self.env.nA)) 70 | policy_stay[:, stay] = 1 71 | 72 | def make_state(apple1_present, apple2_present): 73 | return ApplesState( 74 | agent_pos=(0, 1, 1), 75 | tree_states={(0, 0): apple1_present, (2, 0): apple2_present}, 76 | bucket_states={(1, 2): 0}, 77 | carrying_apple=False, 78 | ) 79 | 80 | state_0_0 = make_state(False, False) 81 | state_0_1 = make_state(False, True) 82 | state_1_0 = make_state(True, False) 83 | state_1_1 = make_state(True, True) 84 | 85 | forward_probs = np.zeros(self.env.nS) 86 | forward_probs[self.env.get_num_from_state(state_1_1)] = 1 87 | backward_probs = np.zeros(self.env.nS) 88 | backward_probs[self.env.get_num_from_state(state_0_0)] = 0.04 89 | backward_probs[self.env.get_num_from_state(state_0_1)] = 0.16 90 | backward_probs[self.env.get_num_from_state(state_1_0)] = 0.16 91 | backward_probs[self.env.get_num_from_state(state_1_1)] = 0.64 92 | transitions = [(state_1_1, 1, forward_probs, backward_probs)] 93 | 94 | unif = np.ones(self.env.nS) / self.env.nS 95 | self.model_tests.append( 96 | { 97 | "policy": policy_stay, 98 | "transitions": transitions, 99 | "initial_state_distribution": unif, 100 | } 101 | ) 102 | 103 | 104 | if __name__ == "__main__": 105 | unittest.main() 106 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/gridworlds/tests/batteries_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from deep_rlsp.envs.gridworlds.batteries import BatteriesState, BatteriesEnv 4 | from deep_rlsp.envs.gridworlds.tests.common import BaseTests, get_directions 5 | 6 | 7 | class TestBatteriesSpec(object): 8 | def __init__(self): 9 | """Test spec for the Batteries environment. 10 | 11 | G is a goal location, B is a battery, A is the agent, and T is the train. 12 | ------- 13 | |B G | 14 | | TT | 15 | | TTG| 16 | | | 17 | |A B| 18 | ------- 19 | """ 20 | self.height = 5 21 | self.width = 5 22 | self.init_state = BatteriesState( 23 | (0, 4), (2, 1), 8, {(0, 0): True, (4, 4): True}, False 24 | ) 25 | self.feature_locations = [(2, 0), (4, 2)] 26 | self.train_transition = { 27 | (2, 1): (3, 1), 28 | (3, 1): (3, 2), 29 | (3, 2): (2, 2), 30 | (2, 2): (2, 1), 31 | } 32 | 33 | 34 | class TestBatteriesEnv(BaseTests.TestEnv): 35 | def setUp(self): 36 | self.env = BatteriesEnv(TestBatteriesSpec()) 37 | u, d, l, r, s = get_directions() 38 | 39 | def make_state(agent, train, life, battery_vals, carrying_battery): 40 | battery_present = dict(zip([(0, 0), (4, 4)], battery_vals)) 41 | return BatteriesState(agent, train, life, battery_present, carrying_battery) 42 | 43 | self.trajectories = [ 44 | [ 45 | (u, (make_state((0, 3), (3, 1), 7, [True, True], False), 1.0)), 46 | (u, (make_state((0, 2), (3, 2), 6, [True, True], False), 1.0)), 47 | (u, (make_state((0, 1), (2, 2), 5, [True, True], False), 1.0)), 48 | (u, (make_state((0, 0), (2, 1), 4, [False, True], True), 1.0)), 49 | (r, (make_state((1, 0), (3, 1), 3, [False, True], True), 1.0)), 50 | (r, (make_state((2, 0), (3, 2), 2, [False, True], True), 1.0)), 51 | (s, (make_state((2, 0), (2, 2), 1, [False, True], True), 1.0)), 52 | (s, (make_state((2, 0), (2, 1), 0, [False, True], True), 1.0)), 53 | (d, (make_state((2, 1), (3, 1), 9, [False, True], False), 1.0)), 54 | (u, (make_state((2, 0), (3, 2), 8, [False, True], False), 1.0)), 55 | ] 56 | ] 57 | 58 | 59 | @unittest.skip("runs very long") 60 | class TestBatteriesModel(BaseTests.TestTabularTransitionModel): 61 | def setUp(self): 62 | self.env = BatteriesEnv(TestBatteriesSpec()) 63 | self.model_tests = [] 64 | self.setUpDeterministic() 65 | 66 | 67 | if __name__ == "__main__": 68 | unittest.main() 69 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/gridworlds/tests/env_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from deep_rlsp.envs.gridworlds.env import Direction 4 | 5 | 6 | class TestDirection(unittest.TestCase): 7 | def test_direction_number_conversion(self): 8 | all_directions = Direction.ALL_DIRECTIONS 9 | all_numbers = [] 10 | 11 | for direction in Direction.ALL_DIRECTIONS: 12 | number = Direction.get_number_from_direction(direction) 13 | direction_again = Direction.get_direction_from_number(number) 14 | self.assertEqual(direction, direction_again) 15 | all_numbers.append(number) 16 | 17 | # Check that all directions are distinct 18 | num_directions = len(all_directions) 19 | self.assertEqual(len(set(all_directions)), num_directions) 20 | # Check that the numbers are 0, 1, ... num_directions - 1 21 | self.assertEqual(set(all_numbers), set(range(num_directions))) 22 | 23 | 24 | if __name__ == "__main__": 25 | unittest.main() 26 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/gridworlds/tests/room_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from deep_rlsp.envs.gridworlds.room import RoomState, RoomEnv 6 | from deep_rlsp.envs.gridworlds.tests.common import ( 7 | BaseTests, 8 | get_directions, 9 | get_two_action_uniform_policy, 10 | ) 11 | 12 | 13 | class TestRoomSpec(object): 14 | def __init__(self): 15 | """Test spec for the Room environment. 16 | 17 | G is a goal location, V is a vase, C is a carpet, A is the agent. 18 | ------- 19 | |G G G| 20 | | CVC | 21 | | A | 22 | ------- 23 | """ 24 | self.height = 3 25 | self.width = 5 26 | self.init_state = RoomState((2, 2), {(2, 1): True}) 27 | self.carpet_locations = [(1, 1), (3, 1)] 28 | self.feature_locations = [(0, 0), (2, 0), (4, 0)] 29 | self.use_pixels_as_observations = False 30 | 31 | 32 | class TestRoomEnv(BaseTests.TestEnv): 33 | def setUp(self): 34 | self.env = RoomEnv(TestRoomSpec()) 35 | u, d, l, r, s = get_directions() 36 | 37 | self.trajectories = [ 38 | [ 39 | (l, (RoomState((1, 2), {(2, 1): True}), 1.0)), 40 | (u, (RoomState((1, 1), {(2, 1): True}), 1.0)), 41 | (u, (RoomState((1, 0), {(2, 1): True}), 1.0)), 42 | (r, (RoomState((2, 0), {(2, 1): True}), 1.0)), 43 | ], 44 | [ 45 | (u, (RoomState((2, 1), {(2, 1): False}), 1.0)), 46 | (u, (RoomState((2, 0), {(2, 1): False}), 1.0)), 47 | ], 48 | [ 49 | (r, (RoomState((3, 2), {(2, 1): True}), 1.0)), 50 | (u, (RoomState((3, 1), {(2, 1): True}), 1.0)), 51 | (l, (RoomState((2, 1), {(2, 1): False}), 1.0)), 52 | (d, (RoomState((2, 2), {(2, 1): False}), 1.0)), 53 | ], 54 | ] 55 | 56 | 57 | class TestRoomModel(BaseTests.TestTabularTransitionModel): 58 | def setUp(self): 59 | self.env = RoomEnv(TestRoomSpec()) 60 | self.model_tests = [] 61 | 62 | u, d, l, r, s = get_directions() 63 | policy_left_right = get_two_action_uniform_policy( 64 | self.env.nS, self.env.nA, l, r 65 | ) 66 | state_middle = RoomState((2, 1), {(2, 1): False}) 67 | state_left_vase = RoomState((1, 1), {(2, 1): True}) 68 | state_right_vase = RoomState((3, 1), {(2, 1): True}) 69 | state_left_novase = RoomState((1, 1), {(2, 1): False}) 70 | state_right_novase = RoomState((3, 1), {(2, 1): False}) 71 | forward_probs = np.zeros(self.env.nS) 72 | forward_probs[self.env.get_num_from_state(state_left_novase)] = 0.5 73 | forward_probs[self.env.get_num_from_state(state_right_novase)] = 0.5 74 | backward_probs = np.zeros(self.env.nS) 75 | backward_probs[self.env.get_num_from_state(state_left_vase)] = 0.25 76 | backward_probs[self.env.get_num_from_state(state_right_vase)] = 0.25 77 | backward_probs[self.env.get_num_from_state(state_left_novase)] = 0.25 78 | backward_probs[self.env.get_num_from_state(state_right_novase)] = 0.25 79 | transitions = [(state_middle, 1, forward_probs, backward_probs)] 80 | unif = np.ones(self.env.nS) / self.env.nS 81 | self.model_tests.append( 82 | { 83 | "policy": policy_left_right, 84 | "transitions": transitions, 85 | "initial_state_distribution": unif, 86 | } 87 | ) 88 | 89 | self.setUpDeterministic() 90 | 91 | 92 | if __name__ == "__main__": 93 | unittest.main() 94 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/gridworlds/tests/test_observation_spaces.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from deep_rlsp.envs.gridworlds import TOY_PROBLEMS 6 | from deep_rlsp.envs.gridworlds.gym_envs import get_gym_gridworld 7 | 8 | 9 | class TestObservationSpaces(unittest.TestCase): 10 | # def test_observation_state_consistency(self): 11 | # for env_name, problems in TOY_PROBLEMS.items(): 12 | # for env_spec in problems.keys(): 13 | # _, cur_state, _, _ = TOY_PROBLEMS[env_name][env_spec] 14 | # env = get_gym_gridworld(env_name, env_spec) 15 | # for s in [env.init_state, cur_state]: 16 | # obs = env.s_to_obs(s) 17 | # s2 = env.obs_to_s(obs) 18 | # # if s != s2: 19 | # # import pdb 20 | # # 21 | # # pdb.set_trace() 22 | # self.assertEqual(s, s2) 23 | # 24 | # def test_no_error_getting_state_from_random_obs(self): 25 | # np.random.seed(29) 26 | # for env_name, problems in TOY_PROBLEMS.items(): 27 | # for env_spec in problems.keys(): 28 | # env = get_gym_gridworld(env_name, env_spec) 29 | # for _ in range(100): 30 | # obs = np.random.random(env.obs_shape) * 4 - 2 31 | # state = env.obs_to_s(obs) 32 | # # if state not in env.state_num: 33 | # # import pdb 34 | # # 35 | # # pdb.set_trace() 36 | # self.assertIn(state, env.state_num) 37 | 38 | def test_obs_state_consistent_in_rollout(self): 39 | np.random.seed(29) 40 | for env_name, problems in TOY_PROBLEMS.items(): 41 | for env_spec in problems.keys(): 42 | env = get_gym_gridworld(env_name, env_spec) 43 | if not env.use_pixels_as_observations: 44 | obs = env.reset() 45 | done = False 46 | while not done: 47 | action = env.action_space.sample() 48 | obs, reward, done, info = env.step(action) 49 | s = env.obs_to_s(obs) 50 | if s != env.s: 51 | import pdb 52 | 53 | pdb.set_trace() 54 | self.assertEqual(s, env.s) 55 | 56 | 57 | if __name__ == "__main__": 58 | unittest.main() 59 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/gridworlds/tests/train_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from deep_rlsp.envs.gridworlds.train import TrainState, TrainEnv 4 | from deep_rlsp.envs.gridworlds.tests.common import BaseTests, get_directions 5 | 6 | 7 | class TestTrainSpec(object): 8 | def __init__(self): 9 | """Test spec for the Train environment. 10 | 11 | G is a goal location, V is a vase, C is a carpet, A is the agent. 12 | ------- 13 | | G C| 14 | | TT | 15 | | VTTG| 16 | | | 17 | |A | 18 | ------- 19 | """ 20 | self.height = 5 21 | self.width = 5 22 | self.init_state = TrainState((0, 4), {(1, 2): True}, (2, 1), True) 23 | self.carpet_locations = [(4, 0)] 24 | self.feature_locations = ([(2, 0), (4, 2)],) 25 | self.train_transition = { 26 | (2, 1): (3, 1), 27 | (3, 1): (3, 2), 28 | (3, 2): (2, 2), 29 | (2, 2): (2, 1), 30 | } 31 | 32 | 33 | class TestTrainEnv(BaseTests.TestEnv): 34 | def setUp(self): 35 | self.env = TrainEnv(TestTrainSpec()) 36 | u, d, l, r, s = get_directions() 37 | 38 | self.trajectories = [ 39 | [ 40 | (u, (TrainState((0, 3), {(1, 2): True}, (3, 1), True), 1.0)), 41 | (u, (TrainState((0, 2), {(1, 2): True}, (3, 2), True), 1.0)), 42 | (u, (TrainState((0, 1), {(1, 2): True}, (2, 2), True), 1.0)), 43 | (r, (TrainState((1, 1), {(1, 2): True}, (2, 1), True), 1.0)), 44 | (u, (TrainState((1, 0), {(1, 2): True}, (3, 1), True), 1.0)), 45 | (r, (TrainState((2, 0), {(1, 2): True}, (3, 2), True), 1.0)), 46 | (s, (TrainState((2, 0), {(1, 2): True}, (2, 2), True), 1.0)), 47 | (s, (TrainState((2, 0), {(1, 2): True}, (2, 1), True), 1.0)), 48 | ], 49 | [ 50 | (u, (TrainState((0, 3), {(1, 2): True}, (3, 1), True), 1.0)), 51 | (r, (TrainState((1, 3), {(1, 2): True}, (3, 2), True), 1.0)), 52 | (r, (TrainState((2, 3), {(1, 2): True}, (2, 2), True), 1.0)), 53 | ], 54 | [ 55 | (r, (TrainState((1, 4), {(1, 2): True}, (3, 1), True), 1.0)), 56 | (r, (TrainState((2, 4), {(1, 2): True}, (3, 2), True), 1.0)), 57 | (r, (TrainState((3, 4), {(1, 2): True}, (2, 2), True), 1.0)), 58 | (u, (TrainState((3, 3), {(1, 2): True}, (2, 1), True), 1.0)), 59 | (u, (TrainState((3, 2), {(1, 2): True}, (3, 1), True), 1.0)), 60 | (s, (TrainState((3, 2), {(1, 2): True}, (3, 2), False), 1.0)), 61 | (s, (TrainState((3, 2), {(1, 2): True}, (3, 2), False), 1.0)), 62 | (u, (TrainState((3, 1), {(1, 2): True}, (3, 2), False), 1.0)), 63 | (l, (TrainState((2, 1), {(1, 2): True}, (3, 2), False), 1.0)), 64 | ], 65 | ] 66 | 67 | 68 | class TestTrainModel(BaseTests.TestTabularTransitionModel): 69 | def setUp(self): 70 | self.env = TrainEnv(TestTrainSpec()) 71 | self.model_tests = [] 72 | self.setUpDeterministic() 73 | 74 | 75 | if __name__ == "__main__": 76 | unittest.main() 77 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/gridworlds/train_spec.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from deep_rlsp.envs.gridworlds.train import TrainState 4 | 5 | 6 | class TrainSpec(object): 7 | def __init__( 8 | self, 9 | height, 10 | width, 11 | init_state, 12 | carpet_locations, 13 | feature_locations, 14 | train_transition, 15 | ): 16 | """See TrainEnv.__init__ in train.py for details.""" 17 | self.height = height 18 | self.width = width 19 | self.init_state = init_state 20 | self.carpet_locations = carpet_locations 21 | self.feature_locations = feature_locations 22 | self.train_transition = train_transition 23 | 24 | 25 | # In the diagrams below, G is a goal location, V is a vase, C is a carpet, A is 26 | # the agent, and T is the train. 27 | # Each tuple is of the form (spec, current state, task R, true R). 28 | 29 | TRAIN_PROBLEMS = { 30 | # ------- 31 | # | G C| 32 | # | TT | 33 | # | VTTG| 34 | # | | 35 | # |A | 36 | # ------- 37 | "default": ( 38 | TrainSpec( 39 | 5, 40 | 5, 41 | TrainState((0, 4), {(1, 2): True}, (2, 1), True), # init state 42 | [(4, 0)], # carpet 43 | [(2, 0), (4, 2)], # goals 44 | # train transitions 45 | {(2, 1): (3, 1), (3, 1): (3, 2), (3, 2): (2, 2), (2, 2): (2, 1)}, 46 | ), 47 | TrainState((2, 0), {(1, 2): True}, (2, 2), True), # current state for inference 48 | np.array([0, 0, 0, 0, 0, 0, 0, 0, 1]), # specified reward 49 | np.array([-1, 0, -1, 0, 0, 0, 0, 0, 1]), # true reward 50 | ) 51 | } 52 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/mujoco/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/mujoco/__init__.py -------------------------------------------------------------------------------- /src/deep_rlsp/envs/mujoco/assets/ant.xml: -------------------------------------------------------------------------------- 1 | 16 | 17 | 18 | 19 | 98 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/mujoco/assets/ant_footsensor.xml: -------------------------------------------------------------------------------- 1 | 16 | 17 | 18 | 19 | 109 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/mujoco/assets/ant_plot.xml: -------------------------------------------------------------------------------- 1 | 16 | 17 | 18 | 19 | 98 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/mujoco/assets/half_cheetah.xml: -------------------------------------------------------------------------------- 1 | 16 | 17 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 114 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/mujoco/assets/half_cheetah_plot.xml: -------------------------------------------------------------------------------- 1 | 16 | 17 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 114 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/mujoco/half_cheetah.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | 21 | from gym import utils 22 | import numpy as np 23 | from gym.envs.mujoco import mujoco_env 24 | 25 | 26 | class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle): 27 | def __init__( 28 | self, 29 | expose_all_qpos=False, 30 | task="default", 31 | target_velocity=None, 32 | model_path="half_cheetah.xml", 33 | plot=False, 34 | ): 35 | # Settings from 36 | # https://github.com/openai/gym/blob/master/gym/envs/__init__.py 37 | self._expose_all_qpos = expose_all_qpos 38 | self._task = task 39 | self._target_velocity = target_velocity 40 | 41 | xml_path = os.path.join(os.path.dirname(__file__), "assets") 42 | self.model_path = os.path.abspath(os.path.join(xml_path, model_path)) 43 | 44 | mujoco_env.MujocoEnv.__init__(self, self.model_path, 5) 45 | utils.EzPickle.__init__(self) 46 | 47 | def step(self, action): 48 | xposbefore = self.sim.data.qpos[0] 49 | self.do_simulation(action, self.frame_skip) 50 | xposafter = self.sim.data.qpos[0] 51 | xvelafter = self.sim.data.qvel[0] 52 | ob = self._get_obs() 53 | reward_ctrl = -0.1 * np.square(action).sum() 54 | 55 | if self._task == "default": 56 | reward_vel = 0.0 57 | reward_run = (xposafter - xposbefore) / self.dt 58 | reward = reward_ctrl + reward_run 59 | elif self._task == "target_velocity": 60 | reward_vel = -((self._target_velocity - xvelafter) ** 2) 61 | reward = reward_ctrl + reward_vel 62 | elif self._task == "run_back": 63 | reward_vel = 0.0 64 | reward_run = (xposbefore - xposafter) / self.dt 65 | reward = reward_ctrl + reward_run 66 | 67 | done = False 68 | return ( 69 | ob, 70 | reward, 71 | done, 72 | dict(reward_run=reward_run, reward_ctrl=reward_ctrl, reward_vel=reward_vel), 73 | ) 74 | 75 | def _get_obs(self): 76 | if self._expose_all_qpos: 77 | return np.concatenate([self.sim.data.qpos.flat, self.sim.data.qvel.flat]) 78 | return np.concatenate([self.sim.data.qpos.flat[1:], self.sim.data.qvel.flat]) 79 | 80 | def reset_model(self): 81 | qpos = self.init_qpos + self.np_random.uniform( 82 | low=-0.1, high=0.1, size=self.sim.model.nq 83 | ) 84 | qvel = self.init_qvel + self.np_random.randn(self.sim.model.nv) * 0.1 85 | self.set_state(qpos, qvel) 86 | return self._get_obs() 87 | 88 | def viewer_setup(self): 89 | # camera_id = self.model.camera_name2id("track") 90 | # self.viewer.cam.type = 2 91 | # self.viewer.cam.fixedcamid = camera_id 92 | # self.viewer.cam.distance = self.model.stat.extent * 0.5 93 | # camera_id = self.model.camera_name2id("fixed") 94 | # self.viewer.cam.type = 2 95 | # self.viewer.cam.fixedcamid = camera_id 96 | 97 | self.viewer.cam.fixedcamid = -1 98 | self.viewer.cam.trackbodyid = -1 99 | # how much you "zoom in", model.stat.extent is the max limits of the arena 100 | self.viewer.cam.distance = self.model.stat.extent * 0.4 101 | # self.viewer.cam.lookat[0] -= 4 102 | self.viewer.cam.lookat[1] += 1.1 103 | self.viewer.cam.lookat[2] += 0 104 | # camera rotation around the axis in the plane going through the frame origin 105 | # (if 0 you just see a line) 106 | self.viewer.cam.elevation = -20 107 | # camera rotation around the camera's vertical axis 108 | self.viewer.cam.azimuth = 90 109 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/reward_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, Union, Optional 2 | 3 | import numpy as np 4 | import gym 5 | 6 | from deep_rlsp.envs.gridworlds.env import Env 7 | from deep_rlsp.model import LatentSpaceModel 8 | from deep_rlsp.util.parameter_checks import check_between 9 | from deep_rlsp.util.helper import init_env_from_obs 10 | 11 | 12 | class RewardWeightWrapper(gym.Wrapper): 13 | def __init__(self, gridworld_env: Env, reward_weights: np.ndarray): 14 | self.gridworld_env = gridworld_env 15 | self.reward_weights = reward_weights 16 | super().__init__(gridworld_env) 17 | 18 | def update_reward_weights(self, reward_weights): 19 | self.reward_weights = reward_weights 20 | 21 | def step(self, action: int) -> Tuple[Union[int, np.ndarray], float, bool, Dict]: 22 | return self.gridworld_env.step(action, r_vec=self.reward_weights) 23 | 24 | 25 | class LatentSpaceRewardWrapper(gym.Wrapper): 26 | def __init__( 27 | self, 28 | env: Env, 29 | latent_space: LatentSpaceModel, 30 | r_inferred: Optional[np.ndarray], 31 | inferred_weight: Optional[float] = None, 32 | init_observations: Optional[list] = None, 33 | time_horizon: Optional[int] = None, 34 | init_prob: float = 0.2, 35 | use_task_reward: bool = False, 36 | reward_action_norm_factor: float = 0, 37 | ): 38 | if inferred_weight is not None: 39 | check_between("inferred_weight", inferred_weight, 0, 1) 40 | self.env = env 41 | self.latent_space = latent_space 42 | self.r_inferred = r_inferred 43 | if inferred_weight is None: 44 | inferred_weight = 1 45 | self.inferred_weight = inferred_weight 46 | self.state = None 47 | self.timestep = 0 48 | self.use_task_reward = use_task_reward 49 | self.reward_action_norm_factor = reward_action_norm_factor 50 | 51 | self.init_prob = init_prob 52 | self.init_observations = init_observations 53 | 54 | if time_horizon is None: 55 | self.time_horizon = self.env.spec.max_episode_steps 56 | else: 57 | self.time_horizon = time_horizon 58 | super().__init__(env) 59 | 60 | def reset(self) -> Tuple[np.ndarray, float, bool, Dict]: 61 | obs = super().reset() 62 | 63 | if self.init_observations is not None: 64 | if np.random.random() < self.init_prob: 65 | idx = np.random.randint(0, len(self.init_observations)) 66 | obs = self.init_observations[idx] 67 | self.env = init_env_from_obs(self.env, obs) 68 | 69 | self.state = self.latent_space.encoder(obs) 70 | self.timestep = 0 71 | return obs 72 | 73 | def _get_reward(self, action, info): 74 | if self.r_inferred is not None: 75 | inferred = np.dot(self.r_inferred, self.state) 76 | info["inferred"] = inferred 77 | else: 78 | inferred = 0 79 | 80 | if self.use_task_reward and "task_reward" in info: 81 | task = info["task_reward"] 82 | else: 83 | task = 0 84 | 85 | action_norm = np.square(action).sum() 86 | 87 | reward = self.inferred_weight * inferred 88 | reward += task 89 | reward += self.reward_action_norm_factor * action_norm 90 | return reward 91 | 92 | def step( 93 | self, action, return_true_reward: bool = False, 94 | ) -> Tuple[np.ndarray, float, bool, Dict]: 95 | action = np.clip(action, self.action_space.low, self.action_space.high) 96 | obs, true_reward, done, info = self.env.step(action) 97 | self.state = self.latent_space.encoder(obs) 98 | 99 | if return_true_reward: 100 | assert "true_reward" in info 101 | reward = info["true_reward"] 102 | else: 103 | reward = self._get_reward(action, info) 104 | 105 | # ignore termination criterion from mujoco environments 106 | # to avoid reward information leak 107 | self.timestep += 1 108 | done = self.timestep > self.time_horizon 109 | 110 | return obs, reward, done, info 111 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/__init__.py -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/fetch/reach.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/fetch/shared.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/stls/.get: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/.get -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/stls/fetch/base_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/base_link_collision.stl -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/stls/fetch/bellows_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/bellows_link_collision.stl -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/stls/fetch/elbow_flex_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/elbow_flex_link_collision.stl -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/stls/fetch/estop_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/estop_link.stl -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/stls/fetch/forearm_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/forearm_roll_link_collision.stl -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/stls/fetch/gripper_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/gripper_link.stl -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/stls/fetch/head_pan_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/head_pan_link_collision.stl -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/stls/fetch/head_tilt_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/head_tilt_link_collision.stl -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/stls/fetch/l_wheel_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/l_wheel_link_collision.stl -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/stls/fetch/laser_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/laser_link.stl -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/stls/fetch/r_wheel_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/r_wheel_link_collision.stl -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/stls/fetch/shoulder_lift_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/shoulder_lift_link_collision.stl -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/stls/fetch/shoulder_pan_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/shoulder_pan_link_collision.stl -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/stls/fetch/torso_fixed_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/torso_fixed_link.stl -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/stls/fetch/torso_lift_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/torso_lift_link_collision.stl -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/stls/fetch/upperarm_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/upperarm_roll_link_collision.stl -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/stls/fetch/wrist_flex_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/wrist_flex_link_collision.stl -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/stls/fetch/wrist_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/wrist_roll_link_collision.stl -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/textures/block.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/textures/block.png -------------------------------------------------------------------------------- /src/deep_rlsp/envs/robotics/assets/textures/block_hidden.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/textures/block_hidden.png -------------------------------------------------------------------------------- /src/deep_rlsp/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .tabular import TabularTransitionModel 2 | from .latent_space import LatentSpaceModel 3 | from .dynamics_mdn import InverseDynamicsMDN 4 | from .dynamics_mlp import InverseDynamicsMLP 5 | from .inverse_policy_mdn import InversePolicyMDN 6 | from .state_vae import StateVAE 7 | from .experience_replay import ExperienceReplay 8 | 9 | __all__ = [ 10 | "TabularTransitionModel", 11 | "LatentSpaceModel", 12 | "InverseDynamicsMDN", 13 | "InverseDynamicsMLP", 14 | "InversePolicyMDN", 15 | "StateVAE", 16 | "ExperienceReplay", 17 | ] 18 | -------------------------------------------------------------------------------- /src/deep_rlsp/model/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class TransitionModel(abc.ABC): 5 | def __init__(self, env): 6 | pass 7 | 8 | @abc.abstractmethod 9 | def models_observations(self): 10 | pass 11 | 12 | @abc.abstractmethod 13 | def forward_sample(self, state): 14 | pass 15 | 16 | @abc.abstractmethod 17 | def backward_sample(self, state, t): 18 | pass 19 | 20 | 21 | class InversePolicy(abc.ABC): 22 | def __init__(self, env, policy): 23 | pass 24 | 25 | @abc.abstractmethod 26 | def step(self, next_state): 27 | pass 28 | 29 | @abc.abstractmethod 30 | def sample(self, next_state): 31 | pass 32 | -------------------------------------------------------------------------------- /src/deep_rlsp/model/exact_dynamics_mujoco.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import math 3 | import numpy as np 4 | 5 | from deep_rlsp.util.helper import init_env_from_obs 6 | 7 | GOLDEN_RATIO = (math.sqrt(5) + 1) / 2 8 | 9 | 10 | def gss(f, a, b, tol=1e-12): 11 | """Golden section search to find the minimum of f on [a,b]. 12 | 13 | This implementation does not reuse function evaluations and assumes the minimum is c 14 | or d (not on the edges at a or b). 15 | Source: https://en.wikipedia.org/wiki/Golden_section_search 16 | 17 | f: a strictly unimodal function on [a,b] 18 | 19 | Example: 20 | >>> f = lambda x: (x-2)**2 21 | >>> x = gss(f, 1, 5) 22 | >>> print("%.15f" % x) 23 | 2.000009644875678 24 | """ 25 | c = b - (b - a) / GOLDEN_RATIO 26 | d = a + (b - a) / GOLDEN_RATIO 27 | while abs(c - d) > tol: 28 | if f(c) < f(d): 29 | b = d 30 | else: 31 | a = c 32 | # We recompute both c and d here to avoid loss of precision 33 | # which may lead to incorrect results or infinite loop 34 | c = b - (b - a) / GOLDEN_RATIO 35 | d = a + (b - a) / GOLDEN_RATIO 36 | return (b + a) / 2 37 | 38 | 39 | def loss(y1, y2): 40 | return np.square(y1 - y2).mean() 41 | 42 | 43 | def invert(f, y, tolerance=1e-12, max_iters=1000): 44 | """ 45 | f: function to invert (for expensive f, make sure to memoize) 46 | y: output to invert 47 | tolerance: Acceptible numerical error 48 | max_iters: Maximum iterations to try 49 | Returns: x' such that f(x') = y up to tolerance or up to amount achieved 50 | in max_iters time. 51 | """ 52 | x = y 53 | for i in range(max_iters): 54 | dx = np.random.normal(size=x.shape) 55 | 56 | def line_fn(fac): 57 | return loss(f(x + fac * dx), y) 58 | 59 | factor = gss(line_fn, -10.0, 10.0) 60 | x = x + factor * dx 61 | if loss(f(x), y) < tolerance: 62 | # print(f"Took {i} iterations") 63 | return x 64 | print("Max it reached, loss value is {}".format(loss(f(x), y))) 65 | return x 66 | 67 | 68 | class ExactDynamicsMujoco: 69 | def __init__(self, env_id, tolerance=1e-12, max_iters=1000): 70 | self.env = gym.make(env_id) 71 | self.env.reset() 72 | self.tolerance = tolerance 73 | self.max_iters = max_iters 74 | 75 | # @memoize 76 | def dynamics(s, a): 77 | self.env = init_env_from_obs(self.env, s) 78 | s2, _, _, _ = self.env.step(a) 79 | return s2 80 | 81 | self.dynamics = dynamics 82 | 83 | def inverse_dynamics(self, s2, a): 84 | def fn(s): 85 | return self.dynamics(s, a) 86 | 87 | return invert(fn, s2, tolerance=self.tolerance, max_iters=self.max_iters) 88 | 89 | 90 | def main(): 91 | dynamics = ExactDynamicsMujoco("InvertedPendulum-v2") 92 | s = dynamics.env.reset() 93 | for _ in range(5): 94 | a = dynamics.env.action_space.sample() 95 | s2 = dynamics.dynamics(s, a) 96 | inverted_s = dynamics.inverse_dynamics(s2, a) 97 | print(f"s2: {s2}\na: {a}\ns: {s}") 98 | print(f"Inverted: {inverted_s}") 99 | print("RMSE: {}\n".format(np.sqrt(loss(s, inverted_s)))) 100 | s = s2 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /src/deep_rlsp/model/experience_replay.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import numpy as np 4 | 5 | from deep_rlsp.model.exact_dynamics_mujoco import ExactDynamicsMujoco 6 | 7 | 8 | class Normalizer: 9 | def __init__(self, num_inputs): 10 | self.n = np.zeros(num_inputs) 11 | self.mean = np.zeros(num_inputs) 12 | self.mean_diff = np.zeros(num_inputs) 13 | self.std = np.ones(num_inputs) 14 | 15 | def add(self, x): 16 | self.n += 1.0 17 | last_mean = self.mean.copy() 18 | self.mean += (x - self.mean) / self.n 19 | self.mean_diff += (x - last_mean) * (x - self.mean) 20 | var = self.mean_diff / self.n 21 | var[var < 1e-8] = 1 22 | self.std = np.sqrt(var) 23 | 24 | def normalize(self, inputs): 25 | return (inputs - self.mean) / self.std 26 | 27 | def unnormalize(self, inputs): 28 | return inputs * self.std + self.mean 29 | 30 | 31 | class ExperienceReplay: 32 | def __init__(self, capacity): 33 | self.buffer = collections.deque(maxlen=capacity) 34 | self.obs_normalizer = None 35 | self.act_normalizer = None 36 | self.delta_normalizer = None 37 | self.max_delta = None 38 | self.max_delta = None 39 | 40 | def __len__(self): 41 | return len(self.buffer) 42 | 43 | def append(self, obs, action, next_obs): 44 | delta = obs - next_obs 45 | if self.obs_normalizer is None: 46 | self.obs_normalizer = Normalizer(len(obs)) 47 | if self.delta_normalizer is None: 48 | self.delta_normalizer = Normalizer(len(obs)) 49 | self.max_delta = delta 50 | self.min_delta = delta 51 | if self.act_normalizer is None: 52 | self.act_normalizer = Normalizer(len(action)) 53 | self.obs_normalizer.add(obs) 54 | self.obs_normalizer.add(next_obs) 55 | self.act_normalizer.add(action) 56 | self.delta_normalizer.add(delta) 57 | self.buffer.append((obs, action, next_obs)) 58 | self.max_delta = np.maximum(self.max_delta, delta) 59 | self.min_delta = np.minimum(self.min_delta, delta) 60 | 61 | def sample(self, batch_size, normalize=False): 62 | indices = np.random.choice(len(self.buffer), batch_size, replace=False) 63 | obs, act, next_obs = zip(*[self.buffer[idx] for idx in indices]) 64 | obs, act, next_obs = np.array(obs), np.array(act), np.array(next_obs) 65 | if normalize: 66 | obs = self.normalize_obs(obs) 67 | act = self.normalize_act(act) 68 | next_obs = self.normalize_obs(next_obs) 69 | return obs, act, next_obs 70 | 71 | def clip_delta(self, delta): 72 | return np.clip(delta, self.min_delta, self.max_delta) 73 | 74 | def normalize_obs(self, obs): 75 | return self.obs_normalizer.normalize(obs) 76 | 77 | def unnormalize_obs(self, obs): 78 | return self.obs_normalizer.unnormalize(obs) 79 | 80 | def normalize_delta(self, delta): 81 | return self.delta_normalizer.normalize(delta) 82 | 83 | def unnormalize_delta(self, delta): 84 | return self.delta_normalizer.unnormalize(delta) 85 | 86 | def normalize_act(self, act): 87 | return self.act_normalizer.normalize(act) 88 | 89 | def unnormalize_act(self, act): 90 | return self.act_normalizer.unnormalize(act) 91 | 92 | def add_random_rollouts(self, env, timesteps, n_rollouts): 93 | for _ in range(n_rollouts): 94 | obs = env.reset() 95 | for t in range(timesteps): 96 | action = env.action_space.sample() 97 | next_obs, _, _, _ = env.step(action) 98 | self.append(obs, action, next_obs) 99 | obs = next_obs 100 | 101 | def add_play_data(self, env, play_data): 102 | dynamics = ExactDynamicsMujoco(env.unwrapped.spec.id) 103 | observations, actions = play_data["observations"], play_data["actions"] 104 | n_traj = len(observations) 105 | assert len(actions) == n_traj 106 | for i in range(n_traj): 107 | l_traj = len(observations[i]) 108 | for t in range(l_traj): 109 | obs = observations[i][t] 110 | action = actions[i][t] 111 | next_obs = dynamics.dynamics(obs, action) 112 | self.append(obs, action, next_obs) 113 | 114 | def add_policy_rollouts(self, env, policy, n_rollouts, horizon, eps_greedy=0): 115 | for i in range(n_rollouts): 116 | obs = env.reset() 117 | for t in range(horizon): 118 | if eps_greedy > 0 and np.random.random() < eps_greedy: 119 | action = env.action_space.sample() 120 | else: 121 | action = policy.predict(obs)[0] 122 | next_obs, reward, done, info = env.step(action) 123 | self.append(obs, action, next_obs) 124 | obs = next_obs 125 | -------------------------------------------------------------------------------- /src/deep_rlsp/model/gridworlds_feature_space.py: -------------------------------------------------------------------------------- 1 | class GridworldsFeatureSpace: 2 | def __init__(self, env): 3 | self.env = env 4 | s = self.env.init_state 5 | f = self.env.s_to_f(s) 6 | assert len(f.shape) == 1 7 | self.state_size = f.shape[0] 8 | 9 | def encoder(self, obs): 10 | s = self.env.obs_to_s(obs) 11 | f = self.env.s_to_f(s) 12 | return f 13 | 14 | def decoder(self, state): 15 | raise NotImplementedError() 16 | -------------------------------------------------------------------------------- /src/deep_rlsp/model/inverse_model_env_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict, Union 2 | 3 | import numpy as np 4 | import gym 5 | 6 | from deep_rlsp.model import LatentSpaceModel, InverseDynamicsMDN 7 | from deep_rlsp.util.mujoco import compute_reward_done_from_obs 8 | 9 | MIN_FLOAT64 = np.finfo(np.float64).min 10 | MAX_FLOAT64 = np.finfo(np.float64).max 11 | 12 | 13 | class InverseModelGymEnv(gym.Env): 14 | """ 15 | Allows to treat an inverse dynamics model as an MDP. 16 | 17 | Used for model evaluation and debugging. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | latent_space: LatentSpaceModel, 23 | inverse_model: InverseDynamicsMDN, 24 | initial_obs: np.ndarray, 25 | time_horizon: int, 26 | ): 27 | self.latent_space = latent_space 28 | self.inverse_model = inverse_model 29 | 30 | self.time_horizon = time_horizon 31 | self.update_initial_state(obs=initial_obs) 32 | self.action_space = self.latent_space.env.action_space 33 | self.observation_space = gym.spaces.Box( 34 | MIN_FLOAT64, MAX_FLOAT64, shape=(self.latent_space.state_size,) 35 | ) 36 | if hasattr(self.latent_space.env, "nA"): 37 | self.nA = self.latent_space.env.nA 38 | self.reset() 39 | 40 | def update_initial_state(self, state=None, obs=None): 41 | if (state is not None and obs is not None) or (state is None and obs is None): 42 | raise ValueError("Exactly one of state and obs should be None.") 43 | if state is None: 44 | obs = np.expand_dims(obs, 0) 45 | state = self.latent_space.encoder(obs)[0] 46 | self.initial_state = state 47 | 48 | def reset(self) -> np.ndarray: 49 | self.timestep = 0 50 | self.state = self.initial_state 51 | return self.state 52 | 53 | def step( 54 | self, action: Union[int, np.ndarray] 55 | ) -> Tuple[np.ndarray, float, bool, Dict]: 56 | last_state = self.state 57 | last_obs = self.latent_space.decoder(last_state) 58 | self.state = self.inverse_model.step(self.state, action) 59 | obs = self.latent_space.decoder(self.state) 60 | # invert transition to get reward 61 | reward, done = compute_reward_done_from_obs( 62 | self.latent_space.env, obs, action, last_obs 63 | ) 64 | return self.state, reward, done, dict() 65 | -------------------------------------------------------------------------------- /src/deep_rlsp/model/mujoco_debug_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Exact dynamics models and handcoded features used for debugging in the pendulum env. 3 | """ 4 | 5 | import numpy as np 6 | 7 | from deep_rlsp.model.exact_dynamics_mujoco import ExactDynamicsMujoco 8 | 9 | 10 | class IdentityFeatures: 11 | def __init__(self, env): 12 | self.encoder = lambda x: x 13 | self.decoder = lambda x: x 14 | 15 | 16 | class MujocoDebugFeatures: 17 | def __init__(self, env): 18 | self.env_id = env.unwrapped.spec.id 19 | assert env.unwrapped.spec.id in ("InvertedPendulum-v2", "FetchReachStack-v1") 20 | self.env = env 21 | if self.env_id == "InvertedPendulum-v2": 22 | self.state_size = 5 23 | elif self.env_id == "FetchReachStack-v1": 24 | self.state_size = env.observation_space.shape[0] + 1 25 | 26 | def encoder(self, obs): 27 | if self.env_id == "InvertedPendulum-v2": 28 | # feature = np.isfinite(obs).all() and (np.abs(obs[1]) <= 0.2) 29 | feature = np.abs(obs[1]) 30 | elif self.env_id == "FetchReachStack-v1": 31 | feature = int(obs[-4] < 0.5) 32 | obs = np.concatenate([obs, [feature]]) 33 | return obs 34 | 35 | def decoder(self, state): 36 | return state[:-1] 37 | 38 | 39 | class PendulumDynamics: 40 | def __init__(self, latent_space, backward=False): 41 | assert latent_space.env.unwrapped.spec.id == "InvertedPendulum-v2" 42 | self.latent_space = latent_space 43 | self.backward = backward 44 | self.dynamics = ExactDynamicsMujoco( 45 | self.latent_space.env.unwrapped.spec.id, tolerance=1e-2, max_iters=100 46 | ) 47 | self.low = np.array([-1, -np.pi, -10, -10]) 48 | self.high = np.array([1, np.pi, 10, 10]) 49 | 50 | def step(self, state, action, sample=True): 51 | obs = state # self.latent_space.decoder(state) 52 | obs = np.clip(obs, self.low, self.high) 53 | if self.backward: 54 | obs = self.dynamics.inverse_dynamics(obs, action) 55 | else: 56 | obs = self.dynamics.dynamics(obs, action) 57 | state = obs # self.latent_space.encoder(obs) 58 | return state 59 | 60 | def learn(self, *args, return_initial_loss=False, **kwargs): 61 | print("Using exact dynamics...") 62 | if return_initial_loss: 63 | return 0, 0 64 | return 0 65 | -------------------------------------------------------------------------------- /src/deep_rlsp/relative_reachability.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def relative_reachability_penalty(mdp, horizon, start): 5 | """ 6 | Calculates the undiscounted relative reachability penalty for each state in an mdp, 7 | compared to the starting state baseline. 8 | 9 | Based on the algorithm described in: https://arxiv.org/pdf/1806.01186.pdf 10 | """ 11 | coverage = get_coverage(mdp, horizon) 12 | distributions = baseline_state_distributions(mdp, horizon, start) 13 | 14 | def penalty(state): 15 | return np.sum(np.maximum(coverage[state, :] - coverage, 0), axis=1) 16 | 17 | def penalty_for_baseline_distribution(dist): 18 | return sum( 19 | ( 20 | dist[state] * penalty(state) 21 | for state in range(mdp.nS) 22 | if dist[state] != 0 23 | ) 24 | ) 25 | 26 | r_r = np.array(list(map(penalty_for_baseline_distribution, distributions))) 27 | if np.amax(r_r) == 0: 28 | return np.zeros_like(r_r) 29 | return r_r / np.amax(r_r) 30 | 31 | 32 | def get_coverage(mdp, horizon): 33 | coverage = np.identity(mdp.nS) 34 | for i in range(horizon): 35 | # coverage(s0, sk) = \max_{a0} \sum_{s1} P(s1 | s0, a) * coverage(s1, sk) 36 | action_coverage = mdp.T_matrix.dot(coverage) 37 | action_coverage = action_coverage.reshape((mdp.nS, mdp.nA, mdp.nS)) 38 | coverage = np.amax(action_coverage, axis=1) 39 | return coverage 40 | 41 | 42 | def baseline_state_distributions(mdp, horizon, start): 43 | distribution = np.zeros(mdp.nS) 44 | distribution[start] = 1 45 | distributions = [distribution] 46 | for _ in range(horizon - 1): 47 | distribution = mdp.baseline_matrix_transpose.dot(distribution) 48 | distributions.append(distribution) 49 | return distributions 50 | -------------------------------------------------------------------------------- /src/deep_rlsp/rlsp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import check_grad 3 | 4 | from deep_rlsp.solvers.value_iter import value_iter, evaluate_policy 5 | from deep_rlsp.solvers.ppo import PPOSolver 6 | from deep_rlsp.util.parameter_checks import check_in 7 | 8 | 9 | def compute_g(mdp, policy, p_0, T, d_last_step_list, expected_features_list): 10 | nS, nA, nF = mdp.nS, mdp.nA, mdp.num_features 11 | 12 | # base case 13 | G = np.zeros((nS, nF)) 14 | # recursive case 15 | for t in range(T - 1): 16 | # G(s') = \sum_{s, a} p(a | s) p(s' | s, a) [ p(s) g(s, a) + G_prev[s] ] 17 | # p(s) is given by d_last_step_list[t] 18 | # g(s, a) = f(s) - F(s) + \sum_{s'} p(s' | s, a) F(s') 19 | # Distribute the addition to get three different terms: 20 | # First term: p(s) [f(s') - F(s')] 21 | # Second term: p(s) \sum_{s2} p(s2 | s, a) F(s2) 22 | # Third term: G_prev[s] 23 | g_first = mdp.f_matrix - expected_features_list[t] 24 | g_second = mdp.T_matrix.dot(expected_features_list[t + 1]) 25 | g_second = g_second.reshape((nS, nA, nF)) 26 | g_total = np.expand_dims(g_first, axis=1) + g_second 27 | 28 | prob_s_a = np.expand_dims(d_last_step_list[t].reshape(nS), axis=1) * policy[t] 29 | 30 | G_value = np.expand_dims(prob_s_a, axis=2) * g_total 31 | G_value = mdp.T_matrix_transpose.dot(G_value.reshape((nS * nA, nF))) 32 | 33 | G_recurse = np.expand_dims(policy[t], axis=-1) * np.expand_dims(G, axis=1) 34 | G_recurse = mdp.T_matrix_transpose.dot(G_recurse.reshape((nS * nA, nF))) 35 | 36 | G = G_value + G_recurse 37 | 38 | return G 39 | 40 | 41 | def compute_d_last_step(mdp, policy, p_0, T, gamma=1, verbose=False, return_all=False): 42 | """Computes the last-step occupancy measure""" 43 | D, d_last_step_list = p_0, [p_0] 44 | for t in range(T - 1): 45 | # D(s') = \sum_{s, a} D_prev(s) * p(a | s) * p(s' | s, a) 46 | state_action_prob = np.expand_dims(D, axis=1) * policy[t] 47 | D = mdp.T_matrix_transpose.dot(state_action_prob.flatten()) 48 | 49 | if verbose is True: 50 | print(D) 51 | if return_all: 52 | d_last_step_list.append(D) 53 | 54 | return (D, d_last_step_list) if return_all else D 55 | 56 | 57 | def compute_feature_expectations(mdp, policy, p_0, T): 58 | nS, nA, nF = mdp.nS, mdp.nA, mdp.num_features 59 | expected_features = mdp.f_matrix 60 | expected_feature_list = [expected_features] 61 | for t in range(T - 2, -1, -1): 62 | # F(s) = f(s) + \sum_{a, s'} p(a | s) * p(s' | s, a) * F(s') 63 | future_features = mdp.T_matrix.dot(expected_features).reshape((nS, nA, nF)) 64 | future_features = future_features * np.expand_dims(policy[t], axis=-1) 65 | expected_features = mdp.f_matrix + np.sum(future_features, axis=1) 66 | expected_feature_list.append(expected_features) 67 | return expected_features, expected_feature_list[::-1] 68 | 69 | 70 | def rlsp( 71 | _run, 72 | mdp, 73 | s_current, 74 | p_0, 75 | horizon, 76 | temp=1, 77 | epochs=1, 78 | learning_rate=0.2, 79 | r_prior=None, 80 | r_vec=None, 81 | threshold=1e-3, 82 | check_grad_flag=False, 83 | solver="value_iter", 84 | reset_solver=False, 85 | solver_iterations=1000, 86 | ): 87 | """The RLSP algorithm.""" 88 | check_in("solver", solver, ("value_iter", "ppo")) 89 | 90 | def compute_grad(r_vec): 91 | # Compute the Boltzmann rational policy \pi_{s,a} = \exp(Q_{s,a} - V_s) 92 | if solver == "value_iter": 93 | policy = value_iter(mdp, 1, mdp.f_matrix @ r_vec, horizon, temp) 94 | elif solver == "ppo": 95 | policy = ppo.learn( 96 | r_vec, reset_model=reset_solver, total_timesteps=solver_iterations 97 | ) 98 | 99 | _run.log_scalar( 100 | "policy_eval_r_vec", 101 | evaluate_policy( 102 | mdp, 103 | policy, 104 | mdp.get_num_from_state(mdp.init_state), 105 | 1, 106 | mdp.f_matrix @ r_vec, 107 | horizon, 108 | ), 109 | i, 110 | ) 111 | 112 | d_last_step, d_last_step_list = compute_d_last_step( 113 | mdp, policy, p_0, horizon, return_all=True 114 | ) 115 | if d_last_step[s_current] == 0: 116 | print("Error in om_method: No feasible trajectories!") 117 | return r_vec 118 | 119 | expected_features, expected_features_list = compute_feature_expectations( 120 | mdp, policy, p_0, horizon 121 | ) 122 | 123 | G = compute_g( 124 | mdp, policy, p_0, horizon, d_last_step_list, expected_features_list 125 | ) 126 | # Compute the gradient 127 | dL_dr_vec = G[s_current] / d_last_step[s_current] 128 | # Gradient of the prior 129 | if r_prior is not None: 130 | dL_dr_vec += r_prior.logdistr_grad(r_vec) 131 | return dL_dr_vec 132 | 133 | def compute_log_likelihood(r_vec): 134 | policy = value_iter(mdp, 1, mdp.f_matrix @ r_vec, horizon, temp) 135 | d_last_step = compute_d_last_step(mdp, policy, p_0, horizon) 136 | log_likelihood = np.log(d_last_step[s_current]) 137 | if r_prior is not None: 138 | log_likelihood += np.sum(r_prior.logpdf(r_vec)) 139 | return log_likelihood 140 | 141 | def get_grad(_): 142 | """dummy function for use with check_grad()""" 143 | return dL_dr_vec 144 | 145 | if r_vec is None: 146 | r_vec = 0.01 * np.random.randn(mdp.f_matrix.shape[1]) 147 | print("Initial reward vector: {}".format(r_vec)) 148 | 149 | ppo = PPOSolver(mdp, temp) 150 | 151 | if check_grad_flag: 152 | grad_error_list = [] 153 | 154 | for i in range(epochs): 155 | dL_dr_vec = compute_grad(r_vec) 156 | if check_grad_flag: 157 | grad_error_list.append(check_grad(compute_log_likelihood, get_grad, r_vec)) 158 | 159 | # Gradient ascent 160 | r_vec = r_vec + learning_rate * dL_dr_vec 161 | 162 | grad_norm = np.linalg.norm(dL_dr_vec) 163 | 164 | with np.printoptions(precision=4, suppress=True): 165 | print( 166 | "Epoch {}; Reward vector: {}; grad norm: {}".format(i, r_vec, grad_norm) 167 | ) 168 | if check_grad_flag: 169 | print("grad error: {}".format(grad_error_list[-1])) 170 | 171 | if grad_norm < threshold: 172 | if check_grad_flag: 173 | print() 174 | print("Max grad error: {}".format(np.amax(np.asarray(grad_error_list)))) 175 | print( 176 | "Median grad error: {}".format( 177 | np.median(np.asarray(grad_error_list)) 178 | ) 179 | ) 180 | break 181 | 182 | return r_vec 183 | -------------------------------------------------------------------------------- /src/deep_rlsp/sampling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math import exp 3 | 4 | from deep_rlsp.solvers.value_iter import value_iter 5 | from deep_rlsp.rlsp import compute_d_last_step 6 | 7 | 8 | def sample_from_posterior( 9 | env, s_current, p_0, h, temp, n_samples, step_size, r_prior, gamma=1, print_level=1 10 | ): 11 | """ 12 | Algorithm similar to BIRL that uses the last-step OM of a Boltzmann rational 13 | policy instead of the BIRL likelihood. Samples the reward from the posterior 14 | p(r | s_T, r_spec) \\propto p(s_T | \\theta) * p(r | r_spec). 15 | 16 | This is Algorithm 1 in Appendix C of the paper. 17 | """ 18 | 19 | def log_last_step_om(policy): 20 | d_last_step = compute_d_last_step(env, policy, p_0, h) 21 | return np.log(d_last_step[s_current]) 22 | 23 | def log_probability(r_vec, verbose=False): 24 | pi = value_iter(env, gamma, env.f_matrix @ r_vec, h, temp) 25 | log_p = log_last_step_om(pi) 26 | 27 | log_prior = 0 28 | if r_prior is not None: 29 | log_prior = np.sum(r_prior.logpdf(r_vec)) 30 | 31 | if verbose: 32 | print( 33 | "Log prior: {}\nLog prob: {}\nTotal: {}".format( 34 | log_prior, log_p, log_p + log_prior 35 | ) 36 | ) 37 | return log_p + log_prior 38 | 39 | times_accepted = 0 40 | samples = [] 41 | 42 | if r_prior is None: 43 | r = 0.01 * np.random.randn(env.num_features) 44 | else: 45 | r = 0.1 * r_prior.rvs() 46 | 47 | if print_level >= 1: 48 | print("Initial reward: {}".format(r)) 49 | 50 | # probability of the initial reward 51 | log_p = log_probability(r, verbose=(print_level >= 1)) 52 | 53 | while len(samples) < n_samples: 54 | verbose = (print_level >= 1) and (len(samples) % 200 == 199) 55 | if verbose: 56 | print("\nGenerating sample {}".format(len(samples) + 1)) 57 | 58 | r_prime = np.random.normal(r, step_size) 59 | log_p_1 = log_probability(r_prime, verbose=verbose) 60 | 61 | # Accept or reject the new sample 62 | # If we reject, the new sample is the previous sample 63 | acceptance_probability = exp(log_p_1 - log_p) 64 | if np.random.uniform() < acceptance_probability: 65 | times_accepted += 1 66 | r, log_p = r_prime, log_p_1 67 | samples.append(r) 68 | 69 | if verbose: 70 | # Acceptance probability should not be very high or very low 71 | print("Acceptance probability is {}".format(acceptance_probability)) 72 | 73 | if print_level >= 1: 74 | print("Done! Accepted {} of samples".format(times_accepted / n_samples)) 75 | return samples 76 | -------------------------------------------------------------------------------- /src/deep_rlsp/solvers/__init__.py: -------------------------------------------------------------------------------- 1 | from stable_baselines import SAC, PPO2 2 | from stable_baselines.common.policies import MlpPolicy 3 | from stable_baselines.sac.policies import MlpPolicy as MlpPolicySac 4 | from stable_baselines.sac.policies import FeedForwardPolicy as SACPolicy 5 | 6 | 7 | class CustomSACPolicy(SACPolicy): 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs, layers=[256, 256], feature_extraction="mlp") 10 | 11 | 12 | def get_sac(env, **kwargs): 13 | env_id = env.unwrapped.spec.id 14 | if ( 15 | env_id.startswith("Ant") 16 | or env_id.startswith("HalfCheetah") 17 | or env_id.startswith("Swimmer") 18 | or env_id.startswith("Fetch") 19 | ): 20 | sac_kwargs = { 21 | "verbose": 1, 22 | "learning_rate": 3e-4, 23 | "gamma": 0.98, 24 | "tau": 0.01, 25 | "ent_coef": "auto", 26 | "buffer_size": 1000000, 27 | "batch_size": 256, 28 | "learning_starts": 10000, 29 | "train_freq": 1, 30 | "gradient_steps": 1, 31 | } 32 | policy = CustomSACPolicy 33 | elif env_id.startswith("Hopper"): 34 | sac_kwargs = { 35 | "verbose": 1, 36 | "learning_rate": 3e-4, 37 | "ent_coef": 0.01, 38 | "buffer_size": 1000000, 39 | "batch_size": 256, 40 | "learning_starts": 1000, 41 | "train_freq": 1, 42 | "gradient_steps": 1, 43 | } 44 | policy = CustomSACPolicy 45 | else: 46 | sac_kwargs = {"verbose": 1, "learning_starts": 1000} 47 | policy = MlpPolicySac 48 | for key, val in kwargs.items(): 49 | sac_kwargs[key] = val 50 | solver = SAC(policy, env, **sac_kwargs) 51 | return solver 52 | 53 | 54 | def get_ppo(env, **kwargs): 55 | sac_kwargs = {"verbose": 1} 56 | policy = MlpPolicy 57 | for key, val in kwargs.items(): 58 | sac_kwargs[key] = val 59 | solver = PPO2(policy, env, **sac_kwargs) 60 | return solver 61 | -------------------------------------------------------------------------------- /src/deep_rlsp/solvers/ppo.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | 5 | from stable_baselines.common.vec_env import DummyVecEnv 6 | from stable_baselines.common.policies import MlpPolicy 7 | from stable_baselines import PPO2 8 | from stable_baselines.bench.monitor import Monitor 9 | 10 | from deep_rlsp.envs.reward_wrapper import RewardWeightWrapper 11 | 12 | 13 | class PPOSolver: 14 | def __init__( 15 | self, env, temperature: float = 1, tensorboard_log: Optional[str] = None 16 | ): 17 | self.env = RewardWeightWrapper(env, None) 18 | self.temperature = temperature 19 | # Monitor allows PPO to log the reward it achieves 20 | monitored_env = Monitor(self.env, None, allow_early_resets=True) 21 | self.vec_env = DummyVecEnv([lambda: monitored_env]) 22 | self.tensorboard_log = tensorboard_log 23 | self._reset_model() 24 | 25 | def _reset_model(self): 26 | self.model = PPO2( 27 | MlpPolicy, 28 | self.vec_env, 29 | verbose=1, 30 | ent_coef=self.temperature, 31 | tensorboard_log=self.tensorboard_log, 32 | ) 33 | 34 | def _get_tabular_policy(self): 35 | """ 36 | Extracts a tabular policy representation from the PPO2 model. 37 | """ 38 | policy = np.zeros((self.env.nS, self.env.nA)) 39 | for state_id in range(self.env.nS): 40 | state = self.env.get_state_from_num(state_id) 41 | obs = self.env.s_to_obs(state) 42 | probs = self.model.action_probability(obs) 43 | policy[state_id, :] = probs 44 | 45 | # `action_probability` sometimes returns slightly unnormalized distributions 46 | # (probably numerical issues) Hence, we normalize manually. 47 | policy /= policy.sum(axis=1, keepdims=True) 48 | assert np.allclose(policy.sum(axis=1), 1) 49 | return policy 50 | 51 | def learn( 52 | self, reward_weights, total_timesteps: int = 1000, reset_model: bool = False 53 | ): 54 | """ 55 | Performs the PPO algorithm using the implementation from `stable_baselines`. 56 | 57 | Returns (np.ndarray): 58 | Array of shape (nS, nA) containing the action probabilites for each state. 59 | """ 60 | if reset_model: 61 | self._reset_model() 62 | self.env.update_reward_weights(reward_weights) 63 | self.model.learn(total_timesteps=total_timesteps, log_interval=10) 64 | return self._get_tabular_policy() 65 | -------------------------------------------------------------------------------- /src/deep_rlsp/solvers/value_iter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def value_iter(mdp, gamma, r, horizon, temperature=1, time_dependent_reward=False): 5 | """ 6 | Performs (soft) value iteration to find a (Boltzman-)optimal policy. 7 | 8 | Finds the optimal state and state-action value functions via value iteration with 9 | the "soft" max-ent Bellman backup: 10 | 11 | $$ 12 | Q_{sa} = r_s + gamma * \\sum_{s'} p(s'|s,a)V_{s'} 13 | V'_s = temperature * log(\\sum_a exp(Q_{sa}/temperature)) 14 | $$ 15 | 16 | Then, computes the Boltzmann rational policy 17 | 18 | $$ 19 | \\pi_{s,a} = exp((Q_{s,a} - V_s)/temperature). 20 | $$ 21 | 22 | Args 23 | ---------- 24 | mdp (Env): MDP describing the environment. 25 | gamma (float): Discount factor; 0 <= gamma <= 1. 26 | r (np.ndarray): Reward vector of length mdp.nS. 27 | horizon (int): Time-horizon for the finite-horizon value iteration. 28 | temperature (float): Rationality constant for the soft value iteration equation. 29 | 30 | Returns (list of np.ndarray): 31 | Arrays of shape (mdp.nS, mdp.nA), each value p[t][s,a] is the probability of 32 | taking action a in state s at time t. 33 | """ 34 | nS, nA = mdp.nS, mdp.nA 35 | 36 | if not time_dependent_reward: 37 | r = [r] * horizon # Fast, since we aren't making copies 38 | 39 | policies = [] 40 | V = np.copy(r[horizon - 1]) 41 | for t in range(horizon - 2, -1, -1): 42 | future_values = mdp.T_matrix.dot(V).reshape((nS, nA)) 43 | Q = np.expand_dims(r[t], axis=1) + gamma * future_values 44 | 45 | if temperature == 0: 46 | V = Q.max(axis=1) 47 | # Argmax to find the action number, then index into np.eye to 48 | # one hot encode. Note this will deterministically break ties 49 | # towards the smaller action. 50 | policy = np.eye(nA)[np.argmax(Q, axis=1)] 51 | else: 52 | # ∀ s: V_s = temperature * log(\\sum_a exp(Q_sa/temperature)) 53 | # ∀ s,a: policy_{s,a} = exp((Q_{s,a} - V_s)/t) 54 | V = softmax(Q, temperature) 55 | V = np.expand_dims(V, axis=1) 56 | policy = np.exp((Q - V) / temperature) 57 | 58 | policies.append(policy) 59 | 60 | if gamma == 1: 61 | # When \\gamma=1, the backup operator is equivariant under adding 62 | # a constant to all entries of V, so we can translate min(V) 63 | # to be 0 at each step of the softmax value iteration without 64 | # changing the policy it converges to, and this fixes the problem 65 | # where log(nA) keep getting added at each iteration. 66 | V = V - np.amin(V) 67 | 68 | return policies[::-1] 69 | 70 | 71 | def evaluate_policy(mdp, policy, start, gamma, r, horizon): 72 | """Expected reward from the policy.""" 73 | policy = np.array(policy) 74 | if len(policy.shape) == 2: 75 | policy = [policy] * horizon 76 | V = r 77 | for t in range(horizon - 2, -1, -1): 78 | future_values = mdp.T_matrix.dot(V).reshape((mdp.nS, mdp.nA)) 79 | Q = np.expand_dims(r, axis=1) + gamma * future_values 80 | V = np.sum(policy[t] * Q, axis=1) 81 | return V[start] 82 | 83 | 84 | def softmax(x, t=1): 85 | """ 86 | Numerically stable computation of t*log(\\sum_j^n exp(x_j / t)) 87 | 88 | If the input is a 1D numpy array, computes it's softmax: 89 | output = t*log(\\sum_j^n exp(x_j / t)). 90 | If the input is a 2D numpy array, computes the softmax of each of the rows: 91 | output_i = t*log(\\sum_j^n exp(x_{ij} / t)) 92 | 93 | Parameters 94 | ---------- 95 | x : 1D or 2D numpy array 96 | 97 | Returns 98 | ------- 99 | 1D numpy array 100 | shape = (n,), where: 101 | n = 1 if x was 1D, or 102 | n is the number of rows (=x.shape[0]) if x was 2D. 103 | """ 104 | assert t >= 0 105 | if len(x.shape) == 1: 106 | x = x.reshape((1, -1)) 107 | if t == 0: 108 | return np.amax(x, axis=1) 109 | if x.shape[1] == 1: 110 | return x 111 | 112 | def softmax_2_arg(x1, x2, t): 113 | """ 114 | Numerically stable computation of t*log(exp(x1/t) + exp(x2/t)) 115 | 116 | Parameters 117 | ---------- 118 | x1 : numpy array of shape (n,1) 119 | x2 : numpy array of shape (n,1) 120 | 121 | Returns 122 | ------- 123 | numpy array of shape (n,1) 124 | Each output_i = t*log(exp(x1_i / t) + exp(x2_i / t)) 125 | """ 126 | 127 | def tlog(x): 128 | return t * np.log(x) 129 | 130 | def expt(x): 131 | return np.exp(x / t) 132 | 133 | max_x = np.amax((x1, x2), axis=0) 134 | min_x = np.amin((x1, x2), axis=0) 135 | return max_x + tlog(1 + expt((min_x - max_x))) 136 | 137 | sm = softmax_2_arg(x[:, 0], x[:, 1], t) 138 | # Use the following property of softmax_2_arg: 139 | # softmax_2_arg(softmax_2_arg(x1,x2),x3) = log(exp(x1) + exp(x2) + exp(x3)) 140 | # which is true since 141 | # log(exp(log(exp(x1) + exp(x2))) + exp(x3)) = log(exp(x1) + exp(x2) + exp(x3)) 142 | for (i, x_i) in enumerate(x.T): 143 | if i > 1: 144 | sm = softmax_2_arg(sm, x_i, t) 145 | return sm 146 | -------------------------------------------------------------------------------- /src/deep_rlsp/tests/test_toy_experiments.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple smoke tests that run some basic experiments and make sure they don't crash 3 | """ 4 | 5 | 6 | import unittest 7 | 8 | from deep_rlsp.run import ex 9 | 10 | ex.observers = [] 11 | 12 | ALL_ALGORITHMS_ON_ROOM = [ 13 | { 14 | "env_name": "room", 15 | "problem_spec": "default", 16 | "combination_algorithm": "additive", 17 | "inference_algorithm": "spec", 18 | "horizon": 7, 19 | "evaluation_horizon": 20, 20 | "epochs": 5, 21 | }, 22 | { 23 | "env_name": "room", 24 | "problem_spec": "default", 25 | "combination_algorithm": "additive", 26 | "inference_algorithm": "deviation", 27 | "horizon": 7, 28 | "evaluation_horizon": 20, 29 | "inferred_weight": 0.5, 30 | "epochs": 5, 31 | }, 32 | { 33 | "env_name": "room", 34 | "problem_spec": "default", 35 | "combination_algorithm": "additive", 36 | "inference_algorithm": "reachability", 37 | "horizon": 7, 38 | "evaluation_horizon": 20, 39 | "epochs": 5, 40 | }, 41 | { 42 | "env_name": "room", 43 | "problem_spec": "default", 44 | "combination_algorithm": "additive", 45 | "inference_algorithm": "rlsp", 46 | "horizon": 7, 47 | "evaluation_horizon": 20, 48 | "solver": "value_iter", 49 | "epochs": 5, 50 | }, 51 | { 52 | "env_name": "room", 53 | "problem_spec": "default", 54 | "combination_algorithm": "additive", 55 | "inference_algorithm": "rlsp", 56 | "horizon": 7, 57 | "evaluation_horizon": 20, 58 | "solver": "ppo", 59 | "solver_iterations": 100, 60 | "reset_solver": False, 61 | "epochs": 5, 62 | }, 63 | { 64 | "env_name": "room", 65 | "problem_spec": "default", 66 | "combination_algorithm": "additive", 67 | "inference_algorithm": "rlsp", 68 | "horizon": 7, 69 | "evaluation_horizon": 20, 70 | "solver": "ppo", 71 | "solver_iterations": 100, 72 | "reset_solver": True, 73 | "epochs": 5, 74 | }, 75 | ] 76 | 77 | DEVIATION_ON_ALL_ENVS = [ 78 | { 79 | "env_name": "room", 80 | "problem_spec": "default", 81 | "combination_algorithm": "additive", 82 | "inference_algorithm": "deviation", 83 | "horizon": 7, 84 | "evaluation_horizon": 20, 85 | "inferred_weight": 0.5, 86 | "epochs": 5, 87 | }, 88 | { 89 | "env_name": "train", 90 | "problem_spec": "default", 91 | "combination_algorithm": "additive", 92 | "inference_algorithm": "deviation", 93 | "horizon": 8, 94 | "evaluation_horizon": 20, 95 | "inferred_weight": 0.5, 96 | "epochs": 5, 97 | }, 98 | { 99 | "env_name": "apples", 100 | "problem_spec": "default", 101 | "combination_algorithm": "additive", 102 | "inference_algorithm": "deviation", 103 | "horizon": 11, 104 | "evaluation_horizon": 20, 105 | "inferred_weight": 0.5, 106 | "epochs": 5, 107 | }, 108 | { 109 | "env_name": "batteries", 110 | "problem_spec": "easy", 111 | "combination_algorithm": "additive", 112 | "inference_algorithm": "deviation", 113 | "horizon": 11, 114 | "evaluation_horizon": 20, 115 | "inferred_weight": 0.5, 116 | "epochs": 5, 117 | }, 118 | { 119 | "env_name": "batteries", 120 | "problem_spec": "default", 121 | "combination_algorithm": "additive", 122 | "inference_algorithm": "deviation", 123 | "horizon": 11, 124 | "evaluation_horizon": 20, 125 | "inferred_weight": 0.5, 126 | "epochs": 5, 127 | }, 128 | { 129 | "env_name": "room", 130 | "problem_spec": "bad", 131 | "combination_algorithm": "additive", 132 | "inference_algorithm": "deviation", 133 | "horizon": 5, 134 | "evaluation_horizon": 20, 135 | "inferred_weight": 0.5, 136 | "epochs": 5, 137 | }, 138 | ] 139 | 140 | 141 | class TestToyExperiments(unittest.TestCase): 142 | def test_toy_experiments(self): 143 | """ 144 | Runs experiments sequentially 145 | """ 146 | for config_updates in ALL_ALGORITHMS_ON_ROOM + DEVIATION_ON_ALL_ENVS: 147 | ex.run(config_updates=config_updates) 148 | -------------------------------------------------------------------------------- /src/deep_rlsp/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/util/__init__.py -------------------------------------------------------------------------------- /src/deep_rlsp/util/dist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import norm, laplace 3 | 4 | 5 | class NormalDistribution(object): 6 | def __init__(self, mu, sigma=1): 7 | self.mu = mu 8 | self.sigma = sigma 9 | self.distribution = norm(loc=mu, scale=sigma) 10 | 11 | def rvs(self): 12 | """sample""" 13 | return self.distribution.rvs() 14 | 15 | def pdf(self, x): 16 | return self.distribution.pdf(x) 17 | 18 | def logpdf(self, x): 19 | return self.distribution.logpdf(x) 20 | 21 | def logdistr_grad(self, x): 22 | return (self.mu - x) / (self.sigma ** 2) 23 | 24 | 25 | class LaplaceDistribution(object): 26 | def __init__(self, mu, b=1): 27 | self.mu = mu 28 | self.b = b 29 | self.distribution = laplace(loc=mu, scale=b) 30 | 31 | def rvs(self): 32 | """sample""" 33 | return self.distribution.rvs() 34 | 35 | def pdf(self, x): 36 | return self.distribution.pdf(x) 37 | 38 | def logpdf(self, x): 39 | return self.distribution.logpdf(x) 40 | 41 | def logdistr_grad(self, x): 42 | return (self.mu - x) / (np.fabs(x - self.mu) * self.b) 43 | -------------------------------------------------------------------------------- /src/deep_rlsp/util/helper.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import numpy as np 4 | from stable_baselines import SAC 5 | from copy import deepcopy 6 | 7 | from deep_rlsp.util.video import save_video 8 | from deep_rlsp.util.mujoco import initialize_mujoco_from_obs 9 | 10 | 11 | def load_data(filename): 12 | with open(filename, "rb") as f: 13 | play_data = pickle.load(f) 14 | return play_data 15 | 16 | 17 | def init_env_from_obs(env, obs): 18 | id = env.spec.id 19 | if ( 20 | "InvertedPendulum" in id 21 | or "HalfCheetah" in id 22 | or "Hopper" in id 23 | or "Ant" in id 24 | or "Fetch" in id 25 | ): 26 | return initialize_mujoco_from_obs(env, obs) 27 | else: 28 | # gridworld env 29 | state = env.obs_to_s(obs) 30 | env.reset() 31 | env.unwrapped.s = deepcopy(state) 32 | return env 33 | 34 | 35 | def get_trajectory( 36 | env, 37 | policy, 38 | get_observations=False, 39 | get_rgbs=False, 40 | get_return=False, 41 | print_debug=False, 42 | ): 43 | observations = [] if get_observations else None 44 | trajectory_rgbs = [] if get_rgbs else None 45 | total_reward = 0 if get_return else None 46 | 47 | obs = env.reset() 48 | done = False 49 | while not done: 50 | if isinstance(policy, SAC): 51 | a = policy.predict(np.expand_dims(obs, 0), deterministic=False)[0][0] 52 | else: 53 | a, _ = policy.predict(obs, deterministic=False) 54 | obs, reward, done, info = env.step(a) 55 | if print_debug: 56 | print("action") 57 | print(a) 58 | print("obs") 59 | print(obs) 60 | print("reward") 61 | print(reward) 62 | if get_observations: 63 | observations.append(obs) 64 | if get_rgbs: 65 | rgb = env.render("rgb_array") 66 | trajectory_rgbs.append(rgb) 67 | if get_return: 68 | total_reward += reward 69 | return observations, trajectory_rgbs, total_reward 70 | 71 | 72 | def evaluate_policy(env, policy, n_rollouts, video_out=None, print_debug=False): 73 | total_reward = 0 74 | for i in range(n_rollouts): 75 | get_rgbs = i == 0 and video_out is not None 76 | _, trajectory_rgbs, total_reward_episode = get_trajectory( 77 | env, policy, False, get_rgbs, True, print_debug=(print_debug and i == 0) 78 | ) 79 | total_reward += total_reward_episode 80 | if get_rgbs: 81 | save_video(trajectory_rgbs, video_out, fps=20.0) 82 | print("Saved video to", video_out) 83 | return total_reward / n_rollouts 84 | 85 | 86 | def memoize(f): 87 | # Assumes that all inputs to f are 1-D Numpy arrays 88 | memo = {} 89 | 90 | def helper(*args): 91 | key = tuple((tuple(x) for x in args)) 92 | if key not in memo: 93 | memo[key] = f(*args) 94 | return memo[key] 95 | 96 | return helper 97 | 98 | 99 | def sample_obs_from_trajectory(observations, n_samples): 100 | idx = np.random.choice(np.arange(len(observations)), n_samples) 101 | return np.array(observations)[idx] 102 | -------------------------------------------------------------------------------- /src/deep_rlsp/util/linalg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # @memoize # don't memoize for continuous vectors, leads to memory leak 5 | def get_cosine_similarity(vec_a, vec_b): 6 | return np.dot(vec_a, vec_b) / (np.linalg.norm(vec_a) * np.linalg.norm(vec_b)) 7 | -------------------------------------------------------------------------------- /src/deep_rlsp/util/mujoco.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class MujocoObsClipper: 5 | def __init__(self, env_id): 6 | if env_id.startswith("InvertedPendulum"): 7 | self.low = -300 8 | self.high = 300 9 | elif env_id.startswith("HalfCheetah"): 10 | self.low = -300 11 | self.high = 300 12 | elif env_id.startswith("Hopper"): 13 | self.low = -300 14 | self.high = 300 15 | elif env_id.startswith("Ant"): 16 | self.low = -100 17 | self.high = 100 18 | else: 19 | self.low = -float("inf") 20 | self.high = float("inf") 21 | self.counter = 0 22 | 23 | def clip(self, obs): 24 | obs_c = np.clip(obs, self.low, self.high) 25 | clipped = np.any(obs != obs_c) 26 | if clipped: 27 | self.counter += 1 28 | return obs_c, clipped 29 | 30 | 31 | def initialize_mujoco_from_obs(env, obs): 32 | """ 33 | Initialize a given mujoco environment to a state conditioned on an observation. 34 | 35 | Missing information in the observation (which is usually the torso-coordinates) 36 | is filled with zeros. 37 | """ 38 | env_id = env.unwrapped.spec.id 39 | if env_id == "InvertedPendulum-v2": 40 | nfill = 0 41 | elif env_id in ( 42 | "HalfCheetah-v2", 43 | "HalfCheetah-FW-v2", 44 | "HalfCheetah-BW-v2", 45 | "HalfCheetah-Plot-v2", 46 | "Hopper-v2", 47 | "Hopper-FW-v2", 48 | ): 49 | nfill = 1 50 | elif env_id == "Ant-FW-v2": 51 | nfill = 2 52 | elif env_id == "FetchReachStack-v1": 53 | env.set_state_from_obs(obs) 54 | return env 55 | else: 56 | raise NotImplementedError(f"{env_id} not supported") 57 | 58 | nq = env.model.nq 59 | nv = env.model.nv 60 | obs_qpos = np.zeros(nq) 61 | obs_qpos[nfill:] = obs[: nq - nfill] 62 | obs_qvel = obs[nq - nfill : nq - nfill + nv] 63 | env.set_state(obs_qpos, obs_qvel) 64 | return env 65 | 66 | 67 | def get_reward_done_from_obs_act(env, obs, act): 68 | """ 69 | Returns a reward and done variable from an obs and action in a mujoco environment. 70 | 71 | This is a hacky way to get some information about the reward of a state and whether 72 | the episode is done, if you just have an observation. However, it is only used for 73 | evaluating the latent space model by directly training a policy on this signal. 74 | """ 75 | env = initialize_mujoco_from_obs(env, obs) 76 | obs, reward, done, info = env.step(act) 77 | return reward, done 78 | 79 | 80 | def compute_reward_done_from_obs(env, last_obs, action, next_obs): 81 | """ 82 | Returns a reward and done variable from a transition in a mujoco environment. 83 | 84 | Both are manually computed from the observation. 85 | """ 86 | env_id = env.unwrapped.spec.id 87 | if env_id == "InvertedPendulum-v2": 88 | # https://github.com/openai/gym/blob/master/gym/envs/mujoco/inverted_pendulum.py 89 | reward = 1.0 90 | notdone = np.isfinite(next_obs).all() and (np.abs(next_obs[1]) <= 0.2) 91 | done = not notdone 92 | elif env_id == "HalfCheetah-v2": 93 | # Note: This does not work for the default environment because the xposition 94 | # is being removed from the observation by the environment. 95 | nv = env.model.nv 96 | last_obs_qpos = last_obs[:-nv] 97 | next_obs_qpos = next_obs[:-nv] 98 | xposbefore = last_obs_qpos[0] 99 | xposafter = next_obs_qpos[0] 100 | reward_ctrl = -0.1 * np.square(action).sum() 101 | reward_run = (xposafter - xposbefore) / env.dt 102 | reward = reward_ctrl + reward_run 103 | done = False 104 | else: 105 | raise NotImplementedError("{} not supported".format(env_id)) 106 | return reward, done 107 | -------------------------------------------------------------------------------- /src/deep_rlsp/util/parameter_checks.py: -------------------------------------------------------------------------------- 1 | def check_in(name, value, allowed_values): 2 | if value not in allowed_values: 3 | raise ValueError( 4 | "Invalid value for '{}': {}. Must be in {}.".format( 5 | name, value, allowed_values 6 | ) 7 | ) 8 | 9 | 10 | def check_greater_equal(name, value, greater_equal_value): 11 | if value < greater_equal_value: 12 | raise ValueError( 13 | "Invalid value for '{}': {}. Must be >= {}.".format( 14 | name, value, greater_equal_value 15 | ) 16 | ) 17 | 18 | 19 | def check_less_equal(name, value, less_equal_value): 20 | if value > less_equal_value: 21 | raise ValueError( 22 | "Invalid value for '{}': {}. Must be <= {}.".format( 23 | name, value, less_equal_value 24 | ) 25 | ) 26 | 27 | 28 | def check_between(name, value, lower_bound, upper_bound): 29 | check_greater_equal(name, value, lower_bound) 30 | check_less_equal(name, value, upper_bound) 31 | 32 | 33 | def check_not_none(name, value): 34 | if value is None: 35 | raise ValueError("'{}' cannot be None".format(name)) 36 | -------------------------------------------------------------------------------- /src/deep_rlsp/util/probs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def sample_observation(obs_probs_tuple): 5 | # inverse of zip(observations, probs) 6 | observations, probs = zip(*obs_probs_tuple) 7 | i = np.random.choice(len(probs), p=probs) 8 | return observations[i] 9 | 10 | 11 | def get_out_probs_tuple(out_probs, dtype, shape): 12 | """ 13 | Convert between a dictionary of state-probability pairs to a tuple. 14 | 15 | Returns a tuple consisting of an observation and a probability given a dictionary 16 | with string representations of the observations as keys and probabilities as values. 17 | 18 | Normalizes the probabilities in the process. 19 | """ 20 | probs_sum = sum(prob for _, prob in out_probs.items()) 21 | out_probs_tuple = ( 22 | (np.fromstring(obs_str, dtype).reshape(shape), prob / probs_sum) 23 | for obs_str, prob in out_probs.items() 24 | ) 25 | return out_probs_tuple 26 | 27 | 28 | def add_obs_prob_to_dict(dictionary, obs, prob): 29 | """ 30 | Updates a dictionary that tracks a probability distribution over observations. 31 | """ 32 | obs_str = obs.tostring() 33 | if obs_str not in dictionary: 34 | dictionary[obs_str] = 0 35 | dictionary[obs_str] += prob 36 | return dictionary 37 | -------------------------------------------------------------------------------- /src/deep_rlsp/util/results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import jsonpickle 3 | 4 | import jsonpickle.ext.numpy as jsonpickle_numpy 5 | 6 | jsonpickle_numpy.register_handlers() 7 | 8 | 9 | class Artifact: 10 | def __init__(self, file_name, method, _run): 11 | self._run = _run 12 | self.file_name = file_name 13 | self.file_path = os.path.join("/tmp", file_name) 14 | if method is not None: 15 | self.file_obj = open(self.file_path, method) 16 | else: 17 | self.file_obj = None 18 | 19 | def __enter__(self): 20 | if self.file_obj is None: 21 | return self.file_path 22 | else: 23 | return self.file_obj 24 | 25 | def __exit__(self, type, value, traceback): 26 | if self.file_obj is not None: 27 | self.file_obj.close() 28 | self._run.add_artifact(self.file_path) 29 | 30 | 31 | class FileExperimentResults: 32 | def __init__(self, result_folder): 33 | self.result_folder = result_folder 34 | self.config = self._read_json(result_folder, "config.json") 35 | self.metrics = self._read_json(result_folder, "metrics.json") 36 | self.run = self._read_json(result_folder, "run.json") 37 | try: 38 | self.info = self._read_json(result_folder, "info.json") 39 | except Exception as e: 40 | print(e) 41 | self.info = None 42 | self.status = self.run["status"] 43 | self.result = self.run["result"] 44 | 45 | def _read_json(self, result_folder, filename): 46 | with open(os.path.join(result_folder, filename), "r") as f: 47 | json_str = f.read() 48 | return jsonpickle.loads(json_str) 49 | 50 | def get_metric(self, name): 51 | metric = self.metrics[name] 52 | return metric["steps"], metric["values"] 53 | 54 | def print_captured_output(self): 55 | with open(os.path.join(self.result_folder, "cout.txt"), "r") as f: 56 | print(f.read()) 57 | -------------------------------------------------------------------------------- /src/deep_rlsp/util/timer.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | 4 | class Timer: 5 | def __init__(self): 6 | self.entries = {} 7 | self.start_times = {} 8 | 9 | def start(self, description): 10 | return self.Instance(self, description) 11 | 12 | def get_average_time(self, description): 13 | times = self.entries[description] 14 | return sum(times) / len(times) 15 | 16 | def get_total_time(self, description): 17 | return sum(self.entries[description]) 18 | 19 | class Instance: 20 | def __init__(self, timer, description): 21 | self.timer = timer 22 | self.description = description 23 | 24 | def __enter__(self): 25 | if self.description in self.timer.start_times: 26 | raise Exception( 27 | "Cannot start timing {}".format(self.description) 28 | + " again before finishing the last invocation" 29 | ) 30 | self.timer.start_times[self.description] = time() 31 | 32 | def __exit__(self, type, value, traceback): 33 | end = time() 34 | start = self.timer.start_times[self.description] 35 | 36 | if self.description not in self.timer.entries: 37 | self.timer.entries[self.description] = [] 38 | 39 | self.timer.entries[self.description].append(end - start) 40 | del self.timer.start_times[self.description] 41 | 42 | 43 | def main(): 44 | def fact(n): 45 | if n == 0: 46 | return 1 47 | return n * fact(n - 1) 48 | 49 | my_timer = Timer() 50 | for i in range(10): 51 | with my_timer.start("fact"): 52 | print(fact(i)) 53 | 54 | print(my_timer.get_average_time("fact")) 55 | print(my_timer.get_total_time("fact")) 56 | print(my_timer.entries) 57 | 58 | def bad_fact(n): 59 | if n == 0: 60 | return 1 61 | with my_timer.start("bad_fact"): 62 | return n * fact(n - 1) 63 | 64 | with my_timer.start("bad_fact"): 65 | print(bad_fact(10)) 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /src/deep_rlsp/util/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from deep_rlsp.util.parameter_checks import check_between, check_greater_equal 5 | 6 | 7 | def get_tf_session(): 8 | config = tf.ConfigProto() 9 | config.gpu_options.allow_growth = True 10 | return tf.Session(config=config) 11 | 12 | 13 | def get_learning_rate(initial_learning_rate, decay_steps, decay_rate): 14 | global_step = tf.Variable(0, trainable=False) 15 | if decay_rate == 1: 16 | learning_rate = tf.convert_to_tensor(initial_learning_rate) 17 | else: 18 | check_between("decay_rate", decay_rate, 0, 1) 19 | check_greater_equal("decay_steps", decay_steps, 1) 20 | learning_rate = tf.train.exponential_decay( 21 | initial_learning_rate, 22 | global_step, 23 | decay_steps=decay_steps, 24 | decay_rate=decay_rate, 25 | ) 26 | return learning_rate, global_step 27 | 28 | 29 | def tensorboard_log_gradients(gradients): 30 | for gradient, variable in gradients: 31 | tf.summary.scalar("gradients/" + variable.name, tf.norm(gradient, ord=2)) 32 | tf.summary.scalar("variables/" + variable.name, tf.norm(variable, ord=2)) 33 | 34 | 35 | def get_batch(data, batch, batch_size): 36 | batches = [] 37 | for dataset in data: 38 | batch_array = dataset[batch * batch_size : (batch + 1) * batch_size] 39 | batches.append(batch_array) 40 | return batches 41 | 42 | 43 | def shuffle_data(data): 44 | n_states = len(data[0]) 45 | shuffled = np.arange(n_states) 46 | np.random.shuffle(shuffled) 47 | shuffled_data = [] 48 | for dataset in data: 49 | assert len(dataset) == n_states 50 | shuffled_data.append(np.array(dataset)[shuffled]) 51 | return shuffled_data 52 | -------------------------------------------------------------------------------- /src/deep_rlsp/util/video.py: -------------------------------------------------------------------------------- 1 | from deep_rlsp.util.mujoco import initialize_mujoco_from_obs 2 | 3 | 4 | def save_video(ims, filename, fps=20.0): 5 | import cv2 6 | 7 | # Define the codec and create VideoWriter object 8 | fourcc = cv2.VideoWriter_fourcc(*"MJPG") 9 | (height, width, _) = ims[0].shape 10 | writer = cv2.VideoWriter(filename, fourcc, fps, (width, height)) 11 | for im in ims: 12 | im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR) 13 | writer.write(im) 14 | writer.release() 15 | 16 | 17 | def render_mujoco_from_obs(env, obs, **kwargs): 18 | env = initialize_mujoco_from_obs(env, obs) 19 | rgb = env.render(mode="rgb_array", **kwargs) 20 | return rgb 21 | --------------------------------------------------------------------------------