├── .gitignore
├── envs
├── credit_assign
│ ├── __init__.py
│ ├── catch.py
│ └── key_to_door
│ │ ├── common.py
│ │ ├── env.py
│ │ ├── game.py
│ │ ├── key_to_door.py
│ │ ├── objects.py
│ │ ├── readme.md
│ │ └── tvt_wrapper.py
├── dmc
│ ├── __init__.py
│ └── dmc_env.py
├── make_pomdp_env.py
├── memory_envs
│ ├── __init__.py
│ ├── configs
│ │ ├── keytodoor.py
│ │ ├── terminal_fns.py
│ │ ├── tmaze_active.py
│ │ ├── tmaze_passive.py
│ │ └── visual_match.py
│ ├── key_to_door
│ │ ├── common.py
│ │ ├── env.py
│ │ ├── game.py
│ │ ├── key_to_door.py
│ │ ├── objects.py
│ │ ├── readme.md
│ │ ├── tvt_wrapper.py
│ │ └── visual_match.py
│ ├── make_env.py
│ ├── readme.md
│ └── tmaze.py
├── meta
│ ├── __init__.py
│ ├── dynamics_meta_env_wrapper.py
│ ├── example_env.py
│ ├── make_env.py
│ ├── mujoco
│ │ ├── ant.py
│ │ ├── ant_dir.py
│ │ ├── ant_goal.py
│ │ ├── ant_multitask_base.py
│ │ ├── assets
│ │ │ ├── ant.xml
│ │ │ └── low_gear_ratio_ant.xml
│ │ ├── core
│ │ │ ├── __init__.py
│ │ │ └── serializable.py
│ │ ├── half_cheetah.py
│ │ ├── half_cheetah_dir.py
│ │ ├── half_cheetah_vel.py
│ │ ├── humanoid_dir.py
│ │ └── mujoco_env.py
│ ├── readme.md
│ ├── toy_navigation
│ │ ├── gridworld.py
│ │ ├── point_robot.py
│ │ └── wind.py
│ └── wrappers.py
├── pomdp
│ ├── __init__.py
│ ├── readme.md
│ └── wrappers.py
├── pomdp_config.py
├── readme.md
├── rl_generalization
│ ├── .gitignore
│ ├── LICENSE
│ ├── README.md
│ ├── setup.py
│ ├── sunblaze_envs
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── assets
│ │ │ └── vizdoom
│ │ │ │ ├── .gitignore
│ │ │ │ ├── basic.wad
│ │ │ │ ├── basic_floor_ceiling_flipped.wad
│ │ │ │ ├── basic_torches.wad
│ │ │ │ ├── navigation.wad
│ │ │ │ ├── navigation_floor_ceiling_flipped.wad
│ │ │ │ ├── navigation_new_layout.wad
│ │ │ │ ├── navigation_torches.wad
│ │ │ │ ├── texture_set_a.txt
│ │ │ │ ├── texture_set_b.txt
│ │ │ │ ├── thing_set_a.txt
│ │ │ │ └── thing_set_b.txt
│ │ ├── base.py
│ │ ├── breakout.py
│ │ ├── classic_control.py
│ │ ├── monitor.py
│ │ ├── mujoco.py
│ │ ├── physical_world.py
│ │ ├── registration.py
│ │ ├── space_invaders.py
│ │ ├── time_limit.py
│ │ ├── vizdoom.py
│ │ └── wrappers.py
│ └── test.py
├── torchkit
│ ├── constant.py
│ ├── core.py
│ ├── distributions.py
│ ├── modules.py
│ ├── networks.py
│ ├── policies_base.py
│ ├── pytorch_utils.py
│ └── serializable.py
├── utils
│ ├── evaluation.py
│ ├── helpers.py
│ ├── logger.py
│ └── system.py
└── yang_domains
│ ├── __init__.py
│ ├── ant_reacher_top.py
│ ├── assets
│ ├── ant_reacher.xml
│ ├── box_1d.xml
│ ├── bump_1d.xml
│ ├── bump_1d_model.xml
│ ├── bumps_1d_model.xml
│ ├── clockwise.png
│ ├── grid_markers.xml
│ ├── gripah
│ │ ├── gripah.urdf
│ │ ├── narrow_finger.STL
│ │ ├── narrow_finger_rescaled.STL
│ │ ├── narrow_finger_tip.STL
│ │ ├── wide_finger.STL
│ │ ├── wide_finger_rescaled.STL
│ │ ├── wide_finger_tip.STL
│ │ ├── wrist.STL
│ │ └── wrist_rescaled.STL
│ ├── gripah_asset.xml
│ ├── gripah_body.xml
│ ├── gripah_contact.xml
│ ├── muj_gripper
│ │ ├── c_base.stl
│ │ ├── c_forearm.stl
│ │ ├── c_robotiq_85_gripper_joint_3_L.stl
│ │ ├── c_robotiq_85_gripper_joint_3_R.stl
│ │ ├── c_shoulder.stl
│ │ ├── c_upperarm.stl
│ │ ├── c_wrist1.stl
│ │ ├── c_wrist2.stl
│ │ ├── c_wrist3.stl
│ │ ├── glass_cup.stl
│ │ ├── glass_cup_2.stl
│ │ ├── glass_cup_3.stl
│ │ ├── inner_finger_coarse.stl
│ │ ├── inner_finger_fine.stl
│ │ ├── inner_knuckle_coarse.stl
│ │ ├── inner_knuckle_fine.stl
│ │ ├── new_solo_cup.stl
│ │ ├── outer_finger_coarse.stl
│ │ ├── outer_finger_fine.stl
│ │ ├── outer_knuckle_coarse.stl
│ │ ├── outer_knuckle_fine.stl
│ │ ├── red_solo_cup.stl
│ │ ├── robotiq_85_base_link_coarse.stl
│ │ ├── robotiq_85_base_link_fine.stl
│ │ ├── smaller_solo_cup.stl
│ │ ├── solo_cup.stl
│ │ ├── upd_solo_cup.stl
│ │ ├── v_base.stl
│ │ ├── v_forearm.stl
│ │ ├── v_robotiq_85_gripper_joint_3_L.stl
│ │ ├── v_robotiq_85_gripper_joint_3_R.stl
│ │ ├── v_shoulder.stl
│ │ ├── v_upperarm.stl
│ │ ├── v_wrist1.stl
│ │ ├── v_wrist2.stl
│ │ └── v_wrist3.stl
│ ├── objects
│ │ ├── bump_40_mujoco.STL
│ │ ├── bump_50_mujoco.STL
│ │ ├── bump_80_mujoco.STL
│ │ ├── plate_half.STL
│ │ └── plate_whole.STL
│ ├── topplate_model.xml
│ └── ur5_reacher.xml
│ ├── box_top.py
│ ├── bump_mdp.py
│ ├── bump_top.py
│ ├── car.py
│ ├── car_episodic.py
│ ├── car_top.py
│ ├── car_top_narrow.py
│ ├── car_top_relative.py
│ ├── cartpole_balance.py
│ ├── cartpole_var_len.py
│ ├── dmc_acrobot.py
│ ├── dmc_cart2pole.py
│ ├── dmc_cart3pole.py
│ ├── dmc_cartpole_b.py
│ ├── dmc_cartpole_su.py
│ ├── dmc_pendulum_su.py
│ ├── dmc_walker_walk.py
│ ├── pendulum_swingup_from_vrm.py
│ ├── pendulum_var_len.py
│ ├── pybullet_ant.py
│ ├── pybullet_halfcheetah.py
│ ├── reacher.py
│ ├── robot_envs
│ ├── __init__.py
│ ├── assets
│ │ ├── bumps
│ │ │ ├── bump_40.stl
│ │ │ ├── bump_40_blue.urdf
│ │ │ ├── bump_40_red.urdf
│ │ │ ├── bump_40_virtual.urdf
│ │ │ ├── bump_50.stl
│ │ │ └── bump_50.urdf
│ │ ├── cup
│ │ │ ├── cup.stl
│ │ │ └── cup.urdf
│ │ ├── plane
│ │ │ ├── checker_blue.png
│ │ │ ├── plane.mtl
│ │ │ ├── plane.urdf
│ │ │ └── plane100.obj
│ │ ├── plate
│ │ │ ├── plate.stl
│ │ │ ├── plate.urdf
│ │ │ ├── plate_half.urdf
│ │ │ ├── plate_holder.stl
│ │ │ ├── plate_holder.urdf
│ │ │ ├── plate_lower_half.stl
│ │ │ └── plate_upper_half.stl
│ │ ├── shelf
│ │ │ ├── shelf_back_board.stl
│ │ │ ├── shelf_back_board.urdf
│ │ │ ├── shelf_horizontal_board.stl
│ │ │ ├── shelf_horizontal_board.urdf
│ │ │ ├── shelf_side_board.stl
│ │ │ └── shelf_side_board.urdf
│ │ └── workspace
│ │ │ ├── grid_mark.urdf
│ │ │ ├── plane.obj
│ │ │ ├── rail.urdf
│ │ │ └── workspace.urdf
│ ├── bump_target.py
│ ├── bumps_diff.py
│ ├── bumps_env.py
│ ├── bumps_norm.py
│ ├── bumps_norm_minmax.py
│ ├── bumps_norm_punish.py
│ ├── bumps_norm_real.py
│ ├── bumps_norm_test.py
│ ├── env.py
│ └── top_plate.py
│ ├── robots
│ ├── __init__.py
│ ├── assets
│ │ ├── jaco
│ │ │ ├── j2s7s300_gym.urdf
│ │ │ └── meshes
│ │ │ │ ├── arm.SLDPRT
│ │ │ │ ├── arm.STL
│ │ │ │ ├── arm.dae
│ │ │ │ ├── arm_half_1.STL
│ │ │ │ ├── arm_half_1.dae
│ │ │ │ ├── arm_half_2.STL
│ │ │ │ ├── arm_half_2.dae
│ │ │ │ ├── arm_mico.STL
│ │ │ │ ├── arm_mico.dae
│ │ │ │ ├── base.STL
│ │ │ │ ├── base.dae
│ │ │ │ ├── finger_distal.STL
│ │ │ │ ├── finger_distal.dae
│ │ │ │ ├── finger_proximal.STL
│ │ │ │ ├── finger_proximal.dae
│ │ │ │ ├── forearm.STL
│ │ │ │ ├── forearm.dae
│ │ │ │ ├── forearm_mico.STL
│ │ │ │ ├── forearm_mico.dae
│ │ │ │ ├── hand_2finger.STL
│ │ │ │ ├── hand_2finger.dae
│ │ │ │ ├── hand_3finger.STL
│ │ │ │ ├── hand_3finger.dae
│ │ │ │ ├── ring_big.STL
│ │ │ │ ├── ring_big.dae
│ │ │ │ ├── ring_small.STL
│ │ │ │ ├── ring_small.dae
│ │ │ │ ├── shoulder.STL
│ │ │ │ ├── shoulder.dae
│ │ │ │ ├── wrist.STL
│ │ │ │ ├── wrist.dae
│ │ │ │ ├── wrist_spherical_1.STL
│ │ │ │ ├── wrist_spherical_1.dae
│ │ │ │ ├── wrist_spherical_2.STL
│ │ │ │ └── wrist_spherical_2.dae
│ │ ├── rdda
│ │ │ ├── meshes
│ │ │ │ ├── narrow_finger.STL
│ │ │ │ ├── wide_finger.STL
│ │ │ │ └── wrist.STL
│ │ │ └── rdda.urdf
│ │ ├── robotiq
│ │ │ ├── README.md
│ │ │ ├── meshes
│ │ │ │ ├── robotiq-2f-base.mtl
│ │ │ │ ├── robotiq-2f-base.obj
│ │ │ │ ├── robotiq-2f-base.stl
│ │ │ │ ├── robotiq-2f-coupler.mtl
│ │ │ │ ├── robotiq-2f-coupler.obj
│ │ │ │ ├── robotiq-2f-coupler.stl
│ │ │ │ ├── robotiq-2f-driver.mtl
│ │ │ │ ├── robotiq-2f-driver.obj
│ │ │ │ ├── robotiq-2f-driver.stl
│ │ │ │ ├── robotiq-2f-follower.mtl
│ │ │ │ ├── robotiq-2f-follower.obj
│ │ │ │ ├── robotiq-2f-follower.stl
│ │ │ │ ├── robotiq-2f-pad.stl
│ │ │ │ ├── robotiq-2f-spring_link.mtl
│ │ │ │ ├── robotiq-2f-spring_link.obj
│ │ │ │ └── robotiq-2f-spring_link.stl
│ │ │ ├── robotiq_2f_85.urdf
│ │ │ └── textures
│ │ │ │ ├── gripper-2f_BaseColor.jpg
│ │ │ │ ├── gripper-2f_Metallic.jpg
│ │ │ │ ├── gripper-2f_Normal.jpg
│ │ │ │ └── gripper-2f_Roughness.jpg
│ │ ├── shovel
│ │ │ ├── meshes
│ │ │ │ ├── shovel_base.STL
│ │ │ │ └── shovel_blade.STL
│ │ │ └── shovel.urdf
│ │ ├── spatula
│ │ │ ├── meshes
│ │ │ │ └── base.obj
│ │ │ └── spatula-base.urdf
│ │ ├── suction
│ │ │ ├── meshes
│ │ │ │ ├── base.obj
│ │ │ │ ├── head.obj
│ │ │ │ ├── mid.obj
│ │ │ │ └── tip.obj
│ │ │ ├── suction-base.urdf
│ │ │ └── suction-head.urdf
│ │ └── ur5
│ │ │ ├── collision
│ │ │ ├── base.stl
│ │ │ ├── forearm.stl
│ │ │ ├── shoulder.stl
│ │ │ ├── upperarm.stl
│ │ │ ├── wrist1.stl
│ │ │ ├── wrist2.stl
│ │ │ └── wrist3.stl
│ │ │ ├── license.txt
│ │ │ ├── ur5.urdf
│ │ │ └── visual
│ │ │ ├── base.stl
│ │ │ ├── forearm.stl
│ │ │ ├── shoulder.stl
│ │ │ ├── upperarm.stl
│ │ │ ├── wrist1.stl
│ │ │ ├── wrist2.stl
│ │ │ └── wrist3.stl
│ ├── end_effector.py
│ ├── gripper.py
│ ├── jaco.py
│ ├── rdda.py
│ ├── robot.py
│ ├── robotiq.py
│ ├── shovel.py
│ ├── spatula.py
│ ├── suction.py
│ └── ur5.py
│ ├── ur5_mdp_top.py
│ ├── ur5_top.py
│ ├── utils
│ ├── __init__.py
│ └── utils.py
│ ├── water_maze.py
│ ├── water_maze_dense.py
│ ├── water_maze_simple.py
│ ├── wrappers.py
│ └── wrappers_for_image.py
├── gen_tmuxp_gpt_mujoco.py
├── gen_tmuxp_gpt_pomdp.py
├── gen_tmuxp_mamba_dmcontrol.py
├── gen_tmuxp_mamba_dynamics_rnd.py
├── gen_tmuxp_mamba_meta.py
├── gen_tmuxp_mamba_mujoco.py
├── gen_tmuxp_mamba_pomdp.py
├── main.py
├── offpolicy_rnn
├── __init__.py
├── algorithm
│ ├── sac.py
│ ├── sac_full_length_rnn_ensembleQ.py
│ ├── sac_full_length_rnn_ensembleQ_sep_optim.py
│ ├── sac_full_length_rnn_redq.py
│ ├── sac_full_length_rnn_redq_sep_optim.py
│ ├── sac_mlp.py
│ ├── sac_mlp_redq.py
│ ├── sac_mlp_redq_ensemble_q.py
│ ├── sac_rnn_slice.py
│ ├── td3_full_length_rnn_ensembleQ.py
│ ├── td3_full_length_rnn_redq.py
│ └── td3_full_length_rnn_redq_sep_optim.py
├── buffers
│ ├── replay_memory.py
│ ├── replay_memory_tail_padding.py
│ └── transition_buffer
│ │ ├── nested_replay_memory.py
│ │ ├── nested_replay_memory_sub_traj.py
│ │ └── replay_memory.py
├── config
│ ├── common_config.yaml
│ ├── experiment_config.yaml
│ └── load_config.py
├── env_utils
│ └── make_env.py
├── models
│ ├── RNNHidden.py
│ ├── contextual_model.py
│ ├── conv1d
│ │ ├── conv1d.py
│ │ └── econv1d.py
│ ├── ensemble_linear_model.py
│ ├── flash_attention
│ │ ├── TransformerFlashAttention.py
│ │ └── gpt.py
│ ├── gilr
│ │ ├── egilr.py
│ │ ├── gilr.py
│ │ └── scan_triton
│ │ │ ├── real_rnn_fast_pscan.py
│ │ │ ├── real_rnn_tie_input_gate.py
│ │ │ └── real_rnn_tie_input_gate_cpu.py
│ ├── gilr_lstm
│ │ ├── egilr_lstm.py
│ │ ├── gilr_lstm.py
│ │ └── scan_triton
│ │ │ ├── real_rnn_fast_pscan.py
│ │ │ ├── real_rnn_tie_input_gate.py
│ │ │ └── real_rnn_tie_input_gate_cpu.py
│ ├── lru
│ │ ├── elru.py
│ │ ├── lru.py
│ │ └── scan_triton
│ │ │ ├── complex_rnn.py
│ │ │ ├── complex_rnn_2.py
│ │ │ ├── complex_rnn_cpu.py
│ │ │ ├── complex_rnn_jax.py
│ │ │ └── pscan.py
│ ├── mlp_base.py
│ ├── multi_ensemble_linear_model.py
│ ├── readme.md
│ ├── rnn_base.py
│ ├── s6
│ │ ├── mamba.py
│ │ ├── mamba_no_conv.py
│ │ └── selective_scan
│ │ │ ├── cpu_scan.py
│ │ │ └── triton_scan.py
│ ├── smamba
│ │ ├── mamba.py
│ │ └── mamba_ssm
│ │ │ └── ops
│ │ │ ├── selective_scan_interface_new.py
│ │ │ └── triton
│ │ │ ├── layernorm.py
│ │ │ ├── layernorm_cpu.py
│ │ │ └── selective_state_update.py
│ └── torch_utility.py
├── parameter
│ └── ParameterSAC.py
├── policy_value_models
│ ├── contextual_sac_discrete_policy.py
│ ├── contextual_sac_discrete_value.py
│ ├── contextual_sac_policy.py
│ ├── contextual_sac_policy_double_head.py
│ ├── contextual_sac_policy_single_head.py
│ ├── contextual_sac_value.py
│ ├── contextual_td3_policy.py
│ ├── contextual_td3_value.py
│ ├── make_models.py
│ └── utils.py
└── utility
│ ├── ValueScheduler.py
│ ├── alg_init.py
│ ├── count_parameters.py
│ ├── q_value_guard.py
│ ├── sample_utility.py
│ └── timer.py
├── readme.md
├── readme_cn.md
├── requirement.txt
├── results.md
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | build
3 | dist
4 | *.egg-info
5 | __pycache__
6 | logfile
7 | logfile_bk
8 | .pytest_cache
9 | .DS_Store
10 | run_all.json
11 |
--------------------------------------------------------------------------------
/envs/credit_assign/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import collections
3 | if sys.version_info.major == 3 and sys.version_info.minor >= 10 and not hasattr(collections, "Mapping"):
4 | setattr(collections, "Mapping", collections.abc.Mapping)
5 | setattr(collections, "Sequence", collections.abc.Sequence)
6 |
7 | from gym.envs.registration import register
8 | from .key_to_door import key_to_door
9 |
10 | delay_fn = lambda runs: (runs - 1) * 7 + 6
11 |
12 | for runs in [1, 2, 5, 10, 20, 40]:
13 | register(
14 | f"Catch-{runs}-v0",
15 | entry_point="envs.credit_assign.catch:DelayedCatch",
16 | kwargs=dict(
17 | delay=delay_fn(runs),
18 | flatten_img=True,
19 | one_hot_actions=False,
20 | ),
21 | max_episode_steps=delay_fn(runs),
22 | )
23 |
24 | # optimal expected return: 1.0 * (~23) + 5.0 = 28. due to unknown number of respawned apples
25 | register(
26 | "KeytoDoor-SR-v0",
27 | entry_point="envs.credit_assign.key_to_door.tvt_wrapper:KeyToDoor",
28 | kwargs=dict(
29 | flatten_img=True,
30 | one_hot_actions=False,
31 | apple_reward=1.0,
32 | final_reward=5.0,
33 | respawn_every=20, # apple respawn after 20 steps
34 | REWARD_GRID=key_to_door.REWARD_GRID_SR,
35 | max_frames=key_to_door.MAX_FRAMES_PER_PHASE_SR,
36 | ),
37 | max_episode_steps=sum(key_to_door.MAX_FRAMES_PER_PHASE_SR.values()),
38 | )
39 |
--------------------------------------------------------------------------------
/envs/credit_assign/key_to_door/game.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=g-bad-file-header
2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ============================================================================
16 | """Pycolab Game interface."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import abc
23 | import six
24 |
25 |
26 | @six.add_metaclass(abc.ABCMeta)
27 | class AbstractGame(object):
28 | """Abstract base class for Pycolab games."""
29 |
30 | @abc.abstractmethod
31 | def __init__(self, rng, **settings):
32 | """Initialize the game."""
33 |
34 | @abc.abstractproperty
35 | def num_actions(self):
36 | """Number of possible actions in the game."""
37 |
38 | @abc.abstractproperty
39 | def colours(self):
40 | """Symbol to colour map for the game."""
41 |
42 | @abc.abstractmethod
43 | def make_episode(self):
44 | """Factory method for generating new episodes of the game."""
45 |
--------------------------------------------------------------------------------
/envs/credit_assign/key_to_door/readme.md:
--------------------------------------------------------------------------------
1 | # Key-to-Door Environment
2 | source code: https://github.com/deepmind/deepmind-research/tree/master/tvt/pycolab
3 |
4 | We adapt the environment framework and change some configurations to match the descriptions in [Synthetic Returns for Long-Term Credit Assignment](https://arxiv.org/abs/2102.12425) with environment name `"KeytoDoor-SR-v0"`.
5 |
--------------------------------------------------------------------------------
/envs/credit_assign/key_to_door/tvt_wrapper.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import gym
3 | from envs.credit_assign.key_to_door import env, key_to_door
4 |
5 |
6 | class KeyToDoor(gym.Env):
7 | def __init__(
8 | self,
9 | num_apples=10,
10 | apple_reward=1.0,
11 | fix_apple_reward_in_episode=True,
12 | final_reward=10.0,
13 | default_reward=0,
14 | respawn_every=20,
15 | REWARD_GRID=key_to_door.REWARD_GRID_SR,
16 | max_frames=key_to_door.MAX_FRAMES_PER_PHASE_SR,
17 | crop=True,
18 | flatten_img=True,
19 | one_hot_actions=False,
20 | ):
21 | super().__init__()
22 | self.pycolab_env = env.PycolabEnvironment(
23 | game="key_to_door",
24 | num_apples=num_apples,
25 | apple_reward=apple_reward,
26 | fix_apple_reward_in_episode=fix_apple_reward_in_episode,
27 | final_reward=final_reward,
28 | respawn_every=respawn_every,
29 | crop=crop,
30 | default_reward=default_reward,
31 | REWARD_GRID=REWARD_GRID,
32 | max_frames=max_frames,
33 | )
34 |
35 | self.action_space = gym.spaces.Discrete(4) # 4 directions
36 | self.one_hot_actions = one_hot_actions
37 |
38 | # original agent uses HWC size, but pytorch uses CHW size, so we transpose below
39 | self.img_size = (3, 5, 5)
40 | self.image_space = gym.spaces.Box(
41 | shape=self.img_size, low=0, high=255, dtype=np.uint8
42 | )
43 | # the pixel normalization should be done in image encoder, not here
44 |
45 | self.flatten_img = flatten_img
46 | if flatten_img:
47 | self.observation_space = gym.spaces.Box(
48 | shape=(np.array(self.img_size).prod(),), low=0, high=255, dtype=np.uint8
49 | )
50 | else:
51 | self.observation_space = self.image_space
52 |
53 | def _convert_obs(self, obs):
54 | new_obs = np.transpose(obs, (-1, 0, 1)) # (H,W,C) -> (C,H,W)
55 | if self.flatten_img:
56 | new_obs = new_obs.flatten() # -> (C*H*W)
57 | return new_obs
58 |
59 | def step(self, action):
60 | if self.one_hot_actions:
61 | action = np.argmax(action)
62 | obs, r = self.pycolab_env.step(action)
63 | self._ret += r
64 |
65 | info = {}
66 |
67 | if self.pycolab_env._episode.game_over:
68 | done = True
69 | info["success"] = self.pycolab_env.last_phase_reward() > 0.0
70 | else:
71 | done = False
72 |
73 | return self._convert_obs(obs), r, done, info
74 |
75 | def reset(self):
76 | obs, _ = self.pycolab_env.reset()
77 | self._ret = 0.0
78 |
79 | return self._convert_obs(obs)
80 |
81 |
82 | if __name__ == "__main__":
83 | env = KeyToDoor()
84 | obs = env.reset()
85 | done = False
86 | t = 0
87 | while not done:
88 | t += 1
89 | obs, rew, done, info = env.step(env.action_space.sample())
90 | print(t, rew, info)
91 |
--------------------------------------------------------------------------------
/envs/dmc/dmc_env.py:
--------------------------------------------------------------------------------
1 | # dmc environments
2 | import gym
3 | import dmc2gym
4 | import numpy as np
5 |
6 | #
7 | class DMCWrapper(gym.Env):
8 | def __init__(self, domain_name, task_name, seed=None):
9 | # print(f"domain name: {domain_name}, task name: {task_name}, seed: {seed}")
10 | self.env = dmc2gym.make(domain_name=domain_name, task_name=task_name, seed=seed)
11 | self.observation_space = self.env.observation_space
12 | self.action_space = self.env.action_space
13 | self._max_episode_steps = self.env.unwrapped._step_limit
14 | self.seed(seed)
15 |
16 | def seed(self, seed):
17 | self.env.seed(seed)
18 | self.env.action_space.seed(seed)
19 |
20 | def reset(self):
21 | return self.env.reset()
22 |
23 | def step(self, action):
24 | action = np.clip(action, self.action_space.low, self.action_space.high)
25 | return self.env.step(action)
26 |
--------------------------------------------------------------------------------
/envs/memory_envs/configs/keytodoor.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from gym.envs.registration import register
3 | from .terminal_fns import finite_horizon_terminal
4 | from ..key_to_door import key_to_door
5 | import copy
6 |
7 | def create_fn(env_name, continuous_act_space=False):
8 | if continuous_act_space:
9 | env_name_fn = lambda distract: f"Mem-SR-{distract}-cont-act-v0"
10 | else:
11 | env_name_fn = lambda distract: f"Mem-SR-{distract}-v0"
12 | env_name_ = env_name
13 | env_name = env_name_fn(env_name)
14 | MAX_FRAMES_PER_PHASE_SR = copy.deepcopy(key_to_door.MAX_FRAMES_PER_PHASE_SR)
15 |
16 | MAX_FRAMES_PER_PHASE_SR.update({"distractor": env_name_})
17 | # optimal expected return: 1.0 * (~23) + 5.0 = 28. due to unknown number of respawned apples
18 | register(
19 | env_name,
20 | entry_point="envs.memory_envs.key_to_door.tvt_wrapper:KeyToDoor",
21 | kwargs=dict(
22 | flatten_img=True,
23 | one_hot_actions=False,
24 | apple_reward=1.0,
25 | final_reward=5.0,
26 | respawn_every=20, # apple respawn after 20 steps
27 | REWARD_GRID=copy.deepcopy(key_to_door.REWARD_GRID_SR),
28 | max_frames=copy.deepcopy(MAX_FRAMES_PER_PHASE_SR),
29 | continuous_act_space=continuous_act_space
30 | ),
31 | max_episode_steps=sum(MAX_FRAMES_PER_PHASE_SR.values()), # no info if max episode steps = sum(MAX_FRAMES_PER_PHASE_SR.values())
32 | )
33 | return env_name
34 |
35 |
36 | #
37 | # def get_config():
38 | # config = ConfigDict()
39 | # config.create_fn = create_fn
40 | #
41 | # config.env_type = "key_to_door"
42 | # config.terminal_fn = finite_horizon_terminal
43 | #
44 | # config.eval_interval = 50
45 | # config.save_interval = 50
46 | # config.eval_episodes = 20
47 | #
48 | # config.env_name = 60
49 | #
50 | # return config
51 |
--------------------------------------------------------------------------------
/envs/memory_envs/configs/terminal_fns.py:
--------------------------------------------------------------------------------
1 | import gym
2 |
3 |
4 | def finite_horizon_terminal(env: gym.Env, done: bool, info: dict) -> bool:
5 | return done
6 |
7 |
8 | def infinite_horizon_terminal(env: gym.Env, done: bool, info: dict) -> bool:
9 | if not done or "TimeLimit.truncated" in info:
10 | terminal = False
11 | else:
12 | terminal = True
13 | return terminal
14 |
--------------------------------------------------------------------------------
/envs/memory_envs/configs/tmaze_active.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from gym.envs.registration import register
3 | from .terminal_fns import finite_horizon_terminal
4 |
5 | env_name_fn = lambda l: f"Mem-T-active-{l}-v0"
6 | env_name_cont_fn = lambda l: f"Mem-T-active-{l}-cont-act-v0"
7 |
8 | def create_fn(env_name, continuous_act_space=False):
9 | length = env_name
10 | env_name = env_name_cont_fn(length) if continuous_act_space else env_name_fn(length)
11 | register(
12 | env_name,
13 | entry_point="envs.memory_envs.tmaze:TMazeClassicActive",
14 | kwargs=dict(
15 | corridor_length=length,
16 | penalty=-1.0 / length, # NOTE: \sum_{t=1}^T -1/T = -1
17 | distract_reward=0.0,
18 | continuous_act_space=continuous_act_space
19 | ),
20 | max_episode_steps=length + 2 * 1 + 1, # NOTE: has to define it here
21 | )
22 |
23 | return env_name
24 |
--------------------------------------------------------------------------------
/envs/memory_envs/configs/tmaze_passive.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from gym.envs.registration import register
3 | from .terminal_fns import finite_horizon_terminal
4 |
5 | env_name_fn = lambda l: f"Mem-T-passive-{l}-v0"
6 | env_name_cont_fn = lambda l: f"Mem-T-passive-{l}-cont-act-v0"
7 |
8 |
9 | def create_fn(env_name, continuous_act_space=False):
10 | length = env_name
11 | env_name = env_name_cont_fn(length) if continuous_act_space else env_name_fn(length)
12 | register(
13 | env_name,
14 | entry_point="envs.memory_envs.tmaze:TMazeClassicPassive",
15 | kwargs=dict(
16 | corridor_length=length,
17 | penalty=-1.0 / length, # NOTE: \sum_{t=1}^T -1/T = -1
18 | distract_reward=0.0,
19 | continuous_act_space=continuous_act_space
20 | ),
21 | max_episode_steps=length + 1, # NOTE: has to define it here
22 | )
23 |
24 | return env_name
25 |
26 |
27 | # def get_config():
28 | # config = ConfigDict()
29 | # config.create_fn = create_fn
30 | #
31 | # config.env_type = "tmaze_passive"
32 | # config.terminal_fn = finite_horizon_terminal
33 | #
34 | # config.eval_interval = 10
35 | # config.save_interval = 10
36 | # config.eval_episodes = 10
37 | #
38 | # # [1, 2, 5, 10, 30, 50, 100, 300, 500, 1000]
39 | # config.env_name = 10
40 | # config.distract_reward = 0.0
41 | #
42 | # return config
43 |
--------------------------------------------------------------------------------
/envs/memory_envs/configs/visual_match.py:
--------------------------------------------------------------------------------
1 | from ml_collections import ConfigDict
2 | from typing import Tuple
3 | from gym.envs.registration import register
4 | from .terminal_fns import finite_horizon_terminal
5 | from ..key_to_door import visual_match
6 |
7 |
8 | def create_fn(config: ConfigDict) -> Tuple[ConfigDict, str]:
9 | env_name_fn = lambda distract: f"passive-visual-{distract}-v0"
10 | env_name = env_name_fn(config.env_name)
11 | MAX_FRAMES_PER_PHASE = visual_match.MAX_FRAMES_PER_PHASE
12 | MAX_FRAMES_PER_PHASE.update({"distractor": config.env_name})
13 |
14 | # optimal expected return: 1.0 * (~23) + 5.0 = 28. due to unknown number of respawned apples
15 | register(
16 | env_name,
17 | entry_point="envs.key_to_door.tvt_wrapper:VisualMatch",
18 | kwargs=dict(
19 | flatten_img=True,
20 | one_hot_actions=False,
21 | apple_reward=1.0,
22 | final_reward=5.0,
23 | respawn_every=20, # apple respawn after 20 steps
24 | passive=True,
25 | max_frames=MAX_FRAMES_PER_PHASE,
26 | ),
27 | max_episode_steps=sum(MAX_FRAMES_PER_PHASE.values()),
28 | )
29 |
30 | del config.create_fn
31 | return config, env_name
32 |
33 |
34 | def get_config():
35 | config = ConfigDict()
36 | config.create_fn = create_fn
37 |
38 | config.env_type = "visual_match"
39 | config.terminal_fn = finite_horizon_terminal
40 |
41 | config.eval_interval = 50
42 | config.save_interval = 50
43 | config.eval_episodes = 20
44 |
45 | config.env_name = 60
46 |
47 | return config
48 |
--------------------------------------------------------------------------------
/envs/memory_envs/key_to_door/game.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=g-bad-file-header
2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ============================================================================
16 | """Pycolab Game interface."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import abc
23 | import six
24 |
25 |
26 | @six.add_metaclass(abc.ABCMeta)
27 | class AbstractGame(object):
28 | """Abstract base class for Pycolab games."""
29 |
30 | @abc.abstractmethod
31 | def __init__(self, rng, **settings):
32 | """Initialize the game."""
33 |
34 | @abc.abstractproperty
35 | def num_actions(self):
36 | """Number of possible actions in the game."""
37 |
38 | @abc.abstractproperty
39 | def colours(self):
40 | """Symbol to colour map for the game."""
41 |
42 | @abc.abstractmethod
43 | def make_episode(self):
44 | """Factory method for generating new episodes of the game."""
45 |
--------------------------------------------------------------------------------
/envs/memory_envs/key_to_door/readme.md:
--------------------------------------------------------------------------------
1 | # Key-to-Door Environments
2 | source code: https://github.com/deepmind/deepmind-research/tree/master/tvt/pycolab
3 |
4 | Used in https://arxiv.org/pdf/2307.03864.pdf
5 |
--------------------------------------------------------------------------------
/envs/memory_envs/make_env.py:
--------------------------------------------------------------------------------
1 | import gym
2 | from gym.wrappers import RescaleAction
3 |
4 |
5 | def make_env(
6 | env_name: str,
7 | seed: int,
8 | ) -> gym.Env:
9 | # Check if the env is in gym.
10 | env = gym.make(env_name)
11 |
12 | env.max_episode_steps = getattr(
13 | env, "max_episode_steps", env.spec.max_episode_steps
14 | )
15 |
16 | if isinstance(env.action_space, gym.spaces.Box):
17 | env = RescaleAction(env, -1.0, 1.0)
18 |
19 | env.seed(seed)
20 | env.action_space.seed(seed)
21 | env.observation_space.seed(seed)
22 | return env
23 |
--------------------------------------------------------------------------------
/envs/memory_envs/readme.md:
--------------------------------------------------------------------------------
1 | # Key-to-Door Environments
2 | Used in https://arxiv.org/pdf/2307.03864.pdf
3 |
4 | 1. Passive T-Maze (pure long memory) shows "Transformer-based agent can solve long-term memory low-dimensional tasks with
5 | good sample efficiency."
6 | 2. Passive Visual Match shows: "Transformer-based agent is more sample-efficient in long-term memory highdimensional tasks."
7 | 3. `Key to Door` and `Active T-Maze` (credit assignment) shows: "Transformer-based RL improves temporal credit assignment compared to LSTMbased RL, but its advantage diminishes in long-term scenarios"
--------------------------------------------------------------------------------
/envs/meta/example_env.py:
--------------------------------------------------------------------------------
1 | import gym
2 |
3 |
4 | class ExampleEnv(gym.Env):
5 | def __init__(self):
6 | super(ExampleEnv, self).__init__()
7 |
8 | def get_task(self):
9 | """
10 | Return a task description, such as goal position or target velocity.
11 | """
12 | pass
13 |
14 | def set_goal(self, goal):
15 | """
16 | Sets goal manually. Mainly used for reward relabelling.
17 | """
18 | pass
19 |
20 | def reset_task(self, task=None):
21 | """
22 | Reset the task, either at random (if task=None) or the given task.
23 | """
24 | pass
25 |
26 | def step(self, action):
27 | """
28 | Execute one step in the environment.
29 | Should return: state, reward, done, info
30 | where info has to include a field 'task'.
31 | """
32 | pass
33 |
34 | def reward(self, state, action):
35 | """
36 | Computes reward function of task.
37 | Returns the reward
38 | """
39 | pass
40 |
41 | def reset(self):
42 | """
43 | Reset the environment. This should *NOT* reset the task!
44 | Resetting the task is handled in the varibad wrapper (see wrappers.py).
45 | """
46 | pass
47 |
--------------------------------------------------------------------------------
/envs/meta/make_env.py:
--------------------------------------------------------------------------------
1 | import gym
2 |
3 | from .wrappers import VariBadWrapper
4 |
5 | # In VariBAD, they use on-policy PPO by vectorized env.
6 | # In BOReL, they use off-policy SAC by single env.
7 | import numpy as np
8 | import contextlib
9 | import random
10 |
11 | @contextlib.contextmanager
12 | def fixed_seed(seed):
13 | """上下文管理器,用于同时固定random和numpy.random的种子"""
14 | state_np = np.random.get_state()
15 | state_random = random.getstate()
16 | np.random.seed(seed)
17 | random.seed(seed)
18 | try:
19 | yield
20 | finally:
21 | np.random.set_state(state_np)
22 | random.setstate(state_random)
23 |
24 | def make_env(env_id, episodes_per_task, seed=None, oracle=False, **kwargs):
25 | """
26 | kwargs: include n_tasks=num_tasks
27 | """
28 | with fixed_seed(seed):
29 | env = gym.make(env_id, **kwargs)
30 | if seed is not None:
31 | env.seed(seed)
32 | env.action_space.np_random.seed(seed)
33 | env = VariBadWrapper(
34 | env=env,
35 | episodes_per_task=episodes_per_task,
36 | oracle=oracle,
37 | )
38 | return env
39 |
--------------------------------------------------------------------------------
/envs/meta/mujoco/ant.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from .mujoco_env import MujocoEnv
4 |
5 |
6 | class AntEnv(MujocoEnv):
7 | def __init__(self, use_low_gear_ratio=False):
8 | # self.init_serialization(locals())
9 | if use_low_gear_ratio:
10 | xml_path = "low_gear_ratio_ant.xml"
11 | else:
12 | xml_path = "ant.xml"
13 | super().__init__(
14 | xml_path,
15 | frame_skip=5,
16 | automatically_set_obs_and_action_space=True,
17 | )
18 |
19 | def step(self, a):
20 | torso_xyz_before = self.get_body_com("torso")
21 | self.do_simulation(a, self.frame_skip)
22 | torso_xyz_after = self.get_body_com("torso")
23 | torso_velocity = torso_xyz_after - torso_xyz_before
24 | forward_reward = torso_velocity[0] / self.dt
25 | ctrl_cost = 0.0 # .5 * np.square(a).sum()
26 | contact_cost = (
27 | 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
28 | )
29 | survive_reward = 0.0 # 1.0
30 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward
31 | state = self.state_vector()
32 | notdone = np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0
33 | done = not notdone
34 | ob = self._get_obs()
35 | return (
36 | ob,
37 | reward,
38 | done,
39 | dict(
40 | reward_forward=forward_reward,
41 | reward_ctrl=-ctrl_cost,
42 | reward_contact=-contact_cost,
43 | reward_survive=survive_reward,
44 | torso_velocity=torso_velocity,
45 | ),
46 | )
47 |
48 | def _get_obs(self):
49 | # this is gym ant obs, should use rllab?
50 | # if position is needed, override this in subclasses
51 | return np.concatenate(
52 | [
53 | self.sim.data.qpos.flat[2:],
54 | self.sim.data.qvel.flat,
55 | ]
56 | )
57 |
58 | def reset_model(self):
59 | qpos = self.init_qpos + self.np_random.uniform(
60 | size=self.model.nq, low=-0.1, high=0.1
61 | )
62 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
63 | self.set_state(qpos, qvel)
64 | return self._get_obs()
65 |
66 | def viewer_setup(self):
67 | self.viewer.cam.distance = self.model.stat.extent * 0.5
68 |
--------------------------------------------------------------------------------
/envs/meta/mujoco/ant_dir.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from .ant_multitask_base import MultitaskAntEnv
4 |
5 |
6 | class AntDirEnv(MultitaskAntEnv):
7 | """
8 | AntDir: forward_backward=True (unlimited tasks) from on-policy varibad code
9 | AntDir2D: forward_backward=False (limited tasks) from off-policy varibad code
10 | """
11 |
12 | def __init__(
13 | self,
14 | task={},
15 | n_tasks=None,
16 | max_episode_steps=200,
17 | forward_backward=True,
18 | **kwargs
19 | ):
20 | self.forward_backward = forward_backward
21 | self._max_episode_steps = max_episode_steps
22 |
23 | super(AntDirEnv, self).__init__(task, n_tasks, **kwargs)
24 |
25 | def step(self, action):
26 | torso_xyz_before = np.array(self.get_body_com("torso"))
27 |
28 | direct = (np.cos(self._goal), np.sin(self._goal))
29 |
30 | self.do_simulation(action, self.frame_skip)
31 | torso_xyz_after = np.array(self.get_body_com("torso"))
32 | torso_velocity = torso_xyz_after - torso_xyz_before
33 | forward_reward = np.dot((torso_velocity[:2] / self.dt), direct)
34 |
35 | ctrl_cost = 0.5 * np.square(action).sum()
36 | contact_cost = (
37 | 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
38 | )
39 | survive_reward = 1.0
40 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward
41 | state = self.state_vector()
42 | notdone = np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0
43 | done = not notdone
44 | ob = self._get_obs()
45 | return (
46 | ob,
47 | reward,
48 | done,
49 | dict(
50 | reward_forward=forward_reward,
51 | reward_ctrl=-ctrl_cost,
52 | reward_contact=-contact_cost,
53 | reward_survive=survive_reward,
54 | torso_velocity=torso_velocity,
55 | ),
56 | )
57 |
58 | def sample_tasks(self, num_tasks: int):
59 | assert self.forward_backward == False
60 | velocities = np.random.uniform(0.0, 2.0 * np.pi, size=(num_tasks,))
61 | tasks = [{"goal": velocity} for velocity in velocities]
62 | return tasks
63 |
64 | def _sample_raw_task(self):
65 | assert self.forward_backward == True
66 | velocity = np.random.choice([-1.0, 1.0]) # not 180 degree
67 | task = {"goal": velocity}
68 | return task
69 |
--------------------------------------------------------------------------------
/envs/meta/mujoco/ant_goal.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from .ant_multitask_base import MultitaskAntEnv
4 |
5 |
6 | class AntGoalEnv(MultitaskAntEnv):
7 | def __init__(self, task={}, n_tasks=2, max_episode_steps=200, **kwargs):
8 | super(AntGoalEnv, self).__init__(task, n_tasks, **kwargs)
9 | self._max_episode_steps = max_episode_steps
10 |
11 | def step(self, action):
12 | self.do_simulation(action, self.frame_skip)
13 | xposafter = np.array(self.get_body_com("torso"))
14 |
15 | goal_reward = -np.sum(
16 | np.abs(xposafter[:2] - self._goal)
17 | ) # make it happy, not suicidal
18 |
19 | ctrl_cost = 0.1 * np.square(action).sum()
20 | contact_cost = (
21 | 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
22 | )
23 | survive_reward = 0.0
24 | reward = goal_reward - ctrl_cost - contact_cost + survive_reward
25 | state = self.state_vector()
26 | done = False
27 | ob = self._get_obs()
28 | return (
29 | ob,
30 | reward,
31 | done,
32 | dict(
33 | goal_forward=goal_reward,
34 | reward_ctrl=-ctrl_cost,
35 | reward_contact=-contact_cost,
36 | reward_survive=survive_reward,
37 | ),
38 | )
39 |
40 | def sample_tasks(self, num_tasks):
41 | a = np.random.random(num_tasks) * 2 * np.pi
42 | r = 3 * np.random.random(num_tasks) ** 0.5
43 | goals = np.stack((r * np.cos(a), r * np.sin(a)), axis=-1)
44 | tasks = [{"goal": goal} for goal in goals]
45 | return tasks
46 |
47 | def _get_obs(self):
48 | return np.concatenate(
49 | [
50 | self.sim.data.qpos.flat,
51 | self.sim.data.qvel.flat,
52 | np.clip(self.sim.data.cfrc_ext, -1, 1).flat,
53 | ]
54 | )
55 |
--------------------------------------------------------------------------------
/envs/meta/mujoco/ant_multitask_base.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from .ant import AntEnv
4 |
5 | # from gym.envs.mujoco.ant import AntEnv
6 |
7 |
8 | class MultitaskAntEnv(AntEnv):
9 | def __init__(self, task={}, n_tasks=2, **kwargs):
10 | self._task = task
11 | self.n_tasks = n_tasks
12 | if n_tasks is None:
13 | self._goal = self._sample_raw_task()["goal"]
14 | else:
15 | self.tasks = self.sample_tasks(n_tasks)
16 | self._goal = self.tasks[0]["goal"]
17 | super(MultitaskAntEnv, self).__init__()
18 |
19 | def get_current_task(self):
20 | # for multi-task MDP
21 | return np.array([self._goal])
22 |
23 | def get_all_task_idx(self):
24 | return range(len(self.tasks))
25 |
26 | def reset_task(self, task_info):
27 | if self.n_tasks is None: # unlimited tasks
28 | assert task_info is None
29 | self._task = self._sample_raw_task() # sample here
30 | else: # limited tasks
31 | self._task = self.tasks[task_info] # as idx
32 | self._goal = self._task[
33 | "goal"
34 | ] # assume parameterization of task by single vector
35 | self.reset()
36 |
--------------------------------------------------------------------------------
/envs/meta/mujoco/core/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | General classes, functions, utilities that are used throughout rlkit.
3 | """
4 |
--------------------------------------------------------------------------------
/envs/meta/mujoco/core/serializable.py:
--------------------------------------------------------------------------------
1 | """
2 | Based on rllab's serializable.py file
3 |
4 | https://github.com/rll/rllab
5 | """
6 |
7 | import inspect
8 | import sys
9 |
10 |
11 | class Serializable(object):
12 | def __init__(self, *args, **kwargs):
13 | self.__args = args
14 | self.__kwargs = kwargs
15 |
16 | def quick_init(self, locals_):
17 | if getattr(self, "_serializable_initialized", False):
18 | return
19 | if sys.version_info >= (3, 0):
20 | spec = inspect.getfullargspec(self.__init__)
21 | # Exclude the first "self" parameter
22 | if spec.varkw:
23 | kwargs = locals_[spec.varkw].copy()
24 | else:
25 | kwargs = dict()
26 | if spec.kwonlyargs:
27 | for key in spec.kwonlyargs:
28 | kwargs[key] = locals_[key]
29 | else:
30 | spec = inspect.getargspec(self.__init__)
31 | if spec.keywords:
32 | kwargs = locals_[spec.keywords]
33 | else:
34 | kwargs = dict()
35 | if spec.varargs:
36 | varargs = locals_[spec.varargs]
37 | else:
38 | varargs = tuple()
39 | in_order_args = [locals_[arg] for arg in spec.args][1:]
40 | self.__args = tuple(in_order_args) + varargs
41 | self.__kwargs = kwargs
42 | setattr(self, "_serializable_initialized", True)
43 |
44 | def __getstate__(self):
45 | return {"__args": self.__args, "__kwargs": self.__kwargs}
46 |
47 | def __setstate__(self, d):
48 | # convert all __args to keyword-based arguments
49 | if sys.version_info >= (3, 0):
50 | spec = inspect.getfullargspec(self.__init__)
51 | else:
52 | spec = inspect.getargspec(self.__init__)
53 | in_order_args = spec.args[1:]
54 | out = type(self)(**dict(zip(in_order_args, d["__args"]), **d["__kwargs"]))
55 | self.__dict__.update(out.__dict__)
56 |
57 | @classmethod
58 | def clone(cls, obj, **kwargs):
59 | assert isinstance(obj, Serializable)
60 | d = obj.__getstate__()
61 | d["__kwargs"] = dict(d["__kwargs"], **kwargs)
62 | out = type(obj).__new__(type(obj))
63 | out.__setstate__(d)
64 | return out
65 |
--------------------------------------------------------------------------------
/envs/meta/mujoco/mujoco_env.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path
3 |
4 | import mujoco_py
5 | import numpy as np
6 | from gym.envs.mujoco import mujoco_env
7 |
8 | from .core.serializable import Serializable
9 |
10 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), "assets")
11 |
12 |
13 | class MujocoEnv(mujoco_env.MujocoEnv, Serializable):
14 | """
15 | My own wrapper around MujocoEnv.
16 |
17 | The caller needs to declare
18 | """
19 |
20 | def __init__(
21 | self,
22 | model_path,
23 | frame_skip=1,
24 | model_path_is_local=True,
25 | automatically_set_obs_and_action_space=False,
26 | ):
27 | if model_path_is_local:
28 | model_path = get_asset_xml(model_path)
29 | if automatically_set_obs_and_action_space:
30 | mujoco_env.MujocoEnv.__init__(self, model_path, frame_skip)
31 | else:
32 | """
33 | Code below is copy/pasted from MujocoEnv's __init__ function.
34 | """
35 | if model_path.startswith("/"):
36 | fullpath = model_path
37 | else:
38 | fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path)
39 | if not path.exists(fullpath):
40 | raise IOError("File %s does not exist" % fullpath)
41 | self.frame_skip = frame_skip
42 | self.model = mujoco_py.MjModel(fullpath)
43 | self.data = self.model.data
44 | self.viewer = None
45 |
46 | self.metadata = {
47 | "render.modes": ["human", "rgb_array"],
48 | "video.frames_per_second": int(np.round(1.0 / self.dt)),
49 | }
50 |
51 | self.init_qpos = self.model.data.qpos.ravel().copy()
52 | self.init_qvel = self.model.data.qvel.ravel().copy()
53 | self._seed()
54 |
55 | def init_serialization(self, locals):
56 | Serializable.quick_init(self, locals)
57 |
58 | def log_diagnostics(self, paths):
59 | pass
60 |
61 |
62 | def get_asset_xml(xml_name):
63 | return os.path.join(ENV_ASSET_DIR, xml_name)
64 |
--------------------------------------------------------------------------------
/envs/meta/readme.md:
--------------------------------------------------------------------------------
1 | # Meta RL Environments
2 | Based on the code https://github.com/Rondorf/BOReL and https://github.com/lmzintgraf/varibad
3 |
--------------------------------------------------------------------------------
/envs/pomdp/readme.md:
--------------------------------------------------------------------------------
1 | # "Standard" POMDP Environments
2 | Based on the code https://github.com/oist-cnru/Variational-Recurrent-Models
3 |
--------------------------------------------------------------------------------
/envs/pomdp/wrappers.py:
--------------------------------------------------------------------------------
1 | import gym
2 | from gym import spaces
3 | import numpy as np
4 |
5 |
6 | class POMDPWrapper(gym.Wrapper):
7 | def __init__(self, env, partially_obs_dims: list):
8 | super().__init__(env)
9 | self.partially_obs_dims = partially_obs_dims
10 | # can equal to the fully-observed env
11 | assert 0 < len(self.partially_obs_dims) <= self.observation_space.shape[0]
12 |
13 | self.observation_space = spaces.Box(
14 | low=self.observation_space.low[self.partially_obs_dims],
15 | high=self.observation_space.high[self.partially_obs_dims],
16 | dtype=np.float32,
17 | )
18 |
19 | if self.env.action_space.__class__.__name__ == "Box":
20 | self.act_continuous = True
21 | # if continuous actions, make sure in [-1, 1]
22 | # NOTE: policy won't use action_space.low/high, just set [-1,1]
23 | # this is a bad practice...
24 | else:
25 | self.act_continuous = False
26 | self.true_state = None
27 |
28 | def get_obs(self, state):
29 | return state[self.partially_obs_dims].copy()
30 |
31 | def get_unobservable(self):
32 | unobserved_dims = [i for i in range(self.true_state.shape[0]) if i not in self.partially_obs_dims]
33 | return self.true_state[unobserved_dims].copy()
34 |
35 | def reset(self):
36 | state = self.env.reset() # no kwargs
37 | self.true_state = state
38 | return self.get_obs(state)
39 |
40 | def step(self, action):
41 | if self.act_continuous:
42 | # recover the action
43 | action = np.clip(action, -1, 1) # first clip into [-1, 1]
44 | lb = self.env.action_space.low
45 | ub = self.env.action_space.high
46 | action = lb + (action + 1.0) * 0.5 * (ub - lb)
47 | action = np.clip(action, lb, ub)
48 |
49 | state, reward, done, info = self.env.step(action)
50 | self.true_state = state
51 | return self.get_obs(state), reward, done, info
52 |
53 |
54 | if __name__ == "__main__":
55 | import envs
56 |
57 | env = gym.make("HopperBLT-F-v0")
58 | obs = env.reset()
59 | done = False
60 | step = 0
61 | while not done:
62 | next_obs, rew, done, info = env.step(env.action_space.sample())
63 | step += 1
64 | print(step, done, info)
65 |
--------------------------------------------------------------------------------
/envs/readme.md:
--------------------------------------------------------------------------------
1 | # POMDP Environments
2 | The POMDP environments are mainly adapted from [POMDP](https://github.com/twni2016/pomdp-baselines) and [ESCP](https://github.com/FanmingL/ESCP)
3 | ## Overview
4 | - `meta/`: Meta RL environments
5 | - `pomdp/`: "standard" POMDP environments
6 | - `rl-generalization`: Generalization in RL and Robust RL environments
7 |
8 | ## Normalized Action Space
9 | We make sure every environment has continuous action space [-1, 1]^|A|, exposed to the policy. Policy should not use `self.action_space.high` or `self.action_space.low`.
10 |
11 | In Meta RL and "standard" POMDP, we use the snipplet for normalizing the action space:
12 | ```python
13 | class EnvWrapper(gym.Wrapper):
14 | def step(self, action):
15 | action = np.clip(action, -1, 1) # first clip into [-1, 1]
16 | lb = self.env.action_space.low
17 | ub = self.env.action_space.high
18 | action = lb + (action + 1.) * 0.5 * (ub - lb) # recover the original action space
19 | action = np.clip(action, lb, ub)
20 | ...
21 | ```
22 |
23 | ## Reproducibilty Issue in Gym Environments
24 | In current gym version (v0.21), we need to set seed for env and its action space to ensure reproducibilty
25 | ```python
26 | env.seed(seed)
27 | env.action_space.np_random.seed(seed)
28 | ```
29 | However, I have not figured out how to do this for old gym version (v0.10) for SunBlaze envs. Please kindly pull request if you know.
30 |
--------------------------------------------------------------------------------
/envs/rl_generalization/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.egg-info
3 | *.pdf
4 | *.png
5 | *.swp
6 | *.swo
7 | __pycache__
8 | videos
9 |
--------------------------------------------------------------------------------
/envs/rl_generalization/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 UC Berkeley, Intel Labs, and other contributors
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/envs/rl_generalization/README.md:
--------------------------------------------------------------------------------
1 | # POMDP Environments for Evaluating Generalization and Robustness in RL
2 | Based on the code https://github.com/sunblaze-ucb/rl-generalization
3 |
4 | ## Using the Generalization Environments
5 | From original SunBlaze repo,
6 |
7 | ```python
8 | import gym
9 | import sunblaze_envs
10 |
11 | # Deterministic: the default version with fixed parameters
12 | fixed_env = sunblaze_envs.make('SunblazeCartPole-v0')
13 |
14 | # Random: parameters are sampled from a range nearby the default settings
15 | random_env = sunblaze_envs.make('SunblazeCartPoleRandomNormal-v0')
16 |
17 | # Extreme: parameters are sampled from an `extreme' range
18 | extreme_env = sunblaze_envs.make('SunblazeCartPoleRandomExtreme-v0')
19 | ```
20 | In the case of CartPole, RandomNormal and RandomExtreme will vary the strength of each actions, the mass of the pole, and the length of the pole:
21 |
22 | Specific ranges for each environment setting are listed [here](sunblaze_envs#environment-details). See the code in [examples](/examples) for usage with example algorithms from OpenAI Baselines.
23 |
24 | ## Using the Robust RL Environments
25 | Similarly, we adopt the environments from [MRPO paper](https://proceedings.mlr.press/v139/jiang21c.html), for example
26 |
27 | ```python
28 | import gym
29 | import sunblaze_envs
30 |
31 | MRPO_walker_env = sunblaze_envs.make('MRPOWalker2dRandomNormal-v0')
32 |
33 | ```
34 |
--------------------------------------------------------------------------------
/envs/rl_generalization/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup, find_packages
3 | import sys
4 |
5 |
6 | if sys.version_info.major != 3:
7 | print(
8 | "This Python is only compatible with Python 3, but you are running "
9 | "Python {}. The installation will likely fail.".format(sys.version_info.major)
10 | )
11 |
12 | package_name = "sunblaze_envs"
13 | authors = ["UC Berkeley", "Intel Labs", "and other contributors"]
14 | url = "https://github.com/sunblaze-ucb/rl-generalization"
15 | description = "Modifiable OpenAI Gym environments for studying generalization in RL"
16 |
17 | setup_py_dir = os.path.dirname(os.path.realpath(__file__))
18 | package_dir = os.path.join(setup_py_dir, package_name)
19 |
20 | # Discover asset files.
21 | ASSET_EXTENSIONS = ["png", "wad", "txt"]
22 | assets = []
23 | for root, dirs, files in os.walk(package_dir):
24 | for filename in files:
25 | extension = os.path.splitext(filename)[1][1:]
26 | if extension and extension in ASSET_EXTENSIONS:
27 | filename = os.path.join(root, filename)
28 | assets.append(filename[1 + len(package_dir) :])
29 |
30 | setup(
31 | name=package_name,
32 | version="0.1.0",
33 | description=description,
34 | author=", ".join(authors),
35 | # maintainer_email="",
36 | url=url,
37 | packages=find_packages(exclude=("examples",)),
38 | package_data={"": assets},
39 | dependency_links=(
40 | # "git+https://github.com/kostko/omgifol.git@master#egg=omgifol-0.1.0",
41 | "git+https://github.com/openai/gym.git@094e6b8e6a102644667d53d9dac6f2245bf80c6f#egg=gym-0.10.8r1",
42 | "git+https://github.com/openai/baselines.git@2b0283b9db18c768f8e9fa29fbedc5e48499acc6#egg=baselines-0.1.5r1",
43 | ),
44 | install_requires=[
45 | "gym==0.10.8r1",
46 | #'gym==0.10.5',
47 | "Box2D==2.3.2",
48 | "cocos2d==0.6.5",
49 | "numpy==1.14.2",
50 | "scipy==1.0.0",
51 | #'vizdoom==1.1.2',
52 | #'omgifol>=0.1.0',
53 | ],
54 | extras_require={
55 | "examples": [
56 | "baselines==0.1.5r1", # use dep_link for specific commit
57 | "PyYAML==3.12",
58 | "opencv-python==3.4.0.12",
59 | "cloudpickle==0.4.1",
60 | "natsort==5.1.0",
61 | "chainer==3.3.0",
62 | "chainerrl==0.3.0",
63 | ],
64 | },
65 | )
66 |
--------------------------------------------------------------------------------
/envs/rl_generalization/sunblaze_envs/README.md:
--------------------------------------------------------------------------------
1 | ## Overview of Generalization Environments
2 |
3 | There are six environments, built on top of the corresponding OpenAI Gym and Roboschool implementations:
4 | * CartPole
5 | * MountainCar
6 | * Acrobot
7 | * Pendulum
8 | * HalfCheetah
9 | * Hopper
10 |
11 | Each has three versions:
12 | * **D**: Environment parameters are set to the default values in Gym and Roboschool. Access by Sunblaze*Environment*-v0, e.g. SunblazeCartPole-v0.
13 | * **R**: Environment parameters are randomly sampled from intervals containing their default values. Access by Sunblaze*Environment*RandomNormal-v0.
14 | * **E**: Environment parameters are randomly sampled from intervals outside those in **R**, containing more extreme values. Access by Sunblaze*Environment*RandomExtreme-v0.
15 |
16 | ## Environment details
17 |
18 | Ranges of parameters for each version of each environment, using set notation.
19 |
20 | | Environment | Parameter | D | R | E |
21 | | --- | --- | --- | --- | --- |
22 | | CartPole | Force | 10 | [5,15] | [1,5] U [15,20] |
23 | | | Length | 0.5 | [0.25,0.75] | [0.05,0.25] U [0.75,1.0] |
24 | | | Mass | 0.1 | [0.05,0.5] | [0.01,0.05] U [0.5,1.0] |
25 | | MountainCar | Force | 0.001 | [0.0005,0.005] | [0.0001,0.0005] U [0.005,0.01] |
26 | | | Mass | 0.0025 | [0.001,0.005] | [0.0005,0.001] U [0.005,0.01] |
27 | | Acrobot | Length | 1 | [0.75,1.25] | [0.5,0.75] U [1.25,1.5] |
28 | | | Mass | 1 | [0.75,1.25] | [0.5,0.75] U [1.25,1.5] |
29 | | | MOI | 1 | [0.75,1.25] | [0.5,0.75] U [1.25,1.5] |
30 | | Pendulum | Length | 1 | [0.75,1.25] | [0.5,0.75] U [1.25,1.5] |
31 | | | Mass | 1 | [0.75,1.25] | [0.5,0.75] U [1.25,1.5] |
32 | | HalfCheetah | Power | 0.90 | [0.70,1.10] | [0.50,0.70] U [1.10,1.30] |
33 | | | Density | 1000 | [750,1250] | [500,750] U [1250,1500] |
34 | | | Friction | 0.8 | [0.5,1.1] | [0.2,0.5] U [1.1,1.4] |
35 | | Hopper | Power | 0.75 | [0.60,0.90] | [0.40,0.60] U [0.90,1.10] |
36 | | | Density | 1000 | [750,1250] | [500,750] U [1250,1500] |
37 | | | Friction | 0.8 | [0.5,1.1] | [0.2,0.5] U [1.1,1.4] |
38 |
--------------------------------------------------------------------------------
/envs/rl_generalization/sunblaze_envs/assets/vizdoom/.gitignore:
--------------------------------------------------------------------------------
1 | *.bak
2 |
--------------------------------------------------------------------------------
/envs/rl_generalization/sunblaze_envs/assets/vizdoom/basic.wad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/rl_generalization/sunblaze_envs/assets/vizdoom/basic.wad
--------------------------------------------------------------------------------
/envs/rl_generalization/sunblaze_envs/assets/vizdoom/basic_floor_ceiling_flipped.wad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/rl_generalization/sunblaze_envs/assets/vizdoom/basic_floor_ceiling_flipped.wad
--------------------------------------------------------------------------------
/envs/rl_generalization/sunblaze_envs/assets/vizdoom/basic_torches.wad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/rl_generalization/sunblaze_envs/assets/vizdoom/basic_torches.wad
--------------------------------------------------------------------------------
/envs/rl_generalization/sunblaze_envs/assets/vizdoom/navigation.wad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/rl_generalization/sunblaze_envs/assets/vizdoom/navigation.wad
--------------------------------------------------------------------------------
/envs/rl_generalization/sunblaze_envs/assets/vizdoom/navigation_floor_ceiling_flipped.wad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/rl_generalization/sunblaze_envs/assets/vizdoom/navigation_floor_ceiling_flipped.wad
--------------------------------------------------------------------------------
/envs/rl_generalization/sunblaze_envs/assets/vizdoom/navigation_new_layout.wad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/rl_generalization/sunblaze_envs/assets/vizdoom/navigation_new_layout.wad
--------------------------------------------------------------------------------
/envs/rl_generalization/sunblaze_envs/assets/vizdoom/navigation_torches.wad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/rl_generalization/sunblaze_envs/assets/vizdoom/navigation_torches.wad
--------------------------------------------------------------------------------
/envs/rl_generalization/sunblaze_envs/assets/vizdoom/texture_set_b.txt:
--------------------------------------------------------------------------------
1 | TANROCK2
2 | TANROCK3
3 | TANROCK4
4 | TANROCK5
5 | TANROCK7
6 | TANROCK8
7 | TEKBRON1
8 | TEKBRON2
9 | TEKGREN1
10 | TEKGREN2
11 | TEKGREN3
12 | TEKGREN4
13 | TEKGREN5
14 | TEKLITE
15 | TEKLITE2
16 | TEKWALL1
17 | TEKWALL2
18 | TEKWALL3
19 | TEKWALL4
20 | TEKWALL5
21 | TEKWALL6
22 | WOOD1
23 | WOOD10
24 | WOOD12
25 | WOOD3
26 | WOOD4
27 | WOOD5
28 | WOOD6
29 | WOOD7
30 | WOOD8
31 | WOOD9
32 | WOODGARG
33 | WOODMET1
34 | WOODMET2
35 | WOODMET3
36 | WOODMET4
37 | WOODSKUL
38 | WOODVERT
39 | ZDOORB1
40 | ZDOORF1
41 | ZELDOOR
42 | ZIMMER1
43 | ZIMMER2
44 | ZIMMER3
45 | ZIMMER4
46 | ZIMMER5
47 | ZIMMER7
48 | ZIMMER8
49 | ZZWOLF1
50 | ZZWOLF10
51 | ZZWOLF11
52 | ZZWOLF12
53 | ZZWOLF13
54 | ZZWOLF2
55 | ZZWOLF3
56 | ZZWOLF4
57 | ZZWOLF5
58 | ZZWOLF6
59 | ZZWOLF7
60 | ZZWOLF9
61 | ZZZFACE1
62 | ZZZFACE2
63 | ZZZFACE3
64 | ZZZFACE4
65 | ZZZFACE5
66 | ZZZFACE6
67 | ZZZFACE7
68 | ZZZFACE8
69 | ZZZFACE9
--------------------------------------------------------------------------------
/envs/rl_generalization/sunblaze_envs/assets/vizdoom/thing_set_a.txt:
--------------------------------------------------------------------------------
1 | 5
2 | 6
3 | 13
4 | 30
5 | 31
6 | 34
7 | 35
8 | 38
9 | 39
10 | 40
11 | 44
12 | 45
13 | 46
--------------------------------------------------------------------------------
/envs/rl_generalization/sunblaze_envs/assets/vizdoom/thing_set_b.txt:
--------------------------------------------------------------------------------
1 | 32
2 | 48
3 | 55
4 | 56
5 | 57
6 | 2028
7 | 5050
--------------------------------------------------------------------------------
/envs/rl_generalization/sunblaze_envs/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import gym
4 |
5 |
6 | class BaseGymEnvironment(gym.Env):
7 | """Base class for all Gym environments."""
8 |
9 | @property
10 | def parameters(self):
11 | """Return environment parameters."""
12 | return {
13 | "id": self.spec.id,
14 | }
15 |
16 |
17 | class EnvBinarySuccessMixin(ABC):
18 | """Adds binary success metric to environment."""
19 |
20 | @abstractmethod
21 | def is_success(self):
22 | """Returns True is current state indicates success, False otherwise"""
23 | pass
24 |
--------------------------------------------------------------------------------
/envs/rl_generalization/sunblaze_envs/monitor.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import gym
5 |
6 |
7 | class MonitorParameters(gym.Wrapper):
8 | """Environment wrapper which records all environment parameters."""
9 |
10 | current_parameters = None
11 |
12 | def __init__(self, env, output_filename):
13 | """
14 | Construct parameter monitor wrapper.
15 |
16 | :param env: Wrapped environment
17 | :param output_filename: Output log filename
18 | """
19 | self._output_filename = output_filename
20 | with open(output_filename, "w"):
21 | # Truncate output file.
22 | pass
23 |
24 | super(MonitorParameters, self).__init__(env)
25 |
26 | def step(self, action):
27 | result = self.env.step(action)
28 | self.record_parameters()
29 | return result
30 |
31 | def reset(self):
32 | result = self.env.reset()
33 | self.record_parameters()
34 | return result
35 |
36 | def record_parameters(self):
37 | """Record current environment parameters."""
38 | if not hasattr(self.env.unwrapped, "parameters"):
39 | return
40 | if self.env.unwrapped.parameters == self.current_parameters:
41 | return
42 |
43 | # Record parameter set in output file.
44 | self.current_parameters = self.env.unwrapped.parameters
45 | with open(self._output_filename, "a") as output_file:
46 | output_file.write(json.dumps(self.current_parameters))
47 | output_file.write("\n")
48 |
--------------------------------------------------------------------------------
/envs/rl_generalization/sunblaze_envs/time_limit.py:
--------------------------------------------------------------------------------
1 | from gym.wrappers.time_limit import TimeLimit as TimeLimitBase
2 | import time
3 |
4 |
5 | class TimeLimit(TimeLimitBase):
6 | """Updated to support reset() with reset_params flag for Adaptive"""
7 |
8 | def reset(self, reset_params=True):
9 | self._episode_started_at = time.time()
10 | self._elapsed_steps = 0
11 | return self.env.reset(reset_params)
12 |
--------------------------------------------------------------------------------
/envs/rl_generalization/sunblaze_envs/wrappers.py:
--------------------------------------------------------------------------------
1 | import collections
2 |
3 | import gym
4 | from gym.envs.registration import load
5 | import numpy as np
6 |
7 |
8 | def ActionDelayWrapper(delay_range_start, delay_range_end):
9 | """Create an action delay wrapper.
10 |
11 | :param delay_range_start: Minimum delay
12 | :param delay_range_end: Maximum delay
13 | """
14 |
15 | class ActionDelayWrapper(gym.Wrapper):
16 | def _step(self, action):
17 | self._action_buffer.append(action)
18 | action = self._action_buffer.popleft()
19 | return self.env.step(action)
20 |
21 | def _reset(self):
22 | self._action_delay = np.random.randint(delay_range_start, delay_range_end)
23 | self._action_buffer = collections.deque(
24 | [0 for _ in range(self._action_delay)]
25 | )
26 | return self.env.reset()
27 |
28 | return ActionDelayWrapper
29 |
30 |
31 | def wrap_environment(wrapped_class, wrappers=None, **kwargs):
32 | """Helper for wrapping environment classes."""
33 | if wrappers is None:
34 | wrappers = []
35 |
36 | env_class = load(wrapped_class)
37 | env = env_class(**kwargs)
38 | for wrapper, wrapper_kwargs in wrappers:
39 | wrapper_class = load(wrapper)
40 | wrapper = wrapper_class(**wrapper_kwargs)
41 | env = wrapper(env)
42 |
43 | return env
44 |
--------------------------------------------------------------------------------
/envs/rl_generalization/test.py:
--------------------------------------------------------------------------------
1 | import sunblaze_envs
2 |
3 | env = sunblaze_envs.make("SunblazeHalfCheetahRandomNormal-v0") # (26, 6)
4 | # env = sunblaze_envs.make('SunblazeHopperRandomNormal-v0') # (15, 3)
5 | print(env.observation_space, env.action_space)
6 | print(env.action_space.high, env.action_space.low) # [-1,1]
7 | obs = env.reset()
8 | print(obs)
9 | print(env._max_episode_steps)
10 | done = False
11 | step = 0
12 | while not done:
13 | step += 1
14 | obs, rew, done, info = env.step(env.action_space.sample())
15 | # there exists early failure done=True
16 | # env.unwrapped.is_success() measures the z >= 20m
17 | print(step, obs, rew, done, info, env.unwrapped.is_success())
18 |
--------------------------------------------------------------------------------
/envs/torchkit/constant.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | LSTM_name = "lstm"
4 | GRU_name = "gru"
5 | RNNs = {
6 | LSTM_name: nn.LSTM,
7 | GRU_name: nn.GRU,
8 | }
9 |
--------------------------------------------------------------------------------
/envs/torchkit/modules.py:
--------------------------------------------------------------------------------
1 | """
2 | Contain some self-contained modules.
3 | """
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class HuberLoss(nn.Module):
9 | def __init__(self, delta=1):
10 | super().__init__()
11 | self.huber_loss_delta1 = nn.SmoothL1Loss()
12 | self.delta = delta
13 |
14 | def forward(self, x, x_hat):
15 | loss = self.huber_loss_delta1(x / self.delta, x_hat / self.delta)
16 | return loss * self.delta * self.delta
17 |
18 |
19 | class LayerNorm(nn.Module):
20 | """
21 | Simple 1D LayerNorm.
22 | """
23 |
24 | def __init__(self, features, center=True, scale=False, eps=1e-6):
25 | super().__init__()
26 | self.center = center
27 | self.scale = scale
28 | self.eps = eps
29 | if self.scale:
30 | self.scale_param = nn.Parameter(torch.ones(features))
31 | else:
32 | self.scale_param = None
33 | if self.center:
34 | self.center_param = nn.Parameter(torch.zeros(features))
35 | else:
36 | self.center_param = None
37 |
38 | def forward(self, x):
39 | mean = x.mean(-1, keepdim=True)
40 | std = x.std(-1, keepdim=True)
41 | output = (x - mean) / (std + self.eps)
42 | if self.scale:
43 | output = output * self.scale_param
44 | if self.center:
45 | output = output + self.center_param
46 | return output
47 |
--------------------------------------------------------------------------------
/envs/torchkit/policies_base.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class Policy(object, metaclass=abc.ABCMeta):
5 | """
6 | General policy interface.
7 | """
8 |
9 | @abc.abstractmethod
10 | def get_action(self, observation):
11 | """
12 | :param observation:
13 | :return: action, debug_dictionary
14 | """
15 | pass
16 |
17 | def reset(self):
18 | pass
19 |
20 |
21 | class ExplorationPolicy(Policy, metaclass=abc.ABCMeta):
22 | def set_num_steps_total(self, t):
23 | pass
24 |
25 |
26 | class SerializablePolicy(Policy, metaclass=abc.ABCMeta):
27 | """
28 | Policy that can be serialized.
29 | """
30 |
31 | def get_param_values(self):
32 | return None
33 |
34 | def set_param_values(self, values):
35 | pass
36 |
37 | """
38 | Parameters should be passed as np arrays in the two functions below.
39 | """
40 |
41 | def get_param_values_np(self):
42 | return None
43 |
44 | def set_param_values_np(self, values):
45 | pass
46 |
--------------------------------------------------------------------------------
/envs/torchkit/serializable.py:
--------------------------------------------------------------------------------
1 | """
2 | Based on rllab's serializable.py file
3 |
4 | https://github.com/rll/rllab
5 | """
6 |
7 | import inspect
8 | import sys
9 |
10 |
11 | class Serializable(object):
12 | def __init__(self, *args, **kwargs):
13 | self.__args = args
14 | self.__kwargs = kwargs
15 |
16 | def quick_init(self, locals_):
17 | if getattr(self, "_serializable_initialized", False):
18 | return
19 | if sys.version_info >= (3, 0):
20 | spec = inspect.getfullargspec(self.__init__)
21 | # Exclude the first "self" parameter
22 | if spec.varkw:
23 | kwargs = locals_[spec.varkw].copy()
24 | else:
25 | kwargs = dict()
26 | if spec.kwonlyargs:
27 | for key in spec.kwonlyargs:
28 | kwargs[key] = locals_[key]
29 | else:
30 | spec = inspect.getargspec(self.__init__)
31 | if spec.keywords:
32 | kwargs = locals_[spec.keywords]
33 | else:
34 | kwargs = dict()
35 | if spec.varargs:
36 | varargs = locals_[spec.varargs]
37 | else:
38 | varargs = tuple()
39 | in_order_args = [locals_[arg] for arg in spec.args][1:]
40 | self.__args = tuple(in_order_args) + varargs
41 | self.__kwargs = kwargs
42 | setattr(self, "_serializable_initialized", True)
43 |
44 | def __getstate__(self):
45 | return {"__args": self.__args, "__kwargs": self.__kwargs}
46 |
47 | def __setstate__(self, d):
48 | # convert all __args to keyword-based arguments
49 | if sys.version_info >= (3, 0):
50 | spec = inspect.getfullargspec(self.__init__)
51 | else:
52 | spec = inspect.getargspec(self.__init__)
53 | in_order_args = spec.args[1:]
54 | out = type(self)(**dict(zip(in_order_args, d["__args"]), **d["__kwargs"]))
55 | self.__dict__.update(out.__dict__)
56 |
57 | @classmethod
58 | def clone(cls, obj, **kwargs):
59 | assert isinstance(obj, Serializable)
60 | d = obj.__getstate__()
61 | d["__kwargs"] = dict(d["__kwargs"], **kwargs)
62 | out = type(obj).__new__(type(obj))
63 | out.__setstate__(d)
64 | return out
65 |
--------------------------------------------------------------------------------
/envs/utils/system.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import torch
4 | import datetime
5 | import dateutil.tz
6 |
7 |
8 | def reproduce(seed):
9 | """
10 | This can only fix the randomness of numpy and torch
11 | To fix the environment's, please use
12 | env.seed(seed)
13 | env.action_space.np_random.seed(seed)
14 | We have add these in our training script
15 | """
16 | np.random.seed(seed)
17 | random.seed(seed)
18 | torch.manual_seed(seed)
19 | if torch.cuda.is_available():
20 | torch.cuda.manual_seed(seed)
21 | torch.cuda.manual_seed_all(seed)
22 |
23 |
24 | def now_str():
25 | now = datetime.datetime.now(dateutil.tz.tzlocal())
26 | return now.strftime("%m-%d:%H-%M:%S.%f")[:-4]
27 |
--------------------------------------------------------------------------------
/envs/yang_domains/assets/box_1d.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
36 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
--------------------------------------------------------------------------------
/envs/yang_domains/assets/clockwise.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/clockwise.png
--------------------------------------------------------------------------------
/envs/yang_domains/assets/grid_markers.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/envs/yang_domains/assets/gripah/narrow_finger.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/gripah/narrow_finger.STL
--------------------------------------------------------------------------------
/envs/yang_domains/assets/gripah/narrow_finger_rescaled.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/gripah/narrow_finger_rescaled.STL
--------------------------------------------------------------------------------
/envs/yang_domains/assets/gripah/narrow_finger_tip.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/gripah/narrow_finger_tip.STL
--------------------------------------------------------------------------------
/envs/yang_domains/assets/gripah/wide_finger.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/gripah/wide_finger.STL
--------------------------------------------------------------------------------
/envs/yang_domains/assets/gripah/wide_finger_rescaled.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/gripah/wide_finger_rescaled.STL
--------------------------------------------------------------------------------
/envs/yang_domains/assets/gripah/wide_finger_tip.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/gripah/wide_finger_tip.STL
--------------------------------------------------------------------------------
/envs/yang_domains/assets/gripah/wrist.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/gripah/wrist.STL
--------------------------------------------------------------------------------
/envs/yang_domains/assets/gripah/wrist_rescaled.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/gripah/wrist_rescaled.STL
--------------------------------------------------------------------------------
/envs/yang_domains/assets/gripah_asset.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/envs/yang_domains/assets/gripah_body.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
16 |
18 |
20 |
22 |
23 |
24 |
25 |
26 |
28 |
30 |
32 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/envs/yang_domains/assets/gripah_contact.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/c_base.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/c_base.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/c_forearm.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/c_forearm.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/c_robotiq_85_gripper_joint_3_L.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/c_robotiq_85_gripper_joint_3_L.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/c_robotiq_85_gripper_joint_3_R.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/c_robotiq_85_gripper_joint_3_R.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/c_shoulder.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/c_shoulder.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/c_upperarm.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/c_upperarm.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/c_wrist1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/c_wrist1.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/c_wrist2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/c_wrist2.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/c_wrist3.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/c_wrist3.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/glass_cup.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/glass_cup.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/glass_cup_2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/glass_cup_2.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/glass_cup_3.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/glass_cup_3.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/inner_finger_coarse.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/inner_finger_coarse.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/inner_finger_fine.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/inner_finger_fine.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/inner_knuckle_coarse.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/inner_knuckle_coarse.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/inner_knuckle_fine.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/inner_knuckle_fine.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/new_solo_cup.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/new_solo_cup.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/outer_finger_coarse.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/outer_finger_coarse.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/outer_finger_fine.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/outer_finger_fine.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/outer_knuckle_coarse.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/outer_knuckle_coarse.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/outer_knuckle_fine.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/outer_knuckle_fine.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/red_solo_cup.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/red_solo_cup.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/robotiq_85_base_link_coarse.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/robotiq_85_base_link_coarse.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/robotiq_85_base_link_fine.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/robotiq_85_base_link_fine.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/smaller_solo_cup.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/smaller_solo_cup.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/solo_cup.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/solo_cup.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/upd_solo_cup.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/upd_solo_cup.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/v_base.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/v_base.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/v_forearm.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/v_forearm.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/v_robotiq_85_gripper_joint_3_L.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/v_robotiq_85_gripper_joint_3_L.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/v_robotiq_85_gripper_joint_3_R.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/v_robotiq_85_gripper_joint_3_R.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/v_shoulder.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/v_shoulder.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/v_upperarm.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/v_upperarm.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/v_wrist1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/v_wrist1.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/v_wrist2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/v_wrist2.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/muj_gripper/v_wrist3.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/v_wrist3.stl
--------------------------------------------------------------------------------
/envs/yang_domains/assets/objects/bump_40_mujoco.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/objects/bump_40_mujoco.STL
--------------------------------------------------------------------------------
/envs/yang_domains/assets/objects/bump_50_mujoco.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/objects/bump_50_mujoco.STL
--------------------------------------------------------------------------------
/envs/yang_domains/assets/objects/bump_80_mujoco.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/objects/bump_80_mujoco.STL
--------------------------------------------------------------------------------
/envs/yang_domains/assets/objects/plate_half.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/objects/plate_half.STL
--------------------------------------------------------------------------------
/envs/yang_domains/assets/objects/plate_whole.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/objects/plate_whole.STL
--------------------------------------------------------------------------------
/envs/yang_domains/dmc_acrobot.py:
--------------------------------------------------------------------------------
1 | import dmc2gym
2 | from .wrappers import ConcatObs
3 | from .wrappers_for_image import Normalize255Image, GrayscaleImage, ConcatImages
4 |
5 |
6 | def mdp():
7 | return dmc2gym.make(domain_name="acrobot", task_name="swingup", keys_to_exclude=[], track_prev_action=False, frame_skip=5)
8 |
9 |
10 | def p():
11 | return dmc2gym.make(domain_name="acrobot", task_name="swingup", keys_to_exclude=['velocity'], track_prev_action=False, frame_skip=5)
12 |
13 |
14 | def va():
15 | return dmc2gym.make(domain_name="acrobot", task_name="swingup", keys_to_exclude=['orientations'], track_prev_action=True, frame_skip=5)
16 |
17 |
18 | def p_concat5():
19 | return ConcatObs(p(), 5)
20 |
21 |
22 | def va_concat10():
23 | return ConcatObs(va(), 10)
24 |
--------------------------------------------------------------------------------
/envs/yang_domains/dmc_cart2pole.py:
--------------------------------------------------------------------------------
1 | import dmc2gym
2 | from .wrappers import ConcatObs
3 | from .wrappers_for_image import Normalize255Image, GrayscaleImage, ConcatImages
4 |
5 |
6 | def mdp():
7 | return dmc2gym.make(domain_name="cartpole", task_name="two_poles", keys_to_exclude=[], frame_skip=5)
8 |
9 |
10 | # def p():
11 | # return dmc2gym.make(domain_name="cartpole", task_name="balance", keys_to_exclude=['velocity'], frame_skip=5)
12 | #
13 | #
14 | # def va():
15 | # return dmc2gym.make(domain_name="cartpole", task_name="balance", keys_to_exclude=['position'], frame_skip=5)
16 | #
17 | #
18 | # def p_concat5():
19 | # return ConcatObs(p(), 5)
20 | #
21 | # def va_concat10():
22 | # return ConcatObs(va(), 10)
23 |
--------------------------------------------------------------------------------
/envs/yang_domains/dmc_cart3pole.py:
--------------------------------------------------------------------------------
1 | import dmc2gym
2 | from .wrappers import ConcatObs
3 | from .wrappers_for_image import Normalize255Image, GrayscaleImage, ConcatImages
4 |
5 |
6 | def mdp():
7 | return dmc2gym.make(domain_name="cartpole", task_name="three_poles", keys_to_exclude=[], frame_skip=5)
8 |
9 |
10 | # def p():
11 | # return dmc2gym.make(domain_name="cartpole", task_name="balance", keys_to_exclude=['velocity'], frame_skip=5)
12 | #
13 | #
14 | # def va():
15 | # return dmc2gym.make(domain_name="cartpole", task_name="balance", keys_to_exclude=['position'], frame_skip=5)
16 | #
17 | #
18 | # def p_concat5():
19 | # return ConcatObs(p(), 5)
20 | #
21 | # def va_concat10():
22 | # return ConcatObs(va(), 10)
23 |
--------------------------------------------------------------------------------
/envs/yang_domains/dmc_cartpole_b.py:
--------------------------------------------------------------------------------
1 | import dmc2gym
2 | from .wrappers import ConcatObs
3 | from .wrappers_for_image import Normalize255Image, GrayscaleImage, ConcatImages
4 |
5 |
6 | def mdp():
7 | return dmc2gym.make(domain_name="cartpole", task_name="balance", keys_to_exclude=[], frame_skip=5, track_prev_action=False)
8 |
9 |
10 | def p():
11 | return dmc2gym.make(domain_name="cartpole", task_name="balance", keys_to_exclude=['velocity'], frame_skip=5, track_prev_action=False)
12 |
13 |
14 | def va():
15 | return dmc2gym.make(domain_name="cartpole", task_name="balance", keys_to_exclude=['position'], track_prev_action=True, frame_skip=5)
16 |
17 |
18 | def p_concat5():
19 | return ConcatObs(p(), 5)
20 |
21 | def va_concat10():
22 | return ConcatObs(va(), 10)
23 |
24 |
25 | def pomdp_img():
26 | """frame skip follows from the Dreamer benchmark"""
27 | raw_env = dmc2gym.make(
28 | domain_name="cartpole",
29 | task_name="balance",
30 | keys_to_exclude=[],
31 | visualize_reward=False,
32 | from_pixels=True,
33 | frame_skip=2
34 | )
35 | return GrayscaleImage(Normalize255Image(raw_env))
36 |
37 |
38 | def mdp_img_concat3():
39 | return ConcatImages(pomdp_img(), 3)
40 |
--------------------------------------------------------------------------------
/envs/yang_domains/dmc_cartpole_su.py:
--------------------------------------------------------------------------------
1 | import dmc2gym
2 | from .wrappers import ConcatObs
3 |
4 |
5 | def mdp():
6 | return dmc2gym.make(domain_name="cartpole", task_name="swingup", keys_to_exclude=[], frame_skip=5, track_prev_action=False)
7 |
8 |
9 | def p():
10 | return dmc2gym.make(domain_name="cartpole", task_name="swingup", keys_to_exclude=['velocity'], frame_skip=5, track_prev_action=False)
11 |
12 |
13 | def va():
14 | return dmc2gym.make(domain_name="cartpole", task_name="swingup", keys_to_exclude=['position'], frame_skip=5, track_prev_action=True)
15 |
16 |
17 | def p_concat5():
18 | return ConcatObs(p(), 5)
19 |
20 | def va_concat10():
21 | return ConcatObs(va(), 10)
22 |
--------------------------------------------------------------------------------
/envs/yang_domains/dmc_pendulum_su.py:
--------------------------------------------------------------------------------
1 | import dmc2gym
2 |
3 |
4 | def mdp():
5 | return dmc2gym.make(domain_name="pendulum", task_name="swingup", keys_to_exclude=[], frame_skip=5, track_prev_action=False)
6 |
--------------------------------------------------------------------------------
/envs/yang_domains/dmc_walker_walk.py:
--------------------------------------------------------------------------------
1 | import dmc2gym
2 | from .wrappers_for_image import Normalize255Image, GrayscaleImage, ConcatImages
3 |
4 |
5 | def pomdp_img():
6 | """frame skip follows from the Dreamer benchmark"""
7 | raw_env = dmc2gym.make(
8 | domain_name="walker",
9 | task_name="walk",
10 | keys_to_exclude=[],
11 | visualize_reward=False,
12 | from_pixels=True,
13 | frame_skip=2
14 | )
15 | return GrayscaleImage(Normalize255Image(raw_env))
16 |
17 |
18 | def mdp_img_concat3():
19 | return ConcatImages(pomdp_img(), 3)
20 |
--------------------------------------------------------------------------------
/envs/yang_domains/pybullet_ant.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import pybullet_envs
3 | from .wrappers import FilterObsByIndex
4 |
5 |
6 | def p():
7 | return FilterObsByIndex(
8 | gym.make("AntBulletEnv-v0"),
9 | indices_to_keep=[0, 1, 2, 3, 4, 5, 6, 7] + [8, 10, 12, 14, 16, 18, 20, 22] + [24, 25, 26, 27]
10 | )
11 |
12 |
13 | def v():
14 | return FilterObsByIndex(
15 | gym.make("AntBulletEnv-v0"),
16 | indices_to_keep=[0, 1, 2, 3, 4, 5, 6, 7] + [9, 11, 13, 15, 17, 19, 21, 23] + [24, 25, 26, 27]
17 | )
18 |
--------------------------------------------------------------------------------
/envs/yang_domains/pybullet_halfcheetah.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import pybullet_envs
3 | from .wrappers import FilterObsByIndex
4 |
5 |
6 | def p():
7 | return FilterObsByIndex(
8 | gym.make("HalfCheetahBulletEnv-v0"),
9 | indices_to_keep=[0, 1, 2, 3, 4, 5, 6, 7] + [8, 10, 12, 14, 16, 18] + [20, 21, 22, 23, 24, 25]
10 | )
11 |
12 |
13 | def v():
14 | return FilterObsByIndex(
15 | gym.make("HalfCheetahBulletEnv-v0"),
16 | indices_to_keep=[0, 1, 2, 3, 4, 5, 6, 7] + [9, 11, 13, 15, 17, 19] + [20, 21, 22, 23, 24, 25]
17 | )
18 |
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Ravens Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/bumps/bump_40.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/bumps/bump_40.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/bumps/bump_40_blue.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/bumps/bump_40_red.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/bumps/bump_40_virtual.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/bumps/bump_50.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/bumps/bump_50.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/bumps/bump_50.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/cup/cup.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/cup/cup.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/cup/cup.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/plane/checker_blue.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/plane/checker_blue.png
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/plane/plane.mtl:
--------------------------------------------------------------------------------
1 | newmtl Material
2 | Ns 10.0000
3 | Ni 1.5000
4 | d 1.0000
5 | Tr 0.0000
6 | Tf 1.0000 1.0000 1.0000
7 | illum 2
8 | Ka 0.0000 0.0000 0.0000
9 | Kd 0.5880 0.5880 0.5880
10 | Ks 0.0000 0.0000 0.0000
11 | Ke 0.0000 0.0000 0.0000
12 | map_Ka cube.tga
13 | map_Kd checker_blue.png
14 |
15 |
16 |
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/plane/plane.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/plane/plane100.obj:
--------------------------------------------------------------------------------
1 | # Blender v2.66 (sub 1) OBJ File: ''
2 | # www.blender.org
3 | mtllib plane.mtl
4 | o Plane
5 | v 100.000000 -100.000000 0.000000
6 | v 100.000000 100.000000 0.000000
7 | v -100.000000 100.000000 0.000000
8 | v -100.000000 -100.000000 0.000000
9 |
10 | vt 100.000000 0.000000
11 | vt 100.000000 100.000000
12 | vt 0.000000 100.000000
13 | vt 0.000000 0.000000
14 |
15 |
16 |
17 | usemtl Material
18 | s off
19 | f 1/1 2/2 3/3
20 | f 1/1 3/3 4/4
21 |
22 |
23 |
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/plate/plate.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/plate/plate.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/plate/plate.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/plate/plate_half.urdf:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/plate/plate_holder.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/plate/plate_holder.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/plate/plate_holder.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/plate/plate_lower_half.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/plate/plate_lower_half.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/plate/plate_upper_half.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/plate/plate_upper_half.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/shelf/shelf_back_board.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/shelf/shelf_back_board.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/shelf/shelf_back_board.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/shelf/shelf_horizontal_board.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/shelf/shelf_horizontal_board.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/shelf/shelf_horizontal_board.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/shelf/shelf_side_board.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/shelf/shelf_side_board.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/shelf/shelf_side_board.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/workspace/grid_mark.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/workspace/plane.obj:
--------------------------------------------------------------------------------
1 | # Blender v2.66 (sub 1) OBJ File: ''
2 | # www.blender.org
3 | mtllib plane.mtl
4 | o Plane
5 | v 15.000000 -15.000000 0.000000
6 | v 15.000000 15.000000 0.000000
7 | v -15.000000 15.000000 0.000000
8 | v -15.000000 -15.000000 0.000000
9 |
10 | vt 15.000000 0.000000
11 | vt 15.000000 15.000000
12 | vt 0.000000 15.000000
13 | vt 0.000000 0.000000
14 |
15 | usemtl Material
16 | s off
17 | f 1/1 2/2 3/3
18 | f 1/1 3/3 4/4
19 |
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/workspace/rail.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/envs/yang_domains/robot_envs/assets/workspace/workspace.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/envs/yang_domains/robots/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/__init__.py
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/jaco/meshes/arm.SLDPRT:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/arm.SLDPRT
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/jaco/meshes/arm.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/arm.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/jaco/meshes/arm_half_1.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/arm_half_1.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/jaco/meshes/arm_half_2.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/arm_half_2.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/jaco/meshes/arm_mico.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/arm_mico.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/jaco/meshes/base.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/base.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/jaco/meshes/finger_distal.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/finger_distal.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/jaco/meshes/finger_proximal.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/finger_proximal.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/jaco/meshes/forearm.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/forearm.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/jaco/meshes/forearm_mico.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/forearm_mico.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/jaco/meshes/hand_2finger.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/hand_2finger.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/jaco/meshes/hand_3finger.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/hand_3finger.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/jaco/meshes/ring_big.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/ring_big.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/jaco/meshes/ring_small.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/ring_small.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/jaco/meshes/shoulder.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/shoulder.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/jaco/meshes/wrist.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/wrist.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/jaco/meshes/wrist_spherical_1.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/wrist_spherical_1.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/jaco/meshes/wrist_spherical_2.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/wrist_spherical_2.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/rdda/meshes/narrow_finger.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/rdda/meshes/narrow_finger.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/rdda/meshes/wide_finger.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/rdda/meshes/wide_finger.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/rdda/meshes/wrist.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/rdda/meshes/wrist.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/robotiq/README.md:
--------------------------------------------------------------------------------
1 | ## Robotiq 2F 85 gripper
2 | For this gripper, the following Github repo can be used as a reference: https://github.com/Shreeyak/robotiq.git
3 |
4 | ### mimic tag in URDF
5 | This gripper is developed for ROS and uses the `mimic` tag within the URDF files to make the gripper move. From our research `mimic` tag within URDF is not supported by pybullet. To overcome this, one can use the `createConstraint` function. Please refer to [this](https://github.com/bulletphysics/bullet3/blob/master/examples/pybullet/examples/mimicJointConstraint.py) example from the bullet3 repo to see how to replicate a `mimic` joint:
6 |
7 | ```python
8 | #a mimic joint can act as a gear between two joints
9 | #you can control the gear ratio in magnitude and sign (>0 reverses direction)
10 |
11 | import pybullet as p
12 | import time
13 | p.connect(p.GUI)
14 | p.loadURDF("plane.urdf",0,0,-2)
15 | wheelA = p.loadURDF("differential/diff_ring.urdf",[0,0,0])
16 | for i in range(p.getNumJoints(wheelA)):
17 | print(p.getJointInfo(wheelA,i))
18 | p.setJointMotorControl2(wheelA,i,p.VELOCITY_CONTROL,targetVelocity=0,force=0)
19 |
20 |
21 | c = p.createConstraint(wheelA,1,wheelA,3,jointType=p.JOINT_GEAR,jointAxis =[0,1,0],parentFramePosition=[0,0,0],childFramePosition=[0,0,0])
22 | p.changeConstraint(c,gearRatio=1, maxForce=10000)
23 |
24 | c = p.createConstraint(wheelA,2,wheelA,4,jointType=p.JOINT_GEAR,jointAxis =[0,1,0],parentFramePosition=[0,0,0],childFramePosition=[0,0,0])
25 | p.changeConstraint(c,gearRatio=-1, maxForce=10000)
26 |
27 | c = p.createConstraint(wheelA,1,wheelA,4,jointType=p.JOINT_GEAR,jointAxis =[0,1,0],parentFramePosition=[0,0,0],childFramePosition=[0,0,0])
28 | p.changeConstraint(c,gearRatio=-1, maxForce=10000)
29 |
30 |
31 | p.setRealTimeSimulation(1)
32 | while(1):
33 | p.setGravity(0,0,-10)
34 | time.sleep(0.01)
35 | #p.removeConstraint(c)
36 |
37 | ```
38 |
39 |
40 | Details on `createConstraint` can be found in the pybullet [getting started](https://docs.google.com/document/d/10sXEhzFRSnvFcl3XxNGhnD4N2SedqwdAvK3dsihxVUA/edit#heading=h.fq749wu22x4c) guide.
41 |
42 | ### Files in folder
43 | Since parameters like gear ratio and direction are required, one can find the `robotiq_2f_85_mimic_joints.urdf` which contains the mimic tags as in original URDF, which can be used as a reference. It was generated from `robotiq/robotiq_2f_robot/robot/simple_rq2f85_pybullet.urdf.xacro` as so:
44 | ```
45 | rosrun xacro xacro --inorder simple_rq2f85_pybullet.urdf.xacro
46 | adaptive_transmission:="true" > robotiq_2f_85_mimic_joints.urdf
47 | ```
48 |
49 | The URDF meant for use in pybullet is `robotiq_2f_85.urdf` and it is generated in a similar manner as above by running:
50 | ```
51 | rosrun xacro xacro --inorder simple_rq2f85_pybullet.urdf.xacro > robotiq_2f_85.urdf
52 | ```
53 |
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-base.mtl:
--------------------------------------------------------------------------------
1 | # Blender MTL File: 'gripper-2f.blend'
2 | # Material Count: 1
3 |
4 | newmtl Default
5 | Ns 96.078431
6 | Ka 1.000000 1.000000 1.000000
7 | Kd 0.640000 0.640000 0.640000
8 | Ks 0.500000 0.500000 0.500000
9 | Ke 0.000000 0.000000 0.000000
10 | Ni 1.000000
11 | d 1.000000
12 | illum 2
13 | map_Kd textures/gripper-2f_BaseColor.jpg
14 |
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-base.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-base.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-coupler.mtl:
--------------------------------------------------------------------------------
1 | # Blender MTL File: 'gripper-2f.blend'
2 | # Material Count: 1
3 |
4 | newmtl Default
5 | Ns 96.078431
6 | Ka 1.000000 1.000000 1.000000
7 | Kd 0.640000 0.640000 0.640000
8 | Ks 0.500000 0.500000 0.500000
9 | Ke 0.000000 0.000000 0.000000
10 | Ni 1.000000
11 | d 1.000000
12 | illum 2
13 | map_Kd textures/gripper-2f_BaseColor.jpg
14 |
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-coupler.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-coupler.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-driver.mtl:
--------------------------------------------------------------------------------
1 | # Blender MTL File: 'gripper-2f.blend'
2 | # Material Count: 1
3 |
4 | newmtl Default
5 | Ns 96.078431
6 | Ka 1.000000 1.000000 1.000000
7 | Kd 0.640000 0.640000 0.640000
8 | Ks 0.500000 0.500000 0.500000
9 | Ke 0.000000 0.000000 0.000000
10 | Ni 1.000000
11 | d 1.000000
12 | illum 2
13 | map_Kd textures/gripper-2f_BaseColor.jpg
14 |
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-driver.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-driver.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-follower.mtl:
--------------------------------------------------------------------------------
1 | # Blender MTL File: 'gripper-2f.blend'
2 | # Material Count: 1
3 |
4 | newmtl Default
5 | Ns 96.078431
6 | Ka 1.000000 1.000000 1.000000
7 | Kd 0.640000 0.640000 0.640000
8 | Ks 0.500000 0.500000 0.500000
9 | Ke 0.000000 0.000000 0.000000
10 | Ni 1.000000
11 | d 1.000000
12 | illum 2
13 | map_Kd textures/gripper-2f_BaseColor.jpg
14 |
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-follower.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-follower.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-pad.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-pad.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-spring_link.mtl:
--------------------------------------------------------------------------------
1 | # Blender MTL File: 'gripper-2f.blend'
2 | # Material Count: 1
3 |
4 | newmtl Default
5 | Ns 96.078431
6 | Ka 1.000000 1.000000 1.000000
7 | Kd 0.640000 0.640000 0.640000
8 | Ks 0.500000 0.500000 0.500000
9 | Ke 0.000000 0.000000 0.000000
10 | Ni 1.000000
11 | d 1.000000
12 | illum 2
13 | map_Kd textures/gripper-2f_BaseColor.jpg
14 |
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-spring_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-spring_link.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/robotiq/textures/gripper-2f_BaseColor.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/robotiq/textures/gripper-2f_BaseColor.jpg
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/robotiq/textures/gripper-2f_Metallic.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/robotiq/textures/gripper-2f_Metallic.jpg
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/robotiq/textures/gripper-2f_Normal.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/robotiq/textures/gripper-2f_Normal.jpg
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/robotiq/textures/gripper-2f_Roughness.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/robotiq/textures/gripper-2f_Roughness.jpg
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/shovel/meshes/shovel_base.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/shovel/meshes/shovel_base.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/shovel/meshes/shovel_blade.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/shovel/meshes/shovel_blade.STL
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/shovel/shovel.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/spatula/spatula-base.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/suction/suction-base.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/ur5/collision/base.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/collision/base.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/ur5/collision/forearm.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/collision/forearm.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/ur5/collision/shoulder.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/collision/shoulder.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/ur5/collision/upperarm.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/collision/upperarm.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/ur5/collision/wrist1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/collision/wrist1.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/ur5/collision/wrist2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/collision/wrist2.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/ur5/collision/wrist3.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/collision/wrist3.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/ur5/license.txt:
--------------------------------------------------------------------------------
1 | Original work Copyright 2018 ROS Industrial (https://rosindustrial.org/)
2 | Modified work Copyright 2018 Virtana, Inc (www.virtanatech.com)
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
13 | or implied. See the License for the specific language governing
14 | permissions and limitations under the License.
15 |
16 | Changenotes:
17 |
18 | 2018-05-08: Vijay Pradeep (vijay@virtanatech.com)
19 | The visual/ mesh files were generated from the dae files in the
20 | ros-industrial universal robot repo [1]. Since the collada pyBullet
21 | parser is somewhat limited, it is unable to parse the UR collada mesh
22 | files. Thus, we imported these collada files into blender and
23 | converted them into STL files. We lost material definitions during
24 | the conversion, but that's ok.
25 |
26 | The URDF was generated by running the xacro xml preprocessor on the
27 | URDF included in the ur_description repo already mentioned here.
28 | Additional manual tweaking was required to update resource paths and
29 | to remove errors caused by missing inertia elements. Varios Gazebo
30 | plugin tags were also removed.
31 |
32 | [1] - https://github.com/ros-industrial/universal_robot/tree/kinetic-devel/ur_description/meshes/ur5/visual
33 |
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/ur5/visual/base.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/visual/base.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/ur5/visual/forearm.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/visual/forearm.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/ur5/visual/shoulder.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/visual/shoulder.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/ur5/visual/upperarm.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/visual/upperarm.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/ur5/visual/wrist1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/visual/wrist1.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/ur5/visual/wrist2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/visual/wrist2.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robots/assets/ur5/visual/wrist3.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/visual/wrist3.stl
--------------------------------------------------------------------------------
/envs/yang_domains/robots/end_effector.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import pybullet as p
3 |
4 | ASSETS_PATH = Path(__file__).resolve().parent / 'assets'
5 |
6 |
7 | class EndEffector:
8 | """
9 | A base class for UR5 end effectors.
10 | """
11 |
12 | def __init__(self, urdf_path, load_position, load_orientation, ur5_install_joints, ee_tip_idx):
13 | """
14 | The initialization of the UR5 end effector.
15 |
16 | :param urdf_path: the path of the urdf that describes the end effector
17 | :param load_position: the position to load the end effector
18 | :param load_orientation: the orientation to load the end effector
19 | :param ur5_install_joints: the home joints of the UR5 robot when installing this end effector
20 | :param ee_tip_idx: the body id of this end effector
21 | """
22 |
23 | self.urdf_path = urdf_path
24 | self.load_position = load_position
25 | self.load_orientation = load_orientation
26 | self.ur5_install_joints = ur5_install_joints
27 | self.ee_tip_idx = ee_tip_idx
28 |
29 | # Loads the model.
30 | self.body_id = p.loadURDF(str(urdf_path), load_position, load_orientation)
31 |
32 | def get_body_id(self):
33 | """
34 | Gets the body id of this end effector in PyBullet.
35 |
36 | :return: the body id
37 | """
38 |
39 | return self.body_id
40 |
41 | def get_position_offset(self):
42 | """
43 | Gets the position offset of this end effector.
44 |
45 | :return: the position offset
46 | """
47 |
48 | return 0, 0, 0
49 |
50 | def get_ur5_install_joints(self):
51 | """
52 | Gets the the home joints of the UR5 robot when installing this end effector.
53 |
54 | :return: the UR5 home joints for this end effector
55 | """
56 |
57 | return self.ur5_install_joints
58 |
59 | def get_base_pose(self):
60 | """
61 | Gets the base position and orientation of this end effector.
62 |
63 | :return: the position and orientation of the base
64 | """
65 |
66 | return p.getBasePositionAndOrientation(self.body_id)
67 |
68 | def get_tip_pose(self):
69 | """
70 | Gets the tip position and orientation of this end effector.
71 |
72 | :return: the position and orientation of the tip
73 | """
74 |
75 | state = p.getLinkState(self.body_id, self.ee_tip_idx)
76 |
77 | return state[4], state[5]
78 |
79 | def reset(self, reset_base=False):
80 | """
81 | Resets this end effector.
82 |
83 | :param reset_base: True if resetting the base pose, False otherwise
84 | """
85 |
86 | if reset_base:
87 | p.resetBasePositionAndOrientation(bodyUniqueId=self.body_id,
88 | posObj=self.load_position,
89 | ornObj=self.load_orientation)
90 |
--------------------------------------------------------------------------------
/envs/yang_domains/robots/rdda.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pybullet as p
3 |
4 | from .gripper import Gripper, ASSETS_PATH
5 |
6 | RDDA_URDF_PATH = ASSETS_PATH / 'rdda' / 'rdda.urdf'
7 |
8 |
9 | class Rdda(Gripper):
10 | """
11 | A class for the RDDA gripper.
12 | """
13 |
14 | def __init__(self, ur5_id, ee_id):
15 | """
16 | The initialization of the RDDA gripper.
17 |
18 | :param ur5_id: the body id of the UR5 robot to install this RDDA
19 | :param ee_id: the link index of the UR5 robot to install this RDDA
20 | """
21 |
22 | super().__init__(urdf_path=RDDA_URDF_PATH,
23 | load_position=(0.4746, 0.1092, 0.4195),
24 | load_orientation=p.getQuaternionFromEuler((np.pi / 2, np.pi / 2, 0)),
25 | ur5_install_joints=np.float32([-1, -0.5, 0.5, 0, 0.5, -1]) * np.pi,
26 | ee_tip_idx=1) # The ee tip of the RDDA gripper is the tip of the wide finger.
27 |
28 | # Installs the RDDA gripper on the UR5.
29 | self.offset = -0.025
30 | constraint_id = p.createConstraint(
31 | parentBodyUniqueId=ur5_id,
32 | parentLinkIndex=ee_id,
33 | childBodyUniqueId=self.body_id,
34 | childLinkIndex=-1,
35 | jointType=p.JOINT_FIXED,
36 | jointAxis=(0, 0, 0),
37 | parentFramePosition=(0, 0, 0),
38 | parentFrameOrientation=p.getQuaternionFromEuler((np.pi / 2, 0, np.pi / 2)),
39 | childFramePosition=(0, self.offset, 0))
40 | p.changeConstraint(constraint_id, maxForce=50)
41 |
42 | def get_position_offset(self):
43 | """
44 | Gets the position offset of this RDDA gripper.
45 |
46 | :return: the position offset
47 | """
48 |
49 | return self.offset, 0, 0
50 |
--------------------------------------------------------------------------------
/envs/yang_domains/robots/robotiq.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pybullet as p
3 |
4 | from ..robots.gripper import Gripper, ASSETS_PATH
5 |
6 | Robotiq_URDF_PATH = ASSETS_PATH / 'robotiq' / 'robotiq_2f_85.urdf'
7 |
8 |
9 | class Robotiq(Gripper):
10 | """
11 | A class for the Robotiq gripper.
12 | """
13 |
14 | def __init__(self, ur5_id, ee_id):
15 | """
16 | The initialization of the Robotiq gripper.
17 |
18 | :param ur5_id: the body id of the UR5 robot to install this Robotiq
19 | :param ee_id: the link index of the UR5 robot to install this Robotiq
20 | """
21 |
22 | super().__init__(urdf_path=Robotiq_URDF_PATH,
23 | load_position=(0.4868, 0.1093, 0.431594),
24 | load_orientation=p.getQuaternionFromEuler((np.pi, 0, 0)),
25 | ur5_install_joints=np.float32([-1, -0.5, 0.5, -0.5, -0.5, 0]) * np.pi,
26 | ee_tip_idx=0) # FIXME
27 |
28 | # Installs the Robotiq gripper on the UR5.
29 | constraint_id = p.createConstraint(
30 | parentBodyUniqueId=ur5_id,
31 | parentLinkIndex=ee_id,
32 | childBodyUniqueId=self.body_id,
33 | childLinkIndex=-1,
34 | jointType=p.JOINT_FIXED,
35 | jointAxis=(0, 0, 0),
36 | parentFramePosition=(0, 0, 0),
37 | parentFrameOrientation=p.getQuaternionFromEuler((0, 0, np.pi / 2)),
38 | childFramePosition=(0, 0, 0))
39 | p.changeConstraint(constraint_id, maxForce=50)
40 |
--------------------------------------------------------------------------------
/envs/yang_domains/robots/shovel.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pybullet as p
3 |
4 | from .gripper import Gripper, ASSETS_PATH
5 |
6 | SHOVEL_URDF_PATH = ASSETS_PATH / 'shovel' / 'shovel.urdf'
7 |
8 |
9 | class Shovel(Gripper):
10 | """
11 | A class for the shovel gripper.
12 | """
13 |
14 | def __init__(self, ur5_id, ee_id):
15 | """
16 | The initialization of the shovel gripper.
17 |
18 | :param ur5_id: the body id of the UR5 robot to install this shovel
19 | :param ee_id: the link index of the UR5 robot to install this shovel
20 | """
21 |
22 | super().__init__(urdf_path=SHOVEL_URDF_PATH,
23 | load_position=(0.487, 0.109, 0.438),
24 | load_orientation=p.getQuaternionFromEuler((np.pi, 0, np.pi / 2)),
25 | ur5_install_joints=np.float32([-1, -0.5, 0.5, -0.5, -0.5, 0]) * np.pi,
26 | ee_tip_idx=1)
27 |
28 | # Installs the shovel gripper on the UR5.
29 | constraint_id = p.createConstraint(
30 | parentBodyUniqueId=ur5_id,
31 | parentLinkIndex=ee_id,
32 | childBodyUniqueId=self.body_id,
33 | childLinkIndex=-1,
34 | jointType=p.JOINT_FIXED,
35 | jointAxis=(0, 0, 0),
36 | parentFramePosition=(0, 0, 0),
37 | childFramePosition=(0, 0, 0.01))
38 | p.changeConstraint(constraint_id, maxForce=50)
39 |
--------------------------------------------------------------------------------
/envs/yang_domains/robots/spatula.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pybullet as p
3 |
4 | from .end_effector import EndEffector, ASSETS_PATH
5 |
6 | SPATULA_URDF_PATH = ASSETS_PATH / 'spatula' / 'spatula-base.urdf'
7 |
8 |
9 | class Spatula(EndEffector):
10 | """
11 | A class for a simple spatula for pushing.
12 | """
13 |
14 | def __init__(self, ur5_id, ee_id):
15 | """
16 | The initialization of the spatula.
17 |
18 | :param ur5_id: the body id of the UR5 robot to install this spatula
19 | :param ee_id: the link index of the UR5 robot to install this spatula
20 | """
21 |
22 | super().__init__(urdf_path=SPATULA_URDF_PATH,
23 | load_position=(0.487, 0.109, 0.438),
24 | load_orientation=p.getQuaternionFromEuler((np.pi, 0, 0)),
25 | ur5_install_joints=np.float32([-1, -0.5, 0.5, -0.5, -0.5, 0]) * np.pi,
26 | ee_tip_idx=0)
27 |
28 | # Installs the spatula on the UR5.
29 | constraint_id = p.createConstraint(
30 | parentBodyUniqueId=ur5_id,
31 | parentLinkIndex=ee_id,
32 | childBodyUniqueId=self.body_id,
33 | childLinkIndex=-1,
34 | jointType=p.JOINT_FIXED,
35 | jointAxis=(0, 0, 0),
36 | parentFramePosition=(0, 0, 0),
37 | parentFrameOrientation=p.getQuaternionFromEuler((0, 0, np.pi / 2)),
38 | childFramePosition=(0, 0, 0.01))
39 | p.changeConstraint(constraint_id, maxForce=50)
40 |
--------------------------------------------------------------------------------
/envs/yang_domains/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Ravens Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/envs/yang_domains/wrappers.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 | import numpy as np
3 | import gym
4 | import gym.spaces as spaces
5 |
6 | # FYI
7 | # - By default, wrappers directly copy the observation and action space from their wrappees.
8 | # - The observation method in ObservationWrappers will get used by both env.reset and env.step.
9 | # - In general, observation space is used only for getting observation shape for networks.
10 |
11 |
12 | class FilterObsByIndex(gym.ObservationWrapper):
13 |
14 | def _filter(self, array: np.array) -> np.array:
15 | return np.array(
16 | [x for i, x in enumerate(array) if i in self.indices_to_keep]
17 | )
18 |
19 | def __init__(self, env, indices_to_keep: list):
20 |
21 | super().__init__(env)
22 | self.indices_to_keep = indices_to_keep
23 |
24 | new_high, new_low = self._filter(self.env.observation_space.high), self._filter(self.env.observation_space.low)
25 | self.observation_space = spaces.Box(low=new_low, high=new_high)
26 |
27 | def observation(self, observation):
28 | return self._filter(observation)
29 |
30 |
31 | class ConcatObs(gym.ObservationWrapper):
32 |
33 | def __init__(self, env, window_size: int):
34 |
35 | super().__init__(env)
36 |
37 | # get info on old observation space
38 | old_obs_space = env.observation_space
39 | old_obs_space_dim = old_obs_space.shape[0]
40 | old_obs_space_low, old_obs_space_high = old_obs_space.low, old_obs_space.high
41 |
42 | # change observation space
43 | self.observation_space = spaces.Box(
44 | low=np.array(list(old_obs_space_low) * window_size),
45 | high=np.array(list(old_obs_space_high) * window_size)
46 | )
47 |
48 | self.window = deque(maxlen=window_size)
49 | for i in range(window_size - 1):
50 | self.window.append(np.zeros((old_obs_space_dim, ))) # append some dummy observations first
51 |
52 | self.window_size = window_size
53 | self.old_obs_space_dim = old_obs_space_dim
54 |
55 | def observation(self, obs: np.array) -> np.array:
56 | self.window.append(obs)
57 | return np.concatenate(self.window)
58 |
59 | def reset(self):
60 | for i in range(self.window_size - 1):
61 | self.window.append(np.zeros((self.old_obs_space_dim, ))) # append some dummy observations first
62 | observation = self.env.reset()
63 | return self.observation(observation)
64 |
--------------------------------------------------------------------------------
/envs/yang_domains/wrappers_for_image.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import gym.spaces as spaces
3 | import numpy as np
4 | from collections import deque
5 |
6 |
7 | class GrayscaleImage(gym.ObservationWrapper):
8 |
9 | def __init__(self, env):
10 | super().__init__(env)
11 | old_observation_space = env.observation_space
12 | self.observation_space = spaces.Box(
13 | low=np.expand_dims(old_observation_space.low[0], axis=0),
14 | high=np.expand_dims(old_observation_space.high[0], axis=0),
15 | )
16 | # https://www.kite.com/python/answers/how-to-convert-an-image-from-rgb-to-grayscale-in-python
17 | self.rgb_weights = np.array([0.2989, 0.5870, 0.1140]).reshape(3, 1)
18 |
19 | def observation(self, observation):
20 | observation = np.moveaxis(observation, 0, 2)
21 | observation = observation @ self.rgb_weights
22 | return np.moveaxis(observation, 2, 0)
23 |
24 |
25 | class Normalize255Image(gym.ObservationWrapper):
26 |
27 | def __init__(self, env):
28 | super().__init__(env)
29 | old_observation_space = env.observation_space
30 | self.observation_space = spaces.Box(
31 | low=old_observation_space.low / 255,
32 | high=old_observation_space.high / 255
33 | )
34 |
35 | def observation(self, observation):
36 | return observation / 255 # uint8 automatically get converts to float64
37 |
38 |
39 | class ConcatImages(gym.ObservationWrapper):
40 |
41 | def __init__(self, env, window_size: int):
42 | super().__init__(env)
43 |
44 | self.window_size = window_size
45 | self.old_obs_space_shape = env.observation_space.shape
46 |
47 | old_obs_space = env.observation_space
48 | old_obs_low = old_obs_space.low
49 | old_obs_high = old_obs_space.high
50 |
51 | self.observation_space = spaces.Box(
52 | low=np.concatenate([old_obs_low, old_obs_low, old_obs_low]),
53 | high=np.concatenate([old_obs_high, old_obs_high, old_obs_high]),
54 | dtype=np.uint8 # a must for images to work with SB3
55 | )
56 |
57 | self.window = deque(maxlen=window_size)
58 | for i in range(self.window_size - 1):
59 | self.window.append(np.zeros(self.old_obs_space_shape))
60 |
61 | def observation(self, observation):
62 | self.window.append(observation)
63 | return np.concatenate(self.window)
64 |
65 | def reset(self):
66 | for i in range(self.window_size - 1):
67 | self.window.append(np.zeros(self.old_obs_space_shape))
68 | observation = self.env.reset()
69 | return self.observation(observation)
70 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | warnings.filterwarnings("ignore", category=UserWarning, module="gym.*")
3 | warnings.filterwarnings("ignore", ".*The DISPLAY environment variable is missing") # dmc
4 | warnings.filterwarnings("ignore", category=FutureWarning) # dmc
5 | import multiprocessing
6 | from offpolicy_rnn import init_smart_logger, Parameter, alg_init
7 |
8 |
9 | def main():
10 | if not multiprocessing.get_start_method(allow_none=True) == 'spawn':
11 | multiprocessing.set_start_method('spawn', force=True)
12 | init_smart_logger()
13 | parameter = Parameter()
14 | sac = alg_init(parameter)
15 | sac.train()
16 |
17 |
18 | if __name__ == '__main__':
19 | main()
--------------------------------------------------------------------------------
/offpolicy_rnn/__init__.py:
--------------------------------------------------------------------------------
1 | from .algorithm.sac import SAC
2 | from .algorithm.sac_mlp import SAC_MLP
3 | from .config.load_config import init_smart_logger
4 | from .parameter.ParameterSAC import Parameter
5 | from .algorithm.sac_rnn_slice import SACRNNSlice
6 | from .algorithm.sac_mlp_redq import SAC_MLP_REDQ
7 | from .algorithm.sac_full_length_rnn_ensembleQ import SACFullLengthRNNEnsembleQ
8 | from .algorithm.sac_full_length_rnn_redq_sep_optim import SACFullLengthRNNREDQ_SEP_OPTIM
9 | from .algorithm.sac_full_length_rnn_ensembleQ_sep_optim import SACFullLengthRNNENSEMBLEQ_SEP_OPTIM
10 | from .algorithm.sac_full_length_rnn_redq import SACFullLengthRNNREDQ
11 | from .utility.alg_init import alg_init
--------------------------------------------------------------------------------
/offpolicy_rnn/buffers/replay_memory_tail_padding.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from collections import namedtuple
3 | import numpy as np
4 | import pickle
5 | import time
6 | from typing import List, Union, Tuple, Dict, Optional, Any
7 | from .replay_memory import Transition, tuplenames, MemoryArray
8 |
9 |
10 | class MemoryArrayTailZeroPadding(MemoryArray):
11 | def __init__(self, max_trajectory_num: int=1000, max_traj_step: int=1000, rnn_slice_length: int=1, fixed_length: int=32):
12 | super().__init__(max_trajectory_num, max_traj_step, rnn_slice_length)
13 |
14 | self.fixed_length = fixed_length
15 | assert self.fixed_length >= 1
16 | self.last_sampled_batch = None
17 |
18 | def reset(self):
19 | super().reset()
20 | self.last_sampled_batch = None
21 |
22 | def sample_fix_length_sub_trajs(self, batch_size, fix_length):
23 | list_ind = np.random.randint(0, self.transition_count, batch_size)
24 | res = [self.transition_buffer[ind] for ind in list_ind]
25 | if self.last_sampled_batch is None or not self.last_sampled_batch.shape[0] == batch_size or not self.last_sampled_batch.shape[1] == fix_length:
26 | trajs = [self.memory_buffer[traj_ind, point_ind:point_ind+self.fixed_length]for traj_ind, point_ind in res]
27 | trajs = np.array(trajs, copy=True)
28 | self.last_sampled_batch = trajs
29 | else:
30 | # reuse last sampled batch, avoid frequently allocating memory
31 | for ind, (traj_ind, point_ind) in enumerate(res):
32 | self.last_sampled_batch[ind, :, :] = self.memory_buffer[traj_ind,
33 | point_ind: point_ind+self.fixed_length, :]
34 |
35 | res = self.array_to_transition(self.last_sampled_batch)
36 | return res
37 |
38 | def _make_buffer(self, dim):
39 | self.memory_buffer = np.zeros((self.max_trajectory_num, self.max_traj_step + self.fixed_length - 1, dim))
40 | print(f'Tail zero padding buffer init done!')
41 |
42 |
43 |
--------------------------------------------------------------------------------
/offpolicy_rnn/config/common_config.yaml:
--------------------------------------------------------------------------------
1 | BACKUP_IGNORE_HEAD:
2 | - __p
3 | - .
4 | BACKUP_IGNORE_KEY:
5 | - logfile
6 | - logfile_bk
7 | BACKUP_IGNORE_TAIL:
8 | - stl
9 | - dae
10 | - STL
11 | - urdf
12 | - pkl
13 | - pdf
14 | BASE_PATH: null
15 | LOG_DIR_BACKING_NAME: OffpolicyRNN0302
16 | LOG_FOLDER_NAME: logfile
17 | LOG_FOLDER_NAME_BK: logfile_bk
18 | MAIN_MACHINE_IP:
19 | - 127.0.0.1
20 | MAIN_MACHINE_LOG_PATH: /path/to/remote/OffpolicyRNN0302
21 | MAIN_MACHINE_PASSWD: passwd
22 | MAIN_MACHINE_PORT:
23 | - 22
24 | MAIN_MACHINE_USER: username
25 | WORKSPACE_PATH: /path/to/remote/OffpolicyRNN0302
26 |
--------------------------------------------------------------------------------
/offpolicy_rnn/config/experiment_config.yaml:
--------------------------------------------------------------------------------
1 | EXPERIMENT_COMMON_PARAMETERS:
2 | MAX_TRAJ_STEP: 1000
3 | EXPERIMENT_TARGET: "Offpolicy RNN Test"
4 | SHORT_NAME_SUFFIX: INIT_TEST
5 | IMPORTANT_CONFIGS:
6 | - env_name
7 | - alg_name
8 | - seed
9 |
10 |
--------------------------------------------------------------------------------
/offpolicy_rnn/config/load_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import smart_logger
3 |
4 |
5 | def init_smart_logger():
6 | base_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7 | config_path = os.path.dirname(os.path.abspath(__file__))
8 | config_relative_path = os.path.relpath(config_path, base_path)
9 | smart_logger.init_config(os.path.join(config_relative_path, 'common_config.yaml'),
10 | os.path.join(config_relative_path, 'experiment_config.yaml'),
11 | base_path)
12 |
--------------------------------------------------------------------------------
/offpolicy_rnn/env_utils/make_env.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from envs.make_pomdp_env import make_pomdp_env
3 | from envs.pomdp_config import env_config
4 | import gym
5 | from typing import Dict
6 | import numpy as np
7 | import contextlib
8 | import random
9 | try:
10 | import envs.dmc
11 | has_dm_control = True
12 | except:
13 | has_dm_control = False
14 |
15 |
16 | @contextlib.contextmanager
17 | def fixed_seed(seed):
18 | """上下文管理器,用于同时固定random和numpy.random的种子"""
19 | state_np = np.random.get_state()
20 | state_random = random.getstate()
21 | torch_random = torch.get_rng_state()
22 | if torch.cuda.is_available():
23 | # TODO: if not preproducable, check torch.get_rng_state
24 | torch_cuda_random = torch.cuda.get_rng_state()
25 | torch_cuda_random_all = torch.cuda.get_rng_state_all()
26 | np.random.seed(seed)
27 | random.seed(seed)
28 | torch.manual_seed(seed)
29 |
30 | try:
31 | yield
32 | finally:
33 | np.random.set_state(state_np)
34 | random.setstate(state_random)
35 | torch.set_rng_state(torch_random)
36 | if torch.cuda.is_available():
37 | torch.cuda.set_rng_state(torch_cuda_random)
38 | torch.cuda.set_rng_state_all(torch_cuda_random_all)
39 |
40 |
41 | def make_env(env_name: str, seed: int) -> Dict:
42 | if env_name in env_config:
43 | with fixed_seed(seed):
44 | result = make_pomdp_env(env_name, seed)
45 | result['seed'] = seed
46 | return result
47 | else:
48 | with fixed_seed(seed):
49 | if env_name.startswith('dmc'):
50 | env = gym.make(env_name, seed=seed)
51 | max_episode_steps = env.unwrapped._max_episode_steps
52 | else:
53 | env = gym.make(env_name)
54 | max_episode_steps = env._max_episode_steps
55 | env.seed(seed)
56 | env.action_space.seed(seed+1)
57 | env.observation_space.seed(seed+2)
58 | result = {
59 | 'train_env': env,
60 | 'eval_env': env,
61 | 'train_tasks': [],
62 | 'eval_tasks': [None] * 20,
63 | 'max_rollouts_per_task': 1,
64 | 'max_trajectory_len': max_episode_steps,
65 | 'obs_dim': env.observation_space.shape[0],
66 | 'act_dim': env.action_space.shape[0] if isinstance(env.action_space, gym.spaces.Box) else env.action_space.n,
67 | 'act_continuous': isinstance(env.action_space, gym.spaces.Box),
68 | 'seed': seed,
69 | 'multiagent': False
70 | }
71 |
72 | return result
--------------------------------------------------------------------------------
/offpolicy_rnn/models/conv1d/conv1d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class Conv1d(torch.nn.Module):
6 | def __init__(self, in_channels, out_channels, d_conv=4, bias=True, ff=True):
7 | super().__init__()
8 | assert in_channels == out_channels
9 | self.in_channels = in_channels
10 | self.out_channels = out_channels
11 | self.d_conv = d_conv
12 | self.conv1d = torch.nn.Conv1d(
13 | in_channels=in_channels,
14 | out_channels=out_channels,
15 | bias=bias,
16 | kernel_size=d_conv,
17 | groups=in_channels,
18 | padding=0,
19 | )
20 | self.desired_hidden_dim = self.in_channels * (self.d_conv - 1)
21 | self.use_ff = ff
22 | if ff:
23 | self.ff = PositionWiseFeedForward(out_channels, 0.0)
24 |
25 |
26 | def conv1d_func(self, x, hidden, mask):
27 | (b, l, d) = x.shape
28 | if mask is not None:
29 | x = x * mask
30 | x_input = torch.cat((hidden, x), dim=-2)
31 | x = x_input.transpose(-2, -1)
32 | x = self.conv1d(x)[:, :, :l]
33 | x = x.transpose(-2, -1)
34 | hidden = x_input[:, -(self.d_conv-1):, :]
35 | return x, hidden
36 |
37 | def forward(self, x, hidden=None, mask=None):
38 | batch_size = x.shape[0]
39 | if hidden is None:
40 | hidden = torch.zeros((batch_size, self.d_conv - 1, self.in_channels), device=x.device, dtype=x.dtype)
41 | else:
42 | hidden = hidden.reshape((batch_size, self.d_conv - 1, self.in_channels))
43 |
44 | x, hidden = self.conv1d_func(x, hidden, mask)
45 | hidden = hidden.reshape((batch_size, 1, -1))
46 | if self.use_ff:
47 | x = self.ff(x)
48 | return x, hidden
49 |
50 |
51 |
52 |
53 | class PositionWiseFeedForward(nn.Module):
54 | def __init__(self, d_model, dropout=0.0):
55 | super().__init__()
56 | self.w_1 = nn.Linear(d_model, d_model)
57 | self.w_2 = nn.Linear(d_model, d_model)
58 | self.activation = nn.GELU()
59 | self.dropout = nn.Dropout(dropout)
60 | self.layer_norm = nn.LayerNorm(d_model)
61 |
62 | def forward(self, x):
63 | x_ = self.dropout(self.activation(self.w_1(x)))
64 | return self.layer_norm(self.dropout(self.w_2(x_)) + x)
65 |
66 |
67 |
--------------------------------------------------------------------------------
/offpolicy_rnn/models/ensemble_linear_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import functional as F
5 | from typing import Dict, List, Union, Tuple, Optional
6 |
7 |
8 | class EnsembleLinear(nn.Module):
9 | def __init__(
10 | self,
11 | input_dim: int,
12 | output_dim: int,
13 | num_ensemble: int,
14 | bias: bool=True,
15 | desire_ndim: int=None
16 | ) -> None:
17 | super().__init__()
18 | self.use_bias = bias
19 | self.desire_ndim = desire_ndim
20 | self.num_ensemble = num_ensemble
21 |
22 | self.register_parameter("weight", nn.Parameter(torch.zeros(num_ensemble, input_dim, output_dim)))
23 | if self.use_bias:
24 | self.register_parameter("bias", nn.Parameter(torch.zeros(num_ensemble, 1, output_dim)))
25 |
26 | nn.init.trunc_normal_(self.weight, std=1/(2*input_dim**0.5))
27 | self.device = torch.device('cpu')
28 |
29 | def forward(self, x: torch.Tensor) -> torch.Tensor:
30 | weight = self.weight
31 | if self.use_bias:
32 | bias = self.bias
33 | else:
34 | bias = None
35 | if len(x.shape) == 2:
36 | x = torch.einsum('ij,bjk->bik', x, weight)
37 | elif len(x.shape) == 3:
38 | # TODO: judge the shape carefully according to your application
39 | if (self.desire_ndim is None or self.desire_ndim == 3) and x.shape[0] == weight.data.shape[0]:
40 | x = torch.einsum('bij,bjk->bik', x, weight)
41 | else:
42 | x = torch.einsum('cij,bjk->bcik', x, weight)
43 | elif len(x.shape) == 4:
44 | if (self.desire_ndim is None or self.desire_ndim == 4) and x.shape[0] == weight.data.shape[0]:
45 | x = torch.einsum('cbij,cjk->cbik', x, weight)
46 | else:
47 | x = torch.einsum('cdij,bjk->bcdik', x, weight)
48 | elif len(x.shape) == 5:
49 | x = torch.einsum('bcdij,bjk->bcdik', x, weight)
50 | if self.use_bias:
51 | assert x.shape[0] == bias.shape[0] and x.shape[-1] == bias.shape[-1]
52 | if len(x.shape) == 4:
53 | bias = bias.unsqueeze(1)
54 | elif len(x.shape) == 5:
55 | bias = bias.unsqueeze(1)
56 | bias = bias.unsqueeze(1)
57 |
58 | x = x + bias
59 |
60 | return x
61 |
62 | def to(self, device):
63 | if not device == self.device:
64 | self.device = device
65 | super().to(device)
66 | self.weight = self.weight.to(self.device)
67 | if self.use_bias:
68 | self.bias = self.bias.to(self.device)
69 |
70 |
--------------------------------------------------------------------------------
/offpolicy_rnn/models/flash_attention/gpt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from flash_attn.models.gpt_rl import GPTModel, GPT2Config
3 | from torch.nn import Module
4 | from flash_attn.utils.generation import InferenceParams
5 |
6 |
7 | # GPT2Config={
8 | # "activation_function": "gelu_new",
9 | # "attn_pdrop": 0.1,
10 | # "bos_token_id": 50256,
11 | # "embd_pdrop": 0.1,
12 | # "eos_token_id": 50256,
13 | # "initializer_range": 0.02,
14 | # "layer_norm_epsilon": 1e-05,
15 | # "model_type": "gpt2",
16 | # "n_embd": 768,
17 | # "n_head": 12,
18 | # "n_inner": null,
19 | # "n_layer": 12,
20 | # "n_positions": 1024,
21 | # "reorder_and_upcast_attn": false,
22 | # "resid_pdrop": 0.1,
23 | # "scale_attn_by_inverse_layer_idx": false,
24 | # "scale_attn_weights": true,
25 | # "summary_activation": null,
26 | # "summary_first_dropout": 0.1,
27 | # "summary_proj_to_labels": true,
28 | # "summary_type": "cls_index",
29 | # "summary_use_proj": true,
30 | # "transformers_version": "4.39.3",
31 | # "use_cache": true,
32 | # "vocab_size": 50257
33 | # }
34 |
35 |
36 | class GPTLayer(Module):
37 | def __init__(self, ndim=768, nhead=12, nlayer=12, pdrop=0.1, norm_epsilon=1e-5):
38 | super().__init__()
39 | config = GPT2Config(n_embd=ndim, n_head=nhead, nlayer=nlayer, attn_pdrop=pdrop,
40 | layer_norm_epsilon=norm_epsilon,
41 | resid_pdrop=pdrop, embd_pdrop=pdrop,
42 | )
43 |
44 | config.use_flash_attn = True
45 | config.fused_bias_fc = True
46 | config.fused_mlp = True
47 | config.fused_dropout_add_ln = True
48 | config.residual_in_fp32 = True
49 | config.rms_norm = True
50 |
51 | config.n_positions = 2048
52 | config.rotary_emb_fraction = 0.0
53 | config.rotary_emb_base = 0
54 | config.use_alibi = True
55 | # config.n_positions = 0
56 | # config.rotary_emb_fraction = 1.0
57 | # config.rotary_emb_base = 10000
58 |
59 | self.model = GPTModel(config, dtype=torch.float32)
60 | self.flash_attn_flag = True
61 |
62 | def make_init_hidden(self, max_seqlen, max_batch_size):
63 | return InferenceParams(max_seqlen, max_batch_size)
64 |
65 | def forward(self, hidden_states, attention_mask_in_length=None, inference_params=None):
66 | original_dtype = hidden_states.dtype
67 | with torch.cuda.amp.autocast(dtype=torch.bfloat16):
68 | result = self.model.forward(hidden_states, attention_mask_in_length, inference_params)
69 | result = result.to(original_dtype)
70 | return result
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/offpolicy_rnn/models/gilr/gilr.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | try:
7 | from .scan_triton.real_rnn_tie_input_gate import real_scan_tie_input_gate, real_scan_tie_input_gate_fused
8 | from .scan_triton.real_rnn_fast_pscan import pscan_fast
9 | except Exception as _:
10 | pass
11 | from .scan_triton.real_rnn_tie_input_gate_cpu import scan_cpu, scan_cpu_fuse
12 | from ..ensemble_linear_model import EnsembleLinear
13 | class GILRLayer(nn.Module):
14 | def __init__(
15 | self,
16 | input_dim,
17 | output_dim,
18 | factor=1,
19 | dropout=0.0,
20 | use_ff=True,
21 | batch_first=True,
22 | ):
23 | super().__init__()
24 | assert batch_first
25 | self.d_model = output_dim
26 | self.in_proj = EnsembleLinear(input_dim, self.d_model * factor, 2, desire_ndim=4)
27 | self.out_proj = torch.nn.Linear(self.d_model * factor, self.d_model * factor)
28 | self.dropout = nn.Dropout(dropout)
29 | self.layer_norm = nn.LayerNorm(factor * self.d_model)
30 | self.swish = nn.SiLU()
31 | self.use_ff = use_ff
32 | if self.use_ff:
33 | self.ff = PositionWiseFeedForward(self.d_model, dropout)
34 | self.device = torch.device('cpu')
35 |
36 | def rnn_parameters(self):
37 | return list(self.parameters(True))
38 |
39 | def to(self, device):
40 | if not self.device == device:
41 | super().to(device)
42 | self.device = device
43 |
44 | def forward(self, x, hidden=None, rnn_start=None):
45 | u = self.in_proj(x)
46 | v = u[0]
47 | f = u[1]
48 | if hidden is None:
49 | hidden = torch.zeros((v.shape[0], 1, self.d_model * 2), device=v.device)
50 | else:
51 | hidden = hidden.transpose(0, 1)
52 | hidden_pre = hidden
53 | f = torch.sigmoid(f)
54 | v = torch.tanh(v)
55 | if rnn_start is not None:
56 | f = f * (1 - rnn_start)
57 | if torch.all(hidden_pre == 0) and not self.device == torch.device('cpu'):
58 | v = real_scan_tie_input_gate(v.contiguous(), f.contiguous())
59 | hidden_pre = v[:, -1:, :]
60 | else:
61 | v, hidden_pre = scan_cpu(v, f, hidden_pre)
62 | out = self.out_proj(v)
63 | hidden = hidden_pre
64 | hidden = hidden.transpose(0, 1)
65 | if self.use_ff:
66 | out = self.ff(out)
67 | return out, hidden
68 |
69 |
70 | class PositionWiseFeedForward(nn.Module):
71 | def __init__(self, d_model, dropout=0.1):
72 | super().__init__()
73 | self.w_1 = nn.Linear(d_model, d_model)
74 | self.w_2 = nn.Linear(d_model, d_model)
75 | self.activation = nn.GELU()
76 | self.dropout = nn.Dropout(dropout)
77 | self.layer_norm = nn.LayerNorm(d_model)
78 |
79 | def forward(self, x):
80 | x_ = self.dropout(self.activation(self.w_1(x)))
81 | return self.layer_norm(self.dropout(self.w_2(x_)) + x)
82 |
--------------------------------------------------------------------------------
/offpolicy_rnn/models/gilr/scan_triton/real_rnn_fast_pscan.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import triton
4 | import triton.language as tl
5 | import torch
6 | import torch.nn.functional as F
7 | from torch.autograd import Function
8 | # from fastpscan.cuda_v3 import fn as pscan_cuda_fn
9 | # from fastpscan.cuda_v4 import fn as pscan_cuda_fn
10 | from fastpscan.triton_v2 import fn as pscan_cuda_fn
11 |
12 |
13 | def pscan_fast(v, f):
14 | B, L, C = v.shape
15 | desire_B = max(B, 4)
16 | h = torch.zeros((desire_B, C), device=v.device, dtype=v.dtype)
17 |
18 | v = v * (1 - f)
19 | if L < 1024:
20 | v = torch.cat((v, torch.zeros((B, 1024 - L, C), device=v.device, dtype=v.dtype)), dim=-2)
21 | f = torch.cat((f, torch.zeros((B, 1024 - L, C), device=v.device, dtype=v.dtype)), dim=-2)
22 | if desire_B > B:
23 | v = torch.cat((v, torch.zeros((desire_B - B, 1024, C), device=v.device, dtype=v.dtype)), dim=0)
24 | f = torch.cat((f, torch.zeros((desire_B - B, 1024, C), device=f.device, dtype=f.dtype)), dim=0)
25 |
26 | Y = pscan_cuda_fn(f, v, h)
27 | Y = Y[:B, :L, :]
28 | return Y
29 |
30 |
31 |
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------
/offpolicy_rnn/models/gilr/scan_triton/real_rnn_tie_input_gate_cpu.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def real_scan_tie_input_gate_no_grad(v, f, h=None):
5 | B, L, C = v.shape
6 | # assert C % 256 == 0, 'Hidden dimension must be a multiple of 256'
7 | output = torch.zeros_like(v)
8 | h = torch.zeros((B, 1, v.shape[-1]), device=v.device, dtype=v.dtype) if h is None else h
9 | for l in range(L):
10 | f_item = f[:, l:l+1, :]
11 | h_new = h * f_item + v[:, l:l+1, :] * (1-f_item)
12 | output[:, l:l+1, :] = h_new
13 | h = h_new
14 | return output, h
15 |
16 | def real_scan_tie_input_gate_no_grad_fuse(v, f, h=None):
17 | B, L, C = v.shape
18 | # assert C % 256 == 0, 'Hidden dimension must be a multiple of 256'
19 | output = torch.zeros_like(v)
20 | h = torch.zeros((B, 1, v.shape[-1]), device=v.device, dtype=v.dtype) if h is None else h
21 | for l in range(L):
22 | f_item = torch.sigmoid(f[:, l:l+1, :])
23 | h_new = h * f_item + v[:, l:l+1, :] * (1-f_item)
24 | output[:, l:l+1, :] = h_new
25 | h = h_new
26 | return output, h
27 |
28 | scan_cpu = real_scan_tie_input_gate_no_grad
29 | scan_cpu_fuse = real_scan_tie_input_gate_no_grad_fuse
30 |
31 |
--------------------------------------------------------------------------------
/offpolicy_rnn/models/gilr_lstm/gilr_lstm.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | try:
7 | from .scan_triton.real_rnn_tie_input_gate import real_scan_tie_input_gate, real_scan_tie_input_gate_fused
8 | from .scan_triton.real_rnn_fast_pscan import pscan_fast
9 | except Exception as _:
10 | pass
11 | from .scan_triton.real_rnn_tie_input_gate_cpu import scan_cpu, scan_cpu_fuse
12 | from ..ensemble_linear_model import EnsembleLinear
13 | class GILRLSTMLayer(nn.Module):
14 | def __init__(
15 | self,
16 | input_dim,
17 | output_dim,
18 | factor=1,
19 | dropout=0.2,
20 | batch_first=True,
21 | ):
22 | super().__init__()
23 | assert batch_first
24 | self.d_model = output_dim
25 | self.in_proj = EnsembleLinear(input_dim, self.d_model * factor, 2, desire_ndim=4)
26 | self.middle_proj = EnsembleLinear(self.d_model * factor, self.d_model * factor, 4, desire_ndim=4)
27 | self.out_proj = torch.nn.Linear(self.d_model * factor, self.d_model * factor)
28 | self.layer_norm = nn.LayerNorm(factor * self.d_model)
29 | self.swish = nn.SiLU()
30 | self.device = torch.device('cpu')
31 |
32 | def rnn_parameters(self):
33 | return list(self.parameters(True))
34 |
35 | def to(self, device):
36 | if not self.device == device:
37 | super().to(device)
38 | self.device = device
39 |
40 | def forward(self, x, hidden=None, rnn_start=None):
41 | u = self.in_proj(x)
42 | v = u[0]
43 | f = u[1]
44 | if hidden is None:
45 | hidden = torch.zeros((v.shape[0], 1, self.d_model * 2), device=v.device)
46 | else:
47 | hidden = hidden.transpose(0, 1)
48 | hidden_pre, hidden_middle = torch.chunk(hidden, 2, -1)
49 | f = torch.sigmoid(f)
50 | v = torch.tanh(v)
51 | if rnn_start is not None:
52 | f = f * (1 - rnn_start)
53 | if torch.all(hidden_pre == 0) and not self.device == torch.device('cpu'):
54 | v = real_scan_tie_input_gate(v.contiguous(), f.contiguous())
55 | hidden_pre = v[:, -1:, :]
56 | else:
57 | v, hidden_pre = scan_cpu(v, f, hidden_pre)
58 | u = self.middle_proj.forward(v)
59 | f = torch.sigmoid(u[0])
60 | i = torch.sigmoid(u[1])
61 | o = torch.sigmoid(u[2])
62 | z = torch.tanh(u[3])
63 | if rnn_start is not None:
64 | f = f * (1 - rnn_start)
65 | if torch.all(hidden_middle == 0) and not self.device == torch.device('cpu'):
66 | v = i * z
67 | out = real_scan_tie_input_gate(v.contiguous(), f.contiguous())
68 | hidden_middle = out[:, -1:, :]
69 | else:
70 | out, hidden_middle = scan_cpu(i * z, f, hidden_middle)
71 | out = out * o
72 | out = self.out_proj(out)
73 | hidden = torch.cat((hidden_pre, hidden_middle), dim=-1)
74 | hidden = hidden.transpose(0, 1)
75 | return out, hidden
76 |
--------------------------------------------------------------------------------
/offpolicy_rnn/models/gilr_lstm/scan_triton/real_rnn_fast_pscan.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import triton
4 | import triton.language as tl
5 | import torch
6 | import torch.nn.functional as F
7 | from torch.autograd import Function
8 | # from fastpscan.cuda_v3 import fn as pscan_cuda_fn
9 | # from fastpscan.cuda_v4 import fn as pscan_cuda_fn
10 | from fastpscan.triton_v2 import fn as pscan_cuda_fn
11 |
12 |
13 | def pscan_fast(v, f):
14 | B, L, C = v.shape
15 | desire_B = max(B, 4)
16 | h = torch.zeros((desire_B, C), device=v.device, dtype=v.dtype)
17 |
18 | v = v * (1 - f)
19 | if L < 1024:
20 | v = torch.cat((v, torch.zeros((B, 1024 - L, C), device=v.device, dtype=v.dtype)), dim=-2)
21 | f = torch.cat((f, torch.zeros((B, 1024 - L, C), device=v.device, dtype=v.dtype)), dim=-2)
22 | if desire_B > B:
23 | v = torch.cat((v, torch.zeros((desire_B - B, 1024, C), device=v.device, dtype=v.dtype)), dim=0)
24 | f = torch.cat((f, torch.zeros((desire_B - B, 1024, C), device=f.device, dtype=f.dtype)), dim=0)
25 |
26 | Y = pscan_cuda_fn(f, v, h)
27 | Y = Y[:B, :L, :]
28 | return Y
29 |
30 |
31 |
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------
/offpolicy_rnn/models/gilr_lstm/scan_triton/real_rnn_tie_input_gate_cpu.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def real_scan_tie_input_gate_no_grad(v, f, h=None):
5 | B, L, C = v.shape
6 | # assert C % 256 == 0, 'Hidden dimension must be a multiple of 256'
7 | output = torch.zeros_like(v)
8 | h = torch.zeros((B, 1, v.shape[-1]), device=v.device, dtype=v.dtype) if h is None else h
9 | for l in range(L):
10 | f_item = f[:, l:l+1, :]
11 | h_new = h * f_item + v[:, l:l+1, :] * (1-f_item)
12 | output[:, l:l+1, :] = h_new
13 | h = h_new
14 | return output, h
15 |
16 | def real_scan_tie_input_gate_no_grad_fuse(v, f, h=None):
17 | B, L, C = v.shape
18 | # assert C % 256 == 0, 'Hidden dimension must be a multiple of 256'
19 | output = torch.zeros_like(v)
20 | h = torch.zeros((B, 1, v.shape[-1]), device=v.device, dtype=v.dtype) if h is None else h
21 | for l in range(L):
22 | f_item = torch.sigmoid(f[:, l:l+1, :])
23 | h_new = h * f_item + v[:, l:l+1, :] * (1-f_item)
24 | output[:, l:l+1, :] = h_new
25 | h = h_new
26 | return output, h
27 |
28 | scan_cpu = real_scan_tie_input_gate_no_grad
29 | scan_cpu_fuse = real_scan_tie_input_gate_no_grad_fuse
30 |
31 |
--------------------------------------------------------------------------------
/offpolicy_rnn/models/lru/scan_triton/complex_rnn_cpu.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def forward_sequential_scan_complex_no_grad(v_real, v_imag, f_real, f_imag, h_real=None, h_imag=None):
5 | B, L, C = v_real.shape
6 | # assert C % 256 == 0, 'Hidden dimension must be a multiple of 256'
7 |
8 | hidden_real = torch.zeros_like(v_real)
9 | hidden_imag = torch.zeros_like(v_imag)
10 |
11 | h_real = torch.zeros((B, 1, v_real.shape[-1]), device=v_real.device, dtype=v_real.dtype) if h_real is None else h_real
12 | h_imag = torch.zeros((B, 1, v_imag.shape[-1]), device=v_imag.device, dtype=v_imag.dtype) if h_imag is None else h_imag
13 |
14 | for l in range(L):
15 | f_real_item = f_real[:, l:l+1, :]
16 | f_imag_item = f_imag[:, l:l+1, :]
17 | h_real_new = h_real * f_real_item - h_imag * f_imag_item + v_real[:, l:l+1, :]
18 | h_imag_new = h_real * f_imag_item + h_imag * f_real_item + v_imag[:, l:l+1, :]
19 |
20 | hidden_real[:, l:l+1, :] = h_real_new
21 | hidden_imag[:, l:l+1, :] = h_imag_new
22 |
23 | h_real = h_real_new
24 | h_imag = h_imag_new
25 |
26 | return hidden_real, hidden_imag, h_real, h_imag
27 |
28 | complex_scan_cpu = forward_sequential_scan_complex_no_grad
29 |
30 |
31 |
--------------------------------------------------------------------------------
/offpolicy_rnn/models/mlp_base.py:
--------------------------------------------------------------------------------
1 | from .rnn_base import RNNBase
2 | from .RNNHidden import RNNHidden
3 | class MLPBase(RNNBase):
4 | def __init__(self, input_size, output_size, hidden_size_list, activation):
5 | super().__init__(input_size, output_size, hidden_size_list, activation, ['fc'] * len(activation))
6 | self.empty_hidden_state = RNNHidden(0, [])
7 |
8 | def meta_forward(self, x, h=None, require_full_hidden=False):
9 | return super(MLPBase, self).meta_forward(x, self.empty_hidden_state, False)[0]
10 |
11 | def forward(self, x):
12 | return self.meta_forward(x)
13 |
--------------------------------------------------------------------------------
/offpolicy_rnn/models/multi_ensemble_linear_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import functional as F
5 | from typing import Dict, List, Union, Tuple, Optional
6 |
7 |
8 | class MultiEnsembleLinear(nn.Module):
9 | def __init__(
10 | self,
11 | input_dim: int,
12 | output_dim: int,
13 | num_ensemble: int,
14 | multi_num: int,
15 | bias: bool=True,
16 | desire_ndim: int=None
17 | ) -> None:
18 | super().__init__()
19 | self.use_bias = bias
20 | self.desire_ndim = desire_ndim
21 | self.num_ensemble = num_ensemble
22 |
23 | self.register_parameter("weight", nn.Parameter(torch.zeros(multi_num, num_ensemble, input_dim, output_dim)))
24 | if self.use_bias:
25 | self.register_parameter("bias", nn.Parameter(torch.zeros(multi_num, num_ensemble, 1, output_dim)))
26 |
27 | nn.init.trunc_normal_(self.weight, std=1/(2*input_dim**0.5))
28 | self.device = torch.device('cpu')
29 |
30 | def forward(self, x: torch.Tensor) -> torch.Tensor:
31 | assert self.desire_ndim is not None, 'for MLP network (data shape is (B, C)), desire_ndim should be 3, for RNN network, (data shape is (B, L, C)), desire_ndim should be 4, while got None'
32 |
33 | weight = self.weight
34 | if self.use_bias:
35 | bias = self.bias
36 | else:
37 | bias = None
38 |
39 | if len(x.shape) == 2:
40 | assert self.desire_ndim == 3
41 | x = torch.einsum('ij,cbjk->cbik', x, weight)
42 | elif len(x.shape) == 3:
43 | if self.desire_ndim == 3:
44 | # target is
45 | x = torch.einsum('bij,cbjk->cbik', x, weight)
46 | elif self.desire_ndim == 4:
47 | x = torch.einsum('hij,cbjk->cbhik', x, weight)
48 | else:
49 | raise NotImplementedError
50 | elif len(x.shape) == 4:
51 | if self.desire_ndim == 4:
52 | x = torch.einsum('bhij,cbjk->cbhik', x, weight)
53 | else:
54 | raise NotImplementedError
55 | elif len(x.shape) == 5:
56 | assert self.desire_ndim == 4
57 | x = torch.einsum('cbhij,cbjk->cbhik', x, weight)
58 | if bias is not None:
59 | if self.desire_ndim == 3:
60 | x = x + bias
61 | elif self.desire_ndim == 4:
62 | x = x + bias.unsqueeze(2)
63 | else:
64 | raise NotImplementedError
65 | return x
66 |
67 | def to(self, device):
68 | if not device == self.device:
69 | self.device = device
70 | super().to(device)
71 | self.weight = self.weight.to(self.device)
72 | if self.use_bias:
73 | self.bias = self.bias.to(self.device)
74 |
75 |
--------------------------------------------------------------------------------
/offpolicy_rnn/models/readme.md:
--------------------------------------------------------------------------------
1 | # Implementations of Advanced RNNs
2 |
3 | Note: Most of these implementations are sourced from GitHub. We have standardized all `forward` functions in these models to conform to the standard RNN API.
4 |
5 | - **conv1d**: We wrap the 1D-convolutional layer as a special kind of RNN.
6 | - **gilr**: A linear RNN implemented with Triton.
7 | - **gilr_lstm**: An LSTM architecture simulated using `gilr`.
8 | - **lru**: Linear Recurrent Unit, implemented with Triton.
9 | - **s6**: Mamba, implemented with Triton.
10 | - **smamba**: Mamba, implemented with officially released CUDA code. An additional `start` variable is introduced as an input, resetting the hidden state to 0 when the `start` flag is true.
11 | - **flash_attention**: GPT implemented with flash_attention.
12 |
--------------------------------------------------------------------------------
/offpolicy_rnn/models/s6/selective_scan/cpu_scan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import einsum, rearrange, repeat
3 |
4 |
5 | # credit: https://github.com/johnma2006/mamba-minimal/blob/master/model.py#L275
6 | def selective_scan_cpu(u, delta, A, B, C, D, start, initial_state=None):
7 | """Does selective scan algorithm. See:
8 | - Section 2 State Space Models in the Mamba paper [1]
9 | - Algorithm 2 in Section 3.2 in the Mamba paper [1]
10 | - run_SSM(A, B, C, u) in The Annotated S4 [2]
11 |
12 | This is the classic discrete state space formula:
13 | x(t + 1) = Ax(t) + Bu(t)
14 | y(t) = Cx(t) + Du(t)
15 | except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).
16 |
17 | Args:
18 | u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
19 | delta: shape (b, l, d_in)
20 | A: shape (d_in, n)
21 | B: shape (b, l, n)
22 | C: shape (b, l, n)
23 | D: shape (d_in,)
24 | start: (b, l, 1)
25 |
26 | Returns:
27 | output: shape (b, l, d_in)
28 |
29 | Official Implementation:
30 | selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
31 | Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
32 |
33 | """
34 | original_dtype = u.dtype
35 | u, delta, A, B, C, D = map(lambda x: x.float(), (u, delta, A, B, C, D))
36 | (b, l, d_in) = u.shape
37 | n = A.shape[1]
38 | start = start.squeeze(-1)
39 | # Discretize continuous parameters (A, B)
40 | # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
41 | # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
42 | # "A is the more important term and the performance doesn't change much with the simplification on B"
43 | deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
44 | deltaA = einsum(deltaA, 1 - start, 'b l d_in n, b l -> b l d_in n')
45 | deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
46 |
47 | # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
48 | # Note that the below is sequential, while the official implementation does a much faster parallel scan that
49 | # is additionally hardware-aware (like FlashAttention).
50 | x = torch.zeros((b, d_in, n), device=deltaA.device)
51 | if initial_state is not None:
52 | x += initial_state
53 | ys = []
54 | for i in range(l):
55 | x = deltaA[:, i] * x + deltaB_u[:, i]
56 | y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
57 | ys.append(y)
58 | y = torch.stack(ys, dim=1) # shape (b, l, d_in)
59 |
60 | y = y + u * D[None, None, :]
61 |
62 | return y.to(original_dtype), x
--------------------------------------------------------------------------------
/offpolicy_rnn/models/smamba/mamba_ssm/ops/triton/layernorm_cpu.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | def layer_norm_fn(x, weight, bias, residual=None, eps=1e-6, prenorm=False, residual_in_fp32=False):
7 | dtype = x.dtype
8 | if residual_in_fp32:
9 | weight = weight.float()
10 | bias = bias.float() if bias is not None else None
11 | if residual_in_fp32:
12 | x = x.float()
13 | residual = residual.float() if residual is not None else residual
14 | if residual is not None:
15 | x = (x + residual).to(x.dtype)
16 | out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
17 | dtype
18 | )
19 | return out if not prenorm else (out, x)
20 |
21 |
22 | def rms_norm_fn(x, weight, bias, residual=None, eps=1e-6, prenorm=False, residual_in_fp32=False):
23 | dtype = x.dtype
24 | if residual_in_fp32:
25 | weight = weight.float()
26 | bias = bias.float() if bias is not None else None
27 | if residual_in_fp32:
28 | x = x.float()
29 | residual = residual.float() if residual is not None else residual
30 | if residual is not None:
31 | x = (x + residual).to(x.dtype)
32 | rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
33 | out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
34 | out = out.to(dtype)
35 | return out if not prenorm else (out, x)
36 |
37 | class RMSNorm(torch.nn.Module):
38 | def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
39 | factory_kwargs = {"device": device, "dtype": dtype}
40 | super().__init__()
41 | self.eps = eps
42 | self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
43 | self.register_parameter("bias", None)
44 | self.reset_parameters()
45 |
46 | def reset_parameters(self):
47 | torch.nn.init.ones_(self.weight)
48 |
49 | def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
50 | return rms_norm_fn(
51 | x,
52 | self.weight,
53 | self.bias,
54 | residual=residual,
55 | eps=self.eps,
56 | prenorm=prenorm,
57 | residual_in_fp32=residual_in_fp32,
58 | )
59 |
60 |
--------------------------------------------------------------------------------
/offpolicy_rnn/policy_value_models/contextual_sac_policy.py:
--------------------------------------------------------------------------------
1 | from .contextual_sac_policy_single_head import ContextualSACPolicySingleHead
2 | from .contextual_sac_policy_double_head import ContextualSACPolicyDoubleHead
3 |
4 | class ContextualSACPolicy(ContextualSACPolicySingleHead):
5 | # class ContextualSACPolicy(ContextualSACPolicyDoubleHead):
6 | def __init__(self, state_dim, action_dim, embedding_size, embedding_hidden, embedding_activations,
7 | embedding_layer_type, uni_model_hidden,
8 | uni_model_activations, uni_model_layer_type, fix_rnn_length, uni_model_input_mapping_dim: int=0,
9 | reward_input=False, last_action_input=True, last_state_input=False, separate_encoder=False,
10 | output_logstd=True, name='ContextualSACPolicy'):
11 | super().__init__(state_dim, action_dim, embedding_size, embedding_hidden, embedding_activations,
12 | embedding_layer_type, uni_model_hidden,
13 | uni_model_activations, uni_model_layer_type, fix_rnn_length, uni_model_input_mapping_dim, reward_input,
14 | last_action_input, last_state_input, separate_encoder, output_logstd,
15 | name=name)
16 |
--------------------------------------------------------------------------------
/offpolicy_rnn/policy_value_models/contextual_td3_policy.py:
--------------------------------------------------------------------------------
1 | from .contextual_sac_policy import ContextualSACPolicy
2 | from typing import List, Union, Tuple, Dict, Optional
3 | from ..models.RNNHidden import RNNHidden
4 | import torch
5 |
6 | class ContextualTD3Policy(ContextualSACPolicy):
7 | def __init__(self, state_dim, action_dim, embedding_size, embedding_hidden, embedding_activations,
8 | embedding_layer_type, uni_model_hidden,
9 | uni_model_activations, uni_model_layer_type, fix_rnn_length, uni_model_input_mapping_dim: int=0,
10 | reward_input=False, last_action_input=True, last_state_input=False, separate_encoder=False, sample_std=0.1):
11 | super().__init__(state_dim, action_dim, embedding_size, embedding_hidden, embedding_activations,
12 | embedding_layer_type, uni_model_hidden,
13 | uni_model_activations, uni_model_layer_type, fix_rnn_length, uni_model_input_mapping_dim, reward_input,
14 | last_action_input, last_state_input, separate_encoder, output_logstd=False, name='ContextualTD3Policy')
15 | self.sample_std = sample_std
16 |
17 |
18 | def forward(self, state: torch.Tensor, lst_state: torch.Tensor, lst_action: torch.Tensor, rnn_memory: Optional[RNNHidden], reward: Optional[torch.Tensor]=None, detach_embedding: bool=False) -> Tuple[
19 | torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, RNNHidden, Optional[RNNHidden]
20 | ]:
21 | """
22 | :param state: (seq_len, dim) or (batch_size, seq_len, dim)
23 | :param lst_state: (seq_len, dim) or (batch_size, seq_len, dim)
24 | :param lst_action: (seq_len, dim) or (batch_size, seq_len, dim)
25 | :param reward: (seq_len, 1) or (batch_size, seq_len, dim)
26 | :param rnn_memory:
27 | :return:
28 | """
29 | embedding_input = self.get_embedding_input(state, lst_state, lst_action, reward)
30 | model_output, rnn_memory, embedding_output, full_rnn_memory = self.meta_forward(embedding_input, state, rnn_memory, detach_embedding)
31 | action_mean = torch.tanh(model_output)
32 | action_sample = torch.tanh(model_output) + torch.randn_like(model_output) * self.sample_std
33 | action_sample = torch.clamp(action_sample, -1, 1)
34 | # not used in TD3
35 | log_prob = torch.zeros_like(action_sample)
36 | return action_mean, embedding_output, action_sample, log_prob, rnn_memory, full_rnn_memory
37 |
--------------------------------------------------------------------------------
/offpolicy_rnn/policy_value_models/contextual_td3_value.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from ..models.contextual_model import ContextualModel
3 | import torch
4 | import os
5 | from ..models.RNNHidden import RNNHidden
6 | from typing import List, Union, Tuple, Dict, Optional
7 | from .utils import nearest_power_of_two, nearest_power_of_two_half
8 | from .contextual_sac_value import ContextualSACValue
9 | class ContextualTD3Value(ContextualSACValue):
10 | def __init__(self, state_dim, action_dim, embedding_size, embedding_hidden, embedding_activations,
11 | embedding_layer_type, uni_model_hidden,
12 | uni_model_activations, uni_model_layer_type, fix_rnn_length, uni_model_input_mapping_dim: int=0,
13 | reward_input=False, last_action_input=True, last_state_input=False, separate_encoder=False):
14 | super(ContextualTD3Value, self).__init__(state_dim, action_dim, embedding_size, embedding_hidden, embedding_activations,
15 | embedding_layer_type, uni_model_hidden,
16 | uni_model_activations, uni_model_layer_type, fix_rnn_length, uni_model_input_mapping_dim, reward_input,
17 | last_action_input, last_state_input, separate_encoder, name='ContextualTD3Value')
18 |
--------------------------------------------------------------------------------
/offpolicy_rnn/policy_value_models/make_models.py:
--------------------------------------------------------------------------------
1 | from .contextual_sac_policy import ContextualSACPolicy
2 | from .contextual_td3_policy import ContextualTD3Policy
3 | from .contextual_sac_value import ContextualSACValue
4 | from .contextual_td3_value import ContextualTD3Value
5 | from .contextual_sac_discrete_policy import ContextualSACDiscretePolicy
6 | from .contextual_sac_discrete_value import ContextualSACDiscreteValue
7 | import torch
8 | from typing import Optional, Union
9 |
10 | def make_policy_model(policy_args, base_alg_name, discrete) -> Union[ContextualTD3Policy, ContextualSACPolicy, ContextualSACDiscretePolicy]:
11 | if base_alg_name == 'sac':
12 | if discrete:
13 | policy = ContextualSACDiscretePolicy(**policy_args)
14 | else:
15 | policy = ContextualSACPolicy(**policy_args)
16 | elif base_alg_name == 'td3':
17 | policy = ContextualTD3Policy(**policy_args)
18 | return policy
19 |
20 | def make_value_model(value_args, base_alg_name, discrete) -> Union[ContextualTD3Value, ContextualSACValue, ContextualSACDiscreteValue]:
21 | if base_alg_name == 'sac':
22 | if discrete:
23 | value = ContextualSACDiscreteValue(**value_args)
24 | else:
25 | value = ContextualSACValue(**value_args)
26 | elif base_alg_name == 'td3':
27 | value = ContextualTD3Value(**value_args)
28 | return value
29 |
--------------------------------------------------------------------------------
/offpolicy_rnn/policy_value_models/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | def nearest_power_of_two_half(x):
4 | # 计算0.5 * x
5 | target = 0.5 * x
6 | # 计算对数,并四舍五入到最接近的整数
7 | nearest_exp = round(math.log(target, 2))
8 | # 计算2的nearest_exp次幂
9 | if nearest_exp < 0:
10 | nearest_exp = 0
11 | nearest_power = int(math.ceil(2 ** nearest_exp))
12 |
13 | return nearest_power
14 |
15 | def nearest_power_of_two(x):
16 | # 计算0.5 * x
17 | target = x
18 | # 计算对数,并四舍五入到最接近的整数
19 | nearest_exp = int(math.ceil(math.log(target, 2)))
20 | # 计算2的nearest_exp次幂
21 | if nearest_exp < 0:
22 | nearest_exp = 0
23 | nearest_power = int(math.ceil(2 ** nearest_exp))
24 | return nearest_power
--------------------------------------------------------------------------------
/offpolicy_rnn/utility/ValueScheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | class CosineScheduler:
4 | def __init__(self, total_steps, initial_value, min_value):
5 | """
6 | 初始化余弦学习率调度器
7 | :param total_steps: 总轮数
8 | :param initial_value: 初始学习率
9 | :param min_value: 最小学习率
10 | """
11 | self.total_steps = total_steps
12 | self.initial_value = initial_value
13 | self.min_value = min_value
14 | self.current_step = 0
15 | self._value = None
16 | self._last_value = None
17 |
18 | def validate(self):
19 | if self._value < self.min_value:
20 | self._value = self.min_value
21 |
22 | if self._last_value is not None:
23 | if self._last_value < self._value:
24 | self._value = self._last_value
25 | else:
26 | self._last_value = self._value
27 |
28 | def step(self):
29 | """
30 | 更新到下一步,计算当前的学习率
31 | """
32 | self.current_step += 1
33 |
34 | self._value = self.min_value + 0.5 * (self.initial_value - self.min_value) * (1 + math.cos(self.current_step / self.total_steps * math.pi))
35 | self.validate()
36 | return self._value
37 |
38 | def get_value(self):
39 | """
40 | 获取当前的学习率
41 | """
42 | if self.current_step == 0:
43 | return self.initial_value
44 | else:
45 | return self._value
46 |
47 | class LinearScheduler:
48 | def __init__(self, total_steps, initial_value, end_value):
49 | """
50 | 初始化余弦学习率调度器
51 | :param total_steps: 总轮数
52 | :param initial_value: 初始学习率
53 | :param min_value: 最小学习率
54 | """
55 | self.total_steps = total_steps
56 | self.initial_value = initial_value
57 | self.end_value = end_value
58 | self.current_step = 0
59 | self._value = self.initial_value
60 |
61 | def validate(self):
62 | if self.current_step >= self.total_steps:
63 | self._value = self.end_value
64 |
65 | def step(self):
66 | """
67 | 更新到下一步,计算当前的学习率
68 | """
69 | self.current_step += 1
70 | self._value = self.current_step / self.total_steps * (self.end_value - self.initial_value) + self.initial_value
71 | self.validate()
72 | return self._value
73 |
74 | def get_value(self):
75 | """
76 | 获取当前的学习率
77 | """
78 | if self.current_step == 0:
79 | return self.initial_value
80 | else:
81 | return self._value
--------------------------------------------------------------------------------
/offpolicy_rnn/utility/alg_init.py:
--------------------------------------------------------------------------------
1 | from ..algorithm.sac import SAC
2 | from ..algorithm.sac_mlp import SAC_MLP
3 | from ..parameter.ParameterSAC import Parameter
4 | from ..algorithm.sac_rnn_slice import SACRNNSlice
5 | from ..algorithm.sac_mlp_redq import SAC_MLP_REDQ
6 | from ..algorithm.sac_full_length_rnn_ensembleQ import SACFullLengthRNNEnsembleQ
7 | from ..algorithm.sac_full_length_rnn_redq import SACFullLengthRNNREDQ
8 | from ..algorithm.sac_mlp_redq_ensemble_q import SAC_MLP_REDQ_EnsembleQ
9 | from ..algorithm.sac_full_length_rnn_redq_sep_optim import SACFullLengthRNNREDQ_SEP_OPTIM
10 | from ..algorithm.sac_full_length_rnn_ensembleQ_sep_optim import SACFullLengthRNNENSEMBLEQ_SEP_OPTIM
11 | from ..algorithm.td3_full_length_rnn_ensembleQ import TD3FullLengthRNNEnsembleQ
12 | from ..algorithm.td3_full_length_rnn_redq import TD3FullLengthRNNREDQ
13 | from ..algorithm.td3_full_length_rnn_redq_sep_optim import TD3FullLengthRNNREDQ_SEP_OPTIM
14 |
15 |
16 | def alg_init(parameter: Parameter) -> SAC:
17 | if parameter.alg_name == 'sac_no_train':
18 | sac = SAC(parameter)
19 | elif parameter.alg_name == 'sac_mlp':
20 | sac = SAC_MLP(parameter)
21 | elif parameter.alg_name == 'sac_mlp_redq':
22 | sac = SAC_MLP_REDQ(parameter)
23 | elif parameter.alg_name == 'sac_rnn_slice':
24 | assert parameter.rnn_slice_length > 0
25 | sac = SACRNNSlice(parameter)
26 | elif parameter.alg_name == 'sac_rnn_full_horizon_ensembleQ':
27 | sac = SACFullLengthRNNEnsembleQ(parameter)
28 | elif parameter.alg_name == 'sac_rnn_full_horizon_redQ':
29 | sac = SACFullLengthRNNREDQ(parameter)
30 | elif parameter.alg_name == 'sac_rnn_full_horizon_redQ_sep_optim': # STAR it!!!!
31 | sac = SACFullLengthRNNREDQ_SEP_OPTIM(parameter)
32 | elif parameter.alg_name == 'td3_rnn_full_horizon_redQ_sep_optim': # STAR it!!!!
33 | parameter.base_algorithm = 'td3'
34 | sac = TD3FullLengthRNNREDQ_SEP_OPTIM(parameter)
35 | elif parameter.alg_name == 'sac_rnn_full_horizon_ensemble_q_sep_optim':
36 | sac = SACFullLengthRNNENSEMBLEQ_SEP_OPTIM(parameter)
37 | elif parameter.alg_name == 'sac_mlp_redq_ensemble_q':
38 | sac = SAC_MLP_REDQ_EnsembleQ(parameter)
39 | elif parameter.alg_name == 'td3_rnn_full_horizon_ensembleQ':
40 | parameter.base_algorithm = 'td3'
41 | sac = TD3FullLengthRNNEnsembleQ(parameter)
42 | elif parameter.alg_name == 'td3_rnn_full_horizon_redQ':
43 | parameter.base_algorithm = 'td3'
44 | sac = TD3FullLengthRNNREDQ(parameter)
45 | else:
46 | raise NotImplementedError(f'Algorithm {parameter.alg_name} has not been implemented!')
47 | return sac
48 |
--------------------------------------------------------------------------------
/offpolicy_rnn/utility/count_parameters.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | def count_parameters(model):
4 | """
5 | Counts the number of trainable parameters in a PyTorch model.
6 |
7 | Args:
8 | model (nn.Module): The PyTorch model to count parameters for.
9 |
10 | Returns:
11 | int: The total number of trainable parameters.
12 | """
13 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
14 |
--------------------------------------------------------------------------------
/offpolicy_rnn/utility/q_value_guard.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class QValueGuard:
5 | """
6 | avoid Q loss from diverging
7 | usage:
8 | target_Q = r + gamma * q_value_guard.clamp(next_q)
9 | q_value_guard.update(target_Q)
10 | """
11 | def __init__(self, guard_min=True, guard_max=True, decay_ratio=1.0):
12 | self._min = 1000000 if guard_min else None
13 | self._max = -1000000 if guard_max else None
14 | self._init_flag = True
15 | self._decay_ratio = decay_ratio
16 |
17 | def reset(self):
18 | self._min = 1000000 if self._min is not None else None
19 | self._max = -1000000 if self._max is not None else None
20 | self._init_flag = True
21 |
22 | def clamp(self, value: torch.Tensor):
23 | if self._init_flag:
24 | self._min = value.min().item() if self._min is not None else None
25 | self._max = value.max().item() if self._max is not None else None
26 | self._init_flag = False
27 | return value.clamp(min=self._min, max=self._max)
28 |
29 | def update(self, value):
30 | value_min = value.min().item()
31 | value_max = value.max().item()
32 | self._min = min(self._min, value_min if self._min is not None else None)
33 | self._max = max(self._max, value_max if self._max is not None else None)
34 | if self._decay_ratio < 1:
35 | if self._min is not None:
36 | self._min = self._decay_ratio * self._min + (1-self._decay_ratio) * value_min
37 | if self._max is not None:
38 | self._max = self._decay_ratio * self._max + (1-self._decay_ratio) * value_max
39 |
40 | def get_min(self) -> float:
41 | return self._min
42 |
43 | def get_max(self) -> float:
44 | return self._max
45 |
46 |
--------------------------------------------------------------------------------
/offpolicy_rnn/utility/timer.py:
--------------------------------------------------------------------------------
1 | import time
2 | import inspect
3 | import numpy as np
4 |
5 | class Timer:
6 | def __init__(self):
7 | self.check_points = {}
8 | self.points_time = {}
9 | self.need_summary = {}
10 | self.init_time = time.time()
11 |
12 | def reset(self):
13 | self.check_points = {}
14 | self.points_time = {}
15 | self.need_summary = {}
16 |
17 | @staticmethod
18 | def file_func_line(stack=1):
19 | frame = inspect.stack()[stack][0]
20 | info = inspect.getframeinfo(frame)
21 | return info.filename, info.function, info.lineno
22 |
23 | @staticmethod
24 | def line(stack=2, short=False):
25 | file, func, lineo = Timer.file_func_line(stack)
26 | if short:
27 | return f"line_{lineo}_func_{func}"
28 | return f"line: {lineo}, func: {func}, file: {file}"
29 |
30 | def register_point(self, tag=None, stack=3, short=True, need_summary=True, level=0):
31 | if tag is None:
32 | tag = self.line(stack, short)
33 | if False and not tag.startswith('__'):
34 | print(f'arrive {tag}, time: {time.time() - self.init_time}, level: {level}')
35 | if level not in self.check_points:
36 | self.check_points[level] = []
37 | self.points_time[level] = []
38 | self.need_summary[level] = set()
39 | self.check_points[level].append(tag)
40 | self.points_time[level].append(time.time())
41 | if need_summary:
42 | self.need_summary[level].add(tag)
43 |
44 | def register_end(self, stack=4, level=0):
45 | self.register_point('__timer_end_unique', stack, need_summary=False, level=level)
46 |
47 | def summary(self, summation=False):
48 | if len(self.check_points) == 0:
49 | return dict()
50 | res = {}
51 | for level in self.check_points:
52 | self.register_point('__timer_finale_unique', level=level)
53 | res_tmp = {}
54 | for ind, item in enumerate(self.check_points[level][:-1]):
55 | time_now = self.points_time[level][ind]
56 | time_next = self.points_time[level][ind + 1]
57 | if item in res_tmp:
58 | res_tmp[item].append(time_next - time_now)
59 | else:
60 | res_tmp[item] = [time_next - time_now]
61 | for k, v in res_tmp.items():
62 | if k in self.need_summary[level]:
63 | res['period_' + k] = np.mean(v)
64 | if summation:
65 | res['sum_' + k] = np.sum(v)
66 | self.reset()
67 | return res
68 |
69 |
70 | def test_timer():
71 | timer = Timer()
72 | for i in range(4):
73 | timer.register_point()
74 | time.sleep(1)
75 | for k, v in timer.summary().items():
76 | print(f'{k}, {v}')
77 |
78 | if __name__ == '__main__':
79 | test_timer()
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | Box2D==2.3.10
2 | causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git#egg=causal_conv1d
3 | cocos==0.2.2
4 | dm_control==1.0.16
5 | dmc2gym @ git+https://github.com/denisyarats/dmc2gym.git#egg=dmc2gym
6 | einops==0.8.0
7 | flash_attn @ git+https://github.com/Dao-AILab/flash-attention.git#egg=flash_attn
8 | gin==0.1.006
9 | gym==0.21.0
10 | jax==0.4.30
11 | matplotlib==3.7.4
12 | meshcat==0.3.2
13 | ml_collections==0.1.1
14 | mujoco_py==2.1.2.14
15 | numpy==1.24.0
16 | omg==1.1.3
17 | packaging==24.1
18 | pandas==2.0.3
19 | psutil==5.9.7
20 | pybullet==3.2.5
21 | pycolab==1.2
22 | pygame==2.5.2
23 | pyglet==2.0.15
24 | pytest==7.4.4
25 | python_dateutil==2.8.2
26 | PyYAML==6.0.1
27 | PyYAML==6.0.1
28 | roboschool==1.0.34
29 | scikit_learn==1.5.1
30 | scipy==1.14.0
31 | seaborn==0.13.2
32 | setuptools==66.0.0
33 | six==1.16.0
34 | smart_logger @ git+https://github.com/FanmingL/SmartLogger#egg=smart_logger
35 | tensorboardX==2.6.2.2
36 | tensorboardX==2.6.2.2
37 | tensorflow_cpu==2.13.1
38 | torch==2.1.2+cu121.with.pypi.cudnn
39 | tqdm==4.66.1
40 | transformers==4.39.3
41 | transforms3d==0.4.2
42 | mamba_ssm @ git+https://github.com/state-spaces/mamba#egg=mamba_ssm
43 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='offpolicy_rnn',
5 | version='0.0.1',
6 | author="Fan-Ming Luo",
7 | author_email="luofm@lamda.nju.edu.cn",
8 | description="An implementation of RNN policy and value function in off-policy RL",
9 | packages=find_packages(),
10 | install_requires=[
11 | ],
12 | )
13 |
--------------------------------------------------------------------------------