├── .gitignore
├── LICENSE
├── README.md
├── conda_environment.yaml
├── equi_diffpo
├── codecs
│ └── imagecodecs_numcodecs.py
├── common
│ ├── checkpoint_util.py
│ ├── cv2_util.py
│ ├── env_util.py
│ ├── json_logger.py
│ ├── nested_dict_util.py
│ ├── normalize_util.py
│ ├── pose_trajectory_interpolator.py
│ ├── precise_sleep.py
│ ├── pymunk_override.py
│ ├── pymunk_util.py
│ ├── pytorch_util.py
│ ├── replay_buffer.py
│ ├── robomimic_config_util.py
│ ├── robomimic_util.py
│ ├── sampler.py
│ └── timestamp_accumulator.py
├── config
│ ├── dp3.yaml
│ ├── task
│ │ ├── mimicgen_abs.yaml
│ │ ├── mimicgen_pc_abs.yaml
│ │ ├── mimicgen_rel.yaml
│ │ ├── mimicgen_voxel_abs.yaml
│ │ └── mimicgen_voxel_rel.yaml
│ ├── train_act_abs.yaml
│ ├── train_bc_rnn.yaml
│ ├── train_diffusion_transformer.yaml
│ ├── train_diffusion_unet.yaml
│ ├── train_diffusion_unet_voxel_abs.yaml
│ ├── train_equi_diffusion_unet_abs.yaml
│ ├── train_equi_diffusion_unet_rel.yaml
│ ├── train_equi_diffusion_unet_voxel_abs.yaml
│ └── train_equi_diffusion_unet_voxel_rel.yaml
├── dataset
│ ├── base_dataset.py
│ ├── robomimic_replay_image_dataset.py
│ ├── robomimic_replay_image_sym_dataset.py
│ ├── robomimic_replay_lowdim_dataset.py
│ ├── robomimic_replay_lowdim_sym_dataset.py
│ ├── robomimic_replay_point_cloud_dataset.py
│ └── robomimic_replay_voxel_sym_dataset.py
├── env
│ └── robomimic
│ │ ├── robomimic_image_wrapper.py
│ │ └── robomimic_lowdim_wrapper.py
├── env_runner
│ ├── base_image_runner.py
│ ├── base_lowdim_runner.py
│ ├── robomimic_image_runner.py
│ └── robomimic_lowdim_runner.py
├── gym_util
│ ├── async_vector_env.py
│ ├── multistep_wrapper.py
│ ├── sync_vector_env.py
│ ├── video_recording_wrapper.py
│ └── video_wrapper.py
├── model
│ ├── common
│ │ ├── dict_of_tensor_mixin.py
│ │ ├── lr_scheduler.py
│ │ ├── module_attr_mixin.py
│ │ ├── normalizer.py
│ │ ├── rotation_transformer.py
│ │ ├── shape_util.py
│ │ └── tensor_util.py
│ ├── detr
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── main.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── backbone.py
│ │ │ ├── detr_vae.py
│ │ │ ├── position_encoding.py
│ │ │ └── transformer.py
│ │ ├── setup.py
│ │ └── util
│ │ │ ├── __init__.py
│ │ │ ├── box_ops.py
│ │ │ ├── misc.py
│ │ │ └── plot_utils.py
│ ├── diffusion
│ │ ├── conditional_unet1d.py
│ │ ├── conv1d_components.py
│ │ ├── dp3_conditional_unet1d.py
│ │ ├── ema_model.py
│ │ ├── mask_generator.py
│ │ ├── positional_embedding.py
│ │ └── transformer_for_diffusion.py
│ ├── equi
│ │ ├── __init__.py
│ │ ├── equi_conditional_unet1d.py
│ │ ├── equi_conditional_unet1d_vel.py
│ │ ├── equi_encoder.py
│ │ └── equi_obs_encoder.py
│ ├── unet
│ │ └── obs_cond_unet1d.py
│ └── vision
│ │ ├── crop_randomizer.py
│ │ ├── model_getter.py
│ │ ├── multi_image_obs_encoder.py
│ │ ├── pointnet_extractor.py
│ │ ├── rot_randomizer.py
│ │ ├── rot_randomizer_vel.py
│ │ ├── voxel_crop_randomizer.py
│ │ ├── voxel_rot_randomizer.py
│ │ └── voxel_rot_randomizer_rel.py
├── policy
│ ├── act_policy.py
│ ├── base_image_policy.py
│ ├── base_lowdim_policy.py
│ ├── diffusion_equi_unet_cnn_enc_policy.py
│ ├── diffusion_equi_unet_cnn_enc_policy_se2.py
│ ├── diffusion_equi_unet_cnn_enc_rel_policy.py
│ ├── diffusion_equi_unet_voxel_policy.py
│ ├── diffusion_equi_unet_voxel_rel_policy.py
│ ├── diffusion_transformer_hybrid_image_policy.py
│ ├── diffusion_unet_hybrid_image_policy.py
│ ├── diffusion_unet_voxel_policy.py
│ ├── dp3.py
│ └── robomimic_image_policy.py
├── scripts
│ ├── dataset_states_to_obs.py
│ ├── download_datasets.py
│ ├── robomimic_dataset_action_comparison.py
│ ├── robomimic_dataset_conversion.py
│ └── robomimic_dataset_obs_conversion.py
├── shared_memory
│ ├── shared_memory_queue.py
│ ├── shared_memory_ring_buffer.py
│ ├── shared_memory_util.py
│ └── shared_ndarray.py
└── workspace
│ ├── base_workspace.py
│ ├── train_act_workspace.py
│ ├── train_diffusion_transformer_hybrid_workspace.py
│ ├── train_diffusion_unet_hybrid_workspace.py
│ ├── train_dp3_workspace.py
│ ├── train_equi_workspace.py
│ └── train_robomimic_image_workspace.py
├── img
└── equi.gif
├── setup.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | bin
2 | logs
3 | wandb
4 | outputs
5 | data
6 | data_local
7 | .vscode
8 | _wandb
9 |
10 | **/.DS_Store
11 |
12 | fuse.cfg
13 |
14 | *.ai
15 |
16 | # Generation results
17 | results/
18 |
19 | ray/auth.json
20 |
21 | # Byte-compiled / optimized / DLL files
22 | __pycache__/
23 | *.py[cod]
24 | *$py.class
25 |
26 | # C extensions
27 | *.so
28 |
29 | # Distribution / packaging
30 | .Python
31 | build/
32 | develop-eggs/
33 | dist/
34 | downloads/
35 | eggs/
36 | .eggs/
37 | lib/
38 | lib64/
39 | parts/
40 | sdist/
41 | var/
42 | wheels/
43 | pip-wheel-metadata/
44 | share/python-wheels/
45 | *.egg-info/
46 | .installed.cfg
47 | *.egg
48 | MANIFEST
49 |
50 | # PyInstaller
51 | # Usually these files are written by a python script from a template
52 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
53 | *.manifest
54 | *.spec
55 |
56 | # Installer logs
57 | pip-log.txt
58 | pip-delete-this-directory.txt
59 |
60 | # Unit test / coverage reports
61 | htmlcov/
62 | .tox/
63 | .nox/
64 | .coverage
65 | .coverage.*
66 | .cache
67 | nosetests.xml
68 | coverage.xml
69 | *.cover
70 | *.py,cover
71 | .hypothesis/
72 | .pytest_cache/
73 |
74 | # Translations
75 | *.mo
76 | *.pot
77 |
78 | # Django stuff:
79 | *.log
80 | local_settings.py
81 | db.sqlite3
82 | db.sqlite3-journal
83 |
84 | # Flask stuff:
85 | instance/
86 | .webassets-cache
87 |
88 | # Scrapy stuff:
89 | .scrapy
90 |
91 | # Sphinx documentation
92 | docs/_build/
93 |
94 | # PyBuilder
95 | target/
96 |
97 | # Jupyter Notebook
98 | .ipynb_checkpoints
99 |
100 | # IPython
101 | profile_default/
102 | ipython_config.py
103 |
104 | # pyenv
105 | .python-version
106 |
107 | # pipenv
108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
111 | # install all needed dependencies.
112 | #Pipfile.lock
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Spyder project settings
125 | .spyderproject
126 | .spyproject
127 |
128 | # Rope project settings
129 | .ropeproject
130 |
131 | # mkdocs documentation
132 | /site
133 |
134 | # mypy
135 | .mypy_cache/
136 | .dmypy.json
137 | dmypy.json
138 |
139 | # Pyre type checker
140 | .pyre/
141 | equi_diffpo/scripts/equidiff_data_conversion.py
142 | equi_diffpo/model/equi/vec_conditional_unet1d_1.py
143 | update_max_score.py
144 | test4.png
145 | test3.png
146 | test2.png
147 | test1.png
148 | test.png
149 | sampled_xyz.png
150 | pc.pt
151 | metric3.py
152 | metric2.py
153 | metric1.py
154 | grouped_xyz.png
155 | all_xyz.png
156 | 1.png
157 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Dian Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Equivariant Diffusion Policy
2 | [Project Website](https://equidiff.github.io) | [Paper](https://arxiv.org/pdf/2407.01812) | [Video](https://youtu.be/xIFSx_NVROU?si=MaxsHmih6AnQKAVy)
3 | Dian Wang1, Stephen Hart2, David Surovik2, Tarik Kelestemur2, Haojie Huang1, Haibo Zhao1, Mark Yeatman2, Jiuguang Wang2, Robin Walters1, Robert Platt12
4 | 1Northeastern Univeristy, 2Boston Dynamics AI Institute
5 | Conference on Robot Learning 2024 (Oral)
6 |  |
7 | ## Installation
8 | 1. Install the following apt packages for mujoco:
9 | ```bash
10 | sudo apt install -y libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf
11 | ```
12 | 1. Install gfortran (dependancy for escnn)
13 | ```bash
14 | sudo apt install -y gfortran
15 | ```
16 |
17 | 1. Install [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge) (strongly recommended) or Anaconda
18 | 1. Clone this repo
19 | ```bash
20 | git clone https://github.com/pointW/equidiff.git
21 | cd equidiff
22 | ```
23 | 1. Install environment:
24 | Use Mambaforge (strongly recommended):
25 | ```bash
26 | mamba env create -f conda_environment.yaml
27 | conda activate equidiff
28 | ```
29 | or use Anaconda (not recommended):
30 | ```bash
31 | conda env create -f conda_environment.yaml
32 | conda activate equidiff
33 | ```
34 | 1. Install mimicgen:
35 | ```bash
36 | cd ..
37 | git clone https://github.com/NVlabs/mimicgen_environments.git
38 | cd mimicgen_environments
39 | # This project was developed with Mimicgen v0.1.0. The latest version should work fine, but it is not tested
40 | git checkout 081f7dbbe5fff17b28c67ce8ec87c371f32526a9
41 | pip install -e .
42 | cd ../equidiff
43 | ```
44 | 1. Make sure mujoco version is 2.3.2 (required by mimicgen)
45 | ```bash
46 | pip list | grep mujoco
47 | ```
48 |
49 | ## Dataset
50 | ### Download Dataset
51 | Download dataset from MimicGen's hugging face: https://huggingface.co/datasets/amandlek/mimicgen_datasets/tree/main/core
52 | Make sure the dataset is kept under `/path/to/equidiff/data/robomimic/datasets/[dataset]/[dataset].hdf5`
53 |
54 | ### Generating Voxel and Point Cloud Observation
55 |
56 | ```bash
57 | # Template
58 | python equi_diffpo/scripts/dataset_states_to_obs.py --input data/robomimic/datasets/[dataset]/[dataset].hdf5 --output data/robomimic/datasets/[dataset]/[dataset]_voxel.hdf5 --num_workers=[n_worker]
59 | # Replace [dataset] and [n_worker] with your choices.
60 | # E.g., use 24 workers to generate point cloud and voxel observation for stack_d1
61 | python equi_diffpo/scripts/dataset_states_to_obs.py --input data/robomimic/datasets/stack_d1/stack_d1.hdf5 --output data/robomimic/datasets/stack_d1/stack_d1_voxel.hdf5 --num_workers=24
62 | ```
63 |
64 | ### Convert Action Space in Dataset
65 | The downloaded dataset has a relative action space. To train with absolute action space, the dataset needs to be converted accordingly
66 | ```bash
67 | # Template
68 | python equi_diffpo/scripts/robomimic_dataset_conversion.py -i data/robomimic/datasets/[dataset]/[dataset].hdf5 -o data/robomimic/datasets/[dataset]/[dataset]_abs.hdf5 -n [n_worker]
69 | # Replace [dataset] and [n_worker] with your choices.
70 | # E.g., convert stack_d1 (non-voxel) with 12 workers
71 | python equi_diffpo/scripts/robomimic_dataset_conversion.py -i data/robomimic/datasets/stack_d1/stack_d1_voxel.hdf5 -o data/robomimic/datasets/stack_d1/stack_d1_abs.hdf5 -n 12
72 | # E.g., convert stack_d1_voxel (voxel) with 12 workers
73 | python equi_diffpo/scripts/robomimic_dataset_conversion.py -i data/robomimic/datasets/stack_d1/stack_d1_voxel.hdf5 -o data/robomimic/datasets/stack_d1/stack_d1_voxel_abs.hdf5 -n 12
74 | ```
75 |
76 | ## Training with image observation
77 | To train Equivariant Diffusion Policy (with absolute pose control) in Stack D1 task:
78 | ```bash
79 | # Make sure you have the non-voxel converted dataset with absolute action space from the previous step
80 | python train.py --config-name=train_equi_diffusion_unet_abs task_name=stack_d1 n_demo=100
81 | ```
82 | To train with relative pose control instead:
83 | ```bash
84 | python train.py --config-name=train_equi_diffusion_unet_rel task_name=stack_d1 n_demo=100
85 | ```
86 | To train in other tasks, replace `stack_d1` with `stack_three_d1`, `square_d2`, `threading_d2`, `coffee_d2`, `three_piece_assembly_d2`, `hammer_cleanup_d1`, `mug_cleanup_d1`, `kitchen_d1`, `nut_assembly_d0`, `pick_place_d0`, `coffee_preparation_d1`. Notice that the corresponding dataset should be downloaded already. If training absolute pose control, the data conversion is also needed.
87 |
88 | To run environments on CPU (to save GPU memory), use `osmesa` instead of `egl` through `MUJOCO_GL=osmesa PYOPENGL_PLATFORM=osmesa`, e.g.,
89 | ```bash
90 | MUJOCO_GL=osmesa PYOPENGL_PLATFORM=osmesa python train.py --config-name=train_equi_diffusion_unet_abs task_name=stack_d1
91 | ```
92 |
93 | Equivariant Diffusion Policy requires around 22G GPU memory to run with batch size of 128 (default). To reduce the GPU usage, consider training with smaller batch size and/or reducing the hidden dimension
94 | ```bash
95 | # to train with batch size of 64 and hidden dimension of 64
96 | MUJOCO_GL=osmesa PYOPENGL_PLATTFORM=osmesa python train.py --config-name=train_equi_diffusion_unet_abs task_name=stack_d1 policy.enc_n_hidden=64 dataloader.batch_size=64
97 | ```
98 |
99 | ## Training with voxel observation
100 | To train Equivariant Diffusion Policy (with absolute pose control) in Stack D1 task:
101 | ```bash
102 | # Make sure you have the voxel converted dataset with absolute action space from the previous step
103 | python train.py --config-name=train_equi_diffusion_unet_voxel_abs task_name=stack_d1 n_demo=100
104 | ```
105 |
106 | ## License
107 | This repository is released under the MIT license. See [LICENSE](LICENSE) for additional details.
108 |
109 | ## Acknowledgement
110 | * Our repo is built upon the origional [Diffusion Policy](https://github.com/real-stanford/diffusion_policy)
111 | * Our ACT baseline is adaped from its [original repo](https://github.com/tonyzhaozh/act)
112 | * Our DP3 baseline is adaped from its [original repo](https://github.com/YanjieZe/3D-Diffusion-Policy)
113 |
--------------------------------------------------------------------------------
/conda_environment.yaml:
--------------------------------------------------------------------------------
1 | name: equidiff
2 | channels:
3 | - pytorch
4 | - pytorch3d
5 | - nvidia
6 | - conda-forge
7 | dependencies:
8 | - python=3.9
9 | - pip=22.2.2
10 | - pytorch=2.1.0
11 | - torchaudio=2.1.0
12 | - torchvision=0.16.0
13 | - pytorch-cuda=11.8
14 | - pytorch3d=0.7.5
15 | - numpy=1.23.3
16 | - numba==0.56.4
17 | - scipy==1.9.1
18 | - py-opencv=4.6.0
19 | - cffi=1.15.1
20 | - ipykernel=6.16
21 | - matplotlib=3.6.1
22 | - zarr=2.12.0
23 | - numcodecs=0.10.2
24 | - h5py=3.7.0
25 | - hydra-core=1.2.0
26 | - einops=0.4.1
27 | - tqdm=4.64.1
28 | - dill=0.3.5.1
29 | - scikit-video=1.1.11
30 | - scikit-image=0.19.3
31 | - gym=0.21.0
32 | - pymunk=6.2.1
33 | - threadpoolctl=3.1.0
34 | - shapely=1.8.4
35 | - cython=0.29.32
36 | - imageio=2.22.0
37 | - imageio-ffmpeg=0.4.7
38 | - termcolor=2.0.1
39 | - tensorboard=2.10.1
40 | - tensorboardx=2.5.1
41 | - psutil=5.9.2
42 | - click=8.0.4
43 | - boto3=1.24.96
44 | - accelerate=0.13.2
45 | - datasets=2.6.1
46 | - diffusers=0.11.1
47 | - av=10.0.0
48 | - cmake=3.24.3
49 | # trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625
50 | - llvm-openmp=14
51 | # trick to force reinstall imagecodecs via pip
52 | - imagecodecs==2022.8.8
53 | - pip:
54 | - open3d
55 | - wandb==0.17.0
56 | - pygame
57 | - imagecodecs==2022.9.26
58 | - escnn @ https://github.com/pointW/escnn/archive/fc4714cb6dc0d2a32f9fcea35771968b89911109.tar.gz
59 | - robosuite @ https://github.com/ARISE-Initiative/robosuite/archive/b9d8d3de5e3dfd1724f4a0e6555246c460407daa.tar.gz
60 | - robomimic @ https://github.com/pointW/robomimic/archive/8aad5b3caaaac9289b1504438a7f5d3a76d06c07.tar.gz
61 | - robosuite-task-zoo @ https://github.com/pointW/robosuite-task-zoo/archive/0f8a7b2fa5d192e4e8800bebfe8090b28926f3ed.tar.gz
62 |
--------------------------------------------------------------------------------
/equi_diffpo/common/checkpoint_util.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Dict
2 | import os
3 |
4 | class TopKCheckpointManager:
5 | def __init__(self,
6 | save_dir,
7 | monitor_key: str,
8 | mode='min',
9 | k=1,
10 | format_str='epoch={epoch:03d}-train_loss={train_loss:.3f}.ckpt'
11 | ):
12 | assert mode in ['max', 'min']
13 | assert k >= 0
14 |
15 | self.save_dir = save_dir
16 | self.monitor_key = monitor_key
17 | self.mode = mode
18 | self.k = k
19 | self.format_str = format_str
20 | self.path_value_map = dict()
21 |
22 | def get_ckpt_path(self, data: Dict[str, float]) -> Optional[str]:
23 | if self.k == 0:
24 | return None
25 |
26 | value = data[self.monitor_key]
27 | ckpt_path = os.path.join(
28 | self.save_dir, self.format_str.format(**data))
29 |
30 | if len(self.path_value_map) < self.k:
31 | # under-capacity
32 | self.path_value_map[ckpt_path] = value
33 | return ckpt_path
34 |
35 | # at capacity
36 | sorted_map = sorted(self.path_value_map.items(), key=lambda x: x[1])
37 | min_path, min_value = sorted_map[0]
38 | max_path, max_value = sorted_map[-1]
39 |
40 | delete_path = None
41 | if self.mode == 'max':
42 | if value > min_value:
43 | delete_path = min_path
44 | else:
45 | if value < max_value:
46 | delete_path = max_path
47 |
48 | if delete_path is None:
49 | return None
50 | else:
51 | del self.path_value_map[delete_path]
52 | self.path_value_map[ckpt_path] = value
53 |
54 | if not os.path.exists(self.save_dir):
55 | os.mkdir(self.save_dir)
56 |
57 | if os.path.exists(delete_path):
58 | os.remove(delete_path)
59 | return ckpt_path
60 |
--------------------------------------------------------------------------------
/equi_diffpo/common/cv2_util.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | import math
3 | import cv2
4 | import numpy as np
5 |
6 | def draw_reticle(img, u, v, label_color):
7 | """
8 | Draws a reticle (cross-hair) on the image at the given position on top of
9 | the original image.
10 | @param img (In/Out) uint8 3 channel image
11 | @param u X coordinate (width)
12 | @param v Y coordinate (height)
13 | @param label_color tuple of 3 ints for RGB color used for drawing.
14 | """
15 | # Cast to int.
16 | u = int(u)
17 | v = int(v)
18 |
19 | white = (255, 255, 255)
20 | cv2.circle(img, (u, v), 10, label_color, 1)
21 | cv2.circle(img, (u, v), 11, white, 1)
22 | cv2.circle(img, (u, v), 12, label_color, 1)
23 | cv2.line(img, (u, v + 1), (u, v + 3), white, 1)
24 | cv2.line(img, (u + 1, v), (u + 3, v), white, 1)
25 | cv2.line(img, (u, v - 1), (u, v - 3), white, 1)
26 | cv2.line(img, (u - 1, v), (u - 3, v), white, 1)
27 |
28 |
29 | def draw_text(
30 | img,
31 | *,
32 | text,
33 | uv_top_left,
34 | color=(255, 255, 255),
35 | fontScale=0.5,
36 | thickness=1,
37 | fontFace=cv2.FONT_HERSHEY_SIMPLEX,
38 | outline_color=(0, 0, 0),
39 | line_spacing=1.5,
40 | ):
41 | """
42 | Draws multiline with an outline.
43 | """
44 | assert isinstance(text, str)
45 |
46 | uv_top_left = np.array(uv_top_left, dtype=float)
47 | assert uv_top_left.shape == (2,)
48 |
49 | for line in text.splitlines():
50 | (w, h), _ = cv2.getTextSize(
51 | text=line,
52 | fontFace=fontFace,
53 | fontScale=fontScale,
54 | thickness=thickness,
55 | )
56 | uv_bottom_left_i = uv_top_left + [0, h]
57 | org = tuple(uv_bottom_left_i.astype(int))
58 |
59 | if outline_color is not None:
60 | cv2.putText(
61 | img,
62 | text=line,
63 | org=org,
64 | fontFace=fontFace,
65 | fontScale=fontScale,
66 | color=outline_color,
67 | thickness=thickness * 3,
68 | lineType=cv2.LINE_AA,
69 | )
70 | cv2.putText(
71 | img,
72 | text=line,
73 | org=org,
74 | fontFace=fontFace,
75 | fontScale=fontScale,
76 | color=color,
77 | thickness=thickness,
78 | lineType=cv2.LINE_AA,
79 | )
80 |
81 | uv_top_left += [0, h * line_spacing]
82 |
83 |
84 | def get_image_transform(
85 | input_res: Tuple[int,int]=(1280,720),
86 | output_res: Tuple[int,int]=(640,480),
87 | bgr_to_rgb: bool=False):
88 |
89 | iw, ih = input_res
90 | ow, oh = output_res
91 | rw, rh = None, None
92 | interp_method = cv2.INTER_AREA
93 |
94 | if (iw/ih) >= (ow/oh):
95 | # input is wider
96 | rh = oh
97 | rw = math.ceil(rh / ih * iw)
98 | if oh > ih:
99 | interp_method = cv2.INTER_LINEAR
100 | else:
101 | rw = ow
102 | rh = math.ceil(rw / iw * ih)
103 | if ow > iw:
104 | interp_method = cv2.INTER_LINEAR
105 |
106 | w_slice_start = (rw - ow) // 2
107 | w_slice = slice(w_slice_start, w_slice_start + ow)
108 | h_slice_start = (rh - oh) // 2
109 | h_slice = slice(h_slice_start, h_slice_start + oh)
110 | c_slice = slice(None)
111 | if bgr_to_rgb:
112 | c_slice = slice(None, None, -1)
113 |
114 | def transform(img: np.ndarray):
115 | assert img.shape == ((ih,iw,3))
116 | # resize
117 | img = cv2.resize(img, (rw, rh), interpolation=interp_method)
118 | # crop
119 | img = img[h_slice, w_slice, c_slice]
120 | return img
121 | return transform
122 |
123 | def optimal_row_cols(
124 | n_cameras,
125 | in_wh_ratio,
126 | max_resolution=(1920, 1080)
127 | ):
128 | out_w, out_h = max_resolution
129 | out_wh_ratio = out_w / out_h
130 |
131 | n_rows = np.arange(n_cameras,dtype=np.int64) + 1
132 | n_cols = np.ceil(n_cameras / n_rows).astype(np.int64)
133 | cat_wh_ratio = in_wh_ratio * (n_cols / n_rows)
134 | ratio_diff = np.abs(out_wh_ratio - cat_wh_ratio)
135 | best_idx = np.argmin(ratio_diff)
136 | best_n_row = n_rows[best_idx]
137 | best_n_col = n_cols[best_idx]
138 | best_cat_wh_ratio = cat_wh_ratio[best_idx]
139 |
140 | rw, rh = None, None
141 | if best_cat_wh_ratio >= out_wh_ratio:
142 | # cat is wider
143 | rw = math.floor(out_w / best_n_col)
144 | rh = math.floor(rw / in_wh_ratio)
145 | else:
146 | rh = math.floor(out_h / best_n_row)
147 | rw = math.floor(rh * in_wh_ratio)
148 |
149 | # crop_resolution = (rw, rh)
150 | return rw, rh, best_n_col, best_n_row
151 |
--------------------------------------------------------------------------------
/equi_diffpo/common/env_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | def render_env_video(env, states, actions=None):
6 | observations = states
7 | imgs = list()
8 | for i in range(len(observations)):
9 | state = observations[i]
10 | env.set_state(state)
11 | if i == 0:
12 | env.set_state(state)
13 | img = env.render()
14 | # draw action
15 | if actions is not None:
16 | action = actions[i]
17 | coord = (action / 512 * 96).astype(np.int32)
18 | cv2.drawMarker(img, coord,
19 | color=(255,0,0), markerType=cv2.MARKER_CROSS,
20 | markerSize=8, thickness=1)
21 | imgs.append(img)
22 | imgs = np.array(imgs)
23 | return imgs
24 |
--------------------------------------------------------------------------------
/equi_diffpo/common/json_logger.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Callable, Any, Sequence
2 | import os
3 | import copy
4 | import json
5 | import numbers
6 | import pandas as pd
7 |
8 |
9 | def read_json_log(path: str,
10 | required_keys: Sequence[str]=tuple(),
11 | **kwargs) -> pd.DataFrame:
12 | """
13 | Read json-per-line file, with potentially incomplete lines.
14 | kwargs passed to pd.read_json
15 | """
16 | lines = list()
17 | with open(path, 'r') as f:
18 | while True:
19 | # one json per line
20 | line = f.readline()
21 | if len(line) == 0:
22 | # EOF
23 | break
24 | elif not line.endswith('\n'):
25 | # incomplete line
26 | break
27 | is_relevant = False
28 | for k in required_keys:
29 | if k in line:
30 | is_relevant = True
31 | break
32 | if is_relevant:
33 | lines.append(line)
34 | if len(lines) < 1:
35 | return pd.DataFrame()
36 | json_buf = f'[{",".join([line for line in (line.strip() for line in lines) if line])}]'
37 | df = pd.read_json(json_buf, **kwargs)
38 | return df
39 |
40 | class JsonLogger:
41 | def __init__(self, path: str,
42 | filter_fn: Optional[Callable[[str,Any],bool]]=None):
43 | if filter_fn is None:
44 | filter_fn = lambda k,v: isinstance(v, numbers.Number)
45 |
46 | # default to append mode
47 | self.path = path
48 | self.filter_fn = filter_fn
49 | self.file = None
50 | self.last_log = None
51 |
52 | def start(self):
53 | # use line buffering
54 | try:
55 | self.file = file = open(self.path, 'r+', buffering=1)
56 | except FileNotFoundError:
57 | self.file = file = open(self.path, 'w+', buffering=1)
58 |
59 | # Move the pointer (similar to a cursor in a text editor) to the end of the file
60 | pos = file.seek(0, os.SEEK_END)
61 |
62 | # Read each character in the file one at a time from the last
63 | # character going backwards, searching for a newline character
64 | # If we find a new line, exit the search
65 | while pos > 0 and file.read(1) != "\n":
66 | pos -= 1
67 | file.seek(pos, os.SEEK_SET)
68 | # now the file pointer is at one past the last '\n'
69 | # and pos is at the last '\n'.
70 | last_line_end = file.tell()
71 |
72 | # find the start of second last line
73 | pos = max(0, pos-1)
74 | file.seek(pos, os.SEEK_SET)
75 | while pos > 0 and file.read(1) != "\n":
76 | pos -= 1
77 | file.seek(pos, os.SEEK_SET)
78 | # now the file pointer is at one past the second last '\n'
79 | last_line_start = file.tell()
80 |
81 | if last_line_start < last_line_end:
82 | # has last line of json
83 | last_line = file.readline()
84 | self.last_log = json.loads(last_line)
85 |
86 | # remove the last incomplete line
87 | file.seek(last_line_end)
88 | file.truncate()
89 |
90 | def stop(self):
91 | self.file.close()
92 | self.file = None
93 |
94 | def __enter__(self):
95 | self.start()
96 | return self
97 |
98 | def __exit__(self, exc_type, exc_val, exc_tb):
99 | self.stop()
100 |
101 | def log(self, data: dict):
102 | filtered_data = dict(
103 | filter(lambda x: self.filter_fn(*x), data.items()))
104 | # save current as last log
105 | self.last_log = filtered_data
106 | for k, v in filtered_data.items():
107 | if isinstance(v, numbers.Integral):
108 | filtered_data[k] = int(v)
109 | elif isinstance(v, numbers.Number):
110 | filtered_data[k] = float(v)
111 | buf = json.dumps(filtered_data)
112 | # ensure one line per json
113 | buf = buf.replace('\n','') + '\n'
114 | self.file.write(buf)
115 |
116 | def get_last_log(self):
117 | return copy.deepcopy(self.last_log)
118 |
--------------------------------------------------------------------------------
/equi_diffpo/common/nested_dict_util.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | def nested_dict_map(f, x):
4 | """
5 | Map f over all leaf of nested dict x
6 | """
7 |
8 | if not isinstance(x, dict):
9 | return f(x)
10 | y = dict()
11 | for key, value in x.items():
12 | y[key] = nested_dict_map(f, value)
13 | return y
14 |
15 | def nested_dict_reduce(f, x):
16 | """
17 | Map f over all values of nested dict x, and reduce to a single value
18 | """
19 | if not isinstance(x, dict):
20 | return x
21 |
22 | reduced_values = list()
23 | for value in x.values():
24 | reduced_values.append(nested_dict_reduce(f, value))
25 | y = functools.reduce(f, reduced_values)
26 | return y
27 |
28 |
29 | def nested_dict_check(f, x):
30 | bool_dict = nested_dict_map(f, x)
31 | result = nested_dict_reduce(lambda x, y: x and y, bool_dict)
32 | return result
33 |
--------------------------------------------------------------------------------
/equi_diffpo/common/precise_sleep.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | def precise_sleep(dt: float, slack_time: float=0.001, time_func=time.monotonic):
4 | """
5 | Use hybrid of time.sleep and spinning to minimize jitter.
6 | Sleep dt - slack_time seconds first, then spin for the rest.
7 | """
8 | t_start = time_func()
9 | if dt > slack_time:
10 | time.sleep(dt - slack_time)
11 | t_end = t_start + dt
12 | while time_func() < t_end:
13 | pass
14 | return
15 |
16 | def precise_wait(t_end: float, slack_time: float=0.001, time_func=time.monotonic):
17 | t_start = time_func()
18 | t_wait = t_end - t_start
19 | if t_wait > 0:
20 | t_sleep = t_wait - slack_time
21 | if t_sleep > 0:
22 | time.sleep(t_sleep)
23 | while time_func() < t_end:
24 | pass
25 | return
26 |
--------------------------------------------------------------------------------
/equi_diffpo/common/pymunk_util.py:
--------------------------------------------------------------------------------
1 | import pygame
2 | import pymunk
3 | import pymunk.pygame_util
4 | import numpy as np
5 |
6 | COLLTYPE_DEFAULT = 0
7 | COLLTYPE_MOUSE = 1
8 | COLLTYPE_BALL = 2
9 |
10 | def get_body_type(static=False):
11 | body_type = pymunk.Body.DYNAMIC
12 | if static:
13 | body_type = pymunk.Body.STATIC
14 | return body_type
15 |
16 |
17 | def create_rectangle(space,
18 | pos_x,pos_y,width,height,
19 | density=3,static=False):
20 | body = pymunk.Body(body_type=get_body_type(static))
21 | body.position = (pos_x,pos_y)
22 | shape = pymunk.Poly.create_box(body,(width,height))
23 | shape.density = density
24 | space.add(body,shape)
25 | return body, shape
26 |
27 |
28 | def create_rectangle_bb(space,
29 | left, bottom, right, top,
30 | **kwargs):
31 | pos_x = (left + right) / 2
32 | pos_y = (top + bottom) / 2
33 | height = top - bottom
34 | width = right - left
35 | return create_rectangle(space, pos_x, pos_y, width, height, **kwargs)
36 |
37 | def create_circle(space, pos_x, pos_y, radius, density=3, static=False):
38 | body = pymunk.Body(body_type=get_body_type(static))
39 | body.position = (pos_x, pos_y)
40 | shape = pymunk.Circle(body, radius=radius)
41 | shape.density = density
42 | shape.collision_type = COLLTYPE_BALL
43 | space.add(body, shape)
44 | return body, shape
45 |
46 | def get_body_state(body):
47 | state = np.zeros(6, dtype=np.float32)
48 | state[:2] = body.position
49 | state[2] = body.angle
50 | state[3:5] = body.velocity
51 | state[5] = body.angular_velocity
52 | return state
53 |
--------------------------------------------------------------------------------
/equi_diffpo/common/pytorch_util.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Callable, List
2 | import collections
3 | import torch
4 | import torch.nn as nn
5 |
6 | def dict_apply(
7 | x: Dict[str, torch.Tensor],
8 | func: Callable[[torch.Tensor], torch.Tensor]
9 | ) -> Dict[str, torch.Tensor]:
10 | result = dict()
11 | for key, value in x.items():
12 | if isinstance(value, dict):
13 | result[key] = dict_apply(value, func)
14 | else:
15 | result[key] = func(value)
16 | return result
17 |
18 | def pad_remaining_dims(x, target):
19 | assert x.shape == target.shape[:len(x.shape)]
20 | return x.reshape(x.shape + (1,)*(len(target.shape) - len(x.shape)))
21 |
22 | def dict_apply_split(
23 | x: Dict[str, torch.Tensor],
24 | split_func: Callable[[torch.Tensor], Dict[str, torch.Tensor]]
25 | ) -> Dict[str, torch.Tensor]:
26 | results = collections.defaultdict(dict)
27 | for key, value in x.items():
28 | result = split_func(value)
29 | for k, v in result.items():
30 | results[k][key] = v
31 | return results
32 |
33 | def dict_apply_reduce(
34 | x: List[Dict[str, torch.Tensor]],
35 | reduce_func: Callable[[List[torch.Tensor]], torch.Tensor]
36 | ) -> Dict[str, torch.Tensor]:
37 | result = dict()
38 | for key in x[0].keys():
39 | result[key] = reduce_func([x_[key] for x_ in x])
40 | return result
41 |
42 |
43 | def replace_submodules(
44 | root_module: nn.Module,
45 | predicate: Callable[[nn.Module], bool],
46 | func: Callable[[nn.Module], nn.Module]) -> nn.Module:
47 | """
48 | predicate: Return true if the module is to be replaced.
49 | func: Return new module to use.
50 | """
51 | if predicate(root_module):
52 | return func(root_module)
53 |
54 | bn_list = [k.split('.') for k, m
55 | in root_module.named_modules(remove_duplicate=True)
56 | if predicate(m)]
57 | for *parent, k in bn_list:
58 | parent_module = root_module
59 | if len(parent) > 0:
60 | parent_module = root_module.get_submodule('.'.join(parent))
61 | if isinstance(parent_module, nn.Sequential):
62 | src_module = parent_module[int(k)]
63 | else:
64 | src_module = getattr(parent_module, k)
65 | tgt_module = func(src_module)
66 | if isinstance(parent_module, nn.Sequential):
67 | parent_module[int(k)] = tgt_module
68 | else:
69 | setattr(parent_module, k, tgt_module)
70 | # verify that all BN are replaced
71 | bn_list = [k.split('.') for k, m
72 | in root_module.named_modules(remove_duplicate=True)
73 | if predicate(m)]
74 | assert len(bn_list) == 0
75 | return root_module
76 |
77 | def optimizer_to(optimizer, device):
78 | for state in optimizer.state.values():
79 | for k, v in state.items():
80 | if isinstance(v, torch.Tensor):
81 | state[k] = v.to(device=device)
82 | return optimizer
83 |
--------------------------------------------------------------------------------
/equi_diffpo/common/robomimic_config_util.py:
--------------------------------------------------------------------------------
1 | from omegaconf import OmegaConf
2 | from robomimic.config import config_factory
3 | import robomimic.scripts.generate_paper_configs as gpc
4 | from robomimic.scripts.generate_paper_configs import (
5 | modify_config_for_default_image_exp,
6 | modify_config_for_default_low_dim_exp,
7 | modify_config_for_dataset,
8 | )
9 |
10 | def get_robomimic_config(
11 | algo_name='bc_rnn',
12 | hdf5_type='low_dim',
13 | task_name='square',
14 | dataset_type='ph'
15 | ):
16 | base_dataset_dir = '/tmp/null'
17 | filter_key = None
18 |
19 | # decide whether to use low-dim or image training defaults
20 | modifier_for_obs = modify_config_for_default_image_exp
21 | if hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
22 | modifier_for_obs = modify_config_for_default_low_dim_exp
23 |
24 | algo_config_name = "bc" if algo_name == "bc_rnn" else algo_name
25 | config = config_factory(algo_name=algo_config_name)
26 | # turn into default config for observation modalities (e.g.: low-dim or rgb)
27 | config = modifier_for_obs(config)
28 | # add in config based on the dataset
29 | config = modify_config_for_dataset(
30 | config=config,
31 | task_name=task_name,
32 | dataset_type=dataset_type,
33 | hdf5_type=hdf5_type,
34 | base_dataset_dir=base_dataset_dir,
35 | filter_key=filter_key,
36 | )
37 | # add in algo hypers based on dataset
38 | algo_config_modifier = getattr(gpc, f'modify_{algo_name}_config_for_dataset')
39 | config = algo_config_modifier(
40 | config=config,
41 | task_name=task_name,
42 | dataset_type=dataset_type,
43 | hdf5_type=hdf5_type,
44 | )
45 | return config
46 |
47 |
48 |
--------------------------------------------------------------------------------
/equi_diffpo/common/sampler.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import numpy as np
3 | import numba
4 | from equi_diffpo.common.replay_buffer import ReplayBuffer
5 |
6 |
7 | @numba.jit(nopython=True)
8 | def create_indices(
9 | episode_ends:np.ndarray, sequence_length:int,
10 | episode_mask: np.ndarray,
11 | pad_before: int=0, pad_after: int=0,
12 | debug:bool=True) -> np.ndarray:
13 | episode_mask.shape == episode_ends.shape
14 | pad_before = min(max(pad_before, 0), sequence_length-1)
15 | pad_after = min(max(pad_after, 0), sequence_length-1)
16 |
17 | indices = list()
18 | for i in range(len(episode_ends)):
19 | if not episode_mask[i]:
20 | # skip episode
21 | continue
22 | start_idx = 0
23 | if i > 0:
24 | start_idx = episode_ends[i-1]
25 | end_idx = episode_ends[i]
26 | episode_length = end_idx - start_idx
27 |
28 | min_start = -pad_before
29 | max_start = episode_length - sequence_length + pad_after
30 |
31 | # range stops one idx before end
32 | for idx in range(min_start, max_start+1):
33 | buffer_start_idx = max(idx, 0) + start_idx
34 | buffer_end_idx = min(idx+sequence_length, episode_length) + start_idx
35 | start_offset = buffer_start_idx - (idx+start_idx)
36 | end_offset = (idx+sequence_length+start_idx) - buffer_end_idx
37 | sample_start_idx = 0 + start_offset
38 | sample_end_idx = sequence_length - end_offset
39 | if debug:
40 | assert(start_offset >= 0)
41 | assert(end_offset >= 0)
42 | assert (sample_end_idx - sample_start_idx) == (buffer_end_idx - buffer_start_idx)
43 | indices.append([
44 | buffer_start_idx, buffer_end_idx,
45 | sample_start_idx, sample_end_idx])
46 | indices = np.array(indices)
47 | return indices
48 |
49 |
50 | def get_val_mask(n_episodes, val_ratio, seed=0):
51 | val_mask = np.zeros(n_episodes, dtype=bool)
52 | if val_ratio <= 0:
53 | return val_mask
54 |
55 | # have at least 1 episode for validation, and at least 1 episode for train
56 | n_val = min(max(1, round(n_episodes * val_ratio)), n_episodes-1)
57 | rng = np.random.default_rng(seed=seed)
58 | val_idxs = rng.choice(n_episodes, size=n_val, replace=False)
59 | val_mask[val_idxs] = True
60 | return val_mask
61 |
62 |
63 | def downsample_mask(mask, max_n, seed=0):
64 | # subsample training data
65 | train_mask = mask
66 | if (max_n is not None) and (np.sum(train_mask) > max_n):
67 | n_train = int(max_n)
68 | curr_train_idxs = np.nonzero(train_mask)[0]
69 | rng = np.random.default_rng(seed=seed)
70 | train_idxs_idx = rng.choice(len(curr_train_idxs), size=n_train, replace=False)
71 | train_idxs = curr_train_idxs[train_idxs_idx]
72 | train_mask = np.zeros_like(train_mask)
73 | train_mask[train_idxs] = True
74 | assert np.sum(train_mask) == n_train
75 | return train_mask
76 |
77 | class SequenceSampler:
78 | def __init__(self,
79 | replay_buffer: ReplayBuffer,
80 | sequence_length:int,
81 | pad_before:int=0,
82 | pad_after:int=0,
83 | keys=None,
84 | key_first_k=dict(),
85 | episode_mask: Optional[np.ndarray]=None,
86 | ):
87 | """
88 | key_first_k: dict str: int
89 | Only take first k data from these keys (to improve perf)
90 | """
91 |
92 | super().__init__()
93 | assert(sequence_length >= 1)
94 | if keys is None:
95 | keys = list(replay_buffer.keys())
96 |
97 | episode_ends = replay_buffer.episode_ends[:]
98 | if episode_mask is None:
99 | episode_mask = np.ones(episode_ends.shape, dtype=bool)
100 |
101 | if np.any(episode_mask):
102 | indices = create_indices(episode_ends,
103 | sequence_length=sequence_length,
104 | pad_before=pad_before,
105 | pad_after=pad_after,
106 | episode_mask=episode_mask
107 | )
108 | else:
109 | indices = np.zeros((0,4), dtype=np.int64)
110 |
111 | # (buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx)
112 | self.indices = indices
113 | self.keys = list(keys) # prevent OmegaConf list performance problem
114 | self.sequence_length = sequence_length
115 | self.replay_buffer = replay_buffer
116 | self.key_first_k = key_first_k
117 |
118 | def __len__(self):
119 | return len(self.indices)
120 |
121 | def sample_sequence(self, idx):
122 | buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx \
123 | = self.indices[idx]
124 | result = dict()
125 | for key in self.keys:
126 | input_arr = self.replay_buffer[key]
127 | # performance optimization, avoid small allocation if possible
128 | if key not in self.key_first_k:
129 | sample = input_arr[buffer_start_idx:buffer_end_idx]
130 | else:
131 | # performance optimization, only load used obs steps
132 | n_data = buffer_end_idx - buffer_start_idx
133 | k_data = min(self.key_first_k[key], n_data)
134 | # fill value with Nan to catch bugs
135 | # the non-loaded region should never be used
136 | sample = np.full((n_data,) + input_arr.shape[1:],
137 | fill_value=np.nan, dtype=input_arr.dtype)
138 | try:
139 | sample[:k_data] = input_arr[buffer_start_idx:buffer_start_idx+k_data]
140 | except Exception as e:
141 | import pdb; pdb.set_trace()
142 | data = sample
143 | if (sample_start_idx > 0) or (sample_end_idx < self.sequence_length):
144 | data = np.zeros(
145 | shape=(self.sequence_length,) + input_arr.shape[1:],
146 | dtype=input_arr.dtype)
147 | if sample_start_idx > 0:
148 | data[:sample_start_idx] = sample[0]
149 | if sample_end_idx < self.sequence_length:
150 | data[sample_end_idx:] = sample[-1]
151 | data[sample_start_idx:sample_end_idx] = sample
152 | result[key] = data
153 | return result
154 |
--------------------------------------------------------------------------------
/equi_diffpo/config/dp3.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - task: mimicgen_pc_abs
4 |
5 | name: train_dp3
6 | _target_: equi_diffpo.workspace.train_dp3_workspace.TrainDP3Workspace
7 |
8 | shape_meta: ${task.shape_meta}
9 | exp_name: "debug"
10 |
11 | task_name: stack_d1
12 | n_demo: 200
13 | horizon: 16
14 | n_obs_steps: 2
15 | n_action_steps: 8
16 | n_latency_steps: 0
17 | dataset_obs_steps: ${n_obs_steps}
18 | keypoint_visible_rate: 1.0
19 | obs_as_global_cond: True
20 | dataset_target: equi_diffpo.dataset.robomimic_replay_point_cloud_dataset.RobomimicReplayPointCloudDataset
21 | dataset_path: data/robomimic/datasets/${task_name}/${task_name}_voxel_abs.hdf5
22 |
23 | policy:
24 | _target_: equi_diffpo.policy.dp3.DP3
25 | use_point_crop: true
26 | condition_type: film
27 | use_down_condition: true
28 | use_mid_condition: true
29 | use_up_condition: true
30 |
31 | diffusion_step_embed_dim: 128
32 | down_dims:
33 | - 512
34 | - 1024
35 | - 2048
36 | crop_shape:
37 | - 80
38 | - 80
39 | encoder_output_dim: 64
40 | horizon: ${horizon}
41 | kernel_size: 5
42 | n_action_steps: ${n_action_steps}
43 | n_groups: 8
44 | n_obs_steps: ${n_obs_steps}
45 |
46 | noise_scheduler:
47 | _target_: diffusers.schedulers.scheduling_ddim.DDIMScheduler
48 | num_train_timesteps: 100
49 | beta_start: 0.0001
50 | beta_end: 0.02
51 | beta_schedule: squaredcos_cap_v2
52 | clip_sample: True
53 | set_alpha_to_one: True
54 | steps_offset: 0
55 | prediction_type: sample
56 |
57 |
58 | num_inference_steps: 10
59 | obs_as_global_cond: true
60 | shape_meta: ${shape_meta}
61 |
62 | use_pc_color: true
63 | pointnet_type: "pointnet"
64 |
65 |
66 | pointcloud_encoder_cfg:
67 | in_channels: 3
68 | out_channels: ${policy.encoder_output_dim}
69 | use_layernorm: true
70 | final_norm: layernorm # layernorm, none
71 | normal_channel: false
72 |
73 |
74 | ema:
75 | _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
76 | update_after_step: 0
77 | inv_gamma: 1.0
78 | power: 0.75
79 | min_value: 0.0
80 | max_value: 0.9999
81 |
82 | dataloader:
83 | batch_size: 128
84 | num_workers: 8
85 | shuffle: True
86 | pin_memory: True
87 | persistent_workers: True
88 |
89 | val_dataloader:
90 | batch_size: 128
91 | num_workers: 8
92 | shuffle: False
93 | pin_memory: True
94 | persistent_workers: True
95 |
96 | optimizer:
97 | _target_: torch.optim.AdamW
98 | lr: 1.0e-4
99 | betas: [0.95, 0.999]
100 | eps: 1.0e-8
101 | weight_decay: 1.0e-6
102 |
103 | training:
104 | device: "cuda:0"
105 | seed: 42
106 | debug: False
107 | resume: True
108 | lr_scheduler: cosine
109 | lr_warmup_steps: 500
110 | num_epochs: ${eval:'50000 / ${n_demo}'}
111 | gradient_accumulate_every: 1
112 | use_ema: True
113 | rollout_every: ${eval:'1000 / ${n_demo}'}
114 | checkpoint_every: ${eval:'1000 / ${n_demo}'}
115 | val_every: 1
116 | sample_every: 5
117 | max_train_steps: null
118 | max_val_steps: null
119 | tqdm_interval_sec: 1.0
120 |
121 | logging:
122 | project: dp3_${task_name}
123 | resume: true
124 | mode: online
125 | name: dp3_${n_demo}
126 | tags: ["${name}", "${task_name}", "${exp_name}"]
127 | id: null
128 | group: null
129 |
130 |
131 | checkpoint:
132 | save_ckpt: False # if True, save checkpoint every checkpoint_every
133 | topk:
134 | monitor_key: test_mean_score
135 | mode: max
136 | k: 1
137 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
138 | save_last_ckpt: True # this only saves when save_ckpt is True
139 | save_last_snapshot: False
140 |
141 | multi_run:
142 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
143 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
144 |
145 | hydra:
146 | job:
147 | override_dirname: ${name}
148 | run:
149 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
150 | sweep:
151 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
152 | subdir: ${hydra.job.num}
--------------------------------------------------------------------------------
/equi_diffpo/config/task/mimicgen_abs.yaml:
--------------------------------------------------------------------------------
1 | name: mimicgen_abs
2 |
3 | shape_meta: &shape_meta
4 | # acceptable types: rgb, low_dim
5 | obs:
6 | agentview_image:
7 | shape: [3, 84, 84]
8 | type: rgb
9 | robot0_eye_in_hand_image:
10 | shape: [3, 84, 84]
11 | type: rgb
12 | robot0_eef_pos:
13 | shape: [3]
14 | # type default: low_dim
15 | robot0_eef_quat:
16 | shape: [4]
17 | robot0_gripper_qpos:
18 | shape: [2]
19 | action:
20 | shape: [10]
21 |
22 | abs_action: &abs_action True
23 |
24 | env_runner:
25 | _target_: equi_diffpo.env_runner.robomimic_image_runner.RobomimicImageRunner
26 | dataset_path: ${dataset_path}
27 | shape_meta: *shape_meta
28 | n_train: 6
29 | n_train_vis: 2
30 | train_start_idx: 0
31 | n_test: 50
32 | n_test_vis: 4
33 | test_start_seed: 100000
34 | max_steps: ${get_max_steps:${task_name}}
35 | n_obs_steps: ${n_obs_steps}
36 | n_action_steps: ${n_action_steps}
37 | render_obs_key: 'agentview_image'
38 | fps: 10
39 | crf: 22
40 | past_action: ${past_action_visible}
41 | abs_action: *abs_action
42 | tqdm_interval_sec: 1.0
43 | n_envs: 28
44 |
45 | dataset:
46 | # _target_: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset
47 | _target_: ${dataset}
48 | n_demo: ${n_demo}
49 | shape_meta: *shape_meta
50 | dataset_path: ${dataset_path}
51 | horizon: ${horizon}
52 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'}
53 | pad_after: ${eval:'${n_action_steps}-1'}
54 | n_obs_steps: ${dataset_obs_steps}
55 | abs_action: *abs_action
56 | rotation_rep: 'rotation_6d'
57 | use_legacy_normalizer: False
58 | use_cache: True
59 | seed: 42
60 | val_ratio: 0.02
61 |
--------------------------------------------------------------------------------
/equi_diffpo/config/task/mimicgen_pc_abs.yaml:
--------------------------------------------------------------------------------
1 | name: mimicgen_pc_abs
2 |
3 | shape_meta: &shape_meta
4 | # acceptable types: rgb, low_dim
5 | obs:
6 | robot0_eye_in_hand_image:
7 | shape: [3, 84, 84]
8 | type: rgb
9 | point_cloud:
10 | shape: [1024, 6]
11 | type: point_cloud
12 | robot0_eef_pos:
13 | shape: [3]
14 | # type default: low_dim
15 | robot0_eef_quat:
16 | shape: [4]
17 | robot0_gripper_qpos:
18 | shape: [2]
19 | action:
20 | shape: [10]
21 |
22 | env_runner_shape_meta: &env_runner_shape_meta
23 | # acceptable types: rgb, low_dim
24 | obs:
25 | robot0_eye_in_hand_image:
26 | shape: [3, 84, 84]
27 | type: rgb
28 | agentview_image:
29 | shape: [3, 84, 84]
30 | type: rgb
31 | point_cloud:
32 | shape: [1024, 6]
33 | type: point_cloud
34 | robot0_eef_pos:
35 | shape: [3]
36 | # type default: low_dim
37 | robot0_eef_quat:
38 | shape: [4]
39 | robot0_gripper_qpos:
40 | shape: [2]
41 | action:
42 | shape: [10]
43 |
44 | abs_action: &abs_action True
45 |
46 | env_runner:
47 | _target_: equi_diffpo.env_runner.robomimic_image_runner.RobomimicImageRunner
48 | dataset_path: ${dataset_path}
49 | shape_meta: *env_runner_shape_meta
50 | n_train: 6
51 | n_train_vis: 2
52 | train_start_idx: 0
53 | n_test: 50
54 | n_test_vis: 4
55 | test_start_seed: 100000
56 | max_steps: ${get_max_steps:${task_name}}
57 | n_obs_steps: ${n_obs_steps}
58 | n_action_steps: ${n_action_steps}
59 | render_obs_key: 'agentview_image'
60 | fps: 10
61 | crf: 22
62 | past_action: False
63 | abs_action: *abs_action
64 | tqdm_interval_sec: 1.0
65 | n_envs: 28
66 |
67 | dataset:
68 | _target_: ${dataset_target}
69 | n_demo: ${n_demo}
70 | shape_meta: *shape_meta
71 | dataset_path: ${dataset_path}
72 | horizon: ${horizon}
73 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'}
74 | pad_after: ${eval:'${n_action_steps}-1'}
75 | n_obs_steps: ${dataset_obs_steps}
76 | abs_action: *abs_action
77 | rotation_rep: 'rotation_6d'
78 | use_legacy_normalizer: False
79 | use_cache: False
80 | seed: 42
81 | val_ratio: 0.02
82 |
--------------------------------------------------------------------------------
/equi_diffpo/config/task/mimicgen_rel.yaml:
--------------------------------------------------------------------------------
1 | name: mimicgen_rel
2 |
3 | shape_meta: &shape_meta
4 | # acceptable types: rgb, low_dim
5 | obs:
6 | agentview_image:
7 | shape: [3, 84, 84]
8 | type: rgb
9 | robot0_eye_in_hand_image:
10 | shape: [3, 84, 84]
11 | type: rgb
12 | robot0_eef_pos:
13 | shape: [3]
14 | # type default: low_dim
15 | robot0_eef_quat:
16 | shape: [4]
17 | robot0_gripper_qpos:
18 | shape: [2]
19 | action:
20 | shape: [7]
21 |
22 | abs_action: &abs_action False
23 |
24 | env_runner:
25 | _target_: equi_diffpo.env_runner.robomimic_image_runner.RobomimicImageRunner
26 | dataset_path: ${dataset_path}
27 | shape_meta: *shape_meta
28 | n_train: 6
29 | n_train_vis: 2
30 | train_start_idx: 0
31 | n_test: 50
32 | n_test_vis: 4
33 | test_start_seed: 100000
34 | max_steps: ${get_max_steps:${task_name}}
35 | n_obs_steps: ${n_obs_steps}
36 | n_action_steps: ${n_action_steps}
37 | render_obs_key: 'agentview_image'
38 | fps: 10
39 | crf: 22
40 | past_action: ${past_action_visible}
41 | abs_action: *abs_action
42 | tqdm_interval_sec: 1.0
43 | n_envs: 28
44 |
45 | dataset:
46 | # _target_: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset
47 | _target_: ${dataset}
48 | n_demo: ${n_demo}
49 | shape_meta: *shape_meta
50 | dataset_path: ${dataset_path}
51 | horizon: ${horizon}
52 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'}
53 | pad_after: ${eval:'${n_action_steps}-1'}
54 | n_obs_steps: ${dataset_obs_steps}
55 | abs_action: *abs_action
56 | rotation_rep: 'rotation_6d'
57 | use_legacy_normalizer: False
58 | use_cache: True
59 | seed: 42
60 | val_ratio: 0.02
61 |
--------------------------------------------------------------------------------
/equi_diffpo/config/task/mimicgen_voxel_abs.yaml:
--------------------------------------------------------------------------------
1 | name: mimicgen_abs
2 |
3 | shape_meta: &shape_meta
4 | # acceptable types: rgb, low_dim
5 | obs:
6 | robot0_eye_in_hand_image:
7 | shape: [3, 84, 84]
8 | type: rgb
9 | voxels:
10 | shape: [4, 64, 64, 64]
11 | type: voxel
12 | robot0_eef_pos:
13 | shape: [3]
14 | # type default: low_dim
15 | robot0_eef_quat:
16 | shape: [4]
17 | robot0_gripper_qpos:
18 | shape: [2]
19 | action:
20 | shape: [10]
21 |
22 | env_runner_shape_meta: &env_runner_shape_meta
23 | # acceptable types: rgb, low_dim
24 | obs:
25 | robot0_eye_in_hand_image:
26 | shape: [3, 84, 84]
27 | type: rgb
28 | agentview_image:
29 | shape: [3, 84, 84]
30 | type: rgb
31 | voxels:
32 | shape: [4, 64, 64, 64]
33 | type: voxel
34 | robot0_eef_pos:
35 | shape: [3]
36 | # type default: low_dim
37 | robot0_eef_quat:
38 | shape: [4]
39 | robot0_gripper_qpos:
40 | shape: [2]
41 | action:
42 | shape: [10]
43 |
44 | # dataset_path: &dataset_path data/robomimic/datasets/${task_name}/${task_name}_voxel_abs.hdf5
45 | abs_action: &abs_action True
46 |
47 | env_runner:
48 | _target_: equi_diffpo.env_runner.robomimic_image_runner.RobomimicImageRunner
49 | dataset_path: ${dataset_path}
50 | shape_meta: *env_runner_shape_meta
51 | n_train: 6
52 | n_train_vis: 2
53 | train_start_idx: 0
54 | n_test: 50
55 | n_test_vis: 4
56 | test_start_seed: 100000
57 | max_steps: ${get_max_steps:${task_name}}
58 | n_obs_steps: ${n_obs_steps}
59 | n_action_steps: ${n_action_steps}
60 | render_obs_key: 'agentview_image'
61 | fps: 10
62 | crf: 22
63 | past_action: ${past_action_visible}
64 | abs_action: *abs_action
65 | tqdm_interval_sec: 1.0
66 | n_envs: 28
67 |
68 | dataset:
69 | _target_: equi_diffpo.dataset.robomimic_replay_voxel_sym_dataset.RobomimicReplayVoxelSymDataset
70 | n_demo: ${n_demo}
71 | shape_meta: *shape_meta
72 | dataset_path: ${dataset_path}
73 | horizon: ${horizon}
74 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'}
75 | pad_after: ${eval:'${n_action_steps}-1'}
76 | n_obs_steps: ${dataset_obs_steps}
77 | abs_action: *abs_action
78 | rotation_rep: 'rotation_6d'
79 | use_legacy_normalizer: False
80 | use_cache: True
81 | seed: 42
82 | val_ratio: 0.02
83 | ws_x_center: ${get_ws_x_center:${task_name}}
84 | ws_y_center: ${get_ws_y_center:${task_name}}
85 |
--------------------------------------------------------------------------------
/equi_diffpo/config/task/mimicgen_voxel_rel.yaml:
--------------------------------------------------------------------------------
1 | name: mimicgen_rel
2 |
3 | shape_meta: &shape_meta
4 | # acceptable types: rgb, low_dim
5 | obs:
6 | robot0_eye_in_hand_image:
7 | shape: [3, 84, 84]
8 | type: rgb
9 | voxels:
10 | shape: [4, 64, 64, 64]
11 | type: voxel
12 | robot0_eef_pos:
13 | shape: [3]
14 | # type default: low_dim
15 | robot0_eef_quat:
16 | shape: [4]
17 | robot0_gripper_qpos:
18 | shape: [2]
19 | action:
20 | shape: [7]
21 |
22 | env_runner_shape_meta: &env_runner_shape_meta
23 | # acceptable types: rgb, low_dim
24 | obs:
25 | robot0_eye_in_hand_image:
26 | shape: [3, 84, 84]
27 | type: rgb
28 | agentview_image:
29 | shape: [3, 84, 84]
30 | type: rgb
31 | voxels:
32 | shape: [4, 64, 64, 64]
33 | type: voxel
34 | robot0_eef_pos:
35 | shape: [3]
36 | # type default: low_dim
37 | robot0_eef_quat:
38 | shape: [4]
39 | robot0_gripper_qpos:
40 | shape: [2]
41 | action:
42 | shape: [7]
43 |
44 | # dataset_path: &dataset_path data/robomimic/datasets/${task_name}/${task_name}_voxel.hdf5
45 | abs_action: &abs_action False
46 |
47 | env_runner:
48 | _target_: equi_diffpo.env_runner.robomimic_image_runner.RobomimicImageRunner
49 | dataset_path: ${dataset_path}
50 | shape_meta: *env_runner_shape_meta
51 | n_train: 6
52 | n_train_vis: 2
53 | train_start_idx: 0
54 | n_test: 50
55 | n_test_vis: 4
56 | test_start_seed: 100000
57 | max_steps: ${get_max_steps:${task_name}}
58 | n_obs_steps: ${n_obs_steps}
59 | n_action_steps: ${n_action_steps}
60 | render_obs_key: 'agentview_image'
61 | fps: 10
62 | crf: 22
63 | past_action: ${past_action_visible}
64 | abs_action: *abs_action
65 | tqdm_interval_sec: 1.0
66 | n_envs: 28
67 |
68 | dataset:
69 | _target_: equi_diffpo.dataset.robomimic_replay_voxel_sym_dataset.RobomimicReplayVoxelSymDataset
70 | n_demo: ${n_demo}
71 | shape_meta: *shape_meta
72 | dataset_path: ${dataset_path}
73 | horizon: ${horizon}
74 | pad_before: ${eval:'${n_obs_steps}-1+${n_latency_steps}'}
75 | pad_after: ${eval:'${n_action_steps}-1'}
76 | n_obs_steps: ${dataset_obs_steps}
77 | abs_action: *abs_action
78 | rotation_rep: 'rotation_6d'
79 | use_legacy_normalizer: False
80 | use_cache: True
81 | seed: 42
82 | val_ratio: 0.02
83 | ws_x_center: ${get_ws_x_center:${task_name}}
84 | ws_y_center: ${get_ws_y_center:${task_name}}
85 |
--------------------------------------------------------------------------------
/equi_diffpo/config/train_act_abs.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - task: mimicgen_abs
4 |
5 | name: act
6 | _target_: equi_diffpo.workspace.train_act_workspace.TrainActWorkspace
7 |
8 | shape_meta: ${task.shape_meta}
9 | exp_name: "default"
10 |
11 | task_name: stack_d1
12 | n_demo: 200
13 | horizon: 10
14 | n_obs_steps: 1
15 | n_action_steps: 10
16 | n_latency_steps: 0
17 | dataset_obs_steps: ${n_obs_steps}
18 | past_action_visible: False
19 | dataset: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset
20 | dataset_path: data/robomimic/datasets/${task_name}/${task_name}_abs.hdf5
21 |
22 | policy:
23 | _target_: equi_diffpo.policy.act_policy.ACTPolicyWrapper
24 |
25 | shape_meta: ${shape_meta}
26 |
27 | max_timesteps: ${task.env_runner.max_steps}
28 | temporal_agg: false
29 | n_envs: ${task.env_runner.n_envs}
30 | horizon: ${horizon}
31 |
32 | dataloader:
33 | batch_size: 64
34 | num_workers: 4
35 | shuffle: True
36 | pin_memory: True
37 | persistent_workers: True
38 |
39 | val_dataloader:
40 | batch_size: 64
41 | num_workers: 4
42 | shuffle: False
43 | pin_memory: True
44 | persistent_workers: True
45 |
46 | training:
47 | device: "cuda:0"
48 | seed: 0
49 | debug: False
50 | resume: True
51 | num_epochs: ${eval:'50000 / ${n_demo}'}
52 | rollout_every: ${eval:'1000 / ${n_demo}'}
53 | checkpoint_every: ${eval:'1000 / ${n_demo}'}
54 | val_every: 1
55 | max_train_steps: null
56 | max_val_steps: null
57 | tqdm_interval_sec: 1.0
58 |
59 | logging:
60 | project: diffusion_policy_${task_name}
61 | resume: True
62 | mode: online
63 | name: act_demo${n_demo}
64 | tags: ["${name}", "${task_name}", "${exp_name}"]
65 | id: null
66 | group: null
67 |
68 | checkpoint:
69 | topk:
70 | monitor_key: test_mean_score
71 | mode: max
72 | k: 5
73 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
74 | save_last_ckpt: True
75 | save_last_snapshot: False
76 |
77 | multi_run:
78 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
79 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
80 |
81 | hydra:
82 | job:
83 | override_dirname: ${name}
84 | run:
85 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
86 | sweep:
87 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
88 | subdir: ${hydra.job.num}
89 |
--------------------------------------------------------------------------------
/equi_diffpo/config/train_bc_rnn.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - task: mimicgen_rel
4 |
5 | name: bc_rnn
6 | _target_: equi_diffpo.workspace.train_robomimic_image_workspace.TrainRobomimicImageWorkspace
7 |
8 | shape_meta: ${task.shape_meta}
9 | exp_name: "default"
10 |
11 | task_name: stack_d1
12 | n_demo: 200
13 | horizon: &horizon 10
14 | n_obs_steps: 1
15 | n_action_steps: 1
16 | n_latency_steps: 0
17 | dataset_obs_steps: *horizon
18 | past_action_visible: False
19 | dataset: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset
20 | dataset_path: data/robomimic/datasets/${task_name}/${task_name}.hdf5
21 |
22 | policy:
23 | _target_: equi_diffpo.policy.robomimic_image_policy.RobomimicImagePolicy
24 | shape_meta: ${shape_meta}
25 | algo_name: bc_rnn
26 | obs_type: image
27 | # oc.select resolver: key, default
28 | task_name: ${oc.select:task.task_name,lift}
29 | dataset_type: ${oc.select:task.dataset_type,ph}
30 | crop_shape: [76,76]
31 |
32 | dataloader:
33 | batch_size: 64
34 | num_workers: 4
35 | shuffle: True
36 | pin_memory: True
37 | persistent_workers: True
38 |
39 | val_dataloader:
40 | batch_size: 64
41 | num_workers: 4
42 | shuffle: False
43 | pin_memory: True
44 | persistent_workers: True
45 |
46 | training:
47 | device: "cuda:0"
48 | seed: 0
49 | debug: False
50 | resume: True
51 | # optimization
52 | num_epochs: ${eval:'50000 / ${n_demo}'}
53 | # training loop control
54 | # in epochs
55 | rollout_every: ${eval:'1000 / ${n_demo}'}
56 | checkpoint_every: ${eval:'1000 / ${n_demo}'}
57 | val_every: 1
58 | sample_every: 5
59 | # steps per epoch
60 | max_train_steps: null
61 | max_val_steps: null
62 | # misc
63 | tqdm_interval_sec: 1.0
64 |
65 | logging:
66 | project: diffusion_policy_${task_name}
67 | resume: True
68 | mode: online
69 | name: bc_rnn_demo${n_demo}
70 | tags: ["${name}", "${task_name}", "${exp_name}"]
71 | id: null
72 | group: null
73 |
74 | checkpoint:
75 | topk:
76 | monitor_key: test_mean_score
77 | mode: max
78 | k: 5
79 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
80 | save_last_ckpt: True
81 | save_last_snapshot: False
82 |
83 | multi_run:
84 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
85 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
86 |
87 | hydra:
88 | job:
89 | override_dirname: ${name}
90 | run:
91 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
92 | sweep:
93 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
94 | subdir: ${hydra.job.num}
95 |
--------------------------------------------------------------------------------
/equi_diffpo/config/train_diffusion_transformer.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - task: mimicgen_abs
4 |
5 | name: diff_t
6 | _target_: equi_diffpo.workspace.train_diffusion_transformer_hybrid_workspace.TrainDiffusionTransformerHybridWorkspace
7 |
8 | shape_meta: ${task.shape_meta}
9 | exp_name: "default"
10 |
11 | task_name: stack_d1
12 | n_demo: 200
13 | horizon: 10
14 | n_obs_steps: 2
15 | n_action_steps: 8
16 | n_latency_steps: 0
17 | dataset_obs_steps: ${n_obs_steps}
18 | past_action_visible: False
19 | obs_as_cond: True
20 | dataset: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset
21 | dataset_path: data/robomimic/datasets/${task_name}/${task_name}_abs.hdf5
22 |
23 | policy:
24 | _target_: equi_diffpo.policy.diffusion_transformer_hybrid_image_policy.DiffusionTransformerHybridImagePolicy
25 |
26 | shape_meta: ${shape_meta}
27 |
28 | noise_scheduler:
29 | _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
30 | num_train_timesteps: 100
31 | beta_start: 0.0001
32 | beta_end: 0.02
33 | beta_schedule: squaredcos_cap_v2
34 | variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
35 | clip_sample: True # required when predict_epsilon=False
36 | prediction_type: epsilon # or sample
37 |
38 | horizon: ${horizon}
39 | n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
40 | n_obs_steps: ${n_obs_steps}
41 | num_inference_steps: 100
42 |
43 | crop_shape: [76, 76]
44 | obs_encoder_group_norm: True
45 | eval_fixed_crop: True
46 |
47 | n_layer: 8
48 | n_cond_layers: 0 # >0: use transformer encoder for cond, otherwise use MLP
49 | n_head: 4
50 | n_emb: 256
51 | p_drop_emb: 0.0
52 | p_drop_attn: 0.3
53 | causal_attn: True
54 | time_as_cond: True # if false, use BERT like encoder only arch, time as input
55 | obs_as_cond: ${obs_as_cond}
56 |
57 | # scheduler.step params
58 | # predict_epsilon: True
59 |
60 | ema:
61 | _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
62 | update_after_step: 0
63 | inv_gamma: 1.0
64 | power: 0.75
65 | min_value: 0.0
66 | max_value: 0.9999
67 |
68 | dataloader:
69 | batch_size: 64
70 | num_workers: 4
71 | shuffle: True
72 | pin_memory: True
73 | persistent_workers: True
74 |
75 | val_dataloader:
76 | batch_size: 64
77 | num_workers: 4
78 | shuffle: False
79 | pin_memory: True
80 | persistent_workers: True
81 |
82 | optimizer:
83 | transformer_weight_decay: 1.0e-3
84 | obs_encoder_weight_decay: 1.0e-6
85 | learning_rate: 1.0e-4
86 | betas: [0.9, 0.95]
87 |
88 | training:
89 | device: "cuda:0"
90 | seed: 0
91 | debug: False
92 | resume: True
93 | # optimization
94 | lr_scheduler: cosine
95 | # Transformer needs LR warmup
96 | lr_warmup_steps: 1000
97 | num_epochs: ${eval:'50000 / ${n_demo}'}
98 | gradient_accumulate_every: 1
99 | # EMA destroys performance when used with BatchNorm
100 | # replace BatchNorm with GroupNorm.
101 | use_ema: True
102 | # training loop control
103 | # in epochs
104 | rollout_every: ${eval:'1000 / ${n_demo}'}
105 | checkpoint_every: ${eval:'1000 / ${n_demo}'}
106 | val_every: 1
107 | sample_every: 5
108 | # steps per epoch
109 | max_train_steps: null
110 | max_val_steps: null
111 | # misc
112 | tqdm_interval_sec: 1.0
113 |
114 | logging:
115 | project: diffusion_policy_${task_name}
116 | resume: True
117 | mode: online
118 | name: diff_t_demo${n_demo}
119 | tags: ["${name}", "${task_name}", "${exp_name}"]
120 | id: null
121 | group: null
122 |
123 | checkpoint:
124 | topk:
125 | monitor_key: test_mean_score
126 | mode: max
127 | k: 5
128 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
129 | save_last_ckpt: True
130 | save_last_snapshot: False
131 |
132 | multi_run:
133 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
134 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
135 |
136 | hydra:
137 | job:
138 | override_dirname: ${name}
139 | run:
140 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
141 | sweep:
142 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
143 | subdir: ${hydra.job.num}
144 |
--------------------------------------------------------------------------------
/equi_diffpo/config/train_diffusion_unet.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - task: mimicgen_abs
4 |
5 | name: diff_c
6 | _target_: equi_diffpo.workspace.train_diffusion_unet_hybrid_workspace.TrainDiffusionUnetHybridWorkspace
7 |
8 | shape_meta: ${task.shape_meta}
9 | exp_name: "default"
10 |
11 | task_name: stack_d1
12 | n_demo: 200
13 | horizon: 16
14 | n_obs_steps: 2
15 | n_action_steps: 8
16 | n_latency_steps: 0
17 | dataset_obs_steps: ${n_obs_steps}
18 | past_action_visible: False
19 | obs_as_global_cond: True
20 | dataset: equi_diffpo.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset
21 | dataset_path: data/robomimic/datasets/${task_name}/${task_name}_abs.hdf5
22 |
23 | policy:
24 | _target_: equi_diffpo.policy.diffusion_unet_hybrid_image_policy.DiffusionUnetHybridImagePolicy
25 |
26 | shape_meta: ${shape_meta}
27 |
28 | noise_scheduler:
29 | _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
30 | num_train_timesteps: 100
31 | beta_start: 0.0001
32 | beta_end: 0.02
33 | beta_schedule: squaredcos_cap_v2
34 | variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
35 | clip_sample: True # required when predict_epsilon=False
36 | prediction_type: epsilon # or sample
37 |
38 | horizon: ${horizon}
39 | n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
40 | n_obs_steps: ${n_obs_steps}
41 | num_inference_steps: 100
42 | obs_as_global_cond: ${obs_as_global_cond}
43 | crop_shape: [76, 76]
44 | # crop_shape: null
45 | diffusion_step_embed_dim: 128
46 | down_dims: [512, 1024, 2048]
47 | kernel_size: 5
48 | n_groups: 8
49 | cond_predict_scale: True
50 | obs_encoder_group_norm: True
51 | eval_fixed_crop: True
52 | rot_aug: False
53 |
54 | # scheduler.step params
55 | # predict_epsilon: True
56 |
57 | ema:
58 | _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
59 | update_after_step: 0
60 | inv_gamma: 1.0
61 | power: 0.75
62 | min_value: 0.0
63 | max_value: 0.9999
64 |
65 | dataloader:
66 | batch_size: 64
67 | num_workers: 4
68 | shuffle: True
69 | pin_memory: True
70 | persistent_workers: True
71 |
72 | val_dataloader:
73 | batch_size: 64
74 | num_workers: 4
75 | shuffle: False
76 | pin_memory: True
77 | persistent_workers: True
78 |
79 | optimizer:
80 | _target_: torch.optim.AdamW
81 | lr: 1.0e-4
82 | betas: [0.95, 0.999]
83 | eps: 1.0e-8
84 | weight_decay: 1.0e-6
85 |
86 | training:
87 | device: "cuda:0"
88 | seed: 0
89 | debug: False
90 | resume: True
91 | # optimization
92 | lr_scheduler: cosine
93 | lr_warmup_steps: 500
94 | num_epochs: ${eval:'50000 / ${n_demo}'}
95 | gradient_accumulate_every: 1
96 | # EMA destroys performance when used with BatchNorm
97 | # replace BatchNorm with GroupNorm.
98 | use_ema: True
99 | # training loop control
100 | # in epochs
101 | rollout_every: ${eval:'1000 / ${n_demo}'}
102 | checkpoint_every: ${eval:'1000 / ${n_demo}'}
103 | val_every: 1
104 | sample_every: 5
105 | # steps per epoch
106 | max_train_steps: null
107 | max_val_steps: null
108 | # misc
109 | tqdm_interval_sec: 1.0
110 |
111 | logging:
112 | project: diffusion_policy_${task_name}
113 | resume: True
114 | mode: online
115 | name: diff_c_demo${n_demo}
116 | tags: ["${name}", "${task_name}", "${exp_name}"]
117 | id: null
118 | group: null
119 |
120 | checkpoint:
121 | topk:
122 | monitor_key: test_mean_score
123 | mode: max
124 | k: 5
125 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
126 | save_last_ckpt: True
127 | save_last_snapshot: False
128 |
129 | multi_run:
130 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
131 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
132 |
133 | hydra:
134 | job:
135 | override_dirname: ${name}
136 | run:
137 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
138 | sweep:
139 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
140 | subdir: ${hydra.job.num}
141 |
--------------------------------------------------------------------------------
/equi_diffpo/config/train_diffusion_unet_voxel_abs.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - task: mimicgen_voxel_abs
4 |
5 | name: diff_voxel
6 | _target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
7 |
8 | shape_meta: ${task.shape_meta}
9 | exp_name: "default"
10 |
11 | task_name: stack_d1
12 | n_demo: 200
13 | horizon: 16
14 | n_obs_steps: 1
15 | n_action_steps: 8
16 | n_latency_steps: 0
17 | dataset_obs_steps: ${n_obs_steps}
18 | past_action_visible: False
19 | # dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
20 | dataset_path: data/robomimic/datasets/${task_name}/${task_name}_voxel_abs.hdf5
21 |
22 | policy:
23 | _target_: equi_diffpo.policy.diffusion_unet_voxel_policy.DiffusionUNetPolicyVoxel
24 |
25 | shape_meta: ${shape_meta}
26 |
27 | noise_scheduler:
28 | _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
29 | num_train_timesteps: 100
30 | beta_start: 0.0001
31 | beta_end: 0.02
32 | beta_schedule: squaredcos_cap_v2
33 | variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
34 | clip_sample: True # required when predict_epsilon=False
35 | prediction_type: epsilon # or sample
36 |
37 | horizon: ${horizon}
38 | n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
39 | n_obs_steps: ${n_obs_steps}
40 | num_inference_steps: 100
41 | crop_shape: [58, 58, 58]
42 | # crop_shape: null
43 | diffusion_step_embed_dim: 128
44 | enc_n_hidden: 256
45 | down_dims: [256, 512, 1024]
46 | kernel_size: 5
47 | n_groups: 8
48 | cond_predict_scale: True
49 | rot_aug: False
50 |
51 | # scheduler.step params
52 | # predict_epsilon: True
53 |
54 | ema:
55 | _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
56 | update_after_step: 0
57 | inv_gamma: 1.0
58 | power: 0.75
59 | min_value: 0.0
60 | max_value: 0.9999
61 |
62 | dataloader:
63 | batch_size: 64
64 | num_workers: 16
65 | shuffle: True
66 | pin_memory: True
67 | persistent_workers: True
68 | drop_last: true
69 |
70 | val_dataloader:
71 | batch_size: 64
72 | num_workers: 16
73 | shuffle: False
74 | pin_memory: True
75 | persistent_workers: True
76 |
77 | optimizer:
78 | betas: [0.95, 0.999]
79 | eps: 1.0e-08
80 | learning_rate: 0.0001
81 | weight_decay: 1.0e-06
82 |
83 | training:
84 | device: "cuda:0"
85 | seed: 0
86 | debug: False
87 | resume: True
88 | # optimization
89 | lr_scheduler: cosine
90 | lr_warmup_steps: 500
91 | num_epochs: ${eval:'50000 / ${n_demo}'}
92 | gradient_accumulate_every: 1
93 | # EMA destroys performance when used with BatchNorm
94 | # replace BatchNorm with GroupNorm.
95 | use_ema: True
96 | # training loop control
97 | # in epochs
98 | rollout_every: ${eval:'1000 / ${n_demo}'}
99 | checkpoint_every: ${eval:'1000 / ${n_demo}'}
100 | val_every: 1
101 | sample_every: 5
102 | # steps per epoch
103 | max_train_steps: null
104 | max_val_steps: null
105 | # misc
106 | tqdm_interval_sec: 1.0
107 |
108 | logging:
109 | project: equi_diff_${task_name}_voxel
110 | resume: True
111 | mode: online
112 | name: diff_voxel_${n_demo}
113 | tags: ["${name}", "${task_name}", "${exp_name}"]
114 | id: null
115 | group: null
116 |
117 | checkpoint:
118 | topk:
119 | monitor_key: test_mean_score
120 | mode: max
121 | k: 5
122 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
123 | save_last_ckpt: True
124 | save_last_snapshot: False
125 |
126 | multi_run:
127 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
128 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
129 |
130 | hydra:
131 | job:
132 | override_dirname: ${name}
133 | run:
134 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
135 | sweep:
136 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
137 | subdir: ${hydra.job.num}
138 |
--------------------------------------------------------------------------------
/equi_diffpo/config/train_equi_diffusion_unet_abs.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - task: mimicgen_abs
4 |
5 | name: equi_diff
6 | _target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
7 |
8 | shape_meta: ${task.shape_meta}
9 | exp_name: "default"
10 |
11 | task_name: stack_d1
12 | n_demo: 200
13 | horizon: 16
14 | n_obs_steps: 2
15 | n_action_steps: 8
16 | n_latency_steps: 0
17 | dataset_obs_steps: ${n_obs_steps}
18 | past_action_visible: False
19 | dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
20 | dataset_path: data/robomimic/datasets/${task_name}/${task_name}_abs.hdf5
21 |
22 | policy:
23 | _target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_policy.DiffusionEquiUNetCNNEncPolicy
24 |
25 | shape_meta: ${shape_meta}
26 |
27 | noise_scheduler:
28 | _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
29 | num_train_timesteps: 100
30 | beta_start: 0.0001
31 | beta_end: 0.02
32 | beta_schedule: squaredcos_cap_v2
33 | variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
34 | clip_sample: True # required when predict_epsilon=False
35 | prediction_type: epsilon # or sample
36 |
37 | horizon: ${horizon}
38 | n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
39 | n_obs_steps: ${n_obs_steps}
40 | num_inference_steps: 100
41 | crop_shape: [76, 76]
42 | # crop_shape: null
43 | diffusion_step_embed_dim: 128
44 | enc_n_hidden: 128
45 | down_dims: [512, 1024, 2048]
46 | kernel_size: 5
47 | n_groups: 8
48 | cond_predict_scale: True
49 | rot_aug: False
50 |
51 | # scheduler.step params
52 | # predict_epsilon: True
53 |
54 | ema:
55 | _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
56 | update_after_step: 0
57 | inv_gamma: 1.0
58 | power: 0.75
59 | min_value: 0.0
60 | max_value: 0.9999
61 |
62 | dataloader:
63 | batch_size: 128
64 | num_workers: 4
65 | shuffle: True
66 | pin_memory: True
67 | persistent_workers: True
68 | drop_last: true
69 |
70 | val_dataloader:
71 | batch_size: 128
72 | num_workers: 8
73 | shuffle: False
74 | pin_memory: True
75 | persistent_workers: True
76 |
77 | optimizer:
78 | betas: [0.95, 0.999]
79 | eps: 1.0e-08
80 | learning_rate: 0.0001
81 | weight_decay: 1.0e-06
82 |
83 | training:
84 | device: "cuda:0"
85 | seed: 0
86 | debug: False
87 | resume: True
88 | # optimization
89 | lr_scheduler: cosine
90 | lr_warmup_steps: 500
91 | num_epochs: ${eval:'50000 / ${n_demo}'}
92 | gradient_accumulate_every: 1
93 | # EMA destroys performance when used with BatchNorm
94 | # replace BatchNorm with GroupNorm.
95 | use_ema: True
96 | # training loop control
97 | # in epochs
98 | rollout_every: ${eval:'1000 / ${n_demo}'}
99 | checkpoint_every: ${eval:'1000 / ${n_demo}'}
100 | val_every: 1
101 | sample_every: 5
102 | # steps per epoch
103 | max_train_steps: null
104 | max_val_steps: null
105 | # misc
106 | tqdm_interval_sec: 1.0
107 |
108 | logging:
109 | project: diffusion_policy_${task_name}
110 | resume: True
111 | mode: online
112 | name: equidiff_demo${n_demo}
113 | tags: ["${name}", "${task_name}", "${exp_name}"]
114 | id: null
115 | group: null
116 |
117 | checkpoint:
118 | topk:
119 | monitor_key: test_mean_score
120 | mode: max
121 | k: 5
122 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
123 | save_last_ckpt: True
124 | save_last_snapshot: False
125 |
126 | multi_run:
127 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
128 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
129 |
130 | hydra:
131 | job:
132 | override_dirname: ${name}
133 | run:
134 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
135 | sweep:
136 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
137 | subdir: ${hydra.job.num}
138 |
--------------------------------------------------------------------------------
/equi_diffpo/config/train_equi_diffusion_unet_rel.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - task: mimicgen_rel
4 |
5 | name: equi_diff
6 | _target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
7 |
8 | shape_meta: ${task.shape_meta}
9 | exp_name: "default"
10 |
11 | task_name: stack_d1
12 | n_demo: 200
13 | horizon: 16
14 | n_obs_steps: 2
15 | n_action_steps: 8
16 | n_latency_steps: 0
17 | dataset_obs_steps: ${n_obs_steps}
18 | past_action_visible: False
19 | dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
20 | dataset_path: data/robomimic/datasets/${task_name}/${task_name}.hdf5
21 |
22 | policy:
23 | _target_: equi_diffpo.policy.diffusion_equi_unet_cnn_enc_rel_policy.DiffusionEquiUNetCNNEncRelPolicy
24 |
25 | shape_meta: ${shape_meta}
26 |
27 | noise_scheduler:
28 | _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
29 | num_train_timesteps: 100
30 | beta_start: 0.0001
31 | beta_end: 0.02
32 | beta_schedule: squaredcos_cap_v2
33 | variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
34 | clip_sample: True # required when predict_epsilon=False
35 | prediction_type: epsilon # or sample
36 |
37 | horizon: ${horizon}
38 | n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
39 | n_obs_steps: ${n_obs_steps}
40 | num_inference_steps: 100
41 | crop_shape: [76, 76]
42 | # crop_shape: null
43 | diffusion_step_embed_dim: 128
44 | enc_n_hidden: 128
45 | down_dims: [512, 1024, 2048]
46 | kernel_size: 5
47 | n_groups: 8
48 | cond_predict_scale: True
49 |
50 | # scheduler.step params
51 | # predict_epsilon: True
52 |
53 | ema:
54 | _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
55 | update_after_step: 0
56 | inv_gamma: 1.0
57 | power: 0.75
58 | min_value: 0.0
59 | max_value: 0.9999
60 |
61 | dataloader:
62 | batch_size: 128
63 | num_workers: 4
64 | shuffle: True
65 | pin_memory: True
66 | persistent_workers: True
67 | drop_last: true
68 |
69 | val_dataloader:
70 | batch_size: 128
71 | num_workers: 8
72 | shuffle: False
73 | pin_memory: True
74 | persistent_workers: True
75 |
76 | optimizer:
77 | betas: [0.95, 0.999]
78 | eps: 1.0e-08
79 | learning_rate: 0.0001
80 | weight_decay: 1.0e-06
81 |
82 | training:
83 | device: "cuda:0"
84 | seed: 0
85 | debug: False
86 | resume: True
87 | # optimization
88 | lr_scheduler: cosine
89 | lr_warmup_steps: 500
90 | num_epochs: ${eval:'50000 / ${n_demo}'}
91 | gradient_accumulate_every: 1
92 | # EMA destroys performance when used with BatchNorm
93 | # replace BatchNorm with GroupNorm.
94 | use_ema: True
95 | # training loop control
96 | # in epochs
97 | rollout_every: ${eval:'1000 / ${n_demo}'}
98 | checkpoint_every: ${eval:'1000 / ${n_demo}'}
99 | val_every: 1
100 | sample_every: 5
101 | # steps per epoch
102 | max_train_steps: null
103 | max_val_steps: null
104 | # misc
105 | tqdm_interval_sec: 1.0
106 |
107 | logging:
108 | project: diffusion_policy_${task_name}_vel
109 | resume: True
110 | mode: online
111 | name: equidiff_demo${n_demo}
112 | tags: ["${name}", "${task_name}", "${exp_name}"]
113 | id: null
114 | group: null
115 |
116 | checkpoint:
117 | topk:
118 | monitor_key: test_mean_score
119 | mode: max
120 | k: 5
121 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
122 | save_last_ckpt: True
123 | save_last_snapshot: False
124 |
125 | multi_run:
126 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
127 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
128 |
129 | hydra:
130 | job:
131 | override_dirname: ${name}
132 | run:
133 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
134 | sweep:
135 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
136 | subdir: ${hydra.job.num}
137 |
--------------------------------------------------------------------------------
/equi_diffpo/config/train_equi_diffusion_unet_voxel_abs.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - task: mimicgen_voxel_abs
4 |
5 | name: equi_diff_voxel
6 | _target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
7 |
8 | shape_meta: ${task.shape_meta}
9 | exp_name: "default"
10 |
11 | task_name: stack_d1
12 | n_demo: 200
13 | horizon: 16
14 | n_obs_steps: 1
15 | n_action_steps: 8
16 | n_latency_steps: 0
17 | dataset_obs_steps: ${n_obs_steps}
18 | past_action_visible: False
19 | # dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
20 | dataset_path: data/robomimic/datasets/${task_name}/${task_name}_voxel_abs.hdf5
21 |
22 | policy:
23 | _target_: equi_diffpo.policy.diffusion_equi_unet_voxel_policy.DiffusionEquiUNetPolicyVoxel
24 |
25 | shape_meta: ${shape_meta}
26 |
27 | noise_scheduler:
28 | _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
29 | num_train_timesteps: 100
30 | beta_start: 0.0001
31 | beta_end: 0.02
32 | beta_schedule: squaredcos_cap_v2
33 | variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
34 | clip_sample: True # required when predict_epsilon=False
35 | prediction_type: epsilon # or sample
36 |
37 | horizon: ${horizon}
38 | n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
39 | n_obs_steps: ${n_obs_steps}
40 | num_inference_steps: 100
41 | crop_shape: [58, 58, 58]
42 | # crop_shape: null
43 | diffusion_step_embed_dim: 128
44 | enc_n_hidden: 128
45 | down_dims: [256, 512, 1024]
46 | kernel_size: 5
47 | n_groups: 8
48 | cond_predict_scale: True
49 | rot_aug: True
50 |
51 | # scheduler.step params
52 | # predict_epsilon: True
53 |
54 | ema:
55 | _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
56 | update_after_step: 0
57 | inv_gamma: 1.0
58 | power: 0.75
59 | min_value: 0.0
60 | max_value: 0.9999
61 |
62 | dataloader:
63 | batch_size: 64
64 | num_workers: 16
65 | shuffle: True
66 | pin_memory: True
67 | persistent_workers: True
68 | drop_last: true
69 |
70 | val_dataloader:
71 | batch_size: 64
72 | num_workers: 16
73 | shuffle: False
74 | pin_memory: True
75 | persistent_workers: True
76 |
77 | optimizer:
78 | betas: [0.95, 0.999]
79 | eps: 1.0e-08
80 | learning_rate: 0.0001
81 | weight_decay: 1.0e-06
82 |
83 | training:
84 | device: "cuda:0"
85 | seed: 0
86 | debug: False
87 | resume: True
88 | # optimization
89 | lr_scheduler: cosine
90 | lr_warmup_steps: 500
91 | num_epochs: ${eval:'50000 / ${n_demo}'}
92 | gradient_accumulate_every: 1
93 | # EMA destroys performance when used with BatchNorm
94 | # replace BatchNorm with GroupNorm.
95 | use_ema: True
96 | # training loop control
97 | # in epochs
98 | rollout_every: ${eval:'1000 / ${n_demo}'}
99 | checkpoint_every: ${eval:'1000 / ${n_demo}'}
100 | val_every: 1
101 | sample_every: 5
102 | # steps per epoch
103 | max_train_steps: null
104 | max_val_steps: null
105 | # misc
106 | tqdm_interval_sec: 1.0
107 |
108 | logging:
109 | project: equi_diff_${task_name}_voxel
110 | resume: True
111 | mode: online
112 | name: equi_diff_voxel_${n_demo}
113 | tags: ["${name}", "${task_name}", "${exp_name}"]
114 | id: null
115 | group: null
116 |
117 | checkpoint:
118 | topk:
119 | monitor_key: test_mean_score
120 | mode: max
121 | k: 5
122 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
123 | save_last_ckpt: True
124 | save_last_snapshot: False
125 |
126 | multi_run:
127 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
128 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
129 |
130 | hydra:
131 | job:
132 | override_dirname: ${name}
133 | run:
134 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
135 | sweep:
136 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
137 | subdir: ${hydra.job.num}
138 |
--------------------------------------------------------------------------------
/equi_diffpo/config/train_equi_diffusion_unet_voxel_rel.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - task: mimicgen_voxel_rel
4 |
5 | name: equi_diff_voxel
6 | _target_: equi_diffpo.workspace.train_equi_workspace.TrainEquiWorkspace
7 |
8 | shape_meta: ${task.shape_meta}
9 | exp_name: "default"
10 |
11 | task_name: stack_d1
12 | n_demo: 200
13 | horizon: 16
14 | n_obs_steps: 1
15 | n_action_steps: 8
16 | n_latency_steps: 0
17 | dataset_obs_steps: ${n_obs_steps}
18 | past_action_visible: False
19 | # dataset: equi_diffpo.dataset.robomimic_replay_image_sym_dataset.RobomimicReplayImageSymDataset
20 | dataset_path: data/robomimic/datasets/${task_name}/${task_name}_voxel.hdf5
21 |
22 | policy:
23 | _target_: equi_diffpo.policy.diffusion_equi_unet_voxel_rel_policy.DiffusionEquiUNetRelPolicyVoxel
24 |
25 | shape_meta: ${shape_meta}
26 |
27 | noise_scheduler:
28 | _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
29 | num_train_timesteps: 100
30 | beta_start: 0.0001
31 | beta_end: 0.02
32 | beta_schedule: squaredcos_cap_v2
33 | variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan
34 | clip_sample: True # required when predict_epsilon=False
35 | prediction_type: epsilon # or sample
36 |
37 | horizon: ${horizon}
38 | n_action_steps: ${eval:'${n_action_steps}+${n_latency_steps}'}
39 | n_obs_steps: ${n_obs_steps}
40 | num_inference_steps: 100
41 | crop_shape: [58, 58, 58]
42 | # crop_shape: null
43 | diffusion_step_embed_dim: 128
44 | enc_n_hidden: 128
45 | down_dims: [256, 512, 1024]
46 | kernel_size: 5
47 | n_groups: 8
48 | cond_predict_scale: True
49 | rot_aug: True
50 |
51 | # scheduler.step params
52 | # predict_epsilon: True
53 |
54 | ema:
55 | _target_: equi_diffpo.model.diffusion.ema_model.EMAModel
56 | update_after_step: 0
57 | inv_gamma: 1.0
58 | power: 0.75
59 | min_value: 0.0
60 | max_value: 0.9999
61 |
62 | dataloader:
63 | batch_size: 64
64 | num_workers: 16
65 | shuffle: True
66 | pin_memory: True
67 | persistent_workers: True
68 | drop_last: true
69 |
70 | val_dataloader:
71 | batch_size: 64
72 | num_workers: 16
73 | shuffle: False
74 | pin_memory: True
75 | persistent_workers: True
76 |
77 | optimizer:
78 | betas: [0.95, 0.999]
79 | eps: 1.0e-08
80 | learning_rate: 0.0001
81 | weight_decay: 1.0e-06
82 |
83 | training:
84 | device: "cuda:0"
85 | seed: 0
86 | debug: False
87 | resume: True
88 | # optimization
89 | lr_scheduler: cosine
90 | lr_warmup_steps: 500
91 | num_epochs: ${eval:'50000 / ${n_demo}'}
92 | gradient_accumulate_every: 1
93 | # EMA destroys performance when used with BatchNorm
94 | # replace BatchNorm with GroupNorm.
95 | use_ema: True
96 | # training loop control
97 | # in epochs
98 | rollout_every: ${eval:'1000 / ${n_demo}'}
99 | checkpoint_every: ${eval:'1000 / ${n_demo}'}
100 | val_every: 1
101 | sample_every: 5
102 | # steps per epoch
103 | max_train_steps: null
104 | max_val_steps: null
105 | # misc
106 | tqdm_interval_sec: 1.0
107 |
108 | logging:
109 | project: equi_diff_${task_name}_voxel_rel
110 | resume: True
111 | mode: online
112 | name: equi_diff_voxel_${n_demo}
113 | tags: ["${name}", "${task_name}", "${exp_name}"]
114 | id: null
115 | group: null
116 |
117 | checkpoint:
118 | topk:
119 | monitor_key: test_mean_score
120 | mode: max
121 | k: 5
122 | format_str: 'epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt'
123 | save_last_ckpt: True
124 | save_last_snapshot: False
125 |
126 | multi_run:
127 | run_dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
128 | wandb_name_base: ${now:%Y.%m.%d-%H.%M.%S}_${name}_${task_name}
129 |
130 | hydra:
131 | job:
132 | override_dirname: ${name}
133 | run:
134 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
135 | sweep:
136 | dir: data/outputs/${now:%Y.%m.%d}/${now:%H.%M.%S}_${name}_${task_name}
137 | subdir: ${hydra.job.num}
138 |
--------------------------------------------------------------------------------
/equi_diffpo/dataset/base_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import torch
4 | import torch.nn
5 | from equi_diffpo.model.common.normalizer import LinearNormalizer
6 |
7 | class BaseLowdimDataset(torch.utils.data.Dataset):
8 | def get_validation_dataset(self) -> 'BaseLowdimDataset':
9 | # return an empty dataset by default
10 | return BaseLowdimDataset()
11 |
12 | def get_normalizer(self, **kwargs) -> LinearNormalizer:
13 | raise NotImplementedError()
14 |
15 | def get_all_actions(self) -> torch.Tensor:
16 | raise NotImplementedError()
17 |
18 | def __len__(self) -> int:
19 | return 0
20 |
21 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
22 | """
23 | output:
24 | obs: T, Do
25 | action: T, Da
26 | """
27 | raise NotImplementedError()
28 |
29 |
30 | class BaseImageDataset(torch.utils.data.Dataset):
31 | def get_validation_dataset(self) -> 'BaseLowdimDataset':
32 | # return an empty dataset by default
33 | return BaseImageDataset()
34 |
35 | def get_normalizer(self, **kwargs) -> LinearNormalizer:
36 | raise NotImplementedError()
37 |
38 | def get_all_actions(self) -> torch.Tensor:
39 | raise NotImplementedError()
40 |
41 | def __len__(self) -> int:
42 | return 0
43 |
44 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
45 | """
46 | output:
47 | obs:
48 | key: T, *
49 | action: T, Da
50 | """
51 | raise NotImplementedError()
52 |
--------------------------------------------------------------------------------
/equi_diffpo/dataset/robomimic_replay_image_sym_dataset.py:
--------------------------------------------------------------------------------
1 | from equi_diffpo.dataset.base_dataset import LinearNormalizer
2 | from equi_diffpo.model.common.normalizer import LinearNormalizer
3 | from equi_diffpo.dataset.robomimic_replay_image_dataset import RobomimicReplayImageDataset, normalizer_from_stat
4 | from equi_diffpo.common.normalize_util import (
5 | robomimic_abs_action_only_symmetric_normalizer_from_stat,
6 | get_range_normalizer_from_stat,
7 | get_range_symmetric_normalizer_from_stat,
8 | get_image_range_normalizer,
9 | get_identity_normalizer_from_stat,
10 | array_to_stats
11 | )
12 | import numpy as np
13 |
14 | class RobomimicReplayImageSymDataset(RobomimicReplayImageDataset):
15 | def __init__(self,
16 | shape_meta: dict,
17 | dataset_path: str,
18 | horizon=1,
19 | pad_before=0,
20 | pad_after=0,
21 | n_obs_steps=None,
22 | abs_action=False,
23 | rotation_rep='rotation_6d', # ignored when abs_action=False
24 | use_legacy_normalizer=False,
25 | use_cache=False,
26 | seed=42,
27 | val_ratio=0.0,
28 | n_demo=100
29 | ):
30 | super().__init__(
31 | shape_meta,
32 | dataset_path,
33 | horizon,
34 | pad_before,
35 | pad_after,
36 | n_obs_steps,
37 | abs_action,
38 | rotation_rep,
39 | use_legacy_normalizer,
40 | use_cache,
41 | seed,
42 | val_ratio,
43 | n_demo
44 | )
45 |
46 | def get_normalizer(self, **kwargs) -> LinearNormalizer:
47 | normalizer = LinearNormalizer()
48 |
49 | # action
50 | stat = array_to_stats(self.replay_buffer['action'])
51 | if self.abs_action:
52 | if stat['mean'].shape[-1] > 10:
53 | # dual arm
54 | raise NotImplementedError
55 | else:
56 | this_normalizer = robomimic_abs_action_only_symmetric_normalizer_from_stat(stat)
57 |
58 | if self.use_legacy_normalizer:
59 | this_normalizer = normalizer_from_stat(stat)
60 | else:
61 | # already normalized
62 | this_normalizer = get_identity_normalizer_from_stat(stat)
63 | normalizer['action'] = this_normalizer
64 |
65 | # obs
66 | for key in self.lowdim_keys:
67 | stat = array_to_stats(self.replay_buffer[key])
68 |
69 | if key.endswith('qpos'):
70 | this_normalizer = get_range_normalizer_from_stat(stat)
71 | elif key.endswith('pos'):
72 | this_normalizer = get_range_symmetric_normalizer_from_stat(stat)
73 | elif key.endswith('quat'):
74 | # quaternion is in [-1,1] already
75 | this_normalizer = get_identity_normalizer_from_stat(stat)
76 | elif key.find('bbox') > -1:
77 | this_normalizer = get_identity_normalizer_from_stat(stat)
78 | else:
79 | raise RuntimeError('unsupported')
80 | normalizer[key] = this_normalizer
81 |
82 | # image
83 | for key in self.rgb_keys:
84 | normalizer[key] = get_image_range_normalizer()
85 |
86 | normalizer['pos_vecs'] = get_identity_normalizer_from_stat({'min': -1 * np.ones([10, 2], np.float32), 'max': np.ones([10, 2], np.float32)})
87 | normalizer['crops'] = get_image_range_normalizer()
88 |
89 | return normalizer
90 |
91 |
--------------------------------------------------------------------------------
/equi_diffpo/dataset/robomimic_replay_lowdim_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 | import torch
3 | import numpy as np
4 | import h5py
5 | from tqdm import tqdm
6 | import copy
7 | from equi_diffpo.common.pytorch_util import dict_apply
8 | from equi_diffpo.dataset.base_dataset import BaseLowdimDataset, LinearNormalizer
9 | from equi_diffpo.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
10 | from equi_diffpo.model.common.rotation_transformer import RotationTransformer
11 | from equi_diffpo.common.replay_buffer import ReplayBuffer
12 | from equi_diffpo.common.sampler import (
13 | SequenceSampler, get_val_mask, downsample_mask)
14 | from equi_diffpo.common.normalize_util import (
15 | robomimic_abs_action_only_normalizer_from_stat,
16 | robomimic_abs_action_only_dual_arm_normalizer_from_stat,
17 | get_identity_normalizer_from_stat,
18 | array_to_stats
19 | )
20 |
21 | class RobomimicReplayLowdimDataset(BaseLowdimDataset):
22 | def __init__(self,
23 | dataset_path: str,
24 | horizon=1,
25 | pad_before=0,
26 | pad_after=0,
27 | obs_keys: List[str]=[
28 | 'object',
29 | 'robot0_eef_pos',
30 | 'robot0_eef_quat',
31 | 'robot0_gripper_qpos'],
32 | abs_action=False,
33 | rotation_rep='rotation_6d',
34 | use_legacy_normalizer=False,
35 | seed=42,
36 | val_ratio=0.0,
37 | max_train_episodes=None,
38 | n_demo=100
39 | ):
40 | obs_keys = list(obs_keys)
41 | rotation_transformer = RotationTransformer(
42 | from_rep='axis_angle', to_rep=rotation_rep)
43 |
44 | replay_buffer = ReplayBuffer.create_empty_numpy()
45 | with h5py.File(dataset_path) as file:
46 | demos = file['data']
47 | for i in tqdm(range(n_demo), desc="Loading hdf5 to ReplayBuffer"):
48 | demo = demos[f'demo_{i}']
49 | episode = _data_to_obs(
50 | raw_obs=demo['obs'],
51 | raw_actions=demo['actions'][:].astype(np.float32),
52 | obs_keys=obs_keys,
53 | abs_action=abs_action,
54 | rotation_transformer=rotation_transformer)
55 | replay_buffer.add_episode(episode)
56 |
57 | val_mask = get_val_mask(
58 | n_episodes=replay_buffer.n_episodes,
59 | val_ratio=val_ratio,
60 | seed=seed)
61 | train_mask = ~val_mask
62 | train_mask = downsample_mask(
63 | mask=train_mask,
64 | max_n=max_train_episodes,
65 | seed=seed)
66 |
67 | sampler = SequenceSampler(
68 | replay_buffer=replay_buffer,
69 | sequence_length=horizon,
70 | pad_before=pad_before,
71 | pad_after=pad_after,
72 | episode_mask=train_mask)
73 |
74 | self.replay_buffer = replay_buffer
75 | self.sampler = sampler
76 | self.abs_action = abs_action
77 | self.train_mask = train_mask
78 | self.horizon = horizon
79 | self.pad_before = pad_before
80 | self.pad_after = pad_after
81 | self.use_legacy_normalizer = use_legacy_normalizer
82 |
83 | def get_validation_dataset(self):
84 | val_set = copy.copy(self)
85 | val_set.sampler = SequenceSampler(
86 | replay_buffer=self.replay_buffer,
87 | sequence_length=self.horizon,
88 | pad_before=self.pad_before,
89 | pad_after=self.pad_after,
90 | episode_mask=~self.train_mask
91 | )
92 | val_set.train_mask = ~self.train_mask
93 | return val_set
94 |
95 | def get_normalizer(self, **kwargs) -> LinearNormalizer:
96 | normalizer = LinearNormalizer()
97 |
98 | # action
99 | stat = array_to_stats(self.replay_buffer['action'])
100 | if self.abs_action:
101 | if stat['mean'].shape[-1] > 10:
102 | # dual arm
103 | this_normalizer = robomimic_abs_action_only_dual_arm_normalizer_from_stat(stat)
104 | else:
105 | this_normalizer = robomimic_abs_action_only_normalizer_from_stat(stat)
106 |
107 | if self.use_legacy_normalizer:
108 | this_normalizer = normalizer_from_stat(stat)
109 | else:
110 | # already normalized
111 | this_normalizer = get_identity_normalizer_from_stat(stat)
112 | normalizer['action'] = this_normalizer
113 |
114 | # aggregate obs stats
115 | obs_stat = array_to_stats(self.replay_buffer['obs'])
116 |
117 |
118 | normalizer['obs'] = normalizer_from_stat(obs_stat)
119 | return normalizer
120 |
121 | def get_all_actions(self) -> torch.Tensor:
122 | return torch.from_numpy(self.replay_buffer['action'])
123 |
124 | def __len__(self):
125 | return len(self.sampler)
126 |
127 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
128 | data = self.sampler.sample_sequence(idx)
129 | torch_data = dict_apply(data, torch.from_numpy)
130 | return torch_data
131 |
132 | def normalizer_from_stat(stat):
133 | max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max())
134 | scale = np.full_like(stat['max'], fill_value=1/max_abs)
135 | offset = np.zeros_like(stat['max'])
136 | return SingleFieldLinearNormalizer.create_manual(
137 | scale=scale,
138 | offset=offset,
139 | input_stats_dict=stat
140 | )
141 |
142 | def _data_to_obs(raw_obs, raw_actions, obs_keys, abs_action, rotation_transformer):
143 | obs = np.concatenate([
144 | raw_obs[key] for key in obs_keys
145 | ], axis=-1).astype(np.float32)
146 |
147 | if abs_action:
148 | is_dual_arm = False
149 | if raw_actions.shape[-1] == 14:
150 | # dual arm
151 | raw_actions = raw_actions.reshape(-1,2,7)
152 | is_dual_arm = True
153 |
154 | pos = raw_actions[...,:3]
155 | rot = raw_actions[...,3:6]
156 | gripper = raw_actions[...,6:]
157 | rot = rotation_transformer.forward(rot)
158 | raw_actions = np.concatenate([
159 | pos, rot, gripper
160 | ], axis=-1).astype(np.float32)
161 |
162 | if is_dual_arm:
163 | raw_actions = raw_actions.reshape(-1,20)
164 |
165 | data = {
166 | 'obs': obs,
167 | 'action': raw_actions
168 | }
169 | return data
170 |
--------------------------------------------------------------------------------
/equi_diffpo/dataset/robomimic_replay_lowdim_sym_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 | import torch
3 | import numpy as np
4 | from equi_diffpo.common.pytorch_util import dict_apply
5 | from equi_diffpo.dataset.base_dataset import LinearNormalizer
6 | from equi_diffpo.dataset.robomimic_replay_lowdim_dataset import RobomimicReplayLowdimDataset, normalizer_from_stat
7 | from equi_diffpo.common.normalize_util import robomimic_abs_action_only_symmetric_normalizer_from_stat
8 | from equi_diffpo.common.normalize_util import (
9 | robomimic_abs_action_only_symmetric_normalizer_from_stat,
10 | get_identity_normalizer_from_stat,
11 | array_to_stats
12 | )
13 |
14 |
15 | class RobomimicReplayLowdimSymDataset(RobomimicReplayLowdimDataset):
16 | def __init__(self,
17 | dataset_path: str,
18 | horizon=1,
19 | pad_before=0,
20 | pad_after=0,
21 | obs_keys: List[str]=[
22 | 'object',
23 | 'robot0_eef_pos',
24 | 'robot0_eef_quat',
25 | 'robot0_gripper_qpos'],
26 | abs_action=False,
27 | rotation_rep='rotation_6d',
28 | use_legacy_normalizer=False,
29 | seed=42,
30 | val_ratio=0.0,
31 | max_train_episodes=None,
32 | n_demo=100
33 | ):
34 | super().__init__(
35 | dataset_path,
36 | horizon,
37 | pad_before,
38 | pad_after,
39 | obs_keys,
40 | abs_action,
41 | rotation_rep,
42 | use_legacy_normalizer,
43 | seed,
44 | val_ratio,
45 | max_train_episodes,
46 | n_demo,
47 | )
48 |
49 | def get_normalizer(self, **kwargs) -> LinearNormalizer:
50 | normalizer = LinearNormalizer()
51 |
52 | # action
53 | stat = array_to_stats(self.replay_buffer['action'])
54 | if self.abs_action:
55 | if stat['mean'].shape[-1] > 10:
56 | # dual arm
57 | raise NotImplementedError
58 | else:
59 | this_normalizer = robomimic_abs_action_only_symmetric_normalizer_from_stat(stat)
60 |
61 | if self.use_legacy_normalizer:
62 | this_normalizer = normalizer_from_stat(stat)
63 | else:
64 | # already normalized
65 | this_normalizer = get_identity_normalizer_from_stat(stat)
66 | normalizer['action'] = this_normalizer
67 |
68 | # aggregate obs stats
69 | obs_stat = array_to_stats(self.replay_buffer['obs'])
70 |
71 |
72 | normalizer['obs'] = normalizer_from_stat(obs_stat)
73 | return normalizer
74 |
--------------------------------------------------------------------------------
/equi_diffpo/env/robomimic/robomimic_image_wrapper.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 | from matplotlib.pyplot import fill
3 | import numpy as np
4 | import gym
5 | from gym import spaces
6 | from omegaconf import OmegaConf
7 | from robomimic.envs.env_robosuite import EnvRobosuite
8 |
9 | class RobomimicImageWrapper(gym.Env):
10 | def __init__(self,
11 | env: EnvRobosuite,
12 | shape_meta: dict,
13 | init_state: Optional[np.ndarray]=None,
14 | render_obs_key='agentview_image',
15 | ):
16 |
17 | self.env = env
18 | self.render_obs_key = render_obs_key
19 | self.init_state = init_state
20 | self.seed_state_map = dict()
21 | self._seed = None
22 | self.shape_meta = shape_meta
23 | self.render_cache = None
24 | self.has_reset_before = False
25 |
26 | # setup spaces
27 | action_shape = shape_meta['action']['shape']
28 | action_space = spaces.Box(
29 | low=-1,
30 | high=1,
31 | shape=action_shape,
32 | dtype=np.float32
33 | )
34 | self.action_space = action_space
35 |
36 | observation_space = spaces.Dict()
37 | for key, value in shape_meta['obs'].items():
38 | shape = value['shape']
39 | min_value, max_value = -1, 1
40 | if key.endswith('image'):
41 | min_value, max_value = 0, 1
42 | elif key.endswith('depth'):
43 | min_value, max_value = 0, 1
44 | elif key.endswith('voxels'):
45 | min_value, max_value = 0, 1
46 | elif key.endswith('point_cloud'):
47 | min_value, max_value = -10, 10
48 | elif key.endswith('quat'):
49 | min_value, max_value = -1, 1
50 | elif key.endswith('qpos'):
51 | min_value, max_value = -1, 1
52 | elif key.endswith('pos'):
53 | # better range?
54 | min_value, max_value = -1, 1
55 | else:
56 | raise RuntimeError(f"Unsupported type {key}")
57 |
58 | this_space = spaces.Box(
59 | low=min_value,
60 | high=max_value,
61 | shape=shape,
62 | dtype=np.float32
63 | )
64 | observation_space[key] = this_space
65 | self.observation_space = observation_space
66 |
67 |
68 | def get_observation(self, raw_obs=None):
69 | if raw_obs is None:
70 | raw_obs = self.env.get_observation()
71 |
72 | self.render_cache = raw_obs[self.render_obs_key]
73 |
74 | obs = dict()
75 | for key in self.observation_space.keys():
76 | obs[key] = raw_obs[key]
77 | return obs
78 |
79 | def seed(self, seed=None):
80 | np.random.seed(seed=seed)
81 | self._seed = seed
82 |
83 | def reset(self):
84 | if self.init_state is not None:
85 | if not self.has_reset_before:
86 | # the env must be fully reset at least once to ensure correct rendering
87 | self.env.reset()
88 | self.has_reset_before = True
89 |
90 | # always reset to the same state
91 | # to be compatible with gym
92 | raw_obs = self.env.reset_to({'states': self.init_state})
93 | elif self._seed is not None:
94 | # reset to a specific seed
95 | seed = self._seed
96 | if seed in self.seed_state_map:
97 | # env.reset is expensive, use cache
98 | raw_obs = self.env.reset_to({'states': self.seed_state_map[seed]})
99 | else:
100 | # robosuite's initializes all use numpy global random state
101 | np.random.seed(seed=seed)
102 | raw_obs = self.env.reset()
103 | state = self.env.get_state()['states']
104 | self.seed_state_map[seed] = state
105 | self._seed = None
106 | else:
107 | # random reset
108 | raw_obs = self.env.reset()
109 |
110 | # return obs
111 | obs = self.get_observation(raw_obs)
112 | return obs
113 |
114 | def step(self, action):
115 | raw_obs, reward, done, info = self.env.step(action)
116 | obs = self.get_observation(raw_obs)
117 | return obs, reward, done, info
118 |
119 | def render(self, mode='rgb_array'):
120 | if self.render_cache is None:
121 | raise RuntimeError('Must run reset or step before render.')
122 | img = np.moveaxis(self.render_cache, 0, -1)
123 | img = (img * 255).astype(np.uint8)
124 | return img
125 |
126 |
127 | def test():
128 | import os
129 | from omegaconf import OmegaConf
130 | cfg_path = os.path.expanduser('~/dev/diffusion_policy/diffusion_policy/config/task/lift_image.yaml')
131 | cfg = OmegaConf.load(cfg_path)
132 | shape_meta = cfg['shape_meta']
133 |
134 |
135 | import robomimic.utils.file_utils as FileUtils
136 | import robomimic.utils.env_utils as EnvUtils
137 | from matplotlib import pyplot as plt
138 |
139 | dataset_path = os.path.expanduser('~/dev/diffusion_policy/data/robomimic/datasets/square/ph/image.hdf5')
140 | env_meta = FileUtils.get_env_metadata_from_dataset(
141 | dataset_path)
142 |
143 | env = EnvUtils.create_env_from_metadata(
144 | env_meta=env_meta,
145 | render=False,
146 | render_offscreen=False,
147 | use_image_obs=True,
148 | )
149 |
150 | wrapper = RobomimicImageWrapper(
151 | env=env,
152 | shape_meta=shape_meta
153 | )
154 | wrapper.seed(0)
155 | obs = wrapper.reset()
156 | img = wrapper.render()
157 | plt.imshow(img)
158 |
159 |
160 | # states = list()
161 | # for _ in range(2):
162 | # wrapper.seed(0)
163 | # wrapper.reset()
164 | # states.append(wrapper.env.get_state()['states'])
165 | # assert np.allclose(states[0], states[1])
166 |
167 | # img = wrapper.render()
168 | # plt.imshow(img)
169 | # wrapper.seed()
170 | # states.append(wrapper.env.get_state()['states'])
171 |
--------------------------------------------------------------------------------
/equi_diffpo/env/robomimic/robomimic_lowdim_wrapper.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict, Optional
2 | import numpy as np
3 | import gym
4 | from gym.spaces import Box
5 | from robomimic.envs.env_robosuite import EnvRobosuite
6 |
7 | class RobomimicLowdimWrapper(gym.Env):
8 | def __init__(self,
9 | env: EnvRobosuite,
10 | obs_keys: List[str]=[
11 | 'object',
12 | 'robot0_eef_pos',
13 | 'robot0_eef_quat',
14 | 'robot0_gripper_qpos'],
15 | init_state: Optional[np.ndarray]=None,
16 | render_hw=(256,256),
17 | render_camera_name='agentview'
18 | ):
19 |
20 | self.env = env
21 | self.obs_keys = obs_keys
22 | self.init_state = init_state
23 | self.render_hw = render_hw
24 | self.render_camera_name = render_camera_name
25 | self.seed_state_map = dict()
26 | self._seed = None
27 |
28 | # setup spaces
29 | low = np.full(env.action_dimension, fill_value=-1)
30 | high = np.full(env.action_dimension, fill_value=1)
31 | self.action_space = Box(
32 | low=low,
33 | high=high,
34 | shape=low.shape,
35 | dtype=low.dtype
36 | )
37 | obs_example = self.get_observation()
38 | low = np.full_like(obs_example, fill_value=-1)
39 | high = np.full_like(obs_example, fill_value=1)
40 | self.observation_space = Box(
41 | low=low,
42 | high=high,
43 | shape=low.shape,
44 | dtype=low.dtype
45 | )
46 |
47 | def get_observation(self):
48 | raw_obs = self.env.get_observation()
49 | obs = np.concatenate([
50 | raw_obs[key] for key in self.obs_keys
51 | ], axis=0)
52 | return obs
53 |
54 | def seed(self, seed=None):
55 | np.random.seed(seed=seed)
56 | self._seed = seed
57 |
58 | def reset(self):
59 | if self.init_state is not None:
60 | # always reset to the same state
61 | # to be compatible with gym
62 | self.env.reset_to({'states': self.init_state})
63 | elif self._seed is not None:
64 | # reset to a specific seed
65 | seed = self._seed
66 | if seed in self.seed_state_map:
67 | # env.reset is expensive, use cache
68 | self.env.reset_to({'states': self.seed_state_map[seed]})
69 | else:
70 | # robosuite's initializes all use numpy global random state
71 | np.random.seed(seed=seed)
72 | self.env.reset()
73 | state = self.env.get_state()['states']
74 | self.seed_state_map[seed] = state
75 | self._seed = None
76 | else:
77 | # random reset
78 | self.env.reset()
79 |
80 | # return obs
81 | obs = self.get_observation()
82 | return obs
83 |
84 | def step(self, action):
85 | raw_obs, reward, done, info = self.env.step(action)
86 | obs = np.concatenate([
87 | raw_obs[key] for key in self.obs_keys
88 | ], axis=0)
89 | return obs, reward, done, info
90 |
91 | def render(self, mode='rgb_array'):
92 | h, w = self.render_hw
93 | return self.env.render(mode=mode,
94 | height=h, width=w,
95 | camera_name=self.render_camera_name)
96 |
97 |
98 | def test():
99 | import robomimic.utils.file_utils as FileUtils
100 | import robomimic.utils.env_utils as EnvUtils
101 | from matplotlib import pyplot as plt
102 |
103 | dataset_path = '/home/cchi/dev/diffusion_policy/data/robomimic/datasets/square/ph/low_dim.hdf5'
104 | env_meta = FileUtils.get_env_metadata_from_dataset(
105 | dataset_path)
106 |
107 | env = EnvUtils.create_env_from_metadata(
108 | env_meta=env_meta,
109 | render=False,
110 | render_offscreen=False,
111 | use_image_obs=False,
112 | )
113 | wrapper = RobomimicLowdimWrapper(
114 | env=env,
115 | obs_keys=[
116 | 'object',
117 | 'robot0_eef_pos',
118 | 'robot0_eef_quat',
119 | 'robot0_gripper_qpos'
120 | ]
121 | )
122 |
123 | states = list()
124 | for _ in range(2):
125 | wrapper.seed(0)
126 | wrapper.reset()
127 | states.append(wrapper.env.get_state()['states'])
128 | assert np.allclose(states[0], states[1])
129 |
130 | img = wrapper.render()
131 | plt.imshow(img)
132 | # wrapper.seed()
133 | # states.append(wrapper.env.get_state()['states'])
134 |
--------------------------------------------------------------------------------
/equi_diffpo/env_runner/base_image_runner.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | from equi_diffpo.policy.base_image_policy import BaseImagePolicy
3 |
4 | class BaseImageRunner:
5 | def __init__(self, output_dir):
6 | self.output_dir = output_dir
7 |
8 | def run(self, policy: BaseImagePolicy) -> Dict:
9 | raise NotImplementedError()
10 |
--------------------------------------------------------------------------------
/equi_diffpo/env_runner/base_lowdim_runner.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | from equi_diffpo.policy.base_lowdim_policy import BaseLowdimPolicy
3 |
4 | class BaseLowdimRunner:
5 | def __init__(self, output_dir):
6 | self.output_dir = output_dir
7 |
8 | def run(self, policy: BaseLowdimPolicy) -> Dict:
9 | raise NotImplementedError()
10 |
--------------------------------------------------------------------------------
/equi_diffpo/gym_util/multistep_wrapper.py:
--------------------------------------------------------------------------------
1 | import gym
2 | from gym import spaces
3 | import numpy as np
4 | from collections import defaultdict, deque
5 | import dill
6 |
7 | def stack_repeated(x, n):
8 | return np.repeat(np.expand_dims(x,axis=0),n,axis=0)
9 |
10 | def repeated_box(box_space, n):
11 | return spaces.Box(
12 | low=stack_repeated(box_space.low, n),
13 | high=stack_repeated(box_space.high, n),
14 | shape=(n,) + box_space.shape,
15 | dtype=box_space.dtype
16 | )
17 |
18 | def repeated_space(space, n):
19 | if isinstance(space, spaces.Box):
20 | return repeated_box(space, n)
21 | elif isinstance(space, spaces.Dict):
22 | result_space = spaces.Dict()
23 | for key, value in space.items():
24 | result_space[key] = repeated_space(value, n)
25 | return result_space
26 | else:
27 | raise RuntimeError(f'Unsupported space type {type(space)}')
28 |
29 | def take_last_n(x, n):
30 | x = list(x)
31 | n = min(len(x), n)
32 | return np.array(x[-n:])
33 |
34 | def dict_take_last_n(x, n):
35 | result = dict()
36 | for key, value in x.items():
37 | result[key] = take_last_n(value, n)
38 | return result
39 |
40 | def aggregate(data, method='max'):
41 | if method == 'max':
42 | # equivalent to any
43 | return np.max(data)
44 | elif method == 'min':
45 | # equivalent to all
46 | return np.min(data)
47 | elif method == 'mean':
48 | return np.mean(data)
49 | elif method == 'sum':
50 | return np.sum(data)
51 | else:
52 | raise NotImplementedError()
53 |
54 | def stack_last_n_obs(all_obs, n_steps):
55 | assert(len(all_obs) > 0)
56 | all_obs = list(all_obs)
57 | result = np.zeros((n_steps,) + all_obs[-1].shape,
58 | dtype=all_obs[-1].dtype)
59 | start_idx = -min(n_steps, len(all_obs))
60 | result[start_idx:] = np.array(all_obs[start_idx:])
61 | if n_steps > len(all_obs):
62 | # pad
63 | result[:start_idx] = result[start_idx]
64 | return result
65 |
66 |
67 | class MultiStepWrapper(gym.Wrapper):
68 | def __init__(self,
69 | env,
70 | n_obs_steps,
71 | n_action_steps,
72 | max_episode_steps=None,
73 | reward_agg_method='max'
74 | ):
75 | super().__init__(env)
76 | self._action_space = repeated_space(env.action_space, n_action_steps)
77 | self._observation_space = repeated_space(env.observation_space, n_obs_steps)
78 | self.max_episode_steps = max_episode_steps
79 | self.n_obs_steps = n_obs_steps
80 | self.n_action_steps = n_action_steps
81 | self.reward_agg_method = reward_agg_method
82 | self.n_obs_steps = n_obs_steps
83 |
84 | self.obs = deque(maxlen=n_obs_steps+1)
85 | self.reward = list()
86 | self.done = list()
87 | self.info = defaultdict(lambda : deque(maxlen=n_obs_steps+1))
88 |
89 | def reset(self):
90 | """Resets the environment using kwargs."""
91 | obs = super().reset()
92 |
93 | self.obs = deque([obs], maxlen=self.n_obs_steps+1)
94 | self.reward = list()
95 | self.done = list()
96 | self.info = defaultdict(lambda : deque(maxlen=self.n_obs_steps+1))
97 |
98 | obs = self._get_obs(self.n_obs_steps)
99 | return obs
100 |
101 | def step(self, action):
102 | """
103 | actions: (n_action_steps,) + action_shape
104 | """
105 | for act in action:
106 | if len(self.done) > 0 and self.done[-1]:
107 | # termination
108 | break
109 | observation, reward, done, info = super().step(act)
110 |
111 | self.obs.append(observation)
112 | self.reward.append(reward)
113 | if (self.max_episode_steps is not None) \
114 | and (len(self.reward) >= self.max_episode_steps):
115 | # truncation
116 | done = True
117 | self.done.append(done)
118 | self._add_info(info)
119 |
120 | observation = self._get_obs(self.n_obs_steps)
121 | reward = aggregate(self.reward, self.reward_agg_method)
122 | done = aggregate(self.done, 'max')
123 | info = dict_take_last_n(self.info, self.n_obs_steps)
124 | return observation, reward, done, info
125 |
126 | def _get_obs(self, n_steps=1):
127 | """
128 | Output (n_steps,) + obs_shape
129 | """
130 | assert(len(self.obs) > 0)
131 | if isinstance(self.observation_space, spaces.Box):
132 | return stack_last_n_obs(self.obs, n_steps)
133 | elif isinstance(self.observation_space, spaces.Dict):
134 | result = dict()
135 | for key in self.observation_space.keys():
136 | result[key] = stack_last_n_obs(
137 | [obs[key] for obs in self.obs],
138 | n_steps
139 | )
140 | return result
141 | else:
142 | raise RuntimeError('Unsupported space type')
143 |
144 | def _add_info(self, info):
145 | for key, value in info.items():
146 | self.info[key].append(value)
147 |
148 | def get_rewards(self):
149 | return self.reward
150 |
151 | def get_attr(self, name):
152 | return getattr(self, name)
153 |
154 | def run_dill_function(self, dill_fn):
155 | fn = dill.loads(dill_fn)
156 | return fn(self)
157 |
158 | def get_infos(self):
159 | result = dict()
160 | for k, v in self.info.items():
161 | result[k] = list(v)
162 | return result
163 |
--------------------------------------------------------------------------------
/equi_diffpo/gym_util/video_wrapper.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 |
4 | class VideoWrapper(gym.Wrapper):
5 | def __init__(self,
6 | env,
7 | mode='rgb_array',
8 | enabled=True,
9 | steps_per_render=1,
10 | **kwargs
11 | ):
12 | super().__init__(env)
13 |
14 | self.mode = mode
15 | self.enabled = enabled
16 | self.render_kwargs = kwargs
17 | self.steps_per_render = steps_per_render
18 |
19 | self.frames = list()
20 | self.step_count = 0
21 |
22 | def reset(self, **kwargs):
23 | obs = super().reset(**kwargs)
24 | self.frames = list()
25 | self.step_count = 1
26 | if self.enabled:
27 | frame = self.env.render(
28 | mode=self.mode, **self.render_kwargs)
29 | assert frame.dtype == np.uint8
30 | self.frames.append(frame)
31 | return obs
32 |
33 | def step(self, action):
34 | result = super().step(action)
35 | self.step_count += 1
36 | if self.enabled and ((self.step_count % self.steps_per_render) == 0):
37 | frame = self.env.render(
38 | mode=self.mode, **self.render_kwargs)
39 | assert frame.dtype == np.uint8
40 | self.frames.append(frame)
41 | return result
42 |
43 | def render(self, mode='rgb_array', **kwargs):
44 | return self.frames
45 |
--------------------------------------------------------------------------------
/equi_diffpo/model/common/dict_of_tensor_mixin.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class DictOfTensorMixin(nn.Module):
5 | def __init__(self, params_dict=None):
6 | super().__init__()
7 | if params_dict is None:
8 | params_dict = nn.ParameterDict()
9 | self.params_dict = params_dict
10 |
11 | @property
12 | def device(self):
13 | return next(iter(self.parameters())).device
14 |
15 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
16 | def dfs_add(dest, keys, value: torch.Tensor):
17 | if len(keys) == 1:
18 | dest[keys[0]] = value
19 | return
20 |
21 | if keys[0] not in dest:
22 | dest[keys[0]] = nn.ParameterDict()
23 | dfs_add(dest[keys[0]], keys[1:], value)
24 |
25 | def load_dict(state_dict, prefix):
26 | out_dict = nn.ParameterDict()
27 | for key, value in state_dict.items():
28 | value: torch.Tensor
29 | if key.startswith(prefix):
30 | param_keys = key[len(prefix):].split('.')[1:]
31 | # if len(param_keys) == 0:
32 | # import pdb; pdb.set_trace()
33 | dfs_add(out_dict, param_keys, value.clone())
34 | return out_dict
35 |
36 | self.params_dict = load_dict(state_dict, prefix + 'params_dict')
37 | self.params_dict.requires_grad_(False)
38 | return
39 |
--------------------------------------------------------------------------------
/equi_diffpo/model/common/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | from diffusers.optimization import (
2 | Union, SchedulerType, Optional,
3 | Optimizer, TYPE_TO_SCHEDULER_FUNCTION
4 | )
5 |
6 | def get_scheduler(
7 | name: Union[str, SchedulerType],
8 | optimizer: Optimizer,
9 | num_warmup_steps: Optional[int] = None,
10 | num_training_steps: Optional[int] = None,
11 | **kwargs
12 | ):
13 | """
14 | Added kwargs vs diffuser's original implementation
15 |
16 | Unified API to get any scheduler from its name.
17 |
18 | Args:
19 | name (`str` or `SchedulerType`):
20 | The name of the scheduler to use.
21 | optimizer (`torch.optim.Optimizer`):
22 | The optimizer that will be used during training.
23 | num_warmup_steps (`int`, *optional*):
24 | The number of warmup steps to do. This is not required by all schedulers (hence the argument being
25 | optional), the function will raise an error if it's unset and the scheduler type requires it.
26 | num_training_steps (`int``, *optional*):
27 | The number of training steps to do. This is not required by all schedulers (hence the argument being
28 | optional), the function will raise an error if it's unset and the scheduler type requires it.
29 | """
30 | name = SchedulerType(name)
31 | schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
32 | if name == SchedulerType.CONSTANT:
33 | return schedule_func(optimizer, **kwargs)
34 |
35 | # All other schedulers require `num_warmup_steps`
36 | if num_warmup_steps is None:
37 | raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
38 |
39 | if name == SchedulerType.CONSTANT_WITH_WARMUP:
40 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs)
41 |
42 | # All other schedulers require `num_training_steps`
43 | if num_training_steps is None:
44 | raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
45 |
46 | return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs)
47 |
--------------------------------------------------------------------------------
/equi_diffpo/model/common/module_attr_mixin.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | class ModuleAttrMixin(nn.Module):
4 | def __init__(self):
5 | super().__init__()
6 | self._dummy_variable = nn.Parameter()
7 |
8 | @property
9 | def device(self):
10 | return next(iter(self.parameters())).device
11 |
12 | @property
13 | def dtype(self):
14 | return next(iter(self.parameters())).dtype
15 |
--------------------------------------------------------------------------------
/equi_diffpo/model/common/rotation_transformer.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import pytorch3d.transforms as pt
3 | import torch
4 | import numpy as np
5 | import functools
6 |
7 | class RotationTransformer:
8 | valid_reps = [
9 | 'axis_angle',
10 | 'euler_angles',
11 | 'quaternion',
12 | 'rotation_6d',
13 | 'matrix'
14 | ]
15 |
16 | def __init__(self,
17 | from_rep='axis_angle',
18 | to_rep='rotation_6d',
19 | from_convention=None,
20 | to_convention=None):
21 | """
22 | Valid representations
23 |
24 | Always use matrix as intermediate representation.
25 | """
26 | assert from_rep != to_rep
27 | assert from_rep in self.valid_reps
28 | assert to_rep in self.valid_reps
29 | if from_rep == 'euler_angles':
30 | assert from_convention is not None
31 | if to_rep == 'euler_angles':
32 | assert to_convention is not None
33 |
34 | forward_funcs = list()
35 | inverse_funcs = list()
36 |
37 | if from_rep != 'matrix':
38 | funcs = [
39 | getattr(pt, f'{from_rep}_to_matrix'),
40 | getattr(pt, f'matrix_to_{from_rep}')
41 | ]
42 | if from_convention is not None:
43 | funcs = [functools.partial(func, convention=from_convention)
44 | for func in funcs]
45 | forward_funcs.append(funcs[0])
46 | inverse_funcs.append(funcs[1])
47 |
48 | if to_rep != 'matrix':
49 | funcs = [
50 | getattr(pt, f'matrix_to_{to_rep}'),
51 | getattr(pt, f'{to_rep}_to_matrix')
52 | ]
53 | if to_convention is not None:
54 | funcs = [functools.partial(func, convention=to_convention)
55 | for func in funcs]
56 | forward_funcs.append(funcs[0])
57 | inverse_funcs.append(funcs[1])
58 |
59 | inverse_funcs = inverse_funcs[::-1]
60 |
61 | self.forward_funcs = forward_funcs
62 | self.inverse_funcs = inverse_funcs
63 |
64 | @staticmethod
65 | def _apply_funcs(x: Union[np.ndarray, torch.Tensor], funcs: list) -> Union[np.ndarray, torch.Tensor]:
66 | x_ = x
67 | if isinstance(x, np.ndarray):
68 | x_ = torch.from_numpy(x)
69 | x_: torch.Tensor
70 | for func in funcs:
71 | x_ = func(x_)
72 | y = x_
73 | if isinstance(x, np.ndarray):
74 | y = x_.numpy()
75 | return y
76 |
77 | def forward(self, x: Union[np.ndarray, torch.Tensor]
78 | ) -> Union[np.ndarray, torch.Tensor]:
79 | return self._apply_funcs(x, self.forward_funcs)
80 |
81 | def inverse(self, x: Union[np.ndarray, torch.Tensor]
82 | ) -> Union[np.ndarray, torch.Tensor]:
83 | return self._apply_funcs(x, self.inverse_funcs)
84 |
85 |
86 | def test():
87 | tf = RotationTransformer()
88 |
89 | rotvec = np.random.uniform(-2*np.pi,2*np.pi,size=(1000,3))
90 | rot6d = tf.forward(rotvec)
91 | new_rotvec = tf.inverse(rot6d)
92 |
93 | from scipy.spatial.transform import Rotation
94 | diff = Rotation.from_rotvec(rotvec) * Rotation.from_rotvec(new_rotvec).inv()
95 | dist = diff.magnitude()
96 | assert dist.max() < 1e-7
97 |
98 | tf = RotationTransformer('rotation_6d', 'matrix')
99 | rot6d_wrong = rot6d + np.random.normal(scale=0.1, size=rot6d.shape)
100 | mat = tf.forward(rot6d_wrong)
101 | mat_det = np.linalg.det(mat)
102 | assert np.allclose(mat_det, 1)
103 | # rotaiton_6d will be normalized to rotation matrix
104 |
--------------------------------------------------------------------------------
/equi_diffpo/model/common/shape_util.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Tuple, Callable
2 | import torch
3 | import torch.nn as nn
4 |
5 | def get_module_device(m: nn.Module):
6 | device = torch.device('cpu')
7 | try:
8 | param = next(iter(m.parameters()))
9 | device = param.device
10 | except StopIteration:
11 | pass
12 | return device
13 |
14 | @torch.no_grad()
15 | def get_output_shape(
16 | input_shape: Tuple[int],
17 | net: Callable[[torch.Tensor], torch.Tensor]
18 | ):
19 | device = get_module_device(net)
20 | test_input = torch.zeros((1,)+tuple(input_shape), device=device)
21 | test_output = net(test_input)
22 | output_shape = tuple(test_output.shape[1:])
23 | return output_shape
24 |
--------------------------------------------------------------------------------
/equi_diffpo/model/detr/README.md:
--------------------------------------------------------------------------------
1 | This part of the codebase is modified from DETR https://github.com/facebookresearch/detr under APACHE 2.0.
2 |
3 | @article{Carion2020EndtoEndOD,
4 | title={End-to-End Object Detection with Transformers},
5 | author={Nicolas Carion and Francisco Massa and Gabriel Synnaeve and Nicolas Usunier and Alexander Kirillov and Sergey Zagoruyko},
6 | journal={ArXiv},
7 | year={2020},
8 | volume={abs/2005.12872}
9 | }
--------------------------------------------------------------------------------
/equi_diffpo/model/detr/main.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import argparse
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import torch
7 | from .models import build_ACT_model, build_CNNMLP_model
8 |
9 | import IPython
10 | e = IPython.embed
11 |
12 | # def get_args_parser():
13 | # parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
14 | # parser.add_argument('--lr', default=1e-4, type=float) # will be overridden
15 | # parser.add_argument('--lr_backbone', default=1e-5, type=float) # will be overridden
16 | # parser.add_argument('--batch_size', default=2, type=int) # not used
17 | # parser.add_argument('--weight_decay', default=1e-4, type=float)
18 | # parser.add_argument('--epochs', default=300, type=int) # not used
19 | # parser.add_argument('--lr_drop', default=200, type=int) # not used
20 | # parser.add_argument('--clip_max_norm', default=0.1, type=float, # not used
21 | # help='gradient clipping max norm')
22 |
23 | # # Model parameters
24 | # # * Backbone
25 | # parser.add_argument('--backbone', default='resnet18', type=str, # will be overridden
26 | # help="Name of the convolutional backbone to use")
27 | # parser.add_argument('--dilation', action='store_true',
28 | # help="If true, we replace stride with dilation in the last convolutional block (DC5)")
29 | # parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
30 | # help="Type of positional embedding to use on top of the image features")
31 | # parser.add_argument('--camera_names', default=[], type=list, # will be overridden
32 | # help="A list of camera names")
33 |
34 | # # * Transformer
35 | # parser.add_argument('--enc_layers', default=4, type=int, # will be overridden
36 | # help="Number of encoding layers in the transformer")
37 | # parser.add_argument('--dec_layers', default=6, type=int, # will be overridden
38 | # help="Number of decoding layers in the transformer")
39 | # parser.add_argument('--dim_feedforward', default=2048, type=int, # will be overridden
40 | # help="Intermediate size of the feedforward layers in the transformer blocks")
41 | # parser.add_argument('--hidden_dim', default=256, type=int, # will be overridden
42 | # help="Size of the embeddings (dimension of the transformer)")
43 | # parser.add_argument('--dropout', default=0.1, type=float,
44 | # help="Dropout applied in the transformer")
45 | # parser.add_argument('--nheads', default=8, type=int, # will be overridden
46 | # help="Number of attention heads inside the transformer's attentions")
47 | # parser.add_argument('--num_queries', default=400, type=int, # will be overridden
48 | # help="Number of query slots")
49 | # parser.add_argument('--pre_norm', action='store_true')
50 |
51 | # # * Segmentation
52 | # parser.add_argument('--masks', action='store_true',
53 | # help="Train segmentation head if the flag is provided")
54 |
55 | # # repeat args in imitate_episodes just to avoid error. Will not be used
56 | # parser.add_argument('--eval', action='store_true')
57 | # parser.add_argument('--onscreen_render', action='store_true')
58 | # parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True)
59 | # parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True)
60 | # parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
61 | # parser.add_argument('--seed', action='store', type=int, help='seed', required=True)
62 | # parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True)
63 | # parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)
64 | # parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
65 | # parser.add_argument('--temporal_agg', action='store_true')
66 |
67 | # return parser
68 |
69 | class Config:
70 | def __init__(self, dictionary):
71 | for key, value in dictionary.items():
72 | setattr(self, key, value)
73 |
74 | def build_ACT_model_and_optimizer(args_override):
75 | args = {
76 | "lr": 1e-4,
77 | "lr_backbone": 1e-05,
78 | "weight_decay": 1e-4,
79 | "backbone": "resnet18",
80 | "dilation": False,
81 | "position_embedding": "sine",
82 | "camera_names": [],
83 | "enc_layers": 4,
84 | "dec_layers": 6,
85 | "dim_feedforward": 2048,
86 | "hidden_dim": 256,
87 | "dropout": 0.1,
88 | "nheads": 8,
89 | "num_queries": 400,
90 | "pre_norm": False,
91 | "masks": False,
92 | }
93 |
94 | # parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
95 | # args = parser.parse_args()
96 | args = Config(args)
97 |
98 | for k, v in args_override.items():
99 | # args[k] = v
100 | setattr(args, k, v)
101 |
102 | model = build_ACT_model(args)
103 |
104 | param_dicts = [
105 | {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
106 | {
107 | "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
108 | "lr": args.lr_backbone,
109 | },
110 | ]
111 | optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
112 | weight_decay=args.weight_decay)
113 |
114 | return model, optimizer
115 |
116 |
117 | def build_CNNMLP_model_and_optimizer(args_override):
118 | parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
119 | args = parser.parse_args()
120 |
121 | for k, v in args_override.items():
122 | setattr(args, k, v)
123 |
124 | model = build_CNNMLP_model(args)
125 |
126 | param_dicts = [
127 | {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
128 | {
129 | "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
130 | "lr": args.lr_backbone,
131 | },
132 | ]
133 | optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
134 | weight_decay=args.weight_decay)
135 |
136 | return model, optimizer
137 |
138 |
--------------------------------------------------------------------------------
/equi_diffpo/model/detr/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | from .detr_vae import build as build_vae
3 | from .detr_vae import build_cnnmlp as build_cnnmlp
4 |
5 | def build_ACT_model(args):
6 | return build_vae(args)
7 |
8 | def build_CNNMLP_model(args):
9 | return build_cnnmlp(args)
--------------------------------------------------------------------------------
/equi_diffpo/model/detr/models/backbone.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Backbone modules.
4 | """
5 | from collections import OrderedDict
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | import torchvision
10 | from torch import nn
11 | from torchvision.models._utils import IntermediateLayerGetter
12 | from typing import Dict, List
13 |
14 | from ..util.misc import NestedTensor, is_main_process
15 |
16 | from .position_encoding import build_position_encoding
17 |
18 | import IPython
19 | e = IPython.embed
20 |
21 | class FrozenBatchNorm2d(torch.nn.Module):
22 | """
23 | BatchNorm2d where the batch statistics and the affine parameters are fixed.
24 |
25 | Copy-paste from torchvision.misc.ops with added eps before rqsrt,
26 | without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
27 | produce nans.
28 | """
29 |
30 | def __init__(self, n):
31 | super(FrozenBatchNorm2d, self).__init__()
32 | self.register_buffer("weight", torch.ones(n))
33 | self.register_buffer("bias", torch.zeros(n))
34 | self.register_buffer("running_mean", torch.zeros(n))
35 | self.register_buffer("running_var", torch.ones(n))
36 |
37 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
38 | missing_keys, unexpected_keys, error_msgs):
39 | num_batches_tracked_key = prefix + 'num_batches_tracked'
40 | if num_batches_tracked_key in state_dict:
41 | del state_dict[num_batches_tracked_key]
42 |
43 | super(FrozenBatchNorm2d, self)._load_from_state_dict(
44 | state_dict, prefix, local_metadata, strict,
45 | missing_keys, unexpected_keys, error_msgs)
46 |
47 | def forward(self, x):
48 | # move reshapes to the beginning
49 | # to make it fuser-friendly
50 | w = self.weight.reshape(1, -1, 1, 1)
51 | b = self.bias.reshape(1, -1, 1, 1)
52 | rv = self.running_var.reshape(1, -1, 1, 1)
53 | rm = self.running_mean.reshape(1, -1, 1, 1)
54 | eps = 1e-5
55 | scale = w * (rv + eps).rsqrt()
56 | bias = b - rm * scale
57 | return x * scale + bias
58 |
59 |
60 | class BackboneBase(nn.Module):
61 |
62 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
63 | super().__init__()
64 | # for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
65 | # if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
66 | # parameter.requires_grad_(False)
67 | if return_interm_layers:
68 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
69 | else:
70 | return_layers = {'layer4': "0"}
71 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
72 | self.num_channels = num_channels
73 |
74 | def forward(self, tensor):
75 | xs = self.body(tensor)
76 | return xs
77 | # out: Dict[str, NestedTensor] = {}
78 | # for name, x in xs.items():
79 | # m = tensor_list.mask
80 | # assert m is not None
81 | # mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
82 | # out[name] = NestedTensor(x, mask)
83 | # return out
84 |
85 |
86 | class Backbone(BackboneBase):
87 | """ResNet backbone with frozen BatchNorm."""
88 | def __init__(self, name: str,
89 | train_backbone: bool,
90 | return_interm_layers: bool,
91 | dilation: bool):
92 | backbone = getattr(torchvision.models, name)(
93 | replace_stride_with_dilation=[False, False, dilation],
94 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
95 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
96 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
97 |
98 |
99 | class Joiner(nn.Sequential):
100 | def __init__(self, backbone, position_embedding):
101 | super().__init__(backbone, position_embedding)
102 |
103 | def forward(self, tensor_list: NestedTensor):
104 | xs = self[0](tensor_list)
105 | out: List[NestedTensor] = []
106 | pos = []
107 | for name, x in xs.items():
108 | out.append(x)
109 | # position encoding
110 | pos.append(self[1](x).to(x.dtype))
111 |
112 | return out, pos
113 |
114 |
115 | def build_backbone(args):
116 | position_embedding = build_position_encoding(args)
117 | train_backbone = args.lr_backbone > 0
118 | return_interm_layers = args.masks
119 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
120 | model = Joiner(backbone, position_embedding)
121 | model.num_channels = backbone.num_channels
122 | return model
123 |
--------------------------------------------------------------------------------
/equi_diffpo/model/detr/models/position_encoding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Various positional encodings for the transformer.
4 | """
5 | import math
6 | import torch
7 | from torch import nn
8 |
9 | from ..util.misc import NestedTensor
10 |
11 | import IPython
12 | e = IPython.embed
13 |
14 | class PositionEmbeddingSine(nn.Module):
15 | """
16 | This is a more standard version of the position embedding, very similar to the one
17 | used by the Attention is all you need paper, generalized to work on images.
18 | """
19 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
20 | super().__init__()
21 | self.num_pos_feats = num_pos_feats
22 | self.temperature = temperature
23 | self.normalize = normalize
24 | if scale is not None and normalize is False:
25 | raise ValueError("normalize should be True if scale is passed")
26 | if scale is None:
27 | scale = 2 * math.pi
28 | self.scale = scale
29 |
30 | def forward(self, tensor):
31 | x = tensor
32 | # mask = tensor_list.mask
33 | # assert mask is not None
34 | # not_mask = ~mask
35 |
36 | not_mask = torch.ones_like(x[0, [0]])
37 | y_embed = not_mask.cumsum(1, dtype=torch.float32)
38 | x_embed = not_mask.cumsum(2, dtype=torch.float32)
39 | if self.normalize:
40 | eps = 1e-6
41 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
42 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
43 |
44 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
45 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
46 |
47 | pos_x = x_embed[:, :, :, None] / dim_t
48 | pos_y = y_embed[:, :, :, None] / dim_t
49 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
50 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
51 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
52 | return pos
53 |
54 |
55 | class PositionEmbeddingLearned(nn.Module):
56 | """
57 | Absolute pos embedding, learned.
58 | """
59 | def __init__(self, num_pos_feats=256):
60 | super().__init__()
61 | self.row_embed = nn.Embedding(50, num_pos_feats)
62 | self.col_embed = nn.Embedding(50, num_pos_feats)
63 | self.reset_parameters()
64 |
65 | def reset_parameters(self):
66 | nn.init.uniform_(self.row_embed.weight)
67 | nn.init.uniform_(self.col_embed.weight)
68 |
69 | def forward(self, tensor_list: NestedTensor):
70 | x = tensor_list.tensors
71 | h, w = x.shape[-2:]
72 | i = torch.arange(w, device=x.device)
73 | j = torch.arange(h, device=x.device)
74 | x_emb = self.col_embed(i)
75 | y_emb = self.row_embed(j)
76 | pos = torch.cat([
77 | x_emb.unsqueeze(0).repeat(h, 1, 1),
78 | y_emb.unsqueeze(1).repeat(1, w, 1),
79 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
80 | return pos
81 |
82 |
83 | def build_position_encoding(args):
84 | N_steps = args.hidden_dim // 2
85 | if args.position_embedding in ('v2', 'sine'):
86 | # TODO find a better way of exposing other arguments
87 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
88 | elif args.position_embedding in ('v3', 'learned'):
89 | position_embedding = PositionEmbeddingLearned(N_steps)
90 | else:
91 | raise ValueError(f"not supported {args.position_embedding}")
92 |
93 | return position_embedding
94 |
--------------------------------------------------------------------------------
/equi_diffpo/model/detr/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from setuptools import find_packages
3 |
4 | setup(
5 | name='detr',
6 | version='0.0.0',
7 | packages=find_packages(),
8 | license='MIT License',
9 | long_description=open('README.md').read(),
10 | )
--------------------------------------------------------------------------------
/equi_diffpo/model/detr/util/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
--------------------------------------------------------------------------------
/equi_diffpo/model/detr/util/box_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Utilities for bounding box manipulation and GIoU.
4 | """
5 | import torch
6 | from torchvision.ops.boxes import box_area
7 |
8 |
9 | def box_cxcywh_to_xyxy(x):
10 | x_c, y_c, w, h = x.unbind(-1)
11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
12 | (x_c + 0.5 * w), (y_c + 0.5 * h)]
13 | return torch.stack(b, dim=-1)
14 |
15 |
16 | def box_xyxy_to_cxcywh(x):
17 | x0, y0, x1, y1 = x.unbind(-1)
18 | b = [(x0 + x1) / 2, (y0 + y1) / 2,
19 | (x1 - x0), (y1 - y0)]
20 | return torch.stack(b, dim=-1)
21 |
22 |
23 | # modified from torchvision to also return the union
24 | def box_iou(boxes1, boxes2):
25 | area1 = box_area(boxes1)
26 | area2 = box_area(boxes2)
27 |
28 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
29 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
30 |
31 | wh = (rb - lt).clamp(min=0) # [N,M,2]
32 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
33 |
34 | union = area1[:, None] + area2 - inter
35 |
36 | iou = inter / union
37 | return iou, union
38 |
39 |
40 | def generalized_box_iou(boxes1, boxes2):
41 | """
42 | Generalized IoU from https://giou.stanford.edu/
43 |
44 | The boxes should be in [x0, y0, x1, y1] format
45 |
46 | Returns a [N, M] pairwise matrix, where N = len(boxes1)
47 | and M = len(boxes2)
48 | """
49 | # degenerate boxes gives inf / nan results
50 | # so do an early check
51 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
52 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
53 | iou, union = box_iou(boxes1, boxes2)
54 |
55 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
56 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
57 |
58 | wh = (rb - lt).clamp(min=0) # [N,M,2]
59 | area = wh[:, :, 0] * wh[:, :, 1]
60 |
61 | return iou - (area - union) / area
62 |
63 |
64 | def masks_to_boxes(masks):
65 | """Compute the bounding boxes around the provided masks
66 |
67 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
68 |
69 | Returns a [N, 4] tensors, with the boxes in xyxy format
70 | """
71 | if masks.numel() == 0:
72 | return torch.zeros((0, 4), device=masks.device)
73 |
74 | h, w = masks.shape[-2:]
75 |
76 | y = torch.arange(0, h, dtype=torch.float)
77 | x = torch.arange(0, w, dtype=torch.float)
78 | y, x = torch.meshgrid(y, x)
79 |
80 | x_mask = (masks * x.unsqueeze(0))
81 | x_max = x_mask.flatten(1).max(-1)[0]
82 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
83 |
84 | y_mask = (masks * y.unsqueeze(0))
85 | y_max = y_mask.flatten(1).max(-1)[0]
86 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
87 |
88 | return torch.stack([x_min, y_min, x_max, y_max], 1)
89 |
--------------------------------------------------------------------------------
/equi_diffpo/model/detr/util/plot_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Plotting utilities to visualize training logs.
3 | """
4 | import torch
5 | import pandas as pd
6 | import numpy as np
7 | import seaborn as sns
8 | import matplotlib.pyplot as plt
9 |
10 | from pathlib import Path, PurePath
11 |
12 |
13 | def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
14 | '''
15 | Function to plot specific fields from training log(s). Plots both training and test results.
16 |
17 | :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
18 | - fields = which results to plot from each log file - plots both training and test for each field.
19 | - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
20 | - log_name = optional, name of log file if different than default 'log.txt'.
21 |
22 | :: Outputs - matplotlib plots of results in fields, color coded for each log file.
23 | - solid lines are training results, dashed lines are test results.
24 |
25 | '''
26 | func_name = "plot_utils.py::plot_logs"
27 |
28 | # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
29 | # convert single Path to list to avoid 'not iterable' error
30 |
31 | if not isinstance(logs, list):
32 | if isinstance(logs, PurePath):
33 | logs = [logs]
34 | print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
35 | else:
36 | raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
37 | Expect list[Path] or single Path obj, received {type(logs)}")
38 |
39 | # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
40 | for i, dir in enumerate(logs):
41 | if not isinstance(dir, PurePath):
42 | raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
43 | if not dir.exists():
44 | raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
45 | # verify log_name exists
46 | fn = Path(dir / log_name)
47 | if not fn.exists():
48 | print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
49 | print(f"--> full path of missing log file: {fn}")
50 | return
51 |
52 | # load log file(s) and plot
53 | dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
54 |
55 | fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
56 |
57 | for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
58 | for j, field in enumerate(fields):
59 | if field == 'mAP':
60 | coco_eval = pd.DataFrame(
61 | np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
62 | ).ewm(com=ewm_col).mean()
63 | axs[j].plot(coco_eval, c=color)
64 | else:
65 | df.interpolate().ewm(com=ewm_col).mean().plot(
66 | y=[f'train_{field}', f'test_{field}'],
67 | ax=axs[j],
68 | color=[color] * 2,
69 | style=['-', '--']
70 | )
71 | for ax, field in zip(axs, fields):
72 | ax.legend([Path(p).name for p in logs])
73 | ax.set_title(field)
74 |
75 |
76 | def plot_precision_recall(files, naming_scheme='iter'):
77 | if naming_scheme == 'exp_id':
78 | # name becomes exp_id
79 | names = [f.parts[-3] for f in files]
80 | elif naming_scheme == 'iter':
81 | names = [f.stem for f in files]
82 | else:
83 | raise ValueError(f'not supported {naming_scheme}')
84 | fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
85 | for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
86 | data = torch.load(f)
87 | # precision is n_iou, n_points, n_cat, n_area, max_det
88 | precision = data['precision']
89 | recall = data['params'].recThrs
90 | scores = data['scores']
91 | # take precision for all classes, all areas and 100 detections
92 | precision = precision[0, :, :, 0, -1].mean(1)
93 | scores = scores[0, :, :, 0, -1].mean(1)
94 | prec = precision.mean()
95 | rec = data['recall'][0, :, 0, -1].mean()
96 | print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
97 | f'score={scores.mean():0.3f}, ' +
98 | f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
99 | )
100 | axs[0].plot(recall, precision, c=color)
101 | axs[1].plot(recall, scores, c=color)
102 |
103 | axs[0].set_title('Precision / Recall')
104 | axs[0].legend(names)
105 | axs[1].set_title('Scores / Recall')
106 | axs[1].legend(names)
107 | return fig, axs
108 |
--------------------------------------------------------------------------------
/equi_diffpo/model/diffusion/conv1d_components.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | # from einops.layers.torch import Rearrange
5 |
6 |
7 | class Downsample1d(nn.Module):
8 | def __init__(self, dim):
9 | super().__init__()
10 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
11 |
12 | def forward(self, x):
13 | return self.conv(x)
14 |
15 | class Upsample1d(nn.Module):
16 | def __init__(self, dim):
17 | super().__init__()
18 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
19 |
20 | def forward(self, x):
21 | return self.conv(x)
22 |
23 | class Conv1dBlock(nn.Module):
24 | '''
25 | Conv1d --> GroupNorm --> Mish
26 | '''
27 |
28 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
29 | super().__init__()
30 |
31 | self.block = nn.Sequential(
32 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
33 | # Rearrange('batch channels horizon -> batch channels 1 horizon'),
34 | nn.GroupNorm(n_groups, out_channels),
35 | # Rearrange('batch channels 1 horizon -> batch channels horizon'),
36 | nn.Mish(),
37 | )
38 |
39 | def forward(self, x):
40 | return self.block(x)
41 |
42 |
43 | def test():
44 | cb = Conv1dBlock(256, 128, kernel_size=3)
45 | x = torch.zeros((1,256,16))
46 | o = cb(x)
47 |
--------------------------------------------------------------------------------
/equi_diffpo/model/diffusion/ema_model.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | from torch.nn.modules.batchnorm import _BatchNorm
4 |
5 | class EMAModel:
6 | """
7 | Exponential Moving Average of models weights
8 | """
9 |
10 | def __init__(
11 | self,
12 | model,
13 | update_after_step=0,
14 | inv_gamma=1.0,
15 | power=2 / 3,
16 | min_value=0.0,
17 | max_value=0.9999
18 | ):
19 | """
20 | @crowsonkb's notes on EMA Warmup:
21 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
22 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
23 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
24 | at 215.4k steps).
25 | Args:
26 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
27 | power (float): Exponential factor of EMA warmup. Default: 2/3.
28 | min_value (float): The minimum EMA decay rate. Default: 0.
29 | """
30 |
31 | self.averaged_model = model
32 | self.averaged_model.eval()
33 | self.averaged_model.requires_grad_(False)
34 |
35 | self.update_after_step = update_after_step
36 | self.inv_gamma = inv_gamma
37 | self.power = power
38 | self.min_value = min_value
39 | self.max_value = max_value
40 |
41 | self.decay = 0.0
42 | self.optimization_step = 0
43 |
44 | def get_decay(self, optimization_step):
45 | """
46 | Compute the decay factor for the exponential moving average.
47 | """
48 | step = max(0, optimization_step - self.update_after_step - 1)
49 | value = 1 - (1 + step / self.inv_gamma) ** -self.power
50 |
51 | if step <= 0:
52 | return 0.0
53 |
54 | return max(self.min_value, min(value, self.max_value))
55 |
56 | @torch.no_grad()
57 | def step(self, new_model):
58 | self.decay = self.get_decay(self.optimization_step)
59 |
60 | # old_all_dataptrs = set()
61 | # for param in new_model.parameters():
62 | # data_ptr = param.data_ptr()
63 | # if data_ptr != 0:
64 | # old_all_dataptrs.add(data_ptr)
65 |
66 | all_dataptrs = set()
67 | for module, ema_module in zip(new_model.modules(), self.averaged_model.modules()):
68 | for param, ema_param in zip(module.parameters(recurse=False), ema_module.parameters(recurse=False)):
69 | # iterative over immediate parameters only.
70 | if isinstance(param, dict):
71 | raise RuntimeError('Dict parameter not supported')
72 |
73 | # data_ptr = param.data_ptr()
74 | # if data_ptr != 0:
75 | # all_dataptrs.add(data_ptr)
76 |
77 | if isinstance(module, _BatchNorm):
78 | # skip batchnorms
79 | ema_param.copy_(param.to(dtype=ema_param.dtype).data)
80 | elif not param.requires_grad:
81 | ema_param.copy_(param.to(dtype=ema_param.dtype).data)
82 | else:
83 | ema_param.mul_(self.decay)
84 | ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
85 |
86 | # verify that iterating over module and then parameters is identical to parameters recursively.
87 | # assert old_all_dataptrs == all_dataptrs
88 | self.optimization_step += 1
89 |
--------------------------------------------------------------------------------
/equi_diffpo/model/diffusion/positional_embedding.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 | class SinusoidalPosEmb(nn.Module):
6 | def __init__(self, dim):
7 | super().__init__()
8 | self.dim = dim
9 |
10 | def forward(self, x):
11 | device = x.device
12 | half_dim = self.dim // 2
13 | emb = math.log(10000) / (half_dim - 1)
14 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
15 | emb = x[:, None] * emb[None, :]
16 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
17 | return emb
18 |
--------------------------------------------------------------------------------
/equi_diffpo/model/equi/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pointW/equidiff/e40abb003b25071d4e5b01bfa9933dc16cd32c67/equi_diffpo/model/equi/__init__.py
--------------------------------------------------------------------------------
/equi_diffpo/model/equi/equi_conditional_unet1d_vel.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import torch
3 | from escnn import gspaces, nn
4 | from escnn.group import CyclicGroup
5 | from einops import rearrange, repeat
6 | from equi_diffpo.model.diffusion.conditional_unet1d import ConditionalUnet1D
7 | from equi_diffpo.model.common.rotation_transformer import RotationTransformer
8 |
9 |
10 | class EquiDiffusionUNetVel(torch.nn.Module):
11 | def __init__(self, act_emb_dim, local_cond_dim, global_cond_dim, diffusion_step_embed_dim, down_dims, kernel_size, n_groups, cond_predict_scale, N):
12 | super().__init__()
13 | self.unet = ConditionalUnet1D(
14 | input_dim=act_emb_dim,
15 | local_cond_dim=local_cond_dim,
16 | global_cond_dim=global_cond_dim,
17 | diffusion_step_embed_dim=diffusion_step_embed_dim,
18 | down_dims=down_dims,
19 | kernel_size=kernel_size,
20 | n_groups=n_groups,
21 | cond_predict_scale=cond_predict_scale
22 | )
23 | self.N = N
24 | self.group = gspaces.no_base_space(CyclicGroup(self.N))
25 | self.order = self.N
26 | self.act_type = nn.FieldType(self.group, act_emb_dim * [self.group.regular_repr])
27 | self.out_layer = nn.Linear(self.act_type,
28 | self.getOutFieldType())
29 | self.enc_a = nn.SequentialModule(
30 | nn.Linear(self.getOutFieldType(), self.act_type),
31 | nn.ReLU(self.act_type)
32 | )
33 |
34 | self.p = torch.tensor([
35 | [1, 0, 0, 0, 1, 0, 0, 0, 0],
36 | [0, -1, 0, 1, 0, 0, 0, 0, 0],
37 | [0, 0, 0, 0, 0, 0, 0, 0, 1],
38 | [0, 0, 1, 0, 0, 0, 0, 0, 0],
39 | [0, 0, 0, 0, 0, 1, 0, 0, 0],
40 | [0, 0, 0, 0, 0, 0, 1, 0, 0],
41 | [0, 0, 0, 0, 0, 0, 0, 1, 0],
42 | [0, 1, 0, 1, 0, 0, 0, 0, 0],
43 | [-1, 0, 0, 0, 1, 0, 0, 0, 0]
44 | ]).float()
45 | self.p_inv = torch.linalg.inv(self.p)
46 | self.axisangle_to_matrix = RotationTransformer('axis_angle', 'matrix')
47 |
48 | def getOutFieldType(self):
49 | return nn.FieldType(
50 | self.group,
51 | 1 * [self.group.irrep(2)] # 2
52 | + 3 * [self.group.irrep(1)] # 6
53 | + 5 * [self.group.trivial_repr], # 5
54 | )
55 |
56 | # matrix
57 | def getOutput(self, conv_out):
58 | rho2 = conv_out[:, 0:2]
59 | xy = conv_out[:, 2:4]
60 | rho11 = conv_out[:, 4:6]
61 | rho12 = conv_out[:, 6:8]
62 | rho01 = conv_out[:, 8:9]
63 | rho02 = conv_out[:, 9:10]
64 | rho03 = conv_out[:, 10:11]
65 | z = conv_out[:, 11:12]
66 | g = conv_out[:, 12:13]
67 |
68 | v = torch.cat((rho01, rho02, rho03, rho11, rho12, rho2), dim=1)
69 | m = torch.matmul(self.p_inv.to(conv_out.device), v.reshape(-1, 9, 1)).reshape(-1, 9)
70 |
71 | action = torch.cat((xy, z, m, g), dim=1)
72 | return action
73 |
74 | def getActionGeometricTensor(self, act):
75 | batch_size = act.shape[0]
76 | xy = act[:, 0:2]
77 | z = act[:, 2:3]
78 | m = act[:, 3:12]
79 | g = act[:, 12:]
80 |
81 | v = torch.matmul(self.p.to(act.device), m.reshape(-1, 9, 1)).reshape(-1, 9)
82 |
83 | cat = torch.cat(
84 | (
85 | v[:, 7:9].reshape(batch_size, 2),
86 | xy.reshape(batch_size, 2),
87 | v[:, 5:7].reshape(batch_size, 2),
88 | v[:, 3:5].reshape(batch_size, 2),
89 | v[:, 0:3].reshape(batch_size, 3),
90 | z.reshape(batch_size, 1),
91 | g.reshape(batch_size, 1),
92 | ),
93 | dim=1,
94 | )
95 | return nn.GeometricTensor(cat, self.getOutFieldType())
96 |
97 | def forward(self,
98 | sample: torch.Tensor,
99 | timestep: Union[torch.Tensor, float, int],
100 | local_cond=None, global_cond=None, **kwargs):
101 | """
102 | x: (B,T,input_dim)
103 | timestep: (B,) or int, diffusion step
104 | local_cond: (B,T,local_cond_dim)
105 | global_cond: (B,global_cond_dim)
106 | output: (B,T,input_dim)
107 | """
108 | B, T = sample.shape[:2]
109 | sample = rearrange(sample, "b t d -> (b t) d")
110 | sample = self.getActionGeometricTensor(sample)
111 | enc_a_out = self.enc_a(sample).tensor.reshape(B, T, -1)
112 | enc_a_out = rearrange(enc_a_out, "b t (c f) -> (b f) t c", f=self.order)
113 | if type(timestep) == torch.Tensor and len(timestep.shape) == 1:
114 | timestep = repeat(timestep, "b -> (b f)", f=self.order)
115 | if local_cond is not None:
116 | local_cond = rearrange(local_cond, "b t (c f) -> (b f) t c", f=self.order)
117 | if global_cond is not None:
118 | global_cond = rearrange(global_cond, "b (c f) -> (b f) c", f=self.order)
119 | out = self.unet(enc_a_out, timestep, local_cond, global_cond, **kwargs)
120 | out = rearrange(out, "(b f) t c -> (b t) (c f)", f=self.order)
121 | out = nn.GeometricTensor(out, self.act_type)
122 | out = self.out_layer(out).tensor.reshape(B * T, -1)
123 | out = self.getOutput(out)
124 | out = rearrange(out, "(b t) n -> b t n", b=B)
125 | return out
126 |
--------------------------------------------------------------------------------
/equi_diffpo/model/unet/obs_cond_unet1d.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import logging
3 | import torch
4 | import torch.nn as nn
5 | import einops
6 | from einops.layers.torch import Rearrange
7 |
8 | from equi_diffpo.model.diffusion.conv1d_components import (
9 | Downsample1d, Upsample1d, Conv1dBlock)
10 | from equi_diffpo.model.diffusion.positional_embedding import SinusoidalPosEmb
11 | from equi_diffpo.model.diffusion.conditional_unet1d import ConditionalResidualBlock1D
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 | class ObsConditionalUnet1D(nn.Module):
16 | def __init__(self,
17 | out_dim=[16,64],
18 | global_cond_dim=128,
19 | down_dims=[256,512,1024],
20 | kernel_size=3,
21 | n_groups=8,
22 | cond_predict_scale=False,
23 | ):
24 | super().__init__()
25 | out_c = out_dim[1]
26 | self.inp = torch.nn.Parameter(torch.randn(list(out_dim)))
27 |
28 | all_dims = [out_c] + list(down_dims)
29 | start_dim = down_dims[0]
30 |
31 | cond_dim = global_cond_dim
32 |
33 | in_out = list(zip(all_dims[:-1], all_dims[1:]))
34 |
35 | mid_dim = all_dims[-1]
36 | self.mid_modules = nn.ModuleList([
37 | ConditionalResidualBlock1D(
38 | mid_dim, mid_dim, cond_dim=cond_dim,
39 | kernel_size=kernel_size, n_groups=n_groups,
40 | cond_predict_scale=cond_predict_scale
41 | ),
42 | ConditionalResidualBlock1D(
43 | mid_dim, mid_dim, cond_dim=cond_dim,
44 | kernel_size=kernel_size, n_groups=n_groups,
45 | cond_predict_scale=cond_predict_scale
46 | ),
47 | ])
48 |
49 | down_modules = nn.ModuleList([])
50 | for ind, (dim_in, dim_out) in enumerate(in_out):
51 | is_last = ind >= (len(in_out) - 1)
52 | down_modules.append(nn.ModuleList([
53 | ConditionalResidualBlock1D(
54 | dim_in, dim_out, cond_dim=cond_dim,
55 | kernel_size=kernel_size, n_groups=n_groups,
56 | cond_predict_scale=cond_predict_scale),
57 | ConditionalResidualBlock1D(
58 | dim_out, dim_out, cond_dim=cond_dim,
59 | kernel_size=kernel_size, n_groups=n_groups,
60 | cond_predict_scale=cond_predict_scale),
61 | Downsample1d(dim_out) if not is_last else nn.Identity()
62 | ]))
63 |
64 | up_modules = nn.ModuleList([])
65 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
66 | is_last = ind >= (len(in_out) - 1)
67 | up_modules.append(nn.ModuleList([
68 | ConditionalResidualBlock1D(
69 | dim_out*2, dim_in, cond_dim=cond_dim,
70 | kernel_size=kernel_size, n_groups=n_groups,
71 | cond_predict_scale=cond_predict_scale),
72 | ConditionalResidualBlock1D(
73 | dim_in, dim_in, cond_dim=cond_dim,
74 | kernel_size=kernel_size, n_groups=n_groups,
75 | cond_predict_scale=cond_predict_scale),
76 | Upsample1d(dim_in) if not is_last else nn.Identity()
77 | ]))
78 |
79 | final_conv = nn.Sequential(
80 | Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
81 | nn.Conv1d(start_dim, out_c, 1),
82 | )
83 |
84 | self.up_modules = up_modules
85 | self.down_modules = down_modules
86 | self.final_conv = final_conv
87 |
88 | logger.info(
89 | "number of parameters: %e", sum(p.numel() for p in self.parameters())
90 | )
91 |
92 | def forward(self, global_cond):
93 | """
94 | global_cond: (B,global_cond_dim)
95 | output: (B,T,input_dim)
96 | """
97 | B = global_cond.shape[0]
98 | x = einops.repeat(self.inp, 't h -> b h t', b=B)
99 |
100 | h = []
101 | for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
102 | x = resnet(x, global_cond)
103 | x = resnet2(x, global_cond)
104 | h.append(x)
105 | x = downsample(x)
106 |
107 | for mid_module in self.mid_modules:
108 | x = mid_module(x, global_cond)
109 |
110 | for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
111 | x = torch.cat((x, h.pop()), dim=1)
112 | x = resnet(x, global_cond)
113 | x = resnet2(x, global_cond)
114 | x = upsample(x)
115 |
116 | x = self.final_conv(x)
117 |
118 | x = einops.rearrange(x, 'b h t -> b t h')
119 | return x
120 |
121 |
122 |
--------------------------------------------------------------------------------
/equi_diffpo/model/vision/model_getter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 |
4 | def get_resnet(name, weights=None, **kwargs):
5 | """
6 | name: resnet18, resnet34, resnet50
7 | weights: "IMAGENET1K_V1", "r3m"
8 | """
9 | # load r3m weights
10 | if (weights == "r3m") or (weights == "R3M"):
11 | return get_r3m(name=name, **kwargs)
12 |
13 | func = getattr(torchvision.models, name)
14 | resnet = func(weights=weights, **kwargs)
15 | resnet.fc = torch.nn.Identity()
16 | return resnet
17 |
18 | def get_r3m(name, **kwargs):
19 | """
20 | name: resnet18, resnet34, resnet50
21 | """
22 | import r3m
23 | r3m.device = 'cpu'
24 | model = r3m.load_r3m(name)
25 | r3m_model = model.module
26 | resnet_model = r3m_model.convnet
27 | resnet_model = resnet_model.to('cpu')
28 | return resnet_model
29 |
--------------------------------------------------------------------------------
/equi_diffpo/model/vision/rot_randomizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.transforms as transforms
4 | import torchvision.transforms.functional as TF
5 | import torch.nn.functional as F
6 | import random
7 | import numpy as np
8 | from einops import rearrange
9 | import math
10 |
11 | from equi_diffpo.model.common.rotation_transformer import RotationTransformer
12 |
13 | class RotRandomizer(nn.Module):
14 | """
15 | Continuously and randomly rotate the input tensor during training.
16 | Does not rotate the tensor during evaluation.
17 | """
18 |
19 | def __init__(self, min_angle=-180, max_angle=180):
20 | """
21 | Args:
22 | min_angle (float): Minimum rotation angle.
23 | max_angle (float): Maximum rotation angle.
24 | """
25 | super().__init__()
26 | self.min_angle = min_angle
27 | self.max_angle = max_angle
28 | self.tf = RotationTransformer('quaternion', 'matrix')
29 |
30 |
31 | def forward(self, nobs, naction):
32 | """
33 | Randomly rotates the inputs if in training mode.
34 | Keeps inputs unchanged if in evaluation mode.
35 |
36 | Args:
37 | inputs (torch.Tensor): input tensors
38 |
39 | Returns:
40 | torch.Tensor: rotated or unrotated tensors based on the mode
41 | """
42 | if self.training:
43 | obs = nobs["agentview_image"]
44 | pos = nobs["robot0_eef_pos"]
45 | # x, y, z, w -> w, x, y, z
46 | quat = nobs["robot0_eef_quat"][:, :, [3, 0, 1, 2]]
47 | batch_size = obs.shape[0]
48 | T = obs.shape[1]
49 |
50 | for i in range(1000):
51 | angles = torch.rand(batch_size) * 2 * np.pi - np.pi
52 | rotation_matrix = torch.zeros((batch_size, 3, 3), device=obs.device)
53 | rotation_matrix[:, 2, 2] = 1
54 |
55 | angles[torch.rand(batch_size) < 1/64] = 0
56 | rotation_matrix[:, 0, 0] = torch.cos(angles)
57 | rotation_matrix[:, 0, 1] = -torch.sin(angles)
58 | rotation_matrix[:, 1, 0] = torch.sin(angles)
59 | rotation_matrix[:, 1, 1] = torch.cos(angles)
60 |
61 | rotated_naction = naction.clone()
62 | rotated_naction[:, :, 0:3] = (rotation_matrix @ naction[:, :, 0:3].permute(0, 2, 1)).permute(0, 2, 1)
63 | rotated_naction[:, :, [3, 6]] = (rotation_matrix[:, :2, :2] @ naction[:, :, [3, 6]].permute(0, 2, 1)).permute(0, 2, 1)
64 | rotated_naction[:, :, [4, 7]] = (rotation_matrix[:, :2, :2] @ naction[:, :, [4, 7]].permute(0, 2, 1)).permute(0, 2, 1)
65 | rotated_naction[:, :, [5, 8]] = (rotation_matrix[:, :2, :2] @ naction[:, :, [5, 8]].permute(0, 2, 1)).permute(0, 2, 1)
66 |
67 | rotated_pos = (rotation_matrix @ pos.permute(0, 2, 1)).permute(0, 2, 1)
68 | rot = self.tf.forward(quat)
69 | rotated_rot = rotation_matrix.unsqueeze(1) @ rot
70 | rotated_quat = self.tf.inverse(rotated_rot)
71 |
72 | if rotated_pos.min() >= -1 and rotated_pos.max() <= 1 and rotated_naction[:, :, :2].min() > -1 and rotated_naction[:, :, :2].max() < 1:
73 | break
74 | if i == 999:
75 | return nobs, naction
76 |
77 | obs = rearrange(obs, "b t c h w -> b (t c) h w")
78 | grid = F.affine_grid(rotation_matrix[:, :2], obs.size(), align_corners=True)
79 | rotated_obs = F.grid_sample(obs, grid, align_corners=True, mode='bilinear')
80 | rotated_obs = rearrange(rotated_obs, "b (t c) h w -> b t c h w", b=batch_size, t=T)
81 |
82 | nobs["agentview_image"] = rotated_obs
83 | nobs["robot0_eef_pos"] = rotated_pos
84 | # w, x, y, z -> x, y, z, w
85 | nobs["robot0_eef_quat"] = rotated_quat[:, :, [1, 2, 3, 0]]
86 | naction = rotated_naction
87 |
88 | return nobs, naction
89 |
90 | def __repr__(self):
91 | """Pretty print the network."""
92 | header = '{}'.format(str(self.__class__.__name__))
93 | msg = header + "(min_angle={}, max_angle={})".format(self.min_angle, self.max_angle)
94 | return msg
95 |
96 |
--------------------------------------------------------------------------------
/equi_diffpo/model/vision/rot_randomizer_vel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.transforms as transforms
4 | import torchvision.transforms.functional as TF
5 | import random
6 | import numpy as np
7 | from einops import rearrange
8 | import math
9 |
10 | from equi_diffpo.model.common.rotation_transformer import RotationTransformer
11 |
12 | class RotRandomizerVel(nn.Module):
13 | """
14 | Continuously and randomly rotate the input tensor during training.
15 | Does not rotate the tensor during evaluation.
16 | """
17 |
18 | def __init__(self, min_angle=-180, max_angle=180):
19 | """
20 | Args:
21 | min_angle (float): Minimum rotation angle.
22 | max_angle (float): Maximum rotation angle.
23 | """
24 | super().__init__()
25 | self.min_angle = min_angle
26 | self.max_angle = max_angle
27 | self.quat_to_matrix = RotationTransformer('quaternion', 'matrix')
28 | self.axisangle_to_matrix = RotationTransformer('axis_angle', 'matrix')
29 |
30 |
31 | def forward(self, nobs, naction):
32 | """
33 | Randomly rotates the inputs if in training mode.
34 | Keeps inputs unchanged if in evaluation mode.
35 |
36 | Args:
37 | inputs (torch.Tensor): input tensors
38 |
39 | Returns:
40 | torch.Tensor: rotated or unrotated tensors based on the mode
41 | """
42 | if self.training and np.random.random() > 1/64:
43 | obs = nobs["agentview_image"]
44 | batch_size = obs.shape[0]
45 |
46 | angle = random.uniform(self.min_angle, self.max_angle)
47 | angle_rad = math.radians(angle)
48 | rotation_matrix = torch.tensor([[math.cos(angle_rad), -math.sin(angle_rad), 0],
49 | [math.sin(angle_rad), math.cos(angle_rad), 0],
50 | [0, 0, 1]]).to(obs.device)
51 |
52 |
53 | obs = rearrange(obs, "b t c h w -> (b t) c h w")
54 | rotated_obs = TF.rotate(obs, angle)
55 | rotated_obs = rearrange(rotated_obs, "(b t) c h w -> b t c h w", b=batch_size)
56 | nobs["agentview_image"] = rotated_obs
57 |
58 | if "crops" in nobs:
59 | crops = nobs["crops"]
60 | n_crop = crops.shape[2]
61 | crops = rearrange(crops, "b t n c h w -> (b t n) c h w")
62 | crops = TF.rotate(crops, angle)
63 | crops = rearrange(crops, "(b t n) c h w -> b t n c h w", b=batch_size, n=n_crop)
64 | nobs["crops"] = crops
65 |
66 | if "pos_vecs" in nobs:
67 | pos_vecs = nobs["pos_vecs"]
68 | pos_vecs = rearrange(pos_vecs, "b t n d -> (b t n) d")
69 | pos_vecs = (rotation_matrix[:2, :2] @ pos_vecs.T).T
70 | pos_vecs = rearrange(pos_vecs, "(b t n) d -> b t n d", b=batch_size, n=n_crop)
71 | nobs["pos_vecs"] = pos_vecs
72 |
73 | pos = nobs["robot0_eef_pos"]
74 | quat = nobs["robot0_eef_quat"]
75 | pos = rearrange(pos, "b t d -> (b t) d")
76 | quat = rearrange(quat, "b t d -> (b t) d")
77 | rot = self.quat_to_matrix.forward(quat)
78 | pos = (rotation_matrix @ pos.T).T
79 | rot = rotation_matrix @ rot
80 | quat = self.quat_to_matrix.inverse(rot)
81 | pos = rearrange(pos, "(b t) d -> b t d", b=batch_size)
82 | quat = rearrange(quat, "(b t) d -> b t d", b=batch_size)
83 | nobs["robot0_eef_pos"] = pos
84 | nobs["robot0_eef_quat"] = quat
85 |
86 | naction = rearrange(naction, "b t d -> (b t) d")
87 | naction[:, 0:3] = (rotation_matrix @ naction[:, 0:3].T).T
88 | axis_angle = naction[:, 3:6]
89 | m = self.axisangle_to_matrix.forward(axis_angle)
90 | gm = rotation_matrix @ m @ torch.linalg.inv(rotation_matrix)
91 | g_axis_angle = self.axisangle_to_matrix.inverse(gm)
92 | naction[:, 3:6] = g_axis_angle
93 |
94 | naction = rearrange(naction, "(b t) d -> b t d", b=batch_size)
95 | naction = torch.clip(naction, -1, 1)
96 | return nobs, naction
97 |
98 | def __repr__(self):
99 | """Pretty print the network."""
100 | header = '{}'.format(str(self.__class__.__name__))
101 | msg = header + "(min_angle={}, max_angle={})".format(self.min_angle, self.max_angle)
102 | return msg
103 |
104 |
--------------------------------------------------------------------------------
/equi_diffpo/model/vision/voxel_crop_randomizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class VoxelCropRandomizer(nn.Module):
5 | def __init__(
6 | self,
7 | crop_depth,
8 | crop_height,
9 | crop_width,
10 | ):
11 | super().__init__()
12 | self.crop_depth = crop_depth
13 | self.crop_height = crop_height
14 | self.crop_width = crop_width
15 |
16 | def forward(self, voxels):
17 | B, C, D, H, W = voxels.shape
18 | if self.training:
19 | cropped_voxel = []
20 | for i in range(B):
21 | d_start = torch.randint(0, D-self.crop_depth+1, [1])[0]
22 | h_start = torch.randint(0, H-self.crop_height+1, [1])[0]
23 | w_start = torch.randint(0, W-self.crop_width+1, [1])[0]
24 | cropped_voxel.append(voxels[i,
25 | :,
26 | d_start:d_start+self.crop_depth,
27 | h_start:h_start+self.crop_height,
28 | w_start:w_start+self.crop_width,
29 | ])
30 | return torch.stack(cropped_voxel, 0)
31 | else:
32 | voxels = voxels[:,
33 | :,
34 | D//2 - self.crop_depth//2: D//2 + self.crop_depth//2,
35 | H//2 - self.crop_height//2: H//2 + self.crop_height//2,
36 | W//2 - self.crop_width//2: W//2 + self.crop_width//2,
37 | ]
38 | return voxels
39 |
--------------------------------------------------------------------------------
/equi_diffpo/model/vision/voxel_rot_randomizer_rel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.transforms as transforms
4 | import torchvision.transforms.functional as TF
5 | import torch.nn.functional as F
6 | import random
7 | import numpy as np
8 | from einops import rearrange, repeat
9 | import math
10 | from copy import deepcopy
11 |
12 | from equi_diffpo.model.common.rotation_transformer import RotationTransformer
13 |
14 | class VoxelRotRandomizerRel(nn.Module):
15 | """
16 | Continuously and randomly rotate the input tensor during training.
17 | Does not rotate the tensor during evaluation.
18 | """
19 |
20 | def __init__(self, min_angle=-180, max_angle=180):
21 | """
22 | Args:
23 | min_angle (float): Minimum rotation angle.
24 | max_angle (float): Maximum rotation angle.
25 | """
26 | super().__init__()
27 | self.min_angle = min_angle
28 | self.max_angle = max_angle
29 | self.quat_to_matrix = RotationTransformer('quaternion', 'matrix')
30 | self.axisangle_to_matrix = RotationTransformer('axis_angle', 'matrix')
31 |
32 |
33 | def forward(self, nobs, naction: torch.Tensor):
34 | """
35 | Randomly rotates the inputs if in training mode.
36 | Keeps inputs unchanged if in evaluation mode.
37 |
38 | Args:
39 | inputs (torch.Tensor): input tensors
40 |
41 | Returns:
42 | torch.Tensor: rotated or unrotated tensors based on the mode
43 | """
44 | if self.training:
45 | obs = nobs["voxels"]
46 | pos = nobs["robot0_eef_pos"]
47 | # x, y, z, w -> w, x, y, z
48 | quat = nobs["robot0_eef_quat"][:, :, [3, 0, 1, 2]]
49 | batch_size = obs.shape[0]
50 | To = obs.shape[1]
51 | C = obs.shape[2]
52 | Ta = naction.shape[1]
53 |
54 | for i in range(1000):
55 | angles = torch.rand(batch_size) * 2 * np.pi - np.pi
56 | rotation_matrix = torch.zeros((batch_size, 3, 3), device=obs.device)
57 | rotation_matrix[:, 2, 2] = 1
58 | # construct rotation matrix
59 | angles[torch.rand(batch_size) < 1/64] = 0
60 | rotation_matrix[:, 0, 0] = torch.cos(angles)
61 | rotation_matrix[:, 0, 1] = -torch.sin(angles)
62 | rotation_matrix[:, 1, 0] = torch.sin(angles)
63 | rotation_matrix[:, 1, 1] = torch.cos(angles)
64 | # rotating the xyz vector in action
65 | rotated_naction = naction.clone()
66 | expanded_rotation_matrix = repeat(rotation_matrix, 'b d1 d2 -> b t d1 d2', t=Ta)
67 | rotated_naction[:, :, :3] = (expanded_rotation_matrix @ naction[:, :, :3].unsqueeze(-1)).squeeze(-1)
68 | # rotating the axis angle rotation vector in action
69 | axis_angle = rotated_naction[:, :, 3:6]
70 | m = self.axisangle_to_matrix.forward(axis_angle)
71 | rotation_matrix_inv = rotation_matrix.transpose(1, 2)
72 | expanded_rotation_matrix_inv = repeat(rotation_matrix_inv, 'b d1 d2 -> b t d1 d2', t=Ta)
73 | gm = expanded_rotation_matrix @ m @ expanded_rotation_matrix_inv
74 | g_axis_angle = self.axisangle_to_matrix.inverse(gm)
75 | rotated_naction[:, :, 3:6] = g_axis_angle
76 | # rotating state pos and quat
77 | rotated_pos = (rotation_matrix @ pos.permute(0, 2, 1)).permute(0, 2, 1)
78 | rot = self.quat_to_matrix.forward(quat)
79 | rotated_rot = rotation_matrix.unsqueeze(1) @ rot
80 | rotated_quat = self.quat_to_matrix.inverse(rotated_rot)
81 |
82 | if rotated_pos.min() >= -1 and rotated_pos.max() <= 1:
83 | break
84 | if i == 999:
85 | return nobs, naction
86 |
87 | obs = rearrange(obs, "b t c h w d -> b t c d w h")
88 | obs = torch.flip(obs, (3, 4))
89 | obs = rearrange(obs, "b t c d w h -> b (t c d) w h")
90 | grid = F.affine_grid(rotation_matrix[:, :2], obs.size(), align_corners=True)
91 | rotated_obs = F.grid_sample(obs, grid, align_corners=True, mode='nearest')
92 | rotated_obs = rearrange(rotated_obs, "b (t c d) w h -> b t c d w h", c=C, t=To)
93 | rotated_obs = torch.flip(rotated_obs, (3, 4))
94 | rotated_obs = rearrange(rotated_obs, "b t c d w h -> b t c h w d")
95 |
96 | nobs["voxels"] = rotated_obs
97 | nobs["robot0_eef_pos"] = rotated_pos
98 | # w, x, y, z -> x, y, z, w
99 | nobs["robot0_eef_quat"] = rotated_quat[:, :, [1, 2, 3, 0]]
100 | naction = rotated_naction
101 |
102 | return nobs, naction
103 |
104 | def __repr__(self):
105 | """Pretty print the network."""
106 | header = '{}'.format(str(self.__class__.__name__))
107 | msg = header + "(min_angle={}, max_angle={})".format(self.min_angle, self.max_angle)
108 | return msg
109 |
110 |
--------------------------------------------------------------------------------
/equi_diffpo/policy/base_image_policy.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import torch
3 | import torch.nn as nn
4 | from equi_diffpo.model.common.module_attr_mixin import ModuleAttrMixin
5 | from equi_diffpo.model.common.normalizer import LinearNormalizer
6 |
7 | class BaseImagePolicy(ModuleAttrMixin):
8 | # init accepts keyword argument shape_meta, see config/task/*_image.yaml
9 |
10 | def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
11 | """
12 | obs_dict:
13 | str: B,To,*
14 | return: B,Ta,Da
15 | """
16 | raise NotImplementedError()
17 |
18 | # reset state for stateful policies
19 | def reset(self):
20 | pass
21 |
22 | # ========== training ===========
23 | # no standard training interface except setting normalizer
24 | def set_normalizer(self, normalizer: LinearNormalizer):
25 | raise NotImplementedError()
26 |
--------------------------------------------------------------------------------
/equi_diffpo/policy/base_lowdim_policy.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import torch
3 | import torch.nn as nn
4 | from equi_diffpo.model.common.module_attr_mixin import ModuleAttrMixin
5 | from equi_diffpo.model.common.normalizer import LinearNormalizer
6 |
7 | class BaseLowdimPolicy(ModuleAttrMixin):
8 | # ========= inference ============
9 | # also as self.device and self.dtype for inference device transfer
10 | def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
11 | """
12 | obs_dict:
13 | obs: B,To,Do
14 | return:
15 | action: B,Ta,Da
16 | To = 3
17 | Ta = 4
18 | T = 6
19 | |o|o|o|
20 | | | |a|a|a|a|
21 | |o|o|
22 | | |a|a|a|a|a|
23 | | | | | |a|a|
24 | """
25 | raise NotImplementedError()
26 |
27 | # reset state for stateful policies
28 | def reset(self):
29 | pass
30 |
31 | # ========== training ===========
32 | # no standard training interface except setting normalizer
33 | def set_normalizer(self, normalizer: LinearNormalizer):
34 | raise NotImplementedError()
35 |
36 |
--------------------------------------------------------------------------------
/equi_diffpo/policy/robomimic_image_policy.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import torch
3 | from equi_diffpo.model.common.normalizer import LinearNormalizer
4 | from equi_diffpo.policy.base_image_policy import BaseImagePolicy
5 | from equi_diffpo.common.pytorch_util import dict_apply
6 |
7 | from robomimic.algo import algo_factory
8 | from robomimic.algo.algo import PolicyAlgo
9 | import robomimic.utils.obs_utils as ObsUtils
10 | from equi_diffpo.common.robomimic_config_util import get_robomimic_config
11 |
12 | class RobomimicImagePolicy(BaseImagePolicy):
13 | def __init__(self,
14 | shape_meta: dict,
15 | algo_name='bc_rnn',
16 | obs_type='image',
17 | task_name='square',
18 | dataset_type='ph',
19 | crop_shape=(76,76)
20 | ):
21 | super().__init__()
22 |
23 | # parse shape_meta
24 | action_shape = shape_meta['action']['shape']
25 | assert len(action_shape) == 1
26 | action_dim = action_shape[0]
27 | obs_shape_meta = shape_meta['obs']
28 | obs_config = {
29 | 'low_dim': [],
30 | 'rgb': [],
31 | 'depth': [],
32 | 'scan': []
33 | }
34 | obs_key_shapes = dict()
35 | for key, attr in obs_shape_meta.items():
36 | shape = attr['shape']
37 | obs_key_shapes[key] = list(shape)
38 |
39 | type = attr.get('type', 'low_dim')
40 | if type == 'rgb':
41 | obs_config['rgb'].append(key)
42 | elif type == 'low_dim':
43 | obs_config['low_dim'].append(key)
44 | else:
45 | raise RuntimeError(f"Unsupported obs type: {type}")
46 |
47 | # get raw robomimic config
48 | config = get_robomimic_config(
49 | algo_name=algo_name,
50 | hdf5_type=obs_type,
51 | task_name=task_name,
52 | dataset_type=dataset_type)
53 |
54 |
55 | with config.unlocked():
56 | # set config with shape_meta
57 | config.observation.modalities.obs = obs_config
58 |
59 | if crop_shape is None:
60 | for key, modality in config.observation.encoder.items():
61 | if modality.obs_randomizer_class == 'CropRandomizer':
62 | modality['obs_randomizer_class'] = None
63 | else:
64 | # set random crop parameter
65 | ch, cw = crop_shape
66 | for key, modality in config.observation.encoder.items():
67 | if modality.obs_randomizer_class == 'CropRandomizer':
68 | modality.obs_randomizer_kwargs.crop_height = ch
69 | modality.obs_randomizer_kwargs.crop_width = cw
70 |
71 | # init global state
72 | ObsUtils.initialize_obs_utils_with_config(config)
73 |
74 | # load model
75 | model: PolicyAlgo = algo_factory(
76 | algo_name=config.algo_name,
77 | config=config,
78 | obs_key_shapes=obs_key_shapes,
79 | ac_dim=action_dim,
80 | device='cpu',
81 | )
82 |
83 | self.model = model
84 | self.nets = model.nets
85 | self.normalizer = LinearNormalizer()
86 | self.config = config
87 |
88 | def to(self,*args,**kwargs):
89 | device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
90 | if device is not None:
91 | self.model.device = device
92 | super().to(*args,**kwargs)
93 |
94 | # =========== inference =============
95 | def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
96 | nobs_dict = self.normalizer(obs_dict)
97 | robomimic_obs_dict = dict_apply(nobs_dict, lambda x: x[:,0,...])
98 | naction = self.model.get_action(robomimic_obs_dict)
99 | action = self.normalizer['action'].unnormalize(naction)
100 | # (B, Da)
101 | result = {
102 | 'action': action[:,None,:] # (B, 1, Da)
103 | }
104 | return result
105 |
106 | def reset(self):
107 | self.model.reset()
108 |
109 | # =========== training ==============
110 | def set_normalizer(self, normalizer: LinearNormalizer):
111 | self.normalizer.load_state_dict(normalizer.state_dict())
112 |
113 | def train_on_batch(self, batch, epoch, validate=False):
114 | nobs = self.normalizer.normalize(batch['obs'])
115 | nactions = self.normalizer['action'].normalize(batch['action'])
116 | robomimic_batch = {
117 | 'obs': nobs,
118 | 'actions': nactions
119 | }
120 | input_batch = self.model.process_batch_for_training(
121 | robomimic_batch)
122 | info = self.model.train_on_batch(
123 | batch=input_batch, epoch=epoch, validate=validate)
124 | # keys: losses, predictions
125 | return info
126 |
127 | def on_epoch_end(self, epoch):
128 | self.model.on_epoch_end(epoch)
129 |
130 | def get_optimizer(self):
131 | return self.model.optimizers['policy']
132 |
133 |
134 | def test():
135 | import os
136 | from omegaconf import OmegaConf
137 | cfg_path = os.path.expanduser('~/dev/diffusion_policy/diffusion_policy/config/task/lift_image.yaml')
138 | cfg = OmegaConf.load(cfg_path)
139 | shape_meta = cfg.shape_meta
140 |
141 | policy = RobomimicImagePolicy(shape_meta=shape_meta)
142 |
143 |
--------------------------------------------------------------------------------
/equi_diffpo/scripts/download_datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4 |
5 | """
6 | Script to download datasets packaged with the repository.
7 | """
8 | import os
9 | import argparse
10 |
11 | import mimicgen_envs
12 | import mimicgen_envs.utils.file_utils as FileUtils
13 | from mimicgen_envs import DATASET_REGISTRY
14 |
15 |
16 | if __name__ == "__main__":
17 | parser = argparse.ArgumentParser()
18 |
19 | # directory to download datasets to
20 | parser.add_argument(
21 | "--download_dir",
22 | type=str,
23 | default=None,
24 | help="Base download directory. Created if it doesn't exist. Defaults to datasets folder in repository.",
25 | )
26 |
27 | # dataset type to download datasets for
28 | parser.add_argument(
29 | "--dataset_type",
30 | type=str,
31 | default="core",
32 | choices=list(DATASET_REGISTRY.keys()),
33 | help="Dataset type to download datasets for (e.g. source, core, object, robot, large_interpolation). Defaults to core.",
34 | )
35 |
36 | # tasks to download datasets for
37 | parser.add_argument(
38 | "--tasks",
39 | type=str,
40 | nargs='+',
41 | default=["square_d0"],
42 | help="Tasks to download datasets for. Defaults to square_d0 task. Pass 'all' to download all tasks\
43 | for the provided dataset type or directly specify the list of tasks.",
44 | )
45 |
46 | # dry run - don't actually download datasets, but print which datasets would be downloaded
47 | parser.add_argument(
48 | "--dry_run",
49 | action='store_true',
50 | help="set this flag to do a dry run to only print which datasets would be downloaded"
51 | )
52 |
53 | args = parser.parse_args()
54 |
55 | # set default base directory for downloads
56 | default_base_dir = args.download_dir
57 | if default_base_dir is None:
58 | default_base_dir = "data/robomimic/datasets"
59 |
60 | # load args
61 | download_dataset_type = args.dataset_type
62 | download_tasks = args.tasks
63 | if "all" in download_tasks:
64 | assert len(download_tasks) == 1, "all should be only tasks argument but got: {}".format(args.tasks)
65 | download_tasks = list(DATASET_REGISTRY[download_dataset_type].keys())
66 | else:
67 | for task in download_tasks:
68 | assert task in DATASET_REGISTRY[download_dataset_type], "got unknown task {} for dataset type {}. Choose one of {}".format(task, download_dataset_type, list(DATASET_REGISTRY[download_dataset_type].keys()))
69 |
70 | # download requested datasets
71 | for task in download_tasks:
72 | download_dir = os.path.abspath(os.path.join(default_base_dir, task))
73 | download_path = os.path.join(download_dir, "{}.hdf5".format(task))
74 | print("\nDownloading dataset:\n dataset type: {}\n task: {}\n download path: {}"
75 | .format(download_dataset_type, task, download_path))
76 | url = DATASET_REGISTRY[download_dataset_type][task]["url"]
77 | if args.dry_run:
78 | print("\ndry run: skip download")
79 | else:
80 | # Make sure path exists and create if it doesn't
81 | os.makedirs(download_dir, exist_ok=True)
82 | print("")
83 | FileUtils.download_url_from_gdrive(
84 | url=url,
85 | download_dir=download_dir,
86 | check_overwrite=True,
87 | )
88 | print("")
--------------------------------------------------------------------------------
/equi_diffpo/scripts/robomimic_dataset_action_comparison.py:
--------------------------------------------------------------------------------
1 | if __name__ == "__main__":
2 | import sys
3 | import os
4 | import pathlib
5 |
6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
7 | sys.path.append(ROOT_DIR)
8 |
9 | import os
10 | import click
11 | import pathlib
12 | import h5py
13 | import numpy as np
14 | from tqdm import tqdm
15 | from scipy.spatial.transform import Rotation
16 |
17 | def read_all_actions(hdf5_file, metric_skip_steps=1):
18 | n_demos = len(hdf5_file['data'])
19 | all_actions = list()
20 | for i in tqdm(range(n_demos)):
21 | actions = hdf5_file[f'data/demo_{i}/actions'][:]
22 | all_actions.append(actions[metric_skip_steps:])
23 | all_actions = np.concatenate(all_actions, axis=0)
24 | return all_actions
25 |
26 |
27 | @click.command()
28 | @click.option('-i', '--input', required=True, help='input hdf5 path')
29 | @click.option('-o', '--output', required=True, help='output hdf5 path. Parent directory must exist')
30 | def main(input, output):
31 | # process inputs
32 | input = pathlib.Path(input).expanduser()
33 | assert input.is_file()
34 | output = pathlib.Path(output).expanduser()
35 | assert output.is_file()
36 |
37 | input_file = h5py.File(str(input), 'r')
38 | output_file = h5py.File(str(output), 'r')
39 |
40 | input_all_actions = read_all_actions(input_file)
41 | output_all_actions = read_all_actions(output_file)
42 | pos_dist = np.linalg.norm(input_all_actions[:,:3] - output_all_actions[:,:3], axis=-1)
43 | rot_dist = (Rotation.from_rotvec(input_all_actions[:,3:6]
44 | ) * Rotation.from_rotvec(output_all_actions[:,3:6]).inv()
45 | ).magnitude()
46 |
47 | print(f'max pos dist: {pos_dist.max()}')
48 | print(f'max rot dist: {rot_dist.max()}')
49 |
50 | if __name__ == "__main__":
51 | main()
52 |
--------------------------------------------------------------------------------
/equi_diffpo/scripts/robomimic_dataset_conversion.py:
--------------------------------------------------------------------------------
1 | if __name__ == "__main__":
2 | import sys
3 | import os
4 | import pathlib
5 |
6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
7 | sys.path.append(ROOT_DIR)
8 |
9 | import multiprocessing
10 | import os
11 | import shutil
12 | import click
13 | import pathlib
14 | import h5py
15 | from tqdm import tqdm
16 | import collections
17 | import pickle
18 | from equi_diffpo.common.robomimic_util import RobomimicAbsoluteActionConverter
19 |
20 | def worker(x):
21 | path, idx, do_eval = x
22 | converter = RobomimicAbsoluteActionConverter(path)
23 | if do_eval:
24 | abs_actions, info = converter.convert_and_eval_idx(idx)
25 | else:
26 | abs_actions = converter.convert_idx(idx)
27 | info = dict()
28 | return abs_actions, info
29 |
30 | @click.command()
31 | @click.option('-i', '--input', required=True, help='input hdf5 path')
32 | @click.option('-o', '--output', required=True, help='output hdf5 path. Parent directory must exist')
33 | @click.option('-e', '--eval_dir', default=None, help='directory to output evaluation metrics')
34 | @click.option('-n', '--num_workers', default=None, type=int)
35 | def main(input, output, eval_dir, num_workers):
36 | # process inputs
37 | input = pathlib.Path(input).expanduser()
38 | assert input.is_file()
39 | output = pathlib.Path(output).expanduser()
40 | assert output.parent.is_dir()
41 | assert not output.is_dir()
42 |
43 | do_eval = False
44 | if eval_dir is not None:
45 | eval_dir = pathlib.Path(eval_dir).expanduser()
46 | assert eval_dir.parent.exists()
47 | do_eval = True
48 |
49 | converter = RobomimicAbsoluteActionConverter(input)
50 |
51 | # run
52 | with multiprocessing.Pool(num_workers) as pool:
53 | results = pool.map(worker, [(input, i, do_eval) for i in range(len(converter))])
54 |
55 | # save output
56 | print('Copying hdf5')
57 | shutil.copy(str(input), str(output))
58 |
59 | # modify action
60 | with h5py.File(output, 'r+') as out_file:
61 | for i in tqdm(range(len(converter)), desc="Writing to output"):
62 | abs_actions, info = results[i]
63 | demo = out_file[f'data/demo_{i}']
64 | demo['actions'][:] = abs_actions
65 |
66 | # save eval
67 | if do_eval:
68 | eval_dir.mkdir(parents=False, exist_ok=True)
69 |
70 | print("Writing error_stats.pkl")
71 | infos = [info for _, info in results]
72 | pickle.dump(infos, eval_dir.joinpath('error_stats.pkl').open('wb'))
73 |
74 | print("Generating visualization")
75 | metrics = ['pos', 'rot']
76 | metrics_dicts = dict()
77 | for m in metrics:
78 | metrics_dicts[m] = collections.defaultdict(list)
79 |
80 | for i in range(len(infos)):
81 | info = infos[i]
82 | for k, v in info.items():
83 | for m in metrics:
84 | metrics_dicts[m][k].append(v[m])
85 |
86 | from matplotlib import pyplot as plt
87 | plt.switch_backend('PDF')
88 |
89 | fig, ax = plt.subplots(1, len(metrics))
90 | for i in range(len(metrics)):
91 | axis = ax[i]
92 | data = metrics_dicts[metrics[i]]
93 | for key, value in data.items():
94 | axis.plot(value, label=key)
95 | axis.legend()
96 | axis.set_title(metrics[i])
97 | fig.set_size_inches(10,4)
98 | fig.savefig(str(eval_dir.joinpath('error_stats.pdf')))
99 | fig.savefig(str(eval_dir.joinpath('error_stats.png')))
100 |
101 |
102 | if __name__ == "__main__":
103 | main()
104 |
--------------------------------------------------------------------------------
/equi_diffpo/scripts/robomimic_dataset_obs_conversion.py:
--------------------------------------------------------------------------------
1 | if __name__ == "__main__":
2 | import sys
3 | import os
4 | import pathlib
5 |
6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent)
7 | sys.path.append(ROOT_DIR)
8 |
9 | import multiprocessing
10 | import os
11 | import shutil
12 | import click
13 | import pathlib
14 | import h5py
15 | from tqdm import tqdm
16 | import numpy as np
17 | import collections
18 | import pickle
19 | from equi_diffpo.common.robomimic_util import RobomimicObsConverter
20 |
21 | multiprocessing.set_start_method('spawn', force=True)
22 |
23 | def worker(x):
24 | path, idx = x
25 | converter = RobomimicObsConverter(path)
26 | obss = converter.convert_idx(idx)
27 | return obss
28 |
29 | @click.command()
30 | @click.option('-i', '--input', required=True, help='input hdf5 path')
31 | @click.option('-o', '--output', required=True, help='output hdf5 path. Parent directory must exist')
32 | @click.option('-n', '--num_workers', default=None, type=int)
33 | def main(input, output, num_workers):
34 | # process inputs
35 | input = pathlib.Path(input).expanduser()
36 | assert input.is_file()
37 | output = pathlib.Path(output).expanduser()
38 | assert output.parent.is_dir()
39 | assert not output.is_dir()
40 |
41 | converter = RobomimicObsConverter(input)
42 |
43 | # save output
44 | print('Copying hdf5')
45 | shutil.copy(str(input), str(output))
46 |
47 | # run
48 | idx = 0
49 | while idx < len(converter):
50 | with multiprocessing.Pool(num_workers) as pool:
51 | end = min(idx + num_workers, len(converter))
52 | results = pool.map(worker, [(input, i) for i in range(idx, end)])
53 |
54 | # modify action
55 | print('Writing {} to {}'.format(idx, end))
56 | with h5py.File(output, 'r+') as out_file:
57 | for i in tqdm(range(idx, end), desc="Writing to output"):
58 | obss = results[i - idx]
59 | demo = out_file[f'data/demo_{i}']
60 | del demo['obs']
61 | for k in obss:
62 | demo.create_dataset("obs/{}".format(k), data=np.array(obss[k]), compression="gzip")
63 |
64 | idx = end
65 | del results
66 |
67 |
68 | if __name__ == "__main__":
69 | main()
70 |
--------------------------------------------------------------------------------
/equi_diffpo/shared_memory/shared_memory_queue.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Union
2 | import numbers
3 | from queue import (Empty, Full)
4 | from multiprocessing.managers import SharedMemoryManager
5 | import numpy as np
6 | from equi_diffpo.shared_memory.shared_memory_util import ArraySpec, SharedAtomicCounter
7 | from equi_diffpo.shared_memory.shared_ndarray import SharedNDArray
8 |
9 |
10 | class SharedMemoryQueue:
11 | """
12 | A Lock-Free FIFO Shared Memory Data Structure.
13 | Stores a sequence of dict of numpy arrays.
14 | """
15 |
16 | def __init__(self,
17 | shm_manager: SharedMemoryManager,
18 | array_specs: List[ArraySpec],
19 | buffer_size: int
20 | ):
21 |
22 | # create atomic counter
23 | write_counter = SharedAtomicCounter(shm_manager)
24 | read_counter = SharedAtomicCounter(shm_manager)
25 |
26 | # allocate shared memory
27 | shared_arrays = dict()
28 | for spec in array_specs:
29 | key = spec.name
30 | assert key not in shared_arrays
31 | array = SharedNDArray.create_from_shape(
32 | mem_mgr=shm_manager,
33 | shape=(buffer_size,) + tuple(spec.shape),
34 | dtype=spec.dtype)
35 | shared_arrays[key] = array
36 |
37 | self.buffer_size = buffer_size
38 | self.array_specs = array_specs
39 | self.write_counter = write_counter
40 | self.read_counter = read_counter
41 | self.shared_arrays = shared_arrays
42 |
43 | @classmethod
44 | def create_from_examples(cls,
45 | shm_manager: SharedMemoryManager,
46 | examples: Dict[str, Union[np.ndarray, numbers.Number]],
47 | buffer_size: int
48 | ):
49 | specs = list()
50 | for key, value in examples.items():
51 | shape = None
52 | dtype = None
53 | if isinstance(value, np.ndarray):
54 | shape = value.shape
55 | dtype = value.dtype
56 | assert dtype != np.dtype('O')
57 | elif isinstance(value, numbers.Number):
58 | shape = tuple()
59 | dtype = np.dtype(type(value))
60 | else:
61 | raise TypeError(f'Unsupported type {type(value)}')
62 |
63 | spec = ArraySpec(
64 | name=key,
65 | shape=shape,
66 | dtype=dtype
67 | )
68 | specs.append(spec)
69 |
70 | obj = cls(
71 | shm_manager=shm_manager,
72 | array_specs=specs,
73 | buffer_size=buffer_size
74 | )
75 | return obj
76 |
77 | def qsize(self):
78 | read_count = self.read_counter.load()
79 | write_count = self.write_counter.load()
80 | n_data = write_count - read_count
81 | return n_data
82 |
83 | def empty(self):
84 | n_data = self.qsize()
85 | return n_data <= 0
86 |
87 | def clear(self):
88 | self.read_counter.store(self.write_counter.load())
89 |
90 | def put(self, data: Dict[str, Union[np.ndarray, numbers.Number]]):
91 | read_count = self.read_counter.load()
92 | write_count = self.write_counter.load()
93 | n_data = write_count - read_count
94 | if n_data >= self.buffer_size:
95 | raise Full()
96 |
97 | next_idx = write_count % self.buffer_size
98 |
99 | # write to shared memory
100 | for key, value in data.items():
101 | arr: np.ndarray
102 | arr = self.shared_arrays[key].get()
103 | if isinstance(value, np.ndarray):
104 | arr[next_idx] = value
105 | else:
106 | arr[next_idx] = np.array(value, dtype=arr.dtype)
107 |
108 | # update idx
109 | self.write_counter.add(1)
110 |
111 | def get(self, out=None) -> Dict[str, np.ndarray]:
112 | write_count = self.write_counter.load()
113 | read_count = self.read_counter.load()
114 | n_data = write_count - read_count
115 | if n_data <= 0:
116 | raise Empty()
117 |
118 | if out is None:
119 | out = self._allocate_empty()
120 |
121 | next_idx = read_count % self.buffer_size
122 | for key, value in self.shared_arrays.items():
123 | arr = value.get()
124 | np.copyto(out[key], arr[next_idx])
125 |
126 | # update idx
127 | self.read_counter.add(1)
128 | return out
129 |
130 | def get_k(self, k, out=None) -> Dict[str, np.ndarray]:
131 | write_count = self.write_counter.load()
132 | read_count = self.read_counter.load()
133 | n_data = write_count - read_count
134 | if n_data <= 0:
135 | raise Empty()
136 | assert k <= n_data
137 |
138 | out = self._get_k_impl(k, read_count, out=out)
139 | self.read_counter.add(k)
140 | return out
141 |
142 | def get_all(self, out=None) -> Dict[str, np.ndarray]:
143 | write_count = self.write_counter.load()
144 | read_count = self.read_counter.load()
145 | n_data = write_count - read_count
146 | if n_data <= 0:
147 | raise Empty()
148 |
149 | out = self._get_k_impl(n_data, read_count, out=out)
150 | self.read_counter.add(n_data)
151 | return out
152 |
153 | def _get_k_impl(self, k, read_count, out=None) -> Dict[str, np.ndarray]:
154 | if out is None:
155 | out = self._allocate_empty(k)
156 |
157 | curr_idx = read_count % self.buffer_size
158 | for key, value in self.shared_arrays.items():
159 | arr = value.get()
160 | target = out[key]
161 |
162 | start = curr_idx
163 | end = min(start + k, self.buffer_size)
164 | target_start = 0
165 | target_end = (end - start)
166 | target[target_start: target_end] = arr[start:end]
167 |
168 | remainder = k - (end - start)
169 | if remainder > 0:
170 | # wrap around
171 | start = 0
172 | end = start + remainder
173 | target_start = target_end
174 | target_end = k
175 | target[target_start: target_end] = arr[start:end]
176 |
177 | return out
178 |
179 | def _allocate_empty(self, k=None):
180 | result = dict()
181 | for spec in self.array_specs:
182 | shape = spec.shape
183 | if k is not None:
184 | shape = (k,) + shape
185 | result[spec.name] = np.empty(
186 | shape=shape, dtype=spec.dtype)
187 | return result
188 |
--------------------------------------------------------------------------------
/equi_diffpo/shared_memory/shared_memory_util.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from dataclasses import dataclass
3 | import numpy as np
4 | from multiprocessing.managers import SharedMemoryManager
5 | from atomics import atomicview, MemoryOrder, UINT
6 |
7 | @dataclass
8 | class ArraySpec:
9 | name: str
10 | shape: Tuple[int]
11 | dtype: np.dtype
12 |
13 |
14 | class SharedAtomicCounter:
15 | def __init__(self,
16 | shm_manager: SharedMemoryManager,
17 | size :int=8 # 64bit int
18 | ):
19 | shm = shm_manager.SharedMemory(size=size)
20 | self.shm = shm
21 | self.size = size
22 | self.store(0) # initialize
23 |
24 | @property
25 | def buf(self):
26 | return self.shm.buf[:self.size]
27 |
28 | def load(self) -> int:
29 | with atomicview(buffer=self.buf, atype=UINT) as a:
30 | value = a.load(order=MemoryOrder.ACQUIRE)
31 | return value
32 |
33 | def store(self, value: int):
34 | with atomicview(buffer=self.buf, atype=UINT) as a:
35 | a.store(value, order=MemoryOrder.RELEASE)
36 |
37 | def add(self, value: int):
38 | with atomicview(buffer=self.buf, atype=UINT) as a:
39 | a.add(value, order=MemoryOrder.ACQ_REL)
40 |
--------------------------------------------------------------------------------
/equi_diffpo/workspace/base_workspace.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import os
3 | import pathlib
4 | import hydra
5 | import copy
6 | from hydra.core.hydra_config import HydraConfig
7 | from omegaconf import OmegaConf
8 | import dill
9 | import torch
10 | import threading
11 |
12 |
13 | class BaseWorkspace:
14 | include_keys = tuple()
15 | exclude_keys = tuple()
16 |
17 | def __init__(self, cfg: OmegaConf, output_dir: Optional[str]=None):
18 | self.cfg = cfg
19 | self._output_dir = output_dir
20 | self._saving_thread = None
21 |
22 | @property
23 | def output_dir(self):
24 | output_dir = self._output_dir
25 | if output_dir is None:
26 | output_dir = HydraConfig.get().runtime.output_dir
27 | return output_dir
28 |
29 | def run(self):
30 | """
31 | Create any resource shouldn't be serialized as local variables
32 | """
33 | pass
34 |
35 | def save_checkpoint(self, path=None, tag='latest',
36 | exclude_keys=None,
37 | include_keys=None,
38 | use_thread=True):
39 | if path is None:
40 | path = pathlib.Path(self.output_dir).joinpath('checkpoints', f'{tag}.ckpt')
41 | else:
42 | path = pathlib.Path(path)
43 | if exclude_keys is None:
44 | exclude_keys = tuple(self.exclude_keys)
45 | if include_keys is None:
46 | include_keys = tuple(self.include_keys) + ('_output_dir',)
47 |
48 | path.parent.mkdir(parents=False, exist_ok=True)
49 | payload = {
50 | 'cfg': self.cfg,
51 | 'state_dicts': dict(),
52 | 'pickles': dict()
53 | }
54 |
55 | for key, value in self.__dict__.items():
56 | if hasattr(value, 'state_dict') and hasattr(value, 'load_state_dict'):
57 | # modules, optimizers and samplers etc
58 | if key not in exclude_keys:
59 | if use_thread:
60 | payload['state_dicts'][key] = _copy_to_cpu(value.state_dict())
61 | else:
62 | payload['state_dicts'][key] = value.state_dict()
63 | elif key in include_keys:
64 | payload['pickles'][key] = dill.dumps(value)
65 | if use_thread:
66 | self._saving_thread = threading.Thread(
67 | target=lambda : torch.save(payload, path.open('wb'), pickle_module=dill))
68 | self._saving_thread.start()
69 | else:
70 | torch.save(payload, path.open('wb'), pickle_module=dill)
71 | return str(path.absolute())
72 |
73 | def get_checkpoint_path(self, tag='latest'):
74 | return pathlib.Path(self.output_dir).joinpath('checkpoints', f'{tag}.ckpt')
75 |
76 | def load_payload(self, payload, exclude_keys=None, include_keys=None, **kwargs):
77 | if exclude_keys is None:
78 | exclude_keys = tuple()
79 | if include_keys is None:
80 | include_keys = payload['pickles'].keys()
81 |
82 | for key, value in payload['state_dicts'].items():
83 | if key not in exclude_keys:
84 | self.__dict__[key].load_state_dict(value, **kwargs)
85 | for key in include_keys:
86 | if key in payload['pickles']:
87 | self.__dict__[key] = dill.loads(payload['pickles'][key])
88 |
89 | def load_checkpoint(self, path=None, tag='latest',
90 | exclude_keys=None,
91 | include_keys=None,
92 | **kwargs):
93 | if path is None:
94 | path = self.get_checkpoint_path(tag=tag)
95 | else:
96 | path = pathlib.Path(path)
97 | payload = torch.load(path.open('rb'), pickle_module=dill, **kwargs)
98 | self.load_payload(payload,
99 | exclude_keys=exclude_keys,
100 | include_keys=include_keys)
101 | return payload
102 |
103 | @classmethod
104 | def create_from_checkpoint(cls, path,
105 | exclude_keys=None,
106 | include_keys=None,
107 | **kwargs):
108 | payload = torch.load(open(path, 'rb'), pickle_module=dill)
109 | instance = cls(payload['cfg'])
110 | instance.load_payload(
111 | payload=payload,
112 | exclude_keys=exclude_keys,
113 | include_keys=include_keys,
114 | **kwargs)
115 | return instance
116 |
117 | def save_snapshot(self, tag='latest'):
118 | """
119 | Quick loading and saving for reserach, saves full state of the workspace.
120 |
121 | However, loading a snapshot assumes the code stays exactly the same.
122 | Use save_checkpoint for long-term storage.
123 | """
124 | path = pathlib.Path(self.output_dir).joinpath('snapshots', f'{tag}.pkl')
125 | path.parent.mkdir(parents=False, exist_ok=True)
126 | torch.save(self, path.open('wb'), pickle_module=dill)
127 | return str(path.absolute())
128 |
129 | @classmethod
130 | def create_from_snapshot(cls, path):
131 | return torch.load(open(path, 'rb'), pickle_module=dill)
132 |
133 |
134 | def _copy_to_cpu(x):
135 | if isinstance(x, torch.Tensor):
136 | return x.detach().to('cpu')
137 | elif isinstance(x, dict):
138 | result = dict()
139 | for k, v in x.items():
140 | result[k] = _copy_to_cpu(v)
141 | return result
142 | elif isinstance(x, list):
143 | return [_copy_to_cpu(k) for k in x]
144 | else:
145 | return copy.deepcopy(x)
146 |
--------------------------------------------------------------------------------
/img/equi.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pointW/equidiff/e40abb003b25071d4e5b01bfa9933dc16cd32c67/img/equi.gif
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'equi_diffpo',
5 | packages = find_packages(),
6 | )
7 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | Training:
4 | python train.py --config-name=train_diffusion_lowdim_workspace
5 | """
6 |
7 | import sys
8 | # use line-buffering for both stdout and stderr
9 | sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
10 | sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)
11 |
12 | import hydra
13 | from omegaconf import OmegaConf
14 | import pathlib
15 | from equi_diffpo.workspace.base_workspace import BaseWorkspace
16 |
17 | max_steps = {
18 | 'stack_d1': 400,
19 | 'stack_three_d1': 400,
20 | 'square_d2': 400,
21 | 'threading_d2': 400,
22 | 'coffee_d2': 400,
23 | 'three_piece_assembly_d2': 500,
24 | 'hammer_cleanup_d1': 500,
25 | 'mug_cleanup_d1': 500,
26 | 'kitchen_d1': 800,
27 | 'nut_assembly_d0': 500,
28 | 'pick_place_d0': 1000,
29 | 'coffee_preparation_d1': 800,
30 | 'tool_hang': 700,
31 | 'can': 400,
32 | 'lift': 400,
33 | 'square': 400,
34 | }
35 |
36 | def get_ws_x_center(task_name):
37 | if task_name.startswith('kitchen_') or task_name.startswith('hammer_cleanup_'):
38 | return -0.2
39 | else:
40 | return 0.
41 |
42 | def get_ws_y_center(task_name):
43 | return 0.
44 |
45 | OmegaConf.register_new_resolver("get_max_steps", lambda x: max_steps[x], replace=True)
46 | OmegaConf.register_new_resolver("get_ws_x_center", get_ws_x_center, replace=True)
47 | OmegaConf.register_new_resolver("get_ws_y_center", get_ws_y_center, replace=True)
48 |
49 | # allows arbitrary python code execution in configs using the ${eval:''} resolver
50 | OmegaConf.register_new_resolver("eval", eval, replace=True)
51 |
52 | @hydra.main(
53 | version_base=None,
54 | config_path=str(pathlib.Path(__file__).parent.joinpath(
55 | 'equi_diffpo','config'))
56 | )
57 | def main(cfg: OmegaConf):
58 | # resolve immediately so all the ${now:} resolvers
59 | # will use the same time.
60 | OmegaConf.resolve(cfg)
61 |
62 | cls = hydra.utils.get_class(cfg._target_)
63 | workspace: BaseWorkspace = cls(cfg)
64 | workspace.run()
65 |
66 | if __name__ == "__main__":
67 | main()
68 |
--------------------------------------------------------------------------------