├── .gitignore ├── .pylintrc ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── collect_demos.py ├── configs ├── actor │ ├── big_deepset.yaml │ ├── deepset.yaml │ ├── mlp.yaml │ ├── padded_mlp.yaml │ ├── simple_deepset.yaml │ └── transformer.yaml ├── bc_experiment │ ├── big_transformer.yaml │ ├── deepset.yaml │ └── mlp.yaml ├── bc_setup │ ├── 2s2p.yaml │ ├── 3p.yaml │ └── 3s.yaml ├── critic │ ├── big_deepset.yaml │ ├── deepset.yaml │ ├── mlp.yaml │ ├── padded_mlp.yaml │ ├── simple_deepset.yaml │ └── transformer.yaml ├── experiment │ ├── big_deepset.yaml │ ├── big_transformer.yaml │ ├── deepset.yaml │ ├── deepset_policyonly.yaml │ ├── med_transformer.yaml │ ├── mlp.yaml │ ├── padded_mlp.yaml │ ├── simple_deepset.yaml │ ├── transformer.yaml │ └── transformer_policyonly.yaml ├── hydra │ └── launcher │ │ ├── local.yaml │ │ ├── slurm.yaml │ │ ├── slurm_bc.yaml │ │ └── slurm_cpu.yaml ├── main.yaml ├── main_bc.yaml └── setup │ ├── 1p.yaml │ ├── 1pdense.yaml │ ├── 1s.yaml │ ├── 1s1pdense.yaml │ ├── 1s_or_1pdense.yaml │ ├── 2pdense.yaml │ ├── 2s.yaml │ ├── 2s2pdense.yaml │ ├── 2s2pdense_fastexp.yaml │ ├── 2s_or_2pdense.yaml │ ├── 2s_or_2pdense_fastexp.yaml │ ├── 2stackdense.yaml │ ├── 3p.yaml │ ├── 3pdense.yaml │ ├── 3pdense_fastexp.yaml │ ├── 3s.yaml │ ├── 3sdense.yaml │ ├── demos_1pdense.yaml │ ├── demos_3p.yaml │ ├── demos_3pdense.yaml │ └── demos_3s.yaml ├── env.yml ├── envs ├── LICENSE.md ├── README.md ├── __init__.py ├── assets │ ├── LICENSE.md │ ├── fetch │ │ ├── pick_and_place.xml │ │ ├── push1.xml │ │ ├── push1_collide.xml │ │ ├── push2.xml │ │ ├── push2_collide.xml │ │ ├── push3.xml │ │ ├── push3_collide.xml │ │ ├── push4.xml │ │ ├── push4_collide.xml │ │ ├── push5.xml │ │ ├── push5_collide.xml │ │ ├── push6.xml │ │ ├── push6_collide.xml │ │ ├── reach.xml │ │ ├── robot.xml │ │ ├── shared.xml │ │ ├── switch1.xml │ │ ├── switch1push1.xml │ │ ├── switch1push1_collide.xml │ │ ├── switch2.xml │ │ ├── switch2push2.xml │ │ ├── switch2push2_collide.xml │ │ ├── switch3.xml │ │ ├── switch3push3.xml │ │ ├── switch3push3_collide.xml │ │ ├── switch4.xml │ │ ├── switch5.xml │ │ └── switch6.xml │ ├── hand │ │ ├── manipulate_block.xml │ │ ├── manipulate_block_touch_sensors.xml │ │ ├── manipulate_egg.xml │ │ ├── manipulate_egg_touch_sensors.xml │ │ ├── manipulate_pen.xml │ │ ├── manipulate_pen_touch_sensors.xml │ │ ├── reach.xml │ │ ├── robot.xml │ │ ├── robot_touch_sensors_92.xml │ │ ├── shared.xml │ │ ├── shared_asset.xml │ │ └── shared_touch_sensors_92.xml │ ├── stls │ │ ├── fetch │ │ │ ├── base_link_collision.stl │ │ │ ├── bellows_link_collision.stl │ │ │ ├── elbow_flex_link_collision.stl │ │ │ ├── estop_link.stl │ │ │ ├── forearm_roll_link_collision.stl │ │ │ ├── gripper_link.stl │ │ │ ├── head_pan_link_collision.stl │ │ │ ├── head_tilt_link_collision.stl │ │ │ ├── l_wheel_link_collision.stl │ │ │ ├── laser_link.stl │ │ │ ├── r_wheel_link_collision.stl │ │ │ ├── shoulder_lift_link_collision.stl │ │ │ ├── shoulder_pan_link_collision.stl │ │ │ ├── torso_fixed_link.stl │ │ │ ├── torso_lift_link_collision.stl │ │ │ ├── upperarm_roll_link_collision.stl │ │ │ ├── wrist_flex_link_collision.stl │ │ │ └── wrist_roll_link_collision.stl │ │ ├── hand │ │ │ ├── F1.stl │ │ │ ├── F2.stl │ │ │ ├── F3.stl │ │ │ ├── TH1_z.stl │ │ │ ├── TH2_z.stl │ │ │ ├── TH3_z.stl │ │ │ ├── forearm_electric.stl │ │ │ ├── forearm_electric_cvx.stl │ │ │ ├── knuckle.stl │ │ │ ├── lfmetacarpal.stl │ │ │ ├── palm.stl │ │ │ └── wrist.stl │ │ └── switch │ │ │ ├── lightswitch.stl │ │ │ └── lightswitchbase.stl │ └── textures │ │ ├── block.png │ │ └── block_hidden.png └── fetch_push_multi.py ├── evaluate.py ├── her_modules ├── __init__.py └── her.py ├── launch.py ├── launch_bc.py ├── render_viz.py ├── rl_modules ├── __init__.py ├── ddpg_agent.py ├── models.py ├── normalizer.py └── replay_buffer.py ├── scripts ├── eval.sh ├── make_data.sh ├── make_vids.sh └── oracle.sh ├── train.py ├── train_bc.py └── vec_env ├── LICENSE.md ├── __init__.py ├── base_vec_env.py └── subproc_vec_env.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # DS Store 107 | .DS_Store 108 | 109 | #saved_model 110 | *.pth 111 | 112 | *.pt 113 | 114 | *.log 115 | 116 | .vscode 117 | 118 | outputs/ 119 | multirun/ 120 | notebooks/**/*.png 121 | notebooks/**/*.svg 122 | output_videos/ 123 | data 124 | weights 125 | wandb -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [TYPECHECK] 2 | generated-members=numpy.*,torch.*,wandb.*,pytorch_lightning.*,pl.* 3 | 4 | [MESSAGES CONTROL] 5 | disable=invalid-name 6 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to entity-factored-rl 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to entity-factored-rl, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Entity Factored RL 2 | This contains code for running the experiments in Policy Architectures for Compositional Generalization in Control. 3 | 4 | ## Imitation Learning 5 | ## Data Collection 6 | 7 | Download the weights for generating BC data from [this link](https://drive.google.com/file/d/1XRYpFEHX__SWTyRQjvNG2MHRfSZ94L_N/view?usp=sharing). Unzip the file in `./weights`. 8 | Then generate behavior cloning data by running: 9 | ```bash 10 | ./script/make_data.sh 11 | ``` 12 | 13 | ## BC Training 14 | Sweep over different environments and architectures. 15 | ```bash 16 | python launch_bc.py -m +bc_experiment=big_transformer,deepset,mlp +bc_setup=3p,3s,2s2p 17 | ``` 18 | 19 | ## Reinforcement Learning 20 | In general, `setup` specifies the environment and exploration schedule and `experiment` specifies the architecture. Some examples: 21 | ```bash 22 | # 3 push for transformer and MLP (padded for extrapolation eval) 23 | python launch.py -m +experiment=3pdense_fastexp +setup=transformer,padded_mlp seed="range(5)" 24 | 25 | # 3 push for deepset uses a faster exploration schedule 26 | python launch.py -m +experiment=3pdense_fastexp +setup=deepset seed="range(5)" 27 | ``` 28 | 29 | ## License 30 | The majority of the code for "Entity Factored RL" is licensed under CC-BY-NC, however portions of the project are available under separate license terms. 31 | * Portions based on [Hindsight Experience Replay](https://github.com/TianhongDai/hindsight-experience-replay), [Stable Baselines 3](https://github.com/DLR-RM/stable-baselines3), and [Gym](https://github.com/openai/gym) are licensed under the MIT license. 32 | -------------------------------------------------------------------------------- /collect_demos.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import argparse 9 | import torch 10 | import numpy as np 11 | import gym 12 | 13 | import envs 14 | from evaluate import preproc_inputs 15 | 16 | 17 | ENV_2_CKPT_RL = { 18 | 'Fetch3Push-v1': './weights/3p-3uhwocsy.pt', 19 | 'Fetch3Switch-v1': './weights/3s-2mb98s47.pt', 20 | 'Fetch2Switch2Push-v1': './weights/2s2p-338rpvyu.pt', 21 | } 22 | 23 | 24 | @torch.no_grad() 25 | def main(args): 26 | if args.chain: 27 | if args.env in ['FetchStack2Stage3-v1', 'FetchStack2StitchOnlyStack-v1', 'FetchStack2Stage1-v1']: 28 | model_paths = { 29 | 'push': './weights/stack-1p-1i0ac7fq.pt', 30 | 'stack': './weights/stack-1s-2q3z23zd.pt', 31 | } 32 | networks = {} 33 | for name, model_path in model_paths.items(): 34 | x_norm, actor_network = torch.load(model_path, map_location=lambda storage, loc: storage) 35 | actor_network.eval() 36 | networks[name] = (x_norm, actor_network) 37 | def policy(observation): 38 | next_obj_idx = observation['next_object_idx'] 39 | if args.env == 'FetchStack2Stage1-v1': 40 | name = 'push' 41 | else: 42 | name = 'push' if next_obj_idx == 0 else 'stack' 43 | grip = observation['gripper_arr'] 44 | obj = observation['object_arr'] 45 | g = observation['desired_goal_arr'] 46 | x_norm, actor_network = networks[name] 47 | inputs = preproc_inputs(grip, obj, g, x_norm) 48 | pi = actor_network(*inputs) 49 | return pi.numpy().squeeze() 50 | else: 51 | model_paths = {} 52 | if 'Push' in args.env: 53 | model_paths['push'] = './weights/1p-3n2mu1hy.pt' 54 | if 'Switch' in args.env: 55 | model_paths['switch'] = './weights/1s-1gde5bpj.pt' 56 | networks = {} 57 | for name, model_path in model_paths.items(): 58 | x_norm, actor_network = torch.load(model_path, map_location=lambda storage, loc: storage) 59 | actor_network.eval() 60 | networks[name] = (x_norm, actor_network) 61 | def policy(observation): 62 | next_obj_idx = observation['next_object_idx'] 63 | grip = observation['gripper_arr'][None] 64 | obj = observation['object_arr'][next_obj_idx][None] 65 | obj_type = np.squeeze(obj[..., -1]) 66 | name = 'switch' if obj_type > 0 else 'push' 67 | g = observation['desired_goal_arr'][next_obj_idx][None] 68 | obj, g = obj[None], g[None] 69 | x_norm, actor_network = networks[name] 70 | inputs = preproc_inputs(grip, obj, g, x_norm) 71 | pi = actor_network(*inputs) 72 | return pi.numpy().squeeze() 73 | else: 74 | model_path = ENV_2_CKPT_RL[args.env] 75 | x_norm, actor_network = torch.load(model_path, map_location=lambda storage, loc: storage) 76 | actor_network.eval() 77 | def policy(observation): 78 | grip = observation['gripper_arr'][None] 79 | obj, g = observation['object_arr'], observation['desired_goal_arr'] 80 | obj, g = obj[None], g[None] 81 | inputs = preproc_inputs(grip, obj, g, x_norm) 82 | pi = actor_network(*inputs) 83 | return pi.numpy().squeeze() 84 | print(f"Collecting demos in {args.env}.") 85 | env = gym.make(args.env) 86 | observation = env.reset() 87 | if args.render: 88 | env = gym.wrappers.Monitor(env, "./demo_videos", force=True) 89 | observation = env.reset() 90 | successes = [] 91 | ret_arr, solved_t_arr = [], [] 92 | grip_arr, obj_arr, act_arr, ag_arr, goal_arr, success_arr = [], [], [], [], [], [] 93 | for i in range(args.num_eps): 94 | # start to do the demo 95 | ep_grip, ep_obj, ep_act, ep_ag, ep_g, ep_success = [], [], [], [], [], [] 96 | observation = env.reset() 97 | ret, solved_t = 0, -1 98 | for t in range(env._max_episode_steps): 99 | action = policy(observation) 100 | # put actions into the environment 101 | observation_new, rew, done, info = env.step(action) 102 | ep_grip.append(observation["gripper_arr"].copy()) 103 | ep_obj.append(observation["object_arr"].copy()) 104 | ep_ag.append(observation["achieved_goal_arr"].copy()) 105 | ep_g.append(observation["desired_goal_arr"].copy()) 106 | ep_act.append(action.copy()) 107 | ep_success.append(info["is_success"]) 108 | observation = observation_new 109 | if solved_t < 0: 110 | ret += rew 111 | if info['is_success'] and solved_t < 0: 112 | solved_t = t 113 | ep_grip.append(observation["gripper_arr"].copy()) 114 | ep_obj.append(observation["object_arr"].copy()) 115 | ep_ag.append(observation["achieved_goal_arr"].copy()) 116 | print(f'episode {i}, is success: {info["is_success"]}, finished t: {solved_t}') 117 | successes.append(info['is_success']) 118 | grip_arr.append(np.stack(ep_grip)) 119 | obj_arr.append(np.stack(ep_obj)) 120 | act_arr.append(np.stack(ep_act)) 121 | ag_arr.append(np.stack(ep_ag)) 122 | goal_arr.append(np.stack(ep_g)) 123 | success_arr.append(np.stack(ep_success)) 124 | ret_arr.append(ret) 125 | solved_t_arr.append(env._max_episode_steps if solved_t < 0 else solved_t) 126 | grip_arr = np.stack(grip_arr) 127 | obj_arr = np.stack(obj_arr) 128 | act_arr = np.stack(act_arr) 129 | ag_arr = np.stack(ag_arr) 130 | goal_arr = np.stack(goal_arr) 131 | success_arr = np.stack(success_arr) 132 | ret_arr = np.array(ret_arr) 133 | solved_t_arr = np.array(solved_t_arr) 134 | name = "init_trajs" if args.chain else "rl_expert" 135 | np.savez( 136 | f"./data/{name}_{args.env}_{args.num_eps}.npz", 137 | grip=grip_arr, 138 | obj=obj_arr, 139 | action=act_arr, 140 | ag=ag_arr, 141 | g=goal_arr, 142 | success=success_arr, 143 | ret_arr=ret_arr, 144 | solved_t_arr=solved_t_arr, 145 | ) 146 | print(np.mean(successes)) 147 | 148 | 149 | if __name__ == "__main__": 150 | parser = argparse.ArgumentParser() 151 | parser.add_argument("env", type=str, default="Fetch3Push-v1") 152 | parser.add_argument("--chain", action="store_true") 153 | parser.add_argument("--num_eps", type=int, default=1500) 154 | parser.add_argument("--render", action="store_true") 155 | args = parser.parse_args() 156 | main(args) 157 | -------------------------------------------------------------------------------- /configs/actor/big_deepset.yaml: -------------------------------------------------------------------------------- 1 | _target_: rl_modules.models.actor_deepset_big 2 | agg: sum -------------------------------------------------------------------------------- /configs/actor/deepset.yaml: -------------------------------------------------------------------------------- 1 | _target_: rl_modules.models.actor_deepset 2 | agg: sum -------------------------------------------------------------------------------- /configs/actor/mlp.yaml: -------------------------------------------------------------------------------- 1 | _target_: rl_modules.models.actor -------------------------------------------------------------------------------- /configs/actor/padded_mlp.yaml: -------------------------------------------------------------------------------- 1 | _target_: rl_modules.models.actor_padded -------------------------------------------------------------------------------- /configs/actor/simple_deepset.yaml: -------------------------------------------------------------------------------- 1 | _target_: rl_modules.models.actor_simple_deepset 2 | agg: sum -------------------------------------------------------------------------------- /configs/actor/transformer.yaml: -------------------------------------------------------------------------------- 1 | _target_: rl_modules.models.actor_tfm 2 | pos_enc: False 3 | embed_dim: 256 4 | dim_ff: 256 5 | n_head: 4 6 | n_blocks: 2 7 | dropout_p: 0.0 -------------------------------------------------------------------------------- /configs/bc_experiment/big_transformer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /actor: transformer 5 | 6 | lr: .0001 7 | actor: 8 | embed_dim: 512 9 | dim_ff: 1024 10 | n_head: 8 11 | n_blocks: 6 12 | 13 | lr_sched: linear_warmup 14 | weight_decay: .0001 15 | warmup_steps: 30000 -------------------------------------------------------------------------------- /configs/bc_experiment/deepset.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /actor: deepset -------------------------------------------------------------------------------- /configs/bc_experiment/mlp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /actor: mlp -------------------------------------------------------------------------------- /configs/bc_setup/2s2p.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch2Switch2Push-v1 4 | num_demos: 1000 5 | data_path: ./data/rl_expert_Fetch2Switch2Push-v1_${num_demos}.npz -------------------------------------------------------------------------------- /configs/bc_setup/3p.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch3Push-v1 4 | num_demos: 1000 5 | data_path: ./data/rl_expert_Fetch3Push-v1_${num_demos}.npz -------------------------------------------------------------------------------- /configs/bc_setup/3s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch3Switch-v1 4 | num_demos: 1000 5 | data_path: ./data/rl_expert_Fetch3Switch-v1_${num_demos}.npz -------------------------------------------------------------------------------- /configs/critic/big_deepset.yaml: -------------------------------------------------------------------------------- 1 | _target_: rl_modules.models.critic_deepset_big 2 | agg: sum -------------------------------------------------------------------------------- /configs/critic/deepset.yaml: -------------------------------------------------------------------------------- 1 | _target_: rl_modules.models.critic_deepset 2 | agg: sum -------------------------------------------------------------------------------- /configs/critic/mlp.yaml: -------------------------------------------------------------------------------- 1 | _target_: rl_modules.models.critic -------------------------------------------------------------------------------- /configs/critic/padded_mlp.yaml: -------------------------------------------------------------------------------- 1 | _target_: rl_modules.models.critic_padded -------------------------------------------------------------------------------- /configs/critic/simple_deepset.yaml: -------------------------------------------------------------------------------- 1 | _target_: rl_modules.models.critic_simple_deepset 2 | agg: sum -------------------------------------------------------------------------------- /configs/critic/transformer.yaml: -------------------------------------------------------------------------------- 1 | _target_: rl_modules.models.critic_tfm 2 | pos_enc: False 3 | embed_dim: 256 4 | dim_ff: 256 5 | n_head: 4 6 | n_blocks: 2 7 | dropout_p: 0.0 -------------------------------------------------------------------------------- /configs/experiment/big_deepset.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /actor: big_deepset 5 | - override /critic: big_deepset -------------------------------------------------------------------------------- /configs/experiment/big_transformer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /actor: transformer 5 | - override /critic: transformer 6 | 7 | lr: .0001 8 | optim_actor: 9 | lr: ${lr} 10 | optim_critic: 11 | lr: ${lr} 12 | 13 | actor: 14 | embed_dim: 512 15 | dim_ff: 1024 16 | n_head: 8 17 | n_blocks: 6 18 | 19 | critic: 20 | embed_dim: 512 21 | dim_ff: 1024 22 | n_head: 8 23 | n_blocks: 6 24 | 25 | warmup: 0 26 | warmup_actor: ${warmup} 27 | warmup_critic: ${warmup} -------------------------------------------------------------------------------- /configs/experiment/deepset.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /actor: deepset 5 | - override /critic: deepset -------------------------------------------------------------------------------- /configs/experiment/deepset_policyonly.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /actor: deepset 5 | - override /critic: mlp -------------------------------------------------------------------------------- /configs/experiment/med_transformer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /actor: transformer 5 | - override /critic: transformer 6 | 7 | lr: .0001 8 | optim_actor: 9 | lr: ${lr} 10 | optim_critic: 11 | lr: ${lr} 12 | 13 | actor: 14 | embed_dim: 512 15 | dim_ff: 1024 16 | n_head: 8 17 | n_blocks: 2 18 | 19 | critic: 20 | embed_dim: 512 21 | dim_ff: 1024 22 | n_head: 8 23 | n_blocks: 2 24 | 25 | warmup: 0 26 | warmup_actor: ${warmup} 27 | warmup_critic: ${warmup} -------------------------------------------------------------------------------- /configs/experiment/mlp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /actor: mlp 5 | - override /critic: mlp -------------------------------------------------------------------------------- /configs/experiment/padded_mlp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /actor: padded_mlp 5 | - override /critic: padded_mlp 6 | -------------------------------------------------------------------------------- /configs/experiment/simple_deepset.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /actor: simple_deepset 5 | - override /critic: simple_deepset 6 | -------------------------------------------------------------------------------- /configs/experiment/transformer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /actor: transformer 5 | - override /critic: transformer 6 | 7 | lr: .0001 8 | optim_actor: 9 | lr: ${lr} 10 | optim_critic: 11 | lr: ${lr} 12 | 13 | warmup: 0 14 | warmup_actor: ${warmup} 15 | warmup_critic: ${warmup} -------------------------------------------------------------------------------- /configs/experiment/transformer_policyonly.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /actor: transformer 5 | - override /critic: mlp 6 | 7 | optim_actor: 8 | lr: 0.0001 -------------------------------------------------------------------------------- /configs/hydra/launcher/local.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | hydra: 3 | launcher: 4 | cpus_per_task: 40 5 | gpus_per_node: 1 6 | tasks_per_node: 1 7 | timeout_min: 600 8 | mem_gb: 64 9 | name: ${hydra.job.name} 10 | _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.LocalLauncher 11 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j 12 | run: 13 | dir: ./outputs/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 14 | subdir: ${hydra.job.num}_${hydra.job.override_dirname} 15 | sweep: 16 | dir: ./outputs/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 17 | subdir: ${hydra.job.num}_${hydra.job.override_dirname} -------------------------------------------------------------------------------- /configs/hydra/launcher/slurm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | hydra: 3 | launcher: 4 | partition: learnfair 5 | cpus_per_task: 32 6 | gpus_per_node: 1 7 | tasks_per_node: 1 8 | timeout_min: 1440 9 | mem_gb: 32 10 | name: ${hydra.job.name} 11 | array_parallelism: 256 12 | _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher 13 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j 14 | run: 15 | dir: /checkpoint/${env:USER}/outputs/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 16 | subdir: ${hydra.job.override_dirname} 17 | sweep: 18 | dir: /checkpoint/${env:USER}/outputs/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 19 | subdir: ${hydra.job.override_dirname} -------------------------------------------------------------------------------- /configs/hydra/launcher/slurm_bc.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | hydra: 3 | launcher: 4 | partition: learnfair 5 | cpus_per_task: 16 6 | gpus_per_node: 1 7 | tasks_per_node: 1 8 | timeout_min: 90 9 | mem_gb: 32 10 | name: ${hydra.job.name} 11 | array_parallelism: 256 12 | _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher 13 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j 14 | run: 15 | dir: /checkpoint/${env:USER}/outputs/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 16 | subdir: ${hydra.job.override_dirname} 17 | sweep: 18 | dir: /checkpoint/${env:USER}/outputs/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 19 | subdir: ${hydra.job.override_dirname} -------------------------------------------------------------------------------- /configs/hydra/launcher/slurm_cpu.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | hydra: 3 | launcher: 4 | partition: learnfair 5 | cpus_per_task: 64 6 | gpus_per_node: 0 7 | tasks_per_node: 1 8 | timeout_min: 1440 9 | mem_gb: 64 10 | name: ${hydra.job.name} 11 | array_parallelism: 256 12 | _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher 13 | submitit_folder: ${hydra.sweep.dir}/.submitit/%j 14 | run: 15 | dir: /checkpoint/${env:USER}/outputs/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 16 | subdir: ${hydra.job.override_dirname} 17 | sweep: 18 | dir: /checkpoint/${env:USER}/outputs/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} 19 | subdir: ${hydra.job.override_dirname} -------------------------------------------------------------------------------- /configs/main.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - actor: mlp 3 | - critic: mlp 4 | 5 | env_name: Fetch1Push-v1 6 | n_epochs: 50 7 | n_cycles: 50 8 | n_batches: 40 9 | save_interval: 5 10 | seed: 123 11 | num_workers: 16 12 | replay_strategy: future 13 | clip_return: 50 14 | noise_eps: 0.2 15 | random_eps: 0.3 16 | exp_schedule: null 17 | buffer_size: 1000000 18 | replay_k: 4 19 | clip_obs: 200 20 | batch_size: 2048 # 256 * 8. MPI version had effectively a larger batch size due to multiple workers. 21 | gamma: 0.98 22 | action_l2: 1 23 | optim_actor: 24 | _target_: torch.optim.Adam 25 | lr: 0.001 26 | weight_decay: 0 27 | optim_critic: 28 | _target_: torch.optim.Adam 29 | lr: 0.001 30 | weight_decay: 0 31 | warmup_actor: 0 32 | warmup_critic: 0 33 | polyak: 0.95 34 | n_test_eps: 100 35 | clip_range: 5 36 | demo_length: 20 37 | cuda: True 38 | num_rollouts: 1 39 | init_trajs: null 40 | n_init_steps: 0 41 | norm_reward: False 42 | min_samples: 0 43 | r_scale: 1 -------------------------------------------------------------------------------- /configs/main_bc.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - actor: mlp 3 | 4 | env_name: Fetch3Push-v1 5 | lr: 0.001 6 | data_path: ./data/rl_expert_Fetch3Push-v1_3000.npz 7 | n_steps: 60000 8 | device: cuda 9 | clip_range: 5 10 | lr_sched: null 11 | batch_size: 128 12 | seed: 0 -------------------------------------------------------------------------------- /configs/setup/1p.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch1Push-v1 4 | r_scale: 1 5 | n_epochs: 50 -------------------------------------------------------------------------------- /configs/setup/1pdense.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch1PushDense-v1 4 | r_scale: 5 -------------------------------------------------------------------------------- /configs/setup/1s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch1Switch-v1 4 | r_scale: 1 5 | n_epochs: 10 -------------------------------------------------------------------------------- /configs/setup/1s1pdense.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch1Switch1PushDense-v1 4 | r_scale: 5 5 | n_epochs: 150 6 | polyak: 0.99 7 | exp_schedule: 8 | - 60 9 | - 100 10 | - .01 -------------------------------------------------------------------------------- /configs/setup/1s_or_1pdense.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch1SwitchOr1PushDense-v1 4 | r_scale: 5 5 | n_epochs: 100 6 | polyak: 0.99 7 | exp_schedule: 8 | - 40 9 | - 80 10 | - .01 -------------------------------------------------------------------------------- /configs/setup/2pdense.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch2PushDense-v1 4 | r_scale: 5 5 | n_epochs: 150 6 | polyak: 0.99 7 | exp_schedule: 8 | - 75 9 | - 125 10 | - .01 -------------------------------------------------------------------------------- /configs/setup/2s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch2Switch-v1 4 | r_scale: 1 5 | n_epochs: 50 -------------------------------------------------------------------------------- /configs/setup/2s2pdense.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch2Switch2PushDense-v1 4 | r_scale: 5 5 | n_epochs: 250 6 | polyak: 0.99 7 | exp_schedule: 8 | - 100 9 | - 175 10 | - .01 -------------------------------------------------------------------------------- /configs/setup/2s2pdense_fastexp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch2Switch2PushDense-v1 4 | r_scale: 5 5 | n_epochs: 250 6 | polyak: 0.99 7 | exp_schedule: 8 | - 75 9 | - 150 10 | - .01 -------------------------------------------------------------------------------- /configs/setup/2s_or_2pdense.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch2SwitchOr2PushDense-v1 4 | r_scale: 5 5 | n_epochs: 250 6 | polyak: 0.99 7 | exp_schedule: 8 | - 100 9 | - 175 10 | - .01 -------------------------------------------------------------------------------- /configs/setup/2s_or_2pdense_fastexp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch2SwitchOr2PushDense-v1 4 | r_scale: 5 5 | n_epochs: 250 6 | polyak: 0.99 7 | exp_schedule: 8 | - 50 9 | - 100 10 | - .01 -------------------------------------------------------------------------------- /configs/setup/2stackdense.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: FetchStack2DenseStitch-v1 4 | r_scale: 5 5 | n_epochs: 250 6 | polyak: 0.99 7 | exp_schedule: 8 | - 75 9 | - 150 10 | - .01 -------------------------------------------------------------------------------- /configs/setup/3p.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch3Push-v1 4 | r_scale: 1 5 | n_epochs: 250 6 | polyak: 0.99 7 | exp_schedule: 8 | - 50 9 | - 100 10 | - .01 -------------------------------------------------------------------------------- /configs/setup/3pdense.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch3PushDense-v1 4 | r_scale: 5 5 | n_epochs: 250 6 | polyak: 0.99 7 | exp_schedule: 8 | - 100 9 | - 175 10 | - .01 -------------------------------------------------------------------------------- /configs/setup/3pdense_fastexp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch3PushDense-v1 4 | r_scale: 5 5 | n_epochs: 250 6 | polyak: 0.99 7 | exp_schedule: 8 | - 30 9 | - 80 10 | - .01 -------------------------------------------------------------------------------- /configs/setup/3s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch3Switch-v1 4 | r_scale: 1 5 | n_epochs: 100 -------------------------------------------------------------------------------- /configs/setup/3sdense.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch3SwitchDense-v1 4 | r_scale: 5 5 | n_epochs: 150 6 | exp_schedule: 7 | - 100 8 | - 150 9 | - .01 -------------------------------------------------------------------------------- /configs/setup/demos_1pdense.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch1PushDense-v1 4 | init_trajs: ./data/init_trajs_Fetch1PushDense-v1_1000.npz 5 | r_scale: 5 6 | n_init_steps: 100 -------------------------------------------------------------------------------- /configs/setup/demos_3p.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch3Push-v1 4 | init_trajs: ./data/init_trajs_Fetch3Push-v1_3000.npz 5 | r_scale: 1 6 | n_epochs: 100 7 | n_init_steps: 100 -------------------------------------------------------------------------------- /configs/setup/demos_3pdense.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch3PushDense-v1 4 | init_trajs: ./data/init_trajs_Fetch3PushDense-v1_3000.npz 5 | r_scale: 15 6 | n_epochs: 100 7 | n_init_steps: 100 -------------------------------------------------------------------------------- /configs/setup/demos_3s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env_name: Fetch3Switch-v1 4 | init_trajs: ./data/init_trajs_Fetch3Switch-v1_3000.npz 5 | r_scale: 1 6 | n_epochs: 50 7 | n_init_steps: 100 -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: her 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=4.5=1_gnu 8 | - blas=1.0=mkl 9 | - ca-certificates=2021.7.5=h06a4308_1 10 | - certifi=2021.5.30=py37h06a4308_0 11 | - cudatoolkit=11.0.221=h6bb024c_0 12 | - freetype=2.10.4=h5ab3b9f_0 13 | - intel-openmp=2021.3.0=h06a4308_3350 14 | - jpeg=9b=h024ee3a_2 15 | - lcms2=2.12=h3be6417_0 16 | - ld_impl_linux-64=2.35.1=h7274673_9 17 | - libffi=3.3=he6710b0_2 18 | - libgcc-ng=9.3.0=h5101ec6_17 19 | - libgomp=9.3.0=h5101ec6_17 20 | - libpng=1.6.37=hbc83047_0 21 | - libstdcxx-ng=9.3.0=hd4cf53a_17 22 | - libtiff=4.2.0=h85742a9_0 23 | - libuv=1.40.0=h7b6447c_0 24 | - libwebp-base=1.2.0=h27cfd23_0 25 | - lz4-c=1.9.3=h295c915_1 26 | - mkl=2021.3.0=h06a4308_520 27 | - mkl-service=2.4.0=py37h7f8727e_0 28 | - mkl_fft=1.3.0=py37h42c9631_2 29 | - mkl_random=1.2.2=py37h51133e4_0 30 | - ncurses=6.2=he6710b0_1 31 | - ninja=1.10.2=hff7bd54_1 32 | - numpy-base=1.20.3=py37h74d4b33_0 33 | - olefile=0.46=py37_0 34 | - openjpeg=2.4.0=h3ad879b_0 35 | - openssl=1.1.1l=h7f8727e_0 36 | - pip=21.0.1=py37h06a4308_0 37 | - python=3.7.10=h12debd9_4 38 | - pytorch=1.7.1=py3.7_cuda11.0.221_cudnn8.0.5_0 39 | - readline=8.1=h27cfd23_0 40 | - setuptools=52.0.0=py37h06a4308_0 41 | - six=1.16.0=pyhd3eb1b0_0 42 | - sqlite=3.36.0=hc218d9a_0 43 | - tk=8.6.10=hbc83047_0 44 | - torchaudio=0.7.2=py37 45 | - torchvision=0.8.2=py37_cu110 46 | - typing_extensions=3.10.0.0=pyh06a4308_0 47 | - wheel=0.37.0=pyhd3eb1b0_0 48 | - xz=5.2.5=h7b6447c_0 49 | - zlib=1.2.11=h7b6447c_3 50 | - zstd=1.4.9=haebb681_0 51 | - pip: 52 | - cffi==1.14.6 53 | - configargparse==1.5.2 54 | - cython==0.29.24 55 | - glfw==2.1.0 56 | - gym==0.12.5 57 | - imageio==2.9.0 58 | - mpi4py==3.1.1 59 | - mujoco-py==1.50.1.56 60 | - numpy==1.21.2 61 | - pillow==8.3.2 62 | - pycparser==2.20 63 | - pyglet==1.5.19 64 | - pyyaml==5.4.1 65 | - scipy==1.7.1 66 | -------------------------------------------------------------------------------- /envs/LICENSE.md: -------------------------------------------------------------------------------- 1 | # gym 2 | 3 | The MIT License 4 | 5 | Copyright (c) 2016 OpenAI (https://openai.com) 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in 15 | all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 | THE SOFTWARE. 24 | 25 | # Mujoco models 26 | This work is derived from [MuJuCo models](http://www.mujoco.org/forum/index.php?resources/) used under the following license: 27 | ``` 28 | This file is part of MuJoCo. 29 | Copyright 2009-2015 Roboti LLC. 30 | Mujoco :: Advanced physics simulation engine 31 | Source : www.roboti.us 32 | Version : 1.31 33 | Released : 23Apr16 34 | Author :: Vikash Kumar 35 | Contacts : kumar@roboti.us 36 | ``` -------------------------------------------------------------------------------- /envs/README.md: -------------------------------------------------------------------------------- 1 | # Fetch N-Push and N-Switch environments 2 | 3 | If your code is outside the envs directory, you can initialize envs like follows: 4 | 5 | ```python 6 | import envs 7 | import gym 8 | 9 | env = gym.make("Fetch3Push-v1") # 3-Push, sparse reward 10 | env = gym.make("Fetch3PushDense-v1") # 3-Push, dense reward 11 | 12 | env = gym.make("Fetch3Switch-v1") # 3-Switch, sparse reward 13 | 14 | env = gym.make("Fetch2Switch2Push-v1") # 2 switches and 2 cubes 15 | 16 | env = gym.make("Fetch2SwitchOr2Push-v1") # 2 switches and 2 cubes, but the goal only involves either the switches or the cubes 17 | ``` 18 | -------------------------------------------------------------------------------- /envs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from gym.envs.registration import registry, register, make, spec 9 | 10 | for reward_type in ["sparse", "dense", "step", "hybrid", "eff"]: 11 | suffix = { 12 | "sparse": "", 13 | "dense": "Dense", 14 | "step": "Step", 15 | "hybrid": "Hybrid", 16 | "eff": "Eff", 17 | }[reward_type] 18 | kwargs = { 19 | "reward_type": reward_type, 20 | } 21 | 22 | for i in range(1, 4): 23 | 24 | register( 25 | id=f"Fetch{i}Switch{i}Push{suffix}-v1", 26 | entry_point="envs.fetch_push_multi:FetchNSwitchMPushEnv", 27 | kwargs={"num_objects": i, "num_switches": i, **kwargs}, 28 | max_episode_steps=70 * i, 29 | ) 30 | 31 | register( 32 | id=f"Fetch{i}Switch{i}PushCollide{suffix}-v1", 33 | entry_point="envs.fetch_push_multi:FetchNSwitchMPushEnv", 34 | kwargs={"num_objects": i, "num_switches": i, 'collisions': True, **kwargs}, 35 | max_episode_steps=70 * i, 36 | ) 37 | 38 | register( 39 | id=f"Fetch{i}SwitchOr{i}Push{suffix}-v1", 40 | entry_point="envs.fetch_push_multi:FetchNSwitchMPushEnv", 41 | kwargs={"num_objects": i, "num_switches": i, "push_switch_exclusive": "random", **kwargs}, 42 | max_episode_steps=50 * i, 43 | ) 44 | 45 | register( 46 | id=f"Fetch{i}SwitchOr{i}PushCollide{suffix}-v1", 47 | entry_point="envs.fetch_push_multi:FetchNSwitchMPushEnv", 48 | kwargs={"num_objects": i, "num_switches": i, "collisions": True, "push_switch_exclusive": "random", **kwargs}, 49 | max_episode_steps=50 * i, 50 | ) 51 | 52 | register( 53 | id=f"Fetch{i}SwitchOr{i}PushOnlyCube{suffix}-v1", 54 | entry_point="envs.fetch_push_multi:FetchNSwitchMPushEnv", 55 | kwargs={"num_objects": i, "num_switches": i, "push_switch_exclusive": "cube", **kwargs}, 56 | max_episode_steps=50 * i, 57 | ) 58 | 59 | register( 60 | id=f"Fetch{i}SwitchOr{i}PushOnlySwitch{suffix}-v1", 61 | entry_point="envs.fetch_push_multi:FetchNSwitchMPushEnv", 62 | kwargs={"num_objects": i, "num_switches": i, "push_switch_exclusive": "switch", **kwargs}, 63 | max_episode_steps=50 * i, 64 | ) 65 | 66 | register( 67 | id=f"Fetch{i}SwitchOr{i}PushOnlyCubeCollide{suffix}-v1", 68 | entry_point="envs.fetch_push_multi:FetchNSwitchMPushEnv", 69 | kwargs={"num_objects": i, "num_switches": i, "push_switch_exclusive": "cube", "collisions": True, **kwargs}, 70 | max_episode_steps=50 * i, 71 | ) 72 | 73 | register( 74 | id=f"Fetch{i}SwitchOr{i}PushOnlySwitchCollide{suffix}-v1", 75 | entry_point="envs.fetch_push_multi:FetchNSwitchMPushEnv", 76 | kwargs={"num_objects": i, "num_switches": i, "push_switch_exclusive": "switch", "collisions": True, **kwargs}, 77 | max_episode_steps=50 * i, 78 | ) 79 | 80 | for num_pairs in range(1, 7): 81 | max_steps = 50 * num_pairs 82 | 83 | register( 84 | id=f"Fetch{num_pairs}Push{suffix}-v1", 85 | entry_point="envs.fetch_push_multi:FetchNPushEnv", 86 | kwargs={"num_objects": num_pairs, **kwargs}, 87 | max_episode_steps=max_steps, 88 | ) 89 | 90 | register( 91 | id=f"Fetch{num_pairs}PushCollide{suffix}-v1", 92 | entry_point="envs.fetch_push_multi:FetchNPushEnv", 93 | kwargs={"num_objects": num_pairs, "collisions": True, **kwargs}, 94 | max_episode_steps=max_steps, 95 | ) 96 | 97 | for num_switches in range(1, 7): 98 | max_steps = 20 * num_switches 99 | register( 100 | id=f"Fetch{num_switches}Switch{suffix}-v1", 101 | entry_point="envs.fetch_push_multi:FetchNSwitchEnv", 102 | kwargs={"num_switches": num_switches, **kwargs}, 103 | max_episode_steps=max_steps, 104 | ) -------------------------------------------------------------------------------- /envs/assets/fetch/pick_and_place.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /envs/assets/fetch/push1.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /envs/assets/fetch/push1_collide.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /envs/assets/fetch/push2.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /envs/assets/fetch/push2_collide.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /envs/assets/fetch/push3.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /envs/assets/fetch/push3_collide.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /envs/assets/fetch/push4.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /envs/assets/fetch/push4_collide.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /envs/assets/fetch/push5.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /envs/assets/fetch/push5_collide.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /envs/assets/fetch/push6.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /envs/assets/fetch/push6_collide.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /envs/assets/fetch/reach.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /envs/assets/fetch/robot.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /envs/assets/fetch/shared.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /envs/assets/fetch/switch1.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /envs/assets/fetch/switch1push1.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /envs/assets/fetch/switch1push1_collide.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /envs/assets/fetch/switch2.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /envs/assets/fetch/switch2push2.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /envs/assets/fetch/switch2push2_collide.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /envs/assets/fetch/switch3.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /envs/assets/fetch/switch3push3.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /envs/assets/fetch/switch3push3_collide.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /envs/assets/fetch/switch4.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /envs/assets/fetch/switch5.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /envs/assets/fetch/switch6.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /envs/assets/hand/manipulate_block.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /envs/assets/hand/manipulate_block_touch_sensors.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /envs/assets/hand/manipulate_egg.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /envs/assets/hand/manipulate_egg_touch_sensors.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /envs/assets/hand/manipulate_pen.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /envs/assets/hand/manipulate_pen_touch_sensors.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /envs/assets/hand/reach.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /envs/assets/hand/shared_asset.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /envs/assets/hand/shared_touch_sensors_92.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /envs/assets/stls/fetch/base_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/fetch/base_link_collision.stl -------------------------------------------------------------------------------- /envs/assets/stls/fetch/bellows_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/fetch/bellows_link_collision.stl -------------------------------------------------------------------------------- /envs/assets/stls/fetch/elbow_flex_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/fetch/elbow_flex_link_collision.stl -------------------------------------------------------------------------------- /envs/assets/stls/fetch/estop_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/fetch/estop_link.stl -------------------------------------------------------------------------------- /envs/assets/stls/fetch/forearm_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/fetch/forearm_roll_link_collision.stl -------------------------------------------------------------------------------- /envs/assets/stls/fetch/gripper_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/fetch/gripper_link.stl -------------------------------------------------------------------------------- /envs/assets/stls/fetch/head_pan_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/fetch/head_pan_link_collision.stl -------------------------------------------------------------------------------- /envs/assets/stls/fetch/head_tilt_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/fetch/head_tilt_link_collision.stl -------------------------------------------------------------------------------- /envs/assets/stls/fetch/l_wheel_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/fetch/l_wheel_link_collision.stl -------------------------------------------------------------------------------- /envs/assets/stls/fetch/laser_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/fetch/laser_link.stl -------------------------------------------------------------------------------- /envs/assets/stls/fetch/r_wheel_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/fetch/r_wheel_link_collision.stl -------------------------------------------------------------------------------- /envs/assets/stls/fetch/shoulder_lift_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/fetch/shoulder_lift_link_collision.stl -------------------------------------------------------------------------------- /envs/assets/stls/fetch/shoulder_pan_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/fetch/shoulder_pan_link_collision.stl -------------------------------------------------------------------------------- /envs/assets/stls/fetch/torso_fixed_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/fetch/torso_fixed_link.stl -------------------------------------------------------------------------------- /envs/assets/stls/fetch/torso_lift_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/fetch/torso_lift_link_collision.stl -------------------------------------------------------------------------------- /envs/assets/stls/fetch/upperarm_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/fetch/upperarm_roll_link_collision.stl -------------------------------------------------------------------------------- /envs/assets/stls/fetch/wrist_flex_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/fetch/wrist_flex_link_collision.stl -------------------------------------------------------------------------------- /envs/assets/stls/fetch/wrist_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/fetch/wrist_roll_link_collision.stl -------------------------------------------------------------------------------- /envs/assets/stls/hand/F1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/hand/F1.stl -------------------------------------------------------------------------------- /envs/assets/stls/hand/F2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/hand/F2.stl -------------------------------------------------------------------------------- /envs/assets/stls/hand/F3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/hand/F3.stl -------------------------------------------------------------------------------- /envs/assets/stls/hand/TH1_z.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/hand/TH1_z.stl -------------------------------------------------------------------------------- /envs/assets/stls/hand/TH2_z.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/hand/TH2_z.stl -------------------------------------------------------------------------------- /envs/assets/stls/hand/TH3_z.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/hand/TH3_z.stl -------------------------------------------------------------------------------- /envs/assets/stls/hand/forearm_electric.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/hand/forearm_electric.stl -------------------------------------------------------------------------------- /envs/assets/stls/hand/forearm_electric_cvx.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/hand/forearm_electric_cvx.stl -------------------------------------------------------------------------------- /envs/assets/stls/hand/knuckle.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/hand/knuckle.stl -------------------------------------------------------------------------------- /envs/assets/stls/hand/lfmetacarpal.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/hand/lfmetacarpal.stl -------------------------------------------------------------------------------- /envs/assets/stls/hand/palm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/hand/palm.stl -------------------------------------------------------------------------------- /envs/assets/stls/hand/wrist.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/hand/wrist.stl -------------------------------------------------------------------------------- /envs/assets/stls/switch/lightswitch.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/switch/lightswitch.stl -------------------------------------------------------------------------------- /envs/assets/stls/switch/lightswitchbase.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/stls/switch/lightswitchbase.stl -------------------------------------------------------------------------------- /envs/assets/textures/block.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/textures/block.png -------------------------------------------------------------------------------- /envs/assets/textures/block_hidden.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/envs/assets/textures/block_hidden.png -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | 9 | import argparse 10 | import pickle 11 | import torch 12 | import gym 13 | import envs, gym_fetch_stack 14 | import numpy as np 15 | import wandb 16 | 17 | 18 | # pre process the inputs 19 | def preproc_inputs(grip, obj, g, x_norm): 20 | # concatenate the stuffs 21 | outs = x_norm.normalize(grip, obj, g) 22 | return [torch.tensor(x, dtype=torch.float32) for x in outs] 23 | 24 | 25 | @torch.no_grad() 26 | def main(args): 27 | api = wandb.Api() 28 | if args.run_tag: 29 | assert len(args.run_path) == 1 30 | runs = api.runs(path=args.run_path[0], filters={"tags": {"$in": [args.run_tag]}}) 31 | else: 32 | runs = [] 33 | for run_path in args.run_path: 34 | runs.append(api.run(run_path)) 35 | # load the model param 36 | for run in runs: 37 | if run._state != "finished": 38 | print("Run not finished, skipping.") 39 | continue 40 | try: 41 | run.file('models/best.pt').download(root='/tmp', replace=True) 42 | x_norm, actor_network = torch.load('/tmp/models/best.pt', map_location=lambda storage, loc: storage) 43 | except: 44 | run.file('models/latest.pt').download(root='/tmp', replace=True) 45 | x_norm, actor_network = torch.load('/tmp/models/latest.pt', map_location=lambda storage, loc: storage) 46 | actor_network.eval() 47 | print(f"Evaluating {run.config['actor']['_target_']} in {args.env_name}") 48 | for env_name in args.env_name: 49 | if not args.overwrite and f"eval_return/{env_name}" in run.summary: 50 | print(f"Run {run.id} already has eval for {env_name}.") 51 | continue 52 | env = gym.make(env_name) 53 | results, rets, timesteps = [], [], [] 54 | # start to do the demo 55 | for ep_num in range(args.num_eps): 56 | done, ret, t = False, 0, 0 57 | observation = env.reset() 58 | while not done: 59 | grip, obj = observation['gripper_arr'][None], observation['object_arr'][None] 60 | g = observation['desired_goal_arr'][None] 61 | inputs = preproc_inputs(grip, obj, g, x_norm) 62 | with torch.no_grad(): 63 | pi = actor_network(*inputs) 64 | action = pi.detach().numpy().squeeze() 65 | # put actions into the environment 66 | observation_new, reward, done, info = env.step(action) 67 | observation = observation_new 68 | ret += reward 69 | done = done or info['is_success'] 70 | t += 1 71 | results.append(info['is_success']) 72 | rets.append(ret) 73 | timesteps.append(t) 74 | print(f"{env_name} avg SR ({len(results)} eps): {np.mean(results):0.4f}.") 75 | run.summary[f"eval_success/{env_name}"] = np.mean(results) 76 | run.summary[f"eval_return/{env_name}"] = np.mean(rets) 77 | run.summary[f"eval_length/{env_name}"] = np.mean(timesteps) 78 | run.summary.update() 79 | 80 | 81 | if __name__ == '__main__': 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument('--env_name', type=str, nargs='+') 84 | parser.add_argument('--run_tag', type=str) 85 | parser.add_argument('--run_path', type=str, nargs='+') 86 | parser.add_argument('--num_eps', type=int, default=200) 87 | parser.add_argument('--render', action='store_true') 88 | parser.add_argument('--overwrite', action='store_true') 89 | args = parser.parse_args() 90 | main(args) -------------------------------------------------------------------------------- /her_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/her_modules/__init__.py -------------------------------------------------------------------------------- /her_modules/her.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Based on the HeR implementation at https://github.com/TianhongDai/hindsight-experience-replay. 8 | 9 | import numpy as np 10 | 11 | class her_sampler: 12 | def __init__(self, replay_strategy, replay_k, reward_func=None): 13 | self.replay_strategy = replay_strategy 14 | self.replay_k = replay_k 15 | if self.replay_strategy == 'future': 16 | self.future_p = 1 - (1. / (1 + replay_k)) 17 | else: 18 | self.future_p = 0 19 | self.reward_func = reward_func 20 | 21 | def sample_her_transitions(self, episode_batch, batch_size_in_transitions): 22 | rollout_batch_size = episode_batch['actions'].shape[0] 23 | batch_size = batch_size_in_transitions 24 | # select which rollouts and which timesteps to be used 25 | episode_idxs = np.random.randint(0, rollout_batch_size, batch_size) 26 | # t_samples = np.random.randint(T, size=batch_size) 27 | final_t = episode_batch['final_t'][episode_idxs] 28 | t_samples = np.random.randint(final_t) 29 | transitions = {} 30 | for key in episode_batch.keys(): 31 | if key != 'final_t': 32 | transitions[key] = episode_batch[key][episode_idxs, t_samples].copy() 33 | # her idx 34 | her_indexes = np.where(np.random.uniform(size=batch_size) < self.future_p) 35 | future_offset = np.random.uniform(size=batch_size) * (final_t - t_samples) 36 | future_offset = future_offset.astype(int) 37 | future_t = (t_samples + 1 + future_offset)[her_indexes] 38 | # replace go with achieved goal 39 | future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t] 40 | transitions['g'][her_indexes] = future_ag 41 | # to get the params to re-compute reward 42 | info = {'gripper_arr': transitions['gripper_next']} 43 | transitions['r'] = np.expand_dims(self.reward_func(transitions['ag_next'], transitions['g'], info), 1) 44 | transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:]) for k in transitions.keys()} 45 | 46 | return transitions -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import hydra 9 | from omegaconf import OmegaConf 10 | from train import launch 11 | 12 | 13 | @hydra.main(config_name="main", config_path="configs") 14 | def main(cfg=None): 15 | print(OmegaConf.to_yaml(cfg, resolve=True)) 16 | launch(cfg) 17 | 18 | 19 | if __name__ == "__main__": 20 | main() -------------------------------------------------------------------------------- /launch_bc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import hydra 9 | from omegaconf import OmegaConf 10 | from train_bc import train 11 | 12 | 13 | @hydra.main(config_name="main_bc", config_path="configs") 14 | def main(cfg=None): 15 | print(OmegaConf.to_yaml(cfg, resolve=True)) 16 | train(cfg) 17 | 18 | 19 | if __name__ == "__main__": 20 | main() -------------------------------------------------------------------------------- /render_viz.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import argparse 9 | import os 10 | import torch 11 | import gym 12 | import envs, gym_fetch_stack 13 | from gym.wrappers.monitoring.video_recorder import VideoRecorder 14 | import numpy as np 15 | import wandb 16 | 17 | 18 | # pre process the inputs 19 | def preproc_inputs(grip, obj, g, x_norm): 20 | # concatenate the stuffs 21 | outs = x_norm.normalize(grip, obj, g) 22 | return [torch.tensor(x, dtype=torch.float32) for x in outs] 23 | 24 | 25 | @torch.no_grad() 26 | def main(args): 27 | if not os.path.exists('output_videos'): 28 | os.mkdir('output_videos') 29 | api = wandb.Api() 30 | runs = [] 31 | for run_id in args.run_id: 32 | runs.append(api.run(os.path.join('ayzhong/fetch-her', run_id))) 33 | 34 | # load the model param 35 | for run in runs: 36 | run.file('models/best.pt').download(root='/tmp', replace=True) 37 | x_norm, actor_network = torch.load('/tmp/models/best.pt', map_location=lambda storage, loc: storage) 38 | actor_network.eval() 39 | # create the environment 40 | print(args.env_name) 41 | for env_name in args.env_name: 42 | env = gym.make(env_name) 43 | actor_cls = run.config['actor']['_target_'].split('.')[-1] 44 | dir = f'./output_videos/{env_name}/{actor_cls}/' 45 | os.makedirs(dir) 46 | for i in range(10): 47 | rec = VideoRecorder(env, base_path=os.path.join(dir, f'vid{i}')) 48 | observation = env.reset() 49 | # start to do the demo 50 | t, solved, solved_step = 0, False, 0 51 | while t < env._max_episode_steps: 52 | if solved and t > solved_step + 10: 53 | break 54 | rec.capture_frame() 55 | grip, obj = observation['gripper_arr'][None], observation['object_arr'][None] 56 | g = observation['desired_goal_arr'][None] 57 | # env.render() 58 | inputs = preproc_inputs(grip, obj, g, x_norm) 59 | with torch.no_grad(): 60 | pi = actor_network(*inputs) 61 | action = pi.detach().numpy().squeeze() 62 | # put actions into the environment 63 | observation_new, reward, done, info = env.step(action) 64 | if not solved and info['is_success']: 65 | solved_step = t 66 | solved = solved or info['is_success'] 67 | observation = observation_new 68 | t += 1 69 | rec.capture_frame() 70 | rec.close() 71 | 72 | 73 | if __name__ == '__main__': 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('--env_name', type=str, nargs='+') 76 | parser.add_argument('--run_id', type=str, nargs='+') 77 | args = parser.parse_args() 78 | main(args) -------------------------------------------------------------------------------- /rl_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/rl_modules/__init__.py -------------------------------------------------------------------------------- /rl_modules/normalizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Based on the HeR implementation at https://github.com/TianhongDai/hindsight-experience-replay. 8 | 9 | import numpy as np 10 | 11 | class normalizer: 12 | def __init__(self, size, eps=1e-2, default_clip_range=np.inf): 13 | self.size = size 14 | self.eps = eps 15 | self.default_clip_range = default_clip_range 16 | # get the total sum sumsq and sum count 17 | self.total_sum = np.zeros(self.size, np.float32) 18 | self.total_sumsq = np.zeros(self.size, np.float32) 19 | self.total_count = np.ones(1, np.float32) 20 | # get the mean and std 21 | self.mean = np.zeros(self.size, np.float32) 22 | self.std = np.ones(self.size, np.float32) 23 | 24 | # update the parameters of the normalizer 25 | def update(self, v): 26 | v = v.reshape(-1, self.size) 27 | # do the computing 28 | self.total_sum += v.sum(axis=0) 29 | self.total_sumsq += (np.square(v)).sum(axis=0) 30 | self.total_count[0] += v.shape[0] 31 | 32 | def recompute_stats(self): 33 | # calculate the new mean and std 34 | self.mean = self.total_sum / self.total_count 35 | self.std = np.sqrt(np.maximum(np.square(self.eps), (self.total_sumsq / self.total_count) - np.square(self.total_sum / self.total_count))) 36 | 37 | # normalize the observation 38 | def normalize(self, v, clip_range=None): 39 | if clip_range is None: 40 | clip_range = self.default_clip_range 41 | return np.clip((v - self.mean) / (self.std), -clip_range, clip_range) 42 | 43 | 44 | class ArrayNormalizer: 45 | def __init__(self, env_params, default_clip_range=np.inf): 46 | self.gripper_norm = normalizer(env_params['gripper'], default_clip_range=default_clip_range) 47 | self.objects_norm = normalizer(env_params['object'], default_clip_range=default_clip_range) 48 | self.goal_norm = normalizer(env_params['goal'], default_clip_range=default_clip_range) 49 | 50 | def update(self, gripper, objects, goal): 51 | self.gripper_norm.update(gripper) 52 | self.objects_norm.update(objects) 53 | self.goal_norm.update(goal) 54 | 55 | def recompute_stats(self): 56 | self.gripper_norm.recompute_stats() 57 | self.objects_norm.recompute_stats() 58 | self.goal_norm.recompute_stats() 59 | 60 | # normalize the observation 61 | def normalize(self, gripper, objects, goal, clip_range=None): 62 | gripper_norm = self.gripper_norm.normalize(gripper, clip_range) 63 | objects_norm = self.objects_norm.normalize(objects, clip_range) 64 | goal_norm = self.goal_norm.normalize(goal, clip_range) 65 | return gripper_norm, objects_norm, goal_norm -------------------------------------------------------------------------------- /rl_modules/replay_buffer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Based on the HeR implementation at https://github.com/TianhongDai/hindsight-experience-replay. 8 | 9 | import numpy as np 10 | 11 | """ 12 | the replay buffer here is basically from the openai baselines code 13 | 14 | """ 15 | class replay_buffer: 16 | def __init__(self, env_params, buffer_size, sample_func): 17 | self.env_params = env_params 18 | self.T = env_params['max_timesteps'] 19 | self.size = buffer_size // self.T 20 | # memory management 21 | self.current_size = 0 22 | self.n_transitions_stored = 0 23 | self.sample_func = sample_func 24 | # create the buffer to store info 25 | self.buffers = {'obj': np.empty([self.size, self.T + 1, self.env_params['n_objects'], self.env_params['object']]), 26 | 'ag': np.empty([self.size, self.T + 1, self.env_params['n_objects'], self.env_params['goal']]), 27 | 'g': np.empty([self.size, self.T, self.env_params['n_objects'], self.env_params['goal']]), 28 | 'actions': np.empty([self.size, self.T, self.env_params['action']]), 29 | 'gripper': np.empty([self.size, self.T + 1, self.env_params['gripper']]), 30 | 'final_t': np.empty(self.size, dtype=np.int64), 31 | } 32 | 33 | # store the episode 34 | def store_episode(self, episode_batch): 35 | mb_grip, mb_obj, mb_ag, mb_g, mb_actions = episode_batch 36 | T = mb_actions.shape[1] 37 | batch_size = mb_grip.shape[0] 38 | idxs = self._get_storage_idx(inc=batch_size) 39 | # store the informations 40 | self.buffers['gripper'][idxs, :T+1] = mb_grip 41 | self.buffers['obj'][idxs, :T+1] = mb_obj 42 | self.buffers['ag'][idxs, :T+1] = mb_ag 43 | self.buffers['g'][idxs, :T] = mb_g 44 | self.buffers['actions'][idxs, :T] = mb_actions 45 | self.buffers['final_t'][idxs] = T 46 | self.n_transitions_stored += self.T * batch_size 47 | 48 | # sample the data from the replay buffer 49 | def sample(self, batch_size): 50 | temp_buffers = {} 51 | for key in self.buffers.keys(): 52 | temp_buffers[key] = self.buffers[key][:self.current_size] 53 | temp_buffers['gripper_next'] = temp_buffers['gripper'][:, 1:, :] 54 | temp_buffers['obj_next'] = temp_buffers['obj'][:, 1:, :] 55 | temp_buffers['ag_next'] = temp_buffers['ag'][:, 1:, :] 56 | # sample transitions 57 | transitions = self.sample_func(temp_buffers, batch_size) 58 | return transitions 59 | 60 | def _get_storage_idx(self, inc=None): 61 | inc = inc or 1 62 | if self.current_size+inc <= self.size: 63 | idx = np.arange(self.current_size, self.current_size+inc) 64 | elif self.current_size < self.size: 65 | overflow = inc - (self.size - self.current_size) 66 | idx_a = np.arange(self.current_size, self.size) 67 | idx_b = np.random.randint(0, self.current_size, overflow) 68 | idx = np.concatenate([idx_a, idx_b]) 69 | else: 70 | idx = np.random.randint(0, self.size, inc) 71 | self.current_size = min(self.size, self.current_size+inc) 72 | if inc == 1: 73 | idx = idx[0] 74 | return idx -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | python demo.py --env_name Fetch1Push-v1 Fetch2Push-v1 Fetch3Push-v1 Fetch4Push-v1 Fetch5Push-v1 Fetch6Push-v1 --run_path ayzhong/fetch-her --run_tag eval 6 | python demo.py --env_name Fetch1Switch-v1 Fetch2Switch-v1 Fetch3Switch-v1 Fetch4Switch-v1 Fetch5Switch-v1 Fetch6Switch-v1 --run_path ayzhong/fetch-her --run_tag eval-3s 7 | python demo.py --env_name Fetch1Switch1Push-v1 Fetch2Switch2Push-v1 Fetch3Switch3Push-v1 --run_path ayzhong/fetch-her --run_tag eval-2s2p 8 | python demo.py --env_name Fetch2SwitchOr2PushOnlyCube-v1 Fetch2SwitchOr2PushOnlySwitch-v1 Fetch2Switch2Push-v1 --run_path ayzhong/fetch-her --run_tag eval-2sor2p 9 | python demo.py --env_name FetchStack2Stage1-v1 FetchStack2StitchOnlyStack-v1 FetchStack2Stage3-v1 --run_path ayzhong/fetch-her --run_tag eval-2stack 10 | 11 | python demo.py --env_name Fetch1PushCollide-v1 Fetch2PushCollide-v1 Fetch3PushCollide-v1 Fetch4PushCollide-v1 Fetch5PushCollide-v1 Fetch6PushCollide-v1 --run_path ayzhong/fetch-her --run_tag eval-3pcolliide 12 | python demo.py --env_name Fetch1Switch1PushCollide-v1 Fetch2Switch2PushCollide-v1 Fetch3Switch3PushCollide-v1 --run_path ayzhong/fetch-her --run_tag eval-2s2pcollide 13 | python demo.py --env_name Fetch2SwitchOr2PushOnlyCubeCollide-v1 Fetch2SwitchOr2PushOnlySwitchCollide-v1 Fetch2Switch2PushCollide-v1 --run_path ayzhong/fetch-her --run_tag eval-2sor2pcollide -------------------------------------------------------------------------------- /scripts/make_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | python collect_demos.py Fetch3Switch-v1 --num_eps 1000 6 | python collect_demos.py Fetch3Switch-v1 --num_eps 2000 7 | python collect_demos.py Fetch3Switch-v1 --num_eps 3000 8 | python collect_demos.py Fetch3Switch-v1 --num_eps 4000 9 | python collect_demos.py Fetch3Switch-v1 --num_eps 5000 10 | 11 | python collect_demos.py Fetch3Push-v1 --num_eps 1000 12 | python collect_demos.py Fetch3Push-v1 --num_eps 2000 13 | python collect_demos.py Fetch3Push-v1 --num_eps 3000 14 | python collect_demos.py Fetch3Push-v1 --num_eps 4000 15 | python collect_demos.py Fetch3Push-v1 --num_eps 5000 16 | 17 | python collect_demos.py Fetch2Switch2Push-v1 --num_eps 1000 18 | python collect_demos.py Fetch2Switch2Push-v1 --num_eps 2000 19 | python collect_demos.py Fetch2Switch2Push-v1 --num_eps 3000 20 | python collect_demos.py Fetch2Switch2Push-v1 --num_eps 4000 21 | python collect_demos.py Fetch2Switch2Push-v1 --num_eps 5000 22 | -------------------------------------------------------------------------------- /scripts/make_vids.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | python render_viz.py --env_name Fetch1Push-v1 Fetch3Push-v1 Fetch6Push-v1 --run_id 3khi7siu 3ly9m71w 6 | python render_viz.py --env_name Fetch1Switch-v1 Fetch3Switch-v1 Fetch6Switch-v1 --run_id 2mb98s47 1sfvs6m4 7 | python render_viz.py --env_name Fetch2Switch2Push-v1 Fetch3Switch3Push-v1 --run_id 2uo2uaqh rvwv2fkn 8 | python render_viz.py --env_name FetchStack2Stage1-v1 FetchStack2StitchOnlyStack-v1 FetchStack2Stage3-v1 --run_id 2pw2m40g pflc1mye -------------------------------------------------------------------------------- /scripts/oracle.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | python collect_demos.py Fetch1Switch-v1 --chain --num_eps 300 6 | python collect_demos.py Fetch2Switch-v1 --chain --num_eps 300 7 | python collect_demos.py Fetch3Switch-v1 --chain --num_eps 300 8 | python collect_demos.py Fetch4Switch-v1 --chain --num_eps 300 9 | python collect_demos.py Fetch5Switch-v1 --chain --num_eps 300 10 | python collect_demos.py Fetch6Switch-v1 --chain --num_eps 300 11 | 12 | python collect_demos.py Fetch1Push-v1 --chain --num_eps 300 13 | python collect_demos.py Fetch2Push-v1 --chain --num_eps 300 14 | python collect_demos.py Fetch3Push-v1 --chain --num_eps 300 15 | python collect_demos.py Fetch4Push-v1 --chain --num_eps 300 16 | python collect_demos.py Fetch5Push-v1 --chain --num_eps 300 17 | python collect_demos.py Fetch6Push-v1 --chain --num_eps 300 18 | 19 | python collect_demos.py Fetch1Switch1Push-v1 --num_eps 300 --chain 20 | python collect_demos.py Fetch2Switch2Push-v1 --num_eps 300 --chain 21 | python collect_demos.py Fetch3Switch3Push-v1 --num_eps 300 --chain 22 | 23 | python collect_demos.py FetchStack2Stage1-v1 --num_eps 300 --chain 24 | python collect_demos.py FetchStack2Stage3-v1 --num_eps 300 --chain 25 | python collect_demos.py FetchStack2StitchOnlyStack-v1 --num_eps 300 --chain 26 | 27 | python collect_demos.py Fetch1PushCollide-v1 --num_eps 300 --chain 28 | python collect_demos.py Fetch2PushCollide-v1 --num_eps 300 --chain 29 | python collect_demos.py Fetch3PushCollide-v1 --num_eps 300 --chain 30 | python collect_demos.py Fetch4PushCollide-v1 --num_eps 300 --chain 31 | python collect_demos.py Fetch5PushCollide-v1 --num_eps 300 --chain 32 | python collect_demos.py Fetch6PushCollide-v1 --num_eps 300 --chain 33 | 34 | python collect_demos.py Fetch1Switch1PushCollide-v1 --num_eps 300 --chain 35 | python collect_demos.py Fetch2Switch2PushCollide-v1 --num_eps 300 --chain 36 | python collect_demos.py Fetch3Switch3PushCollide-v1 --num_eps 300 --chain -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Based on the HeR implementation at https://github.com/TianhongDai/hindsight-experience-replay. 8 | 9 | import os 10 | import pickle 11 | import hydra 12 | import wandb 13 | from omegaconf import OmegaConf 14 | import gym 15 | import mujoco_py 16 | from rl_modules.ddpg_agent import ddpg_agent 17 | from vec_env.subproc_vec_env import SubprocVecEnv 18 | from stable_baselines3.common.vec_env import VecVideoRecorder, vec_normalize 19 | import envs, gym_fetch_stack 20 | 21 | 22 | def get_env_params(env): 23 | obs = env.reset() 24 | # close the environment 25 | params = { 26 | 'gripper': obs['gripper_arr'].shape[-1], 27 | 'goal': obs['desired_goal_arr'].shape[-1], 28 | 'action': env.action_space.shape[0], 29 | 'action_max': env.action_space.high[0], 30 | 'object': obs['object_arr'].shape[-1], 31 | 'n_objects': obs['object_arr'].shape[0], 32 | } 33 | params['max_timesteps'] = env._max_episode_steps 34 | return params 35 | 36 | 37 | def launch(args): 38 | # get the environment parameters 39 | test_env = gym.make(args.env_name) 40 | env_params = get_env_params(test_env) 41 | test_env.close() 42 | # create the ddpg_agent 43 | def make_env(): 44 | import envs, gym_fetch_stack # needed when using start_method="spawn" 45 | return gym.make(args.env_name) 46 | env = SubprocVecEnv([make_env for i in range(args.num_workers)], start_method="spawn") 47 | if args.norm_reward: 48 | env = vec_normalize.VecNormalize(env, norm_obs=False, norm_reward=True, gamma=args.gamma) 49 | eval_env = SubprocVecEnv([make_env for i in range(args.num_workers)], start_method="spawn", auto_reset=True) 50 | eval_env = VecVideoRecorder(eval_env, "eval_vids", lambda i: i < env_params['max_timesteps'], env_params['max_timesteps']) 51 | 52 | ckpt_data, wid = None, None 53 | ckpt_path = "./checkpoint.pkl" 54 | if os.path.exists(ckpt_path): 55 | with open(ckpt_path, "rb") as f: 56 | print(f"Loading data from {ckpt_path}.") 57 | ckpt_data = pickle.load(f) 58 | wid = ckpt_data["wandb_run_id"] 59 | # create the ddpg agent to interact with the environment 60 | wandb.init(project='fetch-her', entity='ayzhong', id=wid, resume="allow", dir=hydra.utils.get_original_cwd()) 61 | if wid is None: 62 | wandb.config.update(OmegaConf.to_container(args, resolve=True)) 63 | ddpg_trainer = ddpg_agent(args, env, eval_env, env_params, test_env.compute_reward, ckpt_data=ckpt_data) 64 | ddpg_trainer.learn() 65 | -------------------------------------------------------------------------------- /train_bc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import os 9 | import time 10 | import hydra 11 | from omegaconf import OmegaConf 12 | import torch 13 | from torch import optim 14 | from torch.optim import lr_scheduler 15 | from torch.utils.data import TensorDataset, DataLoader 16 | import torch.nn.functional as F 17 | import numpy as np 18 | import wandb 19 | import gym 20 | import envs 21 | 22 | from stable_baselines3.common.vec_env import VecVideoRecorder, SubprocVecEnv 23 | from rl_modules.normalizer import ArrayNormalizer 24 | from train import get_env_params 25 | 26 | 27 | def make_env(name): 28 | def helper(): 29 | import envs 30 | return gym.make(name) 31 | return SubprocVecEnv([helper for i in range(16)], start_method="spawn") 32 | 33 | 34 | def load_data(data_path, x_norm): 35 | data = np.load(data_path) 36 | grip, obj, goal, action, success = data["grip"][:, :-1], data["obj"][:, :-1], data["g"], data["action"], data["success"] 37 | succ_grip, succ_obj, succ_goal, succ_action = [], [], [], [] 38 | print("Filtering data.") 39 | for ep_idx in range(grip.shape[0]): 40 | if np.sum(success[ep_idx]) > 0: 41 | first_succ = np.argmax(success[ep_idx]) 42 | succ_grip.append(grip[ep_idx, :first_succ + 1]) 43 | succ_obj.append(obj[ep_idx, :first_succ + 1]) 44 | succ_goal.append(goal[ep_idx, :first_succ + 1]) 45 | succ_action.append(action[ep_idx, :first_succ + 1]) 46 | succ_grip = np.concatenate(succ_grip, 0) 47 | succ_obj = np.concatenate(succ_obj, 0) 48 | succ_goal = np.concatenate(succ_goal, 0) 49 | succ_action = np.concatenate(succ_action, 0) 50 | x_norm.update(succ_grip, succ_obj, succ_goal) 51 | x_norm.recompute_stats() 52 | succ_grip = torch.from_numpy(succ_grip).float() 53 | succ_obj = torch.from_numpy(succ_obj).float() 54 | succ_goal = torch.from_numpy(succ_goal).float() 55 | succ_action = torch.from_numpy(succ_action).float() 56 | print(f"Finished filtering data. Total transitions: {succ_grip.shape[0]}.") 57 | return TensorDataset(succ_grip, succ_obj, succ_goal, succ_action) 58 | 59 | 60 | def get_opt_sched(policy, args): 61 | if args.lr_sched == "linear_warmup": 62 | opt = optim.AdamW(policy.parameters(), lr=args.lr, weight_decay=args.weight_decay) 63 | scheduler = lr_scheduler.LambdaLR(opt, lambda t: min((t+1) / args.warmup_steps, 1)) 64 | else: 65 | opt = optim.Adam(policy.parameters(), lr=args.lr) 66 | scheduler = None 67 | return opt, scheduler 68 | 69 | 70 | @torch.no_grad() 71 | def evaluate(env_name, policy, x_norm, device, folder): 72 | env = make_env(env_name) 73 | # Not necessarily the same as the training env's max_episode_steps 74 | max_steps = env.get_attr("_max_episode_steps")[0] 75 | env = VecVideoRecorder(env, folder, lambda i: i < max_steps, max_steps, name_prefix=env_name) 76 | def predict_act(observation): 77 | inp = x_norm.normalize( 78 | observation["gripper_arr"], 79 | observation["object_arr"], 80 | observation["desired_goal_arr"] 81 | ) 82 | inp = [torch.from_numpy(x).float().to(device) for x in inp] 83 | return policy(*inp).cpu().numpy() 84 | observation = env.reset() 85 | results, step_results = [], [] 86 | # start to do the demo 87 | while len(results) < 100: 88 | # env.render() 89 | action = predict_act(observation) 90 | # put actions into the environment 91 | observation_new, reward, done, info = env.step(action) 92 | if np.any(done): 93 | for idx in np.nonzero(done)[0]: 94 | results.append(info[idx]['is_success']) 95 | step_results.append(info[idx]['step_success']) 96 | observation = observation_new 97 | metrics = { 98 | "success_rate": np.mean(results), 99 | "step_success_rate": np.mean(step_results), 100 | "vids/0": wandb.Video(env.video_recorder.path, format="mp4"), 101 | } 102 | env.close() 103 | return metrics 104 | 105 | 106 | def train(args): 107 | wandb.init(project="fetch-bc", entity="ayzhong", dir=hydra.utils.get_original_cwd()) 108 | wandb.config.update(OmegaConf.to_container(args, resolve=True)) 109 | 110 | device = torch.device(args.device) 111 | print(f"Using device {device}.") 112 | 113 | train_env = gym.make(args.env_name) 114 | env_params = get_env_params(train_env) 115 | del train_env 116 | x_norm = ArrayNormalizer(env_params, default_clip_range=args.clip_range) 117 | 118 | model_dir = os.path.join(wandb.run.dir, "models") 119 | os.mkdir(model_dir) 120 | 121 | policy = hydra.utils.instantiate(args.actor, env_params).to(device) 122 | opt, scheduler = get_opt_sched(policy, args) 123 | # The path in the config files is relative to the code, but hydra cwd is different. 124 | data_path = os.path.join(hydra.utils.get_original_cwd(), args.data_path) 125 | dset = load_data(data_path, x_norm) 126 | dloader = DataLoader(dset, batch_size=args.batch_size, shuffle=True) 127 | 128 | best_success_rate = -1 129 | i = 0 130 | while True: 131 | for grip, obj, goal, action in dloader: 132 | inp = x_norm.normalize(grip.numpy(), obj.numpy(), goal.numpy()) 133 | inp = [torch.from_numpy(x).to(device) for x in inp] 134 | act_predict = policy(*inp) 135 | loss = F.mse_loss(act_predict, action.to(device)) 136 | opt.zero_grad() 137 | loss.backward() 138 | opt.step() 139 | if scheduler is not None: 140 | scheduler.step() 141 | if not i % 500: 142 | wandb.log({ 143 | "loss/train": loss.detach().cpu().item(), 144 | "lr": opt.param_groups[0]['lr'], 145 | }, step=i) 146 | if i % 10000 == 0 or i == args.n_steps - 1: 147 | save_data = [x_norm, policy] 148 | torch.save(save_data, os.path.join(model_dir, "latest.pt")) 149 | start_time = time.time() 150 | vid_dir = os.path.join(wandb.run.dir, f"./vids_{i}") 151 | if not os.path.exists(vid_dir): 152 | os.mkdir(vid_dir) 153 | sr_metrics = {} 154 | policy.eval() 155 | metrics = evaluate(args.env_name, policy, x_norm, device, vid_dir) 156 | sr_metrics.update(metrics) 157 | if metrics["success_rate"] >= best_success_rate or i == 0: 158 | best_success_rate = metrics["success_rate"] 159 | torch.save(save_data, os.path.join(model_dir, "best.pt")) 160 | policy.train() 161 | wandb.log(sr_metrics, step=i) 162 | print(f"Eval in {time.time() - start_time} seconds.") 163 | i += 1 164 | if i == args.n_steps: 165 | return -------------------------------------------------------------------------------- /vec_env/LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2019 Antonin Raffin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /vec_env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/entity-factored-rl/afbd84f9398dd884537589135b182c751e19c709/vec_env/__init__.py -------------------------------------------------------------------------------- /vec_env/subproc_vec_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Based on the vec_env implementations at https://github.com/DLR-RM/stable-baselines3. 8 | 9 | from enum import auto 10 | import os 11 | import multiprocessing 12 | from collections import OrderedDict 13 | from typing import Sequence 14 | 15 | import gym 16 | import mujoco_py 17 | import numpy as np 18 | 19 | from vec_env.base_vec_env import VecEnv, CloudpickleWrapper 20 | 21 | 22 | def _worker(remote, parent_remote, env_fn_wrapper, auto_reset): 23 | parent_remote.close() 24 | with mujoco_py.ignore_mujoco_warnings(): 25 | env = env_fn_wrapper.var() 26 | # Ignore "Nan, Inf or huge value in QACC" issues 27 | while True: 28 | try: 29 | cmd, data = remote.recv() 30 | if cmd == 'step': 31 | observation, reward, done, info = env.step(data) 32 | if auto_reset and done: 33 | # save final observation where user can get it, then reset 34 | info["terminal_observation"] = observation 35 | observation = env.reset() 36 | remote.send((observation, reward, done, info)) 37 | elif cmd == 'seed': 38 | remote.send(env.seed(data)) 39 | elif cmd == 'reset': 40 | observation = env.reset() 41 | remote.send(observation) 42 | elif cmd == 'render': 43 | remote.send(env.render(data)) 44 | elif cmd == 'close': 45 | env.close() 46 | remote.close() 47 | break 48 | elif cmd == 'get_spaces': 49 | remote.send((env.observation_space, env.action_space)) 50 | elif cmd == 'env_method': 51 | method = getattr(env, data[0]) 52 | remote.send(method(*data[1], **data[2])) 53 | elif cmd == 'get_attr': 54 | remote.send(getattr(env, data)) 55 | elif cmd == 'set_attr': 56 | remote.send(setattr(env, data[0], data[1])) 57 | else: 58 | raise NotImplementedError("`{}` is not implemented in the worker".format(cmd)) 59 | except EOFError: 60 | break 61 | 62 | 63 | class SubprocVecEnv(VecEnv): 64 | """ 65 | Creates a multiprocess vectorized wrapper for multiple environments, distributing each environment to its own 66 | process, allowing significant speed up when the environment is computationally complex. 67 | 68 | For performance reasons, if your environment is not IO bound, the number of environments should not exceed the 69 | number of logical cores on your CPU. 70 | 71 | .. warning:: 72 | 73 | Only 'forkserver' and 'spawn' start methods are thread-safe, 74 | which is important when TensorFlow sessions or other non thread-safe 75 | libraries are used in the parent (see issue #217). However, compared to 76 | 'fork' they incur a small start-up cost and have restrictions on 77 | global variables. With those methods, users must wrap the code in an 78 | ``if __name__ == "__main__":`` block. 79 | For more information, see the multiprocessing documentation. 80 | 81 | :param env_fns: ([callable]) A list of functions that will create the environments 82 | (each callable returns a `Gym.Env` instance when called). 83 | :param start_method: (str) method used to start the subprocesses. 84 | Must be one of the methods returned by multiprocessing.get_all_start_methods(). 85 | Defaults to 'forkserver' on available platforms, and 'spawn' otherwise. 86 | """ 87 | 88 | def __init__(self, env_fns, start_method=None, auto_reset=False): 89 | self.waiting = False 90 | self.closed = False 91 | n_envs = len(env_fns) 92 | 93 | # In some cases (like on GitHub workflow machine when running tests), 94 | # "forkserver" method results in an "connection error" (probably due to mpi) 95 | # We allow to bypass the default start method if an environment variable 96 | # is specified by the user 97 | if start_method is None: 98 | start_method = os.environ.get("DEFAULT_START_METHOD") 99 | 100 | # No DEFAULT_START_METHOD was specified, start_method may still be None 101 | if start_method is None: 102 | # Fork is not a thread safe method (see issue #217) 103 | # but is more user friendly (does not require to wrap the code in 104 | # a `if __name__ == "__main__":`) 105 | forkserver_available = 'forkserver' in multiprocessing.get_all_start_methods() 106 | start_method = 'forkserver' if forkserver_available else 'spawn' 107 | ctx = multiprocessing.get_context(start_method) 108 | 109 | self.remotes, self.work_remotes = zip(*[ctx.Pipe(duplex=True) for _ in range(n_envs)]) 110 | self.processes = [] 111 | for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns): 112 | args = (work_remote, remote, CloudpickleWrapper(env_fn), auto_reset) 113 | # daemon=True: if the main process crashes, we should not cause things to hang 114 | process = ctx.Process(target=_worker, args=args, daemon=True) # pytype:disable=attribute-error 115 | process.start() 116 | self.processes.append(process) 117 | work_remote.close() 118 | 119 | self.remotes[0].send(('get_spaces', None)) 120 | observation_space, action_space = self.remotes[0].recv() 121 | VecEnv.__init__(self, len(env_fns), observation_space, action_space) 122 | 123 | def step_async(self, actions): 124 | for remote, action in zip(self.remotes, actions): 125 | remote.send(('step', action)) 126 | self.waiting = True 127 | 128 | def step_wait(self): 129 | results = [remote.recv() for remote in self.remotes] 130 | self.waiting = False 131 | obs, rews, dones, infos = zip(*results) 132 | return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos 133 | 134 | def seed(self, seed=None): 135 | for idx, remote in enumerate(self.remotes): 136 | remote.send(('seed', seed + idx)) 137 | return [remote.recv() for remote in self.remotes] 138 | 139 | def reset(self): 140 | for remote in self.remotes: 141 | remote.send(('reset', None)) 142 | obs = [remote.recv() for remote in self.remotes] 143 | return _flatten_obs(obs, self.observation_space) 144 | 145 | def close(self): 146 | if self.closed: 147 | return 148 | if self.waiting: 149 | for remote in self.remotes: 150 | remote.recv() 151 | for remote in self.remotes: 152 | remote.send(('close', None)) 153 | for process in self.processes: 154 | process.join() 155 | self.closed = True 156 | 157 | def get_images(self) -> Sequence[np.ndarray]: 158 | for pipe in self.remotes: 159 | # gather images from subprocesses 160 | # `mode` will be taken into account later 161 | pipe.send(('render', 'rgb_array')) 162 | imgs = [pipe.recv() for pipe in self.remotes] 163 | return imgs 164 | 165 | def get_attr(self, attr_name, indices=None): 166 | """Return attribute from vectorized environment (see base class).""" 167 | target_remotes = self._get_target_remotes(indices) 168 | for remote in target_remotes: 169 | remote.send(('get_attr', attr_name)) 170 | return [remote.recv() for remote in target_remotes] 171 | 172 | def set_attr(self, attr_name, value, indices=None): 173 | """Set attribute inside vectorized environments (see base class).""" 174 | target_remotes = self._get_target_remotes(indices) 175 | for remote in target_remotes: 176 | remote.send(('set_attr', (attr_name, value))) 177 | for remote in target_remotes: 178 | remote.recv() 179 | 180 | def env_method(self, method_name, *method_args, indices=None, **method_kwargs): 181 | """Call instance methods of vectorized environments.""" 182 | target_remotes = self._get_target_remotes(indices) 183 | for remote in target_remotes: 184 | remote.send(('env_method', (method_name, method_args, method_kwargs))) 185 | return [remote.recv() for remote in target_remotes] 186 | 187 | def _get_target_remotes(self, indices): 188 | """ 189 | Get the connection object needed to communicate with the wanted 190 | envs that are in subprocesses. 191 | 192 | :param indices: (None,int,Iterable) refers to indices of envs. 193 | :return: ([multiprocessing.Connection]) Connection object to communicate between processes. 194 | """ 195 | indices = self._get_indices(indices) 196 | return [self.remotes[i] for i in indices] 197 | 198 | 199 | def _flatten_obs(obs, space): 200 | """ 201 | Flatten observations, depending on the observation space. 202 | 203 | :param obs: (list or tuple where X is dict, tuple or ndarray) observations. 204 | A list or tuple of observations, one per environment. 205 | Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays. 206 | :return (OrderedDict, tuple or ndarray) flattened observations. 207 | A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays. 208 | Each NumPy array has the environment index as its first axis. 209 | """ 210 | assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment" 211 | assert len(obs) > 0, "need observations from at least one environment" 212 | 213 | if isinstance(space, gym.spaces.Dict): 214 | assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces" 215 | assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space" 216 | return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()]) 217 | elif isinstance(space, gym.spaces.Tuple): 218 | assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space" 219 | obs_len = len(space.spaces) 220 | return tuple((np.stack([o[i] for o in obs]) for i in range(obs_len))) 221 | else: 222 | return np.stack(obs) 223 | --------------------------------------------------------------------------------