├── .gitignore ├── LICENSE ├── README.md ├── conda_environment.yaml ├── equi_diffpo ├── codecs │ └── imagecodecs_numcodecs.py ├── common │ ├── checkpoint_util.py │ ├── cv2_util.py │ ├── env_util.py │ ├── json_logger.py │ ├── nested_dict_util.py │ ├── normalize_util.py │ ├── pose_trajectory_interpolator.py │ ├── precise_sleep.py │ ├── pymunk_override.py │ ├── pymunk_util.py │ ├── pytorch_util.py │ ├── replay_buffer.py │ ├── robomimic_config_util.py │ ├── robomimic_util.py │ ├── sampler.py │ └── timestamp_accumulator.py ├── config │ ├── dp3.yaml │ ├── task │ │ ├── mimicgen_abs.yaml │ │ ├── mimicgen_pc_abs.yaml │ │ ├── mimicgen_rel.yaml │ │ ├── mimicgen_voxel_abs.yaml │ │ └── mimicgen_voxel_rel.yaml │ ├── train_act_abs.yaml │ ├── train_bc_rnn.yaml │ ├── train_diffusion_transformer.yaml │ ├── train_diffusion_unet.yaml │ ├── train_diffusion_unet_voxel_abs.yaml │ ├── train_equi_diffusion_unet_abs.yaml │ ├── train_equi_diffusion_unet_rel.yaml │ ├── train_equi_diffusion_unet_voxel_abs.yaml │ └── train_equi_diffusion_unet_voxel_rel.yaml ├── dataset │ ├── base_dataset.py │ ├── robomimic_replay_image_dataset.py │ ├── robomimic_replay_image_sym_dataset.py │ ├── robomimic_replay_lowdim_dataset.py │ ├── robomimic_replay_lowdim_sym_dataset.py │ ├── robomimic_replay_point_cloud_dataset.py │ └── robomimic_replay_voxel_sym_dataset.py ├── env │ └── robomimic │ │ ├── robomimic_image_wrapper.py │ │ └── robomimic_lowdim_wrapper.py ├── env_runner │ ├── base_image_runner.py │ ├── base_lowdim_runner.py │ ├── robomimic_image_runner.py │ └── robomimic_lowdim_runner.py ├── gym_util │ ├── async_vector_env.py │ ├── multistep_wrapper.py │ ├── sync_vector_env.py │ ├── video_recording_wrapper.py │ └── video_wrapper.py ├── model │ ├── common │ │ ├── dict_of_tensor_mixin.py │ │ ├── lr_scheduler.py │ │ ├── module_attr_mixin.py │ │ ├── normalizer.py │ │ ├── rotation_transformer.py │ │ ├── shape_util.py │ │ └── tensor_util.py │ ├── detr │ │ ├── LICENSE │ │ ├── README.md │ │ ├── main.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── backbone.py │ │ │ ├── detr_vae.py │ │ │ ├── position_encoding.py │ │ │ └── transformer.py │ │ ├── setup.py │ │ └── util │ │ │ ├── __init__.py │ │ │ ├── box_ops.py │ │ │ ├── misc.py │ │ │ └── plot_utils.py │ ├── diffusion │ │ ├── conditional_unet1d.py │ │ ├── conv1d_components.py │ │ ├── dp3_conditional_unet1d.py │ │ ├── ema_model.py │ │ ├── mask_generator.py │ │ ├── positional_embedding.py │ │ └── transformer_for_diffusion.py │ ├── equi │ │ ├── __init__.py │ │ ├── equi_conditional_unet1d.py │ │ ├── equi_conditional_unet1d_vel.py │ │ ├── equi_encoder.py │ │ └── equi_obs_encoder.py │ ├── unet │ │ └── obs_cond_unet1d.py │ └── vision │ │ ├── crop_randomizer.py │ │ ├── model_getter.py │ │ ├── multi_image_obs_encoder.py │ │ ├── pointnet_extractor.py │ │ ├── rot_randomizer.py │ │ ├── rot_randomizer_vel.py │ │ ├── voxel_crop_randomizer.py │ │ ├── voxel_rot_randomizer.py │ │ └── voxel_rot_randomizer_rel.py ├── policy │ ├── act_policy.py │ ├── base_image_policy.py │ ├── base_lowdim_policy.py │ ├── diffusion_equi_unet_cnn_enc_policy.py │ ├── diffusion_equi_unet_cnn_enc_policy_se2.py │ ├── diffusion_equi_unet_cnn_enc_rel_policy.py │ ├── diffusion_equi_unet_voxel_policy.py │ ├── diffusion_equi_unet_voxel_rel_policy.py │ ├── diffusion_transformer_hybrid_image_policy.py │ ├── diffusion_unet_hybrid_image_policy.py │ ├── diffusion_unet_voxel_policy.py │ ├── dp3.py │ └── robomimic_image_policy.py ├── scripts │ ├── dataset_states_to_obs.py │ ├── download_datasets.py │ ├── robomimic_dataset_action_comparison.py │ ├── robomimic_dataset_conversion.py │ └── robomimic_dataset_obs_conversion.py ├── shared_memory │ ├── shared_memory_queue.py │ ├── shared_memory_ring_buffer.py │ ├── shared_memory_util.py │ └── shared_ndarray.py └── workspace │ ├── base_workspace.py │ ├── train_act_workspace.py │ ├── train_diffusion_transformer_hybrid_workspace.py │ ├── train_diffusion_unet_hybrid_workspace.py │ ├── train_dp3_workspace.py │ ├── train_equi_workspace.py │ └── train_robomimic_image_workspace.py ├── img └── equi.gif ├── setup.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | bin 2 | logs 3 | wandb 4 | outputs 5 | data 6 | data_local 7 | .vscode 8 | _wandb 9 | 10 | **/.DS_Store 11 | 12 | fuse.cfg 13 | 14 | *.ai 15 | 16 | # Generation results 17 | results/ 18 | 19 | ray/auth.json 20 | 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | *.py[cod] 24 | *$py.class 25 | 26 | # C extensions 27 | *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | .eggs/ 37 | lib/ 38 | lib64/ 39 | parts/ 40 | sdist/ 41 | var/ 42 | wheels/ 43 | pip-wheel-metadata/ 44 | share/python-wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | MANIFEST 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .nox/ 64 | .coverage 65 | .coverage.* 66 | .cache 67 | nosetests.xml 68 | coverage.xml 69 | *.cover 70 | *.py,cover 71 | .hypothesis/ 72 | .pytest_cache/ 73 | 74 | # Translations 75 | *.mo 76 | *.pot 77 | 78 | # Django stuff: 79 | *.log 80 | local_settings.py 81 | db.sqlite3 82 | db.sqlite3-journal 83 | 84 | # Flask stuff: 85 | instance/ 86 | .webassets-cache 87 | 88 | # Scrapy stuff: 89 | .scrapy 90 | 91 | # Sphinx documentation 92 | docs/_build/ 93 | 94 | # PyBuilder 95 | target/ 96 | 97 | # Jupyter Notebook 98 | .ipynb_checkpoints 99 | 100 | # IPython 101 | profile_default/ 102 | ipython_config.py 103 | 104 | # pyenv 105 | .python-version 106 | 107 | # pipenv 108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 111 | # install all needed dependencies. 112 | #Pipfile.lock 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | equi_diffpo/scripts/equidiff_data_conversion.py 142 | equi_diffpo/model/equi/vec_conditional_unet1d_1.py 143 | update_max_score.py 144 | test4.png 145 | test3.png 146 | test2.png 147 | test1.png 148 | test.png 149 | sampled_xyz.png 150 | pc.pt 151 | metric3.py 152 | metric2.py 153 | metric1.py 154 | grouped_xyz.png 155 | all_xyz.png 156 | 1.png 157 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Dian Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Equivariant Diffusion Policy 2 | [Project Website](https://equidiff.github.io) | [Paper](https://arxiv.org/pdf/2407.01812) | [Video](https://youtu.be/xIFSx_NVROU?si=MaxsHmih6AnQKAVy) 3 | Dian Wang1, Stephen Hart2, David Surovik2, Tarik Kelestemur2, Haojie Huang1, Haibo Zhao1, Mark Yeatman2, Jiuguang Wang2, Robin Walters1, Robert Platt12 4 | 1Northeastern Univeristy, 2Boston Dynamics AI Institute 5 | Conference on Robot Learning 2024 (Oral) 6 | ![](img/equi.gif) | 7 | ## Installation 8 | 1. Install the following apt packages for mujoco: 9 | ```bash 10 | sudo apt install -y libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf 11 | ``` 12 | 1. Install gfortran (dependancy for escnn) 13 | ```bash 14 | sudo apt install -y gfortran 15 | ``` 16 | 17 | 1. Install [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge) (strongly recommended) or Anaconda 18 | 1. Clone this repo 19 | ```bash 20 | git clone https://github.com/pointW/equidiff.git 21 | cd equidiff 22 | ``` 23 | 1. Install environment: 24 | Use Mambaforge (strongly recommended): 25 | ```bash 26 | mamba env create -f conda_environment.yaml 27 | conda activate equidiff 28 | ``` 29 | or use Anaconda (not recommended): 30 | ```bash 31 | conda env create -f conda_environment.yaml 32 | conda activate equidiff 33 | ``` 34 | 1. Install mimicgen: 35 | ```bash 36 | cd .. 37 | git clone https://github.com/NVlabs/mimicgen_environments.git 38 | cd mimicgen_environments 39 | # This project was developed with Mimicgen v0.1.0. The latest version should work fine, but it is not tested 40 | git checkout 081f7dbbe5fff17b28c67ce8ec87c371f32526a9 41 | pip install -e . 42 | cd ../equidiff 43 | ``` 44 | 1. Make sure mujoco version is 2.3.2 (required by mimicgen) 45 | ```bash 46 | pip list | grep mujoco 47 | ``` 48 | 49 | ## Dataset 50 | ### Download Dataset 51 | Download dataset from MimicGen's hugging face: https://huggingface.co/datasets/amandlek/mimicgen_datasets/tree/main/core 52 | Make sure the dataset is kept under `/path/to/equidiff/data/robomimic/datasets/[dataset]/[dataset].hdf5` 53 | 54 | ### Generating Voxel and Point Cloud Observation 55 | 56 | ```bash 57 | # Template 58 | python equi_diffpo/scripts/dataset_states_to_obs.py --input data/robomimic/datasets/[dataset]/[dataset].hdf5 --output data/robomimic/datasets/[dataset]/[dataset]_voxel.hdf5 --num_workers=[n_worker] 59 | # Replace [dataset] and [n_worker] with your choices. 60 | # E.g., use 24 workers to generate point cloud and voxel observation for stack_d1 61 | python equi_diffpo/scripts/dataset_states_to_obs.py --input data/robomimic/datasets/stack_d1/stack_d1.hdf5 --output data/robomimic/datasets/stack_d1/stack_d1_voxel.hdf5 --num_workers=24 62 | ``` 63 | 64 | ### Convert Action Space in Dataset 65 | The downloaded dataset has a relative action space. To train with absolute action space, the dataset needs to be converted accordingly 66 | ```bash 67 | # Template 68 | python equi_diffpo/scripts/robomimic_dataset_conversion.py -i data/robomimic/datasets/[dataset]/[dataset].hdf5 -o data/robomimic/datasets/[dataset]/[dataset]_abs.hdf5 -n [n_worker] 69 | # Replace [dataset] and [n_worker] with your choices. 70 | # E.g., convert stack_d1 (non-voxel) with 12 workers 71 | python equi_diffpo/scripts/robomimic_dataset_conversion.py -i data/robomimic/datasets/stack_d1/stack_d1_voxel.hdf5 -o data/robomimic/datasets/stack_d1/stack_d1_abs.hdf5 -n 12 72 | # E.g., convert stack_d1_voxel (voxel) with 12 workers 73 | python equi_diffpo/scripts/robomimic_dataset_conversion.py -i data/robomimic/datasets/stack_d1/stack_d1_voxel.hdf5 -o data/robomimic/datasets/stack_d1/stack_d1_voxel_abs.hdf5 -n 12 74 | ``` 75 | 76 | ## Training with image observation 77 | To train Equivariant Diffusion Policy (with absolute pose control) in Stack D1 task: 78 | ```bash 79 | # Make sure you have the non-voxel converted dataset with absolute action space from the previous step 80 | python train.py --config-name=train_equi_diffusion_unet_abs task_name=stack_d1 n_demo=100 81 | ``` 82 | To train with relative pose control instead: 83 | ```bash 84 | python train.py --config-name=train_equi_diffusion_unet_rel task_name=stack_d1 n_demo=100 85 | ``` 86 | To train in other tasks, replace `stack_d1` with `stack_three_d1`, `square_d2`, `threading_d2`, `coffee_d2`, `three_piece_assembly_d2`, `hammer_cleanup_d1`, `mug_cleanup_d1`, `kitchen_d1`, `nut_assembly_d0`, `pick_place_d0`, `coffee_preparation_d1`. Notice that the corresponding dataset should be downloaded already. If training absolute pose control, the data conversion is also needed. 87 | 88 | To run environments on CPU (to save GPU memory), use `osmesa` instead of `egl` through `MUJOCO_GL=osmesa PYOPENGL_PLATFORM=osmesa`, e.g., 89 | ```bash 90 | MUJOCO_GL=osmesa PYOPENGL_PLATFORM=osmesa python train.py --config-name=train_equi_diffusion_unet_abs task_name=stack_d1 91 | ``` 92 | 93 | Equivariant Diffusion Policy requires around 22G GPU memory to run with batch size of 128 (default). To reduce the GPU usage, consider training with smaller batch size and/or reducing the hidden dimension 94 | ```bash 95 | # to train with batch size of 64 and hidden dimension of 64 96 | MUJOCO_GL=osmesa PYOPENGL_PLATTFORM=osmesa python train.py --config-name=train_equi_diffusion_unet_abs task_name=stack_d1 policy.enc_n_hidden=64 dataloader.batch_size=64 97 | ``` 98 | 99 | ## Training with voxel observation 100 | To train Equivariant Diffusion Policy (with absolute pose control) in Stack D1 task: 101 | ```bash 102 | # Make sure you have the voxel converted dataset with absolute action space from the previous step 103 | python train.py --config-name=train_equi_diffusion_unet_voxel_abs task_name=stack_d1 n_demo=100 104 | ``` 105 | 106 | ## License 107 | This repository is released under the MIT license. See [LICENSE](LICENSE) for additional details. 108 | 109 | ## Acknowledgement 110 | * Our repo is built upon the origional [Diffusion Policy](https://github.com/real-stanford/diffusion_policy) 111 | * Our ACT baseline is adaped from its [original repo](https://github.com/tonyzhaozh/act) 112 | * Our DP3 baseline is adaped from its [original repo](https://github.com/YanjieZe/3D-Diffusion-Policy) 113 | -------------------------------------------------------------------------------- /conda_environment.yaml: -------------------------------------------------------------------------------- 1 | name: equidiff 2 | channels: 3 | - pytorch 4 | - pytorch3d 5 | - nvidia 6 | - conda-forge 7 | dependencies: 8 | - python=3.9 9 | - pip=22.2.2 10 | - pytorch=2.1.0 11 | - torchaudio=2.1.0 12 | - torchvision=0.16.0 13 | - pytorch-cuda=11.8 14 | - pytorch3d=0.7.5 15 | - numpy=1.23.3 16 | - numba==0.56.4 17 | - scipy==1.9.1 18 | - py-opencv=4.6.0 19 | - cffi=1.15.1 20 | - ipykernel=6.16 21 | - matplotlib=3.6.1 22 | - zarr=2.12.0 23 | - numcodecs=0.10.2 24 | - h5py=3.7.0 25 | - hydra-core=1.2.0 26 | - einops=0.4.1 27 | - tqdm=4.64.1 28 | - dill=0.3.5.1 29 | - scikit-video=1.1.11 30 | - scikit-image=0.19.3 31 | - gym=0.21.0 32 | - pymunk=6.2.1 33 | - threadpoolctl=3.1.0 34 | - shapely=1.8.4 35 | - cython=0.29.32 36 | - imageio=2.22.0 37 | - imageio-ffmpeg=0.4.7 38 | - termcolor=2.0.1 39 | - tensorboard=2.10.1 40 | - tensorboardx=2.5.1 41 | - psutil=5.9.2 42 | - click=8.0.4 43 | - boto3=1.24.96 44 | - accelerate=0.13.2 45 | - datasets=2.6.1 46 | - diffusers=0.11.1 47 | - av=10.0.0 48 | - cmake=3.24.3 49 | # trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625 50 | - llvm-openmp=14 51 | # trick to force reinstall imagecodecs via pip 52 | - imagecodecs==2022.8.8 53 | - pip: 54 | - open3d 55 | - wandb==0.17.0 56 | - pygame 57 | - imagecodecs==2022.9.26 58 | - escnn @ https://github.com/pointW/escnn/archive/fc4714cb6dc0d2a32f9fcea35771968b89911109.tar.gz 59 | - robosuite @ https://github.com/ARISE-Initiative/robosuite/archive/b9d8d3de5e3dfd1724f4a0e6555246c460407daa.tar.gz 60 | - robomimic @ https://github.com/pointW/robomimic/archive/8aad5b3caaaac9289b1504438a7f5d3a76d06c07.tar.gz 61 | - robosuite-task-zoo @ https://github.com/pointW/robosuite-task-zoo/archive/0f8a7b2fa5d192e4e8800bebfe8090b28926f3ed.tar.gz 62 | -------------------------------------------------------------------------------- /equi_diffpo/common/checkpoint_util.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict 2 | import os 3 | 4 | class TopKCheckpointManager: 5 | def __init__(self, 6 | save_dir, 7 | monitor_key: str, 8 | mode='min', 9 | k=1, 10 | format_str='epoch={epoch:03d}-train_loss={train_loss:.3f}.ckpt' 11 | ): 12 | assert mode in ['max', 'min'] 13 | assert k >= 0 14 | 15 | self.save_dir = save_dir 16 | self.monitor_key = monitor_key 17 | self.mode = mode 18 | self.k = k 19 | self.format_str = format_str 20 | self.path_value_map = dict() 21 | 22 | def get_ckpt_path(self, data: Dict[str, float]) -> Optional[str]: 23 | if self.k == 0: 24 | return None 25 | 26 | value = data[self.monitor_key] 27 | ckpt_path = os.path.join( 28 | self.save_dir, self.format_str.format(**data)) 29 | 30 | if len(self.path_value_map) < self.k: 31 | # under-capacity 32 | self.path_value_map[ckpt_path] = value 33 | return ckpt_path 34 | 35 | # at capacity 36 | sorted_map = sorted(self.path_value_map.items(), key=lambda x: x[1]) 37 | min_path, min_value = sorted_map[0] 38 | max_path, max_value = sorted_map[-1] 39 | 40 | delete_path = None 41 | if self.mode == 'max': 42 | if value > min_value: 43 | delete_path = min_path 44 | else: 45 | if value < max_value: 46 | delete_path = max_path 47 | 48 | if delete_path is None: 49 | return None 50 | else: 51 | del self.path_value_map[delete_path] 52 | self.path_value_map[ckpt_path] = value 53 | 54 | if not os.path.exists(self.save_dir): 55 | os.mkdir(self.save_dir) 56 | 57 | if os.path.exists(delete_path): 58 | os.remove(delete_path) 59 | return ckpt_path 60 | -------------------------------------------------------------------------------- /equi_diffpo/common/cv2_util.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import math 3 | import cv2 4 | import numpy as np 5 | 6 | def draw_reticle(img, u, v, label_color): 7 | """ 8 | Draws a reticle (cross-hair) on the image at the given position on top of 9 | the original image. 10 | @param img (In/Out) uint8 3 channel image 11 | @param u X coordinate (width) 12 | @param v Y coordinate (height) 13 | @param label_color tuple of 3 ints for RGB color used for drawing. 14 | """ 15 | # Cast to int. 16 | u = int(u) 17 | v = int(v) 18 | 19 | white = (255, 255, 255) 20 | cv2.circle(img, (u, v), 10, label_color, 1) 21 | cv2.circle(img, (u, v), 11, white, 1) 22 | cv2.circle(img, (u, v), 12, label_color, 1) 23 | cv2.line(img, (u, v + 1), (u, v + 3), white, 1) 24 | cv2.line(img, (u + 1, v), (u + 3, v), white, 1) 25 | cv2.line(img, (u, v - 1), (u, v - 3), white, 1) 26 | cv2.line(img, (u - 1, v), (u - 3, v), white, 1) 27 | 28 | 29 | def draw_text( 30 | img, 31 | *, 32 | text, 33 | uv_top_left, 34 | color=(255, 255, 255), 35 | fontScale=0.5, 36 | thickness=1, 37 | fontFace=cv2.FONT_HERSHEY_SIMPLEX, 38 | outline_color=(0, 0, 0), 39 | line_spacing=1.5, 40 | ): 41 | """ 42 | Draws multiline with an outline. 43 | """ 44 | assert isinstance(text, str) 45 | 46 | uv_top_left = np.array(uv_top_left, dtype=float) 47 | assert uv_top_left.shape == (2,) 48 | 49 | for line in text.splitlines(): 50 | (w, h), _ = cv2.getTextSize( 51 | text=line, 52 | fontFace=fontFace, 53 | fontScale=fontScale, 54 | thickness=thickness, 55 | ) 56 | uv_bottom_left_i = uv_top_left + [0, h] 57 | org = tuple(uv_bottom_left_i.astype(int)) 58 | 59 | if outline_color is not None: 60 | cv2.putText( 61 | img, 62 | text=line, 63 | org=org, 64 | fontFace=fontFace, 65 | fontScale=fontScale, 66 | color=outline_color, 67 | thickness=thickness * 3, 68 | lineType=cv2.LINE_AA, 69 | ) 70 | cv2.putText( 71 | img, 72 | text=line, 73 | org=org, 74 | fontFace=fontFace, 75 | fontScale=fontScale, 76 | color=color, 77 | thickness=thickness, 78 | lineType=cv2.LINE_AA, 79 | ) 80 | 81 | uv_top_left += [0, h * line_spacing] 82 | 83 | 84 | def get_image_transform( 85 | input_res: Tuple[int,int]=(1280,720), 86 | output_res: Tuple[int,int]=(640,480), 87 | bgr_to_rgb: bool=False): 88 | 89 | iw, ih = input_res 90 | ow, oh = output_res 91 | rw, rh = None, None 92 | interp_method = cv2.INTER_AREA 93 | 94 | if (iw/ih) >= (ow/oh): 95 | # input is wider 96 | rh = oh 97 | rw = math.ceil(rh / ih * iw) 98 | if oh > ih: 99 | interp_method = cv2.INTER_LINEAR 100 | else: 101 | rw = ow 102 | rh = math.ceil(rw / iw * ih) 103 | if ow > iw: 104 | interp_method = cv2.INTER_LINEAR 105 | 106 | w_slice_start = (rw - ow) // 2 107 | w_slice = slice(w_slice_start, w_slice_start + ow) 108 | h_slice_start = (rh - oh) // 2 109 | h_slice = slice(h_slice_start, h_slice_start + oh) 110 | c_slice = slice(None) 111 | if bgr_to_rgb: 112 | c_slice = slice(None, None, -1) 113 | 114 | def transform(img: np.ndarray): 115 | assert img.shape == ((ih,iw,3)) 116 | # resize 117 | img = cv2.resize(img, (rw, rh), interpolation=interp_method) 118 | # crop 119 | img = img[h_slice, w_slice, c_slice] 120 | return img 121 | return transform 122 | 123 | def optimal_row_cols( 124 | n_cameras, 125 | in_wh_ratio, 126 | max_resolution=(1920, 1080) 127 | ): 128 | out_w, out_h = max_resolution 129 | out_wh_ratio = out_w / out_h 130 | 131 | n_rows = np.arange(n_cameras,dtype=np.int64) + 1 132 | n_cols = np.ceil(n_cameras / n_rows).astype(np.int64) 133 | cat_wh_ratio = in_wh_ratio * (n_cols / n_rows) 134 | ratio_diff = np.abs(out_wh_ratio - cat_wh_ratio) 135 | best_idx = np.argmin(ratio_diff) 136 | best_n_row = n_rows[best_idx] 137 | best_n_col = n_cols[best_idx] 138 | best_cat_wh_ratio = cat_wh_ratio[best_idx] 139 | 140 | rw, rh = None, None 141 | if best_cat_wh_ratio >= out_wh_ratio: 142 | # cat is wider 143 | rw = math.floor(out_w / best_n_col) 144 | rh = math.floor(rw / in_wh_ratio) 145 | else: 146 | rh = math.floor(out_h / best_n_row) 147 | rw = math.floor(rh * in_wh_ratio) 148 | 149 | # crop_resolution = (rw, rh) 150 | return rw, rh, best_n_col, best_n_row 151 | -------------------------------------------------------------------------------- /equi_diffpo/common/env_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def render_env_video(env, states, actions=None): 6 | observations = states 7 | imgs = list() 8 | for i in range(len(observations)): 9 | state = observations[i] 10 | env.set_state(state) 11 | if i == 0: 12 | env.set_state(state) 13 | img = env.render() 14 | # draw action 15 | if actions is not None: 16 | action = actions[i] 17 | coord = (action / 512 * 96).astype(np.int32) 18 | cv2.drawMarker(img, coord, 19 | color=(255,0,0), markerType=cv2.MARKER_CROSS, 20 | markerSize=8, thickness=1) 21 | imgs.append(img) 22 | imgs = np.array(imgs) 23 | return imgs 24 | -------------------------------------------------------------------------------- /equi_diffpo/common/json_logger.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable, Any, Sequence 2 | import os 3 | import copy 4 | import json 5 | import numbers 6 | import pandas as pd 7 | 8 | 9 | def read_json_log(path: str, 10 | required_keys: Sequence[str]=tuple(), 11 | **kwargs) -> pd.DataFrame: 12 | """ 13 | Read json-per-line file, with potentially incomplete lines. 14 | kwargs passed to pd.read_json 15 | """ 16 | lines = list() 17 | with open(path, 'r') as f: 18 | while True: 19 | # one json per line 20 | line = f.readline() 21 | if len(line) == 0: 22 | # EOF 23 | break 24 | elif not line.endswith('\n'): 25 | # incomplete line 26 | break 27 | is_relevant = False 28 | for k in required_keys: 29 | if k in line: 30 | is_relevant = True 31 | break 32 | if is_relevant: 33 | lines.append(line) 34 | if len(lines) < 1: 35 | return pd.DataFrame() 36 | json_buf = f'[{",".join([line for line in (line.strip() for line in lines) if line])}]' 37 | df = pd.read_json(json_buf, **kwargs) 38 | return df 39 | 40 | class JsonLogger: 41 | def __init__(self, path: str, 42 | filter_fn: Optional[Callable[[str,Any],bool]]=None): 43 | if filter_fn is None: 44 | filter_fn = lambda k,v: isinstance(v, numbers.Number) 45 | 46 | # default to append mode 47 | self.path = path 48 | self.filter_fn = filter_fn 49 | self.file = None 50 | self.last_log = None 51 | 52 | def start(self): 53 | # use line buffering 54 | try: 55 | self.file = file = open(self.path, 'r+', buffering=1) 56 | except FileNotFoundError: 57 | self.file = file = open(self.path, 'w+', buffering=1) 58 | 59 | # Move the pointer (similar to a cursor in a text editor) to the end of the file 60 | pos = file.seek(0, os.SEEK_END) 61 | 62 | # Read each character in the file one at a time from the last 63 | # character going backwards, searching for a newline character 64 | # If we find a new line, exit the search 65 | while pos > 0 and file.read(1) != "\n": 66 | pos -= 1 67 | file.seek(pos, os.SEEK_SET) 68 | # now the file pointer is at one past the last '\n' 69 | # and pos is at the last '\n'. 70 | last_line_end = file.tell() 71 | 72 | # find the start of second last line 73 | pos = max(0, pos-1) 74 | file.seek(pos, os.SEEK_SET) 75 | while pos > 0 and file.read(1) != "\n": 76 | pos -= 1 77 | file.seek(pos, os.SEEK_SET) 78 | # now the file pointer is at one past the second last '\n' 79 | last_line_start = file.tell() 80 | 81 | if last_line_start < last_line_end: 82 | # has last line of json 83 | last_line = file.readline() 84 | self.last_log = json.loads(last_line) 85 | 86 | # remove the last incomplete line 87 | file.seek(last_line_end) 88 | file.truncate() 89 | 90 | def stop(self): 91 | self.file.close() 92 | self.file = None 93 | 94 | def __enter__(self): 95 | self.start() 96 | return self 97 | 98 | def __exit__(self, exc_type, exc_val, exc_tb): 99 | self.stop() 100 | 101 | def log(self, data: dict): 102 | filtered_data = dict( 103 | filter(lambda x: self.filter_fn(*x), data.items())) 104 | # save current as last log 105 | self.last_log = filtered_data 106 | for k, v in filtered_data.items(): 107 | if isinstance(v, numbers.Integral): 108 | filtered_data[k] = int(v) 109 | elif isinstance(v, numbers.Number): 110 | filtered_data[k] = float(v) 111 | buf = json.dumps(filtered_data) 112 | # ensure one line per json 113 | buf = buf.replace('\n','') + '\n' 114 | self.file.write(buf) 115 | 116 | def get_last_log(self): 117 | return copy.deepcopy(self.last_log) 118 | -------------------------------------------------------------------------------- /equi_diffpo/common/nested_dict_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | def nested_dict_map(f, x): 4 | """ 5 | Map f over all leaf of nested dict x 6 | """ 7 | 8 | if not isinstance(x, dict): 9 | return f(x) 10 | y = dict() 11 | for key, value in x.items(): 12 | y[key] = nested_dict_map(f, value) 13 | return y 14 | 15 | def nested_dict_reduce(f, x): 16 | """ 17 | Map f over all values of nested dict x, and reduce to a single value 18 | """ 19 | if not isinstance(x, dict): 20 | return x 21 | 22 | reduced_values = list() 23 | for value in x.values(): 24 | reduced_values.append(nested_dict_reduce(f, value)) 25 | y = functools.reduce(f, reduced_values) 26 | return y 27 | 28 | 29 | def nested_dict_check(f, x): 30 | bool_dict = nested_dict_map(f, x) 31 | result = nested_dict_reduce(lambda x, y: x and y, bool_dict) 32 | return result 33 | -------------------------------------------------------------------------------- /equi_diffpo/common/precise_sleep.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | def precise_sleep(dt: float, slack_time: float=0.001, time_func=time.monotonic): 4 | """ 5 | Use hybrid of time.sleep and spinning to minimize jitter. 6 | Sleep dt - slack_time seconds first, then spin for the rest. 7 | """ 8 | t_start = time_func() 9 | if dt > slack_time: 10 | time.sleep(dt - slack_time) 11 | t_end = t_start + dt 12 | while time_func() < t_end: 13 | pass 14 | return 15 | 16 | def precise_wait(t_end: float, slack_time: float=0.001, time_func=time.monotonic): 17 | t_start = time_func() 18 | t_wait = t_end - t_start 19 | if t_wait > 0: 20 | t_sleep = t_wait - slack_time 21 | if t_sleep > 0: 22 | time.sleep(t_sleep) 23 | while time_func() < t_end: 24 | pass 25 | return 26 | -------------------------------------------------------------------------------- /equi_diffpo/common/pymunk_util.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import pymunk 3 | import pymunk.pygame_util 4 | import numpy as np 5 | 6 | COLLTYPE_DEFAULT = 0 7 | COLLTYPE_MOUSE = 1 8 | COLLTYPE_BALL = 2 9 | 10 | def get_body_type(static=False): 11 | body_type = pymunk.Body.DYNAMIC 12 | if static: 13 | body_type = pymunk.Body.STATIC 14 | return body_type 15 | 16 | 17 | def create_rectangle(space, 18 | pos_x,pos_y,width,height, 19 | density=3,static=False): 20 | body = pymunk.Body(body_type=get_body_type(static)) 21 | body.position = (pos_x,pos_y) 22 | shape = pymunk.Poly.create_box(body,(width,height)) 23 | shape.density = density 24 | space.add(body,shape) 25 | return body, shape 26 | 27 | 28 | def create_rectangle_bb(space, 29 | left, bottom, right, top, 30 | **kwargs): 31 | pos_x = (left + right) / 2 32 | pos_y = (top + bottom) / 2 33 | height = top - bottom 34 | width = right - left 35 | return create_rectangle(space, pos_x, pos_y, width, height, **kwargs) 36 | 37 | def create_circle(space, pos_x, pos_y, radius, density=3, static=False): 38 | body = pymunk.Body(body_type=get_body_type(static)) 39 | body.position = (pos_x, pos_y) 40 | shape = pymunk.Circle(body, radius=radius) 41 | shape.density = density 42 | shape.collision_type = COLLTYPE_BALL 43 | space.add(body, shape) 44 | return body, shape 45 | 46 | def get_body_state(body): 47 | state = np.zeros(6, dtype=np.float32) 48 | state[:2] = body.position 49 | state[2] = body.angle 50 | state[3:5] = body.velocity 51 | state[5] = body.angular_velocity 52 | return state 53 | -------------------------------------------------------------------------------- /equi_diffpo/common/pytorch_util.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Callable, List 2 | import collections 3 | import torch 4 | import torch.nn as nn 5 | 6 | def dict_apply( 7 | x: Dict[str, torch.Tensor], 8 | func: Callable[[torch.Tensor], torch.Tensor] 9 | ) -> Dict[str, torch.Tensor]: 10 | result = dict() 11 | for key, value in x.items(): 12 | if isinstance(value, dict): 13 | result[key] = dict_apply(value, func) 14 | else: 15 | result[key] = func(value) 16 | return result 17 | 18 | def pad_remaining_dims(x, target): 19 | assert x.shape == target.shape[:len(x.shape)] 20 | return x.reshape(x.shape + (1,)*(len(target.shape) - len(x.shape))) 21 | 22 | def dict_apply_split( 23 | x: Dict[str, torch.Tensor], 24 | split_func: Callable[[torch.Tensor], Dict[str, torch.Tensor]] 25 | ) -> Dict[str, torch.Tensor]: 26 | results = collections.defaultdict(dict) 27 | for key, value in x.items(): 28 | result = split_func(value) 29 | for k, v in result.items(): 30 | results[k][key] = v 31 | return results 32 | 33 | def dict_apply_reduce( 34 | x: List[Dict[str, torch.Tensor]], 35 | reduce_func: Callable[[List[torch.Tensor]], torch.Tensor] 36 | ) -> Dict[str, torch.Tensor]: 37 | result = dict() 38 | for key in x[0].keys(): 39 | result[key] = reduce_func([x_[key] for x_ in x]) 40 | return result 41 | 42 | 43 | def replace_submodules( 44 | root_module: nn.Module, 45 | predicate: Callable[[nn.Module], bool], 46 | func: Callable[[nn.Module], nn.Module]) -> nn.Module: 47 | """ 48 | predicate: Return true if the module is to be replaced. 49 | func: Return new module to use. 50 | """ 51 | if predicate(root_module): 52 | return func(root_module) 53 | 54 | bn_list = [k.split('.') for k, m 55 | in root_module.named_modules(remove_duplicate=True) 56 | if predicate(m)] 57 | for *parent, k in bn_list: 58 | parent_module = root_module 59 | if len(parent) > 0: 60 | parent_module = root_module.get_submodule('.'.join(parent)) 61 | if isinstance(parent_module, nn.Sequential): 62 | src_module = parent_module[int(k)] 63 | else: 64 | src_module = getattr(parent_module, k) 65 | tgt_module = func(src_module) 66 | if isinstance(parent_module, nn.Sequential): 67 | parent_module[int(k)] = tgt_module 68 | else: 69 | setattr(parent_module, k, tgt_module) 70 | # verify that all BN are replaced 71 | bn_list = [k.split('.') for k, m 72 | in root_module.named_modules(remove_duplicate=True) 73 | if predicate(m)] 74 | assert len(bn_list) == 0 75 | return root_module 76 | 77 | def optimizer_to(optimizer, device): 78 | for state in optimizer.state.values(): 79 | for k, v in state.items(): 80 | if isinstance(v, torch.Tensor): 81 | state[k] = v.to(device=device) 82 | return optimizer 83 | -------------------------------------------------------------------------------- /equi_diffpo/common/robomimic_config_util.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from robomimic.config import config_factory 3 | import robomimic.scripts.generate_paper_configs as gpc 4 | from robomimic.scripts.generate_paper_configs import ( 5 | modify_config_for_default_image_exp, 6 | modify_config_for_default_low_dim_exp, 7 | modify_config_for_dataset, 8 | ) 9 | 10 | def get_robomimic_config( 11 | algo_name='bc_rnn', 12 | hdf5_type='low_dim', 13 | task_name='square', 14 | dataset_type='ph' 15 | ): 16 | base_dataset_dir = '/tmp/null' 17 | filter_key = None 18 | 19 | # decide whether to use low-dim or image training defaults 20 | modifier_for_obs = modify_config_for_default_image_exp 21 | if hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"]: 22 | modifier_for_obs = modify_config_for_default_low_dim_exp 23 | 24 | algo_config_name = "bc" if algo_name == "bc_rnn" else algo_name 25 | config = config_factory(algo_name=algo_config_name) 26 | # turn into default config for observation modalities (e.g.: low-dim or rgb) 27 | config = modifier_for_obs(config) 28 | # add in config based on the dataset 29 | config = modify_config_for_dataset( 30 | config=config, 31 | task_name=task_name, 32 | dataset_type=dataset_type, 33 | hdf5_type=hdf5_type, 34 | base_dataset_dir=base_dataset_dir, 35 | filter_key=filter_key, 36 | ) 37 | # add in algo hypers based on dataset 38 | algo_config_modifier = getattr(gpc, f'modify_{algo_name}_config_for_dataset') 39 | config = algo_config_modifier( 40 | config=config, 41 | task_name=task_name, 42 | dataset_type=dataset_type, 43 | hdf5_type=hdf5_type, 44 | ) 45 | return config 46 | 47 | 48 | -------------------------------------------------------------------------------- /equi_diffpo/common/sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import numpy as np 3 | import numba 4 | from equi_diffpo.common.replay_buffer import ReplayBuffer 5 | 6 | 7 | @numba.jit(nopython=True) 8 | def create_indices( 9 | episode_ends:np.ndarray, sequence_length:int, 10 | episode_mask: np.ndarray, 11 | pad_before: int=0, pad_after: int=0, 12 | debug:bool=True) -> np.ndarray: 13 | episode_mask.shape == episode_ends.shape 14 | pad_before = min(max(pad_before, 0), sequence_length-1) 15 | pad_after = min(max(pad_after, 0), sequence_length-1) 16 | 17 | indices = list() 18 | for i in range(len(episode_ends)): 19 | if not episode_mask[i]: 20 | # skip episode 21 | continue 22 | start_idx = 0 23 | if i > 0: 24 | start_idx = episode_ends[i-1] 25 | end_idx = episode_ends[i] 26 | episode_length = end_idx - start_idx 27 | 28 | min_start = -pad_before 29 | max_start = episode_length - sequence_length + pad_after 30 | 31 | # range stops one idx before end 32 | for idx in range(min_start, max_start+1): 33 | buffer_start_idx = max(idx, 0) + start_idx 34 | buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx 35 | start_offset = buffer_start_idx - (idx+start_idx) 36 | end_offset = (idx+sequence_length+start_idx) - buffer_end_idx 37 | sample_start_idx = 0 + start_offset 38 | sample_end_idx = sequence_length - end_offset 39 | if debug: 40 | assert(start_offset >= 0) 41 | assert(end_offset >= 0) 42 | assert (sample_end_idx - sample_start_idx) == (buffer_end_idx - buffer_start_idx) 43 | indices.append([ 44 | buffer_start_idx, buffer_end_idx, 45 | sample_start_idx, sample_end_idx]) 46 | indices = np.array(indices) 47 | return indices 48 | 49 | 50 | def get_val_mask(n_episodes, val_ratio, seed=0): 51 | val_mask = np.zeros(n_episodes, dtype=bool) 52 | if val_ratio <= 0: 53 | return val_mask 54 | 55 | # have at least 1 episode for validation, and at least 1 episode for train 56 | n_val = min(max(1, round(n_episodes * val_ratio)), n_episodes-1) 57 | rng = np.random.default_rng(seed=seed) 58 | val_idxs = rng.choice(n_episodes, size=n_val, replace=False) 59 | val_mask[val_idxs] = True 60 | return val_mask 61 | 62 | 63 | def downsample_mask(mask, max_n, seed=0): 64 | # subsample training data 65 | train_mask = mask 66 | if (max_n is not None) and (np.sum(train_mask) > max_n): 67 | n_train = int(max_n) 68 | curr_train_idxs = np.nonzero(train_mask)[0] 69 | rng = np.random.default_rng(seed=seed) 70 | train_idxs_idx = rng.choice(len(curr_train_idxs), size=n_train, replace=False) 71 | train_idxs = curr_train_idxs[train_idxs_idx] 72 | train_mask = np.zeros_like(train_mask) 73 | train_mask[train_idxs] = True 74 | assert np.sum(train_mask) == n_train 75 | return train_mask 76 | 77 | class SequenceSampler: 78 | def __init__(self, 79 | replay_buffer: ReplayBuffer, 80 | sequence_length:int, 81 | pad_before:int=0, 82 | pad_after:int=0, 83 | keys=None, 84 | key_first_k=dict(), 85 | episode_mask: Optional[np.ndarray]=None, 86 | ): 87 | """ 88 | key_first_k: dict str: int 89 | Only take first k data from these keys (to improve perf) 90 | """ 91 | 92 | super().__init__() 93 | assert(sequence_length >= 1) 94 | if keys is None: 95 | keys = list(replay_buffer.keys()) 96 | 97 | episode_ends = replay_buffer.episode_ends[:] 98 | if episode_mask is None: 99 | episode_mask = np.ones(episode_ends.shape, dtype=bool) 100 | 101 | if np.any(episode_mask): 102 | indices = create_indices(episode_ends, 103 | sequence_length=sequence_length, 104 | pad_before=pad_before, 105 | pad_after=pad_after, 106 | episode_mask=episode_mask 107 | ) 108 | else: 109 | indices = np.zeros((0,4), dtype=np.int64) 110 | 111 | # (buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx) 112 | self.indices = indices 113 | self.keys = list(keys) # prevent OmegaConf list performance problem 114 | self.sequence_length = sequence_length 115 | self.replay_buffer = replay_buffer 116 | self.key_first_k = key_first_k 117 | 118 | def __len__(self): 119 | return len(self.indices) 120 | 121 | def sample_sequence(self, idx): 122 | buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx \ 123 | = self.indices[idx] 124 | result = dict() 125 | for key in self.keys: 126 | input_arr = self.replay_buffer[key] 127 | # performance optimization, avoid small allocation if possible 128 | if key not in self.key_first_k: 129 | sample = input_arr[buffer_start_idx:buffer_end_idx] 130 | else: 131 | # performance optimization, only load used obs steps 132 | n_data = buffer_end_idx - buffer_start_idx 133 | k_data = min(self.key_first_k[key], n_data) 134 | # fill value with Nan to catch bugs 135 | # the non-loaded region should never be used 136 | sample = np.full((n_data,) + input_arr.shape[1:], 137 | fill_value=np.nan, dtype=input_arr.dtype) 138 | try: 139 | sample[:k_data] = input_arr[buffer_start_idx:buffer_start_idx+k_data] 140 | except Exception as e: 141 | import pdb; pdb.set_trace() 142 | data = sample 143 | if (sample_start_idx > 0) or (sample_end_idx < self.sequence_length): 144 | data = np.zeros( 145 | shape=(self.sequence_length,) + input_arr.shape[1:], 146 | dtype=input_arr.dtype) 147 | if sample_start_idx > 0: 148 | data[:sample_start_idx] = sample[0] 149 | if sample_end_idx < self.sequence_length: 150 | data[sample_end_idx:] = sample[-1] 151 | data[sample_start_idx:sample_end_idx] = sample 152 | result[key] = data 153 | return result 154 | -------------------------------------------------------------------------------- /equi_diffpo/config/dp3.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - task: mimicgen_pc_abs 4 | 5 | name: train_dp3 6 | _target_: equi_diffpo.workspace.train_dp3_workspace.TrainDP3Workspace 7 | 8 | shape_meta: ${task.shape_meta} 9 | exp_name: "debug" 10 | 11 | task_name: stack_d1 12 | n_demo: 200 13 | horizon: 16 14 | n_obs_steps: 2 15 | n_action_steps: 8 16 | n_latency_steps: 0 17 | dataset_obs_steps: ${n_obs_steps} 18 | keypoint_visible_rate: 1.0 19 | obs_as_global_cond: True 20 | dataset_target: equi_diffpo.dataset.robomimic_replay_point_cloud_dataset.RobomimicReplayPointCloudDataset 21 | dataset_path: data/robomimic/datasets/${task_name}/${task_name}_voxel_abs.hdf5 22 | 23 | policy: 24 | _target_: equi_diffpo.policy.dp3.DP3 25 | use_point_crop: true 26 | condition_type: film 27 | use_down_condition: true 28 | use_mid_condition: true 29 | use_up_condition: true 30 | 31 | diffusion_step_embed_dim: 128 32 | down_dims: 33 | - 512 34 | - 1024 35 | - 2048 36 | crop_shape: 37 | - 80 38 | - 80 39 | encoder_output_dim: 64 40 | horizon: ${horizon} 41 | kernel_size: 5 42 | n_action_steps: ${n_action_steps} 43 | n_groups: 8 44 | n_obs_steps: ${n_obs_steps} 45 | 46 | noise_scheduler: 47 | _target_: diffusers.schedulers.scheduling_ddim.DDIMScheduler 48 | num_train_timesteps: 100 49 | beta_start: 0.0001 50 | beta_end: 0.02 51 | beta_schedule: squaredcos_cap_v2 52 | clip_sample: True 53 | set_alpha_to_one: True 54 | steps_offset: 0 55 | prediction_type: sample 56 | 57 | 58 | num_inference_steps: 10 59 | obs_as_global_cond: true 60 | shape_meta: ${shape_meta} 61 | 62 | use_pc_color: true 63 | pointnet_type: "pointnet" 64 | 65 | 66 | pointcloud_encoder_cfg: 67 | in_channels: 3 68 | out_channels: ${policy.encoder_output_dim} 69 | use_layernorm: true 70 | final_norm: layernorm # layernorm, none 71 | normal_channel: false 72 | 73 | 74 | ema: 75 | _target_: equi_diffpo.model.diffusion.ema_model.EMAModel 76 | update_after_step: 0 77 | inv_gamma: 1.0 78 | power: 0.75 79 | min_value: 0.0 80 | max_value: 0.9999 81 | 82 | dataloader: 83 | batch_size: 128 84 | num_workers: 8 85 | shuffle: True 86 | pin_memory: True 87 | persistent_workers: True 88 | 89 | val_dataloader: 90 | batch_size: 128 91 | num_workers: 8 92 | shuffle: False 93 | pin_memory: True 94 | persistent_workers: True 95 | 96 | optimizer: 97 | _target_: torch.optim.AdamW 98 | lr: 1.0e-4 99 | betas: [0.95, 0.999] 100 | eps: 1.0e-8 101 | weight_decay: 1.0e-6 102 | 103 | training: 104 | device: "cuda:0" 105 | seed: 42 106 | debug: False 107 | resume: True 108 | lr_scheduler: cosine 109 | lr_warmup_steps: 500 110 | num_epochs: ${eval:'50000 / ${n_demo}'} 111 | gradient_accumulate_every: 1 112 | use_ema: True 113 | rollout_every: ${eval:'1000 / ${n_demo}'} 114 | checkpoint_every: ${eval:'1000 / ${n_demo}'} 115 | val_every: 1 116 | sample_every: 5 117 | max_train_steps: null 118 | max_val_steps: null 119 | tqdm_interval_sec: 1.0 120 | 121 | logging: 122 | project: dp3_${task_name} 123 | resume: true 124 | mode: online 125 | name: dp3_${n_demo} 126 | tags: ["${name}", "${task_name}", "${exp_name}"] 127 | id: null 128 | group: null 129 | 130 | 131 | checkpoint: 132 | save_ckpt: False # if True, save checkpoint every checkpoint_every 133 | topk: 134 | monitor_key: test_mean_score 135 | mode: max 136 | k: 1 137 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt' 138 | save_last_ckpt: True # this only saves when save_ckpt is True 139 | save_last_snapshot: False 140 | 141 | multi_run: 142 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 143 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 144 | 145 | hydra: 146 | job: 147 | override_dirname: ${name} 148 | run: 149 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 150 | sweep: 151 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 152 | subdir: ${hydra.job.num} -------------------------------------------------------------------------------- /equi_diffpo/config/task/mimicgen_abs.yaml: -------------------------------------------------------------------------------- 1 | name: mimicgen_abs 2 | 3 | shape_meta: &shape_meta 4 | # acceptable types: rgb, low_dim 5 | obs: 6 | agentview_image: 7 | shape: [3, 84, 84] 8 | type: rgb 9 | robot0_eye_in_hand_image: 10 | shape: [3, 84, 84] 11 | type: rgb 12 | robot0_eef_pos: 13 | shape: [3] 14 | # type default: low_dim 15 | robot0_eef_quat: 16 | shape: [4] 17 | robot0_gripper_qpos: 18 | shape: [2] 19 | action: 20 | shape: [10] 21 | 22 | abs_action: &abs_action True 23 | 24 | env_runner: 25 | _target_: equi_diffpo.env_runner.robomimic_image_runner.RobomimicImageRunner 26 | dataset_path: ${dataset_path} 27 | shape_meta: *shape_meta 28 | n_train: 6 29 | n_train_vis: 2 30 | train_start_idx: 0 31 | n_test: 50 32 | n_test_vis: 4 33 | test_start_seed: 100000 34 | max_steps: ${get_max_steps:${task_name}} 35 | n_obs_steps: ${n_obs_steps} 36 | n_action_steps: ${n_action_steps} 37 | render_obs_key: 'agentview_image' 38 | fps: 10 39 | crf: 22 40 | past_action: ${past_action_visible} 41 | abs_action: *abs_action 42 | tqdm_interval_sec: 1.0 43 | n_envs: 28 44 | 45 | dataset: 46 | # _target_: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset 47 | _target_: ${dataset} 48 | n_demo: ${n_demo} 49 | shape_meta: *shape_meta 50 | dataset_path: ${dataset_path} 51 | horizon: ${horizon} 52 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 53 | pad_after: ${eval:'${n_action_steps}-1'} 54 | n_obs_steps: ${dataset_obs_steps} 55 | abs_action: *abs_action 56 | rotation_rep: 'rotation_6d' 57 | use_legacy_normalizer: False 58 | use_cache: True 59 | seed: 42 60 | val_ratio: 0.02 61 | -------------------------------------------------------------------------------- /equi_diffpo/config/task/mimicgen_pc_abs.yaml: -------------------------------------------------------------------------------- 1 | name: mimicgen_pc_abs 2 | 3 | shape_meta: &shape_meta 4 | # acceptable types: rgb, low_dim 5 | obs: 6 | robot0_eye_in_hand_image: 7 | shape: [3, 84, 84] 8 | type: rgb 9 | point_cloud: 10 | shape: [1024, 6] 11 | type: point_cloud 12 | robot0_eef_pos: 13 | shape: [3] 14 | # type default: low_dim 15 | robot0_eef_quat: 16 | shape: [4] 17 | robot0_gripper_qpos: 18 | shape: [2] 19 | action: 20 | shape: [10] 21 | 22 | env_runner_shape_meta: &env_runner_shape_meta 23 | # acceptable types: rgb, low_dim 24 | obs: 25 | robot0_eye_in_hand_image: 26 | shape: [3, 84, 84] 27 | type: rgb 28 | agentview_image: 29 | shape: [3, 84, 84] 30 | type: rgb 31 | point_cloud: 32 | shape: [1024, 6] 33 | type: point_cloud 34 | robot0_eef_pos: 35 | shape: [3] 36 | # type default: low_dim 37 | robot0_eef_quat: 38 | shape: [4] 39 | robot0_gripper_qpos: 40 | shape: [2] 41 | action: 42 | shape: [10] 43 | 44 | abs_action: &abs_action True 45 | 46 | env_runner: 47 | _target_: equi_diffpo.env_runner.robomimic_image_runner.RobomimicImageRunner 48 | dataset_path: ${dataset_path} 49 | shape_meta: *env_runner_shape_meta 50 | n_train: 6 51 | n_train_vis: 2 52 | train_start_idx: 0 53 | n_test: 50 54 | n_test_vis: 4 55 | test_start_seed: 100000 56 | max_steps: ${get_max_steps:${task_name}} 57 | n_obs_steps: ${n_obs_steps} 58 | n_action_steps: ${n_action_steps} 59 | render_obs_key: 'agentview_image' 60 | fps: 10 61 | crf: 22 62 | past_action: False 63 | abs_action: *abs_action 64 | tqdm_interval_sec: 1.0 65 | n_envs: 28 66 | 67 | dataset: 68 | _target_: ${dataset_target} 69 | n_demo: ${n_demo} 70 | shape_meta: *shape_meta 71 | dataset_path: ${dataset_path} 72 | horizon: ${horizon} 73 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 74 | pad_after: ${eval:'${n_action_steps}-1'} 75 | n_obs_steps: ${dataset_obs_steps} 76 | abs_action: *abs_action 77 | rotation_rep: 'rotation_6d' 78 | use_legacy_normalizer: False 79 | use_cache: False 80 | seed: 42 81 | val_ratio: 0.02 82 | -------------------------------------------------------------------------------- /equi_diffpo/config/task/mimicgen_rel.yaml: -------------------------------------------------------------------------------- 1 | name: mimicgen_rel 2 | 3 | shape_meta: &shape_meta 4 | # acceptable types: rgb, low_dim 5 | obs: 6 | agentview_image: 7 | shape: [3, 84, 84] 8 | type: rgb 9 | robot0_eye_in_hand_image: 10 | shape: [3, 84, 84] 11 | type: rgb 12 | robot0_eef_pos: 13 | shape: [3] 14 | # type default: low_dim 15 | robot0_eef_quat: 16 | shape: [4] 17 | robot0_gripper_qpos: 18 | shape: [2] 19 | action: 20 | shape: [7] 21 | 22 | abs_action: &abs_action False 23 | 24 | env_runner: 25 | _target_: equi_diffpo.env_runner.robomimic_image_runner.RobomimicImageRunner 26 | dataset_path: ${dataset_path} 27 | shape_meta: *shape_meta 28 | n_train: 6 29 | n_train_vis: 2 30 | train_start_idx: 0 31 | n_test: 50 32 | n_test_vis: 4 33 | test_start_seed: 100000 34 | max_steps: ${get_max_steps:${task_name}} 35 | n_obs_steps: ${n_obs_steps} 36 | n_action_steps: ${n_action_steps} 37 | render_obs_key: 'agentview_image' 38 | fps: 10 39 | crf: 22 40 | past_action: ${past_action_visible} 41 | abs_action: *abs_action 42 | tqdm_interval_sec: 1.0 43 | n_envs: 28 44 | 45 | dataset: 46 | # _target_: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset 47 | _target_: ${dataset} 48 | n_demo: ${n_demo} 49 | shape_meta: *shape_meta 50 | dataset_path: ${dataset_path} 51 | horizon: ${horizon} 52 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 53 | pad_after: ${eval:'${n_action_steps}-1'} 54 | n_obs_steps: ${dataset_obs_steps} 55 | abs_action: *abs_action 56 | rotation_rep: 'rotation_6d' 57 | use_legacy_normalizer: False 58 | use_cache: True 59 | seed: 42 60 | val_ratio: 0.02 61 | -------------------------------------------------------------------------------- /equi_diffpo/config/task/mimicgen_voxel_abs.yaml: -------------------------------------------------------------------------------- 1 | name: mimicgen_abs 2 | 3 | shape_meta: &shape_meta 4 | # acceptable types: rgb, low_dim 5 | obs: 6 | robot0_eye_in_hand_image: 7 | shape: [3, 84, 84] 8 | type: rgb 9 | voxels: 10 | shape: [4, 64, 64, 64] 11 | type: voxel 12 | robot0_eef_pos: 13 | shape: [3] 14 | # type default: low_dim 15 | robot0_eef_quat: 16 | shape: [4] 17 | robot0_gripper_qpos: 18 | shape: [2] 19 | action: 20 | shape: [10] 21 | 22 | env_runner_shape_meta: &env_runner_shape_meta 23 | # acceptable types: rgb, low_dim 24 | obs: 25 | robot0_eye_in_hand_image: 26 | shape: [3, 84, 84] 27 | type: rgb 28 | agentview_image: 29 | shape: [3, 84, 84] 30 | type: rgb 31 | voxels: 32 | shape: [4, 64, 64, 64] 33 | type: voxel 34 | robot0_eef_pos: 35 | shape: [3] 36 | # type default: low_dim 37 | robot0_eef_quat: 38 | shape: [4] 39 | robot0_gripper_qpos: 40 | shape: [2] 41 | action: 42 | shape: [10] 43 | 44 | # dataset_path: &dataset_path data/robomimic/datasets/${task_name}/${task_name}_voxel_abs.hdf5 45 | abs_action: &abs_action True 46 | 47 | env_runner: 48 | _target_: equi_diffpo.env_runner.robomimic_image_runner.RobomimicImageRunner 49 | dataset_path: ${dataset_path} 50 | shape_meta: *env_runner_shape_meta 51 | n_train: 6 52 | n_train_vis: 2 53 | train_start_idx: 0 54 | n_test: 50 55 | n_test_vis: 4 56 | test_start_seed: 100000 57 | max_steps: ${get_max_steps:${task_name}} 58 | n_obs_steps: ${n_obs_steps} 59 | n_action_steps: ${n_action_steps} 60 | render_obs_key: 'agentview_image' 61 | fps: 10 62 | crf: 22 63 | past_action: ${past_action_visible} 64 | abs_action: *abs_action 65 | tqdm_interval_sec: 1.0 66 | n_envs: 28 67 | 68 | dataset: 69 | _target_: equi_diffpo.dataset.robomimic_replay_voxel_sym_dataset.RobomimicReplayVoxelSymDataset 70 | n_demo: ${n_demo} 71 | shape_meta: *shape_meta 72 | dataset_path: ${dataset_path} 73 | horizon: ${horizon} 74 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 75 | pad_after: ${eval:'${n_action_steps}-1'} 76 | n_obs_steps: ${dataset_obs_steps} 77 | abs_action: *abs_action 78 | rotation_rep: 'rotation_6d' 79 | use_legacy_normalizer: False 80 | use_cache: True 81 | seed: 42 82 | val_ratio: 0.02 83 | ws_x_center: ${get_ws_x_center:${task_name}} 84 | ws_y_center: ${get_ws_y_center:${task_name}} 85 | -------------------------------------------------------------------------------- /equi_diffpo/config/task/mimicgen_voxel_rel.yaml: -------------------------------------------------------------------------------- 1 | name: mimicgen_rel 2 | 3 | shape_meta: &shape_meta 4 | # acceptable types: rgb, low_dim 5 | obs: 6 | robot0_eye_in_hand_image: 7 | shape: [3, 84, 84] 8 | type: rgb 9 | voxels: 10 | shape: [4, 64, 64, 64] 11 | type: voxel 12 | robot0_eef_pos: 13 | shape: [3] 14 | # type default: low_dim 15 | robot0_eef_quat: 16 | shape: [4] 17 | robot0_gripper_qpos: 18 | shape: [2] 19 | action: 20 | shape: [7] 21 | 22 | env_runner_shape_meta: &env_runner_shape_meta 23 | # acceptable types: rgb, low_dim 24 | obs: 25 | robot0_eye_in_hand_image: 26 | shape: [3, 84, 84] 27 | type: rgb 28 | agentview_image: 29 | shape: [3, 84, 84] 30 | type: rgb 31 | voxels: 32 | shape: [4, 64, 64, 64] 33 | type: voxel 34 | robot0_eef_pos: 35 | shape: [3] 36 | # type default: low_dim 37 | robot0_eef_quat: 38 | shape: [4] 39 | robot0_gripper_qpos: 40 | shape: [2] 41 | action: 42 | shape: [7] 43 | 44 | # dataset_path: &dataset_path data/robomimic/datasets/${task_name}/${task_name}_voxel.hdf5 45 | abs_action: &abs_action False 46 | 47 | env_runner: 48 | _target_: equi_diffpo.env_runner.robomimic_image_runner.RobomimicImageRunner 49 | dataset_path: ${dataset_path} 50 | shape_meta: *env_runner_shape_meta 51 | n_train: 6 52 | n_train_vis: 2 53 | train_start_idx: 0 54 | n_test: 50 55 | n_test_vis: 4 56 | test_start_seed: 100000 57 | max_steps: ${get_max_steps:${task_name}} 58 | n_obs_steps: ${n_obs_steps} 59 | n_action_steps: ${n_action_steps} 60 | render_obs_key: 'agentview_image' 61 | fps: 10 62 | crf: 22 63 | past_action: ${past_action_visible} 64 | abs_action: *abs_action 65 | tqdm_interval_sec: 1.0 66 | n_envs: 28 67 | 68 | dataset: 69 | _target_: equi_diffpo.dataset.robomimic_replay_voxel_sym_dataset.RobomimicReplayVoxelSymDataset 70 | n_demo: ${n_demo} 71 | shape_meta: *shape_meta 72 | dataset_path: ${dataset_path} 73 | horizon: ${horizon} 74 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'} 75 | pad_after: ${eval:'${n_action_steps}-1'} 76 | n_obs_steps: ${dataset_obs_steps} 77 | abs_action: *abs_action 78 | rotation_rep: 'rotation_6d' 79 | use_legacy_normalizer: False 80 | use_cache: True 81 | seed: 42 82 | val_ratio: 0.02 83 | ws_x_center: ${get_ws_x_center:${task_name}} 84 | ws_y_center: ${get_ws_y_center:${task_name}} 85 | -------------------------------------------------------------------------------- /equi_diffpo/config/train_act_abs.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - task: mimicgen_abs 4 | 5 | name: act 6 | _target_: equi_diffpo.workspace.train_act_workspace.TrainActWorkspace 7 | 8 | shape_meta: ${task.shape_meta} 9 | exp_name: "default" 10 | 11 | task_name: stack_d1 12 | n_demo: 200 13 | horizon: 10 14 | n_obs_steps: 1 15 | n_action_steps: 10 16 | n_latency_steps: 0 17 | dataset_obs_steps: ${n_obs_steps} 18 | past_action_visible: False 19 | dataset: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset 20 | dataset_path: data/robomimic/datasets/${task_name}/${task_name}_abs.hdf5 21 | 22 | policy: 23 | _target_: equi_diffpo.policy.act_policy.ACTPolicyWrapper 24 | 25 | shape_meta: ${shape_meta} 26 | 27 | max_timesteps: ${task.env_runner.max_steps} 28 | temporal_agg: false 29 | n_envs: ${task.env_runner.n_envs} 30 | horizon: ${horizon} 31 | 32 | dataloader: 33 | batch_size: 64 34 | num_workers: 4 35 | shuffle: True 36 | pin_memory: True 37 | persistent_workers: True 38 | 39 | val_dataloader: 40 | batch_size: 64 41 | num_workers: 4 42 | shuffle: False 43 | pin_memory: True 44 | persistent_workers: True 45 | 46 | training: 47 | device: "cuda:0" 48 | seed: 0 49 | debug: False 50 | resume: True 51 | num_epochs: ${eval:'50000 / ${n_demo}'} 52 | rollout_every: ${eval:'1000 / ${n_demo}'} 53 | checkpoint_every: ${eval:'1000 / ${n_demo}'} 54 | val_every: 1 55 | max_train_steps: null 56 | max_val_steps: null 57 | tqdm_interval_sec: 1.0 58 | 59 | logging: 60 | project: diffusion_policy_${task_name} 61 | resume: True 62 | mode: online 63 | name: act_demo${n_demo} 64 | tags: ["${name}", "${task_name}", "${exp_name}"] 65 | id: null 66 | group: null 67 | 68 | checkpoint: 69 | topk: 70 | monitor_key: test_mean_score 71 | mode: max 72 | k: 5 73 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt' 74 | save_last_ckpt: True 75 | save_last_snapshot: False 76 | 77 | multi_run: 78 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 79 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 80 | 81 | hydra: 82 | job: 83 | override_dirname: ${name} 84 | run: 85 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 86 | sweep: 87 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 88 | subdir: ${hydra.job.num} 89 | -------------------------------------------------------------------------------- /equi_diffpo/config/train_bc_rnn.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - task: mimicgen_rel 4 | 5 | name: bc_rnn 6 | _target_: equi_diffpo.workspace.train_robomimic_image_workspace.TrainRobomimicImageWorkspace 7 | 8 | shape_meta: ${task.shape_meta} 9 | exp_name: "default" 10 | 11 | task_name: stack_d1 12 | n_demo: 200 13 | horizon: &horizon 10 14 | n_obs_steps: 1 15 | n_action_steps: 1 16 | n_latency_steps: 0 17 | dataset_obs_steps: *horizon 18 | past_action_visible: False 19 | dataset: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset 20 | dataset_path: data/robomimic/datasets/${task_name}/${task_name}.hdf5 21 | 22 | policy: 23 | _target_: equi_diffpo.policy.robomimic_image_policy.RobomimicImagePolicy 24 | shape_meta: ${shape_meta} 25 | algo_name: bc_rnn 26 | obs_type: image 27 | # oc.select resolver: key, default 28 | task_name: ${oc.select:task.task_name,lift} 29 | dataset_type: ${oc.select:task.dataset_type,ph} 30 | crop_shape: [76,76] 31 | 32 | dataloader: 33 | batch_size: 64 34 | num_workers: 4 35 | shuffle: True 36 | pin_memory: True 37 | persistent_workers: True 38 | 39 | val_dataloader: 40 | batch_size: 64 41 | num_workers: 4 42 | shuffle: False 43 | pin_memory: True 44 | persistent_workers: True 45 | 46 | training: 47 | device: "cuda:0" 48 | seed: 0 49 | debug: False 50 | resume: True 51 | # optimization 52 | num_epochs: ${eval:'50000 / ${n_demo}'} 53 | # training loop control 54 | # in epochs 55 | rollout_every: ${eval:'1000 / ${n_demo}'} 56 | checkpoint_every: ${eval:'1000 / ${n_demo}'} 57 | val_every: 1 58 | sample_every: 5 59 | # steps per epoch 60 | max_train_steps: null 61 | max_val_steps: null 62 | # misc 63 | tqdm_interval_sec: 1.0 64 | 65 | logging: 66 | project: diffusion_policy_${task_name} 67 | resume: True 68 | mode: online 69 | name: bc_rnn_demo${n_demo} 70 | tags: ["${name}", "${task_name}", "${exp_name}"] 71 | id: null 72 | group: null 73 | 74 | checkpoint: 75 | topk: 76 | monitor_key: test_mean_score 77 | mode: max 78 | k: 5 79 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt' 80 | save_last_ckpt: True 81 | save_last_snapshot: False 82 | 83 | multi_run: 84 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 85 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 86 | 87 | hydra: 88 | job: 89 | override_dirname: ${name} 90 | run: 91 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 92 | sweep: 93 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 94 | subdir: ${hydra.job.num} 95 | -------------------------------------------------------------------------------- /equi_diffpo/config/train_diffusion_transformer.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - task: mimicgen_abs 4 | 5 | name: diff_t 6 | _target_: equi_diffpo.workspace.train_diffusion_transformer_hybrid_workspace.TrainDiffusionTransformerHybridWorkspace 7 | 8 | shape_meta: ${task.shape_meta} 9 | exp_name: "default" 10 | 11 | task_name: stack_d1 12 | n_demo: 200 13 | horizon: 10 14 | n_obs_steps: 2 15 | n_action_steps: 8 16 | n_latency_steps: 0 17 | dataset_obs_steps: ${n_obs_steps} 18 | past_action_visible: False 19 | obs_as_cond: True 20 | dataset: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset 21 | dataset_path: data/robomimic/datasets/${task_name}/${task_name}_abs.hdf5 22 | 23 | policy: 24 | _target_: equi_diffpo.policy.diffusion_transformer_hybrid_image_policy.DiffusionTransformerHybridImagePolicy 25 | 26 | shape_meta: ${shape_meta} 27 | 28 | noise_scheduler: 29 | _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler 30 | num_train_timesteps: 100 31 | beta_start: 0.0001 32 | beta_end: 0.02 33 | beta_schedule: squaredcos_cap_v2 34 | variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan 35 | clip_sample: True # required when predict_epsilon=False 36 | prediction_type: epsilon # or sample 37 | 38 | horizon: ${horizon} 39 | n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'} 40 | n_obs_steps: ${n_obs_steps} 41 | num_inference_steps: 100 42 | 43 | crop_shape: [76, 76] 44 | obs_encoder_group_norm: True 45 | eval_fixed_crop: True 46 | 47 | n_layer: 8 48 | n_cond_layers: 0 # >0: use transformer encoder for cond, otherwise use MLP 49 | n_head: 4 50 | n_emb: 256 51 | p_drop_emb: 0.0 52 | p_drop_attn: 0.3 53 | causal_attn: True 54 | time_as_cond: True # if false, use BERT like encoder only arch, time as input 55 | obs_as_cond: ${obs_as_cond} 56 | 57 | # scheduler.step params 58 | # predict_epsilon: True 59 | 60 | ema: 61 | _target_: equi_diffpo.model.diffusion.ema_model.EMAModel 62 | update_after_step: 0 63 | inv_gamma: 1.0 64 | power: 0.75 65 | min_value: 0.0 66 | max_value: 0.9999 67 | 68 | dataloader: 69 | batch_size: 64 70 | num_workers: 4 71 | shuffle: True 72 | pin_memory: True 73 | persistent_workers: True 74 | 75 | val_dataloader: 76 | batch_size: 64 77 | num_workers: 4 78 | shuffle: False 79 | pin_memory: True 80 | persistent_workers: True 81 | 82 | optimizer: 83 | transformer_weight_decay: 1.0e-3 84 | obs_encoder_weight_decay: 1.0e-6 85 | learning_rate: 1.0e-4 86 | betas: [0.9, 0.95] 87 | 88 | training: 89 | device: "cuda:0" 90 | seed: 0 91 | debug: False 92 | resume: True 93 | # optimization 94 | lr_scheduler: cosine 95 | # Transformer needs LR warmup 96 | lr_warmup_steps: 1000 97 | num_epochs: ${eval:'50000 / ${n_demo}'} 98 | gradient_accumulate_every: 1 99 | # EMA destroys performance when used with BatchNorm 100 | # replace BatchNorm with GroupNorm. 101 | use_ema: True 102 | # training loop control 103 | # in epochs 104 | rollout_every: ${eval:'1000 / ${n_demo}'} 105 | checkpoint_every: ${eval:'1000 / ${n_demo}'} 106 | val_every: 1 107 | sample_every: 5 108 | # steps per epoch 109 | max_train_steps: null 110 | max_val_steps: null 111 | # misc 112 | tqdm_interval_sec: 1.0 113 | 114 | logging: 115 | project: diffusion_policy_${task_name} 116 | resume: True 117 | mode: online 118 | name: diff_t_demo${n_demo} 119 | tags: ["${name}", "${task_name}", "${exp_name}"] 120 | id: null 121 | group: null 122 | 123 | checkpoint: 124 | topk: 125 | monitor_key: test_mean_score 126 | mode: max 127 | k: 5 128 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt' 129 | save_last_ckpt: True 130 | save_last_snapshot: False 131 | 132 | multi_run: 133 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 134 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 135 | 136 | hydra: 137 | job: 138 | override_dirname: ${name} 139 | run: 140 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 141 | sweep: 142 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 143 | subdir: ${hydra.job.num} 144 | -------------------------------------------------------------------------------- /equi_diffpo/config/train_diffusion_unet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - task: mimicgen_abs 4 | 5 | name: diff_c 6 | _target_: equi_diffpo.workspace.train_diffusion_unet_hybrid_workspace.TrainDiffusionUnetHybridWorkspace 7 | 8 | shape_meta: ${task.shape_meta} 9 | exp_name: "default" 10 | 11 | task_name: stack_d1 12 | n_demo: 200 13 | horizon: 16 14 | n_obs_steps: 2 15 | n_action_steps: 8 16 | n_latency_steps: 0 17 | dataset_obs_steps: ${n_obs_steps} 18 | past_action_visible: False 19 | obs_as_global_cond: True 20 | dataset: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset 21 | dataset_path: data/robomimic/datasets/${task_name}/${task_name}_abs.hdf5 22 | 23 | policy: 24 | _target_: equi_diffpo.policy.diffusion_unet_hybrid_image_policy.DiffusionUnetHybridImagePolicy 25 | 26 | shape_meta: ${shape_meta} 27 | 28 | noise_scheduler: 29 | _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler 30 | num_train_timesteps: 100 31 | beta_start: 0.0001 32 | beta_end: 0.02 33 | beta_schedule: squaredcos_cap_v2 34 | variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan 35 | clip_sample: True # required when predict_epsilon=False 36 | prediction_type: epsilon # or sample 37 | 38 | horizon: ${horizon} 39 | n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'} 40 | n_obs_steps: ${n_obs_steps} 41 | num_inference_steps: 100 42 | obs_as_global_cond: ${obs_as_global_cond} 43 | crop_shape: [76, 76] 44 | # crop_shape: null 45 | diffusion_step_embed_dim: 128 46 | down_dims: [512, 1024, 2048] 47 | kernel_size: 5 48 | n_groups: 8 49 | cond_predict_scale: True 50 | obs_encoder_group_norm: True 51 | eval_fixed_crop: True 52 | rot_aug: False 53 | 54 | # scheduler.step params 55 | # predict_epsilon: True 56 | 57 | ema: 58 | _target_: equi_diffpo.model.diffusion.ema_model.EMAModel 59 | update_after_step: 0 60 | inv_gamma: 1.0 61 | power: 0.75 62 | min_value: 0.0 63 | max_value: 0.9999 64 | 65 | dataloader: 66 | batch_size: 64 67 | num_workers: 4 68 | shuffle: True 69 | pin_memory: True 70 | persistent_workers: True 71 | 72 | val_dataloader: 73 | batch_size: 64 74 | num_workers: 4 75 | shuffle: False 76 | pin_memory: True 77 | persistent_workers: True 78 | 79 | optimizer: 80 | _target_: torch.optim.AdamW 81 | lr: 1.0e-4 82 | betas: [0.95, 0.999] 83 | eps: 1.0e-8 84 | weight_decay: 1.0e-6 85 | 86 | training: 87 | device: "cuda:0" 88 | seed: 0 89 | debug: False 90 | resume: True 91 | # optimization 92 | lr_scheduler: cosine 93 | lr_warmup_steps: 500 94 | num_epochs: ${eval:'50000 / ${n_demo}'} 95 | gradient_accumulate_every: 1 96 | # EMA destroys performance when used with BatchNorm 97 | # replace BatchNorm with GroupNorm. 98 | use_ema: True 99 | # training loop control 100 | # in epochs 101 | rollout_every: ${eval:'1000 / ${n_demo}'} 102 | checkpoint_every: ${eval:'1000 / ${n_demo}'} 103 | val_every: 1 104 | sample_every: 5 105 | # steps per epoch 106 | max_train_steps: null 107 | max_val_steps: null 108 | # misc 109 | tqdm_interval_sec: 1.0 110 | 111 | logging: 112 | project: diffusion_policy_${task_name} 113 | resume: True 114 | mode: online 115 | name: diff_c_demo${n_demo} 116 | tags: ["${name}", "${task_name}", "${exp_name}"] 117 | id: null 118 | group: null 119 | 120 | checkpoint: 121 | topk: 122 | monitor_key: test_mean_score 123 | mode: max 124 | k: 5 125 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt' 126 | save_last_ckpt: True 127 | save_last_snapshot: False 128 | 129 | multi_run: 130 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 131 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 132 | 133 | hydra: 134 | job: 135 | override_dirname: ${name} 136 | run: 137 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 138 | sweep: 139 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 140 | subdir: ${hydra.job.num} 141 | -------------------------------------------------------------------------------- /equi_diffpo/config/train_diffusion_unet_voxel_abs.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - task: mimicgen_voxel_abs 4 | 5 | name: diff_voxel 6 | _target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace 7 | 8 | shape_meta: ${task.shape_meta} 9 | exp_name: "default" 10 | 11 | task_name: stack_d1 12 | n_demo: 200 13 | horizon: 16 14 | n_obs_steps: 1 15 | n_action_steps: 8 16 | n_latency_steps: 0 17 | dataset_obs_steps: ${n_obs_steps} 18 | past_action_visible: False 19 | # dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset 20 | dataset_path: data/robomimic/datasets/${task_name}/${task_name}_voxel_abs.hdf5 21 | 22 | policy: 23 | _target_: equi_diffpo.policy.diffusion_unet_voxel_policy.DiffusionUNetPolicyVoxel 24 | 25 | shape_meta: ${shape_meta} 26 | 27 | noise_scheduler: 28 | _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler 29 | num_train_timesteps: 100 30 | beta_start: 0.0001 31 | beta_end: 0.02 32 | beta_schedule: squaredcos_cap_v2 33 | variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan 34 | clip_sample: True # required when predict_epsilon=False 35 | prediction_type: epsilon # or sample 36 | 37 | horizon: ${horizon} 38 | n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'} 39 | n_obs_steps: ${n_obs_steps} 40 | num_inference_steps: 100 41 | crop_shape: [58, 58, 58] 42 | # crop_shape: null 43 | diffusion_step_embed_dim: 128 44 | enc_n_hidden: 256 45 | down_dims: [256, 512, 1024] 46 | kernel_size: 5 47 | n_groups: 8 48 | cond_predict_scale: True 49 | rot_aug: False 50 | 51 | # scheduler.step params 52 | # predict_epsilon: True 53 | 54 | ema: 55 | _target_: equi_diffpo.model.diffusion.ema_model.EMAModel 56 | update_after_step: 0 57 | inv_gamma: 1.0 58 | power: 0.75 59 | min_value: 0.0 60 | max_value: 0.9999 61 | 62 | dataloader: 63 | batch_size: 64 64 | num_workers: 16 65 | shuffle: True 66 | pin_memory: True 67 | persistent_workers: True 68 | drop_last: true 69 | 70 | val_dataloader: 71 | batch_size: 64 72 | num_workers: 16 73 | shuffle: False 74 | pin_memory: True 75 | persistent_workers: True 76 | 77 | optimizer: 78 | betas: [0.95, 0.999] 79 | eps: 1.0e-08 80 | learning_rate: 0.0001 81 | weight_decay: 1.0e-06 82 | 83 | training: 84 | device: "cuda:0" 85 | seed: 0 86 | debug: False 87 | resume: True 88 | # optimization 89 | lr_scheduler: cosine 90 | lr_warmup_steps: 500 91 | num_epochs: ${eval:'50000 / ${n_demo}'} 92 | gradient_accumulate_every: 1 93 | # EMA destroys performance when used with BatchNorm 94 | # replace BatchNorm with GroupNorm. 95 | use_ema: True 96 | # training loop control 97 | # in epochs 98 | rollout_every: ${eval:'1000 / ${n_demo}'} 99 | checkpoint_every: ${eval:'1000 / ${n_demo}'} 100 | val_every: 1 101 | sample_every: 5 102 | # steps per epoch 103 | max_train_steps: null 104 | max_val_steps: null 105 | # misc 106 | tqdm_interval_sec: 1.0 107 | 108 | logging: 109 | project: equi_diff_${task_name}_voxel 110 | resume: True 111 | mode: online 112 | name: diff_voxel_${n_demo} 113 | tags: ["${name}", "${task_name}", "${exp_name}"] 114 | id: null 115 | group: null 116 | 117 | checkpoint: 118 | topk: 119 | monitor_key: test_mean_score 120 | mode: max 121 | k: 5 122 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt' 123 | save_last_ckpt: True 124 | save_last_snapshot: False 125 | 126 | multi_run: 127 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 128 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 129 | 130 | hydra: 131 | job: 132 | override_dirname: ${name} 133 | run: 134 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 135 | sweep: 136 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 137 | subdir: ${hydra.job.num} 138 | -------------------------------------------------------------------------------- /equi_diffpo/config/train_equi_diffusion_unet_abs.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - task: mimicgen_abs 4 | 5 | name: equi_diff 6 | _target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace 7 | 8 | shape_meta: ${task.shape_meta} 9 | exp_name: "default" 10 | 11 | task_name: stack_d1 12 | n_demo: 200 13 | horizon: 16 14 | n_obs_steps: 2 15 | n_action_steps: 8 16 | n_latency_steps: 0 17 | dataset_obs_steps: ${n_obs_steps} 18 | past_action_visible: False 19 | dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset 20 | dataset_path: data/robomimic/datasets/${task_name}/${task_name}_abs.hdf5 21 | 22 | policy: 23 | _target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_policy.DiffusionEquiUNetCNNEncPolicy 24 | 25 | shape_meta: ${shape_meta} 26 | 27 | noise_scheduler: 28 | _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler 29 | num_train_timesteps: 100 30 | beta_start: 0.0001 31 | beta_end: 0.02 32 | beta_schedule: squaredcos_cap_v2 33 | variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan 34 | clip_sample: True # required when predict_epsilon=False 35 | prediction_type: epsilon # or sample 36 | 37 | horizon: ${horizon} 38 | n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'} 39 | n_obs_steps: ${n_obs_steps} 40 | num_inference_steps: 100 41 | crop_shape: [76, 76] 42 | # crop_shape: null 43 | diffusion_step_embed_dim: 128 44 | enc_n_hidden: 128 45 | down_dims: [512, 1024, 2048] 46 | kernel_size: 5 47 | n_groups: 8 48 | cond_predict_scale: True 49 | rot_aug: False 50 | 51 | # scheduler.step params 52 | # predict_epsilon: True 53 | 54 | ema: 55 | _target_: equi_diffpo.model.diffusion.ema_model.EMAModel 56 | update_after_step: 0 57 | inv_gamma: 1.0 58 | power: 0.75 59 | min_value: 0.0 60 | max_value: 0.9999 61 | 62 | dataloader: 63 | batch_size: 128 64 | num_workers: 4 65 | shuffle: True 66 | pin_memory: True 67 | persistent_workers: True 68 | drop_last: true 69 | 70 | val_dataloader: 71 | batch_size: 128 72 | num_workers: 8 73 | shuffle: False 74 | pin_memory: True 75 | persistent_workers: True 76 | 77 | optimizer: 78 | betas: [0.95, 0.999] 79 | eps: 1.0e-08 80 | learning_rate: 0.0001 81 | weight_decay: 1.0e-06 82 | 83 | training: 84 | device: "cuda:0" 85 | seed: 0 86 | debug: False 87 | resume: True 88 | # optimization 89 | lr_scheduler: cosine 90 | lr_warmup_steps: 500 91 | num_epochs: ${eval:'50000 / ${n_demo}'} 92 | gradient_accumulate_every: 1 93 | # EMA destroys performance when used with BatchNorm 94 | # replace BatchNorm with GroupNorm. 95 | use_ema: True 96 | # training loop control 97 | # in epochs 98 | rollout_every: ${eval:'1000 / ${n_demo}'} 99 | checkpoint_every: ${eval:'1000 / ${n_demo}'} 100 | val_every: 1 101 | sample_every: 5 102 | # steps per epoch 103 | max_train_steps: null 104 | max_val_steps: null 105 | # misc 106 | tqdm_interval_sec: 1.0 107 | 108 | logging: 109 | project: diffusion_policy_${task_name} 110 | resume: True 111 | mode: online 112 | name: equidiff_demo${n_demo} 113 | tags: ["${name}", "${task_name}", "${exp_name}"] 114 | id: null 115 | group: null 116 | 117 | checkpoint: 118 | topk: 119 | monitor_key: test_mean_score 120 | mode: max 121 | k: 5 122 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt' 123 | save_last_ckpt: True 124 | save_last_snapshot: False 125 | 126 | multi_run: 127 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 128 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 129 | 130 | hydra: 131 | job: 132 | override_dirname: ${name} 133 | run: 134 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 135 | sweep: 136 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 137 | subdir: ${hydra.job.num} 138 | -------------------------------------------------------------------------------- /equi_diffpo/config/train_equi_diffusion_unet_rel.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - task: mimicgen_rel 4 | 5 | name: equi_diff 6 | _target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace 7 | 8 | shape_meta: ${task.shape_meta} 9 | exp_name: "default" 10 | 11 | task_name: stack_d1 12 | n_demo: 200 13 | horizon: 16 14 | n_obs_steps: 2 15 | n_action_steps: 8 16 | n_latency_steps: 0 17 | dataset_obs_steps: ${n_obs_steps} 18 | past_action_visible: False 19 | dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset 20 | dataset_path: data/robomimic/datasets/${task_name}/${task_name}.hdf5 21 | 22 | policy: 23 | _target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_rel_policy.DiffusionEquiUNetCNNEncRelPolicy 24 | 25 | shape_meta: ${shape_meta} 26 | 27 | noise_scheduler: 28 | _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler 29 | num_train_timesteps: 100 30 | beta_start: 0.0001 31 | beta_end: 0.02 32 | beta_schedule: squaredcos_cap_v2 33 | variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan 34 | clip_sample: True # required when predict_epsilon=False 35 | prediction_type: epsilon # or sample 36 | 37 | horizon: ${horizon} 38 | n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'} 39 | n_obs_steps: ${n_obs_steps} 40 | num_inference_steps: 100 41 | crop_shape: [76, 76] 42 | # crop_shape: null 43 | diffusion_step_embed_dim: 128 44 | enc_n_hidden: 128 45 | down_dims: [512, 1024, 2048] 46 | kernel_size: 5 47 | n_groups: 8 48 | cond_predict_scale: True 49 | 50 | # scheduler.step params 51 | # predict_epsilon: True 52 | 53 | ema: 54 | _target_: equi_diffpo.model.diffusion.ema_model.EMAModel 55 | update_after_step: 0 56 | inv_gamma: 1.0 57 | power: 0.75 58 | min_value: 0.0 59 | max_value: 0.9999 60 | 61 | dataloader: 62 | batch_size: 128 63 | num_workers: 4 64 | shuffle: True 65 | pin_memory: True 66 | persistent_workers: True 67 | drop_last: true 68 | 69 | val_dataloader: 70 | batch_size: 128 71 | num_workers: 8 72 | shuffle: False 73 | pin_memory: True 74 | persistent_workers: True 75 | 76 | optimizer: 77 | betas: [0.95, 0.999] 78 | eps: 1.0e-08 79 | learning_rate: 0.0001 80 | weight_decay: 1.0e-06 81 | 82 | training: 83 | device: "cuda:0" 84 | seed: 0 85 | debug: False 86 | resume: True 87 | # optimization 88 | lr_scheduler: cosine 89 | lr_warmup_steps: 500 90 | num_epochs: ${eval:'50000 / ${n_demo}'} 91 | gradient_accumulate_every: 1 92 | # EMA destroys performance when used with BatchNorm 93 | # replace BatchNorm with GroupNorm. 94 | use_ema: True 95 | # training loop control 96 | # in epochs 97 | rollout_every: ${eval:'1000 / ${n_demo}'} 98 | checkpoint_every: ${eval:'1000 / ${n_demo}'} 99 | val_every: 1 100 | sample_every: 5 101 | # steps per epoch 102 | max_train_steps: null 103 | max_val_steps: null 104 | # misc 105 | tqdm_interval_sec: 1.0 106 | 107 | logging: 108 | project: diffusion_policy_${task_name}_vel 109 | resume: True 110 | mode: online 111 | name: equidiff_demo${n_demo} 112 | tags: ["${name}", "${task_name}", "${exp_name}"] 113 | id: null 114 | group: null 115 | 116 | checkpoint: 117 | topk: 118 | monitor_key: test_mean_score 119 | mode: max 120 | k: 5 121 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt' 122 | save_last_ckpt: True 123 | save_last_snapshot: False 124 | 125 | multi_run: 126 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 127 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 128 | 129 | hydra: 130 | job: 131 | override_dirname: ${name} 132 | run: 133 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 134 | sweep: 135 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 136 | subdir: ${hydra.job.num} 137 | -------------------------------------------------------------------------------- /equi_diffpo/config/train_equi_diffusion_unet_voxel_abs.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - task: mimicgen_voxel_abs 4 | 5 | name: equi_diff_voxel 6 | _target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace 7 | 8 | shape_meta: ${task.shape_meta} 9 | exp_name: "default" 10 | 11 | task_name: stack_d1 12 | n_demo: 200 13 | horizon: 16 14 | n_obs_steps: 1 15 | n_action_steps: 8 16 | n_latency_steps: 0 17 | dataset_obs_steps: ${n_obs_steps} 18 | past_action_visible: False 19 | # dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset 20 | dataset_path: data/robomimic/datasets/${task_name}/${task_name}_voxel_abs.hdf5 21 | 22 | policy: 23 | _target_: equi_diffpo.policy.diffusion_equi_unet_voxel_policy.DiffusionEquiUNetPolicyVoxel 24 | 25 | shape_meta: ${shape_meta} 26 | 27 | noise_scheduler: 28 | _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler 29 | num_train_timesteps: 100 30 | beta_start: 0.0001 31 | beta_end: 0.02 32 | beta_schedule: squaredcos_cap_v2 33 | variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan 34 | clip_sample: True # required when predict_epsilon=False 35 | prediction_type: epsilon # or sample 36 | 37 | horizon: ${horizon} 38 | n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'} 39 | n_obs_steps: ${n_obs_steps} 40 | num_inference_steps: 100 41 | crop_shape: [58, 58, 58] 42 | # crop_shape: null 43 | diffusion_step_embed_dim: 128 44 | enc_n_hidden: 128 45 | down_dims: [256, 512, 1024] 46 | kernel_size: 5 47 | n_groups: 8 48 | cond_predict_scale: True 49 | rot_aug: True 50 | 51 | # scheduler.step params 52 | # predict_epsilon: True 53 | 54 | ema: 55 | _target_: equi_diffpo.model.diffusion.ema_model.EMAModel 56 | update_after_step: 0 57 | inv_gamma: 1.0 58 | power: 0.75 59 | min_value: 0.0 60 | max_value: 0.9999 61 | 62 | dataloader: 63 | batch_size: 64 64 | num_workers: 16 65 | shuffle: True 66 | pin_memory: True 67 | persistent_workers: True 68 | drop_last: true 69 | 70 | val_dataloader: 71 | batch_size: 64 72 | num_workers: 16 73 | shuffle: False 74 | pin_memory: True 75 | persistent_workers: True 76 | 77 | optimizer: 78 | betas: [0.95, 0.999] 79 | eps: 1.0e-08 80 | learning_rate: 0.0001 81 | weight_decay: 1.0e-06 82 | 83 | training: 84 | device: "cuda:0" 85 | seed: 0 86 | debug: False 87 | resume: True 88 | # optimization 89 | lr_scheduler: cosine 90 | lr_warmup_steps: 500 91 | num_epochs: ${eval:'50000 / ${n_demo}'} 92 | gradient_accumulate_every: 1 93 | # EMA destroys performance when used with BatchNorm 94 | # replace BatchNorm with GroupNorm. 95 | use_ema: True 96 | # training loop control 97 | # in epochs 98 | rollout_every: ${eval:'1000 / ${n_demo}'} 99 | checkpoint_every: ${eval:'1000 / ${n_demo}'} 100 | val_every: 1 101 | sample_every: 5 102 | # steps per epoch 103 | max_train_steps: null 104 | max_val_steps: null 105 | # misc 106 | tqdm_interval_sec: 1.0 107 | 108 | logging: 109 | project: equi_diff_${task_name}_voxel 110 | resume: True 111 | mode: online 112 | name: equi_diff_voxel_${n_demo} 113 | tags: ["${name}", "${task_name}", "${exp_name}"] 114 | id: null 115 | group: null 116 | 117 | checkpoint: 118 | topk: 119 | monitor_key: test_mean_score 120 | mode: max 121 | k: 5 122 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt' 123 | save_last_ckpt: True 124 | save_last_snapshot: False 125 | 126 | multi_run: 127 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 128 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 129 | 130 | hydra: 131 | job: 132 | override_dirname: ${name} 133 | run: 134 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 135 | sweep: 136 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 137 | subdir: ${hydra.job.num} 138 | -------------------------------------------------------------------------------- /equi_diffpo/config/train_equi_diffusion_unet_voxel_rel.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - task: mimicgen_voxel_rel 4 | 5 | name: equi_diff_voxel 6 | _target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace 7 | 8 | shape_meta: ${task.shape_meta} 9 | exp_name: "default" 10 | 11 | task_name: stack_d1 12 | n_demo: 200 13 | horizon: 16 14 | n_obs_steps: 1 15 | n_action_steps: 8 16 | n_latency_steps: 0 17 | dataset_obs_steps: ${n_obs_steps} 18 | past_action_visible: False 19 | # dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset 20 | dataset_path: data/robomimic/datasets/${task_name}/${task_name}_voxel.hdf5 21 | 22 | policy: 23 | _target_: equi_diffpo.policy.diffusion_equi_unet_voxel_rel_policy.DiffusionEquiUNetRelPolicyVoxel 24 | 25 | shape_meta: ${shape_meta} 26 | 27 | noise_scheduler: 28 | _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler 29 | num_train_timesteps: 100 30 | beta_start: 0.0001 31 | beta_end: 0.02 32 | beta_schedule: squaredcos_cap_v2 33 | variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan 34 | clip_sample: True # required when predict_epsilon=False 35 | prediction_type: epsilon # or sample 36 | 37 | horizon: ${horizon} 38 | n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'} 39 | n_obs_steps: ${n_obs_steps} 40 | num_inference_steps: 100 41 | crop_shape: [58, 58, 58] 42 | # crop_shape: null 43 | diffusion_step_embed_dim: 128 44 | enc_n_hidden: 128 45 | down_dims: [256, 512, 1024] 46 | kernel_size: 5 47 | n_groups: 8 48 | cond_predict_scale: True 49 | rot_aug: True 50 | 51 | # scheduler.step params 52 | # predict_epsilon: True 53 | 54 | ema: 55 | _target_: equi_diffpo.model.diffusion.ema_model.EMAModel 56 | update_after_step: 0 57 | inv_gamma: 1.0 58 | power: 0.75 59 | min_value: 0.0 60 | max_value: 0.9999 61 | 62 | dataloader: 63 | batch_size: 64 64 | num_workers: 16 65 | shuffle: True 66 | pin_memory: True 67 | persistent_workers: True 68 | drop_last: true 69 | 70 | val_dataloader: 71 | batch_size: 64 72 | num_workers: 16 73 | shuffle: False 74 | pin_memory: True 75 | persistent_workers: True 76 | 77 | optimizer: 78 | betas: [0.95, 0.999] 79 | eps: 1.0e-08 80 | learning_rate: 0.0001 81 | weight_decay: 1.0e-06 82 | 83 | training: 84 | device: "cuda:0" 85 | seed: 0 86 | debug: False 87 | resume: True 88 | # optimization 89 | lr_scheduler: cosine 90 | lr_warmup_steps: 500 91 | num_epochs: ${eval:'50000 / ${n_demo}'} 92 | gradient_accumulate_every: 1 93 | # EMA destroys performance when used with BatchNorm 94 | # replace BatchNorm with GroupNorm. 95 | use_ema: True 96 | # training loop control 97 | # in epochs 98 | rollout_every: ${eval:'1000 / ${n_demo}'} 99 | checkpoint_every: ${eval:'1000 / ${n_demo}'} 100 | val_every: 1 101 | sample_every: 5 102 | # steps per epoch 103 | max_train_steps: null 104 | max_val_steps: null 105 | # misc 106 | tqdm_interval_sec: 1.0 107 | 108 | logging: 109 | project: equi_diff_${task_name}_voxel_rel 110 | resume: True 111 | mode: online 112 | name: equi_diff_voxel_${n_demo} 113 | tags: ["${name}", "${task_name}", "${exp_name}"] 114 | id: null 115 | group: null 116 | 117 | checkpoint: 118 | topk: 119 | monitor_key: test_mean_score 120 | mode: max 121 | k: 5 122 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt' 123 | save_last_ckpt: True 124 | save_last_snapshot: False 125 | 126 | multi_run: 127 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 128 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name} 129 | 130 | hydra: 131 | job: 132 | override_dirname: ${name} 133 | run: 134 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 135 | sweep: 136 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name} 137 | subdir: ${hydra.job.num} 138 | -------------------------------------------------------------------------------- /equi_diffpo/dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.nn 5 | from equi_diffpo.model.common.normalizer import LinearNormalizer 6 | 7 | class BaseLowdimDataset(torch.utils.data.Dataset): 8 | def get_validation_dataset(self) -> 'BaseLowdimDataset': 9 | # return an empty dataset by default 10 | return BaseLowdimDataset() 11 | 12 | def get_normalizer(self, **kwargs) -> LinearNormalizer: 13 | raise NotImplementedError() 14 | 15 | def get_all_actions(self) -> torch.Tensor: 16 | raise NotImplementedError() 17 | 18 | def __len__(self) -> int: 19 | return 0 20 | 21 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 22 | """ 23 | output: 24 | obs: T, Do 25 | action: T, Da 26 | """ 27 | raise NotImplementedError() 28 | 29 | 30 | class BaseImageDataset(torch.utils.data.Dataset): 31 | def get_validation_dataset(self) -> 'BaseLowdimDataset': 32 | # return an empty dataset by default 33 | return BaseImageDataset() 34 | 35 | def get_normalizer(self, **kwargs) -> LinearNormalizer: 36 | raise NotImplementedError() 37 | 38 | def get_all_actions(self) -> torch.Tensor: 39 | raise NotImplementedError() 40 | 41 | def __len__(self) -> int: 42 | return 0 43 | 44 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 45 | """ 46 | output: 47 | obs: 48 | key: T, * 49 | action: T, Da 50 | """ 51 | raise NotImplementedError() 52 | -------------------------------------------------------------------------------- /equi_diffpo/dataset/robomimic_replay_image_sym_dataset.py: -------------------------------------------------------------------------------- 1 | from equi_diffpo.dataset.base_dataset import LinearNormalizer 2 | from equi_diffpo.model.common.normalizer import LinearNormalizer 3 | from equi_diffpo.dataset.robomimic_replay_image_dataset import RobomimicReplayImageDataset, normalizer_from_stat 4 | from equi_diffpo.common.normalize_util import ( 5 | robomimic_abs_action_only_symmetric_normalizer_from_stat, 6 | get_range_normalizer_from_stat, 7 | get_range_symmetric_normalizer_from_stat, 8 | get_image_range_normalizer, 9 | get_identity_normalizer_from_stat, 10 | array_to_stats 11 | ) 12 | import numpy as np 13 | 14 | class RobomimicReplayImageSymDataset(RobomimicReplayImageDataset): 15 | def __init__(self, 16 | shape_meta: dict, 17 | dataset_path: str, 18 | horizon=1, 19 | pad_before=0, 20 | pad_after=0, 21 | n_obs_steps=None, 22 | abs_action=False, 23 | rotation_rep='rotation_6d', # ignored when abs_action=False 24 | use_legacy_normalizer=False, 25 | use_cache=False, 26 | seed=42, 27 | val_ratio=0.0, 28 | n_demo=100 29 | ): 30 | super().__init__( 31 | shape_meta, 32 | dataset_path, 33 | horizon, 34 | pad_before, 35 | pad_after, 36 | n_obs_steps, 37 | abs_action, 38 | rotation_rep, 39 | use_legacy_normalizer, 40 | use_cache, 41 | seed, 42 | val_ratio, 43 | n_demo 44 | ) 45 | 46 | def get_normalizer(self, **kwargs) -> LinearNormalizer: 47 | normalizer = LinearNormalizer() 48 | 49 | # action 50 | stat = array_to_stats(self.replay_buffer['action']) 51 | if self.abs_action: 52 | if stat['mean'].shape[-1] > 10: 53 | # dual arm 54 | raise NotImplementedError 55 | else: 56 | this_normalizer = robomimic_abs_action_only_symmetric_normalizer_from_stat(stat) 57 | 58 | if self.use_legacy_normalizer: 59 | this_normalizer = normalizer_from_stat(stat) 60 | else: 61 | # already normalized 62 | this_normalizer = get_identity_normalizer_from_stat(stat) 63 | normalizer['action'] = this_normalizer 64 | 65 | # obs 66 | for key in self.lowdim_keys: 67 | stat = array_to_stats(self.replay_buffer[key]) 68 | 69 | if key.endswith('qpos'): 70 | this_normalizer = get_range_normalizer_from_stat(stat) 71 | elif key.endswith('pos'): 72 | this_normalizer = get_range_symmetric_normalizer_from_stat(stat) 73 | elif key.endswith('quat'): 74 | # quaternion is in [-1,1] already 75 | this_normalizer = get_identity_normalizer_from_stat(stat) 76 | elif key.find('bbox') > -1: 77 | this_normalizer = get_identity_normalizer_from_stat(stat) 78 | else: 79 | raise RuntimeError('unsupported') 80 | normalizer[key] = this_normalizer 81 | 82 | # image 83 | for key in self.rgb_keys: 84 | normalizer[key] = get_image_range_normalizer() 85 | 86 | normalizer['pos_vecs'] = get_identity_normalizer_from_stat({'min': -1 * np.ones([10, 2], np.float32), 'max': np.ones([10, 2], np.float32)}) 87 | normalizer['crops'] = get_image_range_normalizer() 88 | 89 | return normalizer 90 | 91 | -------------------------------------------------------------------------------- /equi_diffpo/dataset/robomimic_replay_lowdim_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | import torch 3 | import numpy as np 4 | import h5py 5 | from tqdm import tqdm 6 | import copy 7 | from equi_diffpo.common.pytorch_util import dict_apply 8 | from equi_diffpo.dataset.base_dataset import BaseLowdimDataset, LinearNormalizer 9 | from equi_diffpo.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer 10 | from equi_diffpo.model.common.rotation_transformer import RotationTransformer 11 | from equi_diffpo.common.replay_buffer import ReplayBuffer 12 | from equi_diffpo.common.sampler import ( 13 | SequenceSampler, get_val_mask, downsample_mask) 14 | from equi_diffpo.common.normalize_util import ( 15 | robomimic_abs_action_only_normalizer_from_stat, 16 | robomimic_abs_action_only_dual_arm_normalizer_from_stat, 17 | get_identity_normalizer_from_stat, 18 | array_to_stats 19 | ) 20 | 21 | class RobomimicReplayLowdimDataset(BaseLowdimDataset): 22 | def __init__(self, 23 | dataset_path: str, 24 | horizon=1, 25 | pad_before=0, 26 | pad_after=0, 27 | obs_keys: List[str]=[ 28 | 'object', 29 | 'robot0_eef_pos', 30 | 'robot0_eef_quat', 31 | 'robot0_gripper_qpos'], 32 | abs_action=False, 33 | rotation_rep='rotation_6d', 34 | use_legacy_normalizer=False, 35 | seed=42, 36 | val_ratio=0.0, 37 | max_train_episodes=None, 38 | n_demo=100 39 | ): 40 | obs_keys = list(obs_keys) 41 | rotation_transformer = RotationTransformer( 42 | from_rep='axis_angle', to_rep=rotation_rep) 43 | 44 | replay_buffer = ReplayBuffer.create_empty_numpy() 45 | with h5py.File(dataset_path) as file: 46 | demos = file['data'] 47 | for i in tqdm(range(n_demo), desc="Loading hdf5 to ReplayBuffer"): 48 | demo = demos[f'demo_{i}'] 49 | episode = _data_to_obs( 50 | raw_obs=demo['obs'], 51 | raw_actions=demo['actions'][:].astype(np.float32), 52 | obs_keys=obs_keys, 53 | abs_action=abs_action, 54 | rotation_transformer=rotation_transformer) 55 | replay_buffer.add_episode(episode) 56 | 57 | val_mask = get_val_mask( 58 | n_episodes=replay_buffer.n_episodes, 59 | val_ratio=val_ratio, 60 | seed=seed) 61 | train_mask = ~val_mask 62 | train_mask = downsample_mask( 63 | mask=train_mask, 64 | max_n=max_train_episodes, 65 | seed=seed) 66 | 67 | sampler = SequenceSampler( 68 | replay_buffer=replay_buffer, 69 | sequence_length=horizon, 70 | pad_before=pad_before, 71 | pad_after=pad_after, 72 | episode_mask=train_mask) 73 | 74 | self.replay_buffer = replay_buffer 75 | self.sampler = sampler 76 | self.abs_action = abs_action 77 | self.train_mask = train_mask 78 | self.horizon = horizon 79 | self.pad_before = pad_before 80 | self.pad_after = pad_after 81 | self.use_legacy_normalizer = use_legacy_normalizer 82 | 83 | def get_validation_dataset(self): 84 | val_set = copy.copy(self) 85 | val_set.sampler = SequenceSampler( 86 | replay_buffer=self.replay_buffer, 87 | sequence_length=self.horizon, 88 | pad_before=self.pad_before, 89 | pad_after=self.pad_after, 90 | episode_mask=~self.train_mask 91 | ) 92 | val_set.train_mask = ~self.train_mask 93 | return val_set 94 | 95 | def get_normalizer(self, **kwargs) -> LinearNormalizer: 96 | normalizer = LinearNormalizer() 97 | 98 | # action 99 | stat = array_to_stats(self.replay_buffer['action']) 100 | if self.abs_action: 101 | if stat['mean'].shape[-1] > 10: 102 | # dual arm 103 | this_normalizer = robomimic_abs_action_only_dual_arm_normalizer_from_stat(stat) 104 | else: 105 | this_normalizer = robomimic_abs_action_only_normalizer_from_stat(stat) 106 | 107 | if self.use_legacy_normalizer: 108 | this_normalizer = normalizer_from_stat(stat) 109 | else: 110 | # already normalized 111 | this_normalizer = get_identity_normalizer_from_stat(stat) 112 | normalizer['action'] = this_normalizer 113 | 114 | # aggregate obs stats 115 | obs_stat = array_to_stats(self.replay_buffer['obs']) 116 | 117 | 118 | normalizer['obs'] = normalizer_from_stat(obs_stat) 119 | return normalizer 120 | 121 | def get_all_actions(self) -> torch.Tensor: 122 | return torch.from_numpy(self.replay_buffer['action']) 123 | 124 | def __len__(self): 125 | return len(self.sampler) 126 | 127 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 128 | data = self.sampler.sample_sequence(idx) 129 | torch_data = dict_apply(data, torch.from_numpy) 130 | return torch_data 131 | 132 | def normalizer_from_stat(stat): 133 | max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max()) 134 | scale = np.full_like(stat['max'], fill_value=1/max_abs) 135 | offset = np.zeros_like(stat['max']) 136 | return SingleFieldLinearNormalizer.create_manual( 137 | scale=scale, 138 | offset=offset, 139 | input_stats_dict=stat 140 | ) 141 | 142 | def _data_to_obs(raw_obs, raw_actions, obs_keys, abs_action, rotation_transformer): 143 | obs = np.concatenate([ 144 | raw_obs[key] for key in obs_keys 145 | ], axis=-1).astype(np.float32) 146 | 147 | if abs_action: 148 | is_dual_arm = False 149 | if raw_actions.shape[-1] == 14: 150 | # dual arm 151 | raw_actions = raw_actions.reshape(-1,2,7) 152 | is_dual_arm = True 153 | 154 | pos = raw_actions[...,:3] 155 | rot = raw_actions[...,3:6] 156 | gripper = raw_actions[...,6:] 157 | rot = rotation_transformer.forward(rot) 158 | raw_actions = np.concatenate([ 159 | pos, rot, gripper 160 | ], axis=-1).astype(np.float32) 161 | 162 | if is_dual_arm: 163 | raw_actions = raw_actions.reshape(-1,20) 164 | 165 | data = { 166 | 'obs': obs, 167 | 'action': raw_actions 168 | } 169 | return data 170 | -------------------------------------------------------------------------------- /equi_diffpo/dataset/robomimic_replay_lowdim_sym_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | import torch 3 | import numpy as np 4 | from equi_diffpo.common.pytorch_util import dict_apply 5 | from equi_diffpo.dataset.base_dataset import LinearNormalizer 6 | from equi_diffpo.dataset.robomimic_replay_lowdim_dataset import RobomimicReplayLowdimDataset, normalizer_from_stat 7 | from equi_diffpo.common.normalize_util import robomimic_abs_action_only_symmetric_normalizer_from_stat 8 | from equi_diffpo.common.normalize_util import ( 9 | robomimic_abs_action_only_symmetric_normalizer_from_stat, 10 | get_identity_normalizer_from_stat, 11 | array_to_stats 12 | ) 13 | 14 | 15 | class RobomimicReplayLowdimSymDataset(RobomimicReplayLowdimDataset): 16 | def __init__(self, 17 | dataset_path: str, 18 | horizon=1, 19 | pad_before=0, 20 | pad_after=0, 21 | obs_keys: List[str]=[ 22 | 'object', 23 | 'robot0_eef_pos', 24 | 'robot0_eef_quat', 25 | 'robot0_gripper_qpos'], 26 | abs_action=False, 27 | rotation_rep='rotation_6d', 28 | use_legacy_normalizer=False, 29 | seed=42, 30 | val_ratio=0.0, 31 | max_train_episodes=None, 32 | n_demo=100 33 | ): 34 | super().__init__( 35 | dataset_path, 36 | horizon, 37 | pad_before, 38 | pad_after, 39 | obs_keys, 40 | abs_action, 41 | rotation_rep, 42 | use_legacy_normalizer, 43 | seed, 44 | val_ratio, 45 | max_train_episodes, 46 | n_demo, 47 | ) 48 | 49 | def get_normalizer(self, **kwargs) -> LinearNormalizer: 50 | normalizer = LinearNormalizer() 51 | 52 | # action 53 | stat = array_to_stats(self.replay_buffer['action']) 54 | if self.abs_action: 55 | if stat['mean'].shape[-1] > 10: 56 | # dual arm 57 | raise NotImplementedError 58 | else: 59 | this_normalizer = robomimic_abs_action_only_symmetric_normalizer_from_stat(stat) 60 | 61 | if self.use_legacy_normalizer: 62 | this_normalizer = normalizer_from_stat(stat) 63 | else: 64 | # already normalized 65 | this_normalizer = get_identity_normalizer_from_stat(stat) 66 | normalizer['action'] = this_normalizer 67 | 68 | # aggregate obs stats 69 | obs_stat = array_to_stats(self.replay_buffer['obs']) 70 | 71 | 72 | normalizer['obs'] = normalizer_from_stat(obs_stat) 73 | return normalizer 74 | -------------------------------------------------------------------------------- /equi_diffpo/env/robomimic/robomimic_image_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from matplotlib.pyplot import fill 3 | import numpy as np 4 | import gym 5 | from gym import spaces 6 | from omegaconf import OmegaConf 7 | from robomimic.envs.env_robosuite import EnvRobosuite 8 | 9 | class RobomimicImageWrapper(gym.Env): 10 | def __init__(self, 11 | env: EnvRobosuite, 12 | shape_meta: dict, 13 | init_state: Optional[np.ndarray]=None, 14 | render_obs_key='agentview_image', 15 | ): 16 | 17 | self.env = env 18 | self.render_obs_key = render_obs_key 19 | self.init_state = init_state 20 | self.seed_state_map = dict() 21 | self._seed = None 22 | self.shape_meta = shape_meta 23 | self.render_cache = None 24 | self.has_reset_before = False 25 | 26 | # setup spaces 27 | action_shape = shape_meta['action']['shape'] 28 | action_space = spaces.Box( 29 | low=-1, 30 | high=1, 31 | shape=action_shape, 32 | dtype=np.float32 33 | ) 34 | self.action_space = action_space 35 | 36 | observation_space = spaces.Dict() 37 | for key, value in shape_meta['obs'].items(): 38 | shape = value['shape'] 39 | min_value, max_value = -1, 1 40 | if key.endswith('image'): 41 | min_value, max_value = 0, 1 42 | elif key.endswith('depth'): 43 | min_value, max_value = 0, 1 44 | elif key.endswith('voxels'): 45 | min_value, max_value = 0, 1 46 | elif key.endswith('point_cloud'): 47 | min_value, max_value = -10, 10 48 | elif key.endswith('quat'): 49 | min_value, max_value = -1, 1 50 | elif key.endswith('qpos'): 51 | min_value, max_value = -1, 1 52 | elif key.endswith('pos'): 53 | # better range? 54 | min_value, max_value = -1, 1 55 | else: 56 | raise RuntimeError(f"Unsupported type {key}") 57 | 58 | this_space = spaces.Box( 59 | low=min_value, 60 | high=max_value, 61 | shape=shape, 62 | dtype=np.float32 63 | ) 64 | observation_space[key] = this_space 65 | self.observation_space = observation_space 66 | 67 | 68 | def get_observation(self, raw_obs=None): 69 | if raw_obs is None: 70 | raw_obs = self.env.get_observation() 71 | 72 | self.render_cache = raw_obs[self.render_obs_key] 73 | 74 | obs = dict() 75 | for key in self.observation_space.keys(): 76 | obs[key] = raw_obs[key] 77 | return obs 78 | 79 | def seed(self, seed=None): 80 | np.random.seed(seed=seed) 81 | self._seed = seed 82 | 83 | def reset(self): 84 | if self.init_state is not None: 85 | if not self.has_reset_before: 86 | # the env must be fully reset at least once to ensure correct rendering 87 | self.env.reset() 88 | self.has_reset_before = True 89 | 90 | # always reset to the same state 91 | # to be compatible with gym 92 | raw_obs = self.env.reset_to({'states': self.init_state}) 93 | elif self._seed is not None: 94 | # reset to a specific seed 95 | seed = self._seed 96 | if seed in self.seed_state_map: 97 | # env.reset is expensive, use cache 98 | raw_obs = self.env.reset_to({'states': self.seed_state_map[seed]}) 99 | else: 100 | # robosuite's initializes all use numpy global random state 101 | np.random.seed(seed=seed) 102 | raw_obs = self.env.reset() 103 | state = self.env.get_state()['states'] 104 | self.seed_state_map[seed] = state 105 | self._seed = None 106 | else: 107 | # random reset 108 | raw_obs = self.env.reset() 109 | 110 | # return obs 111 | obs = self.get_observation(raw_obs) 112 | return obs 113 | 114 | def step(self, action): 115 | raw_obs, reward, done, info = self.env.step(action) 116 | obs = self.get_observation(raw_obs) 117 | return obs, reward, done, info 118 | 119 | def render(self, mode='rgb_array'): 120 | if self.render_cache is None: 121 | raise RuntimeError('Must run reset or step before render.') 122 | img = np.moveaxis(self.render_cache, 0, -1) 123 | img = (img * 255).astype(np.uint8) 124 | return img 125 | 126 | 127 | def test(): 128 | import os 129 | from omegaconf import OmegaConf 130 | cfg_path = os.path.expanduser('~/dev/diffusion_policy/diffusion_policy/config/task/lift_image.yaml') 131 | cfg = OmegaConf.load(cfg_path) 132 | shape_meta = cfg['shape_meta'] 133 | 134 | 135 | import robomimic.utils.file_utils as FileUtils 136 | import robomimic.utils.env_utils as EnvUtils 137 | from matplotlib import pyplot as plt 138 | 139 | dataset_path = os.path.expanduser('~/dev/diffusion_policy/data/robomimic/datasets/square/ph/image.hdf5') 140 | env_meta = FileUtils.get_env_metadata_from_dataset( 141 | dataset_path) 142 | 143 | env = EnvUtils.create_env_from_metadata( 144 | env_meta=env_meta, 145 | render=False, 146 | render_offscreen=False, 147 | use_image_obs=True, 148 | ) 149 | 150 | wrapper = RobomimicImageWrapper( 151 | env=env, 152 | shape_meta=shape_meta 153 | ) 154 | wrapper.seed(0) 155 | obs = wrapper.reset() 156 | img = wrapper.render() 157 | plt.imshow(img) 158 | 159 | 160 | # states = list() 161 | # for _ in range(2): 162 | # wrapper.seed(0) 163 | # wrapper.reset() 164 | # states.append(wrapper.env.get_state()['states']) 165 | # assert np.allclose(states[0], states[1]) 166 | 167 | # img = wrapper.render() 168 | # plt.imshow(img) 169 | # wrapper.seed() 170 | # states.append(wrapper.env.get_state()['states']) 171 | -------------------------------------------------------------------------------- /equi_diffpo/env/robomimic/robomimic_lowdim_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional 2 | import numpy as np 3 | import gym 4 | from gym.spaces import Box 5 | from robomimic.envs.env_robosuite import EnvRobosuite 6 | 7 | class RobomimicLowdimWrapper(gym.Env): 8 | def __init__(self, 9 | env: EnvRobosuite, 10 | obs_keys: List[str]=[ 11 | 'object', 12 | 'robot0_eef_pos', 13 | 'robot0_eef_quat', 14 | 'robot0_gripper_qpos'], 15 | init_state: Optional[np.ndarray]=None, 16 | render_hw=(256,256), 17 | render_camera_name='agentview' 18 | ): 19 | 20 | self.env = env 21 | self.obs_keys = obs_keys 22 | self.init_state = init_state 23 | self.render_hw = render_hw 24 | self.render_camera_name = render_camera_name 25 | self.seed_state_map = dict() 26 | self._seed = None 27 | 28 | # setup spaces 29 | low = np.full(env.action_dimension, fill_value=-1) 30 | high = np.full(env.action_dimension, fill_value=1) 31 | self.action_space = Box( 32 | low=low, 33 | high=high, 34 | shape=low.shape, 35 | dtype=low.dtype 36 | ) 37 | obs_example = self.get_observation() 38 | low = np.full_like(obs_example, fill_value=-1) 39 | high = np.full_like(obs_example, fill_value=1) 40 | self.observation_space = Box( 41 | low=low, 42 | high=high, 43 | shape=low.shape, 44 | dtype=low.dtype 45 | ) 46 | 47 | def get_observation(self): 48 | raw_obs = self.env.get_observation() 49 | obs = np.concatenate([ 50 | raw_obs[key] for key in self.obs_keys 51 | ], axis=0) 52 | return obs 53 | 54 | def seed(self, seed=None): 55 | np.random.seed(seed=seed) 56 | self._seed = seed 57 | 58 | def reset(self): 59 | if self.init_state is not None: 60 | # always reset to the same state 61 | # to be compatible with gym 62 | self.env.reset_to({'states': self.init_state}) 63 | elif self._seed is not None: 64 | # reset to a specific seed 65 | seed = self._seed 66 | if seed in self.seed_state_map: 67 | # env.reset is expensive, use cache 68 | self.env.reset_to({'states': self.seed_state_map[seed]}) 69 | else: 70 | # robosuite's initializes all use numpy global random state 71 | np.random.seed(seed=seed) 72 | self.env.reset() 73 | state = self.env.get_state()['states'] 74 | self.seed_state_map[seed] = state 75 | self._seed = None 76 | else: 77 | # random reset 78 | self.env.reset() 79 | 80 | # return obs 81 | obs = self.get_observation() 82 | return obs 83 | 84 | def step(self, action): 85 | raw_obs, reward, done, info = self.env.step(action) 86 | obs = np.concatenate([ 87 | raw_obs[key] for key in self.obs_keys 88 | ], axis=0) 89 | return obs, reward, done, info 90 | 91 | def render(self, mode='rgb_array'): 92 | h, w = self.render_hw 93 | return self.env.render(mode=mode, 94 | height=h, width=w, 95 | camera_name=self.render_camera_name) 96 | 97 | 98 | def test(): 99 | import robomimic.utils.file_utils as FileUtils 100 | import robomimic.utils.env_utils as EnvUtils 101 | from matplotlib import pyplot as plt 102 | 103 | dataset_path = '/home/cchi/dev/diffusion_policy/data/robomimic/datasets/square/ph/low_dim.hdf5' 104 | env_meta = FileUtils.get_env_metadata_from_dataset( 105 | dataset_path) 106 | 107 | env = EnvUtils.create_env_from_metadata( 108 | env_meta=env_meta, 109 | render=False, 110 | render_offscreen=False, 111 | use_image_obs=False, 112 | ) 113 | wrapper = RobomimicLowdimWrapper( 114 | env=env, 115 | obs_keys=[ 116 | 'object', 117 | 'robot0_eef_pos', 118 | 'robot0_eef_quat', 119 | 'robot0_gripper_qpos' 120 | ] 121 | ) 122 | 123 | states = list() 124 | for _ in range(2): 125 | wrapper.seed(0) 126 | wrapper.reset() 127 | states.append(wrapper.env.get_state()['states']) 128 | assert np.allclose(states[0], states[1]) 129 | 130 | img = wrapper.render() 131 | plt.imshow(img) 132 | # wrapper.seed() 133 | # states.append(wrapper.env.get_state()['states']) 134 | -------------------------------------------------------------------------------- /equi_diffpo/env_runner/base_image_runner.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from equi_diffpo.policy.base_image_policy import BaseImagePolicy 3 | 4 | class BaseImageRunner: 5 | def __init__(self, output_dir): 6 | self.output_dir = output_dir 7 | 8 | def run(self, policy: BaseImagePolicy) -> Dict: 9 | raise NotImplementedError() 10 | -------------------------------------------------------------------------------- /equi_diffpo/env_runner/base_lowdim_runner.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from equi_diffpo.policy.base_lowdim_policy import BaseLowdimPolicy 3 | 4 | class BaseLowdimRunner: 5 | def __init__(self, output_dir): 6 | self.output_dir = output_dir 7 | 8 | def run(self, policy: BaseLowdimPolicy) -> Dict: 9 | raise NotImplementedError() 10 | -------------------------------------------------------------------------------- /equi_diffpo/gym_util/multistep_wrapper.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import spaces 3 | import numpy as np 4 | from collections import defaultdict, deque 5 | import dill 6 | 7 | def stack_repeated(x, n): 8 | return np.repeat(np.expand_dims(x,axis=0),n,axis=0) 9 | 10 | def repeated_box(box_space, n): 11 | return spaces.Box( 12 | low=stack_repeated(box_space.low, n), 13 | high=stack_repeated(box_space.high, n), 14 | shape=(n,) + box_space.shape, 15 | dtype=box_space.dtype 16 | ) 17 | 18 | def repeated_space(space, n): 19 | if isinstance(space, spaces.Box): 20 | return repeated_box(space, n) 21 | elif isinstance(space, spaces.Dict): 22 | result_space = spaces.Dict() 23 | for key, value in space.items(): 24 | result_space[key] = repeated_space(value, n) 25 | return result_space 26 | else: 27 | raise RuntimeError(f'Unsupported space type {type(space)}') 28 | 29 | def take_last_n(x, n): 30 | x = list(x) 31 | n = min(len(x), n) 32 | return np.array(x[-n:]) 33 | 34 | def dict_take_last_n(x, n): 35 | result = dict() 36 | for key, value in x.items(): 37 | result[key] = take_last_n(value, n) 38 | return result 39 | 40 | def aggregate(data, method='max'): 41 | if method == 'max': 42 | # equivalent to any 43 | return np.max(data) 44 | elif method == 'min': 45 | # equivalent to all 46 | return np.min(data) 47 | elif method == 'mean': 48 | return np.mean(data) 49 | elif method == 'sum': 50 | return np.sum(data) 51 | else: 52 | raise NotImplementedError() 53 | 54 | def stack_last_n_obs(all_obs, n_steps): 55 | assert(len(all_obs) > 0) 56 | all_obs = list(all_obs) 57 | result = np.zeros((n_steps,) + all_obs[-1].shape, 58 | dtype=all_obs[-1].dtype) 59 | start_idx = -min(n_steps, len(all_obs)) 60 | result[start_idx:] = np.array(all_obs[start_idx:]) 61 | if n_steps > len(all_obs): 62 | # pad 63 | result[:start_idx] = result[start_idx] 64 | return result 65 | 66 | 67 | class MultiStepWrapper(gym.Wrapper): 68 | def __init__(self, 69 | env, 70 | n_obs_steps, 71 | n_action_steps, 72 | max_episode_steps=None, 73 | reward_agg_method='max' 74 | ): 75 | super().__init__(env) 76 | self._action_space = repeated_space(env.action_space, n_action_steps) 77 | self._observation_space = repeated_space(env.observation_space, n_obs_steps) 78 | self.max_episode_steps = max_episode_steps 79 | self.n_obs_steps = n_obs_steps 80 | self.n_action_steps = n_action_steps 81 | self.reward_agg_method = reward_agg_method 82 | self.n_obs_steps = n_obs_steps 83 | 84 | self.obs = deque(maxlen=n_obs_steps+1) 85 | self.reward = list() 86 | self.done = list() 87 | self.info = defaultdict(lambda : deque(maxlen=n_obs_steps+1)) 88 | 89 | def reset(self): 90 | """Resets the environment using kwargs.""" 91 | obs = super().reset() 92 | 93 | self.obs = deque([obs], maxlen=self.n_obs_steps+1) 94 | self.reward = list() 95 | self.done = list() 96 | self.info = defaultdict(lambda : deque(maxlen=self.n_obs_steps+1)) 97 | 98 | obs = self._get_obs(self.n_obs_steps) 99 | return obs 100 | 101 | def step(self, action): 102 | """ 103 | actions: (n_action_steps,) + action_shape 104 | """ 105 | for act in action: 106 | if len(self.done) > 0 and self.done[-1]: 107 | # termination 108 | break 109 | observation, reward, done, info = super().step(act) 110 | 111 | self.obs.append(observation) 112 | self.reward.append(reward) 113 | if (self.max_episode_steps is not None) \ 114 | and (len(self.reward) >= self.max_episode_steps): 115 | # truncation 116 | done = True 117 | self.done.append(done) 118 | self._add_info(info) 119 | 120 | observation = self._get_obs(self.n_obs_steps) 121 | reward = aggregate(self.reward, self.reward_agg_method) 122 | done = aggregate(self.done, 'max') 123 | info = dict_take_last_n(self.info, self.n_obs_steps) 124 | return observation, reward, done, info 125 | 126 | def _get_obs(self, n_steps=1): 127 | """ 128 | Output (n_steps,) + obs_shape 129 | """ 130 | assert(len(self.obs) > 0) 131 | if isinstance(self.observation_space, spaces.Box): 132 | return stack_last_n_obs(self.obs, n_steps) 133 | elif isinstance(self.observation_space, spaces.Dict): 134 | result = dict() 135 | for key in self.observation_space.keys(): 136 | result[key] = stack_last_n_obs( 137 | [obs[key] for obs in self.obs], 138 | n_steps 139 | ) 140 | return result 141 | else: 142 | raise RuntimeError('Unsupported space type') 143 | 144 | def _add_info(self, info): 145 | for key, value in info.items(): 146 | self.info[key].append(value) 147 | 148 | def get_rewards(self): 149 | return self.reward 150 | 151 | def get_attr(self, name): 152 | return getattr(self, name) 153 | 154 | def run_dill_function(self, dill_fn): 155 | fn = dill.loads(dill_fn) 156 | return fn(self) 157 | 158 | def get_infos(self): 159 | result = dict() 160 | for k, v in self.info.items(): 161 | result[k] = list(v) 162 | return result 163 | -------------------------------------------------------------------------------- /equi_diffpo/gym_util/video_wrapper.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | class VideoWrapper(gym.Wrapper): 5 | def __init__(self, 6 | env, 7 | mode='rgb_array', 8 | enabled=True, 9 | steps_per_render=1, 10 | **kwargs 11 | ): 12 | super().__init__(env) 13 | 14 | self.mode = mode 15 | self.enabled = enabled 16 | self.render_kwargs = kwargs 17 | self.steps_per_render = steps_per_render 18 | 19 | self.frames = list() 20 | self.step_count = 0 21 | 22 | def reset(self, **kwargs): 23 | obs = super().reset(**kwargs) 24 | self.frames = list() 25 | self.step_count = 1 26 | if self.enabled: 27 | frame = self.env.render( 28 | mode=self.mode, **self.render_kwargs) 29 | assert frame.dtype == np.uint8 30 | self.frames.append(frame) 31 | return obs 32 | 33 | def step(self, action): 34 | result = super().step(action) 35 | self.step_count += 1 36 | if self.enabled and ((self.step_count % self.steps_per_render) == 0): 37 | frame = self.env.render( 38 | mode=self.mode, **self.render_kwargs) 39 | assert frame.dtype == np.uint8 40 | self.frames.append(frame) 41 | return result 42 | 43 | def render(self, mode='rgb_array', **kwargs): 44 | return self.frames 45 | -------------------------------------------------------------------------------- /equi_diffpo/model/common/dict_of_tensor_mixin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DictOfTensorMixin(nn.Module): 5 | def __init__(self, params_dict=None): 6 | super().__init__() 7 | if params_dict is None: 8 | params_dict = nn.ParameterDict() 9 | self.params_dict = params_dict 10 | 11 | @property 12 | def device(self): 13 | return next(iter(self.parameters())).device 14 | 15 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 16 | def dfs_add(dest, keys, value: torch.Tensor): 17 | if len(keys) == 1: 18 | dest[keys[0]] = value 19 | return 20 | 21 | if keys[0] not in dest: 22 | dest[keys[0]] = nn.ParameterDict() 23 | dfs_add(dest[keys[0]], keys[1:], value) 24 | 25 | def load_dict(state_dict, prefix): 26 | out_dict = nn.ParameterDict() 27 | for key, value in state_dict.items(): 28 | value: torch.Tensor 29 | if key.startswith(prefix): 30 | param_keys = key[len(prefix):].split('.')[1:] 31 | # if len(param_keys) == 0: 32 | # import pdb; pdb.set_trace() 33 | dfs_add(out_dict, param_keys, value.clone()) 34 | return out_dict 35 | 36 | self.params_dict = load_dict(state_dict, prefix + 'params_dict') 37 | self.params_dict.requires_grad_(False) 38 | return 39 | -------------------------------------------------------------------------------- /equi_diffpo/model/common/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from diffusers.optimization import ( 2 | Union, SchedulerType, Optional, 3 | Optimizer, TYPE_TO_SCHEDULER_FUNCTION 4 | ) 5 | 6 | def get_scheduler( 7 | name: Union[str, SchedulerType], 8 | optimizer: Optimizer, 9 | num_warmup_steps: Optional[int] = None, 10 | num_training_steps: Optional[int] = None, 11 | **kwargs 12 | ): 13 | """ 14 | Added kwargs vs diffuser's original implementation 15 | 16 | Unified API to get any scheduler from its name. 17 | 18 | Args: 19 | name (`str` or `SchedulerType`): 20 | The name of the scheduler to use. 21 | optimizer (`torch.optim.Optimizer`): 22 | The optimizer that will be used during training. 23 | num_warmup_steps (`int`, *optional*): 24 | The number of warmup steps to do. This is not required by all schedulers (hence the argument being 25 | optional), the function will raise an error if it's unset and the scheduler type requires it. 26 | num_training_steps (`int``, *optional*): 27 | The number of training steps to do. This is not required by all schedulers (hence the argument being 28 | optional), the function will raise an error if it's unset and the scheduler type requires it. 29 | """ 30 | name = SchedulerType(name) 31 | schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] 32 | if name == SchedulerType.CONSTANT: 33 | return schedule_func(optimizer, **kwargs) 34 | 35 | # All other schedulers require `num_warmup_steps` 36 | if num_warmup_steps is None: 37 | raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") 38 | 39 | if name == SchedulerType.CONSTANT_WITH_WARMUP: 40 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs) 41 | 42 | # All other schedulers require `num_training_steps` 43 | if num_training_steps is None: 44 | raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") 45 | 46 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs) 47 | -------------------------------------------------------------------------------- /equi_diffpo/model/common/module_attr_mixin.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class ModuleAttrMixin(nn.Module): 4 | def __init__(self): 5 | super().__init__() 6 | self._dummy_variable = nn.Parameter() 7 | 8 | @property 9 | def device(self): 10 | return next(iter(self.parameters())).device 11 | 12 | @property 13 | def dtype(self): 14 | return next(iter(self.parameters())).dtype 15 | -------------------------------------------------------------------------------- /equi_diffpo/model/common/rotation_transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import pytorch3d.transforms as pt 3 | import torch 4 | import numpy as np 5 | import functools 6 | 7 | class RotationTransformer: 8 | valid_reps = [ 9 | 'axis_angle', 10 | 'euler_angles', 11 | 'quaternion', 12 | 'rotation_6d', 13 | 'matrix' 14 | ] 15 | 16 | def __init__(self, 17 | from_rep='axis_angle', 18 | to_rep='rotation_6d', 19 | from_convention=None, 20 | to_convention=None): 21 | """ 22 | Valid representations 23 | 24 | Always use matrix as intermediate representation. 25 | """ 26 | assert from_rep != to_rep 27 | assert from_rep in self.valid_reps 28 | assert to_rep in self.valid_reps 29 | if from_rep == 'euler_angles': 30 | assert from_convention is not None 31 | if to_rep == 'euler_angles': 32 | assert to_convention is not None 33 | 34 | forward_funcs = list() 35 | inverse_funcs = list() 36 | 37 | if from_rep != 'matrix': 38 | funcs = [ 39 | getattr(pt, f'{from_rep}_to_matrix'), 40 | getattr(pt, f'matrix_to_{from_rep}') 41 | ] 42 | if from_convention is not None: 43 | funcs = [functools.partial(func, convention=from_convention) 44 | for func in funcs] 45 | forward_funcs.append(funcs[0]) 46 | inverse_funcs.append(funcs[1]) 47 | 48 | if to_rep != 'matrix': 49 | funcs = [ 50 | getattr(pt, f'matrix_to_{to_rep}'), 51 | getattr(pt, f'{to_rep}_to_matrix') 52 | ] 53 | if to_convention is not None: 54 | funcs = [functools.partial(func, convention=to_convention) 55 | for func in funcs] 56 | forward_funcs.append(funcs[0]) 57 | inverse_funcs.append(funcs[1]) 58 | 59 | inverse_funcs = inverse_funcs[::-1] 60 | 61 | self.forward_funcs = forward_funcs 62 | self.inverse_funcs = inverse_funcs 63 | 64 | @staticmethod 65 | def _apply_funcs(x: Union[np.ndarray, torch.Tensor], funcs: list) -> Union[np.ndarray, torch.Tensor]: 66 | x_ = x 67 | if isinstance(x, np.ndarray): 68 | x_ = torch.from_numpy(x) 69 | x_: torch.Tensor 70 | for func in funcs: 71 | x_ = func(x_) 72 | y = x_ 73 | if isinstance(x, np.ndarray): 74 | y = x_.numpy() 75 | return y 76 | 77 | def forward(self, x: Union[np.ndarray, torch.Tensor] 78 | ) -> Union[np.ndarray, torch.Tensor]: 79 | return self._apply_funcs(x, self.forward_funcs) 80 | 81 | def inverse(self, x: Union[np.ndarray, torch.Tensor] 82 | ) -> Union[np.ndarray, torch.Tensor]: 83 | return self._apply_funcs(x, self.inverse_funcs) 84 | 85 | 86 | def test(): 87 | tf = RotationTransformer() 88 | 89 | rotvec = np.random.uniform(-2*np.pi,2*np.pi,size=(1000,3)) 90 | rot6d = tf.forward(rotvec) 91 | new_rotvec = tf.inverse(rot6d) 92 | 93 | from scipy.spatial.transform import Rotation 94 | diff = Rotation.from_rotvec(rotvec) * Rotation.from_rotvec(new_rotvec).inv() 95 | dist = diff.magnitude() 96 | assert dist.max() < 1e-7 97 | 98 | tf = RotationTransformer('rotation_6d', 'matrix') 99 | rot6d_wrong = rot6d + np.random.normal(scale=0.1, size=rot6d.shape) 100 | mat = tf.forward(rot6d_wrong) 101 | mat_det = np.linalg.det(mat) 102 | assert np.allclose(mat_det, 1) 103 | # rotaiton_6d will be normalized to rotation matrix 104 | -------------------------------------------------------------------------------- /equi_diffpo/model/common/shape_util.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple, Callable 2 | import torch 3 | import torch.nn as nn 4 | 5 | def get_module_device(m: nn.Module): 6 | device = torch.device('cpu') 7 | try: 8 | param = next(iter(m.parameters())) 9 | device = param.device 10 | except StopIteration: 11 | pass 12 | return device 13 | 14 | @torch.no_grad() 15 | def get_output_shape( 16 | input_shape: Tuple[int], 17 | net: Callable[[torch.Tensor], torch.Tensor] 18 | ): 19 | device = get_module_device(net) 20 | test_input = torch.zeros((1,)+tuple(input_shape), device=device) 21 | test_output = net(test_input) 22 | output_shape = tuple(test_output.shape[1:]) 23 | return output_shape 24 | -------------------------------------------------------------------------------- /equi_diffpo/model/detr/README.md: -------------------------------------------------------------------------------- 1 | This part of the codebase is modified from DETR https://github.com/facebookresearch/detr under APACHE 2.0. 2 | 3 | @article{Carion2020EndtoEndOD, 4 | title={End-to-End Object Detection with Transformers}, 5 | author={Nicolas Carion and Francisco Massa and Gabriel Synnaeve and Nicolas Usunier and Alexander Kirillov and Sergey Zagoruyko}, 6 | journal={ArXiv}, 7 | year={2020}, 8 | volume={abs/2005.12872} 9 | } -------------------------------------------------------------------------------- /equi_diffpo/model/detr/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import argparse 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | from .models import build_ACT_model, build_CNNMLP_model 8 | 9 | import IPython 10 | e = IPython.embed 11 | 12 | # def get_args_parser(): 13 | # parser = argparse.ArgumentParser('Set transformer detector', add_help=False) 14 | # parser.add_argument('--lr', default=1e-4, type=float) # will be overridden 15 | # parser.add_argument('--lr_backbone', default=1e-5, type=float) # will be overridden 16 | # parser.add_argument('--batch_size', default=2, type=int) # not used 17 | # parser.add_argument('--weight_decay', default=1e-4, type=float) 18 | # parser.add_argument('--epochs', default=300, type=int) # not used 19 | # parser.add_argument('--lr_drop', default=200, type=int) # not used 20 | # parser.add_argument('--clip_max_norm', default=0.1, type=float, # not used 21 | # help='gradient clipping max norm') 22 | 23 | # # Model parameters 24 | # # * Backbone 25 | # parser.add_argument('--backbone', default='resnet18', type=str, # will be overridden 26 | # help="Name of the convolutional backbone to use") 27 | # parser.add_argument('--dilation', action='store_true', 28 | # help="If true, we replace stride with dilation in the last convolutional block (DC5)") 29 | # parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), 30 | # help="Type of positional embedding to use on top of the image features") 31 | # parser.add_argument('--camera_names', default=[], type=list, # will be overridden 32 | # help="A list of camera names") 33 | 34 | # # * Transformer 35 | # parser.add_argument('--enc_layers', default=4, type=int, # will be overridden 36 | # help="Number of encoding layers in the transformer") 37 | # parser.add_argument('--dec_layers', default=6, type=int, # will be overridden 38 | # help="Number of decoding layers in the transformer") 39 | # parser.add_argument('--dim_feedforward', default=2048, type=int, # will be overridden 40 | # help="Intermediate size of the feedforward layers in the transformer blocks") 41 | # parser.add_argument('--hidden_dim', default=256, type=int, # will be overridden 42 | # help="Size of the embeddings (dimension of the transformer)") 43 | # parser.add_argument('--dropout', default=0.1, type=float, 44 | # help="Dropout applied in the transformer") 45 | # parser.add_argument('--nheads', default=8, type=int, # will be overridden 46 | # help="Number of attention heads inside the transformer's attentions") 47 | # parser.add_argument('--num_queries', default=400, type=int, # will be overridden 48 | # help="Number of query slots") 49 | # parser.add_argument('--pre_norm', action='store_true') 50 | 51 | # # * Segmentation 52 | # parser.add_argument('--masks', action='store_true', 53 | # help="Train segmentation head if the flag is provided") 54 | 55 | # # repeat args in imitate_episodes just to avoid error. Will not be used 56 | # parser.add_argument('--eval', action='store_true') 57 | # parser.add_argument('--onscreen_render', action='store_true') 58 | # parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True) 59 | # parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True) 60 | # parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True) 61 | # parser.add_argument('--seed', action='store', type=int, help='seed', required=True) 62 | # parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True) 63 | # parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False) 64 | # parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False) 65 | # parser.add_argument('--temporal_agg', action='store_true') 66 | 67 | # return parser 68 | 69 | class Config: 70 | def __init__(self, dictionary): 71 | for key, value in dictionary.items(): 72 | setattr(self, key, value) 73 | 74 | def build_ACT_model_and_optimizer(args_override): 75 | args = { 76 | "lr": 1e-4, 77 | "lr_backbone": 1e-05, 78 | "weight_decay": 1e-4, 79 | "backbone": "resnet18", 80 | "dilation": False, 81 | "position_embedding": "sine", 82 | "camera_names": [], 83 | "enc_layers": 4, 84 | "dec_layers": 6, 85 | "dim_feedforward": 2048, 86 | "hidden_dim": 256, 87 | "dropout": 0.1, 88 | "nheads": 8, 89 | "num_queries": 400, 90 | "pre_norm": False, 91 | "masks": False, 92 | } 93 | 94 | # parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()]) 95 | # args = parser.parse_args() 96 | args = Config(args) 97 | 98 | for k, v in args_override.items(): 99 | # args[k] = v 100 | setattr(args, k, v) 101 | 102 | model = build_ACT_model(args) 103 | 104 | param_dicts = [ 105 | {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]}, 106 | { 107 | "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], 108 | "lr": args.lr_backbone, 109 | }, 110 | ] 111 | optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, 112 | weight_decay=args.weight_decay) 113 | 114 | return model, optimizer 115 | 116 | 117 | def build_CNNMLP_model_and_optimizer(args_override): 118 | parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()]) 119 | args = parser.parse_args() 120 | 121 | for k, v in args_override.items(): 122 | setattr(args, k, v) 123 | 124 | model = build_CNNMLP_model(args) 125 | 126 | param_dicts = [ 127 | {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]}, 128 | { 129 | "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], 130 | "lr": args.lr_backbone, 131 | }, 132 | ] 133 | optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, 134 | weight_decay=args.weight_decay) 135 | 136 | return model, optimizer 137 | 138 | -------------------------------------------------------------------------------- /equi_diffpo/model/detr/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .detr_vae import build as build_vae 3 | from .detr_vae import build_cnnmlp as build_cnnmlp 4 | 5 | def build_ACT_model(args): 6 | return build_vae(args) 7 | 8 | def build_CNNMLP_model(args): 9 | return build_cnnmlp(args) -------------------------------------------------------------------------------- /equi_diffpo/model/detr/models/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Backbone modules. 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torchvision 10 | from torch import nn 11 | from torchvision.models._utils import IntermediateLayerGetter 12 | from typing import Dict, List 13 | 14 | from ..util.misc import NestedTensor, is_main_process 15 | 16 | from .position_encoding import build_position_encoding 17 | 18 | import IPython 19 | e = IPython.embed 20 | 21 | class FrozenBatchNorm2d(torch.nn.Module): 22 | """ 23 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 24 | 25 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 26 | without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101] 27 | produce nans. 28 | """ 29 | 30 | def __init__(self, n): 31 | super(FrozenBatchNorm2d, self).__init__() 32 | self.register_buffer("weight", torch.ones(n)) 33 | self.register_buffer("bias", torch.zeros(n)) 34 | self.register_buffer("running_mean", torch.zeros(n)) 35 | self.register_buffer("running_var", torch.ones(n)) 36 | 37 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 38 | missing_keys, unexpected_keys, error_msgs): 39 | num_batches_tracked_key = prefix + 'num_batches_tracked' 40 | if num_batches_tracked_key in state_dict: 41 | del state_dict[num_batches_tracked_key] 42 | 43 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 44 | state_dict, prefix, local_metadata, strict, 45 | missing_keys, unexpected_keys, error_msgs) 46 | 47 | def forward(self, x): 48 | # move reshapes to the beginning 49 | # to make it fuser-friendly 50 | w = self.weight.reshape(1, -1, 1, 1) 51 | b = self.bias.reshape(1, -1, 1, 1) 52 | rv = self.running_var.reshape(1, -1, 1, 1) 53 | rm = self.running_mean.reshape(1, -1, 1, 1) 54 | eps = 1e-5 55 | scale = w * (rv + eps).rsqrt() 56 | bias = b - rm * scale 57 | return x * scale + bias 58 | 59 | 60 | class BackboneBase(nn.Module): 61 | 62 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): 63 | super().__init__() 64 | # for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this? 65 | # if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 66 | # parameter.requires_grad_(False) 67 | if return_interm_layers: 68 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 69 | else: 70 | return_layers = {'layer4': "0"} 71 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 72 | self.num_channels = num_channels 73 | 74 | def forward(self, tensor): 75 | xs = self.body(tensor) 76 | return xs 77 | # out: Dict[str, NestedTensor] = {} 78 | # for name, x in xs.items(): 79 | # m = tensor_list.mask 80 | # assert m is not None 81 | # mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 82 | # out[name] = NestedTensor(x, mask) 83 | # return out 84 | 85 | 86 | class Backbone(BackboneBase): 87 | """ResNet backbone with frozen BatchNorm.""" 88 | def __init__(self, name: str, 89 | train_backbone: bool, 90 | return_interm_layers: bool, 91 | dilation: bool): 92 | backbone = getattr(torchvision.models, name)( 93 | replace_stride_with_dilation=[False, False, dilation], 94 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm?? 95 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 96 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) 97 | 98 | 99 | class Joiner(nn.Sequential): 100 | def __init__(self, backbone, position_embedding): 101 | super().__init__(backbone, position_embedding) 102 | 103 | def forward(self, tensor_list: NestedTensor): 104 | xs = self[0](tensor_list) 105 | out: List[NestedTensor] = [] 106 | pos = [] 107 | for name, x in xs.items(): 108 | out.append(x) 109 | # position encoding 110 | pos.append(self[1](x).to(x.dtype)) 111 | 112 | return out, pos 113 | 114 | 115 | def build_backbone(args): 116 | position_embedding = build_position_encoding(args) 117 | train_backbone = args.lr_backbone > 0 118 | return_interm_layers = args.masks 119 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) 120 | model = Joiner(backbone, position_embedding) 121 | model.num_channels = backbone.num_channels 122 | return model 123 | -------------------------------------------------------------------------------- /equi_diffpo/model/detr/models/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | from ..util.misc import NestedTensor 10 | 11 | import IPython 12 | e = IPython.embed 13 | 14 | class PositionEmbeddingSine(nn.Module): 15 | """ 16 | This is a more standard version of the position embedding, very similar to the one 17 | used by the Attention is all you need paper, generalized to work on images. 18 | """ 19 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 20 | super().__init__() 21 | self.num_pos_feats = num_pos_feats 22 | self.temperature = temperature 23 | self.normalize = normalize 24 | if scale is not None and normalize is False: 25 | raise ValueError("normalize should be True if scale is passed") 26 | if scale is None: 27 | scale = 2 * math.pi 28 | self.scale = scale 29 | 30 | def forward(self, tensor): 31 | x = tensor 32 | # mask = tensor_list.mask 33 | # assert mask is not None 34 | # not_mask = ~mask 35 | 36 | not_mask = torch.ones_like(x[0, [0]]) 37 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 38 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 39 | if self.normalize: 40 | eps = 1e-6 41 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 42 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 43 | 44 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 45 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 46 | 47 | pos_x = x_embed[:, :, :, None] / dim_t 48 | pos_y = y_embed[:, :, :, None] / dim_t 49 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 50 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 51 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 52 | return pos 53 | 54 | 55 | class PositionEmbeddingLearned(nn.Module): 56 | """ 57 | Absolute pos embedding, learned. 58 | """ 59 | def __init__(self, num_pos_feats=256): 60 | super().__init__() 61 | self.row_embed = nn.Embedding(50, num_pos_feats) 62 | self.col_embed = nn.Embedding(50, num_pos_feats) 63 | self.reset_parameters() 64 | 65 | def reset_parameters(self): 66 | nn.init.uniform_(self.row_embed.weight) 67 | nn.init.uniform_(self.col_embed.weight) 68 | 69 | def forward(self, tensor_list: NestedTensor): 70 | x = tensor_list.tensors 71 | h, w = x.shape[-2:] 72 | i = torch.arange(w, device=x.device) 73 | j = torch.arange(h, device=x.device) 74 | x_emb = self.col_embed(i) 75 | y_emb = self.row_embed(j) 76 | pos = torch.cat([ 77 | x_emb.unsqueeze(0).repeat(h, 1, 1), 78 | y_emb.unsqueeze(1).repeat(1, w, 1), 79 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 80 | return pos 81 | 82 | 83 | def build_position_encoding(args): 84 | N_steps = args.hidden_dim // 2 85 | if args.position_embedding in ('v2', 'sine'): 86 | # TODO find a better way of exposing other arguments 87 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 88 | elif args.position_embedding in ('v3', 'learned'): 89 | position_embedding = PositionEmbeddingLearned(N_steps) 90 | else: 91 | raise ValueError(f"not supported {args.position_embedding}") 92 | 93 | return position_embedding 94 | -------------------------------------------------------------------------------- /equi_diffpo/model/detr/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from setuptools import find_packages 3 | 4 | setup( 5 | name='detr', 6 | version='0.0.0', 7 | packages=find_packages(), 8 | license='MIT License', 9 | long_description=open('README.md').read(), 10 | ) -------------------------------------------------------------------------------- /equi_diffpo/model/detr/util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | -------------------------------------------------------------------------------- /equi_diffpo/model/detr/util/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 12 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 13 | return torch.stack(b, dim=-1) 14 | 15 | 16 | def box_xyxy_to_cxcywh(x): 17 | x0, y0, x1, y1 = x.unbind(-1) 18 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 19 | (x1 - x0), (y1 - y0)] 20 | return torch.stack(b, dim=-1) 21 | 22 | 23 | # modified from torchvision to also return the union 24 | def box_iou(boxes1, boxes2): 25 | area1 = box_area(boxes1) 26 | area2 = box_area(boxes2) 27 | 28 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 29 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 30 | 31 | wh = (rb - lt).clamp(min=0) # [N,M,2] 32 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 33 | 34 | union = area1[:, None] + area2 - inter 35 | 36 | iou = inter / union 37 | return iou, union 38 | 39 | 40 | def generalized_box_iou(boxes1, boxes2): 41 | """ 42 | Generalized IoU from https://giou.stanford.edu/ 43 | 44 | The boxes should be in [x0, y0, x1, y1] format 45 | 46 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 47 | and M = len(boxes2) 48 | """ 49 | # degenerate boxes gives inf / nan results 50 | # so do an early check 51 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 52 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 53 | iou, union = box_iou(boxes1, boxes2) 54 | 55 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 56 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 57 | 58 | wh = (rb - lt).clamp(min=0) # [N,M,2] 59 | area = wh[:, :, 0] * wh[:, :, 1] 60 | 61 | return iou - (area - union) / area 62 | 63 | 64 | def masks_to_boxes(masks): 65 | """Compute the bounding boxes around the provided masks 66 | 67 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 68 | 69 | Returns a [N, 4] tensors, with the boxes in xyxy format 70 | """ 71 | if masks.numel() == 0: 72 | return torch.zeros((0, 4), device=masks.device) 73 | 74 | h, w = masks.shape[-2:] 75 | 76 | y = torch.arange(0, h, dtype=torch.float) 77 | x = torch.arange(0, w, dtype=torch.float) 78 | y, x = torch.meshgrid(y, x) 79 | 80 | x_mask = (masks * x.unsqueeze(0)) 81 | x_max = x_mask.flatten(1).max(-1)[0] 82 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 83 | 84 | y_mask = (masks * y.unsqueeze(0)) 85 | y_max = y_mask.flatten(1).max(-1)[0] 86 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 87 | 88 | return torch.stack([x_min, y_min, x_max, y_max], 1) 89 | -------------------------------------------------------------------------------- /equi_diffpo/model/detr/util/plot_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plotting utilities to visualize training logs. 3 | """ 4 | import torch 5 | import pandas as pd 6 | import numpy as np 7 | import seaborn as sns 8 | import matplotlib.pyplot as plt 9 | 10 | from pathlib import Path, PurePath 11 | 12 | 13 | def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): 14 | ''' 15 | Function to plot specific fields from training log(s). Plots both training and test results. 16 | 17 | :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file 18 | - fields = which results to plot from each log file - plots both training and test for each field. 19 | - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots 20 | - log_name = optional, name of log file if different than default 'log.txt'. 21 | 22 | :: Outputs - matplotlib plots of results in fields, color coded for each log file. 23 | - solid lines are training results, dashed lines are test results. 24 | 25 | ''' 26 | func_name = "plot_utils.py::plot_logs" 27 | 28 | # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, 29 | # convert single Path to list to avoid 'not iterable' error 30 | 31 | if not isinstance(logs, list): 32 | if isinstance(logs, PurePath): 33 | logs = [logs] 34 | print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") 35 | else: 36 | raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ 37 | Expect list[Path] or single Path obj, received {type(logs)}") 38 | 39 | # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir 40 | for i, dir in enumerate(logs): 41 | if not isinstance(dir, PurePath): 42 | raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") 43 | if not dir.exists(): 44 | raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") 45 | # verify log_name exists 46 | fn = Path(dir / log_name) 47 | if not fn.exists(): 48 | print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?") 49 | print(f"--> full path of missing log file: {fn}") 50 | return 51 | 52 | # load log file(s) and plot 53 | dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] 54 | 55 | fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) 56 | 57 | for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): 58 | for j, field in enumerate(fields): 59 | if field == 'mAP': 60 | coco_eval = pd.DataFrame( 61 | np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1] 62 | ).ewm(com=ewm_col).mean() 63 | axs[j].plot(coco_eval, c=color) 64 | else: 65 | df.interpolate().ewm(com=ewm_col).mean().plot( 66 | y=[f'train_{field}', f'test_{field}'], 67 | ax=axs[j], 68 | color=[color] * 2, 69 | style=['-', '--'] 70 | ) 71 | for ax, field in zip(axs, fields): 72 | ax.legend([Path(p).name for p in logs]) 73 | ax.set_title(field) 74 | 75 | 76 | def plot_precision_recall(files, naming_scheme='iter'): 77 | if naming_scheme == 'exp_id': 78 | # name becomes exp_id 79 | names = [f.parts[-3] for f in files] 80 | elif naming_scheme == 'iter': 81 | names = [f.stem for f in files] 82 | else: 83 | raise ValueError(f'not supported {naming_scheme}') 84 | fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) 85 | for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): 86 | data = torch.load(f) 87 | # precision is n_iou, n_points, n_cat, n_area, max_det 88 | precision = data['precision'] 89 | recall = data['params'].recThrs 90 | scores = data['scores'] 91 | # take precision for all classes, all areas and 100 detections 92 | precision = precision[0, :, :, 0, -1].mean(1) 93 | scores = scores[0, :, :, 0, -1].mean(1) 94 | prec = precision.mean() 95 | rec = data['recall'][0, :, 0, -1].mean() 96 | print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + 97 | f'score={scores.mean():0.3f}, ' + 98 | f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' 99 | ) 100 | axs[0].plot(recall, precision, c=color) 101 | axs[1].plot(recall, scores, c=color) 102 | 103 | axs[0].set_title('Precision / Recall') 104 | axs[0].legend(names) 105 | axs[1].set_title('Scores / Recall') 106 | axs[1].legend(names) 107 | return fig, axs 108 | -------------------------------------------------------------------------------- /equi_diffpo/model/diffusion/conv1d_components.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | # from einops.layers.torch import Rearrange 5 | 6 | 7 | class Downsample1d(nn.Module): 8 | def __init__(self, dim): 9 | super().__init__() 10 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1) 11 | 12 | def forward(self, x): 13 | return self.conv(x) 14 | 15 | class Upsample1d(nn.Module): 16 | def __init__(self, dim): 17 | super().__init__() 18 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) 19 | 20 | def forward(self, x): 21 | return self.conv(x) 22 | 23 | class Conv1dBlock(nn.Module): 24 | ''' 25 | Conv1d --> GroupNorm --> Mish 26 | ''' 27 | 28 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): 29 | super().__init__() 30 | 31 | self.block = nn.Sequential( 32 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), 33 | # Rearrange('batch channels horizon -> batch channels 1 horizon'), 34 | nn.GroupNorm(n_groups, out_channels), 35 | # Rearrange('batch channels 1 horizon -> batch channels horizon'), 36 | nn.Mish(), 37 | ) 38 | 39 | def forward(self, x): 40 | return self.block(x) 41 | 42 | 43 | def test(): 44 | cb = Conv1dBlock(256, 128, kernel_size=3) 45 | x = torch.zeros((1,256,16)) 46 | o = cb(x) 47 | -------------------------------------------------------------------------------- /equi_diffpo/model/diffusion/ema_model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torch.nn.modules.batchnorm import _BatchNorm 4 | 5 | class EMAModel: 6 | """ 7 | Exponential Moving Average of models weights 8 | """ 9 | 10 | def __init__( 11 | self, 12 | model, 13 | update_after_step=0, 14 | inv_gamma=1.0, 15 | power=2 / 3, 16 | min_value=0.0, 17 | max_value=0.9999 18 | ): 19 | """ 20 | @crowsonkb's notes on EMA Warmup: 21 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan 22 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), 23 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 24 | at 215.4k steps). 25 | Args: 26 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. 27 | power (float): Exponential factor of EMA warmup. Default: 2/3. 28 | min_value (float): The minimum EMA decay rate. Default: 0. 29 | """ 30 | 31 | self.averaged_model = model 32 | self.averaged_model.eval() 33 | self.averaged_model.requires_grad_(False) 34 | 35 | self.update_after_step = update_after_step 36 | self.inv_gamma = inv_gamma 37 | self.power = power 38 | self.min_value = min_value 39 | self.max_value = max_value 40 | 41 | self.decay = 0.0 42 | self.optimization_step = 0 43 | 44 | def get_decay(self, optimization_step): 45 | """ 46 | Compute the decay factor for the exponential moving average. 47 | """ 48 | step = max(0, optimization_step - self.update_after_step - 1) 49 | value = 1 - (1 + step / self.inv_gamma) ** -self.power 50 | 51 | if step <= 0: 52 | return 0.0 53 | 54 | return max(self.min_value, min(value, self.max_value)) 55 | 56 | @torch.no_grad() 57 | def step(self, new_model): 58 | self.decay = self.get_decay(self.optimization_step) 59 | 60 | # old_all_dataptrs = set() 61 | # for param in new_model.parameters(): 62 | # data_ptr = param.data_ptr() 63 | # if data_ptr != 0: 64 | # old_all_dataptrs.add(data_ptr) 65 | 66 | all_dataptrs = set() 67 | for module, ema_module in zip(new_model.modules(), self.averaged_model.modules()): 68 | for param, ema_param in zip(module.parameters(recurse=False), ema_module.parameters(recurse=False)): 69 | # iterative over immediate parameters only. 70 | if isinstance(param, dict): 71 | raise RuntimeError('Dict parameter not supported') 72 | 73 | # data_ptr = param.data_ptr() 74 | # if data_ptr != 0: 75 | # all_dataptrs.add(data_ptr) 76 | 77 | if isinstance(module, _BatchNorm): 78 | # skip batchnorms 79 | ema_param.copy_(param.to(dtype=ema_param.dtype).data) 80 | elif not param.requires_grad: 81 | ema_param.copy_(param.to(dtype=ema_param.dtype).data) 82 | else: 83 | ema_param.mul_(self.decay) 84 | ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay) 85 | 86 | # verify that iterating over module and then parameters is identical to parameters recursively. 87 | # assert old_all_dataptrs == all_dataptrs 88 | self.optimization_step += 1 89 | -------------------------------------------------------------------------------- /equi_diffpo/model/diffusion/positional_embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | class SinusoidalPosEmb(nn.Module): 6 | def __init__(self, dim): 7 | super().__init__() 8 | self.dim = dim 9 | 10 | def forward(self, x): 11 | device = x.device 12 | half_dim = self.dim // 2 13 | emb = math.log(10000) / (half_dim - 1) 14 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 15 | emb = x[:, None] * emb[None, :] 16 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 17 | return emb 18 | -------------------------------------------------------------------------------- /equi_diffpo/model/equi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pointW/equidiff/e40abb003b25071d4e5b01bfa9933dc16cd32c67/equi_diffpo/model/equi/__init__.py -------------------------------------------------------------------------------- /equi_diffpo/model/equi/equi_conditional_unet1d_vel.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import torch 3 | from escnn import gspaces, nn 4 | from escnn.group import CyclicGroup 5 | from einops import rearrange, repeat 6 | from equi_diffpo.model.diffusion.conditional_unet1d import ConditionalUnet1D 7 | from equi_diffpo.model.common.rotation_transformer import RotationTransformer 8 | 9 | 10 | class EquiDiffusionUNetVel(torch.nn.Module): 11 | def __init__(self, act_emb_dim, local_cond_dim, global_cond_dim, diffusion_step_embed_dim, down_dims, kernel_size, n_groups, cond_predict_scale, N): 12 | super().__init__() 13 | self.unet = ConditionalUnet1D( 14 | input_dim=act_emb_dim, 15 | local_cond_dim=local_cond_dim, 16 | global_cond_dim=global_cond_dim, 17 | diffusion_step_embed_dim=diffusion_step_embed_dim, 18 | down_dims=down_dims, 19 | kernel_size=kernel_size, 20 | n_groups=n_groups, 21 | cond_predict_scale=cond_predict_scale 22 | ) 23 | self.N = N 24 | self.group = gspaces.no_base_space(CyclicGroup(self.N)) 25 | self.order = self.N 26 | self.act_type = nn.FieldType(self.group, act_emb_dim * [self.group.regular_repr]) 27 | self.out_layer = nn.Linear(self.act_type, 28 | self.getOutFieldType()) 29 | self.enc_a = nn.SequentialModule( 30 | nn.Linear(self.getOutFieldType(), self.act_type), 31 | nn.ReLU(self.act_type) 32 | ) 33 | 34 | self.p = torch.tensor([ 35 | [1, 0, 0, 0, 1, 0, 0, 0, 0], 36 | [0, -1, 0, 1, 0, 0, 0, 0, 0], 37 | [0, 0, 0, 0, 0, 0, 0, 0, 1], 38 | [0, 0, 1, 0, 0, 0, 0, 0, 0], 39 | [0, 0, 0, 0, 0, 1, 0, 0, 0], 40 | [0, 0, 0, 0, 0, 0, 1, 0, 0], 41 | [0, 0, 0, 0, 0, 0, 0, 1, 0], 42 | [0, 1, 0, 1, 0, 0, 0, 0, 0], 43 | [-1, 0, 0, 0, 1, 0, 0, 0, 0] 44 | ]).float() 45 | self.p_inv = torch.linalg.inv(self.p) 46 | self.axisangle_to_matrix = RotationTransformer('axis_angle', 'matrix') 47 | 48 | def getOutFieldType(self): 49 | return nn.FieldType( 50 | self.group, 51 | 1 * [self.group.irrep(2)] # 2 52 | + 3 * [self.group.irrep(1)] # 6 53 | + 5 * [self.group.trivial_repr], # 5 54 | ) 55 | 56 | # matrix 57 | def getOutput(self, conv_out): 58 | rho2 = conv_out[:, 0:2] 59 | xy = conv_out[:, 2:4] 60 | rho11 = conv_out[:, 4:6] 61 | rho12 = conv_out[:, 6:8] 62 | rho01 = conv_out[:, 8:9] 63 | rho02 = conv_out[:, 9:10] 64 | rho03 = conv_out[:, 10:11] 65 | z = conv_out[:, 11:12] 66 | g = conv_out[:, 12:13] 67 | 68 | v = torch.cat((rho01, rho02, rho03, rho11, rho12, rho2), dim=1) 69 | m = torch.matmul(self.p_inv.to(conv_out.device), v.reshape(-1, 9, 1)).reshape(-1, 9) 70 | 71 | action = torch.cat((xy, z, m, g), dim=1) 72 | return action 73 | 74 | def getActionGeometricTensor(self, act): 75 | batch_size = act.shape[0] 76 | xy = act[:, 0:2] 77 | z = act[:, 2:3] 78 | m = act[:, 3:12] 79 | g = act[:, 12:] 80 | 81 | v = torch.matmul(self.p.to(act.device), m.reshape(-1, 9, 1)).reshape(-1, 9) 82 | 83 | cat = torch.cat( 84 | ( 85 | v[:, 7:9].reshape(batch_size, 2), 86 | xy.reshape(batch_size, 2), 87 | v[:, 5:7].reshape(batch_size, 2), 88 | v[:, 3:5].reshape(batch_size, 2), 89 | v[:, 0:3].reshape(batch_size, 3), 90 | z.reshape(batch_size, 1), 91 | g.reshape(batch_size, 1), 92 | ), 93 | dim=1, 94 | ) 95 | return nn.GeometricTensor(cat, self.getOutFieldType()) 96 | 97 | def forward(self, 98 | sample: torch.Tensor, 99 | timestep: Union[torch.Tensor, float, int], 100 | local_cond=None, global_cond=None, **kwargs): 101 | """ 102 | x: (B,T,input_dim) 103 | timestep: (B,) or int, diffusion step 104 | local_cond: (B,T,local_cond_dim) 105 | global_cond: (B,global_cond_dim) 106 | output: (B,T,input_dim) 107 | """ 108 | B, T = sample.shape[:2] 109 | sample = rearrange(sample, "b t d -> (b t) d") 110 | sample = self.getActionGeometricTensor(sample) 111 | enc_a_out = self.enc_a(sample).tensor.reshape(B, T, -1) 112 | enc_a_out = rearrange(enc_a_out, "b t (c f) -> (b f) t c", f=self.order) 113 | if type(timestep) == torch.Tensor and len(timestep.shape) == 1: 114 | timestep = repeat(timestep, "b -> (b f)", f=self.order) 115 | if local_cond is not None: 116 | local_cond = rearrange(local_cond, "b t (c f) -> (b f) t c", f=self.order) 117 | if global_cond is not None: 118 | global_cond = rearrange(global_cond, "b (c f) -> (b f) c", f=self.order) 119 | out = self.unet(enc_a_out, timestep, local_cond, global_cond, **kwargs) 120 | out = rearrange(out, "(b f) t c -> (b t) (c f)", f=self.order) 121 | out = nn.GeometricTensor(out, self.act_type) 122 | out = self.out_layer(out).tensor.reshape(B * T, -1) 123 | out = self.getOutput(out) 124 | out = rearrange(out, "(b t) n -> b t n", b=B) 125 | return out 126 | -------------------------------------------------------------------------------- /equi_diffpo/model/unet/obs_cond_unet1d.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import logging 3 | import torch 4 | import torch.nn as nn 5 | import einops 6 | from einops.layers.torch import Rearrange 7 | 8 | from equi_diffpo.model.diffusion.conv1d_components import ( 9 | Downsample1d, Upsample1d, Conv1dBlock) 10 | from equi_diffpo.model.diffusion.positional_embedding import SinusoidalPosEmb 11 | from equi_diffpo.model.diffusion.conditional_unet1d import ConditionalResidualBlock1D 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | class ObsConditionalUnet1D(nn.Module): 16 | def __init__(self, 17 | out_dim=[16,64], 18 | global_cond_dim=128, 19 | down_dims=[256,512,1024], 20 | kernel_size=3, 21 | n_groups=8, 22 | cond_predict_scale=False, 23 | ): 24 | super().__init__() 25 | out_c = out_dim[1] 26 | self.inp = torch.nn.Parameter(torch.randn(list(out_dim))) 27 | 28 | all_dims = [out_c] + list(down_dims) 29 | start_dim = down_dims[0] 30 | 31 | cond_dim = global_cond_dim 32 | 33 | in_out = list(zip(all_dims[:-1], all_dims[1:])) 34 | 35 | mid_dim = all_dims[-1] 36 | self.mid_modules = nn.ModuleList([ 37 | ConditionalResidualBlock1D( 38 | mid_dim, mid_dim, cond_dim=cond_dim, 39 | kernel_size=kernel_size, n_groups=n_groups, 40 | cond_predict_scale=cond_predict_scale 41 | ), 42 | ConditionalResidualBlock1D( 43 | mid_dim, mid_dim, cond_dim=cond_dim, 44 | kernel_size=kernel_size, n_groups=n_groups, 45 | cond_predict_scale=cond_predict_scale 46 | ), 47 | ]) 48 | 49 | down_modules = nn.ModuleList([]) 50 | for ind, (dim_in, dim_out) in enumerate(in_out): 51 | is_last = ind >= (len(in_out) - 1) 52 | down_modules.append(nn.ModuleList([ 53 | ConditionalResidualBlock1D( 54 | dim_in, dim_out, cond_dim=cond_dim, 55 | kernel_size=kernel_size, n_groups=n_groups, 56 | cond_predict_scale=cond_predict_scale), 57 | ConditionalResidualBlock1D( 58 | dim_out, dim_out, cond_dim=cond_dim, 59 | kernel_size=kernel_size, n_groups=n_groups, 60 | cond_predict_scale=cond_predict_scale), 61 | Downsample1d(dim_out) if not is_last else nn.Identity() 62 | ])) 63 | 64 | up_modules = nn.ModuleList([]) 65 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 66 | is_last = ind >= (len(in_out) - 1) 67 | up_modules.append(nn.ModuleList([ 68 | ConditionalResidualBlock1D( 69 | dim_out*2, dim_in, cond_dim=cond_dim, 70 | kernel_size=kernel_size, n_groups=n_groups, 71 | cond_predict_scale=cond_predict_scale), 72 | ConditionalResidualBlock1D( 73 | dim_in, dim_in, cond_dim=cond_dim, 74 | kernel_size=kernel_size, n_groups=n_groups, 75 | cond_predict_scale=cond_predict_scale), 76 | Upsample1d(dim_in) if not is_last else nn.Identity() 77 | ])) 78 | 79 | final_conv = nn.Sequential( 80 | Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size), 81 | nn.Conv1d(start_dim, out_c, 1), 82 | ) 83 | 84 | self.up_modules = up_modules 85 | self.down_modules = down_modules 86 | self.final_conv = final_conv 87 | 88 | logger.info( 89 | "number of parameters: %e", sum(p.numel() for p in self.parameters()) 90 | ) 91 | 92 | def forward(self, global_cond): 93 | """ 94 | global_cond: (B,global_cond_dim) 95 | output: (B,T,input_dim) 96 | """ 97 | B = global_cond.shape[0] 98 | x = einops.repeat(self.inp, 't h -> b h t', b=B) 99 | 100 | h = [] 101 | for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): 102 | x = resnet(x, global_cond) 103 | x = resnet2(x, global_cond) 104 | h.append(x) 105 | x = downsample(x) 106 | 107 | for mid_module in self.mid_modules: 108 | x = mid_module(x, global_cond) 109 | 110 | for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): 111 | x = torch.cat((x, h.pop()), dim=1) 112 | x = resnet(x, global_cond) 113 | x = resnet2(x, global_cond) 114 | x = upsample(x) 115 | 116 | x = self.final_conv(x) 117 | 118 | x = einops.rearrange(x, 'b h t -> b t h') 119 | return x 120 | 121 | 122 | -------------------------------------------------------------------------------- /equi_diffpo/model/vision/model_getter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | def get_resnet(name, weights=None, **kwargs): 5 | """ 6 | name: resnet18, resnet34, resnet50 7 | weights: "IMAGENET1K_V1", "r3m" 8 | """ 9 | # load r3m weights 10 | if (weights == "r3m") or (weights == "R3M"): 11 | return get_r3m(name=name, **kwargs) 12 | 13 | func = getattr(torchvision.models, name) 14 | resnet = func(weights=weights, **kwargs) 15 | resnet.fc = torch.nn.Identity() 16 | return resnet 17 | 18 | def get_r3m(name, **kwargs): 19 | """ 20 | name: resnet18, resnet34, resnet50 21 | """ 22 | import r3m 23 | r3m.device = 'cpu' 24 | model = r3m.load_r3m(name) 25 | r3m_model = model.module 26 | resnet_model = r3m_model.convnet 27 | resnet_model = resnet_model.to('cpu') 28 | return resnet_model 29 | -------------------------------------------------------------------------------- /equi_diffpo/model/vision/rot_randomizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | import torchvision.transforms.functional as TF 5 | import torch.nn.functional as F 6 | import random 7 | import numpy as np 8 | from einops import rearrange 9 | import math 10 | 11 | from equi_diffpo.model.common.rotation_transformer import RotationTransformer 12 | 13 | class RotRandomizer(nn.Module): 14 | """ 15 | Continuously and randomly rotate the input tensor during training. 16 | Does not rotate the tensor during evaluation. 17 | """ 18 | 19 | def __init__(self, min_angle=-180, max_angle=180): 20 | """ 21 | Args: 22 | min_angle (float): Minimum rotation angle. 23 | max_angle (float): Maximum rotation angle. 24 | """ 25 | super().__init__() 26 | self.min_angle = min_angle 27 | self.max_angle = max_angle 28 | self.tf = RotationTransformer('quaternion', 'matrix') 29 | 30 | 31 | def forward(self, nobs, naction): 32 | """ 33 | Randomly rotates the inputs if in training mode. 34 | Keeps inputs unchanged if in evaluation mode. 35 | 36 | Args: 37 | inputs (torch.Tensor): input tensors 38 | 39 | Returns: 40 | torch.Tensor: rotated or unrotated tensors based on the mode 41 | """ 42 | if self.training: 43 | obs = nobs["agentview_image"] 44 | pos = nobs["robot0_eef_pos"] 45 | # x, y, z, w -> w, x, y, z 46 | quat = nobs["robot0_eef_quat"][:, :, [3, 0, 1, 2]] 47 | batch_size = obs.shape[0] 48 | T = obs.shape[1] 49 | 50 | for i in range(1000): 51 | angles = torch.rand(batch_size) * 2 * np.pi - np.pi 52 | rotation_matrix = torch.zeros((batch_size, 3, 3), device=obs.device) 53 | rotation_matrix[:, 2, 2] = 1 54 | 55 | angles[torch.rand(batch_size) < 1/64] = 0 56 | rotation_matrix[:, 0, 0] = torch.cos(angles) 57 | rotation_matrix[:, 0, 1] = -torch.sin(angles) 58 | rotation_matrix[:, 1, 0] = torch.sin(angles) 59 | rotation_matrix[:, 1, 1] = torch.cos(angles) 60 | 61 | rotated_naction = naction.clone() 62 | rotated_naction[:, :, 0:3] = (rotation_matrix @ naction[:, :, 0:3].permute(0, 2, 1)).permute(0, 2, 1) 63 | rotated_naction[:, :, [3, 6]] = (rotation_matrix[:, :2, :2] @ naction[:, :, [3, 6]].permute(0, 2, 1)).permute(0, 2, 1) 64 | rotated_naction[:, :, [4, 7]] = (rotation_matrix[:, :2, :2] @ naction[:, :, [4, 7]].permute(0, 2, 1)).permute(0, 2, 1) 65 | rotated_naction[:, :, [5, 8]] = (rotation_matrix[:, :2, :2] @ naction[:, :, [5, 8]].permute(0, 2, 1)).permute(0, 2, 1) 66 | 67 | rotated_pos = (rotation_matrix @ pos.permute(0, 2, 1)).permute(0, 2, 1) 68 | rot = self.tf.forward(quat) 69 | rotated_rot = rotation_matrix.unsqueeze(1) @ rot 70 | rotated_quat = self.tf.inverse(rotated_rot) 71 | 72 | if rotated_pos.min() >= -1 and rotated_pos.max() <= 1 and rotated_naction[:, :, :2].min() > -1 and rotated_naction[:, :, :2].max() < 1: 73 | break 74 | if i == 999: 75 | return nobs, naction 76 | 77 | obs = rearrange(obs, "b t c h w -> b (t c) h w") 78 | grid = F.affine_grid(rotation_matrix[:, :2], obs.size(), align_corners=True) 79 | rotated_obs = F.grid_sample(obs, grid, align_corners=True, mode='bilinear') 80 | rotated_obs = rearrange(rotated_obs, "b (t c) h w -> b t c h w", b=batch_size, t=T) 81 | 82 | nobs["agentview_image"] = rotated_obs 83 | nobs["robot0_eef_pos"] = rotated_pos 84 | # w, x, y, z -> x, y, z, w 85 | nobs["robot0_eef_quat"] = rotated_quat[:, :, [1, 2, 3, 0]] 86 | naction = rotated_naction 87 | 88 | return nobs, naction 89 | 90 | def __repr__(self): 91 | """Pretty print the network.""" 92 | header = '{}'.format(str(self.__class__.__name__)) 93 | msg = header + "(min_angle={}, max_angle={})".format(self.min_angle, self.max_angle) 94 | return msg 95 | 96 | -------------------------------------------------------------------------------- /equi_diffpo/model/vision/rot_randomizer_vel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | import torchvision.transforms.functional as TF 5 | import random 6 | import numpy as np 7 | from einops import rearrange 8 | import math 9 | 10 | from equi_diffpo.model.common.rotation_transformer import RotationTransformer 11 | 12 | class RotRandomizerVel(nn.Module): 13 | """ 14 | Continuously and randomly rotate the input tensor during training. 15 | Does not rotate the tensor during evaluation. 16 | """ 17 | 18 | def __init__(self, min_angle=-180, max_angle=180): 19 | """ 20 | Args: 21 | min_angle (float): Minimum rotation angle. 22 | max_angle (float): Maximum rotation angle. 23 | """ 24 | super().__init__() 25 | self.min_angle = min_angle 26 | self.max_angle = max_angle 27 | self.quat_to_matrix = RotationTransformer('quaternion', 'matrix') 28 | self.axisangle_to_matrix = RotationTransformer('axis_angle', 'matrix') 29 | 30 | 31 | def forward(self, nobs, naction): 32 | """ 33 | Randomly rotates the inputs if in training mode. 34 | Keeps inputs unchanged if in evaluation mode. 35 | 36 | Args: 37 | inputs (torch.Tensor): input tensors 38 | 39 | Returns: 40 | torch.Tensor: rotated or unrotated tensors based on the mode 41 | """ 42 | if self.training and np.random.random() > 1/64: 43 | obs = nobs["agentview_image"] 44 | batch_size = obs.shape[0] 45 | 46 | angle = random.uniform(self.min_angle, self.max_angle) 47 | angle_rad = math.radians(angle) 48 | rotation_matrix = torch.tensor([[math.cos(angle_rad), -math.sin(angle_rad), 0], 49 | [math.sin(angle_rad), math.cos(angle_rad), 0], 50 | [0, 0, 1]]).to(obs.device) 51 | 52 | 53 | obs = rearrange(obs, "b t c h w -> (b t) c h w") 54 | rotated_obs = TF.rotate(obs, angle) 55 | rotated_obs = rearrange(rotated_obs, "(b t) c h w -> b t c h w", b=batch_size) 56 | nobs["agentview_image"] = rotated_obs 57 | 58 | if "crops" in nobs: 59 | crops = nobs["crops"] 60 | n_crop = crops.shape[2] 61 | crops = rearrange(crops, "b t n c h w -> (b t n) c h w") 62 | crops = TF.rotate(crops, angle) 63 | crops = rearrange(crops, "(b t n) c h w -> b t n c h w", b=batch_size, n=n_crop) 64 | nobs["crops"] = crops 65 | 66 | if "pos_vecs" in nobs: 67 | pos_vecs = nobs["pos_vecs"] 68 | pos_vecs = rearrange(pos_vecs, "b t n d -> (b t n) d") 69 | pos_vecs = (rotation_matrix[:2, :2] @ pos_vecs.T).T 70 | pos_vecs = rearrange(pos_vecs, "(b t n) d -> b t n d", b=batch_size, n=n_crop) 71 | nobs["pos_vecs"] = pos_vecs 72 | 73 | pos = nobs["robot0_eef_pos"] 74 | quat = nobs["robot0_eef_quat"] 75 | pos = rearrange(pos, "b t d -> (b t) d") 76 | quat = rearrange(quat, "b t d -> (b t) d") 77 | rot = self.quat_to_matrix.forward(quat) 78 | pos = (rotation_matrix @ pos.T).T 79 | rot = rotation_matrix @ rot 80 | quat = self.quat_to_matrix.inverse(rot) 81 | pos = rearrange(pos, "(b t) d -> b t d", b=batch_size) 82 | quat = rearrange(quat, "(b t) d -> b t d", b=batch_size) 83 | nobs["robot0_eef_pos"] = pos 84 | nobs["robot0_eef_quat"] = quat 85 | 86 | naction = rearrange(naction, "b t d -> (b t) d") 87 | naction[:, 0:3] = (rotation_matrix @ naction[:, 0:3].T).T 88 | axis_angle = naction[:, 3:6] 89 | m = self.axisangle_to_matrix.forward(axis_angle) 90 | gm = rotation_matrix @ m @ torch.linalg.inv(rotation_matrix) 91 | g_axis_angle = self.axisangle_to_matrix.inverse(gm) 92 | naction[:, 3:6] = g_axis_angle 93 | 94 | naction = rearrange(naction, "(b t) d -> b t d", b=batch_size) 95 | naction = torch.clip(naction, -1, 1) 96 | return nobs, naction 97 | 98 | def __repr__(self): 99 | """Pretty print the network.""" 100 | header = '{}'.format(str(self.__class__.__name__)) 101 | msg = header + "(min_angle={}, max_angle={})".format(self.min_angle, self.max_angle) 102 | return msg 103 | 104 | -------------------------------------------------------------------------------- /equi_diffpo/model/vision/voxel_crop_randomizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class VoxelCropRandomizer(nn.Module): 5 | def __init__( 6 | self, 7 | crop_depth, 8 | crop_height, 9 | crop_width, 10 | ): 11 | super().__init__() 12 | self.crop_depth = crop_depth 13 | self.crop_height = crop_height 14 | self.crop_width = crop_width 15 | 16 | def forward(self, voxels): 17 | B, C, D, H, W = voxels.shape 18 | if self.training: 19 | cropped_voxel = [] 20 | for i in range(B): 21 | d_start = torch.randint(0, D-self.crop_depth+1, [1])[0] 22 | h_start = torch.randint(0, H-self.crop_height+1, [1])[0] 23 | w_start = torch.randint(0, W-self.crop_width+1, [1])[0] 24 | cropped_voxel.append(voxels[i, 25 | :, 26 | d_start:d_start+self.crop_depth, 27 | h_start:h_start+self.crop_height, 28 | w_start:w_start+self.crop_width, 29 | ]) 30 | return torch.stack(cropped_voxel, 0) 31 | else: 32 | voxels = voxels[:, 33 | :, 34 | D//2 - self.crop_depth//2: D//2 + self.crop_depth//2, 35 | H//2 - self.crop_height//2: H//2 + self.crop_height//2, 36 | W//2 - self.crop_width//2: W//2 + self.crop_width//2, 37 | ] 38 | return voxels 39 | -------------------------------------------------------------------------------- /equi_diffpo/model/vision/voxel_rot_randomizer_rel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | import torchvision.transforms.functional as TF 5 | import torch.nn.functional as F 6 | import random 7 | import numpy as np 8 | from einops import rearrange, repeat 9 | import math 10 | from copy import deepcopy 11 | 12 | from equi_diffpo.model.common.rotation_transformer import RotationTransformer 13 | 14 | class VoxelRotRandomizerRel(nn.Module): 15 | """ 16 | Continuously and randomly rotate the input tensor during training. 17 | Does not rotate the tensor during evaluation. 18 | """ 19 | 20 | def __init__(self, min_angle=-180, max_angle=180): 21 | """ 22 | Args: 23 | min_angle (float): Minimum rotation angle. 24 | max_angle (float): Maximum rotation angle. 25 | """ 26 | super().__init__() 27 | self.min_angle = min_angle 28 | self.max_angle = max_angle 29 | self.quat_to_matrix = RotationTransformer('quaternion', 'matrix') 30 | self.axisangle_to_matrix = RotationTransformer('axis_angle', 'matrix') 31 | 32 | 33 | def forward(self, nobs, naction: torch.Tensor): 34 | """ 35 | Randomly rotates the inputs if in training mode. 36 | Keeps inputs unchanged if in evaluation mode. 37 | 38 | Args: 39 | inputs (torch.Tensor): input tensors 40 | 41 | Returns: 42 | torch.Tensor: rotated or unrotated tensors based on the mode 43 | """ 44 | if self.training: 45 | obs = nobs["voxels"] 46 | pos = nobs["robot0_eef_pos"] 47 | # x, y, z, w -> w, x, y, z 48 | quat = nobs["robot0_eef_quat"][:, :, [3, 0, 1, 2]] 49 | batch_size = obs.shape[0] 50 | To = obs.shape[1] 51 | C = obs.shape[2] 52 | Ta = naction.shape[1] 53 | 54 | for i in range(1000): 55 | angles = torch.rand(batch_size) * 2 * np.pi - np.pi 56 | rotation_matrix = torch.zeros((batch_size, 3, 3), device=obs.device) 57 | rotation_matrix[:, 2, 2] = 1 58 | # construct rotation matrix 59 | angles[torch.rand(batch_size) < 1/64] = 0 60 | rotation_matrix[:, 0, 0] = torch.cos(angles) 61 | rotation_matrix[:, 0, 1] = -torch.sin(angles) 62 | rotation_matrix[:, 1, 0] = torch.sin(angles) 63 | rotation_matrix[:, 1, 1] = torch.cos(angles) 64 | # rotating the xyz vector in action 65 | rotated_naction = naction.clone() 66 | expanded_rotation_matrix = repeat(rotation_matrix, 'b d1 d2 -> b t d1 d2', t=Ta) 67 | rotated_naction[:, :, :3] = (expanded_rotation_matrix @ naction[:, :, :3].unsqueeze(-1)).squeeze(-1) 68 | # rotating the axis angle rotation vector in action 69 | axis_angle = rotated_naction[:, :, 3:6] 70 | m = self.axisangle_to_matrix.forward(axis_angle) 71 | rotation_matrix_inv = rotation_matrix.transpose(1, 2) 72 | expanded_rotation_matrix_inv = repeat(rotation_matrix_inv, 'b d1 d2 -> b t d1 d2', t=Ta) 73 | gm = expanded_rotation_matrix @ m @ expanded_rotation_matrix_inv 74 | g_axis_angle = self.axisangle_to_matrix.inverse(gm) 75 | rotated_naction[:, :, 3:6] = g_axis_angle 76 | # rotating state pos and quat 77 | rotated_pos = (rotation_matrix @ pos.permute(0, 2, 1)).permute(0, 2, 1) 78 | rot = self.quat_to_matrix.forward(quat) 79 | rotated_rot = rotation_matrix.unsqueeze(1) @ rot 80 | rotated_quat = self.quat_to_matrix.inverse(rotated_rot) 81 | 82 | if rotated_pos.min() >= -1 and rotated_pos.max() <= 1: 83 | break 84 | if i == 999: 85 | return nobs, naction 86 | 87 | obs = rearrange(obs, "b t c h w d -> b t c d w h") 88 | obs = torch.flip(obs, (3, 4)) 89 | obs = rearrange(obs, "b t c d w h -> b (t c d) w h") 90 | grid = F.affine_grid(rotation_matrix[:, :2], obs.size(), align_corners=True) 91 | rotated_obs = F.grid_sample(obs, grid, align_corners=True, mode='nearest') 92 | rotated_obs = rearrange(rotated_obs, "b (t c d) w h -> b t c d w h", c=C, t=To) 93 | rotated_obs = torch.flip(rotated_obs, (3, 4)) 94 | rotated_obs = rearrange(rotated_obs, "b t c d w h -> b t c h w d") 95 | 96 | nobs["voxels"] = rotated_obs 97 | nobs["robot0_eef_pos"] = rotated_pos 98 | # w, x, y, z -> x, y, z, w 99 | nobs["robot0_eef_quat"] = rotated_quat[:, :, [1, 2, 3, 0]] 100 | naction = rotated_naction 101 | 102 | return nobs, naction 103 | 104 | def __repr__(self): 105 | """Pretty print the network.""" 106 | header = '{}'.format(str(self.__class__.__name__)) 107 | msg = header + "(min_angle={}, max_angle={})".format(self.min_angle, self.max_angle) 108 | return msg 109 | 110 | -------------------------------------------------------------------------------- /equi_diffpo/policy/base_image_policy.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | import torch.nn as nn 4 | from equi_diffpo.model.common.module_attr_mixin import ModuleAttrMixin 5 | from equi_diffpo.model.common.normalizer import LinearNormalizer 6 | 7 | class BaseImagePolicy(ModuleAttrMixin): 8 | # init accepts keyword argument shape_meta, see config/task/*_image.yaml 9 | 10 | def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 11 | """ 12 | obs_dict: 13 | str: B,To,* 14 | return: B,Ta,Da 15 | """ 16 | raise NotImplementedError() 17 | 18 | # reset state for stateful policies 19 | def reset(self): 20 | pass 21 | 22 | # ========== training =========== 23 | # no standard training interface except setting normalizer 24 | def set_normalizer(self, normalizer: LinearNormalizer): 25 | raise NotImplementedError() 26 | -------------------------------------------------------------------------------- /equi_diffpo/policy/base_lowdim_policy.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | import torch.nn as nn 4 | from equi_diffpo.model.common.module_attr_mixin import ModuleAttrMixin 5 | from equi_diffpo.model.common.normalizer import LinearNormalizer 6 | 7 | class BaseLowdimPolicy(ModuleAttrMixin): 8 | # ========= inference ============ 9 | # also as self.device and self.dtype for inference device transfer 10 | def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 11 | """ 12 | obs_dict: 13 | obs: B,To,Do 14 | return: 15 | action: B,Ta,Da 16 | To = 3 17 | Ta = 4 18 | T = 6 19 | |o|o|o| 20 | | | |a|a|a|a| 21 | |o|o| 22 | | |a|a|a|a|a| 23 | | | | | |a|a| 24 | """ 25 | raise NotImplementedError() 26 | 27 | # reset state for stateful policies 28 | def reset(self): 29 | pass 30 | 31 | # ========== training =========== 32 | # no standard training interface except setting normalizer 33 | def set_normalizer(self, normalizer: LinearNormalizer): 34 | raise NotImplementedError() 35 | 36 | -------------------------------------------------------------------------------- /equi_diffpo/policy/robomimic_image_policy.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | from equi_diffpo.model.common.normalizer import LinearNormalizer 4 | from equi_diffpo.policy.base_image_policy import BaseImagePolicy 5 | from equi_diffpo.common.pytorch_util import dict_apply 6 | 7 | from robomimic.algo import algo_factory 8 | from robomimic.algo.algo import PolicyAlgo 9 | import robomimic.utils.obs_utils as ObsUtils 10 | from equi_diffpo.common.robomimic_config_util import get_robomimic_config 11 | 12 | class RobomimicImagePolicy(BaseImagePolicy): 13 | def __init__(self, 14 | shape_meta: dict, 15 | algo_name='bc_rnn', 16 | obs_type='image', 17 | task_name='square', 18 | dataset_type='ph', 19 | crop_shape=(76,76) 20 | ): 21 | super().__init__() 22 | 23 | # parse shape_meta 24 | action_shape = shape_meta['action']['shape'] 25 | assert len(action_shape) == 1 26 | action_dim = action_shape[0] 27 | obs_shape_meta = shape_meta['obs'] 28 | obs_config = { 29 | 'low_dim': [], 30 | 'rgb': [], 31 | 'depth': [], 32 | 'scan': [] 33 | } 34 | obs_key_shapes = dict() 35 | for key, attr in obs_shape_meta.items(): 36 | shape = attr['shape'] 37 | obs_key_shapes[key] = list(shape) 38 | 39 | type = attr.get('type', 'low_dim') 40 | if type == 'rgb': 41 | obs_config['rgb'].append(key) 42 | elif type == 'low_dim': 43 | obs_config['low_dim'].append(key) 44 | else: 45 | raise RuntimeError(f"Unsupported obs type: {type}") 46 | 47 | # get raw robomimic config 48 | config = get_robomimic_config( 49 | algo_name=algo_name, 50 | hdf5_type=obs_type, 51 | task_name=task_name, 52 | dataset_type=dataset_type) 53 | 54 | 55 | with config.unlocked(): 56 | # set config with shape_meta 57 | config.observation.modalities.obs = obs_config 58 | 59 | if crop_shape is None: 60 | for key, modality in config.observation.encoder.items(): 61 | if modality.obs_randomizer_class == 'CropRandomizer': 62 | modality['obs_randomizer_class'] = None 63 | else: 64 | # set random crop parameter 65 | ch, cw = crop_shape 66 | for key, modality in config.observation.encoder.items(): 67 | if modality.obs_randomizer_class == 'CropRandomizer': 68 | modality.obs_randomizer_kwargs.crop_height = ch 69 | modality.obs_randomizer_kwargs.crop_width = cw 70 | 71 | # init global state 72 | ObsUtils.initialize_obs_utils_with_config(config) 73 | 74 | # load model 75 | model: PolicyAlgo = algo_factory( 76 | algo_name=config.algo_name, 77 | config=config, 78 | obs_key_shapes=obs_key_shapes, 79 | ac_dim=action_dim, 80 | device='cpu', 81 | ) 82 | 83 | self.model = model 84 | self.nets = model.nets 85 | self.normalizer = LinearNormalizer() 86 | self.config = config 87 | 88 | def to(self,*args,**kwargs): 89 | device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) 90 | if device is not None: 91 | self.model.device = device 92 | super().to(*args,**kwargs) 93 | 94 | # =========== inference ============= 95 | def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 96 | nobs_dict = self.normalizer(obs_dict) 97 | robomimic_obs_dict = dict_apply(nobs_dict, lambda x: x[:,0,...]) 98 | naction = self.model.get_action(robomimic_obs_dict) 99 | action = self.normalizer['action'].unnormalize(naction) 100 | # (B, Da) 101 | result = { 102 | 'action': action[:,None,:] # (B, 1, Da) 103 | } 104 | return result 105 | 106 | def reset(self): 107 | self.model.reset() 108 | 109 | # =========== training ============== 110 | def set_normalizer(self, normalizer: LinearNormalizer): 111 | self.normalizer.load_state_dict(normalizer.state_dict()) 112 | 113 | def train_on_batch(self, batch, epoch, validate=False): 114 | nobs = self.normalizer.normalize(batch['obs']) 115 | nactions = self.normalizer['action'].normalize(batch['action']) 116 | robomimic_batch = { 117 | 'obs': nobs, 118 | 'actions': nactions 119 | } 120 | input_batch = self.model.process_batch_for_training( 121 | robomimic_batch) 122 | info = self.model.train_on_batch( 123 | batch=input_batch, epoch=epoch, validate=validate) 124 | # keys: losses, predictions 125 | return info 126 | 127 | def on_epoch_end(self, epoch): 128 | self.model.on_epoch_end(epoch) 129 | 130 | def get_optimizer(self): 131 | return self.model.optimizers['policy'] 132 | 133 | 134 | def test(): 135 | import os 136 | from omegaconf import OmegaConf 137 | cfg_path = os.path.expanduser('~/dev/diffusion_policy/diffusion_policy/config/task/lift_image.yaml') 138 | cfg = OmegaConf.load(cfg_path) 139 | shape_meta = cfg.shape_meta 140 | 141 | policy = RobomimicImagePolicy(shape_meta=shape_meta) 142 | 143 | -------------------------------------------------------------------------------- /equi_diffpo/scripts/download_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details]. 4 | 5 | """ 6 | Script to download datasets packaged with the repository. 7 | """ 8 | import os 9 | import argparse 10 | 11 | import mimicgen_envs 12 | import mimicgen_envs.utils.file_utils as FileUtils 13 | from mimicgen_envs import DATASET_REGISTRY 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | 19 | # directory to download datasets to 20 | parser.add_argument( 21 | "--download_dir", 22 | type=str, 23 | default=None, 24 | help="Base download directory. Created if it doesn't exist. Defaults to datasets folder in repository.", 25 | ) 26 | 27 | # dataset type to download datasets for 28 | parser.add_argument( 29 | "--dataset_type", 30 | type=str, 31 | default="core", 32 | choices=list(DATASET_REGISTRY.keys()), 33 | help="Dataset type to download datasets for (e.g. source, core, object, robot, large_interpolation). Defaults to core.", 34 | ) 35 | 36 | # tasks to download datasets for 37 | parser.add_argument( 38 | "--tasks", 39 | type=str, 40 | nargs='+', 41 | default=["square_d0"], 42 | help="Tasks to download datasets for. Defaults to square_d0 task. Pass 'all' to download all tasks\ 43 | for the provided dataset type or directly specify the list of tasks.", 44 | ) 45 | 46 | # dry run - don't actually download datasets, but print which datasets would be downloaded 47 | parser.add_argument( 48 | "--dry_run", 49 | action='store_true', 50 | help="set this flag to do a dry run to only print which datasets would be downloaded" 51 | ) 52 | 53 | args = parser.parse_args() 54 | 55 | # set default base directory for downloads 56 | default_base_dir = args.download_dir 57 | if default_base_dir is None: 58 | default_base_dir = "data/robomimic/datasets" 59 | 60 | # load args 61 | download_dataset_type = args.dataset_type 62 | download_tasks = args.tasks 63 | if "all" in download_tasks: 64 | assert len(download_tasks) == 1, "all should be only tasks argument but got: {}".format(args.tasks) 65 | download_tasks = list(DATASET_REGISTRY[download_dataset_type].keys()) 66 | else: 67 | for task in download_tasks: 68 | assert task in DATASET_REGISTRY[download_dataset_type], "got unknown task {} for dataset type {}. Choose one of {}".format(task, download_dataset_type, list(DATASET_REGISTRY[download_dataset_type].keys())) 69 | 70 | # download requested datasets 71 | for task in download_tasks: 72 | download_dir = os.path.abspath(os.path.join(default_base_dir, task)) 73 | download_path = os.path.join(download_dir, "{}.hdf5".format(task)) 74 | print("\nDownloading dataset:\n dataset type: {}\n task: {}\n download path: {}" 75 | .format(download_dataset_type, task, download_path)) 76 | url = DATASET_REGISTRY[download_dataset_type][task]["url"] 77 | if args.dry_run: 78 | print("\ndry run: skip download") 79 | else: 80 | # Make sure path exists and create if it doesn't 81 | os.makedirs(download_dir, exist_ok=True) 82 | print("") 83 | FileUtils.download_url_from_gdrive( 84 | url=url, 85 | download_dir=download_dir, 86 | check_overwrite=True, 87 | ) 88 | print("") -------------------------------------------------------------------------------- /equi_diffpo/scripts/robomimic_dataset_action_comparison.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import sys 3 | import os 4 | import pathlib 5 | 6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) 7 | sys.path.append(ROOT_DIR) 8 | 9 | import os 10 | import click 11 | import pathlib 12 | import h5py 13 | import numpy as np 14 | from tqdm import tqdm 15 | from scipy.spatial.transform import Rotation 16 | 17 | def read_all_actions(hdf5_file, metric_skip_steps=1): 18 | n_demos = len(hdf5_file['data']) 19 | all_actions = list() 20 | for i in tqdm(range(n_demos)): 21 | actions = hdf5_file[f'data/demo_{i}/actions'][:] 22 | all_actions.append(actions[metric_skip_steps:]) 23 | all_actions = np.concatenate(all_actions, axis=0) 24 | return all_actions 25 | 26 | 27 | @click.command() 28 | @click.option('-i', '--input', required=True, help='input hdf5 path') 29 | @click.option('-o', '--output', required=True, help='output hdf5 path. Parent directory must exist') 30 | def main(input, output): 31 | # process inputs 32 | input = pathlib.Path(input).expanduser() 33 | assert input.is_file() 34 | output = pathlib.Path(output).expanduser() 35 | assert output.is_file() 36 | 37 | input_file = h5py.File(str(input), 'r') 38 | output_file = h5py.File(str(output), 'r') 39 | 40 | input_all_actions = read_all_actions(input_file) 41 | output_all_actions = read_all_actions(output_file) 42 | pos_dist = np.linalg.norm(input_all_actions[:,:3] - output_all_actions[:,:3], axis=-1) 43 | rot_dist = (Rotation.from_rotvec(input_all_actions[:,3:6] 44 | ) * Rotation.from_rotvec(output_all_actions[:,3:6]).inv() 45 | ).magnitude() 46 | 47 | print(f'max pos dist: {pos_dist.max()}') 48 | print(f'max rot dist: {rot_dist.max()}') 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /equi_diffpo/scripts/robomimic_dataset_conversion.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import sys 3 | import os 4 | import pathlib 5 | 6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) 7 | sys.path.append(ROOT_DIR) 8 | 9 | import multiprocessing 10 | import os 11 | import shutil 12 | import click 13 | import pathlib 14 | import h5py 15 | from tqdm import tqdm 16 | import collections 17 | import pickle 18 | from equi_diffpo.common.robomimic_util import RobomimicAbsoluteActionConverter 19 | 20 | def worker(x): 21 | path, idx, do_eval = x 22 | converter = RobomimicAbsoluteActionConverter(path) 23 | if do_eval: 24 | abs_actions, info = converter.convert_and_eval_idx(idx) 25 | else: 26 | abs_actions = converter.convert_idx(idx) 27 | info = dict() 28 | return abs_actions, info 29 | 30 | @click.command() 31 | @click.option('-i', '--input', required=True, help='input hdf5 path') 32 | @click.option('-o', '--output', required=True, help='output hdf5 path. Parent directory must exist') 33 | @click.option('-e', '--eval_dir', default=None, help='directory to output evaluation metrics') 34 | @click.option('-n', '--num_workers', default=None, type=int) 35 | def main(input, output, eval_dir, num_workers): 36 | # process inputs 37 | input = pathlib.Path(input).expanduser() 38 | assert input.is_file() 39 | output = pathlib.Path(output).expanduser() 40 | assert output.parent.is_dir() 41 | assert not output.is_dir() 42 | 43 | do_eval = False 44 | if eval_dir is not None: 45 | eval_dir = pathlib.Path(eval_dir).expanduser() 46 | assert eval_dir.parent.exists() 47 | do_eval = True 48 | 49 | converter = RobomimicAbsoluteActionConverter(input) 50 | 51 | # run 52 | with multiprocessing.Pool(num_workers) as pool: 53 | results = pool.map(worker, [(input, i, do_eval) for i in range(len(converter))]) 54 | 55 | # save output 56 | print('Copying hdf5') 57 | shutil.copy(str(input), str(output)) 58 | 59 | # modify action 60 | with h5py.File(output, 'r+') as out_file: 61 | for i in tqdm(range(len(converter)), desc="Writing to output"): 62 | abs_actions, info = results[i] 63 | demo = out_file[f'data/demo_{i}'] 64 | demo['actions'][:] = abs_actions 65 | 66 | # save eval 67 | if do_eval: 68 | eval_dir.mkdir(parents=False, exist_ok=True) 69 | 70 | print("Writing error_stats.pkl") 71 | infos = [info for _, info in results] 72 | pickle.dump(infos, eval_dir.joinpath('error_stats.pkl').open('wb')) 73 | 74 | print("Generating visualization") 75 | metrics = ['pos', 'rot'] 76 | metrics_dicts = dict() 77 | for m in metrics: 78 | metrics_dicts[m] = collections.defaultdict(list) 79 | 80 | for i in range(len(infos)): 81 | info = infos[i] 82 | for k, v in info.items(): 83 | for m in metrics: 84 | metrics_dicts[m][k].append(v[m]) 85 | 86 | from matplotlib import pyplot as plt 87 | plt.switch_backend('PDF') 88 | 89 | fig, ax = plt.subplots(1, len(metrics)) 90 | for i in range(len(metrics)): 91 | axis = ax[i] 92 | data = metrics_dicts[metrics[i]] 93 | for key, value in data.items(): 94 | axis.plot(value, label=key) 95 | axis.legend() 96 | axis.set_title(metrics[i]) 97 | fig.set_size_inches(10,4) 98 | fig.savefig(str(eval_dir.joinpath('error_stats.pdf'))) 99 | fig.savefig(str(eval_dir.joinpath('error_stats.png'))) 100 | 101 | 102 | if __name__ == "__main__": 103 | main() 104 | -------------------------------------------------------------------------------- /equi_diffpo/scripts/robomimic_dataset_obs_conversion.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import sys 3 | import os 4 | import pathlib 5 | 6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) 7 | sys.path.append(ROOT_DIR) 8 | 9 | import multiprocessing 10 | import os 11 | import shutil 12 | import click 13 | import pathlib 14 | import h5py 15 | from tqdm import tqdm 16 | import numpy as np 17 | import collections 18 | import pickle 19 | from equi_diffpo.common.robomimic_util import RobomimicObsConverter 20 | 21 | multiprocessing.set_start_method('spawn', force=True) 22 | 23 | def worker(x): 24 | path, idx = x 25 | converter = RobomimicObsConverter(path) 26 | obss = converter.convert_idx(idx) 27 | return obss 28 | 29 | @click.command() 30 | @click.option('-i', '--input', required=True, help='input hdf5 path') 31 | @click.option('-o', '--output', required=True, help='output hdf5 path. Parent directory must exist') 32 | @click.option('-n', '--num_workers', default=None, type=int) 33 | def main(input, output, num_workers): 34 | # process inputs 35 | input = pathlib.Path(input).expanduser() 36 | assert input.is_file() 37 | output = pathlib.Path(output).expanduser() 38 | assert output.parent.is_dir() 39 | assert not output.is_dir() 40 | 41 | converter = RobomimicObsConverter(input) 42 | 43 | # save output 44 | print('Copying hdf5') 45 | shutil.copy(str(input), str(output)) 46 | 47 | # run 48 | idx = 0 49 | while idx < len(converter): 50 | with multiprocessing.Pool(num_workers) as pool: 51 | end = min(idx + num_workers, len(converter)) 52 | results = pool.map(worker, [(input, i) for i in range(idx, end)]) 53 | 54 | # modify action 55 | print('Writing {} to {}'.format(idx, end)) 56 | with h5py.File(output, 'r+') as out_file: 57 | for i in tqdm(range(idx, end), desc="Writing to output"): 58 | obss = results[i - idx] 59 | demo = out_file[f'data/demo_{i}'] 60 | del demo['obs'] 61 | for k in obss: 62 | demo.create_dataset("obs/{}".format(k), data=np.array(obss[k]), compression="gzip") 63 | 64 | idx = end 65 | del results 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /equi_diffpo/shared_memory/shared_memory_queue.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Union 2 | import numbers 3 | from queue import (Empty, Full) 4 | from multiprocessing.managers import SharedMemoryManager 5 | import numpy as np 6 | from equi_diffpo.shared_memory.shared_memory_util import ArraySpec, SharedAtomicCounter 7 | from equi_diffpo.shared_memory.shared_ndarray import SharedNDArray 8 | 9 | 10 | class SharedMemoryQueue: 11 | """ 12 | A Lock-Free FIFO Shared Memory Data Structure. 13 | Stores a sequence of dict of numpy arrays. 14 | """ 15 | 16 | def __init__(self, 17 | shm_manager: SharedMemoryManager, 18 | array_specs: List[ArraySpec], 19 | buffer_size: int 20 | ): 21 | 22 | # create atomic counter 23 | write_counter = SharedAtomicCounter(shm_manager) 24 | read_counter = SharedAtomicCounter(shm_manager) 25 | 26 | # allocate shared memory 27 | shared_arrays = dict() 28 | for spec in array_specs: 29 | key = spec.name 30 | assert key not in shared_arrays 31 | array = SharedNDArray.create_from_shape( 32 | mem_mgr=shm_manager, 33 | shape=(buffer_size,) + tuple(spec.shape), 34 | dtype=spec.dtype) 35 | shared_arrays[key] = array 36 | 37 | self.buffer_size = buffer_size 38 | self.array_specs = array_specs 39 | self.write_counter = write_counter 40 | self.read_counter = read_counter 41 | self.shared_arrays = shared_arrays 42 | 43 | @classmethod 44 | def create_from_examples(cls, 45 | shm_manager: SharedMemoryManager, 46 | examples: Dict[str, Union[np.ndarray, numbers.Number]], 47 | buffer_size: int 48 | ): 49 | specs = list() 50 | for key, value in examples.items(): 51 | shape = None 52 | dtype = None 53 | if isinstance(value, np.ndarray): 54 | shape = value.shape 55 | dtype = value.dtype 56 | assert dtype != np.dtype('O') 57 | elif isinstance(value, numbers.Number): 58 | shape = tuple() 59 | dtype = np.dtype(type(value)) 60 | else: 61 | raise TypeError(f'Unsupported type {type(value)}') 62 | 63 | spec = ArraySpec( 64 | name=key, 65 | shape=shape, 66 | dtype=dtype 67 | ) 68 | specs.append(spec) 69 | 70 | obj = cls( 71 | shm_manager=shm_manager, 72 | array_specs=specs, 73 | buffer_size=buffer_size 74 | ) 75 | return obj 76 | 77 | def qsize(self): 78 | read_count = self.read_counter.load() 79 | write_count = self.write_counter.load() 80 | n_data = write_count - read_count 81 | return n_data 82 | 83 | def empty(self): 84 | n_data = self.qsize() 85 | return n_data <= 0 86 | 87 | def clear(self): 88 | self.read_counter.store(self.write_counter.load()) 89 | 90 | def put(self, data: Dict[str, Union[np.ndarray, numbers.Number]]): 91 | read_count = self.read_counter.load() 92 | write_count = self.write_counter.load() 93 | n_data = write_count - read_count 94 | if n_data >= self.buffer_size: 95 | raise Full() 96 | 97 | next_idx = write_count % self.buffer_size 98 | 99 | # write to shared memory 100 | for key, value in data.items(): 101 | arr: np.ndarray 102 | arr = self.shared_arrays[key].get() 103 | if isinstance(value, np.ndarray): 104 | arr[next_idx] = value 105 | else: 106 | arr[next_idx] = np.array(value, dtype=arr.dtype) 107 | 108 | # update idx 109 | self.write_counter.add(1) 110 | 111 | def get(self, out=None) -> Dict[str, np.ndarray]: 112 | write_count = self.write_counter.load() 113 | read_count = self.read_counter.load() 114 | n_data = write_count - read_count 115 | if n_data <= 0: 116 | raise Empty() 117 | 118 | if out is None: 119 | out = self._allocate_empty() 120 | 121 | next_idx = read_count % self.buffer_size 122 | for key, value in self.shared_arrays.items(): 123 | arr = value.get() 124 | np.copyto(out[key], arr[next_idx]) 125 | 126 | # update idx 127 | self.read_counter.add(1) 128 | return out 129 | 130 | def get_k(self, k, out=None) -> Dict[str, np.ndarray]: 131 | write_count = self.write_counter.load() 132 | read_count = self.read_counter.load() 133 | n_data = write_count - read_count 134 | if n_data <= 0: 135 | raise Empty() 136 | assert k <= n_data 137 | 138 | out = self._get_k_impl(k, read_count, out=out) 139 | self.read_counter.add(k) 140 | return out 141 | 142 | def get_all(self, out=None) -> Dict[str, np.ndarray]: 143 | write_count = self.write_counter.load() 144 | read_count = self.read_counter.load() 145 | n_data = write_count - read_count 146 | if n_data <= 0: 147 | raise Empty() 148 | 149 | out = self._get_k_impl(n_data, read_count, out=out) 150 | self.read_counter.add(n_data) 151 | return out 152 | 153 | def _get_k_impl(self, k, read_count, out=None) -> Dict[str, np.ndarray]: 154 | if out is None: 155 | out = self._allocate_empty(k) 156 | 157 | curr_idx = read_count % self.buffer_size 158 | for key, value in self.shared_arrays.items(): 159 | arr = value.get() 160 | target = out[key] 161 | 162 | start = curr_idx 163 | end = min(start + k, self.buffer_size) 164 | target_start = 0 165 | target_end = (end - start) 166 | target[target_start: target_end] = arr[start:end] 167 | 168 | remainder = k - (end - start) 169 | if remainder > 0: 170 | # wrap around 171 | start = 0 172 | end = start + remainder 173 | target_start = target_end 174 | target_end = k 175 | target[target_start: target_end] = arr[start:end] 176 | 177 | return out 178 | 179 | def _allocate_empty(self, k=None): 180 | result = dict() 181 | for spec in self.array_specs: 182 | shape = spec.shape 183 | if k is not None: 184 | shape = (k,) + shape 185 | result[spec.name] = np.empty( 186 | shape=shape, dtype=spec.dtype) 187 | return result 188 | -------------------------------------------------------------------------------- /equi_diffpo/shared_memory/shared_memory_util.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from dataclasses import dataclass 3 | import numpy as np 4 | from multiprocessing.managers import SharedMemoryManager 5 | from atomics import atomicview, MemoryOrder, UINT 6 | 7 | @dataclass 8 | class ArraySpec: 9 | name: str 10 | shape: Tuple[int] 11 | dtype: np.dtype 12 | 13 | 14 | class SharedAtomicCounter: 15 | def __init__(self, 16 | shm_manager: SharedMemoryManager, 17 | size :int=8 # 64bit int 18 | ): 19 | shm = shm_manager.SharedMemory(size=size) 20 | self.shm = shm 21 | self.size = size 22 | self.store(0) # initialize 23 | 24 | @property 25 | def buf(self): 26 | return self.shm.buf[:self.size] 27 | 28 | def load(self) -> int: 29 | with atomicview(buffer=self.buf, atype=UINT) as a: 30 | value = a.load(order=MemoryOrder.ACQUIRE) 31 | return value 32 | 33 | def store(self, value: int): 34 | with atomicview(buffer=self.buf, atype=UINT) as a: 35 | a.store(value, order=MemoryOrder.RELEASE) 36 | 37 | def add(self, value: int): 38 | with atomicview(buffer=self.buf, atype=UINT) as a: 39 | a.add(value, order=MemoryOrder.ACQ_REL) 40 | -------------------------------------------------------------------------------- /equi_diffpo/workspace/base_workspace.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import os 3 | import pathlib 4 | import hydra 5 | import copy 6 | from hydra.core.hydra_config import HydraConfig 7 | from omegaconf import OmegaConf 8 | import dill 9 | import torch 10 | import threading 11 | 12 | 13 | class BaseWorkspace: 14 | include_keys = tuple() 15 | exclude_keys = tuple() 16 | 17 | def __init__(self, cfg: OmegaConf, output_dir: Optional[str]=None): 18 | self.cfg = cfg 19 | self._output_dir = output_dir 20 | self._saving_thread = None 21 | 22 | @property 23 | def output_dir(self): 24 | output_dir = self._output_dir 25 | if output_dir is None: 26 | output_dir = HydraConfig.get().runtime.output_dir 27 | return output_dir 28 | 29 | def run(self): 30 | """ 31 | Create any resource shouldn't be serialized as local variables 32 | """ 33 | pass 34 | 35 | def save_checkpoint(self, path=None, tag='latest', 36 | exclude_keys=None, 37 | include_keys=None, 38 | use_thread=True): 39 | if path is None: 40 | path = pathlib.Path(self.output_dir).joinpath('checkpoints', f'{tag}.ckpt') 41 | else: 42 | path = pathlib.Path(path) 43 | if exclude_keys is None: 44 | exclude_keys = tuple(self.exclude_keys) 45 | if include_keys is None: 46 | include_keys = tuple(self.include_keys) + ('_output_dir',) 47 | 48 | path.parent.mkdir(parents=False, exist_ok=True) 49 | payload = { 50 | 'cfg': self.cfg, 51 | 'state_dicts': dict(), 52 | 'pickles': dict() 53 | } 54 | 55 | for key, value in self.__dict__.items(): 56 | if hasattr(value, 'state_dict') and hasattr(value, 'load_state_dict'): 57 | # modules, optimizers and samplers etc 58 | if key not in exclude_keys: 59 | if use_thread: 60 | payload['state_dicts'][key] = _copy_to_cpu(value.state_dict()) 61 | else: 62 | payload['state_dicts'][key] = value.state_dict() 63 | elif key in include_keys: 64 | payload['pickles'][key] = dill.dumps(value) 65 | if use_thread: 66 | self._saving_thread = threading.Thread( 67 | target=lambda : torch.save(payload, path.open('wb'), pickle_module=dill)) 68 | self._saving_thread.start() 69 | else: 70 | torch.save(payload, path.open('wb'), pickle_module=dill) 71 | return str(path.absolute()) 72 | 73 | def get_checkpoint_path(self, tag='latest'): 74 | return pathlib.Path(self.output_dir).joinpath('checkpoints', f'{tag}.ckpt') 75 | 76 | def load_payload(self, payload, exclude_keys=None, include_keys=None, **kwargs): 77 | if exclude_keys is None: 78 | exclude_keys = tuple() 79 | if include_keys is None: 80 | include_keys = payload['pickles'].keys() 81 | 82 | for key, value in payload['state_dicts'].items(): 83 | if key not in exclude_keys: 84 | self.__dict__[key].load_state_dict(value, **kwargs) 85 | for key in include_keys: 86 | if key in payload['pickles']: 87 | self.__dict__[key] = dill.loads(payload['pickles'][key]) 88 | 89 | def load_checkpoint(self, path=None, tag='latest', 90 | exclude_keys=None, 91 | include_keys=None, 92 | **kwargs): 93 | if path is None: 94 | path = self.get_checkpoint_path(tag=tag) 95 | else: 96 | path = pathlib.Path(path) 97 | payload = torch.load(path.open('rb'), pickle_module=dill, **kwargs) 98 | self.load_payload(payload, 99 | exclude_keys=exclude_keys, 100 | include_keys=include_keys) 101 | return payload 102 | 103 | @classmethod 104 | def create_from_checkpoint(cls, path, 105 | exclude_keys=None, 106 | include_keys=None, 107 | **kwargs): 108 | payload = torch.load(open(path, 'rb'), pickle_module=dill) 109 | instance = cls(payload['cfg']) 110 | instance.load_payload( 111 | payload=payload, 112 | exclude_keys=exclude_keys, 113 | include_keys=include_keys, 114 | **kwargs) 115 | return instance 116 | 117 | def save_snapshot(self, tag='latest'): 118 | """ 119 | Quick loading and saving for reserach, saves full state of the workspace. 120 | 121 | However, loading a snapshot assumes the code stays exactly the same. 122 | Use save_checkpoint for long-term storage. 123 | """ 124 | path = pathlib.Path(self.output_dir).joinpath('snapshots', f'{tag}.pkl') 125 | path.parent.mkdir(parents=False, exist_ok=True) 126 | torch.save(self, path.open('wb'), pickle_module=dill) 127 | return str(path.absolute()) 128 | 129 | @classmethod 130 | def create_from_snapshot(cls, path): 131 | return torch.load(open(path, 'rb'), pickle_module=dill) 132 | 133 | 134 | def _copy_to_cpu(x): 135 | if isinstance(x, torch.Tensor): 136 | return x.detach().to('cpu') 137 | elif isinstance(x, dict): 138 | result = dict() 139 | for k, v in x.items(): 140 | result[k] = _copy_to_cpu(v) 141 | return result 142 | elif isinstance(x, list): 143 | return [_copy_to_cpu(k) for k in x] 144 | else: 145 | return copy.deepcopy(x) 146 | -------------------------------------------------------------------------------- /img/equi.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pointW/equidiff/e40abb003b25071d4e5b01bfa9933dc16cd32c67/img/equi.gif -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'equi_diffpo', 5 | packages = find_packages(), 6 | ) 7 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | Training: 4 | python train.py --config-name=train_diffusion_lowdim_workspace 5 | """ 6 | 7 | import sys 8 | # use line-buffering for both stdout and stderr 9 | sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1) 10 | sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1) 11 | 12 | import hydra 13 | from omegaconf import OmegaConf 14 | import pathlib 15 | from equi_diffpo.workspace.base_workspace import BaseWorkspace 16 | 17 | max_steps = { 18 | 'stack_d1': 400, 19 | 'stack_three_d1': 400, 20 | 'square_d2': 400, 21 | 'threading_d2': 400, 22 | 'coffee_d2': 400, 23 | 'three_piece_assembly_d2': 500, 24 | 'hammer_cleanup_d1': 500, 25 | 'mug_cleanup_d1': 500, 26 | 'kitchen_d1': 800, 27 | 'nut_assembly_d0': 500, 28 | 'pick_place_d0': 1000, 29 | 'coffee_preparation_d1': 800, 30 | 'tool_hang': 700, 31 | 'can': 400, 32 | 'lift': 400, 33 | 'square': 400, 34 | } 35 | 36 | def get_ws_x_center(task_name): 37 | if task_name.startswith('kitchen_') or task_name.startswith('hammer_cleanup_'): 38 | return -0.2 39 | else: 40 | return 0. 41 | 42 | def get_ws_y_center(task_name): 43 | return 0. 44 | 45 | OmegaConf.register_new_resolver("get_max_steps", lambda x: max_steps[x], replace=True) 46 | OmegaConf.register_new_resolver("get_ws_x_center", get_ws_x_center, replace=True) 47 | OmegaConf.register_new_resolver("get_ws_y_center", get_ws_y_center, replace=True) 48 | 49 | # allows arbitrary python code execution in configs using the ${eval:''} resolver 50 | OmegaConf.register_new_resolver("eval", eval, replace=True) 51 | 52 | @hydra.main( 53 | version_base=None, 54 | config_path=str(pathlib.Path(__file__).parent.joinpath( 55 | 'equi_diffpo','config')) 56 | ) 57 | def main(cfg: OmegaConf): 58 | # resolve immediately so all the ${now:} resolvers 59 | # will use the same time. 60 | OmegaConf.resolve(cfg) 61 | 62 | cls = hydra.utils.get_class(cfg._target_) 63 | workspace: BaseWorkspace = cls(cfg) 64 | workspace.run() 65 | 66 | if __name__ == "__main__": 67 | main() 68 | --------------------------------------------------------------------------------