├── .gitignore ├── Download.md ├── ENV.md ├── Eval.md ├── Experiments.md ├── README.md ├── Train.md ├── aloha ├── arp.py ├── compute_waypoints.py ├── configs │ ├── arp.yaml │ └── experiments │ │ ├── act.yaml │ │ ├── diffusion_policy.yaml │ │ └── one_step_prediction.yaml ├── demo.ipynb ├── lerobot │ ├── __init__.py │ ├── __version__.py │ ├── common │ │ ├── datasets │ │ │ ├── compute_stats.py │ │ │ ├── factory.py │ │ │ ├── lerobot_dataset.py │ │ │ ├── push_dataset_to_hub │ │ │ │ ├── _aloha_raw_urls │ │ │ │ │ ├── mobile_cabinet.txt │ │ │ │ │ ├── mobile_chair.txt │ │ │ │ │ ├── mobile_elevator.txt │ │ │ │ │ ├── mobile_shrimp.txt │ │ │ │ │ ├── mobile_wash_pan.txt │ │ │ │ │ ├── mobile_wipe_wine.txt │ │ │ │ │ ├── sim_insertion_human.txt │ │ │ │ │ ├── sim_insertion_scripted.txt │ │ │ │ │ ├── sim_transfer_cube_human.txt │ │ │ │ │ ├── sim_transfer_cube_scripted.txt │ │ │ │ │ ├── static_battery.txt │ │ │ │ │ ├── static_candy.txt │ │ │ │ │ ├── static_coffee.txt │ │ │ │ │ ├── static_coffee_new.txt │ │ │ │ │ ├── static_cups_open.txt │ │ │ │ │ ├── static_fork_pick_up.txt │ │ │ │ │ ├── static_pingpong_test.txt │ │ │ │ │ ├── static_pro_pencil.txt │ │ │ │ │ ├── static_screw_driver.txt │ │ │ │ │ ├── static_tape.txt │ │ │ │ │ ├── static_thread_velcro.txt │ │ │ │ │ ├── static_towel.txt │ │ │ │ │ ├── static_vinh_cup.txt │ │ │ │ │ ├── static_vinh_cup_left.txt │ │ │ │ │ └── static_ziploc_slide.txt │ │ │ │ ├── _diffusion_policy_replay_buffer.py │ │ │ │ ├── _download_raw.py │ │ │ │ ├── _umi_imagecodecs_numcodecs.py │ │ │ │ ├── aloha_hdf5_format.py │ │ │ │ ├── cam_png_format.py │ │ │ │ ├── dora_parquet_format.py │ │ │ │ ├── pusht_zarr_format.py │ │ │ │ ├── umi_zarr_format.py │ │ │ │ ├── utils.py │ │ │ │ └── xarm_pkl_format.py │ │ │ ├── sampler.py │ │ │ ├── transforms.py │ │ │ ├── utils.py │ │ │ └── video_utils.py │ │ ├── envs │ │ │ ├── factory.py │ │ │ └── utils.py │ │ ├── logger.py │ │ ├── policies │ │ │ ├── act │ │ │ │ ├── configuration_act.py │ │ │ │ └── modeling_act.py │ │ │ ├── aloha_diffusion_policy │ │ │ │ ├── __init__.py │ │ │ │ ├── configuration.py │ │ │ │ ├── modeling.py │ │ │ │ └── tinydiffp.py │ │ │ ├── autoregressive_policy │ │ │ │ ├── __init__.py │ │ │ │ ├── configuration.py │ │ │ │ ├── modeling.py │ │ │ │ └── one_step │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── configuration.py │ │ │ │ │ └── modeling.py │ │ │ ├── diffusion │ │ │ │ ├── configuration_diffusion.py │ │ │ │ └── modeling_diffusion.py │ │ │ ├── factory.py │ │ │ ├── normalize.py │ │ │ ├── policy_protocol.py │ │ │ ├── tdmpc │ │ │ │ ├── configuration_tdmpc.py │ │ │ │ └── modeling_tdmpc.py │ │ │ ├── utils.py │ │ │ └── vqbet │ │ │ │ ├── configuration_vqbet.py │ │ │ │ ├── modeling_vqbet.py │ │ │ │ └── vqbet_utils.py │ │ └── utils │ │ │ ├── benchmark.py │ │ │ ├── import_utils.py │ │ │ ├── io_utils.py │ │ │ ├── nn.py │ │ │ └── utils.py │ ├── configs │ │ ├── default.yaml │ │ ├── env │ │ │ ├── aloha.yaml │ │ │ ├── dora_aloha_real.yaml │ │ │ ├── pusht.yaml │ │ │ └── xarm.yaml │ │ └── policy │ │ │ ├── act.single.yaml │ │ │ ├── act.yaml │ │ │ ├── act_real.yaml │ │ │ ├── act_real_no_state.yaml │ │ │ ├── diffusion.yaml │ │ │ ├── diffusion_pusht_keypoints.yaml │ │ │ ├── tdmpc.yaml │ │ │ └── vqbet.yaml │ └── scripts │ │ ├── display_sys_info.py │ │ ├── eval.py │ │ ├── push_dataset_to_hub.py │ │ ├── train.py │ │ ├── visualize_dataset.py │ │ └── visualize_image_transforms.py └── train.py ├── arp.py ├── assets ├── demo.mp4 ├── keydiff.jpg ├── main-fig.jpg ├── result.jpg └── why-chunking-autoregression-works.png ├── environment.yaml ├── profile.ipynb ├── pusht ├── 000.jpg ├── arp.py ├── configs │ ├── arp.yaml │ ├── arp_flat.yaml │ ├── arp_flat_discrete.yaml │ └── experiments │ │ ├── chunk_size │ │ ├── high_level_plan_c1.yaml │ │ ├── high_level_plan_c2.yaml │ │ ├── high_level_plan_c3.yaml │ │ ├── high_level_plan_c4.yaml │ │ ├── low_level_action_c1.yaml │ │ ├── low_level_action_c2.yaml │ │ ├── low_level_action_c4.yaml │ │ └── low_level_action_c8.yaml │ │ ├── image_pusht_diffusion_policy_trans.yaml │ │ ├── one_step_prediction.yaml │ │ └── pusht_act.yaml ├── demo.ipynb ├── diffusion_policy │ ├── codecs │ │ └── imagecodecs_numcodecs.py │ ├── common │ │ ├── act.py │ │ ├── checkpoint_util.py │ │ ├── cv2_util.py │ │ ├── env_util.py │ │ ├── json_logger.py │ │ ├── nested_dict_util.py │ │ ├── normalize_util.py │ │ ├── pose_trajectory_interpolator.py │ │ ├── precise_sleep.py │ │ ├── pymunk_override.py │ │ ├── pymunk_util.py │ │ ├── pytorch_util.py │ │ ├── replay_buffer.py │ │ ├── robomimic_config_util.py │ │ ├── robomimic_util.py │ │ ├── sampler.py │ │ └── timestamp_accumulator.py │ ├── config │ │ ├── task │ │ │ ├── blockpush_lowdim_seed.yaml │ │ │ ├── blockpush_lowdim_seed_abs.yaml │ │ │ ├── can_image.yaml │ │ │ ├── can_image_abs.yaml │ │ │ ├── can_lowdim.yaml │ │ │ ├── can_lowdim_abs.yaml │ │ │ ├── kitchen_lowdim.yaml │ │ │ ├── kitchen_lowdim_abs.yaml │ │ │ ├── lift_image.yaml │ │ │ ├── lift_image_abs.yaml │ │ │ ├── lift_lowdim.yaml │ │ │ ├── lift_lowdim_abs.yaml │ │ │ ├── pusht_image.yaml │ │ │ ├── pusht_lowdim.yaml │ │ │ ├── real_pusht_image.yaml │ │ │ ├── square_image.yaml │ │ │ ├── square_image_abs.yaml │ │ │ ├── square_lowdim.yaml │ │ │ ├── square_lowdim_abs.yaml │ │ │ ├── tool_hang_image.yaml │ │ │ ├── tool_hang_image_abs.yaml │ │ │ ├── tool_hang_lowdim.yaml │ │ │ ├── tool_hang_lowdim_abs.yaml │ │ │ ├── transport_image.yaml │ │ │ ├── transport_image_abs.yaml │ │ │ ├── transport_lowdim.yaml │ │ │ └── transport_lowdim_abs.yaml │ │ ├── train_bet_lowdim_workspace.yaml │ │ ├── train_diffusion_transformer_hybrid_workspace.single.yaml │ │ ├── train_diffusion_transformer_hybrid_workspace.yaml │ │ ├── train_diffusion_transformer_lowdim_kitchen_workspace.yaml │ │ ├── train_diffusion_transformer_lowdim_pusht_workspace.yaml │ │ ├── train_diffusion_transformer_lowdim_workspace.yaml │ │ ├── train_diffusion_transformer_real_hybrid_workspace.yaml │ │ ├── train_diffusion_unet_ddim_hybrid_workspace.yaml │ │ ├── train_diffusion_unet_ddim_lowdim_workspace.yaml │ │ ├── train_diffusion_unet_hybrid_workspace.yaml │ │ ├── train_diffusion_unet_image_pretrained_workspace.yaml │ │ ├── train_diffusion_unet_image_workspace.yaml │ │ ├── train_diffusion_unet_lowdim_workspace.yaml │ │ ├── train_diffusion_unet_real_hybrid_workspace.yaml │ │ ├── train_diffusion_unet_real_image_workspace.yaml │ │ ├── train_diffusion_unet_real_pretrained_workspace.yaml │ │ ├── train_diffusion_unet_video_workspace.yaml │ │ ├── train_ibc_dfo_hybrid_workspace.yaml │ │ ├── train_ibc_dfo_lowdim_workspace.yaml │ │ ├── train_ibc_dfo_real_hybrid_workspace.yaml │ │ ├── train_robomimic_image_workspace.yaml │ │ ├── train_robomimic_lowdim_workspace.yaml │ │ └── train_robomimic_real_image_workspace.yaml │ ├── dataset │ │ ├── base_dataset.py │ │ ├── blockpush_lowdim_dataset.py │ │ ├── kitchen_lowdim_dataset.py │ │ ├── kitchen_mjl_lowdim_dataset.py │ │ ├── pusht_dataset.py │ │ ├── pusht_image_dataset.py │ │ ├── real_pusht_image_dataset.py │ │ ├── robomimic_replay_image_dataset.py │ │ └── robomimic_replay_lowdim_dataset.py │ ├── env │ │ ├── pusht │ │ │ ├── __init__.py │ │ │ ├── pusht_env.py │ │ │ ├── pusht_image_env.py │ │ │ ├── pusht_keypoints_env.py │ │ │ ├── pymunk_keypoint_manager.py │ │ │ └── pymunk_override.py │ │ └── robomimic │ │ │ ├── robomimic_image_wrapper.py │ │ │ └── robomimic_lowdim_wrapper.py │ ├── env_runner │ │ ├── base_image_runner.py │ │ ├── base_lowdim_runner.py │ │ ├── blockpush_lowdim_runner.py │ │ ├── kitchen_lowdim_runner.py │ │ ├── pusht_image_runner.py │ │ ├── pusht_keypoints_runner.py │ │ ├── real_pusht_image_runner.py │ │ ├── robomimic_image_runner.py │ │ └── robomimic_lowdim_runner.py │ ├── gym_util │ │ ├── async_vector_env.py │ │ ├── multistep_wrapper.py │ │ ├── sync_vector_env.py │ │ ├── video_recording_wrapper.py │ │ └── video_wrapper.py │ ├── model │ │ ├── bet │ │ │ ├── action_ae │ │ │ │ ├── __init__.py │ │ │ │ └── discretizers │ │ │ │ │ └── k_means.py │ │ │ ├── latent_generators │ │ │ │ ├── latent_generator.py │ │ │ │ ├── mingpt.py │ │ │ │ └── transformer.py │ │ │ ├── libraries │ │ │ │ ├── loss_fn.py │ │ │ │ └── mingpt │ │ │ │ │ ├── LICENSE │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── model.py │ │ │ │ │ ├── trainer.py │ │ │ │ │ └── utils.py │ │ │ └── utils.py │ │ ├── common │ │ │ ├── dict_of_tensor_mixin.py │ │ │ ├── lr_scheduler.py │ │ │ ├── module_attr_mixin.py │ │ │ ├── normalizer.py │ │ │ ├── rotation_transformer.py │ │ │ ├── shape_util.py │ │ │ └── tensor_util.py │ │ ├── diffusion │ │ │ ├── conditional_unet1d.py │ │ │ ├── conv1d_components.py │ │ │ ├── ema_model.py │ │ │ ├── mask_generator.py │ │ │ ├── positional_embedding.py │ │ │ └── transformer_for_diffusion.py │ │ └── vision │ │ │ ├── crop_randomizer.py │ │ │ ├── model_getter.py │ │ │ └── multi_image_obs_encoder.py │ ├── policy │ │ ├── autoregressive_policy.py │ │ ├── autoregressive_policy_flat.py │ │ ├── autoregressive_policy_flat_discrete.py │ │ ├── base_image_policy.py │ │ ├── base_lowdim_policy.py │ │ ├── bet_lowdim_policy.py │ │ ├── diffusion_transformer_hybrid_image_policy.py │ │ ├── diffusion_transformer_lowdim_policy.py │ │ ├── diffusion_unet_hybrid_image_policy.py │ │ ├── diffusion_unet_image_policy.py │ │ ├── diffusion_unet_lowdim_policy.py │ │ ├── diffusion_unet_video_policy.py │ │ ├── ibc_dfo_hybrid_image_policy.py │ │ ├── ibc_dfo_lowdim_policy.py │ │ ├── robomimic_image_policy.py │ │ ├── robomimic_lowdim_policy.py │ │ └── vae_act_policy.py │ ├── real_world │ │ ├── keystroke_counter.py │ │ ├── multi_camera_visualizer.py │ │ ├── multi_realsense.py │ │ ├── real_data_conversion.py │ │ ├── real_env.py │ │ ├── real_inference_util.py │ │ ├── realsense_config │ │ │ ├── 415_high_accuracy_mode.json │ │ │ └── 435_high_accuracy_mode.json │ │ ├── rtde_interpolation_controller.py │ │ ├── single_realsense.py │ │ ├── spacemouse.py │ │ ├── spacemouse_shared_memory.py │ │ └── video_recorder.py │ ├── scripts │ │ ├── bet_blockpush_conversion.py │ │ ├── blockpush_abs_conversion.py │ │ ├── episode_lengths.py │ │ ├── generate_bet_blockpush.py │ │ ├── real_dataset_conversion.py │ │ ├── real_pusht_metrics.py │ │ ├── real_pusht_successrate.py │ │ ├── robomimic_dataset_action_comparison.py │ │ └── robomimic_dataset_conversion.py │ ├── shared_memory │ │ ├── shared_memory_queue.py │ │ ├── shared_memory_ring_buffer.py │ │ ├── shared_memory_util.py │ │ └── shared_ndarray.py │ └── workspace │ │ ├── arp_workspace.py │ │ ├── base_workspace.py │ │ ├── train_bet_lowdim_workspace.py │ │ ├── train_diffusion_transformer_hybrid_workspace.py │ │ ├── train_diffusion_transformer_lowdim_workspace.py │ │ ├── train_diffusion_unet_hybrid_workspace.py │ │ ├── train_diffusion_unet_image_workspace.py │ │ ├── train_diffusion_unet_lowdim_workspace.py │ │ ├── train_diffusion_unet_video_workspace.py │ │ ├── train_ibc_dfo_hybrid_workspace.py │ │ ├── train_ibc_dfo_lowdim_workspace.py │ │ ├── train_robomimic_image_workspace.py │ │ └── train_robomimic_lowdim_workspace.py ├── draw_human_trajectory.ipynb ├── pusht_human.py ├── qualitative-visualize.ipynb └── train.py ├── real-robot ├── 2d-waypoints-real-robot.ipynb ├── arp.py ├── config.yaml ├── dataset.py ├── eef_T_wrenchHead.txt ├── eef_sm_T_wrenchHead.txt ├── network.py ├── readme.ipynb ├── train.py ├── utils │ ├── __init__.py │ ├── math3d.py │ ├── object.py │ ├── optim.py │ ├── preprocess.py │ ├── spatial.py │ ├── stat_dict.py │ └── vis.py ├── vit.py └── wrench.py └── rlb ├── __init__.py ├── act_policy.py ├── arp.py ├── autoregressive_policy.py ├── autoregressive_policy_plus.py ├── configs ├── act.yaml ├── arp.yaml ├── arp_plus.yaml ├── rvt1.official.yaml ├── rvt2.official.yaml └── rvt2.yaml ├── dataset.py ├── eval.py ├── preprocess.py ├── rvt.py ├── rvt2.py ├── train.py └── utils ├── __init__.py ├── act.py ├── clip.py ├── dist.py ├── env.py ├── layers.py ├── math3d.py ├── metric.py ├── optim.py ├── rollout.py ├── str.py ├── structure.py └── vis.py /.gitignore: -------------------------------------------------------------------------------- 1 | bin 2 | logs 3 | wandb 4 | outputs 5 | data 6 | data_local 7 | .vscode 8 | _wandb 9 | 10 | **/.DS_Store 11 | 12 | fuse.cfg 13 | 14 | *.ai 15 | 16 | # Generation results 17 | results/ 18 | 19 | ray/auth.json 20 | 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | *.py[cod] 24 | *$py.class 25 | 26 | # C extensions 27 | *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | .eggs/ 37 | lib/ 38 | lib64/ 39 | parts/ 40 | sdist/ 41 | var/ 42 | wheels/ 43 | pip-wheel-metadata/ 44 | share/python-wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | MANIFEST 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .nox/ 64 | .coverage 65 | .coverage.* 66 | .cache 67 | nosetests.xml 68 | coverage.xml 69 | *.cover 70 | *.py,cover 71 | .hypothesis/ 72 | .pytest_cache/ 73 | 74 | # Translations 75 | *.mo 76 | *.pot 77 | 78 | # Django stuff: 79 | *.log 80 | local_settings.py 81 | db.sqlite3 82 | db.sqlite3-journal 83 | 84 | # Flask stuff: 85 | instance/ 86 | .webassets-cache 87 | 88 | # Scrapy stuff: 89 | .scrapy 90 | 91 | # Sphinx documentation 92 | docs/_build/ 93 | 94 | # PyBuilder 95 | target/ 96 | 97 | # Jupyter Notebook 98 | .ipynb_checkpoints 99 | 100 | # IPython 101 | profile_default/ 102 | ipython_config.py 103 | 104 | # pyenv 105 | .python-version 106 | 107 | # pipenv 108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 111 | # install all needed dependencies. 112 | #Pipfile.lock 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | 142 | pusht/data 143 | aloha/data 144 | rlb/data 145 | 146 | pusht/outputs 147 | aloha/outputs 148 | rlb/outputs 149 | 150 | pusht/weights 151 | aloha/weights 152 | rlb/weights 153 | 154 | rlb/point-renderer 155 | 156 | real-robot/outputs 157 | real-robot/weights 158 | real-robot/data -------------------------------------------------------------------------------- /Eval.md: -------------------------------------------------------------------------------- 1 | Please first setup your dataset and pretrained weights. 2 | 3 | # Push-T 4 | 5 | Check out [pusht/demo.ipynb](pusht/demo.ipynb). It loads and tests the pretrained model, and save the videos into `pusht/outputs/demo` folder. The full evaluation is done during training periodically. 6 | 7 | 8 | # ALOHA 9 | 10 | Check out [aloha/demo.ipynb](aloha/demo.ipynb). It loads and tests the pretrained model, and save the videos into `aloha/outputs/demo` folder. Like Push-T, the full evaluation is done during training periodically. 11 | 12 | 13 | 14 | # RLBench 15 | 16 | The following command will load ARP+ model and evaluate it on RLBench. 17 | 18 | ```bash 19 | cd rlb 20 | python3 eval.py config=./configs/arp_plus.yaml model.weights=./weights/arp_plus_model_70000.pth hydra.job.name=eval.arp_plus eval.device=0 output_dir=outputs/eval.arp_plus/`date +"%Y-%m-%d_%H-%M"` 21 | ``` 22 | 23 | Adding `eval.save_video=True` will save the rotating videos like the one in [assets/demo.mp4](assets/demo.mp4) to `rlb/outputs/recording`. But the evaluation will get extremely slow. 24 | 25 | Ensure you are running `eval.py` in a machine with GUI and `DISPLAY` environment variable is set. In other cases, check the tip below. 26 | 27 | 28 | ## Tip: Running RLBench on a Headless Server without Sudo 29 | 30 | The evaluation of RLBench requires a GUI environment, and `DISPLAY` environment variable set. 31 | 32 | If you only have access to a remote server, you can use `xvfb` to create a cpu-based virtual display. It's a little bit slow but tolerable. In case that you do not have sudo permission to install softwares. Install it with anaconda, follow scripts below: 33 | 34 | ```bash 35 | conda install anaconda::xorg-x11-server-xvfb-cos6-x86_64 36 | ``` 37 | 38 | The above command will install a xvfb that tailored for centos 6, but it works for other linux as well. Run xvfb with this command 39 | 40 | ```bash 41 | $HOME/anaconda/x86_64-conda-linux-gnu/sysroot/usr/bin/Xvfb :99 -screen 0 1024x768x24 +extension GLX +render -noreset 42 | ``` 43 | 44 | `$HOME/anaconda` is where my anaconda is installed. If you have it else where, change it accordingly. Then, set `DISPLAY` environment variable: 45 | 46 | ```bash 47 | export DISPLAY=:99 48 | ``` 49 | 50 | for the shell where you run the evaluation script. 51 | 52 | 53 | If the `xvfb` command complains about some missing shared libraries, for example `libcrypto.so.10`, you can download it from the internet (or copy it from another machine that has this) and put it in your home directory. Then run xvfb with `LD_PRELOAD`: 54 | 55 | ```bash 56 | LD_PRELOAD=$HOME/libcrypto.so.10 $HOME/anaconda3/x86_64-conda-linux-gnu/sysroot/usr/bin/Xvfb :99 -screen 0 1024x768x24 +extension GLX +render -noreset 57 | ``` 58 | 59 | If the above does not work for you because of some GLIB version issues (the real reason is your server being too out-dated), the last but sure resort is to use [singularity](https://github.com/sylabs/singularity) container and run evaluation inside. I have built a working one, let me know if you need it. -------------------------------------------------------------------------------- /Experiments.md: -------------------------------------------------------------------------------- 1 | Here we document some experiments that are included in the paper. For most experiments, we only list the configuration files without repeating the training command (see [Train.md](Train.md) for detailed commands). 2 | 3 | Our training & eval logs can be found at `training-logs` in the box folder https://rutgers.box.com/s/uzozemx67kje58ycy3lyzf1zgddz8tyq. 4 | 5 | # Baselines 6 | 7 | - Push-T (Diffusion Policy [T]): `pusht/configs/experiments/image_pusht_diffusion_policy_trans.yaml` 8 | 9 | - ALOHA (ACT): `aloha/configs/experiments/act.yaml` 10 | > Note the original ACT mentions 7 decoding layers but only uses 1 due to a code issue. We set the number of decoding layers to be 4. 11 | 12 | - RLBench (RVT): `rlb/configs/rvt2.yaml`. 13 | 14 | The pretrained model of RVT2 can be found at `weights/rlb/rvt/rvt2_without_time_model_70000.pth`. You can run the evaluation with the following command (download the model first): 15 | 16 | ```bash 17 | python3 eval.py config=./configs/rvt2.yaml model.weights=./weights/rvt/rvt2_without_time_model_70000.pth 18 | ``` 19 | 20 | Note you can also run the official rvt/rvt2 models. These models are in the same folder `weights/rlb/rvt` (in box). Their config files are `rlb/configs/rvt1.official.yaml` and `rlb/configs/rvt2.official.yaml`. 21 | 22 | ```bash 23 | # rvt 24 | python3 eval.py config=./configs/rvt1.official.yaml model.weights=./weights/rvt/rvt1_official_model_14.pth 25 | 26 | # rvt2 27 | python3 eval.py config=./configs/rvt2.official.yaml model.weights=./weights/rvt/rvt2_official_model_99.pth 28 | ``` 29 | 30 | 31 | They shall reproduce the same results as in their papers. To change the output folder, set up headless rendering, find more detailed command in [Eval.md](eval.md). 32 | 33 | 34 | # One-step Prediction 35 | 36 | - Push-T: `pusht/configs/experiments/one_step_prediction.yaml` 37 | 38 | - ALOHA: `aloha/configs/experiments/one_step_prediction.yaml` 39 | 40 | # Chunk Sizes 41 | 42 | - Push-T: 43 | - `pusht/configs/experiments/chunk_size/high_level_plan_c{1,2,3,4}.yaml` 44 | - `pusht/configs/experiments/chunk_size/low_level_action_c{1,2,4,8}.yaml` 45 | 46 | 47 | # Existing Methods on Different Environments 48 | 49 | - Diffusion Policy on ALOHA (does not work well): `aloha/configs/experiments/diffusion_policy.yaml` 50 | 51 | - ACT on Push-T: `pusht/configs/experiments/pusht_act.yaml` 52 | 53 | - ACT on RLBench: `rlb/configs/act.yaml` 54 | -------------------------------------------------------------------------------- /Train.md: -------------------------------------------------------------------------------- 1 | Please first setup your dataset and pretrained weights. Our logs are stored in `training-logs/main-results/` folder (in box). 2 | 3 | > Feel free to organize the following snippets in bash script. 4 | 5 | # Push-T 6 | 7 | Set cuda device 8 | 9 | ```bash 10 | export CUDA_VISIBLE_DEVICES=0 11 | ``` 12 | 13 | Set environment variables for MUJOCO server rendering 14 | 15 | ```bash 16 | export PYOPENGL_PLATFORM=egl 17 | export MUJOCO_GL=egl 18 | export EGL_DEVICE_ID=$CUDA_VISIBLE_DEVICES 19 | export MUJOCO_EGL_DEVICE_ID=$CUDA_VISIBLE_DEVICES 20 | ``` 21 | 22 | Then starts training with: 23 | 24 | ```bash 25 | cd ./pusht 26 | 27 | timestamp=`date +"%y-%m-%d_%H_%M_%S"` 28 | python3 ./train.py --config-dir ./configs --config-name arp.yaml hydra.run.dir=outputs/arp/${timestamp} \ 29 | training.device=cuda:0 logging.mode=offline logging.name="arp@${timestamp}" name=arp 30 | ``` 31 | 32 | If you have wandb configured, set `logging.mode=online`. 33 | 34 | 35 | 36 | # ALOHA 37 | 38 | Set cuda device 39 | 40 | ```bash 41 | export CUDA_VISIBLE_DEVICES=0 42 | ``` 43 | 44 | Set environment variables for MUJOCO server rendering 45 | 46 | ```bash 47 | export PYOPENGL_PLATFORM=egl 48 | export MUJOCO_GL=egl 49 | export EGL_DEVICE_ID=$CUDA_VISIBLE_DEVICES 50 | export MUJOCO_EGL_DEVICE_ID=$CUDA_VISIBLE_DEVICES 51 | ``` 52 | 53 | Then, select task: 54 | 55 | ```bash 56 | export task=insertion # or set to `transfer_cube` 57 | 58 | if [[ $task == insertion ]]; then 59 | export aloha_env=AlohaInsertion 60 | elif [[ $task == transfer_cube ]]; then 61 | export aloha_env=AlohaTransferCube 62 | fi 63 | echo "ALOHA ENV: $aloha_env" 64 | ``` 65 | 66 | 67 | Next, start training with: 68 | 69 | ```bash 70 | cd ./aloha 71 | timestamp=`date +"%y-%m-%d_%H_%M_%S"` 72 | run_dir=outputs/${task}-arp/${timestamp} 73 | mkdir -p ${run_dir} 74 | 75 | python3 train.py --config-dir ./configs --config-name arp device=cuda:0 \ 76 | hydra.run.dir="${run_dir}" hydra.job.name="${task}-arp@${timestamp}" \ 77 | env.task=${aloha_env}-v0 dataset_repo_id=lerobot/aloha_sim_${task}_human seed=$RANDOM 78 | ``` 79 | 80 | 81 | 82 | # RLBench 83 | 84 | Suppose you have 2 GPUs, and you want to have a batch size of 96 on each GPU. 85 | 86 | ```bash 87 | python3 train.py config=./configs/arp_plus.yaml hydra.job.name=arp_plus train.num_gpus=2 train.bs=96 88 | ``` 89 | 90 | If you have more GPUs, feel free to change the number of GPUs and batch size accordingly. 91 | 92 | -------------------------------------------------------------------------------- /aloha/arp.py: -------------------------------------------------------------------------------- 1 | ../arp.py -------------------------------------------------------------------------------- /aloha/lerobot/__version__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """To enable `lerobot.__version__`""" 17 | 18 | from importlib.metadata import PackageNotFoundError, version 19 | 20 | try: 21 | __version__ = version("lerobot") 22 | except PackageNotFoundError: 23 | __version__ = "unknown" 24 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_elevator.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/file/d/12ctkOAdkCNGN1JLbZb5ww3XTBn2LFpGI/view?usp=drive_link 2 | https://drive.google.com/file/d/1G_Vd46_4fq6O64gHHjUbJX5Ld44ZZx0y/view?usp=drive_link 3 | https://drive.google.com/file/d/1uKgUy73B3xBogQAOUhfZjO0X5qZGsi2c/view?usp=drive_link 4 | https://drive.google.com/file/d/1fu9cIrfI-fE2LhdGUxbx7-8Ci_PF8Ypm/view?usp=drive_link 5 | https://drive.google.com/file/d/1Ygk9ZPJzx8xw2A9JF3NHbJ44TqnvSTQR/view?usp=drive_link 6 | https://drive.google.com/file/d/18m5xPuccNsEB20WPshm3zhxmXc6k63ED/view?usp=drive_link 7 | https://drive.google.com/file/d/1DiqqxC44rriviRQpqogcv0-EB-Y6nr9g/view?usp=drive_link 8 | https://drive.google.com/file/d/1qPdaoTVDizJXkfXLioWU7iJ8hqCXSyOQ/view?usp=drive_link 9 | https://drive.google.com/file/d/1Fj9kIA_mG7f67WFfACJEaZ7izcHG7vUm/view?usp=drive_link 10 | https://drive.google.com/file/d/1WpYehZnI2P7dUdJPfkE-ij1rqCnjZEbB/view?usp=drive_link 11 | https://drive.google.com/file/d/1_zwWkT4jPyzB38STWb6whlzsPzXmfA9r/view?usp=drive_link 12 | https://drive.google.com/file/d/1U6-J4I_fPlSFFGfhZPxS5_YzKXwXIZYp/view?usp=drive_link 13 | https://drive.google.com/file/d/1pRhxxcTfZp5tQo_EScvJUwfc3amiS6Vk/view?usp=drive_link 14 | https://drive.google.com/file/d/1lWLntqra83RlYU_gN7Vostnfydf6gutd/view?usp=drive_link 15 | https://drive.google.com/file/d/1vIBKo0x-NYEHV1FvRpco1lQMpRdAWAIL/view?usp=drive_link 16 | https://drive.google.com/file/d/1pdrLV3JTQou_XH0Aap61Ssf60iVKm1jJ/view?usp=drive_link 17 | https://drive.google.com/file/d/1QTsLoQ7SwmKdQHjBGVDaR2uTwfFwtrOf/view?usp=drive_link 18 | https://drive.google.com/file/d/1Gytai8M_12J36GY6L_TulEcOC-035jwS/view?usp=drive_link 19 | https://drive.google.com/file/d/14LJudNc629NT-i8xreXtzl27ce_DxOFJ/view?usp=drive_link 20 | https://drive.google.com/file/d/1sBvPCODbzxGAI0S3lgN5cSG9Go3lRi00/view?usp=drive_link 21 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_shrimp.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/file/d/1MJn9GbC8p9lN4gC9KDMLEkTkP_gGpXj0/view?usp=drive_link 2 | https://drive.google.com/file/d/1-4LXgjl7ZCOgp-8GCJmFRD8OeqN5Jf7-/view?usp=drive_link 3 | https://drive.google.com/file/d/1Ho06Ce0SPbqU3juaMxNUwAt3zCRLGC8W/view?usp=drive_link 4 | https://drive.google.com/file/d/1ivHoj7_7olBSxH-Y8kqXEW7ttITK-45j/view?usp=drive_link 5 | https://drive.google.com/file/d/1qjY4hM_IvZ8cq2II_n9MeJbvyeuN4oBP/view?usp=drive_link 6 | https://drive.google.com/file/d/1rKVhO_f92-7sw13T8hTVrza3B9oAVgoy/view?usp=drive_link 7 | https://drive.google.com/file/d/1pcLPHO8fBkc1-CRa88tyQtEueE4xiXNi/view?usp=drive_link 8 | https://drive.google.com/file/d/1Vev_chCsIeEdvQ8poEYNsOJFGy_QU8kZ/view?usp=drive_link 9 | https://drive.google.com/file/d/1l5G4zpRkxSLCQjvGPYSN4zfCvVRQuzMz/view?usp=drive_link 10 | https://drive.google.com/file/d/14vgthE1eoakXkr2-DRw50E6lAqYOiUuE/view?usp=drive_link 11 | https://drive.google.com/file/d/17nPSmKKmgQ2B7zkzWrZYiLM3RBuFod82/view?usp=drive_link 12 | https://drive.google.com/file/d/1QcDsxplVvb_ID9BVrihl5FvlC-j7waXi/view?usp=drive_link 13 | https://drive.google.com/file/d/18pEejBpI-eEVaWAAjBCyC0vgbX3T1Esj/view?usp=drive_link 14 | https://drive.google.com/file/d/1H8eH6_IRODtEFT6WoM77ltR5OoOrqXmI/view?usp=drive_link 15 | https://drive.google.com/file/d/1IWlpFRZhoxyG4nS13CWK4leZVk5wbNx4/view?usp=drive_link 16 | https://drive.google.com/file/d/1PbZA8_OCGmMLxNP9xbkLRSChniL4uGxl/view?usp=drive_link 17 | https://drive.google.com/file/d/1p9XAdmG2f_WeflNO4DIJ_tr1rK6M9B4B/view?usp=drive_link 18 | https://drive.google.com/file/d/1nS59Et1cNAvKo3Y4SeSGRuZD5TvBbCF3/view?usp=drive_link 19 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_wash_pan.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1S8eFg98IaGAIKVZ8QFWG1bx4mHa-O204 2 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/mobile_wipe_wine.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1tC_g1AJ8lglBLY-fjsQrG6DMBa3Ucp-0 2 | https://drive.google.com/file/d/1fG_Yi2MJrFjiUVN3XoiWXLtTxHlwwaDv/view?usp=drive_link 3 | https://drive.google.com/file/d/1WX32VWfzzX3Blmd06DRxLwFbMJfVe7P4/view?usp=drive_link 4 | https://drive.google.com/file/d/18onsX3vXg3xkFwP5bVUCjdV4n9TRn0C9/view?usp=drive_link 5 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/sim_insertion_human.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF 2 | https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link 3 | https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link 4 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/sim_insertion_scripted.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N 2 | https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link 3 | https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link 4 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/sim_transfer_cube_human.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo 2 | https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link 3 | https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link 4 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/sim_transfer_cube_scripted.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj 2 | https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link 3 | https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link 4 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_battery.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/19qS_n7vKgDcPeTMnvDHQ5-n73xEbJz5D 2 | https://drive.google.com/file/d/1oC31By0A2bsBeHyUwBdQw1z4ng6yi9Za/view?usp=drive_link 3 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_candy.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1m5rQ6UVH8Q9RQp_6c0CxkQ88-L-ScO7q 2 | https://drive.google.com/file/d/1wHz2qcmwcVG0C0CZ9MjQDQcmj4OY9_a3/view?usp=drive_link 3 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_coffee.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1seQGay470nGQ-knBI5TjsTr8iL9Qws5q 2 | https://drive.google.com/file/d/1T89hSX5U99wLGvGTE7yUBaQPOpyj6Sai/view?usp=drive_link 3 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_coffee_new.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1t3eDc5Rg0DveyRe8oTm6Dia_FYU5mXyf 2 | https://drive.google.com/file/d/1TXFaduTakvS0ZWJqKCX-HIvYglum_5CY/view?usp=drive_link 3 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_cups_open.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1Z9X3DNzd6LS0FFjQemNUMoMA5yk5VQOh 2 | https://drive.google.com/file/d/1Wlyc0vTkjXuWB6zbaVOWhEfD7BmPgUV_/view?usp=drive_link 3 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_pingpong_test.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1Ut2cv6o6Pkfgg46DgwVUM7Z5PkNG8eJ- 2 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_pro_pencil.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1FqxPV0PgvgIu8XFjtvZSPSExuNcxVVAY 2 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_screw_driver.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1SKtG0ct9q0nVdYssJNMWSOjikcXliT58 2 | https://drive.google.com/file/d/1nchD21O30B3i3LDoqramo1zgW5YvpJIN/view?usp=drive_link 3 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_tape.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1_4DHf2cma0xsChLQFghwigX6Ukti5-zQ 2 | https://drive.google.com/file/d/1_8vS4hDNDgUQY-SmekrNaa7dF67QJYU-/view?usp=drive_link 3 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_thread_velcro.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1_4DHf2cma0xsChLQFghwigX6Ukti5-zQ 2 | https://drive.google.com/file/d/1_8vS4hDNDgUQY-SmekrNaa7dF67QJYU-/view?usp=drive_link 3 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_towel.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1fAD7vkyTGTFB_nGXIKofCU1U05oE3MFv 2 | https://drive.google.com/file/d/1XzyQ2B6LLvcurIonOpEu4nij2qwNWshH/view?usp=drive_link 3 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/_aloha_raw_urls/static_ziploc_slide.txt: -------------------------------------------------------------------------------- 1 | https://drive.google.com/drive/folders/1EgKar7rWBmTIRmeJYZciSwjZx3uP2mHO 2 | https://drive.google.com/file/d/12eYWQO15atK2hBjXhynPJd9MKAj_42pz/view?usp=drive_link 3 | https://drive.google.com/file/d/1Ul4oEeICJDjgfYTl4H1uaisTzVYIM6wd/view?usp=drive_link 4 | https://drive.google.com/file/d/1WSF-OG8lKSe2wVYCv5D1aJNipxpgddk-/view?usp=drive_link 5 | https://drive.google.com/file/d/1_ppD5j5sFh26aWW0JmhLzJMeNB-lCArk/view?usp=drive_link 6 | https://drive.google.com/file/d/1WUp846dgWXYhu4oJfhHxiU6YL_7N6s4W/view?usp=drive_link 7 | https://drive.google.com/file/d/1HRZNAIoAQw_uYiPwnBvtBioQoqiqoXdA/view?usp=drive_link 8 | https://drive.google.com/file/d/1hedGq-QDMnIn8GlXXBC3GiEJ_Y-LTxyt/view?usp=drive_link 9 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/push_dataset_to_hub/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from concurrent.futures import ThreadPoolExecutor 17 | from pathlib import Path 18 | 19 | import numpy 20 | import PIL 21 | import torch 22 | 23 | 24 | def concatenate_episodes(ep_dicts): 25 | data_dict = {} 26 | 27 | keys = ep_dicts[0].keys() 28 | for key in keys: 29 | if torch.is_tensor(ep_dicts[0][key][0]): 30 | data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts]) 31 | else: 32 | if key not in data_dict: 33 | data_dict[key] = [] 34 | for ep_dict in ep_dicts: 35 | for x in ep_dict[key]: 36 | data_dict[key].append(x) 37 | 38 | total_frames = data_dict["frame_index"].shape[0] 39 | data_dict["index"] = torch.arange(0, total_frames, 1) 40 | return data_dict 41 | 42 | 43 | def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4): 44 | out_dir = Path(out_dir) 45 | out_dir.mkdir(parents=True, exist_ok=True) 46 | 47 | def save_image(img_array, i, out_dir): 48 | img = PIL.Image.fromarray(img_array) 49 | img.save(str(out_dir / f"frame_{i:06d}.png"), quality=100) 50 | 51 | num_images = len(imgs_array) 52 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 53 | [executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)] 54 | -------------------------------------------------------------------------------- /aloha/lerobot/common/datasets/sampler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from typing import Iterator, Union 17 | 18 | import torch 19 | 20 | 21 | class EpisodeAwareSampler: 22 | def __init__( 23 | self, 24 | episode_data_index: dict, 25 | episode_indices_to_use: Union[list, None] = None, 26 | drop_n_first_frames: int = 0, 27 | drop_n_last_frames: int = 0, 28 | shuffle: bool = False, 29 | ): 30 | """Sampler that optionally incorporates episode boundary information. 31 | 32 | Args: 33 | episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode. 34 | episode_indices_to_use: List of episode indices to use. If None, all episodes are used. 35 | Assumes that episodes are indexed from 0 to N-1. 36 | drop_n_first_frames: Number of frames to drop from the start of each episode. 37 | drop_n_last_frames: Number of frames to drop from the end of each episode. 38 | shuffle: Whether to shuffle the indices. 39 | """ 40 | indices = [] 41 | for episode_idx, (start_index, end_index) in enumerate( 42 | zip(episode_data_index["from"], episode_data_index["to"]) 43 | ): 44 | if episode_indices_to_use is None or episode_idx in episode_indices_to_use: 45 | indices.extend( 46 | range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames) 47 | ) 48 | 49 | self.indices = indices 50 | self.shuffle = shuffle 51 | 52 | def __iter__(self) -> Iterator[int]: 53 | if self.shuffle: 54 | for i in torch.randperm(len(self.indices)): 55 | yield self.indices[i] 56 | else: 57 | for i in self.indices: 58 | yield i 59 | 60 | def __len__(self) -> int: 61 | return len(self.indices) 62 | -------------------------------------------------------------------------------- /aloha/lerobot/common/envs/factory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import importlib 17 | 18 | import gymnasium as gym 19 | from omegaconf import DictConfig 20 | 21 | 22 | def make_env(cfg: DictConfig, n_envs: int = None) -> gym.vector.VectorEnv: 23 | """Makes a gym vector environment according to the evaluation config. 24 | 25 | n_envs can be used to override eval.batch_size in the configuration. Must be at least 1. 26 | """ 27 | if n_envs is not None and n_envs < 1: 28 | raise ValueError("`n_envs must be at least 1") 29 | 30 | package_name = f"gym_{cfg.env.name}" 31 | 32 | try: 33 | importlib.import_module(package_name) 34 | except ModuleNotFoundError as e: 35 | print( 36 | f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.env.name}]'`" 37 | ) 38 | raise e 39 | 40 | gym_handle = f"{package_name}/{cfg.env.task}" 41 | gym_kwgs = dict(cfg.env.get("gym", {})) 42 | 43 | if cfg.env.get("episode_length"): 44 | gym_kwgs["max_episode_steps"] = cfg.env.episode_length 45 | 46 | # batched version of the env that returns an observation of shape (b, c) 47 | env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv 48 | env = env_cls( 49 | [ 50 | lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs) 51 | for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size) 52 | ] 53 | ) 54 | 55 | return env 56 | -------------------------------------------------------------------------------- /aloha/lerobot/common/envs/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import einops 17 | from typing import Dict 18 | import numpy as np 19 | import torch 20 | from torch import Tensor 21 | 22 | 23 | def preprocess_observation(observations: Dict[str, np.ndarray]) -> Dict[str, Tensor]: 24 | """Convert environment observation to LeRobot format observation. 25 | Args: 26 | observation: Dictionary of observation batches from a Gym vector environment. 27 | Returns: 28 | Dictionary of observation batches with keys renamed to LeRobot format and values as tensors. 29 | """ 30 | # map to expected inputs for the policy 31 | return_observations = {} 32 | if "pixels" in observations: 33 | if isinstance(observations["pixels"], dict): 34 | imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()} 35 | else: 36 | imgs = {"observation.image": observations["pixels"]} 37 | 38 | for imgkey, img in imgs.items(): 39 | img = torch.from_numpy(img) 40 | 41 | # sanity check that images are channel last 42 | _, h, w, c = img.shape 43 | assert c < h and c < w, f"expect channel first images, but instead {img.shape}" 44 | 45 | # sanity check that images are uint8 46 | assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" 47 | 48 | # convert to channel first of type float32 in range [0,1] 49 | img = einops.rearrange(img, "b h w c -> b c h w").contiguous() 50 | img = img.type(torch.float32) 51 | img /= 255 52 | 53 | return_observations[imgkey] = img 54 | 55 | if "environment_state" in observations: 56 | return_observations["observation.environment_state"] = torch.from_numpy( 57 | observations["environment_state"] 58 | ).float() 59 | 60 | # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing 61 | # requirement for "agent_pos" 62 | return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float() 63 | return return_observations 64 | -------------------------------------------------------------------------------- /aloha/lerobot/common/policies/aloha_diffusion_policy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlzxy/arp/73cb339c032b3c350966269f4729bd1f29e927a9/aloha/lerobot/common/policies/aloha_diffusion_policy/__init__.py -------------------------------------------------------------------------------- /aloha/lerobot/common/policies/autoregressive_policy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlzxy/arp/73cb339c032b3c350966269f4729bd1f29e927a9/aloha/lerobot/common/policies/autoregressive_policy/__init__.py -------------------------------------------------------------------------------- /aloha/lerobot/common/policies/autoregressive_policy/configuration.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional, List, Dict 3 | 4 | 5 | @dataclass 6 | class ARPConfig: 7 | # Input / output structure. 8 | n_obs_steps: int = 1 9 | chunk_size: int = 100 10 | n_action_steps: int = 100 11 | n_action_steps_eval: int = -1 12 | 13 | input_shapes: Dict[str, List[int]] = field( 14 | default_factory=lambda: { 15 | "observation.images.top": [3, 480, 640], 16 | "observation.state": [14], 17 | } 18 | ) 19 | output_shapes: Dict[str, List[int]] = field( 20 | default_factory=lambda: { 21 | "action": [14], 22 | } 23 | ) 24 | 25 | # Normalization / Unnormalization 26 | input_normalization_modes: Dict[str, str] = field( 27 | default_factory=lambda: { 28 | "observation.images.top": "mean_std", 29 | "observation.state": "mean_std", 30 | } 31 | ) 32 | output_normalization_modes: Dict[str, str] = field( 33 | default_factory=lambda: { 34 | "action": "mean_std", 35 | } 36 | ) 37 | 38 | # Architecture. 39 | # Vision backbone. 40 | vision_backbone: str = "resnet18" 41 | pretrained_backbone_weights: Optional[str] = "ResNet18_Weights.IMAGENET1K_V1" 42 | replace_final_stride_with_dilation: int = False 43 | # Transformer layers. 44 | pre_norm: bool = False 45 | dim_model: int = 512 46 | n_heads: int = 8 47 | dim_feedforward: int = 3200 48 | feedforward_activation: str = "relu" 49 | n_encoder_layers: int = 4 50 | dropout: float = 0.1 51 | 52 | num_latents: int = 1 53 | num_guide_points: int = 10 54 | guide_pts_downsample: int = 1 55 | guide_pts_corr_dim: int = 64 56 | guide_pts_heatmap_sigma: float = 1.5 57 | 58 | arp_cfg: dict = field(default_factory=lambda: {}) 59 | sample: bool = False 60 | 61 | guide_chunk_size: int = -1 62 | action_chunk_size: int = -1 63 | 64 | def __post_init__(self): 65 | """Input validation (not exhaustive).""" 66 | if not self.vision_backbone.startswith("resnet"): 67 | raise ValueError( 68 | f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." 69 | ) 70 | 71 | if self.n_action_steps > self.chunk_size: 72 | raise ValueError( 73 | f"The chunk size is the upper bound for the number of action steps per model invocation. Got " 74 | f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." 75 | ) 76 | if self.n_obs_steps != 1: 77 | raise ValueError( 78 | f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" 79 | ) 80 | 81 | if self.n_action_steps_eval == -1: 82 | self.n_action_steps_eval = self.n_action_steps 83 | 84 | self.arp_cfg['max_seq_len'] = 1 + self.chunk_size + self.num_guide_points * 2 85 | 86 | 87 | -------------------------------------------------------------------------------- /aloha/lerobot/common/policies/autoregressive_policy/one_step/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlzxy/arp/73cb339c032b3c350966269f4729bd1f29e927a9/aloha/lerobot/common/policies/autoregressive_policy/one_step/__init__.py -------------------------------------------------------------------------------- /aloha/lerobot/common/policies/autoregressive_policy/one_step/configuration.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional, List, Dict 3 | 4 | 5 | @dataclass 6 | class ARPConfig: 7 | # Input / output structure. 8 | n_obs_steps: int = 1 9 | chunk_size: int = 100 10 | n_action_steps: int = 100 11 | n_action_steps_eval: int = -1 12 | 13 | input_shapes: Dict[str, List[int]] = field( 14 | default_factory=lambda: { 15 | "observation.images.top": [3, 480, 640], 16 | "observation.state": [14], 17 | } 18 | ) 19 | output_shapes: Dict[str, List[int]] = field( 20 | default_factory=lambda: { 21 | "action": [14], 22 | } 23 | ) 24 | 25 | # Normalization / Unnormalization 26 | input_normalization_modes: Dict[str, str] = field( 27 | default_factory=lambda: { 28 | "observation.images.top": "mean_std", 29 | "observation.state": "mean_std", 30 | } 31 | ) 32 | output_normalization_modes: Dict[str, str] = field( 33 | default_factory=lambda: { 34 | "action": "mean_std", 35 | } 36 | ) 37 | 38 | # Architecture. 39 | # Vision backbone. 40 | vision_backbone: str = "resnet18" 41 | pretrained_backbone_weights: Optional[str] = "ResNet18_Weights.IMAGENET1K_V1" 42 | replace_final_stride_with_dilation: int = False 43 | # Transformer layers. 44 | pre_norm: bool = False 45 | dim_model: int = 512 46 | n_heads: int = 8 47 | dim_feedforward: int = 3200 48 | feedforward_activation: str = "relu" 49 | n_encoder_layers: int = 4 50 | dropout: float = 0.1 51 | 52 | num_latents: int = 1 53 | num_guide_points: int = 10 54 | guide_pts_downsample: int = 1 55 | guide_pts_corr_dim: int = 64 56 | guide_pts_heatmap_sigma: float = 1.5 57 | 58 | arp_cfg: Dict = field(default_factory=lambda: {}) 59 | sample: bool = False 60 | 61 | guide_chunk_size: int = -1 62 | action_chunk_size: int = -1 63 | 64 | def __post_init__(self): 65 | """Input validation (not exhaustive).""" 66 | if not self.vision_backbone.startswith("resnet"): 67 | raise ValueError( 68 | f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." 69 | ) 70 | 71 | if self.n_action_steps > self.chunk_size: 72 | raise ValueError( 73 | f"The chunk size is the upper bound for the number of action steps per model invocation. Got " 74 | f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." 75 | ) 76 | if self.n_obs_steps != 1: 77 | raise ValueError( 78 | f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" 79 | ) 80 | 81 | if self.n_action_steps_eval == -1: 82 | self.n_action_steps_eval = self.n_action_steps 83 | 84 | self.arp_cfg['max_seq_len'] = 1 + self.chunk_size + self.num_guide_points * 2 85 | 86 | 87 | -------------------------------------------------------------------------------- /aloha/lerobot/common/policies/policy_protocol.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """A protocol that all policies should follow. 17 | 18 | This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes 19 | subclass a base class. 20 | 21 | The protocol structure, method signatures, and docstrings should be used by developers as a reference for 22 | how to implement new policies. 23 | """ 24 | 25 | from typing import Protocol, runtime_checkable, Optional, Dict 26 | 27 | from torch import Tensor 28 | 29 | 30 | @runtime_checkable 31 | class Policy(Protocol): 32 | """The required interface for implementing a policy. 33 | 34 | We also expect all policies to subclass torch.nn.Module and PyTorchModelHubMixin. 35 | """ 36 | 37 | name: str 38 | 39 | def __init__(self, cfg, dataset_stats: Optional[Dict[str, Dict[str, Tensor]]] = None): 40 | """ 41 | Args: 42 | cfg: Policy configuration class instance or None, in which case the default instantiation of the 43 | configuration class is used. 44 | dataset_stats: Dataset statistics to be used for normalization. 45 | """ 46 | 47 | def reset(self): 48 | """To be called whenever the environment is reset. 49 | 50 | Does things like clearing caches. 51 | """ 52 | 53 | def forward(self, batch: Dict[str, Tensor]) -> dict: 54 | """Run the batch through the model and compute the loss for training or validation. 55 | 56 | Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all 57 | other items should be logging-friendly, native Python types. 58 | """ 59 | 60 | def select_action(self, batch: Dict[str, Tensor]) -> Tensor: 61 | """Return one action to run in the environment (potentially in batch mode). 62 | 63 | When the model uses a history of observations, or outputs a sequence of actions, this method deals 64 | with caching. 65 | """ 66 | 67 | 68 | @runtime_checkable 69 | class PolicyWithUpdate(Policy, Protocol): 70 | def update(self): 71 | """An update method that is to be called after a training optimization step. 72 | 73 | Implements an additional updates the model parameters may need (for example, doing an EMA step for a 74 | target model, or incrementing an internal buffer). 75 | """ 76 | -------------------------------------------------------------------------------- /aloha/lerobot/common/policies/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import torch 17 | from torch import nn 18 | 19 | 20 | def populate_queues(queues, batch): 21 | for key in batch: 22 | # Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the 23 | # queues have the keys they want). 24 | if key not in queues: 25 | continue 26 | if len(queues[key]) != queues[key].maxlen: 27 | # initialize by copying the first observation several times until the queue is full 28 | while len(queues[key]) != queues[key].maxlen: 29 | queues[key].append(batch[key]) 30 | else: 31 | # add latest observation to the queue 32 | queues[key].append(batch[key]) 33 | return queues 34 | 35 | 36 | def get_device_from_parameters(module: nn.Module) -> torch.device: 37 | """Get a module's device by checking one of its parameters. 38 | 39 | Note: assumes that all parameters have the same device 40 | """ 41 | return next(iter(module.parameters())).device 42 | 43 | 44 | def get_dtype_from_parameters(module: nn.Module) -> torch.dtype: 45 | """Get a module's parameter dtype by checking one of its parameters. 46 | 47 | Note: assumes that all parameters have the same dtype. 48 | """ 49 | return next(iter(module.parameters())).dtype 50 | -------------------------------------------------------------------------------- /aloha/lerobot/common/utils/benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import threading 17 | import time 18 | from contextlib import ContextDecorator 19 | 20 | 21 | class TimeBenchmark(ContextDecorator): 22 | """ 23 | Measures execution time using a context manager or decorator. 24 | 25 | This class supports both context manager and decorator usage, and is thread-safe for multithreaded 26 | environments. 27 | 28 | Args: 29 | print: If True, prints the elapsed time upon exiting the context or completing the function. Defaults 30 | to False. 31 | 32 | Examples: 33 | 34 | Using as a context manager: 35 | 36 | >>> benchmark = TimeBenchmark() 37 | >>> with benchmark: 38 | ... time.sleep(1) 39 | >>> print(f"Block took {benchmark.result:.4f} seconds") 40 | Block took approximately 1.0000 seconds 41 | 42 | Using with multithreading: 43 | 44 | ```python 45 | import threading 46 | 47 | benchmark = TimeBenchmark() 48 | 49 | def context_manager_example(): 50 | with benchmark: 51 | time.sleep(0.01) 52 | print(f"Block took {benchmark.result_ms:.2f} milliseconds") 53 | 54 | threads = [] 55 | for _ in range(3): 56 | t1 = threading.Thread(target=context_manager_example) 57 | threads.append(t1) 58 | 59 | for t in threads: 60 | t.start() 61 | 62 | for t in threads: 63 | t.join() 64 | ``` 65 | Expected output: 66 | Block took approximately 10.00 milliseconds 67 | Block took approximately 10.00 milliseconds 68 | Block took approximately 10.00 milliseconds 69 | """ 70 | 71 | def __init__(self, print=False): 72 | self.local = threading.local() 73 | self.print_time = print 74 | 75 | def __enter__(self): 76 | self.local.start_time = time.perf_counter() 77 | return self 78 | 79 | def __exit__(self, *exc): 80 | self.local.end_time = time.perf_counter() 81 | self.local.elapsed_time = self.local.end_time - self.local.start_time 82 | if self.print_time: 83 | print(f"Elapsed time: {self.local.elapsed_time:.4f} seconds") 84 | return False 85 | 86 | @property 87 | def result(self): 88 | return getattr(self.local, "elapsed_time", None) 89 | 90 | @property 91 | def result_ms(self): 92 | return self.result * 1e3 93 | -------------------------------------------------------------------------------- /aloha/lerobot/common/utils/import_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import importlib 17 | import logging 18 | 19 | 20 | def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool: 21 | """Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py 22 | Check if the package spec exists and grab its version to avoid importing a local directory. 23 | **Note:** this doesn't work for all packages. 24 | """ 25 | package_exists = importlib.util.find_spec(pkg_name) is not None 26 | package_version = "N/A" 27 | if package_exists: 28 | try: 29 | # Primary method to get the package version 30 | package_version = importlib.metadata.version(pkg_name) 31 | except importlib.metadata.PackageNotFoundError: 32 | # Fallback method: Only for "torch" and versions containing "dev" 33 | if pkg_name == "torch": 34 | try: 35 | package = importlib.import_module(pkg_name) 36 | temp_version = getattr(package, "__version__", "N/A") 37 | # Check if the version contains "dev" 38 | if "dev" in temp_version: 39 | package_version = temp_version 40 | package_exists = True 41 | else: 42 | package_exists = False 43 | except ImportError: 44 | # If the package can't be imported, it's not available 45 | package_exists = False 46 | else: 47 | # For packages other than "torch", don't attempt the fallback and set as not available 48 | package_exists = False 49 | logging.debug(f"Detected {pkg_name} version: {package_version}") 50 | if return_version: 51 | return package_exists, package_version 52 | else: 53 | return package_exists 54 | 55 | 56 | _torch_available, _torch_version = is_package_available("torch", return_version=True) 57 | _gym_xarm_available = is_package_available("gym_xarm") 58 | _gym_aloha_available = is_package_available("gym_aloha") 59 | _gym_pusht_available = is_package_available("gym_pusht") 60 | -------------------------------------------------------------------------------- /aloha/lerobot/common/utils/io_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import warnings 17 | 18 | import imageio 19 | 20 | 21 | def write_video(video_path, stacked_frames, fps): 22 | # Filter out DeprecationWarnings raised from pkg_resources 23 | with warnings.catch_warnings(): 24 | warnings.filterwarnings( 25 | "ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning 26 | ) 27 | imageio.mimsave(video_path, stacked_frames, fps=fps) 28 | -------------------------------------------------------------------------------- /aloha/lerobot/configs/env/aloha.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | fps: 50 4 | 5 | env: 6 | name: aloha 7 | task: AlohaInsertion-v0 8 | state_dim: 14 9 | action_dim: 14 10 | fps: ${fps} 11 | episode_length: 400 12 | gym: 13 | obs_type: pixels_agent_pos 14 | render_mode: rgb_array 15 | -------------------------------------------------------------------------------- /aloha/lerobot/configs/env/dora_aloha_real.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | fps: 30 4 | 5 | env: 6 | name: dora 7 | task: DoraAloha-v0 8 | state_dim: 14 9 | action_dim: 14 10 | fps: ${fps} 11 | episode_length: 400 12 | gym: 13 | fps: ${fps} 14 | -------------------------------------------------------------------------------- /aloha/lerobot/configs/env/pusht.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | fps: 10 4 | 5 | env: 6 | name: pusht 7 | task: PushT-v0 8 | image_size: 96 9 | state_dim: 2 10 | action_dim: 2 11 | fps: ${fps} 12 | episode_length: 300 13 | gym: 14 | obs_type: pixels_agent_pos 15 | render_mode: rgb_array 16 | visualization_width: 384 17 | visualization_height: 384 18 | -------------------------------------------------------------------------------- /aloha/lerobot/configs/env/xarm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | fps: 15 4 | 5 | env: 6 | name: xarm 7 | task: XarmLift-v0 8 | image_size: 84 9 | state_dim: 4 10 | action_dim: 4 11 | fps: ${fps} 12 | episode_length: 25 13 | gym: 14 | obs_type: pixels_agent_pos 15 | render_mode: rgb_array 16 | visualization_width: 384 17 | visualization_height: 384 18 | -------------------------------------------------------------------------------- /aloha/lerobot/configs/policy/act.single.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | seed: 1000 4 | dataset_repo_id: lerobot/aloha_sim_insertion_human 5 | 6 | fps: 50 7 | 8 | env: 9 | name: aloha 10 | task: AlohaInsertion-v0 11 | state_dim: 14 12 | action_dim: 14 13 | fps: ${fps} 14 | episode_length: 400 15 | gym: 16 | obs_type: pixels_agent_pos 17 | render_mode: rgb_array 18 | 19 | 20 | 21 | override_dataset_stats: 22 | observation.images.top: 23 | # stats from imagenet, since we use a pretrained vision model 24 | mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) 25 | std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) 26 | 27 | training: 28 | offline_steps: 100000 29 | online_steps: 0 30 | eval_freq: 20000 31 | save_freq: 20000 32 | save_checkpoint: true 33 | 34 | batch_size: 8 35 | lr: 1e-5 36 | lr_backbone: 1e-5 37 | weight_decay: 1e-4 38 | grad_clip_norm: 10 39 | online_steps_between_rollouts: 1 40 | 41 | delta_timestamps: 42 | action: "[i / ${fps} for i in range(${policy.chunk_size})]" 43 | 44 | eval: 45 | n_episodes: 50 46 | batch_size: 50 47 | 48 | # See `configuration_act.py` for more details. 49 | policy: 50 | name: act 51 | 52 | # Input / output structure. 53 | n_obs_steps: 1 54 | chunk_size: 100 # chunk_size 55 | n_action_steps: 100 56 | 57 | input_shapes: 58 | # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? 59 | observation.images.top: [3, 480, 640] 60 | observation.state: ["${env.state_dim}"] 61 | output_shapes: 62 | action: ["${env.action_dim}"] 63 | 64 | # Normalization / Unnormalization 65 | input_normalization_modes: 66 | observation.images.top: mean_std 67 | observation.state: mean_std 68 | output_normalization_modes: 69 | action: mean_std 70 | 71 | # Architecture. 72 | # Vision backbone. 73 | vision_backbone: resnet18 74 | pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1 75 | replace_final_stride_with_dilation: false 76 | # Transformer layers. 77 | pre_norm: false 78 | dim_model: 512 79 | n_heads: 8 80 | dim_feedforward: 3200 81 | feedforward_activation: relu 82 | n_encoder_layers: 4 83 | # Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code 84 | # that means only the first layer is used. Here we match the original implementation by setting this to 1. 85 | # See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521. 86 | n_decoder_layers: 4 87 | # VAE. 88 | use_vae: true 89 | latent_dim: 32 90 | n_vae_encoder_layers: 4 91 | 92 | # Inference. 93 | temporal_ensemble_momentum: null 94 | 95 | # Training and loss computation. 96 | dropout: 0.1 97 | kl_weight: 10.0 98 | -------------------------------------------------------------------------------- /aloha/lerobot/configs/policy/act.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | seed: 1000 4 | dataset_repo_id: lerobot/aloha_sim_insertion_human 5 | 6 | override_dataset_stats: 7 | observation.images.top: 8 | # stats from imagenet, since we use a pretrained vision model 9 | mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) 10 | std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) 11 | 12 | training: 13 | offline_steps: 100000 14 | online_steps: 0 15 | eval_freq: 20000 16 | save_freq: 20000 17 | save_checkpoint: true 18 | 19 | batch_size: 8 20 | lr: 1e-5 21 | lr_backbone: 1e-5 22 | weight_decay: 1e-4 23 | grad_clip_norm: 10 24 | online_steps_between_rollouts: 1 25 | 26 | delta_timestamps: 27 | action: "[i / ${fps} for i in range(${policy.chunk_size})]" 28 | 29 | eval: 30 | n_episodes: 50 31 | batch_size: 50 32 | 33 | # See `configuration_act.py` for more details. 34 | policy: 35 | name: act 36 | 37 | # Input / output structure. 38 | n_obs_steps: 1 39 | chunk_size: 100 # chunk_size 40 | n_action_steps: 100 41 | 42 | input_shapes: 43 | # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? 44 | observation.images.top: [3, 480, 640] 45 | observation.state: ["${env.state_dim}"] 46 | output_shapes: 47 | action: ["${env.action_dim}"] 48 | 49 | # Normalization / Unnormalization 50 | input_normalization_modes: 51 | observation.images.top: mean_std 52 | observation.state: mean_std 53 | output_normalization_modes: 54 | action: mean_std 55 | 56 | # Architecture. 57 | # Vision backbone. 58 | vision_backbone: resnet18 59 | pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1 60 | replace_final_stride_with_dilation: false 61 | # Transformer layers. 62 | pre_norm: false 63 | dim_model: 512 64 | n_heads: 8 65 | dim_feedforward: 3200 66 | feedforward_activation: relu 67 | n_encoder_layers: 4 68 | # Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code 69 | # that means only the first layer is used. Here we match the original implementation by setting this to 1. 70 | # See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521. 71 | n_decoder_layers: 1 72 | # VAE. 73 | use_vae: true 74 | latent_dim: 32 75 | n_vae_encoder_layers: 4 76 | 77 | # Inference. 78 | temporal_ensemble_momentum: null 79 | 80 | # Training and loss computation. 81 | dropout: 0.1 82 | kl_weight: 10.0 83 | -------------------------------------------------------------------------------- /aloha/lerobot/configs/policy/tdmpc.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | seed: 1 4 | dataset_repo_id: lerobot/xarm_lift_medium 5 | 6 | training: 7 | offline_steps: 25000 8 | # TODO(alexander-soare): uncomment when online training gets reinstated 9 | online_steps: 0 # 25000 not implemented yet 10 | eval_freq: 5000 11 | online_steps_between_rollouts: 1 12 | online_sampling_ratio: 0.5 13 | online_env_seed: 10000 14 | log_freq: 100 15 | 16 | batch_size: 256 17 | grad_clip_norm: 10.0 18 | lr: 3e-4 19 | 20 | delta_timestamps: 21 | observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]" 22 | observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]" 23 | action: "[i / ${fps} for i in range(${policy.horizon})]" 24 | next.reward: "[i / ${fps} for i in range(${policy.horizon})]" 25 | 26 | policy: 27 | name: tdmpc 28 | 29 | pretrained_model_path: 30 | 31 | # Input / output structure. 32 | n_action_repeats: 2 33 | horizon: 5 34 | 35 | input_shapes: 36 | # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? 37 | observation.image: [3, 84, 84] 38 | observation.state: ["${env.state_dim}"] 39 | output_shapes: 40 | action: ["${env.action_dim}"] 41 | 42 | # Normalization / Unnormalization 43 | input_normalization_modes: null 44 | output_normalization_modes: 45 | action: min_max 46 | 47 | # Architecture / modeling. 48 | # Neural networks. 49 | image_encoder_hidden_dim: 32 50 | state_encoder_hidden_dim: 256 51 | latent_dim: 50 52 | q_ensemble_size: 5 53 | mlp_dim: 512 54 | # Reinforcement learning. 55 | discount: 0.9 56 | 57 | # Inference. 58 | use_mpc: true 59 | cem_iterations: 6 60 | max_std: 2.0 61 | min_std: 0.05 62 | n_gaussian_samples: 512 63 | n_pi_samples: 51 64 | uncertainty_regularizer_coeff: 1.0 65 | n_elites: 50 66 | elite_weighting_temperature: 0.5 67 | gaussian_mean_momentum: 0.1 68 | 69 | # Training and loss computation. 70 | max_random_shift_ratio: 0.0476 71 | # Loss coefficients. 72 | reward_coeff: 0.5 73 | expectile_weight: 0.9 74 | value_coeff: 0.1 75 | consistency_coeff: 20.0 76 | advantage_scaling: 3.0 77 | pi_coeff: 0.5 78 | temporal_decay_coeff: 0.5 79 | # Target model. 80 | target_model_momentum: 0.995 81 | -------------------------------------------------------------------------------- /aloha/lerobot/configs/policy/vqbet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Defaults for training for the PushT dataset. 4 | 5 | seed: 100000 6 | dataset_repo_id: lerobot/pusht 7 | 8 | override_dataset_stats: 9 | # TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model? 10 | observation.image: 11 | mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1) 12 | std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1) 13 | # TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model 14 | # from the original codebase, but we should remove these and train our own pretrained model 15 | observation.state: 16 | min: [13.456424, 32.938293] 17 | max: [496.14618, 510.9579] 18 | action: 19 | min: [12.0, 25.0] 20 | max: [511.0, 511.0] 21 | 22 | training: 23 | offline_steps: 250000 24 | online_steps: 0 25 | eval_freq: 25000 26 | save_freq: 25000 27 | save_checkpoint: true 28 | 29 | batch_size: 64 30 | grad_clip_norm: 10 31 | lr: 1.0e-4 32 | lr_scheduler: cosine 33 | lr_warmup_steps: 500 34 | adam_betas: [0.95, 0.999] 35 | adam_eps: 1.0e-8 36 | adam_weight_decay: 1.0e-6 37 | online_steps_between_rollouts: 1 38 | 39 | # VQ-BeT specific 40 | vqvae_lr: 1.0e-3 41 | n_vqvae_training_steps: 20000 42 | bet_weight_decay: 2e-4 43 | bet_learning_rate: 5.5e-5 44 | bet_betas: [0.9, 0.999] 45 | 46 | delta_timestamps: 47 | observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" 48 | observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" 49 | action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.action_chunk_size} - 1)]" 50 | 51 | eval: 52 | n_episodes: 50 53 | batch_size: 50 54 | 55 | policy: 56 | name: vqbet 57 | 58 | # Input / output structure. 59 | n_obs_steps: 5 60 | n_action_pred_token: 7 61 | action_chunk_size: 5 62 | 63 | input_shapes: 64 | # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? 65 | observation.image: [3, 96, 96] 66 | observation.state: ["${env.state_dim}"] 67 | output_shapes: 68 | action: ["${env.action_dim}"] 69 | 70 | # Normalization / Unnormalization 71 | input_normalization_modes: 72 | observation.image: mean_std 73 | observation.state: min_max 74 | output_normalization_modes: 75 | action: min_max 76 | 77 | # Architecture / modeling. 78 | # Vision backbone. 79 | vision_backbone: resnet18 80 | crop_shape: [84, 84] 81 | crop_is_random: True 82 | pretrained_backbone_weights: null 83 | use_group_norm: True 84 | spatial_softmax_num_keypoints: 32 85 | # VQ-VAE 86 | n_vqvae_training_steps: ${training.n_vqvae_training_steps} 87 | vqvae_n_embed: 16 88 | vqvae_embedding_dim: 256 89 | vqvae_enc_hidden_dim: 128 90 | # VQ-BeT 91 | gpt_block_size: 500 92 | gpt_input_dim: 512 93 | gpt_output_dim: 512 94 | gpt_n_layer: 8 95 | gpt_n_head: 8 96 | gpt_hidden_dim: 512 97 | dropout: 0.1 98 | mlp_hidden_dim: 1024 99 | offset_loss_weight: 10000. 100 | primary_code_loss_weight: 5.0 101 | secondary_code_loss_weight: 0.5 102 | bet_softmax_temperature: 0.1 103 | sequentially_select: False 104 | -------------------------------------------------------------------------------- /aloha/lerobot/scripts/display_sys_info.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Use this script to get a quick summary of your system config. 18 | It should be able to run without any of LeRobot's dependencies or LeRobot itself installed. 19 | """ 20 | 21 | import platform 22 | 23 | HAS_HF_HUB = True 24 | HAS_HF_DATASETS = True 25 | HAS_NP = True 26 | HAS_TORCH = True 27 | HAS_LEROBOT = True 28 | 29 | try: 30 | import huggingface_hub 31 | except ImportError: 32 | HAS_HF_HUB = False 33 | 34 | try: 35 | import datasets 36 | except ImportError: 37 | HAS_HF_DATASETS = False 38 | 39 | try: 40 | import numpy as np 41 | except ImportError: 42 | HAS_NP = False 43 | 44 | try: 45 | import torch 46 | except ImportError: 47 | HAS_TORCH = False 48 | 49 | try: 50 | import lerobot 51 | except ImportError: 52 | HAS_LEROBOT = False 53 | 54 | 55 | lerobot_version = lerobot.__version__ if HAS_LEROBOT else "N/A" 56 | hf_hub_version = huggingface_hub.__version__ if HAS_HF_HUB else "N/A" 57 | hf_datasets_version = datasets.__version__ if HAS_HF_DATASETS else "N/A" 58 | np_version = np.__version__ if HAS_NP else "N/A" 59 | 60 | torch_version = torch.__version__ if HAS_TORCH else "N/A" 61 | torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A" 62 | cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A" 63 | 64 | 65 | # TODO(aliberts): refactor into an actual command `lerobot env` 66 | def display_sys_info() -> dict: 67 | """Run this to get basic system info to help for tracking issues & bugs.""" 68 | info = { 69 | "`lerobot` version": lerobot_version, 70 | "Platform": platform.platform(), 71 | "Python version": platform.python_version(), 72 | "Huggingface_hub version": hf_hub_version, 73 | "Dataset version": hf_datasets_version, 74 | "Numpy version": np_version, 75 | "PyTorch version (GPU?)": f"{torch_version} ({torch_cuda_available})", 76 | "Cuda version": cuda_version, 77 | "Using GPU in script?": "", 78 | # "Using distributed or parallel set-up in script?": "", 79 | } 80 | print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n") 81 | print(format_dict(info)) 82 | return info 83 | 84 | 85 | def format_dict(d: dict) -> str: 86 | return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n" 87 | 88 | 89 | if __name__ == "__main__": 90 | display_sys_info() 91 | -------------------------------------------------------------------------------- /assets/demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlzxy/arp/73cb339c032b3c350966269f4729bd1f29e927a9/assets/demo.mp4 -------------------------------------------------------------------------------- /assets/keydiff.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlzxy/arp/73cb339c032b3c350966269f4729bd1f29e927a9/assets/keydiff.jpg -------------------------------------------------------------------------------- /assets/main-fig.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlzxy/arp/73cb339c032b3c350966269f4729bd1f29e927a9/assets/main-fig.jpg -------------------------------------------------------------------------------- /assets/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlzxy/arp/73cb339c032b3c350966269f4729bd1f29e927a9/assets/result.jpg -------------------------------------------------------------------------------- /assets/why-chunking-autoregression-works.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlzxy/arp/73cb339c032b3c350966269f4729bd1f29e927a9/assets/why-chunking-autoregression-works.png -------------------------------------------------------------------------------- /pusht/000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlzxy/arp/73cb339c032b3c350966269f4729bd1f29e927a9/pusht/000.jpg -------------------------------------------------------------------------------- /pusht/arp.py: -------------------------------------------------------------------------------- 1 | ../arp.py -------------------------------------------------------------------------------- /pusht/diffusion_policy/common/checkpoint_util.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict 2 | import os 3 | 4 | class TopKCheckpointManager: 5 | def __init__(self, 6 | save_dir, 7 | monitor_key: str, 8 | mode='min', 9 | k=1, 10 | format_str='epoch={epoch:03d}-train_loss={train_loss:.3f}.ckpt' 11 | ): 12 | assert mode in ['max', 'min'] 13 | assert k >= 0 14 | 15 | self.save_dir = save_dir 16 | self.monitor_key = monitor_key 17 | self.mode = mode 18 | self.k = k 19 | self.format_str = format_str 20 | self.path_value_map = dict() 21 | 22 | def get_ckpt_path(self, data: Dict[str, float]) -> Optional[str]: 23 | if self.k == 0: 24 | return None 25 | 26 | value = data[self.monitor_key] 27 | ckpt_path = os.path.join( 28 | self.save_dir, self.format_str.format(**data)) 29 | 30 | if len(self.path_value_map) < self.k: 31 | # under-capacity 32 | self.path_value_map[ckpt_path] = value 33 | return ckpt_path 34 | 35 | # at capacity 36 | sorted_map = sorted(self.path_value_map.items(), key=lambda x: x[1]) 37 | min_path, min_value = sorted_map[0] 38 | max_path, max_value = sorted_map[-1] 39 | 40 | delete_path = None 41 | if self.mode == 'max': 42 | if value > min_value: 43 | delete_path = min_path 44 | else: 45 | if value < max_value: 46 | delete_path = max_path 47 | 48 | if delete_path is None: 49 | return None 50 | else: 51 | del self.path_value_map[delete_path] 52 | self.path_value_map[ckpt_path] = value 53 | 54 | if not os.path.exists(self.save_dir): 55 | os.mkdir(self.save_dir) 56 | 57 | if os.path.exists(delete_path): 58 | os.remove(delete_path) 59 | return ckpt_path 60 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/common/env_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def render_env_video(env, states, actions=None): 6 | observations = states 7 | imgs = list() 8 | for i in range(len(observations)): 9 | state = observations[i] 10 | env.set_state(state) 11 | if i == 0: 12 | env.set_state(state) 13 | img = env.render() 14 | # draw action 15 | if actions is not None: 16 | action = actions[i] 17 | coord = (action / 512 * 96).astype(np.int32) 18 | cv2.drawMarker(img, coord, 19 | color=(255,0,0), markerType=cv2.MARKER_CROSS, 20 | markerSize=8, thickness=1) 21 | imgs.append(img) 22 | imgs = np.array(imgs) 23 | return imgs 24 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/common/nested_dict_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | def nested_dict_map(f, x): 4 | """ 5 | Map f over all leaf of nested dict x 6 | """ 7 | 8 | if not isinstance(x, dict): 9 | return f(x) 10 | y = dict() 11 | for key, value in x.items(): 12 | y[key] = nested_dict_map(f, value) 13 | return y 14 | 15 | def nested_dict_reduce(f, x): 16 | """ 17 | Map f over all values of nested dict x, and reduce to a single value 18 | """ 19 | if not isinstance(x, dict): 20 | return x 21 | 22 | reduced_values = list() 23 | for value in x.values(): 24 | reduced_values.append(nested_dict_reduce(f, value)) 25 | y = functools.reduce(f, reduced_values) 26 | return y 27 | 28 | 29 | def nested_dict_check(f, x): 30 | bool_dict = nested_dict_map(f, x) 31 | result = nested_dict_reduce(lambda x, y: x and y, bool_dict) 32 | return result 33 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/common/precise_sleep.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | def precise_sleep(dt: float, slack_time: float=0.001, time_func=time.monotonic): 4 | """ 5 | Use hybrid of time.sleep and spinning to minimize jitter. 6 | Sleep dt - slack_time seconds first, then spin for the rest. 7 | """ 8 | t_start = time_func() 9 | if dt > slack_time: 10 | time.sleep(dt - slack_time) 11 | t_end = t_start + dt 12 | while time_func() < t_end: 13 | pass 14 | return 15 | 16 | def precise_wait(t_end: float, slack_time: float=0.001, time_func=time.monotonic): 17 | t_start = time_func() 18 | t_wait = t_end - t_start 19 | if t_wait > 0: 20 | t_sleep = t_wait - slack_time 21 | if t_sleep > 0: 22 | time.sleep(t_sleep) 23 | while time_func() < t_end: 24 | pass 25 | return 26 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/common/pymunk_util.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import pymunk 3 | import pymunk.pygame_util 4 | import numpy as np 5 | 6 | COLLTYPE_DEFAULT = 0 7 | COLLTYPE_MOUSE = 1 8 | COLLTYPE_BALL = 2 9 | 10 | def get_body_type(static=False): 11 | body_type = pymunk.Body.DYNAMIC 12 | if static: 13 | body_type = pymunk.Body.STATIC 14 | return body_type 15 | 16 | 17 | def create_rectangle(space, 18 | pos_x,pos_y,width,height, 19 | density=3,static=False): 20 | body = pymunk.Body(body_type=get_body_type(static)) 21 | body.position = (pos_x,pos_y) 22 | shape = pymunk.Poly.create_box(body,(width,height)) 23 | shape.density = density 24 | space.add(body,shape) 25 | return body, shape 26 | 27 | 28 | def create_rectangle_bb(space, 29 | left, bottom, right, top, 30 | **kwargs): 31 | pos_x = (left + right) / 2 32 | pos_y = (top + bottom) / 2 33 | height = top - bottom 34 | width = right - left 35 | return create_rectangle(space, pos_x, pos_y, width, height, **kwargs) 36 | 37 | def create_circle(space, pos_x, pos_y, radius, density=3, static=False): 38 | body = pymunk.Body(body_type=get_body_type(static)) 39 | body.position = (pos_x, pos_y) 40 | shape = pymunk.Circle(body, radius=radius) 41 | shape.density = density 42 | shape.collision_type = COLLTYPE_BALL 43 | space.add(body, shape) 44 | return body, shape 45 | 46 | def get_body_state(body): 47 | state = np.zeros(6, dtype=np.float32) 48 | state[:2] = body.position 49 | state[2] = body.angle 50 | state[3:5] = body.velocity 51 | state[5] = body.angular_velocity 52 | return state 53 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/common/pytorch_util.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Callable, List 2 | import collections 3 | import torch 4 | import torch.nn as nn 5 | 6 | def dict_apply( 7 | x: Dict[str, torch.Tensor], 8 | func: Callable[[torch.Tensor], torch.Tensor] 9 | ) -> Dict[str, torch.Tensor]: 10 | result = dict() 11 | for key, value in x.items(): 12 | if isinstance(value, dict): 13 | result[key] = dict_apply(value, func) 14 | else: 15 | result[key] = func(value) 16 | return result 17 | 18 | def pad_remaining_dims(x, target): 19 | assert x.shape == target.shape[:len(x.shape)] 20 | return x.reshape(x.shape + (1,)*(len(target.shape) - len(x.shape))) 21 | 22 | def dict_apply_split( 23 | x: Dict[str, torch.Tensor], 24 | split_func: Callable[[torch.Tensor], Dict[str, torch.Tensor]] 25 | ) -> Dict[str, torch.Tensor]: 26 | results = collections.defaultdict(dict) 27 | for key, value in x.items(): 28 | result = split_func(value) 29 | for k, v in result.items(): 30 | results[k][key] = v 31 | return results 32 | 33 | def dict_apply_reduce( 34 | x: List[Dict[str, torch.Tensor]], 35 | reduce_func: Callable[[List[torch.Tensor]], torch.Tensor] 36 | ) -> Dict[str, torch.Tensor]: 37 | result = dict() 38 | for key in x[0].keys(): 39 | result[key] = reduce_func([x_[key] for x_ in x]) 40 | return result 41 | 42 | 43 | def replace_submodules( 44 | root_module: nn.Module, 45 | predicate: Callable[[nn.Module], bool], 46 | func: Callable[[nn.Module], nn.Module]) -> nn.Module: 47 | """ 48 | predicate: Return true if the module is to be replaced. 49 | func: Return new module to use. 50 | """ 51 | if predicate(root_module): 52 | return func(root_module) 53 | 54 | bn_list = [k.split('.') for k, m 55 | in root_module.named_modules(remove_duplicate=True) 56 | if predicate(m)] 57 | for *parent, k in bn_list: 58 | parent_module = root_module 59 | if len(parent) > 0: 60 | parent_module = root_module.get_submodule('.'.join(parent)) 61 | if isinstance(parent_module, nn.Sequential): 62 | src_module = parent_module[int(k)] 63 | else: 64 | src_module = getattr(parent_module, k) 65 | tgt_module = func(src_module) 66 | if isinstance(parent_module, nn.Sequential): 67 | parent_module[int(k)] = tgt_module 68 | else: 69 | setattr(parent_module, k, tgt_module) 70 | # verify that all BN are replaced 71 | bn_list = [k.split('.') for k, m 72 | in root_module.named_modules(remove_duplicate=True) 73 | if predicate(m)] 74 | assert len(bn_list) == 0 75 | return root_module 76 | 77 | def optimizer_to(optimizer, device): 78 | for state in optimizer.state.values(): 79 | for k, v in state.items(): 80 | if isinstance(v, torch.Tensor): 81 | state[k] = v.to(device=device) 82 | return optimizer 83 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/common/robomimic_config_util.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from robomimic.config import config_factory 3 | import robomimic.scripts.generate_paper_configs as gpc 4 | from robomimic.scripts.generate_paper_configs import ( 5 | modify_config_for_default_image_exp, 6 | modify_config_for_default_low_dim_exp, 7 | modify_config_for_dataset, 8 | ) 9 | 10 | def get_robomimic_config( 11 | algo_name='bc_rnn', 12 | hdf5_type='low_dim', 13 | task_name='square', 14 | dataset_type='ph' 15 | ): 16 | base_dataset_dir = '/tmp/null' 17 | filter_key = None 18 | 19 | # decide whether to use low-dim or image training defaults 20 | modifier_for_obs = modify_config_for_default_image_exp 21 | if hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"]: 22 | modifier_for_obs = modify_config_for_default_low_dim_exp 23 | 24 | algo_config_name = "bc" if algo_name == "bc_rnn" else algo_name 25 | config = config_factory(algo_name=algo_config_name) 26 | # turn into default config for observation modalities (e.g.: low-dim or rgb) 27 | config = modifier_for_obs(config) 28 | # add in config based on the dataset 29 | config = modify_config_for_dataset( 30 | config=config, 31 | task_name=task_name, 32 | dataset_type=dataset_type, 33 | hdf5_type=hdf5_type, 34 | base_dataset_dir=base_dataset_dir, 35 | filter_key=filter_key, 36 | ) 37 | # add in algo hypers based on dataset 38 | algo_config_modifier = getattr(gpc, f'modify_{algo_name}_config_for_dataset') 39 | config = algo_config_modifier( 40 | config=config, 41 | task_name=task_name, 42 | dataset_type=dataset_type, 43 | hdf5_type=hdf5_type, 44 | ) 45 | return config 46 | 47 | 48 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/blockpush_lowdim_seed.yaml: -------------------------------------------------------------------------------- 1 | name: blockpush_lowdim_seed 2 | 3 | obs_dim: 16 4 | action_dim: 2 5 | keypoint_dim: 2 6 | obs_eef_target: True 7 | 8 | env_runner: 9 | _target_: diffusion_policy.env_runner.blockpush_lowdim_runner.BlockPushLowdimRunner 10 | n_train: 6 11 | n_train_vis: 2 12 | train_start_seed: 0 13 | n_test: 50 14 | n_test_vis: 4 15 | test_start_seed: 100000 16 | max_steps: 350 17 | n_obs_steps: ${n_obs_steps} 18 | n_action_steps: ${n_action_steps} 19 | fps: 5 20 | past_action: ${past_action_visible} 21 | abs_action: False 22 | obs_eef_target: ${task.obs_eef_target} 23 | n_envs: null 24 | 25 | dataset: 26 | _target_: diffusion_policy.dataset.blockpush_lowdim_dataset.BlockPushLowdimDataset 27 | zarr_path: data/block_pushing/multimodal_push_seed.zarr 28 | horizon: ${horizon} 29 | pad_before: ${eval:'${n_obs_steps}-1'} 30 | pad_after: ${eval:'${n_action_steps}-1'} 31 | obs_eef_target: ${task.obs_eef_target} 32 | use_manual_normalizer: False 33 | seed: 42 34 | val_ratio: 0.02 35 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/blockpush_lowdim_seed_abs.yaml: -------------------------------------------------------------------------------- 1 | name: blockpush_lowdim_seed_abs 2 | 3 | obs_dim: 16 4 | action_dim: 2 5 | keypoint_dim: 2 6 | obs_eef_target: True 7 | 8 | env_runner: 9 | _target_: diffusion_policy.env_runner.blockpush_lowdim_runner.BlockPushLowdimRunner 10 | n_train: 6 11 | n_train_vis: 2 12 | train_start_seed: 0 13 | n_test: 50 14 | n_test_vis: 4 15 | test_start_seed: 100000 16 | max_steps: 350 17 | n_obs_steps: ${n_obs_steps} 18 | n_action_steps: ${n_action_steps} 19 | fps: 5 20 | past_action: ${past_action_visible} 21 | abs_action: True 22 | obs_eef_target: ${task.obs_eef_target} 23 | n_envs: null 24 | 25 | dataset: 26 | _target_: diffusion_policy.dataset.blockpush_lowdim_dataset.BlockPushLowdimDataset 27 | zarr_path: data/block_pushing/multimodal_push_seed_abs.zarr 28 | horizon: ${horizon} 29 | pad_before: ${eval:'${n_obs_steps}-1'} 30 | pad_after: ${eval:'${n_action_steps}-1'} 31 | obs_eef_target: ${task.obs_eef_target} 32 | use_manual_normalizer: False 33 | seed: 42 34 | val_ratio: 0.02 35 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/can_image.yaml: -------------------------------------------------------------------------------- 1 | name: can_image 2 | 3 | shape_meta: &shape_meta 4 | # acceptable types: rgb, low_dim 5 | obs: 6 | agentview_image: 7 | shape: [3, 84, 84] 8 | type: rgb 9 | robot0_eye_in_hand_image: 10 | shape: [3, 84, 84] 11 | type: rgb 12 | robot0_eef_pos: 13 | shape: [3] 14 | # type default: low_dim 15 | robot0_eef_quat: 16 | shape: [4] 17 | robot0_gripper_qpos: 18 | shape: [2] 19 | action: 20 | shape: [7] 21 | 22 | task_name: &task_name can 23 | dataset_type: &dataset_type ph 24 | dataset_path: &dataset_path data/robomimic/datasets/${task.task_name}/${task.dataset_type}/image.hdf5 25 | abs_action: &abs_action False 26 | 27 | env_runner: 28 | _target_: diffusion_policy.env_runner.robomimic_image_runner.RobomimicImageRunner 29 | dataset_path: *dataset_path 30 | shape_meta: *shape_meta 31 | # costs 1GB per env 32 | n_train: 6 33 | n_train_vis: 2 34 | train_start_idx: 0 35 | n_test: 50 36 | n_test_vis: 4 37 | test_start_seed: 100000 38 | # use python's eval function as resolver, single-quoted string as argument 39 | max_steps: ${eval:'500 if "${task.dataset_type}" == "mh" else 400'} 40 | n_obs_steps: ${n_obs_steps} 41 | n_action_steps: ${n_action_steps} 42 | render_obs_key: 'agentview_image' 43 | fps: 10 44 | crf: 22 45 | past_action: ${past_action_visible} 46 | abs_action: *abs_action 47 | tqdm_interval_sec: 1.0 48 | n_envs: 28 49 | # evaluation at this config requires a 16 core 64GB instance. 50 | 51 | dataset: 52 | _target_: diffusion_policy.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset 53 | shape_meta: *shape_meta 54 | dataset_path: *dataset_path 55 | horizon: ${horizon} 56 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 57 | pad_after: ${eval:'${n_action_steps}-1'} 58 | n_obs_steps: ${dataset_obs_steps} 59 | abs_action: *abs_action 60 | rotation_rep: 'rotation_6d' 61 | use_legacy_normalizer: False 62 | use_cache: True 63 | seed: 42 64 | val_ratio: 0.02 65 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/can_image_abs.yaml: -------------------------------------------------------------------------------- 1 | name: can_image 2 | 3 | shape_meta: &shape_meta 4 | # acceptable types: rgb, low_dim 5 | obs: 6 | agentview_image: 7 | shape: [3, 84, 84] 8 | type: rgb 9 | robot0_eye_in_hand_image: 10 | shape: [3, 84, 84] 11 | type: rgb 12 | robot0_eef_pos: 13 | shape: [3] 14 | # type default: low_dim 15 | robot0_eef_quat: 16 | shape: [4] 17 | robot0_gripper_qpos: 18 | shape: [2] 19 | action: 20 | shape: [10] 21 | 22 | task_name: &task_name can 23 | dataset_type: &dataset_type ph 24 | dataset_path: &dataset_path data/robomimic/datasets/${task.task_name}/${task.dataset_type}/image_abs.hdf5 25 | abs_action: &abs_action True 26 | 27 | env_runner: 28 | _target_: diffusion_policy.env_runner.robomimic_image_runner.RobomimicImageRunner 29 | dataset_path: *dataset_path 30 | shape_meta: *shape_meta 31 | # costs 1GB per env 32 | n_train: 6 33 | n_train_vis: 2 34 | train_start_idx: 0 35 | n_test: 50 36 | n_test_vis: 4 37 | test_start_seed: 100000 38 | # use python's eval function as resolver, single-quoted string as argument 39 | max_steps: ${eval:'500 if "${task.dataset_type}" == "mh" else 400'} 40 | n_obs_steps: ${n_obs_steps} 41 | n_action_steps: ${n_action_steps} 42 | render_obs_key: 'agentview_image' 43 | fps: 10 44 | crf: 22 45 | past_action: ${past_action_visible} 46 | abs_action: *abs_action 47 | tqdm_interval_sec: 1.0 48 | n_envs: 28 49 | # evaluation at this config requires a 16 core 64GB instance. 50 | 51 | dataset: 52 | _target_: diffusion_policy.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset 53 | shape_meta: *shape_meta 54 | dataset_path: *dataset_path 55 | horizon: ${horizon} 56 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 57 | pad_after: ${eval:'${n_action_steps}-1'} 58 | n_obs_steps: ${dataset_obs_steps} 59 | abs_action: *abs_action 60 | rotation_rep: 'rotation_6d' 61 | use_legacy_normalizer: False 62 | use_cache: True 63 | seed: 42 64 | val_ratio: 0.02 65 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/can_lowdim.yaml: -------------------------------------------------------------------------------- 1 | name: can_lowdim 2 | 3 | obs_dim: 23 4 | action_dim: 7 5 | keypoint_dim: 3 6 | 7 | obs_keys: &obs_keys ['object', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos'] 8 | task_name: &task_name can 9 | dataset_type: &dataset_type ph 10 | dataset_path: &dataset_path data/robomimic/datasets/${task.task_name}/${task.dataset_type}/low_dim.hdf5 11 | abs_action: &abs_action False 12 | 13 | env_runner: 14 | _target_: diffusion_policy.env_runner.robomimic_lowdim_runner.RobomimicLowdimRunner 15 | dataset_path: *dataset_path 16 | obs_keys: *obs_keys 17 | n_train: 6 18 | n_train_vis: 2 19 | train_start_idx: 0 20 | n_test: 50 21 | n_test_vis: 4 22 | test_start_seed: 100000 23 | # use python's eval function as resolver, single-quoted string as argument 24 | max_steps: ${eval:'500 if "${task.dataset_type}" == "mh" else 400'} 25 | n_obs_steps: ${n_obs_steps} 26 | n_action_steps: ${n_action_steps} 27 | n_latency_steps: ${n_latency_steps} 28 | render_hw: [128,128] 29 | fps: 10 30 | crf: 22 31 | past_action: ${past_action_visible} 32 | abs_action: *abs_action 33 | n_envs: 28 34 | 35 | dataset: 36 | _target_: diffusion_policy.dataset.robomimic_replay_lowdim_dataset.RobomimicReplayLowdimDataset 37 | dataset_path: *dataset_path 38 | horizon: ${horizon} 39 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 40 | pad_after: ${eval:'${n_action_steps}-1'} 41 | obs_keys: *obs_keys 42 | abs_action: *abs_action 43 | use_legacy_normalizer: False 44 | seed: 42 45 | val_ratio: 0.02 46 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/can_lowdim_abs.yaml: -------------------------------------------------------------------------------- 1 | name: can_lowdim 2 | 3 | obs_dim: 23 4 | action_dim: 10 5 | keypoint_dim: 3 6 | 7 | obs_keys: &obs_keys ['object', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos'] 8 | task_name: &task_name can 9 | dataset_type: &dataset_type ph 10 | dataset_path: &dataset_path data/robomimic/datasets/${task.task_name}/${task.dataset_type}/low_dim_abs.hdf5 11 | abs_action: &abs_action True 12 | 13 | env_runner: 14 | _target_: diffusion_policy.env_runner.robomimic_lowdim_runner.RobomimicLowdimRunner 15 | dataset_path: *dataset_path 16 | obs_keys: *obs_keys 17 | n_train: 6 18 | n_train_vis: 2 19 | train_start_idx: 0 20 | n_test: 50 21 | n_test_vis: 4 22 | test_start_seed: 100000 23 | # use python's eval function as resolver, single-quoted string as argument 24 | max_steps: ${eval:'500 if "${task.dataset_type}" == "mh" else 400'} 25 | n_obs_steps: ${n_obs_steps} 26 | n_action_steps: ${n_action_steps} 27 | n_latency_steps: ${n_latency_steps} 28 | render_hw: [128,128] 29 | fps: 10 30 | crf: 22 31 | past_action: ${past_action_visible} 32 | abs_action: *abs_action 33 | n_envs: 28 34 | 35 | dataset: 36 | _target_: diffusion_policy.dataset.robomimic_replay_lowdim_dataset.RobomimicReplayLowdimDataset 37 | dataset_path: *dataset_path 38 | horizon: ${horizon} 39 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 40 | pad_after: ${eval:'${n_action_steps}-1'} 41 | obs_keys: *obs_keys 42 | abs_action: *abs_action 43 | use_legacy_normalizer: False 44 | rotation_rep: rotation_6d 45 | seed: 42 46 | val_ratio: 0.02 47 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/kitchen_lowdim.yaml: -------------------------------------------------------------------------------- 1 | name: kitchen_lowdim 2 | 3 | obs_dim: 60 4 | action_dim: 9 5 | keypoint_dim: 3 6 | 7 | dataset_dir: &dataset_dir data/kitchen 8 | 9 | env_runner: 10 | _target_: diffusion_policy.env_runner.kitchen_lowdim_runner.KitchenLowdimRunner 11 | dataset_dir: *dataset_dir 12 | n_train: 6 13 | n_train_vis: 2 14 | train_start_seed: 0 15 | n_test: 50 16 | n_test_vis: 4 17 | test_start_seed: 100000 18 | max_steps: 280 19 | n_obs_steps: ${n_obs_steps} 20 | n_action_steps: ${n_action_steps} 21 | render_hw: [240, 360] 22 | fps: 12.5 23 | past_action: ${past_action_visible} 24 | n_envs: null 25 | 26 | dataset: 27 | _target_: diffusion_policy.dataset.kitchen_lowdim_dataset.KitchenLowdimDataset 28 | dataset_dir: *dataset_dir 29 | horizon: ${horizon} 30 | pad_before: ${eval:'${n_obs_steps}-1'} 31 | pad_after: ${eval:'${n_action_steps}-1'} 32 | seed: 42 33 | val_ratio: 0.02 34 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/kitchen_lowdim_abs.yaml: -------------------------------------------------------------------------------- 1 | name: kitchen_lowdim 2 | 3 | obs_dim: 60 4 | action_dim: 9 5 | keypoint_dim: 3 6 | 7 | abs_action: True 8 | robot_noise_ratio: 0.1 9 | 10 | env_runner: 11 | _target_: diffusion_policy.env_runner.kitchen_lowdim_runner.KitchenLowdimRunner 12 | dataset_dir: data/kitchen 13 | n_train: 6 14 | n_train_vis: 2 15 | train_start_seed: 0 16 | n_test: 50 17 | n_test_vis: 4 18 | test_start_seed: 100000 19 | max_steps: 280 20 | n_obs_steps: ${n_obs_steps} 21 | n_action_steps: ${n_action_steps} 22 | render_hw: [240, 360] 23 | fps: 12.5 24 | past_action: ${past_action_visible} 25 | abs_action: ${task.abs_action} 26 | robot_noise_ratio: ${task.robot_noise_ratio} 27 | n_envs: null 28 | 29 | dataset: 30 | _target_: diffusion_policy.dataset.kitchen_mjl_lowdim_dataset.KitchenMjlLowdimDataset 31 | dataset_dir: data/kitchen/kitchen_demos_multitask 32 | horizon: ${horizon} 33 | pad_before: ${eval:'${n_obs_steps}-1'} 34 | pad_after: ${eval:'${n_action_steps}-1'} 35 | abs_action: ${task.abs_action} 36 | robot_noise_ratio: ${task.robot_noise_ratio} 37 | seed: 42 38 | val_ratio: 0.02 39 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/lift_image.yaml: -------------------------------------------------------------------------------- 1 | name: lift_image 2 | 3 | shape_meta: &shape_meta 4 | # acceptable types: rgb, low_dim 5 | obs: 6 | agentview_image: 7 | shape: [3, 84, 84] 8 | type: rgb 9 | robot0_eye_in_hand_image: 10 | shape: [3, 84, 84] 11 | type: rgb 12 | robot0_eef_pos: 13 | shape: [3] 14 | # type default: low_dim 15 | robot0_eef_quat: 16 | shape: [4] 17 | robot0_gripper_qpos: 18 | shape: [2] 19 | action: 20 | shape: [7] 21 | 22 | task_name: &task_name lift 23 | dataset_type: &dataset_type ph 24 | dataset_path: &dataset_path data/robomimic/datasets/${task.task_name}/${task.dataset_type}/image.hdf5 25 | abs_action: &abs_action False 26 | 27 | env_runner: 28 | _target_: diffusion_policy.env_runner.robomimic_image_runner.RobomimicImageRunner 29 | dataset_path: *dataset_path 30 | shape_meta: *shape_meta 31 | # costs 1GB per env 32 | n_train: 6 33 | n_train_vis: 1 34 | train_start_idx: 0 35 | n_test: 50 36 | n_test_vis: 3 37 | test_start_seed: 100000 38 | # use python's eval function as resolver, single-quoted string as argument 39 | max_steps: ${eval:'500 if "${task.dataset_type}" == "mh" else 400'} 40 | n_obs_steps: ${n_obs_steps} 41 | n_action_steps: ${n_action_steps} 42 | render_obs_key: 'agentview_image' 43 | fps: 10 44 | crf: 22 45 | past_action: ${past_action_visible} 46 | abs_action: *abs_action 47 | tqdm_interval_sec: 1.0 48 | n_envs: 28 49 | # evaluation at this config requires a 16 core 64GB instance. 50 | 51 | dataset: 52 | _target_: diffusion_policy.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset 53 | shape_meta: *shape_meta 54 | dataset_path: *dataset_path 55 | horizon: ${horizon} 56 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 57 | pad_after: ${eval:'${n_action_steps}-1'} 58 | n_obs_steps: ${dataset_obs_steps} 59 | abs_action: *abs_action 60 | rotation_rep: 'rotation_6d' 61 | use_legacy_normalizer: False 62 | use_cache: True 63 | seed: 42 64 | val_ratio: 0.02 65 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/lift_image_abs.yaml: -------------------------------------------------------------------------------- 1 | name: lift_image 2 | 3 | shape_meta: &shape_meta 4 | # acceptable types: rgb, low_dim 5 | obs: 6 | agentview_image: 7 | shape: [3, 84, 84] 8 | type: rgb 9 | robot0_eye_in_hand_image: 10 | shape: [3, 84, 84] 11 | type: rgb 12 | robot0_eef_pos: 13 | shape: [3] 14 | # type default: low_dim 15 | robot0_eef_quat: 16 | shape: [4] 17 | robot0_gripper_qpos: 18 | shape: [2] 19 | action: 20 | shape: [10] 21 | 22 | task_name: &task_name lift 23 | dataset_type: &dataset_type ph 24 | dataset_path: &dataset_path data/robomimic/datasets/${task.task_name}/${task.dataset_type}/image_abs.hdf5 25 | abs_action: &abs_action True 26 | 27 | env_runner: 28 | _target_: diffusion_policy.env_runner.robomimic_image_runner.RobomimicImageRunner 29 | dataset_path: *dataset_path 30 | shape_meta: *shape_meta 31 | n_train: 6 32 | n_train_vis: 2 33 | train_start_idx: 0 34 | n_test: 50 35 | n_test_vis: 4 36 | test_start_seed: 100000 37 | # use python's eval function as resolver, single-quoted string as argument 38 | max_steps: ${eval:'500 if "${task.dataset_type}" == "mh" else 400'} 39 | n_obs_steps: ${n_obs_steps} 40 | n_action_steps: ${n_action_steps} 41 | render_obs_key: 'agentview_image' 42 | fps: 10 43 | crf: 22 44 | past_action: ${past_action_visible} 45 | abs_action: *abs_action 46 | tqdm_interval_sec: 1.0 47 | n_envs: 28 48 | # evaluation at this config requires a 16 core 64GB instance. 49 | 50 | dataset: 51 | _target_: diffusion_policy.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset 52 | shape_meta: *shape_meta 53 | dataset_path: *dataset_path 54 | horizon: ${horizon} 55 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 56 | pad_after: ${eval:'${n_action_steps}-1'} 57 | n_obs_steps: ${dataset_obs_steps} 58 | abs_action: *abs_action 59 | rotation_rep: 'rotation_6d' 60 | use_legacy_normalizer: False 61 | use_cache: True 62 | seed: 42 63 | val_ratio: 0.02 64 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/lift_lowdim.yaml: -------------------------------------------------------------------------------- 1 | name: lift_lowdim 2 | 3 | obs_dim: 19 4 | action_dim: 7 5 | keypoint_dim: 3 6 | 7 | obs_keys: &obs_keys ['object', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos'] 8 | task_name: &task_name lift 9 | dataset_type: &dataset_type ph 10 | dataset_path: &dataset_path data/robomimic/datasets/${task.task_name}/${task.dataset_type}/low_dim.hdf5 11 | abs_action: &abs_action False 12 | 13 | env_runner: 14 | _target_: diffusion_policy.env_runner.robomimic_lowdim_runner.RobomimicLowdimRunner 15 | dataset_path: *dataset_path 16 | obs_keys: *obs_keys 17 | n_train: 6 18 | n_train_vis: 2 19 | train_start_idx: 0 20 | n_test: 50 21 | n_test_vis: 4 22 | test_start_seed: 100000 23 | # use python's eval function as resolver, single-quoted string as argument 24 | max_steps: ${eval:'500 if "${task.dataset_type}" == "mh" else 400'} 25 | n_obs_steps: ${n_obs_steps} 26 | n_action_steps: ${n_action_steps} 27 | n_latency_steps: ${n_latency_steps} 28 | render_hw: [128,128] 29 | fps: 10 30 | crf: 22 31 | past_action: ${past_action_visible} 32 | abs_action: *abs_action 33 | tqdm_interval_sec: 1.0 34 | n_envs: 28 35 | 36 | dataset: 37 | _target_: diffusion_policy.dataset.robomimic_replay_lowdim_dataset.RobomimicReplayLowdimDataset 38 | dataset_path: *dataset_path 39 | horizon: ${horizon} 40 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 41 | pad_after: ${eval:'${n_action_steps}-1'} 42 | obs_keys: *obs_keys 43 | abs_action: *abs_action 44 | use_legacy_normalizer: False 45 | seed: 42 46 | val_ratio: 0.02 47 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/lift_lowdim_abs.yaml: -------------------------------------------------------------------------------- 1 | name: lift_lowdim 2 | 3 | obs_dim: 19 4 | action_dim: 10 5 | keypoint_dim: 3 6 | 7 | obs_keys: &obs_keys ['object', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos'] 8 | task_name: &task_name lift 9 | dataset_type: &dataset_type ph 10 | dataset_path: &dataset_path data/robomimic/datasets/${task.task_name}/${task.dataset_type}/low_dim_abs.hdf5 11 | abs_action: &abs_action True 12 | 13 | env_runner: 14 | _target_: diffusion_policy.env_runner.robomimic_lowdim_runner.RobomimicLowdimRunner 15 | dataset_path: *dataset_path 16 | obs_keys: *obs_keys 17 | n_train: 6 18 | n_train_vis: 2 19 | train_start_idx: 0 20 | n_test: 50 21 | n_test_vis: 3 22 | test_start_seed: 100000 23 | # use python's eval function as resolver, single-quoted string as argument 24 | max_steps: ${eval:'500 if "${task.dataset_type}" == "mh" else 400'} 25 | n_obs_steps: ${n_obs_steps} 26 | n_action_steps: ${n_action_steps} 27 | n_latency_steps: ${n_latency_steps} 28 | render_hw: [128,128] 29 | fps: 10 30 | crf: 22 31 | past_action: ${past_action_visible} 32 | abs_action: *abs_action 33 | tqdm_interval_sec: 1.0 34 | n_envs: 28 35 | 36 | dataset: 37 | _target_: diffusion_policy.dataset.robomimic_replay_lowdim_dataset.RobomimicReplayLowdimDataset 38 | dataset_path: *dataset_path 39 | horizon: ${horizon} 40 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 41 | pad_after: ${eval:'${n_action_steps}-1'} 42 | obs_keys: *obs_keys 43 | abs_action: *abs_action 44 | use_legacy_normalizer: False 45 | rotation_rep: rotation_6d 46 | seed: 42 47 | val_ratio: 0.02 48 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/pusht_image.yaml: -------------------------------------------------------------------------------- 1 | name: pusht_image 2 | 3 | image_shape: &image_shape [3, 96, 96] 4 | shape_meta: &shape_meta 5 | # acceptable types: rgb, low_dim 6 | obs: 7 | image: 8 | shape: *image_shape 9 | type: rgb 10 | agent_pos: 11 | shape: [2] 12 | type: low_dim 13 | action: 14 | shape: [2] 15 | 16 | env_runner: 17 | _target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner 18 | n_train: 6 19 | n_train_vis: 2 20 | train_start_seed: 0 21 | n_test: 50 22 | n_test_vis: 4 23 | legacy_test: True 24 | test_start_seed: 100000 25 | max_steps: 300 26 | n_obs_steps: ${n_obs_steps} 27 | n_action_steps: ${n_action_steps} 28 | fps: 10 29 | past_action: ${past_action_visible} 30 | n_envs: null 31 | 32 | dataset: 33 | _target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset 34 | zarr_path: data/pusht/pusht_cchi_v7_replay.zarr 35 | horizon: ${horizon} 36 | pad_before: ${eval:'${n_obs_steps}-1'} 37 | pad_after: ${eval:'${n_action_steps}-1'} 38 | seed: 42 39 | val_ratio: 0.02 40 | max_train_episodes: 90 41 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/pusht_lowdim.yaml: -------------------------------------------------------------------------------- 1 | name: pusht_lowdim 2 | 3 | obs_dim: 20 # 9*2 keypoints + 2 state 4 | action_dim: 2 5 | keypoint_dim: 2 6 | 7 | env_runner: 8 | _target_: diffusion_policy.env_runner.pusht_keypoints_runner.PushTKeypointsRunner 9 | keypoint_visible_rate: ${keypoint_visible_rate} 10 | n_train: 6 11 | n_train_vis: 2 12 | train_start_seed: 0 13 | n_test: 50 14 | n_test_vis: 4 15 | legacy_test: True 16 | test_start_seed: 100000 17 | max_steps: 300 18 | n_obs_steps: ${n_obs_steps} 19 | n_action_steps: ${n_action_steps} 20 | n_latency_steps: ${n_latency_steps} 21 | fps: 10 22 | agent_keypoints: False 23 | past_action: ${past_action_visible} 24 | n_envs: null 25 | 26 | dataset: 27 | _target_: diffusion_policy.dataset.pusht_dataset.PushTLowdimDataset 28 | zarr_path: data/pusht/pusht_cchi_v7_replay.zarr 29 | horizon: ${horizon} 30 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 31 | pad_after: ${eval:'${n_action_steps}-1'} 32 | seed: 42 33 | val_ratio: 0.02 34 | max_train_episodes: 90 35 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/real_pusht_image.yaml: -------------------------------------------------------------------------------- 1 | name: real_image 2 | 3 | image_shape: [3, 240, 320] 4 | dataset_path: data/pusht_real/real_pusht_20230105 5 | 6 | shape_meta: &shape_meta 7 | # acceptable types: rgb, low_dim 8 | obs: 9 | # camera_0: 10 | # shape: ${task.image_shape} 11 | # type: rgb 12 | camera_1: 13 | shape: ${task.image_shape} 14 | type: rgb 15 | # camera_2: 16 | # shape: ${task.image_shape} 17 | # type: rgb 18 | camera_3: 19 | shape: ${task.image_shape} 20 | type: rgb 21 | # camera_4: 22 | # shape: ${task.image_shape} 23 | # type: rgb 24 | robot_eef_pose: 25 | shape: [2] 26 | type: low_dim 27 | action: 28 | shape: [2] 29 | 30 | env_runner: 31 | _target_: diffusion_policy.env_runner.real_pusht_image_runner.RealPushTImageRunner 32 | 33 | dataset: 34 | _target_: diffusion_policy.dataset.real_pusht_image_dataset.RealPushTImageDataset 35 | shape_meta: *shape_meta 36 | dataset_path: ${task.dataset_path} 37 | horizon: ${horizon} 38 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 39 | pad_after: ${eval:'${n_action_steps}-1'} 40 | n_obs_steps: ${dataset_obs_steps} 41 | n_latency_steps: ${n_latency_steps} 42 | use_cache: True 43 | seed: 42 44 | val_ratio: 0.00 45 | max_train_episodes: null 46 | delta_action: False 47 | 48 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/square_image.yaml: -------------------------------------------------------------------------------- 1 | name: square_image 2 | 3 | shape_meta: &shape_meta 4 | # acceptable types: rgb, low_dim 5 | obs: 6 | agentview_image: 7 | shape: [3, 84, 84] 8 | type: rgb 9 | robot0_eye_in_hand_image: 10 | shape: [3, 84, 84] 11 | type: rgb 12 | robot0_eef_pos: 13 | shape: [3] 14 | # type default: low_dim 15 | robot0_eef_quat: 16 | shape: [4] 17 | robot0_gripper_qpos: 18 | shape: [2] 19 | action: 20 | shape: [7] 21 | 22 | task_name: &task_name square 23 | dataset_type: &dataset_type ph 24 | dataset_path: &dataset_path data/robomimic/datasets/${task.task_name}/${task.dataset_type}/image.hdf5 25 | abs_action: &abs_action False 26 | 27 | env_runner: 28 | _target_: diffusion_policy.env_runner.robomimic_image_runner.RobomimicImageRunner 29 | dataset_path: *dataset_path 30 | shape_meta: *shape_meta 31 | # costs 1GB per env 32 | n_train: 6 33 | n_train_vis: 2 34 | train_start_idx: 0 35 | n_test: 50 36 | n_test_vis: 4 37 | test_start_seed: 100000 38 | # use python's eval function as resolver, single-quoted string as argument 39 | max_steps: ${eval:'500 if "${task.dataset_type}" == "mh" else 400'} 40 | n_obs_steps: ${n_obs_steps} 41 | n_action_steps: ${n_action_steps} 42 | render_obs_key: 'agentview_image' 43 | fps: 10 44 | crf: 22 45 | past_action: ${past_action_visible} 46 | abs_action: *abs_action 47 | tqdm_interval_sec: 1.0 48 | n_envs: 28 49 | # evaluation at this config requires a 16 core 64GB instance. 50 | 51 | dataset: 52 | _target_: diffusion_policy.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset 53 | shape_meta: *shape_meta 54 | dataset_path: *dataset_path 55 | horizon: ${horizon} 56 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 57 | pad_after: ${eval:'${n_action_steps}-1'} 58 | n_obs_steps: ${dataset_obs_steps} 59 | abs_action: *abs_action 60 | rotation_rep: 'rotation_6d' 61 | use_legacy_normalizer: False 62 | use_cache: True 63 | seed: 42 64 | val_ratio: 0.02 65 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/square_image_abs.yaml: -------------------------------------------------------------------------------- 1 | name: square_image 2 | 3 | shape_meta: &shape_meta 4 | # acceptable types: rgb, low_dim 5 | obs: 6 | agentview_image: 7 | shape: [3, 84, 84] 8 | type: rgb 9 | robot0_eye_in_hand_image: 10 | shape: [3, 84, 84] 11 | type: rgb 12 | robot0_eef_pos: 13 | shape: [3] 14 | # type default: low_dim 15 | robot0_eef_quat: 16 | shape: [4] 17 | robot0_gripper_qpos: 18 | shape: [2] 19 | action: 20 | shape: [10] 21 | 22 | task_name: &task_name square 23 | dataset_type: &dataset_type ph 24 | dataset_path: &dataset_path data/robomimic/datasets/${task.task_name}/${task.dataset_type}/image_abs.hdf5 25 | abs_action: &abs_action True 26 | 27 | env_runner: 28 | _target_: diffusion_policy.env_runner.robomimic_image_runner.RobomimicImageRunner 29 | dataset_path: *dataset_path 30 | shape_meta: *shape_meta 31 | # costs 1GB per env 32 | n_train: 6 33 | n_train_vis: 2 34 | train_start_idx: 0 35 | n_test: 50 36 | n_test_vis: 4 37 | test_start_seed: 100000 38 | # use python's eval function as resolver, single-quoted string as argument 39 | max_steps: ${eval:'500 if "${task.dataset_type}" == "mh" else 400'} 40 | n_obs_steps: ${n_obs_steps} 41 | n_action_steps: ${n_action_steps} 42 | render_obs_key: 'agentview_image' 43 | fps: 10 44 | crf: 22 45 | past_action: ${past_action_visible} 46 | abs_action: *abs_action 47 | tqdm_interval_sec: 1.0 48 | n_envs: 28 49 | # evaluation at this config requires a 16 core 64GB instance. 50 | 51 | dataset: 52 | _target_: diffusion_policy.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset 53 | shape_meta: *shape_meta 54 | dataset_path: *dataset_path 55 | horizon: ${horizon} 56 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 57 | pad_after: ${eval:'${n_action_steps}-1'} 58 | n_obs_steps: ${dataset_obs_steps} 59 | abs_action: *abs_action 60 | rotation_rep: 'rotation_6d' 61 | use_legacy_normalizer: False 62 | use_cache: True 63 | seed: 42 64 | val_ratio: 0.02 65 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/square_lowdim.yaml: -------------------------------------------------------------------------------- 1 | name: square_lowdim 2 | 3 | obs_dim: 23 4 | action_dim: 7 5 | keypoint_dim: 3 6 | 7 | obs_keys: &obs_keys ['object', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos'] 8 | task_name: &task_name square 9 | dataset_type: &dataset_type ph 10 | dataset_path: &dataset_path data/robomimic/datasets/${task.task_name}/${task.dataset_type}/low_dim.hdf5 11 | abs_action: &abs_action False 12 | 13 | env_runner: 14 | _target_: diffusion_policy.env_runner.robomimic_lowdim_runner.RobomimicLowdimRunner 15 | dataset_path: *dataset_path 16 | obs_keys: *obs_keys 17 | n_train: 6 18 | n_train_vis: 2 19 | train_start_idx: 0 20 | n_test: 50 21 | n_test_vis: 4 22 | test_start_seed: 100000 23 | # use python's eval function as resolver, single-quoted string as argument 24 | max_steps: ${eval:'500 if "${task.dataset_type}" == "mh" else 400'} 25 | n_obs_steps: ${n_obs_steps} 26 | n_action_steps: ${n_action_steps} 27 | n_latency_steps: ${n_latency_steps} 28 | render_hw: [128,128] 29 | fps: 10 30 | crf: 22 31 | past_action: ${past_action_visible} 32 | abs_action: *abs_action 33 | n_envs: 28 34 | 35 | dataset: 36 | _target_: diffusion_policy.dataset.robomimic_replay_lowdim_dataset.RobomimicReplayLowdimDataset 37 | dataset_path: *dataset_path 38 | horizon: ${horizon} 39 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 40 | pad_after: ${eval:'${n_action_steps}-1'} 41 | obs_keys: *obs_keys 42 | abs_action: *abs_action 43 | use_legacy_normalizer: False 44 | seed: 42 45 | val_ratio: 0.02 46 | max_train_episodes: null 47 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/square_lowdim_abs.yaml: -------------------------------------------------------------------------------- 1 | name: square_lowdim 2 | 3 | obs_dim: 23 4 | action_dim: 10 5 | keypoint_dim: 3 6 | 7 | obs_keys: &obs_keys ['object', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos'] 8 | task_name: &task_name square 9 | dataset_type: &dataset_type ph 10 | abs_action: &abs_action True 11 | dataset_path: &dataset_path data/robomimic/datasets/${task.task_name}/${task.dataset_type}/low_dim_abs.hdf5 12 | 13 | 14 | env_runner: 15 | _target_: diffusion_policy.env_runner.robomimic_lowdim_runner.RobomimicLowdimRunner 16 | dataset_path: *dataset_path 17 | obs_keys: *obs_keys 18 | n_train: 6 19 | n_train_vis: 2 20 | train_start_idx: 0 21 | n_test: 50 22 | n_test_vis: 4 23 | test_start_seed: 100000 24 | # use python's eval function as resolver, single-quoted string as argument 25 | max_steps: ${eval:'500 if "${task.dataset_type}" == "mh" else 400'} 26 | n_obs_steps: ${n_obs_steps} 27 | n_action_steps: ${n_action_steps} 28 | n_latency_steps: ${n_latency_steps} 29 | render_hw: [128,128] 30 | fps: 10 31 | crf: 22 32 | past_action: ${past_action_visible} 33 | abs_action: *abs_action 34 | n_envs: 28 35 | 36 | dataset: 37 | _target_: diffusion_policy.dataset.robomimic_replay_lowdim_dataset.RobomimicReplayLowdimDataset 38 | dataset_path: *dataset_path 39 | horizon: ${horizon} 40 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 41 | pad_after: ${eval:'${n_action_steps}-1'} 42 | obs_keys: *obs_keys 43 | abs_action: *abs_action 44 | use_legacy_normalizer: False 45 | seed: 42 46 | val_ratio: 0.02 47 | max_train_episodes: null 48 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/tool_hang_image.yaml: -------------------------------------------------------------------------------- 1 | name: tool_hang_image 2 | 3 | shape_meta: &shape_meta 4 | # acceptable types: rgb, low_dim 5 | obs: 6 | sideview_image: 7 | shape: [3, 240, 240] 8 | type: rgb 9 | robot0_eye_in_hand_image: 10 | shape: [3, 240, 240] 11 | type: rgb 12 | robot0_eef_pos: 13 | shape: [3] 14 | # type default: low_dim 15 | robot0_eef_quat: 16 | shape: [4] 17 | robot0_gripper_qpos: 18 | shape: [2] 19 | action: 20 | shape: [7] 21 | 22 | task_name: &task_name tool_hang 23 | dataset_type: &dataset_type ph 24 | dataset_path: &dataset_path data/robomimic/datasets/${task.task_name}/${task.dataset_type}/image.hdf5 25 | abs_action: &abs_action False 26 | 27 | env_runner: 28 | _target_: diffusion_policy.env_runner.robomimic_image_runner.RobomimicImageRunner 29 | dataset_path: *dataset_path 30 | shape_meta: *shape_meta 31 | # costs 1GB per env 32 | n_train: 6 33 | n_train_vis: 2 34 | train_start_idx: 0 35 | n_test: 50 36 | n_test_vis: 4 37 | test_start_seed: 100000 38 | max_steps: 700 39 | n_obs_steps: ${n_obs_steps} 40 | n_action_steps: ${n_action_steps} 41 | render_obs_key: 'sideview_image' 42 | fps: 10 43 | crf: 22 44 | past_action: ${past_action_visible} 45 | abs_action: *abs_action 46 | tqdm_interval_sec: 1.0 47 | n_envs: 28 48 | # evaluation at this config requires a 16 core 64GB instance. 49 | 50 | dataset: 51 | _target_: diffusion_policy.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset 52 | shape_meta: *shape_meta 53 | dataset_path: *dataset_path 54 | horizon: ${horizon} 55 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 56 | pad_after: ${eval:'${n_action_steps}-1'} 57 | n_obs_steps: ${dataset_obs_steps} 58 | abs_action: *abs_action 59 | rotation_rep: 'rotation_6d' 60 | use_legacy_normalizer: False 61 | use_cache: True 62 | seed: 42 63 | val_ratio: 0.02 64 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/tool_hang_image_abs.yaml: -------------------------------------------------------------------------------- 1 | name: tool_hang_image_abs 2 | 3 | shape_meta: &shape_meta 4 | # acceptable types: rgb, low_dim 5 | obs: 6 | sideview_image: 7 | shape: [3, 240, 240] 8 | type: rgb 9 | robot0_eye_in_hand_image: 10 | shape: [3, 240, 240] 11 | type: rgb 12 | robot0_eef_pos: 13 | shape: [3] 14 | # type default: low_dim 15 | robot0_eef_quat: 16 | shape: [4] 17 | robot0_gripper_qpos: 18 | shape: [2] 19 | action: 20 | shape: [10] 21 | 22 | task_name: &task_name tool_hang 23 | dataset_type: &dataset_type ph 24 | dataset_path: &dataset_path data/robomimic/datasets/${task.task_name}/${task.dataset_type}/image_abs.hdf5 25 | abs_action: &abs_action True 26 | 27 | env_runner: 28 | _target_: diffusion_policy.env_runner.robomimic_image_runner.RobomimicImageRunner 29 | dataset_path: *dataset_path 30 | shape_meta: *shape_meta 31 | # costs 1GB per env 32 | n_train: 6 33 | n_train_vis: 2 34 | train_start_idx: 0 35 | n_test: 50 36 | n_test_vis: 4 37 | test_start_seed: 100000 38 | max_steps: 700 39 | n_obs_steps: ${n_obs_steps} 40 | n_action_steps: ${n_action_steps} 41 | render_obs_key: 'sideview_image' 42 | fps: 10 43 | crf: 22 44 | past_action: ${past_action_visible} 45 | abs_action: *abs_action 46 | tqdm_interval_sec: 1.0 47 | n_envs: 28 48 | # evaluation at this config requires a 16 core 64GB instance. 49 | 50 | dataset: 51 | _target_: diffusion_policy.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset 52 | shape_meta: *shape_meta 53 | dataset_path: *dataset_path 54 | horizon: ${horizon} 55 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 56 | pad_after: ${eval:'${n_action_steps}-1'} 57 | n_obs_steps: ${dataset_obs_steps} 58 | abs_action: *abs_action 59 | rotation_rep: 'rotation_6d' 60 | use_legacy_normalizer: False 61 | use_cache: True 62 | seed: 42 63 | val_ratio: 0.02 64 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/tool_hang_lowdim.yaml: -------------------------------------------------------------------------------- 1 | name: tool_hang_lowdim 2 | 3 | obs_dim: 53 4 | action_dim: 7 5 | keypoint_dim: 3 6 | 7 | obs_keys: &obs_keys ['object', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos'] 8 | task_name: &task_name tool_hang 9 | dataset_type: &dataset_type ph 10 | dataset_path: &dataset_path data/robomimic/datasets/${task.task_name}/${task.dataset_type}/low_dim.hdf5 11 | abs_action: &abs_action False 12 | 13 | env_runner: 14 | _target_: diffusion_policy.env_runner.robomimic_lowdim_runner.RobomimicLowdimRunner 15 | dataset_path: *dataset_path 16 | obs_keys: *obs_keys 17 | n_train: 6 18 | n_train_vis: 2 19 | train_start_idx: 0 20 | n_test: 50 21 | n_test_vis: 4 22 | test_start_seed: 100000 23 | max_steps: 700 24 | n_obs_steps: ${n_obs_steps} 25 | n_action_steps: ${n_action_steps} 26 | n_latency_steps: ${n_latency_steps} 27 | render_hw: [128,128] 28 | fps: 10 29 | crf: 22 30 | past_action: ${past_action_visible} 31 | abs_action: *abs_action 32 | n_envs: 28 33 | # seed 42 will crash MuJoCo for some reason. 34 | 35 | dataset: 36 | _target_: diffusion_policy.dataset.robomimic_replay_lowdim_dataset.RobomimicReplayLowdimDataset 37 | dataset_path: *dataset_path 38 | horizon: ${horizon} 39 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 40 | pad_after: ${eval:'${n_action_steps}-1'} 41 | obs_keys: *obs_keys 42 | abs_action: *abs_action 43 | use_legacy_normalizer: False 44 | seed: 42 45 | val_ratio: 0.02 46 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/tool_hang_lowdim_abs.yaml: -------------------------------------------------------------------------------- 1 | name: tool_hang_lowdim 2 | 3 | obs_dim: 53 4 | action_dim: 10 5 | keypoint_dim: 3 6 | 7 | obs_keys: &obs_keys ['object', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos'] 8 | task_name: &task_name tool_hang 9 | dataset_type: &dataset_type ph 10 | dataset_path: &dataset_path data/robomimic/datasets/${task.task_name}/${task.dataset_type}/low_dim_abs.hdf5 11 | abs_action: &abs_action True 12 | 13 | env_runner: 14 | _target_: diffusion_policy.env_runner.robomimic_lowdim_runner.RobomimicLowdimRunner 15 | dataset_path: *dataset_path 16 | obs_keys: *obs_keys 17 | n_train: 6 18 | n_train_vis: 2 19 | train_start_idx: 0 20 | n_test: 50 21 | n_test_vis: 4 22 | test_start_seed: 100000 23 | max_steps: 700 24 | n_obs_steps: ${n_obs_steps} 25 | n_action_steps: ${n_action_steps} 26 | n_latency_steps: ${n_latency_steps} 27 | render_hw: [128,128] 28 | fps: 10 29 | crf: 22 30 | past_action: ${past_action_visible} 31 | abs_action: *abs_action 32 | n_envs: 28 33 | # seed 42 will crash MuJoCo for some reason. 34 | 35 | dataset: 36 | _target_: diffusion_policy.dataset.robomimic_replay_lowdim_dataset.RobomimicReplayLowdimDataset 37 | dataset_path: *dataset_path 38 | horizon: ${horizon} 39 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 40 | pad_after: ${eval:'${n_action_steps}-1'} 41 | obs_keys: *obs_keys 42 | abs_action: *abs_action 43 | use_legacy_normalizer: False 44 | rotation_rep: rotation_6d 45 | seed: 42 46 | val_ratio: 0.02 47 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/transport_image.yaml: -------------------------------------------------------------------------------- 1 | name: transport_image 2 | 3 | shape_meta: &shape_meta 4 | # acceptable types: rgb, low_dim 5 | obs: 6 | shouldercamera0_image: 7 | shape: [3, 84, 84] 8 | type: rgb 9 | robot0_eye_in_hand_image: 10 | shape: [3, 84, 84] 11 | type: rgb 12 | robot0_eef_pos: 13 | shape: [3] 14 | # type default: low_dim 15 | robot0_eef_quat: 16 | shape: [4] 17 | robot0_gripper_qpos: 18 | shape: [2] 19 | shouldercamera1_image: 20 | shape: [3, 84, 84] 21 | type: rgb 22 | robot1_eye_in_hand_image: 23 | shape: [3, 84, 84] 24 | type: rgb 25 | robot1_eef_pos: 26 | shape: [3] 27 | # type default: low_dim 28 | robot1_eef_quat: 29 | shape: [4] 30 | robot1_gripper_qpos: 31 | shape: [2] 32 | action: 33 | shape: [14] 34 | 35 | task_name: &task_name transport 36 | dataset_type: &dataset_type ph 37 | dataset_path: &dataset_path data/robomimic/datasets/${task.task_name}/${task.dataset_type}/image.hdf5 38 | abs_action: &abs_action False 39 | 40 | env_runner: 41 | _target_: diffusion_policy.env_runner.robomimic_image_runner.RobomimicImageRunner 42 | dataset_path: *dataset_path 43 | shape_meta: *shape_meta 44 | n_train: 6 45 | n_train_vis: 2 46 | train_start_idx: 0 47 | n_test: 50 48 | n_test_vis: 4 49 | test_start_seed: 100000 50 | max_steps: 700 51 | n_obs_steps: ${n_obs_steps} 52 | n_action_steps: ${n_action_steps} 53 | render_obs_key: 'shouldercamera0_image' 54 | fps: 10 55 | crf: 22 56 | past_action: ${past_action_visible} 57 | abs_action: *abs_action 58 | tqdm_interval_sec: 1.0 59 | n_envs: 28 60 | # evaluation at this config requires a 16 core 64GB instance. 61 | 62 | dataset: 63 | _target_: diffusion_policy.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset 64 | shape_meta: *shape_meta 65 | dataset_path: *dataset_path 66 | horizon: ${horizon} 67 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 68 | pad_after: ${eval:'${n_action_steps}-1'} 69 | n_obs_steps: ${dataset_obs_steps} 70 | abs_action: *abs_action 71 | rotation_rep: 'rotation_6d' 72 | use_legacy_normalizer: False 73 | use_cache: True 74 | seed: 42 75 | val_ratio: 0.02 76 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/transport_image_abs.yaml: -------------------------------------------------------------------------------- 1 | name: transport_image 2 | 3 | shape_meta: &shape_meta 4 | # acceptable types: rgb, low_dim 5 | obs: 6 | shouldercamera0_image: 7 | shape: [3, 84, 84] 8 | type: rgb 9 | robot0_eye_in_hand_image: 10 | shape: [3, 84, 84] 11 | type: rgb 12 | robot0_eef_pos: 13 | shape: [3] 14 | # type default: low_dim 15 | robot0_eef_quat: 16 | shape: [4] 17 | robot0_gripper_qpos: 18 | shape: [2] 19 | shouldercamera1_image: 20 | shape: [3, 84, 84] 21 | type: rgb 22 | robot1_eye_in_hand_image: 23 | shape: [3, 84, 84] 24 | type: rgb 25 | robot1_eef_pos: 26 | shape: [3] 27 | # type default: low_dim 28 | robot1_eef_quat: 29 | shape: [4] 30 | robot1_gripper_qpos: 31 | shape: [2] 32 | action: 33 | shape: [20] 34 | 35 | task_name: &task_name transport 36 | dataset_type: &dataset_type ph 37 | dataset_path: &dataset_path data/robomimic/datasets/${task.task_name}/${task.dataset_type}/image_abs.hdf5 38 | abs_action: &abs_action True 39 | 40 | env_runner: 41 | _target_: diffusion_policy.env_runner.robomimic_image_runner.RobomimicImageRunner 42 | dataset_path: *dataset_path 43 | shape_meta: *shape_meta 44 | n_train: 6 45 | n_train_vis: 2 46 | train_start_idx: 0 47 | n_test: 50 48 | n_test_vis: 4 49 | test_start_seed: 100000 50 | max_steps: 700 51 | n_obs_steps: ${n_obs_steps} 52 | n_action_steps: ${n_action_steps} 53 | render_obs_key: 'shouldercamera0_image' 54 | fps: 10 55 | crf: 22 56 | past_action: ${past_action_visible} 57 | abs_action: *abs_action 58 | tqdm_interval_sec: 1.0 59 | n_envs: 28 60 | # evaluation at this config requires a 16 core 64GB instance. 61 | 62 | dataset: 63 | _target_: diffusion_policy.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset 64 | shape_meta: *shape_meta 65 | dataset_path: *dataset_path 66 | horizon: ${horizon} 67 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 68 | pad_after: ${eval:'${n_action_steps}-1'} 69 | n_obs_steps: ${dataset_obs_steps} 70 | abs_action: *abs_action 71 | rotation_rep: 'rotation_6d' 72 | use_legacy_normalizer: False 73 | use_cache: True 74 | seed: 42 75 | val_ratio: 0.02 76 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/transport_lowdim.yaml: -------------------------------------------------------------------------------- 1 | name: transport_lowdim 2 | 3 | obs_dim: 59 # 41+(3+4+2)*2 4 | action_dim: 14 # 7*2 5 | keypoint_dim: 3 6 | 7 | obs_keys: &obs_keys [ 8 | 'object', 9 | 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos', 10 | 'robot1_eef_pos', 'robot1_eef_quat', 'robot1_gripper_qpos' 11 | ] 12 | task_name: &task_name transport 13 | dataset_type: &dataset_type ph 14 | dataset_path: &dataset_path data/robomimic/datasets/${task.task_name}/${task.dataset_type}/low_dim.hdf5 15 | abs_action: &abs_action False 16 | 17 | env_runner: 18 | _target_: diffusion_policy.env_runner.robomimic_lowdim_runner.RobomimicLowdimRunner 19 | dataset_path: *dataset_path 20 | obs_keys: *obs_keys 21 | n_train: 6 22 | n_train_vis: 2 23 | train_start_idx: 0 24 | n_test: 50 25 | n_test_vis: 5 26 | test_start_seed: 100000 27 | max_steps: 700 28 | n_obs_steps: ${n_obs_steps} 29 | n_action_steps: ${n_action_steps} 30 | n_latency_steps: ${n_latency_steps} 31 | render_hw: [128,128] 32 | fps: 10 33 | crf: 22 34 | past_action: ${past_action_visible} 35 | abs_action: *abs_action 36 | n_envs: 28 37 | # evaluation at this config requires a 16 core 64GB instance. 38 | 39 | dataset: 40 | _target_: diffusion_policy.dataset.robomimic_replay_lowdim_dataset.RobomimicReplayLowdimDataset 41 | dataset_path: *dataset_path 42 | horizon: ${horizon} 43 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 44 | pad_after: ${eval:'${n_action_steps}-1'} 45 | obs_keys: *obs_keys 46 | abs_action: *abs_action 47 | use_legacy_normalizer: False 48 | seed: 42 49 | val_ratio: 0.02 50 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/task/transport_lowdim_abs.yaml: -------------------------------------------------------------------------------- 1 | name: transport_lowdim 2 | 3 | obs_dim: 59 # 41+(3+4+2)*2 4 | action_dim: 20 # 10*2 5 | keypoint_dim: 3 6 | 7 | obs_keys: &obs_keys [ 8 | 'object', 9 | 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos', 10 | 'robot1_eef_pos', 'robot1_eef_quat', 'robot1_gripper_qpos' 11 | ] 12 | task_name: &task_name transport 13 | dataset_type: &dataset_type ph 14 | dataset_path: &dataset_path data/robomimic/datasets/${task.task_name}/${task.dataset_type}/low_dim_abs.hdf5 15 | abs_action: &abs_action True 16 | 17 | env_runner: 18 | _target_: diffusion_policy.env_runner.robomimic_lowdim_runner.RobomimicLowdimRunner 19 | dataset_path: *dataset_path 20 | obs_keys: *obs_keys 21 | n_train: 6 22 | n_train_vis: 2 23 | train_start_idx: 0 24 | n_test: 50 25 | n_test_vis: 4 26 | test_start_seed: 100000 27 | max_steps: 700 28 | n_obs_steps: ${n_obs_steps} 29 | n_action_steps: ${n_action_steps} 30 | n_latency_steps: ${n_latency_steps} 31 | render_hw: [128,128] 32 | fps: 10 33 | crf: 22 34 | past_action: ${past_action_visible} 35 | abs_action: *abs_action 36 | n_envs: 28 37 | # evaluation at this config requires a 16 core 64GB instance. 38 | 39 | dataset: 40 | _target_: diffusion_policy.dataset.robomimic_replay_lowdim_dataset.RobomimicReplayLowdimDataset 41 | dataset_path: *dataset_path 42 | horizon: ${horizon} 43 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 44 | pad_after: ${eval:'${n_action_steps}-1'} 45 | obs_keys: *obs_keys 46 | abs_action: *abs_action 47 | use_legacy_normalizer: False 48 | seed: 42 49 | val_ratio: 0.02 50 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/train_ibc_dfo_hybrid_workspace.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - task: pusht_image 4 | 5 | name: train_ibc_dfo_hybrid 6 | _target_: diffusion_policy.workspace.train_ibc_dfo_hybrid_workspace.TrainIbcDfoHybridWorkspace 7 | 8 | task_name: ${task.name} 9 | shape_meta: ${task.shape_meta} 10 | exp_name: "default" 11 | 12 | horizon: 2 13 | n_obs_steps: 2 14 | n_action_steps: 1 15 | n_latency_steps: 0 16 | dataset_obs_steps: ${n_obs_steps} 17 | past_action_visible: False 18 | keypoint_visible_rate: 1.0 19 | 20 | policy: 21 | _target_: diffusion_policy.policy.ibc_dfo_hybrid_image_policy.IbcDfoHybridImagePolicy 22 | 23 | shape_meta: ${shape_meta} 24 | 25 | horizon: ${horizon} 26 | n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'} 27 | n_obs_steps: ${n_obs_steps} 28 | dropout: 0.1 29 | train_n_neg: 1024 30 | pred_n_iter: 5 31 | pred_n_samples: 1024 32 | kevin_inference: False 33 | andy_train: False 34 | obs_encoder_group_norm: True 35 | eval_fixed_crop: True 36 | crop_shape: [84, 84] 37 | 38 | dataloader: 39 | batch_size: 128 40 | num_workers: 8 41 | shuffle: True 42 | pin_memory: True 43 | persistent_workers: False 44 | 45 | val_dataloader: 46 | batch_size: 128 47 | num_workers: 8 48 | shuffle: False 49 | pin_memory: True 50 | persistent_workers: False 51 | 52 | optimizer: 53 | _target_: torch.optim.AdamW 54 | lr: 1.0e-4 55 | betas: [0.95, 0.999] 56 | eps: 1.0e-8 57 | weight_decay: 1.0e-6 58 | 59 | training: 60 | device: "cuda:0" 61 | seed: 42 62 | debug: False 63 | resume: True 64 | # optimization 65 | lr_scheduler: cosine 66 | lr_warmup_steps: 500 67 | num_epochs: 3050 68 | gradient_accumulate_every: 1 69 | # training loop control 70 | # in epochs 71 | rollout_every: 50 72 | checkpoint_every: 50 73 | val_every: 1 74 | sample_every: 5 75 | sample_max_batch: 128 76 | # steps per epoch 77 | max_train_steps: null 78 | max_val_steps: null 79 | # misc 80 | tqdm_interval_sec: 1.0 81 | 82 | logging: 83 | project: diffusion_policy_debug 84 | resume: True 85 | mode: online 86 | name: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 87 | tags: ["${name}", "${task_name}", "${exp_name}"] 88 | id: null 89 | group: null 90 | 91 | checkpoint: 92 | topk: 93 | monitor_key: test_mean_score 94 | mode: max 95 | k: 5 96 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt' 97 | save_last_ckpt: True 98 | save_last_snapshot: False 99 | 100 | multi_run: 101 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 102 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 103 | 104 | hydra: 105 | job: 106 | override_dirname: ${name} 107 | run: 108 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 109 | sweep: 110 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 111 | subdir: ${hydra.job.num} 112 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/train_ibc_dfo_lowdim_workspace.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - task: pusht_lowdim 4 | 5 | name: train_ibc_dfo_lowdim 6 | _target_: diffusion_policy.workspace.train_ibc_dfo_lowdim_workspace.TrainIbcDfoLowdimWorkspace 7 | 8 | obs_dim: ${task.obs_dim} 9 | action_dim: ${task.action_dim} 10 | keypoint_dim: ${task.keypoint_dim} 11 | task_name: ${task.name} 12 | exp_name: "default" 13 | 14 | horizon: 2 15 | n_obs_steps: 2 16 | n_action_steps: 1 17 | n_latency_steps: 0 18 | past_action_visible: False 19 | keypoint_visible_rate: 1.0 20 | 21 | policy: 22 | _target_: diffusion_policy.policy.ibc_dfo_lowdim_policy.IbcDfoLowdimPolicy 23 | 24 | horizon: ${horizon} 25 | obs_dim: ${obs_dim} 26 | action_dim: ${action_dim} 27 | n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'} 28 | n_obs_steps: ${n_obs_steps} 29 | dropout: 0.1 30 | train_n_neg: 1024 31 | pred_n_iter: 5 32 | pred_n_samples: 1024 33 | kevin_inference: False 34 | andy_train: False 35 | 36 | dataloader: 37 | batch_size: 256 38 | num_workers: 1 39 | shuffle: True 40 | pin_memory: True 41 | persistent_workers: False 42 | 43 | val_dataloader: 44 | batch_size: 256 45 | num_workers: 1 46 | shuffle: False 47 | pin_memory: True 48 | persistent_workers: False 49 | 50 | optimizer: 51 | _target_: torch.optim.AdamW 52 | lr: 1.0e-4 53 | betas: [0.95, 0.999] 54 | eps: 1.0e-8 55 | weight_decay: 1.0e-6 56 | 57 | training: 58 | device: "cuda:0" 59 | seed: 42 60 | debug: False 61 | resume: True 62 | # optimization 63 | lr_scheduler: cosine 64 | lr_warmup_steps: 500 65 | num_epochs: 5000 66 | gradient_accumulate_every: 1 67 | # training loop control 68 | # in epochs 69 | rollout_every: 50 70 | checkpoint_every: 50 71 | val_every: 1 72 | sample_every: 5 73 | sample_max_batch: 128 74 | # steps per epoch 75 | max_train_steps: null 76 | max_val_steps: null 77 | # misc 78 | tqdm_interval_sec: 1.0 79 | 80 | logging: 81 | project: diffusion_policy_debug 82 | resume: True 83 | mode: online 84 | name: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 85 | tags: ["${name}", "${task_name}", "${exp_name}"] 86 | id: null 87 | group: null 88 | 89 | checkpoint: 90 | topk: 91 | monitor_key: test_mean_score 92 | mode: max 93 | k: 5 94 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt' 95 | save_last_ckpt: True 96 | save_last_snapshot: False 97 | 98 | multi_run: 99 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 100 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 101 | 102 | hydra: 103 | job: 104 | override_dirname: ${name} 105 | run: 106 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 107 | sweep: 108 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 109 | subdir: ${hydra.job.num} 110 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/train_ibc_dfo_real_hybrid_workspace.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - task: real_pusht_image 4 | 5 | name: train_ibc_dfo_hybrid 6 | _target_: diffusion_policy.workspace.train_ibc_dfo_hybrid_workspace.TrainIbcDfoHybridWorkspace 7 | 8 | task_name: ${task.name} 9 | shape_meta: ${task.shape_meta} 10 | exp_name: "default" 11 | 12 | horizon: 2 13 | n_obs_steps: 2 14 | n_action_steps: 1 15 | n_latency_steps: 1 16 | dataset_obs_steps: ${n_obs_steps} 17 | past_action_visible: False 18 | keypoint_visible_rate: 1.0 19 | 20 | policy: 21 | _target_: diffusion_policy.policy.ibc_dfo_hybrid_image_policy.IbcDfoHybridImagePolicy 22 | 23 | shape_meta: ${shape_meta} 24 | 25 | horizon: ${horizon} 26 | n_action_steps: ${n_action_steps} 27 | n_obs_steps: ${n_obs_steps} 28 | dropout: 0.1 29 | train_n_neg: 256 30 | pred_n_iter: 3 31 | pred_n_samples: 1024 32 | kevin_inference: False 33 | andy_train: False 34 | obs_encoder_group_norm: True 35 | eval_fixed_crop: True 36 | crop_shape: [216, 288] # ch, cw 320x240 90% 37 | 38 | dataloader: 39 | batch_size: 128 40 | num_workers: 8 41 | shuffle: True 42 | pin_memory: True 43 | persistent_workers: False 44 | 45 | val_dataloader: 46 | batch_size: 128 47 | num_workers: 1 48 | shuffle: False 49 | pin_memory: True 50 | persistent_workers: False 51 | 52 | optimizer: 53 | _target_: torch.optim.AdamW 54 | lr: 1.0e-4 55 | betas: [0.95, 0.999] 56 | eps: 1.0e-8 57 | weight_decay: 1.0e-6 58 | 59 | training: 60 | device: "cuda:0" 61 | seed: 42 62 | debug: False 63 | resume: True 64 | # optimization 65 | lr_scheduler: cosine 66 | lr_warmup_steps: 500 67 | num_epochs: 1000 68 | gradient_accumulate_every: 1 69 | # training loop control 70 | # in epochs 71 | rollout_every: 50 72 | checkpoint_every: 5 73 | val_every: 1 74 | sample_every: 5 75 | sample_max_batch: 128 76 | # steps per epoch 77 | max_train_steps: null 78 | max_val_steps: null 79 | # misc 80 | tqdm_interval_sec: 1.0 81 | 82 | logging: 83 | project: diffusion_policy_debug 84 | resume: True 85 | mode: online 86 | name: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 87 | tags: ["${name}", "${task_name}", "${exp_name}"] 88 | id: null 89 | group: null 90 | 91 | checkpoint: 92 | topk: 93 | monitor_key: train_action_mse_error 94 | mode: min 95 | k: 5 96 | format_str: 'epoch={epoch:04d}-train_action_mse_error={train_action_mse_error:.3f}.ckpt' 97 | save_last_ckpt: True 98 | save_last_snapshot: False 99 | 100 | multi_run: 101 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 102 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 103 | 104 | hydra: 105 | job: 106 | override_dirname: ${name} 107 | run: 108 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 109 | sweep: 110 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 111 | subdir: ${hydra.job.num} 112 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/train_robomimic_image_workspace.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - task: lift_image 4 | 5 | name: train_robomimic_image 6 | _target_: diffusion_policy.workspace.train_robomimic_image_workspace.TrainRobomimicImageWorkspace 7 | 8 | task_name: ${task.name} 9 | shape_meta: ${task.shape_meta} 10 | exp_name: "default" 11 | 12 | horizon: &horizon 10 13 | n_obs_steps: 1 14 | n_action_steps: 1 15 | n_latency_steps: 0 16 | dataset_obs_steps: *horizon 17 | past_action_visible: False 18 | keypoint_visible_rate: 1.0 19 | 20 | policy: 21 | _target_: diffusion_policy.policy.robomimic_image_policy.RobomimicImagePolicy 22 | shape_meta: ${shape_meta} 23 | algo_name: bc_rnn 24 | obs_type: image 25 | # oc.select resolver: key, default 26 | task_name: ${oc.select:task.task_name,lift} 27 | dataset_type: ${oc.select:task.dataset_type,ph} 28 | crop_shape: [76,76] 29 | 30 | dataloader: 31 | batch_size: 64 32 | num_workers: 16 33 | shuffle: True 34 | pin_memory: True 35 | persistent_workers: False 36 | 37 | val_dataloader: 38 | batch_size: 64 39 | num_workers: 16 40 | shuffle: False 41 | pin_memory: True 42 | persistent_workers: False 43 | 44 | training: 45 | device: "cuda:0" 46 | seed: 42 47 | debug: False 48 | resume: True 49 | # optimization 50 | num_epochs: 3050 51 | # training loop control 52 | # in epochs 53 | rollout_every: 50 54 | checkpoint_every: 50 55 | val_every: 1 56 | sample_every: 5 57 | # steps per epoch 58 | max_train_steps: null 59 | max_val_steps: null 60 | # misc 61 | tqdm_interval_sec: 1.0 62 | 63 | logging: 64 | project: diffusion_policy_debug 65 | resume: True 66 | mode: online 67 | name: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 68 | tags: ["${name}", "${task_name}", "${exp_name}"] 69 | id: null 70 | group: null 71 | 72 | checkpoint: 73 | topk: 74 | monitor_key: test_mean_score 75 | mode: max 76 | k: 5 77 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt' 78 | save_last_ckpt: True 79 | save_last_snapshot: False 80 | 81 | multi_run: 82 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 83 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 84 | 85 | hydra: 86 | job: 87 | override_dirname: ${name} 88 | run: 89 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 90 | sweep: 91 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 92 | subdir: ${hydra.job.num} 93 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/train_robomimic_lowdim_workspace.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - task: pusht_lowdim 4 | 5 | name: train_robomimic_lowdim 6 | _target_: diffusion_policy.workspace.train_robomimic_lowdim_workspace.TrainRobomimicLowdimWorkspace 7 | 8 | obs_dim: ${task.obs_dim} 9 | action_dim: ${task.action_dim} 10 | transition_dim: "${eval: ${task.obs_dim} + ${task.action_dim}}" 11 | task_name: ${task.name} 12 | exp_name: "default" 13 | 14 | horizon: 10 15 | n_obs_steps: 1 16 | n_action_steps: 1 17 | n_latency_steps: 0 18 | past_action_visible: False 19 | keypoint_visible_rate: 1.0 20 | 21 | policy: 22 | _target_: diffusion_policy.policy.robomimic_lowdim_policy.RobomimicLowdimPolicy 23 | action_dim: ${action_dim} 24 | obs_dim: ${obs_dim} 25 | algo_name: bc_rnn 26 | obs_type: low_dim 27 | # oc.select resolver: key, default 28 | task_name: ${oc.select:task.task_name,lift} 29 | dataset_type: ${oc.select:task.dataset_type,ph} 30 | 31 | dataloader: 32 | batch_size: 256 33 | num_workers: 1 34 | shuffle: True 35 | pin_memory: True 36 | persistent_workers: False 37 | 38 | val_dataloader: 39 | batch_size: 256 40 | num_workers: 1 41 | shuffle: False 42 | pin_memory: True 43 | persistent_workers: False 44 | 45 | training: 46 | device: "cuda:0" 47 | seed: 42 48 | debug: False 49 | resume: True 50 | # optimization 51 | num_epochs: 5000 52 | # training loop control 53 | # in epochs 54 | rollout_every: 50 55 | checkpoint_every: 50 56 | val_every: 1 57 | # steps per epoch 58 | max_train_steps: null 59 | max_val_steps: null 60 | # misc 61 | tqdm_interval_sec: 1.0 62 | 63 | logging: 64 | project: diffusion_policy_debug 65 | resume: True 66 | mode: online 67 | name: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 68 | tags: ["${name}", "${task_name}", "${exp_name}"] 69 | id: null 70 | group: null 71 | 72 | checkpoint: 73 | topk: 74 | monitor_key: test_mean_score 75 | mode: max 76 | k: 5 77 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt' 78 | save_last_ckpt: True 79 | save_last_snapshot: False 80 | 81 | multi_run: 82 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 83 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 84 | 85 | hydra: 86 | job: 87 | override_dirname: ${name} 88 | run: 89 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 90 | sweep: 91 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 92 | subdir: ${hydra.job.num} 93 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/config/train_robomimic_real_image_workspace.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - task: real_pusht_image 4 | 5 | name: train_robomimic_image 6 | _target_: diffusion_policy.workspace.train_robomimic_image_workspace.TrainRobomimicImageWorkspace 7 | 8 | task_name: ${task.name} 9 | shape_meta: ${task.shape_meta} 10 | exp_name: "default" 11 | 12 | horizon: &horizon 10 13 | n_obs_steps: 1 14 | n_action_steps: 1 15 | n_latency_steps: 1 16 | dataset_obs_steps: *horizon 17 | past_action_visible: False 18 | keypoint_visible_rate: 1.0 19 | 20 | policy: 21 | _target_: diffusion_policy.policy.robomimic_image_policy.RobomimicImagePolicy 22 | shape_meta: ${shape_meta} 23 | algo_name: bc_rnn 24 | obs_type: image 25 | # oc.select resolver: key, default 26 | task_name: ${oc.select:task.task_name,tool_hang} 27 | dataset_type: ${oc.select:task.dataset_type,ph} 28 | crop_shape: [216, 288] # ch, cw 320x240 90% 29 | 30 | dataloader: 31 | batch_size: 32 32 | num_workers: 8 33 | shuffle: True 34 | pin_memory: True 35 | persistent_workers: True 36 | 37 | val_dataloader: 38 | batch_size: 32 39 | num_workers: 1 40 | shuffle: False 41 | pin_memory: True 42 | persistent_workers: False 43 | 44 | training: 45 | device: "cuda:0" 46 | seed: 42 47 | debug: False 48 | resume: True 49 | # optimization 50 | num_epochs: 1000 51 | # training loop control 52 | # in epochs 53 | rollout_every: 50 54 | checkpoint_every: 50 55 | val_every: 1 56 | sample_every: 5 57 | # steps per epoch 58 | max_train_steps: null 59 | max_val_steps: null 60 | # misc 61 | tqdm_interval_sec: 1.0 62 | 63 | logging: 64 | project: diffusion_policy_debug 65 | resume: True 66 | mode: online 67 | name: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 68 | tags: ["${name}", "${task_name}", "${exp_name}"] 69 | id: null 70 | group: null 71 | 72 | checkpoint: 73 | topk: 74 | monitor_key: train_loss 75 | mode: min 76 | k: 5 77 | format_str: 'epoch={epoch:04d}-train_loss={train_loss:.3f}.ckpt' 78 | save_last_ckpt: True 79 | save_last_snapshot: False 80 | 81 | multi_run: 82 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 83 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 84 | 85 | hydra: 86 | job: 87 | override_dirname: ${name} 88 | run: 89 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 90 | sweep: 91 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 92 | subdir: ${hydra.job.num} 93 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.nn 5 | from diffusion_policy.model.common.normalizer import LinearNormalizer 6 | 7 | class BaseLowdimDataset(torch.utils.data.Dataset): 8 | def get_validation_dataset(self) -> 'BaseLowdimDataset': 9 | # return an empty dataset by default 10 | return BaseLowdimDataset() 11 | 12 | def get_normalizer(self, **kwargs) -> LinearNormalizer: 13 | raise NotImplementedError() 14 | 15 | def get_all_actions(self) -> torch.Tensor: 16 | raise NotImplementedError() 17 | 18 | def __len__(self) -> int: 19 | return 0 20 | 21 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 22 | """ 23 | output: 24 | obs: T, Do 25 | action: T, Da 26 | """ 27 | raise NotImplementedError() 28 | 29 | 30 | class BaseImageDataset(torch.utils.data.Dataset): 31 | def get_validation_dataset(self) -> 'BaseLowdimDataset': 32 | # return an empty dataset by default 33 | return BaseImageDataset() 34 | 35 | def get_normalizer(self, **kwargs) -> LinearNormalizer: 36 | raise NotImplementedError() 37 | 38 | def get_all_actions(self) -> torch.Tensor: 39 | raise NotImplementedError() 40 | 41 | def __len__(self) -> int: 42 | return 0 43 | 44 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 45 | """ 46 | output: 47 | obs: 48 | key: T, * 49 | action: T, Da 50 | """ 51 | raise NotImplementedError() 52 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/env/pusht/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | import diffusion_policy.env.pusht 3 | 4 | register( 5 | id='pusht-keypoints-v0', 6 | entry_point='envs.pusht.pusht_keypoints_env:PushTKeypointsEnv', 7 | max_episode_steps=200, 8 | reward_threshold=1.0 9 | ) -------------------------------------------------------------------------------- /pusht/diffusion_policy/env/pusht/pusht_image_env.py: -------------------------------------------------------------------------------- 1 | from gym import spaces 2 | from diffusion_policy.env.pusht.pusht_env import PushTEnv 3 | import numpy as np 4 | import cv2 5 | 6 | class PushTImageEnv(PushTEnv): 7 | metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 10} 8 | 9 | def __init__(self, 10 | legacy=False, 11 | block_cog=None, 12 | damping=None, 13 | render_size=96): 14 | super().__init__( 15 | legacy=legacy, 16 | block_cog=block_cog, 17 | damping=damping, 18 | render_size=render_size, 19 | render_action=False) 20 | ws = self.window_size 21 | self.observation_space = spaces.Dict({ 22 | 'image': spaces.Box( 23 | low=0, 24 | high=1, 25 | shape=(3,render_size,render_size), 26 | dtype=np.float32 27 | ), 28 | 'agent_pos': spaces.Box( 29 | low=0, 30 | high=ws, 31 | shape=(2,), 32 | dtype=np.float32 33 | ) 34 | }) 35 | self.render_cache = None 36 | 37 | def _get_obs(self): 38 | img = super()._render_frame(mode='rgb_array') 39 | 40 | agent_pos = np.array(self.agent.position) 41 | img_obs = np.moveaxis(img.astype(np.float32) / 255, -1, 0) 42 | obs = { 43 | 'image': img_obs, 44 | 'agent_pos': agent_pos 45 | } 46 | 47 | # draw action 48 | if self.latest_action is not None: 49 | action = np.array(self.latest_action) 50 | coord = (action / 512 * 96).astype(np.int32) 51 | marker_size = int(8/96*self.render_size) 52 | thickness = int(1/96*self.render_size) 53 | cv2.drawMarker(img, coord, 54 | color=(255,0,0), markerType=cv2.MARKER_CROSS, 55 | markerSize=marker_size, thickness=thickness) 56 | self.render_cache = img 57 | 58 | return obs 59 | 60 | def render(self, mode): 61 | assert mode == 'rgb_array' 62 | 63 | if self.render_cache is None: 64 | self._get_obs() 65 | 66 | return self.render_cache 67 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/env_runner/base_image_runner.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from diffusion_policy.policy.base_image_policy import BaseImagePolicy 3 | 4 | class BaseImageRunner: 5 | def __init__(self, output_dir): 6 | self.output_dir = output_dir 7 | 8 | def run(self, policy: BaseImagePolicy) -> Dict: 9 | raise NotImplementedError() 10 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/env_runner/base_lowdim_runner.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from diffusion_policy.policy.base_lowdim_policy import BaseLowdimPolicy 3 | 4 | class BaseLowdimRunner: 5 | def __init__(self, output_dir): 6 | self.output_dir = output_dir 7 | 8 | def run(self, policy: BaseLowdimPolicy) -> Dict: 9 | raise NotImplementedError() 10 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/env_runner/real_pusht_image_runner.py: -------------------------------------------------------------------------------- 1 | from diffusion_policy.policy.base_image_policy import BaseImagePolicy 2 | from diffusion_policy.env_runner.base_image_runner import BaseImageRunner 3 | 4 | class RealPushTImageRunner(BaseImageRunner): 5 | def __init__(self, 6 | output_dir): 7 | super().__init__(output_dir) 8 | 9 | def run(self, policy: BaseImagePolicy): 10 | return dict() 11 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/gym_util/video_recording_wrapper.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | from diffusion_policy.real_world.video_recorder import VideoRecorder 4 | 5 | class VideoRecordingWrapper(gym.Wrapper): 6 | def __init__(self, 7 | env, 8 | video_recoder: VideoRecorder, 9 | mode='rgb_array', 10 | file_path=None, 11 | steps_per_render=1, 12 | **kwargs 13 | ): 14 | """ 15 | When file_path is None, don't record. 16 | """ 17 | super().__init__(env) 18 | 19 | self.mode = mode 20 | self.render_kwargs = kwargs 21 | self.steps_per_render = steps_per_render 22 | self.file_path = file_path 23 | self.video_recoder = video_recoder 24 | 25 | self.step_count = 0 26 | 27 | def reset(self, **kwargs): 28 | obs = super().reset(**kwargs) 29 | self.frames = list() 30 | self.step_count = 1 31 | self.video_recoder.stop() 32 | return obs 33 | 34 | def step(self, action): 35 | result = super().step(action) 36 | self.step_count += 1 37 | if self.file_path is not None \ 38 | and ((self.step_count % self.steps_per_render) == 0): 39 | if not self.video_recoder.is_ready(): 40 | self.video_recoder.start(self.file_path) 41 | 42 | frame = self.env.render( 43 | mode=self.mode, **self.render_kwargs) 44 | assert frame.dtype == np.uint8 45 | self.video_recoder.write_frame(frame) 46 | return result 47 | 48 | def render(self, mode='rgb_array', **kwargs): 49 | if self.video_recoder.is_ready(): 50 | self.video_recoder.stop() 51 | return self.file_path 52 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/gym_util/video_wrapper.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | class VideoWrapper(gym.Wrapper): 5 | def __init__(self, 6 | env, 7 | mode='rgb_array', 8 | enabled=True, 9 | steps_per_render=1, 10 | **kwargs 11 | ): 12 | super().__init__(env) 13 | 14 | self.mode = mode 15 | self.enabled = enabled 16 | self.render_kwargs = kwargs 17 | self.steps_per_render = steps_per_render 18 | 19 | self.frames = list() 20 | self.step_count = 0 21 | 22 | def reset(self, **kwargs): 23 | obs = super().reset(**kwargs) 24 | self.frames = list() 25 | self.step_count = 1 26 | if self.enabled: 27 | frame = self.env.render( 28 | mode=self.mode, **self.render_kwargs) 29 | assert frame.dtype == np.uint8 30 | self.frames.append(frame) 31 | return obs 32 | 33 | def step(self, action): 34 | result = super().step(action) 35 | self.step_count += 1 36 | if self.enabled and ((self.step_count % self.steps_per_render) == 0): 37 | frame = self.env.render( 38 | mode=self.mode, **self.render_kwargs) 39 | assert frame.dtype == np.uint8 40 | self.frames.append(frame) 41 | return result 42 | 43 | def render(self, mode='rgb_array', **kwargs): 44 | return self.frames 45 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/model/bet/action_ae/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | import abc 5 | 6 | from typing import Optional, Union 7 | 8 | import diffusion_policy.model.bet.utils as utils 9 | 10 | 11 | class AbstractActionAE(utils.SaveModule, abc.ABC): 12 | @abc.abstractmethod 13 | def fit_model( 14 | self, 15 | input_dataloader: DataLoader, 16 | eval_dataloader: DataLoader, 17 | obs_encoding_net: Optional[nn.Module] = None, 18 | ) -> None: 19 | pass 20 | 21 | @abc.abstractmethod 22 | def encode_into_latent( 23 | self, 24 | input_action: torch.Tensor, 25 | input_rep: Optional[torch.Tensor], 26 | ) -> torch.Tensor: 27 | """ 28 | Given the input action, discretize it. 29 | 30 | Inputs: 31 | input_action (shape: ... x action_dim): The input action to discretize. This can be in a batch, 32 | and is generally assumed that the last dimnesion is the action dimension. 33 | 34 | Outputs: 35 | discretized_action (shape: ... x num_tokens): The discretized action. 36 | """ 37 | raise NotImplementedError 38 | 39 | @abc.abstractmethod 40 | def decode_actions( 41 | self, 42 | latent_action_batch: Optional[torch.Tensor], 43 | input_rep_batch: Optional[torch.Tensor] = None, 44 | ) -> torch.Tensor: 45 | """ 46 | Given a discretized action, convert it to a continuous action. 47 | 48 | Inputs: 49 | latent_action_batch (shape: ... x num_tokens): The discretized action 50 | generated by the discretizer. 51 | 52 | Outputs: 53 | continuous_action (shape: ... x action_dim): The continuous action. 54 | """ 55 | raise NotImplementedError 56 | 57 | @property 58 | @abc.abstractmethod 59 | def num_latents(self) -> Union[int, float]: 60 | """ 61 | Number of possible latents for this generator, useful for state priors that use softmax. 62 | """ 63 | return float("inf") 64 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/model/bet/latent_generators/latent_generator.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | from typing import Tuple, Optional 4 | 5 | import diffusion_policy.model.bet.utils as utils 6 | 7 | 8 | class AbstractLatentGenerator(abc.ABC, utils.SaveModule): 9 | """ 10 | Abstract class for a generative model that can generate latents given observation representations. 11 | 12 | In the probabilisitc sense, this model fits and samples from P(latent|observation) given some observation. 13 | """ 14 | 15 | @abc.abstractmethod 16 | def get_latent_and_loss( 17 | self, 18 | obs_rep: torch.Tensor, 19 | target_latents: torch.Tensor, 20 | seq_masks: Optional[torch.Tensor] = None, 21 | ) -> Tuple[torch.Tensor, torch.Tensor]: 22 | """ 23 | Given a set of observation representation and generated latents, get the encoded latent and the loss. 24 | 25 | Inputs: 26 | input_action: Batch of the actions taken in the multimodal demonstrations. 27 | target_latents: Batch of the latents that the generator should learn to generate the actions from. 28 | seq_masks: Batch of masks that indicate which timesteps are valid. 29 | 30 | Outputs: 31 | latent: The sampled latent from the observation. 32 | loss: The loss of the latent generator. 33 | """ 34 | pass 35 | 36 | @abc.abstractmethod 37 | def generate_latents( 38 | self, seq_obses: torch.Tensor, seq_masks: torch.Tensor 39 | ) -> torch.Tensor: 40 | """ 41 | Given a batch of sequences of observations, generate a batch of sequences of latents. 42 | 43 | Inputs: 44 | seq_obses: Batch of sequences of observations, of shape seq x batch x dim, following the transformer convention. 45 | seq_masks: Batch of sequences of masks, of shape seq x batch, following the transformer convention. 46 | 47 | Outputs: 48 | seq_latents: Batch of sequences of latents of shape seq x batch x latent_dim. 49 | """ 50 | pass 51 | 52 | def get_optimizer( 53 | self, weight_decay: float, learning_rate: float, betas: Tuple[float, float] 54 | ) -> torch.optim.Optimizer: 55 | """ 56 | Default optimizer class. Override this if you want to use a different optimizer. 57 | """ 58 | return torch.optim.Adam( 59 | self.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=betas 60 | ) 61 | 62 | 63 | class LatentGeneratorDataParallel(torch.nn.DataParallel): 64 | def get_latent_and_loss(self, *args, **kwargs): 65 | return self.module.get_latent_and_loss(*args, **kwargs) # type: ignore 66 | 67 | def generate_latents(self, *args, **kwargs): 68 | return self.module.generate_latents(*args, **kwargs) # type: ignore 69 | 70 | def get_optimizer(self, *args, **kwargs): 71 | return self.module.get_optimizer(*args, **kwargs) # type: ignore 72 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/model/bet/libraries/mingpt/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | 9 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/model/bet/libraries/mingpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlzxy/arp/73cb339c032b3c350966269f4729bd1f29e927a9/pusht/diffusion_policy/model/bet/libraries/mingpt/__init__.py -------------------------------------------------------------------------------- /pusht/diffusion_policy/model/bet/libraries/mingpt/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def set_seed(seed): 9 | random.seed(seed) 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | torch.cuda.manual_seed_all(seed) 13 | 14 | 15 | def top_k_logits(logits, k): 16 | v, ix = torch.topk(logits, k) 17 | out = logits.clone() 18 | out[out < v[:, [-1]]] = -float("Inf") 19 | return out 20 | 21 | 22 | @torch.no_grad() 23 | def sample(model, x, steps, temperature=1.0, sample=False, top_k=None): 24 | """ 25 | take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in 26 | the sequence, feeding the predictions back into the model each time. Clearly the sampling 27 | has quadratic complexity unlike an RNN that is only linear, and has a finite context window 28 | of block_size, unlike an RNN that has an infinite context window. 29 | """ 30 | block_size = model.get_block_size() 31 | model.eval() 32 | for k in range(steps): 33 | x_cond = ( 34 | x if x.size(1) <= block_size else x[:, -block_size:] 35 | ) # crop context if needed 36 | logits, _ = model(x_cond) 37 | # pluck the logits at the final step and scale by temperature 38 | logits = logits[:, -1, :] / temperature 39 | # optionally crop probabilities to only the top k options 40 | if top_k is not None: 41 | logits = top_k_logits(logits, top_k) 42 | # apply softmax to convert to probabilities 43 | probs = F.softmax(logits, dim=-1) 44 | # sample from the distribution or take the most likely 45 | if sample: 46 | ix = torch.multinomial(probs, num_samples=1) 47 | else: 48 | _, ix = torch.topk(probs, k=1, dim=-1) 49 | # append to the sequence and continue 50 | x = torch.cat((x, ix), dim=1) 51 | 52 | return x 53 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/model/common/dict_of_tensor_mixin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DictOfTensorMixin(nn.Module): 5 | def __init__(self, params_dict=None): 6 | super().__init__() 7 | if params_dict is None: 8 | params_dict = nn.ParameterDict() 9 | self.params_dict = params_dict 10 | 11 | @property 12 | def device(self): 13 | return next(iter(self.parameters())).device 14 | 15 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 16 | def dfs_add(dest, keys, value: torch.Tensor): 17 | if len(keys) == 1: 18 | dest[keys[0]] = value 19 | return 20 | 21 | if keys[0] not in dest: 22 | dest[keys[0]] = nn.ParameterDict() 23 | dfs_add(dest[keys[0]], keys[1:], value) 24 | 25 | def load_dict(state_dict, prefix): 26 | out_dict = nn.ParameterDict() 27 | for key, value in state_dict.items(): 28 | value: torch.Tensor 29 | if key.startswith(prefix): 30 | param_keys = key[len(prefix):].split('.')[1:] 31 | # if len(param_keys) == 0: 32 | # import pdb; pdb.set_trace() 33 | dfs_add(out_dict, param_keys, value.clone()) 34 | return out_dict 35 | 36 | self.params_dict = load_dict(state_dict, prefix + 'params_dict') 37 | self.params_dict.requires_grad_(False) 38 | return 39 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/model/common/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from diffusers.optimization import ( 2 | Union, SchedulerType, Optional, 3 | Optimizer, TYPE_TO_SCHEDULER_FUNCTION 4 | ) 5 | 6 | def get_scheduler( 7 | name: Union[str, SchedulerType], 8 | optimizer: Optimizer, 9 | num_warmup_steps: Optional[int] = None, 10 | num_training_steps: Optional[int] = None, 11 | **kwargs 12 | ): 13 | """ 14 | Added kwargs vs diffuser's original implementation 15 | 16 | Unified API to get any scheduler from its name. 17 | 18 | Args: 19 | name (`str` or `SchedulerType`): 20 | The name of the scheduler to use. 21 | optimizer (`torch.optim.Optimizer`): 22 | The optimizer that will be used during training. 23 | num_warmup_steps (`int`, *optional*): 24 | The number of warmup steps to do. This is not required by all schedulers (hence the argument being 25 | optional), the function will raise an error if it's unset and the scheduler type requires it. 26 | num_training_steps (`int``, *optional*): 27 | The number of training steps to do. This is not required by all schedulers (hence the argument being 28 | optional), the function will raise an error if it's unset and the scheduler type requires it. 29 | """ 30 | name = SchedulerType(name) 31 | schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] 32 | if name == SchedulerType.CONSTANT: 33 | return schedule_func(optimizer, **kwargs) 34 | 35 | # All other schedulers require `num_warmup_steps` 36 | if num_warmup_steps is None: 37 | raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") 38 | 39 | if name == SchedulerType.CONSTANT_WITH_WARMUP: 40 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs) 41 | 42 | # All other schedulers require `num_training_steps` 43 | if num_training_steps is None: 44 | raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") 45 | 46 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs) 47 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/model/common/module_attr_mixin.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class ModuleAttrMixin(nn.Module): 4 | def __init__(self): 5 | super().__init__() 6 | self._dummy_variable = nn.Parameter() 7 | 8 | @property 9 | def device(self): 10 | return next(iter(self.parameters())).device 11 | 12 | @property 13 | def dtype(self): 14 | return next(iter(self.parameters())).dtype 15 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/model/common/shape_util.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple, Callable 2 | import torch 3 | import torch.nn as nn 4 | 5 | def get_module_device(m: nn.Module): 6 | device = torch.device('cpu') 7 | try: 8 | param = next(iter(m.parameters())) 9 | device = param.device 10 | except StopIteration: 11 | pass 12 | return device 13 | 14 | @torch.no_grad() 15 | def get_output_shape( 16 | input_shape: Tuple[int], 17 | net: Callable[[torch.Tensor], torch.Tensor] 18 | ): 19 | device = get_module_device(net) 20 | test_input = torch.zeros((1,)+tuple(input_shape), device=device) 21 | test_output = net(test_input) 22 | output_shape = tuple(test_output.shape[1:]) 23 | return output_shape 24 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/model/diffusion/conv1d_components.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | # from einops.layers.torch import Rearrange 5 | 6 | 7 | class Downsample1d(nn.Module): 8 | def __init__(self, dim): 9 | super().__init__() 10 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1) 11 | 12 | def forward(self, x): 13 | return self.conv(x) 14 | 15 | class Upsample1d(nn.Module): 16 | def __init__(self, dim): 17 | super().__init__() 18 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) 19 | 20 | def forward(self, x): 21 | return self.conv(x) 22 | 23 | class Conv1dBlock(nn.Module): 24 | ''' 25 | Conv1d --> GroupNorm --> Mish 26 | ''' 27 | 28 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): 29 | super().__init__() 30 | 31 | self.block = nn.Sequential( 32 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), 33 | # Rearrange('batch channels horizon -> batch channels 1 horizon'), 34 | nn.GroupNorm(n_groups, out_channels), 35 | # Rearrange('batch channels 1 horizon -> batch channels horizon'), 36 | nn.Mish(), 37 | ) 38 | 39 | def forward(self, x): 40 | return self.block(x) 41 | 42 | 43 | def test(): 44 | cb = Conv1dBlock(256, 128, kernel_size=3) 45 | x = torch.zeros((1,256,16)) 46 | o = cb(x) 47 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/model/diffusion/positional_embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | class SinusoidalPosEmb(nn.Module): 6 | def __init__(self, dim): 7 | super().__init__() 8 | self.dim = dim 9 | 10 | def forward(self, x): 11 | device = x.device 12 | half_dim = self.dim // 2 13 | emb = math.log(10000) / (half_dim - 1) 14 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 15 | emb = x[:, None] * emb[None, :] 16 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 17 | return emb 18 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/model/vision/model_getter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | def get_resnet(name, weights=None, **kwargs): 5 | """ 6 | name: resnet18, resnet34, resnet50 7 | weights: "IMAGENET1K_V1", "r3m" 8 | """ 9 | # load r3m weights 10 | if (weights == "r3m") or (weights == "R3M"): 11 | return get_r3m(name=name, **kwargs) 12 | 13 | func = getattr(torchvision.models, name) 14 | resnet = func(weights=weights, **kwargs) 15 | resnet.fc = torch.nn.Identity() 16 | return resnet 17 | 18 | def get_r3m(name, **kwargs): 19 | """ 20 | name: resnet18, resnet34, resnet50 21 | """ 22 | import r3m 23 | r3m.device = 'cpu' 24 | model = r3m.load_r3m(name) 25 | r3m_model = model.module 26 | resnet_model = r3m_model.convnet 27 | resnet_model = resnet_model.to('cpu') 28 | return resnet_model 29 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/policy/base_image_policy.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | import torch.nn as nn 4 | from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin 5 | from diffusion_policy.model.common.normalizer import LinearNormalizer 6 | 7 | class BaseImagePolicy(ModuleAttrMixin): 8 | # init accepts keyword argument shape_meta, see config/task/*_image.yaml 9 | 10 | def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 11 | """ 12 | obs_dict: 13 | str: B,To,* 14 | return: B,Ta,Da 15 | """ 16 | raise NotImplementedError() 17 | 18 | # reset state for stateful policies 19 | def reset(self): 20 | pass 21 | 22 | # ========== training =========== 23 | # no standard training interface except setting normalizer 24 | def set_normalizer(self, normalizer: LinearNormalizer): 25 | raise NotImplementedError() 26 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/policy/base_lowdim_policy.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | import torch.nn as nn 4 | from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin 5 | from diffusion_policy.model.common.normalizer import LinearNormalizer 6 | 7 | class BaseLowdimPolicy(ModuleAttrMixin): 8 | # ========= inference ============ 9 | # also as self.device and self.dtype for inference device transfer 10 | def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 11 | """ 12 | obs_dict: 13 | obs: B,To,Do 14 | return: 15 | action: B,Ta,Da 16 | To = 3 17 | Ta = 4 18 | T = 6 19 | |o|o|o| 20 | | | |a|a|a|a| 21 | |o|o| 22 | | |a|a|a|a|a| 23 | | | | | |a|a| 24 | """ 25 | raise NotImplementedError() 26 | 27 | # reset state for stateful policies 28 | def reset(self): 29 | pass 30 | 31 | # ========== training =========== 32 | # no standard training interface except setting normalizer 33 | def set_normalizer(self, normalizer: LinearNormalizer): 34 | raise NotImplementedError() 35 | 36 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/real_world/keystroke_counter.py: -------------------------------------------------------------------------------- 1 | from pynput.keyboard import Key, KeyCode, Listener 2 | from collections import defaultdict 3 | from threading import Lock 4 | 5 | class KeystrokeCounter(Listener): 6 | def __init__(self): 7 | self.key_count_map = defaultdict(lambda:0) 8 | self.key_press_list = list() 9 | self.lock = Lock() 10 | super().__init__(on_press=self.on_press, on_release=self.on_release) 11 | 12 | def on_press(self, key): 13 | with self.lock: 14 | self.key_count_map[key] += 1 15 | self.key_press_list.append(key) 16 | 17 | def on_release(self, key): 18 | pass 19 | 20 | def clear(self): 21 | with self.lock: 22 | self.key_count_map = defaultdict(lambda:0) 23 | self.key_press_list = list() 24 | 25 | def __getitem__(self, key): 26 | with self.lock: 27 | return self.key_count_map[key] 28 | 29 | def get_press_events(self): 30 | with self.lock: 31 | events = list(self.key_press_list) 32 | self.key_press_list = list() 33 | return events 34 | 35 | if __name__ == '__main__': 36 | import time 37 | with KeystrokeCounter() as counter: 38 | try: 39 | while True: 40 | print('Space:', counter[Key.space]) 41 | print('q:', counter[KeyCode(char='q')]) 42 | time.sleep(1/60) 43 | except KeyboardInterrupt: 44 | events = counter.get_press_events() 45 | print(events) 46 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/real_world/multi_camera_visualizer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import multiprocessing as mp 3 | import numpy as np 4 | import cv2 5 | from threadpoolctl import threadpool_limits 6 | from diffusion_policy.real_world.multi_realsense import MultiRealsense 7 | 8 | class MultiCameraVisualizer(mp.Process): 9 | def __init__(self, 10 | realsense: MultiRealsense, 11 | row, col, 12 | window_name='Multi Cam Vis', 13 | vis_fps=60, 14 | fill_value=0, 15 | rgb_to_bgr=True 16 | ): 17 | super().__init__() 18 | self.row = row 19 | self.col = col 20 | self.window_name = window_name 21 | self.vis_fps = vis_fps 22 | self.fill_value = fill_value 23 | self.rgb_to_bgr=rgb_to_bgr 24 | self.realsense = realsense 25 | # shared variables 26 | self.stop_event = mp.Event() 27 | 28 | def start(self, wait=False): 29 | super().start() 30 | 31 | def stop(self, wait=False): 32 | self.stop_event.set() 33 | if wait: 34 | self.stop_wait() 35 | 36 | def start_wait(self): 37 | pass 38 | 39 | def stop_wait(self): 40 | self.join() 41 | 42 | def run(self): 43 | cv2.setNumThreads(1) 44 | threadpool_limits(1) 45 | channel_slice = slice(None) 46 | if self.rgb_to_bgr: 47 | channel_slice = slice(None,None,-1) 48 | 49 | vis_data = None 50 | vis_img = None 51 | while not self.stop_event.is_set(): 52 | vis_data = self.realsense.get_vis(out=vis_data) 53 | color = vis_data['color'] 54 | N, H, W, C = color.shape 55 | assert C == 3 56 | oh = H * self.row 57 | ow = W * self.col 58 | if vis_img is None: 59 | vis_img = np.full((oh, ow, 3), 60 | fill_value=self.fill_value, dtype=np.uint8) 61 | for row in range(self.row): 62 | for col in range(self.col): 63 | idx = col + row * self.col 64 | h_start = H * row 65 | h_end = h_start + H 66 | w_start = W * col 67 | w_end = w_start + W 68 | if idx < N: 69 | # opencv uses bgr 70 | vis_img[h_start:h_end,w_start:w_end 71 | ] = color[idx,:,:,channel_slice] 72 | cv2.imshow(self.window_name, vis_img) 73 | cv2.pollKey() 74 | time.sleep(1 / self.vis_fps) 75 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/real_world/real_inference_util.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Callable, Tuple 2 | import numpy as np 3 | from diffusion_policy.common.cv2_util import get_image_transform 4 | 5 | def get_real_obs_dict( 6 | env_obs: Dict[str, np.ndarray], 7 | shape_meta: dict, 8 | ) -> Dict[str, np.ndarray]: 9 | obs_dict_np = dict() 10 | obs_shape_meta = shape_meta['obs'] 11 | for key, attr in obs_shape_meta.items(): 12 | type = attr.get('type', 'low_dim') 13 | shape = attr.get('shape') 14 | if type == 'rgb': 15 | this_imgs_in = env_obs[key] 16 | t,hi,wi,ci = this_imgs_in.shape 17 | co,ho,wo = shape 18 | assert ci == co 19 | out_imgs = this_imgs_in 20 | if (ho != hi) or (wo != wi) or (this_imgs_in.dtype == np.uint8): 21 | tf = get_image_transform( 22 | input_res=(wi,hi), 23 | output_res=(wo,ho), 24 | bgr_to_rgb=False) 25 | out_imgs = np.stack([tf(x) for x in this_imgs_in]) 26 | if this_imgs_in.dtype == np.uint8: 27 | out_imgs = out_imgs.astype(np.float32) / 255 28 | # THWC to TCHW 29 | obs_dict_np[key] = np.moveaxis(out_imgs,-1,1) 30 | elif type == 'low_dim': 31 | this_data_in = env_obs[key] 32 | if 'pose' in key and shape == (2,): 33 | # take X,Y coordinates 34 | this_data_in = this_data_in[...,[0,1]] 35 | obs_dict_np[key] = this_data_in 36 | return obs_dict_np 37 | 38 | 39 | def get_real_obs_resolution( 40 | shape_meta: dict 41 | ) -> Tuple[int, int]: 42 | out_res = None 43 | obs_shape_meta = shape_meta['obs'] 44 | for key, attr in obs_shape_meta.items(): 45 | type = attr.get('type', 'low_dim') 46 | shape = attr.get('shape') 47 | if type == 'rgb': 48 | co,ho,wo = shape 49 | if out_res is None: 50 | out_res = (wo, ho) 51 | assert out_res == (wo, ho) 52 | return out_res 53 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/scripts/bet_blockpush_conversion.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import sys 3 | import os 4 | import pathlib 5 | 6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) 7 | sys.path.append(ROOT_DIR) 8 | 9 | 10 | import os 11 | import click 12 | import pathlib 13 | import numpy as np 14 | from diffusion_policy.common.replay_buffer import ReplayBuffer 15 | 16 | @click.command() 17 | @click.option('-i', '--input', required=True, help='input dir contains npy files') 18 | @click.option('-o', '--output', required=True, help='output zarr path') 19 | @click.option('--abs_action', is_flag=True, default=False) 20 | def main(input, output, abs_action): 21 | data_directory = pathlib.Path(input) 22 | observations = np.load( 23 | data_directory / "multimodal_push_observations.npy" 24 | ) 25 | actions = np.load(data_directory / "multimodal_push_actions.npy") 26 | masks = np.load(data_directory / "multimodal_push_masks.npy") 27 | 28 | buffer = ReplayBuffer.create_empty_numpy() 29 | for i in range(len(masks)): 30 | eps_len = int(masks[i].sum()) 31 | obs = observations[i,:eps_len].astype(np.float32) 32 | action = actions[i,:eps_len].astype(np.float32) 33 | if abs_action: 34 | prev_eef_target = obs[:,8:10] 35 | next_eef_target = prev_eef_target + action 36 | action = next_eef_target 37 | data = { 38 | 'obs': obs, 39 | 'action': action 40 | } 41 | buffer.add_episode(data) 42 | 43 | buffer.save_to_path(zarr_path=output, chunk_length=-1) 44 | 45 | if __name__ == '__main__': 46 | main() 47 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/scripts/blockpush_abs_conversion.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import sys 3 | import os 4 | import pathlib 5 | 6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) 7 | sys.path.append(ROOT_DIR) 8 | 9 | import os 10 | import click 11 | import pathlib 12 | from diffusion_policy.common.replay_buffer import ReplayBuffer 13 | 14 | 15 | @click.command() 16 | @click.option('-i', '--input', required=True) 17 | @click.option('-o', '--output', required=True) 18 | @click.option('-t', '--target_eef_idx', default=8, type=int) 19 | def main(input, output, target_eef_idx): 20 | buffer = ReplayBuffer.copy_from_path(input) 21 | obs = buffer['obs'] 22 | action = buffer['action'] 23 | prev_eef_target = obs[:,target_eef_idx:target_eef_idx+action.shape[1]] 24 | next_eef_target = prev_eef_target + action 25 | action[:] = next_eef_target 26 | buffer.save_to_path(zarr_path=output, chunk_length=-1) 27 | 28 | if __name__ == '__main__': 29 | main() 30 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/scripts/episode_lengths.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import sys 3 | import os 4 | import pathlib 5 | 6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) 7 | sys.path.append(ROOT_DIR) 8 | 9 | import click 10 | import numpy as np 11 | import json 12 | from diffusion_policy.common.replay_buffer import ReplayBuffer 13 | 14 | @click.command() 15 | @click.option('--input', '-i', required=True) 16 | @click.option('--dt', default=0.1, type=float) 17 | def main(input, dt): 18 | buffer = ReplayBuffer.create_from_path(input) 19 | lengths = buffer.episode_lengths 20 | durations = lengths * dt 21 | result = { 22 | 'duration/mean': np.mean(durations) 23 | } 24 | 25 | text = json.dumps(result, indent=2) 26 | print(text) 27 | 28 | if __name__ == '__main__': 29 | main() 30 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/scripts/generate_bet_blockpush.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import sys 3 | import os 4 | import pathlib 5 | 6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) 7 | sys.path.append(ROOT_DIR) 8 | 9 | 10 | import os 11 | import click 12 | import pathlib 13 | import numpy as np 14 | from tqdm import tqdm 15 | from diffusion_policy.common.replay_buffer import ReplayBuffer 16 | from tf_agents.environments.wrappers import TimeLimit 17 | from tf_agents.environments.gym_wrapper import GymWrapper 18 | from tf_agents.trajectories.time_step import StepType 19 | from diffusion_policy.env.block_pushing.block_pushing_multimodal import BlockPushMultimodal 20 | from diffusion_policy.env.block_pushing.block_pushing import BlockPush 21 | from diffusion_policy.env.block_pushing.oracles.multimodal_push_oracle import MultimodalOrientedPushOracle 22 | 23 | @click.command() 24 | @click.option('-o', '--output', required=True) 25 | @click.option('-n', '--n_episodes', default=1000) 26 | @click.option('-c', '--chunk_length', default=-1) 27 | def main(output, n_episodes, chunk_length): 28 | 29 | buffer = ReplayBuffer.create_empty_numpy() 30 | env = TimeLimit(GymWrapper(BlockPushMultimodal()), duration=350) 31 | for i in tqdm(range(n_episodes)): 32 | print(i) 33 | obs_history = list() 34 | action_history = list() 35 | 36 | env.seed(i) 37 | policy = MultimodalOrientedPushOracle(env) 38 | time_step = env.reset() 39 | policy_state = policy.get_initial_state(1) 40 | while True: 41 | action_step = policy.action(time_step, policy_state) 42 | obs = np.concatenate(list(time_step.observation.values()), axis=-1) 43 | action = action_step.action 44 | obs_history.append(obs) 45 | action_history.append(action) 46 | 47 | if time_step.step_type == 2: 48 | break 49 | 50 | # state = env.wrapped_env().gym.get_pybullet_state() 51 | time_step = env.step(action) 52 | obs_history = np.array(obs_history) 53 | action_history = np.array(action_history) 54 | 55 | episode = { 56 | 'obs': obs_history, 57 | 'action': action_history 58 | } 59 | buffer.add_episode(episode) 60 | 61 | buffer.save_to_path(output, chunk_length=chunk_length) 62 | 63 | if __name__ == '__main__': 64 | main() 65 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/scripts/real_dataset_conversion.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import sys 3 | import os 4 | import pathlib 5 | 6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) 7 | sys.path.append(ROOT_DIR) 8 | 9 | import os 10 | import click 11 | import pathlib 12 | import zarr 13 | import cv2 14 | import threadpoolctl 15 | from diffusion_policy.real_world.real_data_conversion import real_data_to_replay_buffer 16 | 17 | @click.command() 18 | @click.option('--input', '-i', required=True) 19 | @click.option('--output', '-o', default=None) 20 | @click.option('--resolution', '-r', default='640x480') 21 | @click.option('--n_decoding_threads', '-nd', default=-1, type=int) 22 | @click.option('--n_encoding_threads', '-ne', default=-1, type=int) 23 | def main(input, output, resolution, n_decoding_threads, n_encoding_threads): 24 | out_resolution = tuple(int(x) for x in resolution.split('x')) 25 | input = pathlib.Path(os.path.expanduser(input)) 26 | in_zarr_path = input.joinpath('replay_buffer.zarr') 27 | in_video_dir = input.joinpath('videos') 28 | assert in_zarr_path.is_dir() 29 | assert in_video_dir.is_dir() 30 | if output is None: 31 | output = input.joinpath(resolution + '.zarr.zip') 32 | else: 33 | output = pathlib.Path(os.path.expanduser(output)) 34 | 35 | if output.exists(): 36 | click.confirm('Output path already exists! Overrite?', abort=True) 37 | 38 | cv2.setNumThreads(1) 39 | with threadpoolctl.threadpool_limits(1): 40 | replay_buffer = real_data_to_replay_buffer( 41 | dataset_path=str(input), 42 | out_resolutions=out_resolution, 43 | n_decoding_threads=n_decoding_threads, 44 | n_encoding_threads=n_encoding_threads 45 | ) 46 | 47 | print('Saving to disk') 48 | if output.suffix == '.zip': 49 | with zarr.ZipStore(output) as zip_store: 50 | replay_buffer.save_to_store( 51 | store=zip_store 52 | ) 53 | else: 54 | with zarr.DirectoryStore(output) as store: 55 | replay_buffer.save_to_store( 56 | store=store 57 | ) 58 | 59 | if __name__ == '__main__': 60 | main() 61 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/scripts/real_pusht_successrate.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import sys 3 | import os 4 | import pathlib 5 | 6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) 7 | sys.path.append(ROOT_DIR) 8 | 9 | import os 10 | import click 11 | import collections 12 | import numpy as np 13 | from tqdm import tqdm 14 | import json 15 | 16 | @click.command() 17 | @click.option( 18 | '--reference', '-r', required=True, 19 | help='Reference metrics_raw.json from demonstration dataset.' 20 | ) 21 | @click.option( 22 | '--input', '-i', required=True, 23 | help='Data search path' 24 | ) 25 | def main(reference, input): 26 | # compute the min last metric for demo metrics 27 | demo_metrics = json.load(open(reference, 'r')) 28 | demo_min_metrics = collections.defaultdict(lambda:float('inf')) 29 | for episode_idx, metrics in demo_metrics.items(): 30 | for key, value in metrics.items(): 31 | last_value = value[-1] 32 | demo_min_metrics[key] = min(demo_min_metrics[key], last_value) 33 | print(demo_min_metrics) 34 | 35 | # find all metric 36 | name = 'metrics_raw.json' 37 | search_dir = pathlib.Path(input) 38 | success_rate_map = dict() 39 | for json_path in search_dir.glob('**/'+name): 40 | rel_path = json_path.relative_to(search_dir) 41 | rel_name = str(rel_path.parent) 42 | this_metrics = json.load(json_path.open('r')) 43 | metric_success_idxs = collections.defaultdict(list) 44 | metric_failure_idxs = collections.defaultdict(list) 45 | for episode_idx, metrics in this_metrics.items(): 46 | for key, value in metrics.items(): 47 | last_value = value[-1] 48 | # print(episode_idx, key, last_value) 49 | demo_min = demo_min_metrics[key] 50 | if last_value >= demo_min: 51 | # success 52 | metric_success_idxs[key].append(episode_idx) 53 | else: 54 | metric_failure_idxs[key].append(episode_idx) 55 | # in case of no success 56 | _ = metric_success_idxs[key] 57 | _ = metric_failure_idxs[key] 58 | metric_success_rate = dict() 59 | n_episodes = len(this_metrics) 60 | for key, value in metric_success_idxs.items(): 61 | metric_success_rate[key] = len(value) / n_episodes 62 | # metric_success_rate['failured_idxs'] = metric_failure_idxs 63 | success_rate_map[rel_name] = metric_success_rate 64 | 65 | text = json.dumps(success_rate_map, indent=2) 66 | print(text) 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/scripts/robomimic_dataset_action_comparison.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import sys 3 | import os 4 | import pathlib 5 | 6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) 7 | sys.path.append(ROOT_DIR) 8 | 9 | import os 10 | import click 11 | import pathlib 12 | import h5py 13 | import numpy as np 14 | from tqdm import tqdm 15 | from scipy.spatial.transform import Rotation 16 | 17 | def read_all_actions(hdf5_file, metric_skip_steps=1): 18 | n_demos = len(hdf5_file['data']) 19 | all_actions = list() 20 | for i in tqdm(range(n_demos)): 21 | actions = hdf5_file[f'data/demo_{i}/actions'][:] 22 | all_actions.append(actions[metric_skip_steps:]) 23 | all_actions = np.concatenate(all_actions, axis=0) 24 | return all_actions 25 | 26 | 27 | @click.command() 28 | @click.option('-i', '--input', required=True, help='input hdf5 path') 29 | @click.option('-o', '--output', required=True, help='output hdf5 path. Parent directory must exist') 30 | def main(input, output): 31 | # process inputs 32 | input = pathlib.Path(input).expanduser() 33 | assert input.is_file() 34 | output = pathlib.Path(output).expanduser() 35 | assert output.is_file() 36 | 37 | input_file = h5py.File(str(input), 'r') 38 | output_file = h5py.File(str(output), 'r') 39 | 40 | input_all_actions = read_all_actions(input_file) 41 | output_all_actions = read_all_actions(output_file) 42 | pos_dist = np.linalg.norm(input_all_actions[:,:3] - output_all_actions[:,:3], axis=-1) 43 | rot_dist = (Rotation.from_rotvec(input_all_actions[:,3:6] 44 | ) * Rotation.from_rotvec(output_all_actions[:,3:6]).inv() 45 | ).magnitude() 46 | 47 | print(f'max pos dist: {pos_dist.max()}') 48 | print(f'max rot dist: {rot_dist.max()}') 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /pusht/diffusion_policy/shared_memory/shared_memory_util.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from dataclasses import dataclass 3 | import numpy as np 4 | from multiprocessing.managers import SharedMemoryManager 5 | from atomics import atomicview, MemoryOrder, UINT 6 | 7 | @dataclass 8 | class ArraySpec: 9 | name: str 10 | shape: Tuple[int] 11 | dtype: np.dtype 12 | 13 | 14 | class SharedAtomicCounter: 15 | def __init__(self, 16 | shm_manager: SharedMemoryManager, 17 | size :int=8 # 64bit int 18 | ): 19 | shm = shm_manager.SharedMemory(size=size) 20 | self.shm = shm 21 | self.size = size 22 | self.store(0) # initialize 23 | 24 | @property 25 | def buf(self): 26 | return self.shm.buf[:self.size] 27 | 28 | def load(self) -> int: 29 | with atomicview(buffer=self.buf, atype=UINT) as a: 30 | value = a.load(order=MemoryOrder.ACQUIRE) 31 | return value 32 | 33 | def store(self, value: int): 34 | with atomicview(buffer=self.buf, atype=UINT) as a: 35 | a.store(value, order=MemoryOrder.RELEASE) 36 | 37 | def add(self, value: int): 38 | with atomicview(buffer=self.buf, atype=UINT) as a: 39 | a.add(value, order=MemoryOrder.ACQ_REL) 40 | -------------------------------------------------------------------------------- /pusht/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | Training: 4 | python train.py --config-name=train_diffusion_lowdim_workspace 5 | """ 6 | 7 | import sys 8 | # use line-buffering for both stdout and stderr 9 | sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1) 10 | sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1) 11 | 12 | import hydra 13 | from omegaconf import OmegaConf 14 | import pathlib 15 | from diffusion_policy.workspace.base_workspace import BaseWorkspace 16 | 17 | # allows arbitrary python code execution in configs using the ${eval:''} resolver 18 | OmegaConf.register_new_resolver("eval", eval, replace=True) 19 | 20 | @hydra.main( 21 | version_base=None, 22 | config_path=str(pathlib.Path(__file__).parent.joinpath( 23 | 'diffusion_policy','config')) 24 | ) 25 | def main(cfg: OmegaConf): 26 | # resolve immediately so all the ${now:} resolvers 27 | # will use the same time. 28 | OmegaConf.resolve(cfg) 29 | 30 | cls = hydra.utils.get_class(cfg._target_) 31 | workspace: BaseWorkspace = cls(cfg) 32 | workspace.run() 33 | 34 | if __name__ == "__main__": 35 | main() 36 | -------------------------------------------------------------------------------- /real-robot/arp.py: -------------------------------------------------------------------------------- 1 | ../arp.py -------------------------------------------------------------------------------- /real-robot/config.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | job: 3 | name: default 4 | chdir: false 5 | run: 6 | dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M} 7 | 8 | 9 | 10 | output_dir: ${hydra:run.dir} 11 | wandb: null 12 | device: 0 13 | 14 | data: 15 | folder: './data/demonstrations' 16 | train_episodes: [78, 4, 72, 37, 16, 61, 59, 23, 66, 11, 3, 9, 75, 71, 7, 34, 43, 49, 68, 62, 15, 27, 50, 57, 20, 18, 12, 65, 44, 33, 39, 76, 0, 46, 31, 74, 8, 2, 28, 55, 22, 10, 67, 60, 17, 58, 25, 52, 73, 5, 79, 36, 38, 32, 45, 1, 77, 21, 56, 64, 13, 19, 14, 63, 69, 6, 48, 42, 35] 17 | eval_episodes: [51, 26, 70, 54, 29, 30, 47, 40, 53, 24] 18 | 19 | train: 20 | lr: 5e-5 # 1.25e-5 21 | batch_size: 16 22 | num_workers: 6 23 | 24 | lambda_weight_l2: 1e-4 25 | warmup_steps: 1000 26 | num_steps: 50000 27 | save_freq: 1000 28 | eval_freq: 1000 29 | disp_freq: 100 30 | 31 | 32 | model: 33 | image_size: 420 34 | patch_size: 14 35 | hidden_dim: 128 36 | dropout: 0.1 37 | pre_norm: False 38 | feedforward_activation: relu 39 | dim_feedforward: 256 40 | n_heads: 8 41 | n_encoder_layers: 8 42 | 43 | depth: 4 44 | 45 | max_seq_len: 40 46 | max_chunk_size: 40 47 | 48 | trans_aug_range: [0.1, 0.05, 0.05] 49 | rot_aug_range: [0, 0, 45] 50 | 51 | -------------------------------------------------------------------------------- /real-robot/eef_T_wrenchHead.txt: -------------------------------------------------------------------------------- 1 | -9.981230952413527868e-01 -6.016334860001502222e-02 -1.145471021895335534e-02 -1.265643587985196272e-01 2 | -6.013419166226659229e-02 9.981870760766909934e-01 -2.910802394098028177e-03 2.027579258541842289e-02 3 | 1.160872906480183743e-02 -2.216438772415468654e-03 -9.999307399171909472e-01 1.308345914334110016e-01 4 | 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 5 | -------------------------------------------------------------------------------- /real-robot/eef_sm_T_wrenchHead.txt: -------------------------------------------------------------------------------- 1 | -9.981230952413527868e-01 -6.016334860001502222e-02 -1.145471021895335534e-02 -2.027579258541842289e-02 2 | -6.013419166226659229e-02 9.981870760766909934e-01 -2.910802394098028177e-03 -2.027579258541842289e-02 3 | 1.160872906480183743e-02 -2.216438772415468654e-03 -9.999307399171909472e-01 0 4 | 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 5 | -------------------------------------------------------------------------------- /real-robot/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import ssl 2 | ssl._create_default_https_context = ssl._create_unverified_context 3 | import hydra 4 | import os.path as osp 5 | import sys 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | def configurable(config_path="config/default.yaml"): 9 | def wrapper(main_func): 10 | config_path_arg = None 11 | i = 1 12 | for a in sys.argv[1:]: 13 | if a.startswith("config="): 14 | config_path_arg = a.split("=")[-1] 15 | sys.argv.pop(i) 16 | i += 1 17 | break 18 | if config_path_arg is None: 19 | config_path_arg = config_path 20 | assert config_path_arg, "config file must be given by `config=path/to/file`" 21 | main_wrapper = hydra.main(config_path=osp.abspath(osp.dirname(config_path_arg)), 22 | config_name=osp.splitext(osp.basename(config_path_arg))[0], 23 | version_base=None) 24 | return main_wrapper(main_func) 25 | return wrapper 26 | 27 | 28 | 29 | def load_hydra_config(config_path, overrides=[]): 30 | 31 | from hydra import compose, initialize 32 | from omegaconf import OmegaConf 33 | 34 | with initialize(version_base=None, config_path=osp.dirname(config_path), job_name="load_config"): 35 | cfg = compose(config_name=osp.splitext(osp.basename(config_path))[0], overrides=overrides) 36 | return cfg 37 | 38 | 39 | def config_to_dict(cfg): 40 | return OmegaConf.to_container(cfg) 41 | 42 | -------------------------------------------------------------------------------- /real-robot/utils/object.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | from pathlib import Path 4 | 5 | def load_pkl(fp): 6 | if isinstance(fp, (str, Path)): 7 | with open(fp, 'rb') as f: 8 | return pickle.load(f) 9 | else: 10 | return pickle.load(fp) 11 | 12 | 13 | def to_device(lst, dev): 14 | if isinstance(lst, torch.Tensor): 15 | return lst.to(dev) 16 | else: 17 | if isinstance(lst[0], torch.Tensor): 18 | return [v.to(dev) for v in lst] 19 | else: 20 | return torch.as_tensor(lst, device=dev) -------------------------------------------------------------------------------- /real-robot/utils/stat_dict.py: -------------------------------------------------------------------------------- 1 | from runstats import Statistics 2 | from copy import copy 3 | 4 | def item(v): 5 | if hasattr(v, 'item'): 6 | return v.item() 7 | else: 8 | return v 9 | 10 | class StatisticsDict: 11 | def __init__(self) -> None: 12 | self._run_stats = {} 13 | self._loss_dict = {} 14 | self._stat_dict = {} 15 | 16 | def push(self, loss_dict): 17 | for k in loss_dict.keys(): 18 | if 'loss' in k and k not in self._run_stats: 19 | self._run_stats[k] = Statistics() 20 | self._loss_dict = copy(loss_dict) 21 | 22 | self._stat_dict.clear() 23 | for k, _ in self._run_stats.items(): 24 | if k in loss_dict: 25 | self._run_stats[k].push(item(loss_dict[k])) 26 | self._stat_dict[k] = self._run_stats[k].mean() 27 | 28 | @property 29 | def current(self): 30 | return self._loss_dict 31 | 32 | @property 33 | def running(self): 34 | return self._stat_dict 35 | 36 | def reset(self): 37 | self._run_stats.clear() 38 | self._loss_dict.clear() 39 | self._stat_dict.clear() 40 | -------------------------------------------------------------------------------- /real-robot/utils/vis.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | def im_concat(*imgs): 4 | dst = Image.new('RGB', (sum(im.width for im in imgs), imgs[0].height)) 5 | for i, im in enumerate(imgs): 6 | dst.paste(im, (i * imgs[0].width, 0)) 7 | return dst -------------------------------------------------------------------------------- /real-robot/wrench.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial.transform import Rotation as R 2 | import numpy as np 3 | 4 | eef_T_wrenchHead = np.loadtxt("eef_T_wrenchHead.txt") # (4, 4) 5 | 6 | eef_sm_T_wrenchHead = np.loadtxt("eef_sm_T_wrenchHead.txt") # (4, 4) 7 | 8 | 9 | def wrench_head_to_eef(wrenchHead_T_base: np.array) -> np.array: # (..., 4, 4) 10 | r""" 11 | convert wrench pose in the base's frame to the end-effector's pose in the base frame 12 | """ 13 | eef_T_base = wrenchHead_T_base @ eef_T_wrenchHead 14 | return eef_T_base # (..., 4, 4) 15 | 16 | def eef_to_wrench_head(eef_T_base: np.array) -> np.array: # (..., 4, 4) 17 | r""" 18 | convert end-effector pose in the base's frame to the wrench pose in the base frame 19 | """ 20 | wrenchHead_T_base = eef_T_base @ np.linalg.inv(eef_T_wrenchHead) 21 | return wrenchHead_T_base # (..., 4, 4) 22 | 23 | def pose_to_T(pose_vec: np.array) -> np.array: 24 | r""" 25 | pose_vec[np.array]: [x, y, z, qx, qy, qz, qw] 26 | --- 27 | T[np.array]: (4, 4) 28 | """ 29 | T = np.eye(4) 30 | T[:3, 3] = pose_vec[:3] 31 | T[:3, :3] = R.from_quat(pose_vec[3:]).as_matrix() 32 | return T 33 | 34 | def T_to_pose(T: np.array) -> np.array: 35 | r""" 36 | T[np.array]: (4, 4) 37 | --- 38 | pose_vec[np.array]: [x, y, z, qx, qy, qz, qw] 39 | """ 40 | pose_vec = np.zeros(7) 41 | pose_vec[:3] = T[:3, 3] 42 | pose_vec[3:] = R.from_matrix(T[:3, :3]).as_quat() 43 | return pose_vec 44 | 45 | 46 | def p3_to_p7(p3): 47 | x = np.zeros(7) 48 | x[:3] = p3 49 | x[-1]= 1.0 50 | return x 51 | 52 | 53 | def wrench_pose7_to_eef_pose7(p7): 54 | return T_to_pose(wrench_head_to_eef(pose_to_T(p7))) 55 | 56 | 57 | def wrench_pose3_to_eef_pose3(p3): 58 | p7 = p3_to_p7(p3) 59 | return T_to_pose(wrench_head_to_eef(pose_to_T(p7)))[:3] 60 | 61 | 62 | def eef_pose7_to_wrench_pose7(p7): 63 | return T_to_pose(eef_to_wrench_head(pose_to_T(p7))) 64 | 65 | 66 | def wrench_pose7_to_sm_eef_pose7(p7): 67 | return T_to_pose(pose_to_T(p7) @ np.linalg.inv(eef_sm_T_wrenchHead)) 68 | 69 | -------------------------------------------------------------------------------- /rlb/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlzxy/arp/73cb339c032b3c350966269f4729bd1f29e927a9/rlb/__init__.py -------------------------------------------------------------------------------- /rlb/arp.py: -------------------------------------------------------------------------------- 1 | ../arp.py -------------------------------------------------------------------------------- /rlb/configs/act.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | job: 3 | name: default 4 | chdir: false 5 | run: 6 | dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M} 7 | 8 | # ================================================ # 9 | 10 | py_module: act_policy 11 | 12 | model: 13 | weights: null 14 | 15 | hp: 16 | add_corr: true 17 | add_depth: true 18 | add_lang: true 19 | add_pixel_loc: true 20 | add_proprio: true 21 | attn_dim: 512 22 | attn_dim_head: 64 23 | attn_dropout: 0.1 24 | attn_heads: 8 25 | depth: 8 26 | feat_dim: 220 # 72*3 + 4 27 | im_channels: 64 28 | point_augment_noise: 0.05 29 | img_feat_dim: 3 30 | img_patch_size: 14 31 | img_size: 224 32 | lang_dim: 512 33 | lang_len: 77 34 | norm_corr: true 35 | pe_fix: true 36 | proprio_dim: 3 # 4 # 18 37 | mvt_cameras: ['top', 'left', 'front'] 38 | stage2_zoom_scale: 4 39 | stage2_waypoint_label_noise: 0.05 40 | rotation_aug: #null 41 | - [-2, -1, 0, -1, -2] 42 | - [0.1, 0.2, 0.4, 0.2, 0.1] 43 | use_xformers: true 44 | 45 | gt_hm_sigma: 1.5 46 | move_pc_in_bound: true 47 | place_with_mean: false 48 | 49 | amp: True 50 | bnb: True 51 | 52 | # lr should be thought on per sample basis 53 | # effective lr is multiplied by bs * num_devices 54 | lr: 1.25e-5 #1.25e-5 # 1e-4 55 | warmup_steps: 2000 56 | optimizer_type: lamb 57 | lr_cos_dec: true 58 | add_rgc_loss: true 59 | transform_augmentation: true 60 | transform_augmentation_xyz: [0.125, 0.125, 0.125] 61 | transform_augmentation_rpy: [0.0, 0.0, 45.0] 62 | lambda_weight_l2: 1e-4 # 1e-6 63 | num_rotation_classes: 72 64 | 65 | cos_dec_max_step: -1 # will be override during training 66 | 67 | render_with_cpp: true 68 | 69 | 70 | 71 | 72 | env: 73 | tasks: all # stack_cups 74 | cameras: ["front", "left_shoulder", "right_shoulder", "wrist"] 75 | scene_bounds: [-0.3, -0.5, 0.6, 0.7, 0.5, 1.6] # [x_min, y_min, z_min, x_max, y_max, z_max] - the metric volume to be voxelized 76 | image_size: 128 77 | time_in_state: false 78 | voxel_size: 100 79 | episode_length: 25 80 | rotation_resolution: 5 81 | origin_style_state: true 82 | 83 | train: 84 | bs: 96 # 48 85 | demo_folder: ./data/train 86 | epochs: 100 # 100 87 | num_gpus: 4 88 | num_workers: 8 #, need larger value 89 | num_transitions_per_epoch: 160000 90 | disp_freq: 100 91 | cached_dataset_path: null 92 | save_freq: 10000 93 | eval_mode: false 94 | k2k_sample_ratios: 95 | place_cups: 1.0 96 | stack_cups: 1.0 97 | close_jar: 1.0 98 | push_buttons: 1.0 99 | meat_off_grill: 1.0 100 | stack_blocks: 1.0 101 | reach_and_drag: 1.0 102 | slide_block_to_color_target: 1.0 103 | place_shape_in_shape_sorter: 1.0 104 | open_drawer: 1.0 105 | sweep_to_dustpan_of_size: 1.0 106 | put_groceries_in_cupboard: 1.0 107 | light_bulb_in: 1.0 108 | turn_tap: 1.0 109 | insert_onto_square_peg: 1.0 110 | put_item_in_drawer: 1.0 111 | put_money_in_safe: 1.0 112 | place_wine_at_rack_location: 1.0 113 | 114 | eval: 115 | datafolder: ./data/test 116 | episode_num: 25 117 | start_episode: 0 118 | headless: true 119 | save_video: false 120 | device: 0 121 | 122 | 123 | 124 | output_dir: ${hydra:run.dir} 125 | wandb: null 126 | -------------------------------------------------------------------------------- /rlb/configs/rvt1.official.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | job: 3 | name: default 4 | chdir: false 5 | run: 6 | dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M} 7 | 8 | # ================================================ # 9 | 10 | py_module: rvt 11 | 12 | model: 13 | weights: null 14 | 15 | hp: 16 | depth: 8 17 | img_size: 220 18 | add_proprio: true 19 | proprio_dim: 4 # 4, time is not part of the state 20 | add_lang: true 21 | lang_dim: 512 22 | lang_len: 77 23 | img_feat_dim: 3 24 | feat_dim: 220 # (72 * 3) + 2 + 2 25 | im_channels: 64 26 | attn_dim: 512 27 | attn_heads: 8 28 | attn_dim_head: 64 29 | activation: "lrelu" 30 | weight_tie_layers: false 31 | attn_dropout: 0.1 32 | decoder_dropout: 0.0 33 | img_patch_size: 11 34 | final_dim: 64 35 | self_cross_ver: 1 36 | add_corr: true 37 | add_pixel_loc: true 38 | add_depth: true 39 | pe_fix: true 40 | place_with_mean: true 41 | gt_hm_sigma: 1.5 42 | augmentation_ratio: 0.1 43 | move_pc_in_bound: True 44 | 45 | # lr should be thought on per sample basis 46 | # effective lr is multiplied by bs * num_devices 47 | lr: 1e-4 48 | warmup_steps: 2000 49 | optimizer_type: lamb 50 | lr_cos_dec: true 51 | add_rgc_loss: true 52 | transform_augmentation: true 53 | transform_augmentation_xyz: [0.125, 0.125, 0.125] 54 | transform_augmentation_rpy: [0.0, 0.0, 45.0] 55 | lambda_weight_l2: 1e-6 56 | num_rotation_classes: 72 57 | 58 | cos_dec_max_step: -1 # will be override during training 59 | 60 | env: 61 | tasks: all 62 | cameras: ["front", "left_shoulder", "right_shoulder", "wrist"] 63 | scene_bounds: [-0.3, -0.5, 0.6, 0.7, 0.5, 1.6] # [x_min, y_min, z_min, x_max, y_max, z_max] - the metric volume to be voxelized 64 | image_size: 128 65 | time_in_state: true 66 | voxel_size: 100 67 | episode_length: 25 68 | rotation_resolution: 5 69 | origin_style_state: true 70 | 71 | 72 | train: 73 | bs: 6 74 | demo_folder: ./data/train 75 | epochs: 15 76 | num_gpus: 1 77 | num_workers: 8 # need larger value 78 | num_transitions_per_epoch: 160000 79 | disp_freq: 100 80 | cached_dataset_path: null 81 | save_freq: 10000 82 | eval_mode: false 83 | k2k_sample_ratios: 84 | place_cups: 1.0 85 | stack_cups: 1.0 86 | close_jar: 1.0 87 | push_buttons: 1.0 88 | meat_off_grill: 1.0 89 | stack_blocks: 1.0 90 | reach_and_drag: 1.0 91 | slide_block_to_color_target: 1.0 92 | place_shape_in_shape_sorter: 1.0 93 | open_drawer: 1.0 94 | sweep_to_dustpan_of_size: 1.0 95 | put_groceries_in_cupboard: 1.0 96 | light_bulb_in: 1.0 97 | turn_tap: 1.0 98 | insert_onto_square_peg: 1.0 99 | put_item_in_drawer: 1.0 100 | put_money_in_safe: 1.0 101 | place_wine_at_rack_location: 1.0 102 | 103 | 104 | eval: 105 | datafolder: ./data/test 106 | episode_num: 25 107 | start_episode: 0 108 | headless: true 109 | save_video: false 110 | device: 0 111 | 112 | 113 | output_dir: ${hydra:run.dir} 114 | wandb: null -------------------------------------------------------------------------------- /rlb/configs/rvt2.official.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | job: 3 | name: default 4 | chdir: false 5 | run: 6 | dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M} 7 | 8 | # ================================================ # 9 | 10 | py_module: rvt2 11 | 12 | model: 13 | weights: null 14 | 15 | hp: 16 | add_corr: true 17 | add_depth: true 18 | add_lang: true 19 | add_pixel_loc: true 20 | add_proprio: true 21 | attn_dim: 512 22 | attn_dim_head: 64 23 | attn_dropout: 0.1 24 | attn_heads: 8 25 | depth: 8 26 | feat_dim: 220 # 72*3 + 4 27 | im_channels: 64 28 | point_augment_noise: 0.05 29 | img_feat_dim: 3 30 | img_patch_size: 14 31 | img_size: 224 32 | lang_dim: 512 33 | lang_len: 77 34 | norm_corr: true 35 | pe_fix: true 36 | proprio_dim: 4 # 37 | mvt_cameras: ['top', 'left', 'front'] 38 | stage2_zoom_scale: 4 39 | stage2_waypoint_label_noise: 0.05 40 | rot_x_y_aug: 2 41 | use_xformers: true 42 | 43 | gt_hm_sigma: 1.5 44 | move_pc_in_bound: true 45 | place_with_mean: false 46 | 47 | amp: True 48 | bnb: True 49 | 50 | # lr should be thought on per sample basis 51 | # effective lr is multiplied by bs * num_devices 52 | lr: 1.25e-5 # 1e-4 53 | warmup_steps: 2000 54 | optimizer_type: lamb 55 | lr_cos_dec: true 56 | add_rgc_loss: true 57 | transform_augmentation: true 58 | transform_augmentation_xyz: [0.125, 0.125, 0.125] 59 | transform_augmentation_rpy: [0.0, 0.0, 45.0] 60 | lambda_weight_l2: 1e-4 # 1e-6 61 | num_rotation_classes: 72 62 | 63 | cos_dec_max_step: -1 # will be override during training 64 | 65 | render_with_cpp: true 66 | 67 | env: 68 | tasks: all 69 | cameras: ["front", "left_shoulder", "right_shoulder", "wrist"] 70 | scene_bounds: [-0.3, -0.5, 0.6, 0.7, 0.5, 1.6] # [x_min, y_min, z_min, x_max, y_max, z_max] - the metric volume to be voxelized 71 | image_size: 128 72 | time_in_state: true 73 | voxel_size: 100 74 | episode_length: 25 75 | rotation_resolution: 5 76 | origin_style_state: true 77 | 78 | train: 79 | bs: 48 80 | demo_folder: ./data/train 81 | epochs: 100 # 100 82 | num_gpus: 4 83 | num_workers: 8 # need larger value 84 | num_transitions_per_epoch: 160000 85 | disp_freq: 100 86 | cached_dataset_path: null 87 | save_freq: 10000 88 | eval_mode: false 89 | k2k_sample_ratios: 90 | place_cups: 1.0 91 | stack_cups: 1.0 92 | close_jar: 1.0 93 | push_buttons: 1.0 94 | meat_off_grill: 1.0 95 | stack_blocks: 1.0 96 | reach_and_drag: 1.0 97 | slide_block_to_color_target: 1.0 98 | place_shape_in_shape_sorter: 1.0 99 | open_drawer: 1.0 100 | sweep_to_dustpan_of_size: 1.0 101 | put_groceries_in_cupboard: 1.0 102 | light_bulb_in: 1.0 103 | turn_tap: 1.0 104 | insert_onto_square_peg: 1.0 105 | put_item_in_drawer: 1.0 106 | put_money_in_safe: 1.0 107 | place_wine_at_rack_location: 1.0 108 | 109 | eval: 110 | datafolder: ./data/test 111 | episode_num: 25 112 | start_episode: 0 113 | headless: true 114 | save_video: false 115 | device: 0 116 | 117 | 118 | 119 | output_dir: ${hydra:run.dir} 120 | wandb: null -------------------------------------------------------------------------------- /rlb/configs/rvt2.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | job: 3 | name: default 4 | chdir: false 5 | run: 6 | dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M} 7 | 8 | # ================================================ # 9 | 10 | py_module: rvt2 11 | 12 | model: 13 | weights: null 14 | 15 | hp: 16 | add_corr: true 17 | add_depth: true 18 | add_lang: true 19 | add_pixel_loc: true 20 | add_proprio: true 21 | attn_dim: 512 22 | attn_dim_head: 64 23 | attn_dropout: 0.1 24 | attn_heads: 8 25 | depth: 8 26 | feat_dim: 220 # 72*3 + 4 27 | im_channels: 64 28 | point_augment_noise: 0.05 29 | img_feat_dim: 3 30 | img_patch_size: 14 31 | img_size: 224 32 | lang_dim: 512 33 | lang_len: 77 34 | norm_corr: true 35 | pe_fix: true 36 | proprio_dim: 3 # 4 # 18 37 | mvt_cameras: ['top', 'left', 'front'] 38 | stage2_zoom_scale: 4 39 | stage2_waypoint_label_noise: 0.05 40 | rot_x_y_aug: 2 41 | use_xformers: true 42 | 43 | gt_hm_sigma: 1.5 44 | move_pc_in_bound: true 45 | place_with_mean: false 46 | 47 | amp: True 48 | bnb: True 49 | 50 | # lr should be thought on per sample basis 51 | # effective lr is multiplied by bs * num_devices 52 | lr: 1.25e-5 # 1e-4 53 | warmup_steps: 2000 54 | optimizer_type: lamb 55 | lr_cos_dec: true 56 | add_rgc_loss: true 57 | transform_augmentation: true 58 | transform_augmentation_xyz: [0.125, 0.125, 0.125] 59 | transform_augmentation_rpy: [0.0, 0.0, 45.0] 60 | lambda_weight_l2: 1e-4 # 1e-6 61 | num_rotation_classes: 72 62 | 63 | cos_dec_max_step: -1 # will be override during training 64 | 65 | render_with_cpp: true 66 | 67 | env: 68 | tasks: all 69 | cameras: ["front", "left_shoulder", "right_shoulder", "wrist"] 70 | scene_bounds: [-0.3, -0.5, 0.6, 0.7, 0.5, 1.6] # [x_min, y_min, z_min, x_max, y_max, z_max] - the metric volume to be voxelized 71 | image_size: 128 72 | time_in_state: false 73 | voxel_size: 100 74 | episode_length: 25 75 | rotation_resolution: 5 76 | origin_style_state: true 77 | 78 | train: 79 | bs: 48 80 | demo_folder: ./data/train 81 | epochs: 100 # 100 82 | num_gpus: 4 83 | num_workers: 8 # need larger value 84 | num_transitions_per_epoch: 160000 85 | disp_freq: 100 86 | cached_dataset_path: null 87 | save_freq: 10000 88 | eval_mode: false 89 | k2k_sample_ratios: 90 | place_cups: 1.0 91 | stack_cups: 1.0 92 | close_jar: 1.0 93 | push_buttons: 1.0 94 | meat_off_grill: 1.0 95 | stack_blocks: 1.0 96 | reach_and_drag: 1.0 97 | slide_block_to_color_target: 1.0 98 | place_shape_in_shape_sorter: 1.0 99 | open_drawer: 1.0 100 | sweep_to_dustpan_of_size: 1.0 101 | put_groceries_in_cupboard: 1.0 102 | light_bulb_in: 1.0 103 | turn_tap: 1.0 104 | insert_onto_square_peg: 1.0 105 | put_item_in_drawer: 1.0 106 | put_money_in_safe: 1.0 107 | place_wine_at_rack_location: 1.0 108 | 109 | eval: 110 | datafolder: ./data/test 111 | episode_num: 25 112 | start_episode: 0 113 | headless: true 114 | save_video: false 115 | device: 0 116 | 117 | 118 | 119 | output_dir: ${hydra:run.dir} 120 | wandb: null -------------------------------------------------------------------------------- /rlb/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import ssl 2 | ssl._create_default_https_context = ssl._create_unverified_context 3 | import hydra 4 | import os.path as osp 5 | import sys 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | def configurable(config_path="config/default.yaml"): 9 | def wrapper(main_func): 10 | config_path_arg = None 11 | i = 1 12 | for a in sys.argv[1:]: 13 | if a.startswith("config="): 14 | config_path_arg = a.split("=")[-1] 15 | sys.argv.pop(i) 16 | i += 1 17 | break 18 | if config_path_arg is None: 19 | config_path_arg = config_path 20 | assert config_path_arg, "config file must be given by `config=path/to/file`" 21 | main_wrapper = hydra.main(config_path=osp.abspath(osp.dirname(config_path_arg)), 22 | config_name=osp.splitext(osp.basename(config_path_arg))[0], 23 | version_base=None) 24 | return main_wrapper(main_func) 25 | return wrapper 26 | 27 | 28 | 29 | def load_hydra_config(config_path, overrides=[]): 30 | 31 | from hydra import compose, initialize 32 | from omegaconf import OmegaConf 33 | 34 | with initialize(version_base=None, config_path=osp.dirname(config_path), job_name="load_config"): 35 | cfg = compose(config_name=osp.splitext(osp.basename(config_path))[0], overrides=overrides) 36 | return cfg 37 | 38 | 39 | def config_to_dict(cfg): 40 | return OmegaConf.to_container(cfg) 41 | 42 | -------------------------------------------------------------------------------- /rlb/utils/clip.py: -------------------------------------------------------------------------------- 1 | import clip 2 | import torch 3 | 4 | """ 5 | tokens = clip.tokenize([description]).numpy() 6 | token_tensor = torch.from_numpy(tokens).to(device) 7 | with torch.no_grad(): 8 | lang_feats, lang_embs = _clip_encode_text(clip_model, token_tensor) 9 | """ 10 | 11 | # extract CLIP language features for goal string 12 | def clip_encode_text(clip_model, text): 13 | x = clip_model.token_embedding(text).type( 14 | clip_model.dtype 15 | ) # [batch_size, n_ctx, d_model] 16 | 17 | x = x + clip_model.positional_embedding.type(clip_model.dtype) 18 | x = x.permute(1, 0, 2) # NLD -> LND 19 | x = clip_model.transformer(x) 20 | x = x.permute(1, 0, 2) # LND -> NLD 21 | x = clip_model.ln_final(x).type(clip_model.dtype) 22 | 23 | emb = x.clone() 24 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ clip_model.text_projection 25 | 26 | return x, emb -------------------------------------------------------------------------------- /rlb/utils/dist.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from contextlib import closing 3 | 4 | def find_free_port(): 5 | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: 6 | s.bind(('', 0)) 7 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 8 | return s.getsockname()[1] -------------------------------------------------------------------------------- /rlb/utils/str.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | 4 | def insert_uline_before_cap(str): 5 | return reduce(lambda x, y: x + ('_' if y.isupper() else '') + y, str).lower() -------------------------------------------------------------------------------- /rlb/utils/vis.py: -------------------------------------------------------------------------------- 1 | from torchvision.utils import Optional, Tuple, Union, ImageDraw, List, Image 2 | from torchvision.transforms.functional import pil_to_tensor, to_pil_image 3 | import torch 4 | import numpy as np 5 | 6 | 7 | 8 | 9 | @torch.no_grad() 10 | def draw_keypoints( 11 | image: torch.Tensor, 12 | keypoints: torch.Tensor, 13 | connectivity: Optional[List[Tuple[int, int]]] = None, 14 | colors: Optional[Union[str, Tuple[int, int, int]]] = (255, 0, 0), 15 | line_color = 'white', 16 | radius: int = 2, 17 | width: int = 3, 18 | output_pil=True 19 | ) -> torch.Tensor: 20 | def is_valid(*args): 21 | return all([a >= 0 for a in args]) 22 | 23 | if isinstance(image, np.ndarray): 24 | image = torch.from_numpy(image) 25 | if image.shape[-1] == 3: 26 | image = image.permute(2, 0, 1) 27 | 28 | POINT_SIZE = keypoints.shape[-1] 29 | if isinstance(keypoints, np.ndarray): 30 | keypoints = torch.from_numpy(keypoints) 31 | 32 | keypoints = keypoints.reshape(1, -1, POINT_SIZE) 33 | 34 | ndarr = image.permute(1, 2, 0).cpu().numpy() 35 | img_to_draw = Image.fromarray(ndarr) 36 | draw = ImageDraw.Draw(img_to_draw, None if POINT_SIZE == 2 else 'RGBA') 37 | keypoints = keypoints.clone() 38 | if POINT_SIZE == 3: 39 | keypoints[:, :, -1] *= 255 40 | img_kpts = keypoints.to(torch.int64).tolist() 41 | 42 | for kpt_id, kpt_inst in enumerate(img_kpts): 43 | for inst_id, kpt in enumerate(kpt_inst): 44 | if not is_valid(*kpt): 45 | continue 46 | x1 = kpt[0] - radius 47 | x2 = kpt[0] + radius 48 | y1 = kpt[1] - radius 49 | y2 = kpt[1] + radius 50 | if len(kpt) == 3: 51 | kp_color = colors + (int(kpt[2]), ) 52 | else: 53 | kp_color = colors 54 | draw.ellipse([x1, y1, x2, y2], fill=kp_color, outline=None, width=0) 55 | 56 | if connectivity: 57 | for connection in connectivity: 58 | start_pt_x = kpt_inst[connection[0]][0] 59 | start_pt_y = kpt_inst[connection[0]][1] 60 | 61 | end_pt_x = kpt_inst[connection[1]][0] 62 | end_pt_y = kpt_inst[connection[1]][1] 63 | 64 | if not is_valid(start_pt_x, start_pt_y, end_pt_x, end_pt_y): 65 | continue 66 | 67 | draw.line( 68 | ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)), 69 | width=width, fill=line_color 70 | ) 71 | if output_pil: 72 | return img_to_draw 73 | else: 74 | return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) 75 | 76 | 77 | 78 | --------------------------------------------------------------------------------