├── .flake8
├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── code_checks.sh
├── demonstrations
├── sac_ant_traj_len_10_seed_354516857.pkl
├── sac_ant_traj_len_10_seed_68863226.pkl
├── sac_ant_traj_len_10_seed_827609420.pkl
├── sac_ant_traj_len_1_seed_592603083.pkl
├── sac_ant_traj_len_1_seed_60586324.pkl
├── sac_ant_traj_len_1_seed_702401661.pkl
├── sac_ant_traj_len_50_seed_135895263.pkl
├── sac_ant_traj_len_50_seed_238906182.pkl
├── sac_ant_traj_len_50_seed_276687624.pkl
├── sac_cheetah_bw_traj_len_10_seed_200763839.pkl
├── sac_cheetah_bw_traj_len_10_seed_293462553.pkl
├── sac_cheetah_bw_traj_len_10_seed_779465690.pkl
├── sac_cheetah_bw_traj_len_1_seed_531475214.pkl
├── sac_cheetah_bw_traj_len_1_seed_903840226.pkl
├── sac_cheetah_bw_traj_len_1_seed_905588575.pkl
├── sac_cheetah_bw_traj_len_50_seed_222108450.pkl
├── sac_cheetah_bw_traj_len_50_seed_472880492.pkl
├── sac_cheetah_bw_traj_len_50_seed_890295234.pkl
├── sac_cheetah_fw_traj_len_10_seed_102743744.pkl
├── sac_cheetah_fw_traj_len_10_seed_569580504.pkl
├── sac_cheetah_fw_traj_len_10_seed_874088138.pkl
├── sac_cheetah_fw_traj_len_1_seed_22750069.pkl
├── sac_cheetah_fw_traj_len_1_seed_31823986.pkl
├── sac_cheetah_fw_traj_len_1_seed_686078893.pkl
├── sac_cheetah_fw_traj_len_50_seed_559763593.pkl
├── sac_cheetah_fw_traj_len_50_seed_592117357.pkl
├── sac_cheetah_fw_traj_len_50_seed_96575016.pkl
├── sac_hopper_traj_len_10_seed_647574168.pkl
├── sac_hopper_traj_len_10_seed_700423463.pkl
├── sac_hopper_traj_len_10_seed_750917369.pkl
├── sac_hopper_traj_len_1_seed_125775086.pkl
├── sac_hopper_traj_len_1_seed_582846018.pkl
├── sac_hopper_traj_len_1_seed_809943847.pkl
├── sac_hopper_traj_len_50_seed_264928546.pkl
├── sac_hopper_traj_len_50_seed_371855109.pkl
├── sac_hopper_traj_len_50_seed_802647326.pkl
├── sac_pendulum_traj_len_10_seed_177330974.pkl
├── sac_pendulum_traj_len_10_seed_294267497.pkl
├── sac_pendulum_traj_len_10_seed_610830928.pkl
├── sac_pendulum_traj_len_1_seed_556679729.pkl
├── sac_pendulum_traj_len_1_seed_583612201.pkl
├── sac_pendulum_traj_len_1_seed_728082477.pkl
├── sac_pendulum_traj_len_50_seed_196767060.pkl
├── sac_pendulum_traj_len_50_seed_523507457.pkl
├── sac_pendulum_traj_len_50_seed_960406656.pkl
├── sac_reach_true_traj_len_10_seed_183338627.pkl
├── sac_reach_true_traj_len_10_seed_48907083.pkl
├── sac_reach_true_traj_len_10_seed_880581313.pkl
├── sac_reach_true_traj_len_1_seed_326943823.pkl
├── sac_reach_true_traj_len_1_seed_497025194.pkl
├── sac_reach_true_traj_len_1_seed_571661907.pkl
├── sac_reach_true_traj_len_49_traj_n_1000_seed_85509514.pkl
├── sac_reach_true_traj_len_50_seed_44258750.pkl
├── sac_reach_true_traj_len_50_seed_696002925.pkl
└── sac_reach_true_traj_len_50_seed_742173558.pkl
├── docker
├── Dockerfile
└── environment.yml
├── policies
├── sac_ant_2e6.mp4
├── sac_ant_2e6.zip
├── sac_cheetah_bw_2e6.mp4
├── sac_cheetah_bw_2e6.zip
├── sac_cheetah_fw_2e6.mp4
├── sac_cheetah_fw_2e6.zip
├── sac_hopper_2e6.mp4
├── sac_hopper_2e6.zip
├── sac_pendulum_6e4.mp4
├── sac_pendulum_6e4.zip
├── sac_reach_task_2e6.mp4
├── sac_reach_task_2e6.zip
├── sac_reach_true_2e6.mp4
└── sac_reach_true_2e6.zip
├── scripts
├── create_demonstrations.py
├── evaluate_policy.py
├── evaluate_result_policies.py
├── mujoco_evaluate_inferred_reward.py
├── plot_learning_curves.py
├── run_gail.py
├── toy_experiments.py
├── train_inverse_dynamics.py
├── train_sac.py
└── train_vae.py
├── setup.py
├── skills
├── balancing.mp4
├── balancing_rollouts.pkl
├── jumping.mp4
└── jumping_rollouts.pkl
└── src
└── deep_rlsp
├── __init__.py
├── ablation_AverageFeatures.py
├── ablation_Waypoints.py
├── envs
├── __init__.py
├── gridworlds
│ ├── __init__.py
│ ├── apples.py
│ ├── apples_spec.py
│ ├── basic_room.py
│ ├── batteries.py
│ ├── batteries_spec.py
│ ├── env.py
│ ├── gym_envs.py
│ ├── one_hot_action_space_wrapper.py
│ ├── room.py
│ ├── room_spec.py
│ ├── tests
│ │ ├── apples_test.py
│ │ ├── batteries_test.py
│ │ ├── common.py
│ │ ├── env_test.py
│ │ ├── room_test.py
│ │ ├── test_observation_spaces.py
│ │ └── train_test.py
│ ├── train.py
│ └── train_spec.py
├── mujoco
│ ├── __init__.py
│ ├── ant.py
│ ├── assets
│ │ ├── ant.xml
│ │ ├── ant_footsensor.xml
│ │ ├── ant_plot.xml
│ │ ├── half_cheetah.xml
│ │ └── half_cheetah_plot.xml
│ └── half_cheetah.py
├── reward_wrapper.py
└── robotics
│ ├── __init__.py
│ ├── assets
│ ├── LICENSE.md
│ ├── fetch
│ │ ├── reach.xml
│ │ ├── robot.xml
│ │ └── shared.xml
│ ├── stls
│ │ ├── .get
│ │ └── fetch
│ │ │ ├── base_link_collision.stl
│ │ │ ├── bellows_link_collision.stl
│ │ │ ├── elbow_flex_link_collision.stl
│ │ │ ├── estop_link.stl
│ │ │ ├── forearm_roll_link_collision.stl
│ │ │ ├── gripper_link.stl
│ │ │ ├── head_pan_link_collision.stl
│ │ │ ├── head_tilt_link_collision.stl
│ │ │ ├── l_wheel_link_collision.stl
│ │ │ ├── laser_link.stl
│ │ │ ├── r_wheel_link_collision.stl
│ │ │ ├── shoulder_lift_link_collision.stl
│ │ │ ├── shoulder_pan_link_collision.stl
│ │ │ ├── torso_fixed_link.stl
│ │ │ ├── torso_lift_link_collision.stl
│ │ │ ├── upperarm_roll_link_collision.stl
│ │ │ ├── wrist_flex_link_collision.stl
│ │ │ └── wrist_roll_link_collision.stl
│ └── textures
│ │ ├── block.png
│ │ └── block_hidden.png
│ └── reach_side_effects.py
├── latent_rlsp.py
├── learn_dynamics_model.py
├── model
├── __init__.py
├── base.py
├── dynamics_mdn.py
├── dynamics_mlp.py
├── exact_dynamics_mujoco.py
├── experience_replay.py
├── gridworlds_feature_space.py
├── inverse_model_env_wrapper.py
├── inverse_policy_mdn.py
├── latent_space.py
├── mujoco_debug_models.py
├── rssm.py
├── state_vae.py
└── tabular.py
├── policy_discriminator.py
├── relative_reachability.py
├── rlsp.py
├── run.py
├── run_mujoco.py
├── sampling.py
├── solvers
├── __init__.py
├── ppo.py
└── value_iter.py
├── tests
└── test_toy_experiments.py
└── util
├── __init__.py
├── dist.py
├── helper.py
├── linalg.py
├── mujoco.py
├── parameter_checks.py
├── probs.py
├── results.py
├── timer.py
├── train.py
└── video.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 88
3 | # See https://github.com/PyCQA/pycodestyle/issues/373
4 | extend-ignore = E203,
5 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *~
2 | *.swp
3 | gail_logs/
4 | output/
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # pyenv
81 | .python-version
82 |
83 | # celery beat schedule file
84 | celerybeat-schedule
85 |
86 | # SageMath parsed files
87 | *.sage.py
88 |
89 | # Environments
90 | .env
91 | .venv
92 | env/
93 | venv/
94 | ENV/
95 | env.bak/
96 | venv.bak/
97 |
98 | # Spyder project settings
99 | .spyderproject
100 | .spyproject
101 |
102 | # Rope project settings
103 | .ropeproject
104 |
105 | # mkdocs documentation
106 | /site
107 |
108 | # mypy
109 | .mypy_cache/
110 |
111 | .envrc
112 | results/
113 | tf_ckpt
114 | mjkey.txt
115 |
116 | # log files
117 | *.out
118 | *.log
119 |
120 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | dist: xenial
3 | cache: pip
4 | python:
5 | - 3.7
6 | before_install:
7 | - python$PY -m pip install --upgrade pip setuptools wheel
8 | install:
9 | # hotfix: ray 0.8.0 causes the jobs to not be properly executed on travis
10 | # (but it works on other machines)
11 | - pip install ray==0.7.6
12 | - pip install black mypy flake8
13 | script:
14 | - bash code_checks.sh
15 | - python setup.py test
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Center for Human-Compatible AI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/code_checks.sh:
--------------------------------------------------------------------------------
1 | black --check src
2 | black --check scripts
3 | flake8 src
4 | #flake8 scripts
5 | mypy --ignore-missing-imports src
6 |
--------------------------------------------------------------------------------
/demonstrations/sac_ant_traj_len_10_seed_354516857.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_ant_traj_len_10_seed_354516857.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_ant_traj_len_10_seed_68863226.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_ant_traj_len_10_seed_68863226.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_ant_traj_len_10_seed_827609420.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_ant_traj_len_10_seed_827609420.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_ant_traj_len_1_seed_592603083.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_ant_traj_len_1_seed_592603083.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_ant_traj_len_1_seed_60586324.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_ant_traj_len_1_seed_60586324.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_ant_traj_len_1_seed_702401661.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_ant_traj_len_1_seed_702401661.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_ant_traj_len_50_seed_135895263.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_ant_traj_len_50_seed_135895263.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_ant_traj_len_50_seed_238906182.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_ant_traj_len_50_seed_238906182.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_ant_traj_len_50_seed_276687624.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_ant_traj_len_50_seed_276687624.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_cheetah_bw_traj_len_10_seed_200763839.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_bw_traj_len_10_seed_200763839.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_cheetah_bw_traj_len_10_seed_293462553.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_bw_traj_len_10_seed_293462553.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_cheetah_bw_traj_len_10_seed_779465690.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_bw_traj_len_10_seed_779465690.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_cheetah_bw_traj_len_1_seed_531475214.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_bw_traj_len_1_seed_531475214.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_cheetah_bw_traj_len_1_seed_903840226.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_bw_traj_len_1_seed_903840226.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_cheetah_bw_traj_len_1_seed_905588575.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_bw_traj_len_1_seed_905588575.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_cheetah_bw_traj_len_50_seed_222108450.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_bw_traj_len_50_seed_222108450.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_cheetah_bw_traj_len_50_seed_472880492.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_bw_traj_len_50_seed_472880492.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_cheetah_bw_traj_len_50_seed_890295234.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_bw_traj_len_50_seed_890295234.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_cheetah_fw_traj_len_10_seed_102743744.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_fw_traj_len_10_seed_102743744.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_cheetah_fw_traj_len_10_seed_569580504.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_fw_traj_len_10_seed_569580504.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_cheetah_fw_traj_len_10_seed_874088138.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_fw_traj_len_10_seed_874088138.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_cheetah_fw_traj_len_1_seed_22750069.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_fw_traj_len_1_seed_22750069.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_cheetah_fw_traj_len_1_seed_31823986.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_fw_traj_len_1_seed_31823986.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_cheetah_fw_traj_len_1_seed_686078893.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_fw_traj_len_1_seed_686078893.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_cheetah_fw_traj_len_50_seed_559763593.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_fw_traj_len_50_seed_559763593.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_cheetah_fw_traj_len_50_seed_592117357.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_fw_traj_len_50_seed_592117357.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_cheetah_fw_traj_len_50_seed_96575016.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_cheetah_fw_traj_len_50_seed_96575016.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_hopper_traj_len_10_seed_647574168.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_hopper_traj_len_10_seed_647574168.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_hopper_traj_len_10_seed_700423463.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_hopper_traj_len_10_seed_700423463.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_hopper_traj_len_10_seed_750917369.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_hopper_traj_len_10_seed_750917369.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_hopper_traj_len_1_seed_125775086.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_hopper_traj_len_1_seed_125775086.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_hopper_traj_len_1_seed_582846018.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_hopper_traj_len_1_seed_582846018.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_hopper_traj_len_1_seed_809943847.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_hopper_traj_len_1_seed_809943847.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_hopper_traj_len_50_seed_264928546.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_hopper_traj_len_50_seed_264928546.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_hopper_traj_len_50_seed_371855109.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_hopper_traj_len_50_seed_371855109.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_hopper_traj_len_50_seed_802647326.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_hopper_traj_len_50_seed_802647326.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_pendulum_traj_len_10_seed_177330974.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_pendulum_traj_len_10_seed_177330974.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_pendulum_traj_len_10_seed_294267497.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_pendulum_traj_len_10_seed_294267497.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_pendulum_traj_len_10_seed_610830928.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_pendulum_traj_len_10_seed_610830928.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_pendulum_traj_len_1_seed_556679729.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_pendulum_traj_len_1_seed_556679729.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_pendulum_traj_len_1_seed_583612201.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_pendulum_traj_len_1_seed_583612201.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_pendulum_traj_len_1_seed_728082477.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_pendulum_traj_len_1_seed_728082477.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_pendulum_traj_len_50_seed_196767060.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_pendulum_traj_len_50_seed_196767060.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_pendulum_traj_len_50_seed_523507457.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_pendulum_traj_len_50_seed_523507457.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_pendulum_traj_len_50_seed_960406656.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_pendulum_traj_len_50_seed_960406656.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_reach_true_traj_len_10_seed_183338627.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_reach_true_traj_len_10_seed_183338627.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_reach_true_traj_len_10_seed_48907083.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_reach_true_traj_len_10_seed_48907083.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_reach_true_traj_len_10_seed_880581313.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_reach_true_traj_len_10_seed_880581313.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_reach_true_traj_len_1_seed_326943823.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_reach_true_traj_len_1_seed_326943823.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_reach_true_traj_len_1_seed_497025194.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_reach_true_traj_len_1_seed_497025194.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_reach_true_traj_len_1_seed_571661907.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_reach_true_traj_len_1_seed_571661907.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_reach_true_traj_len_49_traj_n_1000_seed_85509514.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_reach_true_traj_len_49_traj_n_1000_seed_85509514.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_reach_true_traj_len_50_seed_44258750.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_reach_true_traj_len_50_seed_44258750.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_reach_true_traj_len_50_seed_696002925.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_reach_true_traj_len_50_seed_696002925.pkl
--------------------------------------------------------------------------------
/demonstrations/sac_reach_true_traj_len_50_seed_742173558.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/demonstrations/sac_reach_true_traj_len_50_seed_742173558.pkl
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | #FROM tensorflow/tensorflow:1.13.2-gpu
2 | FROM tensorflow/tensorflow:1.13.2
3 |
4 | # Install Anaconda
5 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
6 | ENV PATH /opt/conda/bin:$PATH
7 |
8 | RUN apt-get update --fix-missing && apt-get install -y wget bzip2 ca-certificates \
9 | libglib2.0-0 libxext6 libsm6 libxrender1 \
10 | git mercurial subversion
11 |
12 | RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-py37_4.8.2-Linux-x86_64.sh -O ~/anaconda.sh && \
13 | /bin/bash ~/anaconda.sh -b -p /opt/conda && \
14 | rm ~/anaconda.sh && \
15 | ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \
16 | echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \
17 | echo "conda activate base" >> ~/.bashrc
18 |
19 | # Install Mujoco
20 | RUN apt-get update -q \
21 | && DEBIAN_FRONTEND=noninteractive apt-get install -y \
22 | curl \
23 | git \
24 | libgl1-mesa-dev \
25 | libgl1-mesa-glx \
26 | libglew-dev \
27 | libosmesa6-dev \
28 | software-properties-common \
29 | net-tools \
30 | unzip \
31 | vim \
32 | virtualenv \
33 | wget \
34 | xpra \
35 | xserver-xorg-dev \
36 | patchelf \
37 | && apt-get clean \
38 | && rm -rf /var/lib/apt/lists/*
39 |
40 | RUN mkdir -p /root/.mujoco \
41 | && wget https://www.roboti.us/download/mujoco200_linux.zip -O mujoco.zip \
42 | && unzip mujoco.zip -d /root/.mujoco \
43 | && mv /root/.mujoco/mujoco200_linux /root/.mujoco/mujoco200 \
44 | && rm mujoco.zip
45 | COPY ./mjkey.txt /root/.mujoco/
46 | ENV LD_LIBRARY_PATH /root/.mujoco/mujoco200/bin:${LD_LIBRARY_PATH}
47 | ENV LD_LIBRARY_PATH /usr/local/nvidia/lib64:${LD_LIBRARY_PATH}
48 |
49 | RUN conda install -c menpo osmesa
50 |
51 | # Set up environment
52 | RUN mkdir /deep-rlsp
53 | COPY environment.yml /deep-rlsp/
54 | RUN conda env create -f /deep-rlsp/environment.yml
55 | #RUN conda run -n deep-rlsp pip uninstall tensorflow
56 | #RUN conda run -n deep-rlsp pip install tensorflow-gpu==1.13.2
57 | ENV PYTHONPATH "${PYTHONPATH}:/deep-rlsp/src"
58 |
--------------------------------------------------------------------------------
/docker/environment.yml:
--------------------------------------------------------------------------------
1 | name: deep-rlsp
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - attrs=19.3.0=py_0
7 | - backcall=0.1.0=py37_0
8 | - bleach=3.1.0=py_0
9 | - ca-certificates=2019.11.27=0
10 | - certifi=2019.11.28=py37_0
11 | - decorator=4.4.1=py_0
12 | - defusedxml=0.6.0=py_0
13 | - gmp=6.1.2=h6c8ec71_1
14 | - importlib_metadata=1.3.0=py37_0
15 | - ipykernel=5.1.3=py37h39e3cac_1
16 | - ipython=7.11.1=py37h39e3cac_0
17 | - ipython_genutils=0.2.0=py37_0
18 | - jedi=0.15.2=py37_0
19 | - jinja2=2.10.3=py_0
20 | - jsonschema=3.2.0=py37_0
21 | - jupyter_client=5.3.4=py37_0
22 | - jupyter_core=4.6.1=py37_0
23 | - ld_impl_linux-64=2.33.1=h53a641e_7
24 | - libedit=3.1.20181209=hc058e9b_0
25 | - libffi=3.2.1=hd88cf55_4
26 | - libgcc-ng=9.1.0=hdf63c60_0
27 | - libsodium=1.0.16=h1bed415_0
28 | - libstdcxx-ng=9.1.0=hdf63c60_0
29 | - markupsafe=1.1.1=py37h7b6447c_0
30 | - mistune=0.8.4=py37h7b6447c_0
31 | - nb_conda=2.2.1=py37_0
32 | - nb_conda_kernels=2.2.2=py37_0
33 | - nbconvert=5.6.1=py37_0
34 | - nbformat=4.4.0=py37_0
35 | - ncurses=6.1=he6710b0_1
36 | - notebook=6.0.2=py37_0
37 | - openssl=1.1.1d=h7b6447c_3
38 | - pandoc=2.2.3.2=0
39 | - pandocfilters=1.4.2=py37_1
40 | - parso=0.5.2=py_0
41 | - pip=19.3.1=py37_0
42 | - prometheus_client=0.7.1=py_0
43 | - prompt_toolkit=3.0.2=py_0
44 | - pygments=2.5.2=py_0
45 | - python=3.7.6=h0371630_2
46 | - python-dateutil=2.8.1=py_0
47 | - pyzmq=18.1.0=py37he6710b0_0
48 | - readline=7.0=h7b6447c_5
49 | - send2trash=1.5.0=py37_0
50 | - setuptools=44.0.0=py37_0
51 | - sqlite=3.30.1=h7b6447c_0
52 | - terminado=0.8.3=py37_0
53 | - testpath=0.4.4=py_0
54 | - tk=8.6.8=hbc83047_0
55 | - tornado=6.0.3=py37h7b6447c_0
56 | - traitlets=4.3.3=py37_0
57 | - webencodings=0.5.1=py37_1
58 | - wheel=0.33.6=py37_0
59 | - xz=5.2.4=h14c3975_4
60 | - zeromq=4.3.1=he6710b0_3
61 | - zlib=1.2.11=h7b6447c_3
62 | - pip:
63 | - absl-py==0.9.0
64 | - appdirs==1.4.3
65 | - astor==0.8.1
66 | - atari-py==0.2.6
67 | - black==19.10b0
68 | - cffi==1.13.2
69 | - chardet==3.0.4
70 | - click==7.0
71 | - cloudpickle==1.2.2
72 | - colorama==0.4.3
73 | - cycler==0.10.0
74 | - cython==0.29.14
75 | - docopt==0.6.2
76 | - entrypoints==0.3
77 | - fasteners==0.15
78 | - filelock==3.0.12
79 | - flake8==3.7.9
80 | - funcsigs==1.0.2
81 | - future==0.18.2
82 | - gast==0.2.2
83 | - gitdb2==2.0.6
84 | - gitpython==3.0.5
85 | - glfw==1.10.1
86 | - google-pasta==0.2.0
87 | - grpcio==1.26.0
88 | - gym==0.15.4
89 | - h5py==2.10.0
90 | - idna==2.8
91 | - imageio==2.6.1
92 | - imageio-ffmpeg==0.4.1
93 | - importlib-metadata==1.4.0
94 | - ipdb==0.12.3
95 | - ipython-genutils==0.2.0
96 | - joblib==0.14.1
97 | - jsonpickle==0.9.6
98 | - keras-applications==1.0.8
99 | - keras-preprocessing==1.1.0
100 | - kiwisolver==1.1.0
101 | - markdown==3.1.1
102 | - matplotlib==3.1.2
103 | - mccabe==0.6.1
104 | - mock==3.0.5
105 | - monotonic==1.5
106 | - more-itertools==8.1.0
107 | - munch==2.5.0
108 | - mujoco-py==2.0.2.9
109 | - mypy==0.770
110 | - mypy-extensions==0.4.3
111 | - numpy==1.20.1
112 | - opencv-python==4.1.2.30
113 | - opt-einsum==3.2.1
114 | - packaging==20.0
115 | - pandas==0.25.3
116 | - pathspec==0.7.0
117 | - pexpect==4.8.0
118 | - pickleshare==0.7.5
119 | - pillow==7.0.0
120 | - pluggy==0.13.1
121 | - protobuf==3.11.2
122 | - ptyprocess==0.6.0
123 | - py==1.8.1
124 | - py-cpuinfo==5.0.0
125 | - pycodestyle==2.5.0
126 | - pycparser==2.19
127 | - pyflakes==2.1.1
128 | - pyglet==1.3.1
129 | - pyopengl==3.1.5
130 | - pyparsing==2.4.6
131 | - pyrsistent==0.15.7
132 | - pytest==5.3.4
133 | - pytz==2019.3
134 | - pyyaml==5.3
135 | - ray==0.8.0
136 | - redis==3.3.11
137 | - regex==2020.2.20
138 | - requests==2.22.0
139 | - sacred==0.7.4
140 | - scipy==1.4.1
141 | - seaborn==0.9.0
142 | - six==1.14.0
143 | - smmap2==2.0.5
144 | - stable-baselines==2.9.0
145 | - tensorboard==1.13.1
146 | - tensorflow==1.13.2
147 | - tensorflow-estimator==1.13.0
148 | - tensorflow-probability==0.6.0
149 | - termcolor==1.1.0
150 | - toml==0.10.0
151 | - typed-ast==1.4.1
152 | - typing-extensions==3.7.4.2
153 | - urllib3==1.25.8
154 | - wcwidth==0.1.8
155 | - werkzeug==0.16.0
156 | - wrapt==1.11.2
157 | - zipp==2.0.0
158 | prefix: /home/david/anaconda3/envs/deep-rlsp
159 |
--------------------------------------------------------------------------------
/policies/sac_ant_2e6.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_ant_2e6.mp4
--------------------------------------------------------------------------------
/policies/sac_ant_2e6.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_ant_2e6.zip
--------------------------------------------------------------------------------
/policies/sac_cheetah_bw_2e6.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_cheetah_bw_2e6.mp4
--------------------------------------------------------------------------------
/policies/sac_cheetah_bw_2e6.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_cheetah_bw_2e6.zip
--------------------------------------------------------------------------------
/policies/sac_cheetah_fw_2e6.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_cheetah_fw_2e6.mp4
--------------------------------------------------------------------------------
/policies/sac_cheetah_fw_2e6.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_cheetah_fw_2e6.zip
--------------------------------------------------------------------------------
/policies/sac_hopper_2e6.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_hopper_2e6.mp4
--------------------------------------------------------------------------------
/policies/sac_hopper_2e6.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_hopper_2e6.zip
--------------------------------------------------------------------------------
/policies/sac_pendulum_6e4.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_pendulum_6e4.mp4
--------------------------------------------------------------------------------
/policies/sac_pendulum_6e4.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_pendulum_6e4.zip
--------------------------------------------------------------------------------
/policies/sac_reach_task_2e6.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_reach_task_2e6.mp4
--------------------------------------------------------------------------------
/policies/sac_reach_task_2e6.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_reach_task_2e6.zip
--------------------------------------------------------------------------------
/policies/sac_reach_true_2e6.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_reach_true_2e6.mp4
--------------------------------------------------------------------------------
/policies/sac_reach_true_2e6.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/policies/sac_reach_true_2e6.zip
--------------------------------------------------------------------------------
/scripts/create_demonstrations.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import numpy as np
3 | import tensorflow as tf
4 | import gym
5 | import pickle
6 | import sys
7 |
8 | from stable_baselines import SAC
9 |
10 | from imitation.data import types
11 | from imitation.data.rollout import unwrap_traj
12 | import deep_rlsp
13 |
14 |
15 | def convert_trajs(filename, traj_len):
16 | with open(filename, "rb") as f:
17 | data = pickle.load(f)
18 |
19 | assert traj_len < len(data["observations"][0])
20 | obs = np.array(data["observations"][0][: traj_len + 1])
21 | acts = np.array(data["actions"][0][:traj_len])
22 | rews = np.array([0 for _ in range(traj_len)])
23 | infos = [{} for _ in range(traj_len)]
24 | traj = types.TrajectoryWithRew(obs=obs, acts=acts, infos=infos, rews=rews)
25 | return [traj]
26 |
27 |
28 | def rollout_policy(filename, traj_len, seed, env_name, n_trajs=1):
29 | model = SAC.load(filename)
30 | env = gym.make(env_name)
31 | env.seed(seed)
32 |
33 | trajs = []
34 | for _ in range(int(n_trajs)):
35 | obs_list, acts_list, rews_list = [], [], []
36 | obs = env.reset()
37 | obs_list.append(obs)
38 | for _ in range(traj_len):
39 | act = model.predict(obs, deterministic=True)[0]
40 | obs, r, done, _ = env.step(act)
41 | # assert not done
42 | acts_list.append(act)
43 | obs_list.append(obs)
44 | rews_list.append(r)
45 |
46 | infos = [{} for _ in range(traj_len)]
47 | traj = types.TrajectoryWithRew(
48 | obs=np.array(obs_list),
49 | acts=np.array(acts_list),
50 | infos=infos,
51 | rews=np.array(rews_list),
52 | )
53 | trajs.append(traj)
54 |
55 | return trajs
56 |
57 |
58 | def recode_and_save_trajectories(traj_or_policy_file, save_loc, traj_len, seed, args):
59 | if "skills" in traj_or_policy_file:
60 | trajs = convert_trajs(traj_or_policy_file, traj_len, *args)
61 | else:
62 | trajs = rollout_policy(traj_or_policy_file, traj_len, seed, *args)
63 |
64 | # assert len(trajs) == 1
65 | for traj in trajs:
66 | assert traj.obs.shape[0] == traj_len + 1
67 | assert traj.acts.shape[0] == traj_len
68 | trajs = [dataclasses.replace(traj, infos=None) for traj in trajs]
69 | types.save(save_loc, trajs)
70 |
71 |
72 | if __name__ == "__main__":
73 | _, traj_or_policy_file, save_loc, traj_len, seed = sys.argv[:5]
74 | if seed == "generate_seed":
75 | seed = np.random.randint(0, 1e9)
76 | else:
77 | seed = int(seed)
78 | save_loc = save_loc.format(traj_len, seed)
79 | np.random.seed(seed)
80 | tf.random.set_random_seed(seed)
81 | traj_len = int(traj_len)
82 | recode_and_save_trajectories(
83 | traj_or_policy_file, save_loc, traj_len, seed, sys.argv[5:]
84 | )
85 |
--------------------------------------------------------------------------------
/scripts/evaluate_policy.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 |
4 | import numpy as np
5 | import gym
6 |
7 | from stable_baselines import PPO2, SAC
8 |
9 | from deep_rlsp.util.video import render_mujoco_from_obs, save_video
10 | from deep_rlsp.model.mujoco_debug_models import MujocoDebugFeatures
11 |
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("policy_file", type=str)
16 | parser.add_argument("policy_type", type=str)
17 | parser.add_argument("envname", type=str)
18 | parser.add_argument("--render", action="store_true")
19 | parser.add_argument("--out_video", type=str, default=None)
20 | parser.add_argument(
21 | "--num_rollouts", type=int, default=1, help="Number of expert rollouts"
22 | )
23 | return parser.parse_args()
24 |
25 |
26 | def main():
27 | args = parse_args()
28 | env = gym.make(args.envname)
29 |
30 | print("loading and building expert policy")
31 |
32 | if args.policy_type == "ppo":
33 | model = PPO2.load(args.policy_file)
34 |
35 | def get_action(obs):
36 | return model.predict(obs)[0]
37 |
38 | elif args.policy_type == "sac":
39 | model = SAC.load(args.policy_file)
40 |
41 | def get_action(obs):
42 | return model.predict(obs, deterministic=True)[0]
43 |
44 | elif args.policy_type == "gail":
45 | from imitation.policies import serialize
46 | from stable_baselines.common.vec_env import DummyVecEnv
47 |
48 | venv = DummyVecEnv([lambda: env])
49 | loading_context = serialize.load_policy("ppo2", args.policy_file, venv)
50 | model = loading_context.__enter__()
51 |
52 | def get_action(obs):
53 | return model.step(np.reshape(obs, (1, -1)))[0]
54 |
55 | else:
56 | raise NotImplementedError()
57 |
58 | # # env.unwrapped.viewer.cam.trackbodyid = 0
59 | # env.unwrapped.viewer.cam.fixedcamid = 0
60 |
61 | returns = []
62 | observations = []
63 | render_params = {}
64 | # render_params = {"width": 4000, "height":1000, "camera_id": -1}
65 | timesteps = 1000
66 |
67 | for i in range(args.num_rollouts):
68 | print("iter", i)
69 | if args.out_video is not None:
70 | rgbs = []
71 | obs = env.reset()
72 | done = False
73 | totalr = 0.0
74 | steps = 0
75 | for _ in range(timesteps):
76 | # print(obs)
77 | action = get_action(obs)
78 | action = env.action_space.sample()
79 | observations.append(obs)
80 |
81 | last_done = done
82 | obs, r, done, info = env.step(action)
83 |
84 | if not last_done:
85 | totalr += r
86 | steps += 1
87 | if args.render:
88 | env.render(mode="human", **render_params)
89 | if args.out_video is not None:
90 | rgb = env.render(mode="rgb_array", **render_params)
91 | rgbs.append(rgb)
92 | if steps % 100 == 0:
93 | print("%i/%i" % (steps, env.spec.max_episode_steps))
94 | if steps >= env.spec.max_episode_steps:
95 | break
96 | print("return", totalr)
97 | returns.append(totalr)
98 |
99 | if args.out_video is not None:
100 | save_video(rgbs, args.out_video, fps=20.0)
101 |
102 | print("returns", returns)
103 | print("mean return", np.mean(returns))
104 | print("std of return", np.std(returns))
105 |
106 |
107 | if __name__ == "__main__":
108 | main()
109 |
--------------------------------------------------------------------------------
/scripts/evaluate_result_policies.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import numpy as np
5 | import gym
6 |
7 | from stable_baselines import PPO2, SAC
8 |
9 | from deep_rlsp.util.video import render_mujoco_from_obs, save_video
10 | from deep_rlsp.model import StateVAE
11 | from deep_rlsp.util.results import FileExperimentResults
12 |
13 |
14 | def evaluate_policy(policy_file, policy_type, envname, num_rollouts):
15 | if policy_type == "ppo":
16 | model = PPO2.load(policy_file)
17 |
18 | def get_action(obs):
19 | return model.predict(obs)[0]
20 |
21 | elif policy_type == "sac":
22 | model = SAC.load(policy_file)
23 |
24 | def get_action(obs):
25 | return model.predict(obs, deterministic=True)[0]
26 |
27 | else:
28 | raise NotImplementedError()
29 |
30 | env = gym.make(envname)
31 |
32 | returns = []
33 | for i in range(num_rollouts):
34 | # print("iter", i, end=" ")
35 | obs = env.reset()
36 | done = False
37 | totalr = 0.0
38 | while not done:
39 | action = get_action(obs)
40 | obs, r, done, _ = env.step(action)
41 | totalr += r
42 | returns.append(totalr)
43 |
44 | return np.mean(returns), np.std(returns)
45 |
46 |
47 | def parse_args():
48 | parser = argparse.ArgumentParser()
49 | parser.add_argument("results_folder", type=str)
50 | parser.add_argument("envname", type=str)
51 | parser.add_argument("--policy_type", type=str, default="sac")
52 | parser.add_argument(
53 | "--num_rollouts", type=int, default=1, help="Number of expert rollouts"
54 | )
55 | return parser.parse_args()
56 |
57 |
58 | def main():
59 | args = parse_args()
60 |
61 | base, root_dirs, _ = next(os.walk(args.results_folder))
62 | root_dirs = [os.path.join(base, dir) for dir in root_dirs]
63 | for root in root_dirs:
64 | print(root)
65 | _, dirs, files = next(os.walk(root))
66 | if "policy.zip" in files:
67 | policy_path = os.path.join(root, "policy.zip")
68 | elif any([f.startswith("rlsp_policy") for f in files]):
69 | policy_files = [f for f in files if f.startswith("rlsp_policy")]
70 | policy_numbers = [int(f.split(".")[0].split("_")[2]) for f in policy_files]
71 | # take second to last policy, in case a run crashed while writing
72 | # the last policy
73 | policy_file = f"rlsp_policy_{max(policy_numbers)-1}.zip"
74 | policy_path = policy_path = os.path.join(root, policy_file)
75 | else:
76 | policy_path = None
77 |
78 | if policy_path is not None:
79 | mean, std = evaluate_policy(
80 | policy_path, args.policy_type, args.envname, args.num_rollouts
81 | )
82 | print(policy_path)
83 | print(f"mean {mean} std {std}")
84 |
85 |
86 | if __name__ == "__main__":
87 | main()
88 |
--------------------------------------------------------------------------------
/scripts/mujoco_evaluate_inferred_reward.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluate an inferred reward function by using it to train a policy in the original env.
3 | """
4 |
5 | import argparse
6 | import datetime
7 |
8 | import cv2
9 | import gym
10 | import numpy as np
11 | import matplotlib.pyplot as plt
12 | import tensorflow as tf
13 |
14 | from sacred import Experiment
15 | from sacred.observers import FileStorageObserver, RunObserver
16 |
17 | from stable_baselines import SAC
18 | from stable_baselines.sac.policies import MlpPolicy as MlpPolicySac
19 |
20 | from deep_rlsp.util.results import Artifact, FileExperimentResults
21 | from deep_rlsp.model import StateVAE
22 | from deep_rlsp.envs.reward_wrapper import LatentSpaceRewardWrapper
23 | from deep_rlsp.util.video import render_mujoco_from_obs
24 | from deep_rlsp.util.helper import get_trajectory, evaluate_policy
25 | from deep_rlsp.model.mujoco_debug_models import MujocoDebugFeatures, PendulumDynamics
26 | from deep_rlsp.solvers import get_sac
27 |
28 | # changes the run _id and thereby the path that the FileStorageObserver
29 | # writes the results
30 | # cf. https://github.com/IDSIA/sacred/issues/174
31 | class SetID(RunObserver):
32 | priority = 50 # very high priority to set id
33 |
34 | def started_event(
35 | self, ex_info, command, host_info, start_time, config, meta_info, _id
36 | ):
37 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
38 | label = config["experiment_folder"].strip("/").split("/")[-1]
39 | custom_id = "{}_{}".format(timestamp, label)
40 | return custom_id # started_event returns the _run._id
41 |
42 |
43 | ex = Experiment("mujoco-eval")
44 | ex.observers = [SetID(), FileStorageObserver.create("results/mujoco/eval")]
45 |
46 |
47 | def print_rollout(env, policy, latent_space, decode=False):
48 | state = env.reset()
49 | done = False
50 | while not done:
51 | a, _ = policy.predict(state, deterministic=False)
52 | state, reward, done, info = env.step(a)
53 | if decode:
54 | obs = latent_space.decoder(state)
55 | else:
56 | obs = state
57 | print("action", a)
58 | print("obs", obs)
59 | print("reward", reward)
60 |
61 |
62 | @ex.config
63 | def config():
64 | experiment_folder = None # noqa:F841
65 | iteration = -1 # noqa:F841
66 |
67 |
68 | @ex.automain
69 | def main(_run, experiment_folder, iteration, seed):
70 | ex = FileExperimentResults(experiment_folder)
71 | env_id = ex.config["env_id"]
72 | env = gym.make(env_id)
73 |
74 | if env_id == "InvertedPendulum-v2":
75 | iterations = int(6e4)
76 | else:
77 | iterations = int(2e6)
78 |
79 | label = experiment_folder.strip("/").split("/")[-1]
80 |
81 | if ex.config["debug_handcoded_features"]:
82 | latent_space = MujocoDebugFeatures(env)
83 | else:
84 | graph_latent = tf.Graph()
85 | latent_model_path = ex.info["latent_model_checkpoint"]
86 | with graph_latent.as_default():
87 | latent_space = StateVAE.restore(latent_model_path)
88 |
89 | r_inferred = ex.info["inferred_rewards"][iteration]
90 | r_inferred /= np.linalg.norm(r_inferred)
91 |
92 | print("r_inferred")
93 | print(r_inferred)
94 | if env_id.startswith("Fetch"):
95 | env_has_task_reward = True
96 | inferred_weight = 0.1
97 | else:
98 | env_has_task_reward = False
99 | inferred_weight = None
100 |
101 | env_inferred = LatentSpaceRewardWrapper(
102 | env,
103 | latent_space,
104 | r_inferred,
105 | inferred_weight=inferred_weight,
106 | use_task_reward=env_has_task_reward,
107 | )
108 |
109 | policy_inferred = get_sac(env_inferred)
110 | policy_inferred.learn(total_timesteps=iterations, log_interval=10)
111 | with Artifact(f"policy.zip", None, _run) as f:
112 | policy_inferred.save(f)
113 |
114 | print_rollout(env_inferred, policy_inferred, latent_space)
115 |
116 | N = 10
117 | true_reward_obtained = evaluate_policy(env, policy_inferred, N)
118 | print("Inferred reward policy: true return", true_reward_obtained)
119 | if env_has_task_reward:
120 | env.use_penalty = False
121 | task_reward_obtained = evaluate_policy(env, policy_inferred, N)
122 | print("Inferred reward policy: task return", task_reward_obtained)
123 | env.use_penalty = True
124 | with Artifact(f"video.mp4", None, _run) as f:
125 | inferred_reward_obtained = evaluate_policy(
126 | env_inferred, policy_inferred, N, video_out=f
127 | )
128 | print("Inferred reward policy: inferred return", inferred_reward_obtained)
129 |
130 | good_policy_path = ex.config["good_policy_path"]
131 | if good_policy_path is not None:
132 | true_reward_policy = SAC.load(good_policy_path)
133 | good_policy_true_reward_obtained = evaluate_policy(env, true_reward_policy, N)
134 | print("True reward policy: true return", good_policy_true_reward_obtained)
135 | if env_has_task_reward:
136 | env.use_penalty = False
137 | good_policy_task_reward_obtained = evaluate_policy(
138 | env, true_reward_policy, N
139 | )
140 | print("True reward policy: task return", good_policy_task_reward_obtained)
141 | env.use_penalty = True
142 | good_policy_inferred_reward_obtained = evaluate_policy(
143 | env_inferred, true_reward_policy, N
144 | )
145 | print(
146 | "True reward policy: inferred return", good_policy_inferred_reward_obtained
147 | )
148 |
149 | random_policy = SAC(MlpPolicySac, env_inferred, verbose=1)
150 | random_policy_true_reward_obtained = evaluate_policy(env, random_policy, N)
151 | print("Random policy: true return", random_policy_true_reward_obtained)
152 | if env_has_task_reward:
153 | env.use_penalty = False
154 | random_policy_task_reward_obtained = evaluate_policy(env, random_policy, N)
155 | print("Random reward policy: task return", random_policy_task_reward_obtained)
156 | env.use_penalty = True
157 | random_policy_inferred_reward_obtained = evaluate_policy(
158 | env_inferred, random_policy, N
159 | )
160 | print("Random policy: inferred return", random_policy_inferred_reward_obtained)
161 | print()
162 |
--------------------------------------------------------------------------------
/scripts/plot_learning_curves.py:
--------------------------------------------------------------------------------
1 | import pickle
2 |
3 | import matplotlib.pyplot as plt
4 | import matplotlib
5 | import seaborn as sns
6 | import numpy as np
7 |
8 |
9 | def moving_average(a, n=3):
10 | ret = np.cumsum(a, dtype=float)
11 | ret[n:] = ret[n:] - ret[:-n]
12 | return ret[n - 1 :] / n
13 |
14 |
15 | def main():
16 | n_samples = 1
17 | skill = "balancing"
18 | loss_files = [
19 | f"discriminator_cheetah_{skill}_{n_samples}_gail.pkl",
20 | f"discriminator_cheetah_{skill}_{n_samples}_rlsp.pkl",
21 | f"discriminator_cheetah_{skill}_{n_samples}_average_features.pkl",
22 | f"discriminator_cheetah_{skill}_{n_samples}_waypoints.pkl",
23 | ]
24 | labels = ["GAIL", "Deep RLSP", "AverageFeatures", "Waypoints"]
25 | outfile = f"{skill}_{n_samples}.pdf"
26 | moving_average_n = 10
27 |
28 | sns.set_context("paper", font_scale=3.6, rc={"lines.linewidth": 3})
29 | sns.set_style("white")
30 | matplotlib.rc(
31 | "font",
32 | **{
33 | "family": "serif",
34 | "serif": ["Computer Modern"],
35 | "sans-serif": ["Latin Modern"],
36 | },
37 | )
38 | matplotlib.rc("text", usetex=True)
39 |
40 | markers_every = 100
41 | markersize = 10
42 |
43 | colors = [
44 | "#377eb8",
45 | "#ff7f00",
46 | "#4daf4a",
47 | "#e41a1c",
48 | "#dede00",
49 | "#999999",
50 | "#f781bf",
51 | "#a65628",
52 | "#984ea3",
53 | ]
54 |
55 | markers = ["o", "^", "s", "v", "d", "+", "x", "."]
56 |
57 | L = len(loss_files)
58 | colors = colors[:L]
59 | markers = markers[:L]
60 |
61 | # plt.figure(figsize=(20, 6))
62 | for loss_file, label, color, marker in zip(loss_files, labels, colors, markers):
63 | print(loss_file)
64 | with open(loss_file, "rb") as f:
65 | losses = pickle.load(f)
66 | print(losses.shape)
67 |
68 | if moving_average_n > 1:
69 | losses = moving_average(losses, moving_average_n)
70 |
71 | plt.plot(
72 | range(len(losses)),
73 | losses,
74 | label=label,
75 | color=color,
76 | marker=marker,
77 | markevery=markers_every,
78 | markersize=markersize,
79 | )
80 | plt.xlabel("iterations")
81 | plt.ylabel("cross-entropy loss")
82 |
83 | # plt.legend()
84 | plt.tight_layout()
85 | plt.savefig(outfile)
86 | # plt.show()
87 |
88 |
89 | if __name__ == "__main__":
90 | main()
91 |
--------------------------------------------------------------------------------
/scripts/run_gail.py:
--------------------------------------------------------------------------------
1 | import deep_rlsp
2 | from imitation.scripts.train_adversarial import train_ex
3 | from sacred.observers import FileStorageObserver
4 | import os.path as osp
5 |
6 |
7 | def main_console():
8 | observer = FileStorageObserver(osp.join("output", "sacred", "train"))
9 | train_ex.observers.append(observer)
10 | train_ex.run_commandline()
11 |
12 |
13 | if __name__ == "__main__": # pragma: no cover
14 | main_console()
15 |
--------------------------------------------------------------------------------
/scripts/train_inverse_dynamics.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 |
4 | import gym
5 | import numpy as np
6 |
7 | from deep_rlsp.model import ExperienceReplay, InverseDynamicsMLP, InverseDynamicsMDN
8 | from deep_rlsp.util.video import render_mujoco_from_obs, save_video
9 |
10 |
11 | def parse_args():
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("env_id", type=str)
14 | parser.add_argument("--n_rollouts", default=1000, type=int)
15 | parser.add_argument("--learning_rate", default=1e-5, type=float)
16 | parser.add_argument("--n_epochs", default=100, type=int)
17 | parser.add_argument("--batch_size", default=500, type=int)
18 | parser.add_argument("--n_layers", default=5, type=int)
19 | parser.add_argument("--layer_size", default=1024, type=int)
20 | parser.add_argument("--gridworlds", action="store_true")
21 | return parser.parse_args()
22 |
23 |
24 | def main():
25 | args = parse_args()
26 |
27 | env = gym.make(args.env_id)
28 | experience_replay = ExperienceReplay(None)
29 | experience_replay.add_random_rollouts(
30 | env, env.spec.max_episode_steps, args.n_rollouts
31 | )
32 |
33 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
34 | label = "mlp_{}_{}".format(args.env_id, timestamp)
35 | tensorboard_log = "tf_logs/tf_logs_" + label
36 | checkpoint_folder = "tf_ckpt/tf_ckpt_" + label
37 |
38 | if args.gridworlds:
39 | inverse_dynamics = InverseDynamicsMDN(
40 | env,
41 | experience_replay,
42 | hidden_layer_size=args.layer_size,
43 | n_hidden_layers=args.n_layers,
44 | learning_rate=args.learning_rate,
45 | tensorboard_log=tensorboard_log,
46 | checkpoint_folder=checkpoint_folder,
47 | gauss_stdev=0.1,
48 | n_out_states=3,
49 | )
50 | else:
51 | inverse_dynamics = InverseDynamicsMLP(
52 | env,
53 | experience_replay,
54 | hidden_layer_size=args.layer_size,
55 | n_hidden_layers=args.n_layers,
56 | learning_rate=args.learning_rate,
57 | tensorboard_log=tensorboard_log,
58 | checkpoint_folder=checkpoint_folder,
59 | )
60 |
61 | inverse_dynamics.learn(
62 | n_epochs=args.n_epochs,
63 | batch_size=args.batch_size,
64 | print_evaluation=False,
65 | verbose=True,
66 | )
67 |
68 | # Evaluation
69 |
70 | if args.gridworlds:
71 | obs = env.reset()
72 | for _ in range(env.time_horizon):
73 | s = env.obs_to_s(obs)
74 |
75 | print("agent_pos", s.agent_pos)
76 | print("vase_states", s.vase_states)
77 | print(np.reshape(obs, env.obs_shape)[:, :, 0])
78 |
79 | action = env.action_space.sample()
80 | obs, reward, done, info = env.step(action)
81 | obs_bwd = inverse_dynamics.step(obs, action)
82 |
83 | s = env.obs_to_s(obs_bwd)
84 | print("bwd: agent_pos", s.agent_pos)
85 | print("bwd: vase_states", s.vase_states)
86 | print(np.reshape(obs_bwd, env.obs_shape)[:, :, 0])
87 |
88 | print()
89 | print("action", action)
90 | print()
91 | else:
92 | obs = env.reset()
93 | for _ in range(env.spec.max_episode_steps):
94 | print("obs", obs)
95 | action = env.action_space.sample()
96 | obs, reward, done, info = env.step(action)
97 | obs_bwd = inverse_dynamics.step(obs, action)
98 | with np.printoptions(suppress=True):
99 | print("action", action)
100 | print("\tbwd", obs_bwd)
101 |
102 | obs = env.reset()
103 | rgbs = []
104 | for _ in range(env.spec.max_episode_steps):
105 | obs = np.clip(obs, -30, 30)
106 | rgb = render_mujoco_from_obs(env, obs)
107 | rgbs.append(rgb)
108 | action = env.action_space.sample()
109 | obs = inverse_dynamics.step(obs, action, sample=True)
110 | save_video(rgbs, f"train_{args.env_id}.avi", 20)
111 |
112 |
113 | if __name__ == "__main__":
114 | main()
115 |
--------------------------------------------------------------------------------
/scripts/train_sac.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gym
4 | from deep_rlsp.solvers import get_sac
5 |
6 | # for envs
7 | import deep_rlsp
8 |
9 |
10 | def parse_args():
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("env_id", type=str)
13 | parser.add_argument("policy_out", type=str)
14 | parser.add_argument("--timesteps", default=int(2e6), type=int)
15 | return parser.parse_args()
16 |
17 |
18 | def main():
19 | args = parse_args()
20 | env = gym.make(args.env_id)
21 |
22 | solver = get_sac(env, learning_starts=1000)
23 | solver.learn(total_timesteps=args.timesteps)
24 | solver.save(args.policy_out)
25 |
26 |
27 | if __name__ == "__main__":
28 | main()
29 |
--------------------------------------------------------------------------------
/scripts/train_vae.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 |
4 | import gym
5 | import numpy as np
6 |
7 | from deep_rlsp.model import StateVAE, ExperienceReplay
8 |
9 |
10 | def parse_args():
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("env_id", type=str)
13 | parser.add_argument("--n_rollouts", default=100, type=int)
14 | parser.add_argument("--state_size", default=30, type=int)
15 | parser.add_argument("--learning_rate", default=1e-5, type=float)
16 | parser.add_argument("--n_epochs", default=100, type=int)
17 | parser.add_argument("--batch_size", default=500, type=int)
18 | parser.add_argument("--n_layers", default=3, type=int)
19 | parser.add_argument("--layer_size", default=512, type=int)
20 | parser.add_argument("--prior_stdev", default=1, type=float)
21 | parser.add_argument("--divergence_factor", default=0.001, type=float)
22 | return parser.parse_args()
23 |
24 |
25 | def main():
26 | args = parse_args()
27 |
28 | env = gym.make(args.env_id)
29 | experience_replay = ExperienceReplay(None)
30 | experience_replay.add_random_rollouts(
31 | env, env.spec.max_episode_steps, args.n_rollouts
32 | )
33 |
34 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
35 | label = "vae_{}_{}".format(args.env_id, timestamp)
36 | tensorboard_log = "tf_logs/tf_logs_" + label
37 | checkpoint_folder = "tf_ckpt/tf_ckpt_" + label
38 |
39 | vae = StateVAE(
40 | env.observation_space.shape[0],
41 | args.state_size,
42 | n_layers=args.n_layers,
43 | layer_size=args.layer_size,
44 | learning_rate=args.learning_rate,
45 | prior_stdev=args.prior_stdev,
46 | divergence_factor=args.divergence_factor,
47 | tensorboard_log=tensorboard_log,
48 | checkpoint_folder=checkpoint_folder,
49 | )
50 | # vae.checkpoint_folder = None
51 |
52 | vae.learn(experience_replay, args.n_epochs, args.batch_size, verbose=True)
53 |
54 | with np.printoptions(suppress=True):
55 | for _ in range(20):
56 | x = experience_replay.sample(1)[0][0]
57 | print("x", x)
58 | z = vae.encoder(x)
59 | print("z", z)
60 | x = vae.decoder(z)
61 | print("x2", x)
62 |
63 |
64 | if __name__ == "__main__":
65 | main()
66 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | setup(
4 | name="deep-rlsp",
5 | version=0.1,
6 | description="Learning What To Do by Simulating the Past",
7 | author="David Lindner, Rohin Shah, et al",
8 | python_requires=">=3.7.0",
9 | url="https://github.com/HumanCompatibleAI/deep-rlsp",
10 | packages=find_packages("src"),
11 | package_dir={"": "src"},
12 | install_requires=[
13 | "numpy>=1.13",
14 | "scipy>=0.19",
15 | "sacred==0.8.2",
16 | "ray",
17 | "stable_baselines",
18 | "tensorflow==1.13.2",
19 | "tensorflow-probability==0.6.0",
20 | "seaborn",
21 | "gym",
22 | ],
23 | test_suite="nose.collector",
24 | tests_require=["nose", "nose-cover3"],
25 | include_package_data=True,
26 | license="MIT",
27 | classifiers=[
28 | # Trove classifiers
29 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
30 | "License :: OSI Approved :: MIT License",
31 | "Programming Language :: Python",
32 | "Programming Language :: Python :: 3",
33 | "Programming Language :: Python :: 3.7",
34 | ],
35 | )
36 |
--------------------------------------------------------------------------------
/skills/balancing.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/skills/balancing.mp4
--------------------------------------------------------------------------------
/skills/balancing_rollouts.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/skills/balancing_rollouts.pkl
--------------------------------------------------------------------------------
/skills/jumping.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/skills/jumping.mp4
--------------------------------------------------------------------------------
/skills/jumping_rollouts.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/skills/jumping_rollouts.pkl
--------------------------------------------------------------------------------
/src/deep_rlsp/ablation_AverageFeatures.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | import gym
4 | import numpy as np
5 |
6 | from sacred import Experiment
7 | from sacred.observers import FileStorageObserver, RunObserver
8 |
9 | from deep_rlsp.envs.reward_wrapper import LatentSpaceRewardWrapper
10 | from deep_rlsp.model import StateVAE
11 | from deep_rlsp.util.results import Artifact, FileExperimentResults
12 | from deep_rlsp.solvers import get_sac
13 | from deep_rlsp.util.helper import evaluate_policy
14 |
15 |
16 | # changes the run _id and thereby the path that the FileStorageObserver
17 | # writes the results
18 | # cf. https://github.com/IDSIA/sacred/issues/174
19 | class SetID(RunObserver):
20 | priority = 50 # very high priority to set id
21 |
22 | def started_event(
23 | self, ex_info, command, host_info, start_time, config, meta_info, _id
24 | ):
25 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
26 | if config["result_folder"] is not None:
27 | result_folder = config["result_folder"].strip("/").split("/")[-1]
28 | custom_id = f"{timestamp}_ablation_average_features_{result_folder}"
29 | else:
30 | custom_id = f"{timestamp}_ablation_average_features"
31 | return custom_id # started_event returns the _run._id
32 |
33 |
34 | ex = Experiment("mujoco-ablation-average-features")
35 | ex.observers = [
36 | SetID(),
37 | FileStorageObserver.create("results/mujoco/ablation_average_features"),
38 | ]
39 |
40 |
41 | @ex.config
42 | def config():
43 | result_folder = None # noqa:F841
44 |
45 |
46 | @ex.automain
47 | def main(_run, result_folder, seed):
48 | ex = FileExperimentResults(result_folder)
49 | env_id = ex.config["env_id"]
50 | latent_model_checkpoint = ex.info["latent_model_checkpoint"]
51 | current_states = ex.info["current_states"]
52 |
53 | if env_id == "InvertedPendulum-v2":
54 | iterations = int(6e4)
55 | else:
56 | iterations = int(2e6)
57 |
58 | env = gym.make(env_id)
59 |
60 | latent_space = StateVAE.restore(latent_model_checkpoint)
61 |
62 | r_vec = sum([latent_space.encoder(obs) for obs in current_states])
63 | r_vec /= np.linalg.norm(r_vec)
64 | env_inferred = LatentSpaceRewardWrapper(env, latent_space, r_vec)
65 |
66 | solver = get_sac(env_inferred)
67 | solver.learn(iterations)
68 |
69 | with Artifact(f"policy.zip", None, _run) as f:
70 | solver.save(f)
71 |
72 | N = 10
73 | true_reward_obtained = evaluate_policy(env, solver, N)
74 | inferred_reward_obtained = evaluate_policy(env_inferred, solver, N)
75 | print("Policy: true return", true_reward_obtained)
76 | print("Policy: inferred return", inferred_reward_obtained)
77 |
--------------------------------------------------------------------------------
/src/deep_rlsp/ablation_Waypoints.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | import gym
4 | import numpy as np
5 |
6 | from sacred import Experiment
7 | from sacred.observers import FileStorageObserver, RunObserver
8 |
9 | from deep_rlsp.model import StateVAE
10 | from deep_rlsp.util.results import Artifact, FileExperimentResults
11 | from deep_rlsp.solvers import get_sac
12 | from deep_rlsp.util.helper import evaluate_policy
13 |
14 |
15 | # changes the run _id and thereby the path that the FileStorageObserver
16 | # writes the results
17 | # cf. https://github.com/IDSIA/sacred/issues/174
18 | class SetID(RunObserver):
19 | priority = 50 # very high priority to set id
20 |
21 | def started_event(
22 | self, ex_info, command, host_info, start_time, config, meta_info, _id
23 | ):
24 | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
25 | if config["result_folder"] is not None:
26 | result_folder = config["result_folder"].strip("/").split("/")[-1]
27 | custom_id = f"{timestamp}_ablation_waypoints_{result_folder}"
28 | else:
29 | custom_id = f"{timestamp}_ablation_waypoints"
30 | return custom_id # started_event returns the _run._id
31 |
32 |
33 | ex = Experiment("mujoco-ablation-waypoints")
34 | ex.observers = [
35 | SetID(),
36 | FileStorageObserver.create("results/mujoco/ablation_waypoints"),
37 | ]
38 |
39 |
40 | class LatentSpaceTargetStateRewardWrapper(gym.Wrapper):
41 | def __init__(self, env, latent_space, target_states):
42 | self.env = env
43 | self.latent_space = latent_space
44 | self.target_states = [ts / np.linalg.norm(ts) for ts in target_states]
45 | self.state = None
46 | self.timestep = 0
47 | super().__init__(env)
48 |
49 | def reset(self):
50 | obs = super().reset()
51 | self.state = self.latent_space.encoder(obs)
52 | self.timestep = 0
53 | return obs
54 |
55 | def step(self, action: int):
56 | action = np.clip(action, self.action_space.low, self.action_space.high)
57 | obs, true_reward, done, info = self.env.step(action)
58 | self.state = self.latent_space.encoder(obs)
59 |
60 | waypoints = [np.dot(ts, self.state) for ts in self.target_states]
61 | reward = max(waypoints)
62 |
63 | # remove termination criteria from mujoco environments
64 | self.timestep += 1
65 | done = self.timestep > self.env.spec.max_episode_steps
66 |
67 | return obs, reward, done, info
68 |
69 |
70 | @ex.config
71 | def config():
72 | result_folder = None # noqa:F841
73 |
74 |
75 | @ex.automain
76 | def main(_run, result_folder, seed):
77 | # result_folder = "results/mujoco/20200706_153755_HalfCheetah-FW-v2_optimal_50"
78 | ex = FileExperimentResults(result_folder)
79 | env_id = ex.config["env_id"]
80 | latent_model_checkpoint = ex.info["latent_model_checkpoint"]
81 | current_states = ex.info["current_states"]
82 |
83 | if env_id == "InvertedPendulum-v2":
84 | iterations = int(6e4)
85 | else:
86 | iterations = int(2e6)
87 |
88 | env = gym.make(env_id)
89 |
90 | latent_space = StateVAE.restore(latent_model_checkpoint)
91 |
92 | target_states = [latent_space.encoder(obs) for obs in current_states]
93 | env_inferred = LatentSpaceTargetStateRewardWrapper(env, latent_space, target_states)
94 |
95 | solver = get_sac(env_inferred)
96 | solver.learn(iterations)
97 |
98 | with Artifact(f"policy.zip", None, _run) as f:
99 | solver.save(f)
100 |
101 | N = 10
102 | true_reward_obtained = evaluate_policy(env, solver, N)
103 | inferred_reward_obtained = evaluate_policy(env_inferred, solver, N)
104 | print("Policy: true return", true_reward_obtained)
105 | print("Policy: inferred return", inferred_reward_obtained)
106 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/__init__.py
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/gridworlds/__init__.py:
--------------------------------------------------------------------------------
1 | from .room import RoomEnv
2 | from .apples import ApplesEnv
3 | from .train import TrainEnv
4 | from .batteries import BatteriesEnv
5 |
6 | from .room_spec import ROOM_PROBLEMS
7 | from .apples_spec import APPLES_PROBLEMS
8 | from .train_spec import TRAIN_PROBLEMS
9 | from .batteries_spec import BATTERIES_PROBLEMS
10 |
11 | TOY_PROBLEMS = {
12 | "room": ROOM_PROBLEMS,
13 | "apples": APPLES_PROBLEMS,
14 | "train": TRAIN_PROBLEMS,
15 | "batteries": BATTERIES_PROBLEMS,
16 | }
17 |
18 | TOY_ENV_CLASSES = {
19 | "room": RoomEnv,
20 | "apples": ApplesEnv,
21 | "train": TrainEnv,
22 | "batteries": BatteriesEnv,
23 | }
24 |
25 | __all__ = [
26 | "RoomEnv",
27 | "ApplesEnv",
28 | "TrainEnv",
29 | "BatteriesEnv",
30 | ]
31 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/gridworlds/apples_spec.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from deep_rlsp.envs.gridworlds.apples import ApplesState
3 | from deep_rlsp.envs.gridworlds.env import Direction
4 |
5 |
6 | class ApplesSpec(object):
7 | def __init__(
8 | self,
9 | height,
10 | width,
11 | init_state,
12 | apple_regen_probability,
13 | bucket_capacity,
14 | include_location_features,
15 | ):
16 | """See ApplesEnv.__init__ in apples.py for details."""
17 | self.height = height
18 | self.width = width
19 | self.init_state = init_state
20 | self.apple_regen_probability = apple_regen_probability
21 | self.bucket_capacity = bucket_capacity
22 | self.include_location_features = include_location_features
23 |
24 |
25 | # In the diagrams below, T is a tree, B is a bucket, C is a carpet, A is the
26 | # agent. Each tuple is of the form (spec, current state, task R, true R).
27 |
28 | APPLES_PROBLEMS = {
29 | # -----
30 | # |T T|
31 | # | |
32 | # | B |
33 | # | |
34 | # |A T|
35 | # -----
36 | # After 11 actions (riuiruuildi), it looks like this:
37 | # -----
38 | # |T T|
39 | # | A |
40 | # | B |
41 | # | |
42 | # | T|
43 | # -----
44 | # Where the agent has picked the right trees once and put the fruit in the
45 | # basket.
46 | "default": (
47 | ApplesSpec(
48 | 5,
49 | 3,
50 | ApplesState(
51 | agent_pos=(0, 0, 2),
52 | tree_states={(0, 0): True, (2, 0): True, (2, 4): True},
53 | bucket_states={(1, 2): 0},
54 | carrying_apple=False,
55 | ),
56 | apple_regen_probability=0.1,
57 | bucket_capacity=10,
58 | include_location_features=True,
59 | ),
60 | ApplesState(
61 | agent_pos=(Direction.get_number_from_direction(Direction.SOUTH), 1, 1),
62 | tree_states={(0, 0): True, (2, 0): False, (2, 4): True},
63 | bucket_states={(1, 2): 2},
64 | carrying_apple=False,
65 | ),
66 | np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
67 | np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
68 | )
69 | }
70 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/gridworlds/basic_room.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from itertools import product
3 |
4 | from deep_rlsp.envs.gridworlds.env import Env, Direction, get_grid_representation
5 |
6 |
7 | class BasicRoomEnv(Env):
8 | """
9 | Basic empty room with stochastic transitions. Used for debugging.
10 | """
11 |
12 | def __init__(self, prob, use_pixels_as_observations=True):
13 | self.height = 3
14 | self.width = 3
15 | self.init_state = (1, 1)
16 | self.prob = prob
17 | self.nS = self.height * self.width
18 | self.nA = 5
19 |
20 | super().__init__(1, use_pixels_as_observations=use_pixels_as_observations)
21 |
22 | self.num_features = 2
23 | self.default_action = Direction.get_number_from_direction(Direction.STAY)
24 | self.num_features = len(self.s_to_f(self.init_state))
25 |
26 | self.reset()
27 |
28 | states = self.enumerate_states()
29 | self.make_transition_matrices(states, range(self.nA), self.nS, self.nA)
30 | self.make_f_matrix(self.nS, self.num_features)
31 |
32 | def enumerate_states(self):
33 | return product(range(self.width), range(self.height))
34 |
35 | def get_num_from_state(self, state):
36 | return np.ravel_multi_index(state, (self.width, self.height))
37 |
38 | def get_state_from_num(self, num):
39 | return np.unravel_index(num, (self.width, self.height))
40 |
41 | def s_to_f(self, s):
42 | return s
43 |
44 | def _obs_to_f(self, obs):
45 | return np.unravel_index(obs[0].argmax(), obs[0].shape)
46 |
47 | def _s_to_obs(self, s):
48 | layers = [[s]]
49 | obs = get_grid_representation(self.width, self.height, layers)
50 | return np.array(obs, dtype=np.float32)
51 |
52 | # render_width = 64
53 | # render_height = 64
54 | # x, y = s
55 | # obs = np.zeros((3, render_height, render_width), dtype=np.float32)
56 | # obs[
57 | # :,
58 | # y * render_height : (y + 1) * render_height,
59 | # x * render_width : (x + 1) * render_width,
60 | # ] = 1
61 | # return obs
62 |
63 | def get_next_states(self, state, action):
64 | # next_states = []
65 | # for a in range(self.nA):
66 | # next_s = self.get_next_state(state, a)
67 | # p = 1 - self.prob if a == action else self.prob / (self.nA - 1)
68 | # next_states.append((p, next_s, 0))
69 |
70 | next_s = self.get_next_state(state, action)
71 | next_states = [(self.prob, next_s, 0), (1 - self.prob, state, 0)]
72 | return next_states
73 |
74 | def get_next_state(self, state, action):
75 | """Returns the next state given a state and an action."""
76 | action = int(action)
77 |
78 | if action == Direction.get_number_from_direction(Direction.STAY):
79 | pass
80 | elif action < len(Direction.ALL_DIRECTIONS):
81 | move_x, move_y = Direction.move_in_direction_number(state, action)
82 | # New position is legal
83 | if 0 <= move_x < self.width and 0 <= move_y < self.height:
84 | state = move_x, move_y
85 | else:
86 | # Move only changes orientation, which we already handled
87 | pass
88 | else:
89 | raise ValueError("Invalid action {}".format(action))
90 |
91 | return state
92 |
93 | def s_to_ansi(self, state):
94 | return str(self.s_to_obs(state))
95 |
96 |
97 | if __name__ == "__main__":
98 | from gym.utils.play import play
99 |
100 | env = BasicRoomEnv(1)
101 | play(env, fps=5)
102 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/gridworlds/batteries_spec.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from deep_rlsp.envs.gridworlds.batteries import BatteriesState
3 |
4 |
5 | class BatteriesSpec(object):
6 | def __init__(self, height, width, init_state, feature_locations, train_transition):
7 | """See BatteriesEnv.__init__ in batteries.py for details."""
8 | self.height = height
9 | self.width = width
10 | self.init_state = init_state
11 | self.feature_locations = feature_locations
12 | self.train_transition = train_transition
13 |
14 |
15 | def get_problem(version):
16 | # In the diagram below, G is a goal location, B is a battery, A is the
17 | # agent, and T is the train.
18 | # Each tuple is of the form (spec, current state, task R, true R).
19 | # -------
20 | # |B G |
21 | # | TT |
22 | # | TTG|
23 | # | |
24 | # |A B|
25 | # -------
26 | spec = BatteriesSpec(
27 | 5,
28 | 5,
29 | BatteriesState(
30 | (0, 4), # agent pos
31 | (2, 1), # train pos
32 | 8, # train battery (/life)
33 | {(0, 0): True, (4, 4): True}, # batteries present in environment
34 | False, # agent carrying battery
35 | ),
36 | [(2, 0), (4, 2)], # goals
37 | # train transitions
38 | {(2, 1): (3, 1), (3, 1): (3, 2), (3, 2): (2, 2), (2, 2): (2, 1)},
39 | )
40 | final_state = BatteriesState(
41 | (2, 0), # agent pos
42 | (3, 2), # train pos
43 | 8, # train battery (/ life)
44 | {(0, 0): False, (4, 4): True}, # batteries present in environment
45 | False, # agent carrying battery
46 | )
47 | train_weight = -1 if version == "easy" else 0
48 | task_reward = np.array([0, train_weight, 0, 0, 0, 0, 0, 1])
49 | true_reward = np.array([0, -1, 0, 0, 0, 0, 0, 1])
50 | return (spec, final_state, task_reward, true_reward)
51 |
52 |
53 | BATTERIES_PROBLEMS = {"default": get_problem("default"), "easy": get_problem("easy")}
54 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/gridworlds/gym_envs.py:
--------------------------------------------------------------------------------
1 | import gym
2 |
3 | from deep_rlsp.envs.gridworlds.one_hot_action_space_wrapper import (
4 | OneHotActionSpaceWrapper,
5 | )
6 | from deep_rlsp.envs.gridworlds import TOY_PROBLEMS, TOY_ENV_CLASSES
7 |
8 |
9 | def get_gym_gridworld_id_time_limit(env_name, env_spec):
10 | id = env_name + "_" + env_spec
11 | id = "".join([s.capitalize() for s in id.split("_")])
12 | id += "-v0"
13 | spec, _, _, _ = TOY_PROBLEMS[env_name][env_spec]
14 | env = TOY_ENV_CLASSES[env_name](spec)
15 | return id, env.time_horizon
16 |
17 |
18 | def make_gym_gridworld(env_name, env_spec):
19 | spec, _, _, _ = TOY_PROBLEMS[env_name][env_spec]
20 | env = TOY_ENV_CLASSES[env_name](spec)
21 | env = OneHotActionSpaceWrapper(env)
22 | return env
23 |
24 |
25 | def get_gym_gridworld(env_name, env_spec):
26 | id, time_horizon = get_gym_gridworld_id_time_limit(env_name, env_spec)
27 | env = gym.make(id)
28 | return env
29 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/gridworlds/one_hot_action_space_wrapper.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 |
4 |
5 | class OneHotActionSpaceWrapper(gym.ActionWrapper):
6 | def __init__(self, env):
7 | assert isinstance(env.action_space, gym.spaces.Discrete)
8 | super().__init__(env)
9 | self.n_actions = env.action_space.n
10 | self.action_space = gym.spaces.Box(-np.inf, np.inf, (self.n_actions,))
11 |
12 | def step(self, action, **kwargs):
13 | return self.env.step(self.action(action), **kwargs)
14 |
15 | def action(self, one_hot_action):
16 | assert one_hot_action.shape == (self.n_actions,)
17 | action = np.argmax(one_hot_action)
18 | return action
19 |
20 | def reverse_action(self, action):
21 | return np.arange(self.n_actions) == action
22 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/gridworlds/room_spec.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from deep_rlsp.envs.gridworlds.room import RoomState
4 |
5 |
6 | class RoomSpec(object):
7 | def __init__(
8 | self,
9 | height,
10 | width,
11 | init_state,
12 | carpet_locations,
13 | feature_locations,
14 | use_pixels_as_observations,
15 | ):
16 | """See RoomEnv.__init__ in room.py for details."""
17 | self.height = height
18 | self.width = width
19 | self.init_state = init_state
20 | self.carpet_locations = carpet_locations
21 | self.feature_locations = feature_locations
22 | self.use_pixels_as_observations = use_pixels_as_observations
23 |
24 |
25 | # In the diagrams below, G is a goal location, V is a vase, C is a carpet, A is
26 | # the agent. Each tuple is of the form (spec, current state, task R, true R).
27 |
28 | ROOM_PROBLEMS = {
29 | # -------
30 | # | G |
31 | # |GCVC |
32 | # | A |
33 | # -------
34 | "default": (
35 | RoomSpec(
36 | 3,
37 | 5,
38 | RoomState((2, 2), {(2, 1): True}),
39 | [(1, 1), (3, 1)],
40 | [(0, 1), (2, 0)],
41 | False,
42 | ),
43 | RoomState((2, 0), {(2, 1): True}),
44 | np.array([0, 0, 1, 0]),
45 | np.array([-1, 0, 1, 0]),
46 | ),
47 | "default_pixel": (
48 | RoomSpec(
49 | 3,
50 | 5,
51 | RoomState((2, 2), {(2, 1): True}),
52 | [(1, 1), (3, 1)],
53 | [(0, 1), (2, 0)],
54 | True,
55 | ),
56 | RoomState((2, 0), {(2, 1): True}),
57 | np.array([0, 0, 1, 0]),
58 | np.array([-1, 0, 1, 0]),
59 | ),
60 | # -------
61 | # | G |
62 | # |GCVCA|
63 | # | |
64 | # -------
65 | "alt": (
66 | RoomSpec(
67 | 3,
68 | 5,
69 | RoomState((4, 1), {(2, 1): True}),
70 | [(1, 1), (3, 1)],
71 | [(0, 1), (2, 0)],
72 | False,
73 | ),
74 | RoomState((2, 0), {(2, 1): True}),
75 | np.array([0, 0, 1, 0]),
76 | np.array([-1, 0, 1, 0]),
77 | ),
78 | # -------
79 | # |G VG|
80 | # | |
81 | # |A C |
82 | # -------
83 | "bad": (
84 | RoomSpec(
85 | 3, 5, RoomState((0, 2), {(3, 0): True}), [(3, 2)], [(0, 0), (4, 0)], False
86 | ),
87 | RoomState((0, 0), {(3, 0): True}),
88 | np.array([0, 0, 0, 1]),
89 | np.array([-1, 0, 0, 1]),
90 | ),
91 | # -------
92 | "big": (
93 | RoomSpec(
94 | 10,
95 | 10,
96 | RoomState(
97 | (0, 2),
98 | {
99 | (0, 5): True,
100 | # (0, 9): True,
101 | # (1, 2): True,
102 | # (2, 4): True,
103 | # (2, 5): True,
104 | # (2, 6): True,
105 | # (3, 1): True,
106 | # (3, 3): True,
107 | # (3, 8): True,
108 | # (4, 2): True,
109 | # (4, 4): True,
110 | # (4, 5): True,
111 | # (4, 6): True,
112 | # (4, 9): True,
113 | # (5, 3): True,
114 | # (5, 5): True,
115 | # (5, 7): True,
116 | # (5, 8): True,
117 | # (6, 1): True,
118 | # (6, 2): True,
119 | # (6, 4): True,
120 | # (6, 7): True,
121 | # (6, 9): True,
122 | # (7, 2): True,
123 | # (7, 5): True,
124 | # (7, 8): True,
125 | # (8, 6): True,
126 | # (9, 0): True,
127 | # (9, 2): True,
128 | # (9, 6): True,
129 | },
130 | ),
131 | [(1, 1), (1, 3), (2, 0), (6, 6), (7, 3)],
132 | [(0, 4), (0, 7), (5, 1), (8, 7), (9, 3)],
133 | False,
134 | ),
135 | RoomState(
136 | (0, 7),
137 | {
138 | (0, 5): True,
139 | # (0, 9): True,
140 | # (1, 2): True,
141 | # (2, 4): True,
142 | # (2, 5): True,
143 | # (2, 6): True,
144 | # (3, 1): True,
145 | # (3, 3): True,
146 | # (3, 8): True,
147 | # (4, 2): True,
148 | # (4, 4): True,
149 | # (4, 5): True,
150 | # (4, 6): True,
151 | # (4, 9): True,
152 | # (5, 3): True,
153 | # (5, 5): True,
154 | # (5, 7): True,
155 | # (5, 8): True,
156 | # (6, 1): True,
157 | # (6, 2): True,
158 | # (6, 4): True,
159 | # (6, 7): True,
160 | # (6, 9): True,
161 | # (7, 2): True,
162 | # (7, 5): True,
163 | # (7, 8): True,
164 | # (8, 6): True,
165 | # (9, 0): True,
166 | # (9, 2): True,
167 | # (9, 6): True,
168 | },
169 | ),
170 | np.array([0, 0, 0, 0, 0, 0, 1]),
171 | np.array([-1, 0, 0, 0, 0, 0, 1]),
172 | ),
173 | }
174 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/gridworlds/tests/apples_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 |
5 | from deep_rlsp.envs.gridworlds.apples import ApplesState, ApplesEnv
6 | from deep_rlsp.envs.gridworlds.tests.common import BaseTests, get_directions
7 |
8 |
9 | class TestApplesSpec(object):
10 | def __init__(self):
11 | """Test spec for the Apples environment.
12 |
13 | T is a tree, B is a bucket, C is a carpet, A is the agent.
14 | -----
15 | |T T|
16 | | |
17 | |AB |
18 | -----
19 | """
20 | self.height = 3
21 | self.width = 5
22 | self.init_state = ApplesState(
23 | agent_pos=(0, 0, 2),
24 | tree_states={(0, 0): True, (2, 0): True},
25 | bucket_states={(1, 2): 0},
26 | carrying_apple=False,
27 | )
28 | # Use a power of 2, to avoid rounding issues
29 | self.apple_regen_probability = 1.0 / 4
30 | self.bucket_capacity = 10
31 | self.include_location_features = True
32 |
33 |
34 | class TestApplesEnv(BaseTests.TestEnv):
35 | def setUp(self):
36 | self.env = ApplesEnv(TestApplesSpec())
37 |
38 | u, d, l, r, s = get_directions()
39 | i = 5 # interact action
40 |
41 | def make_state(agent_pos, tree1, tree2, bucket, carrying_apple):
42 | tree_states = {(0, 0): tree1, (2, 0): tree2}
43 | bucket_state = {(1, 2): bucket}
44 | return ApplesState(agent_pos, tree_states, bucket_state, carrying_apple)
45 |
46 | self.trajectories = [
47 | [
48 | (u, (make_state((u, 0, 1), True, True, 0, False), 1.0)),
49 | (i, (make_state((u, 0, 1), False, True, 0, True), 1.0)),
50 | (r, (make_state((r, 1, 1), False, True, 0, True), 3.0 / 4)),
51 | (d, (make_state((d, 1, 1), False, True, 0, True), 3.0 / 4)),
52 | (i, (make_state((d, 1, 1), False, True, 1, False), 3.0 / 4)),
53 | (u, (make_state((u, 1, 0), False, True, 1, False), 3.0 / 4)),
54 | (r, (make_state((r, 1, 0), False, True, 1, False), 3.0 / 4)),
55 | (i, (make_state((r, 1, 0), False, False, 1, True), 3.0 / 4)),
56 | (d, (make_state((d, 1, 1), False, False, 1, True), 9.0 / 16)),
57 | (i, (make_state((d, 1, 1), True, False, 2, False), 3.0 / 16)),
58 | (s, (make_state((d, 1, 1), True, True, 2, False), 1.0 / 4)),
59 | ]
60 | ]
61 |
62 |
63 | class TestApplesModel(BaseTests.TestTabularTransitionModel):
64 | def setUp(self):
65 | self.env = ApplesEnv(TestApplesSpec())
66 | self.model_tests = []
67 |
68 | _, _, _, _, stay = get_directions()
69 | policy_stay = np.zeros((self.env.nS, self.env.nA))
70 | policy_stay[:, stay] = 1
71 |
72 | def make_state(apple1_present, apple2_present):
73 | return ApplesState(
74 | agent_pos=(0, 1, 1),
75 | tree_states={(0, 0): apple1_present, (2, 0): apple2_present},
76 | bucket_states={(1, 2): 0},
77 | carrying_apple=False,
78 | )
79 |
80 | state_0_0 = make_state(False, False)
81 | state_0_1 = make_state(False, True)
82 | state_1_0 = make_state(True, False)
83 | state_1_1 = make_state(True, True)
84 |
85 | forward_probs = np.zeros(self.env.nS)
86 | forward_probs[self.env.get_num_from_state(state_1_1)] = 1
87 | backward_probs = np.zeros(self.env.nS)
88 | backward_probs[self.env.get_num_from_state(state_0_0)] = 0.04
89 | backward_probs[self.env.get_num_from_state(state_0_1)] = 0.16
90 | backward_probs[self.env.get_num_from_state(state_1_0)] = 0.16
91 | backward_probs[self.env.get_num_from_state(state_1_1)] = 0.64
92 | transitions = [(state_1_1, 1, forward_probs, backward_probs)]
93 |
94 | unif = np.ones(self.env.nS) / self.env.nS
95 | self.model_tests.append(
96 | {
97 | "policy": policy_stay,
98 | "transitions": transitions,
99 | "initial_state_distribution": unif,
100 | }
101 | )
102 |
103 |
104 | if __name__ == "__main__":
105 | unittest.main()
106 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/gridworlds/tests/batteries_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from deep_rlsp.envs.gridworlds.batteries import BatteriesState, BatteriesEnv
4 | from deep_rlsp.envs.gridworlds.tests.common import BaseTests, get_directions
5 |
6 |
7 | class TestBatteriesSpec(object):
8 | def __init__(self):
9 | """Test spec for the Batteries environment.
10 |
11 | G is a goal location, B is a battery, A is the agent, and T is the train.
12 | -------
13 | |B G |
14 | | TT |
15 | | TTG|
16 | | |
17 | |A B|
18 | -------
19 | """
20 | self.height = 5
21 | self.width = 5
22 | self.init_state = BatteriesState(
23 | (0, 4), (2, 1), 8, {(0, 0): True, (4, 4): True}, False
24 | )
25 | self.feature_locations = [(2, 0), (4, 2)]
26 | self.train_transition = {
27 | (2, 1): (3, 1),
28 | (3, 1): (3, 2),
29 | (3, 2): (2, 2),
30 | (2, 2): (2, 1),
31 | }
32 |
33 |
34 | class TestBatteriesEnv(BaseTests.TestEnv):
35 | def setUp(self):
36 | self.env = BatteriesEnv(TestBatteriesSpec())
37 | u, d, l, r, s = get_directions()
38 |
39 | def make_state(agent, train, life, battery_vals, carrying_battery):
40 | battery_present = dict(zip([(0, 0), (4, 4)], battery_vals))
41 | return BatteriesState(agent, train, life, battery_present, carrying_battery)
42 |
43 | self.trajectories = [
44 | [
45 | (u, (make_state((0, 3), (3, 1), 7, [True, True], False), 1.0)),
46 | (u, (make_state((0, 2), (3, 2), 6, [True, True], False), 1.0)),
47 | (u, (make_state((0, 1), (2, 2), 5, [True, True], False), 1.0)),
48 | (u, (make_state((0, 0), (2, 1), 4, [False, True], True), 1.0)),
49 | (r, (make_state((1, 0), (3, 1), 3, [False, True], True), 1.0)),
50 | (r, (make_state((2, 0), (3, 2), 2, [False, True], True), 1.0)),
51 | (s, (make_state((2, 0), (2, 2), 1, [False, True], True), 1.0)),
52 | (s, (make_state((2, 0), (2, 1), 0, [False, True], True), 1.0)),
53 | (d, (make_state((2, 1), (3, 1), 9, [False, True], False), 1.0)),
54 | (u, (make_state((2, 0), (3, 2), 8, [False, True], False), 1.0)),
55 | ]
56 | ]
57 |
58 |
59 | @unittest.skip("runs very long")
60 | class TestBatteriesModel(BaseTests.TestTabularTransitionModel):
61 | def setUp(self):
62 | self.env = BatteriesEnv(TestBatteriesSpec())
63 | self.model_tests = []
64 | self.setUpDeterministic()
65 |
66 |
67 | if __name__ == "__main__":
68 | unittest.main()
69 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/gridworlds/tests/env_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from deep_rlsp.envs.gridworlds.env import Direction
4 |
5 |
6 | class TestDirection(unittest.TestCase):
7 | def test_direction_number_conversion(self):
8 | all_directions = Direction.ALL_DIRECTIONS
9 | all_numbers = []
10 |
11 | for direction in Direction.ALL_DIRECTIONS:
12 | number = Direction.get_number_from_direction(direction)
13 | direction_again = Direction.get_direction_from_number(number)
14 | self.assertEqual(direction, direction_again)
15 | all_numbers.append(number)
16 |
17 | # Check that all directions are distinct
18 | num_directions = len(all_directions)
19 | self.assertEqual(len(set(all_directions)), num_directions)
20 | # Check that the numbers are 0, 1, ... num_directions - 1
21 | self.assertEqual(set(all_numbers), set(range(num_directions)))
22 |
23 |
24 | if __name__ == "__main__":
25 | unittest.main()
26 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/gridworlds/tests/room_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 |
5 | from deep_rlsp.envs.gridworlds.room import RoomState, RoomEnv
6 | from deep_rlsp.envs.gridworlds.tests.common import (
7 | BaseTests,
8 | get_directions,
9 | get_two_action_uniform_policy,
10 | )
11 |
12 |
13 | class TestRoomSpec(object):
14 | def __init__(self):
15 | """Test spec for the Room environment.
16 |
17 | G is a goal location, V is a vase, C is a carpet, A is the agent.
18 | -------
19 | |G G G|
20 | | CVC |
21 | | A |
22 | -------
23 | """
24 | self.height = 3
25 | self.width = 5
26 | self.init_state = RoomState((2, 2), {(2, 1): True})
27 | self.carpet_locations = [(1, 1), (3, 1)]
28 | self.feature_locations = [(0, 0), (2, 0), (4, 0)]
29 | self.use_pixels_as_observations = False
30 |
31 |
32 | class TestRoomEnv(BaseTests.TestEnv):
33 | def setUp(self):
34 | self.env = RoomEnv(TestRoomSpec())
35 | u, d, l, r, s = get_directions()
36 |
37 | self.trajectories = [
38 | [
39 | (l, (RoomState((1, 2), {(2, 1): True}), 1.0)),
40 | (u, (RoomState((1, 1), {(2, 1): True}), 1.0)),
41 | (u, (RoomState((1, 0), {(2, 1): True}), 1.0)),
42 | (r, (RoomState((2, 0), {(2, 1): True}), 1.0)),
43 | ],
44 | [
45 | (u, (RoomState((2, 1), {(2, 1): False}), 1.0)),
46 | (u, (RoomState((2, 0), {(2, 1): False}), 1.0)),
47 | ],
48 | [
49 | (r, (RoomState((3, 2), {(2, 1): True}), 1.0)),
50 | (u, (RoomState((3, 1), {(2, 1): True}), 1.0)),
51 | (l, (RoomState((2, 1), {(2, 1): False}), 1.0)),
52 | (d, (RoomState((2, 2), {(2, 1): False}), 1.0)),
53 | ],
54 | ]
55 |
56 |
57 | class TestRoomModel(BaseTests.TestTabularTransitionModel):
58 | def setUp(self):
59 | self.env = RoomEnv(TestRoomSpec())
60 | self.model_tests = []
61 |
62 | u, d, l, r, s = get_directions()
63 | policy_left_right = get_two_action_uniform_policy(
64 | self.env.nS, self.env.nA, l, r
65 | )
66 | state_middle = RoomState((2, 1), {(2, 1): False})
67 | state_left_vase = RoomState((1, 1), {(2, 1): True})
68 | state_right_vase = RoomState((3, 1), {(2, 1): True})
69 | state_left_novase = RoomState((1, 1), {(2, 1): False})
70 | state_right_novase = RoomState((3, 1), {(2, 1): False})
71 | forward_probs = np.zeros(self.env.nS)
72 | forward_probs[self.env.get_num_from_state(state_left_novase)] = 0.5
73 | forward_probs[self.env.get_num_from_state(state_right_novase)] = 0.5
74 | backward_probs = np.zeros(self.env.nS)
75 | backward_probs[self.env.get_num_from_state(state_left_vase)] = 0.25
76 | backward_probs[self.env.get_num_from_state(state_right_vase)] = 0.25
77 | backward_probs[self.env.get_num_from_state(state_left_novase)] = 0.25
78 | backward_probs[self.env.get_num_from_state(state_right_novase)] = 0.25
79 | transitions = [(state_middle, 1, forward_probs, backward_probs)]
80 | unif = np.ones(self.env.nS) / self.env.nS
81 | self.model_tests.append(
82 | {
83 | "policy": policy_left_right,
84 | "transitions": transitions,
85 | "initial_state_distribution": unif,
86 | }
87 | )
88 |
89 | self.setUpDeterministic()
90 |
91 |
92 | if __name__ == "__main__":
93 | unittest.main()
94 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/gridworlds/tests/test_observation_spaces.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 |
5 | from deep_rlsp.envs.gridworlds import TOY_PROBLEMS
6 | from deep_rlsp.envs.gridworlds.gym_envs import get_gym_gridworld
7 |
8 |
9 | class TestObservationSpaces(unittest.TestCase):
10 | # def test_observation_state_consistency(self):
11 | # for env_name, problems in TOY_PROBLEMS.items():
12 | # for env_spec in problems.keys():
13 | # _, cur_state, _, _ = TOY_PROBLEMS[env_name][env_spec]
14 | # env = get_gym_gridworld(env_name, env_spec)
15 | # for s in [env.init_state, cur_state]:
16 | # obs = env.s_to_obs(s)
17 | # s2 = env.obs_to_s(obs)
18 | # # if s != s2:
19 | # # import pdb
20 | # #
21 | # # pdb.set_trace()
22 | # self.assertEqual(s, s2)
23 | #
24 | # def test_no_error_getting_state_from_random_obs(self):
25 | # np.random.seed(29)
26 | # for env_name, problems in TOY_PROBLEMS.items():
27 | # for env_spec in problems.keys():
28 | # env = get_gym_gridworld(env_name, env_spec)
29 | # for _ in range(100):
30 | # obs = np.random.random(env.obs_shape) * 4 - 2
31 | # state = env.obs_to_s(obs)
32 | # # if state not in env.state_num:
33 | # # import pdb
34 | # #
35 | # # pdb.set_trace()
36 | # self.assertIn(state, env.state_num)
37 |
38 | def test_obs_state_consistent_in_rollout(self):
39 | np.random.seed(29)
40 | for env_name, problems in TOY_PROBLEMS.items():
41 | for env_spec in problems.keys():
42 | env = get_gym_gridworld(env_name, env_spec)
43 | if not env.use_pixels_as_observations:
44 | obs = env.reset()
45 | done = False
46 | while not done:
47 | action = env.action_space.sample()
48 | obs, reward, done, info = env.step(action)
49 | s = env.obs_to_s(obs)
50 | if s != env.s:
51 | import pdb
52 |
53 | pdb.set_trace()
54 | self.assertEqual(s, env.s)
55 |
56 |
57 | if __name__ == "__main__":
58 | unittest.main()
59 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/gridworlds/tests/train_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from deep_rlsp.envs.gridworlds.train import TrainState, TrainEnv
4 | from deep_rlsp.envs.gridworlds.tests.common import BaseTests, get_directions
5 |
6 |
7 | class TestTrainSpec(object):
8 | def __init__(self):
9 | """Test spec for the Train environment.
10 |
11 | G is a goal location, V is a vase, C is a carpet, A is the agent.
12 | -------
13 | | G C|
14 | | TT |
15 | | VTTG|
16 | | |
17 | |A |
18 | -------
19 | """
20 | self.height = 5
21 | self.width = 5
22 | self.init_state = TrainState((0, 4), {(1, 2): True}, (2, 1), True)
23 | self.carpet_locations = [(4, 0)]
24 | self.feature_locations = ([(2, 0), (4, 2)],)
25 | self.train_transition = {
26 | (2, 1): (3, 1),
27 | (3, 1): (3, 2),
28 | (3, 2): (2, 2),
29 | (2, 2): (2, 1),
30 | }
31 |
32 |
33 | class TestTrainEnv(BaseTests.TestEnv):
34 | def setUp(self):
35 | self.env = TrainEnv(TestTrainSpec())
36 | u, d, l, r, s = get_directions()
37 |
38 | self.trajectories = [
39 | [
40 | (u, (TrainState((0, 3), {(1, 2): True}, (3, 1), True), 1.0)),
41 | (u, (TrainState((0, 2), {(1, 2): True}, (3, 2), True), 1.0)),
42 | (u, (TrainState((0, 1), {(1, 2): True}, (2, 2), True), 1.0)),
43 | (r, (TrainState((1, 1), {(1, 2): True}, (2, 1), True), 1.0)),
44 | (u, (TrainState((1, 0), {(1, 2): True}, (3, 1), True), 1.0)),
45 | (r, (TrainState((2, 0), {(1, 2): True}, (3, 2), True), 1.0)),
46 | (s, (TrainState((2, 0), {(1, 2): True}, (2, 2), True), 1.0)),
47 | (s, (TrainState((2, 0), {(1, 2): True}, (2, 1), True), 1.0)),
48 | ],
49 | [
50 | (u, (TrainState((0, 3), {(1, 2): True}, (3, 1), True), 1.0)),
51 | (r, (TrainState((1, 3), {(1, 2): True}, (3, 2), True), 1.0)),
52 | (r, (TrainState((2, 3), {(1, 2): True}, (2, 2), True), 1.0)),
53 | ],
54 | [
55 | (r, (TrainState((1, 4), {(1, 2): True}, (3, 1), True), 1.0)),
56 | (r, (TrainState((2, 4), {(1, 2): True}, (3, 2), True), 1.0)),
57 | (r, (TrainState((3, 4), {(1, 2): True}, (2, 2), True), 1.0)),
58 | (u, (TrainState((3, 3), {(1, 2): True}, (2, 1), True), 1.0)),
59 | (u, (TrainState((3, 2), {(1, 2): True}, (3, 1), True), 1.0)),
60 | (s, (TrainState((3, 2), {(1, 2): True}, (3, 2), False), 1.0)),
61 | (s, (TrainState((3, 2), {(1, 2): True}, (3, 2), False), 1.0)),
62 | (u, (TrainState((3, 1), {(1, 2): True}, (3, 2), False), 1.0)),
63 | (l, (TrainState((2, 1), {(1, 2): True}, (3, 2), False), 1.0)),
64 | ],
65 | ]
66 |
67 |
68 | class TestTrainModel(BaseTests.TestTabularTransitionModel):
69 | def setUp(self):
70 | self.env = TrainEnv(TestTrainSpec())
71 | self.model_tests = []
72 | self.setUpDeterministic()
73 |
74 |
75 | if __name__ == "__main__":
76 | unittest.main()
77 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/gridworlds/train_spec.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from deep_rlsp.envs.gridworlds.train import TrainState
4 |
5 |
6 | class TrainSpec(object):
7 | def __init__(
8 | self,
9 | height,
10 | width,
11 | init_state,
12 | carpet_locations,
13 | feature_locations,
14 | train_transition,
15 | ):
16 | """See TrainEnv.__init__ in train.py for details."""
17 | self.height = height
18 | self.width = width
19 | self.init_state = init_state
20 | self.carpet_locations = carpet_locations
21 | self.feature_locations = feature_locations
22 | self.train_transition = train_transition
23 |
24 |
25 | # In the diagrams below, G is a goal location, V is a vase, C is a carpet, A is
26 | # the agent, and T is the train.
27 | # Each tuple is of the form (spec, current state, task R, true R).
28 |
29 | TRAIN_PROBLEMS = {
30 | # -------
31 | # | G C|
32 | # | TT |
33 | # | VTTG|
34 | # | |
35 | # |A |
36 | # -------
37 | "default": (
38 | TrainSpec(
39 | 5,
40 | 5,
41 | TrainState((0, 4), {(1, 2): True}, (2, 1), True), # init state
42 | [(4, 0)], # carpet
43 | [(2, 0), (4, 2)], # goals
44 | # train transitions
45 | {(2, 1): (3, 1), (3, 1): (3, 2), (3, 2): (2, 2), (2, 2): (2, 1)},
46 | ),
47 | TrainState((2, 0), {(1, 2): True}, (2, 2), True), # current state for inference
48 | np.array([0, 0, 0, 0, 0, 0, 0, 0, 1]), # specified reward
49 | np.array([-1, 0, -1, 0, 0, 0, 0, 0, 1]), # true reward
50 | )
51 | }
52 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/mujoco/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/mujoco/__init__.py
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/mujoco/assets/ant.xml:
--------------------------------------------------------------------------------
1 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/mujoco/assets/ant_footsensor.xml:
--------------------------------------------------------------------------------
1 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/mujoco/assets/ant_plot.xml:
--------------------------------------------------------------------------------
1 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/mujoco/assets/half_cheetah.xml:
--------------------------------------------------------------------------------
1 |
16 |
17 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/mujoco/assets/half_cheetah_plot.xml:
--------------------------------------------------------------------------------
1 |
16 |
17 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/mujoco/half_cheetah.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import absolute_import
16 | from __future__ import division
17 | from __future__ import print_function
18 |
19 | import os
20 |
21 | from gym import utils
22 | import numpy as np
23 | from gym.envs.mujoco import mujoco_env
24 |
25 |
26 | class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
27 | def __init__(
28 | self,
29 | expose_all_qpos=False,
30 | task="default",
31 | target_velocity=None,
32 | model_path="half_cheetah.xml",
33 | plot=False,
34 | ):
35 | # Settings from
36 | # https://github.com/openai/gym/blob/master/gym/envs/__init__.py
37 | self._expose_all_qpos = expose_all_qpos
38 | self._task = task
39 | self._target_velocity = target_velocity
40 |
41 | xml_path = os.path.join(os.path.dirname(__file__), "assets")
42 | self.model_path = os.path.abspath(os.path.join(xml_path, model_path))
43 |
44 | mujoco_env.MujocoEnv.__init__(self, self.model_path, 5)
45 | utils.EzPickle.__init__(self)
46 |
47 | def step(self, action):
48 | xposbefore = self.sim.data.qpos[0]
49 | self.do_simulation(action, self.frame_skip)
50 | xposafter = self.sim.data.qpos[0]
51 | xvelafter = self.sim.data.qvel[0]
52 | ob = self._get_obs()
53 | reward_ctrl = -0.1 * np.square(action).sum()
54 |
55 | if self._task == "default":
56 | reward_vel = 0.0
57 | reward_run = (xposafter - xposbefore) / self.dt
58 | reward = reward_ctrl + reward_run
59 | elif self._task == "target_velocity":
60 | reward_vel = -((self._target_velocity - xvelafter) ** 2)
61 | reward = reward_ctrl + reward_vel
62 | elif self._task == "run_back":
63 | reward_vel = 0.0
64 | reward_run = (xposbefore - xposafter) / self.dt
65 | reward = reward_ctrl + reward_run
66 |
67 | done = False
68 | return (
69 | ob,
70 | reward,
71 | done,
72 | dict(reward_run=reward_run, reward_ctrl=reward_ctrl, reward_vel=reward_vel),
73 | )
74 |
75 | def _get_obs(self):
76 | if self._expose_all_qpos:
77 | return np.concatenate([self.sim.data.qpos.flat, self.sim.data.qvel.flat])
78 | return np.concatenate([self.sim.data.qpos.flat[1:], self.sim.data.qvel.flat])
79 |
80 | def reset_model(self):
81 | qpos = self.init_qpos + self.np_random.uniform(
82 | low=-0.1, high=0.1, size=self.sim.model.nq
83 | )
84 | qvel = self.init_qvel + self.np_random.randn(self.sim.model.nv) * 0.1
85 | self.set_state(qpos, qvel)
86 | return self._get_obs()
87 |
88 | def viewer_setup(self):
89 | # camera_id = self.model.camera_name2id("track")
90 | # self.viewer.cam.type = 2
91 | # self.viewer.cam.fixedcamid = camera_id
92 | # self.viewer.cam.distance = self.model.stat.extent * 0.5
93 | # camera_id = self.model.camera_name2id("fixed")
94 | # self.viewer.cam.type = 2
95 | # self.viewer.cam.fixedcamid = camera_id
96 |
97 | self.viewer.cam.fixedcamid = -1
98 | self.viewer.cam.trackbodyid = -1
99 | # how much you "zoom in", model.stat.extent is the max limits of the arena
100 | self.viewer.cam.distance = self.model.stat.extent * 0.4
101 | # self.viewer.cam.lookat[0] -= 4
102 | self.viewer.cam.lookat[1] += 1.1
103 | self.viewer.cam.lookat[2] += 0
104 | # camera rotation around the axis in the plane going through the frame origin
105 | # (if 0 you just see a line)
106 | self.viewer.cam.elevation = -20
107 | # camera rotation around the camera's vertical axis
108 | self.viewer.cam.azimuth = 90
109 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/reward_wrapper.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Dict, Union, Optional
2 |
3 | import numpy as np
4 | import gym
5 |
6 | from deep_rlsp.envs.gridworlds.env import Env
7 | from deep_rlsp.model import LatentSpaceModel
8 | from deep_rlsp.util.parameter_checks import check_between
9 | from deep_rlsp.util.helper import init_env_from_obs
10 |
11 |
12 | class RewardWeightWrapper(gym.Wrapper):
13 | def __init__(self, gridworld_env: Env, reward_weights: np.ndarray):
14 | self.gridworld_env = gridworld_env
15 | self.reward_weights = reward_weights
16 | super().__init__(gridworld_env)
17 |
18 | def update_reward_weights(self, reward_weights):
19 | self.reward_weights = reward_weights
20 |
21 | def step(self, action: int) -> Tuple[Union[int, np.ndarray], float, bool, Dict]:
22 | return self.gridworld_env.step(action, r_vec=self.reward_weights)
23 |
24 |
25 | class LatentSpaceRewardWrapper(gym.Wrapper):
26 | def __init__(
27 | self,
28 | env: Env,
29 | latent_space: LatentSpaceModel,
30 | r_inferred: Optional[np.ndarray],
31 | inferred_weight: Optional[float] = None,
32 | init_observations: Optional[list] = None,
33 | time_horizon: Optional[int] = None,
34 | init_prob: float = 0.2,
35 | use_task_reward: bool = False,
36 | reward_action_norm_factor: float = 0,
37 | ):
38 | if inferred_weight is not None:
39 | check_between("inferred_weight", inferred_weight, 0, 1)
40 | self.env = env
41 | self.latent_space = latent_space
42 | self.r_inferred = r_inferred
43 | if inferred_weight is None:
44 | inferred_weight = 1
45 | self.inferred_weight = inferred_weight
46 | self.state = None
47 | self.timestep = 0
48 | self.use_task_reward = use_task_reward
49 | self.reward_action_norm_factor = reward_action_norm_factor
50 |
51 | self.init_prob = init_prob
52 | self.init_observations = init_observations
53 |
54 | if time_horizon is None:
55 | self.time_horizon = self.env.spec.max_episode_steps
56 | else:
57 | self.time_horizon = time_horizon
58 | super().__init__(env)
59 |
60 | def reset(self) -> Tuple[np.ndarray, float, bool, Dict]:
61 | obs = super().reset()
62 |
63 | if self.init_observations is not None:
64 | if np.random.random() < self.init_prob:
65 | idx = np.random.randint(0, len(self.init_observations))
66 | obs = self.init_observations[idx]
67 | self.env = init_env_from_obs(self.env, obs)
68 |
69 | self.state = self.latent_space.encoder(obs)
70 | self.timestep = 0
71 | return obs
72 |
73 | def _get_reward(self, action, info):
74 | if self.r_inferred is not None:
75 | inferred = np.dot(self.r_inferred, self.state)
76 | info["inferred"] = inferred
77 | else:
78 | inferred = 0
79 |
80 | if self.use_task_reward and "task_reward" in info:
81 | task = info["task_reward"]
82 | else:
83 | task = 0
84 |
85 | action_norm = np.square(action).sum()
86 |
87 | reward = self.inferred_weight * inferred
88 | reward += task
89 | reward += self.reward_action_norm_factor * action_norm
90 | return reward
91 |
92 | def step(
93 | self, action, return_true_reward: bool = False,
94 | ) -> Tuple[np.ndarray, float, bool, Dict]:
95 | action = np.clip(action, self.action_space.low, self.action_space.high)
96 | obs, true_reward, done, info = self.env.step(action)
97 | self.state = self.latent_space.encoder(obs)
98 |
99 | if return_true_reward:
100 | assert "true_reward" in info
101 | reward = info["true_reward"]
102 | else:
103 | reward = self._get_reward(action, info)
104 |
105 | # ignore termination criterion from mujoco environments
106 | # to avoid reward information leak
107 | self.timestep += 1
108 | done = self.timestep > self.time_horizon
109 |
110 | return obs, reward, done, info
111 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/__init__.py
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/fetch/reach.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/fetch/shared.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/stls/.get:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/.get
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/stls/fetch/base_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/base_link_collision.stl
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/stls/fetch/bellows_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/bellows_link_collision.stl
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/stls/fetch/elbow_flex_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/elbow_flex_link_collision.stl
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/stls/fetch/estop_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/estop_link.stl
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/stls/fetch/forearm_roll_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/forearm_roll_link_collision.stl
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/stls/fetch/gripper_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/gripper_link.stl
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/stls/fetch/head_pan_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/head_pan_link_collision.stl
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/stls/fetch/head_tilt_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/head_tilt_link_collision.stl
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/stls/fetch/l_wheel_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/l_wheel_link_collision.stl
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/stls/fetch/laser_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/laser_link.stl
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/stls/fetch/r_wheel_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/r_wheel_link_collision.stl
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/stls/fetch/shoulder_lift_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/shoulder_lift_link_collision.stl
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/stls/fetch/shoulder_pan_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/shoulder_pan_link_collision.stl
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/stls/fetch/torso_fixed_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/torso_fixed_link.stl
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/stls/fetch/torso_lift_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/torso_lift_link_collision.stl
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/stls/fetch/upperarm_roll_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/upperarm_roll_link_collision.stl
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/stls/fetch/wrist_flex_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/wrist_flex_link_collision.stl
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/stls/fetch/wrist_roll_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/stls/fetch/wrist_roll_link_collision.stl
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/textures/block.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/textures/block.png
--------------------------------------------------------------------------------
/src/deep_rlsp/envs/robotics/assets/textures/block_hidden.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/envs/robotics/assets/textures/block_hidden.png
--------------------------------------------------------------------------------
/src/deep_rlsp/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .tabular import TabularTransitionModel
2 | from .latent_space import LatentSpaceModel
3 | from .dynamics_mdn import InverseDynamicsMDN
4 | from .dynamics_mlp import InverseDynamicsMLP
5 | from .inverse_policy_mdn import InversePolicyMDN
6 | from .state_vae import StateVAE
7 | from .experience_replay import ExperienceReplay
8 |
9 | __all__ = [
10 | "TabularTransitionModel",
11 | "LatentSpaceModel",
12 | "InverseDynamicsMDN",
13 | "InverseDynamicsMLP",
14 | "InversePolicyMDN",
15 | "StateVAE",
16 | "ExperienceReplay",
17 | ]
18 |
--------------------------------------------------------------------------------
/src/deep_rlsp/model/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class TransitionModel(abc.ABC):
5 | def __init__(self, env):
6 | pass
7 |
8 | @abc.abstractmethod
9 | def models_observations(self):
10 | pass
11 |
12 | @abc.abstractmethod
13 | def forward_sample(self, state):
14 | pass
15 |
16 | @abc.abstractmethod
17 | def backward_sample(self, state, t):
18 | pass
19 |
20 |
21 | class InversePolicy(abc.ABC):
22 | def __init__(self, env, policy):
23 | pass
24 |
25 | @abc.abstractmethod
26 | def step(self, next_state):
27 | pass
28 |
29 | @abc.abstractmethod
30 | def sample(self, next_state):
31 | pass
32 |
--------------------------------------------------------------------------------
/src/deep_rlsp/model/exact_dynamics_mujoco.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import math
3 | import numpy as np
4 |
5 | from deep_rlsp.util.helper import init_env_from_obs
6 |
7 | GOLDEN_RATIO = (math.sqrt(5) + 1) / 2
8 |
9 |
10 | def gss(f, a, b, tol=1e-12):
11 | """Golden section search to find the minimum of f on [a,b].
12 |
13 | This implementation does not reuse function evaluations and assumes the minimum is c
14 | or d (not on the edges at a or b).
15 | Source: https://en.wikipedia.org/wiki/Golden_section_search
16 |
17 | f: a strictly unimodal function on [a,b]
18 |
19 | Example:
20 | >>> f = lambda x: (x-2)**2
21 | >>> x = gss(f, 1, 5)
22 | >>> print("%.15f" % x)
23 | 2.000009644875678
24 | """
25 | c = b - (b - a) / GOLDEN_RATIO
26 | d = a + (b - a) / GOLDEN_RATIO
27 | while abs(c - d) > tol:
28 | if f(c) < f(d):
29 | b = d
30 | else:
31 | a = c
32 | # We recompute both c and d here to avoid loss of precision
33 | # which may lead to incorrect results or infinite loop
34 | c = b - (b - a) / GOLDEN_RATIO
35 | d = a + (b - a) / GOLDEN_RATIO
36 | return (b + a) / 2
37 |
38 |
39 | def loss(y1, y2):
40 | return np.square(y1 - y2).mean()
41 |
42 |
43 | def invert(f, y, tolerance=1e-12, max_iters=1000):
44 | """
45 | f: function to invert (for expensive f, make sure to memoize)
46 | y: output to invert
47 | tolerance: Acceptible numerical error
48 | max_iters: Maximum iterations to try
49 | Returns: x' such that f(x') = y up to tolerance or up to amount achieved
50 | in max_iters time.
51 | """
52 | x = y
53 | for i in range(max_iters):
54 | dx = np.random.normal(size=x.shape)
55 |
56 | def line_fn(fac):
57 | return loss(f(x + fac * dx), y)
58 |
59 | factor = gss(line_fn, -10.0, 10.0)
60 | x = x + factor * dx
61 | if loss(f(x), y) < tolerance:
62 | # print(f"Took {i} iterations")
63 | return x
64 | print("Max it reached, loss value is {}".format(loss(f(x), y)))
65 | return x
66 |
67 |
68 | class ExactDynamicsMujoco:
69 | def __init__(self, env_id, tolerance=1e-12, max_iters=1000):
70 | self.env = gym.make(env_id)
71 | self.env.reset()
72 | self.tolerance = tolerance
73 | self.max_iters = max_iters
74 |
75 | # @memoize
76 | def dynamics(s, a):
77 | self.env = init_env_from_obs(self.env, s)
78 | s2, _, _, _ = self.env.step(a)
79 | return s2
80 |
81 | self.dynamics = dynamics
82 |
83 | def inverse_dynamics(self, s2, a):
84 | def fn(s):
85 | return self.dynamics(s, a)
86 |
87 | return invert(fn, s2, tolerance=self.tolerance, max_iters=self.max_iters)
88 |
89 |
90 | def main():
91 | dynamics = ExactDynamicsMujoco("InvertedPendulum-v2")
92 | s = dynamics.env.reset()
93 | for _ in range(5):
94 | a = dynamics.env.action_space.sample()
95 | s2 = dynamics.dynamics(s, a)
96 | inverted_s = dynamics.inverse_dynamics(s2, a)
97 | print(f"s2: {s2}\na: {a}\ns: {s}")
98 | print(f"Inverted: {inverted_s}")
99 | print("RMSE: {}\n".format(np.sqrt(loss(s, inverted_s))))
100 | s = s2
101 |
102 |
103 | if __name__ == "__main__":
104 | main()
105 |
--------------------------------------------------------------------------------
/src/deep_rlsp/model/experience_replay.py:
--------------------------------------------------------------------------------
1 | import collections
2 |
3 | import numpy as np
4 |
5 | from deep_rlsp.model.exact_dynamics_mujoco import ExactDynamicsMujoco
6 |
7 |
8 | class Normalizer:
9 | def __init__(self, num_inputs):
10 | self.n = np.zeros(num_inputs)
11 | self.mean = np.zeros(num_inputs)
12 | self.mean_diff = np.zeros(num_inputs)
13 | self.std = np.ones(num_inputs)
14 |
15 | def add(self, x):
16 | self.n += 1.0
17 | last_mean = self.mean.copy()
18 | self.mean += (x - self.mean) / self.n
19 | self.mean_diff += (x - last_mean) * (x - self.mean)
20 | var = self.mean_diff / self.n
21 | var[var < 1e-8] = 1
22 | self.std = np.sqrt(var)
23 |
24 | def normalize(self, inputs):
25 | return (inputs - self.mean) / self.std
26 |
27 | def unnormalize(self, inputs):
28 | return inputs * self.std + self.mean
29 |
30 |
31 | class ExperienceReplay:
32 | def __init__(self, capacity):
33 | self.buffer = collections.deque(maxlen=capacity)
34 | self.obs_normalizer = None
35 | self.act_normalizer = None
36 | self.delta_normalizer = None
37 | self.max_delta = None
38 | self.max_delta = None
39 |
40 | def __len__(self):
41 | return len(self.buffer)
42 |
43 | def append(self, obs, action, next_obs):
44 | delta = obs - next_obs
45 | if self.obs_normalizer is None:
46 | self.obs_normalizer = Normalizer(len(obs))
47 | if self.delta_normalizer is None:
48 | self.delta_normalizer = Normalizer(len(obs))
49 | self.max_delta = delta
50 | self.min_delta = delta
51 | if self.act_normalizer is None:
52 | self.act_normalizer = Normalizer(len(action))
53 | self.obs_normalizer.add(obs)
54 | self.obs_normalizer.add(next_obs)
55 | self.act_normalizer.add(action)
56 | self.delta_normalizer.add(delta)
57 | self.buffer.append((obs, action, next_obs))
58 | self.max_delta = np.maximum(self.max_delta, delta)
59 | self.min_delta = np.minimum(self.min_delta, delta)
60 |
61 | def sample(self, batch_size, normalize=False):
62 | indices = np.random.choice(len(self.buffer), batch_size, replace=False)
63 | obs, act, next_obs = zip(*[self.buffer[idx] for idx in indices])
64 | obs, act, next_obs = np.array(obs), np.array(act), np.array(next_obs)
65 | if normalize:
66 | obs = self.normalize_obs(obs)
67 | act = self.normalize_act(act)
68 | next_obs = self.normalize_obs(next_obs)
69 | return obs, act, next_obs
70 |
71 | def clip_delta(self, delta):
72 | return np.clip(delta, self.min_delta, self.max_delta)
73 |
74 | def normalize_obs(self, obs):
75 | return self.obs_normalizer.normalize(obs)
76 |
77 | def unnormalize_obs(self, obs):
78 | return self.obs_normalizer.unnormalize(obs)
79 |
80 | def normalize_delta(self, delta):
81 | return self.delta_normalizer.normalize(delta)
82 |
83 | def unnormalize_delta(self, delta):
84 | return self.delta_normalizer.unnormalize(delta)
85 |
86 | def normalize_act(self, act):
87 | return self.act_normalizer.normalize(act)
88 |
89 | def unnormalize_act(self, act):
90 | return self.act_normalizer.unnormalize(act)
91 |
92 | def add_random_rollouts(self, env, timesteps, n_rollouts):
93 | for _ in range(n_rollouts):
94 | obs = env.reset()
95 | for t in range(timesteps):
96 | action = env.action_space.sample()
97 | next_obs, _, _, _ = env.step(action)
98 | self.append(obs, action, next_obs)
99 | obs = next_obs
100 |
101 | def add_play_data(self, env, play_data):
102 | dynamics = ExactDynamicsMujoco(env.unwrapped.spec.id)
103 | observations, actions = play_data["observations"], play_data["actions"]
104 | n_traj = len(observations)
105 | assert len(actions) == n_traj
106 | for i in range(n_traj):
107 | l_traj = len(observations[i])
108 | for t in range(l_traj):
109 | obs = observations[i][t]
110 | action = actions[i][t]
111 | next_obs = dynamics.dynamics(obs, action)
112 | self.append(obs, action, next_obs)
113 |
114 | def add_policy_rollouts(self, env, policy, n_rollouts, horizon, eps_greedy=0):
115 | for i in range(n_rollouts):
116 | obs = env.reset()
117 | for t in range(horizon):
118 | if eps_greedy > 0 and np.random.random() < eps_greedy:
119 | action = env.action_space.sample()
120 | else:
121 | action = policy.predict(obs)[0]
122 | next_obs, reward, done, info = env.step(action)
123 | self.append(obs, action, next_obs)
124 | obs = next_obs
125 |
--------------------------------------------------------------------------------
/src/deep_rlsp/model/gridworlds_feature_space.py:
--------------------------------------------------------------------------------
1 | class GridworldsFeatureSpace:
2 | def __init__(self, env):
3 | self.env = env
4 | s = self.env.init_state
5 | f = self.env.s_to_f(s)
6 | assert len(f.shape) == 1
7 | self.state_size = f.shape[0]
8 |
9 | def encoder(self, obs):
10 | s = self.env.obs_to_s(obs)
11 | f = self.env.s_to_f(s)
12 | return f
13 |
14 | def decoder(self, state):
15 | raise NotImplementedError()
16 |
--------------------------------------------------------------------------------
/src/deep_rlsp/model/inverse_model_env_wrapper.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Dict, Union
2 |
3 | import numpy as np
4 | import gym
5 |
6 | from deep_rlsp.model import LatentSpaceModel, InverseDynamicsMDN
7 | from deep_rlsp.util.mujoco import compute_reward_done_from_obs
8 |
9 | MIN_FLOAT64 = np.finfo(np.float64).min
10 | MAX_FLOAT64 = np.finfo(np.float64).max
11 |
12 |
13 | class InverseModelGymEnv(gym.Env):
14 | """
15 | Allows to treat an inverse dynamics model as an MDP.
16 |
17 | Used for model evaluation and debugging.
18 | """
19 |
20 | def __init__(
21 | self,
22 | latent_space: LatentSpaceModel,
23 | inverse_model: InverseDynamicsMDN,
24 | initial_obs: np.ndarray,
25 | time_horizon: int,
26 | ):
27 | self.latent_space = latent_space
28 | self.inverse_model = inverse_model
29 |
30 | self.time_horizon = time_horizon
31 | self.update_initial_state(obs=initial_obs)
32 | self.action_space = self.latent_space.env.action_space
33 | self.observation_space = gym.spaces.Box(
34 | MIN_FLOAT64, MAX_FLOAT64, shape=(self.latent_space.state_size,)
35 | )
36 | if hasattr(self.latent_space.env, "nA"):
37 | self.nA = self.latent_space.env.nA
38 | self.reset()
39 |
40 | def update_initial_state(self, state=None, obs=None):
41 | if (state is not None and obs is not None) or (state is None and obs is None):
42 | raise ValueError("Exactly one of state and obs should be None.")
43 | if state is None:
44 | obs = np.expand_dims(obs, 0)
45 | state = self.latent_space.encoder(obs)[0]
46 | self.initial_state = state
47 |
48 | def reset(self) -> np.ndarray:
49 | self.timestep = 0
50 | self.state = self.initial_state
51 | return self.state
52 |
53 | def step(
54 | self, action: Union[int, np.ndarray]
55 | ) -> Tuple[np.ndarray, float, bool, Dict]:
56 | last_state = self.state
57 | last_obs = self.latent_space.decoder(last_state)
58 | self.state = self.inverse_model.step(self.state, action)
59 | obs = self.latent_space.decoder(self.state)
60 | # invert transition to get reward
61 | reward, done = compute_reward_done_from_obs(
62 | self.latent_space.env, obs, action, last_obs
63 | )
64 | return self.state, reward, done, dict()
65 |
--------------------------------------------------------------------------------
/src/deep_rlsp/model/mujoco_debug_models.py:
--------------------------------------------------------------------------------
1 | """
2 | Exact dynamics models and handcoded features used for debugging in the pendulum env.
3 | """
4 |
5 | import numpy as np
6 |
7 | from deep_rlsp.model.exact_dynamics_mujoco import ExactDynamicsMujoco
8 |
9 |
10 | class IdentityFeatures:
11 | def __init__(self, env):
12 | self.encoder = lambda x: x
13 | self.decoder = lambda x: x
14 |
15 |
16 | class MujocoDebugFeatures:
17 | def __init__(self, env):
18 | self.env_id = env.unwrapped.spec.id
19 | assert env.unwrapped.spec.id in ("InvertedPendulum-v2", "FetchReachStack-v1")
20 | self.env = env
21 | if self.env_id == "InvertedPendulum-v2":
22 | self.state_size = 5
23 | elif self.env_id == "FetchReachStack-v1":
24 | self.state_size = env.observation_space.shape[0] + 1
25 |
26 | def encoder(self, obs):
27 | if self.env_id == "InvertedPendulum-v2":
28 | # feature = np.isfinite(obs).all() and (np.abs(obs[1]) <= 0.2)
29 | feature = np.abs(obs[1])
30 | elif self.env_id == "FetchReachStack-v1":
31 | feature = int(obs[-4] < 0.5)
32 | obs = np.concatenate([obs, [feature]])
33 | return obs
34 |
35 | def decoder(self, state):
36 | return state[:-1]
37 |
38 |
39 | class PendulumDynamics:
40 | def __init__(self, latent_space, backward=False):
41 | assert latent_space.env.unwrapped.spec.id == "InvertedPendulum-v2"
42 | self.latent_space = latent_space
43 | self.backward = backward
44 | self.dynamics = ExactDynamicsMujoco(
45 | self.latent_space.env.unwrapped.spec.id, tolerance=1e-2, max_iters=100
46 | )
47 | self.low = np.array([-1, -np.pi, -10, -10])
48 | self.high = np.array([1, np.pi, 10, 10])
49 |
50 | def step(self, state, action, sample=True):
51 | obs = state # self.latent_space.decoder(state)
52 | obs = np.clip(obs, self.low, self.high)
53 | if self.backward:
54 | obs = self.dynamics.inverse_dynamics(obs, action)
55 | else:
56 | obs = self.dynamics.dynamics(obs, action)
57 | state = obs # self.latent_space.encoder(obs)
58 | return state
59 |
60 | def learn(self, *args, return_initial_loss=False, **kwargs):
61 | print("Using exact dynamics...")
62 | if return_initial_loss:
63 | return 0, 0
64 | return 0
65 |
--------------------------------------------------------------------------------
/src/deep_rlsp/relative_reachability.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def relative_reachability_penalty(mdp, horizon, start):
5 | """
6 | Calculates the undiscounted relative reachability penalty for each state in an mdp,
7 | compared to the starting state baseline.
8 |
9 | Based on the algorithm described in: https://arxiv.org/pdf/1806.01186.pdf
10 | """
11 | coverage = get_coverage(mdp, horizon)
12 | distributions = baseline_state_distributions(mdp, horizon, start)
13 |
14 | def penalty(state):
15 | return np.sum(np.maximum(coverage[state, :] - coverage, 0), axis=1)
16 |
17 | def penalty_for_baseline_distribution(dist):
18 | return sum(
19 | (
20 | dist[state] * penalty(state)
21 | for state in range(mdp.nS)
22 | if dist[state] != 0
23 | )
24 | )
25 |
26 | r_r = np.array(list(map(penalty_for_baseline_distribution, distributions)))
27 | if np.amax(r_r) == 0:
28 | return np.zeros_like(r_r)
29 | return r_r / np.amax(r_r)
30 |
31 |
32 | def get_coverage(mdp, horizon):
33 | coverage = np.identity(mdp.nS)
34 | for i in range(horizon):
35 | # coverage(s0, sk) = \max_{a0} \sum_{s1} P(s1 | s0, a) * coverage(s1, sk)
36 | action_coverage = mdp.T_matrix.dot(coverage)
37 | action_coverage = action_coverage.reshape((mdp.nS, mdp.nA, mdp.nS))
38 | coverage = np.amax(action_coverage, axis=1)
39 | return coverage
40 |
41 |
42 | def baseline_state_distributions(mdp, horizon, start):
43 | distribution = np.zeros(mdp.nS)
44 | distribution[start] = 1
45 | distributions = [distribution]
46 | for _ in range(horizon - 1):
47 | distribution = mdp.baseline_matrix_transpose.dot(distribution)
48 | distributions.append(distribution)
49 | return distributions
50 |
--------------------------------------------------------------------------------
/src/deep_rlsp/rlsp.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.optimize import check_grad
3 |
4 | from deep_rlsp.solvers.value_iter import value_iter, evaluate_policy
5 | from deep_rlsp.solvers.ppo import PPOSolver
6 | from deep_rlsp.util.parameter_checks import check_in
7 |
8 |
9 | def compute_g(mdp, policy, p_0, T, d_last_step_list, expected_features_list):
10 | nS, nA, nF = mdp.nS, mdp.nA, mdp.num_features
11 |
12 | # base case
13 | G = np.zeros((nS, nF))
14 | # recursive case
15 | for t in range(T - 1):
16 | # G(s') = \sum_{s, a} p(a | s) p(s' | s, a) [ p(s) g(s, a) + G_prev[s] ]
17 | # p(s) is given by d_last_step_list[t]
18 | # g(s, a) = f(s) - F(s) + \sum_{s'} p(s' | s, a) F(s')
19 | # Distribute the addition to get three different terms:
20 | # First term: p(s) [f(s') - F(s')]
21 | # Second term: p(s) \sum_{s2} p(s2 | s, a) F(s2)
22 | # Third term: G_prev[s]
23 | g_first = mdp.f_matrix - expected_features_list[t]
24 | g_second = mdp.T_matrix.dot(expected_features_list[t + 1])
25 | g_second = g_second.reshape((nS, nA, nF))
26 | g_total = np.expand_dims(g_first, axis=1) + g_second
27 |
28 | prob_s_a = np.expand_dims(d_last_step_list[t].reshape(nS), axis=1) * policy[t]
29 |
30 | G_value = np.expand_dims(prob_s_a, axis=2) * g_total
31 | G_value = mdp.T_matrix_transpose.dot(G_value.reshape((nS * nA, nF)))
32 |
33 | G_recurse = np.expand_dims(policy[t], axis=-1) * np.expand_dims(G, axis=1)
34 | G_recurse = mdp.T_matrix_transpose.dot(G_recurse.reshape((nS * nA, nF)))
35 |
36 | G = G_value + G_recurse
37 |
38 | return G
39 |
40 |
41 | def compute_d_last_step(mdp, policy, p_0, T, gamma=1, verbose=False, return_all=False):
42 | """Computes the last-step occupancy measure"""
43 | D, d_last_step_list = p_0, [p_0]
44 | for t in range(T - 1):
45 | # D(s') = \sum_{s, a} D_prev(s) * p(a | s) * p(s' | s, a)
46 | state_action_prob = np.expand_dims(D, axis=1) * policy[t]
47 | D = mdp.T_matrix_transpose.dot(state_action_prob.flatten())
48 |
49 | if verbose is True:
50 | print(D)
51 | if return_all:
52 | d_last_step_list.append(D)
53 |
54 | return (D, d_last_step_list) if return_all else D
55 |
56 |
57 | def compute_feature_expectations(mdp, policy, p_0, T):
58 | nS, nA, nF = mdp.nS, mdp.nA, mdp.num_features
59 | expected_features = mdp.f_matrix
60 | expected_feature_list = [expected_features]
61 | for t in range(T - 2, -1, -1):
62 | # F(s) = f(s) + \sum_{a, s'} p(a | s) * p(s' | s, a) * F(s')
63 | future_features = mdp.T_matrix.dot(expected_features).reshape((nS, nA, nF))
64 | future_features = future_features * np.expand_dims(policy[t], axis=-1)
65 | expected_features = mdp.f_matrix + np.sum(future_features, axis=1)
66 | expected_feature_list.append(expected_features)
67 | return expected_features, expected_feature_list[::-1]
68 |
69 |
70 | def rlsp(
71 | _run,
72 | mdp,
73 | s_current,
74 | p_0,
75 | horizon,
76 | temp=1,
77 | epochs=1,
78 | learning_rate=0.2,
79 | r_prior=None,
80 | r_vec=None,
81 | threshold=1e-3,
82 | check_grad_flag=False,
83 | solver="value_iter",
84 | reset_solver=False,
85 | solver_iterations=1000,
86 | ):
87 | """The RLSP algorithm."""
88 | check_in("solver", solver, ("value_iter", "ppo"))
89 |
90 | def compute_grad(r_vec):
91 | # Compute the Boltzmann rational policy \pi_{s,a} = \exp(Q_{s,a} - V_s)
92 | if solver == "value_iter":
93 | policy = value_iter(mdp, 1, mdp.f_matrix @ r_vec, horizon, temp)
94 | elif solver == "ppo":
95 | policy = ppo.learn(
96 | r_vec, reset_model=reset_solver, total_timesteps=solver_iterations
97 | )
98 |
99 | _run.log_scalar(
100 | "policy_eval_r_vec",
101 | evaluate_policy(
102 | mdp,
103 | policy,
104 | mdp.get_num_from_state(mdp.init_state),
105 | 1,
106 | mdp.f_matrix @ r_vec,
107 | horizon,
108 | ),
109 | i,
110 | )
111 |
112 | d_last_step, d_last_step_list = compute_d_last_step(
113 | mdp, policy, p_0, horizon, return_all=True
114 | )
115 | if d_last_step[s_current] == 0:
116 | print("Error in om_method: No feasible trajectories!")
117 | return r_vec
118 |
119 | expected_features, expected_features_list = compute_feature_expectations(
120 | mdp, policy, p_0, horizon
121 | )
122 |
123 | G = compute_g(
124 | mdp, policy, p_0, horizon, d_last_step_list, expected_features_list
125 | )
126 | # Compute the gradient
127 | dL_dr_vec = G[s_current] / d_last_step[s_current]
128 | # Gradient of the prior
129 | if r_prior is not None:
130 | dL_dr_vec += r_prior.logdistr_grad(r_vec)
131 | return dL_dr_vec
132 |
133 | def compute_log_likelihood(r_vec):
134 | policy = value_iter(mdp, 1, mdp.f_matrix @ r_vec, horizon, temp)
135 | d_last_step = compute_d_last_step(mdp, policy, p_0, horizon)
136 | log_likelihood = np.log(d_last_step[s_current])
137 | if r_prior is not None:
138 | log_likelihood += np.sum(r_prior.logpdf(r_vec))
139 | return log_likelihood
140 |
141 | def get_grad(_):
142 | """dummy function for use with check_grad()"""
143 | return dL_dr_vec
144 |
145 | if r_vec is None:
146 | r_vec = 0.01 * np.random.randn(mdp.f_matrix.shape[1])
147 | print("Initial reward vector: {}".format(r_vec))
148 |
149 | ppo = PPOSolver(mdp, temp)
150 |
151 | if check_grad_flag:
152 | grad_error_list = []
153 |
154 | for i in range(epochs):
155 | dL_dr_vec = compute_grad(r_vec)
156 | if check_grad_flag:
157 | grad_error_list.append(check_grad(compute_log_likelihood, get_grad, r_vec))
158 |
159 | # Gradient ascent
160 | r_vec = r_vec + learning_rate * dL_dr_vec
161 |
162 | grad_norm = np.linalg.norm(dL_dr_vec)
163 |
164 | with np.printoptions(precision=4, suppress=True):
165 | print(
166 | "Epoch {}; Reward vector: {}; grad norm: {}".format(i, r_vec, grad_norm)
167 | )
168 | if check_grad_flag:
169 | print("grad error: {}".format(grad_error_list[-1]))
170 |
171 | if grad_norm < threshold:
172 | if check_grad_flag:
173 | print()
174 | print("Max grad error: {}".format(np.amax(np.asarray(grad_error_list))))
175 | print(
176 | "Median grad error: {}".format(
177 | np.median(np.asarray(grad_error_list))
178 | )
179 | )
180 | break
181 |
182 | return r_vec
183 |
--------------------------------------------------------------------------------
/src/deep_rlsp/sampling.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from math import exp
3 |
4 | from deep_rlsp.solvers.value_iter import value_iter
5 | from deep_rlsp.rlsp import compute_d_last_step
6 |
7 |
8 | def sample_from_posterior(
9 | env, s_current, p_0, h, temp, n_samples, step_size, r_prior, gamma=1, print_level=1
10 | ):
11 | """
12 | Algorithm similar to BIRL that uses the last-step OM of a Boltzmann rational
13 | policy instead of the BIRL likelihood. Samples the reward from the posterior
14 | p(r | s_T, r_spec) \\propto p(s_T | \\theta) * p(r | r_spec).
15 |
16 | This is Algorithm 1 in Appendix C of the paper.
17 | """
18 |
19 | def log_last_step_om(policy):
20 | d_last_step = compute_d_last_step(env, policy, p_0, h)
21 | return np.log(d_last_step[s_current])
22 |
23 | def log_probability(r_vec, verbose=False):
24 | pi = value_iter(env, gamma, env.f_matrix @ r_vec, h, temp)
25 | log_p = log_last_step_om(pi)
26 |
27 | log_prior = 0
28 | if r_prior is not None:
29 | log_prior = np.sum(r_prior.logpdf(r_vec))
30 |
31 | if verbose:
32 | print(
33 | "Log prior: {}\nLog prob: {}\nTotal: {}".format(
34 | log_prior, log_p, log_p + log_prior
35 | )
36 | )
37 | return log_p + log_prior
38 |
39 | times_accepted = 0
40 | samples = []
41 |
42 | if r_prior is None:
43 | r = 0.01 * np.random.randn(env.num_features)
44 | else:
45 | r = 0.1 * r_prior.rvs()
46 |
47 | if print_level >= 1:
48 | print("Initial reward: {}".format(r))
49 |
50 | # probability of the initial reward
51 | log_p = log_probability(r, verbose=(print_level >= 1))
52 |
53 | while len(samples) < n_samples:
54 | verbose = (print_level >= 1) and (len(samples) % 200 == 199)
55 | if verbose:
56 | print("\nGenerating sample {}".format(len(samples) + 1))
57 |
58 | r_prime = np.random.normal(r, step_size)
59 | log_p_1 = log_probability(r_prime, verbose=verbose)
60 |
61 | # Accept or reject the new sample
62 | # If we reject, the new sample is the previous sample
63 | acceptance_probability = exp(log_p_1 - log_p)
64 | if np.random.uniform() < acceptance_probability:
65 | times_accepted += 1
66 | r, log_p = r_prime, log_p_1
67 | samples.append(r)
68 |
69 | if verbose:
70 | # Acceptance probability should not be very high or very low
71 | print("Acceptance probability is {}".format(acceptance_probability))
72 |
73 | if print_level >= 1:
74 | print("Done! Accepted {} of samples".format(times_accepted / n_samples))
75 | return samples
76 |
--------------------------------------------------------------------------------
/src/deep_rlsp/solvers/__init__.py:
--------------------------------------------------------------------------------
1 | from stable_baselines import SAC, PPO2
2 | from stable_baselines.common.policies import MlpPolicy
3 | from stable_baselines.sac.policies import MlpPolicy as MlpPolicySac
4 | from stable_baselines.sac.policies import FeedForwardPolicy as SACPolicy
5 |
6 |
7 | class CustomSACPolicy(SACPolicy):
8 | def __init__(self, *args, **kwargs):
9 | super().__init__(*args, **kwargs, layers=[256, 256], feature_extraction="mlp")
10 |
11 |
12 | def get_sac(env, **kwargs):
13 | env_id = env.unwrapped.spec.id
14 | if (
15 | env_id.startswith("Ant")
16 | or env_id.startswith("HalfCheetah")
17 | or env_id.startswith("Swimmer")
18 | or env_id.startswith("Fetch")
19 | ):
20 | sac_kwargs = {
21 | "verbose": 1,
22 | "learning_rate": 3e-4,
23 | "gamma": 0.98,
24 | "tau": 0.01,
25 | "ent_coef": "auto",
26 | "buffer_size": 1000000,
27 | "batch_size": 256,
28 | "learning_starts": 10000,
29 | "train_freq": 1,
30 | "gradient_steps": 1,
31 | }
32 | policy = CustomSACPolicy
33 | elif env_id.startswith("Hopper"):
34 | sac_kwargs = {
35 | "verbose": 1,
36 | "learning_rate": 3e-4,
37 | "ent_coef": 0.01,
38 | "buffer_size": 1000000,
39 | "batch_size": 256,
40 | "learning_starts": 1000,
41 | "train_freq": 1,
42 | "gradient_steps": 1,
43 | }
44 | policy = CustomSACPolicy
45 | else:
46 | sac_kwargs = {"verbose": 1, "learning_starts": 1000}
47 | policy = MlpPolicySac
48 | for key, val in kwargs.items():
49 | sac_kwargs[key] = val
50 | solver = SAC(policy, env, **sac_kwargs)
51 | return solver
52 |
53 |
54 | def get_ppo(env, **kwargs):
55 | sac_kwargs = {"verbose": 1}
56 | policy = MlpPolicy
57 | for key, val in kwargs.items():
58 | sac_kwargs[key] = val
59 | solver = PPO2(policy, env, **sac_kwargs)
60 | return solver
61 |
--------------------------------------------------------------------------------
/src/deep_rlsp/solvers/ppo.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import numpy as np
4 |
5 | from stable_baselines.common.vec_env import DummyVecEnv
6 | from stable_baselines.common.policies import MlpPolicy
7 | from stable_baselines import PPO2
8 | from stable_baselines.bench.monitor import Monitor
9 |
10 | from deep_rlsp.envs.reward_wrapper import RewardWeightWrapper
11 |
12 |
13 | class PPOSolver:
14 | def __init__(
15 | self, env, temperature: float = 1, tensorboard_log: Optional[str] = None
16 | ):
17 | self.env = RewardWeightWrapper(env, None)
18 | self.temperature = temperature
19 | # Monitor allows PPO to log the reward it achieves
20 | monitored_env = Monitor(self.env, None, allow_early_resets=True)
21 | self.vec_env = DummyVecEnv([lambda: monitored_env])
22 | self.tensorboard_log = tensorboard_log
23 | self._reset_model()
24 |
25 | def _reset_model(self):
26 | self.model = PPO2(
27 | MlpPolicy,
28 | self.vec_env,
29 | verbose=1,
30 | ent_coef=self.temperature,
31 | tensorboard_log=self.tensorboard_log,
32 | )
33 |
34 | def _get_tabular_policy(self):
35 | """
36 | Extracts a tabular policy representation from the PPO2 model.
37 | """
38 | policy = np.zeros((self.env.nS, self.env.nA))
39 | for state_id in range(self.env.nS):
40 | state = self.env.get_state_from_num(state_id)
41 | obs = self.env.s_to_obs(state)
42 | probs = self.model.action_probability(obs)
43 | policy[state_id, :] = probs
44 |
45 | # `action_probability` sometimes returns slightly unnormalized distributions
46 | # (probably numerical issues) Hence, we normalize manually.
47 | policy /= policy.sum(axis=1, keepdims=True)
48 | assert np.allclose(policy.sum(axis=1), 1)
49 | return policy
50 |
51 | def learn(
52 | self, reward_weights, total_timesteps: int = 1000, reset_model: bool = False
53 | ):
54 | """
55 | Performs the PPO algorithm using the implementation from `stable_baselines`.
56 |
57 | Returns (np.ndarray):
58 | Array of shape (nS, nA) containing the action probabilites for each state.
59 | """
60 | if reset_model:
61 | self._reset_model()
62 | self.env.update_reward_weights(reward_weights)
63 | self.model.learn(total_timesteps=total_timesteps, log_interval=10)
64 | return self._get_tabular_policy()
65 |
--------------------------------------------------------------------------------
/src/deep_rlsp/solvers/value_iter.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def value_iter(mdp, gamma, r, horizon, temperature=1, time_dependent_reward=False):
5 | """
6 | Performs (soft) value iteration to find a (Boltzman-)optimal policy.
7 |
8 | Finds the optimal state and state-action value functions via value iteration with
9 | the "soft" max-ent Bellman backup:
10 |
11 | $$
12 | Q_{sa} = r_s + gamma * \\sum_{s'} p(s'|s,a)V_{s'}
13 | V'_s = temperature * log(\\sum_a exp(Q_{sa}/temperature))
14 | $$
15 |
16 | Then, computes the Boltzmann rational policy
17 |
18 | $$
19 | \\pi_{s,a} = exp((Q_{s,a} - V_s)/temperature).
20 | $$
21 |
22 | Args
23 | ----------
24 | mdp (Env): MDP describing the environment.
25 | gamma (float): Discount factor; 0 <= gamma <= 1.
26 | r (np.ndarray): Reward vector of length mdp.nS.
27 | horizon (int): Time-horizon for the finite-horizon value iteration.
28 | temperature (float): Rationality constant for the soft value iteration equation.
29 |
30 | Returns (list of np.ndarray):
31 | Arrays of shape (mdp.nS, mdp.nA), each value p[t][s,a] is the probability of
32 | taking action a in state s at time t.
33 | """
34 | nS, nA = mdp.nS, mdp.nA
35 |
36 | if not time_dependent_reward:
37 | r = [r] * horizon # Fast, since we aren't making copies
38 |
39 | policies = []
40 | V = np.copy(r[horizon - 1])
41 | for t in range(horizon - 2, -1, -1):
42 | future_values = mdp.T_matrix.dot(V).reshape((nS, nA))
43 | Q = np.expand_dims(r[t], axis=1) + gamma * future_values
44 |
45 | if temperature == 0:
46 | V = Q.max(axis=1)
47 | # Argmax to find the action number, then index into np.eye to
48 | # one hot encode. Note this will deterministically break ties
49 | # towards the smaller action.
50 | policy = np.eye(nA)[np.argmax(Q, axis=1)]
51 | else:
52 | # ∀ s: V_s = temperature * log(\\sum_a exp(Q_sa/temperature))
53 | # ∀ s,a: policy_{s,a} = exp((Q_{s,a} - V_s)/t)
54 | V = softmax(Q, temperature)
55 | V = np.expand_dims(V, axis=1)
56 | policy = np.exp((Q - V) / temperature)
57 |
58 | policies.append(policy)
59 |
60 | if gamma == 1:
61 | # When \\gamma=1, the backup operator is equivariant under adding
62 | # a constant to all entries of V, so we can translate min(V)
63 | # to be 0 at each step of the softmax value iteration without
64 | # changing the policy it converges to, and this fixes the problem
65 | # where log(nA) keep getting added at each iteration.
66 | V = V - np.amin(V)
67 |
68 | return policies[::-1]
69 |
70 |
71 | def evaluate_policy(mdp, policy, start, gamma, r, horizon):
72 | """Expected reward from the policy."""
73 | policy = np.array(policy)
74 | if len(policy.shape) == 2:
75 | policy = [policy] * horizon
76 | V = r
77 | for t in range(horizon - 2, -1, -1):
78 | future_values = mdp.T_matrix.dot(V).reshape((mdp.nS, mdp.nA))
79 | Q = np.expand_dims(r, axis=1) + gamma * future_values
80 | V = np.sum(policy[t] * Q, axis=1)
81 | return V[start]
82 |
83 |
84 | def softmax(x, t=1):
85 | """
86 | Numerically stable computation of t*log(\\sum_j^n exp(x_j / t))
87 |
88 | If the input is a 1D numpy array, computes it's softmax:
89 | output = t*log(\\sum_j^n exp(x_j / t)).
90 | If the input is a 2D numpy array, computes the softmax of each of the rows:
91 | output_i = t*log(\\sum_j^n exp(x_{ij} / t))
92 |
93 | Parameters
94 | ----------
95 | x : 1D or 2D numpy array
96 |
97 | Returns
98 | -------
99 | 1D numpy array
100 | shape = (n,), where:
101 | n = 1 if x was 1D, or
102 | n is the number of rows (=x.shape[0]) if x was 2D.
103 | """
104 | assert t >= 0
105 | if len(x.shape) == 1:
106 | x = x.reshape((1, -1))
107 | if t == 0:
108 | return np.amax(x, axis=1)
109 | if x.shape[1] == 1:
110 | return x
111 |
112 | def softmax_2_arg(x1, x2, t):
113 | """
114 | Numerically stable computation of t*log(exp(x1/t) + exp(x2/t))
115 |
116 | Parameters
117 | ----------
118 | x1 : numpy array of shape (n,1)
119 | x2 : numpy array of shape (n,1)
120 |
121 | Returns
122 | -------
123 | numpy array of shape (n,1)
124 | Each output_i = t*log(exp(x1_i / t) + exp(x2_i / t))
125 | """
126 |
127 | def tlog(x):
128 | return t * np.log(x)
129 |
130 | def expt(x):
131 | return np.exp(x / t)
132 |
133 | max_x = np.amax((x1, x2), axis=0)
134 | min_x = np.amin((x1, x2), axis=0)
135 | return max_x + tlog(1 + expt((min_x - max_x)))
136 |
137 | sm = softmax_2_arg(x[:, 0], x[:, 1], t)
138 | # Use the following property of softmax_2_arg:
139 | # softmax_2_arg(softmax_2_arg(x1,x2),x3) = log(exp(x1) + exp(x2) + exp(x3))
140 | # which is true since
141 | # log(exp(log(exp(x1) + exp(x2))) + exp(x3)) = log(exp(x1) + exp(x2) + exp(x3))
142 | for (i, x_i) in enumerate(x.T):
143 | if i > 1:
144 | sm = softmax_2_arg(sm, x_i, t)
145 | return sm
146 |
--------------------------------------------------------------------------------
/src/deep_rlsp/tests/test_toy_experiments.py:
--------------------------------------------------------------------------------
1 | """
2 | Simple smoke tests that run some basic experiments and make sure they don't crash
3 | """
4 |
5 |
6 | import unittest
7 |
8 | from deep_rlsp.run import ex
9 |
10 | ex.observers = []
11 |
12 | ALL_ALGORITHMS_ON_ROOM = [
13 | {
14 | "env_name": "room",
15 | "problem_spec": "default",
16 | "combination_algorithm": "additive",
17 | "inference_algorithm": "spec",
18 | "horizon": 7,
19 | "evaluation_horizon": 20,
20 | "epochs": 5,
21 | },
22 | {
23 | "env_name": "room",
24 | "problem_spec": "default",
25 | "combination_algorithm": "additive",
26 | "inference_algorithm": "deviation",
27 | "horizon": 7,
28 | "evaluation_horizon": 20,
29 | "inferred_weight": 0.5,
30 | "epochs": 5,
31 | },
32 | {
33 | "env_name": "room",
34 | "problem_spec": "default",
35 | "combination_algorithm": "additive",
36 | "inference_algorithm": "reachability",
37 | "horizon": 7,
38 | "evaluation_horizon": 20,
39 | "epochs": 5,
40 | },
41 | {
42 | "env_name": "room",
43 | "problem_spec": "default",
44 | "combination_algorithm": "additive",
45 | "inference_algorithm": "rlsp",
46 | "horizon": 7,
47 | "evaluation_horizon": 20,
48 | "solver": "value_iter",
49 | "epochs": 5,
50 | },
51 | {
52 | "env_name": "room",
53 | "problem_spec": "default",
54 | "combination_algorithm": "additive",
55 | "inference_algorithm": "rlsp",
56 | "horizon": 7,
57 | "evaluation_horizon": 20,
58 | "solver": "ppo",
59 | "solver_iterations": 100,
60 | "reset_solver": False,
61 | "epochs": 5,
62 | },
63 | {
64 | "env_name": "room",
65 | "problem_spec": "default",
66 | "combination_algorithm": "additive",
67 | "inference_algorithm": "rlsp",
68 | "horizon": 7,
69 | "evaluation_horizon": 20,
70 | "solver": "ppo",
71 | "solver_iterations": 100,
72 | "reset_solver": True,
73 | "epochs": 5,
74 | },
75 | ]
76 |
77 | DEVIATION_ON_ALL_ENVS = [
78 | {
79 | "env_name": "room",
80 | "problem_spec": "default",
81 | "combination_algorithm": "additive",
82 | "inference_algorithm": "deviation",
83 | "horizon": 7,
84 | "evaluation_horizon": 20,
85 | "inferred_weight": 0.5,
86 | "epochs": 5,
87 | },
88 | {
89 | "env_name": "train",
90 | "problem_spec": "default",
91 | "combination_algorithm": "additive",
92 | "inference_algorithm": "deviation",
93 | "horizon": 8,
94 | "evaluation_horizon": 20,
95 | "inferred_weight": 0.5,
96 | "epochs": 5,
97 | },
98 | {
99 | "env_name": "apples",
100 | "problem_spec": "default",
101 | "combination_algorithm": "additive",
102 | "inference_algorithm": "deviation",
103 | "horizon": 11,
104 | "evaluation_horizon": 20,
105 | "inferred_weight": 0.5,
106 | "epochs": 5,
107 | },
108 | {
109 | "env_name": "batteries",
110 | "problem_spec": "easy",
111 | "combination_algorithm": "additive",
112 | "inference_algorithm": "deviation",
113 | "horizon": 11,
114 | "evaluation_horizon": 20,
115 | "inferred_weight": 0.5,
116 | "epochs": 5,
117 | },
118 | {
119 | "env_name": "batteries",
120 | "problem_spec": "default",
121 | "combination_algorithm": "additive",
122 | "inference_algorithm": "deviation",
123 | "horizon": 11,
124 | "evaluation_horizon": 20,
125 | "inferred_weight": 0.5,
126 | "epochs": 5,
127 | },
128 | {
129 | "env_name": "room",
130 | "problem_spec": "bad",
131 | "combination_algorithm": "additive",
132 | "inference_algorithm": "deviation",
133 | "horizon": 5,
134 | "evaluation_horizon": 20,
135 | "inferred_weight": 0.5,
136 | "epochs": 5,
137 | },
138 | ]
139 |
140 |
141 | class TestToyExperiments(unittest.TestCase):
142 | def test_toy_experiments(self):
143 | """
144 | Runs experiments sequentially
145 | """
146 | for config_updates in ALL_ALGORITHMS_ON_ROOM + DEVIATION_ON_ALL_ENVS:
147 | ex.run(config_updates=config_updates)
148 |
--------------------------------------------------------------------------------
/src/deep_rlsp/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/deep-rlsp/81941693aba2aa9157ca96e96567f4e3cb95fbc3/src/deep_rlsp/util/__init__.py
--------------------------------------------------------------------------------
/src/deep_rlsp/util/dist.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.stats import norm, laplace
3 |
4 |
5 | class NormalDistribution(object):
6 | def __init__(self, mu, sigma=1):
7 | self.mu = mu
8 | self.sigma = sigma
9 | self.distribution = norm(loc=mu, scale=sigma)
10 |
11 | def rvs(self):
12 | """sample"""
13 | return self.distribution.rvs()
14 |
15 | def pdf(self, x):
16 | return self.distribution.pdf(x)
17 |
18 | def logpdf(self, x):
19 | return self.distribution.logpdf(x)
20 |
21 | def logdistr_grad(self, x):
22 | return (self.mu - x) / (self.sigma ** 2)
23 |
24 |
25 | class LaplaceDistribution(object):
26 | def __init__(self, mu, b=1):
27 | self.mu = mu
28 | self.b = b
29 | self.distribution = laplace(loc=mu, scale=b)
30 |
31 | def rvs(self):
32 | """sample"""
33 | return self.distribution.rvs()
34 |
35 | def pdf(self, x):
36 | return self.distribution.pdf(x)
37 |
38 | def logpdf(self, x):
39 | return self.distribution.logpdf(x)
40 |
41 | def logdistr_grad(self, x):
42 | return (self.mu - x) / (np.fabs(x - self.mu) * self.b)
43 |
--------------------------------------------------------------------------------
/src/deep_rlsp/util/helper.py:
--------------------------------------------------------------------------------
1 | import pickle
2 |
3 | import numpy as np
4 | from stable_baselines import SAC
5 | from copy import deepcopy
6 |
7 | from deep_rlsp.util.video import save_video
8 | from deep_rlsp.util.mujoco import initialize_mujoco_from_obs
9 |
10 |
11 | def load_data(filename):
12 | with open(filename, "rb") as f:
13 | play_data = pickle.load(f)
14 | return play_data
15 |
16 |
17 | def init_env_from_obs(env, obs):
18 | id = env.spec.id
19 | if (
20 | "InvertedPendulum" in id
21 | or "HalfCheetah" in id
22 | or "Hopper" in id
23 | or "Ant" in id
24 | or "Fetch" in id
25 | ):
26 | return initialize_mujoco_from_obs(env, obs)
27 | else:
28 | # gridworld env
29 | state = env.obs_to_s(obs)
30 | env.reset()
31 | env.unwrapped.s = deepcopy(state)
32 | return env
33 |
34 |
35 | def get_trajectory(
36 | env,
37 | policy,
38 | get_observations=False,
39 | get_rgbs=False,
40 | get_return=False,
41 | print_debug=False,
42 | ):
43 | observations = [] if get_observations else None
44 | trajectory_rgbs = [] if get_rgbs else None
45 | total_reward = 0 if get_return else None
46 |
47 | obs = env.reset()
48 | done = False
49 | while not done:
50 | if isinstance(policy, SAC):
51 | a = policy.predict(np.expand_dims(obs, 0), deterministic=False)[0][0]
52 | else:
53 | a, _ = policy.predict(obs, deterministic=False)
54 | obs, reward, done, info = env.step(a)
55 | if print_debug:
56 | print("action")
57 | print(a)
58 | print("obs")
59 | print(obs)
60 | print("reward")
61 | print(reward)
62 | if get_observations:
63 | observations.append(obs)
64 | if get_rgbs:
65 | rgb = env.render("rgb_array")
66 | trajectory_rgbs.append(rgb)
67 | if get_return:
68 | total_reward += reward
69 | return observations, trajectory_rgbs, total_reward
70 |
71 |
72 | def evaluate_policy(env, policy, n_rollouts, video_out=None, print_debug=False):
73 | total_reward = 0
74 | for i in range(n_rollouts):
75 | get_rgbs = i == 0 and video_out is not None
76 | _, trajectory_rgbs, total_reward_episode = get_trajectory(
77 | env, policy, False, get_rgbs, True, print_debug=(print_debug and i == 0)
78 | )
79 | total_reward += total_reward_episode
80 | if get_rgbs:
81 | save_video(trajectory_rgbs, video_out, fps=20.0)
82 | print("Saved video to", video_out)
83 | return total_reward / n_rollouts
84 |
85 |
86 | def memoize(f):
87 | # Assumes that all inputs to f are 1-D Numpy arrays
88 | memo = {}
89 |
90 | def helper(*args):
91 | key = tuple((tuple(x) for x in args))
92 | if key not in memo:
93 | memo[key] = f(*args)
94 | return memo[key]
95 |
96 | return helper
97 |
98 |
99 | def sample_obs_from_trajectory(observations, n_samples):
100 | idx = np.random.choice(np.arange(len(observations)), n_samples)
101 | return np.array(observations)[idx]
102 |
--------------------------------------------------------------------------------
/src/deep_rlsp/util/linalg.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | # @memoize # don't memoize for continuous vectors, leads to memory leak
5 | def get_cosine_similarity(vec_a, vec_b):
6 | return np.dot(vec_a, vec_b) / (np.linalg.norm(vec_a) * np.linalg.norm(vec_b))
7 |
--------------------------------------------------------------------------------
/src/deep_rlsp/util/mujoco.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class MujocoObsClipper:
5 | def __init__(self, env_id):
6 | if env_id.startswith("InvertedPendulum"):
7 | self.low = -300
8 | self.high = 300
9 | elif env_id.startswith("HalfCheetah"):
10 | self.low = -300
11 | self.high = 300
12 | elif env_id.startswith("Hopper"):
13 | self.low = -300
14 | self.high = 300
15 | elif env_id.startswith("Ant"):
16 | self.low = -100
17 | self.high = 100
18 | else:
19 | self.low = -float("inf")
20 | self.high = float("inf")
21 | self.counter = 0
22 |
23 | def clip(self, obs):
24 | obs_c = np.clip(obs, self.low, self.high)
25 | clipped = np.any(obs != obs_c)
26 | if clipped:
27 | self.counter += 1
28 | return obs_c, clipped
29 |
30 |
31 | def initialize_mujoco_from_obs(env, obs):
32 | """
33 | Initialize a given mujoco environment to a state conditioned on an observation.
34 |
35 | Missing information in the observation (which is usually the torso-coordinates)
36 | is filled with zeros.
37 | """
38 | env_id = env.unwrapped.spec.id
39 | if env_id == "InvertedPendulum-v2":
40 | nfill = 0
41 | elif env_id in (
42 | "HalfCheetah-v2",
43 | "HalfCheetah-FW-v2",
44 | "HalfCheetah-BW-v2",
45 | "HalfCheetah-Plot-v2",
46 | "Hopper-v2",
47 | "Hopper-FW-v2",
48 | ):
49 | nfill = 1
50 | elif env_id == "Ant-FW-v2":
51 | nfill = 2
52 | elif env_id == "FetchReachStack-v1":
53 | env.set_state_from_obs(obs)
54 | return env
55 | else:
56 | raise NotImplementedError(f"{env_id} not supported")
57 |
58 | nq = env.model.nq
59 | nv = env.model.nv
60 | obs_qpos = np.zeros(nq)
61 | obs_qpos[nfill:] = obs[: nq - nfill]
62 | obs_qvel = obs[nq - nfill : nq - nfill + nv]
63 | env.set_state(obs_qpos, obs_qvel)
64 | return env
65 |
66 |
67 | def get_reward_done_from_obs_act(env, obs, act):
68 | """
69 | Returns a reward and done variable from an obs and action in a mujoco environment.
70 |
71 | This is a hacky way to get some information about the reward of a state and whether
72 | the episode is done, if you just have an observation. However, it is only used for
73 | evaluating the latent space model by directly training a policy on this signal.
74 | """
75 | env = initialize_mujoco_from_obs(env, obs)
76 | obs, reward, done, info = env.step(act)
77 | return reward, done
78 |
79 |
80 | def compute_reward_done_from_obs(env, last_obs, action, next_obs):
81 | """
82 | Returns a reward and done variable from a transition in a mujoco environment.
83 |
84 | Both are manually computed from the observation.
85 | """
86 | env_id = env.unwrapped.spec.id
87 | if env_id == "InvertedPendulum-v2":
88 | # https://github.com/openai/gym/blob/master/gym/envs/mujoco/inverted_pendulum.py
89 | reward = 1.0
90 | notdone = np.isfinite(next_obs).all() and (np.abs(next_obs[1]) <= 0.2)
91 | done = not notdone
92 | elif env_id == "HalfCheetah-v2":
93 | # Note: This does not work for the default environment because the xposition
94 | # is being removed from the observation by the environment.
95 | nv = env.model.nv
96 | last_obs_qpos = last_obs[:-nv]
97 | next_obs_qpos = next_obs[:-nv]
98 | xposbefore = last_obs_qpos[0]
99 | xposafter = next_obs_qpos[0]
100 | reward_ctrl = -0.1 * np.square(action).sum()
101 | reward_run = (xposafter - xposbefore) / env.dt
102 | reward = reward_ctrl + reward_run
103 | done = False
104 | else:
105 | raise NotImplementedError("{} not supported".format(env_id))
106 | return reward, done
107 |
--------------------------------------------------------------------------------
/src/deep_rlsp/util/parameter_checks.py:
--------------------------------------------------------------------------------
1 | def check_in(name, value, allowed_values):
2 | if value not in allowed_values:
3 | raise ValueError(
4 | "Invalid value for '{}': {}. Must be in {}.".format(
5 | name, value, allowed_values
6 | )
7 | )
8 |
9 |
10 | def check_greater_equal(name, value, greater_equal_value):
11 | if value < greater_equal_value:
12 | raise ValueError(
13 | "Invalid value for '{}': {}. Must be >= {}.".format(
14 | name, value, greater_equal_value
15 | )
16 | )
17 |
18 |
19 | def check_less_equal(name, value, less_equal_value):
20 | if value > less_equal_value:
21 | raise ValueError(
22 | "Invalid value for '{}': {}. Must be <= {}.".format(
23 | name, value, less_equal_value
24 | )
25 | )
26 |
27 |
28 | def check_between(name, value, lower_bound, upper_bound):
29 | check_greater_equal(name, value, lower_bound)
30 | check_less_equal(name, value, upper_bound)
31 |
32 |
33 | def check_not_none(name, value):
34 | if value is None:
35 | raise ValueError("'{}' cannot be None".format(name))
36 |
--------------------------------------------------------------------------------
/src/deep_rlsp/util/probs.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def sample_observation(obs_probs_tuple):
5 | # inverse of zip(observations, probs)
6 | observations, probs = zip(*obs_probs_tuple)
7 | i = np.random.choice(len(probs), p=probs)
8 | return observations[i]
9 |
10 |
11 | def get_out_probs_tuple(out_probs, dtype, shape):
12 | """
13 | Convert between a dictionary of state-probability pairs to a tuple.
14 |
15 | Returns a tuple consisting of an observation and a probability given a dictionary
16 | with string representations of the observations as keys and probabilities as values.
17 |
18 | Normalizes the probabilities in the process.
19 | """
20 | probs_sum = sum(prob for _, prob in out_probs.items())
21 | out_probs_tuple = (
22 | (np.fromstring(obs_str, dtype).reshape(shape), prob / probs_sum)
23 | for obs_str, prob in out_probs.items()
24 | )
25 | return out_probs_tuple
26 |
27 |
28 | def add_obs_prob_to_dict(dictionary, obs, prob):
29 | """
30 | Updates a dictionary that tracks a probability distribution over observations.
31 | """
32 | obs_str = obs.tostring()
33 | if obs_str not in dictionary:
34 | dictionary[obs_str] = 0
35 | dictionary[obs_str] += prob
36 | return dictionary
37 |
--------------------------------------------------------------------------------
/src/deep_rlsp/util/results.py:
--------------------------------------------------------------------------------
1 | import os
2 | import jsonpickle
3 |
4 | import jsonpickle.ext.numpy as jsonpickle_numpy
5 |
6 | jsonpickle_numpy.register_handlers()
7 |
8 |
9 | class Artifact:
10 | def __init__(self, file_name, method, _run):
11 | self._run = _run
12 | self.file_name = file_name
13 | self.file_path = os.path.join("/tmp", file_name)
14 | if method is not None:
15 | self.file_obj = open(self.file_path, method)
16 | else:
17 | self.file_obj = None
18 |
19 | def __enter__(self):
20 | if self.file_obj is None:
21 | return self.file_path
22 | else:
23 | return self.file_obj
24 |
25 | def __exit__(self, type, value, traceback):
26 | if self.file_obj is not None:
27 | self.file_obj.close()
28 | self._run.add_artifact(self.file_path)
29 |
30 |
31 | class FileExperimentResults:
32 | def __init__(self, result_folder):
33 | self.result_folder = result_folder
34 | self.config = self._read_json(result_folder, "config.json")
35 | self.metrics = self._read_json(result_folder, "metrics.json")
36 | self.run = self._read_json(result_folder, "run.json")
37 | try:
38 | self.info = self._read_json(result_folder, "info.json")
39 | except Exception as e:
40 | print(e)
41 | self.info = None
42 | self.status = self.run["status"]
43 | self.result = self.run["result"]
44 |
45 | def _read_json(self, result_folder, filename):
46 | with open(os.path.join(result_folder, filename), "r") as f:
47 | json_str = f.read()
48 | return jsonpickle.loads(json_str)
49 |
50 | def get_metric(self, name):
51 | metric = self.metrics[name]
52 | return metric["steps"], metric["values"]
53 |
54 | def print_captured_output(self):
55 | with open(os.path.join(self.result_folder, "cout.txt"), "r") as f:
56 | print(f.read())
57 |
--------------------------------------------------------------------------------
/src/deep_rlsp/util/timer.py:
--------------------------------------------------------------------------------
1 | from time import time
2 |
3 |
4 | class Timer:
5 | def __init__(self):
6 | self.entries = {}
7 | self.start_times = {}
8 |
9 | def start(self, description):
10 | return self.Instance(self, description)
11 |
12 | def get_average_time(self, description):
13 | times = self.entries[description]
14 | return sum(times) / len(times)
15 |
16 | def get_total_time(self, description):
17 | return sum(self.entries[description])
18 |
19 | class Instance:
20 | def __init__(self, timer, description):
21 | self.timer = timer
22 | self.description = description
23 |
24 | def __enter__(self):
25 | if self.description in self.timer.start_times:
26 | raise Exception(
27 | "Cannot start timing {}".format(self.description)
28 | + " again before finishing the last invocation"
29 | )
30 | self.timer.start_times[self.description] = time()
31 |
32 | def __exit__(self, type, value, traceback):
33 | end = time()
34 | start = self.timer.start_times[self.description]
35 |
36 | if self.description not in self.timer.entries:
37 | self.timer.entries[self.description] = []
38 |
39 | self.timer.entries[self.description].append(end - start)
40 | del self.timer.start_times[self.description]
41 |
42 |
43 | def main():
44 | def fact(n):
45 | if n == 0:
46 | return 1
47 | return n * fact(n - 1)
48 |
49 | my_timer = Timer()
50 | for i in range(10):
51 | with my_timer.start("fact"):
52 | print(fact(i))
53 |
54 | print(my_timer.get_average_time("fact"))
55 | print(my_timer.get_total_time("fact"))
56 | print(my_timer.entries)
57 |
58 | def bad_fact(n):
59 | if n == 0:
60 | return 1
61 | with my_timer.start("bad_fact"):
62 | return n * fact(n - 1)
63 |
64 | with my_timer.start("bad_fact"):
65 | print(bad_fact(10))
66 |
67 |
68 | if __name__ == "__main__":
69 | main()
70 |
--------------------------------------------------------------------------------
/src/deep_rlsp/util/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | from deep_rlsp.util.parameter_checks import check_between, check_greater_equal
5 |
6 |
7 | def get_tf_session():
8 | config = tf.ConfigProto()
9 | config.gpu_options.allow_growth = True
10 | return tf.Session(config=config)
11 |
12 |
13 | def get_learning_rate(initial_learning_rate, decay_steps, decay_rate):
14 | global_step = tf.Variable(0, trainable=False)
15 | if decay_rate == 1:
16 | learning_rate = tf.convert_to_tensor(initial_learning_rate)
17 | else:
18 | check_between("decay_rate", decay_rate, 0, 1)
19 | check_greater_equal("decay_steps", decay_steps, 1)
20 | learning_rate = tf.train.exponential_decay(
21 | initial_learning_rate,
22 | global_step,
23 | decay_steps=decay_steps,
24 | decay_rate=decay_rate,
25 | )
26 | return learning_rate, global_step
27 |
28 |
29 | def tensorboard_log_gradients(gradients):
30 | for gradient, variable in gradients:
31 | tf.summary.scalar("gradients/" + variable.name, tf.norm(gradient, ord=2))
32 | tf.summary.scalar("variables/" + variable.name, tf.norm(variable, ord=2))
33 |
34 |
35 | def get_batch(data, batch, batch_size):
36 | batches = []
37 | for dataset in data:
38 | batch_array = dataset[batch * batch_size : (batch + 1) * batch_size]
39 | batches.append(batch_array)
40 | return batches
41 |
42 |
43 | def shuffle_data(data):
44 | n_states = len(data[0])
45 | shuffled = np.arange(n_states)
46 | np.random.shuffle(shuffled)
47 | shuffled_data = []
48 | for dataset in data:
49 | assert len(dataset) == n_states
50 | shuffled_data.append(np.array(dataset)[shuffled])
51 | return shuffled_data
52 |
--------------------------------------------------------------------------------
/src/deep_rlsp/util/video.py:
--------------------------------------------------------------------------------
1 | from deep_rlsp.util.mujoco import initialize_mujoco_from_obs
2 |
3 |
4 | def save_video(ims, filename, fps=20.0):
5 | import cv2
6 |
7 | # Define the codec and create VideoWriter object
8 | fourcc = cv2.VideoWriter_fourcc(*"MJPG")
9 | (height, width, _) = ims[0].shape
10 | writer = cv2.VideoWriter(filename, fourcc, fps, (width, height))
11 | for im in ims:
12 | im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
13 | writer.write(im)
14 | writer.release()
15 |
16 |
17 | def render_mujoco_from_obs(env, obs, **kwargs):
18 | env = initialize_mujoco_from_obs(env, obs)
19 | rgb = env.render(mode="rgb_array", **kwargs)
20 | return rgb
21 |
--------------------------------------------------------------------------------