├── .gitignore ├── README.md ├── __init__.py ├── agents ├── __init__.py ├── agent.py ├── dp_agent.py ├── dp_agent_zmq.py ├── quest_agent.py ├── quest_agent_eef.py └── universal_robots_ur5e │ ├── LICENSE │ ├── README.md │ ├── assets │ ├── base_0.obj │ ├── base_1.obj │ ├── forearm_0.obj │ ├── forearm_1.obj │ ├── forearm_2.obj │ ├── forearm_3.obj │ ├── shoulder_0.obj │ ├── shoulder_1.obj │ ├── shoulder_2.obj │ ├── upperarm_0.obj │ ├── upperarm_1.obj │ ├── upperarm_2.obj │ ├── upperarm_3.obj │ ├── wrist1_0.obj │ ├── wrist1_1.obj │ ├── wrist1_2.obj │ ├── wrist2_0.obj │ ├── wrist2_1.obj │ ├── wrist2_2.obj │ └── wrist3.obj │ ├── scene.xml │ ├── ur5e.png │ └── ur5e.xml ├── camera_node.py ├── cameras ├── camera.py └── realsense_camera.py ├── env.py ├── eval_dir.py ├── inference_node.py ├── launch_inference_nodes.py ├── launch_nodes.py ├── learning └── dp │ ├── .gitignore │ ├── data_processing.py │ ├── dataset.py │ ├── learner.py │ ├── models.py │ ├── pipeline.py │ └── utils.py ├── requirements.txt ├── robot_node.py ├── robots ├── ability_gripper.py ├── robot.py ├── robotiq_gripper.py └── ur.py ├── run_env.py ├── run_openloop.py ├── test_dp_agent.py ├── test_dp_agent_zmq.py └── workflow ├── create_eval.py ├── download_dataset.sh ├── gen_deploy_scripts.py └── split_data.py /.gitignore: -------------------------------------------------------------------------------- 1 | */__pycache__/ 2 | *.pyc 3 | *.pkl 4 | *.pth 5 | *.zip 6 | *.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🕊️ HATO: Learning Visuotactile Skills with Two Multifingered Hands 2 | 3 | [[Project](https://toruowo.github.io/hato/)] 4 | [[Paper](https://arxiv.org/abs/2404.16823)] 5 | 6 | [Toru Lin](https://toruowo.github.io/), 7 | [Yu Zhang*](), 8 | [Qiyang Li*](https://colinqiyangli.github.io/), 9 | [Haozhi Qi*](https://haozhi.io/), 10 | [Brent Yi](https://scholar.google.com/citations?user=Ecy6lXwAAAAJ&hl=en), 11 | [Sergey Levine](https://people.eecs.berkeley.edu/~svlevine/), 12 | [Jitendra Malik](https://people.eecs.berkeley.edu/~malik/) 13 |
14 | 15 | ## Overview 16 | 17 | This repo contains code, datasets, and instructions to support the following use cases: 18 | - Collecting Demonstration Data 19 | - Training and Evaluating Diffusion Policies 20 | - Deploying Policies on Hardware 21 | 22 | ## Installation 23 | 24 | ``` 25 | conda create -n hato python=3.9 26 | conda activate hato 27 | conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia -y 28 | pip install -r ./requirements.txt 29 | ``` 30 | 31 | ## Collecting Demonstration Data 32 | 33 | 34 | This repo supports Meta Quest 2 as the teleoperation device. To start, install [oculus_reader](https://github.com/rail-berkeley/oculus_reader/blob/main/oculus_reader/reader.py) by following instructions in the link. 35 | 36 | We use [ZMQ](https://zeromq.org/) to handle communication between hardwares. 37 | Our data collection code is structured in the following way (credit to [gello_software](https://github.com/wuphilipp/gello_software) for the clean and modular template): 38 | 39 | 40 | | Directory / File | Detail | 41 | | :-------------: |:-------------:| 42 | | agents/ | contains Agent classes that generate low-level hardware control commands from teleoperation device or diffusion policy | 43 | | cameras/ | contains Camera classes that provide utilities to obtain real-time camera data | 44 | | robots/ | contains Robot classes that interface between Python code and physical hardwares to read observations or send low-level control commands | 45 | | *_node.py | contains ZMQ node classes for camera / robot / policy | 46 | | env.py | contains environment classes that organize the teleoperation and data collection process | 47 | | launch_nodes.py | script to launch robot hardware and sensor ZMQ nodes | 48 | | run_env.py | script to start the teleoperation and data collection process | 49 | 50 | (Code files not mentioned contain utilities for training and evaluating diffusion policies, and deploying policies on hardware; please see the next two sections for more details.) 51 | 52 | Currently available classes: 53 | - `agents/quest_agent.py`, `agents/quest_agent_eef.py`: support using Meta Quest 2 for hardware teleoperation, in either joint-space control and end-effector-space control mode 54 | - `agents/dp_agent.py`, `agents/dp_agent_zmq.py`: support using learned diffusion policies to provide control commands, either synchronously or asynchronously (see "Deploying Policies on Hardware" section for more details) 55 | - `cameras/realsense_camera.py`: supports reading and preprocessing RGB-D data from RealSense cameras 56 | - `robots/ur.py`, `robots/ability_gripper.py`, `robots/robotiq_gripper.py`: support hardware setup with a single UR5e arm or two UR5e arms, using Ability Hand or Robotiq gripper as the end-effector(s) 57 | 58 | These classes can be flexibly modified or extended to support other teleoperation devices / learning pipelines / cameras / robot hardwares. 59 | 60 | Example usage (note that `launch_node.py` and `run_env.py` should be run simultaneously in two separate windows): 61 | 62 | ``` 63 | # to collect data with two UR5e arms and Ability hands at 10Hz 64 | python launch_nodes.py --robot bimanual_ur --hand_type ability 65 | python run_env.py --agent quest_hand --no-show-camera-view --hz 10 --save_data 66 | 67 | # to collect data with two UR5e arms and Ability hands at 10Hz, showing camera view during data collection 68 | python launch_nodes.py --robot bimanual_ur --hand_type ability 69 | python run_env.py --agent quest_hand --hz 10 --save_data 70 | ``` 71 | Node processes can be cleaned up by running `pkill -9 -f launch_nodes.py`. 72 | 73 | ## Training and Evaluating Diffusion Policies 74 | 75 | ### Download Datasets 76 | 77 | [[Data Files](https://berkeley.app.box.com/s/379cf57zqm1akvr00vdcloxqxi3ucb9g?sortColumn=name&sortDirection=ASC)] 78 | 79 | The linked data folder contains datasets for the four tasks featured in [the HATO paper](): `banana`, `stack`, `pour`, and `steak`. 80 | 81 | Full dataset files can be unzipped using the `unzip` command. 82 | Note that the `pour` and `steak` datasets are split into part files because of the large file size. Before unzipping, the part files should be first concatenated back into single files using the following commands: 83 | ``` 84 | cat data_pour_part_0* > data_pour.zip 85 | cat data_steak_part_0* > data_steak.zip 86 | ``` 87 | 88 | Scripts to download and concatenate the datasets can be found in `workflow/download_dataset.sh`. 89 | 90 | ### Run Training 91 | 92 | 1. Run `python workflow/split_data.py --base_path Traj_Folder_Path --output_path Output_Folder_Path --data_name Data_Name --num_trajs N1 N2 N3` to split the data into train and validation sets. Number of trajectories used can be specified via the `num_trajs` argument. 93 | 2. Run `python ./learning/dp/pipeline.py --data_path Split_Folder_Path/Data_Name --model_save_path Model_Path` to train the model, where 94 | - `--data_path` is the splitted trajectory folder, which is the output_path + data_name in step 1. (data_name should not include suffix like `_train` or `_train_10`) 95 | - `--model_save_path` is the path to save the model 96 | 97 | Important Training Arguments 98 | 1. `--batch_size` : the batch size for training. 99 | 2. `--num_epochs` : the number of epochs for training. 100 | 3. `--representation_type`: the data representation type for the model. Format: `repr1--repr2--...`. Repr can be `eef`, `img`, `depth`, `touch`, `pos`, `hand_pos` 101 | 4. `--camera_indices`: the camera indices to use for the image data modality. Format: `01`,`012`,`02`, etc. 102 | 5. `--train_suffix`: the suffix for the training folder. This is useful when you want to train the model on different data splits and should be used with the `--num_trajs` arg of `split_data.py`. Format: `_10`, `_50`, etc. 103 | 6. `--load_img`: whether to load all the images into memory. If set to `True`, the training will be faster but will consume more memory. 104 | 7. `--use_memmap_cache`: whether to use memmap cache for the images. If set to `True`, it will create a memmap file in the training folder to accelerate the data loading. 105 | 8. `--use_wandb`: whether to use wandb for logging. 106 | 9. `--wandb_project_name`: the wandb project name. 107 | 10. `--wandb_entity_name`: the wandb entity name. 108 | 11. `--load_path`: the path to load the model. If set, the model will be loaded from the path and continue training. This should be the path of non-ema model. 109 | 110 | 111 | ### Run Evaluation 112 | 113 | Run `python ./eval_dir.py --eval_dir Traj_Folder_Path --ckpt_path Model_Path_1 Model_Path_2` to evaluate multiple models on all trajectories in the folder. 114 | 115 | 116 | ## Deploying Policies on Hardware 117 | 118 | A set of useful bash scripts can be generated using the following command: 119 | 120 | ```python workflow/gen_deploy_scripts.py -c [ckpt_folder]``` 121 | 122 | where `ckpt_folder` is a path that contains one or more checkpoints resulted from the training above. The generated bash scripts provide the following functionalities. 123 | 124 | - To run policy deployment with asynchronous setup (see Section IV-D in paper for more details), first run `*_inference.sh` to launch the server, then run `*_node.sh` && `*_env_jit.sh` in separate terminal windows to deploy the robot with inference server. 125 | 126 | - To run policy deployment without asynchronous setup, run `*_node.sh` && `*_env.sh` in separate terminal windows. 127 | 128 | - To run policy evaluation for individual checkpoints, run `*_test.sh`. Note that path to a folder containing trajectories for evaluation needs to be specified in addition to the model checkpoint path. 129 | 130 | - To run open-loop policy test, run `*_openloop.sh`. Note that path to a demonstration data trajectory needs to be specified in addition to the model checkpoint path. 131 | 132 | 133 | ## Acknowledgement 134 | 135 | This project was developed with help from the following codebases. 136 | 137 | - [ability-hand-api](https://github.com/psyonicinc/ability-hand-api/tree/master/python) 138 | - [diffusion_policy](https://github.com/real-stanford/diffusion_policy) 139 | - [gello_software](https://github.com/wuphilipp/gello_software/tree/main) 140 | - [mujoco_menagerie](https://github.com/google-deepmind/mujoco_menagerie/blob/main/universal_robots_ur5e/ur5e.xml) 141 | - [oculus_reader](https://github.com/rail-berkeley/oculus_reader) 142 | 143 | ## Reference 144 | 145 | If you find HATO or this codebase helpful in your research, please consider citing: 146 | 147 | ``` 148 | @article{lin2024learning, 149 | author={Lin, Toru and Zhang, Yu and Li, Qiyang and Qi, Haozhi and Yi, Brent and Levine, Sergey and Malik, Jitendra}, 150 | title={Learning Visuotactile Skills with Two Multifingered Hands}, 151 | journal={arXiv:2404.16823}, 152 | year={2024} 153 | } 154 | ``` 155 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ToruOwO/hato/3b6f4496e11a386df8f362f398b36e832944b2b8/__init__.py -------------------------------------------------------------------------------- /agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ToruOwO/hato/3b6f4496e11a386df8f362f398b36e832944b2b8/agents/__init__.py -------------------------------------------------------------------------------- /agents/agent.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Protocol 2 | 3 | import numpy as np 4 | 5 | 6 | class Agent(Protocol): 7 | def act(self, obs: Dict[str, Any]) -> np.ndarray: 8 | """Returns an action given an observation. 9 | 10 | Args: 11 | obs: observation from the environment. 12 | 13 | Returns: 14 | action: action to take on the environment. 15 | """ 16 | raise NotImplementedError 17 | 18 | 19 | class BimanualAgent(Agent): 20 | def __init__(self, agent_left: Agent, agent_right: Agent): 21 | self.agent_left = agent_left 22 | self.agent_right = agent_right 23 | 24 | def act(self, obs: Dict[str, Any]) -> np.ndarray: 25 | left_obs = {} 26 | right_obs = {} 27 | for key, val in obs.items(): 28 | L = val.shape[0] 29 | half_dim = L // 2 30 | if key.endswith("rgb") or key.endswith("depth"): 31 | left_obs[key] = val 32 | right_obs[key] = val 33 | else: 34 | assert L == half_dim * 2, f"{key} must be even, something is wrong" 35 | left_obs[key] = val[:half_dim] 36 | right_obs[key] = val[half_dim:] 37 | return np.concatenate( 38 | [self.agent_left.act(left_obs), self.agent_right.act(right_obs)] 39 | ) 40 | 41 | 42 | class SafetyWrapper: 43 | def __init__(self, ur_idx, hand_idx, agent, delta=0.5, hand_delta=0.1): 44 | self.ur_idx = ur_idx 45 | self.hand_idx = hand_idx 46 | self.agent = agent 47 | self.delta = delta 48 | self.hand_delta = hand_delta 49 | 50 | # Ability Hand ranges 51 | self.num_hand_dofs = 12 52 | self.upper_ranges = [110, 110, 110, 110, 90, 120] * 2 53 | self.lower_ranges = [5] * self.num_hand_dofs 54 | 55 | def _hand_pos_to_cmd(self, pos): 56 | """ 57 | pos: desired hand degrees for Ability Hands 58 | """ 59 | assert len(pos) == self.num_hand_dofs 60 | cmd = [0] * self.num_hand_dofs 61 | for i in range(self.num_hand_dofs): 62 | if i in [5, 11]: 63 | pos[i] = -pos[i] 64 | cmd[i] = (pos[i] - self.lower_ranges[i]) / ( 65 | self.upper_ranges[i] - self.lower_ranges[i] 66 | ) 67 | return cmd 68 | 69 | def act_safe(self, agent, obs, eef=False): 70 | joints = obs["joint_positions"] 71 | action = agent.act(obs) 72 | if eef: 73 | eef_pose = obs["ee_pos_quat"] 74 | left_eef_pos = eef_pose[:3] 75 | right_eef_pos = eef_pose[6:9] 76 | left_eef_target = action[:3] 77 | right_eef_target = action[12:15] 78 | if np.linalg.norm(left_eef_pos - left_eef_target) > 0.5: 79 | print("Left EEF action is too big") 80 | print( 81 | f"Left EEF pos: {left_eef_pos}, target: {left_eef_target}, diff: {left_eef_pos - left_eef_target}" 82 | ) 83 | if np.linalg.norm(right_eef_pos - right_eef_target) > 0.5: 84 | print("Right EEF action is too big") 85 | print( 86 | f"Right EEF pos: {right_eef_pos}, target: {right_eef_target}, diff: {right_eef_pos - right_eef_target}" 87 | ) 88 | 89 | left_eef_target = np.clip( 90 | left_eef_target, 91 | left_eef_pos - self.delta, 92 | left_eef_pos + self.delta, 93 | ) 94 | right_eef_target = np.clip( 95 | right_eef_target, 96 | right_eef_pos - self.delta, 97 | right_eef_pos + self.delta, 98 | ) 99 | action[:3] = left_eef_target 100 | action[12:15] = right_eef_target 101 | else: 102 | # check if action is too big 103 | if (np.abs(action[self.ur_idx] - joints[self.ur_idx]) > self.delta).any(): 104 | print("Action is too big") 105 | 106 | # print which joints are too big 107 | joint_index = np.where(np.abs(action - joints) > self.delta)[0] 108 | for j in joint_index: 109 | if j in self.ur_idx: 110 | print( 111 | f"Joint [{j}], leader: {action[j]}, follower: {joints[j]}, diff: {action[j] - joints[j]}" 112 | ) 113 | action[self.ur_idx] = np.clip( 114 | action[self.ur_idx], 115 | joints[self.ur_idx] - self.delta, 116 | joints[self.ur_idx] + self.delta, 117 | ) 118 | 119 | if self.hand_idx is not None: 120 | joint_cmd = self._hand_pos_to_cmd(joints[self.hand_idx]) 121 | action[self.hand_idx] = joint_cmd + np.clip( 122 | action[self.hand_idx] - joint_cmd, -self.hand_delta, self.hand_delta 123 | ) 124 | action[self.hand_idx] = np.clip(action[self.hand_idx], 0, 1) 125 | return action 126 | -------------------------------------------------------------------------------- /agents/dp_agent.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | import os 4 | from typing import Any, Dict 5 | 6 | import numpy as np 7 | import quaternion 8 | import torch 9 | 10 | from learning.dp.pipeline import Agent as DPAgent 11 | 12 | 13 | def from_numpy(data, device, unsqueeze=True): 14 | return { 15 | key: torch.from_numpy(value).to(device=device)[None] 16 | if unsqueeze 17 | else torch.from_numpy(value).to(device=device) 18 | for key, value in data.items() 19 | if key != "activated" 20 | } 21 | 22 | 23 | UR_IDX = list(range(6)) + list(range(12, 18)) 24 | LEFT_HAND_IDX = list(range(6, 12)) 25 | RIGHT_HAND_IDX = list(range(18, 24)) 26 | 27 | 28 | def get_reset_joints(ur_eef=False): 29 | if ur_eef: 30 | # these are EEF pose 31 | arm_joints_left = [ 32 | -0.10244499252760966, 33 | -0.7492784625293504, 34 | 0.14209881911585326, 35 | -0.3622358797572402, 36 | -1.4347279978985925, 37 | 0.8691789808786153, 38 | ] 39 | 40 | arm_joints_right = [ 41 | 0.2313341406775527, 42 | -0.7512396951283128, 43 | 0.06337444935928868, 44 | 0.6512089317940273, 45 | 1.3246649009637026, 46 | 0.5471978290474188, 47 | ] 48 | else: 49 | arm_joints_left = [-80, -140, -80, -85, -10, 80] 50 | arm_joints_right = [-270, -30, 70, -85, 10, 0] 51 | hand_joints = [0, 0, 0, 0, 0.5, 0.5] 52 | reset_joints_left = np.concatenate([np.deg2rad(arm_joints_left), hand_joints]) 53 | reset_joints_right = np.concatenate([np.deg2rad(arm_joints_right), hand_joints]) 54 | reset_joints = np.concatenate([reset_joints_left, reset_joints_right]) 55 | return reset_joints 56 | 57 | 58 | def get_eef_pose(eef_pose, eef_delta): 59 | pos_delta = eef_delta[:3] 60 | rot_delta = eef_delta[3:] 61 | pos = eef_pose[:3] + pos_delta 62 | # quaternion multiplication 63 | rot = quaternion.as_rotation_vector( 64 | quaternion.from_rotation_vector(rot_delta) 65 | * quaternion.from_rotation_vector(eef_pose[3:]) 66 | ) 67 | return np.concatenate((pos, rot)) 68 | 69 | 70 | def parse_txt_to_json(input_file_path, output_file_path): 71 | # Initialize an empty dictionary to store the parsed key-value pairs 72 | data = {} 73 | 74 | # Open and read the input text file line by line 75 | with open(input_file_path, "r") as file: 76 | for line in file: 77 | kv = line.strip().split(": ", 1) 78 | if len(kv) != 2: 79 | continue 80 | key, value = kv 81 | if key == "camera_indices": 82 | data[key] = list(map(int, value)) 83 | elif key == "representation_type": 84 | data[key] = value.split("-") 85 | else: 86 | try: 87 | value = int(value) 88 | except ValueError: 89 | try: 90 | value = float(value) 91 | except ValueError: 92 | if value == "True": 93 | value = True 94 | elif value == "False": 95 | value = False 96 | elif value == "None": 97 | value = None 98 | 99 | data[key] = value 100 | 101 | with open(output_file_path, "w") as json_file: 102 | json.dump(data, json_file, indent=4) 103 | return data 104 | 105 | 106 | class BimanualDPAgent: 107 | def __init__( 108 | self, 109 | ckpt_path, 110 | dp_args=None, 111 | binarize_finger_action=False, 112 | ): 113 | if dp_args is None: 114 | dp_args = self.get_default_dp_args() 115 | 116 | # rewrite dp_args based on saved args 117 | args_txt = os.path.join(os.path.dirname(ckpt_path), "args_log.txt") 118 | args_json = os.path.join(os.path.dirname(ckpt_path), "args_log.json") 119 | args = parse_txt_to_json(args_txt, args_json) 120 | for k in dp_args.keys(): 121 | if k == "output_sizes": 122 | dp_args[k]["img"] = args["image_output_size"] 123 | else: 124 | if k in args: 125 | dp_args[k] = args[k] 126 | 127 | # save dp args in ckpt path as json 128 | ckpt_dir = os.path.dirname(ckpt_path) 129 | args_path = os.path.join(ckpt_dir, "dp_args.json") 130 | if os.path.exists(args_path): 131 | with open(args_path, "r") as f: 132 | dp_args = json.load(f) 133 | else: 134 | with open(args_path, "w") as f: 135 | json.dump(dp_args, f) 136 | 137 | torch.cuda.set_device(0) 138 | self.dp = DPAgent( 139 | output_sizes=dp_args["output_sizes"], 140 | representation_type=dp_args["representation_type"], 141 | identity_encoder=dp_args["identity_encoder"], 142 | camera_indices=dp_args["camera_indices"], 143 | pred_horizon=dp_args["pred_horizon"], 144 | obs_horizon=dp_args["obs_horizon"], 145 | action_horizon=dp_args["action_horizon"], 146 | without_sampling=dp_args["without_sampling"], 147 | predict_eef_delta=dp_args["predict_eef_delta"], 148 | predict_pos_delta=dp_args["predict_pos_delta"], 149 | use_ddim=dp_args["use_ddim"], 150 | ) 151 | self.dp_args = dp_args 152 | self.obsque = collections.deque(maxlen=dp_args["obs_horizon"]) 153 | self.dp.load(ckpt_path) 154 | self.action_queue = collections.deque(maxlen=dp_args["action_horizon"]) 155 | self.max_length = 100 156 | self.count = 0 157 | self.except_thumb_hand_indices = np.array([6, 7, 8, 9, 18, 19, 20, 21]) 158 | self.binaraize_finger_action = binarize_finger_action 159 | self.clip_far = dp_args["clip_far"] 160 | self.predict_eef_delta = dp_args["predict_eef_delta"] 161 | self.predict_pos_delta = dp_args["predict_pos_delta"] 162 | assert not (self.predict_eef_delta and self.predict_pos_delta) 163 | self.control = get_reset_joints(ur_eef=self.predict_eef_delta) 164 | 165 | self.num_diffusion_iters = dp_args["num_diffusion_iters"] 166 | 167 | self.hand_uppers = np.array([110.0, 110.0, 110.0, 110.0, 90.0, 120.0]) 168 | self.hand_lowers = np.array([5.0, 5.0, 5.0, 5.0, 5.0, 5.0]) 169 | 170 | # TODO: remove hack 171 | self.hand_new_uppers = np.array([75] * 4 + [90.0, 120.0]) 172 | 173 | self.trigger_state = {"l": True, "r": True} 174 | 175 | @staticmethod 176 | def get_default_dp_args(): 177 | return { 178 | "output_sizes": { 179 | "eef": 64, 180 | "hand_pos": 64, 181 | "img": 128, 182 | "pos": 128, 183 | "touch": 64, 184 | }, 185 | "representation_type": ["img", "pos", "touch", "depth"], 186 | "identity_encoder": False, 187 | "camera_indices": [0, 1, 2], 188 | "obs_horizon": 4, 189 | "pred_horizon": 16, 190 | "action_horizon": 8, 191 | "num_diffusion_iters": 15, 192 | "without_sampling": False, 193 | "clip_far": False, 194 | "predict_eef_delta": False, 195 | "predict_pos_delta": False, 196 | "use_ddim": False, 197 | } 198 | 199 | def compile_inference(self, example_obs, precision="high", num_inference_iters=5): 200 | torch.set_float32_matmul_precision(precision) 201 | self.dp.policy.forward = torch.compile(torch.no_grad(self.dp.policy.forward)) 202 | self.num_diffusion_iters = num_inference_iters 203 | 204 | for i in range(25): # burn in 205 | self.act(example_obs) 206 | 207 | def act(self, obs: Dict[str, Any]) -> np.ndarray: 208 | curr_joint_pos = obs["joint_positions"] 209 | curr_eef_pose = obs["ee_pos_quat"] 210 | obs = self.dp.get_observation([obs], load_img=True) 211 | if "img" in obs: 212 | obs["img"] = self.dp.eval_transform(obs["img"].squeeze(0)) 213 | 214 | # if obsque is empty, fill it with the current observation 215 | if len(self.obsque) == 0: 216 | self.obsque.extend([obs] * self.dp_args["obs_horizon"]) 217 | else: 218 | self.obsque.append(obs) 219 | 220 | # if action queue is not empty, return the first action in the queue 221 | if len(self.action_queue) > 0: 222 | act = self.action_queue.popleft() 223 | 224 | # if action queue is empty, predict new actions 225 | else: 226 | pred = self.dp.predict( 227 | self.obsque, num_diffusion_iters=self.num_diffusion_iters 228 | ) 229 | for i in range(self.dp_args["action_horizon"]): 230 | act = pred[i] 231 | self.action_queue.append(act) 232 | 233 | act = self.action_queue.popleft() 234 | 235 | if self.predict_pos_delta: 236 | self.control[UR_IDX] = curr_joint_pos[UR_IDX] 237 | self.control = self.control + act 238 | act = self.control 239 | # act = curr_joint_pos + act 240 | 241 | if self.predict_eef_delta: 242 | left_arm_act = get_eef_pose(curr_eef_pose[:6], act[:6]) 243 | left_hand_act = act[6:12] 244 | right_arm_act = get_eef_pose(curr_eef_pose[6:], act[12:18]) 245 | right_hand_act = act[18:24] 246 | act = np.concatenate( 247 | [left_arm_act, left_hand_act, right_arm_act, right_hand_act], 248 | axis=-1, 249 | ) 250 | 251 | # if binarize_finger_action is True, binarize the finger action 252 | 253 | if self.binaraize_finger_action: 254 | mean_act = np.mean(act[self.except_thumb_hand_indices]) 255 | if mean_act > 0.5: 256 | act[self.except_thumb_hand_indices] = 1.0 257 | else: 258 | act[self.except_thumb_hand_indices] = 0.0 259 | else: 260 | left_hand = ( 261 | act[LEFT_HAND_IDX] * (self.hand_uppers - self.hand_lowers) 262 | + self.hand_lowers 263 | ) 264 | act[LEFT_HAND_IDX] = (left_hand - self.hand_lowers) / ( 265 | self.hand_new_uppers - self.hand_lowers 266 | ) 267 | right_hand = ( 268 | act[RIGHT_HAND_IDX] * (self.hand_uppers - self.hand_lowers) 269 | + self.hand_lowers 270 | ) 271 | act[RIGHT_HAND_IDX] = (right_hand - self.hand_lowers) / ( 272 | self.hand_new_uppers - self.hand_lowers 273 | ) 274 | act[list(range(6, 12)) + list(range(18, 24))] = np.clip( 275 | act[list(range(6, 12)) + list(range(18, 24))], 0, 1 276 | ) 277 | 278 | return act 279 | -------------------------------------------------------------------------------- /agents/dp_agent_zmq.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | import os 4 | from typing import Any, Dict 5 | 6 | import numpy as np 7 | import quaternion 8 | import torch 9 | import time 10 | 11 | import pickle 12 | 13 | from learning.dp.pipeline import Agent as DPAgent 14 | from inference_node import ( 15 | ZMQInferenceClient, 16 | ZMQInferenceServer, 17 | DEFAULT_INFERENCE_PORT, 18 | ) 19 | 20 | 21 | def from_numpy(data, device, unsqueeze=True): 22 | return { 23 | key: torch.from_numpy(value).to(device=device)[None] 24 | if unsqueeze 25 | else torch.from_numpy(value).to(device=device) 26 | for key, value in data.items() 27 | if key != "activated" 28 | } 29 | 30 | 31 | UR_IDX = list(range(6)) + list(range(12, 18)) 32 | LEFT_HAND_IDX = list(range(6, 12)) 33 | RIGHT_HAND_IDX = list(range(18, 24)) 34 | 35 | 36 | def get_reset_joints(ur_eef=False): 37 | if ur_eef: 38 | # these are EEF pose 39 | arm_joints_left = [ 40 | -0.10244499252760966, 41 | -0.7492784625293504, 42 | 0.14209881911585326, 43 | -0.3622358797572402, 44 | -1.4347279978985925, 45 | 0.8691789808786153, 46 | ] 47 | 48 | arm_joints_right = [ 49 | 0.2313341406775527, 50 | -0.7512396951283128, 51 | 0.06337444935928868, 52 | 0.6512089317940273, 53 | 1.3246649009637026, 54 | 0.5471978290474188, 55 | ] 56 | else: 57 | arm_joints_left = [-80, -140, -80, -85, -10, 80] 58 | arm_joints_right = [-270, -30, 70, -85, 10, 0] 59 | hand_joints = [0, 0, 0, 0, 0.5, 0.5] 60 | reset_joints_left = np.concatenate([np.deg2rad(arm_joints_left), hand_joints]) 61 | reset_joints_right = np.concatenate([np.deg2rad(arm_joints_right), hand_joints]) 62 | reset_joints = np.concatenate([reset_joints_left, reset_joints_right]) 63 | return reset_joints 64 | 65 | 66 | def get_eef_pose(eef_pose, eef_delta): 67 | pos_delta = eef_delta[:3] 68 | rot_delta = eef_delta[3:] 69 | pos = eef_pose[:3] + pos_delta 70 | # quaternion multiplication 71 | rot = quaternion.as_rotation_vector( 72 | quaternion.from_rotation_vector(rot_delta) 73 | * quaternion.from_rotation_vector(eef_pose[3:]) 74 | ) 75 | return np.concatenate((pos, rot)) 76 | 77 | 78 | def parse_txt_to_json(input_file_path, output_file_path): 79 | # Initialize an empty dictionary to store the parsed key-value pairs 80 | data = {} 81 | 82 | # Open and read the input text file line by line 83 | with open(input_file_path, "r") as file: 84 | for line in file: 85 | kv = line.strip().split(": ", 1) 86 | if len(kv) != 2: 87 | continue 88 | key, value = kv 89 | if key == "camera_indices": 90 | data[key] = list(map(int, value)) 91 | elif key == "representation_type": 92 | data[key] = value.split("-") 93 | else: 94 | try: 95 | value = int(value) 96 | except ValueError: 97 | try: 98 | value = float(value) 99 | except ValueError: 100 | if value == "True": 101 | value = True 102 | elif value == "False": 103 | value = False 104 | elif value == "None": 105 | value = None 106 | 107 | data[key] = value 108 | 109 | with open(output_file_path, "w") as json_file: 110 | json.dump(data, json_file, indent=4) 111 | return data 112 | 113 | 114 | class BimanualDPAgentServer(ZMQInferenceServer): 115 | def __init__( 116 | self, 117 | ckpt_path, 118 | dp_args=None, 119 | binarize_finger_action=False, 120 | *args, 121 | **kwargs, 122 | ): 123 | super().__init__(*args, **kwargs) 124 | 125 | if dp_args is None: 126 | dp_args = self.get_default_dp_args() 127 | 128 | # rewrite dp_args based on saved args 129 | args_txt = os.path.join(os.path.dirname(ckpt_path), "args_log.txt") 130 | args_json = os.path.join(os.path.dirname(ckpt_path), "args_log.json") 131 | args = parse_txt_to_json(args_txt, args_json) 132 | for k in dp_args.keys(): 133 | if k == "output_sizes": 134 | dp_args[k]["img"] = args["image_output_size"] 135 | else: 136 | if k in args: 137 | dp_args[k] = args[k] 138 | 139 | # save dp args in ckpt path as json 140 | ckpt_dir = os.path.dirname(ckpt_path) 141 | args_path = os.path.join(ckpt_dir, "dp_args.json") 142 | if os.path.exists(args_path): 143 | with open(args_path, "r") as f: 144 | dp_args = json.load(f) 145 | else: 146 | with open(args_path, "w") as f: 147 | json.dump(dp_args, f) 148 | 149 | torch.cuda.set_device(0) 150 | self.dp = DPAgent( 151 | output_sizes=dp_args["output_sizes"], 152 | representation_type=dp_args["representation_type"], 153 | identity_encoder=dp_args["identity_encoder"], 154 | camera_indices=dp_args["camera_indices"], 155 | pred_horizon=dp_args["pred_horizon"], 156 | obs_horizon=dp_args["obs_horizon"], 157 | action_horizon=dp_args["action_horizon"], 158 | without_sampling=dp_args["without_sampling"], 159 | predict_eef_delta=dp_args["predict_eef_delta"], 160 | predict_pos_delta=dp_args["predict_pos_delta"], 161 | use_ddim=dp_args["use_ddim"], 162 | ) 163 | self.dp_args = dp_args 164 | self.obsque = collections.deque(maxlen=dp_args["obs_horizon"]) 165 | self.dp.load(ckpt_path) 166 | self.action_queue = collections.deque(maxlen=dp_args["action_horizon"]) 167 | self.max_length = 100 168 | self.count = 0 169 | self.except_thumb_hand_indices = np.array([6, 7, 8, 9, 18, 19, 20, 21]) 170 | self.binaraize_finger_action = binarize_finger_action 171 | self.clip_far = dp_args["clip_far"] 172 | self.predict_eef_delta = dp_args["predict_eef_delta"] 173 | self.predict_pos_delta = dp_args["predict_pos_delta"] 174 | assert not (self.predict_eef_delta and self.predict_pos_delta) 175 | self.control = get_reset_joints(ur_eef=self.predict_eef_delta) 176 | 177 | self.num_diffusion_iters = dp_args["num_diffusion_iters"] 178 | 179 | self.hand_uppers = np.array([110.0, 110.0, 110.0, 110.0, 90.0, 120.0]) 180 | self.hand_lowers = np.array([5.0, 5.0, 5.0, 5.0, 5.0, 5.0]) 181 | 182 | # TODO: remove hack 183 | self.hand_new_uppers = np.array([75] * 4 + [90.0, 120.0]) 184 | 185 | self.trigger_state = {"l": True, "r": True} 186 | 187 | @staticmethod 188 | def get_default_dp_args(): 189 | return { 190 | "output_sizes": { 191 | "eef": 64, 192 | "hand_pos": 64, 193 | "img": 128, 194 | "pos": 128, 195 | "touch": 64, 196 | }, 197 | "representation_type": ["img", "pos", "touch", "depth"], 198 | "identity_encoder": False, 199 | "camera_indices": [0, 1, 2], 200 | "obs_horizon": 4, 201 | "pred_horizon": 16, 202 | "action_horizon": 8, 203 | "num_diffusion_iters": 15, 204 | "without_sampling": False, 205 | "clip_far": False, 206 | "predict_eef_delta": False, 207 | "predict_pos_delta": False, 208 | "use_ddim": False, 209 | } 210 | 211 | def compile_inference(self, precision="high"): 212 | message = self._socket.recv() 213 | start_time = time.time() 214 | state_dict = pickle.loads(message) 215 | self.num_diffusion_iters = state_dict["num_diffusion_iters"] 216 | example_obs = state_dict["example_obs"] 217 | print( 218 | f"received compilation request: # diff iters = {state_dict['num_diffusion_iters']}" 219 | ) 220 | 221 | torch.set_float32_matmul_precision(precision) 222 | self.dp.policy.forward = torch.compile(torch.no_grad(self.dp.policy.forward)) 223 | 224 | for i in range(25): # burn in 225 | self.act(example_obs) 226 | print("success, compile time: " + str(time.time() - start_time)) 227 | self._socket.send_string("success") 228 | 229 | def infer(self, obs: Dict[str, Any]) -> np.ndarray: 230 | return self.dp.predict([obs], num_diffusion_iters=self.num_diffusion_iters) 231 | 232 | def act(self, obs: Dict[str, Any]) -> np.ndarray: 233 | curr_joint_pos = obs["joint_positions"] 234 | curr_eef_pose = obs["ee_pos_quat"] 235 | obs = self.dp.get_observation([obs], load_img=True) 236 | if "img" in obs: 237 | obs["img"] = self.dp.eval_transform(obs["img"].squeeze(0)) 238 | return self.infer(obs) 239 | 240 | 241 | class BimanualDPAgent(ZMQInferenceClient): 242 | def __init__( 243 | self, 244 | ckpt_path, 245 | dp_args=None, 246 | binarize_finger_action=False, 247 | port=DEFAULT_INFERENCE_PORT, 248 | host="127.0.0.1", 249 | temporal_ensemble_mode="new", 250 | temporal_ensemble_act_tau=0.5, 251 | ): 252 | super().__init__( 253 | default_action=get_reset_joints(), 254 | port=port, 255 | host=host, 256 | ensemble_mode=temporal_ensemble_mode, 257 | act_tau=temporal_ensemble_act_tau, 258 | ) 259 | 260 | if dp_args is None: 261 | dp_args = self.get_default_dp_args() 262 | 263 | # rewrite dp_args based on saved args 264 | args_txt = os.path.join(os.path.dirname(ckpt_path), "args_log.txt") 265 | args_json = os.path.join(os.path.dirname(ckpt_path), "args_log.json") 266 | args = parse_txt_to_json(args_txt, args_json) 267 | for k in dp_args.keys(): 268 | if k == "output_sizes": 269 | dp_args[k]["img"] = args["image_output_size"] 270 | else: 271 | if k in args: 272 | dp_args[k] = args[k] 273 | 274 | self.dp_args = dp_args 275 | self.obsque = collections.deque(maxlen=dp_args["obs_horizon"]) 276 | # self.dp.load(ckpt_path) 277 | self.action_queue = collections.deque(maxlen=dp_args["action_horizon"]) 278 | self.max_length = 100 279 | self.count = 0 280 | self.except_thumb_hand_indices = np.array([6, 7, 8, 9, 18, 19, 20, 21]) 281 | self.binaraize_finger_action = binarize_finger_action 282 | self.clip_far = dp_args["clip_far"] 283 | self.predict_eef_delta = dp_args["predict_eef_delta"] 284 | self.predict_pos_delta = dp_args["predict_pos_delta"] 285 | assert not (self.predict_eef_delta and self.predict_pos_delta) 286 | self.control = get_reset_joints(ur_eef=self.predict_eef_delta) 287 | 288 | self.num_diffusion_iters = dp_args["num_diffusion_iters"] 289 | 290 | self.hand_uppers = np.array([110.0, 110.0, 110.0, 110.0, 90.0, 120.0]) 291 | self.hand_lowers = np.array([5.0, 5.0, 5.0, 5.0, 5.0, 5.0]) 292 | 293 | # TODO: remove hack 294 | self.hand_new_uppers = np.array([75] * 4 + [90.0, 120.0]) 295 | 296 | self.trigger_state = {"l": True, "r": True} 297 | 298 | @staticmethod 299 | def get_default_dp_args(): 300 | return { 301 | "output_sizes": { 302 | "eef": 64, 303 | "hand_pos": 64, 304 | "img": 128, 305 | "pos": 128, 306 | "touch": 64, 307 | }, 308 | "representation_type": ["img", "pos", "touch", "depth"], 309 | "identity_encoder": False, 310 | "camera_indices": [0, 1, 2], 311 | "obs_horizon": 4, 312 | "pred_horizon": 16, 313 | "action_horizon": 8, 314 | "num_diffusion_iters": 15, 315 | "without_sampling": False, 316 | "clip_far": False, 317 | "predict_eef_delta": False, 318 | "predict_pos_delta": False, 319 | "use_ddim": False, 320 | } 321 | 322 | def compile_inference(self, example_obs, num_diffusion_iters): 323 | message = pickle.dumps( 324 | {"example_obs": example_obs, "num_diffusion_iters": num_diffusion_iters} 325 | ) 326 | self._socket.send(message) 327 | 328 | message = self._socket.recv() 329 | assert message == b"success" 330 | 331 | def act(self, obs: Dict[str, Any]) -> np.ndarray: 332 | curr_joint_pos = obs["joint_positions"] 333 | curr_eef_pose = obs["ee_pos_quat"] 334 | act = super().act(obs) 335 | 336 | if self.predict_pos_delta: 337 | self.control[UR_IDX] = curr_joint_pos[UR_IDX] 338 | self.control = self.control + act 339 | act = self.control 340 | # act = curr_joint_pos + act 341 | 342 | if self.predict_eef_delta: 343 | left_arm_act = get_eef_pose(curr_eef_pose[:6], act[:6]) 344 | left_hand_act = act[6:12] 345 | right_arm_act = get_eef_pose(curr_eef_pose[6:], act[12:18]) 346 | right_hand_act = act[18:24] 347 | act = np.concatenate( 348 | [left_arm_act, left_hand_act, right_arm_act, right_hand_act], 349 | axis=-1, 350 | ) 351 | 352 | # if binarize_finger_action is True, binarize the finger action 353 | 354 | if self.binaraize_finger_action: 355 | mean_act = np.mean(act[self.except_thumb_hand_indices]) 356 | if mean_act > 0.5: 357 | act[self.except_thumb_hand_indices] = 1.0 358 | else: 359 | act[self.except_thumb_hand_indices] = 0.0 360 | else: 361 | left_hand = ( 362 | act[LEFT_HAND_IDX] * (self.hand_uppers - self.hand_lowers) 363 | + self.hand_lowers 364 | ) 365 | act[LEFT_HAND_IDX] = (left_hand - self.hand_lowers) / ( 366 | self.hand_new_uppers - self.hand_lowers 367 | ) 368 | right_hand = ( 369 | act[RIGHT_HAND_IDX] * (self.hand_uppers - self.hand_lowers) 370 | + self.hand_lowers 371 | ) 372 | act[RIGHT_HAND_IDX] = (right_hand - self.hand_lowers) / ( 373 | self.hand_new_uppers - self.hand_lowers 374 | ) 375 | act[list(range(6, 12)) + list(range(18, 24))] = np.clip( 376 | act[list(range(6, 12)) + list(range(18, 24))], 0, 1 377 | ) 378 | 379 | return act 380 | -------------------------------------------------------------------------------- /agents/quest_agent.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import numpy as np 4 | import quaternion 5 | from dm_control import mjcf 6 | from dm_control.mujoco.wrapper import mjbindings 7 | from dm_control.utils.inverse_kinematics import qpos_from_site_pose, nullspace_method 8 | from oculus_reader.reader import OculusReader 9 | from agents.agent import Agent 10 | 11 | mjlib = mjbindings.mjlib 12 | 13 | mj2ur = np.array([[0, -1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) 14 | ur2mj = np.linalg.inv(mj2ur) 15 | 16 | trigger_state = {"l": False, "r": False} 17 | 18 | 19 | def apply_transfer(mat: np.ndarray, xyz: np.ndarray) -> np.ndarray: 20 | # xyz can be 3dim or 4dim (homogeneous) or can be a rotation matrix 21 | if len(xyz) == 3: 22 | xyz = np.append(xyz, 1) 23 | return np.matmul(mat, xyz)[:3] 24 | 25 | 26 | quest2ur = np.array([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]) 27 | # 45 deg CCW 28 | ur2left = np.array( 29 | [ 30 | [1 / 2 * np.sqrt(2), 1 / 2 * np.sqrt(2), 0, 0], 31 | [-1 / 2 * np.sqrt(2), 1 / 2 * np.sqrt(2), 0, 0], 32 | [0, 0, 1, 0], 33 | [0, 0, 0, 1], 34 | ] 35 | ) 36 | # 45 deg CW 37 | ur2right = np.array( 38 | [ 39 | [1 / 2 * np.sqrt(2), -1 / 2 * np.sqrt(2), 0, 0], 40 | [1 / 2 * np.sqrt(2), 1 / 2 * np.sqrt(2), 0, 0], 41 | [0, 0, 1, 0], 42 | [0, 0, 0, 1], 43 | ] 44 | ) 45 | ur2quest = np.linalg.inv(quest2ur) 46 | 47 | 48 | def velocity_ik(physics, 49 | site_name, 50 | delta_rot_mat, 51 | delta_pos, 52 | joint_names): 53 | 54 | dtype = physics.data.qpos.dtype 55 | 56 | # Convert site name to index. 57 | site_id = physics.model.name2id(site_name, 'site') 58 | 59 | 60 | jac = np.empty((6, physics.model.nv), dtype=dtype) 61 | err = np.zeros(6, dtype=dtype) 62 | jac_pos, jac_rot = jac[:3], jac[3:] 63 | err_pos, err_rot = err[:3], err[3:] 64 | 65 | err_pos[:] = delta_pos[:] 66 | 67 | delta_rot_quat = np.empty(4, dtype=dtype) 68 | mjlib.mju_mat2Quat(delta_rot_quat, delta_rot_mat) 69 | mjlib.mju_quat2Vel(err_rot, delta_rot_quat, 1) 70 | 71 | mjlib.mj_jacSite( 72 | physics.model.ptr, physics.data.ptr, jac_pos, jac_rot, site_id) 73 | 74 | dof_indices = [] 75 | for jn in joint_names: 76 | dof_idx = physics.model.joint(jn).id 77 | dof_indices.append(dof_idx) 78 | 79 | jac_joints = jac[:, dof_indices] 80 | 81 | update_joints = nullspace_method( 82 | jac_joints, err, regularization_strength=0.03) 83 | 84 | return update_joints 85 | 86 | 87 | class SingleArmQuestAgent(Agent): 88 | def __init__( 89 | self, 90 | robot_type: str, 91 | which_hand: str, 92 | eef_control_mode: int = 0, 93 | verbose: bool = False, 94 | use_vel_ik: bool = False, 95 | vel_ik_speed_scale: float = 0.95, 96 | ) -> None: 97 | """Interact with the robot using the quest controller. 98 | 99 | leftTrig: press to start control (also record the current position as the home position) 100 | leftJS: a tuple of (x,y) for the joystick, only need y to control the gripper 101 | """ 102 | self.which_hand = which_hand 103 | self.eef_control_mode = eef_control_mode 104 | assert self.which_hand in ["l", "r"] 105 | 106 | self.oculus_reader = OculusReader() 107 | if robot_type == "ur5": 108 | mjcf_model = mjcf.from_path("universal_robots_ur5e/ur5e.xml") 109 | mjcf_model.name = robot_type 110 | else: 111 | raise ValueError(f"Unknown robot type: {robot_type}") 112 | self.physics = mjcf.Physics.from_mjcf_model(mjcf_model) 113 | self.control_active = False 114 | self.reference_quest_pose = None 115 | self.reference_ee_rot_ur = None 116 | self.reference_ee_pos_ur = None 117 | self.reference_js = [0.5, 0.5] 118 | self.js_speed_scale = 0.1 119 | 120 | self.use_vel_ik = use_vel_ik 121 | self.vel_ik_speed_scale = vel_ik_speed_scale 122 | 123 | if use_vel_ik: 124 | self.translation_scaling_factor = 1.0 125 | else: 126 | self.translation_scaling_factor = 2.0 127 | 128 | self.robot_type = robot_type 129 | self._verbose = verbose 130 | 131 | def act(self, obs: Dict[str, np.ndarray]) -> np.ndarray: 132 | if self.robot_type == "ur5": 133 | num_dof = 6 134 | current_qpos = obs["joint_positions"][:num_dof] # last one dim is the gripper 135 | # run the fk 136 | self.physics.data.qpos[:num_dof] = current_qpos 137 | self.physics.step() 138 | 139 | ee_rot_mj = np.array( 140 | self.physics.named.data.site_xmat["attachment_site"] 141 | ).reshape(3, 3) 142 | ee_pos_mj = np.array(self.physics.named.data.site_xpos["attachment_site"]) 143 | if self.which_hand == "l": 144 | pose_key = "l" 145 | trigger_key = "leftTrig" 146 | grip_key = "leftGrip" 147 | joystick_key = "leftJS" 148 | elif self.which_hand == "r": 149 | pose_key = "r" 150 | trigger_key = "rightTrig" 151 | grip_key = "rightGrip" 152 | joystick_key = "rightJS" 153 | else: 154 | raise ValueError(f"Unknown hand: {self.which_hand}") 155 | # check the trigger button state 156 | ( 157 | pose_data, 158 | button_data, 159 | ) = self.oculus_reader.get_transformations_and_buttons() 160 | if len(pose_data) == 0 or len(button_data) == 0: 161 | print("no data, quest not yet ready") 162 | return np.concatenate( 163 | [current_qpos, obs["joint_positions"][num_dof:] * 0.0] 164 | ) 165 | 166 | if self.eef_control_mode == 0: 167 | new_gripper_angle = [button_data[grip_key][0]] 168 | elif self.eef_control_mode == 1: 169 | new_gripper_angle = [button_data[grip_key][0]] * 6 # [0, 1] 170 | else: 171 | # (x, y) position of joystick, range (-1.0, 1.0) 172 | if self.which_hand == "r": 173 | js_y = button_data[joystick_key][0] * -1 174 | js_x = button_data[joystick_key][1] * -1 175 | else: 176 | js_y = button_data[joystick_key][0] 177 | js_x = button_data[joystick_key][1] * -1 178 | 179 | if self.eef_control_mode == 2: 180 | # convert js_x, js_y from range (-1.0, 1.0) to (0.0, 1.0) 181 | js_x = (js_x + 1) / 2 182 | js_y = (js_y + 1) / 2 183 | # control absolute position using joystick 184 | self.reference_js = [js_x, js_y] 185 | else: 186 | # control relative position using joystick 187 | self.reference_js = [ 188 | max(0, min(self.reference_js[0] + js_x * self.js_speed_scale, 1)), 189 | max(0, min(self.reference_js[1] + js_y * self.js_speed_scale, 1)), 190 | ] 191 | new_gripper_angle = [ 192 | button_data[grip_key][0], 193 | button_data[grip_key][0], 194 | button_data[grip_key][0], 195 | button_data[grip_key][0], 196 | self.reference_js[0], 197 | self.reference_js[1], 198 | ] # [0, 1] 199 | arm_not_move_return = np.concatenate([current_qpos, new_gripper_angle]) 200 | if len(pose_data) == 0: 201 | print("no data, quest not yet ready") 202 | return arm_not_move_return 203 | 204 | global trigger_state 205 | trigger_state[self.which_hand] = button_data[trigger_key][0] > 0.5 206 | if trigger_state[self.which_hand]: 207 | if self.control_active is True: 208 | if self._verbose: 209 | print("controlling the arm") 210 | current_pose = pose_data[pose_key] 211 | delta_rot = current_pose[:3, :3] @ np.linalg.inv( 212 | self.reference_quest_pose[:3, :3] 213 | ) 214 | delta_pos = current_pose[:3, 3] - self.reference_quest_pose[:3, 3] 215 | if self.which_hand == "l": 216 | t_mat = np.matmul(ur2left, quest2ur) 217 | else: 218 | t_mat = np.matmul(ur2right, quest2ur) 219 | delta_pos_ur = ( 220 | apply_transfer(t_mat, delta_pos) * self.translation_scaling_factor 221 | ) 222 | # ? is this the case? 223 | delta_rot_ur = quest2ur[:3, :3] @ delta_rot @ ur2quest[:3, :3] 224 | if self._verbose: 225 | print(f"delta pos and rot in VR space: \n{delta_pos}, {delta_rot}") 226 | print( 227 | f"delta pos and rot in ur space: \n{delta_pos_ur}, {delta_rot_ur}" 228 | ) 229 | next_ee_rot_ur = delta_rot_ur @ self.reference_ee_rot_ur 230 | next_ee_pos_ur = delta_pos_ur + self.reference_ee_pos_ur 231 | 232 | if self.use_vel_ik: 233 | next_ee_pos_mj = apply_transfer(ur2mj, next_ee_pos_ur) 234 | next_ee_rot_mj = ur2mj[:3, :3] @ next_ee_rot_ur 235 | 236 | err_rot_mj = next_ee_rot_mj @ np.linalg.inv(ee_rot_mj) 237 | err_pos_mj = next_ee_pos_mj - ee_pos_mj 238 | 239 | print(err_pos_mj) 240 | 241 | delta_qpos = velocity_ik( 242 | self.physics, 243 | "attachment_site", 244 | err_rot_mj.flatten(), 245 | err_pos_mj, 246 | joint_names=[ 247 | "shoulder_pan_joint", 248 | "shoulder_lift_joint", 249 | "elbow_joint", 250 | "wrist_1_joint", 251 | "wrist_2_joint", 252 | "wrist_3_joint", 253 | ], 254 | ) 255 | 256 | new_qpos = current_qpos + delta_qpos * self.vel_ik_speed_scale 257 | 258 | else: 259 | target_quat = quaternion.as_float_array( 260 | quaternion.from_rotation_matrix(ur2mj[:3, :3] @ next_ee_rot_ur) 261 | ) 262 | ik_result = qpos_from_site_pose( 263 | self.physics, 264 | "attachment_site", 265 | target_pos=apply_transfer(ur2mj, next_ee_pos_ur), 266 | target_quat=target_quat, 267 | tol=1e-14, 268 | max_steps=400, 269 | ) 270 | self.physics.reset() 271 | if ik_result.success: 272 | new_qpos = ik_result.qpos[:num_dof] 273 | else: 274 | print("ik failed, using the original qpos") 275 | return arm_not_move_return 276 | command = np.concatenate([new_qpos, new_gripper_angle]) 277 | return command 278 | 279 | else: # last state is not in active 280 | self.control_active = True 281 | if self._verbose: 282 | print("control activated!") 283 | self.reference_quest_pose = pose_data[pose_key] 284 | 285 | self.reference_ee_rot_ur = mj2ur[:3, :3] @ ee_rot_mj 286 | self.reference_ee_pos_ur = apply_transfer(mj2ur, ee_pos_mj) 287 | return arm_not_move_return 288 | else: 289 | if self._verbose: 290 | print("deactive control") 291 | self.control_active = False 292 | self.reference_quest_pose = None 293 | return arm_not_move_return 294 | 295 | 296 | class DualArmQuestAgent(Agent): 297 | def __init__(self, agent_left: Agent, agent_right: Agent): 298 | self.agent_left = agent_left 299 | self.agent_right = agent_right 300 | global trigger_state 301 | self.trigger_state = trigger_state 302 | 303 | def act(self, obs: Dict) -> np.ndarray: 304 | left_obs = {} 305 | right_obs = {} 306 | for key, val in obs.items(): 307 | L = val.shape[0] 308 | half_dim = L // 2 309 | if key.endswith("rgb") or key.endswith("depth"): 310 | left_obs[key] = val 311 | right_obs[key] = val 312 | else: 313 | assert L == half_dim * 2, f"{key} must be even, something is wrong" 314 | left_obs[key] = val[:half_dim] 315 | right_obs[key] = val[half_dim:] 316 | return np.concatenate( 317 | [self.agent_left.act(left_obs), self.agent_right.act(right_obs)] 318 | ) 319 | 320 | 321 | if __name__ == "__main__": 322 | oculus_reader = OculusReader() 323 | while True: 324 | """ 325 | example output: 326 | ({'l': array([[-0.828395 , 0.541667 , -0.142682 , 0.219646 ], 327 | [-0.107737 , 0.0958919, 0.989544 , -0.833478 ], 328 | [ 0.549685 , 0.835106 , -0.0210789, -0.892425 ], 329 | [ 0. , 0. , 0. , 1. ]]), 'r': array([[-0.328058, 0.82021 , 0.468652, -1.8288 ], 330 | [ 0.070887, 0.516083, -0.8536 , -0.238691], 331 | [-0.941994, -0.246809, -0.227447, -0.370447], 332 | [ 0. , 0. , 0. , 1. ]])}, 333 | {'A': False, 'B': False, 'RThU': True, 'RJ': False, 'RG': False, 'RTr': False, 'X': False, 'Y': False, 'LThU': True, 'LJ': False, 'LG': False, 'LTr': False, 'leftJS': (0.0, 0.0), 'leftTrig': (0.0,), 'leftGrip': (0.0,), 'rightJS': (0.0, 0.0), 'rightTrig': (0.0,), 'rightGrip': (0.0,)}) 334 | 335 | """ 336 | pose_data, button_data = oculus_reader.get_transformations_and_buttons() 337 | if len(pose_data) == 0: 338 | print("no data") 339 | continue 340 | else: 341 | print(pose_data, button_data) 342 | -------------------------------------------------------------------------------- /agents/quest_agent_eef.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import numpy as np 4 | import quaternion 5 | from oculus_reader.reader import OculusReader 6 | from agents.agent import Agent 7 | 8 | 9 | mj2ur = np.array([[0, -1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) 10 | ur2mj = np.linalg.inv(mj2ur) 11 | 12 | trigger_state = {"l": False, "r": False} 13 | 14 | 15 | def apply_transfer(mat: np.ndarray, xyz: np.ndarray) -> np.ndarray: 16 | # xyz can be 3dim or 4dim (homogeneous) or can be a rotation matrix 17 | if len(xyz) == 3: 18 | xyz = np.append(xyz, 1) 19 | return np.matmul(mat, xyz)[:3] 20 | 21 | 22 | quest2isaac = np.array([[0, 0, -1, 0], [-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) 23 | left2isaac = np.array( 24 | [ 25 | [0, -1, 0, 0], 26 | [-1 / 2 * np.sqrt(2), 0, -1 / 2 * np.sqrt(2), 0], 27 | [1 / 2 * np.sqrt(2), 0, -1 / 2 * np.sqrt(2), 0], 28 | [0, 0, 0, 1], 29 | ] 30 | ) 31 | right2isaac = np.array( 32 | [ 33 | [0, -1, 0, 0], 34 | [-1 / 2 * np.sqrt(2), 0, 1 / 2 * np.sqrt(2), 0], 35 | [-1 / 2 * np.sqrt(2), 0, -1 / 2 * np.sqrt(2), 0], 36 | [0, 0, 0, 1], 37 | ] 38 | ) 39 | isaac2left = np.linalg.inv(left2isaac) 40 | isaac2right = np.linalg.inv(right2isaac) 41 | 42 | quest2left = np.matmul(isaac2left, quest2isaac) 43 | quest2right = np.matmul(isaac2right, quest2isaac) 44 | left2quest = np.linalg.inv(quest2left) 45 | right2quest = np.linalg.inv(quest2right) 46 | 47 | translation_scaling_factor = 1.0 48 | 49 | 50 | class SingleArmQuestAgent(Agent): 51 | def __init__( 52 | self, 53 | robot_type: str, 54 | which_hand: str, 55 | eef_control_mode: int = 0, 56 | verbose: bool = False, 57 | ) -> None: 58 | """Interact with the robot using the quest controller. 59 | 60 | leftTrig: press to start control (also record the current position as the home position) 61 | leftJS: a tuple of (x,y) for the joystick, only need y to control the gripper 62 | """ 63 | self.which_hand = which_hand 64 | self.eef_control_mode = eef_control_mode 65 | assert self.which_hand in ["l", "r"] 66 | 67 | self.oculus_reader = OculusReader() 68 | self.control_active = False 69 | self.reference_quest_pose = None 70 | self.reference_ee_rot_ur = None 71 | self.reference_ee_pos_ur = None 72 | self.reference_js = [0.5, 0.5] 73 | self.js_speed_scale = 0.1 74 | 75 | self.robot_type = robot_type 76 | self._verbose = verbose 77 | 78 | def act(self, obs: Dict[str, np.ndarray]) -> np.ndarray: 79 | if self.robot_type == "ur5": 80 | num_dof = 6 81 | current_eef_pose = obs["ee_pos_quat"] 82 | 83 | # pos and rot in robot base frame 84 | ee_pos = current_eef_pose[:3] 85 | ee_rot = current_eef_pose[3:] 86 | ee_rot = quaternion.as_rotation_matrix(quaternion.from_rotation_vector(ee_rot)) 87 | 88 | if self.which_hand == "l": 89 | pose_key = "l" 90 | trigger_key = "leftTrig" 91 | grip_key = "leftGrip" 92 | joystick_key = "leftJS" 93 | # left yx 94 | gripper_open_key = "Y" 95 | gripper_close_key = "X" 96 | elif self.which_hand == "r": 97 | pose_key = "r" 98 | trigger_key = "rightTrig" 99 | grip_key = "rightGrip" 100 | joystick_key = "rightJS" 101 | # right ba for the key 102 | gripper_open_key = "B" 103 | gripper_close_key = "A" 104 | else: 105 | raise ValueError(f"Unknown hand: {self.which_hand}") 106 | # check the trigger button state 107 | ( 108 | pose_data, 109 | button_data, 110 | ) = self.oculus_reader.get_transformations_and_buttons() 111 | if len(pose_data) == 0 or len(button_data) == 0: 112 | print("no data, quest not yet ready") 113 | return np.concatenate( 114 | [current_eef_pose, obs["joint_positions"][num_dof:] * 0.0] 115 | ) 116 | 117 | if self.eef_control_mode == 0: 118 | new_gripper_angle = [button_data[grip_key][0]] 119 | elif self.eef_control_mode == 1: 120 | new_gripper_angle = [button_data[grip_key][0]] * 6 # [0, 1] 121 | else: 122 | # (x, y) position of joystick, range (-1.0, 1.0) 123 | if self.which_hand == "r": 124 | js_y = button_data[joystick_key][0] * -1 125 | js_x = button_data[joystick_key][1] * -1 126 | else: 127 | js_y = button_data[joystick_key][0] 128 | js_x = button_data[joystick_key][1] * -1 129 | 130 | if self.eef_control_mode == 2: 131 | # convert js_x, js_y from range (-1.0, 1.0) to (0.0, 1.0) 132 | js_x = (js_x + 1) / 2 133 | js_y = (js_y + 1) / 2 134 | # control absolute position using joystick 135 | self.reference_js = [js_x, js_y] 136 | else: 137 | # control relative position using joystick 138 | self.reference_js = [ 139 | max(0, min(self.reference_js[0] + js_x * self.js_speed_scale, 1)), 140 | max(0, min(self.reference_js[1] + js_y * self.js_speed_scale, 1)), 141 | ] 142 | new_gripper_angle = [ 143 | button_data[grip_key][0], 144 | button_data[grip_key][0], 145 | button_data[grip_key][0], 146 | button_data[grip_key][0], 147 | self.reference_js[0], 148 | self.reference_js[1], 149 | ] # [0, 1] 150 | arm_not_move_return = np.concatenate([current_eef_pose, new_gripper_angle]) 151 | if len(pose_data) == 0: 152 | print("no data, quest not yet ready") 153 | return arm_not_move_return 154 | 155 | global trigger_state 156 | trigger_state[self.which_hand] = button_data[trigger_key][0] > 0.5 157 | if trigger_state[self.which_hand]: 158 | if self.control_active is True: 159 | if self._verbose: 160 | print("controlling the arm") 161 | current_pose = pose_data[pose_key] 162 | delta_rot = current_pose[:3, :3] @ np.linalg.inv( 163 | self.reference_quest_pose[:3, :3] 164 | ) 165 | delta_pos = current_pose[:3, 3] - self.reference_quest_pose[:3, 3] 166 | if self.which_hand == "l": 167 | t_mat = quest2left 168 | t_mat_inv = left2quest 169 | ur2isaac = left2isaac 170 | else: 171 | t_mat = quest2right 172 | t_mat_inv = right2quest 173 | ur2isaac = right2isaac 174 | delta_pos_ur = ( 175 | apply_transfer(t_mat, delta_pos) * translation_scaling_factor 176 | ) 177 | delta_rot_ur = t_mat[:3, :3] @ delta_rot @ t_mat_inv[:3, :3] 178 | if self._verbose: 179 | print(f"delta pos and rot in VR space: \n{delta_pos}, {delta_rot}") 180 | print( 181 | f"delta pos and rot in ur space: \n{delta_pos_ur}, {delta_rot_ur}" 182 | ) 183 | delta_pos_isaac = apply_transfer(quest2isaac, delta_pos) 184 | delta_pos_isaac_ur = apply_transfer(ur2isaac, delta_pos_ur) 185 | print("delta pos in isaac", delta_pos_isaac) 186 | print("delta pos in ur in isaac", delta_pos_isaac_ur) 187 | print("delta", delta_pos_isaac - delta_pos_isaac_ur) 188 | 189 | next_ee_rot_ur = delta_rot_ur @ self.reference_ee_rot_ur # [3, 3] 190 | next_ee_pos_ur = delta_pos_ur + self.reference_ee_pos_ur 191 | next_ee_rot_ur = quaternion.as_rotation_vector( 192 | quaternion.from_rotation_matrix(next_ee_rot_ur) 193 | ) 194 | new_eef_pose = np.concatenate([next_ee_pos_ur, next_ee_rot_ur]) 195 | command = np.concatenate([new_eef_pose, new_gripper_angle]) 196 | return command 197 | 198 | else: # last state is not in active 199 | self.control_active = True 200 | if self._verbose: 201 | print("control activated!") 202 | self.reference_quest_pose = pose_data[pose_key] 203 | 204 | # reference in their local TCP frames 205 | self.reference_ee_rot_ur = ee_rot 206 | self.reference_ee_pos_ur = ee_pos 207 | return arm_not_move_return 208 | else: 209 | if self._verbose: 210 | print("deactive control") 211 | self.control_active = False 212 | self.reference_quest_pose = None 213 | return arm_not_move_return 214 | 215 | 216 | class DualArmQuestAgent(Agent): 217 | def __init__(self, agent_left: Agent, agent_right: Agent): 218 | self.agent_left = agent_left 219 | self.agent_right = agent_right 220 | global trigger_state 221 | self.trigger_state = trigger_state 222 | 223 | def act(self, obs: Dict) -> np.ndarray: 224 | left_obs = {} 225 | right_obs = {} 226 | for key, val in obs.items(): 227 | L = val.shape[0] 228 | half_dim = L // 2 229 | if key.endswith("rgb") or key.endswith("depth"): 230 | left_obs[key] = val 231 | right_obs[key] = val 232 | else: 233 | assert L == half_dim * 2, f"{key} must be even, something is wrong" 234 | left_obs[key] = val[:half_dim] 235 | right_obs[key] = val[half_dim:] 236 | return np.concatenate( 237 | [self.agent_left.act(left_obs), self.agent_right.act(right_obs)] 238 | ) 239 | 240 | 241 | if __name__ == "__main__": 242 | oculus_reader = OculusReader() 243 | while True: 244 | """ 245 | example output: 246 | ({'l': array([[-0.828395 , 0.541667 , -0.142682 , 0.219646 ], 247 | [-0.107737 , 0.0958919, 0.989544 , -0.833478 ], 248 | [ 0.549685 , 0.835106 , -0.0210789, -0.892425 ], 249 | [ 0. , 0. , 0. , 1. ]]), 'r': array([[-0.328058, 0.82021 , 0.468652, -1.8288 ], 250 | [ 0.070887, 0.516083, -0.8536 , -0.238691], 251 | [-0.941994, -0.246809, -0.227447, -0.370447], 252 | [ 0. , 0. , 0. , 1. ]])}, 253 | {'A': False, 'B': False, 'RThU': True, 'RJ': False, 'RG': False, 'RTr': False, 'X': False, 'Y': False, 'LThU': True, 'LJ': False, 'LG': False, 'LTr': False, 'leftJS': (0.0, 0.0), 'leftTrig': (0.0,), 'leftGrip': (0.0,), 'rightJS': (0.0, 0.0), 'rightTrig': (0.0,), 'rightGrip': (0.0,)}) 254 | 255 | """ 256 | pose_data, button_data = oculus_reader.get_transformations_and_buttons() 257 | if len(pose_data) == 0: 258 | print("no data") 259 | continue 260 | else: 261 | print(pose_data, button_data) 262 | -------------------------------------------------------------------------------- /agents/universal_robots_ur5e/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2018 ROS Industrial Consortium 2 | 3 | Redistribution and use in source and binary forms, with or without modification, 4 | are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this 7 | list of conditions and the following disclaimer. 8 | 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation and/or 11 | other materials provided with the distribution. 12 | 13 | 3. Neither the name of the copyright holder nor the names of its contributors 14 | may be used to endorse or promote products derived from this software without 15 | specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 21 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 24 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | -------------------------------------------------------------------------------- /agents/universal_robots_ur5e/README.md: -------------------------------------------------------------------------------- 1 | # Universal Robots UR5e Description (MJCF) 2 | 3 | Requires MuJoCo 2.3.3 or later. 4 | 5 | ## Overview 6 | 7 | This package contains a simplified robot description (MJCF) of the 8 | [UR5e](https://www.universal-robots.com/products/ur5-robot/) developed by 9 | [Universal Robots](https://www.universal-robots.com/). It is derived from the 10 | [publicly available URDF 11 | description](https://github.com/ros-industrial/universal_robot/tree/kinetic-devel/ur_e_description). 12 | 13 |

14 | 15 |

16 | 17 | ### URDF → MJCF derivation steps 18 | 19 | 1. Converted the DAE [mesh 20 | files](https://github.com/ros-industrial/universal_robot/tree/kinetic-devel/ur_e_description/meshes/ur5e/visual) 21 | to OBJ format using [Blender](https://www.blender.org/). 22 | 2. Processed `.obj` files with [`obj2mjcf`](https://github.com/kevinzakka/obj2mjcf). 23 | 3. Added ` ` to the URDF's 24 | `` clause in order to preserve visual geometries. 25 | 4. Loaded the URDF into MuJoCo and saved a corresponding MJCF. 26 | 5. Added a tracking light to the base. 27 | 6. Manually edited the MJCF to extract common properties into the `` section. 28 | 7. Added position-controlled actuators. Max joint torque values were taken from 29 | [here](https://www.universal-robots.com/articles/ur/robot-care-maintenance/max-joint-torques/). 30 | 8. Added home joint configuration as a `keyframe`. 31 | 9. Manually designed collision geometries. 32 | 10. Added `scene.xml` which includes the robot, with a textured ground plane, skybox and haze. 33 | 34 | ## License 35 | 36 | This model is released under a [BSD-3-Clause License](LICENSE). 37 | -------------------------------------------------------------------------------- /agents/universal_robots_ur5e/scene.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /agents/universal_robots_ur5e/ur5e.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ToruOwO/hato/3b6f4496e11a386df8f362f398b36e832944b2b8/agents/universal_robots_ur5e/ur5e.png -------------------------------------------------------------------------------- /agents/universal_robots_ur5e/ur5e.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 138 | -------------------------------------------------------------------------------- /camera_node.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import threading 3 | import time 4 | from typing import Optional, Tuple 5 | 6 | import numpy as np 7 | import zmq 8 | 9 | from cameras.camera import CameraDriver 10 | 11 | DEFAULT_CAMERA_PORT = 5000 12 | 13 | 14 | class ZMQClientCamera(CameraDriver): 15 | """A class representing a ZMQ client for a leader robot.""" 16 | 17 | def __init__(self, port: int = DEFAULT_CAMERA_PORT, host: str = "127.0.0.1"): 18 | self._context = zmq.Context() 19 | self._socket = self._context.socket(zmq.REQ) 20 | self._socket.connect(f"tcp://{host}:{port}") 21 | 22 | def read( 23 | self, 24 | img_size: Optional[Tuple[int, int]] = None, 25 | ) -> Tuple[np.ndarray, np.ndarray]: 26 | """Get the current state of the leader robot. 27 | 28 | Returns: 29 | T: The current state of the leader robot. 30 | """ 31 | # pack the image_size and send it to the server 32 | send_message = pickle.dumps(img_size) 33 | self._socket.send(send_message) 34 | state_dict = pickle.loads(self._socket.recv()) 35 | return state_dict 36 | 37 | 38 | class ZMQServerCamera: 39 | def __init__( 40 | self, 41 | camera: CameraDriver, 42 | port: int = DEFAULT_CAMERA_PORT, 43 | host: str = "127.0.0.1", 44 | ): 45 | self._camera = camera 46 | self._context = zmq.Context() 47 | self._socket = self._context.socket(zmq.REP) 48 | addr = f"tcp://{host}:{port}" 49 | debug_message = f"Camera Sever Binding to {addr}, Camera: {camera}" 50 | print(debug_message) 51 | self._timout_message = f"Timeout in Camera Server, Camera: {camera}" 52 | self._socket.bind(addr) 53 | self._stop_event = threading.Event() 54 | 55 | def serve(self) -> None: 56 | """Serve the leader robot state over ZMQ.""" 57 | self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms 58 | while not self._stop_event.is_set(): 59 | try: 60 | message = self._socket.recv() 61 | img_size = pickle.loads(message) 62 | camera_read = self._camera.read(img_size) 63 | self._socket.send(pickle.dumps(camera_read)) 64 | except zmq.Again: 65 | print(self._timout_message) 66 | # Timeout occurred, check if the stop event is set 67 | 68 | def stop(self) -> None: 69 | """Signal the server to stop serving.""" 70 | self._stop_event.set() 71 | 72 | 73 | class ZMQServerCameraFaster: 74 | def __init__( 75 | self, 76 | camera: CameraDriver, 77 | port: int = DEFAULT_CAMERA_PORT, 78 | host: str = "127.0.0.1", 79 | ): 80 | self._camera = camera 81 | self._context = zmq.Context() 82 | self._socket = self._context.socket(zmq.REP) 83 | addr = f"tcp://{host}:{port}" 84 | debug_message = f"Camera Sever Binding to {addr}, Camera: {camera}" 85 | print(debug_message) 86 | self._timout_message = f"Timeout in Camera Server, Camera: {camera}" 87 | self._socket.bind(addr) 88 | self._stop_event = threading.Event() 89 | 90 | self.cam_buffer = None 91 | self.refresh_interval = 1 / 30 # Refresh every 1/30 second 92 | self.refresh_thread = threading.Thread(target=self._refresh_buffer) 93 | self.refresh_thread.daemon = True 94 | self.refresh_thread.start() 95 | 96 | def serve(self) -> None: 97 | """Serve the leader robot state over ZMQ.""" 98 | self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms 99 | while not self._stop_event.is_set(): 100 | try: 101 | _ = self._socket.recv() 102 | if self.cam_buffer is not None: 103 | self._socket.send(pickle.dumps(self.cam_buffer)) 104 | else: 105 | self._socket.send(b"Buffer is empty.") 106 | except zmq.Again: 107 | print(self._timout_message) 108 | 109 | def _refresh_buffer(self): 110 | """Periodically refresh the buffer.""" 111 | while not self._stop_event.is_set(): 112 | self.cam_buffer = self._camera.read() 113 | time.sleep(self.refresh_interval) 114 | 115 | def stop(self) -> None: 116 | """Signal the server to stop serving.""" 117 | self._stop_event.set() 118 | -------------------------------------------------------------------------------- /cameras/camera.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Protocol, Tuple 2 | 3 | import numpy as np 4 | 5 | 6 | class CameraDriver(Protocol): 7 | """Camera protocol. 8 | 9 | A protocol for a camera driver. This is used to abstract the camera from the rest of the code. 10 | """ 11 | 12 | def read( 13 | self, 14 | img_size: Optional[Tuple[int, int]] = None, 15 | ) -> Tuple[np.ndarray, np.ndarray]: 16 | """Read a frame from the camera. 17 | 18 | Args: 19 | img_size: The size of the image to return. If None, the original size is returned. 20 | farthest: The farthest distance to map to 255. 21 | 22 | Returns: 23 | np.ndarray: The color image. 24 | np.ndarray: The depth image. 25 | """ 26 | -------------------------------------------------------------------------------- /cameras/realsense_camera.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import List, Optional, Tuple 3 | 4 | import cv2 5 | import numpy as np 6 | import pyrealsense2 as rs 7 | 8 | from .camera import CameraDriver 9 | 10 | 11 | def get_device_ids() -> List[str]: 12 | device_ids = [] 13 | for dev in rs.context().query_devices(): 14 | device_ids.append(dev.get_info(rs.camera_info.serial_number)) 15 | return device_ids 16 | 17 | 18 | class RealSenseCamera(CameraDriver): 19 | def __repr__(self) -> str: 20 | return f"RealSenseCamera(device_ids={self.device_ids})" 21 | 22 | def __init__( 23 | self, 24 | device_ids: Optional[List] = None, 25 | height: int = 480, 26 | width: int = 640, 27 | fps: int = 30, 28 | warm_start: int = 60, 29 | img_size: Optional[Tuple[int, int]] = None, 30 | ): 31 | self.height = height 32 | self.width = width 33 | self.fps = fps 34 | if device_ids is None: 35 | self.device_ids = get_device_ids() 36 | else: 37 | self.device_ids = device_ids 38 | self.img_size = img_size 39 | 40 | # Start stream 41 | print(f"Connecting to RealSense cameras ({len(self.device_ids)} found) ...") 42 | self.pipes = [] 43 | self.profiles = OrderedDict() 44 | for i, device_id in enumerate(self.device_ids): 45 | pipe = rs.pipeline() 46 | config = rs.config() 47 | 48 | config.enable_device(device_id) 49 | config.enable_stream( 50 | rs.stream.depth, self.width, self.height, rs.format.z16, self.fps 51 | ) 52 | config.enable_stream( 53 | rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps 54 | ) 55 | 56 | self.pipes.append(pipe) 57 | self.profiles[device_id] = pipe.start(config) 58 | print(f"Connected to camera {i} ({device_id}).") 59 | 60 | self.align = rs.align(rs.stream.color) 61 | 62 | # Warm start camera (realsense automatically adjusts brightness during initial frames) 63 | for _ in range(warm_start): 64 | self._get_frames() 65 | 66 | def _get_frames(self): 67 | framesets = [pipe.wait_for_frames() for pipe in self.pipes] 68 | return [self.align.process(frameset) for frameset in framesets] 69 | 70 | def get_num_cameras(self): 71 | return len(self.device_ids) 72 | 73 | def read( 74 | self, 75 | img_size: Optional[Tuple[int, int]] = None, 76 | concatenate: bool = False, 77 | ) -> Tuple[np.ndarray, np.ndarray]: 78 | """Read a frame from the camera. 79 | 80 | Args: 81 | img_size: The size of the image to return. If None, the original size is returned. 82 | farthest: The farthest distance to map to 255. 83 | 84 | Returns: 85 | np.ndarray: The color image, shape=(H, W, 3) 86 | np.ndarray: The depth image, shape=(H, W) 87 | """ 88 | framesets = self._get_frames() 89 | num_cams = self.get_num_cameras() 90 | rgbd = np.empty([num_cams, self.height, self.width, 4], dtype=np.uint16) 91 | if self.img_size is not None: 92 | rgbd_resized = np.empty( 93 | [num_cams, self.img_size[1], self.img_size[0], 4], dtype=np.uint16 94 | ) 95 | 96 | for i, frameset in enumerate(framesets): 97 | color_frame = frameset.get_color_frame() 98 | rgbd[i, :, :, :3] = np.asanyarray(color_frame.get_data()) 99 | 100 | depth_frame = frameset.get_depth_frame() 101 | depth_frame = rs.decimation_filter(1).process(depth_frame) 102 | depth_frame = rs.disparity_transform(True).process(depth_frame) 103 | depth_frame = rs.spatial_filter().process(depth_frame) 104 | depth_frame = rs.temporal_filter().process(depth_frame) 105 | depth_frame = rs.disparity_transform(False).process(depth_frame) 106 | depth_frame = rs.hole_filling_filter().process(depth_frame) 107 | rgbd[i, :, :, 3] = np.asanyarray(depth_frame.get_data()) 108 | 109 | if self.img_size is not None: 110 | rgbd_resized[i] = cv2.resize(rgbd[i], self.img_size) 111 | 112 | if self.img_size is not None: 113 | rgbd = rgbd_resized 114 | 115 | if concatenate: 116 | image = np.concatenate(rgbd[..., :3], axis=1, dtype=np.uint8) 117 | depth = np.concatenate(rgbd[..., -1], axis=1, dtype=np.uint8) 118 | else: 119 | image = rgbd[..., :3].astype(np.uint8) 120 | depth = rgbd[..., -1].astype(np.uint16) 121 | return image, depth 122 | -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import pickle 4 | import time 5 | from typing import Any, Dict, Optional 6 | 7 | import cv2 8 | import numpy as np 9 | from natsort import natsorted 10 | 11 | from cameras.camera import CameraDriver 12 | from robots.robot import Robot 13 | 14 | 15 | class Rate: 16 | def __init__(self, rate: float): 17 | self.last = time.time() 18 | self.rate = rate 19 | 20 | def sleep(self) -> None: 21 | while self.last + 1.0 / self.rate > time.time(): 22 | time.sleep(0.0001) 23 | self.last = time.time() 24 | 25 | 26 | class EvalRobotEnv: 27 | def __init__( 28 | self, 29 | robot: Robot, 30 | traj_path: str, 31 | control_rate_hz: float, 32 | camera_dict: Optional[Dict[str, CameraDriver]] = None, 33 | ) -> None: 34 | self._robot = robot 35 | self._rate = Rate(control_rate_hz) 36 | self._camera_dict = {} if camera_dict is None else camera_dict 37 | self.traj_path = traj_path 38 | 39 | self.pkls = natsorted( 40 | glob.glob(os.path.join(self.traj_path, "*.pkl"), recursive=True) 41 | ) 42 | print("Finished reading dir", self.traj_path) 43 | print("No. of files:", len(self.pkls)) 44 | self.traj_len = len(self.pkls) 45 | self.count = 0 46 | 47 | def robot(self) -> Robot: 48 | """Get the robot object. 49 | 50 | Returns: 51 | robot: the robot object. 52 | """ 53 | return self._robot 54 | 55 | def __len__(self): 56 | # Return positive integer for batched envs. 57 | return self.traj_len 58 | 59 | def step_eef(self, eef_pose: np.ndarray) -> Dict[str, Any]: 60 | """Step the environment forward. 61 | 62 | Args: 63 | eef_pose: end effector pose command to step the environment with. 64 | 65 | Returns: 66 | obs: observation from the environment. 67 | """ 68 | assert len(eef_pose) == self._robot.num_dofs(), f"input:{len(eef_pose)}" 69 | self._robot.command_eef_pose(eef_pose) 70 | self._rate.sleep() 71 | return self.get_obs() 72 | 73 | def step(self, joints: np.ndarray) -> Dict[str, Any]: 74 | """Step the environment forward. 75 | 76 | Args: 77 | joints: joint angles command to step the environment with. 78 | 79 | Returns: 80 | obs: observation from the environment. 81 | """ 82 | assert len(joints) == ( 83 | self._robot.num_dofs() 84 | ), f"input:{len(joints)}, robot:{self._robot.num_dofs()}" 85 | assert self._robot.num_dofs() == len(joints) 86 | self._robot.command_joint_state(joints) 87 | self._rate.sleep() 88 | return self.get_obs() 89 | 90 | def get_real_obs(self) -> Dict[str, Any]: 91 | observations = {} 92 | for name, camera in self._camera_dict.items(): 93 | image, depth = camera.read() 94 | observations[f"{name}_rgb"] = image 95 | observations[f"{name}_depth"] = depth 96 | 97 | robot_obs = self._robot.get_observations() 98 | for k, v in robot_obs.items(): 99 | observations[k] = v 100 | return observations 101 | 102 | def get_obs(self) -> Dict[str, Any]: 103 | """Get observation from the environment. 104 | 105 | Returns: 106 | obs: observation from the environment. 107 | """ 108 | if self.count >= self.traj_len: 109 | return None 110 | pkl = self.pkls[self.count] 111 | with open(pkl, "rb") as f: 112 | observations = pickle.load(f) 113 | self.count += 1 114 | return observations 115 | 116 | 117 | class RobotEnv: 118 | def __init__( 119 | self, 120 | robot: Robot, 121 | control_rate_hz: float = 100.0, 122 | camera_dict: Optional[Dict[str, CameraDriver]] = None, 123 | show_camera_view: bool = True, 124 | save_depth: bool = True, 125 | ) -> None: 126 | self._robot = robot 127 | self._rate = Rate(control_rate_hz) 128 | print("RobotEnv: control_rate_hz", control_rate_hz) 129 | self._camera_dict = {} if camera_dict is None else camera_dict 130 | 131 | self._show_camera_view = show_camera_view 132 | if self._show_camera_view: 133 | for name in list(self._camera_dict.keys()): 134 | cv2.namedWindow(name, cv2.WINDOW_NORMAL) 135 | 136 | self._save_depth = save_depth 137 | 138 | def robot(self) -> Robot: 139 | """Get the robot object. 140 | 141 | Returns: 142 | robot: the robot object. 143 | """ 144 | return self._robot 145 | 146 | def __len__(self): 147 | # Return positive integer for batched envs. 148 | return 0 149 | 150 | def step_eef(self, eef_pose: np.ndarray) -> Dict[str, Any]: 151 | """Step the environment forward. 152 | 153 | Args: 154 | eef_pose: end effector pose command to step the environment with. 155 | 156 | Returns: 157 | obs: observation from the environment. 158 | """ 159 | assert len(eef_pose) == self._robot.num_dofs(), f"input:{len(eef_pose)}" 160 | self._robot.command_eef_pose(eef_pose) 161 | self._rate.sleep() 162 | return self.get_obs() 163 | 164 | def step(self, joints: np.ndarray) -> Dict[str, Any]: 165 | """Step the environment forward. 166 | 167 | Args: 168 | joints: joint angles command to step the environment with. 169 | 170 | Returns: 171 | obs: observation from the environment. 172 | """ 173 | assert len(joints) == ( 174 | self._robot.num_dofs() 175 | ), f"input:{len(joints)}, robot:{self._robot.num_dofs()}" 176 | assert self._robot.num_dofs() == len(joints) 177 | self._robot.command_joint_state(joints) 178 | self._rate.sleep() 179 | return self.get_obs() 180 | 181 | def get_obs(self) -> Dict[str, Any]: 182 | """Get observation from the environment. 183 | 184 | Returns: 185 | obs: observation from the environment. 186 | """ 187 | observations = {} 188 | for name, camera in self._camera_dict.items(): 189 | image, depth = camera.read() 190 | observations[f"{name}_rgb"] = image 191 | if self._save_depth: 192 | observations[f"{name}_depth"] = depth 193 | 194 | if self._show_camera_view: 195 | depth = cv2.applyColorMap(depth, cv2.COLORMAP_JET) 196 | image_depth = cv2.hconcat([image[:, :, ::-1], depth]) 197 | cv2.imshow(name, image_depth) 198 | cv2.waitKey(1) 199 | 200 | robot_obs = self._robot.get_observations() 201 | for k, v in robot_obs.items(): 202 | observations[k] = v 203 | return observations 204 | -------------------------------------------------------------------------------- /eval_dir.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | 5 | from agents.dp_agent import BimanualDPAgent 6 | 7 | 8 | def eval_ckpts(ckpt_paths, eval_dir, save_path): 9 | mse_dict = {} 10 | 11 | if os.path.exists(save_path): 12 | with open(save_path, "rb") as f: 13 | mse_dict = pickle.load(f) 14 | print(f"Loaded previous MSE dict from {save_path}") 15 | 16 | eval_loader = None 17 | last_arg = None 18 | 19 | for ckpt_path in ckpt_paths: 20 | ckpt_name = os.path.basename(os.path.dirname(ckpt_path)) 21 | ckpt_num = os.path.basename(ckpt_path) 22 | agent = BimanualDPAgent(ckpt_path) 23 | if eval_loader is None: 24 | eval_loader = agent.dp.get_eval_loader(eval_dir) 25 | elif ( 26 | agent.dp_args["representation_type"] != last_arg["representation_type"] 27 | or agent.dp_args["camera_indices"] != last_arg["camera_indices"] 28 | ): 29 | eval_loader = agent.dp.get_eval_loader(eval_dir) 30 | last_arg = agent.dp_args 31 | mse, action_mse = agent.dp.eval_dir(eval_loader) 32 | if mse_dict.get(ckpt_name) is None: 33 | mse_dict[ckpt_name] = {} 34 | mse_dict[ckpt_name]["config"] = agent.dp_args 35 | 36 | mse_dict[ckpt_name][ckpt_num] = {} 37 | mse_dict[ckpt_name][ckpt_num]["mse"] = mse 38 | mse_dict[ckpt_name][ckpt_num]["action_mse"] = action_mse 39 | 40 | print(f"MSE for {ckpt_name}: {mse}") 41 | 42 | with open(save_path, "wb") as f: 43 | pickle.dump(mse_dict, f) 44 | print(f"Saved MSE dict to {save_path}") 45 | 46 | 47 | if __name__ == "__main__": 48 | args = argparse.ArgumentParser() 49 | args.add_argument( 50 | "--ckpt_path", 51 | nargs="+", 52 | type=str, 53 | default=[ 54 | "model_epoch_100.ckpt", 55 | "model_epoch_200.ckpt", 56 | "model_epoch_300.ckpt", 57 | ], 58 | ) 59 | args.add_argument( 60 | "--eval_dir", 61 | type=str, 62 | default="/split_data/data_pour_train_10", 63 | ) 64 | args.add_argument("--save_path", type=str, default=None) 65 | 66 | args = args.parse_args() 67 | 68 | if args.save_path is None: 69 | data_name = os.path.basename(args.eval_dir) 70 | args.save_path = "./eval_results/eval_{}.pkl".format(data_name) 71 | eval_ckpts(args.ckpt_path, args.eval_dir, args.save_path) 72 | -------------------------------------------------------------------------------- /inference_node.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import pickle 4 | import threading 5 | import time 6 | 7 | import numpy as np 8 | import zmq 9 | 10 | DEFAULT_INFERENCE_PORT = 4321 11 | 12 | 13 | class ZMQInferenceServer: 14 | """A class representing a ZMQ client for a leader robot.""" 15 | 16 | def __init__(self, port: int = DEFAULT_INFERENCE_PORT): 17 | self._context = zmq.Context() 18 | self._socket = self._context.socket(zmq.PAIR) 19 | self._socket.bind(f"tcp://*:{port}") 20 | self._stop_event = threading.Event() 21 | 22 | def get_obs(self): 23 | state_dict = None 24 | while True: 25 | try: 26 | # check for a message, this will not block 27 | message = self._socket.recv(flags=zmq.NOBLOCK) 28 | 29 | except zmq.Again as e: 30 | # print("observation queue exhausted") 31 | break 32 | else: 33 | state_dict = pickle.loads(message) 34 | if state_dict is None: # block until an observation is recieved 35 | while True: 36 | message = self._socket.recv() 37 | state_dict = pickle.loads(message) 38 | if "obs" not in state_dict and "t" not in state_dict: 39 | if "num_diffusion_iters" in state_dict: # ignore and send success 40 | self._socket.send_string("success") 41 | continue 42 | break 43 | 44 | return state_dict["obs"], state_dict["t"] 45 | 46 | def infer(self, *args, **kwargs): 47 | raise NotImplementedError 48 | 49 | def act(self, obs): 50 | raise NotImplementedError 51 | 52 | def serve(self): 53 | # self._socket.setsockopt(zmq.RCVTIMEO, 10000) # Set timeout to 1000 ms 54 | while not self._stop_event.is_set(): 55 | obs, t = self.get_obs() # get obs from the client 56 | 57 | print(f"Recieved observation at time {t}. Inference start!") 58 | pred = self.act(obs) 59 | print(f"Inference ended.") 60 | 61 | message = pickle.dumps({"acts": pred, "t": t}) 62 | self._socket.send(message) # send the action back to the client 63 | 64 | def stop(self) -> None: 65 | """Signal the server to stop serving.""" 66 | self._stop_event.set() 67 | 68 | 69 | class ZMQInferenceClient: 70 | """A class representing a ZMQ client for a leader robot.""" 71 | 72 | def __init__( 73 | self, 74 | port: int = DEFAULT_INFERENCE_PORT, 75 | host: str = "111.11.111.11", 76 | default_action=None, 77 | queue_size=32, 78 | ensemble_mode="new", 79 | act_tau=0.5, 80 | ): 81 | self._context = zmq.Context() 82 | self._socket = self._context.socket(zmq.PAIR) 83 | self._socket.connect(f"tcp://{host}:{port}") 84 | print(f"connected -- tcp://{host}:{port}") 85 | 86 | self.act_q = collections.deque(maxlen=queue_size) 87 | self.t = 0 88 | self.last_act = default_action 89 | self.ensemble_mode = ensemble_mode 90 | self.act_tau = act_tau 91 | 92 | def act(self, obs): 93 | self.t += 1 94 | final_act = self.last_act 95 | 96 | # send the observation 97 | message = pickle.dumps({"obs": obs, "t": self.t}) 98 | self._socket.send(message) 99 | 100 | # process the incoming message queue 101 | while True: 102 | try: 103 | # check for a message, this will not block 104 | message = self._socket.recv(flags=zmq.NOBLOCK) 105 | 106 | except zmq.Again as e: 107 | # print("action queue exhausted") 108 | break 109 | else: 110 | state_dict = pickle.loads(message) 111 | acts, pt = state_dict["acts"], state_dict["t"] 112 | 113 | while len(self.act_q) > 0 and self.act_q[0][1] < self.t: 114 | self.act_q.popleft() 115 | while pt < self.t and len(acts) > 0: 116 | pt += 1 117 | acts = acts[1:] 118 | for c_acts, ct in self.act_q: 119 | if ct == pt: 120 | c_acts.append(acts[0]) 121 | pt += 1 122 | acts = acts[1:] 123 | if len(acts) == 0: 124 | break 125 | # for 126 | # push all the new actions in 127 | for i, act in enumerate(acts): 128 | self.act_q.append(([act], pt + i)) 129 | 130 | # now searching for the matching time stamp 131 | while len(self.act_q) > 0: 132 | c_acts, tt = self.act_q.popleft() 133 | if tt == self.t: 134 | if self.ensemble_mode == "act": 135 | z_act = c_acts[0] 136 | for act in c_acts[1:]: 137 | z_act = z_act * self.act_tau + act * (1.0 - self.act_tau) 138 | final_act = z_act 139 | elif self.ensemble_mode == "avg": 140 | final_act = np.mean(np.array(c_acts), axis=0) 141 | elif self.ensemble_mode == "old": 142 | final_act = c_acts[0] 143 | elif self.ensemble_mode == "new": 144 | final_act = c_acts[-1] 145 | 146 | break 147 | print("action queue (dt):", [t - self.t for a, t in self.act_q]) 148 | print("action queue (size):", [len(a) for a, t in self.act_q]) 149 | self.last_act = final_act 150 | # print("action:", final_act) 151 | return final_act 152 | 153 | 154 | if __name__ == "__main__": 155 | args = argparse.ArgumentParser() 156 | args.add_argument("--type", type=str, default="client") 157 | args.add_argument("--freq", type=float, default=10) 158 | args.add_argument("--inft", type=float, default=0.2) 159 | args = args.parse_args() 160 | 161 | action_dim = 24 162 | 163 | class DummyAgentServer(ZMQInferenceServer): 164 | def infer(self, obs): 165 | print("start inference") 166 | time.sleep(args.inft) 167 | print("stop inference") 168 | return np.zeros((16, action_dim)) 169 | 170 | if args.type == "server": 171 | server = DummyAgentServer() 172 | server.serve() 173 | 174 | elif args.type == "client": 175 | client = ZMQInferenceClient( 176 | default_action=np.zeros((action_dim,)), queue_size=32 177 | ) 178 | obs = { 179 | "img": np.zeros((4, 240, 360), dtype=np.uint16), 180 | "eef": np.zeros((24,), dtype=np.float32), 181 | } 182 | 183 | while True: 184 | time.sleep(1 / args.freq) 185 | action = client.act(obs) 186 | -------------------------------------------------------------------------------- /launch_inference_nodes.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import tyro 4 | 5 | from agents.dp_agent_zmq import BimanualDPAgentServer 6 | 7 | 8 | def boolean_string(s): 9 | if s not in {"False", "True"}: 10 | raise ValueError("Not a valid boolean string") 11 | return s == "True" 12 | 13 | 14 | @dataclass 15 | class Args: 16 | port: str = "4321" 17 | dp_ckpt_path: str = "best.ckpt" 18 | 19 | 20 | def launch_server(args: Args): 21 | server = BimanualDPAgentServer(ckpt_path=args.dp_ckpt_path, port=args.port) 22 | print(f"Starting inference server on {args.port}") 23 | 24 | print("Compiling inference") 25 | server.compile_inference() 26 | print("Done. Inference available.") 27 | server.serve() 28 | 29 | 30 | def main(args): 31 | launch_server(args) 32 | 33 | 34 | if __name__ == "__main__": 35 | main(tyro.cli(Args)) 36 | -------------------------------------------------------------------------------- /launch_nodes.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from multiprocessing import Process 3 | from typing import List, Optional, Tuple 4 | 5 | import tyro 6 | 7 | from camera_node import ZMQServerCamera, ZMQServerCameraFaster 8 | from robot_node import ZMQServerRobot 9 | from robots.robot import BimanualRobot 10 | 11 | 12 | @dataclass 13 | class Args: 14 | robot: str = "bimanual_ur" 15 | hand_type: str = "" 16 | hostname: str = "127.0.0.1" 17 | robot_ip: str = "111.111.1.1" 18 | faster: bool = True 19 | cam_names: Tuple[str, ...] = "435" 20 | ability_gripper_grip_range: int = 110 21 | img_size: Optional[Tuple[int, int]] = None # (320, 240) 22 | 23 | 24 | def launch_server_cameras(port: int, camera_id: List[str], args: Args): 25 | from cameras.realsense_camera import RealSenseCamera 26 | 27 | camera = RealSenseCamera(camera_id, img_size=args.img_size) 28 | 29 | if args.faster: 30 | server = ZMQServerCameraFaster(camera, port=port, host=args.hostname) 31 | else: 32 | server = ZMQServerCamera(camera, port=port, host=args.hostname) 33 | print(f"Starting camera server on port {port}") 34 | server.serve() 35 | 36 | 37 | def launch_robot_server(port: int, args: Args): 38 | if args.robot == "ur": 39 | from robots.ur import URRobot 40 | 41 | robot = URRobot(robot_ip=args.robot_ip) 42 | elif args.robot == "bimanual_ur": 43 | from robots.ur import URRobot 44 | 45 | if args.hand_type == "ability": 46 | # 6 DoF Ability Hand 47 | # robot_l - right hand; robot_r - left hand 48 | _robot_l = URRobot( 49 | robot_ip="111.111.1.3", 50 | no_gripper=False, 51 | gripper_type="ability", 52 | grip_range=args.ability_gripper_grip_range, 53 | port_idx=1, 54 | ) 55 | _robot_r = URRobot( 56 | robot_ip="111.111.2.3", 57 | no_gripper=False, 58 | gripper_type="ability", 59 | grip_range=args.ability_gripper_grip_range, 60 | port_idx=2, 61 | ) 62 | else: 63 | # Robotiq gripper 64 | _robot_l = URRobot(robot_ip="111.111.1.3", no_gripper=False) 65 | _robot_r = URRobot(robot_ip="111.111.2.3", no_gripper=False) 66 | robot = BimanualRobot(_robot_l, _robot_r) 67 | else: 68 | raise NotImplementedError(f"Robot {args.robot} not implemented") 69 | server = ZMQServerRobot(robot, port=port, host=args.hostname) 70 | print(f"Starting robot server on port {port}") 71 | server.serve() 72 | 73 | 74 | CAM_IDS = { 75 | "435": "000000000000", 76 | } 77 | 78 | 79 | def create_camera_server(args: Args) -> List[Process]: 80 | ids = [CAM_IDS[name] for name in args.cam_names] 81 | camera_port = 5000 82 | # start a single python process for all cameras 83 | print(f"Launching cameras {ids} on port {camera_port}") 84 | server = Process(target=launch_server_cameras, args=(camera_port, ids, args)) 85 | return server 86 | 87 | 88 | def main(args): 89 | camera_server = create_camera_server(args) 90 | print("Starting camera server process") 91 | camera_server.start() 92 | 93 | launch_robot_server(6000, args) 94 | 95 | 96 | if __name__ == "__main__": 97 | main(tyro.cli(Args)) 98 | -------------------------------------------------------------------------------- /learning/dp/.gitignore: -------------------------------------------------------------------------------- 1 | model 2 | eval 3 | __pycache__ 4 | *.sh 5 | *.yml 6 | eval 7 | runs -------------------------------------------------------------------------------- /learning/dp/data_processing.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import os 3 | import pickle 4 | 5 | import cv2 6 | import natsort 7 | import numpy as np 8 | 9 | 10 | def from_pickle(path, load_img = True, num_cam = 3): 11 | with open(path, "rb") as f: 12 | data = pickle.load(f) 13 | if "base_rgb" not in data and load_img: 14 | rgb = [] 15 | for i in range(num_cam): 16 | rgb_path = path.replace(".pkl", f"-{i}.png") 17 | if os.path.exists(rgb_path): 18 | rgb.append(cv2.imread(rgb_path)) 19 | data["base_rgb"] = np.stack(rgb, axis=0) 20 | 21 | return data 22 | 23 | # Get the trajectory data from the given directory 24 | def iterate(path, workers=32, load_img=True, num_cam=3): 25 | dir = os.listdir(path) 26 | dir = [d for d in dir if d.endswith(".pkl")] 27 | dir = natsort.natsorted(dir) 28 | dirname = os.path.basename(path) 29 | root_path = "./mask_cache" 30 | data = [] 31 | with concurrent.futures.ThreadPoolExecutor(max_workers = workers) as executor: 32 | futures = {executor.submit(from_pickle, os.path.join(path, file), load_img, num_cam): (i, file) for i, file in enumerate(dir)} 33 | for future in futures: 34 | try: 35 | i, file = futures[future] 36 | d = future.result() 37 | if not d["activated"]["l"] and not d["activated"]["r"]: 38 | continue 39 | basedirfile = os.path.join(dirname, file) 40 | maskfile = os.path.join(root_path, basedirfile) 41 | if os.path.exists(maskfile): 42 | d["mask"] = from_pickle(maskfile) 43 | d["mask_path"] = maskfile 44 | d["file_path"] = os.path.join(path, file) 45 | data.append(d) 46 | except: 47 | print(f"Failed to load {file}") 48 | pass 49 | return data 50 | 51 | 52 | def get_latest(path): 53 | dir = os.listdir(path) 54 | dir = natsort.natsorted(dir) 55 | return from_pickle(os.path.join(path, dir[-1])) 56 | 57 | # Get all trajectory directories from the given path 58 | def get_epi_dir(path, traj_type, prefix=None): 59 | dir = natsort.natsorted(os.listdir(path)) 60 | if prefix is not None: 61 | prefixs = prefix.split("-") 62 | 63 | new_dir = [] 64 | for d in dir: 65 | if os.path.isdir(os.path.join(path, d)): 66 | matched = False 67 | if prefix is None: 68 | matched = True 69 | else: 70 | for prefix in prefixs: 71 | if d.startswith(prefix): 72 | matched = True 73 | if matched: 74 | new_dir.append(d) 75 | 76 | print("All Directories") 77 | print(new_dir) 78 | print("==========") 79 | dir = new_dir 80 | if traj_type == "plain": 81 | dir = [ 82 | d 83 | for d in dir 84 | if not d.endswith("failed") 85 | and not d.endswith("ood") 86 | and not d.endswith("ikbad") 87 | and not d.endswith("heated") 88 | and not d.endswith("stop") 89 | and not d.endswith("hard") 90 | ] 91 | elif traj_type == "all": 92 | dir = dir 93 | else: 94 | raise NotImplementedError 95 | dir_list = [ 96 | os.path.join(path, d) for d in dir if os.path.isdir(os.path.join(path, d)) 97 | ] 98 | return dir_list 99 | 100 | -------------------------------------------------------------------------------- /learning/dp/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def create_sample_indices( 9 | episode_ends: np.ndarray, 10 | sequence_length: int, 11 | pad_before: int = 0, 12 | pad_after: int = 0, 13 | ): 14 | indices = list() 15 | for i in range(len(episode_ends)): 16 | start_idx = 0 17 | if i > 0: 18 | start_idx = episode_ends[i - 1] 19 | end_idx = episode_ends[i] 20 | episode_length = end_idx - start_idx 21 | 22 | min_start = -pad_before 23 | max_start = episode_length - sequence_length + pad_after 24 | 25 | # range stops one idx before end 26 | for idx in range(min_start, max_start + 1): 27 | buffer_start_idx = max(idx, 0) + start_idx 28 | buffer_end_idx = min(idx + sequence_length, episode_length) + start_idx 29 | start_offset = buffer_start_idx - (idx + start_idx) 30 | end_offset = (idx + sequence_length + start_idx) - buffer_end_idx 31 | sample_start_idx = 0 + start_offset 32 | sample_end_idx = sequence_length - end_offset 33 | indices.append( 34 | [buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx] 35 | ) 36 | indices = np.array(indices) 37 | return indices 38 | 39 | 40 | def sample_sequence( 41 | train_data, 42 | sequence_length, 43 | buffer_start_idx, 44 | buffer_end_idx, 45 | sample_start_idx, 46 | sample_end_idx, 47 | ): 48 | result = dict() 49 | for key, input_arr in train_data.items(): 50 | sample = input_arr[buffer_start_idx:buffer_end_idx] 51 | data = sample 52 | if (sample_start_idx > 0) or (sample_end_idx < sequence_length): 53 | data = np.zeros( 54 | shape=(sequence_length,) + input_arr.shape[1:], dtype=input_arr.dtype 55 | ) 56 | if sample_start_idx > 0: 57 | data[:sample_start_idx] = sample[0] 58 | if sample_end_idx < sequence_length: 59 | data[sample_end_idx:] = sample[-1] 60 | data[sample_start_idx:sample_end_idx] = sample 61 | result[key] = data 62 | return result 63 | 64 | 65 | # normalize data 66 | def get_data_stats(data): 67 | data = data.reshape(-1, data.shape[-1]) 68 | stats = {"min": np.min(data, axis=0), "max": np.max(data, axis=0)} 69 | if np.any(stats["max"] > 1e5) or np.any(stats["min"] < -1e5): 70 | raise ValueError("data out of range") 71 | return stats 72 | 73 | 74 | def normalize_data(data, stats): 75 | # nomalize to [0,1] 76 | ndata = (data - stats["min"]) / (stats["max"] - stats["min"] + 1e-8) 77 | # normalize to [-1, 1] 78 | ndata = ndata * 2 - 1 79 | return ndata 80 | 81 | 82 | def unnormalize_data(ndata, stats): 83 | ndata = (ndata + 1) / 2 84 | data = ndata * (stats["max"] - stats["min"] + 1e-8) + stats["min"] 85 | return data 86 | 87 | 88 | class MemmapLoader: 89 | def __init__(self, path): 90 | with open(os.path.join(path, "metadata.pkl"), "rb") as f: 91 | meta_data = pickle.load(f) 92 | 93 | print("Meta Data:", meta_data) 94 | self.fps = {} 95 | 96 | self.length = None 97 | for key, (shape, dtype) in meta_data.items(): 98 | self.fps[key] = np.memmap( 99 | os.path.join(path, key + ".dat"), dtype=dtype, shape=shape, mode="r" 100 | ) 101 | if self.length is None: 102 | self.length = shape[0] 103 | else: 104 | assert self.length == shape[0] 105 | 106 | def __getitem__(self, index): 107 | rets = {} 108 | for key in self.fps.keys(): 109 | value = self.fps[key] 110 | value = value[index] 111 | value_cp = np.empty(dtype=value.dtype, shape=value.shape) 112 | value_cp[:] = value 113 | rets[key] = value_cp 114 | return rets 115 | 116 | def __length__(self): 117 | return self.length 118 | 119 | 120 | # dataset 121 | class Dataset(torch.utils.data.Dataset): 122 | def __init__( 123 | self, 124 | data: dict, 125 | representation_type: list, 126 | pred_horizon: int, 127 | obs_horizon: int, 128 | action_horizon: int, 129 | stats: dict = None, 130 | transform=None, 131 | get_img=None, 132 | load_img: bool = False, 133 | hand_grip_range: int = 110, 134 | binarize_touch: bool = False, 135 | state_noise: float = 0.0, 136 | ): 137 | self.state_noise = state_noise 138 | self.memmap_loader = None 139 | if "memmap_loader_path" in data.keys(): 140 | self.memmap_loader = MemmapLoader(data["memmap_loader_path"]) 141 | self.representation_type = representation_type 142 | self.transform = transform 143 | self.get_img = get_img 144 | self.load_img = load_img 145 | 146 | print("Representation type: ", representation_type) 147 | except_img_representation_type = representation_type.copy() 148 | 149 | if "img" in representation_type: 150 | train_image_data = data["data"]["img"][:] 151 | except_img_representation_type.remove("img") 152 | 153 | train_data = { 154 | rt: data["data"][rt][:, :] for rt in except_img_representation_type 155 | } 156 | train_data["action"] = data["data"]["action"][:] 157 | episode_ends = data["meta"]["episode_ends"][:] 158 | 159 | # compute start and end of each state-action sequence 160 | # also handles padding 161 | indices = create_sample_indices( 162 | episode_ends=episode_ends, 163 | sequence_length=pred_horizon, 164 | pad_before=obs_horizon - 1, 165 | pad_after=action_horizon - 1, 166 | ) 167 | 168 | normalized_train_data = dict() 169 | 170 | # compute statistics and normalized data to [-1,1] 171 | if stats is None: 172 | stats = dict() 173 | for key, data in train_data.items(): 174 | stats[key] = get_data_stats(data) 175 | 176 | # overwrite the min max of hand info 177 | left_hand_indices = np.array([6, 7, 8, 9, 10, 11]) 178 | right_hand_indices = np.array([18, 19, 20, 21, 22, 23]) 179 | 180 | # hand joint_position range 181 | hand_upper_ranges = np.array( 182 | [hand_grip_range] * 4 + [90, 120], dtype=np.float32 183 | ) 184 | hand_lower_ranges = np.array([5, 5, 5, 5, 5, 5], dtype=np.float32) 185 | 186 | if "pos" in representation_type: 187 | stats["pos"]["min"][left_hand_indices] = hand_lower_ranges 188 | stats["pos"]["min"][right_hand_indices] = hand_lower_ranges 189 | stats["pos"]["max"][left_hand_indices] = hand_upper_ranges 190 | stats["pos"]["max"][right_hand_indices] = hand_upper_ranges 191 | elif "hand_pos" in representation_type: 192 | stats["hand_pos"]["min"][range(6)] = hand_lower_ranges 193 | stats["hand_pos"]["min"][range(6, 12)] = hand_lower_ranges 194 | stats["hand_pos"]["max"][range(6)] = hand_upper_ranges 195 | stats["hand_pos"]["max"][range(6, 12)] = hand_upper_ranges 196 | 197 | # hand action is normalized to [0,1] 198 | stats["action"]["min"][left_hand_indices] = 0.0 199 | stats["action"]["max"][left_hand_indices] = 1.0 200 | stats["action"]["min"][right_hand_indices] = 0.0 201 | stats["action"]["max"][right_hand_indices] = 1.0 202 | 203 | for key, data in train_data.items(): 204 | if key == "touch" and binarize_touch: 205 | normalized_train_data[key] = ( 206 | data # don't normalize if binarize touch in model 207 | ) 208 | else: 209 | normalized_train_data[key] = normalize_data(data, stats[key]) 210 | 211 | # images are already normalized 212 | if "img" in representation_type: 213 | normalized_train_data["img"] = train_image_data 214 | 215 | self.indices = indices 216 | self.stats = stats 217 | self.normalized_train_data = normalized_train_data 218 | self.pred_horizon = pred_horizon 219 | self.action_horizon = action_horizon 220 | self.obs_horizon = obs_horizon 221 | self.binarize_touch = binarize_touch 222 | 223 | def __len__(self): 224 | return len(self.indices) 225 | 226 | def read_img(self, image_pathes, idx): 227 | if self.memmap_loader is not None: 228 | # using memmap loader 229 | indices = range(idx, idx + self.obs_horizon) 230 | data = self.memmap_loader[indices] 231 | data = [ 232 | {"base_rgb": data["base_rgb"][i], "base_depth": data["base_depth"][i]} 233 | for i in range(data["base_rgb"].shape[0]) 234 | ] 235 | else: 236 | # not using memmap loader and loading images while training 237 | data = [pickle.load(open(image_path, "rb")) for image_path in image_pathes] 238 | imgs = self.get_img(data) 239 | return imgs 240 | 241 | def __getitem__(self, idx): 242 | # get the start/end indices for this datapoint 243 | ( 244 | buffer_start_idx, 245 | buffer_end_idx, 246 | sample_start_idx, 247 | sample_end_idx, 248 | ) = self.indices[idx] 249 | 250 | # get nomralized data using these indices 251 | nsample = sample_sequence( 252 | train_data=self.normalized_train_data, 253 | sequence_length=self.pred_horizon, 254 | buffer_start_idx=buffer_start_idx, 255 | buffer_end_idx=buffer_end_idx, 256 | sample_start_idx=sample_start_idx, 257 | sample_end_idx=sample_end_idx, 258 | ) 259 | 260 | for k in self.representation_type: 261 | # discard unused observations 262 | nsample[k] = nsample[k][: self.obs_horizon] 263 | if k == "img": 264 | if not self.load_img: 265 | nsample["img"] = self.read_img(nsample["img"], idx) 266 | else: 267 | nsample["img"] = torch.tensor( 268 | nsample["img"].astype(np.float32), dtype=torch.float32 269 | ) 270 | nsample_shape = nsample["img"].shape 271 | # transform the img 272 | nsample["img"] = nsample["img"].reshape( 273 | nsample_shape[0] * nsample_shape[1], *nsample_shape[2:] 274 | ) # (Batch * num_cam, Channel, Height, Width) 275 | nsample["img"] = self.transform(nsample["img"]) 276 | nsample["img"] = nsample["img"].reshape(nsample_shape[:3] + (216, 288)) # (Batch, num_cam, Channel, Height, Width) 277 | 278 | else: 279 | nsample[k] = torch.tensor(nsample[k], dtype=torch.float32) 280 | if self.state_noise > 0.0: 281 | # add noise to the state 282 | nsample[k] = nsample[k] + torch.randn_like(nsample[k]) * self.state_noise 283 | nsample["action"] = torch.tensor(nsample["action"], dtype=torch.float32) 284 | return nsample 285 | -------------------------------------------------------------------------------- /learning/dp/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pprint 4 | import random 5 | import string 6 | import tempfile 7 | import time 8 | from copy import copy 9 | from datetime import datetime 10 | 11 | import absl.flags 12 | import numpy as np 13 | import quaternion 14 | import wandb 15 | from absl import logging 16 | from ml_collections import ConfigDict 17 | from ml_collections.config_dict import config_dict 18 | from ml_collections.config_flags import config_flags 19 | 20 | 21 | def save_args(args, output_dir): 22 | args_dict = vars(args) 23 | args_str = "\n".join(f"{key}: {value}" for key, value in args_dict.items()) 24 | with open(os.path.join(output_dir, "args_log.txt"), "w") as file: 25 | file.write(args_str) 26 | with open(os.path.join(output_dir, "args_log.json"), "w") as json_file: 27 | json.dump(args_dict, json_file, indent=4) 28 | 29 | 30 | def generate_random_string(length=4, characters=string.ascii_letters + string.digits): 31 | """ 32 | Generate a random string of the specified length using the given characters. 33 | 34 | :param length: The length of the random string (default is 12). 35 | :param characters: The characters to choose from when generating the string 36 | (default is uppercase letters, lowercase letters, and digits). 37 | :return: A random string of the specified length. 38 | """ 39 | return "".join(random.choice(characters) for _ in range(length)) 40 | 41 | 42 | def get_eef_delta(eef_pose, eef_pose_target): 43 | pos_delta = eef_pose_target[:3] - eef_pose[:3] 44 | # axis angle to quaternion 45 | ee_rot = quaternion.from_rotation_vector(eef_pose[3:]) 46 | ee_rot_target = quaternion.from_rotation_vector(eef_pose_target[3:]) 47 | # calculate the quaternion difference 48 | rot_delta = quaternion.as_rotation_vector(ee_rot_target * ee_rot.inverse()) 49 | return np.concatenate((pos_delta, rot_delta)) 50 | 51 | 52 | class Timer(object): 53 | def __init__(self): 54 | self._time = None 55 | 56 | def __enter__(self): 57 | self._start_time = time.time() 58 | return self 59 | 60 | def __exit__(self, exc_type, exc_value, exc_tb): 61 | self._time = time.time() - self._start_time 62 | 63 | def __call__(self): 64 | return self._time 65 | 66 | 67 | class WandBLogger(object): 68 | @staticmethod 69 | def get_default_config(updates=None): 70 | config = ConfigDict() 71 | config.mode = "online" 72 | config.project = "hato" 73 | config.entity = "user" 74 | config.output_dir = "." 75 | config.exp_name = str(datetime.now())[:19].replace(" ", "_") 76 | config.random_delay = 0.5 77 | config.experiment_id = config_dict.placeholder(str) 78 | config.anonymous = config_dict.placeholder(str) 79 | config.notes = config_dict.placeholder(str) 80 | config.time = str(datetime.now())[:19].replace(" ", "_") 81 | 82 | if updates is not None: 83 | config.update(ConfigDict(updates).copy_and_resolve_references()) 84 | return config 85 | 86 | def __init__(self, config, variant, prefix=None): 87 | self.config = self.get_default_config(config) 88 | 89 | for key, val in sorted(self.config.items()): 90 | if type(val) != str: 91 | continue 92 | new_val = _parse(val, variant) 93 | if val != new_val: 94 | logging.info( 95 | "processing configs: {}: {} => {}".format(key, val, new_val) 96 | ) 97 | setattr(self.config, key, new_val) 98 | 99 | output = flatten_config_dict(self.config, prefix=prefix) 100 | variant.update(output) 101 | 102 | if self.config.output_dir == "": 103 | self.config.output_dir = tempfile.mkdtemp() 104 | 105 | output = flatten_config_dict(self.config, prefix=prefix) 106 | variant.update(output) 107 | 108 | self._variant = copy(variant) 109 | 110 | logging.info( 111 | "wandb logging with hyperparameters: \n{}".format( 112 | pprint.pformat( 113 | ["{}: {}".format(key, val) for key, val in self.variant.items()] 114 | ) 115 | ) 116 | ) 117 | 118 | if self.config.random_delay > 0: 119 | time.sleep(np.random.uniform(0.1, 0.1 + self.config.random_delay)) 120 | 121 | self.run = wandb.init( 122 | entity=self.config.entity, 123 | reinit=True, 124 | config=self._variant, 125 | project=self.config.project, 126 | dir=self.config.output_dir, 127 | name=self.config.exp_name, 128 | anonymous=self.config.anonymous, 129 | monitor_gym=False, 130 | notes=self.config.notes, 131 | settings=wandb.Settings( 132 | start_method="thread", 133 | _disable_stats=True, 134 | ), 135 | mode=self.config.mode, 136 | ) 137 | 138 | self.logging_step = 0 139 | 140 | def log(self, *args, **kwargs): 141 | self.run.log(*args, **kwargs, step=self.logging_step) 142 | 143 | def step(self): 144 | self.logging_step += 1 145 | 146 | @property 147 | def experiment_id(self): 148 | return self.config.experiment_id 149 | 150 | @property 151 | def variant(self): 152 | return self._variant 153 | 154 | @property 155 | def output_dir(self): 156 | return self.config.output_dir 157 | 158 | 159 | def define_flags_with_default(**kwargs): 160 | for key, val in kwargs.items(): 161 | if isinstance(val, ConfigDict): 162 | config_flags.DEFINE_config_dict(key, val) 163 | elif isinstance(val, bool): 164 | # Note that True and False are instances of int. 165 | absl.flags.DEFINE_bool(key, val, "automatically defined flag") 166 | elif isinstance(val, int): 167 | absl.flags.DEFINE_integer(key, val, "automatically defined flag") 168 | elif isinstance(val, float): 169 | absl.flags.DEFINE_float(key, val, "automatically defined flag") 170 | elif isinstance(val, str): 171 | absl.flags.DEFINE_string(key, val, "automatically defined flag") 172 | else: 173 | raise ValueError("Incorrect value type") 174 | return kwargs 175 | 176 | 177 | def _parse(s, variant): 178 | orig_s = copy(s) 179 | final_s = [] 180 | 181 | while len(s) > 0: 182 | indx = s.find("{") 183 | if indx == -1: 184 | final_s.append(s) 185 | break 186 | final_s.append(s[:indx]) 187 | s = s[indx + 1 :] 188 | indx = s.find("}") 189 | assert indx != -1, "can't find the matching right bracket for {}".format(orig_s) 190 | final_s.append(str(variant[s[:indx]])) 191 | s = s[indx + 1 :] 192 | 193 | return "".join(final_s) 194 | 195 | 196 | def get_user_flags(flags, flags_def): 197 | output = {} 198 | for key in sorted(flags_def): 199 | val = getattr(flags, key) 200 | if isinstance(val, ConfigDict): 201 | flatten_config_dict(val, prefix=key, output=output) 202 | else: 203 | output[key] = val 204 | 205 | return output 206 | 207 | 208 | def flatten_config_dict(config, prefix=None, output=None): 209 | if output is None: 210 | output = {} 211 | for key, val in sorted(config.items()): 212 | if prefix is not None: 213 | next_prefix = "{}.{}".format(prefix, key) 214 | else: 215 | next_prefix = key 216 | if isinstance(val, ConfigDict): 217 | flatten_config_dict(val, prefix=next_prefix, output=output) 218 | else: 219 | output[next_prefix] = val 220 | return output 221 | 222 | 223 | def to_config_dict(flattened): 224 | config = config_dict.ConfigDict() 225 | for key, val in flattened.items(): 226 | c_config = config 227 | ks = key.split(".") 228 | for k in ks[:-1]: 229 | if k not in c_config: 230 | c_config[k] = config_dict.ConfigDict() 231 | c_config = c_config[k] 232 | c_config[ks[-1]] = val 233 | return config.to_dict() 234 | 235 | 236 | def prefix_metrics(metrics, prefix): 237 | return {"{}/{}".format(prefix, key): value for key, value in metrics.items()} 238 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # keep in alphabetical order 2 | diffusers 3 | ml-collections 4 | natsort 5 | numpy 6 | numpy-quaternion 7 | tensorboard 8 | tyro -------------------------------------------------------------------------------- /robot_node.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import threading 3 | from typing import Any, Dict 4 | 5 | import numpy as np 6 | import zmq 7 | 8 | from robots.robot import Robot 9 | 10 | DEFAULT_ROBOT_PORT = 6000 11 | 12 | 13 | class ZMQServerRobot: 14 | def __init__( 15 | self, 16 | robot: Robot, 17 | port: int = DEFAULT_ROBOT_PORT, 18 | host: str = "127.0.0.1", 19 | ): 20 | self._robot = robot 21 | self._context = zmq.Context() 22 | self._socket = self._context.socket(zmq.REP) 23 | addr = f"tcp://{host}:{port}" 24 | debug_message = f"Robot Sever Binding to {addr}, Robot: {robot}" 25 | print(debug_message) 26 | self._timout_message = f"Timeout in Robot Server, Robot: {robot}" 27 | self._socket.bind(addr) 28 | self._stop_event = threading.Event() 29 | 30 | def serve(self) -> None: 31 | """Serve the leader robot state over ZMQ.""" 32 | self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms 33 | while not self._stop_event.is_set(): 34 | try: 35 | # Wait for next request from client 36 | message = self._socket.recv() 37 | request = pickle.loads(message) 38 | 39 | # Call the appropriate method based on the request 40 | method = request.get("method") 41 | args = request.get("args", {}) 42 | result: Any 43 | if method == "num_dofs": 44 | result = self._robot.num_dofs() 45 | elif method == "get_joint_state": 46 | result = self._robot.get_joint_state() 47 | elif method == "command_joint_state": 48 | result = self._robot.command_joint_state(**args) 49 | elif method == "command_eef_pose": 50 | result = self._robot.command_eef_pose(**args) 51 | elif method == "get_observations": 52 | result = self._robot.get_observations() 53 | else: 54 | result = {"error": "Invalid method"} 55 | print(result) 56 | raise NotImplementedError( 57 | f"Invalid method: {method}, {args, result}" 58 | ) 59 | 60 | self._socket.send(pickle.dumps(result)) 61 | except zmq.Again: 62 | print(self._timout_message) 63 | # Timeout occurred, check if the stop event is set 64 | 65 | def stop(self) -> None: 66 | """Signal the server to stop serving.""" 67 | self._stop_event.set() 68 | 69 | 70 | class ZMQClientRobot(Robot): 71 | """A class representing a ZMQ client for a leader robot.""" 72 | 73 | def __init__(self, port: int = DEFAULT_ROBOT_PORT, host: str = "127.0.0.1"): 74 | self._context = zmq.Context() 75 | self._socket = self._context.socket(zmq.REQ) 76 | self._socket.connect(f"tcp://{host}:{port}") 77 | 78 | def num_dofs(self) -> int: 79 | """Get the number of joints in the robot. 80 | 81 | Returns: 82 | int: The number of joints in the robot. 83 | """ 84 | request = {"method": "num_dofs"} 85 | send_message = pickle.dumps(request) 86 | self._socket.send(send_message) 87 | result = pickle.loads(self._socket.recv()) 88 | return result 89 | 90 | def get_joint_state(self) -> np.ndarray: 91 | """Get the current state of the leader robot. 92 | 93 | Returns: 94 | T: The current state of the leader robot. 95 | """ 96 | request = {"method": "get_joint_state"} 97 | send_message = pickle.dumps(request) 98 | self._socket.send(send_message) 99 | result = pickle.loads(self._socket.recv()) 100 | return result 101 | 102 | def command_joint_state(self, joint_state: np.ndarray) -> None: 103 | """Command the leader robot to the given state. 104 | 105 | Args: 106 | joint_state (T): The state to command the leader robot to. 107 | """ 108 | request = { 109 | "method": "command_joint_state", 110 | "args": {"joint_state": joint_state}, 111 | } 112 | send_message = pickle.dumps(request) 113 | self._socket.send(send_message) 114 | result = pickle.loads(self._socket.recv()) 115 | return result 116 | 117 | def command_eef_pose(self, eef_pose: np.ndarray) -> None: 118 | """Command the leader robot to the given state. 119 | 120 | Args: 121 | eef_pose: end effector pose command to step the environment with. 122 | """ 123 | request = { 124 | "method": "command_eef_pose", 125 | "args": {"eef_pose": eef_pose}, 126 | } 127 | send_message = pickle.dumps(request) 128 | self._socket.send(send_message) 129 | result = pickle.loads(self._socket.recv()) 130 | return result 131 | 132 | def get_observations(self) -> Dict[str, np.ndarray]: 133 | """Get the current observations of the leader robot. 134 | 135 | Returns: 136 | Dict[str, np.ndarray]: The current observations of the leader robot. 137 | """ 138 | request = {"method": "get_observations"} 139 | send_message = pickle.dumps(request) 140 | self._socket.send(send_message) 141 | result = pickle.loads(self._socket.recv()) 142 | return result 143 | -------------------------------------------------------------------------------- /robots/ability_gripper.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import struct 3 | import threading 4 | import time 5 | 6 | import numpy as np 7 | import serial 8 | from serial.tools import list_ports 9 | 10 | 11 | ## Send Miscellanous Command to Ability Hand 12 | def create_misc_msg(cmd): 13 | barr = [] 14 | barr.append((struct.pack(" 150: 167 | need_reset = True 168 | 169 | ## Extract Touch Data if Available 170 | if reply_len == 71: 171 | ## Extract Data two at a time 172 | for i in range(0, 15): 173 | dual_data = data[(i * 3) + 24 : ((i + 1) * 3) + 24] 174 | data1 = struct.unpack("> 4 178 | self.touch[i * 2] = int(data1) 179 | self.touch[(i * 2) + 1] = int(data2) 180 | if data1 > 4096 or data2 > 4096: 181 | need_reset = True 182 | else: 183 | need_reset = True 184 | else: 185 | need_reset = True 186 | 187 | if need_reset: 188 | self.ser.reset_input_buffer() 189 | self.reset_count += 1 190 | need_reset = False 191 | self.pos = self.prev_pos.copy() 192 | self.touch = self.prev_touch.copy() 193 | 194 | self.prev_pos = self.pos.copy() 195 | self.prev_touch = self.touch.copy() 196 | 197 | self.total_count += 1 198 | 199 | def _process_pos_cmd(self, pos_cmd): 200 | """ 201 | pos_cmd: a list of floats [0, 1] normalized command from external input devices 202 | """ 203 | assert len(pos_cmd) == 6 204 | positions = [1] * 6 205 | for i in range(6): 206 | positions[i] = self.lower_ranges[i] + pos_cmd[i] * ( 207 | self.upper_ranges[i] - self.lower_ranges[i] 208 | ) 209 | if i == 5: 210 | # invert thumb rotator 211 | positions[i] = -positions[i] 212 | return positions 213 | 214 | def _pos_to_cmd(self, pos): 215 | """ 216 | pos: desired hand degrees 217 | """ 218 | assert len(pos) == 6 219 | cmd = [0] * 6 220 | for i in range(6): 221 | if i == 5: 222 | pos[i] = -pos[i] 223 | cmd[i] = (pos[i] - self.lower_ranges[i]) / ( 224 | self.upper_ranges[i] - self.lower_ranges[i] 225 | ) 226 | return cmd 227 | 228 | def _generate_tx(self, positions): 229 | """ 230 | Position control mode - reply format must be one of the 3 variants (0x10, 0x11, 0x12) 231 | """ 232 | txBuf = [] 233 | 234 | ## Address in byte 0 235 | txBuf.append((struct.pack("> 8) & 0xFF))[0]) 245 | 246 | ## calculate checksum 247 | cksum = 0 248 | for b in txBuf: 249 | cksum = cksum + b 250 | cksum = (-cksum) & 0xFF 251 | txBuf.append((struct.pack(" 1000 means contact 286 | return self.last_pos_msg + self.pos + self.touch 287 | 288 | def get_current_touch(self): 289 | return self.touch 290 | 291 | def get_current_position(self): 292 | # pos readings 293 | return self.pos 294 | -------------------------------------------------------------------------------- /robots/robot.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Dict, Protocol 3 | 4 | import numpy as np 5 | 6 | 7 | class Robot(Protocol): 8 | """Robot protocol. 9 | 10 | A protocol for a robot that can be controlled. 11 | """ 12 | 13 | @abstractmethod 14 | def num_dofs(self) -> int: 15 | """Get the number of joints of the robot. 16 | 17 | Returns: 18 | int: The number of joints of the robot. 19 | """ 20 | raise NotImplementedError 21 | 22 | @abstractmethod 23 | def get_joint_state(self) -> np.ndarray: 24 | """Get the current state of the leader robot. 25 | 26 | Returns: 27 | T: The current state of the leader robot. 28 | """ 29 | raise NotImplementedError 30 | 31 | @abstractmethod 32 | def command_joint_state(self, joint_state: np.ndarray) -> None: 33 | """Command the leader robot to a given state. 34 | 35 | Args: 36 | joint_state (np.ndarray): The state to command the leader robot to. 37 | """ 38 | raise NotImplementedError 39 | 40 | @abstractmethod 41 | def command_eef_pose(self, eef_pose: np.ndarray) -> None: 42 | """Command the leader robot to a given state. 43 | 44 | Args: 45 | eef_pose (np.ndarray): The EEF pose to command the leader robot to. 46 | """ 47 | raise NotImplementedError 48 | 49 | @abstractmethod 50 | def get_observations(self) -> Dict[str, np.ndarray]: 51 | """Get the current observations of the robot. 52 | 53 | This is to extract all the information that is available from the robot, 54 | such as joint positions, joint velocities, etc. This may also include 55 | information from additional sensors, such as cameras, force sensors, etc. 56 | 57 | Returns: 58 | Dict[str, np.ndarray]: A dictionary of observations. 59 | """ 60 | raise NotImplementedError 61 | 62 | 63 | class BimanualRobot(Robot): 64 | def __init__(self, robot_l: Robot, robot_r: Robot): 65 | self._robot_l = robot_l 66 | self._robot_r = robot_r 67 | 68 | def num_dofs(self) -> int: 69 | return self._robot_l.num_dofs() + self._robot_r.num_dofs() 70 | 71 | def get_joint_state(self) -> np.ndarray: 72 | return np.concatenate( 73 | (self._robot_l.get_joint_state(), self._robot_r.get_joint_state()) 74 | ) 75 | 76 | def command_joint_state(self, joint_state: np.ndarray) -> None: 77 | self._robot_l.command_joint_state(joint_state[: self._robot_l.num_dofs()]) 78 | self._robot_r.command_joint_state(joint_state[self._robot_l.num_dofs() :]) 79 | 80 | def command_eef_pose(self, eef_pose: np.ndarray) -> None: 81 | self._robot_l.command_eef_pose(eef_pose[: self._robot_l.num_dofs()]) 82 | self._robot_r.command_eef_pose(eef_pose[self._robot_l.num_dofs() :]) 83 | 84 | def get_observations(self) -> Dict[str, np.ndarray]: 85 | l_obs = self._robot_l.get_observations() 86 | r_obs = self._robot_r.get_observations() 87 | assert l_obs.keys() == r_obs.keys() 88 | return_obs = {} 89 | for k in l_obs.keys(): 90 | try: 91 | return_obs[k] = np.concatenate((l_obs[k], r_obs[k])) 92 | except Exception as e: 93 | print(e) 94 | print(k) 95 | print(l_obs[k]) 96 | print(r_obs[k]) 97 | raise RuntimeError() 98 | 99 | return return_obs 100 | -------------------------------------------------------------------------------- /robots/robotiq_gripper.py: -------------------------------------------------------------------------------- 1 | """Module to control Robotiq's grippers - tested with HAND-E. 2 | 3 | Taken from https://github.com/githubuser0xFFFF/py_robotiq_gripper/blob/master/src/robotiq_gripper.py 4 | """ 5 | 6 | import socket 7 | import threading 8 | import time 9 | from enum import Enum 10 | from typing import OrderedDict, Tuple, Union 11 | 12 | 13 | class RobotiqGripper: 14 | """Communicates with the gripper directly, via socket with string commands, leveraging string names for variables.""" 15 | 16 | # WRITE VARIABLES (CAN ALSO READ) 17 | ACT = ( 18 | "ACT" # act : activate (1 while activated, can be reset to clear fault status) 19 | ) 20 | GTO = ( 21 | "GTO" # gto : go to (will perform go to with the actions set in pos, for, spe) 22 | ) 23 | ATR = "ATR" # atr : auto-release (emergency slow move) 24 | ADR = ( 25 | "ADR" # adr : auto-release direction (open(1) or close(0) during auto-release) 26 | ) 27 | FOR = "FOR" # for : force (0-255) 28 | SPE = "SPE" # spe : speed (0-255) 29 | POS = "POS" # pos : position (0-255), 0 = open 30 | # READ VARIABLES 31 | STA = "STA" # status (0 = is reset, 1 = activating, 3 = active) 32 | PRE = "PRE" # position request (echo of last commanded position) 33 | OBJ = "OBJ" # object detection (0 = moving, 1 = outer grip, 2 = inner grip, 3 = no object at rest) 34 | FLT = "FLT" # fault (0=ok, see manual for errors if not zero) 35 | 36 | ENCODING = "UTF-8" # ASCII and UTF-8 both seem to work 37 | 38 | class GripperStatus(Enum): 39 | """Gripper status reported by the gripper. The integer values have to match what the gripper sends.""" 40 | 41 | RESET = 0 42 | ACTIVATING = 1 43 | # UNUSED = 2 # This value is currently not used by the gripper firmware 44 | ACTIVE = 3 45 | 46 | class ObjectStatus(Enum): 47 | """Object status reported by the gripper. The integer values have to match what the gripper sends.""" 48 | 49 | MOVING = 0 50 | STOPPED_OUTER_OBJECT = 1 51 | STOPPED_INNER_OBJECT = 2 52 | AT_DEST = 3 53 | 54 | def __init__(self): 55 | """Constructor.""" 56 | self.socket = None 57 | self.command_lock = threading.Lock() 58 | self._min_position = 0 59 | self._max_position = 255 60 | self._min_speed = 0 61 | self._max_speed = 255 62 | self._min_force = 0 63 | self._max_force = 255 64 | 65 | def connect(self, hostname: str, port: int, socket_timeout: float = 10.0) -> None: 66 | """Connects to a gripper at the given address. 67 | 68 | :param hostname: Hostname or ip. 69 | :param port: Port. 70 | :param socket_timeout: Timeout for blocking socket operations. 71 | """ 72 | self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 73 | assert self.socket is not None 74 | self.socket.connect((hostname, port)) 75 | self.socket.settimeout(socket_timeout) 76 | 77 | def disconnect(self) -> None: 78 | """Closes the connection with the gripper.""" 79 | assert self.socket is not None 80 | self.socket.close() 81 | 82 | def _set_vars(self, var_dict: OrderedDict[str, Union[int, float]]): 83 | """Sends the appropriate command via socket to set the value of n variables, and waits for its 'ack' response. 84 | 85 | :param var_dict: Dictionary of variables to set (variable_name, value). 86 | :return: True on successful reception of ack, false if no ack was received, indicating the set may not 87 | have been effective. 88 | """ 89 | assert self.socket is not None 90 | # construct unique command 91 | cmd = "SET" 92 | for variable, value in var_dict.items(): 93 | cmd += f" {variable} {str(value)}" 94 | cmd += "\n" # new line is required for the command to finish 95 | # atomic commands send/rcv 96 | with self.command_lock: 97 | self.socket.sendall(cmd.encode(self.ENCODING)) 98 | data = self.socket.recv(1024) 99 | return self._is_ack(data) 100 | 101 | def _set_var(self, variable: str, value: Union[int, float]): 102 | """Sends the appropriate command via socket to set the value of a variable, and waits for its 'ack' response. 103 | 104 | :param variable: Variable to set. 105 | :param value: Value to set for the variable. 106 | :return: True on successful reception of ack, false if no ack was received, indicating the set may not 107 | have been effective. 108 | """ 109 | return self._set_vars(OrderedDict([(variable, value)])) 110 | 111 | def _get_var(self, variable: str): 112 | """Sends the appropriate command to retrieve the value of a variable from the gripper, blocking until the response is received or the socket times out. 113 | 114 | :param variable: Name of the variable to retrieve. 115 | :return: Value of the variable as integer. 116 | """ 117 | assert self.socket is not None 118 | # atomic commands send/rcv 119 | with self.command_lock: 120 | cmd = f"GET {variable}\n" 121 | self.socket.sendall(cmd.encode(self.ENCODING)) 122 | data = self.socket.recv(1024) 123 | 124 | # expect data of the form 'VAR x', where VAR is an echo of the variable name, and X the value 125 | # note some special variables (like FLT) may send 2 bytes, instead of an integer. We assume integer here 126 | var_name, value_str = data.decode(self.ENCODING).split() 127 | if var_name != variable: 128 | raise ValueError( 129 | f"Unexpected response {data} ({data.decode(self.ENCODING)}): does not match '{variable}'" 130 | ) 131 | value = int(value_str) 132 | return value 133 | 134 | @staticmethod 135 | def _is_ack(data: str): 136 | return data == b"ack" 137 | 138 | def _reset(self): 139 | """Reset the gripper. 140 | 141 | The following code is executed in the corresponding script function 142 | def rq_reset(gripper_socket="1"): 143 | rq_set_var("ACT", 0, gripper_socket) 144 | rq_set_var("ATR", 0, gripper_socket) 145 | 146 | while(not rq_get_var("ACT", 1, gripper_socket) == 0 or not rq_get_var("STA", 1, gripper_socket) == 0): 147 | rq_set_var("ACT", 0, gripper_socket) 148 | rq_set_var("ATR", 0, gripper_socket) 149 | sync() 150 | end 151 | 152 | sleep(0.5) 153 | end 154 | """ 155 | self._set_var(self.ACT, 0) 156 | self._set_var(self.ATR, 0) 157 | while not self._get_var(self.ACT) == 0 or not self._get_var(self.STA) == 0: 158 | self._set_var(self.ACT, 0) 159 | self._set_var(self.ATR, 0) 160 | time.sleep(0.5) 161 | 162 | def activate(self, auto_calibrate: bool = True): 163 | """Resets the activation flag in the gripper, and sets it back to one, clearing previous fault flags. 164 | 165 | :param auto_calibrate: Whether to calibrate the minimum and maximum positions based on actual motion. 166 | 167 | The following code is executed in the corresponding script function 168 | def rq_activate(gripper_socket="1"): 169 | if (not rq_is_gripper_activated(gripper_socket)): 170 | rq_reset(gripper_socket) 171 | 172 | while(not rq_get_var("ACT", 1, gripper_socket) == 0 or not rq_get_var("STA", 1, gripper_socket) == 0): 173 | rq_reset(gripper_socket) 174 | sync() 175 | end 176 | 177 | rq_set_var("ACT",1, gripper_socket) 178 | end 179 | end 180 | 181 | def rq_activate_and_wait(gripper_socket="1"): 182 | if (not rq_is_gripper_activated(gripper_socket)): 183 | rq_activate(gripper_socket) 184 | sleep(1.0) 185 | 186 | while(not rq_get_var("ACT", 1, gripper_socket) == 1 or not rq_get_var("STA", 1, gripper_socket) == 3): 187 | sleep(0.1) 188 | end 189 | 190 | sleep(0.5) 191 | end 192 | end 193 | """ 194 | if not self.is_active(): 195 | self._reset() 196 | while not self._get_var(self.ACT) == 0 or not self._get_var(self.STA) == 0: 197 | time.sleep(0.01) 198 | 199 | self._set_var(self.ACT, 1) 200 | time.sleep(1.0) 201 | while not self._get_var(self.ACT) == 1 or not self._get_var(self.STA) == 3: 202 | time.sleep(0.01) 203 | 204 | # auto-calibrate position range if desired 205 | if auto_calibrate: 206 | self.auto_calibrate() 207 | 208 | def is_active(self): 209 | """Returns whether the gripper is active.""" 210 | status = self._get_var(self.STA) 211 | return ( 212 | RobotiqGripper.GripperStatus(status) == RobotiqGripper.GripperStatus.ACTIVE 213 | ) 214 | 215 | def get_min_position(self) -> int: 216 | """Returns the minimum position the gripper can reach (open position).""" 217 | return self._min_position 218 | 219 | def get_max_position(self) -> int: 220 | """Returns the maximum position the gripper can reach (closed position).""" 221 | return self._max_position 222 | 223 | def get_open_position(self) -> int: 224 | """Returns what is considered the open position for gripper (minimum position value).""" 225 | return self.get_min_position() 226 | 227 | def get_closed_position(self) -> int: 228 | """Returns what is considered the closed position for gripper (maximum position value).""" 229 | return self.get_max_position() 230 | 231 | def is_open(self): 232 | """Returns whether the current position is considered as being fully open.""" 233 | return self.get_current_position() <= self.get_open_position() 234 | 235 | def is_closed(self): 236 | """Returns whether the current position is considered as being fully closed.""" 237 | return self.get_current_position() >= self.get_closed_position() 238 | 239 | def get_current_position(self) -> int: 240 | """Returns the current position as returned by the physical hardware.""" 241 | return self._get_var(self.POS) 242 | 243 | def auto_calibrate(self, log: bool = True) -> None: 244 | """Attempts to calibrate the open and closed positions, by slowly closing and opening the gripper. 245 | 246 | :param log: Whether to print the results to log. 247 | """ 248 | # first try to open in case we are holding an object 249 | (position, status) = self.move_and_wait_for_pos(self.get_open_position(), 64, 1) 250 | if RobotiqGripper.ObjectStatus(status) != RobotiqGripper.ObjectStatus.AT_DEST: 251 | raise RuntimeError(f"Calibration failed opening to start: {str(status)}") 252 | 253 | # try to close as far as possible, and record the number 254 | (position, status) = self.move_and_wait_for_pos( 255 | self.get_closed_position(), 64, 1 256 | ) 257 | if RobotiqGripper.ObjectStatus(status) != RobotiqGripper.ObjectStatus.AT_DEST: 258 | raise RuntimeError( 259 | f"Calibration failed because of an object: {str(status)}" 260 | ) 261 | assert position <= self._max_position 262 | self._max_position = position 263 | 264 | # try to open as far as possible, and record the number 265 | (position, status) = self.move_and_wait_for_pos(self.get_open_position(), 64, 1) 266 | if RobotiqGripper.ObjectStatus(status) != RobotiqGripper.ObjectStatus.AT_DEST: 267 | raise RuntimeError( 268 | f"Calibration failed because of an object: {str(status)}" 269 | ) 270 | assert position >= self._min_position 271 | self._min_position = position 272 | 273 | if log: 274 | print( 275 | f"Gripper auto-calibrated to [{self.get_min_position()}, {self.get_max_position()}]" 276 | ) 277 | 278 | def move(self, position: int, speed: int, force: int) -> Tuple[bool, int]: 279 | """Sends commands to start moving towards the given position, with the specified speed and force. 280 | 281 | :param position: Position to move to [min_position, max_position] 282 | :param speed: Speed to move at [min_speed, max_speed] 283 | :param force: Force to use [min_force, max_force] 284 | :return: A tuple with a bool indicating whether the action it was successfully sent, and an integer with 285 | the actual position that was requested, after being adjusted to the min/max calibrated range. 286 | """ 287 | 288 | def clip_val(min_val, val, max_val): 289 | return max(min_val, min(val, max_val)) 290 | 291 | clip_pos = clip_val(self._min_position, position, self._max_position) 292 | clip_spe = clip_val(self._min_speed, speed, self._max_speed) 293 | clip_for = clip_val(self._min_force, force, self._max_force) 294 | 295 | # moves to the given position with the given speed and force 296 | var_dict = OrderedDict( 297 | [ 298 | (self.POS, clip_pos), 299 | (self.SPE, clip_spe), 300 | (self.FOR, clip_for), 301 | (self.GTO, 1), 302 | ] 303 | ) 304 | succ = self._set_vars(var_dict) 305 | time.sleep(0.008) # need to wait (dont know why) 306 | return succ, clip_pos 307 | 308 | def move_and_wait_for_pos( 309 | self, position: int, speed: int, force: int 310 | ) -> Tuple[int, ObjectStatus]: # noqa 311 | """Sends commands to start moving towards the given position, with the specified speed and force, and then waits for the move to complete. 312 | 313 | :param position: Position to move to [min_position, max_position] 314 | :param speed: Speed to move at [min_speed, max_speed] 315 | :param force: Force to use [min_force, max_force] 316 | :return: A tuple with an integer representing the last position returned by the gripper after it notified 317 | that the move had completed, a status indicating how the move ended (see ObjectStatus enum for details). Note 318 | that it is possible that the position was not reached, if an object was detected during motion. 319 | """ 320 | set_ok, cmd_pos = self.move(position, speed, force) 321 | if not set_ok: 322 | raise RuntimeError("Failed to set variables for move.") 323 | 324 | # wait until the gripper acknowledges that it will try to go to the requested position 325 | while self._get_var(self.PRE) != cmd_pos: 326 | time.sleep(0.001) 327 | 328 | # wait until not moving 329 | cur_obj = self._get_var(self.OBJ) 330 | while ( 331 | RobotiqGripper.ObjectStatus(cur_obj) == RobotiqGripper.ObjectStatus.MOVING 332 | ): 333 | cur_obj = self._get_var(self.OBJ) 334 | 335 | # report the actual position and the object status 336 | final_pos = self._get_var(self.POS) 337 | final_obj = cur_obj 338 | return final_pos, RobotiqGripper.ObjectStatus(final_obj) 339 | 340 | -------------------------------------------------------------------------------- /robots/ur.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import numpy as np 4 | 5 | from robots.robot import Robot 6 | 7 | 8 | class URRobot(Robot): 9 | """A class representing a UR robot.""" 10 | 11 | def __init__( 12 | self, 13 | robot_ip: str = "111.111.1.11", 14 | no_gripper: bool = False, 15 | gripper_type="", 16 | grip_range=110, 17 | port_idx=-1, 18 | ): 19 | import rtde_control 20 | import rtde_receive 21 | 22 | [print("in ur robot:", robot_ip) for _ in range(3)] 23 | self.robot = rtde_control.RTDEControlInterface(robot_ip) 24 | self.r_inter = rtde_receive.RTDEReceiveInterface(robot_ip) 25 | if not no_gripper: 26 | if gripper_type == "ability": 27 | from robots.ability_gripper import AbilityGripper 28 | 29 | self.gripper = AbilityGripper(port_idx=port_idx, grip_range=grip_range) 30 | self.gripper.connect() 31 | else: 32 | from robots.robotiq_gripper import RobotiqGripper 33 | 34 | self.gripper = RobotiqGripper() 35 | self.gripper.connect(hostname=robot_ip, port=63352) 36 | 37 | [print("connect") for _ in range(3)] 38 | 39 | self._free_drive = False 40 | self.robot.endFreedriveMode() 41 | self._use_gripper = not no_gripper 42 | self.gripper_type = gripper_type 43 | 44 | self.velocity = 0.5 45 | self.acceleration = 0.5 46 | 47 | # EEF 48 | self.velocity_l = 0.3 49 | self.acceleration_l = 0.3 50 | self.dt = 1.0 / 500 # 2ms 51 | self.lookahead_time = 0.2 # [0.03, 0.2]s smoothens the trajectory 52 | self.gain = 100 # [100, 2000] proportional gain for following target position 53 | 54 | def num_dofs(self) -> int: 55 | """Get the number of joints of the robot. 56 | 57 | Returns: 58 | int: The number of joints of the robot. 59 | """ 60 | if self._use_gripper: 61 | if self.gripper_type == "ability": 62 | return 12 63 | else: 64 | return 7 65 | return 6 66 | 67 | def _get_gripper_pos(self) -> float: 68 | if self.gripper_type in ["ability"]: 69 | gripper_pos = self.gripper.get_current_position() 70 | return gripper_pos 71 | else: 72 | gripper_pos = self.gripper.get_current_position() 73 | assert 0 <= gripper_pos <= 255, "Gripper position must be between 0 and 255" 74 | return gripper_pos / 255 75 | 76 | def get_joint_state(self) -> np.ndarray: 77 | """Get the current state of the leader robot. 78 | 79 | Returns: 80 | T: The current state of the leader robot. 81 | """ 82 | robot_joints = self.r_inter.getActualQ() 83 | if self._use_gripper: 84 | gripper_pos = self._get_gripper_pos() 85 | pos = np.append(robot_joints, gripper_pos) 86 | else: 87 | pos = robot_joints 88 | return pos 89 | 90 | def get_joint_velocities(self) -> np.ndarray: 91 | return self.r_inter.getActualQd() 92 | 93 | def get_eef_speed(self) -> np.ndarray: 94 | return self.r_inter.getActualTCPSpeed() 95 | 96 | def get_eef_pose(self) -> np.ndarray: 97 | """Get the current pose of the leader robot's end effector. 98 | 99 | Returns: 100 | T: The current pose of the leader robot's end effector. 101 | """ 102 | return self.r_inter.getActualTCPPose() 103 | 104 | def command_joint_state(self, joint_state: np.ndarray) -> None: 105 | """Command the leader robot to a given state. 106 | 107 | Args: 108 | joint_state (np.ndarray): The state to command the leader robot to. 109 | """ 110 | robot_joints = joint_state[:6] 111 | t_start = self.robot.initPeriod() 112 | self.robot.servoJ( 113 | robot_joints, 114 | self.velocity, 115 | self.acceleration, 116 | self.dt, 117 | self.lookahead_time, 118 | self.gain, 119 | ) 120 | if self._use_gripper: 121 | if self.gripper_type == "ability": 122 | assert ( 123 | max(joint_state[6:]) <= 1 and min(joint_state[6:]) >= 0 124 | ), f"Gripper position must be between 0 and 1:{joint_state[6:]}" 125 | self.gripper.move(joint_state[6:], debug=False) 126 | elif self.gripper_type == "allegro": 127 | self.gripper.move(joint_state[6:]) 128 | elif self.gripper_type == "dummy": 129 | pass 130 | else: 131 | gripper_pos = joint_state[-1] * 255 132 | # print(f"gripper move command: {gripper_pos}") 133 | self.gripper.move(int(gripper_pos), 255, 10) 134 | self.robot.waitPeriod(t_start) 135 | 136 | def command_eef_pose(self, eef_pose: np.ndarray) -> None: 137 | """Command the leader robot to a given state. 138 | 139 | Args: 140 | eef_pose (np.ndarray): The EEF pose to command the leader robot to. 141 | """ 142 | pose_command = eef_pose[:6] 143 | # print("current TCP:", self.r_inter.getActualTCPPose()) 144 | # print("pose_command:", pose_command) 145 | # input("press enter to continue") 146 | t_start = self.robot.initPeriod() 147 | self.robot.servoL( 148 | pose_command, 149 | self.velocity_l, 150 | self.acceleration_l, 151 | self.dt, 152 | self.lookahead_time, 153 | self.gain, 154 | ) 155 | if self._use_gripper: 156 | if self.gripper_type == "ability": 157 | assert ( 158 | max(eef_pose[6:]) <= 1 and min(eef_pose[6:]) >= 0 159 | ), f"Gripper position must be between 0 and 1:{eef_pose[6:]}" 160 | self.gripper.move(eef_pose[6:]) 161 | else: 162 | gripper_pos = eef_pose[-1] * 255 163 | # print(f"gripper move command: {gripper_pos}") 164 | self.gripper.move(int(gripper_pos), 255, 10) 165 | self.robot.waitPeriod(t_start) 166 | 167 | def freedrive_enabled(self) -> bool: 168 | """Check if the robot is in freedrive mode. 169 | 170 | Returns: 171 | bool: True if the robot is in freedrive mode, False otherwise. 172 | """ 173 | return self._free_drive 174 | 175 | def set_freedrive_mode(self, enable: bool) -> None: 176 | """Set the freedrive mode of the robot. 177 | 178 | Args: 179 | enable (bool): True to enable freedrive mode, False to disable it. 180 | """ 181 | if enable and not self._free_drive: 182 | self._free_drive = True 183 | self.robot.freedriveMode() 184 | elif not enable and self._free_drive: 185 | self._free_drive = False 186 | self.robot.endFreedriveMode() 187 | 188 | def get_observations(self) -> Dict[str, np.ndarray]: 189 | joints = self.get_joint_state() 190 | joint_vels = self.get_joint_velocities() 191 | eef_speed = self.get_eef_speed() 192 | pos_quat = self.get_eef_pose() 193 | gripper_pos = np.array([joints[-1]]) 194 | if self._use_gripper and self.gripper_type == "ability": 195 | # include Ability hand touch data 196 | touch = self.gripper.get_current_touch() 197 | else: 198 | touch = np.zeros(30) 199 | return { 200 | "joint_positions": joints, 201 | "joint_velocities": joint_vels, 202 | "eef_speed": eef_speed, 203 | "ee_pos_quat": pos_quat, # TODO: this is pos_rot actually 204 | "gripper_position": gripper_pos, 205 | "touch": touch, 206 | } 207 | 208 | 209 | -------------------------------------------------------------------------------- /run_env.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import pickle 4 | import time 5 | from dataclasses import dataclass 6 | from pathlib import Path 7 | from typing import Dict 8 | 9 | import cv2 10 | import numpy as np 11 | import termcolor 12 | import tyro 13 | 14 | # foot pedal 15 | from pynput import keyboard 16 | 17 | from agents.agent import BimanualAgent, SafetyWrapper 18 | from camera_node import ZMQClientCamera 19 | from env import RobotEnv 20 | from robot_node import ZMQClientRobot 21 | 22 | trigger_state = {"l": False, "r": False} 23 | 24 | 25 | def listen_key(key): 26 | global trigger_state 27 | try: 28 | trigger_state[key.char] = True 29 | except: 30 | pass 31 | 32 | 33 | def reset_key(key): 34 | global trigger_state 35 | try: 36 | trigger_state[key.char] = False 37 | except: 38 | pass 39 | 40 | 41 | listener = keyboard.Listener(on_press=listen_key) 42 | listener2 = keyboard.Listener(on_release=reset_key) 43 | listener.start() 44 | listener2.start() 45 | 46 | ### 47 | 48 | 49 | def count_folders(path): 50 | """Counts the number of folders under the given path.""" 51 | folder_count = 0 52 | for root, dirs, files in os.walk(path): 53 | folder_count += len(dirs) # Count directories only at current level 54 | break # Prevents descending into subdirectories 55 | return folder_count 56 | 57 | 58 | def print_color(*args, color=None, attrs=(), **kwargs): 59 | if len(args) > 0: 60 | args = tuple(termcolor.colored(arg, color=color, attrs=attrs) for arg in args) 61 | print(*args, **kwargs) 62 | 63 | 64 | def save_frame( 65 | folder: Path, 66 | timestamp: datetime.datetime, 67 | obs: Dict[str, np.ndarray], 68 | action: np.ndarray, 69 | activated=True, 70 | save_png=False, 71 | ) -> None: 72 | obs["activated"] = activated 73 | obs["control"] = action # add action to obs 74 | recorded_file = folder / ( 75 | timestamp.isoformat().replace(":", "-").replace(".", "-") + ".pkl" 76 | ) 77 | with open(recorded_file, "wb") as f: 78 | pickle.dump(obs, f) 79 | 80 | # save rgb image as png 81 | if save_png: 82 | rgb = obs["base_rgb"] 83 | for i in range(rgb.shape[0]): 84 | rgbi = cv2.cvtColor(rgb[i], cv2.COLOR_RGB2BGR) 85 | fn = str(recorded_file)[:-4] + f"-{i}.png" 86 | cv2.imwrite(fn, rgbi) 87 | 88 | 89 | @dataclass 90 | class Args: 91 | robot_port: int = 6000 92 | wrist_camera_port: int = 5001 93 | base_camera_port: int = 5000 94 | hostname: str = "111.0.0.1" 95 | hz: int = 100 96 | show_camera_view: bool = True 97 | 98 | agent: str = "quest" 99 | robot_type: str = "ur5" 100 | save_data: bool = False 101 | save_depth: bool = True 102 | save_png: bool = False 103 | data_dir: str = "/shared/data/bc_data" 104 | verbose: bool = False 105 | safe: bool = False 106 | use_vel_ik: bool = False 107 | 108 | num_diffusion_iters_compile: int = 15 # used for compilation only for now 109 | jit_compile: bool = False # send the compilation signal to the server (only need to do this once per inference server run). 110 | use_jit_agent: bool = False # use the inference server to get actions. The inference_agent_port and the inference_agent_host need to be set to the proper values. 111 | inference_agent_port: str = ( 112 | "1234" # port must be the same as the inference server port 113 | ) 114 | inference_agent_host = "111.11.111.11" # ip of the inference server (localhost if running locally; currently defaults to bt) inference server needs to use the same checkpoint folder when launching the inference node (args need to match) 115 | 116 | dp_ckpt_path: str = "/shared/ckpts/best.ckpt" 117 | 118 | temporal_ensemble_mode: str = "avg" 119 | temporal_ensemble_act_tau: float = 0.5 120 | 121 | 122 | def main(args): 123 | camera_clients = { 124 | "base": ZMQClientCamera(port=args.base_camera_port, host=args.hostname), 125 | } 126 | robot_client = ZMQClientRobot(port=args.robot_port, host=args.hostname) 127 | env = RobotEnv( 128 | robot_client, 129 | control_rate_hz=args.hz, 130 | camera_dict=camera_clients, 131 | show_camera_view=args.show_camera_view, 132 | save_depth=args.save_depth, 133 | ) 134 | 135 | if args.agent == "quest": 136 | from agents.quest_agent import SingleArmQuestAgent 137 | 138 | left_agent = SingleArmQuestAgent(robot_type=args.robot_type, which_hand="l") 139 | right_agent = SingleArmQuestAgent(robot_type=args.robot_type, which_hand="r") 140 | agent = BimanualAgent(left_agent, right_agent) 141 | print("Quest agent created") 142 | elif args.agent == "quest_hand": 143 | # some custom mapping from Quest controller to hand control 144 | from agents.quest_agent import ( 145 | DualArmQuestAgent, 146 | SingleArmQuestAgent, 147 | ) 148 | 149 | left_agent = SingleArmQuestAgent( 150 | robot_type=args.robot_type, 151 | which_hand="l", 152 | eef_control_mode=3, 153 | use_vel_ik=args.use_vel_ik, 154 | ) 155 | right_agent = SingleArmQuestAgent( 156 | robot_type=args.robot_type, 157 | which_hand="r", 158 | eef_control_mode=3, 159 | use_vel_ik=args.use_vel_ik, 160 | ) 161 | agent = DualArmQuestAgent(left_agent, right_agent) 162 | print("Quest agent created") 163 | elif args.agent == "quest_hand_eef": 164 | # some custom mapping from Quest controller to hand control 165 | from agents.quest_agent_eef import ( 166 | DualArmQuestAgent, 167 | SingleArmQuestAgent, 168 | ) 169 | 170 | left_agent = SingleArmQuestAgent( 171 | robot_type=args.robot_type, 172 | which_hand="l", 173 | eef_control_mode=3, 174 | ) 175 | right_agent = SingleArmQuestAgent( 176 | robot_type=args.robot_type, 177 | which_hand="r", 178 | eef_control_mode=3, 179 | ) 180 | agent = DualArmQuestAgent(left_agent, right_agent) 181 | print("Quest EEF agent created") 182 | elif args.agent in ["dp", "dp_eef"]: 183 | if args.use_jit_agent: 184 | from agents.dp_agent_zmq import BimanualDPAgent 185 | 186 | agent = BimanualDPAgent( 187 | ckpt_path=args.dp_ckpt_path, 188 | port=args.inference_agent_port, 189 | host=args.inference_agent_host, 190 | temporal_ensemble_act_tau=args.temporal_ensemble_act_tau, 191 | temporal_ensemble_mode=args.temporal_ensemble_mode, 192 | ) 193 | else: 194 | from agents.dp_agent import BimanualDPAgent 195 | 196 | agent = BimanualDPAgent(ckpt_path=args.dp_ckpt_path) 197 | else: 198 | raise ValueError(f"Invalid agent name for bimanual: {args.agent}") 199 | 200 | if args.agent == "quest": 201 | # using grippers 202 | reset_joints_left = np.deg2rad([-80, -140, -80, -85, -10, 80, 0]) 203 | reset_joints_right = np.deg2rad([-270, -30, 70, -85, 10, 0, 0]) 204 | else: 205 | # using Ability hands 206 | arm_joints_left = [-80, -140, -80, -85, -10, 80] 207 | arm_joints_right = [-270, -30, 70, -85, 10, 0] 208 | hand_joints = [0, 0, 0, 0, 0.5, 0.5] 209 | reset_joints_left = np.concatenate([np.deg2rad(arm_joints_left), hand_joints]) 210 | reset_joints_right = np.concatenate([np.deg2rad(arm_joints_right), hand_joints]) 211 | reset_joints = np.concatenate([reset_joints_left, reset_joints_right]) 212 | curr_joints = env.get_obs()["joint_positions"] 213 | curr_joints[6:12] = hand_joints 214 | curr_joints[18:] = hand_joints 215 | print("Current joints:", curr_joints) 216 | print("Reset joints:", reset_joints) 217 | max_delta = (np.abs(curr_joints - reset_joints)).max() 218 | steps = min(int(max_delta / 0.01), 20) 219 | for jnt in np.linspace(curr_joints, reset_joints, steps): 220 | env.step(jnt) 221 | 222 | obs = env.get_obs() 223 | 224 | if args.jit_compile: 225 | agent.compile_inference( 226 | obs, num_diffusion_iters=args.num_diffusion_iters_compile 227 | ) 228 | 229 | # going to start position 230 | print("Going to start position") 231 | start_pos = agent.act(obs) 232 | obs = env.get_obs() 233 | joints = obs["joint_positions"] 234 | 235 | if args.agent == "quest": 236 | ur_idx = [i for i in range(len(joints))] 237 | hand_idx = None 238 | else: 239 | ur_idx = list(range(0, 6)) + list(range(12, 18)) 240 | hand_idx = list(range(6, 12)) + list(range(18, 24)) 241 | 242 | if args.safe: 243 | max_joint_delta = 0.5 244 | max_hand_delta = 0.1 245 | safety_wrapper = SafetyWrapper( 246 | ur_idx, hand_idx, agent, delta=max_joint_delta, hand_delta=max_hand_delta 247 | ) 248 | 249 | print(f"Start pos: {len(start_pos)}", f"Joints: {len(joints)}") 250 | assert len(start_pos) == len( 251 | joints 252 | ), f"agent output dim = {len(start_pos)}, but env dim = {len(joints)}" 253 | 254 | print(f"Collecting traj no.{count_folders(args.data_dir) + 1}") 255 | 256 | # time.sleep(2.0) 257 | while not trigger_state["r"]: 258 | print(">>> Step on right") 259 | time.sleep(0.2) 260 | 261 | print_color("\nReady to go 🚀🚀🚀", color="green", attrs=("bold",)) 262 | 263 | start_time = time.time() 264 | 265 | if args.save_data: 266 | time_str = datetime.datetime.now().strftime("%m%d_%H%M%S") 267 | if args.agent.startswith("dp"): 268 | # eval 269 | save_path = ( 270 | Path(args.data_dir).expanduser() 271 | / "_".join( 272 | [ 273 | args.dp_ckpt_path.split("/")[-2], 274 | args.dp_ckpt_path.split("/")[-1][:-5], 275 | ] 276 | ) 277 | / time_str 278 | ) 279 | else: 280 | save_path = Path(args.data_dir).expanduser() / time_str 281 | save_path.mkdir(parents=True, exist_ok=True) 282 | print(f"Saving to {save_path}") 283 | 284 | is_first_frame = True 285 | try: 286 | frame_freq = [] 287 | while True: 288 | new_start_time = time.time() 289 | num = new_start_time - start_time 290 | message = f"\rTime passed: {round(num, 2)} " 291 | print_color( 292 | message, 293 | color="white", 294 | attrs=("bold",), 295 | end="", 296 | flush=True, 297 | ) 298 | if args.safe: 299 | action = safety_wrapper.act_safe( 300 | agent, obs, eef=(args.agent.endswith("_eef")) 301 | ) 302 | else: 303 | action = agent.act(obs) 304 | dt = datetime.datetime.now() 305 | 306 | if args.save_data: 307 | if is_first_frame: 308 | is_first_frame = False 309 | else: 310 | save_frame( 311 | save_path, 312 | dt, 313 | obs, 314 | action, 315 | activated=agent.trigger_state, 316 | save_png=args.save_png, 317 | ) 318 | 319 | if args.agent.endswith("_eef"): 320 | obs = env.step_eef(action) 321 | else: 322 | obs = env.step(action) 323 | 324 | ff = 1 / (time.time() - new_start_time) 325 | frame_freq.append(ff) 326 | 327 | if trigger_state["l"]: 328 | print_color("\nTriggered!", color="red", attrs=("bold",)) 329 | break 330 | 331 | except KeyboardInterrupt: 332 | print_color("\nInterrupted!", color="red", attrs=("bold",)) 333 | finally: 334 | if "dp" in args.agent: 335 | import glob 336 | 337 | from moviepy.editor import ImageSequenceClip 338 | 339 | # find all the pkl files in the episode directory 340 | pkls = sorted(glob.glob(os.path.join(save_path, "*.pkl"))) 341 | print("Total number of pkls: ", len(pkls)) 342 | frames = [] 343 | for pkl in pkls: 344 | with open(pkl, "rb") as f: 345 | try: 346 | data = pickle.load(f) 347 | except: 348 | continue 349 | rgb = data["base_rgb"] 350 | rgb = np.concatenate([rgb[i] for i in range(rgb.shape[0])], axis=1) 351 | frames.append(rgb) 352 | clip = ImageSequenceClip(frames, fps=5) 353 | ckpt_path = os.path.dirname(args.dp_ckpt_path) 354 | parent_name = os.path.basename(ckpt_path) 355 | clip.write_videofile( 356 | os.path.join(ckpt_path, f"{parent_name}_{time_str}.mp4") 357 | ) 358 | 359 | # save frame freq as txt 360 | with open(os.path.join(ckpt_path, f"freq_{time_str}.txt"), "w") as f: 361 | for step, freq in enumerate(frame_freq): 362 | f.write(f"{step}: {freq}\n") 363 | else: 364 | print("Done") 365 | 366 | # save frame freq as txt 367 | with open(save_path / "freq.txt", "w") as f: 368 | f.write( 369 | f"Average FPS: {np.mean(frame_freq[1:])}\n" 370 | f"Max FPS: {np.max(frame_freq[1:])}\n" 371 | f"Min FPS: {np.min(frame_freq[1:])}\n" 372 | f"Std FPS: {np.std(frame_freq[1:])}\n\n" 373 | ) 374 | for step, freq in enumerate(frame_freq): 375 | f.write(f"{step}: {freq}\n") 376 | 377 | os._exit(0) 378 | 379 | 380 | if __name__ == "__main__": 381 | main(tyro.cli(Args)) 382 | -------------------------------------------------------------------------------- /run_openloop.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import pickle 3 | import time 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import Dict 7 | 8 | import numpy as np 9 | import tyro 10 | 11 | from agents.agent import SafetyWrapper 12 | from camera_node import ZMQClientCamera 13 | from env import EvalRobotEnv 14 | from robot_node import ZMQClientRobot 15 | 16 | 17 | def print_color(*args, color=None, attrs=(), **kwargs): 18 | import termcolor 19 | 20 | if len(args) > 0: 21 | args = tuple(termcolor.colored(arg, color=color, attrs=attrs) for arg in args) 22 | print(*args, **kwargs) 23 | 24 | 25 | def save_frame( 26 | folder: Path, 27 | timestamp: datetime.datetime, 28 | obs: Dict[str, np.ndarray], 29 | action: np.ndarray, 30 | activated=True, 31 | ) -> None: 32 | obs["activated"] = activated 33 | obs["control"] = action # add action to obs 34 | recorded_file = folder / (timestamp.isoformat() + ".pkl") 35 | with open(recorded_file, "wb") as f: 36 | pickle.dump(obs, f) 37 | 38 | 39 | @dataclass 40 | class Args: 41 | robot_port: int = 6000 42 | wrist_camera_port: int = 5001 43 | base_camera_port: int = 5000 44 | hostname: str = "127.0.0.1" 45 | hz: int = 100 46 | show_camera_view: bool = True 47 | 48 | agent: str = "dp" 49 | robot_type: str = "ur5" 50 | hand_type: str = "ability" 51 | save_data: bool = False 52 | data_dir: str = "/shared/data/bc_data" 53 | verbose: bool = False 54 | safe: bool = False 55 | use_vel_ik: bool = False 56 | 57 | traj_path: str = "/shared/data/test_data" 58 | dp_ckpt_path: str = "best.ckpt" 59 | 60 | 61 | def main(args): 62 | camera_clients = { 63 | "base": ZMQClientCamera(port=args.base_camera_port, host=args.hostname), 64 | } 65 | robot_client = ZMQClientRobot(port=args.robot_port, host=args.hostname) 66 | env = EvalRobotEnv( 67 | robot_client, 68 | traj_path=args.traj_path, 69 | control_rate_hz=args.hz, 70 | camera_dict=camera_clients, 71 | ) 72 | if args.agent.startswith("dp"): 73 | from agents.dp_agent import BimanualDPAgent 74 | 75 | agent = BimanualDPAgent(ckpt_path=args.dp_ckpt_path) 76 | else: 77 | raise ValueError(f"Invalid agent name: {args.agent}") 78 | 79 | if args.hand_type == "ability": 80 | arm_joints_left = [-80, -140, -80, -85, -10, 80] 81 | arm_joints_right = [-270, -30, 70, -85, 10, 0] 82 | hand_joints = [0, 0, 0, 0, 0.5, 0.5] 83 | else: 84 | raise ValueError(f"Invalid hand type: {args.hand_type}") 85 | reset_joints_left = np.concatenate([np.deg2rad(arm_joints_left), hand_joints]) 86 | reset_joints_right = np.concatenate([np.deg2rad(arm_joints_right), hand_joints]) 87 | reset_joints = np.concatenate([reset_joints_left, reset_joints_right]) 88 | curr_joints = env.get_real_obs()["joint_positions"] 89 | if args.hand_type == "ability": 90 | curr_joints[6:12] = hand_joints 91 | curr_joints[18:] = hand_joints 92 | print("Current joints:", curr_joints) 93 | print("Reset joints:", reset_joints) 94 | max_delta = (np.abs(curr_joints - reset_joints)).max() 95 | steps = min(int(max_delta / 0.01), 20) 96 | for jnt in np.linspace(curr_joints, reset_joints, steps): 97 | obs = env.step(jnt) 98 | 99 | # going to start position 100 | print("Going to start position") 101 | start_pos = agent.act(env.get_real_obs()) 102 | obs = env.get_real_obs() 103 | joints = obs["joint_positions"] 104 | 105 | # if args.hand_type == "ability": 106 | ur_idx = list(range(0, 6)) + list(range(12, 18)) 107 | hand_idx = list(range(6, 12)) + list(range(18, 24)) 108 | 109 | if args.safe: 110 | max_joint_delta = 0.5 111 | max_hand_delta = 0.1 112 | safety_wrapper = SafetyWrapper( 113 | ur_idx, hand_idx, agent, delta=max_joint_delta, hand_delta=max_hand_delta 114 | ) 115 | 116 | print(f"Start pos: {len(start_pos)}", f"Joints: {len(joints)}") 117 | assert len(start_pos) == len( 118 | joints 119 | ), f"agent output dim = {len(start_pos)}, but env dim = {len(joints)}" 120 | 121 | for step in range(3): 122 | print("Countdown step", step) 123 | time.sleep(0.5) 124 | 125 | print_color("\nReady to go 🚀🚀🚀", color="green", attrs=("bold",)) 126 | 127 | start_time = time.time() 128 | 129 | if args.save_data: 130 | time_str = datetime.datetime.now().strftime("%m%d_%H%M%S") 131 | save_path = ( 132 | Path(args.data_dir).expanduser() 133 | / (args.traj_path.split("/")[-1] + "_openloop") 134 | / time_str 135 | ) 136 | save_path.mkdir(parents=True, exist_ok=True) 137 | print(f"Saving to {save_path}") 138 | 139 | while obs is not None: 140 | num = time.time() - start_time 141 | message = f"\rTime passed: {round(num, 2)} " 142 | print_color( 143 | message, 144 | color="white", 145 | attrs=("bold",), 146 | end="", 147 | flush=True, 148 | ) 149 | if args.safe: 150 | action = safety_wrapper.act_safe( 151 | agent, obs, eef=(args.agent.endswith("_eef")) 152 | ) 153 | else: 154 | action = agent.act(obs) 155 | dt = datetime.datetime.now() 156 | img, depth = camera_clients["base"].read() 157 | if args.save_data: 158 | obs["base_rgb"] = img 159 | obs["base_depth"] = depth 160 | save_frame(save_path, dt, obs, action, activated=agent.trigger_state) 161 | # input("Press Enter to continue...") 162 | 163 | if args.agent.endswith("_eef"): 164 | obs = env.step_eef(action) 165 | else: 166 | obs = env.step(action) 167 | 168 | # save eval video 169 | import glob 170 | import os 171 | 172 | from moviepy.editor import ImageSequenceClip 173 | 174 | episode_dir = save_path 175 | 176 | # find all the pkl files in the episode directory 177 | pkls = sorted(glob.glob(os.path.join(episode_dir, "*.pkl"))) 178 | 179 | # read all images 180 | frames = [] 181 | for pkl in pkls: 182 | with open(pkl, "rb") as f: 183 | try: 184 | data = pickle.load(f) 185 | except: 186 | continue 187 | rgb = data["base_rgb"] 188 | rgb = np.concatenate([rgb[i] for i in range(rgb.shape[0])], axis=1) 189 | frames.append(rgb) 190 | 191 | # Create a video clip 192 | clip = ImageSequenceClip(frames, fps=10) 193 | ckpt_path = os.path.dirname(args.dp_ckpt_path) 194 | clip.write_videofile(os.path.join(ckpt_path, f"{time_str}_openloop.mp4")) 195 | 196 | 197 | if __name__ == "__main__": 198 | main(tyro.cli(Args)) 199 | -------------------------------------------------------------------------------- /test_dp_agent.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import time 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from agents.dp_agent import BimanualDPAgent 10 | from learning.dp.data_processing import iterate 11 | 12 | torch.cuda.set_device(0) 13 | 14 | 15 | def main(args): 16 | hand_uppers = np.array([110.0, 110.0, 110.0, 110.0, 90.0, 120.0]) 17 | hand_lowers = np.array([5.0, 5.0, 5.0, 5.0, 5.0, 5.0]) 18 | 19 | dp_args = BimanualDPAgent.get_default_dp_args() 20 | data = iterate(args.data_dir) 21 | for num_diffusion_iters in args.num_diffusion_iters: 22 | dp_args["num_diffusion_iters"] = num_diffusion_iters 23 | dp_agent = BimanualDPAgent(ckpt_path=args.ckpt_path, dp_args=dp_args) 24 | 25 | controls = [] 26 | pred_actions = [] 27 | delta_action = [] 28 | last_action = data[0]["control"] 29 | 30 | start = time.time() 31 | dp_agent.compile_inference(data[0], num_inference_iters=num_diffusion_iters) 32 | end = time.time() 33 | ctime = end - start 34 | print(f"compilation time: {ctime}") 35 | 36 | start = time.time() 37 | max_infer_time = 0 38 | for i, obs in enumerate(data): 39 | control = obs["control"] 40 | delta_action.append(control - last_action) 41 | last_action = control 42 | controls.append(control) 43 | if i != 0: 44 | obs["joint_positions"][list(range(6)) + list(range(12, 18))] = ( 45 | pred_actions[-1][list(range(6)) + list(range(12, 18))] 46 | ) 47 | obs["joint_positions"][6:12] = ( 48 | pred_actions[-1][6:12] * (hand_uppers - hand_lowers) + hand_lowers 49 | ) 50 | obs["joint_positions"][18:24] = ( 51 | pred_actions[-1][18:24] * (hand_uppers - hand_lowers) + hand_lowers 52 | ) 53 | infer_start = time.time() 54 | pred_actions.append(dp_agent.act(obs)) 55 | infer_time = time.time() - infer_start 56 | max_infer_time = max(max_infer_time, infer_time) 57 | 58 | print("num_diffusion_iters:", num_diffusion_iters) 59 | print("time:", time.time() - start) 60 | print("Hz:", len(data) / (time.time() - start)) 61 | print("max_infer_time:", max_infer_time) 62 | print("lowest_freq:", 1 / max_infer_time) 63 | 64 | pred_actions = np.array(pred_actions) 65 | controls = np.array(controls) 66 | 67 | mse = np.mean(np.abs((pred_actions - controls)), axis=0) 68 | mean_delta_action = np.mean(np.abs(delta_action), axis=0) 69 | 70 | print_str = "\n".join( 71 | [ 72 | "mse:", 73 | str(mse.tolist()), 74 | "\n", 75 | "mean_delta_action:", 76 | str(mean_delta_action.tolist()), 77 | "\nfinal_diff:", 78 | str((pred_actions[-1] - controls[-1]).tolist()), 79 | ] 80 | ) 81 | print_str += "\n" 82 | print_str += ( 83 | "mse: " 84 | + str(mse.mean()) 85 | + " mean_delta_action: " 86 | + str(mean_delta_action.mean()) 87 | ) 88 | print(print_str) 89 | 90 | # save print_str as txt in ckpt dir 91 | ckpt_dir = os.path.dirname(args.ckpt_path) 92 | with open( 93 | os.path.join(ckpt_dir, f"eval_stats_{num_diffusion_iters}.txt"), "w" 94 | ) as f: 95 | f.write(print_str) 96 | 97 | traj_name = os.path.basename(args.data_dir) 98 | save_path = os.path.join( 99 | ckpt_dir, f"openloop_{traj_name}_{num_diffusion_iters}" 100 | ) 101 | os.makedirs(save_path, exist_ok=True) 102 | for i in range(len(pred_actions)): 103 | with open(os.path.join(save_path, str(i) + ".pkl"), "wb") as f: 104 | pickle.dump( 105 | { 106 | "control": pred_actions[i], 107 | "joint_positions": data[i]["joint_positions"], 108 | }, 109 | f, 110 | ) 111 | 112 | 113 | if __name__ == "__main__": 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument( 116 | "--ckpt_path", 117 | type=str, 118 | default="best.ckpt", 119 | ) 120 | parser.add_argument( 121 | "--data_dir", 122 | type=str, 123 | default="data/", 124 | ) 125 | parser.add_argument( 126 | "--num_diffusion_iters", default=[5, 15, 25, 50], type=int, nargs="+" 127 | ) 128 | args = parser.parse_args() 129 | main(args) 130 | -------------------------------------------------------------------------------- /test_dp_agent_zmq.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import time 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from agents.dp_agent_zmq import BimanualDPAgent 10 | from learning.dp.data_processing import iterate 11 | 12 | torch.cuda.set_device(0) 13 | 14 | 15 | def main(args): 16 | hand_uppers = np.array([110.0, 110.0, 110.0, 110.0, 90.0, 120.0]) 17 | hand_lowers = np.array([5.0, 5.0, 5.0, 5.0, 5.0, 5.0]) 18 | 19 | data = iterate(args.data_dir) 20 | 21 | num_diffusion_iters = args.num_diffusion_iters 22 | dp_agent = BimanualDPAgent( 23 | ckpt_path="best.ckpt", 24 | host="localhost", 25 | port="4321", 26 | ) 27 | dp_agent.compile_inference(data[0], num_diffusion_iters=num_diffusion_iters) 28 | controls = [] 29 | pred_actions = [] 30 | delta_action = [] 31 | last_action = data[0]["control"] 32 | 33 | # start = time.time() 34 | # end = time.time() 35 | # ctime = end - start 36 | # print(f"compilation time: {ctime}") 37 | 38 | start = time.time() 39 | max_infer_time = 0 40 | for i, obs in enumerate(data): 41 | time.sleep(0.1) 42 | control = obs["control"] 43 | delta_action.append(control - last_action) 44 | last_action = control 45 | controls.append(control) 46 | if i != 0: 47 | obs["joint_positions"][list(range(6)) + list(range(12, 18))] = pred_actions[ 48 | -1 49 | ][list(range(6)) + list(range(12, 18))] 50 | obs["joint_positions"][6:12] = ( 51 | pred_actions[-1][6:12] * (hand_uppers - hand_lowers) + hand_lowers 52 | ) 53 | obs["joint_positions"][18:24] = ( 54 | pred_actions[-1][18:24] * (hand_uppers - hand_lowers) + hand_lowers 55 | ) 56 | infer_start = time.time() 57 | pred_actions.append(dp_agent.act(obs)) 58 | infer_time = time.time() - infer_start 59 | max_infer_time = max(max_infer_time, infer_time) 60 | 61 | print("num_diffusion_iters:", num_diffusion_iters) 62 | print("time:", time.time() - start) 63 | print("Hz:", len(data) / (time.time() - start)) 64 | print("max_infer_time:", max_infer_time) 65 | print("lowest_freq:", 1 / max_infer_time) 66 | 67 | pred_actions = np.array(pred_actions) 68 | controls = np.array(controls) 69 | 70 | mse = np.mean(np.abs((pred_actions - controls)), axis=0) 71 | mean_delta_action = np.mean(np.abs(delta_action), axis=0) 72 | 73 | print_str = "\n".join( 74 | [ 75 | "mse:", 76 | str(mse.tolist()), 77 | "\n", 78 | "mean_delta_action:", 79 | str(mean_delta_action.tolist()), 80 | "\nfinal_diff:", 81 | str((pred_actions[-1] - controls[-1]).tolist()), 82 | ] 83 | ) 84 | print_str += "\n" 85 | print_str += ( 86 | "mse: " 87 | + str(mse.mean()) 88 | + " mean_delta_action: " 89 | + str(mean_delta_action.mean()) 90 | ) 91 | print(print_str) 92 | 93 | # save print_str as txt in ckpt dir 94 | ckpt_dir = os.path.dirname(args.ckpt_path) 95 | with open( 96 | os.path.join(ckpt_dir, f"eval_stats_{num_diffusion_iters}.txt"), "w" 97 | ) as f: 98 | f.write(print_str) 99 | 100 | traj_name = os.path.basename(args.data_dir) 101 | save_path = os.path.join(ckpt_dir, f"openloop_{traj_name}_{num_diffusion_iters}") 102 | os.makedirs(save_path, exist_ok=True) 103 | for i in range(len(pred_actions)): 104 | with open(os.path.join(save_path, str(i) + ".pkl"), "wb") as f: 105 | pickle.dump( 106 | { 107 | "control": pred_actions[i], 108 | "joint_positions": data[i]["joint_positions"], 109 | }, 110 | f, 111 | ) 112 | 113 | 114 | if __name__ == "__main__": 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument( 117 | "--data_dir", 118 | type=str, 119 | default="data/", 120 | ) 121 | parser.add_argument("--num_diffusion_iters", default=15, type=int) 122 | args = parser.parse_args() 123 | main(args) 124 | -------------------------------------------------------------------------------- /workflow/create_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | 6 | random.seed(0) 7 | 8 | 9 | def create_eval(origin_path, target_path, num=20): 10 | if not os.path.exists(target_path): 11 | os.makedirs(target_path) 12 | data_dir = os.listdir(origin_path) 13 | assert ( 14 | len(data_dir) >= num 15 | ), "The number of data is less than the number of data to move" 16 | assert len(data_dir) > 100, "at least 100 traj" 17 | # sample without replacement 18 | eval_dir = random.sample(data_dir, num) 19 | print(eval_dir) 20 | for file in eval_dir: 21 | shutil.move(os.path.join(origin_path, file), target_path) 22 | print("Move {} to {}".format(file, target_path)) 23 | 24 | 25 | if __name__ == "__main__": 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument( 28 | "--origin_path", type=str, default=None, help="path to the original data" 29 | ) 30 | parser.add_argument( 31 | "--target_path", type=str, default=None, help="path to the target data" 32 | ) 33 | parser.add_argument("--num", type=int, default=20, help="number of data to move") 34 | args = parser.parse_args() 35 | create_eval(args.origin_path, args.target_path, args.num) 36 | -------------------------------------------------------------------------------- /workflow/download_dataset.sh: -------------------------------------------------------------------------------- 1 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1512475235180' > data_banana.zip 2 | md5sum data_banana.zip 3 | 4 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1512698219144' > data_stack.zip 5 | md5sum data_stack.zip 6 | 7 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513394792580' > data_pour.zip 8 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513418334040' >> data_pour.zip 9 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513412800402' >> data_pour.zip 10 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513388327963' >> data_pour.zip 11 | md5sum data_pour.zip 12 | 13 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513478982439' > data_steak.zip 14 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513509873535' >> data_steak.zip 15 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513526807695' >> data_steak.zip 16 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513547399539' >> data_steak.zip 17 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513550058414' >> data_steak.zip 18 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513544174072' >> data_steak.zip 19 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513541198161' >> data_steak.zip 20 | md5sum data_steak.zip -------------------------------------------------------------------------------- /workflow/gen_deploy_scripts.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | 5 | def find_ckpts(ckpt_dirs): 6 | """ 7 | Find all checkpoint paths given a parent dir or a list of checkpoint directory names. 8 | """ 9 | ckpt_dict = {} 10 | for ckpt_dir in ckpt_dirs: 11 | for root, _, files in os.walk(ckpt_dir): 12 | for file in files: 13 | if file.endswith(".ckpt"): 14 | run = root.split("/")[-1] 15 | ckpt_idx = int(file.split(".")[-2].split("_")[-1]) 16 | if (run not in ckpt_dict) or ( 17 | run in ckpt_dict and ckpt_dict[run]["idx"] < ckpt_idx 18 | ): 19 | # update the ckpt path 20 | if os.path.exists(os.path.join(root, "args_log.txt")): 21 | args_path = os.path.join(root, "args_log.txt") 22 | else: 23 | args_path = None 24 | ckpt_dict[run] = { 25 | "root": root, 26 | "ckpt": os.path.join(root, file), 27 | "args": args_path, 28 | "idx": ckpt_idx, 29 | } 30 | return ckpt_dict 31 | 32 | 33 | def gen_deploy_scripts( 34 | ckpt_dirs, agent_type="dp", test_data_path="", conda_env_name="hato" 35 | ): 36 | ckpt_dict = find_ckpts(ckpt_dirs) 37 | 38 | for run, ckpt_args in ckpt_dict.items(): 39 | ckpt = ckpt_args["ckpt"] 40 | 41 | # generate env script 42 | env_script = f"python run_env.py --agent {agent_type} --no-show-camera-view --save_data --dp_ckpt_path {ckpt}" 43 | 44 | if test_data_path is None or test_data_path == "": 45 | print("No test traj for this ckpt.") 46 | openloop_script = f"python run_openloop.py --agent {agent_type} --no-show-camera-view --save_data --dp_ckpt_path {ckpt} --traj_path {test_data_path}" 47 | 48 | # generate node script 49 | node_script = ( 50 | 'python launch_nodes.py --hand_type ability --faster --cam_names "435"' 51 | ) 52 | env_script += " --hz 10" 53 | openloop_script += " --hz 10" 54 | 55 | if "pour" in ckpt or "steak" in ckpt: 56 | node_script += " --ability_gripper_grip_range 75" 57 | 58 | test_script = ( 59 | f"python test_dp_agent.py --ckpt_path {ckpt} --data_dir {test_data_path}" 60 | ) 61 | env_jit_script = ( 62 | env_script 63 | + " --use_jit_agent --inference_agent_port=1325 --num_diffusion_iters_compile=50 --jit_compile" 64 | ) 65 | 66 | inference_script = f"'conda activate {conda_env_name} && python launch_inference_nodes.py --dp_ckpt_path {ckpt} --port=1325'" 67 | # save the scripts at the ckpt directory 68 | save_dir = ckpt_args["root"] 69 | with open(os.path.join(save_dir, f"{run}_node.sh"), "w") as f: 70 | f.write(node_script) 71 | with open(os.path.join(save_dir, f"{run}_env.sh"), "w") as f: 72 | f.write(env_script) 73 | with open(os.path.join(save_dir, f"{run}_env_jit.sh"), "w") as f: 74 | f.write(env_jit_script) 75 | with open(os.path.join(save_dir, f"{run}_openloop.sh"), "w") as f: 76 | f.write(openloop_script) 77 | with open(os.path.join(save_dir, f"{run}_test.sh"), "w") as f: 78 | f.write(test_script) 79 | with open(os.path.join(save_dir, f"{run}_inference.sh"), "w") as f: 80 | f.write(inference_script) 81 | 82 | return ckpt_dict 83 | 84 | 85 | def run_test_scripts(ckpt_dirs, filter="", run_all=True): 86 | for checkpoint_dir in ckpt_dirs: 87 | for root, _, files in os.walk(checkpoint_dir): 88 | for file in files: 89 | if file.endswith("_test.sh") and filter in root: 90 | print("Running test script: ", file) 91 | if run_all: 92 | os.system(f"bash {os.path.join(root, file)}") 93 | else: 94 | # run the test script if enter yes 95 | answer = input(f"\nRun {os.path.join(root, file)}? (y/n)") 96 | if answer == "y": 97 | os.system(f"bash {os.path.join(root, file)}") 98 | else: 99 | print("Skip.") 100 | 101 | 102 | if __name__ == "__main__": 103 | import argparse 104 | 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument( 107 | "-c", 108 | "--ckpt_dirs", 109 | nargs="+", 110 | default=[ 111 | "/shared/ckpts/data_banana", 112 | ], 113 | ) 114 | parser.add_argument("-m", "--mode", type=str, default="gen") 115 | parser.add_argument("-f", "--filter", type=str, default="") 116 | args = parser.parse_args() 117 | 118 | if args.mode == "gen": 119 | checkpoint_paths = gen_deploy_scripts(args.ckpt_dirs) 120 | print(checkpoint_paths) 121 | print(len(checkpoint_paths)) 122 | print("Done.") 123 | elif args.mode == "test": 124 | run_test_scripts(args.ckpt_dirs, filter=args.filter) 125 | elif args.mode == "all": 126 | checkpoint_paths = gen_deploy_scripts(args.ckpt_dirs) 127 | print(checkpoint_paths) 128 | print(len(checkpoint_paths)) 129 | run_test_scripts(args.ckpt_dirs, filter=args.filter) 130 | -------------------------------------------------------------------------------- /workflow/split_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import random 5 | 6 | 7 | def split_symlink_train_test_merge_dataset(root_list, t_root, num_traj_per_task): 8 | target_train_root = t_root + f"_train" 9 | target_test_root = t_root + f"_test" 10 | print( 11 | f"split data points from", 12 | root_list, 13 | f"to {target_train_root}: train and {target_test_root}: test", 14 | ) 15 | if os.path.exists(target_train_root): 16 | return 17 | assert not os.path.exists(target_test_root) 18 | 19 | os.makedirs(target_train_root) 20 | os.makedirs(target_test_root) 21 | 22 | all_paths = [] 23 | for root in sorted(root_list): 24 | for path in sorted(os.listdir(root))[:num_traj_per_task]: 25 | if not os.path.isfile(os.path.join(root, path)): 26 | d = path 27 | if ( 28 | d.endswith("failed") 29 | or d.endswith("ood") 30 | or d.endswith("ikbad") 31 | or d.endswith("heated") 32 | or d.endswith("stop") 33 | or d.endswith("hard") 34 | ): 35 | continue 36 | all_paths.append((root, path)) 37 | 38 | train_paths = all_paths[:-1] 39 | test_paths = all_paths[-1:] 40 | 41 | for root, sl_path in train_paths: 42 | src_path = os.path.abspath(os.path.join(root, sl_path)) 43 | tgt_path = os.path.abspath(os.path.join(target_train_root, sl_path)) 44 | print("\rlinking", src_path, tgt_path, end="") 45 | os.symlink(src_path, tgt_path) 46 | 47 | for root, sl_path in test_paths: 48 | src_path = os.path.abspath(os.path.join(root, sl_path)) 49 | tgt_path = os.path.abspath(os.path.join(target_test_root, sl_path)) 50 | print("\rlinking", src_path, tgt_path, end="") 51 | os.symlink(src_path, tgt_path) 52 | 53 | print() 54 | print("Done!!") 55 | 56 | 57 | def split_symlink_train_test_dataset(root, t_root): 58 | target_train_root = t_root + f"_train" 59 | target_test_root = t_root + f"_test" 60 | print( 61 | f"split data points from {root} to {target_train_root}: train and {target_test_root}: test" 62 | ) 63 | 64 | if os.path.exists(target_train_root): 65 | return 66 | assert not os.path.exists(target_test_root) 67 | 68 | os.makedirs(target_train_root) 69 | os.makedirs(target_test_root) 70 | 71 | all_paths = [] 72 | for path in sorted(os.listdir(root)): 73 | if not os.path.isfile(os.path.join(root, path)): 74 | d = path 75 | if ( 76 | d.endswith("failed") 77 | or d.endswith("ood") 78 | or d.endswith("ikbad") 79 | or d.endswith("heated") 80 | or d.endswith("stop") 81 | or d.endswith("hard") 82 | ): 83 | continue 84 | all_paths.append(path) 85 | 86 | train_paths = all_paths[:-1] 87 | test_paths = all_paths[-1:] 88 | 89 | for sl_path in train_paths: 90 | src_path = os.path.abspath(os.path.join(root, sl_path)) 91 | tgt_path = os.path.abspath(os.path.join(target_train_root, sl_path)) 92 | print("\rlinking", src_path, tgt_path, end="") 93 | os.symlink(src_path, tgt_path) 94 | for sl_path in test_paths: 95 | src_path = os.path.abspath(os.path.join(root, sl_path)) 96 | tgt_path = os.path.abspath(os.path.join(target_test_root, sl_path)) 97 | print("\rlinking", src_path, tgt_path, end="") 98 | os.symlink(src_path, tgt_path) 99 | print() 100 | print("Done!!") 101 | 102 | 103 | def split_symlink_dataset(root, num_trajs): 104 | target_root = root + f"_{num_trajs}" 105 | print(f"split {num_trajs} data points from {root} to {target_root}") 106 | 107 | if os.path.exists(target_root): 108 | return 109 | assert not os.path.exists(target_root) 110 | 111 | os.makedirs(target_root) 112 | 113 | all_paths = [] 114 | for path in os.listdir(root): 115 | if not os.path.isfile(os.path.join(root, path)): 116 | all_paths.append(path) 117 | 118 | random.seed(0) 119 | random.shuffle(all_paths) 120 | 121 | sl_paths = all_paths[:num_trajs] 122 | 123 | assert len(all_paths) >= num_trajs 124 | for sl_path in sl_paths: 125 | src_path = os.path.abspath(os.path.join(root, sl_path)) 126 | tgt_path = os.path.abspath(os.path.join(target_root, sl_path)) 127 | print("\rlinking", src_path, tgt_path, end="") 128 | os.symlink(src_path, tgt_path) 129 | print() 130 | print("Done!!") 131 | 132 | 133 | if __name__ == "__main__": 134 | arg = argparse.ArgumentParser() 135 | arg.add_argument("--base_path", type=str, default="/hato/") 136 | arg.add_argument("--output_path", type=str, default="/split_data") 137 | arg.add_argument( 138 | "--data_name", 139 | nargs="+", 140 | type=str, 141 | default=[ 142 | "data_banana", 143 | ], 144 | ) 145 | arg.add_argument("--num_trajs", nargs="+", type=int, default=[10, 25, 50, 75]) 146 | arg.add_argument("--merge", action="store_true") 147 | arg.add_argument("--merge_name", type=str, default="data_banana_all") 148 | arg.add_argument("--num_traj_per_task", type=int, default=20) 149 | args = arg.parse_args() 150 | 151 | if not args.merge: 152 | for data_name in args.data_name: 153 | split_symlink_train_test_dataset( 154 | os.path.join(args.base_path, data_name), 155 | os.path.join(args.output_path, data_name), 156 | ) 157 | for num_trajs in args.num_trajs: 158 | split_symlink_dataset( 159 | os.path.join(args.output_path, data_name) + "_train", num_trajs 160 | ) 161 | 162 | else: 163 | data_name_list = [ 164 | os.path.join(args.base_path, data_name) for data_name in args.data_name 165 | ] 166 | split_symlink_train_test_merge_dataset( 167 | data_name_list, 168 | os.path.join(args.output_path, args.merge_name), 169 | args.num_traj_per_task, 170 | ) 171 | --------------------------------------------------------------------------------