├── .gitignore ├── envs ├── credit_assign │ ├── __init__.py │ ├── catch.py │ └── key_to_door │ │ ├── common.py │ │ ├── env.py │ │ ├── game.py │ │ ├── key_to_door.py │ │ ├── objects.py │ │ ├── readme.md │ │ └── tvt_wrapper.py ├── dmc │ ├── __init__.py │ └── dmc_env.py ├── make_pomdp_env.py ├── memory_envs │ ├── __init__.py │ ├── configs │ │ ├── keytodoor.py │ │ ├── terminal_fns.py │ │ ├── tmaze_active.py │ │ ├── tmaze_passive.py │ │ └── visual_match.py │ ├── key_to_door │ │ ├── common.py │ │ ├── env.py │ │ ├── game.py │ │ ├── key_to_door.py │ │ ├── objects.py │ │ ├── readme.md │ │ ├── tvt_wrapper.py │ │ └── visual_match.py │ ├── make_env.py │ ├── readme.md │ └── tmaze.py ├── meta │ ├── __init__.py │ ├── dynamics_meta_env_wrapper.py │ ├── example_env.py │ ├── make_env.py │ ├── mujoco │ │ ├── ant.py │ │ ├── ant_dir.py │ │ ├── ant_goal.py │ │ ├── ant_multitask_base.py │ │ ├── assets │ │ │ ├── ant.xml │ │ │ └── low_gear_ratio_ant.xml │ │ ├── core │ │ │ ├── __init__.py │ │ │ └── serializable.py │ │ ├── half_cheetah.py │ │ ├── half_cheetah_dir.py │ │ ├── half_cheetah_vel.py │ │ ├── humanoid_dir.py │ │ └── mujoco_env.py │ ├── readme.md │ ├── toy_navigation │ │ ├── gridworld.py │ │ ├── point_robot.py │ │ └── wind.py │ └── wrappers.py ├── pomdp │ ├── __init__.py │ ├── readme.md │ └── wrappers.py ├── pomdp_config.py ├── readme.md ├── rl_generalization │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── setup.py │ ├── sunblaze_envs │ │ ├── README.md │ │ ├── __init__.py │ │ ├── assets │ │ │ └── vizdoom │ │ │ │ ├── .gitignore │ │ │ │ ├── basic.wad │ │ │ │ ├── basic_floor_ceiling_flipped.wad │ │ │ │ ├── basic_torches.wad │ │ │ │ ├── navigation.wad │ │ │ │ ├── navigation_floor_ceiling_flipped.wad │ │ │ │ ├── navigation_new_layout.wad │ │ │ │ ├── navigation_torches.wad │ │ │ │ ├── texture_set_a.txt │ │ │ │ ├── texture_set_b.txt │ │ │ │ ├── thing_set_a.txt │ │ │ │ └── thing_set_b.txt │ │ ├── base.py │ │ ├── breakout.py │ │ ├── classic_control.py │ │ ├── monitor.py │ │ ├── mujoco.py │ │ ├── physical_world.py │ │ ├── registration.py │ │ ├── space_invaders.py │ │ ├── time_limit.py │ │ ├── vizdoom.py │ │ └── wrappers.py │ └── test.py ├── torchkit │ ├── constant.py │ ├── core.py │ ├── distributions.py │ ├── modules.py │ ├── networks.py │ ├── policies_base.py │ ├── pytorch_utils.py │ └── serializable.py ├── utils │ ├── evaluation.py │ ├── helpers.py │ ├── logger.py │ └── system.py └── yang_domains │ ├── __init__.py │ ├── ant_reacher_top.py │ ├── assets │ ├── ant_reacher.xml │ ├── box_1d.xml │ ├── bump_1d.xml │ ├── bump_1d_model.xml │ ├── bumps_1d_model.xml │ ├── clockwise.png │ ├── grid_markers.xml │ ├── gripah │ │ ├── gripah.urdf │ │ ├── narrow_finger.STL │ │ ├── narrow_finger_rescaled.STL │ │ ├── narrow_finger_tip.STL │ │ ├── wide_finger.STL │ │ ├── wide_finger_rescaled.STL │ │ ├── wide_finger_tip.STL │ │ ├── wrist.STL │ │ └── wrist_rescaled.STL │ ├── gripah_asset.xml │ ├── gripah_body.xml │ ├── gripah_contact.xml │ ├── muj_gripper │ │ ├── c_base.stl │ │ ├── c_forearm.stl │ │ ├── c_robotiq_85_gripper_joint_3_L.stl │ │ ├── c_robotiq_85_gripper_joint_3_R.stl │ │ ├── c_shoulder.stl │ │ ├── c_upperarm.stl │ │ ├── c_wrist1.stl │ │ ├── c_wrist2.stl │ │ ├── c_wrist3.stl │ │ ├── glass_cup.stl │ │ ├── glass_cup_2.stl │ │ ├── glass_cup_3.stl │ │ ├── inner_finger_coarse.stl │ │ ├── inner_finger_fine.stl │ │ ├── inner_knuckle_coarse.stl │ │ ├── inner_knuckle_fine.stl │ │ ├── new_solo_cup.stl │ │ ├── outer_finger_coarse.stl │ │ ├── outer_finger_fine.stl │ │ ├── outer_knuckle_coarse.stl │ │ ├── outer_knuckle_fine.stl │ │ ├── red_solo_cup.stl │ │ ├── robotiq_85_base_link_coarse.stl │ │ ├── robotiq_85_base_link_fine.stl │ │ ├── smaller_solo_cup.stl │ │ ├── solo_cup.stl │ │ ├── upd_solo_cup.stl │ │ ├── v_base.stl │ │ ├── v_forearm.stl │ │ ├── v_robotiq_85_gripper_joint_3_L.stl │ │ ├── v_robotiq_85_gripper_joint_3_R.stl │ │ ├── v_shoulder.stl │ │ ├── v_upperarm.stl │ │ ├── v_wrist1.stl │ │ ├── v_wrist2.stl │ │ └── v_wrist3.stl │ ├── objects │ │ ├── bump_40_mujoco.STL │ │ ├── bump_50_mujoco.STL │ │ ├── bump_80_mujoco.STL │ │ ├── plate_half.STL │ │ └── plate_whole.STL │ ├── topplate_model.xml │ └── ur5_reacher.xml │ ├── box_top.py │ ├── bump_mdp.py │ ├── bump_top.py │ ├── car.py │ ├── car_episodic.py │ ├── car_top.py │ ├── car_top_narrow.py │ ├── car_top_relative.py │ ├── cartpole_balance.py │ ├── cartpole_var_len.py │ ├── dmc_acrobot.py │ ├── dmc_cart2pole.py │ ├── dmc_cart3pole.py │ ├── dmc_cartpole_b.py │ ├── dmc_cartpole_su.py │ ├── dmc_pendulum_su.py │ ├── dmc_walker_walk.py │ ├── pendulum_swingup_from_vrm.py │ ├── pendulum_var_len.py │ ├── pybullet_ant.py │ ├── pybullet_halfcheetah.py │ ├── reacher.py │ ├── robot_envs │ ├── __init__.py │ ├── assets │ │ ├── bumps │ │ │ ├── bump_40.stl │ │ │ ├── bump_40_blue.urdf │ │ │ ├── bump_40_red.urdf │ │ │ ├── bump_40_virtual.urdf │ │ │ ├── bump_50.stl │ │ │ └── bump_50.urdf │ │ ├── cup │ │ │ ├── cup.stl │ │ │ └── cup.urdf │ │ ├── plane │ │ │ ├── checker_blue.png │ │ │ ├── plane.mtl │ │ │ ├── plane.urdf │ │ │ └── plane100.obj │ │ ├── plate │ │ │ ├── plate.stl │ │ │ ├── plate.urdf │ │ │ ├── plate_half.urdf │ │ │ ├── plate_holder.stl │ │ │ ├── plate_holder.urdf │ │ │ ├── plate_lower_half.stl │ │ │ └── plate_upper_half.stl │ │ ├── shelf │ │ │ ├── shelf_back_board.stl │ │ │ ├── shelf_back_board.urdf │ │ │ ├── shelf_horizontal_board.stl │ │ │ ├── shelf_horizontal_board.urdf │ │ │ ├── shelf_side_board.stl │ │ │ └── shelf_side_board.urdf │ │ └── workspace │ │ │ ├── grid_mark.urdf │ │ │ ├── plane.obj │ │ │ ├── rail.urdf │ │ │ └── workspace.urdf │ ├── bump_target.py │ ├── bumps_diff.py │ ├── bumps_env.py │ ├── bumps_norm.py │ ├── bumps_norm_minmax.py │ ├── bumps_norm_punish.py │ ├── bumps_norm_real.py │ ├── bumps_norm_test.py │ ├── env.py │ └── top_plate.py │ ├── robots │ ├── __init__.py │ ├── assets │ │ ├── jaco │ │ │ ├── j2s7s300_gym.urdf │ │ │ └── meshes │ │ │ │ ├── arm.SLDPRT │ │ │ │ ├── arm.STL │ │ │ │ ├── arm.dae │ │ │ │ ├── arm_half_1.STL │ │ │ │ ├── arm_half_1.dae │ │ │ │ ├── arm_half_2.STL │ │ │ │ ├── arm_half_2.dae │ │ │ │ ├── arm_mico.STL │ │ │ │ ├── arm_mico.dae │ │ │ │ ├── base.STL │ │ │ │ ├── base.dae │ │ │ │ ├── finger_distal.STL │ │ │ │ ├── finger_distal.dae │ │ │ │ ├── finger_proximal.STL │ │ │ │ ├── finger_proximal.dae │ │ │ │ ├── forearm.STL │ │ │ │ ├── forearm.dae │ │ │ │ ├── forearm_mico.STL │ │ │ │ ├── forearm_mico.dae │ │ │ │ ├── hand_2finger.STL │ │ │ │ ├── hand_2finger.dae │ │ │ │ ├── hand_3finger.STL │ │ │ │ ├── hand_3finger.dae │ │ │ │ ├── ring_big.STL │ │ │ │ ├── ring_big.dae │ │ │ │ ├── ring_small.STL │ │ │ │ ├── ring_small.dae │ │ │ │ ├── shoulder.STL │ │ │ │ ├── shoulder.dae │ │ │ │ ├── wrist.STL │ │ │ │ ├── wrist.dae │ │ │ │ ├── wrist_spherical_1.STL │ │ │ │ ├── wrist_spherical_1.dae │ │ │ │ ├── wrist_spherical_2.STL │ │ │ │ └── wrist_spherical_2.dae │ │ ├── rdda │ │ │ ├── meshes │ │ │ │ ├── narrow_finger.STL │ │ │ │ ├── wide_finger.STL │ │ │ │ └── wrist.STL │ │ │ └── rdda.urdf │ │ ├── robotiq │ │ │ ├── README.md │ │ │ ├── meshes │ │ │ │ ├── robotiq-2f-base.mtl │ │ │ │ ├── robotiq-2f-base.obj │ │ │ │ ├── robotiq-2f-base.stl │ │ │ │ ├── robotiq-2f-coupler.mtl │ │ │ │ ├── robotiq-2f-coupler.obj │ │ │ │ ├── robotiq-2f-coupler.stl │ │ │ │ ├── robotiq-2f-driver.mtl │ │ │ │ ├── robotiq-2f-driver.obj │ │ │ │ ├── robotiq-2f-driver.stl │ │ │ │ ├── robotiq-2f-follower.mtl │ │ │ │ ├── robotiq-2f-follower.obj │ │ │ │ ├── robotiq-2f-follower.stl │ │ │ │ ├── robotiq-2f-pad.stl │ │ │ │ ├── robotiq-2f-spring_link.mtl │ │ │ │ ├── robotiq-2f-spring_link.obj │ │ │ │ └── robotiq-2f-spring_link.stl │ │ │ ├── robotiq_2f_85.urdf │ │ │ └── textures │ │ │ │ ├── gripper-2f_BaseColor.jpg │ │ │ │ ├── gripper-2f_Metallic.jpg │ │ │ │ ├── gripper-2f_Normal.jpg │ │ │ │ └── gripper-2f_Roughness.jpg │ │ ├── shovel │ │ │ ├── meshes │ │ │ │ ├── shovel_base.STL │ │ │ │ └── shovel_blade.STL │ │ │ └── shovel.urdf │ │ ├── spatula │ │ │ ├── meshes │ │ │ │ └── base.obj │ │ │ └── spatula-base.urdf │ │ ├── suction │ │ │ ├── meshes │ │ │ │ ├── base.obj │ │ │ │ ├── head.obj │ │ │ │ ├── mid.obj │ │ │ │ └── tip.obj │ │ │ ├── suction-base.urdf │ │ │ └── suction-head.urdf │ │ └── ur5 │ │ │ ├── collision │ │ │ ├── base.stl │ │ │ ├── forearm.stl │ │ │ ├── shoulder.stl │ │ │ ├── upperarm.stl │ │ │ ├── wrist1.stl │ │ │ ├── wrist2.stl │ │ │ └── wrist3.stl │ │ │ ├── license.txt │ │ │ ├── ur5.urdf │ │ │ └── visual │ │ │ ├── base.stl │ │ │ ├── forearm.stl │ │ │ ├── shoulder.stl │ │ │ ├── upperarm.stl │ │ │ ├── wrist1.stl │ │ │ ├── wrist2.stl │ │ │ └── wrist3.stl │ ├── end_effector.py │ ├── gripper.py │ ├── jaco.py │ ├── rdda.py │ ├── robot.py │ ├── robotiq.py │ ├── shovel.py │ ├── spatula.py │ ├── suction.py │ └── ur5.py │ ├── ur5_mdp_top.py │ ├── ur5_top.py │ ├── utils │ ├── __init__.py │ └── utils.py │ ├── water_maze.py │ ├── water_maze_dense.py │ ├── water_maze_simple.py │ ├── wrappers.py │ └── wrappers_for_image.py ├── gen_tmuxp_gpt_mujoco.py ├── gen_tmuxp_gpt_pomdp.py ├── gen_tmuxp_mamba_dmcontrol.py ├── gen_tmuxp_mamba_dynamics_rnd.py ├── gen_tmuxp_mamba_meta.py ├── gen_tmuxp_mamba_mujoco.py ├── gen_tmuxp_mamba_pomdp.py ├── main.py ├── offpolicy_rnn ├── __init__.py ├── algorithm │ ├── sac.py │ ├── sac_full_length_rnn_ensembleQ.py │ ├── sac_full_length_rnn_ensembleQ_sep_optim.py │ ├── sac_full_length_rnn_redq.py │ ├── sac_full_length_rnn_redq_sep_optim.py │ ├── sac_mlp.py │ ├── sac_mlp_redq.py │ ├── sac_mlp_redq_ensemble_q.py │ ├── sac_rnn_slice.py │ ├── td3_full_length_rnn_ensembleQ.py │ ├── td3_full_length_rnn_redq.py │ └── td3_full_length_rnn_redq_sep_optim.py ├── buffers │ ├── replay_memory.py │ ├── replay_memory_tail_padding.py │ └── transition_buffer │ │ ├── nested_replay_memory.py │ │ ├── nested_replay_memory_sub_traj.py │ │ └── replay_memory.py ├── config │ ├── common_config.yaml │ ├── experiment_config.yaml │ └── load_config.py ├── env_utils │ └── make_env.py ├── models │ ├── RNNHidden.py │ ├── contextual_model.py │ ├── conv1d │ │ ├── conv1d.py │ │ └── econv1d.py │ ├── ensemble_linear_model.py │ ├── flash_attention │ │ ├── TransformerFlashAttention.py │ │ └── gpt.py │ ├── gilr │ │ ├── egilr.py │ │ ├── gilr.py │ │ └── scan_triton │ │ │ ├── real_rnn_fast_pscan.py │ │ │ ├── real_rnn_tie_input_gate.py │ │ │ └── real_rnn_tie_input_gate_cpu.py │ ├── gilr_lstm │ │ ├── egilr_lstm.py │ │ ├── gilr_lstm.py │ │ └── scan_triton │ │ │ ├── real_rnn_fast_pscan.py │ │ │ ├── real_rnn_tie_input_gate.py │ │ │ └── real_rnn_tie_input_gate_cpu.py │ ├── lru │ │ ├── elru.py │ │ ├── lru.py │ │ └── scan_triton │ │ │ ├── complex_rnn.py │ │ │ ├── complex_rnn_2.py │ │ │ ├── complex_rnn_cpu.py │ │ │ ├── complex_rnn_jax.py │ │ │ └── pscan.py │ ├── mlp_base.py │ ├── multi_ensemble_linear_model.py │ ├── readme.md │ ├── rnn_base.py │ ├── s6 │ │ ├── mamba.py │ │ ├── mamba_no_conv.py │ │ └── selective_scan │ │ │ ├── cpu_scan.py │ │ │ └── triton_scan.py │ ├── smamba │ │ ├── mamba.py │ │ └── mamba_ssm │ │ │ └── ops │ │ │ ├── selective_scan_interface_new.py │ │ │ └── triton │ │ │ ├── layernorm.py │ │ │ ├── layernorm_cpu.py │ │ │ └── selective_state_update.py │ └── torch_utility.py ├── parameter │ └── ParameterSAC.py ├── policy_value_models │ ├── contextual_sac_discrete_policy.py │ ├── contextual_sac_discrete_value.py │ ├── contextual_sac_policy.py │ ├── contextual_sac_policy_double_head.py │ ├── contextual_sac_policy_single_head.py │ ├── contextual_sac_value.py │ ├── contextual_td3_policy.py │ ├── contextual_td3_value.py │ ├── make_models.py │ └── utils.py └── utility │ ├── ValueScheduler.py │ ├── alg_init.py │ ├── count_parameters.py │ ├── q_value_guard.py │ ├── sample_utility.py │ └── timer.py ├── readme.md ├── readme_cn.md ├── requirement.txt ├── results.md └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | build 3 | dist 4 | *.egg-info 5 | __pycache__ 6 | logfile 7 | logfile_bk 8 | .pytest_cache 9 | .DS_Store 10 | run_all.json 11 | -------------------------------------------------------------------------------- /envs/credit_assign/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import collections 3 | if sys.version_info.major == 3 and sys.version_info.minor >= 10 and not hasattr(collections, "Mapping"): 4 | setattr(collections, "Mapping", collections.abc.Mapping) 5 | setattr(collections, "Sequence", collections.abc.Sequence) 6 | 7 | from gym.envs.registration import register 8 | from .key_to_door import key_to_door 9 | 10 | delay_fn = lambda runs: (runs - 1) * 7 + 6 11 | 12 | for runs in [1, 2, 5, 10, 20, 40]: 13 | register( 14 | f"Catch-{runs}-v0", 15 | entry_point="envs.credit_assign.catch:DelayedCatch", 16 | kwargs=dict( 17 | delay=delay_fn(runs), 18 | flatten_img=True, 19 | one_hot_actions=False, 20 | ), 21 | max_episode_steps=delay_fn(runs), 22 | ) 23 | 24 | # optimal expected return: 1.0 * (~23) + 5.0 = 28. due to unknown number of respawned apples 25 | register( 26 | "KeytoDoor-SR-v0", 27 | entry_point="envs.credit_assign.key_to_door.tvt_wrapper:KeyToDoor", 28 | kwargs=dict( 29 | flatten_img=True, 30 | one_hot_actions=False, 31 | apple_reward=1.0, 32 | final_reward=5.0, 33 | respawn_every=20, # apple respawn after 20 steps 34 | REWARD_GRID=key_to_door.REWARD_GRID_SR, 35 | max_frames=key_to_door.MAX_FRAMES_PER_PHASE_SR, 36 | ), 37 | max_episode_steps=sum(key_to_door.MAX_FRAMES_PER_PHASE_SR.values()), 38 | ) 39 | -------------------------------------------------------------------------------- /envs/credit_assign/key_to_door/game.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Pycolab Game interface.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import abc 23 | import six 24 | 25 | 26 | @six.add_metaclass(abc.ABCMeta) 27 | class AbstractGame(object): 28 | """Abstract base class for Pycolab games.""" 29 | 30 | @abc.abstractmethod 31 | def __init__(self, rng, **settings): 32 | """Initialize the game.""" 33 | 34 | @abc.abstractproperty 35 | def num_actions(self): 36 | """Number of possible actions in the game.""" 37 | 38 | @abc.abstractproperty 39 | def colours(self): 40 | """Symbol to colour map for the game.""" 41 | 42 | @abc.abstractmethod 43 | def make_episode(self): 44 | """Factory method for generating new episodes of the game.""" 45 | -------------------------------------------------------------------------------- /envs/credit_assign/key_to_door/readme.md: -------------------------------------------------------------------------------- 1 | # Key-to-Door Environment 2 | source code: https://github.com/deepmind/deepmind-research/tree/master/tvt/pycolab 3 | 4 | We adapt the environment framework and change some configurations to match the descriptions in [Synthetic Returns for Long-Term Credit Assignment](https://arxiv.org/abs/2102.12425) with environment name `"KeytoDoor-SR-v0"`. 5 | -------------------------------------------------------------------------------- /envs/credit_assign/key_to_door/tvt_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | from envs.credit_assign.key_to_door import env, key_to_door 4 | 5 | 6 | class KeyToDoor(gym.Env): 7 | def __init__( 8 | self, 9 | num_apples=10, 10 | apple_reward=1.0, 11 | fix_apple_reward_in_episode=True, 12 | final_reward=10.0, 13 | default_reward=0, 14 | respawn_every=20, 15 | REWARD_GRID=key_to_door.REWARD_GRID_SR, 16 | max_frames=key_to_door.MAX_FRAMES_PER_PHASE_SR, 17 | crop=True, 18 | flatten_img=True, 19 | one_hot_actions=False, 20 | ): 21 | super().__init__() 22 | self.pycolab_env = env.PycolabEnvironment( 23 | game="key_to_door", 24 | num_apples=num_apples, 25 | apple_reward=apple_reward, 26 | fix_apple_reward_in_episode=fix_apple_reward_in_episode, 27 | final_reward=final_reward, 28 | respawn_every=respawn_every, 29 | crop=crop, 30 | default_reward=default_reward, 31 | REWARD_GRID=REWARD_GRID, 32 | max_frames=max_frames, 33 | ) 34 | 35 | self.action_space = gym.spaces.Discrete(4) # 4 directions 36 | self.one_hot_actions = one_hot_actions 37 | 38 | # original agent uses HWC size, but pytorch uses CHW size, so we transpose below 39 | self.img_size = (3, 5, 5) 40 | self.image_space = gym.spaces.Box( 41 | shape=self.img_size, low=0, high=255, dtype=np.uint8 42 | ) 43 | # the pixel normalization should be done in image encoder, not here 44 | 45 | self.flatten_img = flatten_img 46 | if flatten_img: 47 | self.observation_space = gym.spaces.Box( 48 | shape=(np.array(self.img_size).prod(),), low=0, high=255, dtype=np.uint8 49 | ) 50 | else: 51 | self.observation_space = self.image_space 52 | 53 | def _convert_obs(self, obs): 54 | new_obs = np.transpose(obs, (-1, 0, 1)) # (H,W,C) -> (C,H,W) 55 | if self.flatten_img: 56 | new_obs = new_obs.flatten() # -> (C*H*W) 57 | return new_obs 58 | 59 | def step(self, action): 60 | if self.one_hot_actions: 61 | action = np.argmax(action) 62 | obs, r = self.pycolab_env.step(action) 63 | self._ret += r 64 | 65 | info = {} 66 | 67 | if self.pycolab_env._episode.game_over: 68 | done = True 69 | info["success"] = self.pycolab_env.last_phase_reward() > 0.0 70 | else: 71 | done = False 72 | 73 | return self._convert_obs(obs), r, done, info 74 | 75 | def reset(self): 76 | obs, _ = self.pycolab_env.reset() 77 | self._ret = 0.0 78 | 79 | return self._convert_obs(obs) 80 | 81 | 82 | if __name__ == "__main__": 83 | env = KeyToDoor() 84 | obs = env.reset() 85 | done = False 86 | t = 0 87 | while not done: 88 | t += 1 89 | obs, rew, done, info = env.step(env.action_space.sample()) 90 | print(t, rew, info) 91 | -------------------------------------------------------------------------------- /envs/dmc/dmc_env.py: -------------------------------------------------------------------------------- 1 | # dmc environments 2 | import gym 3 | import dmc2gym 4 | import numpy as np 5 | 6 | # 7 | class DMCWrapper(gym.Env): 8 | def __init__(self, domain_name, task_name, seed=None): 9 | # print(f"domain name: {domain_name}, task name: {task_name}, seed: {seed}") 10 | self.env = dmc2gym.make(domain_name=domain_name, task_name=task_name, seed=seed) 11 | self.observation_space = self.env.observation_space 12 | self.action_space = self.env.action_space 13 | self._max_episode_steps = self.env.unwrapped._step_limit 14 | self.seed(seed) 15 | 16 | def seed(self, seed): 17 | self.env.seed(seed) 18 | self.env.action_space.seed(seed) 19 | 20 | def reset(self): 21 | return self.env.reset() 22 | 23 | def step(self, action): 24 | action = np.clip(action, self.action_space.low, self.action_space.high) 25 | return self.env.step(action) 26 | -------------------------------------------------------------------------------- /envs/memory_envs/configs/keytodoor.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from gym.envs.registration import register 3 | from .terminal_fns import finite_horizon_terminal 4 | from ..key_to_door import key_to_door 5 | import copy 6 | 7 | def create_fn(env_name, continuous_act_space=False): 8 | if continuous_act_space: 9 | env_name_fn = lambda distract: f"Mem-SR-{distract}-cont-act-v0" 10 | else: 11 | env_name_fn = lambda distract: f"Mem-SR-{distract}-v0" 12 | env_name_ = env_name 13 | env_name = env_name_fn(env_name) 14 | MAX_FRAMES_PER_PHASE_SR = copy.deepcopy(key_to_door.MAX_FRAMES_PER_PHASE_SR) 15 | 16 | MAX_FRAMES_PER_PHASE_SR.update({"distractor": env_name_}) 17 | # optimal expected return: 1.0 * (~23) + 5.0 = 28. due to unknown number of respawned apples 18 | register( 19 | env_name, 20 | entry_point="envs.memory_envs.key_to_door.tvt_wrapper:KeyToDoor", 21 | kwargs=dict( 22 | flatten_img=True, 23 | one_hot_actions=False, 24 | apple_reward=1.0, 25 | final_reward=5.0, 26 | respawn_every=20, # apple respawn after 20 steps 27 | REWARD_GRID=copy.deepcopy(key_to_door.REWARD_GRID_SR), 28 | max_frames=copy.deepcopy(MAX_FRAMES_PER_PHASE_SR), 29 | continuous_act_space=continuous_act_space 30 | ), 31 | max_episode_steps=sum(MAX_FRAMES_PER_PHASE_SR.values()), # no info if max episode steps = sum(MAX_FRAMES_PER_PHASE_SR.values()) 32 | ) 33 | return env_name 34 | 35 | 36 | # 37 | # def get_config(): 38 | # config = ConfigDict() 39 | # config.create_fn = create_fn 40 | # 41 | # config.env_type = "key_to_door" 42 | # config.terminal_fn = finite_horizon_terminal 43 | # 44 | # config.eval_interval = 50 45 | # config.save_interval = 50 46 | # config.eval_episodes = 20 47 | # 48 | # config.env_name = 60 49 | # 50 | # return config 51 | -------------------------------------------------------------------------------- /envs/memory_envs/configs/terminal_fns.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | 4 | def finite_horizon_terminal(env: gym.Env, done: bool, info: dict) -> bool: 5 | return done 6 | 7 | 8 | def infinite_horizon_terminal(env: gym.Env, done: bool, info: dict) -> bool: 9 | if not done or "TimeLimit.truncated" in info: 10 | terminal = False 11 | else: 12 | terminal = True 13 | return terminal 14 | -------------------------------------------------------------------------------- /envs/memory_envs/configs/tmaze_active.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from gym.envs.registration import register 3 | from .terminal_fns import finite_horizon_terminal 4 | 5 | env_name_fn = lambda l: f"Mem-T-active-{l}-v0" 6 | env_name_cont_fn = lambda l: f"Mem-T-active-{l}-cont-act-v0" 7 | 8 | def create_fn(env_name, continuous_act_space=False): 9 | length = env_name 10 | env_name = env_name_cont_fn(length) if continuous_act_space else env_name_fn(length) 11 | register( 12 | env_name, 13 | entry_point="envs.memory_envs.tmaze:TMazeClassicActive", 14 | kwargs=dict( 15 | corridor_length=length, 16 | penalty=-1.0 / length, # NOTE: \sum_{t=1}^T -1/T = -1 17 | distract_reward=0.0, 18 | continuous_act_space=continuous_act_space 19 | ), 20 | max_episode_steps=length + 2 * 1 + 1, # NOTE: has to define it here 21 | ) 22 | 23 | return env_name 24 | -------------------------------------------------------------------------------- /envs/memory_envs/configs/tmaze_passive.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from gym.envs.registration import register 3 | from .terminal_fns import finite_horizon_terminal 4 | 5 | env_name_fn = lambda l: f"Mem-T-passive-{l}-v0" 6 | env_name_cont_fn = lambda l: f"Mem-T-passive-{l}-cont-act-v0" 7 | 8 | 9 | def create_fn(env_name, continuous_act_space=False): 10 | length = env_name 11 | env_name = env_name_cont_fn(length) if continuous_act_space else env_name_fn(length) 12 | register( 13 | env_name, 14 | entry_point="envs.memory_envs.tmaze:TMazeClassicPassive", 15 | kwargs=dict( 16 | corridor_length=length, 17 | penalty=-1.0 / length, # NOTE: \sum_{t=1}^T -1/T = -1 18 | distract_reward=0.0, 19 | continuous_act_space=continuous_act_space 20 | ), 21 | max_episode_steps=length + 1, # NOTE: has to define it here 22 | ) 23 | 24 | return env_name 25 | 26 | 27 | # def get_config(): 28 | # config = ConfigDict() 29 | # config.create_fn = create_fn 30 | # 31 | # config.env_type = "tmaze_passive" 32 | # config.terminal_fn = finite_horizon_terminal 33 | # 34 | # config.eval_interval = 10 35 | # config.save_interval = 10 36 | # config.eval_episodes = 10 37 | # 38 | # # [1, 2, 5, 10, 30, 50, 100, 300, 500, 1000] 39 | # config.env_name = 10 40 | # config.distract_reward = 0.0 41 | # 42 | # return config 43 | -------------------------------------------------------------------------------- /envs/memory_envs/configs/visual_match.py: -------------------------------------------------------------------------------- 1 | from ml_collections import ConfigDict 2 | from typing import Tuple 3 | from gym.envs.registration import register 4 | from .terminal_fns import finite_horizon_terminal 5 | from ..key_to_door import visual_match 6 | 7 | 8 | def create_fn(config: ConfigDict) -> Tuple[ConfigDict, str]: 9 | env_name_fn = lambda distract: f"passive-visual-{distract}-v0" 10 | env_name = env_name_fn(config.env_name) 11 | MAX_FRAMES_PER_PHASE = visual_match.MAX_FRAMES_PER_PHASE 12 | MAX_FRAMES_PER_PHASE.update({"distractor": config.env_name}) 13 | 14 | # optimal expected return: 1.0 * (~23) + 5.0 = 28. due to unknown number of respawned apples 15 | register( 16 | env_name, 17 | entry_point="envs.key_to_door.tvt_wrapper:VisualMatch", 18 | kwargs=dict( 19 | flatten_img=True, 20 | one_hot_actions=False, 21 | apple_reward=1.0, 22 | final_reward=5.0, 23 | respawn_every=20, # apple respawn after 20 steps 24 | passive=True, 25 | max_frames=MAX_FRAMES_PER_PHASE, 26 | ), 27 | max_episode_steps=sum(MAX_FRAMES_PER_PHASE.values()), 28 | ) 29 | 30 | del config.create_fn 31 | return config, env_name 32 | 33 | 34 | def get_config(): 35 | config = ConfigDict() 36 | config.create_fn = create_fn 37 | 38 | config.env_type = "visual_match" 39 | config.terminal_fn = finite_horizon_terminal 40 | 41 | config.eval_interval = 50 42 | config.save_interval = 50 43 | config.eval_episodes = 20 44 | 45 | config.env_name = 60 46 | 47 | return config 48 | -------------------------------------------------------------------------------- /envs/memory_envs/key_to_door/game.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=g-bad-file-header 2 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Pycolab Game interface.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import abc 23 | import six 24 | 25 | 26 | @six.add_metaclass(abc.ABCMeta) 27 | class AbstractGame(object): 28 | """Abstract base class for Pycolab games.""" 29 | 30 | @abc.abstractmethod 31 | def __init__(self, rng, **settings): 32 | """Initialize the game.""" 33 | 34 | @abc.abstractproperty 35 | def num_actions(self): 36 | """Number of possible actions in the game.""" 37 | 38 | @abc.abstractproperty 39 | def colours(self): 40 | """Symbol to colour map for the game.""" 41 | 42 | @abc.abstractmethod 43 | def make_episode(self): 44 | """Factory method for generating new episodes of the game.""" 45 | -------------------------------------------------------------------------------- /envs/memory_envs/key_to_door/readme.md: -------------------------------------------------------------------------------- 1 | # Key-to-Door Environments 2 | source code: https://github.com/deepmind/deepmind-research/tree/master/tvt/pycolab 3 | 4 | Used in https://arxiv.org/pdf/2307.03864.pdf 5 | -------------------------------------------------------------------------------- /envs/memory_envs/make_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym.wrappers import RescaleAction 3 | 4 | 5 | def make_env( 6 | env_name: str, 7 | seed: int, 8 | ) -> gym.Env: 9 | # Check if the env is in gym. 10 | env = gym.make(env_name) 11 | 12 | env.max_episode_steps = getattr( 13 | env, "max_episode_steps", env.spec.max_episode_steps 14 | ) 15 | 16 | if isinstance(env.action_space, gym.spaces.Box): 17 | env = RescaleAction(env, -1.0, 1.0) 18 | 19 | env.seed(seed) 20 | env.action_space.seed(seed) 21 | env.observation_space.seed(seed) 22 | return env 23 | -------------------------------------------------------------------------------- /envs/memory_envs/readme.md: -------------------------------------------------------------------------------- 1 | # Key-to-Door Environments 2 | Used in https://arxiv.org/pdf/2307.03864.pdf 3 | 4 | 1. Passive T-Maze (pure long memory) shows "Transformer-based agent can solve long-term memory low-dimensional tasks with 5 | good sample efficiency." 6 | 2. Passive Visual Match shows: "Transformer-based agent is more sample-efficient in long-term memory highdimensional tasks." 7 | 3. `Key to Door` and `Active T-Maze` (credit assignment) shows: "Transformer-based RL improves temporal credit assignment compared to LSTMbased RL, but its advantage diminishes in long-term scenarios" -------------------------------------------------------------------------------- /envs/meta/example_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | 4 | class ExampleEnv(gym.Env): 5 | def __init__(self): 6 | super(ExampleEnv, self).__init__() 7 | 8 | def get_task(self): 9 | """ 10 | Return a task description, such as goal position or target velocity. 11 | """ 12 | pass 13 | 14 | def set_goal(self, goal): 15 | """ 16 | Sets goal manually. Mainly used for reward relabelling. 17 | """ 18 | pass 19 | 20 | def reset_task(self, task=None): 21 | """ 22 | Reset the task, either at random (if task=None) or the given task. 23 | """ 24 | pass 25 | 26 | def step(self, action): 27 | """ 28 | Execute one step in the environment. 29 | Should return: state, reward, done, info 30 | where info has to include a field 'task'. 31 | """ 32 | pass 33 | 34 | def reward(self, state, action): 35 | """ 36 | Computes reward function of task. 37 | Returns the reward 38 | """ 39 | pass 40 | 41 | def reset(self): 42 | """ 43 | Reset the environment. This should *NOT* reset the task! 44 | Resetting the task is handled in the varibad wrapper (see wrappers.py). 45 | """ 46 | pass 47 | -------------------------------------------------------------------------------- /envs/meta/make_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | from .wrappers import VariBadWrapper 4 | 5 | # In VariBAD, they use on-policy PPO by vectorized env. 6 | # In BOReL, they use off-policy SAC by single env. 7 | import numpy as np 8 | import contextlib 9 | import random 10 | 11 | @contextlib.contextmanager 12 | def fixed_seed(seed): 13 | """上下文管理器,用于同时固定random和numpy.random的种子""" 14 | state_np = np.random.get_state() 15 | state_random = random.getstate() 16 | np.random.seed(seed) 17 | random.seed(seed) 18 | try: 19 | yield 20 | finally: 21 | np.random.set_state(state_np) 22 | random.setstate(state_random) 23 | 24 | def make_env(env_id, episodes_per_task, seed=None, oracle=False, **kwargs): 25 | """ 26 | kwargs: include n_tasks=num_tasks 27 | """ 28 | with fixed_seed(seed): 29 | env = gym.make(env_id, **kwargs) 30 | if seed is not None: 31 | env.seed(seed) 32 | env.action_space.np_random.seed(seed) 33 | env = VariBadWrapper( 34 | env=env, 35 | episodes_per_task=episodes_per_task, 36 | oracle=oracle, 37 | ) 38 | return env 39 | -------------------------------------------------------------------------------- /envs/meta/mujoco/ant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .mujoco_env import MujocoEnv 4 | 5 | 6 | class AntEnv(MujocoEnv): 7 | def __init__(self, use_low_gear_ratio=False): 8 | # self.init_serialization(locals()) 9 | if use_low_gear_ratio: 10 | xml_path = "low_gear_ratio_ant.xml" 11 | else: 12 | xml_path = "ant.xml" 13 | super().__init__( 14 | xml_path, 15 | frame_skip=5, 16 | automatically_set_obs_and_action_space=True, 17 | ) 18 | 19 | def step(self, a): 20 | torso_xyz_before = self.get_body_com("torso") 21 | self.do_simulation(a, self.frame_skip) 22 | torso_xyz_after = self.get_body_com("torso") 23 | torso_velocity = torso_xyz_after - torso_xyz_before 24 | forward_reward = torso_velocity[0] / self.dt 25 | ctrl_cost = 0.0 # .5 * np.square(a).sum() 26 | contact_cost = ( 27 | 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 28 | ) 29 | survive_reward = 0.0 # 1.0 30 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward 31 | state = self.state_vector() 32 | notdone = np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0 33 | done = not notdone 34 | ob = self._get_obs() 35 | return ( 36 | ob, 37 | reward, 38 | done, 39 | dict( 40 | reward_forward=forward_reward, 41 | reward_ctrl=-ctrl_cost, 42 | reward_contact=-contact_cost, 43 | reward_survive=survive_reward, 44 | torso_velocity=torso_velocity, 45 | ), 46 | ) 47 | 48 | def _get_obs(self): 49 | # this is gym ant obs, should use rllab? 50 | # if position is needed, override this in subclasses 51 | return np.concatenate( 52 | [ 53 | self.sim.data.qpos.flat[2:], 54 | self.sim.data.qvel.flat, 55 | ] 56 | ) 57 | 58 | def reset_model(self): 59 | qpos = self.init_qpos + self.np_random.uniform( 60 | size=self.model.nq, low=-0.1, high=0.1 61 | ) 62 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1 63 | self.set_state(qpos, qvel) 64 | return self._get_obs() 65 | 66 | def viewer_setup(self): 67 | self.viewer.cam.distance = self.model.stat.extent * 0.5 68 | -------------------------------------------------------------------------------- /envs/meta/mujoco/ant_dir.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .ant_multitask_base import MultitaskAntEnv 4 | 5 | 6 | class AntDirEnv(MultitaskAntEnv): 7 | """ 8 | AntDir: forward_backward=True (unlimited tasks) from on-policy varibad code 9 | AntDir2D: forward_backward=False (limited tasks) from off-policy varibad code 10 | """ 11 | 12 | def __init__( 13 | self, 14 | task={}, 15 | n_tasks=None, 16 | max_episode_steps=200, 17 | forward_backward=True, 18 | **kwargs 19 | ): 20 | self.forward_backward = forward_backward 21 | self._max_episode_steps = max_episode_steps 22 | 23 | super(AntDirEnv, self).__init__(task, n_tasks, **kwargs) 24 | 25 | def step(self, action): 26 | torso_xyz_before = np.array(self.get_body_com("torso")) 27 | 28 | direct = (np.cos(self._goal), np.sin(self._goal)) 29 | 30 | self.do_simulation(action, self.frame_skip) 31 | torso_xyz_after = np.array(self.get_body_com("torso")) 32 | torso_velocity = torso_xyz_after - torso_xyz_before 33 | forward_reward = np.dot((torso_velocity[:2] / self.dt), direct) 34 | 35 | ctrl_cost = 0.5 * np.square(action).sum() 36 | contact_cost = ( 37 | 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 38 | ) 39 | survive_reward = 1.0 40 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward 41 | state = self.state_vector() 42 | notdone = np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0 43 | done = not notdone 44 | ob = self._get_obs() 45 | return ( 46 | ob, 47 | reward, 48 | done, 49 | dict( 50 | reward_forward=forward_reward, 51 | reward_ctrl=-ctrl_cost, 52 | reward_contact=-contact_cost, 53 | reward_survive=survive_reward, 54 | torso_velocity=torso_velocity, 55 | ), 56 | ) 57 | 58 | def sample_tasks(self, num_tasks: int): 59 | assert self.forward_backward == False 60 | velocities = np.random.uniform(0.0, 2.0 * np.pi, size=(num_tasks,)) 61 | tasks = [{"goal": velocity} for velocity in velocities] 62 | return tasks 63 | 64 | def _sample_raw_task(self): 65 | assert self.forward_backward == True 66 | velocity = np.random.choice([-1.0, 1.0]) # not 180 degree 67 | task = {"goal": velocity} 68 | return task 69 | -------------------------------------------------------------------------------- /envs/meta/mujoco/ant_goal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .ant_multitask_base import MultitaskAntEnv 4 | 5 | 6 | class AntGoalEnv(MultitaskAntEnv): 7 | def __init__(self, task={}, n_tasks=2, max_episode_steps=200, **kwargs): 8 | super(AntGoalEnv, self).__init__(task, n_tasks, **kwargs) 9 | self._max_episode_steps = max_episode_steps 10 | 11 | def step(self, action): 12 | self.do_simulation(action, self.frame_skip) 13 | xposafter = np.array(self.get_body_com("torso")) 14 | 15 | goal_reward = -np.sum( 16 | np.abs(xposafter[:2] - self._goal) 17 | ) # make it happy, not suicidal 18 | 19 | ctrl_cost = 0.1 * np.square(action).sum() 20 | contact_cost = ( 21 | 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 22 | ) 23 | survive_reward = 0.0 24 | reward = goal_reward - ctrl_cost - contact_cost + survive_reward 25 | state = self.state_vector() 26 | done = False 27 | ob = self._get_obs() 28 | return ( 29 | ob, 30 | reward, 31 | done, 32 | dict( 33 | goal_forward=goal_reward, 34 | reward_ctrl=-ctrl_cost, 35 | reward_contact=-contact_cost, 36 | reward_survive=survive_reward, 37 | ), 38 | ) 39 | 40 | def sample_tasks(self, num_tasks): 41 | a = np.random.random(num_tasks) * 2 * np.pi 42 | r = 3 * np.random.random(num_tasks) ** 0.5 43 | goals = np.stack((r * np.cos(a), r * np.sin(a)), axis=-1) 44 | tasks = [{"goal": goal} for goal in goals] 45 | return tasks 46 | 47 | def _get_obs(self): 48 | return np.concatenate( 49 | [ 50 | self.sim.data.qpos.flat, 51 | self.sim.data.qvel.flat, 52 | np.clip(self.sim.data.cfrc_ext, -1, 1).flat, 53 | ] 54 | ) 55 | -------------------------------------------------------------------------------- /envs/meta/mujoco/ant_multitask_base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .ant import AntEnv 4 | 5 | # from gym.envs.mujoco.ant import AntEnv 6 | 7 | 8 | class MultitaskAntEnv(AntEnv): 9 | def __init__(self, task={}, n_tasks=2, **kwargs): 10 | self._task = task 11 | self.n_tasks = n_tasks 12 | if n_tasks is None: 13 | self._goal = self._sample_raw_task()["goal"] 14 | else: 15 | self.tasks = self.sample_tasks(n_tasks) 16 | self._goal = self.tasks[0]["goal"] 17 | super(MultitaskAntEnv, self).__init__() 18 | 19 | def get_current_task(self): 20 | # for multi-task MDP 21 | return np.array([self._goal]) 22 | 23 | def get_all_task_idx(self): 24 | return range(len(self.tasks)) 25 | 26 | def reset_task(self, task_info): 27 | if self.n_tasks is None: # unlimited tasks 28 | assert task_info is None 29 | self._task = self._sample_raw_task() # sample here 30 | else: # limited tasks 31 | self._task = self.tasks[task_info] # as idx 32 | self._goal = self._task[ 33 | "goal" 34 | ] # assume parameterization of task by single vector 35 | self.reset() 36 | -------------------------------------------------------------------------------- /envs/meta/mujoco/core/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | General classes, functions, utilities that are used throughout rlkit. 3 | """ 4 | -------------------------------------------------------------------------------- /envs/meta/mujoco/core/serializable.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on rllab's serializable.py file 3 | 4 | https://github.com/rll/rllab 5 | """ 6 | 7 | import inspect 8 | import sys 9 | 10 | 11 | class Serializable(object): 12 | def __init__(self, *args, **kwargs): 13 | self.__args = args 14 | self.__kwargs = kwargs 15 | 16 | def quick_init(self, locals_): 17 | if getattr(self, "_serializable_initialized", False): 18 | return 19 | if sys.version_info >= (3, 0): 20 | spec = inspect.getfullargspec(self.__init__) 21 | # Exclude the first "self" parameter 22 | if spec.varkw: 23 | kwargs = locals_[spec.varkw].copy() 24 | else: 25 | kwargs = dict() 26 | if spec.kwonlyargs: 27 | for key in spec.kwonlyargs: 28 | kwargs[key] = locals_[key] 29 | else: 30 | spec = inspect.getargspec(self.__init__) 31 | if spec.keywords: 32 | kwargs = locals_[spec.keywords] 33 | else: 34 | kwargs = dict() 35 | if spec.varargs: 36 | varargs = locals_[spec.varargs] 37 | else: 38 | varargs = tuple() 39 | in_order_args = [locals_[arg] for arg in spec.args][1:] 40 | self.__args = tuple(in_order_args) + varargs 41 | self.__kwargs = kwargs 42 | setattr(self, "_serializable_initialized", True) 43 | 44 | def __getstate__(self): 45 | return {"__args": self.__args, "__kwargs": self.__kwargs} 46 | 47 | def __setstate__(self, d): 48 | # convert all __args to keyword-based arguments 49 | if sys.version_info >= (3, 0): 50 | spec = inspect.getfullargspec(self.__init__) 51 | else: 52 | spec = inspect.getargspec(self.__init__) 53 | in_order_args = spec.args[1:] 54 | out = type(self)(**dict(zip(in_order_args, d["__args"]), **d["__kwargs"])) 55 | self.__dict__.update(out.__dict__) 56 | 57 | @classmethod 58 | def clone(cls, obj, **kwargs): 59 | assert isinstance(obj, Serializable) 60 | d = obj.__getstate__() 61 | d["__kwargs"] = dict(d["__kwargs"], **kwargs) 62 | out = type(obj).__new__(type(obj)) 63 | out.__setstate__(d) 64 | return out 65 | -------------------------------------------------------------------------------- /envs/meta/mujoco/mujoco_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | 4 | import mujoco_py 5 | import numpy as np 6 | from gym.envs.mujoco import mujoco_env 7 | 8 | from .core.serializable import Serializable 9 | 10 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), "assets") 11 | 12 | 13 | class MujocoEnv(mujoco_env.MujocoEnv, Serializable): 14 | """ 15 | My own wrapper around MujocoEnv. 16 | 17 | The caller needs to declare 18 | """ 19 | 20 | def __init__( 21 | self, 22 | model_path, 23 | frame_skip=1, 24 | model_path_is_local=True, 25 | automatically_set_obs_and_action_space=False, 26 | ): 27 | if model_path_is_local: 28 | model_path = get_asset_xml(model_path) 29 | if automatically_set_obs_and_action_space: 30 | mujoco_env.MujocoEnv.__init__(self, model_path, frame_skip) 31 | else: 32 | """ 33 | Code below is copy/pasted from MujocoEnv's __init__ function. 34 | """ 35 | if model_path.startswith("/"): 36 | fullpath = model_path 37 | else: 38 | fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path) 39 | if not path.exists(fullpath): 40 | raise IOError("File %s does not exist" % fullpath) 41 | self.frame_skip = frame_skip 42 | self.model = mujoco_py.MjModel(fullpath) 43 | self.data = self.model.data 44 | self.viewer = None 45 | 46 | self.metadata = { 47 | "render.modes": ["human", "rgb_array"], 48 | "video.frames_per_second": int(np.round(1.0 / self.dt)), 49 | } 50 | 51 | self.init_qpos = self.model.data.qpos.ravel().copy() 52 | self.init_qvel = self.model.data.qvel.ravel().copy() 53 | self._seed() 54 | 55 | def init_serialization(self, locals): 56 | Serializable.quick_init(self, locals) 57 | 58 | def log_diagnostics(self, paths): 59 | pass 60 | 61 | 62 | def get_asset_xml(xml_name): 63 | return os.path.join(ENV_ASSET_DIR, xml_name) 64 | -------------------------------------------------------------------------------- /envs/meta/readme.md: -------------------------------------------------------------------------------- 1 | # Meta RL Environments 2 | Based on the code https://github.com/Rondorf/BOReL and https://github.com/lmzintgraf/varibad 3 | -------------------------------------------------------------------------------- /envs/pomdp/readme.md: -------------------------------------------------------------------------------- 1 | # "Standard" POMDP Environments 2 | Based on the code https://github.com/oist-cnru/Variational-Recurrent-Models 3 | -------------------------------------------------------------------------------- /envs/pomdp/wrappers.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import spaces 3 | import numpy as np 4 | 5 | 6 | class POMDPWrapper(gym.Wrapper): 7 | def __init__(self, env, partially_obs_dims: list): 8 | super().__init__(env) 9 | self.partially_obs_dims = partially_obs_dims 10 | # can equal to the fully-observed env 11 | assert 0 < len(self.partially_obs_dims) <= self.observation_space.shape[0] 12 | 13 | self.observation_space = spaces.Box( 14 | low=self.observation_space.low[self.partially_obs_dims], 15 | high=self.observation_space.high[self.partially_obs_dims], 16 | dtype=np.float32, 17 | ) 18 | 19 | if self.env.action_space.__class__.__name__ == "Box": 20 | self.act_continuous = True 21 | # if continuous actions, make sure in [-1, 1] 22 | # NOTE: policy won't use action_space.low/high, just set [-1,1] 23 | # this is a bad practice... 24 | else: 25 | self.act_continuous = False 26 | self.true_state = None 27 | 28 | def get_obs(self, state): 29 | return state[self.partially_obs_dims].copy() 30 | 31 | def get_unobservable(self): 32 | unobserved_dims = [i for i in range(self.true_state.shape[0]) if i not in self.partially_obs_dims] 33 | return self.true_state[unobserved_dims].copy() 34 | 35 | def reset(self): 36 | state = self.env.reset() # no kwargs 37 | self.true_state = state 38 | return self.get_obs(state) 39 | 40 | def step(self, action): 41 | if self.act_continuous: 42 | # recover the action 43 | action = np.clip(action, -1, 1) # first clip into [-1, 1] 44 | lb = self.env.action_space.low 45 | ub = self.env.action_space.high 46 | action = lb + (action + 1.0) * 0.5 * (ub - lb) 47 | action = np.clip(action, lb, ub) 48 | 49 | state, reward, done, info = self.env.step(action) 50 | self.true_state = state 51 | return self.get_obs(state), reward, done, info 52 | 53 | 54 | if __name__ == "__main__": 55 | import envs 56 | 57 | env = gym.make("HopperBLT-F-v0") 58 | obs = env.reset() 59 | done = False 60 | step = 0 61 | while not done: 62 | next_obs, rew, done, info = env.step(env.action_space.sample()) 63 | step += 1 64 | print(step, done, info) 65 | -------------------------------------------------------------------------------- /envs/readme.md: -------------------------------------------------------------------------------- 1 | # POMDP Environments 2 | The POMDP environments are mainly adapted from [POMDP](https://github.com/twni2016/pomdp-baselines) and [ESCP](https://github.com/FanmingL/ESCP) 3 | ## Overview 4 | - `meta/`: Meta RL environments 5 | - `pomdp/`: "standard" POMDP environments 6 | - `rl-generalization`: Generalization in RL and Robust RL environments 7 | 8 | ## Normalized Action Space 9 | We make sure every environment has continuous action space [-1, 1]^|A|, exposed to the policy. Policy should not use `self.action_space.high` or `self.action_space.low`. 10 | 11 | In Meta RL and "standard" POMDP, we use the snipplet for normalizing the action space: 12 | ```python 13 | class EnvWrapper(gym.Wrapper): 14 | def step(self, action): 15 | action = np.clip(action, -1, 1) # first clip into [-1, 1] 16 | lb = self.env.action_space.low 17 | ub = self.env.action_space.high 18 | action = lb + (action + 1.) * 0.5 * (ub - lb) # recover the original action space 19 | action = np.clip(action, lb, ub) 20 | ... 21 | ``` 22 | 23 | ## Reproducibilty Issue in Gym Environments 24 | In current gym version (v0.21), we need to set seed for env and its action space to ensure reproducibilty 25 | ```python 26 | env.seed(seed) 27 | env.action_space.np_random.seed(seed) 28 | ``` 29 | However, I have not figured out how to do this for old gym version (v0.10) for SunBlaze envs. Please kindly pull request if you know. 30 | -------------------------------------------------------------------------------- /envs/rl_generalization/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info 3 | *.pdf 4 | *.png 5 | *.swp 6 | *.swo 7 | __pycache__ 8 | videos 9 | -------------------------------------------------------------------------------- /envs/rl_generalization/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 UC Berkeley, Intel Labs, and other contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /envs/rl_generalization/README.md: -------------------------------------------------------------------------------- 1 | # POMDP Environments for Evaluating Generalization and Robustness in RL 2 | Based on the code https://github.com/sunblaze-ucb/rl-generalization 3 | 4 | ## Using the Generalization Environments 5 | From original SunBlaze repo, 6 | 7 | ```python 8 | import gym 9 | import sunblaze_envs 10 | 11 | # Deterministic: the default version with fixed parameters 12 | fixed_env = sunblaze_envs.make('SunblazeCartPole-v0') 13 | 14 | # Random: parameters are sampled from a range nearby the default settings 15 | random_env = sunblaze_envs.make('SunblazeCartPoleRandomNormal-v0') 16 | 17 | # Extreme: parameters are sampled from an `extreme' range 18 | extreme_env = sunblaze_envs.make('SunblazeCartPoleRandomExtreme-v0') 19 | ``` 20 | In the case of CartPole, RandomNormal and RandomExtreme will vary the strength of each actions, the mass of the pole, and the length of the pole: 21 | 22 | Specific ranges for each environment setting are listed [here](sunblaze_envs#environment-details). See the code in [examples](/examples) for usage with example algorithms from OpenAI Baselines. 23 | 24 | ## Using the Robust RL Environments 25 | Similarly, we adopt the environments from [MRPO paper](https://proceedings.mlr.press/v139/jiang21c.html), for example 26 | 27 | ```python 28 | import gym 29 | import sunblaze_envs 30 | 31 | MRPO_walker_env = sunblaze_envs.make('MRPOWalker2dRandomNormal-v0') 32 | 33 | ``` 34 | -------------------------------------------------------------------------------- /envs/rl_generalization/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup, find_packages 3 | import sys 4 | 5 | 6 | if sys.version_info.major != 3: 7 | print( 8 | "This Python is only compatible with Python 3, but you are running " 9 | "Python {}. The installation will likely fail.".format(sys.version_info.major) 10 | ) 11 | 12 | package_name = "sunblaze_envs" 13 | authors = ["UC Berkeley", "Intel Labs", "and other contributors"] 14 | url = "https://github.com/sunblaze-ucb/rl-generalization" 15 | description = "Modifiable OpenAI Gym environments for studying generalization in RL" 16 | 17 | setup_py_dir = os.path.dirname(os.path.realpath(__file__)) 18 | package_dir = os.path.join(setup_py_dir, package_name) 19 | 20 | # Discover asset files. 21 | ASSET_EXTENSIONS = ["png", "wad", "txt"] 22 | assets = [] 23 | for root, dirs, files in os.walk(package_dir): 24 | for filename in files: 25 | extension = os.path.splitext(filename)[1][1:] 26 | if extension and extension in ASSET_EXTENSIONS: 27 | filename = os.path.join(root, filename) 28 | assets.append(filename[1 + len(package_dir) :]) 29 | 30 | setup( 31 | name=package_name, 32 | version="0.1.0", 33 | description=description, 34 | author=", ".join(authors), 35 | # maintainer_email="", 36 | url=url, 37 | packages=find_packages(exclude=("examples",)), 38 | package_data={"": assets}, 39 | dependency_links=( 40 | # "git+https://github.com/kostko/omgifol.git@master#egg=omgifol-0.1.0", 41 | "git+https://github.com/openai/gym.git@094e6b8e6a102644667d53d9dac6f2245bf80c6f#egg=gym-0.10.8r1", 42 | "git+https://github.com/openai/baselines.git@2b0283b9db18c768f8e9fa29fbedc5e48499acc6#egg=baselines-0.1.5r1", 43 | ), 44 | install_requires=[ 45 | "gym==0.10.8r1", 46 | #'gym==0.10.5', 47 | "Box2D==2.3.2", 48 | "cocos2d==0.6.5", 49 | "numpy==1.14.2", 50 | "scipy==1.0.0", 51 | #'vizdoom==1.1.2', 52 | #'omgifol>=0.1.0', 53 | ], 54 | extras_require={ 55 | "examples": [ 56 | "baselines==0.1.5r1", # use dep_link for specific commit 57 | "PyYAML==3.12", 58 | "opencv-python==3.4.0.12", 59 | "cloudpickle==0.4.1", 60 | "natsort==5.1.0", 61 | "chainer==3.3.0", 62 | "chainerrl==0.3.0", 63 | ], 64 | }, 65 | ) 66 | -------------------------------------------------------------------------------- /envs/rl_generalization/sunblaze_envs/README.md: -------------------------------------------------------------------------------- 1 | ## Overview of Generalization Environments 2 | 3 | There are six environments, built on top of the corresponding OpenAI Gym and Roboschool implementations: 4 | * CartPole 5 | * MountainCar 6 | * Acrobot 7 | * Pendulum 8 | * HalfCheetah 9 | * Hopper 10 | 11 | Each has three versions: 12 | * **D**: Environment parameters are set to the default values in Gym and Roboschool. Access by Sunblaze*Environment*-v0, e.g. SunblazeCartPole-v0. 13 | * **R**: Environment parameters are randomly sampled from intervals containing their default values. Access by Sunblaze*Environment*RandomNormal-v0. 14 | * **E**: Environment parameters are randomly sampled from intervals outside those in **R**, containing more extreme values. Access by Sunblaze*Environment*RandomExtreme-v0. 15 | 16 | ## Environment details 17 | 18 | Ranges of parameters for each version of each environment, using set notation. 19 | 20 | | Environment | Parameter | D | R | E | 21 | | --- | --- | --- | --- | --- | 22 | | CartPole | Force | 10 | [5,15] | [1,5] U [15,20] | 23 | | | Length | 0.5 | [0.25,0.75] | [0.05,0.25] U [0.75,1.0] | 24 | | | Mass | 0.1 | [0.05,0.5] | [0.01,0.05] U [0.5,1.0] | 25 | | MountainCar | Force | 0.001 | [0.0005,0.005] | [0.0001,0.0005] U [0.005,0.01] | 26 | | | Mass | 0.0025 | [0.001,0.005] | [0.0005,0.001] U [0.005,0.01] | 27 | | Acrobot | Length | 1 | [0.75,1.25] | [0.5,0.75] U [1.25,1.5] | 28 | | | Mass | 1 | [0.75,1.25] | [0.5,0.75] U [1.25,1.5] | 29 | | | MOI | 1 | [0.75,1.25] | [0.5,0.75] U [1.25,1.5] | 30 | | Pendulum | Length | 1 | [0.75,1.25] | [0.5,0.75] U [1.25,1.5] | 31 | | | Mass | 1 | [0.75,1.25] | [0.5,0.75] U [1.25,1.5] | 32 | | HalfCheetah | Power | 0.90 | [0.70,1.10] | [0.50,0.70] U [1.10,1.30] | 33 | | | Density | 1000 | [750,1250] | [500,750] U [1250,1500] | 34 | | | Friction | 0.8 | [0.5,1.1] | [0.2,0.5] U [1.1,1.4] | 35 | | Hopper | Power | 0.75 | [0.60,0.90] | [0.40,0.60] U [0.90,1.10] | 36 | | | Density | 1000 | [750,1250] | [500,750] U [1250,1500] | 37 | | | Friction | 0.8 | [0.5,1.1] | [0.2,0.5] U [1.1,1.4] | 38 | -------------------------------------------------------------------------------- /envs/rl_generalization/sunblaze_envs/assets/vizdoom/.gitignore: -------------------------------------------------------------------------------- 1 | *.bak 2 | -------------------------------------------------------------------------------- /envs/rl_generalization/sunblaze_envs/assets/vizdoom/basic.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/rl_generalization/sunblaze_envs/assets/vizdoom/basic.wad -------------------------------------------------------------------------------- /envs/rl_generalization/sunblaze_envs/assets/vizdoom/basic_floor_ceiling_flipped.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/rl_generalization/sunblaze_envs/assets/vizdoom/basic_floor_ceiling_flipped.wad -------------------------------------------------------------------------------- /envs/rl_generalization/sunblaze_envs/assets/vizdoom/basic_torches.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/rl_generalization/sunblaze_envs/assets/vizdoom/basic_torches.wad -------------------------------------------------------------------------------- /envs/rl_generalization/sunblaze_envs/assets/vizdoom/navigation.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/rl_generalization/sunblaze_envs/assets/vizdoom/navigation.wad -------------------------------------------------------------------------------- /envs/rl_generalization/sunblaze_envs/assets/vizdoom/navigation_floor_ceiling_flipped.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/rl_generalization/sunblaze_envs/assets/vizdoom/navigation_floor_ceiling_flipped.wad -------------------------------------------------------------------------------- /envs/rl_generalization/sunblaze_envs/assets/vizdoom/navigation_new_layout.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/rl_generalization/sunblaze_envs/assets/vizdoom/navigation_new_layout.wad -------------------------------------------------------------------------------- /envs/rl_generalization/sunblaze_envs/assets/vizdoom/navigation_torches.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/rl_generalization/sunblaze_envs/assets/vizdoom/navigation_torches.wad -------------------------------------------------------------------------------- /envs/rl_generalization/sunblaze_envs/assets/vizdoom/texture_set_b.txt: -------------------------------------------------------------------------------- 1 | TANROCK2 2 | TANROCK3 3 | TANROCK4 4 | TANROCK5 5 | TANROCK7 6 | TANROCK8 7 | TEKBRON1 8 | TEKBRON2 9 | TEKGREN1 10 | TEKGREN2 11 | TEKGREN3 12 | TEKGREN4 13 | TEKGREN5 14 | TEKLITE 15 | TEKLITE2 16 | TEKWALL1 17 | TEKWALL2 18 | TEKWALL3 19 | TEKWALL4 20 | TEKWALL5 21 | TEKWALL6 22 | WOOD1 23 | WOOD10 24 | WOOD12 25 | WOOD3 26 | WOOD4 27 | WOOD5 28 | WOOD6 29 | WOOD7 30 | WOOD8 31 | WOOD9 32 | WOODGARG 33 | WOODMET1 34 | WOODMET2 35 | WOODMET3 36 | WOODMET4 37 | WOODSKUL 38 | WOODVERT 39 | ZDOORB1 40 | ZDOORF1 41 | ZELDOOR 42 | ZIMMER1 43 | ZIMMER2 44 | ZIMMER3 45 | ZIMMER4 46 | ZIMMER5 47 | ZIMMER7 48 | ZIMMER8 49 | ZZWOLF1 50 | ZZWOLF10 51 | ZZWOLF11 52 | ZZWOLF12 53 | ZZWOLF13 54 | ZZWOLF2 55 | ZZWOLF3 56 | ZZWOLF4 57 | ZZWOLF5 58 | ZZWOLF6 59 | ZZWOLF7 60 | ZZWOLF9 61 | ZZZFACE1 62 | ZZZFACE2 63 | ZZZFACE3 64 | ZZZFACE4 65 | ZZZFACE5 66 | ZZZFACE6 67 | ZZZFACE7 68 | ZZZFACE8 69 | ZZZFACE9 -------------------------------------------------------------------------------- /envs/rl_generalization/sunblaze_envs/assets/vizdoom/thing_set_a.txt: -------------------------------------------------------------------------------- 1 | 5 2 | 6 3 | 13 4 | 30 5 | 31 6 | 34 7 | 35 8 | 38 9 | 39 10 | 40 11 | 44 12 | 45 13 | 46 -------------------------------------------------------------------------------- /envs/rl_generalization/sunblaze_envs/assets/vizdoom/thing_set_b.txt: -------------------------------------------------------------------------------- 1 | 32 2 | 48 3 | 55 4 | 56 5 | 57 6 | 2028 7 | 5050 -------------------------------------------------------------------------------- /envs/rl_generalization/sunblaze_envs/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import gym 4 | 5 | 6 | class BaseGymEnvironment(gym.Env): 7 | """Base class for all Gym environments.""" 8 | 9 | @property 10 | def parameters(self): 11 | """Return environment parameters.""" 12 | return { 13 | "id": self.spec.id, 14 | } 15 | 16 | 17 | class EnvBinarySuccessMixin(ABC): 18 | """Adds binary success metric to environment.""" 19 | 20 | @abstractmethod 21 | def is_success(self): 22 | """Returns True is current state indicates success, False otherwise""" 23 | pass 24 | -------------------------------------------------------------------------------- /envs/rl_generalization/sunblaze_envs/monitor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import gym 5 | 6 | 7 | class MonitorParameters(gym.Wrapper): 8 | """Environment wrapper which records all environment parameters.""" 9 | 10 | current_parameters = None 11 | 12 | def __init__(self, env, output_filename): 13 | """ 14 | Construct parameter monitor wrapper. 15 | 16 | :param env: Wrapped environment 17 | :param output_filename: Output log filename 18 | """ 19 | self._output_filename = output_filename 20 | with open(output_filename, "w"): 21 | # Truncate output file. 22 | pass 23 | 24 | super(MonitorParameters, self).__init__(env) 25 | 26 | def step(self, action): 27 | result = self.env.step(action) 28 | self.record_parameters() 29 | return result 30 | 31 | def reset(self): 32 | result = self.env.reset() 33 | self.record_parameters() 34 | return result 35 | 36 | def record_parameters(self): 37 | """Record current environment parameters.""" 38 | if not hasattr(self.env.unwrapped, "parameters"): 39 | return 40 | if self.env.unwrapped.parameters == self.current_parameters: 41 | return 42 | 43 | # Record parameter set in output file. 44 | self.current_parameters = self.env.unwrapped.parameters 45 | with open(self._output_filename, "a") as output_file: 46 | output_file.write(json.dumps(self.current_parameters)) 47 | output_file.write("\n") 48 | -------------------------------------------------------------------------------- /envs/rl_generalization/sunblaze_envs/time_limit.py: -------------------------------------------------------------------------------- 1 | from gym.wrappers.time_limit import TimeLimit as TimeLimitBase 2 | import time 3 | 4 | 5 | class TimeLimit(TimeLimitBase): 6 | """Updated to support reset() with reset_params flag for Adaptive""" 7 | 8 | def reset(self, reset_params=True): 9 | self._episode_started_at = time.time() 10 | self._elapsed_steps = 0 11 | return self.env.reset(reset_params) 12 | -------------------------------------------------------------------------------- /envs/rl_generalization/sunblaze_envs/wrappers.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import gym 4 | from gym.envs.registration import load 5 | import numpy as np 6 | 7 | 8 | def ActionDelayWrapper(delay_range_start, delay_range_end): 9 | """Create an action delay wrapper. 10 | 11 | :param delay_range_start: Minimum delay 12 | :param delay_range_end: Maximum delay 13 | """ 14 | 15 | class ActionDelayWrapper(gym.Wrapper): 16 | def _step(self, action): 17 | self._action_buffer.append(action) 18 | action = self._action_buffer.popleft() 19 | return self.env.step(action) 20 | 21 | def _reset(self): 22 | self._action_delay = np.random.randint(delay_range_start, delay_range_end) 23 | self._action_buffer = collections.deque( 24 | [0 for _ in range(self._action_delay)] 25 | ) 26 | return self.env.reset() 27 | 28 | return ActionDelayWrapper 29 | 30 | 31 | def wrap_environment(wrapped_class, wrappers=None, **kwargs): 32 | """Helper for wrapping environment classes.""" 33 | if wrappers is None: 34 | wrappers = [] 35 | 36 | env_class = load(wrapped_class) 37 | env = env_class(**kwargs) 38 | for wrapper, wrapper_kwargs in wrappers: 39 | wrapper_class = load(wrapper) 40 | wrapper = wrapper_class(**wrapper_kwargs) 41 | env = wrapper(env) 42 | 43 | return env 44 | -------------------------------------------------------------------------------- /envs/rl_generalization/test.py: -------------------------------------------------------------------------------- 1 | import sunblaze_envs 2 | 3 | env = sunblaze_envs.make("SunblazeHalfCheetahRandomNormal-v0") # (26, 6) 4 | # env = sunblaze_envs.make('SunblazeHopperRandomNormal-v0') # (15, 3) 5 | print(env.observation_space, env.action_space) 6 | print(env.action_space.high, env.action_space.low) # [-1,1] 7 | obs = env.reset() 8 | print(obs) 9 | print(env._max_episode_steps) 10 | done = False 11 | step = 0 12 | while not done: 13 | step += 1 14 | obs, rew, done, info = env.step(env.action_space.sample()) 15 | # there exists early failure done=True 16 | # env.unwrapped.is_success() measures the z >= 20m 17 | print(step, obs, rew, done, info, env.unwrapped.is_success()) 18 | -------------------------------------------------------------------------------- /envs/torchkit/constant.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | LSTM_name = "lstm" 4 | GRU_name = "gru" 5 | RNNs = { 6 | LSTM_name: nn.LSTM, 7 | GRU_name: nn.GRU, 8 | } 9 | -------------------------------------------------------------------------------- /envs/torchkit/modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contain some self-contained modules. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class HuberLoss(nn.Module): 9 | def __init__(self, delta=1): 10 | super().__init__() 11 | self.huber_loss_delta1 = nn.SmoothL1Loss() 12 | self.delta = delta 13 | 14 | def forward(self, x, x_hat): 15 | loss = self.huber_loss_delta1(x / self.delta, x_hat / self.delta) 16 | return loss * self.delta * self.delta 17 | 18 | 19 | class LayerNorm(nn.Module): 20 | """ 21 | Simple 1D LayerNorm. 22 | """ 23 | 24 | def __init__(self, features, center=True, scale=False, eps=1e-6): 25 | super().__init__() 26 | self.center = center 27 | self.scale = scale 28 | self.eps = eps 29 | if self.scale: 30 | self.scale_param = nn.Parameter(torch.ones(features)) 31 | else: 32 | self.scale_param = None 33 | if self.center: 34 | self.center_param = nn.Parameter(torch.zeros(features)) 35 | else: 36 | self.center_param = None 37 | 38 | def forward(self, x): 39 | mean = x.mean(-1, keepdim=True) 40 | std = x.std(-1, keepdim=True) 41 | output = (x - mean) / (std + self.eps) 42 | if self.scale: 43 | output = output * self.scale_param 44 | if self.center: 45 | output = output + self.center_param 46 | return output 47 | -------------------------------------------------------------------------------- /envs/torchkit/policies_base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Policy(object, metaclass=abc.ABCMeta): 5 | """ 6 | General policy interface. 7 | """ 8 | 9 | @abc.abstractmethod 10 | def get_action(self, observation): 11 | """ 12 | :param observation: 13 | :return: action, debug_dictionary 14 | """ 15 | pass 16 | 17 | def reset(self): 18 | pass 19 | 20 | 21 | class ExplorationPolicy(Policy, metaclass=abc.ABCMeta): 22 | def set_num_steps_total(self, t): 23 | pass 24 | 25 | 26 | class SerializablePolicy(Policy, metaclass=abc.ABCMeta): 27 | """ 28 | Policy that can be serialized. 29 | """ 30 | 31 | def get_param_values(self): 32 | return None 33 | 34 | def set_param_values(self, values): 35 | pass 36 | 37 | """ 38 | Parameters should be passed as np arrays in the two functions below. 39 | """ 40 | 41 | def get_param_values_np(self): 42 | return None 43 | 44 | def set_param_values_np(self, values): 45 | pass 46 | -------------------------------------------------------------------------------- /envs/torchkit/serializable.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on rllab's serializable.py file 3 | 4 | https://github.com/rll/rllab 5 | """ 6 | 7 | import inspect 8 | import sys 9 | 10 | 11 | class Serializable(object): 12 | def __init__(self, *args, **kwargs): 13 | self.__args = args 14 | self.__kwargs = kwargs 15 | 16 | def quick_init(self, locals_): 17 | if getattr(self, "_serializable_initialized", False): 18 | return 19 | if sys.version_info >= (3, 0): 20 | spec = inspect.getfullargspec(self.__init__) 21 | # Exclude the first "self" parameter 22 | if spec.varkw: 23 | kwargs = locals_[spec.varkw].copy() 24 | else: 25 | kwargs = dict() 26 | if spec.kwonlyargs: 27 | for key in spec.kwonlyargs: 28 | kwargs[key] = locals_[key] 29 | else: 30 | spec = inspect.getargspec(self.__init__) 31 | if spec.keywords: 32 | kwargs = locals_[spec.keywords] 33 | else: 34 | kwargs = dict() 35 | if spec.varargs: 36 | varargs = locals_[spec.varargs] 37 | else: 38 | varargs = tuple() 39 | in_order_args = [locals_[arg] for arg in spec.args][1:] 40 | self.__args = tuple(in_order_args) + varargs 41 | self.__kwargs = kwargs 42 | setattr(self, "_serializable_initialized", True) 43 | 44 | def __getstate__(self): 45 | return {"__args": self.__args, "__kwargs": self.__kwargs} 46 | 47 | def __setstate__(self, d): 48 | # convert all __args to keyword-based arguments 49 | if sys.version_info >= (3, 0): 50 | spec = inspect.getfullargspec(self.__init__) 51 | else: 52 | spec = inspect.getargspec(self.__init__) 53 | in_order_args = spec.args[1:] 54 | out = type(self)(**dict(zip(in_order_args, d["__args"]), **d["__kwargs"])) 55 | self.__dict__.update(out.__dict__) 56 | 57 | @classmethod 58 | def clone(cls, obj, **kwargs): 59 | assert isinstance(obj, Serializable) 60 | d = obj.__getstate__() 61 | d["__kwargs"] = dict(d["__kwargs"], **kwargs) 62 | out = type(obj).__new__(type(obj)) 63 | out.__setstate__(d) 64 | return out 65 | -------------------------------------------------------------------------------- /envs/utils/system.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | import datetime 5 | import dateutil.tz 6 | 7 | 8 | def reproduce(seed): 9 | """ 10 | This can only fix the randomness of numpy and torch 11 | To fix the environment's, please use 12 | env.seed(seed) 13 | env.action_space.np_random.seed(seed) 14 | We have add these in our training script 15 | """ 16 | np.random.seed(seed) 17 | random.seed(seed) 18 | torch.manual_seed(seed) 19 | if torch.cuda.is_available(): 20 | torch.cuda.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | 23 | 24 | def now_str(): 25 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 26 | return now.strftime("%m-%d:%H-%M:%S.%f")[:-4] 27 | -------------------------------------------------------------------------------- /envs/yang_domains/assets/box_1d.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 36 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /envs/yang_domains/assets/clockwise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/clockwise.png -------------------------------------------------------------------------------- /envs/yang_domains/assets/grid_markers.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /envs/yang_domains/assets/gripah/narrow_finger.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/gripah/narrow_finger.STL -------------------------------------------------------------------------------- /envs/yang_domains/assets/gripah/narrow_finger_rescaled.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/gripah/narrow_finger_rescaled.STL -------------------------------------------------------------------------------- /envs/yang_domains/assets/gripah/narrow_finger_tip.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/gripah/narrow_finger_tip.STL -------------------------------------------------------------------------------- /envs/yang_domains/assets/gripah/wide_finger.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/gripah/wide_finger.STL -------------------------------------------------------------------------------- /envs/yang_domains/assets/gripah/wide_finger_rescaled.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/gripah/wide_finger_rescaled.STL -------------------------------------------------------------------------------- /envs/yang_domains/assets/gripah/wide_finger_tip.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/gripah/wide_finger_tip.STL -------------------------------------------------------------------------------- /envs/yang_domains/assets/gripah/wrist.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/gripah/wrist.STL -------------------------------------------------------------------------------- /envs/yang_domains/assets/gripah/wrist_rescaled.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/gripah/wrist_rescaled.STL -------------------------------------------------------------------------------- /envs/yang_domains/assets/gripah_asset.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /envs/yang_domains/assets/gripah_body.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 16 | 18 | 20 | 22 | 23 | 24 | 25 | 26 | 28 | 30 | 32 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /envs/yang_domains/assets/gripah_contact.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/c_base.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/c_base.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/c_forearm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/c_forearm.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/c_robotiq_85_gripper_joint_3_L.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/c_robotiq_85_gripper_joint_3_L.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/c_robotiq_85_gripper_joint_3_R.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/c_robotiq_85_gripper_joint_3_R.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/c_shoulder.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/c_shoulder.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/c_upperarm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/c_upperarm.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/c_wrist1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/c_wrist1.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/c_wrist2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/c_wrist2.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/c_wrist3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/c_wrist3.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/glass_cup.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/glass_cup.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/glass_cup_2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/glass_cup_2.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/glass_cup_3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/glass_cup_3.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/inner_finger_coarse.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/inner_finger_coarse.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/inner_finger_fine.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/inner_finger_fine.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/inner_knuckle_coarse.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/inner_knuckle_coarse.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/inner_knuckle_fine.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/inner_knuckle_fine.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/new_solo_cup.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/new_solo_cup.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/outer_finger_coarse.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/outer_finger_coarse.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/outer_finger_fine.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/outer_finger_fine.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/outer_knuckle_coarse.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/outer_knuckle_coarse.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/outer_knuckle_fine.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/outer_knuckle_fine.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/red_solo_cup.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/red_solo_cup.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/robotiq_85_base_link_coarse.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/robotiq_85_base_link_coarse.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/robotiq_85_base_link_fine.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/robotiq_85_base_link_fine.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/smaller_solo_cup.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/smaller_solo_cup.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/solo_cup.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/solo_cup.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/upd_solo_cup.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/upd_solo_cup.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/v_base.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/v_base.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/v_forearm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/v_forearm.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/v_robotiq_85_gripper_joint_3_L.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/v_robotiq_85_gripper_joint_3_L.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/v_robotiq_85_gripper_joint_3_R.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/v_robotiq_85_gripper_joint_3_R.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/v_shoulder.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/v_shoulder.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/v_upperarm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/v_upperarm.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/v_wrist1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/v_wrist1.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/v_wrist2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/v_wrist2.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/muj_gripper/v_wrist3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/muj_gripper/v_wrist3.stl -------------------------------------------------------------------------------- /envs/yang_domains/assets/objects/bump_40_mujoco.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/objects/bump_40_mujoco.STL -------------------------------------------------------------------------------- /envs/yang_domains/assets/objects/bump_50_mujoco.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/objects/bump_50_mujoco.STL -------------------------------------------------------------------------------- /envs/yang_domains/assets/objects/bump_80_mujoco.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/objects/bump_80_mujoco.STL -------------------------------------------------------------------------------- /envs/yang_domains/assets/objects/plate_half.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/objects/plate_half.STL -------------------------------------------------------------------------------- /envs/yang_domains/assets/objects/plate_whole.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/assets/objects/plate_whole.STL -------------------------------------------------------------------------------- /envs/yang_domains/dmc_acrobot.py: -------------------------------------------------------------------------------- 1 | import dmc2gym 2 | from .wrappers import ConcatObs 3 | from .wrappers_for_image import Normalize255Image, GrayscaleImage, ConcatImages 4 | 5 | 6 | def mdp(): 7 | return dmc2gym.make(domain_name="acrobot", task_name="swingup", keys_to_exclude=[], track_prev_action=False, frame_skip=5) 8 | 9 | 10 | def p(): 11 | return dmc2gym.make(domain_name="acrobot", task_name="swingup", keys_to_exclude=['velocity'], track_prev_action=False, frame_skip=5) 12 | 13 | 14 | def va(): 15 | return dmc2gym.make(domain_name="acrobot", task_name="swingup", keys_to_exclude=['orientations'], track_prev_action=True, frame_skip=5) 16 | 17 | 18 | def p_concat5(): 19 | return ConcatObs(p(), 5) 20 | 21 | 22 | def va_concat10(): 23 | return ConcatObs(va(), 10) 24 | -------------------------------------------------------------------------------- /envs/yang_domains/dmc_cart2pole.py: -------------------------------------------------------------------------------- 1 | import dmc2gym 2 | from .wrappers import ConcatObs 3 | from .wrappers_for_image import Normalize255Image, GrayscaleImage, ConcatImages 4 | 5 | 6 | def mdp(): 7 | return dmc2gym.make(domain_name="cartpole", task_name="two_poles", keys_to_exclude=[], frame_skip=5) 8 | 9 | 10 | # def p(): 11 | # return dmc2gym.make(domain_name="cartpole", task_name="balance", keys_to_exclude=['velocity'], frame_skip=5) 12 | # 13 | # 14 | # def va(): 15 | # return dmc2gym.make(domain_name="cartpole", task_name="balance", keys_to_exclude=['position'], frame_skip=5) 16 | # 17 | # 18 | # def p_concat5(): 19 | # return ConcatObs(p(), 5) 20 | # 21 | # def va_concat10(): 22 | # return ConcatObs(va(), 10) 23 | -------------------------------------------------------------------------------- /envs/yang_domains/dmc_cart3pole.py: -------------------------------------------------------------------------------- 1 | import dmc2gym 2 | from .wrappers import ConcatObs 3 | from .wrappers_for_image import Normalize255Image, GrayscaleImage, ConcatImages 4 | 5 | 6 | def mdp(): 7 | return dmc2gym.make(domain_name="cartpole", task_name="three_poles", keys_to_exclude=[], frame_skip=5) 8 | 9 | 10 | # def p(): 11 | # return dmc2gym.make(domain_name="cartpole", task_name="balance", keys_to_exclude=['velocity'], frame_skip=5) 12 | # 13 | # 14 | # def va(): 15 | # return dmc2gym.make(domain_name="cartpole", task_name="balance", keys_to_exclude=['position'], frame_skip=5) 16 | # 17 | # 18 | # def p_concat5(): 19 | # return ConcatObs(p(), 5) 20 | # 21 | # def va_concat10(): 22 | # return ConcatObs(va(), 10) 23 | -------------------------------------------------------------------------------- /envs/yang_domains/dmc_cartpole_b.py: -------------------------------------------------------------------------------- 1 | import dmc2gym 2 | from .wrappers import ConcatObs 3 | from .wrappers_for_image import Normalize255Image, GrayscaleImage, ConcatImages 4 | 5 | 6 | def mdp(): 7 | return dmc2gym.make(domain_name="cartpole", task_name="balance", keys_to_exclude=[], frame_skip=5, track_prev_action=False) 8 | 9 | 10 | def p(): 11 | return dmc2gym.make(domain_name="cartpole", task_name="balance", keys_to_exclude=['velocity'], frame_skip=5, track_prev_action=False) 12 | 13 | 14 | def va(): 15 | return dmc2gym.make(domain_name="cartpole", task_name="balance", keys_to_exclude=['position'], track_prev_action=True, frame_skip=5) 16 | 17 | 18 | def p_concat5(): 19 | return ConcatObs(p(), 5) 20 | 21 | def va_concat10(): 22 | return ConcatObs(va(), 10) 23 | 24 | 25 | def pomdp_img(): 26 | """frame skip follows from the Dreamer benchmark""" 27 | raw_env = dmc2gym.make( 28 | domain_name="cartpole", 29 | task_name="balance", 30 | keys_to_exclude=[], 31 | visualize_reward=False, 32 | from_pixels=True, 33 | frame_skip=2 34 | ) 35 | return GrayscaleImage(Normalize255Image(raw_env)) 36 | 37 | 38 | def mdp_img_concat3(): 39 | return ConcatImages(pomdp_img(), 3) 40 | -------------------------------------------------------------------------------- /envs/yang_domains/dmc_cartpole_su.py: -------------------------------------------------------------------------------- 1 | import dmc2gym 2 | from .wrappers import ConcatObs 3 | 4 | 5 | def mdp(): 6 | return dmc2gym.make(domain_name="cartpole", task_name="swingup", keys_to_exclude=[], frame_skip=5, track_prev_action=False) 7 | 8 | 9 | def p(): 10 | return dmc2gym.make(domain_name="cartpole", task_name="swingup", keys_to_exclude=['velocity'], frame_skip=5, track_prev_action=False) 11 | 12 | 13 | def va(): 14 | return dmc2gym.make(domain_name="cartpole", task_name="swingup", keys_to_exclude=['position'], frame_skip=5, track_prev_action=True) 15 | 16 | 17 | def p_concat5(): 18 | return ConcatObs(p(), 5) 19 | 20 | def va_concat10(): 21 | return ConcatObs(va(), 10) 22 | -------------------------------------------------------------------------------- /envs/yang_domains/dmc_pendulum_su.py: -------------------------------------------------------------------------------- 1 | import dmc2gym 2 | 3 | 4 | def mdp(): 5 | return dmc2gym.make(domain_name="pendulum", task_name="swingup", keys_to_exclude=[], frame_skip=5, track_prev_action=False) 6 | -------------------------------------------------------------------------------- /envs/yang_domains/dmc_walker_walk.py: -------------------------------------------------------------------------------- 1 | import dmc2gym 2 | from .wrappers_for_image import Normalize255Image, GrayscaleImage, ConcatImages 3 | 4 | 5 | def pomdp_img(): 6 | """frame skip follows from the Dreamer benchmark""" 7 | raw_env = dmc2gym.make( 8 | domain_name="walker", 9 | task_name="walk", 10 | keys_to_exclude=[], 11 | visualize_reward=False, 12 | from_pixels=True, 13 | frame_skip=2 14 | ) 15 | return GrayscaleImage(Normalize255Image(raw_env)) 16 | 17 | 18 | def mdp_img_concat3(): 19 | return ConcatImages(pomdp_img(), 3) 20 | -------------------------------------------------------------------------------- /envs/yang_domains/pybullet_ant.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import pybullet_envs 3 | from .wrappers import FilterObsByIndex 4 | 5 | 6 | def p(): 7 | return FilterObsByIndex( 8 | gym.make("AntBulletEnv-v0"), 9 | indices_to_keep=[0, 1, 2, 3, 4, 5, 6, 7] + [8, 10, 12, 14, 16, 18, 20, 22] + [24, 25, 26, 27] 10 | ) 11 | 12 | 13 | def v(): 14 | return FilterObsByIndex( 15 | gym.make("AntBulletEnv-v0"), 16 | indices_to_keep=[0, 1, 2, 3, 4, 5, 6, 7] + [9, 11, 13, 15, 17, 19, 21, 23] + [24, 25, 26, 27] 17 | ) 18 | -------------------------------------------------------------------------------- /envs/yang_domains/pybullet_halfcheetah.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import pybullet_envs 3 | from .wrappers import FilterObsByIndex 4 | 5 | 6 | def p(): 7 | return FilterObsByIndex( 8 | gym.make("HalfCheetahBulletEnv-v0"), 9 | indices_to_keep=[0, 1, 2, 3, 4, 5, 6, 7] + [8, 10, 12, 14, 16, 18] + [20, 21, 22, 23, 24, 25] 10 | ) 11 | 12 | 13 | def v(): 14 | return FilterObsByIndex( 15 | gym.make("HalfCheetahBulletEnv-v0"), 16 | indices_to_keep=[0, 1, 2, 3, 4, 5, 6, 7] + [9, 11, 13, 15, 17, 19] + [20, 21, 22, 23, 24, 25] 17 | ) 18 | -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Ravens Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/bumps/bump_40.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/bumps/bump_40.stl -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/bumps/bump_40_blue.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/bumps/bump_40_red.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/bumps/bump_40_virtual.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/bumps/bump_50.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/bumps/bump_50.stl -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/bumps/bump_50.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/cup/cup.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/cup/cup.stl -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/cup/cup.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/plane/checker_blue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/plane/checker_blue.png -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/plane/plane.mtl: -------------------------------------------------------------------------------- 1 | newmtl Material 2 | Ns 10.0000 3 | Ni 1.5000 4 | d 1.0000 5 | Tr 0.0000 6 | Tf 1.0000 1.0000 1.0000 7 | illum 2 8 | Ka 0.0000 0.0000 0.0000 9 | Kd 0.5880 0.5880 0.5880 10 | Ks 0.0000 0.0000 0.0000 11 | Ke 0.0000 0.0000 0.0000 12 | map_Ka cube.tga 13 | map_Kd checker_blue.png 14 | 15 | 16 | -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/plane/plane.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/plane/plane100.obj: -------------------------------------------------------------------------------- 1 | # Blender v2.66 (sub 1) OBJ File: '' 2 | # www.blender.org 3 | mtllib plane.mtl 4 | o Plane 5 | v 100.000000 -100.000000 0.000000 6 | v 100.000000 100.000000 0.000000 7 | v -100.000000 100.000000 0.000000 8 | v -100.000000 -100.000000 0.000000 9 | 10 | vt 100.000000 0.000000 11 | vt 100.000000 100.000000 12 | vt 0.000000 100.000000 13 | vt 0.000000 0.000000 14 | 15 | 16 | 17 | usemtl Material 18 | s off 19 | f 1/1 2/2 3/3 20 | f 1/1 3/3 4/4 21 | 22 | 23 | -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/plate/plate.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/plate/plate.stl -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/plate/plate.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/plate/plate_half.urdf: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/plate/plate_holder.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/plate/plate_holder.stl -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/plate/plate_holder.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/plate/plate_lower_half.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/plate/plate_lower_half.stl -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/plate/plate_upper_half.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/plate/plate_upper_half.stl -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/shelf/shelf_back_board.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/shelf/shelf_back_board.stl -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/shelf/shelf_back_board.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/shelf/shelf_horizontal_board.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/shelf/shelf_horizontal_board.stl -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/shelf/shelf_horizontal_board.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/shelf/shelf_side_board.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robot_envs/assets/shelf/shelf_side_board.stl -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/shelf/shelf_side_board.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/workspace/grid_mark.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/workspace/plane.obj: -------------------------------------------------------------------------------- 1 | # Blender v2.66 (sub 1) OBJ File: '' 2 | # www.blender.org 3 | mtllib plane.mtl 4 | o Plane 5 | v 15.000000 -15.000000 0.000000 6 | v 15.000000 15.000000 0.000000 7 | v -15.000000 15.000000 0.000000 8 | v -15.000000 -15.000000 0.000000 9 | 10 | vt 15.000000 0.000000 11 | vt 15.000000 15.000000 12 | vt 0.000000 15.000000 13 | vt 0.000000 0.000000 14 | 15 | usemtl Material 16 | s off 17 | f 1/1 2/2 3/3 18 | f 1/1 3/3 4/4 19 | -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/workspace/rail.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /envs/yang_domains/robot_envs/assets/workspace/workspace.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /envs/yang_domains/robots/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/__init__.py -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/jaco/meshes/arm.SLDPRT: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/arm.SLDPRT -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/jaco/meshes/arm.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/arm.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/jaco/meshes/arm_half_1.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/arm_half_1.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/jaco/meshes/arm_half_2.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/arm_half_2.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/jaco/meshes/arm_mico.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/arm_mico.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/jaco/meshes/base.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/base.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/jaco/meshes/finger_distal.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/finger_distal.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/jaco/meshes/finger_proximal.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/finger_proximal.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/jaco/meshes/forearm.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/forearm.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/jaco/meshes/forearm_mico.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/forearm_mico.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/jaco/meshes/hand_2finger.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/hand_2finger.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/jaco/meshes/hand_3finger.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/hand_3finger.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/jaco/meshes/ring_big.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/ring_big.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/jaco/meshes/ring_small.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/ring_small.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/jaco/meshes/shoulder.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/shoulder.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/jaco/meshes/wrist.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/wrist.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/jaco/meshes/wrist_spherical_1.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/wrist_spherical_1.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/jaco/meshes/wrist_spherical_2.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/jaco/meshes/wrist_spherical_2.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/rdda/meshes/narrow_finger.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/rdda/meshes/narrow_finger.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/rdda/meshes/wide_finger.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/rdda/meshes/wide_finger.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/rdda/meshes/wrist.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/rdda/meshes/wrist.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/robotiq/README.md: -------------------------------------------------------------------------------- 1 | ## Robotiq 2F 85 gripper 2 | For this gripper, the following Github repo can be used as a reference: https://github.com/Shreeyak/robotiq.git 3 | 4 | ### mimic tag in URDF 5 | This gripper is developed for ROS and uses the `mimic` tag within the URDF files to make the gripper move. From our research `mimic` tag within URDF is not supported by pybullet. To overcome this, one can use the `createConstraint` function. Please refer to [this](https://github.com/bulletphysics/bullet3/blob/master/examples/pybullet/examples/mimicJointConstraint.py) example from the bullet3 repo to see how to replicate a `mimic` joint: 6 | 7 | ```python 8 | #a mimic joint can act as a gear between two joints 9 | #you can control the gear ratio in magnitude and sign (>0 reverses direction) 10 | 11 | import pybullet as p 12 | import time 13 | p.connect(p.GUI) 14 | p.loadURDF("plane.urdf",0,0,-2) 15 | wheelA = p.loadURDF("differential/diff_ring.urdf",[0,0,0]) 16 | for i in range(p.getNumJoints(wheelA)): 17 | print(p.getJointInfo(wheelA,i)) 18 | p.setJointMotorControl2(wheelA,i,p.VELOCITY_CONTROL,targetVelocity=0,force=0) 19 | 20 | 21 | c = p.createConstraint(wheelA,1,wheelA,3,jointType=p.JOINT_GEAR,jointAxis =[0,1,0],parentFramePosition=[0,0,0],childFramePosition=[0,0,0]) 22 | p.changeConstraint(c,gearRatio=1, maxForce=10000) 23 | 24 | c = p.createConstraint(wheelA,2,wheelA,4,jointType=p.JOINT_GEAR,jointAxis =[0,1,0],parentFramePosition=[0,0,0],childFramePosition=[0,0,0]) 25 | p.changeConstraint(c,gearRatio=-1, maxForce=10000) 26 | 27 | c = p.createConstraint(wheelA,1,wheelA,4,jointType=p.JOINT_GEAR,jointAxis =[0,1,0],parentFramePosition=[0,0,0],childFramePosition=[0,0,0]) 28 | p.changeConstraint(c,gearRatio=-1, maxForce=10000) 29 | 30 | 31 | p.setRealTimeSimulation(1) 32 | while(1): 33 | p.setGravity(0,0,-10) 34 | time.sleep(0.01) 35 | #p.removeConstraint(c) 36 | 37 | ``` 38 | 39 | 40 | Details on `createConstraint` can be found in the pybullet [getting started](https://docs.google.com/document/d/10sXEhzFRSnvFcl3XxNGhnD4N2SedqwdAvK3dsihxVUA/edit#heading=h.fq749wu22x4c) guide. 41 | 42 | ### Files in folder 43 | Since parameters like gear ratio and direction are required, one can find the `robotiq_2f_85_mimic_joints.urdf` which contains the mimic tags as in original URDF, which can be used as a reference. It was generated from `robotiq/robotiq_2f_robot/robot/simple_rq2f85_pybullet.urdf.xacro` as so: 44 | ``` 45 | rosrun xacro xacro --inorder simple_rq2f85_pybullet.urdf.xacro 46 | adaptive_transmission:="true" > robotiq_2f_85_mimic_joints.urdf 47 | ``` 48 | 49 | The URDF meant for use in pybullet is `robotiq_2f_85.urdf` and it is generated in a similar manner as above by running: 50 | ``` 51 | rosrun xacro xacro --inorder simple_rq2f85_pybullet.urdf.xacro > robotiq_2f_85.urdf 52 | ``` 53 | -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-base.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'gripper-2f.blend' 2 | # Material Count: 1 3 | 4 | newmtl Default 5 | Ns 96.078431 6 | Ka 1.000000 1.000000 1.000000 7 | Kd 0.640000 0.640000 0.640000 8 | Ks 0.500000 0.500000 0.500000 9 | Ke 0.000000 0.000000 0.000000 10 | Ni 1.000000 11 | d 1.000000 12 | illum 2 13 | map_Kd textures/gripper-2f_BaseColor.jpg 14 | -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-base.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-base.stl -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-coupler.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'gripper-2f.blend' 2 | # Material Count: 1 3 | 4 | newmtl Default 5 | Ns 96.078431 6 | Ka 1.000000 1.000000 1.000000 7 | Kd 0.640000 0.640000 0.640000 8 | Ks 0.500000 0.500000 0.500000 9 | Ke 0.000000 0.000000 0.000000 10 | Ni 1.000000 11 | d 1.000000 12 | illum 2 13 | map_Kd textures/gripper-2f_BaseColor.jpg 14 | -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-coupler.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-coupler.stl -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-driver.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'gripper-2f.blend' 2 | # Material Count: 1 3 | 4 | newmtl Default 5 | Ns 96.078431 6 | Ka 1.000000 1.000000 1.000000 7 | Kd 0.640000 0.640000 0.640000 8 | Ks 0.500000 0.500000 0.500000 9 | Ke 0.000000 0.000000 0.000000 10 | Ni 1.000000 11 | d 1.000000 12 | illum 2 13 | map_Kd textures/gripper-2f_BaseColor.jpg 14 | -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-driver.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-driver.stl -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-follower.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'gripper-2f.blend' 2 | # Material Count: 1 3 | 4 | newmtl Default 5 | Ns 96.078431 6 | Ka 1.000000 1.000000 1.000000 7 | Kd 0.640000 0.640000 0.640000 8 | Ks 0.500000 0.500000 0.500000 9 | Ke 0.000000 0.000000 0.000000 10 | Ni 1.000000 11 | d 1.000000 12 | illum 2 13 | map_Kd textures/gripper-2f_BaseColor.jpg 14 | -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-follower.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-follower.stl -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-pad.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-pad.stl -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-spring_link.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'gripper-2f.blend' 2 | # Material Count: 1 3 | 4 | newmtl Default 5 | Ns 96.078431 6 | Ka 1.000000 1.000000 1.000000 7 | Kd 0.640000 0.640000 0.640000 8 | Ks 0.500000 0.500000 0.500000 9 | Ke 0.000000 0.000000 0.000000 10 | Ni 1.000000 11 | d 1.000000 12 | illum 2 13 | map_Kd textures/gripper-2f_BaseColor.jpg 14 | -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-spring_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/robotiq/meshes/robotiq-2f-spring_link.stl -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/robotiq/textures/gripper-2f_BaseColor.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/robotiq/textures/gripper-2f_BaseColor.jpg -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/robotiq/textures/gripper-2f_Metallic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/robotiq/textures/gripper-2f_Metallic.jpg -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/robotiq/textures/gripper-2f_Normal.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/robotiq/textures/gripper-2f_Normal.jpg -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/robotiq/textures/gripper-2f_Roughness.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/robotiq/textures/gripper-2f_Roughness.jpg -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/shovel/meshes/shovel_base.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/shovel/meshes/shovel_base.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/shovel/meshes/shovel_blade.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/shovel/meshes/shovel_blade.STL -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/shovel/shovel.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/spatula/spatula-base.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/suction/suction-base.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/ur5/collision/base.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/collision/base.stl -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/ur5/collision/forearm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/collision/forearm.stl -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/ur5/collision/shoulder.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/collision/shoulder.stl -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/ur5/collision/upperarm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/collision/upperarm.stl -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/ur5/collision/wrist1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/collision/wrist1.stl -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/ur5/collision/wrist2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/collision/wrist2.stl -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/ur5/collision/wrist3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/collision/wrist3.stl -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/ur5/license.txt: -------------------------------------------------------------------------------- 1 | Original work Copyright 2018 ROS Industrial (https://rosindustrial.org/) 2 | Modified work Copyright 2018 Virtana, Inc (www.virtanatech.com) 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 13 | or implied. See the License for the specific language governing 14 | permissions and limitations under the License. 15 | 16 | Changenotes: 17 | 18 | 2018-05-08: Vijay Pradeep (vijay@virtanatech.com) 19 | The visual/ mesh files were generated from the dae files in the 20 | ros-industrial universal robot repo [1]. Since the collada pyBullet 21 | parser is somewhat limited, it is unable to parse the UR collada mesh 22 | files. Thus, we imported these collada files into blender and 23 | converted them into STL files. We lost material definitions during 24 | the conversion, but that's ok. 25 | 26 | The URDF was generated by running the xacro xml preprocessor on the 27 | URDF included in the ur_description repo already mentioned here. 28 | Additional manual tweaking was required to update resource paths and 29 | to remove errors caused by missing inertia elements. Varios Gazebo 30 | plugin tags were also removed. 31 | 32 | [1] - https://github.com/ros-industrial/universal_robot/tree/kinetic-devel/ur_description/meshes/ur5/visual 33 | -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/ur5/visual/base.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/visual/base.stl -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/ur5/visual/forearm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/visual/forearm.stl -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/ur5/visual/shoulder.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/visual/shoulder.stl -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/ur5/visual/upperarm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/visual/upperarm.stl -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/ur5/visual/wrist1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/visual/wrist1.stl -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/ur5/visual/wrist2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/visual/wrist2.stl -------------------------------------------------------------------------------- /envs/yang_domains/robots/assets/ur5/visual/wrist3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FanmingL/Recurrent-Offpolicy-RL/2f0c0121a664737af91b48aef50999d8836c2e61/envs/yang_domains/robots/assets/ur5/visual/wrist3.stl -------------------------------------------------------------------------------- /envs/yang_domains/robots/end_effector.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import pybullet as p 3 | 4 | ASSETS_PATH = Path(__file__).resolve().parent / 'assets' 5 | 6 | 7 | class EndEffector: 8 | """ 9 | A base class for UR5 end effectors. 10 | """ 11 | 12 | def __init__(self, urdf_path, load_position, load_orientation, ur5_install_joints, ee_tip_idx): 13 | """ 14 | The initialization of the UR5 end effector. 15 | 16 | :param urdf_path: the path of the urdf that describes the end effector 17 | :param load_position: the position to load the end effector 18 | :param load_orientation: the orientation to load the end effector 19 | :param ur5_install_joints: the home joints of the UR5 robot when installing this end effector 20 | :param ee_tip_idx: the body id of this end effector 21 | """ 22 | 23 | self.urdf_path = urdf_path 24 | self.load_position = load_position 25 | self.load_orientation = load_orientation 26 | self.ur5_install_joints = ur5_install_joints 27 | self.ee_tip_idx = ee_tip_idx 28 | 29 | # Loads the model. 30 | self.body_id = p.loadURDF(str(urdf_path), load_position, load_orientation) 31 | 32 | def get_body_id(self): 33 | """ 34 | Gets the body id of this end effector in PyBullet. 35 | 36 | :return: the body id 37 | """ 38 | 39 | return self.body_id 40 | 41 | def get_position_offset(self): 42 | """ 43 | Gets the position offset of this end effector. 44 | 45 | :return: the position offset 46 | """ 47 | 48 | return 0, 0, 0 49 | 50 | def get_ur5_install_joints(self): 51 | """ 52 | Gets the the home joints of the UR5 robot when installing this end effector. 53 | 54 | :return: the UR5 home joints for this end effector 55 | """ 56 | 57 | return self.ur5_install_joints 58 | 59 | def get_base_pose(self): 60 | """ 61 | Gets the base position and orientation of this end effector. 62 | 63 | :return: the position and orientation of the base 64 | """ 65 | 66 | return p.getBasePositionAndOrientation(self.body_id) 67 | 68 | def get_tip_pose(self): 69 | """ 70 | Gets the tip position and orientation of this end effector. 71 | 72 | :return: the position and orientation of the tip 73 | """ 74 | 75 | state = p.getLinkState(self.body_id, self.ee_tip_idx) 76 | 77 | return state[4], state[5] 78 | 79 | def reset(self, reset_base=False): 80 | """ 81 | Resets this end effector. 82 | 83 | :param reset_base: True if resetting the base pose, False otherwise 84 | """ 85 | 86 | if reset_base: 87 | p.resetBasePositionAndOrientation(bodyUniqueId=self.body_id, 88 | posObj=self.load_position, 89 | ornObj=self.load_orientation) 90 | -------------------------------------------------------------------------------- /envs/yang_domains/robots/rdda.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pybullet as p 3 | 4 | from .gripper import Gripper, ASSETS_PATH 5 | 6 | RDDA_URDF_PATH = ASSETS_PATH / 'rdda' / 'rdda.urdf' 7 | 8 | 9 | class Rdda(Gripper): 10 | """ 11 | A class for the RDDA gripper. 12 | """ 13 | 14 | def __init__(self, ur5_id, ee_id): 15 | """ 16 | The initialization of the RDDA gripper. 17 | 18 | :param ur5_id: the body id of the UR5 robot to install this RDDA 19 | :param ee_id: the link index of the UR5 robot to install this RDDA 20 | """ 21 | 22 | super().__init__(urdf_path=RDDA_URDF_PATH, 23 | load_position=(0.4746, 0.1092, 0.4195), 24 | load_orientation=p.getQuaternionFromEuler((np.pi / 2, np.pi / 2, 0)), 25 | ur5_install_joints=np.float32([-1, -0.5, 0.5, 0, 0.5, -1]) * np.pi, 26 | ee_tip_idx=1) # The ee tip of the RDDA gripper is the tip of the wide finger. 27 | 28 | # Installs the RDDA gripper on the UR5. 29 | self.offset = -0.025 30 | constraint_id = p.createConstraint( 31 | parentBodyUniqueId=ur5_id, 32 | parentLinkIndex=ee_id, 33 | childBodyUniqueId=self.body_id, 34 | childLinkIndex=-1, 35 | jointType=p.JOINT_FIXED, 36 | jointAxis=(0, 0, 0), 37 | parentFramePosition=(0, 0, 0), 38 | parentFrameOrientation=p.getQuaternionFromEuler((np.pi / 2, 0, np.pi / 2)), 39 | childFramePosition=(0, self.offset, 0)) 40 | p.changeConstraint(constraint_id, maxForce=50) 41 | 42 | def get_position_offset(self): 43 | """ 44 | Gets the position offset of this RDDA gripper. 45 | 46 | :return: the position offset 47 | """ 48 | 49 | return self.offset, 0, 0 50 | -------------------------------------------------------------------------------- /envs/yang_domains/robots/robotiq.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pybullet as p 3 | 4 | from ..robots.gripper import Gripper, ASSETS_PATH 5 | 6 | Robotiq_URDF_PATH = ASSETS_PATH / 'robotiq' / 'robotiq_2f_85.urdf' 7 | 8 | 9 | class Robotiq(Gripper): 10 | """ 11 | A class for the Robotiq gripper. 12 | """ 13 | 14 | def __init__(self, ur5_id, ee_id): 15 | """ 16 | The initialization of the Robotiq gripper. 17 | 18 | :param ur5_id: the body id of the UR5 robot to install this Robotiq 19 | :param ee_id: the link index of the UR5 robot to install this Robotiq 20 | """ 21 | 22 | super().__init__(urdf_path=Robotiq_URDF_PATH, 23 | load_position=(0.4868, 0.1093, 0.431594), 24 | load_orientation=p.getQuaternionFromEuler((np.pi, 0, 0)), 25 | ur5_install_joints=np.float32([-1, -0.5, 0.5, -0.5, -0.5, 0]) * np.pi, 26 | ee_tip_idx=0) # FIXME 27 | 28 | # Installs the Robotiq gripper on the UR5. 29 | constraint_id = p.createConstraint( 30 | parentBodyUniqueId=ur5_id, 31 | parentLinkIndex=ee_id, 32 | childBodyUniqueId=self.body_id, 33 | childLinkIndex=-1, 34 | jointType=p.JOINT_FIXED, 35 | jointAxis=(0, 0, 0), 36 | parentFramePosition=(0, 0, 0), 37 | parentFrameOrientation=p.getQuaternionFromEuler((0, 0, np.pi / 2)), 38 | childFramePosition=(0, 0, 0)) 39 | p.changeConstraint(constraint_id, maxForce=50) 40 | -------------------------------------------------------------------------------- /envs/yang_domains/robots/shovel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pybullet as p 3 | 4 | from .gripper import Gripper, ASSETS_PATH 5 | 6 | SHOVEL_URDF_PATH = ASSETS_PATH / 'shovel' / 'shovel.urdf' 7 | 8 | 9 | class Shovel(Gripper): 10 | """ 11 | A class for the shovel gripper. 12 | """ 13 | 14 | def __init__(self, ur5_id, ee_id): 15 | """ 16 | The initialization of the shovel gripper. 17 | 18 | :param ur5_id: the body id of the UR5 robot to install this shovel 19 | :param ee_id: the link index of the UR5 robot to install this shovel 20 | """ 21 | 22 | super().__init__(urdf_path=SHOVEL_URDF_PATH, 23 | load_position=(0.487, 0.109, 0.438), 24 | load_orientation=p.getQuaternionFromEuler((np.pi, 0, np.pi / 2)), 25 | ur5_install_joints=np.float32([-1, -0.5, 0.5, -0.5, -0.5, 0]) * np.pi, 26 | ee_tip_idx=1) 27 | 28 | # Installs the shovel gripper on the UR5. 29 | constraint_id = p.createConstraint( 30 | parentBodyUniqueId=ur5_id, 31 | parentLinkIndex=ee_id, 32 | childBodyUniqueId=self.body_id, 33 | childLinkIndex=-1, 34 | jointType=p.JOINT_FIXED, 35 | jointAxis=(0, 0, 0), 36 | parentFramePosition=(0, 0, 0), 37 | childFramePosition=(0, 0, 0.01)) 38 | p.changeConstraint(constraint_id, maxForce=50) 39 | -------------------------------------------------------------------------------- /envs/yang_domains/robots/spatula.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pybullet as p 3 | 4 | from .end_effector import EndEffector, ASSETS_PATH 5 | 6 | SPATULA_URDF_PATH = ASSETS_PATH / 'spatula' / 'spatula-base.urdf' 7 | 8 | 9 | class Spatula(EndEffector): 10 | """ 11 | A class for a simple spatula for pushing. 12 | """ 13 | 14 | def __init__(self, ur5_id, ee_id): 15 | """ 16 | The initialization of the spatula. 17 | 18 | :param ur5_id: the body id of the UR5 robot to install this spatula 19 | :param ee_id: the link index of the UR5 robot to install this spatula 20 | """ 21 | 22 | super().__init__(urdf_path=SPATULA_URDF_PATH, 23 | load_position=(0.487, 0.109, 0.438), 24 | load_orientation=p.getQuaternionFromEuler((np.pi, 0, 0)), 25 | ur5_install_joints=np.float32([-1, -0.5, 0.5, -0.5, -0.5, 0]) * np.pi, 26 | ee_tip_idx=0) 27 | 28 | # Installs the spatula on the UR5. 29 | constraint_id = p.createConstraint( 30 | parentBodyUniqueId=ur5_id, 31 | parentLinkIndex=ee_id, 32 | childBodyUniqueId=self.body_id, 33 | childLinkIndex=-1, 34 | jointType=p.JOINT_FIXED, 35 | jointAxis=(0, 0, 0), 36 | parentFramePosition=(0, 0, 0), 37 | parentFrameOrientation=p.getQuaternionFromEuler((0, 0, np.pi / 2)), 38 | childFramePosition=(0, 0, 0.01)) 39 | p.changeConstraint(constraint_id, maxForce=50) 40 | -------------------------------------------------------------------------------- /envs/yang_domains/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Ravens Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /envs/yang_domains/wrappers.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import numpy as np 3 | import gym 4 | import gym.spaces as spaces 5 | 6 | # FYI 7 | # - By default, wrappers directly copy the observation and action space from their wrappees. 8 | # - The observation method in ObservationWrappers will get used by both env.reset and env.step. 9 | # - In general, observation space is used only for getting observation shape for networks. 10 | 11 | 12 | class FilterObsByIndex(gym.ObservationWrapper): 13 | 14 | def _filter(self, array: np.array) -> np.array: 15 | return np.array( 16 | [x for i, x in enumerate(array) if i in self.indices_to_keep] 17 | ) 18 | 19 | def __init__(self, env, indices_to_keep: list): 20 | 21 | super().__init__(env) 22 | self.indices_to_keep = indices_to_keep 23 | 24 | new_high, new_low = self._filter(self.env.observation_space.high), self._filter(self.env.observation_space.low) 25 | self.observation_space = spaces.Box(low=new_low, high=new_high) 26 | 27 | def observation(self, observation): 28 | return self._filter(observation) 29 | 30 | 31 | class ConcatObs(gym.ObservationWrapper): 32 | 33 | def __init__(self, env, window_size: int): 34 | 35 | super().__init__(env) 36 | 37 | # get info on old observation space 38 | old_obs_space = env.observation_space 39 | old_obs_space_dim = old_obs_space.shape[0] 40 | old_obs_space_low, old_obs_space_high = old_obs_space.low, old_obs_space.high 41 | 42 | # change observation space 43 | self.observation_space = spaces.Box( 44 | low=np.array(list(old_obs_space_low) * window_size), 45 | high=np.array(list(old_obs_space_high) * window_size) 46 | ) 47 | 48 | self.window = deque(maxlen=window_size) 49 | for i in range(window_size - 1): 50 | self.window.append(np.zeros((old_obs_space_dim, ))) # append some dummy observations first 51 | 52 | self.window_size = window_size 53 | self.old_obs_space_dim = old_obs_space_dim 54 | 55 | def observation(self, obs: np.array) -> np.array: 56 | self.window.append(obs) 57 | return np.concatenate(self.window) 58 | 59 | def reset(self): 60 | for i in range(self.window_size - 1): 61 | self.window.append(np.zeros((self.old_obs_space_dim, ))) # append some dummy observations first 62 | observation = self.env.reset() 63 | return self.observation(observation) 64 | -------------------------------------------------------------------------------- /envs/yang_domains/wrappers_for_image.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import gym.spaces as spaces 3 | import numpy as np 4 | from collections import deque 5 | 6 | 7 | class GrayscaleImage(gym.ObservationWrapper): 8 | 9 | def __init__(self, env): 10 | super().__init__(env) 11 | old_observation_space = env.observation_space 12 | self.observation_space = spaces.Box( 13 | low=np.expand_dims(old_observation_space.low[0], axis=0), 14 | high=np.expand_dims(old_observation_space.high[0], axis=0), 15 | ) 16 | # https://www.kite.com/python/answers/how-to-convert-an-image-from-rgb-to-grayscale-in-python 17 | self.rgb_weights = np.array([0.2989, 0.5870, 0.1140]).reshape(3, 1) 18 | 19 | def observation(self, observation): 20 | observation = np.moveaxis(observation, 0, 2) 21 | observation = observation @ self.rgb_weights 22 | return np.moveaxis(observation, 2, 0) 23 | 24 | 25 | class Normalize255Image(gym.ObservationWrapper): 26 | 27 | def __init__(self, env): 28 | super().__init__(env) 29 | old_observation_space = env.observation_space 30 | self.observation_space = spaces.Box( 31 | low=old_observation_space.low / 255, 32 | high=old_observation_space.high / 255 33 | ) 34 | 35 | def observation(self, observation): 36 | return observation / 255 # uint8 automatically get converts to float64 37 | 38 | 39 | class ConcatImages(gym.ObservationWrapper): 40 | 41 | def __init__(self, env, window_size: int): 42 | super().__init__(env) 43 | 44 | self.window_size = window_size 45 | self.old_obs_space_shape = env.observation_space.shape 46 | 47 | old_obs_space = env.observation_space 48 | old_obs_low = old_obs_space.low 49 | old_obs_high = old_obs_space.high 50 | 51 | self.observation_space = spaces.Box( 52 | low=np.concatenate([old_obs_low, old_obs_low, old_obs_low]), 53 | high=np.concatenate([old_obs_high, old_obs_high, old_obs_high]), 54 | dtype=np.uint8 # a must for images to work with SB3 55 | ) 56 | 57 | self.window = deque(maxlen=window_size) 58 | for i in range(self.window_size - 1): 59 | self.window.append(np.zeros(self.old_obs_space_shape)) 60 | 61 | def observation(self, observation): 62 | self.window.append(observation) 63 | return np.concatenate(self.window) 64 | 65 | def reset(self): 66 | for i in range(self.window_size - 1): 67 | self.window.append(np.zeros(self.old_obs_space_shape)) 68 | observation = self.env.reset() 69 | return self.observation(observation) 70 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore", category=UserWarning, module="gym.*") 3 | warnings.filterwarnings("ignore", ".*The DISPLAY environment variable is missing") # dmc 4 | warnings.filterwarnings("ignore", category=FutureWarning) # dmc 5 | import multiprocessing 6 | from offpolicy_rnn import init_smart_logger, Parameter, alg_init 7 | 8 | 9 | def main(): 10 | if not multiprocessing.get_start_method(allow_none=True) == 'spawn': 11 | multiprocessing.set_start_method('spawn', force=True) 12 | init_smart_logger() 13 | parameter = Parameter() 14 | sac = alg_init(parameter) 15 | sac.train() 16 | 17 | 18 | if __name__ == '__main__': 19 | main() -------------------------------------------------------------------------------- /offpolicy_rnn/__init__.py: -------------------------------------------------------------------------------- 1 | from .algorithm.sac import SAC 2 | from .algorithm.sac_mlp import SAC_MLP 3 | from .config.load_config import init_smart_logger 4 | from .parameter.ParameterSAC import Parameter 5 | from .algorithm.sac_rnn_slice import SACRNNSlice 6 | from .algorithm.sac_mlp_redq import SAC_MLP_REDQ 7 | from .algorithm.sac_full_length_rnn_ensembleQ import SACFullLengthRNNEnsembleQ 8 | from .algorithm.sac_full_length_rnn_redq_sep_optim import SACFullLengthRNNREDQ_SEP_OPTIM 9 | from .algorithm.sac_full_length_rnn_ensembleQ_sep_optim import SACFullLengthRNNENSEMBLEQ_SEP_OPTIM 10 | from .algorithm.sac_full_length_rnn_redq import SACFullLengthRNNREDQ 11 | from .utility.alg_init import alg_init -------------------------------------------------------------------------------- /offpolicy_rnn/buffers/replay_memory_tail_padding.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import namedtuple 3 | import numpy as np 4 | import pickle 5 | import time 6 | from typing import List, Union, Tuple, Dict, Optional, Any 7 | from .replay_memory import Transition, tuplenames, MemoryArray 8 | 9 | 10 | class MemoryArrayTailZeroPadding(MemoryArray): 11 | def __init__(self, max_trajectory_num: int=1000, max_traj_step: int=1000, rnn_slice_length: int=1, fixed_length: int=32): 12 | super().__init__(max_trajectory_num, max_traj_step, rnn_slice_length) 13 | 14 | self.fixed_length = fixed_length 15 | assert self.fixed_length >= 1 16 | self.last_sampled_batch = None 17 | 18 | def reset(self): 19 | super().reset() 20 | self.last_sampled_batch = None 21 | 22 | def sample_fix_length_sub_trajs(self, batch_size, fix_length): 23 | list_ind = np.random.randint(0, self.transition_count, batch_size) 24 | res = [self.transition_buffer[ind] for ind in list_ind] 25 | if self.last_sampled_batch is None or not self.last_sampled_batch.shape[0] == batch_size or not self.last_sampled_batch.shape[1] == fix_length: 26 | trajs = [self.memory_buffer[traj_ind, point_ind:point_ind+self.fixed_length]for traj_ind, point_ind in res] 27 | trajs = np.array(trajs, copy=True) 28 | self.last_sampled_batch = trajs 29 | else: 30 | # reuse last sampled batch, avoid frequently allocating memory 31 | for ind, (traj_ind, point_ind) in enumerate(res): 32 | self.last_sampled_batch[ind, :, :] = self.memory_buffer[traj_ind, 33 | point_ind: point_ind+self.fixed_length, :] 34 | 35 | res = self.array_to_transition(self.last_sampled_batch) 36 | return res 37 | 38 | def _make_buffer(self, dim): 39 | self.memory_buffer = np.zeros((self.max_trajectory_num, self.max_traj_step + self.fixed_length - 1, dim)) 40 | print(f'Tail zero padding buffer init done!') 41 | 42 | 43 | -------------------------------------------------------------------------------- /offpolicy_rnn/config/common_config.yaml: -------------------------------------------------------------------------------- 1 | BACKUP_IGNORE_HEAD: 2 | - __p 3 | - . 4 | BACKUP_IGNORE_KEY: 5 | - logfile 6 | - logfile_bk 7 | BACKUP_IGNORE_TAIL: 8 | - stl 9 | - dae 10 | - STL 11 | - urdf 12 | - pkl 13 | - pdf 14 | BASE_PATH: null 15 | LOG_DIR_BACKING_NAME: OffpolicyRNN0302 16 | LOG_FOLDER_NAME: logfile 17 | LOG_FOLDER_NAME_BK: logfile_bk 18 | MAIN_MACHINE_IP: 19 | - 127.0.0.1 20 | MAIN_MACHINE_LOG_PATH: /path/to/remote/OffpolicyRNN0302 21 | MAIN_MACHINE_PASSWD: passwd 22 | MAIN_MACHINE_PORT: 23 | - 22 24 | MAIN_MACHINE_USER: username 25 | WORKSPACE_PATH: /path/to/remote/OffpolicyRNN0302 26 | -------------------------------------------------------------------------------- /offpolicy_rnn/config/experiment_config.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT_COMMON_PARAMETERS: 2 | MAX_TRAJ_STEP: 1000 3 | EXPERIMENT_TARGET: "Offpolicy RNN Test" 4 | SHORT_NAME_SUFFIX: INIT_TEST 5 | IMPORTANT_CONFIGS: 6 | - env_name 7 | - alg_name 8 | - seed 9 | 10 | -------------------------------------------------------------------------------- /offpolicy_rnn/config/load_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import smart_logger 3 | 4 | 5 | def init_smart_logger(): 6 | base_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | config_path = os.path.dirname(os.path.abspath(__file__)) 8 | config_relative_path = os.path.relpath(config_path, base_path) 9 | smart_logger.init_config(os.path.join(config_relative_path, 'common_config.yaml'), 10 | os.path.join(config_relative_path, 'experiment_config.yaml'), 11 | base_path) 12 | -------------------------------------------------------------------------------- /offpolicy_rnn/env_utils/make_env.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from envs.make_pomdp_env import make_pomdp_env 3 | from envs.pomdp_config import env_config 4 | import gym 5 | from typing import Dict 6 | import numpy as np 7 | import contextlib 8 | import random 9 | try: 10 | import envs.dmc 11 | has_dm_control = True 12 | except: 13 | has_dm_control = False 14 | 15 | 16 | @contextlib.contextmanager 17 | def fixed_seed(seed): 18 | """上下文管理器,用于同时固定random和numpy.random的种子""" 19 | state_np = np.random.get_state() 20 | state_random = random.getstate() 21 | torch_random = torch.get_rng_state() 22 | if torch.cuda.is_available(): 23 | # TODO: if not preproducable, check torch.get_rng_state 24 | torch_cuda_random = torch.cuda.get_rng_state() 25 | torch_cuda_random_all = torch.cuda.get_rng_state_all() 26 | np.random.seed(seed) 27 | random.seed(seed) 28 | torch.manual_seed(seed) 29 | 30 | try: 31 | yield 32 | finally: 33 | np.random.set_state(state_np) 34 | random.setstate(state_random) 35 | torch.set_rng_state(torch_random) 36 | if torch.cuda.is_available(): 37 | torch.cuda.set_rng_state(torch_cuda_random) 38 | torch.cuda.set_rng_state_all(torch_cuda_random_all) 39 | 40 | 41 | def make_env(env_name: str, seed: int) -> Dict: 42 | if env_name in env_config: 43 | with fixed_seed(seed): 44 | result = make_pomdp_env(env_name, seed) 45 | result['seed'] = seed 46 | return result 47 | else: 48 | with fixed_seed(seed): 49 | if env_name.startswith('dmc'): 50 | env = gym.make(env_name, seed=seed) 51 | max_episode_steps = env.unwrapped._max_episode_steps 52 | else: 53 | env = gym.make(env_name) 54 | max_episode_steps = env._max_episode_steps 55 | env.seed(seed) 56 | env.action_space.seed(seed+1) 57 | env.observation_space.seed(seed+2) 58 | result = { 59 | 'train_env': env, 60 | 'eval_env': env, 61 | 'train_tasks': [], 62 | 'eval_tasks': [None] * 20, 63 | 'max_rollouts_per_task': 1, 64 | 'max_trajectory_len': max_episode_steps, 65 | 'obs_dim': env.observation_space.shape[0], 66 | 'act_dim': env.action_space.shape[0] if isinstance(env.action_space, gym.spaces.Box) else env.action_space.n, 67 | 'act_continuous': isinstance(env.action_space, gym.spaces.Box), 68 | 'seed': seed, 69 | 'multiagent': False 70 | } 71 | 72 | return result -------------------------------------------------------------------------------- /offpolicy_rnn/models/conv1d/conv1d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Conv1d(torch.nn.Module): 6 | def __init__(self, in_channels, out_channels, d_conv=4, bias=True, ff=True): 7 | super().__init__() 8 | assert in_channels == out_channels 9 | self.in_channels = in_channels 10 | self.out_channels = out_channels 11 | self.d_conv = d_conv 12 | self.conv1d = torch.nn.Conv1d( 13 | in_channels=in_channels, 14 | out_channels=out_channels, 15 | bias=bias, 16 | kernel_size=d_conv, 17 | groups=in_channels, 18 | padding=0, 19 | ) 20 | self.desired_hidden_dim = self.in_channels * (self.d_conv - 1) 21 | self.use_ff = ff 22 | if ff: 23 | self.ff = PositionWiseFeedForward(out_channels, 0.0) 24 | 25 | 26 | def conv1d_func(self, x, hidden, mask): 27 | (b, l, d) = x.shape 28 | if mask is not None: 29 | x = x * mask 30 | x_input = torch.cat((hidden, x), dim=-2) 31 | x = x_input.transpose(-2, -1) 32 | x = self.conv1d(x)[:, :, :l] 33 | x = x.transpose(-2, -1) 34 | hidden = x_input[:, -(self.d_conv-1):, :] 35 | return x, hidden 36 | 37 | def forward(self, x, hidden=None, mask=None): 38 | batch_size = x.shape[0] 39 | if hidden is None: 40 | hidden = torch.zeros((batch_size, self.d_conv - 1, self.in_channels), device=x.device, dtype=x.dtype) 41 | else: 42 | hidden = hidden.reshape((batch_size, self.d_conv - 1, self.in_channels)) 43 | 44 | x, hidden = self.conv1d_func(x, hidden, mask) 45 | hidden = hidden.reshape((batch_size, 1, -1)) 46 | if self.use_ff: 47 | x = self.ff(x) 48 | return x, hidden 49 | 50 | 51 | 52 | 53 | class PositionWiseFeedForward(nn.Module): 54 | def __init__(self, d_model, dropout=0.0): 55 | super().__init__() 56 | self.w_1 = nn.Linear(d_model, d_model) 57 | self.w_2 = nn.Linear(d_model, d_model) 58 | self.activation = nn.GELU() 59 | self.dropout = nn.Dropout(dropout) 60 | self.layer_norm = nn.LayerNorm(d_model) 61 | 62 | def forward(self, x): 63 | x_ = self.dropout(self.activation(self.w_1(x))) 64 | return self.layer_norm(self.dropout(self.w_2(x_)) + x) 65 | 66 | 67 | -------------------------------------------------------------------------------- /offpolicy_rnn/models/ensemble_linear_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | from typing import Dict, List, Union, Tuple, Optional 6 | 7 | 8 | class EnsembleLinear(nn.Module): 9 | def __init__( 10 | self, 11 | input_dim: int, 12 | output_dim: int, 13 | num_ensemble: int, 14 | bias: bool=True, 15 | desire_ndim: int=None 16 | ) -> None: 17 | super().__init__() 18 | self.use_bias = bias 19 | self.desire_ndim = desire_ndim 20 | self.num_ensemble = num_ensemble 21 | 22 | self.register_parameter("weight", nn.Parameter(torch.zeros(num_ensemble, input_dim, output_dim))) 23 | if self.use_bias: 24 | self.register_parameter("bias", nn.Parameter(torch.zeros(num_ensemble, 1, output_dim))) 25 | 26 | nn.init.trunc_normal_(self.weight, std=1/(2*input_dim**0.5)) 27 | self.device = torch.device('cpu') 28 | 29 | def forward(self, x: torch.Tensor) -> torch.Tensor: 30 | weight = self.weight 31 | if self.use_bias: 32 | bias = self.bias 33 | else: 34 | bias = None 35 | if len(x.shape) == 2: 36 | x = torch.einsum('ij,bjk->bik', x, weight) 37 | elif len(x.shape) == 3: 38 | # TODO: judge the shape carefully according to your application 39 | if (self.desire_ndim is None or self.desire_ndim == 3) and x.shape[0] == weight.data.shape[0]: 40 | x = torch.einsum('bij,bjk->bik', x, weight) 41 | else: 42 | x = torch.einsum('cij,bjk->bcik', x, weight) 43 | elif len(x.shape) == 4: 44 | if (self.desire_ndim is None or self.desire_ndim == 4) and x.shape[0] == weight.data.shape[0]: 45 | x = torch.einsum('cbij,cjk->cbik', x, weight) 46 | else: 47 | x = torch.einsum('cdij,bjk->bcdik', x, weight) 48 | elif len(x.shape) == 5: 49 | x = torch.einsum('bcdij,bjk->bcdik', x, weight) 50 | if self.use_bias: 51 | assert x.shape[0] == bias.shape[0] and x.shape[-1] == bias.shape[-1] 52 | if len(x.shape) == 4: 53 | bias = bias.unsqueeze(1) 54 | elif len(x.shape) == 5: 55 | bias = bias.unsqueeze(1) 56 | bias = bias.unsqueeze(1) 57 | 58 | x = x + bias 59 | 60 | return x 61 | 62 | def to(self, device): 63 | if not device == self.device: 64 | self.device = device 65 | super().to(device) 66 | self.weight = self.weight.to(self.device) 67 | if self.use_bias: 68 | self.bias = self.bias.to(self.device) 69 | 70 | -------------------------------------------------------------------------------- /offpolicy_rnn/models/flash_attention/gpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from flash_attn.models.gpt_rl import GPTModel, GPT2Config 3 | from torch.nn import Module 4 | from flash_attn.utils.generation import InferenceParams 5 | 6 | 7 | # GPT2Config={ 8 | # "activation_function": "gelu_new", 9 | # "attn_pdrop": 0.1, 10 | # "bos_token_id": 50256, 11 | # "embd_pdrop": 0.1, 12 | # "eos_token_id": 50256, 13 | # "initializer_range": 0.02, 14 | # "layer_norm_epsilon": 1e-05, 15 | # "model_type": "gpt2", 16 | # "n_embd": 768, 17 | # "n_head": 12, 18 | # "n_inner": null, 19 | # "n_layer": 12, 20 | # "n_positions": 1024, 21 | # "reorder_and_upcast_attn": false, 22 | # "resid_pdrop": 0.1, 23 | # "scale_attn_by_inverse_layer_idx": false, 24 | # "scale_attn_weights": true, 25 | # "summary_activation": null, 26 | # "summary_first_dropout": 0.1, 27 | # "summary_proj_to_labels": true, 28 | # "summary_type": "cls_index", 29 | # "summary_use_proj": true, 30 | # "transformers_version": "4.39.3", 31 | # "use_cache": true, 32 | # "vocab_size": 50257 33 | # } 34 | 35 | 36 | class GPTLayer(Module): 37 | def __init__(self, ndim=768, nhead=12, nlayer=12, pdrop=0.1, norm_epsilon=1e-5): 38 | super().__init__() 39 | config = GPT2Config(n_embd=ndim, n_head=nhead, nlayer=nlayer, attn_pdrop=pdrop, 40 | layer_norm_epsilon=norm_epsilon, 41 | resid_pdrop=pdrop, embd_pdrop=pdrop, 42 | ) 43 | 44 | config.use_flash_attn = True 45 | config.fused_bias_fc = True 46 | config.fused_mlp = True 47 | config.fused_dropout_add_ln = True 48 | config.residual_in_fp32 = True 49 | config.rms_norm = True 50 | 51 | config.n_positions = 2048 52 | config.rotary_emb_fraction = 0.0 53 | config.rotary_emb_base = 0 54 | config.use_alibi = True 55 | # config.n_positions = 0 56 | # config.rotary_emb_fraction = 1.0 57 | # config.rotary_emb_base = 10000 58 | 59 | self.model = GPTModel(config, dtype=torch.float32) 60 | self.flash_attn_flag = True 61 | 62 | def make_init_hidden(self, max_seqlen, max_batch_size): 63 | return InferenceParams(max_seqlen, max_batch_size) 64 | 65 | def forward(self, hidden_states, attention_mask_in_length=None, inference_params=None): 66 | original_dtype = hidden_states.dtype 67 | with torch.cuda.amp.autocast(dtype=torch.bfloat16): 68 | result = self.model.forward(hidden_states, attention_mask_in_length, inference_params) 69 | result = result.to(original_dtype) 70 | return result 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /offpolicy_rnn/models/gilr/gilr.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | try: 7 | from .scan_triton.real_rnn_tie_input_gate import real_scan_tie_input_gate, real_scan_tie_input_gate_fused 8 | from .scan_triton.real_rnn_fast_pscan import pscan_fast 9 | except Exception as _: 10 | pass 11 | from .scan_triton.real_rnn_tie_input_gate_cpu import scan_cpu, scan_cpu_fuse 12 | from ..ensemble_linear_model import EnsembleLinear 13 | class GILRLayer(nn.Module): 14 | def __init__( 15 | self, 16 | input_dim, 17 | output_dim, 18 | factor=1, 19 | dropout=0.0, 20 | use_ff=True, 21 | batch_first=True, 22 | ): 23 | super().__init__() 24 | assert batch_first 25 | self.d_model = output_dim 26 | self.in_proj = EnsembleLinear(input_dim, self.d_model * factor, 2, desire_ndim=4) 27 | self.out_proj = torch.nn.Linear(self.d_model * factor, self.d_model * factor) 28 | self.dropout = nn.Dropout(dropout) 29 | self.layer_norm = nn.LayerNorm(factor * self.d_model) 30 | self.swish = nn.SiLU() 31 | self.use_ff = use_ff 32 | if self.use_ff: 33 | self.ff = PositionWiseFeedForward(self.d_model, dropout) 34 | self.device = torch.device('cpu') 35 | 36 | def rnn_parameters(self): 37 | return list(self.parameters(True)) 38 | 39 | def to(self, device): 40 | if not self.device == device: 41 | super().to(device) 42 | self.device = device 43 | 44 | def forward(self, x, hidden=None, rnn_start=None): 45 | u = self.in_proj(x) 46 | v = u[0] 47 | f = u[1] 48 | if hidden is None: 49 | hidden = torch.zeros((v.shape[0], 1, self.d_model * 2), device=v.device) 50 | else: 51 | hidden = hidden.transpose(0, 1) 52 | hidden_pre = hidden 53 | f = torch.sigmoid(f) 54 | v = torch.tanh(v) 55 | if rnn_start is not None: 56 | f = f * (1 - rnn_start) 57 | if torch.all(hidden_pre == 0) and not self.device == torch.device('cpu'): 58 | v = real_scan_tie_input_gate(v.contiguous(), f.contiguous()) 59 | hidden_pre = v[:, -1:, :] 60 | else: 61 | v, hidden_pre = scan_cpu(v, f, hidden_pre) 62 | out = self.out_proj(v) 63 | hidden = hidden_pre 64 | hidden = hidden.transpose(0, 1) 65 | if self.use_ff: 66 | out = self.ff(out) 67 | return out, hidden 68 | 69 | 70 | class PositionWiseFeedForward(nn.Module): 71 | def __init__(self, d_model, dropout=0.1): 72 | super().__init__() 73 | self.w_1 = nn.Linear(d_model, d_model) 74 | self.w_2 = nn.Linear(d_model, d_model) 75 | self.activation = nn.GELU() 76 | self.dropout = nn.Dropout(dropout) 77 | self.layer_norm = nn.LayerNorm(d_model) 78 | 79 | def forward(self, x): 80 | x_ = self.dropout(self.activation(self.w_1(x))) 81 | return self.layer_norm(self.dropout(self.w_2(x_)) + x) 82 | -------------------------------------------------------------------------------- /offpolicy_rnn/models/gilr/scan_triton/real_rnn_fast_pscan.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import triton 4 | import triton.language as tl 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.autograd import Function 8 | # from fastpscan.cuda_v3 import fn as pscan_cuda_fn 9 | # from fastpscan.cuda_v4 import fn as pscan_cuda_fn 10 | from fastpscan.triton_v2 import fn as pscan_cuda_fn 11 | 12 | 13 | def pscan_fast(v, f): 14 | B, L, C = v.shape 15 | desire_B = max(B, 4) 16 | h = torch.zeros((desire_B, C), device=v.device, dtype=v.dtype) 17 | 18 | v = v * (1 - f) 19 | if L < 1024: 20 | v = torch.cat((v, torch.zeros((B, 1024 - L, C), device=v.device, dtype=v.dtype)), dim=-2) 21 | f = torch.cat((f, torch.zeros((B, 1024 - L, C), device=v.device, dtype=v.dtype)), dim=-2) 22 | if desire_B > B: 23 | v = torch.cat((v, torch.zeros((desire_B - B, 1024, C), device=v.device, dtype=v.dtype)), dim=0) 24 | f = torch.cat((f, torch.zeros((desire_B - B, 1024, C), device=f.device, dtype=f.dtype)), dim=0) 25 | 26 | Y = pscan_cuda_fn(f, v, h) 27 | Y = Y[:B, :L, :] 28 | return Y 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /offpolicy_rnn/models/gilr/scan_triton/real_rnn_tie_input_gate_cpu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def real_scan_tie_input_gate_no_grad(v, f, h=None): 5 | B, L, C = v.shape 6 | # assert C % 256 == 0, 'Hidden dimension must be a multiple of 256' 7 | output = torch.zeros_like(v) 8 | h = torch.zeros((B, 1, v.shape[-1]), device=v.device, dtype=v.dtype) if h is None else h 9 | for l in range(L): 10 | f_item = f[:, l:l+1, :] 11 | h_new = h * f_item + v[:, l:l+1, :] * (1-f_item) 12 | output[:, l:l+1, :] = h_new 13 | h = h_new 14 | return output, h 15 | 16 | def real_scan_tie_input_gate_no_grad_fuse(v, f, h=None): 17 | B, L, C = v.shape 18 | # assert C % 256 == 0, 'Hidden dimension must be a multiple of 256' 19 | output = torch.zeros_like(v) 20 | h = torch.zeros((B, 1, v.shape[-1]), device=v.device, dtype=v.dtype) if h is None else h 21 | for l in range(L): 22 | f_item = torch.sigmoid(f[:, l:l+1, :]) 23 | h_new = h * f_item + v[:, l:l+1, :] * (1-f_item) 24 | output[:, l:l+1, :] = h_new 25 | h = h_new 26 | return output, h 27 | 28 | scan_cpu = real_scan_tie_input_gate_no_grad 29 | scan_cpu_fuse = real_scan_tie_input_gate_no_grad_fuse 30 | 31 | -------------------------------------------------------------------------------- /offpolicy_rnn/models/gilr_lstm/gilr_lstm.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | try: 7 | from .scan_triton.real_rnn_tie_input_gate import real_scan_tie_input_gate, real_scan_tie_input_gate_fused 8 | from .scan_triton.real_rnn_fast_pscan import pscan_fast 9 | except Exception as _: 10 | pass 11 | from .scan_triton.real_rnn_tie_input_gate_cpu import scan_cpu, scan_cpu_fuse 12 | from ..ensemble_linear_model import EnsembleLinear 13 | class GILRLSTMLayer(nn.Module): 14 | def __init__( 15 | self, 16 | input_dim, 17 | output_dim, 18 | factor=1, 19 | dropout=0.2, 20 | batch_first=True, 21 | ): 22 | super().__init__() 23 | assert batch_first 24 | self.d_model = output_dim 25 | self.in_proj = EnsembleLinear(input_dim, self.d_model * factor, 2, desire_ndim=4) 26 | self.middle_proj = EnsembleLinear(self.d_model * factor, self.d_model * factor, 4, desire_ndim=4) 27 | self.out_proj = torch.nn.Linear(self.d_model * factor, self.d_model * factor) 28 | self.layer_norm = nn.LayerNorm(factor * self.d_model) 29 | self.swish = nn.SiLU() 30 | self.device = torch.device('cpu') 31 | 32 | def rnn_parameters(self): 33 | return list(self.parameters(True)) 34 | 35 | def to(self, device): 36 | if not self.device == device: 37 | super().to(device) 38 | self.device = device 39 | 40 | def forward(self, x, hidden=None, rnn_start=None): 41 | u = self.in_proj(x) 42 | v = u[0] 43 | f = u[1] 44 | if hidden is None: 45 | hidden = torch.zeros((v.shape[0], 1, self.d_model * 2), device=v.device) 46 | else: 47 | hidden = hidden.transpose(0, 1) 48 | hidden_pre, hidden_middle = torch.chunk(hidden, 2, -1) 49 | f = torch.sigmoid(f) 50 | v = torch.tanh(v) 51 | if rnn_start is not None: 52 | f = f * (1 - rnn_start) 53 | if torch.all(hidden_pre == 0) and not self.device == torch.device('cpu'): 54 | v = real_scan_tie_input_gate(v.contiguous(), f.contiguous()) 55 | hidden_pre = v[:, -1:, :] 56 | else: 57 | v, hidden_pre = scan_cpu(v, f, hidden_pre) 58 | u = self.middle_proj.forward(v) 59 | f = torch.sigmoid(u[0]) 60 | i = torch.sigmoid(u[1]) 61 | o = torch.sigmoid(u[2]) 62 | z = torch.tanh(u[3]) 63 | if rnn_start is not None: 64 | f = f * (1 - rnn_start) 65 | if torch.all(hidden_middle == 0) and not self.device == torch.device('cpu'): 66 | v = i * z 67 | out = real_scan_tie_input_gate(v.contiguous(), f.contiguous()) 68 | hidden_middle = out[:, -1:, :] 69 | else: 70 | out, hidden_middle = scan_cpu(i * z, f, hidden_middle) 71 | out = out * o 72 | out = self.out_proj(out) 73 | hidden = torch.cat((hidden_pre, hidden_middle), dim=-1) 74 | hidden = hidden.transpose(0, 1) 75 | return out, hidden 76 | -------------------------------------------------------------------------------- /offpolicy_rnn/models/gilr_lstm/scan_triton/real_rnn_fast_pscan.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import triton 4 | import triton.language as tl 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.autograd import Function 8 | # from fastpscan.cuda_v3 import fn as pscan_cuda_fn 9 | # from fastpscan.cuda_v4 import fn as pscan_cuda_fn 10 | from fastpscan.triton_v2 import fn as pscan_cuda_fn 11 | 12 | 13 | def pscan_fast(v, f): 14 | B, L, C = v.shape 15 | desire_B = max(B, 4) 16 | h = torch.zeros((desire_B, C), device=v.device, dtype=v.dtype) 17 | 18 | v = v * (1 - f) 19 | if L < 1024: 20 | v = torch.cat((v, torch.zeros((B, 1024 - L, C), device=v.device, dtype=v.dtype)), dim=-2) 21 | f = torch.cat((f, torch.zeros((B, 1024 - L, C), device=v.device, dtype=v.dtype)), dim=-2) 22 | if desire_B > B: 23 | v = torch.cat((v, torch.zeros((desire_B - B, 1024, C), device=v.device, dtype=v.dtype)), dim=0) 24 | f = torch.cat((f, torch.zeros((desire_B - B, 1024, C), device=f.device, dtype=f.dtype)), dim=0) 25 | 26 | Y = pscan_cuda_fn(f, v, h) 27 | Y = Y[:B, :L, :] 28 | return Y 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /offpolicy_rnn/models/gilr_lstm/scan_triton/real_rnn_tie_input_gate_cpu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def real_scan_tie_input_gate_no_grad(v, f, h=None): 5 | B, L, C = v.shape 6 | # assert C % 256 == 0, 'Hidden dimension must be a multiple of 256' 7 | output = torch.zeros_like(v) 8 | h = torch.zeros((B, 1, v.shape[-1]), device=v.device, dtype=v.dtype) if h is None else h 9 | for l in range(L): 10 | f_item = f[:, l:l+1, :] 11 | h_new = h * f_item + v[:, l:l+1, :] * (1-f_item) 12 | output[:, l:l+1, :] = h_new 13 | h = h_new 14 | return output, h 15 | 16 | def real_scan_tie_input_gate_no_grad_fuse(v, f, h=None): 17 | B, L, C = v.shape 18 | # assert C % 256 == 0, 'Hidden dimension must be a multiple of 256' 19 | output = torch.zeros_like(v) 20 | h = torch.zeros((B, 1, v.shape[-1]), device=v.device, dtype=v.dtype) if h is None else h 21 | for l in range(L): 22 | f_item = torch.sigmoid(f[:, l:l+1, :]) 23 | h_new = h * f_item + v[:, l:l+1, :] * (1-f_item) 24 | output[:, l:l+1, :] = h_new 25 | h = h_new 26 | return output, h 27 | 28 | scan_cpu = real_scan_tie_input_gate_no_grad 29 | scan_cpu_fuse = real_scan_tie_input_gate_no_grad_fuse 30 | 31 | -------------------------------------------------------------------------------- /offpolicy_rnn/models/lru/scan_triton/complex_rnn_cpu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def forward_sequential_scan_complex_no_grad(v_real, v_imag, f_real, f_imag, h_real=None, h_imag=None): 5 | B, L, C = v_real.shape 6 | # assert C % 256 == 0, 'Hidden dimension must be a multiple of 256' 7 | 8 | hidden_real = torch.zeros_like(v_real) 9 | hidden_imag = torch.zeros_like(v_imag) 10 | 11 | h_real = torch.zeros((B, 1, v_real.shape[-1]), device=v_real.device, dtype=v_real.dtype) if h_real is None else h_real 12 | h_imag = torch.zeros((B, 1, v_imag.shape[-1]), device=v_imag.device, dtype=v_imag.dtype) if h_imag is None else h_imag 13 | 14 | for l in range(L): 15 | f_real_item = f_real[:, l:l+1, :] 16 | f_imag_item = f_imag[:, l:l+1, :] 17 | h_real_new = h_real * f_real_item - h_imag * f_imag_item + v_real[:, l:l+1, :] 18 | h_imag_new = h_real * f_imag_item + h_imag * f_real_item + v_imag[:, l:l+1, :] 19 | 20 | hidden_real[:, l:l+1, :] = h_real_new 21 | hidden_imag[:, l:l+1, :] = h_imag_new 22 | 23 | h_real = h_real_new 24 | h_imag = h_imag_new 25 | 26 | return hidden_real, hidden_imag, h_real, h_imag 27 | 28 | complex_scan_cpu = forward_sequential_scan_complex_no_grad 29 | 30 | 31 | -------------------------------------------------------------------------------- /offpolicy_rnn/models/mlp_base.py: -------------------------------------------------------------------------------- 1 | from .rnn_base import RNNBase 2 | from .RNNHidden import RNNHidden 3 | class MLPBase(RNNBase): 4 | def __init__(self, input_size, output_size, hidden_size_list, activation): 5 | super().__init__(input_size, output_size, hidden_size_list, activation, ['fc'] * len(activation)) 6 | self.empty_hidden_state = RNNHidden(0, []) 7 | 8 | def meta_forward(self, x, h=None, require_full_hidden=False): 9 | return super(MLPBase, self).meta_forward(x, self.empty_hidden_state, False)[0] 10 | 11 | def forward(self, x): 12 | return self.meta_forward(x) 13 | -------------------------------------------------------------------------------- /offpolicy_rnn/models/multi_ensemble_linear_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | from typing import Dict, List, Union, Tuple, Optional 6 | 7 | 8 | class MultiEnsembleLinear(nn.Module): 9 | def __init__( 10 | self, 11 | input_dim: int, 12 | output_dim: int, 13 | num_ensemble: int, 14 | multi_num: int, 15 | bias: bool=True, 16 | desire_ndim: int=None 17 | ) -> None: 18 | super().__init__() 19 | self.use_bias = bias 20 | self.desire_ndim = desire_ndim 21 | self.num_ensemble = num_ensemble 22 | 23 | self.register_parameter("weight", nn.Parameter(torch.zeros(multi_num, num_ensemble, input_dim, output_dim))) 24 | if self.use_bias: 25 | self.register_parameter("bias", nn.Parameter(torch.zeros(multi_num, num_ensemble, 1, output_dim))) 26 | 27 | nn.init.trunc_normal_(self.weight, std=1/(2*input_dim**0.5)) 28 | self.device = torch.device('cpu') 29 | 30 | def forward(self, x: torch.Tensor) -> torch.Tensor: 31 | assert self.desire_ndim is not None, 'for MLP network (data shape is (B, C)), desire_ndim should be 3, for RNN network, (data shape is (B, L, C)), desire_ndim should be 4, while got None' 32 | 33 | weight = self.weight 34 | if self.use_bias: 35 | bias = self.bias 36 | else: 37 | bias = None 38 | 39 | if len(x.shape) == 2: 40 | assert self.desire_ndim == 3 41 | x = torch.einsum('ij,cbjk->cbik', x, weight) 42 | elif len(x.shape) == 3: 43 | if self.desire_ndim == 3: 44 | # target is 45 | x = torch.einsum('bij,cbjk->cbik', x, weight) 46 | elif self.desire_ndim == 4: 47 | x = torch.einsum('hij,cbjk->cbhik', x, weight) 48 | else: 49 | raise NotImplementedError 50 | elif len(x.shape) == 4: 51 | if self.desire_ndim == 4: 52 | x = torch.einsum('bhij,cbjk->cbhik', x, weight) 53 | else: 54 | raise NotImplementedError 55 | elif len(x.shape) == 5: 56 | assert self.desire_ndim == 4 57 | x = torch.einsum('cbhij,cbjk->cbhik', x, weight) 58 | if bias is not None: 59 | if self.desire_ndim == 3: 60 | x = x + bias 61 | elif self.desire_ndim == 4: 62 | x = x + bias.unsqueeze(2) 63 | else: 64 | raise NotImplementedError 65 | return x 66 | 67 | def to(self, device): 68 | if not device == self.device: 69 | self.device = device 70 | super().to(device) 71 | self.weight = self.weight.to(self.device) 72 | if self.use_bias: 73 | self.bias = self.bias.to(self.device) 74 | 75 | -------------------------------------------------------------------------------- /offpolicy_rnn/models/readme.md: -------------------------------------------------------------------------------- 1 | # Implementations of Advanced RNNs 2 | 3 | Note: Most of these implementations are sourced from GitHub. We have standardized all `forward` functions in these models to conform to the standard RNN API. 4 | 5 | - **conv1d**: We wrap the 1D-convolutional layer as a special kind of RNN. 6 | - **gilr**: A linear RNN implemented with Triton. 7 | - **gilr_lstm**: An LSTM architecture simulated using `gilr`. 8 | - **lru**: Linear Recurrent Unit, implemented with Triton. 9 | - **s6**: Mamba, implemented with Triton. 10 | - **smamba**: Mamba, implemented with officially released CUDA code. An additional `start` variable is introduced as an input, resetting the hidden state to 0 when the `start` flag is true. 11 | - **flash_attention**: GPT implemented with flash_attention. 12 | -------------------------------------------------------------------------------- /offpolicy_rnn/models/s6/selective_scan/cpu_scan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import einsum, rearrange, repeat 3 | 4 | 5 | # credit: https://github.com/johnma2006/mamba-minimal/blob/master/model.py#L275 6 | def selective_scan_cpu(u, delta, A, B, C, D, start, initial_state=None): 7 | """Does selective scan algorithm. See: 8 | - Section 2 State Space Models in the Mamba paper [1] 9 | - Algorithm 2 in Section 3.2 in the Mamba paper [1] 10 | - run_SSM(A, B, C, u) in The Annotated S4 [2] 11 | 12 | This is the classic discrete state space formula: 13 | x(t + 1) = Ax(t) + Bu(t) 14 | y(t) = Cx(t) + Du(t) 15 | except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t). 16 | 17 | Args: 18 | u: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...) 19 | delta: shape (b, l, d_in) 20 | A: shape (d_in, n) 21 | B: shape (b, l, n) 22 | C: shape (b, l, n) 23 | D: shape (d_in,) 24 | start: (b, l, 1) 25 | 26 | Returns: 27 | output: shape (b, l, d_in) 28 | 29 | Official Implementation: 30 | selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86 31 | Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly. 32 | 33 | """ 34 | original_dtype = u.dtype 35 | u, delta, A, B, C, D = map(lambda x: x.float(), (u, delta, A, B, C, D)) 36 | (b, l, d_in) = u.shape 37 | n = A.shape[1] 38 | start = start.squeeze(-1) 39 | # Discretize continuous parameters (A, B) 40 | # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1]) 41 | # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: 42 | # "A is the more important term and the performance doesn't change much with the simplification on B" 43 | deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n')) 44 | deltaA = einsum(deltaA, 1 - start, 'b l d_in n, b l -> b l d_in n') 45 | deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n') 46 | 47 | # Perform selective scan (see scan_SSM() in The Annotated S4 [2]) 48 | # Note that the below is sequential, while the official implementation does a much faster parallel scan that 49 | # is additionally hardware-aware (like FlashAttention). 50 | x = torch.zeros((b, d_in, n), device=deltaA.device) 51 | if initial_state is not None: 52 | x += initial_state 53 | ys = [] 54 | for i in range(l): 55 | x = deltaA[:, i] * x + deltaB_u[:, i] 56 | y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') 57 | ys.append(y) 58 | y = torch.stack(ys, dim=1) # shape (b, l, d_in) 59 | 60 | y = y + u * D[None, None, :] 61 | 62 | return y.to(original_dtype), x -------------------------------------------------------------------------------- /offpolicy_rnn/models/smamba/mamba_ssm/ops/triton/layernorm_cpu.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | def layer_norm_fn(x, weight, bias, residual=None, eps=1e-6, prenorm=False, residual_in_fp32=False): 7 | dtype = x.dtype 8 | if residual_in_fp32: 9 | weight = weight.float() 10 | bias = bias.float() if bias is not None else None 11 | if residual_in_fp32: 12 | x = x.float() 13 | residual = residual.float() if residual is not None else residual 14 | if residual is not None: 15 | x = (x + residual).to(x.dtype) 16 | out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( 17 | dtype 18 | ) 19 | return out if not prenorm else (out, x) 20 | 21 | 22 | def rms_norm_fn(x, weight, bias, residual=None, eps=1e-6, prenorm=False, residual_in_fp32=False): 23 | dtype = x.dtype 24 | if residual_in_fp32: 25 | weight = weight.float() 26 | bias = bias.float() if bias is not None else None 27 | if residual_in_fp32: 28 | x = x.float() 29 | residual = residual.float() if residual is not None else residual 30 | if residual is not None: 31 | x = (x + residual).to(x.dtype) 32 | rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) 33 | out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) 34 | out = out.to(dtype) 35 | return out if not prenorm else (out, x) 36 | 37 | class RMSNorm(torch.nn.Module): 38 | def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): 39 | factory_kwargs = {"device": device, "dtype": dtype} 40 | super().__init__() 41 | self.eps = eps 42 | self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) 43 | self.register_parameter("bias", None) 44 | self.reset_parameters() 45 | 46 | def reset_parameters(self): 47 | torch.nn.init.ones_(self.weight) 48 | 49 | def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): 50 | return rms_norm_fn( 51 | x, 52 | self.weight, 53 | self.bias, 54 | residual=residual, 55 | eps=self.eps, 56 | prenorm=prenorm, 57 | residual_in_fp32=residual_in_fp32, 58 | ) 59 | 60 | -------------------------------------------------------------------------------- /offpolicy_rnn/policy_value_models/contextual_sac_policy.py: -------------------------------------------------------------------------------- 1 | from .contextual_sac_policy_single_head import ContextualSACPolicySingleHead 2 | from .contextual_sac_policy_double_head import ContextualSACPolicyDoubleHead 3 | 4 | class ContextualSACPolicy(ContextualSACPolicySingleHead): 5 | # class ContextualSACPolicy(ContextualSACPolicyDoubleHead): 6 | def __init__(self, state_dim, action_dim, embedding_size, embedding_hidden, embedding_activations, 7 | embedding_layer_type, uni_model_hidden, 8 | uni_model_activations, uni_model_layer_type, fix_rnn_length, uni_model_input_mapping_dim: int=0, 9 | reward_input=False, last_action_input=True, last_state_input=False, separate_encoder=False, 10 | output_logstd=True, name='ContextualSACPolicy'): 11 | super().__init__(state_dim, action_dim, embedding_size, embedding_hidden, embedding_activations, 12 | embedding_layer_type, uni_model_hidden, 13 | uni_model_activations, uni_model_layer_type, fix_rnn_length, uni_model_input_mapping_dim, reward_input, 14 | last_action_input, last_state_input, separate_encoder, output_logstd, 15 | name=name) 16 | -------------------------------------------------------------------------------- /offpolicy_rnn/policy_value_models/contextual_td3_policy.py: -------------------------------------------------------------------------------- 1 | from .contextual_sac_policy import ContextualSACPolicy 2 | from typing import List, Union, Tuple, Dict, Optional 3 | from ..models.RNNHidden import RNNHidden 4 | import torch 5 | 6 | class ContextualTD3Policy(ContextualSACPolicy): 7 | def __init__(self, state_dim, action_dim, embedding_size, embedding_hidden, embedding_activations, 8 | embedding_layer_type, uni_model_hidden, 9 | uni_model_activations, uni_model_layer_type, fix_rnn_length, uni_model_input_mapping_dim: int=0, 10 | reward_input=False, last_action_input=True, last_state_input=False, separate_encoder=False, sample_std=0.1): 11 | super().__init__(state_dim, action_dim, embedding_size, embedding_hidden, embedding_activations, 12 | embedding_layer_type, uni_model_hidden, 13 | uni_model_activations, uni_model_layer_type, fix_rnn_length, uni_model_input_mapping_dim, reward_input, 14 | last_action_input, last_state_input, separate_encoder, output_logstd=False, name='ContextualTD3Policy') 15 | self.sample_std = sample_std 16 | 17 | 18 | def forward(self, state: torch.Tensor, lst_state: torch.Tensor, lst_action: torch.Tensor, rnn_memory: Optional[RNNHidden], reward: Optional[torch.Tensor]=None, detach_embedding: bool=False) -> Tuple[ 19 | torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, RNNHidden, Optional[RNNHidden] 20 | ]: 21 | """ 22 | :param state: (seq_len, dim) or (batch_size, seq_len, dim) 23 | :param lst_state: (seq_len, dim) or (batch_size, seq_len, dim) 24 | :param lst_action: (seq_len, dim) or (batch_size, seq_len, dim) 25 | :param reward: (seq_len, 1) or (batch_size, seq_len, dim) 26 | :param rnn_memory: 27 | :return: 28 | """ 29 | embedding_input = self.get_embedding_input(state, lst_state, lst_action, reward) 30 | model_output, rnn_memory, embedding_output, full_rnn_memory = self.meta_forward(embedding_input, state, rnn_memory, detach_embedding) 31 | action_mean = torch.tanh(model_output) 32 | action_sample = torch.tanh(model_output) + torch.randn_like(model_output) * self.sample_std 33 | action_sample = torch.clamp(action_sample, -1, 1) 34 | # not used in TD3 35 | log_prob = torch.zeros_like(action_sample) 36 | return action_mean, embedding_output, action_sample, log_prob, rnn_memory, full_rnn_memory 37 | -------------------------------------------------------------------------------- /offpolicy_rnn/policy_value_models/contextual_td3_value.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ..models.contextual_model import ContextualModel 3 | import torch 4 | import os 5 | from ..models.RNNHidden import RNNHidden 6 | from typing import List, Union, Tuple, Dict, Optional 7 | from .utils import nearest_power_of_two, nearest_power_of_two_half 8 | from .contextual_sac_value import ContextualSACValue 9 | class ContextualTD3Value(ContextualSACValue): 10 | def __init__(self, state_dim, action_dim, embedding_size, embedding_hidden, embedding_activations, 11 | embedding_layer_type, uni_model_hidden, 12 | uni_model_activations, uni_model_layer_type, fix_rnn_length, uni_model_input_mapping_dim: int=0, 13 | reward_input=False, last_action_input=True, last_state_input=False, separate_encoder=False): 14 | super(ContextualTD3Value, self).__init__(state_dim, action_dim, embedding_size, embedding_hidden, embedding_activations, 15 | embedding_layer_type, uni_model_hidden, 16 | uni_model_activations, uni_model_layer_type, fix_rnn_length, uni_model_input_mapping_dim, reward_input, 17 | last_action_input, last_state_input, separate_encoder, name='ContextualTD3Value') 18 | -------------------------------------------------------------------------------- /offpolicy_rnn/policy_value_models/make_models.py: -------------------------------------------------------------------------------- 1 | from .contextual_sac_policy import ContextualSACPolicy 2 | from .contextual_td3_policy import ContextualTD3Policy 3 | from .contextual_sac_value import ContextualSACValue 4 | from .contextual_td3_value import ContextualTD3Value 5 | from .contextual_sac_discrete_policy import ContextualSACDiscretePolicy 6 | from .contextual_sac_discrete_value import ContextualSACDiscreteValue 7 | import torch 8 | from typing import Optional, Union 9 | 10 | def make_policy_model(policy_args, base_alg_name, discrete) -> Union[ContextualTD3Policy, ContextualSACPolicy, ContextualSACDiscretePolicy]: 11 | if base_alg_name == 'sac': 12 | if discrete: 13 | policy = ContextualSACDiscretePolicy(**policy_args) 14 | else: 15 | policy = ContextualSACPolicy(**policy_args) 16 | elif base_alg_name == 'td3': 17 | policy = ContextualTD3Policy(**policy_args) 18 | return policy 19 | 20 | def make_value_model(value_args, base_alg_name, discrete) -> Union[ContextualTD3Value, ContextualSACValue, ContextualSACDiscreteValue]: 21 | if base_alg_name == 'sac': 22 | if discrete: 23 | value = ContextualSACDiscreteValue(**value_args) 24 | else: 25 | value = ContextualSACValue(**value_args) 26 | elif base_alg_name == 'td3': 27 | value = ContextualTD3Value(**value_args) 28 | return value 29 | -------------------------------------------------------------------------------- /offpolicy_rnn/policy_value_models/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | def nearest_power_of_two_half(x): 4 | # 计算0.5 * x 5 | target = 0.5 * x 6 | # 计算对数,并四舍五入到最接近的整数 7 | nearest_exp = round(math.log(target, 2)) 8 | # 计算2的nearest_exp次幂 9 | if nearest_exp < 0: 10 | nearest_exp = 0 11 | nearest_power = int(math.ceil(2 ** nearest_exp)) 12 | 13 | return nearest_power 14 | 15 | def nearest_power_of_two(x): 16 | # 计算0.5 * x 17 | target = x 18 | # 计算对数,并四舍五入到最接近的整数 19 | nearest_exp = int(math.ceil(math.log(target, 2))) 20 | # 计算2的nearest_exp次幂 21 | if nearest_exp < 0: 22 | nearest_exp = 0 23 | nearest_power = int(math.ceil(2 ** nearest_exp)) 24 | return nearest_power -------------------------------------------------------------------------------- /offpolicy_rnn/utility/ValueScheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | class CosineScheduler: 4 | def __init__(self, total_steps, initial_value, min_value): 5 | """ 6 | 初始化余弦学习率调度器 7 | :param total_steps: 总轮数 8 | :param initial_value: 初始学习率 9 | :param min_value: 最小学习率 10 | """ 11 | self.total_steps = total_steps 12 | self.initial_value = initial_value 13 | self.min_value = min_value 14 | self.current_step = 0 15 | self._value = None 16 | self._last_value = None 17 | 18 | def validate(self): 19 | if self._value < self.min_value: 20 | self._value = self.min_value 21 | 22 | if self._last_value is not None: 23 | if self._last_value < self._value: 24 | self._value = self._last_value 25 | else: 26 | self._last_value = self._value 27 | 28 | def step(self): 29 | """ 30 | 更新到下一步,计算当前的学习率 31 | """ 32 | self.current_step += 1 33 | 34 | self._value = self.min_value + 0.5 * (self.initial_value - self.min_value) * (1 + math.cos(self.current_step / self.total_steps * math.pi)) 35 | self.validate() 36 | return self._value 37 | 38 | def get_value(self): 39 | """ 40 | 获取当前的学习率 41 | """ 42 | if self.current_step == 0: 43 | return self.initial_value 44 | else: 45 | return self._value 46 | 47 | class LinearScheduler: 48 | def __init__(self, total_steps, initial_value, end_value): 49 | """ 50 | 初始化余弦学习率调度器 51 | :param total_steps: 总轮数 52 | :param initial_value: 初始学习率 53 | :param min_value: 最小学习率 54 | """ 55 | self.total_steps = total_steps 56 | self.initial_value = initial_value 57 | self.end_value = end_value 58 | self.current_step = 0 59 | self._value = self.initial_value 60 | 61 | def validate(self): 62 | if self.current_step >= self.total_steps: 63 | self._value = self.end_value 64 | 65 | def step(self): 66 | """ 67 | 更新到下一步,计算当前的学习率 68 | """ 69 | self.current_step += 1 70 | self._value = self.current_step / self.total_steps * (self.end_value - self.initial_value) + self.initial_value 71 | self.validate() 72 | return self._value 73 | 74 | def get_value(self): 75 | """ 76 | 获取当前的学习率 77 | """ 78 | if self.current_step == 0: 79 | return self.initial_value 80 | else: 81 | return self._value -------------------------------------------------------------------------------- /offpolicy_rnn/utility/alg_init.py: -------------------------------------------------------------------------------- 1 | from ..algorithm.sac import SAC 2 | from ..algorithm.sac_mlp import SAC_MLP 3 | from ..parameter.ParameterSAC import Parameter 4 | from ..algorithm.sac_rnn_slice import SACRNNSlice 5 | from ..algorithm.sac_mlp_redq import SAC_MLP_REDQ 6 | from ..algorithm.sac_full_length_rnn_ensembleQ import SACFullLengthRNNEnsembleQ 7 | from ..algorithm.sac_full_length_rnn_redq import SACFullLengthRNNREDQ 8 | from ..algorithm.sac_mlp_redq_ensemble_q import SAC_MLP_REDQ_EnsembleQ 9 | from ..algorithm.sac_full_length_rnn_redq_sep_optim import SACFullLengthRNNREDQ_SEP_OPTIM 10 | from ..algorithm.sac_full_length_rnn_ensembleQ_sep_optim import SACFullLengthRNNENSEMBLEQ_SEP_OPTIM 11 | from ..algorithm.td3_full_length_rnn_ensembleQ import TD3FullLengthRNNEnsembleQ 12 | from ..algorithm.td3_full_length_rnn_redq import TD3FullLengthRNNREDQ 13 | from ..algorithm.td3_full_length_rnn_redq_sep_optim import TD3FullLengthRNNREDQ_SEP_OPTIM 14 | 15 | 16 | def alg_init(parameter: Parameter) -> SAC: 17 | if parameter.alg_name == 'sac_no_train': 18 | sac = SAC(parameter) 19 | elif parameter.alg_name == 'sac_mlp': 20 | sac = SAC_MLP(parameter) 21 | elif parameter.alg_name == 'sac_mlp_redq': 22 | sac = SAC_MLP_REDQ(parameter) 23 | elif parameter.alg_name == 'sac_rnn_slice': 24 | assert parameter.rnn_slice_length > 0 25 | sac = SACRNNSlice(parameter) 26 | elif parameter.alg_name == 'sac_rnn_full_horizon_ensembleQ': 27 | sac = SACFullLengthRNNEnsembleQ(parameter) 28 | elif parameter.alg_name == 'sac_rnn_full_horizon_redQ': 29 | sac = SACFullLengthRNNREDQ(parameter) 30 | elif parameter.alg_name == 'sac_rnn_full_horizon_redQ_sep_optim': # STAR it!!!! 31 | sac = SACFullLengthRNNREDQ_SEP_OPTIM(parameter) 32 | elif parameter.alg_name == 'td3_rnn_full_horizon_redQ_sep_optim': # STAR it!!!! 33 | parameter.base_algorithm = 'td3' 34 | sac = TD3FullLengthRNNREDQ_SEP_OPTIM(parameter) 35 | elif parameter.alg_name == 'sac_rnn_full_horizon_ensemble_q_sep_optim': 36 | sac = SACFullLengthRNNENSEMBLEQ_SEP_OPTIM(parameter) 37 | elif parameter.alg_name == 'sac_mlp_redq_ensemble_q': 38 | sac = SAC_MLP_REDQ_EnsembleQ(parameter) 39 | elif parameter.alg_name == 'td3_rnn_full_horizon_ensembleQ': 40 | parameter.base_algorithm = 'td3' 41 | sac = TD3FullLengthRNNEnsembleQ(parameter) 42 | elif parameter.alg_name == 'td3_rnn_full_horizon_redQ': 43 | parameter.base_algorithm = 'td3' 44 | sac = TD3FullLengthRNNREDQ(parameter) 45 | else: 46 | raise NotImplementedError(f'Algorithm {parameter.alg_name} has not been implemented!') 47 | return sac 48 | -------------------------------------------------------------------------------- /offpolicy_rnn/utility/count_parameters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | def count_parameters(model): 4 | """ 5 | Counts the number of trainable parameters in a PyTorch model. 6 | 7 | Args: 8 | model (nn.Module): The PyTorch model to count parameters for. 9 | 10 | Returns: 11 | int: The total number of trainable parameters. 12 | """ 13 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 14 | -------------------------------------------------------------------------------- /offpolicy_rnn/utility/q_value_guard.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class QValueGuard: 5 | """ 6 | avoid Q loss from diverging 7 | usage: 8 | target_Q = r + gamma * q_value_guard.clamp(next_q) 9 | q_value_guard.update(target_Q) 10 | """ 11 | def __init__(self, guard_min=True, guard_max=True, decay_ratio=1.0): 12 | self._min = 1000000 if guard_min else None 13 | self._max = -1000000 if guard_max else None 14 | self._init_flag = True 15 | self._decay_ratio = decay_ratio 16 | 17 | def reset(self): 18 | self._min = 1000000 if self._min is not None else None 19 | self._max = -1000000 if self._max is not None else None 20 | self._init_flag = True 21 | 22 | def clamp(self, value: torch.Tensor): 23 | if self._init_flag: 24 | self._min = value.min().item() if self._min is not None else None 25 | self._max = value.max().item() if self._max is not None else None 26 | self._init_flag = False 27 | return value.clamp(min=self._min, max=self._max) 28 | 29 | def update(self, value): 30 | value_min = value.min().item() 31 | value_max = value.max().item() 32 | self._min = min(self._min, value_min if self._min is not None else None) 33 | self._max = max(self._max, value_max if self._max is not None else None) 34 | if self._decay_ratio < 1: 35 | if self._min is not None: 36 | self._min = self._decay_ratio * self._min + (1-self._decay_ratio) * value_min 37 | if self._max is not None: 38 | self._max = self._decay_ratio * self._max + (1-self._decay_ratio) * value_max 39 | 40 | def get_min(self) -> float: 41 | return self._min 42 | 43 | def get_max(self) -> float: 44 | return self._max 45 | 46 | -------------------------------------------------------------------------------- /offpolicy_rnn/utility/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import inspect 3 | import numpy as np 4 | 5 | class Timer: 6 | def __init__(self): 7 | self.check_points = {} 8 | self.points_time = {} 9 | self.need_summary = {} 10 | self.init_time = time.time() 11 | 12 | def reset(self): 13 | self.check_points = {} 14 | self.points_time = {} 15 | self.need_summary = {} 16 | 17 | @staticmethod 18 | def file_func_line(stack=1): 19 | frame = inspect.stack()[stack][0] 20 | info = inspect.getframeinfo(frame) 21 | return info.filename, info.function, info.lineno 22 | 23 | @staticmethod 24 | def line(stack=2, short=False): 25 | file, func, lineo = Timer.file_func_line(stack) 26 | if short: 27 | return f"line_{lineo}_func_{func}" 28 | return f"line: {lineo}, func: {func}, file: {file}" 29 | 30 | def register_point(self, tag=None, stack=3, short=True, need_summary=True, level=0): 31 | if tag is None: 32 | tag = self.line(stack, short) 33 | if False and not tag.startswith('__'): 34 | print(f'arrive {tag}, time: {time.time() - self.init_time}, level: {level}') 35 | if level not in self.check_points: 36 | self.check_points[level] = [] 37 | self.points_time[level] = [] 38 | self.need_summary[level] = set() 39 | self.check_points[level].append(tag) 40 | self.points_time[level].append(time.time()) 41 | if need_summary: 42 | self.need_summary[level].add(tag) 43 | 44 | def register_end(self, stack=4, level=0): 45 | self.register_point('__timer_end_unique', stack, need_summary=False, level=level) 46 | 47 | def summary(self, summation=False): 48 | if len(self.check_points) == 0: 49 | return dict() 50 | res = {} 51 | for level in self.check_points: 52 | self.register_point('__timer_finale_unique', level=level) 53 | res_tmp = {} 54 | for ind, item in enumerate(self.check_points[level][:-1]): 55 | time_now = self.points_time[level][ind] 56 | time_next = self.points_time[level][ind + 1] 57 | if item in res_tmp: 58 | res_tmp[item].append(time_next - time_now) 59 | else: 60 | res_tmp[item] = [time_next - time_now] 61 | for k, v in res_tmp.items(): 62 | if k in self.need_summary[level]: 63 | res['period_' + k] = np.mean(v) 64 | if summation: 65 | res['sum_' + k] = np.sum(v) 66 | self.reset() 67 | return res 68 | 69 | 70 | def test_timer(): 71 | timer = Timer() 72 | for i in range(4): 73 | timer.register_point() 74 | time.sleep(1) 75 | for k, v in timer.summary().items(): 76 | print(f'{k}, {v}') 77 | 78 | if __name__ == '__main__': 79 | test_timer() -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | Box2D==2.3.10 2 | causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git#egg=causal_conv1d 3 | cocos==0.2.2 4 | dm_control==1.0.16 5 | dmc2gym @ git+https://github.com/denisyarats/dmc2gym.git#egg=dmc2gym 6 | einops==0.8.0 7 | flash_attn @ git+https://github.com/Dao-AILab/flash-attention.git#egg=flash_attn 8 | gin==0.1.006 9 | gym==0.21.0 10 | jax==0.4.30 11 | matplotlib==3.7.4 12 | meshcat==0.3.2 13 | ml_collections==0.1.1 14 | mujoco_py==2.1.2.14 15 | numpy==1.24.0 16 | omg==1.1.3 17 | packaging==24.1 18 | pandas==2.0.3 19 | psutil==5.9.7 20 | pybullet==3.2.5 21 | pycolab==1.2 22 | pygame==2.5.2 23 | pyglet==2.0.15 24 | pytest==7.4.4 25 | python_dateutil==2.8.2 26 | PyYAML==6.0.1 27 | PyYAML==6.0.1 28 | roboschool==1.0.34 29 | scikit_learn==1.5.1 30 | scipy==1.14.0 31 | seaborn==0.13.2 32 | setuptools==66.0.0 33 | six==1.16.0 34 | smart_logger @ git+https://github.com/FanmingL/SmartLogger#egg=smart_logger 35 | tensorboardX==2.6.2.2 36 | tensorboardX==2.6.2.2 37 | tensorflow_cpu==2.13.1 38 | torch==2.1.2+cu121.with.pypi.cudnn 39 | tqdm==4.66.1 40 | transformers==4.39.3 41 | transforms3d==0.4.2 42 | mamba_ssm @ git+https://github.com/state-spaces/mamba#egg=mamba_ssm 43 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='offpolicy_rnn', 5 | version='0.0.1', 6 | author="Fan-Ming Luo", 7 | author_email="luofm@lamda.nju.edu.cn", 8 | description="An implementation of RNN policy and value function in off-policy RL", 9 | packages=find_packages(), 10 | install_requires=[ 11 | ], 12 | ) 13 | --------------------------------------------------------------------------------