├── media ├── aloha_teaser.png └── aloha_dataset.png ├── .gitignore ├── test.py ├── data4robotics ├── models │ ├── __init__.py │ ├── base.py │ ├── action_distributions.py │ ├── vit.py │ ├── action_transformer.py │ ├── resnet.py │ ├── diffusion_unet.py │ └── diffusion.py ├── trainers │ ├── __init__.py │ ├── bc.py │ ├── utils.py │ └── base.py ├── sim │ ├── __init__.py │ ├── README.md │ ├── base.py │ └── robosuite.py ├── __init__.py ├── load_pretrained.py ├── misc.py ├── task.py ├── transforms.py ├── replay_buffer.py └── agent.py ├── .flake8 ├── experiments ├── agent │ ├── features │ │ ├── r3m.yaml │ │ ├── vit_base.yaml │ │ ├── robomimic.yaml │ │ ├── resnet_gn.yaml │ │ └── resnet_gn_nopool.yaml │ ├── policy │ │ ├── gaussian_mixture.yaml │ │ └── gaussian_constant.yaml │ ├── default.yaml │ ├── diffusion_unet.yaml │ ├── transformer.yaml │ └── diffusion.yaml ├── trainer │ ├── bc.yaml │ ├── bc_step_sched.yaml │ └── bc_cos_sched.yaml ├── hydra │ └── launcher │ │ └── slurm.yaml ├── task │ ├── robomimic_can.yaml │ ├── robomimic_lift.yaml │ ├── robomimic_square.yaml │ ├── robomimic_toolhang.yaml │ ├── viperx.yaml │ ├── aloha.yaml │ ├── end_effector.yaml │ ├── end_effector_r6.yaml │ └── aloha_gc.yaml └── finetune.yaml ├── setup.py ├── download_features.sh ├── pretrained_networks_example.py ├── env.yml ├── LICENSE.md ├── .pre-commit-config.yaml ├── eval_scripts ├── README.md ├── eval_droid.py ├── eval_droid_state.py └── eval_aloha.py ├── jobs.sh ├── diffuse_jobs.sh ├── test_agent.py ├── finetune.py └── README.md /media/aloha_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SudeepDasari/dit-policy/HEAD/media/aloha_teaser.png -------------------------------------------------------------------------------- /media/aloha_dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SudeepDasari/dit-policy/HEAD/media/aloha_dataset.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.jpg 2 | *.gif 3 | *.mp4 4 | *.png 5 | __pycache__/ 6 | bc_finetune/ 7 | visual_features/ 8 | *.egg-info/ 9 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from data4robotics.models.resnet import ResNet 2 | 3 | m = ResNet(size=34, weights='IMAGENET1K_V1', norm_cfg=dict(name='batch_norm')) 4 | 5 | -------------------------------------------------------------------------------- /data4robotics/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /data4robotics/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /data4robotics/sim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from .base import SimTask, SimTaskReplayBuffer 8 | -------------------------------------------------------------------------------- /data4robotics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from .load_pretrained import load_resnet18, load_vit 8 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = .git,data4robotics/sim/__init__.py 3 | max-line-length = 95 4 | select = E,F,W,C 5 | ignore=W503, 6 | E203, 7 | E731, 8 | E722, 9 | F841, 10 | E402, 11 | E741, 12 | E501, 13 | C406, 14 | -------------------------------------------------------------------------------- /experiments/agent/features/r3m.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.models.resnet.R3M 8 | size: 18 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from setuptools import setup 8 | 9 | setup(name="data4robotics", packages=["data4robotics"], version="0.1") 10 | -------------------------------------------------------------------------------- /download_features.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # make sure the folder doesn't already exist 4 | if [ -d "visual_features" ]; then 5 | echo "Data already downloaded!" 6 | exit 0 7 | fi 8 | 9 | wget --output-document features.zip https://www.cs.cmu.edu/~data4robotics/release/features.zip 10 | unzip features.zip 11 | rm features.zip 12 | -------------------------------------------------------------------------------- /experiments/agent/policy/gaussian_mixture.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.models.action_distributions.GaussianMixture 8 | in_dim: ${index:${agent.shared_mlp}, -1} 9 | ac_dim: ${task.ac_dim} 10 | ac_chunk: ${ac_chunk} 11 | num_modes: 5 12 | -------------------------------------------------------------------------------- /experiments/agent/policy/gaussian_constant.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.models.action_distributions.GaussianSharedScale 8 | in_dim: ${index:${agent.shared_mlp}, -1} 9 | ac_dim: ${task.ac_dim} 10 | ac_chunk: ${ac_chunk} 11 | std_fixed: True 12 | -------------------------------------------------------------------------------- /experiments/agent/features/vit_base.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.models.vit.load_vit 8 | restore_path: '' 9 | model: 10 | _target_: data4robotics.models.vit.vit_base_patch16 11 | img_size: 224 12 | use_cls: True 13 | drop_path_rate: 0.0 14 | -------------------------------------------------------------------------------- /experiments/agent/features/robomimic.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.models.resnet.RobomimicResNet 8 | size: 18 9 | weights: 'IMAGENET1K_V1' # for torchvision weight restore 10 | norm_cfg: 11 | name: diffusion_policy 12 | img_size: 224 13 | feature_dim: 64 14 | -------------------------------------------------------------------------------- /experiments/trainer/bc.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.trainers.bc.BehaviorCloning 8 | 9 | optim_builder: 10 | _target_: data4robotics.trainers.utils.optim_builder 11 | optimizer_type: Adam 12 | optimizer_kwargs: 13 | lr: ${lr} 14 | weight_decay: 0.0001 15 | -------------------------------------------------------------------------------- /experiments/hydra/launcher/slurm.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | defaults: 8 | - submitit_slurm 9 | 10 | timeout_min: 360 11 | partition: default 12 | tasks_per_node: ${devices} 13 | gpus_per_node: ${devices} 14 | cpus_per_task: ${num_workers} 15 | mem_gb: ${mult:${devices},124} 16 | nodes: 1 17 | max_num_timeout: 100 18 | -------------------------------------------------------------------------------- /experiments/agent/features/resnet_gn.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.models.resnet.ResNet 8 | size: 18 9 | weights: 'IMAGENET1K_V1' # for torchvision weight restore 10 | restore_path: '' # for restoring our custom weights 11 | avg_pool: True 12 | norm_cfg: 13 | name: group_norm 14 | num_groups: 16 15 | -------------------------------------------------------------------------------- /experiments/agent/features/resnet_gn_nopool.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.models.resnet.ResNet 8 | size: 18 9 | weights: 'IMAGENET1K_V1' # for torchvision weight restore 10 | restore_path: '' # for restoring our custom weights 11 | avg_pool: False # return all final image tokens 12 | conv_repeat: ${mult:${agent.early_fusion},${agent.imgs_per_cam}} 13 | norm_cfg: 14 | name: group_norm 15 | num_groups: 16 16 | -------------------------------------------------------------------------------- /experiments/agent/default.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | defaults: 8 | - features: vit_base 9 | - policy: gaussian_mixture 10 | - _self_ 11 | 12 | _target_: data4robotics.agent.MLPAgent 13 | shared_mlp: [512,512] 14 | odim: ${task.obs_dim} 15 | n_cams: ${task.n_cams} 16 | use_obs: add_token 17 | dropout: 0.2 18 | imgs_per_cam: ${add:${img_chunk},${len:${task.train_buffer.goal_indexes}}} 19 | share_cam_features: False 20 | early_fusion: False 21 | feat_norm: batch_norm 22 | -------------------------------------------------------------------------------- /experiments/trainer/bc_step_sched.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.trainers.bc.BehaviorCloning 8 | 9 | optim_builder: 10 | _target_: data4robotics.trainers.utils.optim_builder 11 | optimizer_type: Adam 12 | optimizer_kwargs: 13 | lr: ${lr} 14 | weight_decay: 0.0001 15 | 16 | schedule_builder: 17 | _target_: data4robotics.trainers.utils.schedule_builder 18 | schedule_type: 'StepLR' 19 | schedule_kwargs: 20 | step_size: 1 21 | gamma: 0.99997 22 | -------------------------------------------------------------------------------- /experiments/task/robomimic_can.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.sim.robosuite.RoboSuiteTask 8 | obs_dim: 11 9 | ac_dim: 7 10 | n_cams: 1 11 | task: can 12 | test_transform: ${transform:preproc} 13 | 14 | train_buffer: 15 | _target_: data4robotics.sim.robosuite.RoboSuiteBuffer 16 | task: ${task.task} 17 | transform: ${transform:${train_transform}} 18 | n_train_demos: 200 19 | ac_chunk: ${ac_chunk} 20 | cam_indexes: [0] 21 | goal_indexes: [] 22 | past_frames: ${add:${img_chunk},-1} 23 | -------------------------------------------------------------------------------- /experiments/task/robomimic_lift.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.sim.robosuite.RoboSuiteTask 8 | obs_dim: 11 9 | ac_dim: 7 10 | n_cams: 1 11 | task: lift 12 | test_transform: ${transform:preproc} 13 | 14 | train_buffer: 15 | _target_: data4robotics.sim.robosuite.RoboSuiteBuffer 16 | task: ${task.task} 17 | transform: ${transform:${train_transform}} 18 | n_train_demos: 200 19 | ac_chunk: ${ac_chunk} 20 | cam_indexes: [0] 21 | goal_indexes: [] 22 | past_frames: ${add:${img_chunk},-1} 23 | -------------------------------------------------------------------------------- /experiments/task/robomimic_square.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.sim.robosuite.RoboSuiteTask 8 | obs_dim: 11 9 | ac_dim: 7 10 | n_cams: 1 11 | task: square 12 | test_transform: ${transform:preproc} 13 | 14 | train_buffer: 15 | _target_: data4robotics.sim.robosuite.RoboSuiteBuffer 16 | task: ${task.task} 17 | transform: ${transform:${train_transform}} 18 | n_train_demos: 200 19 | ac_chunk: ${ac_chunk} 20 | cam_indexes: [0] 21 | goal_indexes: [] 22 | past_frames: ${add:${img_chunk},-1} 23 | -------------------------------------------------------------------------------- /experiments/task/robomimic_toolhang.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.sim.robosuite.RoboSuiteTask 8 | obs_dim: 11 9 | ac_dim: 7 10 | n_cams: 1 11 | task: tool_hang 12 | test_transform: ${transform:preproc} 13 | 14 | train_buffer: 15 | _target_: data4robotics.sim.robosuite.RoboSuiteBuffer 16 | task: ${task.task} 17 | transform: ${transform:${train_transform}} 18 | n_train_demos: 200 19 | ac_chunk: ${ac_chunk} 20 | cam_indexes: [0] 21 | goal_indexes: [] 22 | past_frames: ${add:${img_chunk},-1} 23 | -------------------------------------------------------------------------------- /experiments/trainer/bc_cos_sched.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.trainers.bc.BehaviorCloning 8 | 9 | optim_builder: 10 | _target_: data4robotics.trainers.utils.optim_builder 11 | optimizer_type: AdamW 12 | optimizer_kwargs: 13 | lr: ${lr} 14 | betas: [0.95, 0.999] 15 | weight_decay: 1.0e-6 16 | eps: 1.0e-8 17 | 18 | schedule_builder: 19 | _target_: data4robotics.trainers.utils.schedule_builder 20 | schedule_type: 'cosine' 21 | from_diffusers: True 22 | schedule_kwargs: 23 | num_warmup_steps: 2000 24 | num_training_steps: ${max_iterations} 25 | -------------------------------------------------------------------------------- /pretrained_networks_example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | from data4robotics import load_resnet18, load_vit 9 | 10 | # load strongest vit/resnet models 11 | vit_transform, vit_model = load_vit() 12 | res_transform, res_model = load_resnet18() 13 | 14 | 15 | # get embeddings from each network 16 | input_img = torch.rand((1, 3, 480, 640)).cuda() 17 | emb_vit = vit_model(vit_transform(input_img)) 18 | emb_res = res_model(res_transform(input_img)) 19 | 20 | 21 | # print out shapes 22 | print("vit_base embedding shape:", emb_vit.shape) 23 | print("resnet18 embedding shape:", emb_res.shape) 24 | -------------------------------------------------------------------------------- /experiments/agent/diffusion_unet.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | defaults: 8 | - features: robomimic 9 | - _self_ 10 | 11 | _target_: data4robotics.models.diffusion_unet.DiffusionUnetAgent 12 | odim: ${task.obs_dim} 13 | n_cams: ${task.n_cams} 14 | use_obs: False 15 | dropout: 0.1 16 | train_diffusion_steps: 100 17 | eval_diffusion_steps: 16 18 | ac_dim: ${task.ac_dim} 19 | ac_chunk: ${ac_chunk} 20 | imgs_per_cam: ${add:${img_chunk},${len:${task.train_buffer.goal_indexes}}} 21 | share_cam_features: False 22 | early_fusion: False 23 | feat_norm: null 24 | 25 | noise_net_kwargs: 26 | diffusion_step_embed_dim: 256 27 | down_dims: [256, 512, 1024] 28 | kernel_size: 3 29 | n_groups: 8 30 | -------------------------------------------------------------------------------- /data4robotics/sim/README.md: -------------------------------------------------------------------------------- 1 | # RoboSuite Evaluation 2 | 3 | We provide a bare-bones implementation to reproduce our `robomimic` sim evaluation results. First, install the following `dm-control` versions of [robosuite](https://github.com/SudeepDasari/robosuite/tree/restore_dit) and [robomimic](https://github.com/SudeepDasari/robomimic), along with their associated dependencies. You will also have to [download](https://github.com/SudeepDasari/robomimic/blob/restore_dit/robomimic/scripts/download_datasets.py) the robomomic dataset (no camera obs required in download) into `/path/to/robomimic/downloads`. Then run: 4 | 5 | ``` 6 | python finetune.py exp_name=test agent=diffusion task=[robomimic_lift/can/square/toolhand] agent/features=resnet_gn agent.features.restore_path=/path/to/resnet18/IN_1M_resnet18.pth trainer=bc_cos_sched ac_chunk=10 eval_freq=15000 batch_size=350 7 | ``` 8 | -------------------------------------------------------------------------------- /data4robotics/models/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class BaseModel(nn.Module): 12 | def __init__(self, model, restore_path): 13 | super().__init__() 14 | self._model = model 15 | if restore_path: 16 | print("Restoring model from", restore_path) 17 | state_dict = torch.load(restore_path, map_location="cpu") 18 | state_dict = ( 19 | state_dict["features"] 20 | if "features" in state_dict 21 | else state_dict["model"] 22 | ) 23 | self.load_state_dict(state_dict) 24 | 25 | @property 26 | def embed_dim(self): 27 | raise NotImplementedError 28 | -------------------------------------------------------------------------------- /experiments/agent/transformer.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | defaults: 7 | - features: resnet_gn_nopool 8 | - _self_ 9 | 10 | _target_: data4robotics.models.action_transformer.TransformerAgent 11 | odim: ${task.obs_dim} 12 | n_cams: ${task.n_cams} 13 | use_obs: add_token 14 | dropout: 0.1 15 | ac_dim: ${task.ac_dim} 16 | ac_chunk: ${ac_chunk} 17 | imgs_per_cam: ${add:${img_chunk},${len:${task.train_buffer.goal_indexes}}} 18 | share_cam_features: False 19 | early_fusion: True 20 | feat_norm: layer_norm 21 | token_dim: 512 22 | 23 | transformer_kwargs: 24 | d_model: ${agent.token_dim} 25 | dropout: ${agent.dropout} 26 | nhead: 8 27 | num_encoder_layers: 4 28 | num_decoder_layers: 6 29 | dim_feedforward: 3200 30 | activation: relu 31 | -------------------------------------------------------------------------------- /experiments/agent/diffusion.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | defaults: 8 | - features: resnet_gn_nopool 9 | - _self_ 10 | 11 | _target_: data4robotics.models.diffusion.DiffusionTransformerAgent 12 | odim: ${task.obs_dim} 13 | n_cams: ${task.n_cams} 14 | use_obs: add_token 15 | dropout: 0.1 16 | train_diffusion_steps: 100 17 | eval_diffusion_steps: 8 18 | ac_dim: ${task.ac_dim} 19 | ac_chunk: ${ac_chunk} 20 | imgs_per_cam: ${add:${img_chunk},${len:${task.train_buffer.goal_indexes}}} 21 | share_cam_features: False 22 | early_fusion: False 23 | feat_norm: null 24 | 25 | noise_net_kwargs: 26 | time_dim: 256 27 | hidden_dim: 512 28 | num_blocks: 6 29 | dim_feedforward: 2048 30 | dropout: ${agent.dropout} 31 | nhead: 8 32 | activation: "gelu" 33 | -------------------------------------------------------------------------------- /data4robotics/trainers/bc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from data4robotics.trainers.base import BaseTrainer 8 | 9 | 10 | class BehaviorCloning(BaseTrainer): 11 | def training_step(self, batch, global_step): 12 | (imgs, obs), actions, mask = batch 13 | imgs = {k: v.to(self.device_id) for k, v in imgs.items()} 14 | obs, actions, mask = [ar.to(self.device_id) for ar in (obs, actions, mask)] 15 | 16 | ac_flat = actions.reshape((actions.shape[0], -1)) 17 | mask_flat = mask.reshape((mask.shape[0], -1)) 18 | loss = self.model(imgs, obs, ac_flat, mask_flat) 19 | self.log("bc_loss", global_step, loss.item()) 20 | if self.is_train: 21 | self.log("lr", global_step, self.lr) 22 | return loss 23 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: data4robotics 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - python ==3.9 9 | - pytorch=2.0.1=py3.9_cuda11.8_cudnn8.7.0_0 10 | - torchvision 11 | - pytorch-cuda=11.8 12 | - cmake=3.22.1=h1fce559_0 13 | - numpy 14 | - pandas 15 | - plotly 16 | - pip 17 | - pytest==7.3.1 18 | - scipy 19 | - tqdm 20 | - patchelf 21 | - pip: 22 | - transforms3d 23 | - opencv-python 24 | - hydra-ray-launcher==1.2.0 25 | - hydra-core==1.2.0 26 | - hydra-submitit-launcher==1.2.0 27 | - wandb==0.13.4 28 | - timm==0.6.11 29 | - gym==0.23.1 30 | - huggingface-hub==0.12.1 31 | - dm-control==1.0.10 32 | - dm-env==1.6 33 | - dm-tree==0.1.8 34 | - cloudpickle==2.0.0 35 | - mujoco==2.3.2 36 | - imageio==2.22.1 37 | - imageio-ffmpeg==0.4.7 38 | - diffusers==0.25.0 39 | - pre-commit == 3.3.3 40 | -------------------------------------------------------------------------------- /experiments/task/viperx.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.task.BCTask 8 | obs_dim: 8 9 | ac_dim: 7 10 | n_cams: ${len:${task.train_buffer.cam_indexes}} 11 | 12 | 13 | train_buffer: 14 | _target_: data4robotics.replay_buffer.RobobufReplayBuffer 15 | buffer_path: ${buffer_path} 16 | transform: ${transform:${train_transform}} 17 | n_test_trans: 500 18 | ac_chunk: ${ac_chunk} 19 | mode: train 20 | cam_indexes: [0, 1] 21 | past_frames: ${add:${img_chunk},-1} 22 | ac_dim: ${task.ac_dim} 23 | 24 | test_buffer: 25 | _target_: data4robotics.replay_buffer.RobobufReplayBuffer 26 | buffer_path: ${buffer_path} 27 | transform: ${transform:preproc} 28 | n_test_trans: 500 29 | ac_chunk: ${ac_chunk} 30 | mode: test 31 | cam_indexes: ${task.train_buffer.cam_indexes} 32 | past_frames: ${add:${img_chunk},-1} 33 | ac_dim: ${task.ac_dim} 34 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /experiments/task/aloha.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.task.BCTask 8 | obs_dim: 14 9 | ac_dim: 14 10 | n_cams: ${len:${task.train_buffer.cam_indexes}} 11 | 12 | 13 | train_buffer: 14 | _target_: data4robotics.replay_buffer.RobobufReplayBuffer 15 | buffer_path: ${buffer_path} 16 | transform: ${transform:${train_transform}} 17 | n_test_trans: 500 18 | ac_chunk: ${ac_chunk} 19 | mode: train 20 | cam_indexes: [0, 1, 3] 21 | goal_indexes: [] 22 | past_frames: ${add:${img_chunk},-1} 23 | ac_dim: ${task.ac_dim} 24 | 25 | test_buffer: 26 | _target_: data4robotics.replay_buffer.RobobufReplayBuffer 27 | buffer_path: ${buffer_path} 28 | transform: ${transform:preproc} 29 | n_test_trans: 500 30 | ac_chunk: ${ac_chunk} 31 | mode: test 32 | cam_indexes: ${task.train_buffer.cam_indexes} 33 | goal_indexes: ${task.train_buffer.goal_indexes} 34 | past_frames: ${add:${img_chunk},-1} 35 | ac_dim: ${task.ac_dim} 36 | -------------------------------------------------------------------------------- /experiments/task/end_effector.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.task.BCTask 8 | obs_dim: 7 9 | ac_dim: 7 10 | n_cams: ${len:${task.train_buffer.cam_indexes}} 11 | 12 | 13 | train_buffer: 14 | _target_: data4robotics.replay_buffer.RobobufReplayBuffer 15 | buffer_path: ${buffer_path} 16 | transform: ${transform:${train_transform}} 17 | n_test_trans: 500 18 | ac_chunk: ${ac_chunk} 19 | mode: train 20 | cam_indexes: [0] 21 | goal_indexes: [] 22 | past_frames: ${add:${img_chunk},-1} 23 | ac_dim: ${task.ac_dim} 24 | 25 | test_buffer: 26 | _target_: data4robotics.replay_buffer.RobobufReplayBuffer 27 | buffer_path: ${buffer_path} 28 | transform: ${transform:preproc} 29 | n_test_trans: 500 30 | ac_chunk: ${ac_chunk} 31 | mode: test 32 | cam_indexes: ${task.train_buffer.cam_indexes} 33 | goal_indexes: ${task.train_buffer.goal_indexes} 34 | past_frames: ${add:${img_chunk},-1} 35 | ac_dim: ${task.ac_dim} 36 | -------------------------------------------------------------------------------- /experiments/task/end_effector_r6.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.task.BCTask 8 | obs_dim: 7 9 | ac_dim: 10 10 | n_cams: ${len:${task.train_buffer.cam_indexes}} 11 | 12 | 13 | train_buffer: 14 | _target_: data4robotics.replay_buffer.RobobufReplayBuffer 15 | buffer_path: ${buffer_path} 16 | transform: ${transform:${train_transform}} 17 | n_test_trans: 500 18 | ac_chunk: ${ac_chunk} 19 | mode: train 20 | cam_indexes: [0] 21 | goal_indexes: [] 22 | past_frames: ${add:${img_chunk},-1} 23 | ac_dim: ${task.ac_dim} 24 | 25 | test_buffer: 26 | _target_: data4robotics.replay_buffer.RobobufReplayBuffer 27 | buffer_path: ${buffer_path} 28 | transform: ${transform:preproc} 29 | n_test_trans: 500 30 | ac_chunk: ${ac_chunk} 31 | mode: test 32 | cam_indexes: ${task.train_buffer.cam_indexes} 33 | goal_indexes: ${task.train_buffer.goal_indexes} 34 | past_frames: ${add:${img_chunk},-1} 35 | ac_dim: ${task.ac_dim} 36 | -------------------------------------------------------------------------------- /experiments/task/aloha_gc.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | _target_: data4robotics.task.BCTask 8 | obs_dim: 14 9 | ac_dim: 14 10 | n_cams: ${len:${task.train_buffer.cam_indexes}} 11 | 12 | 13 | train_buffer: 14 | _target_: data4robotics.replay_buffer.RobobufReplayBuffer 15 | buffer_path: ${buffer_path} 16 | transform: ${transform:${train_transform}} 17 | n_test_trans: 500 18 | ac_chunk: ${ac_chunk} 19 | mode: train 20 | cam_indexes: [0, 1, 3] 21 | goal_indexes: [0] 22 | past_frames: ${add:${img_chunk},-1} 23 | ac_dim: ${task.ac_dim} 24 | 25 | test_buffer: 26 | _target_: data4robotics.replay_buffer.RobobufReplayBuffer 27 | buffer_path: ${buffer_path} 28 | transform: ${transform:preproc} 29 | n_test_trans: 500 30 | ac_chunk: ${ac_chunk} 31 | mode: test 32 | cam_indexes: ${task.train_buffer.cam_indexes} 33 | goal_indexes: ${task.train_buffer.goal_indexes} 34 | past_frames: ${add:${img_chunk},-1} 35 | ac_dim: ${task.ac_dim} 36 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.9 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v2.3.0 6 | hooks: 7 | - id: check-yaml 8 | - id: check-ast 9 | - id: check-added-large-files 10 | exclude: ^examples/ 11 | - id: check-case-conflict 12 | - id: check-merge-conflict 13 | - id: end-of-file-fixer 14 | - id: trailing-whitespace 15 | - id: detect-private-key 16 | - id: debug-statements 17 | exclude: ^experiments/ 18 | - repo: https://github.com/psf/black 19 | rev: 22.10.0 20 | hooks: 21 | - id: black 22 | exclude: ^experiments/ 23 | - repo: https://github.com/PyCQA/flake8 24 | rev: 6.1.0 25 | hooks: 26 | - id: flake8 27 | exclude: ^experiments/ 28 | - repo: https://github.com/pycqa/isort 29 | rev: 5.12.0 30 | hooks: 31 | - id: isort 32 | exclude: ^experiments/ 33 | args: ["--profile", "black", "--src", "data4robotics", "--src", "experiments"] 34 | - repo: https://github.com/srstevenson/nb-clean 35 | rev: 3.1.0 36 | hooks: 37 | - id: nb-clean 38 | args: 39 | - --remove-empty-cells 40 | - --preserve-cell-outputs 41 | -------------------------------------------------------------------------------- /experiments/finetune.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | defaults: 8 | - agent: default 9 | - task: end_effector 10 | - trainer: bc 11 | - override hydra/launcher: slurm 12 | - _self_ 13 | 14 | 15 | hydra: 16 | run: 17 | dir: bc_finetune/${exp_name}/wandb_${wandb.name}_${hydra:runtime.choices.task}_${hydra:runtime.choices.agent/features}_${now:%Y-%m-%d_%H-%M-%S} 18 | sweep: 19 | dir: ${base:}/../bc_finetune/${exp_name}/${now:%Y-%m-%d_%H-%M-%S} 20 | subdir: run${hydra:job.num}_${hydra:runtime.choices.task}_${hydra:runtime.choices.agent/features} 21 | 22 | rt: ${hydra:runtime.choices.agent/features} 23 | 24 | exp_name: test 25 | checkpoint_path: ${exp_name}.ckpt 26 | batch_size: 150 27 | num_workers: 10 28 | lr: 0.0003 29 | max_iterations: 150000 30 | eval_freq: 10000 31 | save_freq: 10000 32 | schedule_freq: 1 33 | devices: 1 34 | seed: 292285 35 | 36 | buffer_path: ./buffer.pkl 37 | ac_chunk: 1 38 | img_chunk: 1 # number of image timesteps to use (including current one) 39 | train_transform: medium 40 | 41 | wandb: 42 | name: null 43 | project: human_aloha 44 | group: ${exp_name} 45 | sweep_name_prefix: eval 46 | debug: False 47 | entity: cmu-agi-lab 48 | -------------------------------------------------------------------------------- /eval_scripts/README.md: -------------------------------------------------------------------------------- 1 | # Eval Scripts 2 | 3 | We provide some example evaluation scrips that will allow you to run our policies on ALOHA and DROID robots. 4 | * `eval_aloha.py` will deploy our policies on an ALOHA robot assuming the default 14-DoF joint-state action space. It provides an implementation for both temporal ensembling and receding horizon control with chunked action predictions. 5 | * `eval_droid.py` will deploy our policies on a DROID robot using the default cartesian velocity action space. Predicted actions are directly executed on the robot (no test time smoothing). 6 | * `eval_droid_state.py` will deploy our policies on a DROID robot using the cartesian position action space. Note that the rotation actions are predicted using a R6 representation (conversion code [here](https://github.com/AGI-Labs/r2d2_to_robobuf/blob/main/converter.py)) following [Chi et. al.](https://diffusion-policy.cs.columbia.edu). After prediction the actions are further smoothed with temporal ensembling. This action space is ideal for diffusion policy (U-Net) action heads. 7 | 8 | 9 | ## Setup Instructions 10 | 11 | Just download the policy folder (produced by `finetune.py`) and add a file named `obs_config.yaml` to it. This will tell the eval script how to process the observations for the policy. An example is provided below: 12 | 13 | ``` 14 | img: '26638268_left' 15 | transform: 16 | _target_: data4robotics.transforms.get_transform_by_name 17 | name: preproc 18 | ``` 19 | -------------------------------------------------------------------------------- /jobs.sh: -------------------------------------------------------------------------------- 1 | # vc-1 training command (velocity action space, gaussian mlp policy) 2 | nice -n 19 python finetune.py agent/policy=gaussian_constant exp_name=octo_baselines wandb.name=vc1_baseline buffer_path=/path/to/vel/buf.pkl max_iterations=50000 task.train_buffer.cam_indexes=[] train_transform=hard agent.features.restore_path=/path/to/vc1.pth 3 | 4 | # r3m training command (velocity action space, gaussian mlp policy) 5 | nice -n 19 python finetune.py agent/features=r3m agent/policy=gaussian_constant exp_name=octo_baselines wandb.name=r3m_baseline buffer_path=/path/to/vel/buf.pkl max_iterations=50000 task.train_buffer.cam_indexes=[] train_transform=medium agent.features.size=50 6 | 7 | # single-cam diffusion (position + r6 rotation action space) 8 | nice -n 19 python finetune.py agent=diffusion_unet exp_name=octo_baselines wandb.name=diffusion_singlecam buffer_path=/path/to/abs_r6/buf.pkl max_iterations=500000 trainer=bc_cos_sched ac_chunk=16 train_transform=medium task.train_buffer.cam_indexes=[] agent.features.feature_dim=256 9 | 10 | # wrist-cam + 2-step obs diffusion (position + r6 rotation action space) 11 | nice -n 19 python finetune.py agent=diffusion_unet exp_name=octo_baselines wandb.name=diffusion_multicam buffer_path=/path/to/abs_r6/buf.pkl max_iterations=500000 trainer=bc_cos_sched ac_chunk=16 train_transform=medium task.train_buffer.cam_indexes=[, ] task.train_buffer.cam_indexes=[0,2] img_chunk=2 -------------------------------------------------------------------------------- /data4robotics/trainers/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | # schedule_builder inspired from Diffusion Policy Codebase (Chi et al; arXiv:2303.04137) 3 | 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import functools 9 | 10 | import torch.optim as optim 11 | from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, SchedulerType 12 | from torch.optim import lr_scheduler 13 | 14 | 15 | def optim_builder(optimizer_type, optimizer_kwargs): 16 | optimizer_class = getattr(optim, optimizer_type) 17 | return functools.partial(optimizer_class, **optimizer_kwargs) 18 | 19 | 20 | def schedule_builder(schedule_type, schedule_kwargs, from_diffusers=False): 21 | if from_diffusers: 22 | schedule_type = SchedulerType(schedule_type) 23 | schedule_func = TYPE_TO_SCHEDULER_FUNCTION[schedule_type] 24 | 25 | if schedule_type == SchedulerType.CONSTANT: 26 | return functools.partial(schedule_func, **schedule_kwargs) 27 | 28 | assert ( 29 | "num_warmup_steps" in schedule_kwargs 30 | ), "Scheduler requires num_warmup_steps!" 31 | if schedule_type == SchedulerType.CONSTANT_WITH_WARMUP: 32 | return functools.partial(schedule_func, **schedule_kwargs) 33 | 34 | # All other schedulers require `num_training_steps` 35 | assert ( 36 | "num_training_steps" in schedule_kwargs 37 | ), "Scheduler requires num_training_steps!" 38 | return functools.partial(schedule_func, **schedule_kwargs) 39 | schedule_class = getattr(lr_scheduler, schedule_type) 40 | return functools.partial(schedule_class, **schedule_kwargs) 41 | -------------------------------------------------------------------------------- /data4robotics/load_pretrained.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import os 8 | 9 | import torch 10 | from torchvision import transforms 11 | 12 | import data4robotics 13 | from data4robotics.models import resnet, vit 14 | 15 | # feature install path 16 | BASE_PATH = os.path.dirname(data4robotics.__file__) + "/../" 17 | FEATURE_PATH = os.path.join(BASE_PATH, "visual_features") 18 | 19 | 20 | def _check_and_download(): 21 | old_cwd = os.getcwd() 22 | 23 | # change cwd to main folder and run download script 24 | os.chdir(BASE_PATH) 25 | download_script = os.path.join(BASE_PATH, "download_features.sh") 26 | os.system(download_script) 27 | 28 | # change cwd back to old location 29 | os.chdir(old_cwd) 30 | 31 | 32 | def default_transform(): 33 | return transforms.Compose( 34 | [ 35 | transforms.Resize((224, 224), antialias=False), 36 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 37 | ] 38 | ) 39 | 40 | 41 | def load_vit(model_name="IN_hrp", device=torch.device("cuda:0")): 42 | _check_and_download() 43 | model = vit.vit_base_patch16(img_size=224, use_cls=True, drop_path_rate=0.0) 44 | 45 | restore_path = ( 46 | f"hrp/{model_name}.pth" if "hrp" in model_name else f"vit_base/{model_name}.pth" 47 | ) 48 | restore_path = os.path.join(FEATURE_PATH, restore_path) 49 | model = vit.load_vit(model, restore_path) 50 | return default_transform(), model.to(device) 51 | 52 | 53 | def load_resnet18(model_name="IN_1M_resnet18", device=torch.device("cuda:0")): 54 | _check_and_download() 55 | restore_path = os.path.join(FEATURE_PATH, f"resnet18/{model_name}.pth") 56 | model = resnet.ResNet( 57 | size=18, 58 | pretrained=None, 59 | restore_path=restore_path, 60 | norm_cfg=dict(name="group_norm", num_groups=16), 61 | ) 62 | return default_transform(), model.to(device) 63 | -------------------------------------------------------------------------------- /eval_scripts/eval_droid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import time 5 | from pathlib import Path 6 | 7 | import cv2 8 | import numpy as np 9 | import torch 10 | import yaml 11 | 12 | # r2d2 robot imports 13 | from droid.user_interface.eval_gui import EvalGUI 14 | 15 | import hydra 16 | 17 | 18 | class DROIDPolicy: 19 | def __init__(self, agent_path, model_name): 20 | with open(Path(agent_path, "agent_config.yaml"), "r") as f: 21 | config_yaml = f.read() 22 | agent_config = yaml.safe_load(config_yaml) 23 | with open(Path(agent_path, "obs_config.yaml"), "r") as f: 24 | config_yaml = f.read() 25 | obs_config = yaml.safe_load(config_yaml) 26 | with open(Path(agent_path, "ac_norm.json"), "r") as f: 27 | ac_norm_dict = json.load(f) 28 | loc, scale = ac_norm_dict["loc"], ac_norm_dict["scale"] 29 | self.loc = np.array(loc).astype(np.float32) 30 | self.scale = np.array(scale).astype(np.float32) 31 | 32 | agent = hydra.utils.instantiate(agent_config) 33 | save_dict = torch.load(Path(agent_path, model_name), map_location="cpu") 34 | agent.load_state_dict(save_dict["model"]) 35 | self.agent = agent.eval().cuda() 36 | 37 | self.transform = hydra.utils.instantiate(obs_config["transform"]) 38 | self.img_key = obs_config["img"] 39 | 40 | print(f"loaded agent from {agent_path}, at step: {save_dict['global_step']}") 41 | self._last_time = None 42 | 43 | def _proc_image(self, zed_img, size=(256, 256)): 44 | bgr_img = zed_img[:, :, :3] 45 | bgr_img = cv2.resize(bgr_img, size, interpolation=cv2.INTER_AREA) 46 | rgb_img = bgr_img[:, :, ::-1].copy() 47 | rgb_img = torch.from_numpy(rgb_img).float().permute((2, 0, 1)) / 255 48 | return {"cam0": self.transform(rgb_img)[None].cuda()} 49 | 50 | def _proc_state(self, cart_pos, grip_pos): 51 | state = np.concatenate((cart_pos, np.array([grip_pos]))).astype(np.float32) 52 | return torch.from_numpy(state)[None].cuda() 53 | 54 | def forward(self, obs): 55 | img = self._proc_image(obs["image"][self.img_key]) 56 | state = self._proc_state( 57 | obs["robot_state"]["cartesian_position"], 58 | obs["robot_state"]["gripper_position"], 59 | ) 60 | with torch.no_grad(): 61 | ac = self.agent.get_actions(img, state) 62 | ac = ac[0, 0].cpu().numpy().astype(np.float32) 63 | ac = np.clip(ac * self.scale + self.loc, -1, 1) # denormalize the actions 64 | 65 | cur_time = time.time() 66 | if self._last_time is not None: 67 | print("Effective HZ:", 1.0 / (cur_time - self._last_time)) 68 | self._last_time = cur_time 69 | return ac 70 | 71 | def load_goal_imgs(self, goal_dict): 72 | pass 73 | 74 | def load_lang(self, text): 75 | pass 76 | 77 | 78 | def main(): 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument("checkpoint") 81 | args = parser.parse_args() 82 | 83 | agent_path = os.path.expanduser(os.path.dirname(args.checkpoint)) 84 | model_name = args.checkpoint.split("/")[-1] 85 | policy = DROIDPolicy(agent_path, model_name) 86 | 87 | # start up DROID eval gui 88 | EvalGUI(policy=policy) 89 | 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /diffuse_jobs.sh: -------------------------------------------------------------------------------- 1 | nice -n 19 python finetune.py exp_name=t3_base3 agent=diffusion agent.features.restore_path=/home/sudeep/dp_pt/IN_1M.pth buffer_path=/scratch/sudeep/toaster3/buf.pkl wandb.debug=False task.train_buffer.cam_idx=2 max_iterations=100000 trainer=bc_step_sched 2 | nice -n 19 python finetune.py exp_name=t3_base3 agent=diffusion agent.features.restore_path=/home/sudeep/dp_pt/IN_1M.pth buffer_path=/scratch/sudeep/toaster3/buf.pkl wandb.debug=False task.train_buffer.cam_idx=2 max_iterations=100000 trainer=bc_step_sched agent.noise_net_kwargs.hidden_dim=512 3 | 4 | nice -n 19 python finetune.py exp_name=t3_base2 agent=diffusion agent.features.restore_path=/home/sudeep/dp_pt/IN_1M.pth buffer_path=/scratch/sudeep/toaster3/buf.pkl wandb.debug=False task.train_buffer.cam_idx=2 trainer.lr=0.0001 trainer.weight_decay=0.0001; 5 | nice -n 19 python finetune.py exp_name=t3_base2 agent=diffusion agent.features.restore_path=/home/sudeep/dp_pt/IN_1M.pth buffer_path=/scratch/sudeep/toaster3/buf.pkl wandb.debug=False task.train_buffer.cam_idx=2 trainer.lr=0.0005; 6 | nice -n 19 python finetune.py exp_name=t3_base2 agent=diffusion agent.features.restore_path=/home/sudeep/dp_pt/IN_1M.pth buffer_path=/scratch/sudeep/toaster3/buf.pkl wandb.debug=False task.train_buffer.cam_idx=2 trainer.lr=0.001; 7 | nice -n 19 python finetune.py exp_name=t3_base2 agent=diffusion agent.features.restore_path=/home/sudeep/dp_pt/IN_1M.pth buffer_path=/scratch/sudeep/toaster3/buf.pkl wandb.debug=False task.train_buffer.cam_idx=2 trainer.lr=0.0001 trainer.weight_decay=0; 8 | 9 | # baselines 10 | # nice -n 19 python finetune.py exp_name=t3_base2 agent.features.restore_path=/home/sudeep/dp_pt/IN_1M.pth buffer_path=/scratch/sudeep/toaster3/buf.pkl wandb.debug=False task.train_buffer.cam_idx=2 11 | 12 | # test 13 | # python finetune.py exp_name=test agent=diffusion agent.features.restore_path=/home/sudeep/dp_pt/IN_1M.pth buffer_path=/scratch/sudeep/toaster3/buf.pkl wandb.debug=True task.train_buffer.cam_idx=2 batch_size=5 14 | 15 | 16 | nice -n 19 python finetune.py exp_name=diffuse_toaster3 agent=diffusion_unet buffer_path=/scratch/sudeep/toaster3/diffusion/buf.pkl max_iterations=500000 trainer=bc_cos_sched ac_chunk=16 train_transform=medium task.train_buffer.cam_indexes=[0,2] img_chunk=2 17 | nice -n 19 python finetune.py exp_name=diffuse_toaster3 agent=diffusion_unet buffer_path=/scratch/sudeep/toaster3/diffusion/buf.pkl max_iterations=500000 trainer=bc_cos_sched ac_chunk=16 train_transform=medium task.train_buffer.cam_indexes=[2] agent.features.feature_dim=256 18 | 19 | 20 | # baseline command (ImageNet + GMM) 21 | nice -n 19 python finetune.py exp_name=test wandb.name=vit_baseline buffer_path=/scratch/sudeep/toaster3/vel/buf.pkl max_iterations=50000 task.train_buffer.cam_indexes=[2] train_transform=hard agent.features.restore_path=/home/sudeep/dp_pt/IN_1M.pth 22 | 23 | # w/ old net 24 | 25 | nice -n 19 python finetune.py exp_name=diffuse_toaster3 agent=diffusion task=end_effector_r6 buffer_path=/scratch/sudeep/toaster3/diffusion/buf.pkl max_iterations=500000 trainer=bc_cos_sched ac_chunk=16 train_transform=medium task.train_buffer.cam_indexes=[2] 26 | 27 | 28 | # old 29 | # nice -n 19 python finetune.py exp_name=diffuse_toaster3 agent=diffusion_unet buffer_path=/scratch/sudeep/toaster3/abs/buf.pkl task.train_buffer.cam_idx=2 max_iterations=500000 trainer=bc_cos_sched ac_chunk=16 train_transform=diffusion 30 | # nice -n 19 python finetune.py exp_name=test buffer_path=/scratch/sudeep/r2d2_octo/berkeley_pick_data/buf.pkl wandb.debug=True task.train_buffer.cam_idx=0 max_iterations=500000 trainer=bc_cos_sched ac_chunk=16 train_transform=diffusion batch_size=5 -------------------------------------------------------------------------------- /data4robotics/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import functools 8 | import os 9 | import signal 10 | import sys 11 | 12 | import wandb 13 | import yaml 14 | from hydra.core.hydra_config import HydraConfig 15 | from omegaconf import OmegaConf 16 | 17 | from data4robotics.transforms import get_transform_by_name 18 | 19 | OmegaConf.register_new_resolver("env", lambda x: os.environ[x]) 20 | OmegaConf.register_new_resolver( 21 | "base", lambda: os.path.dirname(os.path.abspath(__file__)) 22 | ) 23 | OmegaConf.register_new_resolver("transform", lambda name: get_transform_by_name(name)) 24 | OmegaConf.register_new_resolver("mult", lambda x, y: int(x) * int(y)) 25 | OmegaConf.register_new_resolver("add", lambda x, y: int(x) + int(y)) 26 | OmegaConf.register_new_resolver("index", lambda arr, idx: arr[idx]) 27 | OmegaConf.register_new_resolver("len", lambda arr: len(arr)) 28 | 29 | 30 | GLOBAL_STEP = 0 31 | REQUEUE_CAUGHT = False 32 | 33 | 34 | def _signal_helper(signal, frame, prior_handler, trainer, ckpt_path): 35 | global REQUEUE_CAUGHT, GLOBAL_STEP 36 | REQUEUE_CAUGHT = True 37 | 38 | # save train checkpoint 39 | print(f"Caught requeue signal at step: {GLOBAL_STEP}") 40 | trainer.save_checkpoint(ckpt_path, GLOBAL_STEP) 41 | 42 | # return back to submitit handler if it exists 43 | if callable(prior_handler): 44 | return prior_handler(signal, frame) 45 | return sys.exit(-1) 46 | 47 | 48 | def set_checkpoint_handler(trainer, ckpt_path): 49 | global REQUEUE_CAUGHT 50 | REQUEUE_CAUGHT = False 51 | prior_handler = signal.getsignal(signal.SIGUSR2) 52 | handler = functools.partial( 53 | _signal_helper, 54 | prior_handler=prior_handler, 55 | trainer=trainer, 56 | ckpt_path=ckpt_path, 57 | ) 58 | signal.signal(signal.SIGUSR2, handler) 59 | 60 | 61 | def create_wandb_run(wandb_cfg, job_config, run_id=None): 62 | if wandb_cfg.debug: 63 | return "null_id" 64 | try: 65 | job_id = HydraConfig().get().job.num 66 | override_dirname = HydraConfig().get().job.override_dirname 67 | name = f"{wandb_cfg.sweep_name_prefix}-{job_id}" 68 | notes = f"{override_dirname}" 69 | except: 70 | name, notes = wandb_cfg.name, None 71 | 72 | wandb_run = wandb.init( 73 | project=wandb_cfg.project, 74 | group=wandb_cfg.group, 75 | entity=wandb_cfg.entity, 76 | config=job_config, 77 | name=name, 78 | notes=notes, 79 | id=run_id, 80 | resume=run_id is not None, 81 | ) 82 | return wandb_run.id 83 | 84 | 85 | def init_job(cfg): 86 | cfg_yaml = OmegaConf.to_yaml(cfg) 87 | if os.path.exists("exp_config.yaml"): 88 | old_config = yaml.safe_load(open("exp_config.yaml", "r")) 89 | create_wandb_run(cfg.wandb, old_config["params"], old_config["wandb_id"]) 90 | resume_model = cfg.checkpoint_path 91 | assert os.path.exists(resume_model), "{} does not exist!".format( 92 | cfg.checkpoint_path 93 | ) 94 | else: 95 | params = yaml.safe_load(cfg_yaml) 96 | wandb_id = create_wandb_run(cfg.wandb, params) 97 | save_dict = dict(wandb_id=wandb_id, params=params) 98 | yaml.dump(save_dict, open("exp_config.yaml", "w")) 99 | resume_model = None 100 | print("Training w/ Config:") 101 | print(cfg_yaml) 102 | return resume_model 103 | -------------------------------------------------------------------------------- /data4robotics/trainers/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | from abc import ABC, abstractmethod 8 | 9 | import numpy as np 10 | import torch 11 | import wandb 12 | from torch.nn.parallel import DistributedDataParallel as DDP 13 | 14 | TRAIN_LOG_FREQ, EVAL_LOG_FREQ = 100, 1 15 | 16 | 17 | class RunningMean: 18 | def __init__(self, max_len=TRAIN_LOG_FREQ): 19 | self._values = [] 20 | self._ctr, self._max_len = 0, max_len 21 | 22 | def append(self, item): 23 | self._ctr = (self._ctr + 1) % self._max_len 24 | if len(self._values) < self._max_len: 25 | self._values.append(item) 26 | else: 27 | self._values[self._ctr] = item 28 | 29 | @property 30 | def mean(self): 31 | if len(self._values) == 0: 32 | raise ValueError 33 | return np.mean(self._values) 34 | 35 | 36 | class BaseTrainer(ABC): 37 | def __init__(self, model, device_id, optim_builder, schedule_builder=None): 38 | self.model, self.device_id = model, device_id 39 | self.set_device(device_id) 40 | self.optim = optim_builder(self.model.parameters()) 41 | self.schedule = ( 42 | None if schedule_builder is None else schedule_builder(self.optim) 43 | ) 44 | self._trackers = dict() 45 | self._is_train = True 46 | self.set_train() 47 | 48 | @abstractmethod 49 | def training_step(self, batch_input, global_step): 50 | pass 51 | 52 | @property 53 | def lr(self): 54 | if self.schedule is None: 55 | return self.optim.param_groups[0]["lr"] 56 | return self.schedule.get_last_lr()[0] 57 | 58 | def step_schedule(self): 59 | if self.schedule is None: 60 | return 61 | self.schedule.step() 62 | 63 | def save_checkpoint(self, save_path, global_step): 64 | model = self.model 65 | model_weights = ( 66 | model.module.state_dict() if isinstance(model, DDP) else model.state_dict() 67 | ) 68 | schedule_state = dict() if self.schedule is None else self.schedule.state_dict() 69 | save_dict = dict( 70 | model=model_weights, 71 | optim=self.optim.state_dict(), 72 | schedule=schedule_state, 73 | global_step=global_step, 74 | ) 75 | torch.save(save_dict, save_path) 76 | 77 | def load_checkpoint(self, load_path): 78 | load_dict = torch.load(load_path) 79 | model = self.model 80 | model = model.module if isinstance(model, DDP) else model 81 | model.load_state_dict(load_dict["model"]) 82 | 83 | self.optim.load_state_dict(load_dict["optim"]) 84 | if self.schedule is not None: 85 | self.schedule.load_state_dict(load_dict["schedule"]) 86 | 87 | return load_dict["global_step"] 88 | 89 | def _load_callback(self, load_path, load_dict): 90 | pass 91 | 92 | @property 93 | def is_train(self): 94 | return self._is_train 95 | 96 | def set_train(self): 97 | self._is_train = True 98 | self.model = self.model.train() 99 | 100 | def set_eval(self): 101 | self._is_train = False 102 | self.model = self.model.eval() 103 | 104 | # reset running mean for eval trackers 105 | for k in self._trackers: 106 | if "eval/" in k: 107 | self._trackers[k] = RunningMean() 108 | 109 | def log(self, key, global_step, value): 110 | log_freq = TRAIN_LOG_FREQ if self._is_train else EVAL_LOG_FREQ 111 | key_prepend = "train/" if self._is_train else "eval/" 112 | key = key_prepend + key 113 | 114 | if key not in self._trackers: 115 | self._trackers[key] = RunningMean() 116 | 117 | tracker = self._trackers[key] 118 | tracker.append(value) 119 | 120 | if global_step % log_freq == 0 and wandb.run is not None: 121 | wandb.log({key: tracker.mean}, step=global_step) 122 | 123 | def set_device(self, device_id): 124 | self.model = self.model.to(device_id) 125 | -------------------------------------------------------------------------------- /data4robotics/task.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import numpy as np 8 | import torch 9 | import wandb 10 | from torch.utils.data import DataLoader, IterableDataset 11 | 12 | from data4robotics.replay_buffer import IterableWrapper 13 | 14 | _TEST_WORKERS = 4 15 | 16 | 17 | def _build_data_loader(buffer, batch_size, num_workers, is_train=False): 18 | if is_train and not isinstance(buffer, IterableDataset): 19 | buffer = IterableWrapper(buffer) 20 | 21 | return DataLoader( 22 | buffer, 23 | batch_size=batch_size, 24 | num_workers=num_workers, 25 | shuffle=not isinstance(buffer, IterableDataset), 26 | pin_memory=True, 27 | persistent_workers=num_workers > 0, 28 | drop_last=True, 29 | worker_init_fn=lambda _: np.random.seed(), 30 | ) 31 | 32 | 33 | class DefaultTask: 34 | def __init__( 35 | self, 36 | train_buffer, 37 | test_buffer, 38 | n_cams, 39 | obs_dim, 40 | ac_dim, 41 | batch_size, 42 | num_workers, 43 | ): 44 | self.n_cams, self.obs_dim, self.ac_dim = n_cams, obs_dim, ac_dim 45 | self.train_loader = _build_data_loader( 46 | train_buffer, batch_size, num_workers, is_train=True 47 | ) 48 | if test_buffer is not None: 49 | test_workers = min(num_workers, _TEST_WORKERS) 50 | self.test_loader = _build_data_loader(test_buffer, batch_size, test_workers) 51 | 52 | def eval(self, trainer, global_step): 53 | losses = [] 54 | for batch in self.test_loader: 55 | with torch.no_grad(): 56 | loss = trainer.training_step(batch, global_step) 57 | losses.append(loss.item()) 58 | 59 | mean_val_loss = np.mean(losses) 60 | print(f"Step: {global_step}\tVal Loss: {mean_val_loss:.4f}") 61 | if wandb.run is not None: 62 | wandb.log({"eval/task_loss": mean_val_loss}, step=global_step) 63 | 64 | 65 | class BCTask(DefaultTask): 66 | def eval(self, trainer, global_step): 67 | losses = [] 68 | action_l2, action_lsig = [], [] 69 | for batch in self.test_loader: 70 | (imgs, obs), actions, mask = batch 71 | imgs = {k: v.to(trainer.device_id) for k, v in imgs.items()} 72 | obs, actions, mask = [ 73 | ar.to(trainer.device_id) for ar in (obs, actions, mask) 74 | ] 75 | 76 | with torch.no_grad(): 77 | loss = trainer.training_step(batch, global_step) 78 | losses.append(loss.item()) 79 | 80 | # compare predicted actions versus GT 81 | pred_actions = trainer.model.get_actions(imgs, obs) 82 | 83 | # calculate l2 loss between pred_action and action 84 | l2_delta = torch.square(mask * (pred_actions - actions)) 85 | l2_delta = l2_delta.sum((1, 2)) / mask.sum((1, 2)) 86 | 87 | # calculate the % of time the signs agree 88 | lsig = torch.logical_or( 89 | torch.logical_and(actions > 0, pred_actions <= 0), 90 | torch.logical_and(actions <= 0, pred_actions > 0), 91 | ) 92 | lsig = (lsig.float() * mask).sum((1, 2)) / mask.sum((1, 2)) 93 | 94 | # log mean error values 95 | action_l2.append(l2_delta.mean().item()) 96 | action_lsig.append(lsig.mean().item()) 97 | 98 | mean_val_loss = np.mean(losses) 99 | ac_l2, ac_lsig = np.mean(action_l2), np.mean(action_lsig) 100 | print(f"Step: {global_step}\tVal Loss: {mean_val_loss:.4f}") 101 | print(f"Step: {global_step}\tAC L2={ac_l2:.2f}\tAC LSig={ac_lsig:.2f}") 102 | 103 | if wandb.run is not None: 104 | wandb.log( 105 | { 106 | "eval/task_loss": mean_val_loss, 107 | "eval/action_l2": ac_l2, 108 | "eval/action_lsig": ac_lsig, 109 | }, 110 | step=global_step, 111 | ) 112 | -------------------------------------------------------------------------------- /test_agent.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import time 4 | import cv2 5 | import hydra 6 | import numpy as np 7 | import torch 8 | import yaml 9 | import os 10 | import json 11 | import random 12 | import pickle as pkl 13 | from robobuf import ReplayBuffer as RB 14 | from data4robotics.transforms import get_transform_by_name 15 | 16 | 17 | # constants for data loading 18 | BUF_SHUFFLE_RNG = 3904767649 # from replay_buffer.py 19 | n_test_trans = 500 # usually hardocded in task/franka.yaml 20 | 21 | 22 | class BaselinePolicy: 23 | def __init__(self, agent_path, model_name): 24 | with open(Path(agent_path, "agent_config.yaml"), "r") as f: 25 | config_yaml = f.read() 26 | agent_config = yaml.safe_load(config_yaml) 27 | with open(Path(agent_path, "exp_config.yaml"), "r") as f: 28 | config_yaml = f.read() 29 | exp_config = yaml.safe_load(config_yaml) 30 | self.cam_idx = exp_config['params']['task']['train_buffer']['cam_idx'] 31 | 32 | agent = hydra.utils.instantiate(agent_config) 33 | save_dict = torch.load(Path(agent_path, model_name), map_location="cpu") 34 | agent.load_state_dict(save_dict['model']) 35 | self.agent = agent.eval().cuda() 36 | 37 | self.transform = get_transform_by_name('preproc') 38 | 39 | def _proc_image(self, rgb_img, size=(256,256)): 40 | rgb_img = cv2.resize(rgb_img, size, interpolation=cv2.INTER_AREA) 41 | rgb_img = torch.from_numpy(rgb_img).float().permute((2, 0, 1)) / 255 42 | return self.transform(rgb_img)[None].cuda() 43 | 44 | def forward(self, img, obs): 45 | img = self._proc_image(img) 46 | state = torch.from_numpy(obs)[None].float().cuda() 47 | 48 | with torch.no_grad(): 49 | ac = self.agent.get_actions(img, state) 50 | ac = ac[0].cpu().numpy().astype(np.float32) 51 | return ac 52 | 53 | @property 54 | def ac_chunk(self): 55 | return self.agent.ac_chunk 56 | 57 | 58 | def _get_data(idx, buf, ac_chunk, cam_idx): 59 | t = buf[idx] 60 | loop_t, chunked_actions = t, [] 61 | for _ in range(ac_chunk): 62 | if loop_t.next is None: 63 | break 64 | chunked_actions.append(loop_t.action[None]) 65 | loop_t = loop_t.next 66 | 67 | if len(chunked_actions) != ac_chunk: 68 | raise ValueError 69 | 70 | i_t, o_t = t.obs.image(cam_idx), t.obs.state 71 | i_t_prime, o_t_prime = t.next.obs.image(cam_idx), t.next.obs.state 72 | a_t = np.concatenate(chunked_actions, 0) 73 | return i_t, o_t, a_t 74 | 75 | 76 | def main(): 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument("checkpoint") 79 | parser.add_argument("--buffer_path", default='/scratch/sudeep/toaster3/buf.pkl') 80 | args = parser.parse_args() 81 | 82 | agent_path = os.path.expanduser(os.path.dirname(args.checkpoint)) 83 | model_name = args.checkpoint.split('/')[-1] 84 | policy = BaselinePolicy(agent_path, model_name) 85 | 86 | # build data loader 87 | cam_idx = policy.cam_idx 88 | print('cam_idx:', cam_idx) 89 | with open(args.buffer_path, 'rb') as f: 90 | buf = RB.load_traj_list(pkl.load(f)) 91 | 92 | # shuffle the list with the fixed seed 93 | rng = random.Random(BUF_SHUFFLE_RNG) 94 | 95 | # get and shuffle list of buf indices, and get test data 96 | index_list = list(range(len(buf))) 97 | rng.shuffle(index_list) 98 | index_list = index_list[:n_test_trans] 99 | 100 | l2s, lsigs = [], [] 101 | for idx in index_list[:50]: 102 | i_t, o_t, a_t = _get_data(idx, buf, policy.ac_chunk, cam_idx) 103 | pred_ac = policy.forward(i_t, o_t) 104 | 105 | # calculate deltas 106 | l2 = np.linalg.norm(a_t - pred_ac) 107 | lsign = np.sum(np.logical_or(np.logical_and(a_t > 0, pred_ac <= 0), 108 | np.logical_and(a_t <= 0, pred_ac > 0))) 109 | l2s.append(l2); lsigs.append(lsign) 110 | 111 | print('\n') 112 | print('a_t', a_t) 113 | print('pred_ac', pred_ac) 114 | print(f'losses: l2={l2:0.2f}\tlsign={lsign}') 115 | print('\n') 116 | 117 | print(f'avg losses: l2={np.mean(l2s):0.3f}\tlsign={np.mean(lsigs):0.3f}') 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import os 8 | import traceback 9 | 10 | import numpy as np 11 | import torch 12 | import tqdm 13 | from omegaconf import DictConfig, OmegaConf 14 | 15 | import hydra 16 | from data4robotics import misc, transforms 17 | 18 | base_path = os.path.dirname(os.path.abspath(__file__)) 19 | 20 | 21 | @hydra.main( 22 | config_path=os.path.join(base_path, "experiments"), config_name="finetune.yaml" 23 | ) 24 | def bc_finetune(cfg: DictConfig): 25 | try: 26 | resume_model = misc.init_job(cfg) 27 | 28 | # set random seeds for reproducibility 29 | torch.manual_seed(cfg.seed) 30 | np.random.seed(cfg.seed + 1) 31 | 32 | # build agent from hydra configs 33 | with open("agent_config.yaml", "w") as f: 34 | agent_yaml = OmegaConf.to_yaml(cfg.agent, resolve=True) 35 | f.write(agent_yaml) 36 | 37 | agent = hydra.utils.instantiate(cfg.agent) 38 | trainer = hydra.utils.instantiate(cfg.trainer, model=agent, device_id=0) 39 | 40 | # build task, replay buffer, and dataloader 41 | task = hydra.utils.instantiate( 42 | cfg.task, batch_size=cfg.batch_size, num_workers=cfg.num_workers 43 | ) 44 | 45 | # create a gpu train transform (if used) 46 | gpu_transform = ( 47 | transforms.get_gpu_transform_by_name(cfg.train_transform) 48 | if "gpu" in cfg.train_transform 49 | else None 50 | ) 51 | 52 | # restore/save the model as required 53 | if resume_model is not None: 54 | misc.GLOBAL_STEP = trainer.load_checkpoint(resume_model) 55 | elif misc.GLOBAL_STEP == 0: 56 | trainer.save_checkpoint(cfg.checkpoint_path, misc.GLOBAL_STEP) 57 | assert misc.GLOBAL_STEP >= 0, "GLOBAL_STEP not loaded correctly!" 58 | 59 | # register checkpoint handler and enter train loop 60 | misc.set_checkpoint_handler(trainer, cfg.checkpoint_path) 61 | print(f"Starting at Global Step {misc.GLOBAL_STEP}") 62 | 63 | trainer.set_train() 64 | train_iterator = iter(task.train_loader) 65 | for itr in ( 66 | pbar := tqdm.tqdm(range(cfg.max_iterations), postfix=dict(Loss=None)) 67 | ): 68 | if itr < misc.GLOBAL_STEP: 69 | continue 70 | 71 | # infinitely sample batches until the train loop is finished 72 | try: 73 | batch = next(train_iterator) 74 | except StopIteration: 75 | train_iterator = iter(task.train_loader) 76 | batch = next(train_iterator) 77 | 78 | # handle the image transform on GPU if specified 79 | if gpu_transform is not None: 80 | (imgs, obs), actions, mask = batch 81 | imgs = {k: v.to(trainer.device_id) for k, v in imgs.items()} 82 | imgs = {k: gpu_transform(v) for k, v in imgs.items()} 83 | batch = ((imgs, obs), actions, mask) 84 | 85 | trainer.optim.zero_grad() 86 | loss = trainer.training_step(batch, misc.GLOBAL_STEP) 87 | loss.backward() 88 | trainer.optim.step() 89 | 90 | pbar.set_postfix(dict(Loss=loss.item())) 91 | misc.GLOBAL_STEP += 1 92 | 93 | if misc.GLOBAL_STEP % cfg.schedule_freq == 0: 94 | trainer.step_schedule() 95 | 96 | if misc.GLOBAL_STEP % cfg.eval_freq == 0: 97 | trainer.set_eval() 98 | task.eval(trainer, misc.GLOBAL_STEP) 99 | trainer.set_train() 100 | 101 | if misc.GLOBAL_STEP >= cfg.max_iterations: 102 | trainer.save_checkpoint(cfg.checkpoint_path, misc.GLOBAL_STEP) 103 | return 104 | elif misc.GLOBAL_STEP % cfg.save_freq == 0: 105 | trainer.save_checkpoint(cfg.checkpoint_path, misc.GLOBAL_STEP) 106 | 107 | # gracefully handle and log errors 108 | except Exception: 109 | traceback.print_exc(file=open("exception.log", "w")) 110 | with open("exception.log", "r") as f: 111 | print(f.read()) 112 | 113 | 114 | if __name__ == "__main__": 115 | bc_finetune() 116 | -------------------------------------------------------------------------------- /data4robotics/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import torch 8 | from torch import nn 9 | from torchvision import transforms 10 | 11 | 12 | class _MediumAug(nn.Module): 13 | def __init__(self, pad=0, size=224): 14 | super().__init__() 15 | self.pad = pad 16 | self.size = size 17 | self.norm = transforms.Normalize( 18 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 19 | ) 20 | 21 | def forward(self, x): 22 | extra_dim = len(x.shape) > 4 23 | if extra_dim: 24 | assert len(x.shape) == 5 25 | B, T, C, H, W = x.shape 26 | x = x.reshape((B * T, C, H, W)) 27 | 28 | n, c, h, w = x.size() 29 | assert h == w 30 | if self.pad > 0: 31 | padding = tuple([self.pad] * 4) 32 | x = torch.nn.functional.pad(x, padding, "replicate") 33 | eps = 1.0 / (h + 2 * self.pad) 34 | arange = torch.linspace( 35 | -1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype 36 | )[: self.size] 37 | arange = arange.unsqueeze(0).repeat(self.size, 1).unsqueeze(2) 38 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) 39 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) 40 | 41 | shift = torch.randint( 42 | 0, 43 | 2 * self.pad + h - self.size + 1, 44 | size=(n, 1, 1, 2), 45 | device=x.device, 46 | dtype=x.dtype, 47 | ) 48 | shift *= 2.0 / (h + 2 * self.pad) 49 | 50 | grid = base_grid + shift 51 | x = torch.nn.functional.grid_sample( 52 | x, grid, padding_mode="zeros", align_corners=False 53 | ) 54 | x = self.norm(x) 55 | 56 | if extra_dim: 57 | return x.reshape((B, T, C, self.size, self.size)) 58 | return x 59 | 60 | 61 | def get_gpu_transform_by_name(name, size=224): 62 | if name == "gpu_medium": 63 | return _MediumAug(size=size) 64 | raise NotImplementedError 65 | 66 | 67 | def get_transform_by_name(name, size=224): 68 | if "gpu" in name: 69 | return None 70 | 71 | if name == "preproc": 72 | return transforms.Compose( 73 | [ 74 | transforms.Resize((size, size), antialias=False), 75 | transforms.Normalize( 76 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 77 | ), 78 | ] 79 | ) 80 | if name == "basic": 81 | return transforms.Compose( 82 | [ 83 | transforms.RandomResizedCrop( 84 | size=size, scale=(0.2, 1.0), antialias=False 85 | ), 86 | transforms.Normalize( 87 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 88 | ), 89 | ] 90 | ) 91 | if name == "medium": 92 | kernel_size = int(0.05 * size) 93 | kernel_size = kernel_size + (1 - kernel_size % 2) 94 | return transforms.Compose( 95 | [ 96 | transforms.RandomResizedCrop( 97 | size=size, scale=(0.9, 1.0), antialias=False 98 | ), 99 | transforms.GaussianBlur(kernel_size=kernel_size), 100 | transforms.Normalize( 101 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 102 | ), 103 | ] 104 | ) 105 | if name == "hard": 106 | kernel_size = int(0.05 * size) 107 | kernel_size = kernel_size + (1 - kernel_size % 2) 108 | return transforms.Compose( 109 | [ 110 | transforms.RandomResizedCrop( 111 | size=size, scale=(0.2, 1.0), antialias=False 112 | ), 113 | transforms.GaussianBlur(kernel_size=kernel_size), 114 | transforms.Normalize( 115 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 116 | ), 117 | ] 118 | ) 119 | if name == "advanced": 120 | kernel_size = int(0.05 * size) 121 | kernel_size = kernel_size + (1 - kernel_size % 2) 122 | color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 123 | return transforms.Compose( 124 | [ 125 | transforms.RandomResizedCrop( 126 | size=size, scale=(0.2, 1.0), antialias=False 127 | ), 128 | transforms.RandomApply([color_jitter], p=0.8), 129 | transforms.RandomGrayscale(p=0.2), 130 | transforms.GaussianBlur(kernel_size=kernel_size), 131 | transforms.Normalize( 132 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 133 | ), 134 | ] 135 | ) 136 | raise NotImplementedError(f"{name} not found!") 137 | -------------------------------------------------------------------------------- /eval_scripts/eval_droid_state.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import time 5 | from collections import deque 6 | from pathlib import Path 7 | 8 | import cv2 9 | import numpy as np 10 | import torch 11 | import yaml 12 | 13 | # droid robot imports 14 | from droid.robot_env import RobotEnv 15 | from droid.user_interface.eval_gui import EvalGUI 16 | from scipy.spatial.transform import Rotation as R 17 | 18 | import hydra 19 | 20 | PRED_HORIZON = 16 21 | EXP_WEIGHT = 0 22 | GRIP_THRESH = 0.55 23 | 24 | 25 | def rmat_to_euler(rot_mat, degrees=False): 26 | euler = R.from_matrix(rot_mat).as_euler("xyz", degrees=degrees) 27 | return euler 28 | 29 | 30 | def normalize(vec, eps=1e-12): 31 | norm = np.linalg.norm(vec, axis=-1) 32 | norm = np.maximum(norm, eps) 33 | return vec / norm 34 | 35 | 36 | def rot6d_to_euler(d6): 37 | a1, a2 = d6[:3], d6[3:] 38 | b1 = normalize(a1) 39 | b2 = a2 - np.sum(b1 * a2, axis=-1, keepdims=True) * b1 40 | b2 = normalize(b2) 41 | b3 = np.cross(b1, b2, axis=-1) 42 | out = np.stack((b1, b2, b3), axis=-2) 43 | return rmat_to_euler(out) 44 | 45 | 46 | class DROIDStatePolicy: 47 | def __init__(self, agent_path, model_name): 48 | with open(Path(agent_path, "agent_config.yaml"), "r") as f: 49 | config_yaml = f.read() 50 | agent_config = yaml.safe_load(config_yaml) 51 | with open(Path(agent_path, "obs_config.yaml"), "r") as f: 52 | config_yaml = f.read() 53 | obs_config = yaml.safe_load(config_yaml) 54 | with open(Path(agent_path, "ac_norm.json"), "r") as f: 55 | ac_norm_dict = json.load(f) 56 | loc, scale = ac_norm_dict["loc"], ac_norm_dict["scale"] 57 | self.loc = np.array(loc).astype(np.float32) 58 | self.scale = np.array(scale).astype(np.float32) 59 | 60 | agent = hydra.utils.instantiate(agent_config) 61 | save_dict = torch.load(Path(agent_path, model_name), map_location="cpu") 62 | agent.load_state_dict(save_dict["model"]) 63 | self.agent = agent.eval().cuda() 64 | 65 | self.transform = hydra.utils.instantiate(obs_config["transform"]) 66 | self.img_key = obs_config["img"] 67 | 68 | print(f"loaded agent from {agent_path}, at step: {save_dict['global_step']}") 69 | self.act_history = deque(maxlen=PRED_HORIZON) 70 | self._last_time = None 71 | 72 | def _proc_image(self, zed_img, size=(256, 256)): 73 | bgr_img = zed_img[:, :, :3] 74 | bgr_img = cv2.resize(bgr_img, size, interpolation=cv2.INTER_AREA) 75 | rgb_img = bgr_img[:, :, ::-1].copy() 76 | rgb_img = torch.from_numpy(rgb_img).float().permute((2, 0, 1)) / 255 77 | return {"cam0": self.transform(rgb_img)[None].cuda()} 78 | 79 | def _proc_state(self, cart_pos, grip_pos): 80 | state = np.concatenate((cart_pos, np.array([grip_pos]))).astype(np.float32) 81 | return torch.from_numpy(state)[None].cuda() 82 | 83 | def forward(self, obs): 84 | img = self._proc_image(obs["image"][self.img_key]) 85 | state = self._proc_state( 86 | obs["robot_state"]["cartesian_position"], 87 | obs["robot_state"]["gripper_position"], 88 | ) 89 | 90 | with torch.no_grad(): 91 | ac = self.agent.get_actions(img, state) 92 | ac = ac[0].cpu().numpy().astype(np.float32)[:PRED_HORIZON] 93 | self.act_history.append(ac) 94 | 95 | # handle temporal blending 96 | num_actions = len(self.act_history) 97 | curr_act_preds = np.stack( 98 | [ 99 | pred_actions[i] 100 | for (i, pred_actions) in zip( 101 | range(num_actions - 1, -1, -1), self.act_history 102 | ) 103 | ] 104 | ) 105 | 106 | # more recent predictions get exponentially *less* weight than older predictions 107 | weights = np.exp(-EXP_WEIGHT * np.arange(num_actions)) 108 | weights = weights / weights.sum() 109 | # compute the weighted average across all predictions for this timestep 110 | ac = np.sum(weights[:, None] * curr_act_preds, axis=0) 111 | 112 | # denormalize the actions and swap to R6 113 | ac = ac * self.scale + self.loc 114 | if len(ac) == 10: 115 | xyz, r6, grip = ac[:3], ac[3:9], ac[9:] 116 | ac = np.concatenate((xyz, rot6d_to_euler(r6), grip)) 117 | assert len(ac) == 7, "Assuming 7d action dim!" 118 | 119 | # threshold the gripper to make crisp grasp decisions 120 | if ac[-1] > GRIP_THRESH: 121 | ac[-1] = 1.0 122 | 123 | print("current", obs["robot_state"]["cartesian_position"]) 124 | print("action", ac) 125 | cur_time = time.time() 126 | if self._last_time is not None: 127 | print("Effective HZ:", 1.0 / (cur_time - self._last_time)) 128 | self._last_time = cur_time 129 | print() 130 | return ac 131 | 132 | def load_goal_imgs(self, goal_dict): 133 | pass 134 | 135 | def load_lang(self, text): 136 | pass 137 | 138 | 139 | def main(): 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument("checkpoint") 142 | args = parser.parse_args() 143 | 144 | agent_path = os.path.expanduser(os.path.dirname(args.checkpoint)) 145 | model_name = args.checkpoint.split("/")[-1] 146 | policy = DROIDStatePolicy(agent_path, model_name) 147 | 148 | # start up DROID eval gui 149 | env = RobotEnv(action_space="cartesian_position") 150 | EvalGUI(policy=policy, env=env) 151 | 152 | 153 | if __name__ == "__main__": 154 | main() 155 | -------------------------------------------------------------------------------- /data4robotics/replay_buffer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import os 8 | import pickle as pkl 9 | import random 10 | import shutil 11 | 12 | import numpy as np 13 | import torch 14 | import tqdm 15 | from robobuf import ReplayBuffer as RB 16 | from tensorflow.io import gfile 17 | from torch.utils.data import Dataset, IterableDataset 18 | 19 | # cache loading from the buffer list to half memory overhead 20 | buf_cache = dict() 21 | BUF_SHUFFLE_RNG = 3904767649 22 | 23 | 24 | # helper functions 25 | _img_to_tensor = ( 26 | lambda x: torch.from_numpy(x.copy()).permute((0, 3, 1, 2)).float() / 255 27 | ) 28 | _to_tensor = lambda x: torch.from_numpy(x).float() 29 | 30 | 31 | def _cached_load(path): 32 | global buf_cache 33 | 34 | if path in buf_cache: 35 | return buf_cache[path] 36 | 37 | with gfile.GFile(path, "rb") as f: 38 | buf = RB.load_traj_list(pkl.load(f)) 39 | buf_cache[path] = buf 40 | return buf 41 | 42 | 43 | def _get_imgs(t, cam_idx, past_frames): 44 | imgs = [] 45 | while len(imgs) < past_frames + 1: 46 | imgs.append(t.obs.image(cam_idx)[None]) 47 | 48 | if t.prev is not None: 49 | t = t.prev 50 | return np.concatenate(imgs, axis=0) 51 | 52 | 53 | class IterableWrapper(IterableDataset): 54 | def __init__(self, wrapped_dataset, max_count=float("inf")): 55 | self.wrapped = wrapped_dataset 56 | self.ctr, self.max_count = 0, max_count 57 | 58 | def __iter__(self): 59 | self.ctr = 0 60 | return self 61 | 62 | def __next__(self): 63 | if self.ctr > self.max_count: 64 | raise StopIteration 65 | 66 | self.ctr += 1 67 | idx = int(np.random.choice(len(self.wrapped))) 68 | return self.wrapped[idx] 69 | 70 | 71 | class RobobufReplayBuffer(Dataset): 72 | def __init__( 73 | self, 74 | buffer_path, 75 | transform=None, 76 | n_test_trans=500, 77 | mode="train", 78 | ac_chunk=1, 79 | cam_indexes=[0], 80 | goal_indexes=[], 81 | goal_geom_prob=0.01, 82 | past_frames=0, 83 | ac_dim=7, 84 | ): 85 | assert mode in ("train", "test"), "Mode must be train/test" 86 | buf = _cached_load(buffer_path) 87 | assert len(buf) > n_test_trans, "Not enough transitions!" 88 | 89 | norm_file = os.path.join(os.path.dirname(buffer_path), "ac_norm.json") 90 | if os.path.exists(norm_file): 91 | shutil.copyfile(norm_file, "./ac_norm.json") 92 | 93 | # shuffle the list with the fixed seed 94 | rng = random.Random(BUF_SHUFFLE_RNG) 95 | 96 | # get and shuffle list of buf indices 97 | index_list = list(range(len(buf))) 98 | rng.shuffle(index_list) 99 | 100 | # split data according to mode 101 | index_list = ( 102 | index_list[n_test_trans:] if mode == "train" else index_list[:n_test_trans] 103 | ) 104 | 105 | self.transform = transform 106 | self.s_a_mask = [] 107 | 108 | self.cam_indexes = cam_indexes = list(cam_indexes) 109 | self.past_frames = past_frames 110 | print(f"Building {mode} buffer with cam_indexes={cam_indexes}") 111 | 112 | self.goal_geom_prob = goal_geom_prob 113 | self.goal_indexes = set(goal_indexes) 114 | assert all([g in self.cam_indexes for g in self.goal_indexes]) 115 | 116 | for idx in tqdm.tqdm(index_list): 117 | t = buf[idx] 118 | 119 | loop_t, chunked_actions, loss_mask = t, [], [] 120 | for _ in range(ac_chunk): 121 | chunked_actions.append(loop_t.action[None]) 122 | loss_mask.append(1.0) 123 | 124 | if loop_t.next is None: 125 | break 126 | loop_t = loop_t.next 127 | 128 | if len(chunked_actions) < ac_chunk: 129 | for _ in range(ac_chunk - len(chunked_actions)): 130 | chunked_actions.append(chunked_actions[-1]) 131 | loss_mask.append(0.0) 132 | 133 | a_t = np.concatenate(chunked_actions, 0).astype(np.float32) 134 | assert ac_dim == a_t.shape[-1] 135 | 136 | loss_mask = np.array(loss_mask, dtype=np.float32) 137 | self.s_a_mask.append((t, a_t, loss_mask, loop_t)) 138 | 139 | def __len__(self): 140 | return len(self.s_a_mask) 141 | 142 | def __getitem__(self, idx): 143 | step, a_t, loss_mask, goal = self.s_a_mask[idx] 144 | 145 | if self.goal_indexes: 146 | while np.random.uniform() > self.goal_geom_prob and goal.next is not None: 147 | goal = goal.next 148 | 149 | i_t, o_t = dict(), step.obs.state 150 | for idx, cam_idx in enumerate(self.cam_indexes): 151 | i_c = _get_imgs(step, cam_idx, self.past_frames) 152 | if self.goal_indexes: 153 | g_c = ( 154 | _get_imgs(goal, cam_idx, 0) 155 | if cam_idx in self.goal_indexes 156 | else np.zeros_like(i_c[:1]) 157 | ) 158 | i_c = np.concatenate((g_c, i_c), axis=0) 159 | 160 | i_c = _img_to_tensor(i_c) 161 | if self.transform is not None: 162 | i_c = self.transform(i_c) 163 | 164 | i_t[f"cam{idx}"] = i_c 165 | 166 | o_t, a_t = _to_tensor(o_t), _to_tensor(a_t) 167 | loss_mask = _to_tensor(loss_mask)[:, None].repeat((1, a_t.shape[-1])) 168 | assert ( 169 | loss_mask.shape[0] == a_t.shape[0] 170 | ), "a_t and mask shape must be ac_chunk!" 171 | return (i_t, o_t), a_t, loss_mask 172 | -------------------------------------------------------------------------------- /data4robotics/sim/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import json 8 | import os 9 | import random 10 | 11 | import cv2 12 | import imageio 13 | import numpy as np 14 | import torch 15 | import tqdm 16 | import wandb 17 | from torch.utils.data import Dataset 18 | 19 | from data4robotics.task import DefaultTask 20 | 21 | # helper functions 22 | _img_to_tensor = ( 23 | lambda x: torch.from_numpy(x.copy()).permute((0, 3, 1, 2)).float() / 255 24 | ) 25 | _to_tensor = lambda x: torch.from_numpy(x).float() 26 | BUF_SHUFFLE_RNG = 3904767649 27 | _VID_FPS = 10 28 | 29 | 30 | class SimTask(DefaultTask): 31 | def __init__( 32 | self, 33 | train_buffer, 34 | test_transform, 35 | task, 36 | n_cams, 37 | obs_dim, 38 | ac_dim, 39 | batch_size, 40 | num_workers, 41 | ): 42 | self.build_eval_env(task, n_cams, obs_dim, ac_dim) 43 | self.test_transform = test_transform 44 | super().__init__( 45 | train_buffer, None, n_cams, obs_dim, ac_dim, batch_size, num_workers 46 | ) 47 | 48 | def build_eval_env(self, task, n_cams, obs_dim, ac_dim): 49 | raise NotImplementedError() 50 | 51 | def rollout_sim(self, agent, device_id, frame_buffer): 52 | raise NotImplementedError() 53 | 54 | def eval(self, trainer, global_step): 55 | frame_buffer = [] 56 | mean_success, mean_rewards = self.rollout_sim( 57 | trainer.model, trainer.device_id, frame_buffer 58 | ) 59 | 60 | # load success log if it exists or make a new one 61 | rollout_logs = dict(success=[], reward=[], step=[]) 62 | if os.path.exists("rollout_log.json"): 63 | rollout_logs = json.load(open("rollout_log.json", "r")) 64 | rollout_logs["success"].append(mean_success) 65 | rollout_logs["reward"].append(mean_rewards) 66 | rollout_logs["step"].append(global_step) 67 | with open("rollout_log.json", "w") as f: 68 | json.dump(rollout_logs, f) 69 | f.write("\n") 70 | 71 | max_success = max(rollout_logs["success"]) 72 | print(f"Max Success={max_success} @ step={global_step}") 73 | 74 | vid_path = f"rollout_itr_{global_step:07d}.mp4" 75 | writer = imageio.get_writer(vid_path, fps=_VID_FPS) 76 | for im in frame_buffer: 77 | im_out = cv2.resize(im, (128, 128), interpolation=cv2.INTER_AREA) 78 | writer.append_data(im_out) 79 | writer.close() 80 | 81 | if wandb.run is not None: 82 | wandb.log( 83 | { 84 | "eval/max_success": max_success, 85 | "eval/success": mean_success, 86 | "eval/rewards": mean_rewards, 87 | "eval/rollout": wandb.Video(vid_path), 88 | }, 89 | step=global_step, 90 | ) 91 | 92 | 93 | class SimTaskReplayBuffer(Dataset): 94 | def __init__( 95 | self, 96 | task, 97 | transform=None, 98 | n_train_demos=200, 99 | ac_chunk=1, 100 | cam_indexes=[0], 101 | goal_indexes=[], 102 | past_frames=0, 103 | ): 104 | 105 | # these tasks don't require conditioning so skip 106 | assert cam_indexes == [0], "only need 0th cam" 107 | assert goal_indexes == [], "no need for goal indexes" 108 | assert past_frames == 0, "past frames should be 0" 109 | 110 | buffer_data = self.load_buffer(task) 111 | assert len(buffer_data) >= n_train_demos, "Not enough demos!" 112 | 113 | # shuffle the list with the fixed seed 114 | rng = random.Random(BUF_SHUFFLE_RNG) 115 | rng.shuffle(buffer_data) 116 | 117 | # take n_train_demos demos for training 118 | buffer_data = buffer_data[:n_train_demos] 119 | 120 | self.transform = transform 121 | self.s_a_mask = [] 122 | for traj in tqdm.tqdm(buffer_data): 123 | imgs, obs, acs = traj["images"], traj["observations"], traj["actions"] 124 | assert len(obs) == len(acs) and len(acs) == len( 125 | imgs 126 | ), "All time dimensions must match!" 127 | 128 | # pad camera dimension if needed 129 | if len(imgs.shape) == 4: 130 | imgs = imgs[:, None] 131 | 132 | for t in range(len(imgs) - 1): 133 | i_t = {f"cam{c}": imgs[t, c][None] for c in range(imgs.shape[1])} 134 | o_t = obs[t] 135 | 136 | loss_mask = np.ones((ac_chunk,), dtype=np.float32) 137 | a_t = acs[t : t + ac_chunk] 138 | assert len(a_t) > 0 139 | if len(a_t) < ac_chunk: 140 | missing = ac_chunk - len(a_t) 141 | 142 | action_pad = np.zeros((missing, a_t.shape[-1])).astype(np.float32) 143 | a_t = np.concatenate((a_t, action_pad), axis=0) 144 | loss_mask[-missing:] = 0 145 | 146 | self.s_a_mask.append(((i_t, o_t), a_t, loss_mask)) 147 | 148 | def load_buffer(self, buffer_path): 149 | raise NotImplementedError 150 | 151 | def __len__(self): 152 | return len(self.s_a_mask) 153 | 154 | def __getitem__(self, idx): 155 | (i_t, o_t), a_t, loss_mask = self.s_a_mask[idx] 156 | 157 | i_t = {k: _img_to_tensor(v) for k, v in i_t.items()} 158 | if self.transform is not None: 159 | i_t = {k: self.transform(v) for k, v in i_t.items()} 160 | 161 | o_t, a_t = _to_tensor(o_t), _to_tensor(a_t) 162 | loss_mask = _to_tensor(loss_mask)[:, None].repeat((1, a_t.shape[-1])) 163 | assert ( 164 | loss_mask.shape[0] == a_t.shape[0] 165 | ), "a_t and mask shape must be ac_chunk!" 166 | return (i_t, o_t), a_t, loss_mask 167 | -------------------------------------------------------------------------------- /data4robotics/models/action_distributions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import torch 8 | import torch.distributions as D 9 | import torch.nn.functional as F 10 | from torch import nn 11 | 12 | 13 | class ActionDistribution(nn.Module): 14 | def __init__(self, ac_dim, ac_chunk=1): 15 | super().__init__() 16 | self._ac_chunk, self._ac_dim = ac_chunk, ac_dim 17 | 18 | @property 19 | def ac_dim(self): 20 | return self._ac_dim 21 | 22 | @property 23 | def ac_chunk(self): 24 | return self._ac_chunk 25 | 26 | @property 27 | def num_ac_pred(self): 28 | return self._ac_chunk * self._ac_dim 29 | 30 | def unflatten_ac_tensor(self, ac_tensor): 31 | out_shape = list(ac_tensor.shape[:-1]) + [self._ac_chunk, self._ac_dim] 32 | return ac_tensor.reshape(out_shape) 33 | 34 | def get_actions(self, inputs, zero_std=True): 35 | acs = self._sample(inputs, zero_std) 36 | return self.unflatten_ac_tensor(acs) 37 | 38 | def _sample(self, inputs, zero_std=True): 39 | dist = self(inputs, zero_std) 40 | return dist.sample() 41 | 42 | 43 | class Deterministic(ActionDistribution): 44 | def __init__(self, in_dim, ac_dim, ac_chunk=1): 45 | super().__init__(ac_dim, ac_chunk) 46 | self._layer = nn.Linear(in_dim, self.num_ac_pred) 47 | 48 | def forward(self, inputs, zero_std=True): 49 | assert zero_std, "No std prediction in this network!" 50 | return self._layer(inputs) 51 | 52 | def _sample(self, inputs, zero_std=True): 53 | return self(inputs, zero_std) 54 | 55 | 56 | class Gaussian(ActionDistribution): 57 | def __init__(self, in_dim, ac_dim, ac_chunk=1, min_std=1e-4, tanh_mean=False): 58 | super().__init__(ac_dim, ac_chunk) 59 | self._min_std, self._tanh_mean = min_std, tanh_mean 60 | self._mean_net = nn.Linear(in_dim, self.num_ac_pred) 61 | self._scale_net = nn.Linear(in_dim, self.num_ac_pred) 62 | 63 | def forward(self, in_repr, zero_std=False): 64 | B = in_repr.shape[0] 65 | mean = self._mean_net(in_repr).reshape(B, self.num_ac_pred) 66 | scale = self._scale_net(in_repr).reshape(B, self.num_ac_pred) 67 | 68 | # bound the action means and convert scale to std 69 | if self._tanh_mean: 70 | mean = torch.tanh(mean) 71 | std = ( 72 | torch.ones_like(scale) * self._min_std 73 | if zero_std 74 | else F.softplus(scale) + self._min_std 75 | ) 76 | 77 | # create Normal action distributions 78 | return D.Normal(loc=mean, scale=std) 79 | 80 | 81 | class GaussianSharedScale(ActionDistribution): 82 | def __init__( 83 | self, 84 | in_dim, 85 | ac_dim, 86 | ac_chunk=1, 87 | min_std=1e-4, 88 | tanh_mean=False, 89 | log_std_init=0, 90 | std_fixed=False, 91 | ): 92 | super().__init__(ac_dim, ac_chunk) 93 | self._min_std, self._tanh_mean = min_std, tanh_mean 94 | self._mean_net = nn.Linear(in_dim, self.num_ac_pred) 95 | 96 | # create log_std vector and store as param 97 | log_std = torch.Tensor([log_std_init] * ac_dim) 98 | self.register_parameter( 99 | "log_std", nn.Parameter(log_std, requires_grad=not std_fixed) 100 | ) 101 | 102 | def forward(self, in_repr, zero_std=False): 103 | B = in_repr.shape[0] 104 | mean = self._mean_net(in_repr).reshape(B, self.num_ac_pred) 105 | scale = self.log_std[None].repeat((B, self._ac_chunk)) 106 | 107 | if self._tanh_mean: 108 | mean = torch.tanh(mean) 109 | std = ( 110 | torch.ones_like(scale) * self._min_std 111 | if zero_std 112 | else torch.exp(scale) + self._min_std 113 | ) 114 | 115 | # create Normal action distributions 116 | return D.Normal(loc=mean, scale=std) 117 | 118 | 119 | class _MaskedIndependent(D.Independent): 120 | def masked_log_prob(self, value, mask): 121 | log_prob = self.base_dist.log_prob(value) 122 | return (log_prob * mask).sum(-1) 123 | 124 | 125 | class _MixtureHelper(D.MixtureSameFamily): 126 | def masked_log_prob(self, x, mask): 127 | if self._validate_args: 128 | self._validate_sample(x) 129 | x, mask = self._pad(x), mask[:, None] 130 | log_prob_x = self.component_distribution.masked_log_prob(x, mask) # [S, B, k] 131 | log_mix_prob = torch.log_softmax( 132 | self.mixture_distribution.logits, dim=-1 133 | ) # [B, k] 134 | return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1) # [S, B] 135 | 136 | 137 | class GaussianMixture(ActionDistribution): 138 | def __init__( 139 | self, num_modes, in_dim, ac_dim, ac_chunk=1, min_std=1e-4, tanh_mean=False 140 | ): 141 | super().__init__(ac_dim, ac_chunk) 142 | self._min_std, self._tanh_mean = min_std, tanh_mean 143 | self._num_modes = num_modes 144 | 145 | self._mean_net = nn.Linear(in_dim, num_modes * self.num_ac_pred) 146 | self._scale_net = nn.Linear(in_dim, num_modes * self.num_ac_pred) 147 | self._logit_net = nn.Linear(in_dim, num_modes) 148 | 149 | def forward(self, in_repr, zero_std=False): 150 | B = in_repr.shape[0] 151 | mean = self._mean_net(in_repr).reshape(B, self._num_modes, self.num_ac_pred) 152 | scale = self._scale_net(in_repr).reshape(B, self._num_modes, self.num_ac_pred) 153 | logits = self._logit_net(in_repr).reshape((B, self._num_modes)) 154 | 155 | # bound the action means and convert scale to std 156 | if self._tanh_mean: 157 | mean = torch.tanh(mean) 158 | std = ( 159 | torch.ones_like(scale) * self._min_std 160 | if zero_std 161 | else F.softplus(scale) + self._min_std 162 | ) 163 | 164 | # create num_modes independent action distributions 165 | ac_dist = D.Normal(loc=mean, scale=std) 166 | ac_dist = _MaskedIndependent(ac_dist, 1) 167 | 168 | # parameterize the mixing distribution and the final GMM 169 | mix_dist = D.Categorical(logits=logits) 170 | gmm_dist = _MixtureHelper( 171 | mixture_distribution=mix_dist, component_distribution=ac_dist 172 | ) 173 | return gmm_dist 174 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # The Ingredients for Robotic Diffusion Transformers 2 | [![arXiv](https://img.shields.io/badge/arXiv-2410.10088-df2a2a.svg)](https://arxiv.org/pdf/2410.10088) 3 | [![HF Dataset](https://img.shields.io/badge/%F0%9F%A4%97-Dataset-yellow)](https://huggingface.co/datasets/oier-mees/BiPlay) 4 | [![Python](https://img.shields.io/badge/python-3.9-blue)](https://www.python.org) 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 6 | [![Static Badge](https://img.shields.io/badge/Project-Page-a)](https://dit-policy.github.io/) 7 | 8 | [Sudeep Dasari](https://sudeepdasari.github.io/), [Oier Mees](https://www.oiermees.com/), [Sebastian Zhao](http://linkedin.com/in/sebbyzhao/), [Mohan Kumar Srirama](https://scholar.google.com/citations?user=Yu18Q6MAAAAJ&hl=en/), [Sergey Levine](https://people.eecs.berkeley.edu/~svlevine/) 9 |
10 | 11 | This repository offers an implementation of our improved Diffusion Transformer Policy (DiT-Block Policy), which achieves state-of-the-art manipulation results on long horizon bi-manual ALOHA robots and single-arm DROID Franka robots. This repo also allows easy use of our advanced pre-trained representations from [prior](https://data4robotics.github.io) [work](https://hrp-robot.github.io). We've succesfully deployed policies from this code on Franka robots (w/ [DROID](https://github.com/droid-dataset/droid/tree/main) and [MaNiMo](https://github.com/AGI-Labs/manimo)), [ALOHA](https://tonyzhaozh.github.io/aloha/) robots, and on [LEAP hands](https://www.leaphand.com). Check out our [eval scripts](eval_scripts/README.md) for more information. These policies can also be tested in simulation (see [Sim README](https://github.com/SudeepDasari/data4robotics/tree/dit_release/data4robotics/sim)). 12 | 13 | ![](media/aloha_teaser.png) 14 | 15 | ## Installation 16 | Our repository is easy to install using miniconda or anaconda: 17 | 18 | ``` 19 | conda env create -f env.yml 20 | conda activate data4robotics 21 | pip install git+https://github.com/AGI-Labs/robobuf.git 22 | pip install git+https://github.com/facebookresearch/r3m.git 23 | pip install -e ./ 24 | pre-commit install # required for pushing back to the source git 25 | ``` 26 | 27 | ## Training DiT Policies (and Baselines) 28 | First, you're going to need to convert your training trajectories into our [robobuf](https://github.com/AGI-Labs/robobuf/tree/main) format (pseudo-code below). Check out some example ALOHA and DROID conversion code [here](https://github.com/AGI-Labs/r2d2_to_robobuf). 29 | 30 | ``` 31 | def _resize_and_encode(rgb_img, size=(256,256)): 32 | bgr_image = cv2.resize(bgr_image, size, interpolation=cv2.INTER_AREA) 33 | _, encoded = cv2.imencode(".jpg", bgr_image) 34 | return encoded 35 | 36 | def convert_trajectories(input_trajs, out_path): 37 | out_buffer = [] 38 | for traj in tqdm(input_trajs): 39 | out_traj = [] 40 | for in_obs, in_ac, in_reward in enumerate(data): 41 | out_obs = dict(state=np.array(in_obs['state']).astype(np.float32), 42 | enc_cam_0=_resize_and_encode(in_obs['image'])) 43 | out_action = np.array(in_ac).astype(np.float32) 44 | out_reward = float(in_reward) 45 | out_traj.append((out_obs, out_action, out_reward)) 46 | out_buffer.append(out_traj) 47 | 48 | with open(os.path.join(out_path, 'buf.pkl'), 'wb') as f: 49 | pkl.dump(out_trajs, f) 50 | ``` 51 | 52 | Once the conversion is complete, you can train our models using the example commands below: 53 | ``` 54 | # Training DiT Policy (Diffusion Transformer w/ adaLN + ResNet Tokenizer) 55 | python finetune.py exp_name=test agent=diffusion task=end_effector_r6 agent/features=resnet_gn agent.features.restore_path=/pat/to/resnet18/IN_1M_resnet18.pth trainer=bc_cos_sched ac_chunk=100 56 | 57 | ## SOME EXAMPLE BASELINES 58 | 59 | # Gaussian Mixture Model bc-policy with SOUP representations 60 | python finetune.py exp_name=test agent.features.restore_path=/path/to/SOUP_1M_DH.pth buffer_path=/data/path/buffer.pkl 61 | 62 | # Diffusion Policy (U-Net head) w/ HRP representations 63 | python finetune.py exp_name=test agent=diffusion_unet task=end_effector_r6 agent/features=vit_base agent.features.restore_path=/path/to/IN_hrp.pth buffer_path=/data/path/buffer.pkl trainer=bc_cos_sched ac_chunk=16 64 | ``` 65 | This will result in a policy checkpoint saved in the `bc_finetune/` folder. 66 | 67 | ## Downloading the Bi-Play Dataset 68 | We also provide an open-sourced dataset, named BiPlay, with over 7000 diverse, text-annotated, bi-manual expert demonstrations collected on an ALOHA robot. You may download the dataset from the following [HuggingFace dataset](https://huggingface.co/datasets/oier-mees/BiPlay). It can be loaded out of the box with the dataloader from [Octo](https://octo-models.github.io). 69 | 70 |

71 | Aloha Dataset 72 |

73 | 74 | 75 | ## Using Pre-Trained Features 76 | You can easily download our pre-trained represenations using the provided script: `./download_features.sh`. You may also download the features individually on our [release website](https://www.cs.cmu.edu/~data4robotics/release/). 77 | 78 | The features are very modular, and easy to use in your own code-base! Please refer to the [example code](https://github.com/SudeepDasari/data4robotics/blob/main/pretrained_networks_example.py) if you're interested. 79 | 80 | ## Policy Deployment (Sim and Real) 81 | 82 | Detailed instructions and eval scripts for real world deployment are provided [here](https://github.com/SudeepDasari/data4robotics/tree/dit_release/eval_scripts). Similarly, you can reproduce our sim results, using the command/code provided [here](https://github.com/SudeepDasari/data4robotics/tree/dit_release/data4robotics/sim). 83 | 84 | ## Citations 85 | If you find this codebase or the diffusion transformer useful, please cite: 86 | ``` 87 | @article{dasari2024ditpi, 88 | title={The Ingredients for Robotic Diffusion Transformers}, 89 | author = {Sudeep Dasari and Oier Mees and Sebastian Zhao and Mohan Kumar Srirama and Sergey Levine}, 90 | journal = {arXiv preprint arXiv:2410.10088}, 91 | year={2024}, 92 | } 93 | ``` 94 | 95 | And if you use the representations, please cite: 96 | ``` 97 | @inproceedings{dasari2023datasets, 98 | title={An Unbiased Look at Datasets for Visuo-Motor Pre-Training}, 99 | author={Dasari, Sudeep and Srirama, Mohan Kumar and Jain, Unnat and Gupta, Abhinav}, 100 | booktitle={Conference on Robot Learning}, 101 | year={2023}, 102 | organization={PMLR} 103 | } 104 | 105 | @inproceedings{kumar2024hrp, 106 | title={HRP: Human Affordances for Robotic Pre-Training}, 107 | author = {Mohan Kumar Srirama and Sudeep Dasari and Shikhar Bahl and Abhinav Gupta}, 108 | booktitle = {Proceedings of Robotics: Science and Systems}, 109 | address = {Delft, Netherlands}, 110 | year = {2024}, 111 | } 112 | ``` -------------------------------------------------------------------------------- /data4robotics/agent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import copy 8 | 9 | import torch 10 | from torch import nn 11 | 12 | 13 | class _BatchNorm1DHelper(nn.BatchNorm1d): 14 | def forward(self, x): 15 | if len(x.shape) == 3: 16 | x = x.transpose(1, 2) 17 | x = super().forward(x) 18 | return x.transpose(1, 2) 19 | return super().forward(x) 20 | 21 | 22 | class BaseAgent(nn.Module): 23 | def __init__( 24 | self, 25 | odim, 26 | features, 27 | n_cams, 28 | imgs_per_cam, 29 | use_obs=False, 30 | share_cam_features=False, 31 | early_fusion=False, 32 | dropout=0, 33 | feat_norm=None, 34 | token_dim=None, 35 | ): 36 | super().__init__() 37 | 38 | # store visual features (duplicate weights if shared) 39 | self._share_cam_features = share_cam_features 40 | if self._share_cam_features: 41 | self.visual_features = features 42 | else: 43 | feat_list = [features] + [copy.deepcopy(features) for _ in range(1, n_cams)] 44 | self.visual_features = nn.ModuleList(feat_list) 45 | 46 | self.early_fusion = early_fusion 47 | imgs_per_cam = 1 if early_fusion else imgs_per_cam 48 | self._token_dim = features.embed_dim 49 | self._n_tokens = imgs_per_cam * n_cams * features.n_tokens 50 | 51 | # handle obs tokenization strategies 52 | if use_obs == "add_token": 53 | self._obs_strat = "add_token" 54 | self._n_tokens += 1 55 | self._obs_proc = nn.Sequential( 56 | nn.Dropout(p=0.2), nn.Linear(odim, self._token_dim) 57 | ) 58 | elif use_obs == "pad_img_tokens": 59 | self._obs_strat = "pad_img_tokens" 60 | self._token_dim += odim 61 | self._obs_proc = nn.Dropout(p=0.2) 62 | else: 63 | assert not use_obs 64 | self._obs_strat = None 65 | 66 | # build (optional) token feature projection layer 67 | linear_proj = nn.Identity() 68 | if token_dim is not None and token_dim != self._token_dim: 69 | linear_proj = nn.Linear(self._token_dim, token_dim) 70 | self._token_dim = token_dim 71 | 72 | # build feature normalization layers 73 | if feat_norm == "batch_norm": 74 | norm = _BatchNorm1DHelper(self._token_dim) 75 | elif feat_norm == "layer_norm": 76 | norm = nn.LayerNorm(self._token_dim) 77 | else: 78 | assert feat_norm is None 79 | norm = nn.Identity() 80 | 81 | # final token post proc network 82 | self.post_proc = nn.Sequential(linear_proj, norm, nn.Dropout(dropout)) 83 | 84 | def forward(self, imgs, obs, ac_flat, mask_flat): 85 | raise NotImplementedError 86 | 87 | def get_actions(self, img, obs): 88 | raise NotImplementedError 89 | 90 | def tokenize_obs(self, imgs, obs, flatten=False): 91 | # start by getting image tokens 92 | tokens = self.embed(imgs) 93 | 94 | if self._obs_strat == "add_token": 95 | obs_token = self._obs_proc(obs)[:, None] 96 | tokens = torch.cat((tokens, obs_token), 1) 97 | elif self._obs_strat == "pad_img_tokens": 98 | obs = self._obs_proc(obs) 99 | obs = obs[:, None].repeat((1, tokens.shape[1], 1)) 100 | tokens = torch.cat((obs, tokens), 2) 101 | else: 102 | assert self._obs_strat is None 103 | 104 | tokens = self.post_proc(tokens) 105 | if flatten: 106 | return tokens.reshape((tokens.shape[0], -1)) 107 | return tokens 108 | 109 | def embed(self, imgs): 110 | def embed_helper(net, im): 111 | if self.early_fusion and len(im.shape) == 5: 112 | T = im.shape[1] 113 | im = torch.cat([im[:, t] for t in range(T)], 1) 114 | return net(im) 115 | elif len(im.shape) == 5: 116 | B, T, C, H, W = im.shape 117 | embeds = net(im.reshape((B * T, C, H, W))) 118 | embeds = embeds.reshape((B, -1, net.embed_dim)) 119 | return embeds 120 | 121 | assert len(im.shape) == 4 122 | return net(im) 123 | 124 | if self._share_cam_features: 125 | embeds = [ 126 | embed_helper(self.visual_features, imgs[f"cam{i}"]) 127 | for i in range(self._n_cams) 128 | ] 129 | else: 130 | embeds = [ 131 | embed_helper(net, imgs[f"cam{i}"]) 132 | for i, net in enumerate(self.visual_features) 133 | ] 134 | return torch.cat(embeds, dim=1) 135 | 136 | @property 137 | def n_cams(self): 138 | return self._n_cams 139 | 140 | @property 141 | def ac_chunk(self): 142 | raise NotImplementedError 143 | 144 | @property 145 | def token_dim(self): 146 | return self._token_dim 147 | 148 | @property 149 | def n_tokens(self): 150 | return self._n_tokens 151 | 152 | 153 | class MLPAgent(BaseAgent): 154 | def __init__( 155 | self, 156 | features, 157 | policy, 158 | shared_mlp, 159 | odim, 160 | n_cams, 161 | use_obs, 162 | imgs_per_cam=1, 163 | dropout=0, 164 | share_cam_features=False, 165 | early_fusion=False, 166 | feat_norm="layer_norm", 167 | token_dim=None, 168 | ): 169 | 170 | # initialize obs and img tokenizers 171 | super().__init__( 172 | odim=odim, 173 | features=features, 174 | n_cams=n_cams, 175 | imgs_per_cam=imgs_per_cam, 176 | use_obs=use_obs, 177 | share_cam_features=share_cam_features, 178 | early_fusion=early_fusion, 179 | dropout=dropout, 180 | feat_norm=feat_norm, 181 | token_dim=token_dim, 182 | ) 183 | 184 | # assign policy class 185 | self._policy = policy 186 | 187 | mlp_in = self.n_tokens * self.token_dim 188 | mlp_def = [mlp_in] + shared_mlp 189 | layers = [] 190 | for i, o in zip(mlp_def[:-1], mlp_def[1:]): 191 | layers.append(nn.Linear(i, o)) 192 | layers.append(nn.ReLU()) 193 | layers.append(nn.Dropout(dropout)) 194 | self._mlp = nn.Sequential(*layers) 195 | 196 | def forward(self, imgs, obs, ac_flat, mask_flat): 197 | s_t = self._mlp_forward(imgs, obs) 198 | action_dist = self._policy(s_t) 199 | loss = ( 200 | -torch.mean(action_dist.masked_log_prob(ac_flat, mask_flat)) 201 | if hasattr(action_dist, "masked_log_prob") 202 | else -(action_dist.log_prob(ac_flat) * mask_flat).sum() / mask_flat.sum() 203 | ) 204 | return loss 205 | 206 | def get_actions(self, img, obs, zero_std=True): 207 | policy_in = self._mlp_forward(img, obs) 208 | return self._policy.get_actions(policy_in, zero_std=zero_std) 209 | 210 | def _mlp_forward(self, imgs, obs): 211 | tokens_flat = self.tokenize_obs(imgs, obs, flatten=True) 212 | return self._mlp(tokens_flat) 213 | 214 | @property 215 | def ac_chunk(self): 216 | return self._policy.ac_chunk 217 | -------------------------------------------------------------------------------- /data4robotics/sim/robosuite.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import os 8 | from collections import deque 9 | 10 | import h5py 11 | import numpy as np 12 | import robomimic.utils.env_utils as EnvUtils 13 | import robomimic.utils.file_utils as FileUtils 14 | import robomimic.utils.obs_utils as ObsUtils 15 | import torch 16 | import tqdm 17 | 18 | from data4robotics.sim import SimTask, SimTaskReplayBuffer 19 | 20 | os.environ["MUJOCO_GL"] = "egl" 21 | OBS_KEYS = [ 22 | "robot0_eef_pos", 23 | "robot0_eef_quat", 24 | "robot0_gripper_qpos", 25 | "robot0_gripper_qvel", 26 | ] 27 | _MAX_STEPS = 200 28 | _N_ROLLOUTS = 50 29 | EXP_WEIGHT = 0.0 30 | 31 | 32 | _AC_LOC = None 33 | _AC_SCALE = None 34 | 35 | 36 | def _normalize_actions(actions): 37 | loc = _AC_LOC[None].astype(actions.dtype) 38 | scale = _AC_SCALE[None].astype(actions.dtype) 39 | return (actions - loc) / scale 40 | 41 | 42 | def _denormalize_action(actions): 43 | return actions * _AC_SCALE + _AC_LOC 44 | 45 | 46 | def _render(env, height=256, width=256): 47 | img = env.render( 48 | mode="rgb_array", height=height, width=width, camera_name="agentview" 49 | ) 50 | return img.astype(np.uint8) 51 | 52 | 53 | def _obs_dict_to_vec(obs_dict): 54 | return np.concatenate([obs_dict[k] for k in OBS_KEYS]).astype(np.float32) 55 | 56 | 57 | def _get_task_path(task): 58 | # little hack now for transport 59 | global OBS_KEYS, _MAX_STEPS 60 | if task == "transport" and "robot1_eef_pos" not in OBS_KEYS: 61 | OBS_KEYS += [ 62 | "robot1_eef_pos", 63 | "robot1_eef_quat", 64 | "robot1_gripper_qpos", 65 | "robot1_gripper_qvel", 66 | ] 67 | 68 | if task in ("transport", "tool_hang"): 69 | _MAX_STEPS = 725 70 | 71 | task_path = os.path.expanduser(f"~/robomimic/datasets/{task}/ph/low_dim.hdf5") 72 | assert os.path.exists(task_path), "Missing task data!" 73 | return task_path 74 | 75 | 76 | def _make_env(task): 77 | task_path = _get_task_path(task) 78 | dummy_spec = dict( 79 | obs=dict( 80 | low_dim=["robot0_eef_pos"], 81 | rgb=[], 82 | ), 83 | ) 84 | ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec) 85 | env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=task_path) 86 | 87 | env = EnvUtils.create_env_from_metadata(env_meta=env_meta, render_offscreen=True) 88 | env.reset() 89 | return env 90 | 91 | 92 | class RoboSuiteBuffer(SimTaskReplayBuffer): 93 | def load_buffer(self, task): 94 | global _AC_LOC, _AC_SCALE 95 | 96 | render_env = _make_env(task) 97 | task_path = _get_task_path(task) 98 | buffer = [] 99 | all_actions = [] 100 | with h5py.File(task_path, "r") as f: 101 | print("Loading demonstration data!") 102 | 103 | for demo in tqdm.tqdm(list(f["data"].keys())): 104 | sim_states = f[f"data/{demo}/states"][()] 105 | demo_acs = f[f"data/{demo}/actions"][()] 106 | 107 | images, obs, actions = [], [], [] 108 | for s, a in zip(sim_states, demo_acs): 109 | # reset to demo state and generate robot obs/images 110 | obs_dict = render_env.reset_to({"states": s}) 111 | obs.append(_obs_dict_to_vec(obs_dict)) 112 | images.append(_render(render_env)) 113 | actions.append(a.astype(np.float32)) 114 | all_actions.append(a.astype(np.float32)) 115 | 116 | buffer.append( 117 | dict( 118 | images=np.array(images), 119 | observations=np.array(obs), 120 | actions=np.array(actions), 121 | ) 122 | ) 123 | 124 | all_actions = np.array(all_actions) 125 | max_ac = np.max(all_actions, axis=0) 126 | min_ac = np.min(all_actions, axis=0) 127 | _AC_LOC = (max_ac + min_ac) / 2 128 | _AC_SCALE = (max_ac - min_ac) / 2 129 | 130 | for t in buffer: 131 | t["actions"] = _normalize_actions(t["actions"]) 132 | print("built task", task, "with rollout steps", _MAX_STEPS) 133 | return buffer 134 | 135 | 136 | class RoboSuiteTask(SimTask): 137 | def build_eval_env(self, task, n_cams, obs_dim, ac_dim): 138 | assert n_cams == 1, "Only support single cam tasks!" 139 | if task == "transport": 140 | assert obs_dim == 22, "Robosuite obs_dim should be 22!" 141 | assert ac_dim == 14, "Robosuite should have ac_dim of 14!" 142 | else: 143 | assert obs_dim == 11, "Robosuite obs_dim should be 11!" 144 | assert ac_dim == 7, "Robosuite should have ac_dim of 7!" 145 | self.eval_env = _make_env(task) 146 | 147 | def rollout_sim(self, agent, device_id, frame_buffer): 148 | env = self.eval_env 149 | transform = self.test_transform 150 | 151 | success_flags, total_rewards = [], [] 152 | for i in range(_N_ROLLOUTS): 153 | print(f"Rollout {i}", end="\r") 154 | o = env.reset() 155 | done = False 156 | t = 0 157 | total_reward = 0 158 | act_history = None 159 | 160 | while not done and t < _MAX_STEPS and not env.is_success()["task"]: 161 | o = torch.from_numpy(_obs_dict_to_vec(o))[None].to(device_id) 162 | raw_img = _render(env) 163 | frame_buffer.append(raw_img) 164 | i = ( 165 | torch.from_numpy(raw_img).permute((2, 0, 1)).float().to(device_id) 166 | / 255 167 | ) 168 | i = dict(cam0=transform(i)[None][None]) 169 | 170 | with torch.no_grad(): 171 | acs = agent.get_actions(i, o)[0] 172 | 173 | acs = acs.cpu().numpy() 174 | if act_history is None: 175 | act_history = deque(maxlen=len(acs)) 176 | act_history.append(acs) 177 | 178 | num_actions = len(act_history) 179 | curr_act_preds = np.stack( 180 | [ 181 | pred_actions[i] 182 | for (i, pred_actions) in zip( 183 | range(num_actions - 1, -1, -1), act_history 184 | ) 185 | ] 186 | ) 187 | 188 | # compute the weighted average across all predictions for this timestep 189 | weights = np.exp(-EXP_WEIGHT * np.arange(num_actions))[::-1] 190 | weights = weights / weights.sum() 191 | ac = np.sum(weights[:, None] * curr_act_preds, axis=0) 192 | 193 | # denormalize then execute on robot 194 | ac = _denormalize_action(ac) 195 | o, r, done, _ = env.step(ac) 196 | 197 | # process the env return and break if done 198 | t += 1 199 | total_reward += r 200 | if done: 201 | break 202 | 203 | # for ac_step in acs: 204 | # o, r, done, _ = env.step(ac_step) 205 | # t += 1; total_reward += r 206 | # if done: 207 | # break 208 | success_flags.append(float(env.is_success()["task"])) 209 | total_rewards.append(float(total_reward)) 210 | return np.mean(success_flags), np.mean(total_rewards) 211 | -------------------------------------------------------------------------------- /data4robotics/models/vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # adapted from: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | # Modified by Sudeep Dasari 12 | 13 | 14 | from functools import partial 15 | 16 | import timm.models.vision_transformer 17 | import torch 18 | import torch.nn as nn 19 | from timm.models.vision_transformer import resize_pos_embed 20 | 21 | 22 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 23 | """Vision Transformer with support for global average pooling""" 24 | 25 | def __init__( 26 | self, global_pool=False, use_cls=True, mask_ratio=None, del_head=True, **kwargs 27 | ): 28 | super(VisionTransformer, self).__init__(**kwargs) 29 | assert use_cls and not global_pool, "token counting only works for use_cls mode" 30 | self.classifier_feature = "use_cls_token" 31 | 32 | if del_head: 33 | del self.head # don't use prediction head 34 | 35 | if self.classifier_feature == "global_pool": 36 | norm_layer = kwargs["norm_layer"] 37 | embed_dim = kwargs["embed_dim"] 38 | self.fc_norm = norm_layer(embed_dim) 39 | 40 | del self.norm # remove the original norm 41 | 42 | if self.classifier_feature == "reshape_embedding": 43 | self.final_spatial = int(self.patch_embed.num_patches**0.5) 44 | self.embed_dim = ( 45 | self.patch_embed.grid_size[0], 46 | self.patch_embed.grid_size[1], 47 | kwargs["embed_dim"], 48 | ) 49 | 50 | self.mask_ratio = mask_ratio 51 | 52 | def random_masking(self, x, mask_ratio): 53 | """ 54 | Perform per-sample random masking by per-sample shuffling. 55 | Per-sample shuffling is done by argsort random noise. 56 | x: [N, L, D], sequence 57 | """ 58 | N, L, D = x.shape # batch, length, dim 59 | len_keep = int(L * (1 - mask_ratio)) 60 | 61 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 62 | 63 | # sort noise for each sample 64 | ids_shuffle = torch.argsort( 65 | noise, dim=1 66 | ) # ascend: small is keep, large is remove 67 | ids_restore = torch.argsort(ids_shuffle, dim=1) 68 | 69 | # keep the first subset 70 | ids_keep = ids_shuffle[:, :len_keep] 71 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 72 | 73 | # generate the binary mask: 0 is keep, 1 is remove 74 | mask = torch.ones([N, L], device=x.device) 75 | mask[:, :len_keep] = 0 76 | # unshuffle to get the binary mask 77 | mask = torch.gather(mask, dim=1, index=ids_restore) 78 | 79 | return x_masked, mask, ids_restore 80 | 81 | def handle_outcome(self, x): 82 | if self.classifier_feature == "global_pool": 83 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 84 | outcome = self.fc_norm(x) 85 | elif self.classifier_feature == "use_cls_token": 86 | x = self.norm(x) 87 | outcome = x[:, :1] # use cls token 88 | elif self.classifier_feature == "reshape_embedding": 89 | x = self.norm(x) 90 | outcome = reshape_embedding( 91 | x[:, 1:] 92 | ) # remove cls token and reshape embedding 93 | else: 94 | raise NotImplementedError 95 | 96 | return outcome 97 | 98 | def forward_features(self, x): 99 | B = x.shape[0] 100 | x = self.patch_embed(x) 101 | 102 | # add pos embed w/o cls token 103 | x = x + self.pos_embed[:, 1:, :] 104 | 105 | # masking: length -> length * mask_ratio 106 | if self.mask_ratio is not None: 107 | x, _, _ = self.random_masking(x, mask_ratio=self.mask_ratio) 108 | 109 | # append cls token 110 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 111 | x = torch.cat((cls_token.expand(B, -1, -1), x), dim=1) 112 | 113 | x = self.blocks(x) 114 | return self.handle_outcome(x) 115 | 116 | def forward(self, x): 117 | return self.forward_features(x) 118 | 119 | @property 120 | def n_tokens(self): 121 | # hard-coded assuming use_cls_token 122 | return 1 123 | 124 | 125 | class ClipVisionTransformer(VisionTransformer): 126 | def forward_features(self, x): 127 | B = x.shape[0] 128 | x = self.patch_embed(x) 129 | x = torch.cat( 130 | [ 131 | self.cls_token.squeeze() 132 | + torch.zeros(B, 1, x.shape[-1], device=x.device), 133 | x, 134 | ], 135 | dim=1, 136 | ) # shape = [*, grid ** 2 + 1, width] 137 | x = x + self.pos_embed.squeeze().to(x.dtype) 138 | x = self.norm_pre(x) 139 | 140 | x = self.blocks(x) 141 | return self.handle_outcome(x) 142 | 143 | 144 | def reshape_embedding(x): 145 | N, L, D = x.shape 146 | H = W = int(L**0.5) 147 | x = x.reshape(N, H, W, D) 148 | x = torch.einsum("nhwd->ndhw", x) 149 | return x 150 | 151 | 152 | def vit_small_patch16(**kwargs): 153 | """ViT small as defined in the DeiT paper.""" 154 | model = VisionTransformer( 155 | patch_size=16, 156 | embed_dim=384, 157 | depth=12, 158 | num_heads=6, 159 | mlp_ratio=4, 160 | qkv_bias=True, 161 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 162 | **kwargs, 163 | ) 164 | return model 165 | 166 | 167 | def vit_base_patch16(**kwargs): 168 | model = VisionTransformer( 169 | patch_size=16, 170 | embed_dim=768, 171 | depth=12, 172 | num_heads=12, 173 | mlp_ratio=4, 174 | qkv_bias=True, 175 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 176 | **kwargs, 177 | ) 178 | return model 179 | 180 | 181 | def clip_vit_base_patch16(**kwargs): 182 | model = ClipVisionTransformer( 183 | patch_size=16, 184 | embed_dim=768, 185 | depth=12, 186 | num_heads=12, 187 | mlp_ratio=4, 188 | qkv_bias=True, 189 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 190 | # CLIP-specific: 191 | pre_norm=True, 192 | num_classes=512, 193 | **kwargs, 194 | ) 195 | return model 196 | 197 | 198 | def vit_large_patch16(**kwargs): 199 | model = VisionTransformer( 200 | patch_size=16, 201 | embed_dim=1024, 202 | depth=24, 203 | num_heads=16, 204 | mlp_ratio=4, 205 | qkv_bias=True, 206 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 207 | **kwargs, 208 | ) 209 | return model 210 | 211 | 212 | def vit_huge_patch14(**kwargs): 213 | model = VisionTransformer( 214 | patch_size=14, 215 | embed_dim=1280, 216 | depth=32, 217 | num_heads=16, 218 | mlp_ratio=4, 219 | qkv_bias=True, 220 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 221 | **kwargs, 222 | ) 223 | return model 224 | 225 | 226 | def load_vit(model, restore_path): 227 | if restore_path: 228 | print("Restoring model from", restore_path) 229 | state_dict = torch.load(restore_path, map_location="cpu") 230 | state_dict = ( 231 | state_dict["features"] if "features" in state_dict else state_dict["model"] 232 | ) 233 | 234 | # resize pos_embed if required 235 | if state_dict["pos_embed"].shape != model.pos_embed.shape: 236 | print( 237 | f"resizing pos_embed from {state_dict['pos_embed'].shape} to {model.pos_embed.shape}" 238 | ) 239 | state_dict["pos_embed"] = resize_pos_embed( 240 | state_dict["pos_embed"], 241 | model.pos_embed, 242 | getattr(model, "num_tokens", 1), 243 | model.patch_embed.grid_size, 244 | ) 245 | 246 | # filter out keys with name decoder or mask_token 247 | state_dict = { 248 | k: v 249 | for k, v in state_dict.items() 250 | if "decoder" not in k and "mask_token" not in k 251 | } 252 | 253 | # remove norm if using global_pool instead of class token 254 | if model.classifier_feature == "global_pool": 255 | print("Removing extra weights for global_pool") 256 | # remove layer that start with norm 257 | state_dict = { 258 | k: v for k, v in state_dict.items() if not k.startswith("norm") 259 | } 260 | # add fc_norm in the state dict from the model 261 | state_dict["fc_norm.weight"] = model.fc_norm.weight 262 | state_dict["fc_norm.bias"] = model.fc_norm.bias 263 | 264 | # load state_dict 265 | model.load_state_dict(state_dict) 266 | return model 267 | -------------------------------------------------------------------------------- /data4robotics/models/action_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | # Heavy inspiration taken from ACT by Tony Zhao: https://github.com/tonyzhaozh/act 3 | # and DETR by Meta AI (Carion et. al.): https://github.com/facebookresearch/detr 4 | 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import copy 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | from data4robotics.agent import BaseAgent 17 | 18 | 19 | def _get_clones(module, N): 20 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 21 | 22 | 23 | def _get_activation_fn(activation): 24 | """Return an activation function given a string""" 25 | if activation == "relu": 26 | return F.relu 27 | if activation == "gelu": 28 | return F.gelu 29 | if activation == "glu": 30 | return F.glu 31 | raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.") 32 | 33 | 34 | def _with_pos_embed(tensor, pos=None): 35 | return tensor if pos is None else tensor + pos 36 | 37 | 38 | class _PositionalEncoding(nn.Module): 39 | def __init__(self, d_model, max_len=5000): 40 | super().__init__() 41 | # Compute the positional encodings once in log space 42 | pe = torch.zeros(max_len, d_model) 43 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 44 | div_term = torch.exp( 45 | torch.arange(0, d_model, 2, dtype=torch.float) 46 | * -(np.log(10000.0) / d_model) 47 | ) 48 | pe[:, 0::2] = torch.sin(position * div_term) 49 | pe[:, 1::2] = torch.cos(position * div_term) 50 | pe = pe.unsqueeze(0).transpose(0, 1) 51 | self.register_buffer("pe", pe) 52 | 53 | def forward(self, x): 54 | """ 55 | Args: 56 | x: Tensor of shape (seq_len, batch_size, d_model) 57 | 58 | Returns: 59 | Tensor of shape (seq_len, batch_size, d_model) with positional encodings added 60 | """ 61 | pe = self.pe[: x.shape[0]] 62 | pe = pe.repeat((1, x.shape[1], 1)) 63 | return pe.detach().clone() 64 | 65 | 66 | class _TransformerEncoderLayer(nn.Module): 67 | def __init__( 68 | self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu" 69 | ): 70 | super().__init__() 71 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 72 | # Implementation of Feedforward model 73 | self.linear1 = nn.Linear(d_model, dim_feedforward) 74 | self.linear2 = nn.Linear(dim_feedforward, d_model) 75 | 76 | self.norm1 = nn.LayerNorm(d_model) 77 | self.norm2 = nn.LayerNorm(d_model) 78 | 79 | self.dropout1 = nn.Dropout(dropout) 80 | self.dropout2 = nn.Dropout(dropout) 81 | self.dropout3 = nn.Dropout(dropout) 82 | 83 | self.activation = _get_activation_fn(activation) 84 | 85 | def forward(self, src, pos): 86 | q = k = _with_pos_embed(src, pos) 87 | src2, _ = self.self_attn(q, k, value=src) 88 | src = src + self.dropout1(src2) 89 | src = self.norm1(src) 90 | src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) 91 | src = src + self.dropout3(src2) 92 | src = self.norm2(src) 93 | return src 94 | 95 | 96 | class _TransformerDecoderLayer(nn.Module): 97 | def __init__( 98 | self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu" 99 | ): 100 | super().__init__() 101 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 102 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 103 | # Implementation of Feedforward model 104 | self.linear1 = nn.Linear(d_model, dim_feedforward) 105 | self.linear2 = nn.Linear(dim_feedforward, d_model) 106 | 107 | self.norm1 = nn.LayerNorm(d_model) 108 | self.norm2 = nn.LayerNorm(d_model) 109 | self.norm3 = nn.LayerNorm(d_model) 110 | 111 | self.dropout1 = nn.Dropout(dropout) 112 | self.dropout2 = nn.Dropout(dropout) 113 | self.dropout3 = nn.Dropout(dropout) 114 | self.dropout4 = nn.Dropout(dropout) 115 | 116 | self.activation = _get_activation_fn(activation) 117 | 118 | def forward(self, tgt, memory, pos=None, query_pos=None): 119 | q = k = _with_pos_embed(tgt, query_pos) 120 | tgt2, _ = self.self_attn(q, k, value=tgt) 121 | tgt = tgt + self.dropout1(tgt2) 122 | tgt = self.norm1(tgt) 123 | 124 | tgt2, _ = self.multihead_attn( 125 | query=_with_pos_embed(tgt, query_pos), 126 | key=_with_pos_embed(memory, pos), 127 | value=memory, 128 | ) 129 | tgt = tgt + self.dropout2(tgt2) 130 | tgt = self.norm2(tgt) 131 | 132 | tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) 133 | tgt = tgt + self.dropout4(tgt2) 134 | tgt = self.norm3(tgt) 135 | return tgt 136 | 137 | 138 | class _TransformerEncoder(nn.Module): 139 | def __init__(self, encoder_layer, num_layers): 140 | super().__init__() 141 | self.layers = _get_clones(encoder_layer, num_layers) 142 | self.num_layers = num_layers 143 | 144 | def forward(self, src, pos): 145 | output = src 146 | for layer in self.layers: 147 | output = layer(output, pos) 148 | return output 149 | 150 | 151 | class _TransformerDecoder(nn.Module): 152 | def __init__(self, decoder_layer, num_layers): 153 | super().__init__() 154 | self.layers = _get_clones(decoder_layer, num_layers) 155 | self.num_layers = num_layers 156 | 157 | def forward(self, tgt, memory, pos, query_pos, return_intermediate=False): 158 | output = tgt 159 | intermediate = [] 160 | for layer in self.layers: 161 | output = layer(output, memory, pos=pos, query_pos=query_pos) 162 | if return_intermediate: 163 | intermediate.append(self.norm(output)) 164 | 165 | if return_intermediate: 166 | return torch.stack(intermediate) 167 | return output 168 | 169 | 170 | class _ACT(nn.Module): 171 | def __init__( 172 | self, 173 | d_model=512, 174 | nhead=8, 175 | num_encoder_layers=6, 176 | num_decoder_layers=6, 177 | dim_feedforward=2048, 178 | dropout=0.1, 179 | activation="relu", 180 | ): 181 | super().__init__() 182 | 183 | encoder_layer = _TransformerEncoderLayer( 184 | d_model, nhead, dim_feedforward, dropout, activation 185 | ) 186 | self.encoder = _TransformerEncoder(encoder_layer, num_encoder_layers) 187 | 188 | decoder_layer = _TransformerDecoderLayer( 189 | d_model, nhead, dim_feedforward, dropout, activation 190 | ) 191 | self.decoder = _TransformerDecoder(decoder_layer, num_decoder_layers) 192 | 193 | self._reset_parameters() 194 | 195 | self.pos_helper = _PositionalEncoding(d_model) 196 | self.d_model = d_model 197 | self.nhead = nhead 198 | 199 | def _reset_parameters(self): 200 | for p in self.parameters(): 201 | if p.dim() > 1: 202 | nn.init.xavier_uniform_(p) 203 | 204 | def forward(self, input_tokens, query_enc): 205 | input_tokens = input_tokens.transpose(0, 1) 206 | input_pos = self.pos_helper(input_tokens) 207 | memory = self.encoder(input_tokens, input_pos) 208 | 209 | query_enc = query_enc[:, None].repeat((1, input_tokens.shape[1], 1)) 210 | tgt = torch.zeros_like(query_enc) 211 | acs_tokens = self.decoder(tgt, memory, input_pos, query_enc) 212 | return acs_tokens.transpose(0, 1) 213 | 214 | 215 | class TransformerAgent(BaseAgent): 216 | def __init__( 217 | self, 218 | features, 219 | odim, 220 | n_cams, 221 | ac_dim, 222 | ac_chunk, 223 | use_obs="add_token", 224 | imgs_per_cam=1, 225 | dropout=0, 226 | share_cam_features=False, 227 | early_fusion=False, 228 | feat_norm=False, 229 | token_dim=512, 230 | transformer_kwargs=dict(), 231 | ): 232 | 233 | # initialize obs and img tokenizers 234 | super().__init__( 235 | odim=odim, 236 | features=features, 237 | n_cams=n_cams, 238 | imgs_per_cam=imgs_per_cam, 239 | use_obs=use_obs, 240 | share_cam_features=share_cam_features, 241 | early_fusion=early_fusion, 242 | dropout=dropout, 243 | feat_norm=feat_norm, 244 | token_dim=token_dim, 245 | ) 246 | 247 | self.transformer = _ACT(**transformer_kwargs) 248 | self.ac_query = nn.Embedding(ac_chunk, self.transformer.d_model) 249 | self.ac_proj = nn.Linear(self.transformer.d_model, ac_dim) 250 | self._ac_dim, self._ac_chunk = ac_dim, ac_chunk 251 | 252 | def forward(self, imgs, obs, ac_flat, mask_flat): 253 | tokens = self.tokenize_obs(imgs, obs) 254 | action_tokens = self.transformer(tokens, self.ac_query.weight) 255 | actions = self.ac_proj(action_tokens) 256 | 257 | ac_flat_hat = actions.reshape((actions.shape[0], -1)) 258 | all_l1 = F.l1_loss(ac_flat_hat, ac_flat, reduction="none") 259 | l1 = (all_l1 * mask_flat).mean() 260 | return l1 261 | 262 | def get_actions(self, imgs, obs): 263 | tokens = self.tokenize_obs(imgs, obs) 264 | action_tokens = self.transformer(tokens, self.ac_query.weight) 265 | return self.ac_proj(action_tokens) 266 | 267 | @property 268 | def ac_chunk(self): 269 | return self._ac_chunk 270 | 271 | @property 272 | def ac_dim(self): 273 | return self._ac_dim 274 | -------------------------------------------------------------------------------- /data4robotics/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import math 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | from r3m import load_r3m 13 | from torchvision import models 14 | 15 | from data4robotics.models.base import BaseModel 16 | 17 | 18 | def _make_norm(norm_cfg): 19 | if norm_cfg["name"] == "batch_norm": 20 | return nn.BatchNorm2d 21 | if norm_cfg["name"] == "group_norm": 22 | num_groups = norm_cfg["num_groups"] 23 | return lambda num_channels: nn.GroupNorm(num_groups, num_channels) 24 | if norm_cfg["name"] == "diffusion_policy": 25 | 26 | def _gn_builder(num_channels): 27 | num_groups = int(num_channels // 16) 28 | return nn.GroupNorm(num_groups, num_channels) 29 | 30 | return _gn_builder 31 | raise NotImplementedError(f"Missing norm layer: {norm_cfg['name']}") 32 | 33 | 34 | def _construct_resnet(size, norm, weights=None): 35 | if size == 18: 36 | w = models.ResNet18_Weights 37 | m = models.resnet18(norm_layer=norm) 38 | elif size == 34: 39 | w = models.ResNet34_Weights 40 | m = models.resnet34(norm_layer=norm) 41 | elif size == 50: 42 | w = models.ResNet50_Weights 43 | m = models.resnet50(norm_layer=norm) 44 | else: 45 | raise NotImplementedError(f"Missing size: {size}") 46 | 47 | if weights is not None: 48 | w = w.verify(weights).get_state_dict(progress=True) 49 | if norm is not nn.BatchNorm2d: 50 | w = { 51 | k: v 52 | for k, v in w.items() 53 | if "running_mean" not in k and "running_var" not in k 54 | } 55 | m.load_state_dict(w) 56 | return m 57 | 58 | 59 | class ResNet(BaseModel): 60 | def __init__( 61 | self, 62 | size, 63 | norm_cfg, 64 | weights=None, 65 | restore_path="", 66 | avg_pool=True, 67 | conv_repeat=0, 68 | ): 69 | norm_layer = _make_norm(norm_cfg) 70 | model = _construct_resnet(size, norm_layer, weights) 71 | model.fc = nn.Identity() 72 | if not avg_pool: 73 | model.avgpool = nn.Identity() 74 | 75 | if conv_repeat > 1: 76 | w = model.conv1.weight.data.repeat((1, conv_repeat, 1, 1)) 77 | model.conv1.weight.data = w 78 | model.conv1.in_channels *= conv_repeat 79 | 80 | super().__init__(model, restore_path) 81 | self._size, self._avg_pool = size, avg_pool 82 | 83 | def forward(self, x): 84 | if self._avg_pool: 85 | return self._model(x)[:, None] 86 | B = x.shape[0] 87 | x = self._model(x) 88 | x = x.reshape((B, self.embed_dim, -1)) 89 | return x.transpose(1, 2) 90 | 91 | @property 92 | def embed_dim(self): 93 | return {18: 512, 34: 512, 50: 2048}[self._size] 94 | 95 | @property 96 | def n_tokens(self): 97 | if self._avg_pool: 98 | return 1 99 | return 49 # assuming 224x224 images 100 | 101 | 102 | class R3M(ResNet): 103 | def __init__(self, size, avg_pool=True): 104 | nn.Module.__init__(self) 105 | self._model = load_r3m(f"resnet{size}").module.convnet.cpu() 106 | if not avg_pool: 107 | self._model.avgpool = nn.Identity() 108 | self._size, self._avg_pool = size, avg_pool 109 | 110 | 111 | class SpatialSoftmax(nn.Module): 112 | """ 113 | Spatial Softmax Layer. 114 | Based on Deep Spatial Autoencoders for Visuomotor Learning by Finn et al. 115 | https://rll.berkeley.edu/dsae/dsae.pdf 116 | """ 117 | 118 | def __init__( 119 | self, 120 | input_shape, 121 | num_kp=None, 122 | temperature=1.0, 123 | learnable_temperature=False, 124 | output_variance=False, 125 | noise_std=0.0, 126 | ): 127 | """ 128 | Args: 129 | input_shape (list): shape of the input feature (C, H, W) 130 | num_kp (int): number of keypoints (None for not use spatialsoftmax) 131 | temperature (float): temperature term for the softmax. 132 | learnable_temperature (bool): whether to learn the temperature 133 | output_variance (bool): treat attention as a distribution, 134 | and compute second-order statistics to return 135 | noise_std (float): add random spatial noise to the predicted keypoints 136 | """ 137 | super(SpatialSoftmax, self).__init__() 138 | assert len(input_shape) == 3 139 | self._in_c, self._in_h, self._in_w = input_shape # (C, H, W) 140 | 141 | if num_kp is not None: 142 | self.nets = nn.Conv2d(self._in_c, num_kp, kernel_size=1) 143 | self._num_kp = num_kp 144 | else: 145 | self.nets = None 146 | self._num_kp = self._in_c 147 | self.learnable_temperature = learnable_temperature 148 | self.output_variance = output_variance 149 | self.noise_std = noise_std 150 | 151 | if self.learnable_temperature: 152 | # temperature will be learned 153 | temperature = nn.Parameter(torch.ones(1) * temperature, requires_grad=True) 154 | self.register_parameter("temperature", temperature) 155 | else: 156 | # temperature held constant after initialization 157 | temperature = nn.Parameter(torch.ones(1) * temperature, requires_grad=False) 158 | self.register_buffer("temperature", temperature) 159 | 160 | pos_x, pos_y = np.meshgrid( 161 | np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h) 162 | ) 163 | pos_x = torch.from_numpy(pos_x.reshape(1, self._in_h * self._in_w)).float() 164 | pos_y = torch.from_numpy(pos_y.reshape(1, self._in_h * self._in_w)).float() 165 | self.register_buffer("pos_x", pos_x) 166 | self.register_buffer("pos_y", pos_y) 167 | 168 | def output_shape(self, input_shape): 169 | """ 170 | Function to compute output shape from inputs to this module. 171 | Args: 172 | input_shape (iterable of int): shape of input. Does not include batch. 173 | Some modules may not need this argument, if their output does not depend 174 | on the size of the input, or if they assume fixed size input. 175 | Returns: 176 | out_shape ([int]): list of integers corresponding to output shape 177 | """ 178 | assert len(input_shape) == 3 179 | assert input_shape[0] == self._in_c 180 | return [self._num_kp, 2] 181 | 182 | def forward(self, feature): 183 | """ 184 | Forward pass through spatial softmax layer. For each keypoint, a 2D spatial 185 | probability distribution is created using a softmax, where the support is the 186 | pixel locations. This distribution is used to compute the expected value of 187 | the pixel location, which becomes a keypoint of dimension 2. K such keypoints 188 | are created. 189 | Returns: 190 | out (torch.Tensor or tuple): mean keypoints of shape [B, K, 2], and possibly 191 | keypoint variance of shape [B, K, 2, 2] corresponding to the covariance 192 | under the 2D spatial softmax distribution 193 | """ 194 | assert feature.shape[1] == self._in_c 195 | assert feature.shape[2] == self._in_h 196 | assert feature.shape[3] == self._in_w 197 | if self.nets is not None: 198 | feature = self.nets(feature) 199 | 200 | # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints 201 | feature = feature.reshape(-1, self._in_h * self._in_w) 202 | # 2d softmax normalization 203 | attention = torch.nn.functional.softmax(feature / self.temperature, dim=-1) 204 | # [1, H * W] x [B * K, H * W] -> [B * K, 1] 205 | # for spatial coordinate mean in x and y dimensions 206 | expected_x = torch.sum(self.pos_x * attention, dim=1, keepdim=True) 207 | expected_y = torch.sum(self.pos_y * attention, dim=1, keepdim=True) 208 | # stack to [B * K, 2] 209 | expected_xy = torch.cat([expected_x, expected_y], 1) 210 | # reshape to [B, K, 2] 211 | feature_keypoints = expected_xy.view(-1, self._num_kp, 2) 212 | 213 | if self.training: 214 | noise = torch.randn_like(feature_keypoints) * self.noise_std 215 | feature_keypoints += noise 216 | 217 | if self.output_variance: 218 | # treat attention as a distribution 219 | # and compute second-order statistics to return 220 | expected_xx = torch.sum( 221 | self.pos_x * self.pos_x * attention, dim=1, keepdim=True 222 | ) 223 | expected_yy = torch.sum( 224 | self.pos_y * self.pos_y * attention, dim=1, keepdim=True 225 | ) 226 | expected_xy = torch.sum( 227 | self.pos_x * self.pos_y * attention, dim=1, keepdim=True 228 | ) 229 | var_x = expected_xx - expected_x * expected_x 230 | var_y = expected_yy - expected_y * expected_y 231 | var_xy = expected_xy - expected_x * expected_y 232 | # stack to [B * K, 4] and then reshape to [B, K, 2, 2] 233 | # where last 2 dims are covariance matrix 234 | feature_covar = torch.cat([var_x, var_xy, var_xy, var_y], 1).reshape( 235 | -1, self._num_kp, 2, 2 236 | ) 237 | feature_keypoints = (feature_keypoints, feature_covar) 238 | 239 | return feature_keypoints 240 | 241 | 242 | class RobomimicResNet(nn.Module): 243 | def __init__(self, size, norm_cfg, weights=None, img_size=224, feature_dim=64): 244 | super().__init__() 245 | norm_layer = _make_norm(norm_cfg) 246 | model = _construct_resnet(size, norm_layer, weights) 247 | # Cut the last two layers. 248 | self.resnet = nn.Sequential(*(list(model.children())[:-2])) 249 | resnet_out_dim = int(math.ceil(img_size / 32.0)) 250 | resnet_output_shape = [512, resnet_out_dim, resnet_out_dim] 251 | self.spatial_softmax = SpatialSoftmax( 252 | resnet_output_shape, 253 | num_kp=64, 254 | temperature=1.0, 255 | noise_std=0.0, 256 | output_variance=False, 257 | learnable_temperature=False, 258 | ) 259 | pool_output_shape = self.spatial_softmax.output_shape(resnet_output_shape) 260 | self.flatten = nn.Flatten(start_dim=1, end_dim=-1) 261 | self.proj = nn.Linear(int(np.prod(pool_output_shape)), feature_dim) 262 | self.feature_dim = feature_dim 263 | 264 | def forward(self, x): 265 | x = self.resnet(x) 266 | x = self.spatial_softmax(x) 267 | x = self.flatten(x) 268 | x = self.proj(x) 269 | return x[:, None] 270 | 271 | @property 272 | def embed_dim(self): 273 | return self.feature_dim 274 | 275 | @property 276 | def n_tokens(self): 277 | return 1 278 | -------------------------------------------------------------------------------- /eval_scripts/eval_aloha.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import sys 6 | import threading 7 | import time 8 | from collections import deque 9 | from pathlib import Path 10 | 11 | import cv2 12 | import numpy as np 13 | import torch 14 | import yaml 15 | 16 | import hydra 17 | 18 | 19 | # speed up torch agent using tensorcores 20 | torch.set_float32_matmul_precision('high') 21 | 22 | 23 | # add core aloha files to path then do aloha imports 24 | BASE_PATH = os.path.expanduser("~/aloha/") 25 | sys.path.append(BASE_PATH + "src") 26 | sys.path.append(BASE_PATH + "src/aloha_pro/aloha_scripts/") 27 | from aloha_pro.aloha_scripts.real_env import make_real_env 28 | 29 | 30 | class Policy: 31 | def __init__(self, agent_path, model_name, args): 32 | self.args = args 33 | 34 | with open(Path(agent_path, "agent_config.yaml"), "r") as f: 35 | config_yaml = f.read() 36 | agent_config = yaml.safe_load(config_yaml) 37 | with open(Path(agent_path, "obs_config.yaml"), "r") as f: 38 | config_yaml = f.read() 39 | obs_config = yaml.safe_load(config_yaml) 40 | with open(Path(agent_path, "ac_norm.json"), "r") as f: 41 | ac_norm_dict = json.load(f) 42 | loc, scale = ac_norm_dict["loc"], ac_norm_dict["scale"] 43 | self.loc = np.array(loc).astype(np.float32) 44 | self.scale = np.array(scale).astype(np.float32) 45 | 46 | agent = hydra.utils.instantiate(agent_config) 47 | 48 | save_dict = torch.load(Path(agent_path, model_name), map_location="cpu") 49 | agent.load_state_dict(save_dict["model"]) 50 | self.agent = torch.compile(agent.eval().cuda().get_actions) 51 | 52 | self.transform = hydra.utils.instantiate(obs_config["transform"]) 53 | self.img_keys = obs_config["imgs"] 54 | 55 | if args.goal_path: 56 | bgr_img = cv2.imread(args.goal_path)[:, :, :3] 57 | bgr_img = cv2.resize(bgr_img, (256, 256), interpolation=cv2.INTER_AREA) 58 | rgb_img = ( 59 | torch.from_numpy(bgr_img[:, :, ::-1].copy()).float().permute((2, 0, 1)) 60 | / 255 61 | ) 62 | self.goal_img = self.transform(rgb_img)[None].cuda() 63 | else: 64 | self.goal_img = None 65 | 66 | print(f"loaded agent from {agent_path}, at step: {save_dict['global_step']}") 67 | self.temp_ensemble = args.temp_ensemble 68 | self.pred_horizon = args.pred_horizon 69 | self.reset() 70 | 71 | def reset(self): 72 | self.act_history = deque(maxlen=self.pred_horizon) 73 | self.last_ac = None 74 | self._last_time = None 75 | 76 | def _proc_images(self, img_dict, size=(256, 256)): 77 | torch_imgs = dict() 78 | for i, k in enumerate(self.img_keys): 79 | bgr_img = img_dict[k][:, :, :3] 80 | bgr_img = cv2.resize(bgr_img, size, interpolation=cv2.INTER_AREA) 81 | rgb_img = bgr_img[:, :, ::-1].copy() 82 | rgb_img = torch.from_numpy(rgb_img).float().permute((2, 0, 1)) / 255 83 | rgb_img = self.transform(rgb_img)[None].cuda() 84 | 85 | if self.goal_img is not None: 86 | if k == "cam_high": 87 | torch_imgs[f"cam{i}"] = torch.cat((self.goal_img, rgb_img), 0)[None] 88 | else: 89 | zero_pad = torch.zeros_like(self.goal_img) 90 | torch_imgs[f"cam{i}"] = torch.cat((zero_pad, rgb_img), 0)[None] 91 | else: 92 | torch_imgs[f"cam{i}"] = rgb_img[None] 93 | 94 | return torch_imgs 95 | 96 | def _proc_state(self, qpos): 97 | return torch.from_numpy(qpos).float()[None].cuda() 98 | 99 | def _infer_policy(self, obs): 100 | start = time.time() 101 | img = self._proc_images(obs["images"]) 102 | print("Image processing time:", time.time() - start) 103 | start = time.time() 104 | state = self._proc_state(obs["qpos"]) 105 | print("State processing time:", time.time() - start) 106 | 107 | start = time.time() 108 | with torch.no_grad(): 109 | ac = self.agent(img, state) 110 | ac = ac[0].cpu().numpy().astype(np.float32)[: self.args.pred_horizon] 111 | print("Inference time:", time.time() - start) 112 | 113 | # make sure the model predicted enough steps 114 | assert ( 115 | len(ac) >= self.args.pred_horizon 116 | ), "model did not return enough predictions!" 117 | return ac 118 | 119 | def _forward_ensemble(self, obs): 120 | ac = self._infer_policy(obs) 121 | self.act_history.append(ac) 122 | 123 | # potentially consider not ensembling every timestep. 124 | 125 | # handle temporal blending 126 | num_actions = len(self.act_history) 127 | print("Num actions:", num_actions) 128 | curr_act_preds = np.stack( 129 | [ 130 | pred_actions[i] 131 | for (i, pred_actions) in zip( 132 | range(num_actions - 1, -1, -1), self.act_history 133 | ) 134 | ] 135 | ) 136 | 137 | # more recent predictions get exponentially *less* weight than older predictions 138 | weights = np.exp(-self.args.exp_weight * np.arange(num_actions)) 139 | weights = weights / weights.sum() 140 | 141 | # return the weighted average across all predictions for this timestep 142 | return np.sum(weights[:, None] * curr_act_preds, axis=0) 143 | 144 | def _forward_chunked(self, obs): 145 | if not len(self.act_history): 146 | acs = self._infer_policy(obs) 147 | for ac in acs: 148 | self.act_history.append(ac) 149 | 150 | raw_ac = self.act_history.popleft() 151 | last_ac = self.last_ac if self.last_ac is not None else raw_ac 152 | self.last_ac = self.args.gamma * raw_ac + (1 - self.args.gamma) * last_ac 153 | return self.last_ac.copy() 154 | 155 | def forward(self, obs): 156 | ac = ( 157 | self._forward_ensemble(obs) 158 | if self.temp_ensemble 159 | else self._forward_chunked(obs) 160 | ) 161 | 162 | # denormalize the actions 163 | ac = ac * self.scale + self.loc 164 | 165 | # check effective HZ 166 | if self._last_time is not None: 167 | delta = time.time() - self._last_time 168 | if delta < self.args.period: 169 | time.sleep(self.args.period - delta) 170 | print("Effective HZ:", 1.0 / (time.time() - self._last_time)) 171 | self._last_time = time.time() 172 | return ac 173 | 174 | 175 | def main(): 176 | parser = argparse.ArgumentParser() 177 | parser.add_argument("checkpoint") 178 | parser.add_argument("--T", default=400, type=int) 179 | parser.add_argument("--temp_ensemble", default=False, action="store_true") 180 | parser.add_argument("--num_rollouts", default=1, type=int) 181 | parser.add_argument("--pred_horizon", default=48, type=int) 182 | parser.add_argument("--exp_weight", default=0, type=float) 183 | parser.add_argument("--hz", default=48, type=float) 184 | parser.add_argument("--gamma", default=0.85, type=float) 185 | parser.add_argument("--save_dir", type=str) 186 | parser.add_argument("--goal_path", default=None, type=str) 187 | 188 | args = parser.parse_args() 189 | args.period = 1.0 / args.hz 190 | 191 | if args.save_dir and not os.path.exists(args.save_dir): 192 | os.makedirs(args.save_dir) 193 | 194 | agent_path = os.path.expanduser(os.path.dirname(args.checkpoint)) 195 | model_name = args.checkpoint.split("/")[-1] 196 | policy = Policy(agent_path, model_name, args) 197 | 198 | env = make_real_env(init_node=True) 199 | 200 | if args.save_dir: 201 | next_highest = get_highest_rollout_num(args.save_dir) + 1 202 | else: 203 | # Default to starting at 0 204 | next_highest = 0 205 | 206 | # Roll out the policy num_rollout times 207 | for rollout_num in range(next_highest, args.num_rollouts + next_highest): 208 | 209 | last_input = None 210 | while last_input != "y": 211 | if last_input == "r": 212 | obs = env.reset() 213 | last_input = input("Continue with rollout (y; r to reset now)?") 214 | 215 | policy.reset() 216 | 217 | obs_data = [] 218 | 219 | obs = env.reset() 220 | start_time = time.time() 221 | obs_data.append(obs) 222 | 223 | for _ in range(args.T): 224 | ac = policy.forward(obs.observation) 225 | obs = env.step(ac) 226 | obs_data.append(obs) 227 | 228 | end_time = time.time() 229 | 230 | if args.save_dir: 231 | # Save the rollout video if save_dir is provided 232 | 233 | rollout_name = f"episode_{rollout_num}.mp4" 234 | save_path = os.path.join(args.save_dir, rollout_name) 235 | save_thread = threading.Thread( 236 | target=save_rollout_video, 237 | args=(obs_data, save_path, policy.img_keys, end_time - start_time), 238 | ) 239 | save_thread.start() 240 | 241 | env._reset_gripper() 242 | 243 | 244 | def get_highest_rollout_num(save_dir): 245 | """ 246 | Get the highest rollout number in the save directory 247 | """ 248 | if not os.path.exists(save_dir): 249 | raise ValueError(f"Directory {save_dir} does not exist.") 250 | 251 | files = [ 252 | os.path.basename(f) 253 | for f in os.listdir(save_dir) 254 | if os.path.isfile(os.path.join(save_dir, f)) 255 | ] 256 | if not files: 257 | return -1 # No files yet 258 | return max([int(re.search(r"\d+", f_name)[0]) for f_name in files]) 259 | 260 | 261 | def save_rollout_video(obs, path, camera_names, length_of_episode): 262 | """ 263 | Save the policy rollout to a video 264 | """ 265 | t0 = time.time() 266 | 267 | # Get the list 268 | image_dict = {} 269 | 270 | for cam_name in camera_names: 271 | image_dict[cam_name] = [] 272 | 273 | # len(action): max_timesteps, len(time_steps): max_timesteps + 1 274 | while len(obs) > 1: 275 | ts = obs.pop(0) 276 | for cam_name in camera_names: 277 | image_dict[cam_name].append(ts.observation["images"][cam_name]) 278 | 279 | cam_names = list(image_dict.keys()) 280 | 281 | all_cam_videos = [] 282 | for cam_name in cam_names: 283 | all_cam_videos.append(image_dict[cam_name]) 284 | all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension 285 | 286 | n_frames, h, w, _ = all_cam_videos.shape 287 | fps = int( 288 | n_frames / length_of_episode 289 | ) # This is an estimate, but the frames are not uniformly distributed when rolling out the policy, leading to skips in the video. 290 | 291 | out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) 292 | for t in range(n_frames): 293 | image = all_cam_videos[t] 294 | image = image[:, :, [0, 1, 2]] 295 | out.write(image) 296 | out.release() 297 | 298 | print(f"Saving: {time.time() - t0:.1f} secs") 299 | 300 | 301 | if __name__ == "__main__": 302 | main() 303 | -------------------------------------------------------------------------------- /data4robotics/models/diffusion_unet.py: -------------------------------------------------------------------------------- 1 | # U-Net implementation from: Diffusion Policy Codebase (Chi et al; arXiv:2303.04137) 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import math 8 | from typing import Optional, Union 9 | 10 | import torch 11 | import torch.nn as nn 12 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler 13 | 14 | from data4robotics.agent import BaseAgent 15 | 16 | 17 | class SinusoidalPosEmb(nn.Module): 18 | def __init__(self, dim): 19 | super().__init__() 20 | self.dim = dim 21 | 22 | def forward(self, x): 23 | device = x.device 24 | half_dim = self.dim // 2 25 | emb = math.log(10000) / (half_dim - 1) 26 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 27 | emb = x[:, None] * emb[None, :] 28 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 29 | return emb 30 | 31 | 32 | class Downsample1d(nn.Module): 33 | def __init__(self, dim): 34 | super().__init__() 35 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1) 36 | 37 | def forward(self, x): 38 | return self.conv(x) 39 | 40 | 41 | class Upsample1d(nn.Module): 42 | def __init__(self, dim): 43 | super().__init__() 44 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) 45 | 46 | def forward(self, x): 47 | return self.conv(x) 48 | 49 | 50 | class Conv1dBlock(nn.Module): 51 | """ 52 | Conv1d --> GroupNorm --> Mish 53 | """ 54 | 55 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): 56 | super().__init__() 57 | 58 | self.block = nn.Sequential( 59 | nn.Conv1d( 60 | inp_channels, out_channels, kernel_size, padding=kernel_size // 2 61 | ), 62 | nn.GroupNorm(n_groups, out_channels), 63 | nn.Mish(), 64 | ) 65 | 66 | def forward(self, x): 67 | return self.block(x) 68 | 69 | 70 | class ConditionalResidualBlock1D(nn.Module): 71 | def __init__(self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8): 72 | super().__init__() 73 | 74 | self.blocks = nn.ModuleList( 75 | [ 76 | Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), 77 | Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), 78 | ] 79 | ) 80 | 81 | # FiLM modulation https://arxiv.org/abs/1709.07871 82 | # predicts per-channel scale and bias 83 | cond_channels = out_channels * 2 84 | self.out_channels = out_channels 85 | self.cond_encoder = nn.Sequential( 86 | nn.Mish(), nn.Linear(cond_dim, cond_channels), nn.Unflatten(-1, (-1, 1)) 87 | ) 88 | 89 | # make sure dimensions compatible 90 | self.residual_conv = ( 91 | nn.Conv1d(in_channels, out_channels, 1) 92 | if in_channels != out_channels 93 | else nn.Identity() 94 | ) 95 | 96 | def forward(self, x, cond): 97 | """ 98 | x : [ batch_size x in_channels x horizon ] 99 | cond : [ batch_size x cond_dim] 100 | 101 | returns: 102 | out : [ batch_size x out_channels x horizon ] 103 | """ 104 | out = self.blocks[0](x) 105 | embed = self.cond_encoder(cond) 106 | 107 | embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1) 108 | scale = embed[:, 0, ...] 109 | bias = embed[:, 1, ...] 110 | out = scale * out + bias 111 | 112 | out = self.blocks[1](out) 113 | out = out + self.residual_conv(x) 114 | return out 115 | 116 | 117 | class ConditionalUnet1D(nn.Module): 118 | def __init__( 119 | self, 120 | input_dim, 121 | global_cond_dim, 122 | diffusion_step_embed_dim=256, 123 | down_dims=[256, 512, 1024], 124 | kernel_size=3, 125 | n_groups=8, 126 | ): 127 | """ 128 | input_dim: Dim of actions. 129 | global_cond_dim: Dim of global conditioning applied with FiLM 130 | in addition to diffusion step embedding. This is usually obs_horizon * obs_dim 131 | diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k 132 | down_dims: Channel size for each UNet level. 133 | The length of this array determines numebr of levels. 134 | kernel_size: Conv kernel size 135 | n_groups: Number of groups for GroupNorm 136 | """ 137 | super().__init__() 138 | all_dims = [input_dim] + list(down_dims) 139 | start_dim = down_dims[0] 140 | 141 | dsed = diffusion_step_embed_dim 142 | diffusion_step_encoder = nn.Sequential( 143 | SinusoidalPosEmb(dsed), 144 | nn.Linear(dsed, dsed * 4), 145 | nn.Mish(), 146 | nn.Linear(dsed * 4, dsed), 147 | ) 148 | cond_dim = dsed + global_cond_dim 149 | 150 | in_out = list(zip(all_dims[:-1], all_dims[1:])) 151 | mid_dim = all_dims[-1] 152 | self.mid_modules = nn.ModuleList( 153 | [ 154 | ConditionalResidualBlock1D( 155 | mid_dim, 156 | mid_dim, 157 | cond_dim=cond_dim, 158 | kernel_size=kernel_size, 159 | n_groups=n_groups, 160 | ), 161 | ConditionalResidualBlock1D( 162 | mid_dim, 163 | mid_dim, 164 | cond_dim=cond_dim, 165 | kernel_size=kernel_size, 166 | n_groups=n_groups, 167 | ), 168 | ] 169 | ) 170 | 171 | down_modules = nn.ModuleList([]) 172 | for ind, (dim_in, dim_out) in enumerate(in_out): 173 | is_last = ind >= (len(in_out) - 1) 174 | down_modules.append( 175 | nn.ModuleList( 176 | [ 177 | ConditionalResidualBlock1D( 178 | dim_in, 179 | dim_out, 180 | cond_dim=cond_dim, 181 | kernel_size=kernel_size, 182 | n_groups=n_groups, 183 | ), 184 | ConditionalResidualBlock1D( 185 | dim_out, 186 | dim_out, 187 | cond_dim=cond_dim, 188 | kernel_size=kernel_size, 189 | n_groups=n_groups, 190 | ), 191 | Downsample1d(dim_out) if not is_last else nn.Identity(), 192 | ] 193 | ) 194 | ) 195 | 196 | up_modules = nn.ModuleList([]) 197 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 198 | is_last = ind >= (len(in_out) - 1) 199 | up_modules.append( 200 | nn.ModuleList( 201 | [ 202 | ConditionalResidualBlock1D( 203 | dim_out * 2, 204 | dim_in, 205 | cond_dim=cond_dim, 206 | kernel_size=kernel_size, 207 | n_groups=n_groups, 208 | ), 209 | ConditionalResidualBlock1D( 210 | dim_in, 211 | dim_in, 212 | cond_dim=cond_dim, 213 | kernel_size=kernel_size, 214 | n_groups=n_groups, 215 | ), 216 | Upsample1d(dim_in) if not is_last else nn.Identity(), 217 | ] 218 | ) 219 | ) 220 | 221 | final_conv = nn.Sequential( 222 | Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size), 223 | nn.Conv1d(start_dim, input_dim, 1), 224 | ) 225 | 226 | self.diffusion_step_encoder = diffusion_step_encoder 227 | self.up_modules = up_modules 228 | self.down_modules = down_modules 229 | self.final_conv = final_conv 230 | 231 | print( 232 | "number of diffusion parameters: {:e}".format( 233 | sum(p.numel() for p in self.parameters()) 234 | ) 235 | ) 236 | 237 | def forward( 238 | self, 239 | sample: torch.Tensor, 240 | timestep: Union[torch.Tensor, int], 241 | global_cond: Optional[torch.Tensor] = None, 242 | ): 243 | """ 244 | x: (B,T,input_dim) 245 | timestep: (B,) or int, diffusion step 246 | global_cond: (B,global_cond_dim) 247 | output: (B,T,input_dim) 248 | """ 249 | # (B,T,C) 250 | sample = sample.moveaxis(-1, -2) 251 | # (B,C,T) 252 | 253 | # 1. time 254 | timesteps = timestep 255 | if not torch.is_tensor(timesteps): 256 | timesteps = torch.tensor( 257 | [timesteps], dtype=torch.long, device=sample.device 258 | ) 259 | elif isinstance(timesteps, torch.Tensor) and len(timesteps.shape) == 0: 260 | timesteps = timesteps[None].to(sample.device) 261 | 262 | assert isinstance(timesteps, torch.Tensor) 263 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 264 | timesteps = timesteps.expand(sample.shape[0]) 265 | 266 | global_feature: torch.Tensor = self.diffusion_step_encoder(timesteps) 267 | 268 | if global_cond is not None: 269 | global_feature = torch.cat([global_feature, global_cond], dim=-1) 270 | 271 | x = sample 272 | h = [] 273 | for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): 274 | x = resnet(x, global_feature) 275 | x = resnet2(x, global_feature) 276 | h.append(x) 277 | x = downsample(x) 278 | 279 | for mid_module in self.mid_modules: 280 | x = mid_module(x, global_feature) 281 | 282 | for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): 283 | x = torch.cat((x, h.pop()), dim=1) 284 | x = resnet(x, global_feature) 285 | x = resnet2(x, global_feature) 286 | x = upsample(x) 287 | 288 | x = self.final_conv(x) 289 | 290 | # (B,C,T) 291 | x = x.moveaxis(-1, -2) 292 | # (B,T,C) 293 | return x 294 | 295 | 296 | class DiffusionUnetAgent(BaseAgent): 297 | def __init__( 298 | self, 299 | features, 300 | odim, 301 | n_cams, 302 | use_obs, 303 | ac_dim, 304 | ac_chunk, 305 | train_diffusion_steps, 306 | eval_diffusion_steps, 307 | imgs_per_cam=1, 308 | dropout=0, 309 | share_cam_features=False, 310 | early_fusion=False, 311 | feat_norm=None, 312 | token_dim=None, 313 | noise_net_kwargs=dict(), 314 | ): 315 | 316 | # initialize obs and img tokenizers 317 | super().__init__( 318 | odim=odim, 319 | features=features, 320 | n_cams=n_cams, 321 | imgs_per_cam=imgs_per_cam, 322 | use_obs=use_obs, 323 | share_cam_features=share_cam_features, 324 | early_fusion=early_fusion, 325 | dropout=dropout, 326 | feat_norm=feat_norm, 327 | token_dim=token_dim, 328 | ) 329 | 330 | cond_dim = self.n_tokens * self.token_dim 331 | self.noise_net = ConditionalUnet1D( 332 | input_dim=ac_dim, global_cond_dim=cond_dim, **noise_net_kwargs 333 | ) 334 | 335 | self._ac_dim, self._ac_chunk = ac_dim, ac_chunk 336 | 337 | assert ( 338 | eval_diffusion_steps <= train_diffusion_steps 339 | ), "Can't eval with more steps!" 340 | self._train_diffusion_steps = train_diffusion_steps 341 | self._eval_diffusion_steps = eval_diffusion_steps 342 | self.diffusion_schedule = DDIMScheduler( 343 | num_train_timesteps=train_diffusion_steps, 344 | beta_start=0.0001, 345 | beta_end=0.02, 346 | beta_schedule="squaredcos_cap_v2", 347 | clip_sample=True, 348 | set_alpha_to_one=True, 349 | steps_offset=0, 350 | prediction_type="epsilon", 351 | ) 352 | 353 | def forward(self, imgs, obs, ac_flat, mask_flat): 354 | # get observation encoding and sample noise/timesteps 355 | B, device = obs.shape[0], obs.device 356 | s_t = self.tokenize_obs(imgs, obs, flatten=True) 357 | timesteps = torch.randint( 358 | low=0, high=self._train_diffusion_steps, size=(B,), device=device 359 | ).long() 360 | 361 | # diffusion unet logic assumes [B, T, adim] 362 | mask = mask_flat.reshape((B, self.ac_chunk, self.ac_dim)) 363 | actions = ac_flat.reshape((B, self.ac_chunk, self.ac_dim)) 364 | noise = torch.randn_like(actions) 365 | 366 | # construct noise actions given real actions, noise, and diffusion schedule 367 | noise_acs = self.diffusion_schedule.add_noise(actions, noise, timesteps) 368 | noise_pred = self.noise_net(noise_acs, timesteps, s_t) 369 | 370 | # calculate loss for noise net 371 | loss = nn.functional.mse_loss(noise_pred, noise, reduction="none") 372 | loss = (loss * mask).sum((1, 2)) # mask the loss to only consider "real" acs 373 | return loss.mean() 374 | 375 | def get_actions(self, imgs, obs, n_steps=None): 376 | # get observation encoding and sample noise 377 | B, device = obs.shape[0], obs.device 378 | s_t = self.tokenize_obs(imgs, obs, flatten=True) 379 | noise_actions = torch.randn(B, self.ac_chunk, self.ac_dim, device=device) 380 | 381 | # set number of steps 382 | eval_steps = self._eval_diffusion_steps 383 | if n_steps is not None: 384 | assert ( 385 | n_steps <= self._train_diffusion_steps 386 | ), f"can't be > {self._train_diffusion_steps}" 387 | eval_steps = n_steps 388 | 389 | # begin diffusion process 390 | self.diffusion_schedule.set_timesteps(eval_steps) 391 | self.diffusion_schedule.alphas_cumprod = ( 392 | self.diffusion_schedule.alphas_cumprod.to(device) 393 | ) 394 | for timestep in self.diffusion_schedule.timesteps: 395 | # predict noise given timestep 396 | batched_timestep = timestep.unsqueeze(0).repeat(B).to(device) 397 | noise_pred = self.noise_net(noise_actions, batched_timestep, s_t) 398 | 399 | # take diffusion step 400 | noise_actions = self.diffusion_schedule.step( 401 | model_output=noise_pred, timestep=timestep, sample=noise_actions 402 | ).prev_sample 403 | 404 | # return final action post diffusion 405 | return noise_actions 406 | 407 | @property 408 | def ac_chunk(self): 409 | return self._ac_chunk 410 | 411 | @property 412 | def ac_dim(self): 413 | return self._ac_dim 414 | -------------------------------------------------------------------------------- /data4robotics/models/diffusion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sudeep Dasari, 2023 2 | # Heavy inspiration taken from DETR by Meta AI (Carion et. al.): https://github.com/facebookresearch/detr 3 | # and DiT by Meta AI (Peebles and Xie): https://github.com/facebookresearch/DiT 4 | 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | import copy 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler 16 | 17 | from data4robotics.agent import BaseAgent 18 | 19 | 20 | def _get_activation_fn(activation): 21 | """Return an activation function given a string""" 22 | if activation == "relu": 23 | return F.relu 24 | if activation == "gelu": 25 | return nn.GELU(approximate="tanh") 26 | if activation == "glu": 27 | return F.glu 28 | raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.") 29 | 30 | 31 | def _with_pos_embed(tensor, pos=None): 32 | return tensor if pos is None else tensor + pos 33 | 34 | 35 | class _PositionalEncoding(nn.Module): 36 | def __init__(self, d_model, max_len=5000): 37 | super().__init__() 38 | # Compute the positional encodings once in log space 39 | pe = torch.zeros(max_len, d_model) 40 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 41 | div_term = torch.exp( 42 | torch.arange(0, d_model, 2, dtype=torch.float) 43 | * -(np.log(10000.0) / d_model) 44 | ) 45 | pe[:, 0::2] = torch.sin(position * div_term) 46 | pe[:, 1::2] = torch.cos(position * div_term) 47 | pe = pe.unsqueeze(0).transpose(0, 1) 48 | self.register_buffer("pe", pe) 49 | 50 | def forward(self, x): 51 | """ 52 | Args: 53 | x: Tensor of shape (seq_len, batch_size, d_model) 54 | 55 | Returns: 56 | Tensor of shape (seq_len, batch_size, d_model) with positional encodings added 57 | """ 58 | pe = self.pe[: x.shape[0]] 59 | pe = pe.repeat((1, x.shape[1], 1)) 60 | return pe.detach().clone() 61 | 62 | 63 | class _TimeNetwork(nn.Module): 64 | def __init__(self, time_dim, out_dim, learnable_w=False): 65 | assert time_dim % 2 == 0, "time_dim must be even!" 66 | half_dim = int(time_dim // 2) 67 | super().__init__() 68 | 69 | w = np.log(10000) / (half_dim - 1) 70 | w = torch.exp(torch.arange(half_dim) * -w).float() 71 | self.register_parameter("w", nn.Parameter(w, requires_grad=learnable_w)) 72 | 73 | self.out_net = nn.Sequential( 74 | nn.Linear(time_dim, out_dim), nn.SiLU(), nn.Linear(out_dim, out_dim) 75 | ) 76 | 77 | def forward(self, x): 78 | assert len(x.shape) == 1, "assumes 1d input timestep array" 79 | x = x[:, None] * self.w[None] 80 | x = torch.cat((torch.cos(x), torch.sin(x)), dim=1) 81 | return self.out_net(x) 82 | 83 | 84 | class _SelfAttnEncoder(nn.Module): 85 | def __init__( 86 | self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="gelu" 87 | ): 88 | super().__init__() 89 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 90 | # Implementation of Feedforward model 91 | self.linear1 = nn.Linear(d_model, dim_feedforward) 92 | self.linear2 = nn.Linear(dim_feedforward, d_model) 93 | 94 | self.norm1 = nn.LayerNorm(d_model) 95 | self.norm2 = nn.LayerNorm(d_model) 96 | 97 | self.dropout1 = nn.Dropout(dropout) 98 | self.dropout2 = nn.Dropout(dropout) 99 | self.dropout3 = nn.Dropout(dropout) 100 | 101 | self.activation = _get_activation_fn(activation) 102 | 103 | def forward(self, src, pos): 104 | q = k = _with_pos_embed(src, pos) 105 | src2, _ = self.self_attn(q, k, value=src, need_weights=False) 106 | src = src + self.dropout1(src2) 107 | src = self.norm1(src) 108 | src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) 109 | src = src + self.dropout3(src2) 110 | src = self.norm2(src) 111 | return src 112 | 113 | def reset_parameters(self): 114 | for p in self.parameters(): 115 | if p.dim() > 1: 116 | nn.init.xavier_uniform_(p) 117 | 118 | 119 | class _ShiftScaleMod(nn.Module): 120 | def __init__(self, dim): 121 | super().__init__() 122 | self.act = nn.SiLU() 123 | self.scale = nn.Linear(dim, dim) 124 | self.shift = nn.Linear(dim, dim) 125 | 126 | def forward(self, x, c): 127 | c = self.act(c) 128 | return x * self.scale(c)[None] + self.shift(c)[None] 129 | 130 | def reset_parameters(self): 131 | nn.init.xavier_uniform_(self.scale.weight) 132 | nn.init.xavier_uniform_(self.shift.weight) 133 | nn.init.zeros_(self.scale.bias) 134 | nn.init.zeros_(self.shift.bias) 135 | 136 | 137 | class _ZeroScaleMod(nn.Module): 138 | def __init__(self, dim): 139 | super().__init__() 140 | self.act = nn.SiLU() 141 | self.scale = nn.Linear(dim, dim) 142 | 143 | def forward(self, x, c): 144 | c = self.act(c) 145 | return x * self.scale(c)[None] 146 | 147 | def reset_parameters(self): 148 | nn.init.zeros_(self.scale.weight) 149 | nn.init.zeros_(self.scale.bias) 150 | 151 | 152 | class _DiTDecoder(nn.Module): 153 | def __init__( 154 | self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="gelu" 155 | ): 156 | super().__init__() 157 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 158 | # Implementation of Feedforward model 159 | self.linear1 = nn.Linear(d_model, dim_feedforward) 160 | self.linear2 = nn.Linear(dim_feedforward, d_model) 161 | 162 | self.norm1 = nn.LayerNorm(d_model) 163 | self.norm2 = nn.LayerNorm(d_model) 164 | 165 | self.dropout1 = nn.Dropout(dropout) 166 | self.dropout2 = nn.Dropout(dropout) 167 | self.dropout3 = nn.Dropout(dropout) 168 | 169 | self.activation = _get_activation_fn(activation) 170 | 171 | # create modulation layers 172 | self.attn_mod1 = _ShiftScaleMod(d_model) 173 | self.attn_mod2 = _ZeroScaleMod(d_model) 174 | self.mlp_mod1 = _ShiftScaleMod(d_model) 175 | self.mlp_mod2 = _ZeroScaleMod(d_model) 176 | 177 | def forward(self, x, t, cond): 178 | # process the conditioning vector first 179 | cond = torch.mean(cond, axis=0) 180 | cond = cond + t 181 | 182 | x2 = self.attn_mod1(self.norm1(x), cond) 183 | x2, _ = self.self_attn(x2, x2, x2, need_weights=False) 184 | x = self.attn_mod2(self.dropout1(x2), cond) + x 185 | 186 | x2 = self.mlp_mod1(self.norm2(x), cond) 187 | x2 = self.linear2(self.dropout2(self.activation(self.linear1(x2)))) 188 | x2 = self.mlp_mod2(self.dropout3(x2), cond) 189 | return x + x2 190 | 191 | def reset_parameters(self): 192 | for p in self.parameters(): 193 | if p.dim() > 1: 194 | nn.init.xavier_uniform_(p) 195 | 196 | for s in (self.attn_mod1, self.attn_mod2, self.mlp_mod1, self.mlp_mod2): 197 | s.reset_parameters() 198 | 199 | 200 | class _FinalLayer(nn.Module): 201 | def __init__(self, hidden_size, out_size): 202 | super().__init__() 203 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 204 | self.linear = nn.Linear(hidden_size, out_size, bias=True) 205 | self.adaLN_modulation = nn.Sequential( 206 | nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) 207 | ) 208 | 209 | def forward(self, x, t, cond): 210 | # process the conditioning vector first 211 | cond = torch.mean(cond, axis=0) 212 | cond = cond + t 213 | 214 | shift, scale = self.adaLN_modulation(cond).chunk(2, dim=1) 215 | x = x * scale[None] + shift[None] 216 | x = self.linear(x) 217 | return x.transpose(0, 1) 218 | 219 | def reset_parameters(self): 220 | for p in self.parameters(): 221 | nn.init.zeros_(p) 222 | 223 | 224 | class _TransformerEncoder(nn.Module): 225 | def __init__(self, base_module, num_layers): 226 | super().__init__() 227 | self.layers = nn.ModuleList( 228 | [copy.deepcopy(base_module) for _ in range(num_layers)] 229 | ) 230 | 231 | for l in self.layers: 232 | l.reset_parameters() 233 | 234 | def forward(self, src, pos): 235 | x, outputs = src, [] 236 | for layer in self.layers: 237 | x = layer(x, pos) 238 | outputs.append(x) 239 | return outputs 240 | 241 | 242 | class _TransformerDecoder(_TransformerEncoder): 243 | def forward(self, src, t, all_conds): 244 | x = src 245 | for layer, cond in zip(self.layers, all_conds): 246 | x = layer(x, t, cond) 247 | return x 248 | 249 | 250 | class _DiTNoiseNet(nn.Module): 251 | def __init__( 252 | self, 253 | ac_dim, 254 | ac_chunk, 255 | time_dim=256, 256 | hidden_dim=512, 257 | num_blocks=6, 258 | dropout=0.1, 259 | dim_feedforward=2048, 260 | nhead=8, 261 | activation="gelu", 262 | ): 263 | super().__init__() 264 | 265 | # positional encoding blocks 266 | self.enc_pos = _PositionalEncoding(hidden_dim) 267 | self.register_parameter( 268 | "dec_pos", 269 | nn.Parameter(torch.empty(ac_chunk, 1, hidden_dim), requires_grad=True), 270 | ) 271 | nn.init.xavier_uniform_(self.dec_pos.data) 272 | 273 | # input encoder mlps 274 | self.time_net = _TimeNetwork(time_dim, hidden_dim) 275 | self.ac_proj = nn.Sequential( 276 | nn.Linear(ac_dim, ac_dim), 277 | nn.GELU(approximate="tanh"), 278 | nn.Linear(ac_dim, hidden_dim), 279 | ) 280 | 281 | # encoder blocks 282 | encoder_module = _SelfAttnEncoder( 283 | hidden_dim, 284 | nhead=nhead, 285 | dim_feedforward=dim_feedforward, 286 | dropout=dropout, 287 | activation=activation, 288 | ) 289 | self.encoder = _TransformerEncoder(encoder_module, num_blocks) 290 | 291 | # decoder blocks 292 | decoder_module = _DiTDecoder( 293 | hidden_dim, 294 | nhead=nhead, 295 | dim_feedforward=dim_feedforward, 296 | dropout=dropout, 297 | activation=activation, 298 | ) 299 | self.decoder = _TransformerDecoder(decoder_module, num_blocks) 300 | 301 | # turns predicted tokens into epsilons 302 | self.eps_out = _FinalLayer(hidden_dim, ac_dim) 303 | 304 | print( 305 | "number of diffusion parameters: {:e}".format( 306 | sum(p.numel() for p in self.parameters()) 307 | ) 308 | ) 309 | 310 | def forward(self, noise_actions, time, obs_enc, enc_cache=None): 311 | if enc_cache is None: 312 | enc_cache = self.forward_enc(obs_enc) 313 | return enc_cache, self.forward_dec(noise_actions, time, enc_cache) 314 | 315 | def forward_enc(self, obs_enc): 316 | obs_enc = obs_enc.transpose(0, 1) 317 | pos = self.enc_pos(obs_enc) 318 | enc_cache = self.encoder(obs_enc, pos) 319 | return enc_cache 320 | 321 | def forward_dec(self, noise_actions, time, enc_cache): 322 | time_enc = self.time_net(time) 323 | 324 | ac_tokens = self.ac_proj(noise_actions) 325 | ac_tokens = ac_tokens.transpose(0, 1) 326 | dec_in = ac_tokens + self.dec_pos 327 | 328 | # apply decoder 329 | dec_out = self.decoder(dec_in, time_enc, enc_cache) 330 | 331 | # apply final epsilon prediction layer 332 | return self.eps_out(dec_out, time_enc, enc_cache[-1]) 333 | 334 | 335 | class DiffusionTransformerAgent(BaseAgent): 336 | def __init__( 337 | self, 338 | features, 339 | odim, 340 | n_cams, 341 | use_obs, 342 | ac_dim, 343 | ac_chunk, 344 | train_diffusion_steps, 345 | eval_diffusion_steps, 346 | imgs_per_cam=1, 347 | dropout=0, 348 | share_cam_features=False, 349 | early_fusion=False, 350 | feat_norm=None, 351 | token_dim=None, 352 | noise_net_kwargs=dict(), 353 | ): 354 | 355 | # initialize obs and img tokenizers 356 | super().__init__( 357 | odim=odim, 358 | features=features, 359 | n_cams=n_cams, 360 | imgs_per_cam=imgs_per_cam, 361 | use_obs=use_obs, 362 | share_cam_features=share_cam_features, 363 | early_fusion=early_fusion, 364 | dropout=dropout, 365 | feat_norm=feat_norm, 366 | token_dim=token_dim, 367 | ) 368 | 369 | self.noise_net = _DiTNoiseNet( 370 | ac_dim=ac_dim, 371 | ac_chunk=ac_chunk, 372 | **noise_net_kwargs, 373 | ) 374 | self._ac_dim, self._ac_chunk = ac_dim, ac_chunk 375 | 376 | assert ( 377 | eval_diffusion_steps <= train_diffusion_steps 378 | ), "Can't eval with more steps!" 379 | self._train_diffusion_steps = train_diffusion_steps 380 | self._eval_diffusion_steps = eval_diffusion_steps 381 | self.diffusion_schedule = DDIMScheduler( 382 | num_train_timesteps=train_diffusion_steps, 383 | beta_start=0.0001, 384 | beta_end=0.02, 385 | beta_schedule="squaredcos_cap_v2", 386 | clip_sample=True, 387 | set_alpha_to_one=True, 388 | steps_offset=0, 389 | prediction_type="epsilon", 390 | ) 391 | 392 | def forward(self, imgs, obs, ac_flat, mask_flat): 393 | # get observation encoding and sample noise/timesteps 394 | B, device = obs.shape[0], obs.device 395 | s_t = self.tokenize_obs(imgs, obs) 396 | timesteps = torch.randint( 397 | low=0, high=self._train_diffusion_steps, size=(B,), device=device 398 | ).long() 399 | 400 | # diffusion unet logic assumes [B, T, adim] 401 | mask = mask_flat.reshape((B, self.ac_chunk, self.ac_dim)) 402 | actions = ac_flat.reshape((B, self.ac_chunk, self.ac_dim)) 403 | noise = torch.randn_like(actions) 404 | 405 | # construct noise actions given real actions, noise, and diffusion schedule 406 | noise_acs = self.diffusion_schedule.add_noise(actions, noise, timesteps) 407 | _, noise_pred = self.noise_net(noise_acs, timesteps, s_t) 408 | 409 | # calculate loss for noise net 410 | loss = nn.functional.mse_loss(noise_pred, noise, reduction="none") 411 | loss = (loss * mask).sum(1) # mask the loss to only consider "real" acs 412 | return loss.mean() 413 | 414 | def get_actions(self, imgs, obs, n_steps=None): 415 | # get observation encoding and sample noise 416 | B, device = obs.shape[0], obs.device 417 | s_t = self.tokenize_obs(imgs, obs) 418 | enc_cache = None 419 | noise_actions = torch.randn(B, self.ac_chunk, self.ac_dim, device=device) 420 | 421 | # set number of steps 422 | eval_steps = self._eval_diffusion_steps 423 | if n_steps is not None: 424 | assert ( 425 | n_steps <= self._train_diffusion_steps 426 | ), f"can't be > {self._train_diffusion_steps}" 427 | eval_steps = n_steps 428 | 429 | enc_cache = self.noise_net.forward_enc(s_t) 430 | 431 | # begin diffusion process 432 | self.diffusion_schedule.set_timesteps(eval_steps) 433 | self.diffusion_schedule.alphas_cumprod = ( 434 | self.diffusion_schedule.alphas_cumprod.to(device) 435 | ) 436 | for timestep in self.diffusion_schedule.timesteps: 437 | # predict noise given timestep 438 | batched_timestep = timestep.unsqueeze(0).repeat(B).to(device) 439 | noise_pred = self.noise_net.forward_dec(noise_actions, batched_timestep, enc_cache) 440 | 441 | # take diffusion step 442 | noise_actions = self.diffusion_schedule.step( 443 | model_output=noise_pred, timestep=timestep, sample=noise_actions 444 | ).prev_sample 445 | 446 | # return final action post diffusion 447 | return noise_actions 448 | 449 | @property 450 | def ac_chunk(self): 451 | return self._ac_chunk 452 | 453 | @property 454 | def ac_dim(self): 455 | return self._ac_dim 456 | --------------------------------------------------------------------------------