├── .gitignore
├── README.md
├── __init__.py
├── agents
├── __init__.py
├── agent.py
├── dp_agent.py
├── dp_agent_zmq.py
├── quest_agent.py
├── quest_agent_eef.py
└── universal_robots_ur5e
│ ├── LICENSE
│ ├── README.md
│ ├── assets
│ ├── base_0.obj
│ ├── base_1.obj
│ ├── forearm_0.obj
│ ├── forearm_1.obj
│ ├── forearm_2.obj
│ ├── forearm_3.obj
│ ├── shoulder_0.obj
│ ├── shoulder_1.obj
│ ├── shoulder_2.obj
│ ├── upperarm_0.obj
│ ├── upperarm_1.obj
│ ├── upperarm_2.obj
│ ├── upperarm_3.obj
│ ├── wrist1_0.obj
│ ├── wrist1_1.obj
│ ├── wrist1_2.obj
│ ├── wrist2_0.obj
│ ├── wrist2_1.obj
│ ├── wrist2_2.obj
│ └── wrist3.obj
│ ├── scene.xml
│ ├── ur5e.png
│ └── ur5e.xml
├── camera_node.py
├── cameras
├── camera.py
└── realsense_camera.py
├── env.py
├── eval_dir.py
├── inference_node.py
├── launch_inference_nodes.py
├── launch_nodes.py
├── learning
└── dp
│ ├── .gitignore
│ ├── data_processing.py
│ ├── dataset.py
│ ├── learner.py
│ ├── models.py
│ ├── pipeline.py
│ └── utils.py
├── requirements.txt
├── robot_node.py
├── robots
├── ability_gripper.py
├── robot.py
├── robotiq_gripper.py
└── ur.py
├── run_env.py
├── run_openloop.py
├── test_dp_agent.py
├── test_dp_agent_zmq.py
└── workflow
├── create_eval.py
├── download_dataset.sh
├── gen_deploy_scripts.py
└── split_data.py
/.gitignore:
--------------------------------------------------------------------------------
1 | */__pycache__/
2 | *.pyc
3 | *.pkl
4 | *.pth
5 | *.zip
6 | *.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🕊️ HATO: Learning Visuotactile Skills with Two Multifingered Hands
2 |
3 | [[Project](https://toruowo.github.io/hato/)]
4 | [[Paper](https://arxiv.org/abs/2404.16823)]
5 |
6 | [Toru Lin](https://toruowo.github.io/),
7 | [Yu Zhang*](),
8 | [Qiyang Li*](https://colinqiyangli.github.io/),
9 | [Haozhi Qi*](https://haozhi.io/),
10 | [Brent Yi](https://scholar.google.com/citations?user=Ecy6lXwAAAAJ&hl=en),
11 | [Sergey Levine](https://people.eecs.berkeley.edu/~svlevine/),
12 | [Jitendra Malik](https://people.eecs.berkeley.edu/~malik/)
13 |
14 |
15 | ## Overview
16 |
17 | This repo contains code, datasets, and instructions to support the following use cases:
18 | - Collecting Demonstration Data
19 | - Training and Evaluating Diffusion Policies
20 | - Deploying Policies on Hardware
21 |
22 | ## Installation
23 |
24 | ```
25 | conda create -n hato python=3.9
26 | conda activate hato
27 | conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia -y
28 | pip install -r ./requirements.txt
29 | ```
30 |
31 | ## Collecting Demonstration Data
32 |
33 |
34 | This repo supports Meta Quest 2 as the teleoperation device. To start, install [oculus_reader](https://github.com/rail-berkeley/oculus_reader/blob/main/oculus_reader/reader.py) by following instructions in the link.
35 |
36 | We use [ZMQ](https://zeromq.org/) to handle communication between hardwares.
37 | Our data collection code is structured in the following way (credit to [gello_software](https://github.com/wuphilipp/gello_software) for the clean and modular template):
38 |
39 |
40 | | Directory / File | Detail |
41 | | :-------------: |:-------------:|
42 | | agents/ | contains Agent classes that generate low-level hardware control commands from teleoperation device or diffusion policy |
43 | | cameras/ | contains Camera classes that provide utilities to obtain real-time camera data |
44 | | robots/ | contains Robot classes that interface between Python code and physical hardwares to read observations or send low-level control commands |
45 | | *_node.py | contains ZMQ node classes for camera / robot / policy |
46 | | env.py | contains environment classes that organize the teleoperation and data collection process |
47 | | launch_nodes.py | script to launch robot hardware and sensor ZMQ nodes |
48 | | run_env.py | script to start the teleoperation and data collection process |
49 |
50 | (Code files not mentioned contain utilities for training and evaluating diffusion policies, and deploying policies on hardware; please see the next two sections for more details.)
51 |
52 | Currently available classes:
53 | - `agents/quest_agent.py`, `agents/quest_agent_eef.py`: support using Meta Quest 2 for hardware teleoperation, in either joint-space control and end-effector-space control mode
54 | - `agents/dp_agent.py`, `agents/dp_agent_zmq.py`: support using learned diffusion policies to provide control commands, either synchronously or asynchronously (see "Deploying Policies on Hardware" section for more details)
55 | - `cameras/realsense_camera.py`: supports reading and preprocessing RGB-D data from RealSense cameras
56 | - `robots/ur.py`, `robots/ability_gripper.py`, `robots/robotiq_gripper.py`: support hardware setup with a single UR5e arm or two UR5e arms, using Ability Hand or Robotiq gripper as the end-effector(s)
57 |
58 | These classes can be flexibly modified or extended to support other teleoperation devices / learning pipelines / cameras / robot hardwares.
59 |
60 | Example usage (note that `launch_node.py` and `run_env.py` should be run simultaneously in two separate windows):
61 |
62 | ```
63 | # to collect data with two UR5e arms and Ability hands at 10Hz
64 | python launch_nodes.py --robot bimanual_ur --hand_type ability
65 | python run_env.py --agent quest_hand --no-show-camera-view --hz 10 --save_data
66 |
67 | # to collect data with two UR5e arms and Ability hands at 10Hz, showing camera view during data collection
68 | python launch_nodes.py --robot bimanual_ur --hand_type ability
69 | python run_env.py --agent quest_hand --hz 10 --save_data
70 | ```
71 | Node processes can be cleaned up by running `pkill -9 -f launch_nodes.py`.
72 |
73 | ## Training and Evaluating Diffusion Policies
74 |
75 | ### Download Datasets
76 |
77 | [[Data Files](https://berkeley.app.box.com/s/379cf57zqm1akvr00vdcloxqxi3ucb9g?sortColumn=name&sortDirection=ASC)]
78 |
79 | The linked data folder contains datasets for the four tasks featured in [the HATO paper](): `banana`, `stack`, `pour`, and `steak`.
80 |
81 | Full dataset files can be unzipped using the `unzip` command.
82 | Note that the `pour` and `steak` datasets are split into part files because of the large file size. Before unzipping, the part files should be first concatenated back into single files using the following commands:
83 | ```
84 | cat data_pour_part_0* > data_pour.zip
85 | cat data_steak_part_0* > data_steak.zip
86 | ```
87 |
88 | Scripts to download and concatenate the datasets can be found in `workflow/download_dataset.sh`.
89 |
90 | ### Run Training
91 |
92 | 1. Run `python workflow/split_data.py --base_path Traj_Folder_Path --output_path Output_Folder_Path --data_name Data_Name --num_trajs N1 N2 N3` to split the data into train and validation sets. Number of trajectories used can be specified via the `num_trajs` argument.
93 | 2. Run `python ./learning/dp/pipeline.py --data_path Split_Folder_Path/Data_Name --model_save_path Model_Path` to train the model, where
94 | - `--data_path` is the splitted trajectory folder, which is the output_path + data_name in step 1. (data_name should not include suffix like `_train` or `_train_10`)
95 | - `--model_save_path` is the path to save the model
96 |
97 | Important Training Arguments
98 | 1. `--batch_size` : the batch size for training.
99 | 2. `--num_epochs` : the number of epochs for training.
100 | 3. `--representation_type`: the data representation type for the model. Format: `repr1--repr2--...`. Repr can be `eef`, `img`, `depth`, `touch`, `pos`, `hand_pos`
101 | 4. `--camera_indices`: the camera indices to use for the image data modality. Format: `01`,`012`,`02`, etc.
102 | 5. `--train_suffix`: the suffix for the training folder. This is useful when you want to train the model on different data splits and should be used with the `--num_trajs` arg of `split_data.py`. Format: `_10`, `_50`, etc.
103 | 6. `--load_img`: whether to load all the images into memory. If set to `True`, the training will be faster but will consume more memory.
104 | 7. `--use_memmap_cache`: whether to use memmap cache for the images. If set to `True`, it will create a memmap file in the training folder to accelerate the data loading.
105 | 8. `--use_wandb`: whether to use wandb for logging.
106 | 9. `--wandb_project_name`: the wandb project name.
107 | 10. `--wandb_entity_name`: the wandb entity name.
108 | 11. `--load_path`: the path to load the model. If set, the model will be loaded from the path and continue training. This should be the path of non-ema model.
109 |
110 |
111 | ### Run Evaluation
112 |
113 | Run `python ./eval_dir.py --eval_dir Traj_Folder_Path --ckpt_path Model_Path_1 Model_Path_2` to evaluate multiple models on all trajectories in the folder.
114 |
115 |
116 | ## Deploying Policies on Hardware
117 |
118 | A set of useful bash scripts can be generated using the following command:
119 |
120 | ```python workflow/gen_deploy_scripts.py -c [ckpt_folder]```
121 |
122 | where `ckpt_folder` is a path that contains one or more checkpoints resulted from the training above. The generated bash scripts provide the following functionalities.
123 |
124 | - To run policy deployment with asynchronous setup (see Section IV-D in paper for more details), first run `*_inference.sh` to launch the server, then run `*_node.sh` && `*_env_jit.sh` in separate terminal windows to deploy the robot with inference server.
125 |
126 | - To run policy deployment without asynchronous setup, run `*_node.sh` && `*_env.sh` in separate terminal windows.
127 |
128 | - To run policy evaluation for individual checkpoints, run `*_test.sh`. Note that path to a folder containing trajectories for evaluation needs to be specified in addition to the model checkpoint path.
129 |
130 | - To run open-loop policy test, run `*_openloop.sh`. Note that path to a demonstration data trajectory needs to be specified in addition to the model checkpoint path.
131 |
132 |
133 | ## Acknowledgement
134 |
135 | This project was developed with help from the following codebases.
136 |
137 | - [ability-hand-api](https://github.com/psyonicinc/ability-hand-api/tree/master/python)
138 | - [diffusion_policy](https://github.com/real-stanford/diffusion_policy)
139 | - [gello_software](https://github.com/wuphilipp/gello_software/tree/main)
140 | - [mujoco_menagerie](https://github.com/google-deepmind/mujoco_menagerie/blob/main/universal_robots_ur5e/ur5e.xml)
141 | - [oculus_reader](https://github.com/rail-berkeley/oculus_reader)
142 |
143 | ## Reference
144 |
145 | If you find HATO or this codebase helpful in your research, please consider citing:
146 |
147 | ```
148 | @article{lin2024learning,
149 | author={Lin, Toru and Zhang, Yu and Li, Qiyang and Qi, Haozhi and Yi, Brent and Levine, Sergey and Malik, Jitendra},
150 | title={Learning Visuotactile Skills with Two Multifingered Hands},
151 | journal={arXiv:2404.16823},
152 | year={2024}
153 | }
154 | ```
155 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ToruOwO/hato/3b6f4496e11a386df8f362f398b36e832944b2b8/__init__.py
--------------------------------------------------------------------------------
/agents/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ToruOwO/hato/3b6f4496e11a386df8f362f398b36e832944b2b8/agents/__init__.py
--------------------------------------------------------------------------------
/agents/agent.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Protocol
2 |
3 | import numpy as np
4 |
5 |
6 | class Agent(Protocol):
7 | def act(self, obs: Dict[str, Any]) -> np.ndarray:
8 | """Returns an action given an observation.
9 |
10 | Args:
11 | obs: observation from the environment.
12 |
13 | Returns:
14 | action: action to take on the environment.
15 | """
16 | raise NotImplementedError
17 |
18 |
19 | class BimanualAgent(Agent):
20 | def __init__(self, agent_left: Agent, agent_right: Agent):
21 | self.agent_left = agent_left
22 | self.agent_right = agent_right
23 |
24 | def act(self, obs: Dict[str, Any]) -> np.ndarray:
25 | left_obs = {}
26 | right_obs = {}
27 | for key, val in obs.items():
28 | L = val.shape[0]
29 | half_dim = L // 2
30 | if key.endswith("rgb") or key.endswith("depth"):
31 | left_obs[key] = val
32 | right_obs[key] = val
33 | else:
34 | assert L == half_dim * 2, f"{key} must be even, something is wrong"
35 | left_obs[key] = val[:half_dim]
36 | right_obs[key] = val[half_dim:]
37 | return np.concatenate(
38 | [self.agent_left.act(left_obs), self.agent_right.act(right_obs)]
39 | )
40 |
41 |
42 | class SafetyWrapper:
43 | def __init__(self, ur_idx, hand_idx, agent, delta=0.5, hand_delta=0.1):
44 | self.ur_idx = ur_idx
45 | self.hand_idx = hand_idx
46 | self.agent = agent
47 | self.delta = delta
48 | self.hand_delta = hand_delta
49 |
50 | # Ability Hand ranges
51 | self.num_hand_dofs = 12
52 | self.upper_ranges = [110, 110, 110, 110, 90, 120] * 2
53 | self.lower_ranges = [5] * self.num_hand_dofs
54 |
55 | def _hand_pos_to_cmd(self, pos):
56 | """
57 | pos: desired hand degrees for Ability Hands
58 | """
59 | assert len(pos) == self.num_hand_dofs
60 | cmd = [0] * self.num_hand_dofs
61 | for i in range(self.num_hand_dofs):
62 | if i in [5, 11]:
63 | pos[i] = -pos[i]
64 | cmd[i] = (pos[i] - self.lower_ranges[i]) / (
65 | self.upper_ranges[i] - self.lower_ranges[i]
66 | )
67 | return cmd
68 |
69 | def act_safe(self, agent, obs, eef=False):
70 | joints = obs["joint_positions"]
71 | action = agent.act(obs)
72 | if eef:
73 | eef_pose = obs["ee_pos_quat"]
74 | left_eef_pos = eef_pose[:3]
75 | right_eef_pos = eef_pose[6:9]
76 | left_eef_target = action[:3]
77 | right_eef_target = action[12:15]
78 | if np.linalg.norm(left_eef_pos - left_eef_target) > 0.5:
79 | print("Left EEF action is too big")
80 | print(
81 | f"Left EEF pos: {left_eef_pos}, target: {left_eef_target}, diff: {left_eef_pos - left_eef_target}"
82 | )
83 | if np.linalg.norm(right_eef_pos - right_eef_target) > 0.5:
84 | print("Right EEF action is too big")
85 | print(
86 | f"Right EEF pos: {right_eef_pos}, target: {right_eef_target}, diff: {right_eef_pos - right_eef_target}"
87 | )
88 |
89 | left_eef_target = np.clip(
90 | left_eef_target,
91 | left_eef_pos - self.delta,
92 | left_eef_pos + self.delta,
93 | )
94 | right_eef_target = np.clip(
95 | right_eef_target,
96 | right_eef_pos - self.delta,
97 | right_eef_pos + self.delta,
98 | )
99 | action[:3] = left_eef_target
100 | action[12:15] = right_eef_target
101 | else:
102 | # check if action is too big
103 | if (np.abs(action[self.ur_idx] - joints[self.ur_idx]) > self.delta).any():
104 | print("Action is too big")
105 |
106 | # print which joints are too big
107 | joint_index = np.where(np.abs(action - joints) > self.delta)[0]
108 | for j in joint_index:
109 | if j in self.ur_idx:
110 | print(
111 | f"Joint [{j}], leader: {action[j]}, follower: {joints[j]}, diff: {action[j] - joints[j]}"
112 | )
113 | action[self.ur_idx] = np.clip(
114 | action[self.ur_idx],
115 | joints[self.ur_idx] - self.delta,
116 | joints[self.ur_idx] + self.delta,
117 | )
118 |
119 | if self.hand_idx is not None:
120 | joint_cmd = self._hand_pos_to_cmd(joints[self.hand_idx])
121 | action[self.hand_idx] = joint_cmd + np.clip(
122 | action[self.hand_idx] - joint_cmd, -self.hand_delta, self.hand_delta
123 | )
124 | action[self.hand_idx] = np.clip(action[self.hand_idx], 0, 1)
125 | return action
126 |
--------------------------------------------------------------------------------
/agents/dp_agent.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import json
3 | import os
4 | from typing import Any, Dict
5 |
6 | import numpy as np
7 | import quaternion
8 | import torch
9 |
10 | from learning.dp.pipeline import Agent as DPAgent
11 |
12 |
13 | def from_numpy(data, device, unsqueeze=True):
14 | return {
15 | key: torch.from_numpy(value).to(device=device)[None]
16 | if unsqueeze
17 | else torch.from_numpy(value).to(device=device)
18 | for key, value in data.items()
19 | if key != "activated"
20 | }
21 |
22 |
23 | UR_IDX = list(range(6)) + list(range(12, 18))
24 | LEFT_HAND_IDX = list(range(6, 12))
25 | RIGHT_HAND_IDX = list(range(18, 24))
26 |
27 |
28 | def get_reset_joints(ur_eef=False):
29 | if ur_eef:
30 | # these are EEF pose
31 | arm_joints_left = [
32 | -0.10244499252760966,
33 | -0.7492784625293504,
34 | 0.14209881911585326,
35 | -0.3622358797572402,
36 | -1.4347279978985925,
37 | 0.8691789808786153,
38 | ]
39 |
40 | arm_joints_right = [
41 | 0.2313341406775527,
42 | -0.7512396951283128,
43 | 0.06337444935928868,
44 | 0.6512089317940273,
45 | 1.3246649009637026,
46 | 0.5471978290474188,
47 | ]
48 | else:
49 | arm_joints_left = [-80, -140, -80, -85, -10, 80]
50 | arm_joints_right = [-270, -30, 70, -85, 10, 0]
51 | hand_joints = [0, 0, 0, 0, 0.5, 0.5]
52 | reset_joints_left = np.concatenate([np.deg2rad(arm_joints_left), hand_joints])
53 | reset_joints_right = np.concatenate([np.deg2rad(arm_joints_right), hand_joints])
54 | reset_joints = np.concatenate([reset_joints_left, reset_joints_right])
55 | return reset_joints
56 |
57 |
58 | def get_eef_pose(eef_pose, eef_delta):
59 | pos_delta = eef_delta[:3]
60 | rot_delta = eef_delta[3:]
61 | pos = eef_pose[:3] + pos_delta
62 | # quaternion multiplication
63 | rot = quaternion.as_rotation_vector(
64 | quaternion.from_rotation_vector(rot_delta)
65 | * quaternion.from_rotation_vector(eef_pose[3:])
66 | )
67 | return np.concatenate((pos, rot))
68 |
69 |
70 | def parse_txt_to_json(input_file_path, output_file_path):
71 | # Initialize an empty dictionary to store the parsed key-value pairs
72 | data = {}
73 |
74 | # Open and read the input text file line by line
75 | with open(input_file_path, "r") as file:
76 | for line in file:
77 | kv = line.strip().split(": ", 1)
78 | if len(kv) != 2:
79 | continue
80 | key, value = kv
81 | if key == "camera_indices":
82 | data[key] = list(map(int, value))
83 | elif key == "representation_type":
84 | data[key] = value.split("-")
85 | else:
86 | try:
87 | value = int(value)
88 | except ValueError:
89 | try:
90 | value = float(value)
91 | except ValueError:
92 | if value == "True":
93 | value = True
94 | elif value == "False":
95 | value = False
96 | elif value == "None":
97 | value = None
98 |
99 | data[key] = value
100 |
101 | with open(output_file_path, "w") as json_file:
102 | json.dump(data, json_file, indent=4)
103 | return data
104 |
105 |
106 | class BimanualDPAgent:
107 | def __init__(
108 | self,
109 | ckpt_path,
110 | dp_args=None,
111 | binarize_finger_action=False,
112 | ):
113 | if dp_args is None:
114 | dp_args = self.get_default_dp_args()
115 |
116 | # rewrite dp_args based on saved args
117 | args_txt = os.path.join(os.path.dirname(ckpt_path), "args_log.txt")
118 | args_json = os.path.join(os.path.dirname(ckpt_path), "args_log.json")
119 | args = parse_txt_to_json(args_txt, args_json)
120 | for k in dp_args.keys():
121 | if k == "output_sizes":
122 | dp_args[k]["img"] = args["image_output_size"]
123 | else:
124 | if k in args:
125 | dp_args[k] = args[k]
126 |
127 | # save dp args in ckpt path as json
128 | ckpt_dir = os.path.dirname(ckpt_path)
129 | args_path = os.path.join(ckpt_dir, "dp_args.json")
130 | if os.path.exists(args_path):
131 | with open(args_path, "r") as f:
132 | dp_args = json.load(f)
133 | else:
134 | with open(args_path, "w") as f:
135 | json.dump(dp_args, f)
136 |
137 | torch.cuda.set_device(0)
138 | self.dp = DPAgent(
139 | output_sizes=dp_args["output_sizes"],
140 | representation_type=dp_args["representation_type"],
141 | identity_encoder=dp_args["identity_encoder"],
142 | camera_indices=dp_args["camera_indices"],
143 | pred_horizon=dp_args["pred_horizon"],
144 | obs_horizon=dp_args["obs_horizon"],
145 | action_horizon=dp_args["action_horizon"],
146 | without_sampling=dp_args["without_sampling"],
147 | predict_eef_delta=dp_args["predict_eef_delta"],
148 | predict_pos_delta=dp_args["predict_pos_delta"],
149 | use_ddim=dp_args["use_ddim"],
150 | )
151 | self.dp_args = dp_args
152 | self.obsque = collections.deque(maxlen=dp_args["obs_horizon"])
153 | self.dp.load(ckpt_path)
154 | self.action_queue = collections.deque(maxlen=dp_args["action_horizon"])
155 | self.max_length = 100
156 | self.count = 0
157 | self.except_thumb_hand_indices = np.array([6, 7, 8, 9, 18, 19, 20, 21])
158 | self.binaraize_finger_action = binarize_finger_action
159 | self.clip_far = dp_args["clip_far"]
160 | self.predict_eef_delta = dp_args["predict_eef_delta"]
161 | self.predict_pos_delta = dp_args["predict_pos_delta"]
162 | assert not (self.predict_eef_delta and self.predict_pos_delta)
163 | self.control = get_reset_joints(ur_eef=self.predict_eef_delta)
164 |
165 | self.num_diffusion_iters = dp_args["num_diffusion_iters"]
166 |
167 | self.hand_uppers = np.array([110.0, 110.0, 110.0, 110.0, 90.0, 120.0])
168 | self.hand_lowers = np.array([5.0, 5.0, 5.0, 5.0, 5.0, 5.0])
169 |
170 | # TODO: remove hack
171 | self.hand_new_uppers = np.array([75] * 4 + [90.0, 120.0])
172 |
173 | self.trigger_state = {"l": True, "r": True}
174 |
175 | @staticmethod
176 | def get_default_dp_args():
177 | return {
178 | "output_sizes": {
179 | "eef": 64,
180 | "hand_pos": 64,
181 | "img": 128,
182 | "pos": 128,
183 | "touch": 64,
184 | },
185 | "representation_type": ["img", "pos", "touch", "depth"],
186 | "identity_encoder": False,
187 | "camera_indices": [0, 1, 2],
188 | "obs_horizon": 4,
189 | "pred_horizon": 16,
190 | "action_horizon": 8,
191 | "num_diffusion_iters": 15,
192 | "without_sampling": False,
193 | "clip_far": False,
194 | "predict_eef_delta": False,
195 | "predict_pos_delta": False,
196 | "use_ddim": False,
197 | }
198 |
199 | def compile_inference(self, example_obs, precision="high", num_inference_iters=5):
200 | torch.set_float32_matmul_precision(precision)
201 | self.dp.policy.forward = torch.compile(torch.no_grad(self.dp.policy.forward))
202 | self.num_diffusion_iters = num_inference_iters
203 |
204 | for i in range(25): # burn in
205 | self.act(example_obs)
206 |
207 | def act(self, obs: Dict[str, Any]) -> np.ndarray:
208 | curr_joint_pos = obs["joint_positions"]
209 | curr_eef_pose = obs["ee_pos_quat"]
210 | obs = self.dp.get_observation([obs], load_img=True)
211 | if "img" in obs:
212 | obs["img"] = self.dp.eval_transform(obs["img"].squeeze(0))
213 |
214 | # if obsque is empty, fill it with the current observation
215 | if len(self.obsque) == 0:
216 | self.obsque.extend([obs] * self.dp_args["obs_horizon"])
217 | else:
218 | self.obsque.append(obs)
219 |
220 | # if action queue is not empty, return the first action in the queue
221 | if len(self.action_queue) > 0:
222 | act = self.action_queue.popleft()
223 |
224 | # if action queue is empty, predict new actions
225 | else:
226 | pred = self.dp.predict(
227 | self.obsque, num_diffusion_iters=self.num_diffusion_iters
228 | )
229 | for i in range(self.dp_args["action_horizon"]):
230 | act = pred[i]
231 | self.action_queue.append(act)
232 |
233 | act = self.action_queue.popleft()
234 |
235 | if self.predict_pos_delta:
236 | self.control[UR_IDX] = curr_joint_pos[UR_IDX]
237 | self.control = self.control + act
238 | act = self.control
239 | # act = curr_joint_pos + act
240 |
241 | if self.predict_eef_delta:
242 | left_arm_act = get_eef_pose(curr_eef_pose[:6], act[:6])
243 | left_hand_act = act[6:12]
244 | right_arm_act = get_eef_pose(curr_eef_pose[6:], act[12:18])
245 | right_hand_act = act[18:24]
246 | act = np.concatenate(
247 | [left_arm_act, left_hand_act, right_arm_act, right_hand_act],
248 | axis=-1,
249 | )
250 |
251 | # if binarize_finger_action is True, binarize the finger action
252 |
253 | if self.binaraize_finger_action:
254 | mean_act = np.mean(act[self.except_thumb_hand_indices])
255 | if mean_act > 0.5:
256 | act[self.except_thumb_hand_indices] = 1.0
257 | else:
258 | act[self.except_thumb_hand_indices] = 0.0
259 | else:
260 | left_hand = (
261 | act[LEFT_HAND_IDX] * (self.hand_uppers - self.hand_lowers)
262 | + self.hand_lowers
263 | )
264 | act[LEFT_HAND_IDX] = (left_hand - self.hand_lowers) / (
265 | self.hand_new_uppers - self.hand_lowers
266 | )
267 | right_hand = (
268 | act[RIGHT_HAND_IDX] * (self.hand_uppers - self.hand_lowers)
269 | + self.hand_lowers
270 | )
271 | act[RIGHT_HAND_IDX] = (right_hand - self.hand_lowers) / (
272 | self.hand_new_uppers - self.hand_lowers
273 | )
274 | act[list(range(6, 12)) + list(range(18, 24))] = np.clip(
275 | act[list(range(6, 12)) + list(range(18, 24))], 0, 1
276 | )
277 |
278 | return act
279 |
--------------------------------------------------------------------------------
/agents/dp_agent_zmq.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import json
3 | import os
4 | from typing import Any, Dict
5 |
6 | import numpy as np
7 | import quaternion
8 | import torch
9 | import time
10 |
11 | import pickle
12 |
13 | from learning.dp.pipeline import Agent as DPAgent
14 | from inference_node import (
15 | ZMQInferenceClient,
16 | ZMQInferenceServer,
17 | DEFAULT_INFERENCE_PORT,
18 | )
19 |
20 |
21 | def from_numpy(data, device, unsqueeze=True):
22 | return {
23 | key: torch.from_numpy(value).to(device=device)[None]
24 | if unsqueeze
25 | else torch.from_numpy(value).to(device=device)
26 | for key, value in data.items()
27 | if key != "activated"
28 | }
29 |
30 |
31 | UR_IDX = list(range(6)) + list(range(12, 18))
32 | LEFT_HAND_IDX = list(range(6, 12))
33 | RIGHT_HAND_IDX = list(range(18, 24))
34 |
35 |
36 | def get_reset_joints(ur_eef=False):
37 | if ur_eef:
38 | # these are EEF pose
39 | arm_joints_left = [
40 | -0.10244499252760966,
41 | -0.7492784625293504,
42 | 0.14209881911585326,
43 | -0.3622358797572402,
44 | -1.4347279978985925,
45 | 0.8691789808786153,
46 | ]
47 |
48 | arm_joints_right = [
49 | 0.2313341406775527,
50 | -0.7512396951283128,
51 | 0.06337444935928868,
52 | 0.6512089317940273,
53 | 1.3246649009637026,
54 | 0.5471978290474188,
55 | ]
56 | else:
57 | arm_joints_left = [-80, -140, -80, -85, -10, 80]
58 | arm_joints_right = [-270, -30, 70, -85, 10, 0]
59 | hand_joints = [0, 0, 0, 0, 0.5, 0.5]
60 | reset_joints_left = np.concatenate([np.deg2rad(arm_joints_left), hand_joints])
61 | reset_joints_right = np.concatenate([np.deg2rad(arm_joints_right), hand_joints])
62 | reset_joints = np.concatenate([reset_joints_left, reset_joints_right])
63 | return reset_joints
64 |
65 |
66 | def get_eef_pose(eef_pose, eef_delta):
67 | pos_delta = eef_delta[:3]
68 | rot_delta = eef_delta[3:]
69 | pos = eef_pose[:3] + pos_delta
70 | # quaternion multiplication
71 | rot = quaternion.as_rotation_vector(
72 | quaternion.from_rotation_vector(rot_delta)
73 | * quaternion.from_rotation_vector(eef_pose[3:])
74 | )
75 | return np.concatenate((pos, rot))
76 |
77 |
78 | def parse_txt_to_json(input_file_path, output_file_path):
79 | # Initialize an empty dictionary to store the parsed key-value pairs
80 | data = {}
81 |
82 | # Open and read the input text file line by line
83 | with open(input_file_path, "r") as file:
84 | for line in file:
85 | kv = line.strip().split(": ", 1)
86 | if len(kv) != 2:
87 | continue
88 | key, value = kv
89 | if key == "camera_indices":
90 | data[key] = list(map(int, value))
91 | elif key == "representation_type":
92 | data[key] = value.split("-")
93 | else:
94 | try:
95 | value = int(value)
96 | except ValueError:
97 | try:
98 | value = float(value)
99 | except ValueError:
100 | if value == "True":
101 | value = True
102 | elif value == "False":
103 | value = False
104 | elif value == "None":
105 | value = None
106 |
107 | data[key] = value
108 |
109 | with open(output_file_path, "w") as json_file:
110 | json.dump(data, json_file, indent=4)
111 | return data
112 |
113 |
114 | class BimanualDPAgentServer(ZMQInferenceServer):
115 | def __init__(
116 | self,
117 | ckpt_path,
118 | dp_args=None,
119 | binarize_finger_action=False,
120 | *args,
121 | **kwargs,
122 | ):
123 | super().__init__(*args, **kwargs)
124 |
125 | if dp_args is None:
126 | dp_args = self.get_default_dp_args()
127 |
128 | # rewrite dp_args based on saved args
129 | args_txt = os.path.join(os.path.dirname(ckpt_path), "args_log.txt")
130 | args_json = os.path.join(os.path.dirname(ckpt_path), "args_log.json")
131 | args = parse_txt_to_json(args_txt, args_json)
132 | for k in dp_args.keys():
133 | if k == "output_sizes":
134 | dp_args[k]["img"] = args["image_output_size"]
135 | else:
136 | if k in args:
137 | dp_args[k] = args[k]
138 |
139 | # save dp args in ckpt path as json
140 | ckpt_dir = os.path.dirname(ckpt_path)
141 | args_path = os.path.join(ckpt_dir, "dp_args.json")
142 | if os.path.exists(args_path):
143 | with open(args_path, "r") as f:
144 | dp_args = json.load(f)
145 | else:
146 | with open(args_path, "w") as f:
147 | json.dump(dp_args, f)
148 |
149 | torch.cuda.set_device(0)
150 | self.dp = DPAgent(
151 | output_sizes=dp_args["output_sizes"],
152 | representation_type=dp_args["representation_type"],
153 | identity_encoder=dp_args["identity_encoder"],
154 | camera_indices=dp_args["camera_indices"],
155 | pred_horizon=dp_args["pred_horizon"],
156 | obs_horizon=dp_args["obs_horizon"],
157 | action_horizon=dp_args["action_horizon"],
158 | without_sampling=dp_args["without_sampling"],
159 | predict_eef_delta=dp_args["predict_eef_delta"],
160 | predict_pos_delta=dp_args["predict_pos_delta"],
161 | use_ddim=dp_args["use_ddim"],
162 | )
163 | self.dp_args = dp_args
164 | self.obsque = collections.deque(maxlen=dp_args["obs_horizon"])
165 | self.dp.load(ckpt_path)
166 | self.action_queue = collections.deque(maxlen=dp_args["action_horizon"])
167 | self.max_length = 100
168 | self.count = 0
169 | self.except_thumb_hand_indices = np.array([6, 7, 8, 9, 18, 19, 20, 21])
170 | self.binaraize_finger_action = binarize_finger_action
171 | self.clip_far = dp_args["clip_far"]
172 | self.predict_eef_delta = dp_args["predict_eef_delta"]
173 | self.predict_pos_delta = dp_args["predict_pos_delta"]
174 | assert not (self.predict_eef_delta and self.predict_pos_delta)
175 | self.control = get_reset_joints(ur_eef=self.predict_eef_delta)
176 |
177 | self.num_diffusion_iters = dp_args["num_diffusion_iters"]
178 |
179 | self.hand_uppers = np.array([110.0, 110.0, 110.0, 110.0, 90.0, 120.0])
180 | self.hand_lowers = np.array([5.0, 5.0, 5.0, 5.0, 5.0, 5.0])
181 |
182 | # TODO: remove hack
183 | self.hand_new_uppers = np.array([75] * 4 + [90.0, 120.0])
184 |
185 | self.trigger_state = {"l": True, "r": True}
186 |
187 | @staticmethod
188 | def get_default_dp_args():
189 | return {
190 | "output_sizes": {
191 | "eef": 64,
192 | "hand_pos": 64,
193 | "img": 128,
194 | "pos": 128,
195 | "touch": 64,
196 | },
197 | "representation_type": ["img", "pos", "touch", "depth"],
198 | "identity_encoder": False,
199 | "camera_indices": [0, 1, 2],
200 | "obs_horizon": 4,
201 | "pred_horizon": 16,
202 | "action_horizon": 8,
203 | "num_diffusion_iters": 15,
204 | "without_sampling": False,
205 | "clip_far": False,
206 | "predict_eef_delta": False,
207 | "predict_pos_delta": False,
208 | "use_ddim": False,
209 | }
210 |
211 | def compile_inference(self, precision="high"):
212 | message = self._socket.recv()
213 | start_time = time.time()
214 | state_dict = pickle.loads(message)
215 | self.num_diffusion_iters = state_dict["num_diffusion_iters"]
216 | example_obs = state_dict["example_obs"]
217 | print(
218 | f"received compilation request: # diff iters = {state_dict['num_diffusion_iters']}"
219 | )
220 |
221 | torch.set_float32_matmul_precision(precision)
222 | self.dp.policy.forward = torch.compile(torch.no_grad(self.dp.policy.forward))
223 |
224 | for i in range(25): # burn in
225 | self.act(example_obs)
226 | print("success, compile time: " + str(time.time() - start_time))
227 | self._socket.send_string("success")
228 |
229 | def infer(self, obs: Dict[str, Any]) -> np.ndarray:
230 | return self.dp.predict([obs], num_diffusion_iters=self.num_diffusion_iters)
231 |
232 | def act(self, obs: Dict[str, Any]) -> np.ndarray:
233 | curr_joint_pos = obs["joint_positions"]
234 | curr_eef_pose = obs["ee_pos_quat"]
235 | obs = self.dp.get_observation([obs], load_img=True)
236 | if "img" in obs:
237 | obs["img"] = self.dp.eval_transform(obs["img"].squeeze(0))
238 | return self.infer(obs)
239 |
240 |
241 | class BimanualDPAgent(ZMQInferenceClient):
242 | def __init__(
243 | self,
244 | ckpt_path,
245 | dp_args=None,
246 | binarize_finger_action=False,
247 | port=DEFAULT_INFERENCE_PORT,
248 | host="127.0.0.1",
249 | temporal_ensemble_mode="new",
250 | temporal_ensemble_act_tau=0.5,
251 | ):
252 | super().__init__(
253 | default_action=get_reset_joints(),
254 | port=port,
255 | host=host,
256 | ensemble_mode=temporal_ensemble_mode,
257 | act_tau=temporal_ensemble_act_tau,
258 | )
259 |
260 | if dp_args is None:
261 | dp_args = self.get_default_dp_args()
262 |
263 | # rewrite dp_args based on saved args
264 | args_txt = os.path.join(os.path.dirname(ckpt_path), "args_log.txt")
265 | args_json = os.path.join(os.path.dirname(ckpt_path), "args_log.json")
266 | args = parse_txt_to_json(args_txt, args_json)
267 | for k in dp_args.keys():
268 | if k == "output_sizes":
269 | dp_args[k]["img"] = args["image_output_size"]
270 | else:
271 | if k in args:
272 | dp_args[k] = args[k]
273 |
274 | self.dp_args = dp_args
275 | self.obsque = collections.deque(maxlen=dp_args["obs_horizon"])
276 | # self.dp.load(ckpt_path)
277 | self.action_queue = collections.deque(maxlen=dp_args["action_horizon"])
278 | self.max_length = 100
279 | self.count = 0
280 | self.except_thumb_hand_indices = np.array([6, 7, 8, 9, 18, 19, 20, 21])
281 | self.binaraize_finger_action = binarize_finger_action
282 | self.clip_far = dp_args["clip_far"]
283 | self.predict_eef_delta = dp_args["predict_eef_delta"]
284 | self.predict_pos_delta = dp_args["predict_pos_delta"]
285 | assert not (self.predict_eef_delta and self.predict_pos_delta)
286 | self.control = get_reset_joints(ur_eef=self.predict_eef_delta)
287 |
288 | self.num_diffusion_iters = dp_args["num_diffusion_iters"]
289 |
290 | self.hand_uppers = np.array([110.0, 110.0, 110.0, 110.0, 90.0, 120.0])
291 | self.hand_lowers = np.array([5.0, 5.0, 5.0, 5.0, 5.0, 5.0])
292 |
293 | # TODO: remove hack
294 | self.hand_new_uppers = np.array([75] * 4 + [90.0, 120.0])
295 |
296 | self.trigger_state = {"l": True, "r": True}
297 |
298 | @staticmethod
299 | def get_default_dp_args():
300 | return {
301 | "output_sizes": {
302 | "eef": 64,
303 | "hand_pos": 64,
304 | "img": 128,
305 | "pos": 128,
306 | "touch": 64,
307 | },
308 | "representation_type": ["img", "pos", "touch", "depth"],
309 | "identity_encoder": False,
310 | "camera_indices": [0, 1, 2],
311 | "obs_horizon": 4,
312 | "pred_horizon": 16,
313 | "action_horizon": 8,
314 | "num_diffusion_iters": 15,
315 | "without_sampling": False,
316 | "clip_far": False,
317 | "predict_eef_delta": False,
318 | "predict_pos_delta": False,
319 | "use_ddim": False,
320 | }
321 |
322 | def compile_inference(self, example_obs, num_diffusion_iters):
323 | message = pickle.dumps(
324 | {"example_obs": example_obs, "num_diffusion_iters": num_diffusion_iters}
325 | )
326 | self._socket.send(message)
327 |
328 | message = self._socket.recv()
329 | assert message == b"success"
330 |
331 | def act(self, obs: Dict[str, Any]) -> np.ndarray:
332 | curr_joint_pos = obs["joint_positions"]
333 | curr_eef_pose = obs["ee_pos_quat"]
334 | act = super().act(obs)
335 |
336 | if self.predict_pos_delta:
337 | self.control[UR_IDX] = curr_joint_pos[UR_IDX]
338 | self.control = self.control + act
339 | act = self.control
340 | # act = curr_joint_pos + act
341 |
342 | if self.predict_eef_delta:
343 | left_arm_act = get_eef_pose(curr_eef_pose[:6], act[:6])
344 | left_hand_act = act[6:12]
345 | right_arm_act = get_eef_pose(curr_eef_pose[6:], act[12:18])
346 | right_hand_act = act[18:24]
347 | act = np.concatenate(
348 | [left_arm_act, left_hand_act, right_arm_act, right_hand_act],
349 | axis=-1,
350 | )
351 |
352 | # if binarize_finger_action is True, binarize the finger action
353 |
354 | if self.binaraize_finger_action:
355 | mean_act = np.mean(act[self.except_thumb_hand_indices])
356 | if mean_act > 0.5:
357 | act[self.except_thumb_hand_indices] = 1.0
358 | else:
359 | act[self.except_thumb_hand_indices] = 0.0
360 | else:
361 | left_hand = (
362 | act[LEFT_HAND_IDX] * (self.hand_uppers - self.hand_lowers)
363 | + self.hand_lowers
364 | )
365 | act[LEFT_HAND_IDX] = (left_hand - self.hand_lowers) / (
366 | self.hand_new_uppers - self.hand_lowers
367 | )
368 | right_hand = (
369 | act[RIGHT_HAND_IDX] * (self.hand_uppers - self.hand_lowers)
370 | + self.hand_lowers
371 | )
372 | act[RIGHT_HAND_IDX] = (right_hand - self.hand_lowers) / (
373 | self.hand_new_uppers - self.hand_lowers
374 | )
375 | act[list(range(6, 12)) + list(range(18, 24))] = np.clip(
376 | act[list(range(6, 12)) + list(range(18, 24))], 0, 1
377 | )
378 |
379 | return act
380 |
--------------------------------------------------------------------------------
/agents/quest_agent.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import numpy as np
4 | import quaternion
5 | from dm_control import mjcf
6 | from dm_control.mujoco.wrapper import mjbindings
7 | from dm_control.utils.inverse_kinematics import qpos_from_site_pose, nullspace_method
8 | from oculus_reader.reader import OculusReader
9 | from agents.agent import Agent
10 |
11 | mjlib = mjbindings.mjlib
12 |
13 | mj2ur = np.array([[0, -1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
14 | ur2mj = np.linalg.inv(mj2ur)
15 |
16 | trigger_state = {"l": False, "r": False}
17 |
18 |
19 | def apply_transfer(mat: np.ndarray, xyz: np.ndarray) -> np.ndarray:
20 | # xyz can be 3dim or 4dim (homogeneous) or can be a rotation matrix
21 | if len(xyz) == 3:
22 | xyz = np.append(xyz, 1)
23 | return np.matmul(mat, xyz)[:3]
24 |
25 |
26 | quest2ur = np.array([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]])
27 | # 45 deg CCW
28 | ur2left = np.array(
29 | [
30 | [1 / 2 * np.sqrt(2), 1 / 2 * np.sqrt(2), 0, 0],
31 | [-1 / 2 * np.sqrt(2), 1 / 2 * np.sqrt(2), 0, 0],
32 | [0, 0, 1, 0],
33 | [0, 0, 0, 1],
34 | ]
35 | )
36 | # 45 deg CW
37 | ur2right = np.array(
38 | [
39 | [1 / 2 * np.sqrt(2), -1 / 2 * np.sqrt(2), 0, 0],
40 | [1 / 2 * np.sqrt(2), 1 / 2 * np.sqrt(2), 0, 0],
41 | [0, 0, 1, 0],
42 | [0, 0, 0, 1],
43 | ]
44 | )
45 | ur2quest = np.linalg.inv(quest2ur)
46 |
47 |
48 | def velocity_ik(physics,
49 | site_name,
50 | delta_rot_mat,
51 | delta_pos,
52 | joint_names):
53 |
54 | dtype = physics.data.qpos.dtype
55 |
56 | # Convert site name to index.
57 | site_id = physics.model.name2id(site_name, 'site')
58 |
59 |
60 | jac = np.empty((6, physics.model.nv), dtype=dtype)
61 | err = np.zeros(6, dtype=dtype)
62 | jac_pos, jac_rot = jac[:3], jac[3:]
63 | err_pos, err_rot = err[:3], err[3:]
64 |
65 | err_pos[:] = delta_pos[:]
66 |
67 | delta_rot_quat = np.empty(4, dtype=dtype)
68 | mjlib.mju_mat2Quat(delta_rot_quat, delta_rot_mat)
69 | mjlib.mju_quat2Vel(err_rot, delta_rot_quat, 1)
70 |
71 | mjlib.mj_jacSite(
72 | physics.model.ptr, physics.data.ptr, jac_pos, jac_rot, site_id)
73 |
74 | dof_indices = []
75 | for jn in joint_names:
76 | dof_idx = physics.model.joint(jn).id
77 | dof_indices.append(dof_idx)
78 |
79 | jac_joints = jac[:, dof_indices]
80 |
81 | update_joints = nullspace_method(
82 | jac_joints, err, regularization_strength=0.03)
83 |
84 | return update_joints
85 |
86 |
87 | class SingleArmQuestAgent(Agent):
88 | def __init__(
89 | self,
90 | robot_type: str,
91 | which_hand: str,
92 | eef_control_mode: int = 0,
93 | verbose: bool = False,
94 | use_vel_ik: bool = False,
95 | vel_ik_speed_scale: float = 0.95,
96 | ) -> None:
97 | """Interact with the robot using the quest controller.
98 |
99 | leftTrig: press to start control (also record the current position as the home position)
100 | leftJS: a tuple of (x,y) for the joystick, only need y to control the gripper
101 | """
102 | self.which_hand = which_hand
103 | self.eef_control_mode = eef_control_mode
104 | assert self.which_hand in ["l", "r"]
105 |
106 | self.oculus_reader = OculusReader()
107 | if robot_type == "ur5":
108 | mjcf_model = mjcf.from_path("universal_robots_ur5e/ur5e.xml")
109 | mjcf_model.name = robot_type
110 | else:
111 | raise ValueError(f"Unknown robot type: {robot_type}")
112 | self.physics = mjcf.Physics.from_mjcf_model(mjcf_model)
113 | self.control_active = False
114 | self.reference_quest_pose = None
115 | self.reference_ee_rot_ur = None
116 | self.reference_ee_pos_ur = None
117 | self.reference_js = [0.5, 0.5]
118 | self.js_speed_scale = 0.1
119 |
120 | self.use_vel_ik = use_vel_ik
121 | self.vel_ik_speed_scale = vel_ik_speed_scale
122 |
123 | if use_vel_ik:
124 | self.translation_scaling_factor = 1.0
125 | else:
126 | self.translation_scaling_factor = 2.0
127 |
128 | self.robot_type = robot_type
129 | self._verbose = verbose
130 |
131 | def act(self, obs: Dict[str, np.ndarray]) -> np.ndarray:
132 | if self.robot_type == "ur5":
133 | num_dof = 6
134 | current_qpos = obs["joint_positions"][:num_dof] # last one dim is the gripper
135 | # run the fk
136 | self.physics.data.qpos[:num_dof] = current_qpos
137 | self.physics.step()
138 |
139 | ee_rot_mj = np.array(
140 | self.physics.named.data.site_xmat["attachment_site"]
141 | ).reshape(3, 3)
142 | ee_pos_mj = np.array(self.physics.named.data.site_xpos["attachment_site"])
143 | if self.which_hand == "l":
144 | pose_key = "l"
145 | trigger_key = "leftTrig"
146 | grip_key = "leftGrip"
147 | joystick_key = "leftJS"
148 | elif self.which_hand == "r":
149 | pose_key = "r"
150 | trigger_key = "rightTrig"
151 | grip_key = "rightGrip"
152 | joystick_key = "rightJS"
153 | else:
154 | raise ValueError(f"Unknown hand: {self.which_hand}")
155 | # check the trigger button state
156 | (
157 | pose_data,
158 | button_data,
159 | ) = self.oculus_reader.get_transformations_and_buttons()
160 | if len(pose_data) == 0 or len(button_data) == 0:
161 | print("no data, quest not yet ready")
162 | return np.concatenate(
163 | [current_qpos, obs["joint_positions"][num_dof:] * 0.0]
164 | )
165 |
166 | if self.eef_control_mode == 0:
167 | new_gripper_angle = [button_data[grip_key][0]]
168 | elif self.eef_control_mode == 1:
169 | new_gripper_angle = [button_data[grip_key][0]] * 6 # [0, 1]
170 | else:
171 | # (x, y) position of joystick, range (-1.0, 1.0)
172 | if self.which_hand == "r":
173 | js_y = button_data[joystick_key][0] * -1
174 | js_x = button_data[joystick_key][1] * -1
175 | else:
176 | js_y = button_data[joystick_key][0]
177 | js_x = button_data[joystick_key][1] * -1
178 |
179 | if self.eef_control_mode == 2:
180 | # convert js_x, js_y from range (-1.0, 1.0) to (0.0, 1.0)
181 | js_x = (js_x + 1) / 2
182 | js_y = (js_y + 1) / 2
183 | # control absolute position using joystick
184 | self.reference_js = [js_x, js_y]
185 | else:
186 | # control relative position using joystick
187 | self.reference_js = [
188 | max(0, min(self.reference_js[0] + js_x * self.js_speed_scale, 1)),
189 | max(0, min(self.reference_js[1] + js_y * self.js_speed_scale, 1)),
190 | ]
191 | new_gripper_angle = [
192 | button_data[grip_key][0],
193 | button_data[grip_key][0],
194 | button_data[grip_key][0],
195 | button_data[grip_key][0],
196 | self.reference_js[0],
197 | self.reference_js[1],
198 | ] # [0, 1]
199 | arm_not_move_return = np.concatenate([current_qpos, new_gripper_angle])
200 | if len(pose_data) == 0:
201 | print("no data, quest not yet ready")
202 | return arm_not_move_return
203 |
204 | global trigger_state
205 | trigger_state[self.which_hand] = button_data[trigger_key][0] > 0.5
206 | if trigger_state[self.which_hand]:
207 | if self.control_active is True:
208 | if self._verbose:
209 | print("controlling the arm")
210 | current_pose = pose_data[pose_key]
211 | delta_rot = current_pose[:3, :3] @ np.linalg.inv(
212 | self.reference_quest_pose[:3, :3]
213 | )
214 | delta_pos = current_pose[:3, 3] - self.reference_quest_pose[:3, 3]
215 | if self.which_hand == "l":
216 | t_mat = np.matmul(ur2left, quest2ur)
217 | else:
218 | t_mat = np.matmul(ur2right, quest2ur)
219 | delta_pos_ur = (
220 | apply_transfer(t_mat, delta_pos) * self.translation_scaling_factor
221 | )
222 | # ? is this the case?
223 | delta_rot_ur = quest2ur[:3, :3] @ delta_rot @ ur2quest[:3, :3]
224 | if self._verbose:
225 | print(f"delta pos and rot in VR space: \n{delta_pos}, {delta_rot}")
226 | print(
227 | f"delta pos and rot in ur space: \n{delta_pos_ur}, {delta_rot_ur}"
228 | )
229 | next_ee_rot_ur = delta_rot_ur @ self.reference_ee_rot_ur
230 | next_ee_pos_ur = delta_pos_ur + self.reference_ee_pos_ur
231 |
232 | if self.use_vel_ik:
233 | next_ee_pos_mj = apply_transfer(ur2mj, next_ee_pos_ur)
234 | next_ee_rot_mj = ur2mj[:3, :3] @ next_ee_rot_ur
235 |
236 | err_rot_mj = next_ee_rot_mj @ np.linalg.inv(ee_rot_mj)
237 | err_pos_mj = next_ee_pos_mj - ee_pos_mj
238 |
239 | print(err_pos_mj)
240 |
241 | delta_qpos = velocity_ik(
242 | self.physics,
243 | "attachment_site",
244 | err_rot_mj.flatten(),
245 | err_pos_mj,
246 | joint_names=[
247 | "shoulder_pan_joint",
248 | "shoulder_lift_joint",
249 | "elbow_joint",
250 | "wrist_1_joint",
251 | "wrist_2_joint",
252 | "wrist_3_joint",
253 | ],
254 | )
255 |
256 | new_qpos = current_qpos + delta_qpos * self.vel_ik_speed_scale
257 |
258 | else:
259 | target_quat = quaternion.as_float_array(
260 | quaternion.from_rotation_matrix(ur2mj[:3, :3] @ next_ee_rot_ur)
261 | )
262 | ik_result = qpos_from_site_pose(
263 | self.physics,
264 | "attachment_site",
265 | target_pos=apply_transfer(ur2mj, next_ee_pos_ur),
266 | target_quat=target_quat,
267 | tol=1e-14,
268 | max_steps=400,
269 | )
270 | self.physics.reset()
271 | if ik_result.success:
272 | new_qpos = ik_result.qpos[:num_dof]
273 | else:
274 | print("ik failed, using the original qpos")
275 | return arm_not_move_return
276 | command = np.concatenate([new_qpos, new_gripper_angle])
277 | return command
278 |
279 | else: # last state is not in active
280 | self.control_active = True
281 | if self._verbose:
282 | print("control activated!")
283 | self.reference_quest_pose = pose_data[pose_key]
284 |
285 | self.reference_ee_rot_ur = mj2ur[:3, :3] @ ee_rot_mj
286 | self.reference_ee_pos_ur = apply_transfer(mj2ur, ee_pos_mj)
287 | return arm_not_move_return
288 | else:
289 | if self._verbose:
290 | print("deactive control")
291 | self.control_active = False
292 | self.reference_quest_pose = None
293 | return arm_not_move_return
294 |
295 |
296 | class DualArmQuestAgent(Agent):
297 | def __init__(self, agent_left: Agent, agent_right: Agent):
298 | self.agent_left = agent_left
299 | self.agent_right = agent_right
300 | global trigger_state
301 | self.trigger_state = trigger_state
302 |
303 | def act(self, obs: Dict) -> np.ndarray:
304 | left_obs = {}
305 | right_obs = {}
306 | for key, val in obs.items():
307 | L = val.shape[0]
308 | half_dim = L // 2
309 | if key.endswith("rgb") or key.endswith("depth"):
310 | left_obs[key] = val
311 | right_obs[key] = val
312 | else:
313 | assert L == half_dim * 2, f"{key} must be even, something is wrong"
314 | left_obs[key] = val[:half_dim]
315 | right_obs[key] = val[half_dim:]
316 | return np.concatenate(
317 | [self.agent_left.act(left_obs), self.agent_right.act(right_obs)]
318 | )
319 |
320 |
321 | if __name__ == "__main__":
322 | oculus_reader = OculusReader()
323 | while True:
324 | """
325 | example output:
326 | ({'l': array([[-0.828395 , 0.541667 , -0.142682 , 0.219646 ],
327 | [-0.107737 , 0.0958919, 0.989544 , -0.833478 ],
328 | [ 0.549685 , 0.835106 , -0.0210789, -0.892425 ],
329 | [ 0. , 0. , 0. , 1. ]]), 'r': array([[-0.328058, 0.82021 , 0.468652, -1.8288 ],
330 | [ 0.070887, 0.516083, -0.8536 , -0.238691],
331 | [-0.941994, -0.246809, -0.227447, -0.370447],
332 | [ 0. , 0. , 0. , 1. ]])},
333 | {'A': False, 'B': False, 'RThU': True, 'RJ': False, 'RG': False, 'RTr': False, 'X': False, 'Y': False, 'LThU': True, 'LJ': False, 'LG': False, 'LTr': False, 'leftJS': (0.0, 0.0), 'leftTrig': (0.0,), 'leftGrip': (0.0,), 'rightJS': (0.0, 0.0), 'rightTrig': (0.0,), 'rightGrip': (0.0,)})
334 |
335 | """
336 | pose_data, button_data = oculus_reader.get_transformations_and_buttons()
337 | if len(pose_data) == 0:
338 | print("no data")
339 | continue
340 | else:
341 | print(pose_data, button_data)
342 |
--------------------------------------------------------------------------------
/agents/quest_agent_eef.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import numpy as np
4 | import quaternion
5 | from oculus_reader.reader import OculusReader
6 | from agents.agent import Agent
7 |
8 |
9 | mj2ur = np.array([[0, -1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
10 | ur2mj = np.linalg.inv(mj2ur)
11 |
12 | trigger_state = {"l": False, "r": False}
13 |
14 |
15 | def apply_transfer(mat: np.ndarray, xyz: np.ndarray) -> np.ndarray:
16 | # xyz can be 3dim or 4dim (homogeneous) or can be a rotation matrix
17 | if len(xyz) == 3:
18 | xyz = np.append(xyz, 1)
19 | return np.matmul(mat, xyz)[:3]
20 |
21 |
22 | quest2isaac = np.array([[0, 0, -1, 0], [-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
23 | left2isaac = np.array(
24 | [
25 | [0, -1, 0, 0],
26 | [-1 / 2 * np.sqrt(2), 0, -1 / 2 * np.sqrt(2), 0],
27 | [1 / 2 * np.sqrt(2), 0, -1 / 2 * np.sqrt(2), 0],
28 | [0, 0, 0, 1],
29 | ]
30 | )
31 | right2isaac = np.array(
32 | [
33 | [0, -1, 0, 0],
34 | [-1 / 2 * np.sqrt(2), 0, 1 / 2 * np.sqrt(2), 0],
35 | [-1 / 2 * np.sqrt(2), 0, -1 / 2 * np.sqrt(2), 0],
36 | [0, 0, 0, 1],
37 | ]
38 | )
39 | isaac2left = np.linalg.inv(left2isaac)
40 | isaac2right = np.linalg.inv(right2isaac)
41 |
42 | quest2left = np.matmul(isaac2left, quest2isaac)
43 | quest2right = np.matmul(isaac2right, quest2isaac)
44 | left2quest = np.linalg.inv(quest2left)
45 | right2quest = np.linalg.inv(quest2right)
46 |
47 | translation_scaling_factor = 1.0
48 |
49 |
50 | class SingleArmQuestAgent(Agent):
51 | def __init__(
52 | self,
53 | robot_type: str,
54 | which_hand: str,
55 | eef_control_mode: int = 0,
56 | verbose: bool = False,
57 | ) -> None:
58 | """Interact with the robot using the quest controller.
59 |
60 | leftTrig: press to start control (also record the current position as the home position)
61 | leftJS: a tuple of (x,y) for the joystick, only need y to control the gripper
62 | """
63 | self.which_hand = which_hand
64 | self.eef_control_mode = eef_control_mode
65 | assert self.which_hand in ["l", "r"]
66 |
67 | self.oculus_reader = OculusReader()
68 | self.control_active = False
69 | self.reference_quest_pose = None
70 | self.reference_ee_rot_ur = None
71 | self.reference_ee_pos_ur = None
72 | self.reference_js = [0.5, 0.5]
73 | self.js_speed_scale = 0.1
74 |
75 | self.robot_type = robot_type
76 | self._verbose = verbose
77 |
78 | def act(self, obs: Dict[str, np.ndarray]) -> np.ndarray:
79 | if self.robot_type == "ur5":
80 | num_dof = 6
81 | current_eef_pose = obs["ee_pos_quat"]
82 |
83 | # pos and rot in robot base frame
84 | ee_pos = current_eef_pose[:3]
85 | ee_rot = current_eef_pose[3:]
86 | ee_rot = quaternion.as_rotation_matrix(quaternion.from_rotation_vector(ee_rot))
87 |
88 | if self.which_hand == "l":
89 | pose_key = "l"
90 | trigger_key = "leftTrig"
91 | grip_key = "leftGrip"
92 | joystick_key = "leftJS"
93 | # left yx
94 | gripper_open_key = "Y"
95 | gripper_close_key = "X"
96 | elif self.which_hand == "r":
97 | pose_key = "r"
98 | trigger_key = "rightTrig"
99 | grip_key = "rightGrip"
100 | joystick_key = "rightJS"
101 | # right ba for the key
102 | gripper_open_key = "B"
103 | gripper_close_key = "A"
104 | else:
105 | raise ValueError(f"Unknown hand: {self.which_hand}")
106 | # check the trigger button state
107 | (
108 | pose_data,
109 | button_data,
110 | ) = self.oculus_reader.get_transformations_and_buttons()
111 | if len(pose_data) == 0 or len(button_data) == 0:
112 | print("no data, quest not yet ready")
113 | return np.concatenate(
114 | [current_eef_pose, obs["joint_positions"][num_dof:] * 0.0]
115 | )
116 |
117 | if self.eef_control_mode == 0:
118 | new_gripper_angle = [button_data[grip_key][0]]
119 | elif self.eef_control_mode == 1:
120 | new_gripper_angle = [button_data[grip_key][0]] * 6 # [0, 1]
121 | else:
122 | # (x, y) position of joystick, range (-1.0, 1.0)
123 | if self.which_hand == "r":
124 | js_y = button_data[joystick_key][0] * -1
125 | js_x = button_data[joystick_key][1] * -1
126 | else:
127 | js_y = button_data[joystick_key][0]
128 | js_x = button_data[joystick_key][1] * -1
129 |
130 | if self.eef_control_mode == 2:
131 | # convert js_x, js_y from range (-1.0, 1.0) to (0.0, 1.0)
132 | js_x = (js_x + 1) / 2
133 | js_y = (js_y + 1) / 2
134 | # control absolute position using joystick
135 | self.reference_js = [js_x, js_y]
136 | else:
137 | # control relative position using joystick
138 | self.reference_js = [
139 | max(0, min(self.reference_js[0] + js_x * self.js_speed_scale, 1)),
140 | max(0, min(self.reference_js[1] + js_y * self.js_speed_scale, 1)),
141 | ]
142 | new_gripper_angle = [
143 | button_data[grip_key][0],
144 | button_data[grip_key][0],
145 | button_data[grip_key][0],
146 | button_data[grip_key][0],
147 | self.reference_js[0],
148 | self.reference_js[1],
149 | ] # [0, 1]
150 | arm_not_move_return = np.concatenate([current_eef_pose, new_gripper_angle])
151 | if len(pose_data) == 0:
152 | print("no data, quest not yet ready")
153 | return arm_not_move_return
154 |
155 | global trigger_state
156 | trigger_state[self.which_hand] = button_data[trigger_key][0] > 0.5
157 | if trigger_state[self.which_hand]:
158 | if self.control_active is True:
159 | if self._verbose:
160 | print("controlling the arm")
161 | current_pose = pose_data[pose_key]
162 | delta_rot = current_pose[:3, :3] @ np.linalg.inv(
163 | self.reference_quest_pose[:3, :3]
164 | )
165 | delta_pos = current_pose[:3, 3] - self.reference_quest_pose[:3, 3]
166 | if self.which_hand == "l":
167 | t_mat = quest2left
168 | t_mat_inv = left2quest
169 | ur2isaac = left2isaac
170 | else:
171 | t_mat = quest2right
172 | t_mat_inv = right2quest
173 | ur2isaac = right2isaac
174 | delta_pos_ur = (
175 | apply_transfer(t_mat, delta_pos) * translation_scaling_factor
176 | )
177 | delta_rot_ur = t_mat[:3, :3] @ delta_rot @ t_mat_inv[:3, :3]
178 | if self._verbose:
179 | print(f"delta pos and rot in VR space: \n{delta_pos}, {delta_rot}")
180 | print(
181 | f"delta pos and rot in ur space: \n{delta_pos_ur}, {delta_rot_ur}"
182 | )
183 | delta_pos_isaac = apply_transfer(quest2isaac, delta_pos)
184 | delta_pos_isaac_ur = apply_transfer(ur2isaac, delta_pos_ur)
185 | print("delta pos in isaac", delta_pos_isaac)
186 | print("delta pos in ur in isaac", delta_pos_isaac_ur)
187 | print("delta", delta_pos_isaac - delta_pos_isaac_ur)
188 |
189 | next_ee_rot_ur = delta_rot_ur @ self.reference_ee_rot_ur # [3, 3]
190 | next_ee_pos_ur = delta_pos_ur + self.reference_ee_pos_ur
191 | next_ee_rot_ur = quaternion.as_rotation_vector(
192 | quaternion.from_rotation_matrix(next_ee_rot_ur)
193 | )
194 | new_eef_pose = np.concatenate([next_ee_pos_ur, next_ee_rot_ur])
195 | command = np.concatenate([new_eef_pose, new_gripper_angle])
196 | return command
197 |
198 | else: # last state is not in active
199 | self.control_active = True
200 | if self._verbose:
201 | print("control activated!")
202 | self.reference_quest_pose = pose_data[pose_key]
203 |
204 | # reference in their local TCP frames
205 | self.reference_ee_rot_ur = ee_rot
206 | self.reference_ee_pos_ur = ee_pos
207 | return arm_not_move_return
208 | else:
209 | if self._verbose:
210 | print("deactive control")
211 | self.control_active = False
212 | self.reference_quest_pose = None
213 | return arm_not_move_return
214 |
215 |
216 | class DualArmQuestAgent(Agent):
217 | def __init__(self, agent_left: Agent, agent_right: Agent):
218 | self.agent_left = agent_left
219 | self.agent_right = agent_right
220 | global trigger_state
221 | self.trigger_state = trigger_state
222 |
223 | def act(self, obs: Dict) -> np.ndarray:
224 | left_obs = {}
225 | right_obs = {}
226 | for key, val in obs.items():
227 | L = val.shape[0]
228 | half_dim = L // 2
229 | if key.endswith("rgb") or key.endswith("depth"):
230 | left_obs[key] = val
231 | right_obs[key] = val
232 | else:
233 | assert L == half_dim * 2, f"{key} must be even, something is wrong"
234 | left_obs[key] = val[:half_dim]
235 | right_obs[key] = val[half_dim:]
236 | return np.concatenate(
237 | [self.agent_left.act(left_obs), self.agent_right.act(right_obs)]
238 | )
239 |
240 |
241 | if __name__ == "__main__":
242 | oculus_reader = OculusReader()
243 | while True:
244 | """
245 | example output:
246 | ({'l': array([[-0.828395 , 0.541667 , -0.142682 , 0.219646 ],
247 | [-0.107737 , 0.0958919, 0.989544 , -0.833478 ],
248 | [ 0.549685 , 0.835106 , -0.0210789, -0.892425 ],
249 | [ 0. , 0. , 0. , 1. ]]), 'r': array([[-0.328058, 0.82021 , 0.468652, -1.8288 ],
250 | [ 0.070887, 0.516083, -0.8536 , -0.238691],
251 | [-0.941994, -0.246809, -0.227447, -0.370447],
252 | [ 0. , 0. , 0. , 1. ]])},
253 | {'A': False, 'B': False, 'RThU': True, 'RJ': False, 'RG': False, 'RTr': False, 'X': False, 'Y': False, 'LThU': True, 'LJ': False, 'LG': False, 'LTr': False, 'leftJS': (0.0, 0.0), 'leftTrig': (0.0,), 'leftGrip': (0.0,), 'rightJS': (0.0, 0.0), 'rightTrig': (0.0,), 'rightGrip': (0.0,)})
254 |
255 | """
256 | pose_data, button_data = oculus_reader.get_transformations_and_buttons()
257 | if len(pose_data) == 0:
258 | print("no data")
259 | continue
260 | else:
261 | print(pose_data, button_data)
262 |
--------------------------------------------------------------------------------
/agents/universal_robots_ur5e/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2018 ROS Industrial Consortium
2 |
3 | Redistribution and use in source and binary forms, with or without modification,
4 | are permitted provided that the following conditions are met:
5 |
6 | 1. Redistributions of source code must retain the above copyright notice, this
7 | list of conditions and the following disclaimer.
8 |
9 | 2. Redistributions in binary form must reproduce the above copyright notice,
10 | this list of conditions and the following disclaimer in the documentation and/or
11 | other materials provided with the distribution.
12 |
13 | 3. Neither the name of the copyright holder nor the names of its contributors
14 | may be used to endorse or promote products derived from this software without
15 | specific prior written permission.
16 |
17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
21 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
23 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
24 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 |
--------------------------------------------------------------------------------
/agents/universal_robots_ur5e/README.md:
--------------------------------------------------------------------------------
1 | # Universal Robots UR5e Description (MJCF)
2 |
3 | Requires MuJoCo 2.3.3 or later.
4 |
5 | ## Overview
6 |
7 | This package contains a simplified robot description (MJCF) of the
8 | [UR5e](https://www.universal-robots.com/products/ur5-robot/) developed by
9 | [Universal Robots](https://www.universal-robots.com/). It is derived from the
10 | [publicly available URDF
11 | description](https://github.com/ros-industrial/universal_robot/tree/kinetic-devel/ur_e_description).
12 |
13 |
14 |
15 |
16 |
17 | ### URDF → MJCF derivation steps
18 |
19 | 1. Converted the DAE [mesh
20 | files](https://github.com/ros-industrial/universal_robot/tree/kinetic-devel/ur_e_description/meshes/ur5e/visual)
21 | to OBJ format using [Blender](https://www.blender.org/).
22 | 2. Processed `.obj` files with [`obj2mjcf`](https://github.com/kevinzakka/obj2mjcf).
23 | 3. Added ` ` to the URDF's
24 | `` clause in order to preserve visual geometries.
25 | 4. Loaded the URDF into MuJoCo and saved a corresponding MJCF.
26 | 5. Added a tracking light to the base.
27 | 6. Manually edited the MJCF to extract common properties into the `` section.
28 | 7. Added position-controlled actuators. Max joint torque values were taken from
29 | [here](https://www.universal-robots.com/articles/ur/robot-care-maintenance/max-joint-torques/).
30 | 8. Added home joint configuration as a `keyframe`.
31 | 9. Manually designed collision geometries.
32 | 10. Added `scene.xml` which includes the robot, with a textured ground plane, skybox and haze.
33 |
34 | ## License
35 |
36 | This model is released under a [BSD-3-Clause License](LICENSE).
37 |
--------------------------------------------------------------------------------
/agents/universal_robots_ur5e/scene.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/agents/universal_robots_ur5e/ur5e.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ToruOwO/hato/3b6f4496e11a386df8f362f398b36e832944b2b8/agents/universal_robots_ur5e/ur5e.png
--------------------------------------------------------------------------------
/agents/universal_robots_ur5e/ur5e.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
--------------------------------------------------------------------------------
/camera_node.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import threading
3 | import time
4 | from typing import Optional, Tuple
5 |
6 | import numpy as np
7 | import zmq
8 |
9 | from cameras.camera import CameraDriver
10 |
11 | DEFAULT_CAMERA_PORT = 5000
12 |
13 |
14 | class ZMQClientCamera(CameraDriver):
15 | """A class representing a ZMQ client for a leader robot."""
16 |
17 | def __init__(self, port: int = DEFAULT_CAMERA_PORT, host: str = "127.0.0.1"):
18 | self._context = zmq.Context()
19 | self._socket = self._context.socket(zmq.REQ)
20 | self._socket.connect(f"tcp://{host}:{port}")
21 |
22 | def read(
23 | self,
24 | img_size: Optional[Tuple[int, int]] = None,
25 | ) -> Tuple[np.ndarray, np.ndarray]:
26 | """Get the current state of the leader robot.
27 |
28 | Returns:
29 | T: The current state of the leader robot.
30 | """
31 | # pack the image_size and send it to the server
32 | send_message = pickle.dumps(img_size)
33 | self._socket.send(send_message)
34 | state_dict = pickle.loads(self._socket.recv())
35 | return state_dict
36 |
37 |
38 | class ZMQServerCamera:
39 | def __init__(
40 | self,
41 | camera: CameraDriver,
42 | port: int = DEFAULT_CAMERA_PORT,
43 | host: str = "127.0.0.1",
44 | ):
45 | self._camera = camera
46 | self._context = zmq.Context()
47 | self._socket = self._context.socket(zmq.REP)
48 | addr = f"tcp://{host}:{port}"
49 | debug_message = f"Camera Sever Binding to {addr}, Camera: {camera}"
50 | print(debug_message)
51 | self._timout_message = f"Timeout in Camera Server, Camera: {camera}"
52 | self._socket.bind(addr)
53 | self._stop_event = threading.Event()
54 |
55 | def serve(self) -> None:
56 | """Serve the leader robot state over ZMQ."""
57 | self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms
58 | while not self._stop_event.is_set():
59 | try:
60 | message = self._socket.recv()
61 | img_size = pickle.loads(message)
62 | camera_read = self._camera.read(img_size)
63 | self._socket.send(pickle.dumps(camera_read))
64 | except zmq.Again:
65 | print(self._timout_message)
66 | # Timeout occurred, check if the stop event is set
67 |
68 | def stop(self) -> None:
69 | """Signal the server to stop serving."""
70 | self._stop_event.set()
71 |
72 |
73 | class ZMQServerCameraFaster:
74 | def __init__(
75 | self,
76 | camera: CameraDriver,
77 | port: int = DEFAULT_CAMERA_PORT,
78 | host: str = "127.0.0.1",
79 | ):
80 | self._camera = camera
81 | self._context = zmq.Context()
82 | self._socket = self._context.socket(zmq.REP)
83 | addr = f"tcp://{host}:{port}"
84 | debug_message = f"Camera Sever Binding to {addr}, Camera: {camera}"
85 | print(debug_message)
86 | self._timout_message = f"Timeout in Camera Server, Camera: {camera}"
87 | self._socket.bind(addr)
88 | self._stop_event = threading.Event()
89 |
90 | self.cam_buffer = None
91 | self.refresh_interval = 1 / 30 # Refresh every 1/30 second
92 | self.refresh_thread = threading.Thread(target=self._refresh_buffer)
93 | self.refresh_thread.daemon = True
94 | self.refresh_thread.start()
95 |
96 | def serve(self) -> None:
97 | """Serve the leader robot state over ZMQ."""
98 | self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms
99 | while not self._stop_event.is_set():
100 | try:
101 | _ = self._socket.recv()
102 | if self.cam_buffer is not None:
103 | self._socket.send(pickle.dumps(self.cam_buffer))
104 | else:
105 | self._socket.send(b"Buffer is empty.")
106 | except zmq.Again:
107 | print(self._timout_message)
108 |
109 | def _refresh_buffer(self):
110 | """Periodically refresh the buffer."""
111 | while not self._stop_event.is_set():
112 | self.cam_buffer = self._camera.read()
113 | time.sleep(self.refresh_interval)
114 |
115 | def stop(self) -> None:
116 | """Signal the server to stop serving."""
117 | self._stop_event.set()
118 |
--------------------------------------------------------------------------------
/cameras/camera.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Protocol, Tuple
2 |
3 | import numpy as np
4 |
5 |
6 | class CameraDriver(Protocol):
7 | """Camera protocol.
8 |
9 | A protocol for a camera driver. This is used to abstract the camera from the rest of the code.
10 | """
11 |
12 | def read(
13 | self,
14 | img_size: Optional[Tuple[int, int]] = None,
15 | ) -> Tuple[np.ndarray, np.ndarray]:
16 | """Read a frame from the camera.
17 |
18 | Args:
19 | img_size: The size of the image to return. If None, the original size is returned.
20 | farthest: The farthest distance to map to 255.
21 |
22 | Returns:
23 | np.ndarray: The color image.
24 | np.ndarray: The depth image.
25 | """
26 |
--------------------------------------------------------------------------------
/cameras/realsense_camera.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import List, Optional, Tuple
3 |
4 | import cv2
5 | import numpy as np
6 | import pyrealsense2 as rs
7 |
8 | from .camera import CameraDriver
9 |
10 |
11 | def get_device_ids() -> List[str]:
12 | device_ids = []
13 | for dev in rs.context().query_devices():
14 | device_ids.append(dev.get_info(rs.camera_info.serial_number))
15 | return device_ids
16 |
17 |
18 | class RealSenseCamera(CameraDriver):
19 | def __repr__(self) -> str:
20 | return f"RealSenseCamera(device_ids={self.device_ids})"
21 |
22 | def __init__(
23 | self,
24 | device_ids: Optional[List] = None,
25 | height: int = 480,
26 | width: int = 640,
27 | fps: int = 30,
28 | warm_start: int = 60,
29 | img_size: Optional[Tuple[int, int]] = None,
30 | ):
31 | self.height = height
32 | self.width = width
33 | self.fps = fps
34 | if device_ids is None:
35 | self.device_ids = get_device_ids()
36 | else:
37 | self.device_ids = device_ids
38 | self.img_size = img_size
39 |
40 | # Start stream
41 | print(f"Connecting to RealSense cameras ({len(self.device_ids)} found) ...")
42 | self.pipes = []
43 | self.profiles = OrderedDict()
44 | for i, device_id in enumerate(self.device_ids):
45 | pipe = rs.pipeline()
46 | config = rs.config()
47 |
48 | config.enable_device(device_id)
49 | config.enable_stream(
50 | rs.stream.depth, self.width, self.height, rs.format.z16, self.fps
51 | )
52 | config.enable_stream(
53 | rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps
54 | )
55 |
56 | self.pipes.append(pipe)
57 | self.profiles[device_id] = pipe.start(config)
58 | print(f"Connected to camera {i} ({device_id}).")
59 |
60 | self.align = rs.align(rs.stream.color)
61 |
62 | # Warm start camera (realsense automatically adjusts brightness during initial frames)
63 | for _ in range(warm_start):
64 | self._get_frames()
65 |
66 | def _get_frames(self):
67 | framesets = [pipe.wait_for_frames() for pipe in self.pipes]
68 | return [self.align.process(frameset) for frameset in framesets]
69 |
70 | def get_num_cameras(self):
71 | return len(self.device_ids)
72 |
73 | def read(
74 | self,
75 | img_size: Optional[Tuple[int, int]] = None,
76 | concatenate: bool = False,
77 | ) -> Tuple[np.ndarray, np.ndarray]:
78 | """Read a frame from the camera.
79 |
80 | Args:
81 | img_size: The size of the image to return. If None, the original size is returned.
82 | farthest: The farthest distance to map to 255.
83 |
84 | Returns:
85 | np.ndarray: The color image, shape=(H, W, 3)
86 | np.ndarray: The depth image, shape=(H, W)
87 | """
88 | framesets = self._get_frames()
89 | num_cams = self.get_num_cameras()
90 | rgbd = np.empty([num_cams, self.height, self.width, 4], dtype=np.uint16)
91 | if self.img_size is not None:
92 | rgbd_resized = np.empty(
93 | [num_cams, self.img_size[1], self.img_size[0], 4], dtype=np.uint16
94 | )
95 |
96 | for i, frameset in enumerate(framesets):
97 | color_frame = frameset.get_color_frame()
98 | rgbd[i, :, :, :3] = np.asanyarray(color_frame.get_data())
99 |
100 | depth_frame = frameset.get_depth_frame()
101 | depth_frame = rs.decimation_filter(1).process(depth_frame)
102 | depth_frame = rs.disparity_transform(True).process(depth_frame)
103 | depth_frame = rs.spatial_filter().process(depth_frame)
104 | depth_frame = rs.temporal_filter().process(depth_frame)
105 | depth_frame = rs.disparity_transform(False).process(depth_frame)
106 | depth_frame = rs.hole_filling_filter().process(depth_frame)
107 | rgbd[i, :, :, 3] = np.asanyarray(depth_frame.get_data())
108 |
109 | if self.img_size is not None:
110 | rgbd_resized[i] = cv2.resize(rgbd[i], self.img_size)
111 |
112 | if self.img_size is not None:
113 | rgbd = rgbd_resized
114 |
115 | if concatenate:
116 | image = np.concatenate(rgbd[..., :3], axis=1, dtype=np.uint8)
117 | depth = np.concatenate(rgbd[..., -1], axis=1, dtype=np.uint8)
118 | else:
119 | image = rgbd[..., :3].astype(np.uint8)
120 | depth = rgbd[..., -1].astype(np.uint16)
121 | return image, depth
122 |
--------------------------------------------------------------------------------
/env.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import pickle
4 | import time
5 | from typing import Any, Dict, Optional
6 |
7 | import cv2
8 | import numpy as np
9 | from natsort import natsorted
10 |
11 | from cameras.camera import CameraDriver
12 | from robots.robot import Robot
13 |
14 |
15 | class Rate:
16 | def __init__(self, rate: float):
17 | self.last = time.time()
18 | self.rate = rate
19 |
20 | def sleep(self) -> None:
21 | while self.last + 1.0 / self.rate > time.time():
22 | time.sleep(0.0001)
23 | self.last = time.time()
24 |
25 |
26 | class EvalRobotEnv:
27 | def __init__(
28 | self,
29 | robot: Robot,
30 | traj_path: str,
31 | control_rate_hz: float,
32 | camera_dict: Optional[Dict[str, CameraDriver]] = None,
33 | ) -> None:
34 | self._robot = robot
35 | self._rate = Rate(control_rate_hz)
36 | self._camera_dict = {} if camera_dict is None else camera_dict
37 | self.traj_path = traj_path
38 |
39 | self.pkls = natsorted(
40 | glob.glob(os.path.join(self.traj_path, "*.pkl"), recursive=True)
41 | )
42 | print("Finished reading dir", self.traj_path)
43 | print("No. of files:", len(self.pkls))
44 | self.traj_len = len(self.pkls)
45 | self.count = 0
46 |
47 | def robot(self) -> Robot:
48 | """Get the robot object.
49 |
50 | Returns:
51 | robot: the robot object.
52 | """
53 | return self._robot
54 |
55 | def __len__(self):
56 | # Return positive integer for batched envs.
57 | return self.traj_len
58 |
59 | def step_eef(self, eef_pose: np.ndarray) -> Dict[str, Any]:
60 | """Step the environment forward.
61 |
62 | Args:
63 | eef_pose: end effector pose command to step the environment with.
64 |
65 | Returns:
66 | obs: observation from the environment.
67 | """
68 | assert len(eef_pose) == self._robot.num_dofs(), f"input:{len(eef_pose)}"
69 | self._robot.command_eef_pose(eef_pose)
70 | self._rate.sleep()
71 | return self.get_obs()
72 |
73 | def step(self, joints: np.ndarray) -> Dict[str, Any]:
74 | """Step the environment forward.
75 |
76 | Args:
77 | joints: joint angles command to step the environment with.
78 |
79 | Returns:
80 | obs: observation from the environment.
81 | """
82 | assert len(joints) == (
83 | self._robot.num_dofs()
84 | ), f"input:{len(joints)}, robot:{self._robot.num_dofs()}"
85 | assert self._robot.num_dofs() == len(joints)
86 | self._robot.command_joint_state(joints)
87 | self._rate.sleep()
88 | return self.get_obs()
89 |
90 | def get_real_obs(self) -> Dict[str, Any]:
91 | observations = {}
92 | for name, camera in self._camera_dict.items():
93 | image, depth = camera.read()
94 | observations[f"{name}_rgb"] = image
95 | observations[f"{name}_depth"] = depth
96 |
97 | robot_obs = self._robot.get_observations()
98 | for k, v in robot_obs.items():
99 | observations[k] = v
100 | return observations
101 |
102 | def get_obs(self) -> Dict[str, Any]:
103 | """Get observation from the environment.
104 |
105 | Returns:
106 | obs: observation from the environment.
107 | """
108 | if self.count >= self.traj_len:
109 | return None
110 | pkl = self.pkls[self.count]
111 | with open(pkl, "rb") as f:
112 | observations = pickle.load(f)
113 | self.count += 1
114 | return observations
115 |
116 |
117 | class RobotEnv:
118 | def __init__(
119 | self,
120 | robot: Robot,
121 | control_rate_hz: float = 100.0,
122 | camera_dict: Optional[Dict[str, CameraDriver]] = None,
123 | show_camera_view: bool = True,
124 | save_depth: bool = True,
125 | ) -> None:
126 | self._robot = robot
127 | self._rate = Rate(control_rate_hz)
128 | print("RobotEnv: control_rate_hz", control_rate_hz)
129 | self._camera_dict = {} if camera_dict is None else camera_dict
130 |
131 | self._show_camera_view = show_camera_view
132 | if self._show_camera_view:
133 | for name in list(self._camera_dict.keys()):
134 | cv2.namedWindow(name, cv2.WINDOW_NORMAL)
135 |
136 | self._save_depth = save_depth
137 |
138 | def robot(self) -> Robot:
139 | """Get the robot object.
140 |
141 | Returns:
142 | robot: the robot object.
143 | """
144 | return self._robot
145 |
146 | def __len__(self):
147 | # Return positive integer for batched envs.
148 | return 0
149 |
150 | def step_eef(self, eef_pose: np.ndarray) -> Dict[str, Any]:
151 | """Step the environment forward.
152 |
153 | Args:
154 | eef_pose: end effector pose command to step the environment with.
155 |
156 | Returns:
157 | obs: observation from the environment.
158 | """
159 | assert len(eef_pose) == self._robot.num_dofs(), f"input:{len(eef_pose)}"
160 | self._robot.command_eef_pose(eef_pose)
161 | self._rate.sleep()
162 | return self.get_obs()
163 |
164 | def step(self, joints: np.ndarray) -> Dict[str, Any]:
165 | """Step the environment forward.
166 |
167 | Args:
168 | joints: joint angles command to step the environment with.
169 |
170 | Returns:
171 | obs: observation from the environment.
172 | """
173 | assert len(joints) == (
174 | self._robot.num_dofs()
175 | ), f"input:{len(joints)}, robot:{self._robot.num_dofs()}"
176 | assert self._robot.num_dofs() == len(joints)
177 | self._robot.command_joint_state(joints)
178 | self._rate.sleep()
179 | return self.get_obs()
180 |
181 | def get_obs(self) -> Dict[str, Any]:
182 | """Get observation from the environment.
183 |
184 | Returns:
185 | obs: observation from the environment.
186 | """
187 | observations = {}
188 | for name, camera in self._camera_dict.items():
189 | image, depth = camera.read()
190 | observations[f"{name}_rgb"] = image
191 | if self._save_depth:
192 | observations[f"{name}_depth"] = depth
193 |
194 | if self._show_camera_view:
195 | depth = cv2.applyColorMap(depth, cv2.COLORMAP_JET)
196 | image_depth = cv2.hconcat([image[:, :, ::-1], depth])
197 | cv2.imshow(name, image_depth)
198 | cv2.waitKey(1)
199 |
200 | robot_obs = self._robot.get_observations()
201 | for k, v in robot_obs.items():
202 | observations[k] = v
203 | return observations
204 |
--------------------------------------------------------------------------------
/eval_dir.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import pickle
4 |
5 | from agents.dp_agent import BimanualDPAgent
6 |
7 |
8 | def eval_ckpts(ckpt_paths, eval_dir, save_path):
9 | mse_dict = {}
10 |
11 | if os.path.exists(save_path):
12 | with open(save_path, "rb") as f:
13 | mse_dict = pickle.load(f)
14 | print(f"Loaded previous MSE dict from {save_path}")
15 |
16 | eval_loader = None
17 | last_arg = None
18 |
19 | for ckpt_path in ckpt_paths:
20 | ckpt_name = os.path.basename(os.path.dirname(ckpt_path))
21 | ckpt_num = os.path.basename(ckpt_path)
22 | agent = BimanualDPAgent(ckpt_path)
23 | if eval_loader is None:
24 | eval_loader = agent.dp.get_eval_loader(eval_dir)
25 | elif (
26 | agent.dp_args["representation_type"] != last_arg["representation_type"]
27 | or agent.dp_args["camera_indices"] != last_arg["camera_indices"]
28 | ):
29 | eval_loader = agent.dp.get_eval_loader(eval_dir)
30 | last_arg = agent.dp_args
31 | mse, action_mse = agent.dp.eval_dir(eval_loader)
32 | if mse_dict.get(ckpt_name) is None:
33 | mse_dict[ckpt_name] = {}
34 | mse_dict[ckpt_name]["config"] = agent.dp_args
35 |
36 | mse_dict[ckpt_name][ckpt_num] = {}
37 | mse_dict[ckpt_name][ckpt_num]["mse"] = mse
38 | mse_dict[ckpt_name][ckpt_num]["action_mse"] = action_mse
39 |
40 | print(f"MSE for {ckpt_name}: {mse}")
41 |
42 | with open(save_path, "wb") as f:
43 | pickle.dump(mse_dict, f)
44 | print(f"Saved MSE dict to {save_path}")
45 |
46 |
47 | if __name__ == "__main__":
48 | args = argparse.ArgumentParser()
49 | args.add_argument(
50 | "--ckpt_path",
51 | nargs="+",
52 | type=str,
53 | default=[
54 | "model_epoch_100.ckpt",
55 | "model_epoch_200.ckpt",
56 | "model_epoch_300.ckpt",
57 | ],
58 | )
59 | args.add_argument(
60 | "--eval_dir",
61 | type=str,
62 | default="/split_data/data_pour_train_10",
63 | )
64 | args.add_argument("--save_path", type=str, default=None)
65 |
66 | args = args.parse_args()
67 |
68 | if args.save_path is None:
69 | data_name = os.path.basename(args.eval_dir)
70 | args.save_path = "./eval_results/eval_{}.pkl".format(data_name)
71 | eval_ckpts(args.ckpt_path, args.eval_dir, args.save_path)
72 |
--------------------------------------------------------------------------------
/inference_node.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import collections
3 | import pickle
4 | import threading
5 | import time
6 |
7 | import numpy as np
8 | import zmq
9 |
10 | DEFAULT_INFERENCE_PORT = 4321
11 |
12 |
13 | class ZMQInferenceServer:
14 | """A class representing a ZMQ client for a leader robot."""
15 |
16 | def __init__(self, port: int = DEFAULT_INFERENCE_PORT):
17 | self._context = zmq.Context()
18 | self._socket = self._context.socket(zmq.PAIR)
19 | self._socket.bind(f"tcp://*:{port}")
20 | self._stop_event = threading.Event()
21 |
22 | def get_obs(self):
23 | state_dict = None
24 | while True:
25 | try:
26 | # check for a message, this will not block
27 | message = self._socket.recv(flags=zmq.NOBLOCK)
28 |
29 | except zmq.Again as e:
30 | # print("observation queue exhausted")
31 | break
32 | else:
33 | state_dict = pickle.loads(message)
34 | if state_dict is None: # block until an observation is recieved
35 | while True:
36 | message = self._socket.recv()
37 | state_dict = pickle.loads(message)
38 | if "obs" not in state_dict and "t" not in state_dict:
39 | if "num_diffusion_iters" in state_dict: # ignore and send success
40 | self._socket.send_string("success")
41 | continue
42 | break
43 |
44 | return state_dict["obs"], state_dict["t"]
45 |
46 | def infer(self, *args, **kwargs):
47 | raise NotImplementedError
48 |
49 | def act(self, obs):
50 | raise NotImplementedError
51 |
52 | def serve(self):
53 | # self._socket.setsockopt(zmq.RCVTIMEO, 10000) # Set timeout to 1000 ms
54 | while not self._stop_event.is_set():
55 | obs, t = self.get_obs() # get obs from the client
56 |
57 | print(f"Recieved observation at time {t}. Inference start!")
58 | pred = self.act(obs)
59 | print(f"Inference ended.")
60 |
61 | message = pickle.dumps({"acts": pred, "t": t})
62 | self._socket.send(message) # send the action back to the client
63 |
64 | def stop(self) -> None:
65 | """Signal the server to stop serving."""
66 | self._stop_event.set()
67 |
68 |
69 | class ZMQInferenceClient:
70 | """A class representing a ZMQ client for a leader robot."""
71 |
72 | def __init__(
73 | self,
74 | port: int = DEFAULT_INFERENCE_PORT,
75 | host: str = "111.11.111.11",
76 | default_action=None,
77 | queue_size=32,
78 | ensemble_mode="new",
79 | act_tau=0.5,
80 | ):
81 | self._context = zmq.Context()
82 | self._socket = self._context.socket(zmq.PAIR)
83 | self._socket.connect(f"tcp://{host}:{port}")
84 | print(f"connected -- tcp://{host}:{port}")
85 |
86 | self.act_q = collections.deque(maxlen=queue_size)
87 | self.t = 0
88 | self.last_act = default_action
89 | self.ensemble_mode = ensemble_mode
90 | self.act_tau = act_tau
91 |
92 | def act(self, obs):
93 | self.t += 1
94 | final_act = self.last_act
95 |
96 | # send the observation
97 | message = pickle.dumps({"obs": obs, "t": self.t})
98 | self._socket.send(message)
99 |
100 | # process the incoming message queue
101 | while True:
102 | try:
103 | # check for a message, this will not block
104 | message = self._socket.recv(flags=zmq.NOBLOCK)
105 |
106 | except zmq.Again as e:
107 | # print("action queue exhausted")
108 | break
109 | else:
110 | state_dict = pickle.loads(message)
111 | acts, pt = state_dict["acts"], state_dict["t"]
112 |
113 | while len(self.act_q) > 0 and self.act_q[0][1] < self.t:
114 | self.act_q.popleft()
115 | while pt < self.t and len(acts) > 0:
116 | pt += 1
117 | acts = acts[1:]
118 | for c_acts, ct in self.act_q:
119 | if ct == pt:
120 | c_acts.append(acts[0])
121 | pt += 1
122 | acts = acts[1:]
123 | if len(acts) == 0:
124 | break
125 | # for
126 | # push all the new actions in
127 | for i, act in enumerate(acts):
128 | self.act_q.append(([act], pt + i))
129 |
130 | # now searching for the matching time stamp
131 | while len(self.act_q) > 0:
132 | c_acts, tt = self.act_q.popleft()
133 | if tt == self.t:
134 | if self.ensemble_mode == "act":
135 | z_act = c_acts[0]
136 | for act in c_acts[1:]:
137 | z_act = z_act * self.act_tau + act * (1.0 - self.act_tau)
138 | final_act = z_act
139 | elif self.ensemble_mode == "avg":
140 | final_act = np.mean(np.array(c_acts), axis=0)
141 | elif self.ensemble_mode == "old":
142 | final_act = c_acts[0]
143 | elif self.ensemble_mode == "new":
144 | final_act = c_acts[-1]
145 |
146 | break
147 | print("action queue (dt):", [t - self.t for a, t in self.act_q])
148 | print("action queue (size):", [len(a) for a, t in self.act_q])
149 | self.last_act = final_act
150 | # print("action:", final_act)
151 | return final_act
152 |
153 |
154 | if __name__ == "__main__":
155 | args = argparse.ArgumentParser()
156 | args.add_argument("--type", type=str, default="client")
157 | args.add_argument("--freq", type=float, default=10)
158 | args.add_argument("--inft", type=float, default=0.2)
159 | args = args.parse_args()
160 |
161 | action_dim = 24
162 |
163 | class DummyAgentServer(ZMQInferenceServer):
164 | def infer(self, obs):
165 | print("start inference")
166 | time.sleep(args.inft)
167 | print("stop inference")
168 | return np.zeros((16, action_dim))
169 |
170 | if args.type == "server":
171 | server = DummyAgentServer()
172 | server.serve()
173 |
174 | elif args.type == "client":
175 | client = ZMQInferenceClient(
176 | default_action=np.zeros((action_dim,)), queue_size=32
177 | )
178 | obs = {
179 | "img": np.zeros((4, 240, 360), dtype=np.uint16),
180 | "eef": np.zeros((24,), dtype=np.float32),
181 | }
182 |
183 | while True:
184 | time.sleep(1 / args.freq)
185 | action = client.act(obs)
186 |
--------------------------------------------------------------------------------
/launch_inference_nodes.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import tyro
4 |
5 | from agents.dp_agent_zmq import BimanualDPAgentServer
6 |
7 |
8 | def boolean_string(s):
9 | if s not in {"False", "True"}:
10 | raise ValueError("Not a valid boolean string")
11 | return s == "True"
12 |
13 |
14 | @dataclass
15 | class Args:
16 | port: str = "4321"
17 | dp_ckpt_path: str = "best.ckpt"
18 |
19 |
20 | def launch_server(args: Args):
21 | server = BimanualDPAgentServer(ckpt_path=args.dp_ckpt_path, port=args.port)
22 | print(f"Starting inference server on {args.port}")
23 |
24 | print("Compiling inference")
25 | server.compile_inference()
26 | print("Done. Inference available.")
27 | server.serve()
28 |
29 |
30 | def main(args):
31 | launch_server(args)
32 |
33 |
34 | if __name__ == "__main__":
35 | main(tyro.cli(Args))
36 |
--------------------------------------------------------------------------------
/launch_nodes.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from multiprocessing import Process
3 | from typing import List, Optional, Tuple
4 |
5 | import tyro
6 |
7 | from camera_node import ZMQServerCamera, ZMQServerCameraFaster
8 | from robot_node import ZMQServerRobot
9 | from robots.robot import BimanualRobot
10 |
11 |
12 | @dataclass
13 | class Args:
14 | robot: str = "bimanual_ur"
15 | hand_type: str = ""
16 | hostname: str = "127.0.0.1"
17 | robot_ip: str = "111.111.1.1"
18 | faster: bool = True
19 | cam_names: Tuple[str, ...] = "435"
20 | ability_gripper_grip_range: int = 110
21 | img_size: Optional[Tuple[int, int]] = None # (320, 240)
22 |
23 |
24 | def launch_server_cameras(port: int, camera_id: List[str], args: Args):
25 | from cameras.realsense_camera import RealSenseCamera
26 |
27 | camera = RealSenseCamera(camera_id, img_size=args.img_size)
28 |
29 | if args.faster:
30 | server = ZMQServerCameraFaster(camera, port=port, host=args.hostname)
31 | else:
32 | server = ZMQServerCamera(camera, port=port, host=args.hostname)
33 | print(f"Starting camera server on port {port}")
34 | server.serve()
35 |
36 |
37 | def launch_robot_server(port: int, args: Args):
38 | if args.robot == "ur":
39 | from robots.ur import URRobot
40 |
41 | robot = URRobot(robot_ip=args.robot_ip)
42 | elif args.robot == "bimanual_ur":
43 | from robots.ur import URRobot
44 |
45 | if args.hand_type == "ability":
46 | # 6 DoF Ability Hand
47 | # robot_l - right hand; robot_r - left hand
48 | _robot_l = URRobot(
49 | robot_ip="111.111.1.3",
50 | no_gripper=False,
51 | gripper_type="ability",
52 | grip_range=args.ability_gripper_grip_range,
53 | port_idx=1,
54 | )
55 | _robot_r = URRobot(
56 | robot_ip="111.111.2.3",
57 | no_gripper=False,
58 | gripper_type="ability",
59 | grip_range=args.ability_gripper_grip_range,
60 | port_idx=2,
61 | )
62 | else:
63 | # Robotiq gripper
64 | _robot_l = URRobot(robot_ip="111.111.1.3", no_gripper=False)
65 | _robot_r = URRobot(robot_ip="111.111.2.3", no_gripper=False)
66 | robot = BimanualRobot(_robot_l, _robot_r)
67 | else:
68 | raise NotImplementedError(f"Robot {args.robot} not implemented")
69 | server = ZMQServerRobot(robot, port=port, host=args.hostname)
70 | print(f"Starting robot server on port {port}")
71 | server.serve()
72 |
73 |
74 | CAM_IDS = {
75 | "435": "000000000000",
76 | }
77 |
78 |
79 | def create_camera_server(args: Args) -> List[Process]:
80 | ids = [CAM_IDS[name] for name in args.cam_names]
81 | camera_port = 5000
82 | # start a single python process for all cameras
83 | print(f"Launching cameras {ids} on port {camera_port}")
84 | server = Process(target=launch_server_cameras, args=(camera_port, ids, args))
85 | return server
86 |
87 |
88 | def main(args):
89 | camera_server = create_camera_server(args)
90 | print("Starting camera server process")
91 | camera_server.start()
92 |
93 | launch_robot_server(6000, args)
94 |
95 |
96 | if __name__ == "__main__":
97 | main(tyro.cli(Args))
98 |
--------------------------------------------------------------------------------
/learning/dp/.gitignore:
--------------------------------------------------------------------------------
1 | model
2 | eval
3 | __pycache__
4 | *.sh
5 | *.yml
6 | eval
7 | runs
--------------------------------------------------------------------------------
/learning/dp/data_processing.py:
--------------------------------------------------------------------------------
1 | import concurrent.futures
2 | import os
3 | import pickle
4 |
5 | import cv2
6 | import natsort
7 | import numpy as np
8 |
9 |
10 | def from_pickle(path, load_img = True, num_cam = 3):
11 | with open(path, "rb") as f:
12 | data = pickle.load(f)
13 | if "base_rgb" not in data and load_img:
14 | rgb = []
15 | for i in range(num_cam):
16 | rgb_path = path.replace(".pkl", f"-{i}.png")
17 | if os.path.exists(rgb_path):
18 | rgb.append(cv2.imread(rgb_path))
19 | data["base_rgb"] = np.stack(rgb, axis=0)
20 |
21 | return data
22 |
23 | # Get the trajectory data from the given directory
24 | def iterate(path, workers=32, load_img=True, num_cam=3):
25 | dir = os.listdir(path)
26 | dir = [d for d in dir if d.endswith(".pkl")]
27 | dir = natsort.natsorted(dir)
28 | dirname = os.path.basename(path)
29 | root_path = "./mask_cache"
30 | data = []
31 | with concurrent.futures.ThreadPoolExecutor(max_workers = workers) as executor:
32 | futures = {executor.submit(from_pickle, os.path.join(path, file), load_img, num_cam): (i, file) for i, file in enumerate(dir)}
33 | for future in futures:
34 | try:
35 | i, file = futures[future]
36 | d = future.result()
37 | if not d["activated"]["l"] and not d["activated"]["r"]:
38 | continue
39 | basedirfile = os.path.join(dirname, file)
40 | maskfile = os.path.join(root_path, basedirfile)
41 | if os.path.exists(maskfile):
42 | d["mask"] = from_pickle(maskfile)
43 | d["mask_path"] = maskfile
44 | d["file_path"] = os.path.join(path, file)
45 | data.append(d)
46 | except:
47 | print(f"Failed to load {file}")
48 | pass
49 | return data
50 |
51 |
52 | def get_latest(path):
53 | dir = os.listdir(path)
54 | dir = natsort.natsorted(dir)
55 | return from_pickle(os.path.join(path, dir[-1]))
56 |
57 | # Get all trajectory directories from the given path
58 | def get_epi_dir(path, traj_type, prefix=None):
59 | dir = natsort.natsorted(os.listdir(path))
60 | if prefix is not None:
61 | prefixs = prefix.split("-")
62 |
63 | new_dir = []
64 | for d in dir:
65 | if os.path.isdir(os.path.join(path, d)):
66 | matched = False
67 | if prefix is None:
68 | matched = True
69 | else:
70 | for prefix in prefixs:
71 | if d.startswith(prefix):
72 | matched = True
73 | if matched:
74 | new_dir.append(d)
75 |
76 | print("All Directories")
77 | print(new_dir)
78 | print("==========")
79 | dir = new_dir
80 | if traj_type == "plain":
81 | dir = [
82 | d
83 | for d in dir
84 | if not d.endswith("failed")
85 | and not d.endswith("ood")
86 | and not d.endswith("ikbad")
87 | and not d.endswith("heated")
88 | and not d.endswith("stop")
89 | and not d.endswith("hard")
90 | ]
91 | elif traj_type == "all":
92 | dir = dir
93 | else:
94 | raise NotImplementedError
95 | dir_list = [
96 | os.path.join(path, d) for d in dir if os.path.isdir(os.path.join(path, d))
97 | ]
98 | return dir_list
99 |
100 |
--------------------------------------------------------------------------------
/learning/dp/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | import numpy as np
5 | import torch
6 |
7 |
8 | def create_sample_indices(
9 | episode_ends: np.ndarray,
10 | sequence_length: int,
11 | pad_before: int = 0,
12 | pad_after: int = 0,
13 | ):
14 | indices = list()
15 | for i in range(len(episode_ends)):
16 | start_idx = 0
17 | if i > 0:
18 | start_idx = episode_ends[i - 1]
19 | end_idx = episode_ends[i]
20 | episode_length = end_idx - start_idx
21 |
22 | min_start = -pad_before
23 | max_start = episode_length - sequence_length + pad_after
24 |
25 | # range stops one idx before end
26 | for idx in range(min_start, max_start + 1):
27 | buffer_start_idx = max(idx, 0) + start_idx
28 | buffer_end_idx = min(idx + sequence_length, episode_length) + start_idx
29 | start_offset = buffer_start_idx - (idx + start_idx)
30 | end_offset = (idx + sequence_length + start_idx) - buffer_end_idx
31 | sample_start_idx = 0 + start_offset
32 | sample_end_idx = sequence_length - end_offset
33 | indices.append(
34 | [buffer_start_idx, buffer_end_idx, sample_start_idx, sample_end_idx]
35 | )
36 | indices = np.array(indices)
37 | return indices
38 |
39 |
40 | def sample_sequence(
41 | train_data,
42 | sequence_length,
43 | buffer_start_idx,
44 | buffer_end_idx,
45 | sample_start_idx,
46 | sample_end_idx,
47 | ):
48 | result = dict()
49 | for key, input_arr in train_data.items():
50 | sample = input_arr[buffer_start_idx:buffer_end_idx]
51 | data = sample
52 | if (sample_start_idx > 0) or (sample_end_idx < sequence_length):
53 | data = np.zeros(
54 | shape=(sequence_length,) + input_arr.shape[1:], dtype=input_arr.dtype
55 | )
56 | if sample_start_idx > 0:
57 | data[:sample_start_idx] = sample[0]
58 | if sample_end_idx < sequence_length:
59 | data[sample_end_idx:] = sample[-1]
60 | data[sample_start_idx:sample_end_idx] = sample
61 | result[key] = data
62 | return result
63 |
64 |
65 | # normalize data
66 | def get_data_stats(data):
67 | data = data.reshape(-1, data.shape[-1])
68 | stats = {"min": np.min(data, axis=0), "max": np.max(data, axis=0)}
69 | if np.any(stats["max"] > 1e5) or np.any(stats["min"] < -1e5):
70 | raise ValueError("data out of range")
71 | return stats
72 |
73 |
74 | def normalize_data(data, stats):
75 | # nomalize to [0,1]
76 | ndata = (data - stats["min"]) / (stats["max"] - stats["min"] + 1e-8)
77 | # normalize to [-1, 1]
78 | ndata = ndata * 2 - 1
79 | return ndata
80 |
81 |
82 | def unnormalize_data(ndata, stats):
83 | ndata = (ndata + 1) / 2
84 | data = ndata * (stats["max"] - stats["min"] + 1e-8) + stats["min"]
85 | return data
86 |
87 |
88 | class MemmapLoader:
89 | def __init__(self, path):
90 | with open(os.path.join(path, "metadata.pkl"), "rb") as f:
91 | meta_data = pickle.load(f)
92 |
93 | print("Meta Data:", meta_data)
94 | self.fps = {}
95 |
96 | self.length = None
97 | for key, (shape, dtype) in meta_data.items():
98 | self.fps[key] = np.memmap(
99 | os.path.join(path, key + ".dat"), dtype=dtype, shape=shape, mode="r"
100 | )
101 | if self.length is None:
102 | self.length = shape[0]
103 | else:
104 | assert self.length == shape[0]
105 |
106 | def __getitem__(self, index):
107 | rets = {}
108 | for key in self.fps.keys():
109 | value = self.fps[key]
110 | value = value[index]
111 | value_cp = np.empty(dtype=value.dtype, shape=value.shape)
112 | value_cp[:] = value
113 | rets[key] = value_cp
114 | return rets
115 |
116 | def __length__(self):
117 | return self.length
118 |
119 |
120 | # dataset
121 | class Dataset(torch.utils.data.Dataset):
122 | def __init__(
123 | self,
124 | data: dict,
125 | representation_type: list,
126 | pred_horizon: int,
127 | obs_horizon: int,
128 | action_horizon: int,
129 | stats: dict = None,
130 | transform=None,
131 | get_img=None,
132 | load_img: bool = False,
133 | hand_grip_range: int = 110,
134 | binarize_touch: bool = False,
135 | state_noise: float = 0.0,
136 | ):
137 | self.state_noise = state_noise
138 | self.memmap_loader = None
139 | if "memmap_loader_path" in data.keys():
140 | self.memmap_loader = MemmapLoader(data["memmap_loader_path"])
141 | self.representation_type = representation_type
142 | self.transform = transform
143 | self.get_img = get_img
144 | self.load_img = load_img
145 |
146 | print("Representation type: ", representation_type)
147 | except_img_representation_type = representation_type.copy()
148 |
149 | if "img" in representation_type:
150 | train_image_data = data["data"]["img"][:]
151 | except_img_representation_type.remove("img")
152 |
153 | train_data = {
154 | rt: data["data"][rt][:, :] for rt in except_img_representation_type
155 | }
156 | train_data["action"] = data["data"]["action"][:]
157 | episode_ends = data["meta"]["episode_ends"][:]
158 |
159 | # compute start and end of each state-action sequence
160 | # also handles padding
161 | indices = create_sample_indices(
162 | episode_ends=episode_ends,
163 | sequence_length=pred_horizon,
164 | pad_before=obs_horizon - 1,
165 | pad_after=action_horizon - 1,
166 | )
167 |
168 | normalized_train_data = dict()
169 |
170 | # compute statistics and normalized data to [-1,1]
171 | if stats is None:
172 | stats = dict()
173 | for key, data in train_data.items():
174 | stats[key] = get_data_stats(data)
175 |
176 | # overwrite the min max of hand info
177 | left_hand_indices = np.array([6, 7, 8, 9, 10, 11])
178 | right_hand_indices = np.array([18, 19, 20, 21, 22, 23])
179 |
180 | # hand joint_position range
181 | hand_upper_ranges = np.array(
182 | [hand_grip_range] * 4 + [90, 120], dtype=np.float32
183 | )
184 | hand_lower_ranges = np.array([5, 5, 5, 5, 5, 5], dtype=np.float32)
185 |
186 | if "pos" in representation_type:
187 | stats["pos"]["min"][left_hand_indices] = hand_lower_ranges
188 | stats["pos"]["min"][right_hand_indices] = hand_lower_ranges
189 | stats["pos"]["max"][left_hand_indices] = hand_upper_ranges
190 | stats["pos"]["max"][right_hand_indices] = hand_upper_ranges
191 | elif "hand_pos" in representation_type:
192 | stats["hand_pos"]["min"][range(6)] = hand_lower_ranges
193 | stats["hand_pos"]["min"][range(6, 12)] = hand_lower_ranges
194 | stats["hand_pos"]["max"][range(6)] = hand_upper_ranges
195 | stats["hand_pos"]["max"][range(6, 12)] = hand_upper_ranges
196 |
197 | # hand action is normalized to [0,1]
198 | stats["action"]["min"][left_hand_indices] = 0.0
199 | stats["action"]["max"][left_hand_indices] = 1.0
200 | stats["action"]["min"][right_hand_indices] = 0.0
201 | stats["action"]["max"][right_hand_indices] = 1.0
202 |
203 | for key, data in train_data.items():
204 | if key == "touch" and binarize_touch:
205 | normalized_train_data[key] = (
206 | data # don't normalize if binarize touch in model
207 | )
208 | else:
209 | normalized_train_data[key] = normalize_data(data, stats[key])
210 |
211 | # images are already normalized
212 | if "img" in representation_type:
213 | normalized_train_data["img"] = train_image_data
214 |
215 | self.indices = indices
216 | self.stats = stats
217 | self.normalized_train_data = normalized_train_data
218 | self.pred_horizon = pred_horizon
219 | self.action_horizon = action_horizon
220 | self.obs_horizon = obs_horizon
221 | self.binarize_touch = binarize_touch
222 |
223 | def __len__(self):
224 | return len(self.indices)
225 |
226 | def read_img(self, image_pathes, idx):
227 | if self.memmap_loader is not None:
228 | # using memmap loader
229 | indices = range(idx, idx + self.obs_horizon)
230 | data = self.memmap_loader[indices]
231 | data = [
232 | {"base_rgb": data["base_rgb"][i], "base_depth": data["base_depth"][i]}
233 | for i in range(data["base_rgb"].shape[0])
234 | ]
235 | else:
236 | # not using memmap loader and loading images while training
237 | data = [pickle.load(open(image_path, "rb")) for image_path in image_pathes]
238 | imgs = self.get_img(data)
239 | return imgs
240 |
241 | def __getitem__(self, idx):
242 | # get the start/end indices for this datapoint
243 | (
244 | buffer_start_idx,
245 | buffer_end_idx,
246 | sample_start_idx,
247 | sample_end_idx,
248 | ) = self.indices[idx]
249 |
250 | # get nomralized data using these indices
251 | nsample = sample_sequence(
252 | train_data=self.normalized_train_data,
253 | sequence_length=self.pred_horizon,
254 | buffer_start_idx=buffer_start_idx,
255 | buffer_end_idx=buffer_end_idx,
256 | sample_start_idx=sample_start_idx,
257 | sample_end_idx=sample_end_idx,
258 | )
259 |
260 | for k in self.representation_type:
261 | # discard unused observations
262 | nsample[k] = nsample[k][: self.obs_horizon]
263 | if k == "img":
264 | if not self.load_img:
265 | nsample["img"] = self.read_img(nsample["img"], idx)
266 | else:
267 | nsample["img"] = torch.tensor(
268 | nsample["img"].astype(np.float32), dtype=torch.float32
269 | )
270 | nsample_shape = nsample["img"].shape
271 | # transform the img
272 | nsample["img"] = nsample["img"].reshape(
273 | nsample_shape[0] * nsample_shape[1], *nsample_shape[2:]
274 | ) # (Batch * num_cam, Channel, Height, Width)
275 | nsample["img"] = self.transform(nsample["img"])
276 | nsample["img"] = nsample["img"].reshape(nsample_shape[:3] + (216, 288)) # (Batch, num_cam, Channel, Height, Width)
277 |
278 | else:
279 | nsample[k] = torch.tensor(nsample[k], dtype=torch.float32)
280 | if self.state_noise > 0.0:
281 | # add noise to the state
282 | nsample[k] = nsample[k] + torch.randn_like(nsample[k]) * self.state_noise
283 | nsample["action"] = torch.tensor(nsample["action"], dtype=torch.float32)
284 | return nsample
285 |
--------------------------------------------------------------------------------
/learning/dp/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import pprint
4 | import random
5 | import string
6 | import tempfile
7 | import time
8 | from copy import copy
9 | from datetime import datetime
10 |
11 | import absl.flags
12 | import numpy as np
13 | import quaternion
14 | import wandb
15 | from absl import logging
16 | from ml_collections import ConfigDict
17 | from ml_collections.config_dict import config_dict
18 | from ml_collections.config_flags import config_flags
19 |
20 |
21 | def save_args(args, output_dir):
22 | args_dict = vars(args)
23 | args_str = "\n".join(f"{key}: {value}" for key, value in args_dict.items())
24 | with open(os.path.join(output_dir, "args_log.txt"), "w") as file:
25 | file.write(args_str)
26 | with open(os.path.join(output_dir, "args_log.json"), "w") as json_file:
27 | json.dump(args_dict, json_file, indent=4)
28 |
29 |
30 | def generate_random_string(length=4, characters=string.ascii_letters + string.digits):
31 | """
32 | Generate a random string of the specified length using the given characters.
33 |
34 | :param length: The length of the random string (default is 12).
35 | :param characters: The characters to choose from when generating the string
36 | (default is uppercase letters, lowercase letters, and digits).
37 | :return: A random string of the specified length.
38 | """
39 | return "".join(random.choice(characters) for _ in range(length))
40 |
41 |
42 | def get_eef_delta(eef_pose, eef_pose_target):
43 | pos_delta = eef_pose_target[:3] - eef_pose[:3]
44 | # axis angle to quaternion
45 | ee_rot = quaternion.from_rotation_vector(eef_pose[3:])
46 | ee_rot_target = quaternion.from_rotation_vector(eef_pose_target[3:])
47 | # calculate the quaternion difference
48 | rot_delta = quaternion.as_rotation_vector(ee_rot_target * ee_rot.inverse())
49 | return np.concatenate((pos_delta, rot_delta))
50 |
51 |
52 | class Timer(object):
53 | def __init__(self):
54 | self._time = None
55 |
56 | def __enter__(self):
57 | self._start_time = time.time()
58 | return self
59 |
60 | def __exit__(self, exc_type, exc_value, exc_tb):
61 | self._time = time.time() - self._start_time
62 |
63 | def __call__(self):
64 | return self._time
65 |
66 |
67 | class WandBLogger(object):
68 | @staticmethod
69 | def get_default_config(updates=None):
70 | config = ConfigDict()
71 | config.mode = "online"
72 | config.project = "hato"
73 | config.entity = "user"
74 | config.output_dir = "."
75 | config.exp_name = str(datetime.now())[:19].replace(" ", "_")
76 | config.random_delay = 0.5
77 | config.experiment_id = config_dict.placeholder(str)
78 | config.anonymous = config_dict.placeholder(str)
79 | config.notes = config_dict.placeholder(str)
80 | config.time = str(datetime.now())[:19].replace(" ", "_")
81 |
82 | if updates is not None:
83 | config.update(ConfigDict(updates).copy_and_resolve_references())
84 | return config
85 |
86 | def __init__(self, config, variant, prefix=None):
87 | self.config = self.get_default_config(config)
88 |
89 | for key, val in sorted(self.config.items()):
90 | if type(val) != str:
91 | continue
92 | new_val = _parse(val, variant)
93 | if val != new_val:
94 | logging.info(
95 | "processing configs: {}: {} => {}".format(key, val, new_val)
96 | )
97 | setattr(self.config, key, new_val)
98 |
99 | output = flatten_config_dict(self.config, prefix=prefix)
100 | variant.update(output)
101 |
102 | if self.config.output_dir == "":
103 | self.config.output_dir = tempfile.mkdtemp()
104 |
105 | output = flatten_config_dict(self.config, prefix=prefix)
106 | variant.update(output)
107 |
108 | self._variant = copy(variant)
109 |
110 | logging.info(
111 | "wandb logging with hyperparameters: \n{}".format(
112 | pprint.pformat(
113 | ["{}: {}".format(key, val) for key, val in self.variant.items()]
114 | )
115 | )
116 | )
117 |
118 | if self.config.random_delay > 0:
119 | time.sleep(np.random.uniform(0.1, 0.1 + self.config.random_delay))
120 |
121 | self.run = wandb.init(
122 | entity=self.config.entity,
123 | reinit=True,
124 | config=self._variant,
125 | project=self.config.project,
126 | dir=self.config.output_dir,
127 | name=self.config.exp_name,
128 | anonymous=self.config.anonymous,
129 | monitor_gym=False,
130 | notes=self.config.notes,
131 | settings=wandb.Settings(
132 | start_method="thread",
133 | _disable_stats=True,
134 | ),
135 | mode=self.config.mode,
136 | )
137 |
138 | self.logging_step = 0
139 |
140 | def log(self, *args, **kwargs):
141 | self.run.log(*args, **kwargs, step=self.logging_step)
142 |
143 | def step(self):
144 | self.logging_step += 1
145 |
146 | @property
147 | def experiment_id(self):
148 | return self.config.experiment_id
149 |
150 | @property
151 | def variant(self):
152 | return self._variant
153 |
154 | @property
155 | def output_dir(self):
156 | return self.config.output_dir
157 |
158 |
159 | def define_flags_with_default(**kwargs):
160 | for key, val in kwargs.items():
161 | if isinstance(val, ConfigDict):
162 | config_flags.DEFINE_config_dict(key, val)
163 | elif isinstance(val, bool):
164 | # Note that True and False are instances of int.
165 | absl.flags.DEFINE_bool(key, val, "automatically defined flag")
166 | elif isinstance(val, int):
167 | absl.flags.DEFINE_integer(key, val, "automatically defined flag")
168 | elif isinstance(val, float):
169 | absl.flags.DEFINE_float(key, val, "automatically defined flag")
170 | elif isinstance(val, str):
171 | absl.flags.DEFINE_string(key, val, "automatically defined flag")
172 | else:
173 | raise ValueError("Incorrect value type")
174 | return kwargs
175 |
176 |
177 | def _parse(s, variant):
178 | orig_s = copy(s)
179 | final_s = []
180 |
181 | while len(s) > 0:
182 | indx = s.find("{")
183 | if indx == -1:
184 | final_s.append(s)
185 | break
186 | final_s.append(s[:indx])
187 | s = s[indx + 1 :]
188 | indx = s.find("}")
189 | assert indx != -1, "can't find the matching right bracket for {}".format(orig_s)
190 | final_s.append(str(variant[s[:indx]]))
191 | s = s[indx + 1 :]
192 |
193 | return "".join(final_s)
194 |
195 |
196 | def get_user_flags(flags, flags_def):
197 | output = {}
198 | for key in sorted(flags_def):
199 | val = getattr(flags, key)
200 | if isinstance(val, ConfigDict):
201 | flatten_config_dict(val, prefix=key, output=output)
202 | else:
203 | output[key] = val
204 |
205 | return output
206 |
207 |
208 | def flatten_config_dict(config, prefix=None, output=None):
209 | if output is None:
210 | output = {}
211 | for key, val in sorted(config.items()):
212 | if prefix is not None:
213 | next_prefix = "{}.{}".format(prefix, key)
214 | else:
215 | next_prefix = key
216 | if isinstance(val, ConfigDict):
217 | flatten_config_dict(val, prefix=next_prefix, output=output)
218 | else:
219 | output[next_prefix] = val
220 | return output
221 |
222 |
223 | def to_config_dict(flattened):
224 | config = config_dict.ConfigDict()
225 | for key, val in flattened.items():
226 | c_config = config
227 | ks = key.split(".")
228 | for k in ks[:-1]:
229 | if k not in c_config:
230 | c_config[k] = config_dict.ConfigDict()
231 | c_config = c_config[k]
232 | c_config[ks[-1]] = val
233 | return config.to_dict()
234 |
235 |
236 | def prefix_metrics(metrics, prefix):
237 | return {"{}/{}".format(prefix, key): value for key, value in metrics.items()}
238 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # keep in alphabetical order
2 | diffusers
3 | ml-collections
4 | natsort
5 | numpy
6 | numpy-quaternion
7 | tensorboard
8 | tyro
--------------------------------------------------------------------------------
/robot_node.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import threading
3 | from typing import Any, Dict
4 |
5 | import numpy as np
6 | import zmq
7 |
8 | from robots.robot import Robot
9 |
10 | DEFAULT_ROBOT_PORT = 6000
11 |
12 |
13 | class ZMQServerRobot:
14 | def __init__(
15 | self,
16 | robot: Robot,
17 | port: int = DEFAULT_ROBOT_PORT,
18 | host: str = "127.0.0.1",
19 | ):
20 | self._robot = robot
21 | self._context = zmq.Context()
22 | self._socket = self._context.socket(zmq.REP)
23 | addr = f"tcp://{host}:{port}"
24 | debug_message = f"Robot Sever Binding to {addr}, Robot: {robot}"
25 | print(debug_message)
26 | self._timout_message = f"Timeout in Robot Server, Robot: {robot}"
27 | self._socket.bind(addr)
28 | self._stop_event = threading.Event()
29 |
30 | def serve(self) -> None:
31 | """Serve the leader robot state over ZMQ."""
32 | self._socket.setsockopt(zmq.RCVTIMEO, 1000) # Set timeout to 1000 ms
33 | while not self._stop_event.is_set():
34 | try:
35 | # Wait for next request from client
36 | message = self._socket.recv()
37 | request = pickle.loads(message)
38 |
39 | # Call the appropriate method based on the request
40 | method = request.get("method")
41 | args = request.get("args", {})
42 | result: Any
43 | if method == "num_dofs":
44 | result = self._robot.num_dofs()
45 | elif method == "get_joint_state":
46 | result = self._robot.get_joint_state()
47 | elif method == "command_joint_state":
48 | result = self._robot.command_joint_state(**args)
49 | elif method == "command_eef_pose":
50 | result = self._robot.command_eef_pose(**args)
51 | elif method == "get_observations":
52 | result = self._robot.get_observations()
53 | else:
54 | result = {"error": "Invalid method"}
55 | print(result)
56 | raise NotImplementedError(
57 | f"Invalid method: {method}, {args, result}"
58 | )
59 |
60 | self._socket.send(pickle.dumps(result))
61 | except zmq.Again:
62 | print(self._timout_message)
63 | # Timeout occurred, check if the stop event is set
64 |
65 | def stop(self) -> None:
66 | """Signal the server to stop serving."""
67 | self._stop_event.set()
68 |
69 |
70 | class ZMQClientRobot(Robot):
71 | """A class representing a ZMQ client for a leader robot."""
72 |
73 | def __init__(self, port: int = DEFAULT_ROBOT_PORT, host: str = "127.0.0.1"):
74 | self._context = zmq.Context()
75 | self._socket = self._context.socket(zmq.REQ)
76 | self._socket.connect(f"tcp://{host}:{port}")
77 |
78 | def num_dofs(self) -> int:
79 | """Get the number of joints in the robot.
80 |
81 | Returns:
82 | int: The number of joints in the robot.
83 | """
84 | request = {"method": "num_dofs"}
85 | send_message = pickle.dumps(request)
86 | self._socket.send(send_message)
87 | result = pickle.loads(self._socket.recv())
88 | return result
89 |
90 | def get_joint_state(self) -> np.ndarray:
91 | """Get the current state of the leader robot.
92 |
93 | Returns:
94 | T: The current state of the leader robot.
95 | """
96 | request = {"method": "get_joint_state"}
97 | send_message = pickle.dumps(request)
98 | self._socket.send(send_message)
99 | result = pickle.loads(self._socket.recv())
100 | return result
101 |
102 | def command_joint_state(self, joint_state: np.ndarray) -> None:
103 | """Command the leader robot to the given state.
104 |
105 | Args:
106 | joint_state (T): The state to command the leader robot to.
107 | """
108 | request = {
109 | "method": "command_joint_state",
110 | "args": {"joint_state": joint_state},
111 | }
112 | send_message = pickle.dumps(request)
113 | self._socket.send(send_message)
114 | result = pickle.loads(self._socket.recv())
115 | return result
116 |
117 | def command_eef_pose(self, eef_pose: np.ndarray) -> None:
118 | """Command the leader robot to the given state.
119 |
120 | Args:
121 | eef_pose: end effector pose command to step the environment with.
122 | """
123 | request = {
124 | "method": "command_eef_pose",
125 | "args": {"eef_pose": eef_pose},
126 | }
127 | send_message = pickle.dumps(request)
128 | self._socket.send(send_message)
129 | result = pickle.loads(self._socket.recv())
130 | return result
131 |
132 | def get_observations(self) -> Dict[str, np.ndarray]:
133 | """Get the current observations of the leader robot.
134 |
135 | Returns:
136 | Dict[str, np.ndarray]: The current observations of the leader robot.
137 | """
138 | request = {"method": "get_observations"}
139 | send_message = pickle.dumps(request)
140 | self._socket.send(send_message)
141 | result = pickle.loads(self._socket.recv())
142 | return result
143 |
--------------------------------------------------------------------------------
/robots/ability_gripper.py:
--------------------------------------------------------------------------------
1 | import platform
2 | import struct
3 | import threading
4 | import time
5 |
6 | import numpy as np
7 | import serial
8 | from serial.tools import list_ports
9 |
10 |
11 | ## Send Miscellanous Command to Ability Hand
12 | def create_misc_msg(cmd):
13 | barr = []
14 | barr.append((struct.pack(" 150:
167 | need_reset = True
168 |
169 | ## Extract Touch Data if Available
170 | if reply_len == 71:
171 | ## Extract Data two at a time
172 | for i in range(0, 15):
173 | dual_data = data[(i * 3) + 24 : ((i + 1) * 3) + 24]
174 | data1 = struct.unpack("> 4
178 | self.touch[i * 2] = int(data1)
179 | self.touch[(i * 2) + 1] = int(data2)
180 | if data1 > 4096 or data2 > 4096:
181 | need_reset = True
182 | else:
183 | need_reset = True
184 | else:
185 | need_reset = True
186 |
187 | if need_reset:
188 | self.ser.reset_input_buffer()
189 | self.reset_count += 1
190 | need_reset = False
191 | self.pos = self.prev_pos.copy()
192 | self.touch = self.prev_touch.copy()
193 |
194 | self.prev_pos = self.pos.copy()
195 | self.prev_touch = self.touch.copy()
196 |
197 | self.total_count += 1
198 |
199 | def _process_pos_cmd(self, pos_cmd):
200 | """
201 | pos_cmd: a list of floats [0, 1] normalized command from external input devices
202 | """
203 | assert len(pos_cmd) == 6
204 | positions = [1] * 6
205 | for i in range(6):
206 | positions[i] = self.lower_ranges[i] + pos_cmd[i] * (
207 | self.upper_ranges[i] - self.lower_ranges[i]
208 | )
209 | if i == 5:
210 | # invert thumb rotator
211 | positions[i] = -positions[i]
212 | return positions
213 |
214 | def _pos_to_cmd(self, pos):
215 | """
216 | pos: desired hand degrees
217 | """
218 | assert len(pos) == 6
219 | cmd = [0] * 6
220 | for i in range(6):
221 | if i == 5:
222 | pos[i] = -pos[i]
223 | cmd[i] = (pos[i] - self.lower_ranges[i]) / (
224 | self.upper_ranges[i] - self.lower_ranges[i]
225 | )
226 | return cmd
227 |
228 | def _generate_tx(self, positions):
229 | """
230 | Position control mode - reply format must be one of the 3 variants (0x10, 0x11, 0x12)
231 | """
232 | txBuf = []
233 |
234 | ## Address in byte 0
235 | txBuf.append((struct.pack("> 8) & 0xFF))[0])
245 |
246 | ## calculate checksum
247 | cksum = 0
248 | for b in txBuf:
249 | cksum = cksum + b
250 | cksum = (-cksum) & 0xFF
251 | txBuf.append((struct.pack(" 1000 means contact
286 | return self.last_pos_msg + self.pos + self.touch
287 |
288 | def get_current_touch(self):
289 | return self.touch
290 |
291 | def get_current_position(self):
292 | # pos readings
293 | return self.pos
294 |
--------------------------------------------------------------------------------
/robots/robot.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from typing import Dict, Protocol
3 |
4 | import numpy as np
5 |
6 |
7 | class Robot(Protocol):
8 | """Robot protocol.
9 |
10 | A protocol for a robot that can be controlled.
11 | """
12 |
13 | @abstractmethod
14 | def num_dofs(self) -> int:
15 | """Get the number of joints of the robot.
16 |
17 | Returns:
18 | int: The number of joints of the robot.
19 | """
20 | raise NotImplementedError
21 |
22 | @abstractmethod
23 | def get_joint_state(self) -> np.ndarray:
24 | """Get the current state of the leader robot.
25 |
26 | Returns:
27 | T: The current state of the leader robot.
28 | """
29 | raise NotImplementedError
30 |
31 | @abstractmethod
32 | def command_joint_state(self, joint_state: np.ndarray) -> None:
33 | """Command the leader robot to a given state.
34 |
35 | Args:
36 | joint_state (np.ndarray): The state to command the leader robot to.
37 | """
38 | raise NotImplementedError
39 |
40 | @abstractmethod
41 | def command_eef_pose(self, eef_pose: np.ndarray) -> None:
42 | """Command the leader robot to a given state.
43 |
44 | Args:
45 | eef_pose (np.ndarray): The EEF pose to command the leader robot to.
46 | """
47 | raise NotImplementedError
48 |
49 | @abstractmethod
50 | def get_observations(self) -> Dict[str, np.ndarray]:
51 | """Get the current observations of the robot.
52 |
53 | This is to extract all the information that is available from the robot,
54 | such as joint positions, joint velocities, etc. This may also include
55 | information from additional sensors, such as cameras, force sensors, etc.
56 |
57 | Returns:
58 | Dict[str, np.ndarray]: A dictionary of observations.
59 | """
60 | raise NotImplementedError
61 |
62 |
63 | class BimanualRobot(Robot):
64 | def __init__(self, robot_l: Robot, robot_r: Robot):
65 | self._robot_l = robot_l
66 | self._robot_r = robot_r
67 |
68 | def num_dofs(self) -> int:
69 | return self._robot_l.num_dofs() + self._robot_r.num_dofs()
70 |
71 | def get_joint_state(self) -> np.ndarray:
72 | return np.concatenate(
73 | (self._robot_l.get_joint_state(), self._robot_r.get_joint_state())
74 | )
75 |
76 | def command_joint_state(self, joint_state: np.ndarray) -> None:
77 | self._robot_l.command_joint_state(joint_state[: self._robot_l.num_dofs()])
78 | self._robot_r.command_joint_state(joint_state[self._robot_l.num_dofs() :])
79 |
80 | def command_eef_pose(self, eef_pose: np.ndarray) -> None:
81 | self._robot_l.command_eef_pose(eef_pose[: self._robot_l.num_dofs()])
82 | self._robot_r.command_eef_pose(eef_pose[self._robot_l.num_dofs() :])
83 |
84 | def get_observations(self) -> Dict[str, np.ndarray]:
85 | l_obs = self._robot_l.get_observations()
86 | r_obs = self._robot_r.get_observations()
87 | assert l_obs.keys() == r_obs.keys()
88 | return_obs = {}
89 | for k in l_obs.keys():
90 | try:
91 | return_obs[k] = np.concatenate((l_obs[k], r_obs[k]))
92 | except Exception as e:
93 | print(e)
94 | print(k)
95 | print(l_obs[k])
96 | print(r_obs[k])
97 | raise RuntimeError()
98 |
99 | return return_obs
100 |
--------------------------------------------------------------------------------
/robots/robotiq_gripper.py:
--------------------------------------------------------------------------------
1 | """Module to control Robotiq's grippers - tested with HAND-E.
2 |
3 | Taken from https://github.com/githubuser0xFFFF/py_robotiq_gripper/blob/master/src/robotiq_gripper.py
4 | """
5 |
6 | import socket
7 | import threading
8 | import time
9 | from enum import Enum
10 | from typing import OrderedDict, Tuple, Union
11 |
12 |
13 | class RobotiqGripper:
14 | """Communicates with the gripper directly, via socket with string commands, leveraging string names for variables."""
15 |
16 | # WRITE VARIABLES (CAN ALSO READ)
17 | ACT = (
18 | "ACT" # act : activate (1 while activated, can be reset to clear fault status)
19 | )
20 | GTO = (
21 | "GTO" # gto : go to (will perform go to with the actions set in pos, for, spe)
22 | )
23 | ATR = "ATR" # atr : auto-release (emergency slow move)
24 | ADR = (
25 | "ADR" # adr : auto-release direction (open(1) or close(0) during auto-release)
26 | )
27 | FOR = "FOR" # for : force (0-255)
28 | SPE = "SPE" # spe : speed (0-255)
29 | POS = "POS" # pos : position (0-255), 0 = open
30 | # READ VARIABLES
31 | STA = "STA" # status (0 = is reset, 1 = activating, 3 = active)
32 | PRE = "PRE" # position request (echo of last commanded position)
33 | OBJ = "OBJ" # object detection (0 = moving, 1 = outer grip, 2 = inner grip, 3 = no object at rest)
34 | FLT = "FLT" # fault (0=ok, see manual for errors if not zero)
35 |
36 | ENCODING = "UTF-8" # ASCII and UTF-8 both seem to work
37 |
38 | class GripperStatus(Enum):
39 | """Gripper status reported by the gripper. The integer values have to match what the gripper sends."""
40 |
41 | RESET = 0
42 | ACTIVATING = 1
43 | # UNUSED = 2 # This value is currently not used by the gripper firmware
44 | ACTIVE = 3
45 |
46 | class ObjectStatus(Enum):
47 | """Object status reported by the gripper. The integer values have to match what the gripper sends."""
48 |
49 | MOVING = 0
50 | STOPPED_OUTER_OBJECT = 1
51 | STOPPED_INNER_OBJECT = 2
52 | AT_DEST = 3
53 |
54 | def __init__(self):
55 | """Constructor."""
56 | self.socket = None
57 | self.command_lock = threading.Lock()
58 | self._min_position = 0
59 | self._max_position = 255
60 | self._min_speed = 0
61 | self._max_speed = 255
62 | self._min_force = 0
63 | self._max_force = 255
64 |
65 | def connect(self, hostname: str, port: int, socket_timeout: float = 10.0) -> None:
66 | """Connects to a gripper at the given address.
67 |
68 | :param hostname: Hostname or ip.
69 | :param port: Port.
70 | :param socket_timeout: Timeout for blocking socket operations.
71 | """
72 | self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
73 | assert self.socket is not None
74 | self.socket.connect((hostname, port))
75 | self.socket.settimeout(socket_timeout)
76 |
77 | def disconnect(self) -> None:
78 | """Closes the connection with the gripper."""
79 | assert self.socket is not None
80 | self.socket.close()
81 |
82 | def _set_vars(self, var_dict: OrderedDict[str, Union[int, float]]):
83 | """Sends the appropriate command via socket to set the value of n variables, and waits for its 'ack' response.
84 |
85 | :param var_dict: Dictionary of variables to set (variable_name, value).
86 | :return: True on successful reception of ack, false if no ack was received, indicating the set may not
87 | have been effective.
88 | """
89 | assert self.socket is not None
90 | # construct unique command
91 | cmd = "SET"
92 | for variable, value in var_dict.items():
93 | cmd += f" {variable} {str(value)}"
94 | cmd += "\n" # new line is required for the command to finish
95 | # atomic commands send/rcv
96 | with self.command_lock:
97 | self.socket.sendall(cmd.encode(self.ENCODING))
98 | data = self.socket.recv(1024)
99 | return self._is_ack(data)
100 |
101 | def _set_var(self, variable: str, value: Union[int, float]):
102 | """Sends the appropriate command via socket to set the value of a variable, and waits for its 'ack' response.
103 |
104 | :param variable: Variable to set.
105 | :param value: Value to set for the variable.
106 | :return: True on successful reception of ack, false if no ack was received, indicating the set may not
107 | have been effective.
108 | """
109 | return self._set_vars(OrderedDict([(variable, value)]))
110 |
111 | def _get_var(self, variable: str):
112 | """Sends the appropriate command to retrieve the value of a variable from the gripper, blocking until the response is received or the socket times out.
113 |
114 | :param variable: Name of the variable to retrieve.
115 | :return: Value of the variable as integer.
116 | """
117 | assert self.socket is not None
118 | # atomic commands send/rcv
119 | with self.command_lock:
120 | cmd = f"GET {variable}\n"
121 | self.socket.sendall(cmd.encode(self.ENCODING))
122 | data = self.socket.recv(1024)
123 |
124 | # expect data of the form 'VAR x', where VAR is an echo of the variable name, and X the value
125 | # note some special variables (like FLT) may send 2 bytes, instead of an integer. We assume integer here
126 | var_name, value_str = data.decode(self.ENCODING).split()
127 | if var_name != variable:
128 | raise ValueError(
129 | f"Unexpected response {data} ({data.decode(self.ENCODING)}): does not match '{variable}'"
130 | )
131 | value = int(value_str)
132 | return value
133 |
134 | @staticmethod
135 | def _is_ack(data: str):
136 | return data == b"ack"
137 |
138 | def _reset(self):
139 | """Reset the gripper.
140 |
141 | The following code is executed in the corresponding script function
142 | def rq_reset(gripper_socket="1"):
143 | rq_set_var("ACT", 0, gripper_socket)
144 | rq_set_var("ATR", 0, gripper_socket)
145 |
146 | while(not rq_get_var("ACT", 1, gripper_socket) == 0 or not rq_get_var("STA", 1, gripper_socket) == 0):
147 | rq_set_var("ACT", 0, gripper_socket)
148 | rq_set_var("ATR", 0, gripper_socket)
149 | sync()
150 | end
151 |
152 | sleep(0.5)
153 | end
154 | """
155 | self._set_var(self.ACT, 0)
156 | self._set_var(self.ATR, 0)
157 | while not self._get_var(self.ACT) == 0 or not self._get_var(self.STA) == 0:
158 | self._set_var(self.ACT, 0)
159 | self._set_var(self.ATR, 0)
160 | time.sleep(0.5)
161 |
162 | def activate(self, auto_calibrate: bool = True):
163 | """Resets the activation flag in the gripper, and sets it back to one, clearing previous fault flags.
164 |
165 | :param auto_calibrate: Whether to calibrate the minimum and maximum positions based on actual motion.
166 |
167 | The following code is executed in the corresponding script function
168 | def rq_activate(gripper_socket="1"):
169 | if (not rq_is_gripper_activated(gripper_socket)):
170 | rq_reset(gripper_socket)
171 |
172 | while(not rq_get_var("ACT", 1, gripper_socket) == 0 or not rq_get_var("STA", 1, gripper_socket) == 0):
173 | rq_reset(gripper_socket)
174 | sync()
175 | end
176 |
177 | rq_set_var("ACT",1, gripper_socket)
178 | end
179 | end
180 |
181 | def rq_activate_and_wait(gripper_socket="1"):
182 | if (not rq_is_gripper_activated(gripper_socket)):
183 | rq_activate(gripper_socket)
184 | sleep(1.0)
185 |
186 | while(not rq_get_var("ACT", 1, gripper_socket) == 1 or not rq_get_var("STA", 1, gripper_socket) == 3):
187 | sleep(0.1)
188 | end
189 |
190 | sleep(0.5)
191 | end
192 | end
193 | """
194 | if not self.is_active():
195 | self._reset()
196 | while not self._get_var(self.ACT) == 0 or not self._get_var(self.STA) == 0:
197 | time.sleep(0.01)
198 |
199 | self._set_var(self.ACT, 1)
200 | time.sleep(1.0)
201 | while not self._get_var(self.ACT) == 1 or not self._get_var(self.STA) == 3:
202 | time.sleep(0.01)
203 |
204 | # auto-calibrate position range if desired
205 | if auto_calibrate:
206 | self.auto_calibrate()
207 |
208 | def is_active(self):
209 | """Returns whether the gripper is active."""
210 | status = self._get_var(self.STA)
211 | return (
212 | RobotiqGripper.GripperStatus(status) == RobotiqGripper.GripperStatus.ACTIVE
213 | )
214 |
215 | def get_min_position(self) -> int:
216 | """Returns the minimum position the gripper can reach (open position)."""
217 | return self._min_position
218 |
219 | def get_max_position(self) -> int:
220 | """Returns the maximum position the gripper can reach (closed position)."""
221 | return self._max_position
222 |
223 | def get_open_position(self) -> int:
224 | """Returns what is considered the open position for gripper (minimum position value)."""
225 | return self.get_min_position()
226 |
227 | def get_closed_position(self) -> int:
228 | """Returns what is considered the closed position for gripper (maximum position value)."""
229 | return self.get_max_position()
230 |
231 | def is_open(self):
232 | """Returns whether the current position is considered as being fully open."""
233 | return self.get_current_position() <= self.get_open_position()
234 |
235 | def is_closed(self):
236 | """Returns whether the current position is considered as being fully closed."""
237 | return self.get_current_position() >= self.get_closed_position()
238 |
239 | def get_current_position(self) -> int:
240 | """Returns the current position as returned by the physical hardware."""
241 | return self._get_var(self.POS)
242 |
243 | def auto_calibrate(self, log: bool = True) -> None:
244 | """Attempts to calibrate the open and closed positions, by slowly closing and opening the gripper.
245 |
246 | :param log: Whether to print the results to log.
247 | """
248 | # first try to open in case we are holding an object
249 | (position, status) = self.move_and_wait_for_pos(self.get_open_position(), 64, 1)
250 | if RobotiqGripper.ObjectStatus(status) != RobotiqGripper.ObjectStatus.AT_DEST:
251 | raise RuntimeError(f"Calibration failed opening to start: {str(status)}")
252 |
253 | # try to close as far as possible, and record the number
254 | (position, status) = self.move_and_wait_for_pos(
255 | self.get_closed_position(), 64, 1
256 | )
257 | if RobotiqGripper.ObjectStatus(status) != RobotiqGripper.ObjectStatus.AT_DEST:
258 | raise RuntimeError(
259 | f"Calibration failed because of an object: {str(status)}"
260 | )
261 | assert position <= self._max_position
262 | self._max_position = position
263 |
264 | # try to open as far as possible, and record the number
265 | (position, status) = self.move_and_wait_for_pos(self.get_open_position(), 64, 1)
266 | if RobotiqGripper.ObjectStatus(status) != RobotiqGripper.ObjectStatus.AT_DEST:
267 | raise RuntimeError(
268 | f"Calibration failed because of an object: {str(status)}"
269 | )
270 | assert position >= self._min_position
271 | self._min_position = position
272 |
273 | if log:
274 | print(
275 | f"Gripper auto-calibrated to [{self.get_min_position()}, {self.get_max_position()}]"
276 | )
277 |
278 | def move(self, position: int, speed: int, force: int) -> Tuple[bool, int]:
279 | """Sends commands to start moving towards the given position, with the specified speed and force.
280 |
281 | :param position: Position to move to [min_position, max_position]
282 | :param speed: Speed to move at [min_speed, max_speed]
283 | :param force: Force to use [min_force, max_force]
284 | :return: A tuple with a bool indicating whether the action it was successfully sent, and an integer with
285 | the actual position that was requested, after being adjusted to the min/max calibrated range.
286 | """
287 |
288 | def clip_val(min_val, val, max_val):
289 | return max(min_val, min(val, max_val))
290 |
291 | clip_pos = clip_val(self._min_position, position, self._max_position)
292 | clip_spe = clip_val(self._min_speed, speed, self._max_speed)
293 | clip_for = clip_val(self._min_force, force, self._max_force)
294 |
295 | # moves to the given position with the given speed and force
296 | var_dict = OrderedDict(
297 | [
298 | (self.POS, clip_pos),
299 | (self.SPE, clip_spe),
300 | (self.FOR, clip_for),
301 | (self.GTO, 1),
302 | ]
303 | )
304 | succ = self._set_vars(var_dict)
305 | time.sleep(0.008) # need to wait (dont know why)
306 | return succ, clip_pos
307 |
308 | def move_and_wait_for_pos(
309 | self, position: int, speed: int, force: int
310 | ) -> Tuple[int, ObjectStatus]: # noqa
311 | """Sends commands to start moving towards the given position, with the specified speed and force, and then waits for the move to complete.
312 |
313 | :param position: Position to move to [min_position, max_position]
314 | :param speed: Speed to move at [min_speed, max_speed]
315 | :param force: Force to use [min_force, max_force]
316 | :return: A tuple with an integer representing the last position returned by the gripper after it notified
317 | that the move had completed, a status indicating how the move ended (see ObjectStatus enum for details). Note
318 | that it is possible that the position was not reached, if an object was detected during motion.
319 | """
320 | set_ok, cmd_pos = self.move(position, speed, force)
321 | if not set_ok:
322 | raise RuntimeError("Failed to set variables for move.")
323 |
324 | # wait until the gripper acknowledges that it will try to go to the requested position
325 | while self._get_var(self.PRE) != cmd_pos:
326 | time.sleep(0.001)
327 |
328 | # wait until not moving
329 | cur_obj = self._get_var(self.OBJ)
330 | while (
331 | RobotiqGripper.ObjectStatus(cur_obj) == RobotiqGripper.ObjectStatus.MOVING
332 | ):
333 | cur_obj = self._get_var(self.OBJ)
334 |
335 | # report the actual position and the object status
336 | final_pos = self._get_var(self.POS)
337 | final_obj = cur_obj
338 | return final_pos, RobotiqGripper.ObjectStatus(final_obj)
339 |
340 |
--------------------------------------------------------------------------------
/robots/ur.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import numpy as np
4 |
5 | from robots.robot import Robot
6 |
7 |
8 | class URRobot(Robot):
9 | """A class representing a UR robot."""
10 |
11 | def __init__(
12 | self,
13 | robot_ip: str = "111.111.1.11",
14 | no_gripper: bool = False,
15 | gripper_type="",
16 | grip_range=110,
17 | port_idx=-1,
18 | ):
19 | import rtde_control
20 | import rtde_receive
21 |
22 | [print("in ur robot:", robot_ip) for _ in range(3)]
23 | self.robot = rtde_control.RTDEControlInterface(robot_ip)
24 | self.r_inter = rtde_receive.RTDEReceiveInterface(robot_ip)
25 | if not no_gripper:
26 | if gripper_type == "ability":
27 | from robots.ability_gripper import AbilityGripper
28 |
29 | self.gripper = AbilityGripper(port_idx=port_idx, grip_range=grip_range)
30 | self.gripper.connect()
31 | else:
32 | from robots.robotiq_gripper import RobotiqGripper
33 |
34 | self.gripper = RobotiqGripper()
35 | self.gripper.connect(hostname=robot_ip, port=63352)
36 |
37 | [print("connect") for _ in range(3)]
38 |
39 | self._free_drive = False
40 | self.robot.endFreedriveMode()
41 | self._use_gripper = not no_gripper
42 | self.gripper_type = gripper_type
43 |
44 | self.velocity = 0.5
45 | self.acceleration = 0.5
46 |
47 | # EEF
48 | self.velocity_l = 0.3
49 | self.acceleration_l = 0.3
50 | self.dt = 1.0 / 500 # 2ms
51 | self.lookahead_time = 0.2 # [0.03, 0.2]s smoothens the trajectory
52 | self.gain = 100 # [100, 2000] proportional gain for following target position
53 |
54 | def num_dofs(self) -> int:
55 | """Get the number of joints of the robot.
56 |
57 | Returns:
58 | int: The number of joints of the robot.
59 | """
60 | if self._use_gripper:
61 | if self.gripper_type == "ability":
62 | return 12
63 | else:
64 | return 7
65 | return 6
66 |
67 | def _get_gripper_pos(self) -> float:
68 | if self.gripper_type in ["ability"]:
69 | gripper_pos = self.gripper.get_current_position()
70 | return gripper_pos
71 | else:
72 | gripper_pos = self.gripper.get_current_position()
73 | assert 0 <= gripper_pos <= 255, "Gripper position must be between 0 and 255"
74 | return gripper_pos / 255
75 |
76 | def get_joint_state(self) -> np.ndarray:
77 | """Get the current state of the leader robot.
78 |
79 | Returns:
80 | T: The current state of the leader robot.
81 | """
82 | robot_joints = self.r_inter.getActualQ()
83 | if self._use_gripper:
84 | gripper_pos = self._get_gripper_pos()
85 | pos = np.append(robot_joints, gripper_pos)
86 | else:
87 | pos = robot_joints
88 | return pos
89 |
90 | def get_joint_velocities(self) -> np.ndarray:
91 | return self.r_inter.getActualQd()
92 |
93 | def get_eef_speed(self) -> np.ndarray:
94 | return self.r_inter.getActualTCPSpeed()
95 |
96 | def get_eef_pose(self) -> np.ndarray:
97 | """Get the current pose of the leader robot's end effector.
98 |
99 | Returns:
100 | T: The current pose of the leader robot's end effector.
101 | """
102 | return self.r_inter.getActualTCPPose()
103 |
104 | def command_joint_state(self, joint_state: np.ndarray) -> None:
105 | """Command the leader robot to a given state.
106 |
107 | Args:
108 | joint_state (np.ndarray): The state to command the leader robot to.
109 | """
110 | robot_joints = joint_state[:6]
111 | t_start = self.robot.initPeriod()
112 | self.robot.servoJ(
113 | robot_joints,
114 | self.velocity,
115 | self.acceleration,
116 | self.dt,
117 | self.lookahead_time,
118 | self.gain,
119 | )
120 | if self._use_gripper:
121 | if self.gripper_type == "ability":
122 | assert (
123 | max(joint_state[6:]) <= 1 and min(joint_state[6:]) >= 0
124 | ), f"Gripper position must be between 0 and 1:{joint_state[6:]}"
125 | self.gripper.move(joint_state[6:], debug=False)
126 | elif self.gripper_type == "allegro":
127 | self.gripper.move(joint_state[6:])
128 | elif self.gripper_type == "dummy":
129 | pass
130 | else:
131 | gripper_pos = joint_state[-1] * 255
132 | # print(f"gripper move command: {gripper_pos}")
133 | self.gripper.move(int(gripper_pos), 255, 10)
134 | self.robot.waitPeriod(t_start)
135 |
136 | def command_eef_pose(self, eef_pose: np.ndarray) -> None:
137 | """Command the leader robot to a given state.
138 |
139 | Args:
140 | eef_pose (np.ndarray): The EEF pose to command the leader robot to.
141 | """
142 | pose_command = eef_pose[:6]
143 | # print("current TCP:", self.r_inter.getActualTCPPose())
144 | # print("pose_command:", pose_command)
145 | # input("press enter to continue")
146 | t_start = self.robot.initPeriod()
147 | self.robot.servoL(
148 | pose_command,
149 | self.velocity_l,
150 | self.acceleration_l,
151 | self.dt,
152 | self.lookahead_time,
153 | self.gain,
154 | )
155 | if self._use_gripper:
156 | if self.gripper_type == "ability":
157 | assert (
158 | max(eef_pose[6:]) <= 1 and min(eef_pose[6:]) >= 0
159 | ), f"Gripper position must be between 0 and 1:{eef_pose[6:]}"
160 | self.gripper.move(eef_pose[6:])
161 | else:
162 | gripper_pos = eef_pose[-1] * 255
163 | # print(f"gripper move command: {gripper_pos}")
164 | self.gripper.move(int(gripper_pos), 255, 10)
165 | self.robot.waitPeriod(t_start)
166 |
167 | def freedrive_enabled(self) -> bool:
168 | """Check if the robot is in freedrive mode.
169 |
170 | Returns:
171 | bool: True if the robot is in freedrive mode, False otherwise.
172 | """
173 | return self._free_drive
174 |
175 | def set_freedrive_mode(self, enable: bool) -> None:
176 | """Set the freedrive mode of the robot.
177 |
178 | Args:
179 | enable (bool): True to enable freedrive mode, False to disable it.
180 | """
181 | if enable and not self._free_drive:
182 | self._free_drive = True
183 | self.robot.freedriveMode()
184 | elif not enable and self._free_drive:
185 | self._free_drive = False
186 | self.robot.endFreedriveMode()
187 |
188 | def get_observations(self) -> Dict[str, np.ndarray]:
189 | joints = self.get_joint_state()
190 | joint_vels = self.get_joint_velocities()
191 | eef_speed = self.get_eef_speed()
192 | pos_quat = self.get_eef_pose()
193 | gripper_pos = np.array([joints[-1]])
194 | if self._use_gripper and self.gripper_type == "ability":
195 | # include Ability hand touch data
196 | touch = self.gripper.get_current_touch()
197 | else:
198 | touch = np.zeros(30)
199 | return {
200 | "joint_positions": joints,
201 | "joint_velocities": joint_vels,
202 | "eef_speed": eef_speed,
203 | "ee_pos_quat": pos_quat, # TODO: this is pos_rot actually
204 | "gripper_position": gripper_pos,
205 | "touch": touch,
206 | }
207 |
208 |
209 |
--------------------------------------------------------------------------------
/run_env.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import pickle
4 | import time
5 | from dataclasses import dataclass
6 | from pathlib import Path
7 | from typing import Dict
8 |
9 | import cv2
10 | import numpy as np
11 | import termcolor
12 | import tyro
13 |
14 | # foot pedal
15 | from pynput import keyboard
16 |
17 | from agents.agent import BimanualAgent, SafetyWrapper
18 | from camera_node import ZMQClientCamera
19 | from env import RobotEnv
20 | from robot_node import ZMQClientRobot
21 |
22 | trigger_state = {"l": False, "r": False}
23 |
24 |
25 | def listen_key(key):
26 | global trigger_state
27 | try:
28 | trigger_state[key.char] = True
29 | except:
30 | pass
31 |
32 |
33 | def reset_key(key):
34 | global trigger_state
35 | try:
36 | trigger_state[key.char] = False
37 | except:
38 | pass
39 |
40 |
41 | listener = keyboard.Listener(on_press=listen_key)
42 | listener2 = keyboard.Listener(on_release=reset_key)
43 | listener.start()
44 | listener2.start()
45 |
46 | ###
47 |
48 |
49 | def count_folders(path):
50 | """Counts the number of folders under the given path."""
51 | folder_count = 0
52 | for root, dirs, files in os.walk(path):
53 | folder_count += len(dirs) # Count directories only at current level
54 | break # Prevents descending into subdirectories
55 | return folder_count
56 |
57 |
58 | def print_color(*args, color=None, attrs=(), **kwargs):
59 | if len(args) > 0:
60 | args = tuple(termcolor.colored(arg, color=color, attrs=attrs) for arg in args)
61 | print(*args, **kwargs)
62 |
63 |
64 | def save_frame(
65 | folder: Path,
66 | timestamp: datetime.datetime,
67 | obs: Dict[str, np.ndarray],
68 | action: np.ndarray,
69 | activated=True,
70 | save_png=False,
71 | ) -> None:
72 | obs["activated"] = activated
73 | obs["control"] = action # add action to obs
74 | recorded_file = folder / (
75 | timestamp.isoformat().replace(":", "-").replace(".", "-") + ".pkl"
76 | )
77 | with open(recorded_file, "wb") as f:
78 | pickle.dump(obs, f)
79 |
80 | # save rgb image as png
81 | if save_png:
82 | rgb = obs["base_rgb"]
83 | for i in range(rgb.shape[0]):
84 | rgbi = cv2.cvtColor(rgb[i], cv2.COLOR_RGB2BGR)
85 | fn = str(recorded_file)[:-4] + f"-{i}.png"
86 | cv2.imwrite(fn, rgbi)
87 |
88 |
89 | @dataclass
90 | class Args:
91 | robot_port: int = 6000
92 | wrist_camera_port: int = 5001
93 | base_camera_port: int = 5000
94 | hostname: str = "111.0.0.1"
95 | hz: int = 100
96 | show_camera_view: bool = True
97 |
98 | agent: str = "quest"
99 | robot_type: str = "ur5"
100 | save_data: bool = False
101 | save_depth: bool = True
102 | save_png: bool = False
103 | data_dir: str = "/shared/data/bc_data"
104 | verbose: bool = False
105 | safe: bool = False
106 | use_vel_ik: bool = False
107 |
108 | num_diffusion_iters_compile: int = 15 # used for compilation only for now
109 | jit_compile: bool = False # send the compilation signal to the server (only need to do this once per inference server run).
110 | use_jit_agent: bool = False # use the inference server to get actions. The inference_agent_port and the inference_agent_host need to be set to the proper values.
111 | inference_agent_port: str = (
112 | "1234" # port must be the same as the inference server port
113 | )
114 | inference_agent_host = "111.11.111.11" # ip of the inference server (localhost if running locally; currently defaults to bt) inference server needs to use the same checkpoint folder when launching the inference node (args need to match)
115 |
116 | dp_ckpt_path: str = "/shared/ckpts/best.ckpt"
117 |
118 | temporal_ensemble_mode: str = "avg"
119 | temporal_ensemble_act_tau: float = 0.5
120 |
121 |
122 | def main(args):
123 | camera_clients = {
124 | "base": ZMQClientCamera(port=args.base_camera_port, host=args.hostname),
125 | }
126 | robot_client = ZMQClientRobot(port=args.robot_port, host=args.hostname)
127 | env = RobotEnv(
128 | robot_client,
129 | control_rate_hz=args.hz,
130 | camera_dict=camera_clients,
131 | show_camera_view=args.show_camera_view,
132 | save_depth=args.save_depth,
133 | )
134 |
135 | if args.agent == "quest":
136 | from agents.quest_agent import SingleArmQuestAgent
137 |
138 | left_agent = SingleArmQuestAgent(robot_type=args.robot_type, which_hand="l")
139 | right_agent = SingleArmQuestAgent(robot_type=args.robot_type, which_hand="r")
140 | agent = BimanualAgent(left_agent, right_agent)
141 | print("Quest agent created")
142 | elif args.agent == "quest_hand":
143 | # some custom mapping from Quest controller to hand control
144 | from agents.quest_agent import (
145 | DualArmQuestAgent,
146 | SingleArmQuestAgent,
147 | )
148 |
149 | left_agent = SingleArmQuestAgent(
150 | robot_type=args.robot_type,
151 | which_hand="l",
152 | eef_control_mode=3,
153 | use_vel_ik=args.use_vel_ik,
154 | )
155 | right_agent = SingleArmQuestAgent(
156 | robot_type=args.robot_type,
157 | which_hand="r",
158 | eef_control_mode=3,
159 | use_vel_ik=args.use_vel_ik,
160 | )
161 | agent = DualArmQuestAgent(left_agent, right_agent)
162 | print("Quest agent created")
163 | elif args.agent == "quest_hand_eef":
164 | # some custom mapping from Quest controller to hand control
165 | from agents.quest_agent_eef import (
166 | DualArmQuestAgent,
167 | SingleArmQuestAgent,
168 | )
169 |
170 | left_agent = SingleArmQuestAgent(
171 | robot_type=args.robot_type,
172 | which_hand="l",
173 | eef_control_mode=3,
174 | )
175 | right_agent = SingleArmQuestAgent(
176 | robot_type=args.robot_type,
177 | which_hand="r",
178 | eef_control_mode=3,
179 | )
180 | agent = DualArmQuestAgent(left_agent, right_agent)
181 | print("Quest EEF agent created")
182 | elif args.agent in ["dp", "dp_eef"]:
183 | if args.use_jit_agent:
184 | from agents.dp_agent_zmq import BimanualDPAgent
185 |
186 | agent = BimanualDPAgent(
187 | ckpt_path=args.dp_ckpt_path,
188 | port=args.inference_agent_port,
189 | host=args.inference_agent_host,
190 | temporal_ensemble_act_tau=args.temporal_ensemble_act_tau,
191 | temporal_ensemble_mode=args.temporal_ensemble_mode,
192 | )
193 | else:
194 | from agents.dp_agent import BimanualDPAgent
195 |
196 | agent = BimanualDPAgent(ckpt_path=args.dp_ckpt_path)
197 | else:
198 | raise ValueError(f"Invalid agent name for bimanual: {args.agent}")
199 |
200 | if args.agent == "quest":
201 | # using grippers
202 | reset_joints_left = np.deg2rad([-80, -140, -80, -85, -10, 80, 0])
203 | reset_joints_right = np.deg2rad([-270, -30, 70, -85, 10, 0, 0])
204 | else:
205 | # using Ability hands
206 | arm_joints_left = [-80, -140, -80, -85, -10, 80]
207 | arm_joints_right = [-270, -30, 70, -85, 10, 0]
208 | hand_joints = [0, 0, 0, 0, 0.5, 0.5]
209 | reset_joints_left = np.concatenate([np.deg2rad(arm_joints_left), hand_joints])
210 | reset_joints_right = np.concatenate([np.deg2rad(arm_joints_right), hand_joints])
211 | reset_joints = np.concatenate([reset_joints_left, reset_joints_right])
212 | curr_joints = env.get_obs()["joint_positions"]
213 | curr_joints[6:12] = hand_joints
214 | curr_joints[18:] = hand_joints
215 | print("Current joints:", curr_joints)
216 | print("Reset joints:", reset_joints)
217 | max_delta = (np.abs(curr_joints - reset_joints)).max()
218 | steps = min(int(max_delta / 0.01), 20)
219 | for jnt in np.linspace(curr_joints, reset_joints, steps):
220 | env.step(jnt)
221 |
222 | obs = env.get_obs()
223 |
224 | if args.jit_compile:
225 | agent.compile_inference(
226 | obs, num_diffusion_iters=args.num_diffusion_iters_compile
227 | )
228 |
229 | # going to start position
230 | print("Going to start position")
231 | start_pos = agent.act(obs)
232 | obs = env.get_obs()
233 | joints = obs["joint_positions"]
234 |
235 | if args.agent == "quest":
236 | ur_idx = [i for i in range(len(joints))]
237 | hand_idx = None
238 | else:
239 | ur_idx = list(range(0, 6)) + list(range(12, 18))
240 | hand_idx = list(range(6, 12)) + list(range(18, 24))
241 |
242 | if args.safe:
243 | max_joint_delta = 0.5
244 | max_hand_delta = 0.1
245 | safety_wrapper = SafetyWrapper(
246 | ur_idx, hand_idx, agent, delta=max_joint_delta, hand_delta=max_hand_delta
247 | )
248 |
249 | print(f"Start pos: {len(start_pos)}", f"Joints: {len(joints)}")
250 | assert len(start_pos) == len(
251 | joints
252 | ), f"agent output dim = {len(start_pos)}, but env dim = {len(joints)}"
253 |
254 | print(f"Collecting traj no.{count_folders(args.data_dir) + 1}")
255 |
256 | # time.sleep(2.0)
257 | while not trigger_state["r"]:
258 | print(">>> Step on right")
259 | time.sleep(0.2)
260 |
261 | print_color("\nReady to go 🚀🚀🚀", color="green", attrs=("bold",))
262 |
263 | start_time = time.time()
264 |
265 | if args.save_data:
266 | time_str = datetime.datetime.now().strftime("%m%d_%H%M%S")
267 | if args.agent.startswith("dp"):
268 | # eval
269 | save_path = (
270 | Path(args.data_dir).expanduser()
271 | / "_".join(
272 | [
273 | args.dp_ckpt_path.split("/")[-2],
274 | args.dp_ckpt_path.split("/")[-1][:-5],
275 | ]
276 | )
277 | / time_str
278 | )
279 | else:
280 | save_path = Path(args.data_dir).expanduser() / time_str
281 | save_path.mkdir(parents=True, exist_ok=True)
282 | print(f"Saving to {save_path}")
283 |
284 | is_first_frame = True
285 | try:
286 | frame_freq = []
287 | while True:
288 | new_start_time = time.time()
289 | num = new_start_time - start_time
290 | message = f"\rTime passed: {round(num, 2)} "
291 | print_color(
292 | message,
293 | color="white",
294 | attrs=("bold",),
295 | end="",
296 | flush=True,
297 | )
298 | if args.safe:
299 | action = safety_wrapper.act_safe(
300 | agent, obs, eef=(args.agent.endswith("_eef"))
301 | )
302 | else:
303 | action = agent.act(obs)
304 | dt = datetime.datetime.now()
305 |
306 | if args.save_data:
307 | if is_first_frame:
308 | is_first_frame = False
309 | else:
310 | save_frame(
311 | save_path,
312 | dt,
313 | obs,
314 | action,
315 | activated=agent.trigger_state,
316 | save_png=args.save_png,
317 | )
318 |
319 | if args.agent.endswith("_eef"):
320 | obs = env.step_eef(action)
321 | else:
322 | obs = env.step(action)
323 |
324 | ff = 1 / (time.time() - new_start_time)
325 | frame_freq.append(ff)
326 |
327 | if trigger_state["l"]:
328 | print_color("\nTriggered!", color="red", attrs=("bold",))
329 | break
330 |
331 | except KeyboardInterrupt:
332 | print_color("\nInterrupted!", color="red", attrs=("bold",))
333 | finally:
334 | if "dp" in args.agent:
335 | import glob
336 |
337 | from moviepy.editor import ImageSequenceClip
338 |
339 | # find all the pkl files in the episode directory
340 | pkls = sorted(glob.glob(os.path.join(save_path, "*.pkl")))
341 | print("Total number of pkls: ", len(pkls))
342 | frames = []
343 | for pkl in pkls:
344 | with open(pkl, "rb") as f:
345 | try:
346 | data = pickle.load(f)
347 | except:
348 | continue
349 | rgb = data["base_rgb"]
350 | rgb = np.concatenate([rgb[i] for i in range(rgb.shape[0])], axis=1)
351 | frames.append(rgb)
352 | clip = ImageSequenceClip(frames, fps=5)
353 | ckpt_path = os.path.dirname(args.dp_ckpt_path)
354 | parent_name = os.path.basename(ckpt_path)
355 | clip.write_videofile(
356 | os.path.join(ckpt_path, f"{parent_name}_{time_str}.mp4")
357 | )
358 |
359 | # save frame freq as txt
360 | with open(os.path.join(ckpt_path, f"freq_{time_str}.txt"), "w") as f:
361 | for step, freq in enumerate(frame_freq):
362 | f.write(f"{step}: {freq}\n")
363 | else:
364 | print("Done")
365 |
366 | # save frame freq as txt
367 | with open(save_path / "freq.txt", "w") as f:
368 | f.write(
369 | f"Average FPS: {np.mean(frame_freq[1:])}\n"
370 | f"Max FPS: {np.max(frame_freq[1:])}\n"
371 | f"Min FPS: {np.min(frame_freq[1:])}\n"
372 | f"Std FPS: {np.std(frame_freq[1:])}\n\n"
373 | )
374 | for step, freq in enumerate(frame_freq):
375 | f.write(f"{step}: {freq}\n")
376 |
377 | os._exit(0)
378 |
379 |
380 | if __name__ == "__main__":
381 | main(tyro.cli(Args))
382 |
--------------------------------------------------------------------------------
/run_openloop.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import pickle
3 | import time
4 | from dataclasses import dataclass
5 | from pathlib import Path
6 | from typing import Dict
7 |
8 | import numpy as np
9 | import tyro
10 |
11 | from agents.agent import SafetyWrapper
12 | from camera_node import ZMQClientCamera
13 | from env import EvalRobotEnv
14 | from robot_node import ZMQClientRobot
15 |
16 |
17 | def print_color(*args, color=None, attrs=(), **kwargs):
18 | import termcolor
19 |
20 | if len(args) > 0:
21 | args = tuple(termcolor.colored(arg, color=color, attrs=attrs) for arg in args)
22 | print(*args, **kwargs)
23 |
24 |
25 | def save_frame(
26 | folder: Path,
27 | timestamp: datetime.datetime,
28 | obs: Dict[str, np.ndarray],
29 | action: np.ndarray,
30 | activated=True,
31 | ) -> None:
32 | obs["activated"] = activated
33 | obs["control"] = action # add action to obs
34 | recorded_file = folder / (timestamp.isoformat() + ".pkl")
35 | with open(recorded_file, "wb") as f:
36 | pickle.dump(obs, f)
37 |
38 |
39 | @dataclass
40 | class Args:
41 | robot_port: int = 6000
42 | wrist_camera_port: int = 5001
43 | base_camera_port: int = 5000
44 | hostname: str = "127.0.0.1"
45 | hz: int = 100
46 | show_camera_view: bool = True
47 |
48 | agent: str = "dp"
49 | robot_type: str = "ur5"
50 | hand_type: str = "ability"
51 | save_data: bool = False
52 | data_dir: str = "/shared/data/bc_data"
53 | verbose: bool = False
54 | safe: bool = False
55 | use_vel_ik: bool = False
56 |
57 | traj_path: str = "/shared/data/test_data"
58 | dp_ckpt_path: str = "best.ckpt"
59 |
60 |
61 | def main(args):
62 | camera_clients = {
63 | "base": ZMQClientCamera(port=args.base_camera_port, host=args.hostname),
64 | }
65 | robot_client = ZMQClientRobot(port=args.robot_port, host=args.hostname)
66 | env = EvalRobotEnv(
67 | robot_client,
68 | traj_path=args.traj_path,
69 | control_rate_hz=args.hz,
70 | camera_dict=camera_clients,
71 | )
72 | if args.agent.startswith("dp"):
73 | from agents.dp_agent import BimanualDPAgent
74 |
75 | agent = BimanualDPAgent(ckpt_path=args.dp_ckpt_path)
76 | else:
77 | raise ValueError(f"Invalid agent name: {args.agent}")
78 |
79 | if args.hand_type == "ability":
80 | arm_joints_left = [-80, -140, -80, -85, -10, 80]
81 | arm_joints_right = [-270, -30, 70, -85, 10, 0]
82 | hand_joints = [0, 0, 0, 0, 0.5, 0.5]
83 | else:
84 | raise ValueError(f"Invalid hand type: {args.hand_type}")
85 | reset_joints_left = np.concatenate([np.deg2rad(arm_joints_left), hand_joints])
86 | reset_joints_right = np.concatenate([np.deg2rad(arm_joints_right), hand_joints])
87 | reset_joints = np.concatenate([reset_joints_left, reset_joints_right])
88 | curr_joints = env.get_real_obs()["joint_positions"]
89 | if args.hand_type == "ability":
90 | curr_joints[6:12] = hand_joints
91 | curr_joints[18:] = hand_joints
92 | print("Current joints:", curr_joints)
93 | print("Reset joints:", reset_joints)
94 | max_delta = (np.abs(curr_joints - reset_joints)).max()
95 | steps = min(int(max_delta / 0.01), 20)
96 | for jnt in np.linspace(curr_joints, reset_joints, steps):
97 | obs = env.step(jnt)
98 |
99 | # going to start position
100 | print("Going to start position")
101 | start_pos = agent.act(env.get_real_obs())
102 | obs = env.get_real_obs()
103 | joints = obs["joint_positions"]
104 |
105 | # if args.hand_type == "ability":
106 | ur_idx = list(range(0, 6)) + list(range(12, 18))
107 | hand_idx = list(range(6, 12)) + list(range(18, 24))
108 |
109 | if args.safe:
110 | max_joint_delta = 0.5
111 | max_hand_delta = 0.1
112 | safety_wrapper = SafetyWrapper(
113 | ur_idx, hand_idx, agent, delta=max_joint_delta, hand_delta=max_hand_delta
114 | )
115 |
116 | print(f"Start pos: {len(start_pos)}", f"Joints: {len(joints)}")
117 | assert len(start_pos) == len(
118 | joints
119 | ), f"agent output dim = {len(start_pos)}, but env dim = {len(joints)}"
120 |
121 | for step in range(3):
122 | print("Countdown step", step)
123 | time.sleep(0.5)
124 |
125 | print_color("\nReady to go 🚀🚀🚀", color="green", attrs=("bold",))
126 |
127 | start_time = time.time()
128 |
129 | if args.save_data:
130 | time_str = datetime.datetime.now().strftime("%m%d_%H%M%S")
131 | save_path = (
132 | Path(args.data_dir).expanduser()
133 | / (args.traj_path.split("/")[-1] + "_openloop")
134 | / time_str
135 | )
136 | save_path.mkdir(parents=True, exist_ok=True)
137 | print(f"Saving to {save_path}")
138 |
139 | while obs is not None:
140 | num = time.time() - start_time
141 | message = f"\rTime passed: {round(num, 2)} "
142 | print_color(
143 | message,
144 | color="white",
145 | attrs=("bold",),
146 | end="",
147 | flush=True,
148 | )
149 | if args.safe:
150 | action = safety_wrapper.act_safe(
151 | agent, obs, eef=(args.agent.endswith("_eef"))
152 | )
153 | else:
154 | action = agent.act(obs)
155 | dt = datetime.datetime.now()
156 | img, depth = camera_clients["base"].read()
157 | if args.save_data:
158 | obs["base_rgb"] = img
159 | obs["base_depth"] = depth
160 | save_frame(save_path, dt, obs, action, activated=agent.trigger_state)
161 | # input("Press Enter to continue...")
162 |
163 | if args.agent.endswith("_eef"):
164 | obs = env.step_eef(action)
165 | else:
166 | obs = env.step(action)
167 |
168 | # save eval video
169 | import glob
170 | import os
171 |
172 | from moviepy.editor import ImageSequenceClip
173 |
174 | episode_dir = save_path
175 |
176 | # find all the pkl files in the episode directory
177 | pkls = sorted(glob.glob(os.path.join(episode_dir, "*.pkl")))
178 |
179 | # read all images
180 | frames = []
181 | for pkl in pkls:
182 | with open(pkl, "rb") as f:
183 | try:
184 | data = pickle.load(f)
185 | except:
186 | continue
187 | rgb = data["base_rgb"]
188 | rgb = np.concatenate([rgb[i] for i in range(rgb.shape[0])], axis=1)
189 | frames.append(rgb)
190 |
191 | # Create a video clip
192 | clip = ImageSequenceClip(frames, fps=10)
193 | ckpt_path = os.path.dirname(args.dp_ckpt_path)
194 | clip.write_videofile(os.path.join(ckpt_path, f"{time_str}_openloop.mp4"))
195 |
196 |
197 | if __name__ == "__main__":
198 | main(tyro.cli(Args))
199 |
--------------------------------------------------------------------------------
/test_dp_agent.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import pickle
4 | import time
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from agents.dp_agent import BimanualDPAgent
10 | from learning.dp.data_processing import iterate
11 |
12 | torch.cuda.set_device(0)
13 |
14 |
15 | def main(args):
16 | hand_uppers = np.array([110.0, 110.0, 110.0, 110.0, 90.0, 120.0])
17 | hand_lowers = np.array([5.0, 5.0, 5.0, 5.0, 5.0, 5.0])
18 |
19 | dp_args = BimanualDPAgent.get_default_dp_args()
20 | data = iterate(args.data_dir)
21 | for num_diffusion_iters in args.num_diffusion_iters:
22 | dp_args["num_diffusion_iters"] = num_diffusion_iters
23 | dp_agent = BimanualDPAgent(ckpt_path=args.ckpt_path, dp_args=dp_args)
24 |
25 | controls = []
26 | pred_actions = []
27 | delta_action = []
28 | last_action = data[0]["control"]
29 |
30 | start = time.time()
31 | dp_agent.compile_inference(data[0], num_inference_iters=num_diffusion_iters)
32 | end = time.time()
33 | ctime = end - start
34 | print(f"compilation time: {ctime}")
35 |
36 | start = time.time()
37 | max_infer_time = 0
38 | for i, obs in enumerate(data):
39 | control = obs["control"]
40 | delta_action.append(control - last_action)
41 | last_action = control
42 | controls.append(control)
43 | if i != 0:
44 | obs["joint_positions"][list(range(6)) + list(range(12, 18))] = (
45 | pred_actions[-1][list(range(6)) + list(range(12, 18))]
46 | )
47 | obs["joint_positions"][6:12] = (
48 | pred_actions[-1][6:12] * (hand_uppers - hand_lowers) + hand_lowers
49 | )
50 | obs["joint_positions"][18:24] = (
51 | pred_actions[-1][18:24] * (hand_uppers - hand_lowers) + hand_lowers
52 | )
53 | infer_start = time.time()
54 | pred_actions.append(dp_agent.act(obs))
55 | infer_time = time.time() - infer_start
56 | max_infer_time = max(max_infer_time, infer_time)
57 |
58 | print("num_diffusion_iters:", num_diffusion_iters)
59 | print("time:", time.time() - start)
60 | print("Hz:", len(data) / (time.time() - start))
61 | print("max_infer_time:", max_infer_time)
62 | print("lowest_freq:", 1 / max_infer_time)
63 |
64 | pred_actions = np.array(pred_actions)
65 | controls = np.array(controls)
66 |
67 | mse = np.mean(np.abs((pred_actions - controls)), axis=0)
68 | mean_delta_action = np.mean(np.abs(delta_action), axis=0)
69 |
70 | print_str = "\n".join(
71 | [
72 | "mse:",
73 | str(mse.tolist()),
74 | "\n",
75 | "mean_delta_action:",
76 | str(mean_delta_action.tolist()),
77 | "\nfinal_diff:",
78 | str((pred_actions[-1] - controls[-1]).tolist()),
79 | ]
80 | )
81 | print_str += "\n"
82 | print_str += (
83 | "mse: "
84 | + str(mse.mean())
85 | + " mean_delta_action: "
86 | + str(mean_delta_action.mean())
87 | )
88 | print(print_str)
89 |
90 | # save print_str as txt in ckpt dir
91 | ckpt_dir = os.path.dirname(args.ckpt_path)
92 | with open(
93 | os.path.join(ckpt_dir, f"eval_stats_{num_diffusion_iters}.txt"), "w"
94 | ) as f:
95 | f.write(print_str)
96 |
97 | traj_name = os.path.basename(args.data_dir)
98 | save_path = os.path.join(
99 | ckpt_dir, f"openloop_{traj_name}_{num_diffusion_iters}"
100 | )
101 | os.makedirs(save_path, exist_ok=True)
102 | for i in range(len(pred_actions)):
103 | with open(os.path.join(save_path, str(i) + ".pkl"), "wb") as f:
104 | pickle.dump(
105 | {
106 | "control": pred_actions[i],
107 | "joint_positions": data[i]["joint_positions"],
108 | },
109 | f,
110 | )
111 |
112 |
113 | if __name__ == "__main__":
114 | parser = argparse.ArgumentParser()
115 | parser.add_argument(
116 | "--ckpt_path",
117 | type=str,
118 | default="best.ckpt",
119 | )
120 | parser.add_argument(
121 | "--data_dir",
122 | type=str,
123 | default="data/",
124 | )
125 | parser.add_argument(
126 | "--num_diffusion_iters", default=[5, 15, 25, 50], type=int, nargs="+"
127 | )
128 | args = parser.parse_args()
129 | main(args)
130 |
--------------------------------------------------------------------------------
/test_dp_agent_zmq.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import pickle
4 | import time
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from agents.dp_agent_zmq import BimanualDPAgent
10 | from learning.dp.data_processing import iterate
11 |
12 | torch.cuda.set_device(0)
13 |
14 |
15 | def main(args):
16 | hand_uppers = np.array([110.0, 110.0, 110.0, 110.0, 90.0, 120.0])
17 | hand_lowers = np.array([5.0, 5.0, 5.0, 5.0, 5.0, 5.0])
18 |
19 | data = iterate(args.data_dir)
20 |
21 | num_diffusion_iters = args.num_diffusion_iters
22 | dp_agent = BimanualDPAgent(
23 | ckpt_path="best.ckpt",
24 | host="localhost",
25 | port="4321",
26 | )
27 | dp_agent.compile_inference(data[0], num_diffusion_iters=num_diffusion_iters)
28 | controls = []
29 | pred_actions = []
30 | delta_action = []
31 | last_action = data[0]["control"]
32 |
33 | # start = time.time()
34 | # end = time.time()
35 | # ctime = end - start
36 | # print(f"compilation time: {ctime}")
37 |
38 | start = time.time()
39 | max_infer_time = 0
40 | for i, obs in enumerate(data):
41 | time.sleep(0.1)
42 | control = obs["control"]
43 | delta_action.append(control - last_action)
44 | last_action = control
45 | controls.append(control)
46 | if i != 0:
47 | obs["joint_positions"][list(range(6)) + list(range(12, 18))] = pred_actions[
48 | -1
49 | ][list(range(6)) + list(range(12, 18))]
50 | obs["joint_positions"][6:12] = (
51 | pred_actions[-1][6:12] * (hand_uppers - hand_lowers) + hand_lowers
52 | )
53 | obs["joint_positions"][18:24] = (
54 | pred_actions[-1][18:24] * (hand_uppers - hand_lowers) + hand_lowers
55 | )
56 | infer_start = time.time()
57 | pred_actions.append(dp_agent.act(obs))
58 | infer_time = time.time() - infer_start
59 | max_infer_time = max(max_infer_time, infer_time)
60 |
61 | print("num_diffusion_iters:", num_diffusion_iters)
62 | print("time:", time.time() - start)
63 | print("Hz:", len(data) / (time.time() - start))
64 | print("max_infer_time:", max_infer_time)
65 | print("lowest_freq:", 1 / max_infer_time)
66 |
67 | pred_actions = np.array(pred_actions)
68 | controls = np.array(controls)
69 |
70 | mse = np.mean(np.abs((pred_actions - controls)), axis=0)
71 | mean_delta_action = np.mean(np.abs(delta_action), axis=0)
72 |
73 | print_str = "\n".join(
74 | [
75 | "mse:",
76 | str(mse.tolist()),
77 | "\n",
78 | "mean_delta_action:",
79 | str(mean_delta_action.tolist()),
80 | "\nfinal_diff:",
81 | str((pred_actions[-1] - controls[-1]).tolist()),
82 | ]
83 | )
84 | print_str += "\n"
85 | print_str += (
86 | "mse: "
87 | + str(mse.mean())
88 | + " mean_delta_action: "
89 | + str(mean_delta_action.mean())
90 | )
91 | print(print_str)
92 |
93 | # save print_str as txt in ckpt dir
94 | ckpt_dir = os.path.dirname(args.ckpt_path)
95 | with open(
96 | os.path.join(ckpt_dir, f"eval_stats_{num_diffusion_iters}.txt"), "w"
97 | ) as f:
98 | f.write(print_str)
99 |
100 | traj_name = os.path.basename(args.data_dir)
101 | save_path = os.path.join(ckpt_dir, f"openloop_{traj_name}_{num_diffusion_iters}")
102 | os.makedirs(save_path, exist_ok=True)
103 | for i in range(len(pred_actions)):
104 | with open(os.path.join(save_path, str(i) + ".pkl"), "wb") as f:
105 | pickle.dump(
106 | {
107 | "control": pred_actions[i],
108 | "joint_positions": data[i]["joint_positions"],
109 | },
110 | f,
111 | )
112 |
113 |
114 | if __name__ == "__main__":
115 | parser = argparse.ArgumentParser()
116 | parser.add_argument(
117 | "--data_dir",
118 | type=str,
119 | default="data/",
120 | )
121 | parser.add_argument("--num_diffusion_iters", default=15, type=int)
122 | args = parser.parse_args()
123 | main(args)
124 |
--------------------------------------------------------------------------------
/workflow/create_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import shutil
5 |
6 | random.seed(0)
7 |
8 |
9 | def create_eval(origin_path, target_path, num=20):
10 | if not os.path.exists(target_path):
11 | os.makedirs(target_path)
12 | data_dir = os.listdir(origin_path)
13 | assert (
14 | len(data_dir) >= num
15 | ), "The number of data is less than the number of data to move"
16 | assert len(data_dir) > 100, "at least 100 traj"
17 | # sample without replacement
18 | eval_dir = random.sample(data_dir, num)
19 | print(eval_dir)
20 | for file in eval_dir:
21 | shutil.move(os.path.join(origin_path, file), target_path)
22 | print("Move {} to {}".format(file, target_path))
23 |
24 |
25 | if __name__ == "__main__":
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument(
28 | "--origin_path", type=str, default=None, help="path to the original data"
29 | )
30 | parser.add_argument(
31 | "--target_path", type=str, default=None, help="path to the target data"
32 | )
33 | parser.add_argument("--num", type=int, default=20, help="number of data to move")
34 | args = parser.parse_args()
35 | create_eval(args.origin_path, args.target_path, args.num)
36 |
--------------------------------------------------------------------------------
/workflow/download_dataset.sh:
--------------------------------------------------------------------------------
1 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1512475235180' > data_banana.zip
2 | md5sum data_banana.zip
3 |
4 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1512698219144' > data_stack.zip
5 | md5sum data_stack.zip
6 |
7 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513394792580' > data_pour.zip
8 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513418334040' >> data_pour.zip
9 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513412800402' >> data_pour.zip
10 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513388327963' >> data_pour.zip
11 | md5sum data_pour.zip
12 |
13 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513478982439' > data_steak.zip
14 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513509873535' >> data_steak.zip
15 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513526807695' >> data_steak.zip
16 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513547399539' >> data_steak.zip
17 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513550058414' >> data_steak.zip
18 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513544174072' >> data_steak.zip
19 | curl -L 'https://berkeley.app.box.com/index.php?rm=box_download_shared_file&shared_name=379cf57zqm1akvr00vdcloxqxi3ucb9g&file_id=f_1513541198161' >> data_steak.zip
20 | md5sum data_steak.zip
--------------------------------------------------------------------------------
/workflow/gen_deploy_scripts.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 |
5 | def find_ckpts(ckpt_dirs):
6 | """
7 | Find all checkpoint paths given a parent dir or a list of checkpoint directory names.
8 | """
9 | ckpt_dict = {}
10 | for ckpt_dir in ckpt_dirs:
11 | for root, _, files in os.walk(ckpt_dir):
12 | for file in files:
13 | if file.endswith(".ckpt"):
14 | run = root.split("/")[-1]
15 | ckpt_idx = int(file.split(".")[-2].split("_")[-1])
16 | if (run not in ckpt_dict) or (
17 | run in ckpt_dict and ckpt_dict[run]["idx"] < ckpt_idx
18 | ):
19 | # update the ckpt path
20 | if os.path.exists(os.path.join(root, "args_log.txt")):
21 | args_path = os.path.join(root, "args_log.txt")
22 | else:
23 | args_path = None
24 | ckpt_dict[run] = {
25 | "root": root,
26 | "ckpt": os.path.join(root, file),
27 | "args": args_path,
28 | "idx": ckpt_idx,
29 | }
30 | return ckpt_dict
31 |
32 |
33 | def gen_deploy_scripts(
34 | ckpt_dirs, agent_type="dp", test_data_path="", conda_env_name="hato"
35 | ):
36 | ckpt_dict = find_ckpts(ckpt_dirs)
37 |
38 | for run, ckpt_args in ckpt_dict.items():
39 | ckpt = ckpt_args["ckpt"]
40 |
41 | # generate env script
42 | env_script = f"python run_env.py --agent {agent_type} --no-show-camera-view --save_data --dp_ckpt_path {ckpt}"
43 |
44 | if test_data_path is None or test_data_path == "":
45 | print("No test traj for this ckpt.")
46 | openloop_script = f"python run_openloop.py --agent {agent_type} --no-show-camera-view --save_data --dp_ckpt_path {ckpt} --traj_path {test_data_path}"
47 |
48 | # generate node script
49 | node_script = (
50 | 'python launch_nodes.py --hand_type ability --faster --cam_names "435"'
51 | )
52 | env_script += " --hz 10"
53 | openloop_script += " --hz 10"
54 |
55 | if "pour" in ckpt or "steak" in ckpt:
56 | node_script += " --ability_gripper_grip_range 75"
57 |
58 | test_script = (
59 | f"python test_dp_agent.py --ckpt_path {ckpt} --data_dir {test_data_path}"
60 | )
61 | env_jit_script = (
62 | env_script
63 | + " --use_jit_agent --inference_agent_port=1325 --num_diffusion_iters_compile=50 --jit_compile"
64 | )
65 |
66 | inference_script = f"'conda activate {conda_env_name} && python launch_inference_nodes.py --dp_ckpt_path {ckpt} --port=1325'"
67 | # save the scripts at the ckpt directory
68 | save_dir = ckpt_args["root"]
69 | with open(os.path.join(save_dir, f"{run}_node.sh"), "w") as f:
70 | f.write(node_script)
71 | with open(os.path.join(save_dir, f"{run}_env.sh"), "w") as f:
72 | f.write(env_script)
73 | with open(os.path.join(save_dir, f"{run}_env_jit.sh"), "w") as f:
74 | f.write(env_jit_script)
75 | with open(os.path.join(save_dir, f"{run}_openloop.sh"), "w") as f:
76 | f.write(openloop_script)
77 | with open(os.path.join(save_dir, f"{run}_test.sh"), "w") as f:
78 | f.write(test_script)
79 | with open(os.path.join(save_dir, f"{run}_inference.sh"), "w") as f:
80 | f.write(inference_script)
81 |
82 | return ckpt_dict
83 |
84 |
85 | def run_test_scripts(ckpt_dirs, filter="", run_all=True):
86 | for checkpoint_dir in ckpt_dirs:
87 | for root, _, files in os.walk(checkpoint_dir):
88 | for file in files:
89 | if file.endswith("_test.sh") and filter in root:
90 | print("Running test script: ", file)
91 | if run_all:
92 | os.system(f"bash {os.path.join(root, file)}")
93 | else:
94 | # run the test script if enter yes
95 | answer = input(f"\nRun {os.path.join(root, file)}? (y/n)")
96 | if answer == "y":
97 | os.system(f"bash {os.path.join(root, file)}")
98 | else:
99 | print("Skip.")
100 |
101 |
102 | if __name__ == "__main__":
103 | import argparse
104 |
105 | parser = argparse.ArgumentParser()
106 | parser.add_argument(
107 | "-c",
108 | "--ckpt_dirs",
109 | nargs="+",
110 | default=[
111 | "/shared/ckpts/data_banana",
112 | ],
113 | )
114 | parser.add_argument("-m", "--mode", type=str, default="gen")
115 | parser.add_argument("-f", "--filter", type=str, default="")
116 | args = parser.parse_args()
117 |
118 | if args.mode == "gen":
119 | checkpoint_paths = gen_deploy_scripts(args.ckpt_dirs)
120 | print(checkpoint_paths)
121 | print(len(checkpoint_paths))
122 | print("Done.")
123 | elif args.mode == "test":
124 | run_test_scripts(args.ckpt_dirs, filter=args.filter)
125 | elif args.mode == "all":
126 | checkpoint_paths = gen_deploy_scripts(args.ckpt_dirs)
127 | print(checkpoint_paths)
128 | print(len(checkpoint_paths))
129 | run_test_scripts(args.ckpt_dirs, filter=args.filter)
130 |
--------------------------------------------------------------------------------
/workflow/split_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 | import random
5 |
6 |
7 | def split_symlink_train_test_merge_dataset(root_list, t_root, num_traj_per_task):
8 | target_train_root = t_root + f"_train"
9 | target_test_root = t_root + f"_test"
10 | print(
11 | f"split data points from",
12 | root_list,
13 | f"to {target_train_root}: train and {target_test_root}: test",
14 | )
15 | if os.path.exists(target_train_root):
16 | return
17 | assert not os.path.exists(target_test_root)
18 |
19 | os.makedirs(target_train_root)
20 | os.makedirs(target_test_root)
21 |
22 | all_paths = []
23 | for root in sorted(root_list):
24 | for path in sorted(os.listdir(root))[:num_traj_per_task]:
25 | if not os.path.isfile(os.path.join(root, path)):
26 | d = path
27 | if (
28 | d.endswith("failed")
29 | or d.endswith("ood")
30 | or d.endswith("ikbad")
31 | or d.endswith("heated")
32 | or d.endswith("stop")
33 | or d.endswith("hard")
34 | ):
35 | continue
36 | all_paths.append((root, path))
37 |
38 | train_paths = all_paths[:-1]
39 | test_paths = all_paths[-1:]
40 |
41 | for root, sl_path in train_paths:
42 | src_path = os.path.abspath(os.path.join(root, sl_path))
43 | tgt_path = os.path.abspath(os.path.join(target_train_root, sl_path))
44 | print("\rlinking", src_path, tgt_path, end="")
45 | os.symlink(src_path, tgt_path)
46 |
47 | for root, sl_path in test_paths:
48 | src_path = os.path.abspath(os.path.join(root, sl_path))
49 | tgt_path = os.path.abspath(os.path.join(target_test_root, sl_path))
50 | print("\rlinking", src_path, tgt_path, end="")
51 | os.symlink(src_path, tgt_path)
52 |
53 | print()
54 | print("Done!!")
55 |
56 |
57 | def split_symlink_train_test_dataset(root, t_root):
58 | target_train_root = t_root + f"_train"
59 | target_test_root = t_root + f"_test"
60 | print(
61 | f"split data points from {root} to {target_train_root}: train and {target_test_root}: test"
62 | )
63 |
64 | if os.path.exists(target_train_root):
65 | return
66 | assert not os.path.exists(target_test_root)
67 |
68 | os.makedirs(target_train_root)
69 | os.makedirs(target_test_root)
70 |
71 | all_paths = []
72 | for path in sorted(os.listdir(root)):
73 | if not os.path.isfile(os.path.join(root, path)):
74 | d = path
75 | if (
76 | d.endswith("failed")
77 | or d.endswith("ood")
78 | or d.endswith("ikbad")
79 | or d.endswith("heated")
80 | or d.endswith("stop")
81 | or d.endswith("hard")
82 | ):
83 | continue
84 | all_paths.append(path)
85 |
86 | train_paths = all_paths[:-1]
87 | test_paths = all_paths[-1:]
88 |
89 | for sl_path in train_paths:
90 | src_path = os.path.abspath(os.path.join(root, sl_path))
91 | tgt_path = os.path.abspath(os.path.join(target_train_root, sl_path))
92 | print("\rlinking", src_path, tgt_path, end="")
93 | os.symlink(src_path, tgt_path)
94 | for sl_path in test_paths:
95 | src_path = os.path.abspath(os.path.join(root, sl_path))
96 | tgt_path = os.path.abspath(os.path.join(target_test_root, sl_path))
97 | print("\rlinking", src_path, tgt_path, end="")
98 | os.symlink(src_path, tgt_path)
99 | print()
100 | print("Done!!")
101 |
102 |
103 | def split_symlink_dataset(root, num_trajs):
104 | target_root = root + f"_{num_trajs}"
105 | print(f"split {num_trajs} data points from {root} to {target_root}")
106 |
107 | if os.path.exists(target_root):
108 | return
109 | assert not os.path.exists(target_root)
110 |
111 | os.makedirs(target_root)
112 |
113 | all_paths = []
114 | for path in os.listdir(root):
115 | if not os.path.isfile(os.path.join(root, path)):
116 | all_paths.append(path)
117 |
118 | random.seed(0)
119 | random.shuffle(all_paths)
120 |
121 | sl_paths = all_paths[:num_trajs]
122 |
123 | assert len(all_paths) >= num_trajs
124 | for sl_path in sl_paths:
125 | src_path = os.path.abspath(os.path.join(root, sl_path))
126 | tgt_path = os.path.abspath(os.path.join(target_root, sl_path))
127 | print("\rlinking", src_path, tgt_path, end="")
128 | os.symlink(src_path, tgt_path)
129 | print()
130 | print("Done!!")
131 |
132 |
133 | if __name__ == "__main__":
134 | arg = argparse.ArgumentParser()
135 | arg.add_argument("--base_path", type=str, default="/hato/")
136 | arg.add_argument("--output_path", type=str, default="/split_data")
137 | arg.add_argument(
138 | "--data_name",
139 | nargs="+",
140 | type=str,
141 | default=[
142 | "data_banana",
143 | ],
144 | )
145 | arg.add_argument("--num_trajs", nargs="+", type=int, default=[10, 25, 50, 75])
146 | arg.add_argument("--merge", action="store_true")
147 | arg.add_argument("--merge_name", type=str, default="data_banana_all")
148 | arg.add_argument("--num_traj_per_task", type=int, default=20)
149 | args = arg.parse_args()
150 |
151 | if not args.merge:
152 | for data_name in args.data_name:
153 | split_symlink_train_test_dataset(
154 | os.path.join(args.base_path, data_name),
155 | os.path.join(args.output_path, data_name),
156 | )
157 | for num_trajs in args.num_trajs:
158 | split_symlink_dataset(
159 | os.path.join(args.output_path, data_name) + "_train", num_trajs
160 | )
161 |
162 | else:
163 | data_name_list = [
164 | os.path.join(args.base_path, data_name) for data_name in args.data_name
165 | ]
166 | split_symlink_train_test_merge_dataset(
167 | data_name_list,
168 | os.path.join(args.output_path, args.merge_name),
169 | args.num_traj_per_task,
170 | )
171 |
--------------------------------------------------------------------------------