├── frankateach ├── __init__.py ├── configs │ ├── osc-pose-controller.yml │ ├── deoxys_left.yml │ └── deoxys_right.yml ├── constants.py ├── camera_server.py ├── oculus_stick.py ├── sensors │ ├── reskin.py │ ├── fisheye_cam.py │ └── realsense.py ├── utils.py ├── messages.py ├── franka_server.py ├── teleoperator.py ├── data_collector.py └── network.py ├── .gitignore ├── imgs ├── fci.png ├── foxy_proxy.png └── unlock_joints.png ├── configs ├── franka_server.yaml ├── reskin.yaml ├── teleop.yaml ├── collect_data.yaml └── camera.yaml ├── requirements.txt ├── franka-env ├── setup.py └── franka_env │ ├── envs │ ├── __init__.py │ ├── franka_env_relative.py │ └── franka_env.py │ └── __init__.py ├── setup.py ├── franka_server.py ├── reskin_server.py ├── .pre-commit-config.yaml ├── collect_data.py ├── camera_server.py ├── teleop.py ├── views_test.py └── README.md /frankateach/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info/ 2 | __pycache__/ 3 | logs/ 4 | extracted_data*/ 5 | outputs/ 6 | -------------------------------------------------------------------------------- /imgs/fci.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NYU-robot-learning/Franka-Teach/HEAD/imgs/fci.png -------------------------------------------------------------------------------- /configs/franka_server.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | 4 | deoxys_config_path: "deoxys_right.yml" -------------------------------------------------------------------------------- /configs/reskin.yaml: -------------------------------------------------------------------------------- 1 | reskin_config: 2 | port: "/dev/ttyACM0" 3 | num_mags: 10 4 | history: 40 5 | -------------------------------------------------------------------------------- /imgs/foxy_proxy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NYU-robot-learning/Franka-Teach/HEAD/imgs/foxy_proxy.png -------------------------------------------------------------------------------- /imgs/unlock_joints.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NYU-robot-learning/Franka-Teach/HEAD/imgs/unlock_joints.png -------------------------------------------------------------------------------- /configs/teleop.yaml: -------------------------------------------------------------------------------- 1 | init_gripper_state: open # open, close 2 | teleop_mode: robot # robot, human 3 | home_offset: null 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyzmq 2 | scipy 3 | opencv-python 4 | blosc 5 | hydra-core 6 | pyrealsense2 7 | h5py 8 | pre-commit 9 | -------------------------------------------------------------------------------- /franka-env/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="franka_env", 5 | version="0.0.1", 6 | packages=["franka_env"], 7 | install_requires=["gym"], 8 | ) 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="frankateach", 5 | version="0.0.1", 6 | packages=["frankateach"], 7 | install_requires=["gymnasium"], 8 | ) 9 | -------------------------------------------------------------------------------- /franka-env/franka_env/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from franka_env.envs.franka_env import FrankaEnv 2 | from franka_env.envs.franka_env_relative import FrankaEnvRelative 3 | 4 | __all__ = ["FrankaEnv", "FrankaEnvRelative"] 5 | -------------------------------------------------------------------------------- /configs/collect_data.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - camera 4 | 5 | storage_path: "path/to/data/extracted_data/task_name" 6 | demo_num: 0 7 | collect_img: True 8 | collect_depth: False 9 | collect_state: True 10 | collect_reskin: False 11 | -------------------------------------------------------------------------------- /franka_server.py: -------------------------------------------------------------------------------- 1 | from frankateach.franka_server import FrankaServer 2 | import hydra 3 | 4 | @hydra.main(version_base="1.2", config_path="configs", config_name="franka_server") 5 | def main(cfg): 6 | fs = FrankaServer(cfg.deoxys_config_path) 7 | fs.init_server() 8 | 9 | 10 | if __name__ == "__main__": 11 | main() -------------------------------------------------------------------------------- /franka-env/franka_env/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id="Franka-v1", 5 | entry_point="franka_env.envs:FrankaEnv", 6 | max_episode_steps=400, 7 | ) 8 | 9 | register( 10 | id="FrankaRelative-v1", 11 | entry_point="franka_env.envs:FrankaEnvRelative", 12 | max_episode_steps=400, 13 | ) 14 | -------------------------------------------------------------------------------- /reskin_server.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from frankateach.sensors.reskin import ReskinSensorPublisher 3 | 4 | 5 | @hydra.main(version_base="1.2", config_path="configs", config_name="reskin") 6 | def main(cfg): 7 | reskin_publisher = ReskinSensorPublisher(reskin_config=cfg.reskin_config) 8 | reskin_publisher.stream() 9 | 10 | 11 | if __name__ == "__main__": 12 | main() 13 | -------------------------------------------------------------------------------- /configs/camera.yaml: -------------------------------------------------------------------------------- 1 | cam_info: 2 | - 3 | cam_id: 1 4 | cam_serial_num: "147322072736" 5 | type: realsense 6 | - 7 | cam_id: 2 8 | cam_serial_num: "028522071213" 9 | type: realsense 10 | 11 | cam_config: 12 | realsense: 13 | width: 640 14 | height: 480 15 | fps: 30 16 | processing_preset: 1 17 | depth: True 18 | fisheye: 19 | width: 640 20 | height: 480 21 | fps: 30 22 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/astral-sh/ruff-pre-commit 9 | # Ruff version. 10 | rev: v0.6.8 11 | hooks: 12 | # Run the linter. 13 | - id: ruff 14 | args: [ --fix ] 15 | # Run the formatter. 16 | - id: ruff-format 17 | -------------------------------------------------------------------------------- /frankateach/configs/osc-pose-controller.yml: -------------------------------------------------------------------------------- 1 | controller_type: OSC_POSE 2 | 3 | is_delta: true 4 | 5 | traj_interpolator_cfg: 6 | traj_interpolator_type: LINEAR_POSE 7 | time_fraction: 0.3 8 | 9 | Kp: 10 | # translation: 350.0 11 | # rotation: 300.0 12 | translation: 150.0 13 | rotation: 250.0 14 | 15 | action_scale: 16 | translation: 1.0 # 0.05 17 | rotation: 1.0 18 | 19 | residual_mass_vec: [0.0, 0.0, 0.0, 0.0, 0.1, 0.5, 0.5] 20 | 21 | state_estimator_cfg: 22 | is_estimation: false 23 | state_estimator_type: EXPONENTIAL_SMOOTHING 24 | alpha_q: 0.9 25 | alpha_dq: 0.9 26 | alpha_eef: 1.0 27 | alpha_eef_vel: 1.0 28 | -------------------------------------------------------------------------------- /collect_data.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from frankateach.data_collector import DataCollector 3 | 4 | 5 | @hydra.main(config_path="configs", config_name="collect_data", version_base="1.2") 6 | def main(cfg): 7 | data_collector = DataCollector( 8 | storage_path=cfg.storage_path, 9 | demo_num=cfg.demo_num, 10 | cams=cfg.cam_info, 11 | cam_config=cfg.cam_config, 12 | collect_img=cfg.collect_img, 13 | collect_depth=cfg.collect_depth, 14 | collect_state=cfg.collect_state, 15 | collect_reskin=cfg.collect_reskin, 16 | ) 17 | data_collector.start() 18 | 19 | 20 | if __name__ == "__main__": 21 | main() 22 | -------------------------------------------------------------------------------- /camera_server.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from frankateach.camera_server import CameraServer 3 | from frankateach.constants import HOST, CAM_PORT 4 | import hydra 5 | import argparse 6 | 7 | 8 | @hydra.main(version_base="1.2", config_path="configs", config_name="camera") 9 | def main(cfg): 10 | cam_configs = defaultdict(list) 11 | for camera in cfg.cam_info: 12 | cam_config = argparse.Namespace(**camera, **cfg.cam_config[camera.type]) 13 | cam_configs[camera.type].append(cam_config) 14 | 15 | camera_server = CameraServer( 16 | host=HOST, 17 | cam_port=CAM_PORT, 18 | cam_configs=cam_configs, 19 | ) 20 | 21 | camera_server._init_camera_threads() 22 | 23 | 24 | if __name__ == "__main__": 25 | main() 26 | -------------------------------------------------------------------------------- /frankateach/configs/deoxys_left.yml: -------------------------------------------------------------------------------- 1 | PC: 2 | NAME: "lambda" 3 | IP: 192.16.1.2 4 | 5 | NUC: 6 | NAME: "nuc" 7 | IP: 192.16.1.3 8 | PUB_PORT: 5566 9 | SUB_PORT: 5565 10 | GRIPPER_PUB_PORT: 5568 11 | GRIPPER_SUB_PORT: 5567 12 | 13 | ROBOT: 14 | IP: 172.16.1.4 15 | 16 | CONTROL: 17 | STATE_PUBLISHER_RATE: 100 18 | POLICY_RATE: 20 19 | TRAJ_RATE: 500 20 | ZMQ_NOBLOCK: true 21 | 22 | ARM_LOGGER: 23 | CONSOLE: 24 | LOGGER_NAME: "arm_logger" 25 | LEVEL: "info" 26 | USE: true 27 | FILE: 28 | LOGGER_NAME: "logs/deoxys_control_arm_program.log" 29 | LEVEL: "debug" 30 | USE: true 31 | 32 | GRIPPER_LOGGER: 33 | CONSOLE: 34 | LOGGER_NAME: "gripper_logger" 35 | LEVEL: "info" 36 | USE: true 37 | FILE: 38 | LOGGER_NAME: "logs/deoxys_control_gripper_program.log" 39 | LEVEL: "debug" 40 | USE: true 41 | -------------------------------------------------------------------------------- /frankateach/configs/deoxys_right.yml: -------------------------------------------------------------------------------- 1 | PC: 2 | NAME: "lambda" 3 | IP: 192.16.1.2 4 | 5 | NUC: 6 | NAME: "nuc" 7 | IP: 192.16.1.3 8 | PUB_PORT: 5556 9 | SUB_PORT: 5555 10 | GRIPPER_PUB_PORT: 5558 11 | GRIPPER_SUB_PORT: 5557 12 | 13 | ROBOT: 14 | IP: 172.16.0.4 15 | 16 | CONTROL: 17 | STATE_PUBLISHER_RATE: 100 18 | POLICY_RATE: 20 19 | TRAJ_RATE: 500 20 | ZMQ_NOBLOCK: true 21 | 22 | ARM_LOGGER: 23 | CONSOLE: 24 | LOGGER_NAME: "arm_logger" 25 | LEVEL: "info" 26 | USE: true 27 | FILE: 28 | LOGGER_NAME: "logs/deoxys_control_arm_program.log" 29 | LEVEL: "debug" 30 | USE: true 31 | 32 | GRIPPER_LOGGER: 33 | CONSOLE: 34 | LOGGER_NAME: "gripper_logger" 35 | LEVEL: "info" 36 | USE: true 37 | FILE: 38 | LOGGER_NAME: "logs/deoxys_control_gripper_program.log" 39 | LEVEL: "debug" 40 | USE: true 41 | -------------------------------------------------------------------------------- /frankateach/constants.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Network constants 4 | HOST = "localhost" 5 | CAM_PORT = 10005 6 | VR_CONTROLLER_STATE_PORT = 8889 7 | STATE_PORT = 8900 8 | CONTROL_PORT = 8901 9 | COMMANDED_STATE_PORT = 8902 10 | RESKIN_STREAM_PORT = 12005 11 | 12 | 13 | STATE_TOPIC = "state" 14 | CONTROL_TOPIC = "control" 15 | 16 | # VR constants 17 | VR_TCP_HOST = "10.19.225.15" 18 | VR_TCP_PORT = 5555 19 | VR_CONTROLLER_TOPIC = b"oculus_controller" 20 | 21 | # Robot constants 22 | GRIPPER_OPEN = -1 23 | GRIPPER_CLOSE = 1 24 | H_R_V = np.array([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]) 25 | H_R_V_star = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]) 26 | x_min, x_max = 0.2, 0.75 27 | y_min, y_max = -0.4, 0.4 28 | z_min, z_max = 0.05, 0.7 # 232, 550 29 | ROBOT_WORKSPACE_MIN = np.array([x_min, y_min, z_min]) 30 | ROBOT_WORKSPACE_MAX = np.array([x_max, y_max, z_max]) 31 | 32 | TRANSLATIONAL_POSE_VELOCITY_SCALE = 5 33 | ROTATIONAL_POSE_VELOCITY_SCALE = 0.75 34 | ROTATION_VELOCITY_LIMIT = 0.5 35 | TRANSLATION_VELOCITY_LIMIT = 1 36 | 37 | # Frequencies 38 | # TODO: Separate VR and deploy frequencies 39 | VR_FREQ = 20 40 | CONTROL_FREQ = 20 41 | STATE_FREQ = 100 42 | CAM_FPS = 30 43 | DEPTH_PORT_OFFSET = 1000 44 | -------------------------------------------------------------------------------- /teleop.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from multiprocessing import Process 3 | from frankateach.teleoperator import FrankaOperator 4 | from frankateach.oculus_stick import OculusVRStickDetector 5 | from frankateach.constants import HOST, VR_CONTROLLER_STATE_PORT 6 | 7 | 8 | def start_teleop(init_gripper_state="open", teleop_mode="robot", home_offset=None): 9 | operator = FrankaOperator( 10 | init_gripper_state=init_gripper_state, 11 | teleop_mode=teleop_mode, 12 | home_offset=home_offset, 13 | ) 14 | operator.stream() 15 | 16 | 17 | def start_oculus_stick(): 18 | detector = OculusVRStickDetector(HOST, VR_CONTROLLER_STATE_PORT) 19 | detector.stream() 20 | 21 | 22 | @hydra.main(version_base="1.2", config_path="configs", config_name="teleop") 23 | def main(cfg): 24 | teleop_process = Process( 25 | target=start_teleop, 26 | args=( 27 | cfg.init_gripper_state, 28 | cfg.teleop_mode, 29 | cfg.home_offset, 30 | ), 31 | ) 32 | oculus_stick_process = Process(target=start_oculus_stick) 33 | 34 | teleop_process.start() 35 | oculus_stick_process.start() 36 | 37 | teleop_process.join() 38 | oculus_stick_process.join() 39 | 40 | 41 | if __name__ == "__main__": 42 | main() 43 | -------------------------------------------------------------------------------- /franka-env/franka_env/envs/franka_env_relative.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.transform import Rotation as R 3 | 4 | from franka_env.envs.franka_env import FrankaEnv 5 | 6 | 7 | class FrankaEnvRelative(FrankaEnv): 8 | def step(self, rel_action): 9 | current_state = self.franka_state 10 | 11 | pos_curr = current_state.pos 12 | ori_curr = current_state.quat 13 | r_curr = R.from_quat(ori_curr).as_matrix() 14 | matrix_curr = np.eye(4) 15 | matrix_curr[:3, :3] = r_curr 16 | matrix_curr[:3, 3] = pos_curr 17 | 18 | # find transformation matrix 19 | pos_delta = rel_action[:3] 20 | ori_delta = rel_action[3:6] 21 | r_delta = R.from_rotvec(ori_delta).as_matrix() 22 | matrix_delta = np.eye(4) 23 | matrix_delta[:3, :3] = r_delta 24 | matrix_delta[:3, 3] = pos_delta 25 | 26 | # find desired matrix 27 | matrix_desired = matrix_curr @ matrix_delta 28 | 29 | pos_desired = pos_curr + pos_delta 30 | r_desired = matrix_desired[:3, :3] 31 | ori_desired = R.from_matrix(r_desired).as_quat() 32 | desired_cartesian_pose = np.concatenate([pos_desired, ori_desired]) 33 | 34 | action = np.concatenate([desired_cartesian_pose, rel_action[6:]]) 35 | return super().step(action) 36 | -------------------------------------------------------------------------------- /views_test.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import shutil 3 | from pathlib import Path 4 | 5 | 6 | def test_camera_index(index): 7 | cap = cv2.VideoCapture(index) 8 | cap.set(cv2.CAP_PROP_FRAME_WIDTH, 680) 9 | cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) 10 | if not cap.isOpened(): 11 | print(f"Failed to open camera at index {index}") 12 | return 13 | 14 | idx = 0 15 | while True: 16 | ret, frame = cap.read() 17 | if not ret: 18 | print(f"Failed to capture frame from camera at index {index}") 19 | break 20 | 21 | idx += 1 22 | if idx == 10: 23 | print( 24 | "Frame default resolution: (" 25 | + str(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 26 | + "; " 27 | + str(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 28 | + ")" 29 | ) 30 | cv2.imwrite(f"views_test/camera_{index}.jpg", frame) 31 | break 32 | 33 | cap.release() 34 | 35 | 36 | # Test camera indices from 0 to 100 37 | if __name__ == "__main__": 38 | # make directory to save images 39 | dir_path = Path("./views_test") 40 | if dir_path.exists(): 41 | shutil.rmtree(dir_path) 42 | Path("views_test").mkdir(parents=True, exist_ok=True) 43 | 44 | for camera_id in range(100): 45 | test_camera_index(camera_id) 46 | -------------------------------------------------------------------------------- /frankateach/camera_server.py: -------------------------------------------------------------------------------- 1 | import pyrealsense2 as rs 2 | import time 3 | import threading 4 | 5 | from frankateach.sensors.realsense import RealsenseCamera 6 | from frankateach.sensors.fisheye_cam import FishEyeCamera 7 | 8 | 9 | class CameraServer: 10 | def __init__(self, host: str, cam_port: int, cam_configs: list): 11 | self._host = host 12 | self._cam_port = cam_port 13 | self._cam_configs = cam_configs 14 | self._cam_threads = [] 15 | 16 | if "realsense" in cam_configs.keys(): 17 | ctx = rs.context() 18 | devices = ctx.query_devices() 19 | 20 | for dev in devices: 21 | dev.hardware_reset() 22 | 23 | print("Waiting for hardware reset on cameras for 15 seconds...") 24 | time.sleep(10) 25 | 26 | def _start_component(self, cam_idx, cam_config): 27 | cam_type = cam_config.type 28 | if cam_type == "realsense": 29 | component = RealsenseCamera( 30 | host=self._host, 31 | port=self._cam_port + cam_idx, 32 | cam_id=cam_idx, 33 | cam_config=cam_config, 34 | ) 35 | elif cam_type == "fisheye": 36 | component = FishEyeCamera( 37 | host=self._host, 38 | port=self._cam_port + cam_idx, 39 | cam_id=cam_idx, 40 | cam_config=cam_config, 41 | ) 42 | else: 43 | raise ValueError(f"Invalid camera type: {cam_type}") 44 | component.stream() 45 | 46 | def _init_camera_threads(self): 47 | for cam_type in self._cam_configs: 48 | for cam_cfg in self._cam_configs[cam_type]: 49 | cam_thread = threading.Thread( 50 | target=self._start_component, 51 | args=(cam_cfg.cam_id, cam_cfg), 52 | daemon=True, 53 | ) 54 | cam_thread.start() 55 | self._cam_threads.append(cam_thread) 56 | 57 | for cam_thread in self._cam_threads: 58 | cam_thread.join() 59 | -------------------------------------------------------------------------------- /frankateach/oculus_stick.py: -------------------------------------------------------------------------------- 1 | from frankateach.constants import ( 2 | VR_CONTROLLER_STATE_PORT, 3 | VR_FREQ, 4 | VR_TCP_HOST, 5 | VR_TCP_PORT, 6 | ) 7 | from frankateach.utils import FrequencyTimer 8 | from frankateach.network import create_subscriber_socket, ZMQKeypointPublisher 9 | from frankateach.utils import parse_controller_state 10 | 11 | from frankateach.utils import notify_component_start 12 | 13 | 14 | # This class is used to detect the hand keypoints from the VR and publish them. 15 | class OculusVRStickDetector: 16 | def __init__(self, host, controller_state_pub_port): 17 | notify_component_start("vr detector") 18 | 19 | # Create a subscriber socket 20 | self.stick_socket = create_subscriber_socket( 21 | VR_TCP_HOST, VR_TCP_PORT, b"", conflate=True 22 | ) # bytes(VR_CONTROLLER_TOPIC, 'utf-8')) 23 | 24 | # Create a publisher for the controller state 25 | self.controller_state_publisher = ZMQKeypointPublisher( 26 | host=host, port=controller_state_pub_port 27 | ) 28 | self.timer = FrequencyTimer(VR_FREQ) 29 | 30 | # Function to Publish the message 31 | def _publish_controller_state(self, controller_state): 32 | self.controller_state_publisher.pub_keypoints( 33 | keypoint_array=controller_state, topic_name="controller_state" 34 | ) 35 | 36 | # Function to publish the left/right hand keypoints and button Feedback 37 | def stream(self): 38 | print("oculus stick stream") 39 | while True: 40 | try: 41 | self.timer.start_loop() 42 | 43 | message = self.stick_socket.recv_string() 44 | if message == "oculus_controller": 45 | continue 46 | 47 | controller_state = parse_controller_state(message) 48 | 49 | # Publish message 50 | self._publish_controller_state(controller_state) 51 | 52 | self.timer.end_loop() 53 | 54 | except KeyboardInterrupt: 55 | break 56 | 57 | self.controller_state_publisher.stop() 58 | print("Stopping the oculus keypoint extraction process.") 59 | 60 | 61 | def main(): 62 | detector = OculusVRStickDetector("localhost", VR_CONTROLLER_STATE_PORT) 63 | detector.stream() 64 | 65 | 66 | if __name__ == "__main__": 67 | main() 68 | -------------------------------------------------------------------------------- /frankateach/sensors/reskin.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import time 3 | from frankateach.constants import HOST, RESKIN_STREAM_PORT 4 | from frankateach.network import ZMQKeypointPublisher, ZMQKeypointSubscriber 5 | from frankateach.utils import FrequencyTimer, notify_component_start 6 | 7 | from reskin_sensor import ReSkinProcess 8 | 9 | 10 | class ReskinSensorPublisher: 11 | def __init__(self, reskin_config): 12 | self.reskin_publisher = ZMQKeypointPublisher(HOST, RESKIN_STREAM_PORT) 13 | 14 | self.timer = FrequencyTimer(100) 15 | self.reskin_config = reskin_config 16 | if reskin_config.history is not None: 17 | self.history = deque(maxlen=reskin_config.history) 18 | else: 19 | self.history = deque(maxlen=1) 20 | self._start_reskin() 21 | 22 | def _start_reskin(self): 23 | self.sensor_proc = ReSkinProcess( 24 | num_mags=self.reskin_config["num_mags"], 25 | port=self.reskin_config["port"], 26 | baudrate=100000, 27 | burst_mode=True, 28 | device_id=0, 29 | temp_filtered=True, 30 | reskin_data_struct=True, 31 | ) 32 | self.sensor_proc.start() 33 | time.sleep(0.5) 34 | 35 | def stream(self): 36 | notify_component_start("Reskin sensors") 37 | 38 | while True: 39 | try: 40 | self.timer.start_loop() 41 | reskin_state = self.sensor_proc.get_data(1)[0] 42 | data_dict = {} 43 | data_dict["timestamp"] = reskin_state.time 44 | data_dict["sensor_values"] = reskin_state.data 45 | self.history.append(reskin_state.data) 46 | data_dict["sensor_history"] = list(self.history) 47 | self.reskin_publisher.pub_keypoints(data_dict, topic_name="reskin") 48 | self.timer.end_loop() 49 | 50 | except KeyboardInterrupt: 51 | break 52 | 53 | 54 | class ReskinSensorSubscriber: 55 | def __init__(self): 56 | self.reskin_subscriber = ZMQKeypointSubscriber( 57 | HOST, RESKIN_STREAM_PORT, topic="reskin" 58 | ) 59 | 60 | def __repr__(self): 61 | return "reskin" 62 | 63 | def get_sensor_state(self): 64 | reskin_state = self.reskin_subscriber.recv_keypoints() 65 | return reskin_state 66 | -------------------------------------------------------------------------------- /frankateach/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | from typing import Tuple 4 | 5 | from frankateach.messages import ControllerState 6 | 7 | 8 | def notify_component_start(component_name): 9 | print("***************************************************************") 10 | print(" Starting {} component".format(component_name)) 11 | print("***************************************************************") 12 | 13 | 14 | class FrequencyTimer(object): 15 | def __init__(self, frequency_rate): 16 | self.time_available = 1e9 / frequency_rate 17 | 18 | def start_loop(self): 19 | self.start_time = time.time_ns() 20 | 21 | def check_time(self, frequency_rate): 22 | # if prev_check_time variable doesn't exist, create it 23 | if not hasattr(self, "prev_check_time"): 24 | self.prev_check_time = self.start_time 25 | 26 | curr_time = time.time_ns() 27 | if (curr_time - self.prev_check_time) > 1e9 / frequency_rate: 28 | self.prev_check_time = curr_time 29 | return True 30 | return False 31 | 32 | def end_loop(self): 33 | wait_time = self.time_available + self.start_time 34 | 35 | while time.time_ns() < wait_time: 36 | continue 37 | 38 | 39 | def parse_controller_state(controller_state_string: str) -> ControllerState: 40 | 41 | left_data, right_data = controller_state_string.split("|") 42 | 43 | left_data = left_data.split(";")[1:-1] 44 | right_data = right_data.split(";")[1:-1] 45 | 46 | def parse_bool(val: str) -> bool: 47 | return val.split(":")[1].lower().strip() == "true" 48 | 49 | def parse_float(val: str) -> float: 50 | return float(val.split(":")[1]) 51 | 52 | def parse_list_float(val: str) -> np.ndarray: 53 | return np.array(list(map(float, val.split(":")[1].split(",")))) 54 | 55 | def parse_section(data: list) -> Tuple: 56 | return ( 57 | # Buttons 58 | parse_bool(data[0]), 59 | parse_bool(data[1]), 60 | parse_bool(data[2]), 61 | parse_bool(data[3]), 62 | # Triggers 63 | parse_float(data[4]), 64 | parse_float(data[5]), 65 | # Thumbstick 66 | parse_list_float(data[6]), 67 | # Pose 68 | parse_list_float(data[7]), 69 | parse_list_float(data[8]), 70 | ) 71 | 72 | left_parsed = parse_section(left_data) 73 | right_parsed = parse_section(right_data) 74 | 75 | return ControllerState(time.time(), *left_parsed, *right_parsed) 76 | -------------------------------------------------------------------------------- /frankateach/sensors/fisheye_cam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from frankateach.constants import CAM_FPS 4 | from frankateach.utils import FrequencyTimer, notify_component_start 5 | from frankateach.network import ZMQCameraPublisher 6 | 7 | import cv2 8 | import time 9 | 10 | 11 | class FishEyeCamera: 12 | def __init__( 13 | self, 14 | host, 15 | port, 16 | cam_id, 17 | cam_config, 18 | ): 19 | # Disabling scientific notations 20 | np.set_printoptions(suppress=True) 21 | self.cam_id = cam_id 22 | self.cam_config = cam_config 23 | self._cam_serial_num = cam_config.cam_serial_num 24 | 25 | # Different publishers to avoid overload 26 | self.rgb_publisher = ZMQCameraPublisher(host, port) 27 | self.timer = FrequencyTimer(CAM_FPS) # 30 fps 28 | 29 | # Starting the Fisheye pipeline 30 | self._start_fisheye() 31 | 32 | def _start_fisheye(self): 33 | self.cap = cv2.VideoCapture(self._cam_serial_num) 34 | # self.cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'MJPG')) 35 | 36 | self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 680) 37 | self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) 38 | 39 | # Check if the camera is opened successfully, wait until it is 40 | while not self.cap.isOpened(): 41 | self.cap.isOpened() 42 | 43 | def get_rgb_depth_images(self): 44 | frame = None 45 | while frame is None: 46 | _, frame = self.cap.read() 47 | timestamp = time.time() 48 | return frame, timestamp 49 | 50 | @staticmethod 51 | def rescale_image(image, rescale_factor): 52 | width, height = ( 53 | int(image.shape[1] / rescale_factor), 54 | int(image.shape[0] / rescale_factor), 55 | ) 56 | return cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA) 57 | 58 | def stream(self): 59 | notify_component_start("fisheye camera") 60 | print(f"Started the pipeline for FishEye camera: {self.cam_id}!") 61 | 62 | try: 63 | while True: 64 | self.timer.start_loop() 65 | color_image, timestamp = self.get_rgb_depth_images() 66 | 67 | # Publishing the rgb images 68 | self.rgb_publisher.pub_rgb_image(color_image, timestamp) 69 | 70 | self.timer.end_loop() 71 | if cv2.waitKey(1) == ord("q"): 72 | break 73 | except KeyboardInterrupt: 74 | pass 75 | finally: 76 | self.cap.release() 77 | print("Shutting down pipeline for camera {}.".format(self.cam_id)) 78 | -------------------------------------------------------------------------------- /frankateach/messages.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import numpy as np 3 | from typing import Tuple 4 | from scipy.spatial.transform import Rotation as R 5 | 6 | 7 | @dataclass 8 | class FrankaState: 9 | pos: np.ndarray 10 | quat: np.ndarray 11 | gripper: np.ndarray 12 | timestamp: float 13 | start_teleop: bool = False 14 | 15 | 16 | @dataclass 17 | class FrankaAction: 18 | pos: np.ndarray 19 | quat: np.ndarray 20 | gripper: np.ndarray 21 | reset: bool 22 | timestamp: float 23 | 24 | 25 | @dataclass 26 | class ControllerState: 27 | created_timestamp: float 28 | 29 | left_x: bool 30 | left_y: bool 31 | left_menu: bool 32 | left_thumbstick: bool 33 | left_index_trigger: float 34 | left_hand_trigger: float 35 | left_thumbstick_axes: np.ndarray[Tuple[float, float]] 36 | left_local_position: np.ndarray[Tuple[float, float, float]] 37 | left_local_rotation: np.ndarray[Tuple[float, float, float, float]] 38 | 39 | right_a: bool 40 | right_b: bool 41 | right_menu: bool 42 | right_thumbstick: bool 43 | right_index_trigger: float 44 | right_hand_trigger: float 45 | right_thumbstick_axes: np.ndarray[Tuple[float, float]] 46 | right_local_position: np.ndarray[Tuple[float, float, float]] 47 | right_local_rotation: np.ndarray[Tuple[float, float, float, float]] 48 | 49 | @property 50 | def right_position(self) -> np.ndarray: 51 | return self.right_affine[:3, 3] 52 | 53 | @property 54 | def left_position(self) -> np.ndarray: 55 | return self.left_affine[:3, 3] 56 | 57 | @property 58 | def right_rotation_matrix(self) -> np.ndarray: 59 | return self.right_affine[:3, :3] 60 | 61 | @property 62 | def left_rotation_matrix(self) -> np.ndarray: 63 | return self.left_affine[:3, :3] 64 | 65 | @property 66 | def left_affine(self) -> np.ndarray: 67 | return self.get_affine(self.left_local_position, self.left_local_rotation) 68 | 69 | @property 70 | def right_affine(self) -> np.ndarray: 71 | return self.get_affine(self.right_local_position, self.right_local_rotation) 72 | 73 | def get_affine( 74 | self, controller_position: np.ndarray, controller_rotation: np.ndarray 75 | ): 76 | """Returns a 4x4 affine matrix from the controller's position and rotation. 77 | Args: 78 | controller_position: 3D position of the controller. 79 | controller_rotation: 4D quaternion of the controller's rotation. 80 | 81 | All in headset space. 82 | """ 83 | 84 | return np.block( 85 | [ 86 | [ 87 | R.as_matrix(R.from_quat(controller_rotation)), 88 | controller_position[:, np.newaxis], 89 | ], 90 | [np.zeros((1, 3)), 1.0], 91 | ] 92 | ) 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Franka-Teach 2 | 3 | Bi-Manual Franka 3 robot setup. 4 | 5 | 6 | ## NUC Setup 7 | 8 | 1. Install Ubuntu 22.04 and a real-time kernel 9 | 2. Make sure the NUC is booted with the real-time kernel [[link](https://frankaemika.github.io/docs/installation_linux.html#setting-up-the-real-time-kernel)]. 10 | 11 | 12 | ## Lambda Machine Setup 13 | 14 | todo: how to setup network, etc. 15 | 16 | 17 | 1. Setup deoxys_control. NOTE: When doing `./InstallPackage`, select `0.13.3` for installing libfranka: 18 | 19 | ```bash 20 | git clone git@github.com:NYU-robot-learning/deoxys_control.git 21 | mamba create -n "franka_teach" python=3.10 22 | conda activate franka_teach 23 | cd deoxys_control/deoxys 24 | 25 | # Instructions from deoxys repo (this takes a while to build everything) 26 | ./InstallPackage 27 | make -j build_deoxys=1 28 | pip install -U -r requirements.txt 29 | ``` 30 | 31 | 2. Install the Franka-Teach requirements: 32 | 33 | ```bash 34 | cd /path/to/Franka-Teach 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | 3. Install ReSkin sensor library: 39 | 40 | ```bash 41 | git clone git@github.com:NYU-robot-learning/reskin_sensor.git 42 | cd reskin_sensor 43 | pip install -e . 44 | ``` 45 | 46 | 47 | ## Proxy Setup 48 | 49 | 1. Install FoxyProxy extension on Chrome or Firefox. Set up the proxy like this: 50 | 51 | ![Foxy Proxy](./imgs/foxy_proxy.png) 52 | 53 | 2. Setup NUC as an ssh host like this: 54 | 55 | ```bash 56 | Host nuc 57 | HostName 10.19.248.70 58 | User robot-lab 59 | LogLevel ERROR 60 | DynamicForward 1337 61 | ``` 62 | 63 | 64 | ## How to run the Franka-Teach environment 65 | 66 | 1. Ssh into the nuc: 67 | 68 | ```bash 69 | ssh nuc 70 | ``` 71 | 72 | 2. Go to `172.16.0.4/desk` for the Franka Desk interface for the right robot and `172.16.1.4/desk` for the left robot. 73 | 74 | The following credentials are used for the Franka Desk interface: 75 | 76 | ``` 77 | Username: GRAIL 78 | Password: grail1234 79 | ``` 80 | 81 | NOTE: Franka Desk is cursed. You might face all sorts of issues with it. General troubleshooting: 82 | 83 | - If it doesn't seem to connect, keep refreshing. If you lose all hope, reboot the robot. 84 | - If end-effector doesn't show as active, you can go to the settings page and do a power off/on for the end-effector. You need to re-initialize the gripper in the same page after doing a power cycle. 85 | - Two desk pages (for two robots) cannot be open at the same time. Close one tab and connect to the other one. 86 | 87 | 3. Open the brakes for the robot: 88 | 89 | ![open_brakes](./imgs/unlock_joints.png) 90 | 91 | 4. Enable FCI mode: 92 | 93 | ![fci](./imgs/fci.png) 94 | 95 | 5. Start the deoxys control process on the NUC: 96 | 97 | ```bash 98 | cd /home/robot-lab/work/deoxys_control/deoxys 99 | ./auto_scripts/auto_arm.sh config/franka_left.yml # franka_right.yml for the right robot 100 | ./auto_scripts/auto_gripper.sh config/franka_left.yml # in a different terminal, if you want to use the gripper 101 | ``` 102 | 103 | 6. From the Lambda, start servers: 104 | 105 | ```bash 106 | cd /path/to/Franka-Teach/ 107 | python3 franka_server.py 108 | python3 camera_server.py # in a different terminal 109 | ``` 110 | 111 | 7. TODO: run franka_env test script: 112 | 113 | ```bash 114 | cd /path/to/Franka-Teach/ 115 | python3 test_franka_env.py 116 | ``` 117 | 118 | ## How to teleoperate 119 | 120 | 1. Do the steps until 6 in the "How to run the Franka-Teach environment" section. 121 | 122 | 123 | 2. Also, start the teleoperation script. Set the teleop mode based on if you are collecting human or robot demonstrations.: 124 | 125 | ```bash 126 | python3 teleop.py teleop_mode= 127 | ``` 128 | 129 | 3. You can start the data collection by running the `collect_data.py` script. Set the `demo_num` to the number of demonstrations you want to collect and `collect_depth` to `True` if you want to collect depth data from the Intel realsense cameras. 130 | 131 | ```bash 132 | python3 collect_data.py demo_num=0 collect_depth= 133 | ``` 134 | 135 | 4. For robot teleoperation, use the VR controllers to control the robot. When collecting human data, use the VR controller to start and stop the data collection while performing the actions with the human hand. 136 | -------------------------------------------------------------------------------- /frankateach/sensors/realsense.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from frankateach.network import ZMQCameraPublisher 4 | from frankateach.utils import FrequencyTimer, notify_component_start 5 | from frankateach.constants import CAM_FPS, DEPTH_PORT_OFFSET 6 | 7 | import pyrealsense2 as rs 8 | import time 9 | 10 | 11 | class RealsenseCamera: 12 | def __init__( 13 | self, 14 | host, 15 | port, 16 | cam_id, 17 | cam_config, 18 | ): 19 | self.cam_id = cam_id 20 | self.cam_config = cam_config 21 | self._cam_serial_num = cam_config.cam_serial_num 22 | self._depth = cam_config.depth 23 | 24 | # Different publishers to avoid overload 25 | self.rgb_publisher = ZMQCameraPublisher(host, port) 26 | 27 | if self._depth: 28 | self.depth_publisher = ZMQCameraPublisher( 29 | host, port=port + DEPTH_PORT_OFFSET 30 | ) 31 | 32 | self.timer = FrequencyTimer(CAM_FPS) 33 | 34 | # Starting the realsense pipeline 35 | self._start_realsense(self._cam_serial_num) 36 | 37 | def _start_realsense(self, cam_serial_num): 38 | config = rs.config() 39 | self.pipeline = rs.pipeline() 40 | config.enable_device(cam_serial_num) 41 | 42 | # Enabling camera streams 43 | config.enable_stream( 44 | rs.stream.color, 45 | self.cam_config.width, 46 | self.cam_config.height, 47 | rs.format.bgr8, 48 | self.cam_config.fps, 49 | ) 50 | if self._depth: 51 | config.enable_stream( 52 | rs.stream.depth, 53 | self.cam_config.width, 54 | self.cam_config.height, 55 | rs.format.z16, 56 | self.cam_config.fps, 57 | ) 58 | 59 | # Starting the pipeline 60 | cfg = self.pipeline.start(config) 61 | device = cfg.get_device() 62 | 63 | time.sleep(1) 64 | 65 | if self._depth: 66 | # Setting the depth mode to high accuracy mode 67 | depth_sensor = device.first_depth_sensor() 68 | depth_sensor.set_option( 69 | rs.option.visual_preset, self.cam_config.processing_preset 70 | ) 71 | 72 | self.realsense = self.pipeline 73 | 74 | # Obtaining the color intrinsics matrix for aligning the color and depth images 75 | profile = self.pipeline.get_active_profile() 76 | color_profile = rs.video_stream_profile(profile.get_stream(rs.stream.color)) 77 | intrinsics = color_profile.get_intrinsics() 78 | self.intrinsics_matrix = np.array( 79 | [ 80 | [intrinsics.fx, 0, intrinsics.ppx], 81 | [0, intrinsics.fy, intrinsics.ppy], 82 | [0, 0, 1], 83 | ] 84 | ) 85 | 86 | # Align function - aligns other frames with the color frame 87 | self.align = rs.align(rs.stream.color) 88 | 89 | def get_rgb_depth_images(self): 90 | frames = None 91 | 92 | while frames is None: 93 | # Obtaining and aligning the frames 94 | frames = self.realsense.wait_for_frames() 95 | aligned_frames = self.align.process(frames) 96 | 97 | color_frame = aligned_frames.get_color_frame() 98 | color_image = np.asanyarray(color_frame.get_data()) 99 | if self._depth: 100 | depth_frame = aligned_frames.get_depth_frame() 101 | depth_image = np.asanyarray(depth_frame.get_data()) 102 | else: 103 | depth_image = None 104 | 105 | return color_image, depth_image, frames.get_timestamp() 106 | 107 | def stream(self): 108 | # Starting the realsense stream 109 | notify_component_start("realsense") 110 | print(f"Started the Realsense pipeline for camera: {self._cam_serial_num}!") 111 | 112 | try: 113 | while True: 114 | self.timer.start_loop() 115 | color_image, depth_image, timestamp = self.get_rgb_depth_images() 116 | 117 | # color_image = rotate_image(color_image, self.cam_configs.rotation_angle) 118 | # depth_image = rotate_image(depth_image, self.cam_configs.rotation_angle) 119 | 120 | self.rgb_publisher.pub_rgb_image(color_image, timestamp) 121 | if self._depth: 122 | self.depth_publisher.pub_depth_image(depth_image, timestamp) 123 | 124 | self.timer.end_loop() 125 | except KeyboardInterrupt: 126 | pass 127 | finally: 128 | print("Shutting down realsense pipeline for camera {}.".format(self.cam_id)) 129 | self.rgb_publisher.stop() 130 | if self._depth: 131 | self.depth_publisher.stop() 132 | self.pipeline.stop() 133 | -------------------------------------------------------------------------------- /frankateach/franka_server.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import pickle 4 | import time 5 | import numpy as np 6 | 7 | from deoxys.utils import YamlConfig 8 | from deoxys.franka_interface import FrankaInterface 9 | from deoxys.utils import transform_utils 10 | from deoxys.utils.config_utils import ( 11 | get_default_controller_config, 12 | verify_controller_config, 13 | ) 14 | 15 | from frankateach.utils import notify_component_start 16 | from frankateach.network import create_response_socket 17 | from frankateach.messages import FrankaAction, FrankaState 18 | from frankateach.constants import ( 19 | CONTROL_PORT, 20 | HOST, 21 | CONTROL_FREQ, 22 | ) 23 | 24 | CONFIG_ROOT = Path(__file__).parent / "configs" 25 | 26 | 27 | class FrankaServer: 28 | def __init__(self, cfg): 29 | self._robot = Robot(cfg, CONTROL_FREQ) 30 | # Action REQ/REP 31 | self.action_socket = create_response_socket(HOST, CONTROL_PORT) 32 | 33 | def init_server(self): 34 | # connect to robot 35 | print("Starting Franka server...") 36 | self._robot.reset_robot() 37 | self.control_daemon() 38 | 39 | def get_state(self): 40 | quat, pos = self._robot.last_eef_quat_and_pos 41 | gripper = self._robot.last_gripper_action 42 | if quat is not None and pos is not None and gripper is not None: 43 | state = FrankaState( 44 | pos=pos.flatten().astype(np.float32), 45 | quat=quat.flatten().astype(np.float32), 46 | gripper=gripper, 47 | timestamp=time.time(), 48 | ) 49 | return bytes(pickle.dumps(state, protocol=-1)) 50 | else: 51 | return b"state_error" 52 | 53 | def control_daemon(self): 54 | notify_component_start(component_name="Franka Control Subscriber") 55 | try: 56 | while True: 57 | command = self.action_socket.recv() 58 | if command == b"get_state": 59 | self.action_socket.send(self.get_state()) 60 | else: 61 | franka_control: FrankaAction = pickle.loads(command) 62 | if franka_control.reset: 63 | self._robot.reset_joints(gripper_open=franka_control.gripper) 64 | time.sleep(1) 65 | else: 66 | self._robot.osc_move( 67 | franka_control.pos, 68 | franka_control.quat, 69 | franka_control.gripper, 70 | ) 71 | self.action_socket.send(self.get_state()) 72 | except KeyboardInterrupt: 73 | pass 74 | finally: 75 | self._robot.close() 76 | self.action_socket.close() 77 | 78 | 79 | class Robot(FrankaInterface): 80 | def __init__(self, cfg, control_freq): 81 | super(Robot, self).__init__( 82 | general_cfg_file=os.path.join(CONFIG_ROOT, cfg), 83 | use_visualizer=False, 84 | control_freq=control_freq, 85 | ) 86 | self.velocity_controller_cfg = verify_controller_config( 87 | YamlConfig( 88 | os.path.join(CONFIG_ROOT, "osc-pose-controller.yml") 89 | ).as_easydict() 90 | ) 91 | 92 | def reset_robot(self): 93 | self.reset() 94 | 95 | print("Waiting for the robot to connect...") 96 | while len(self._state_buffer) == 0: 97 | time.sleep(0.01) 98 | 99 | print("Franka is connected") 100 | 101 | def osc_move(self, target_pos, target_quat, gripper_state): 102 | num_steps = 3 103 | 104 | for _ in range(num_steps): 105 | target_mat = transform_utils.pose2mat(pose=(target_pos, target_quat)) 106 | 107 | current_quat, current_pos = self.last_eef_quat_and_pos 108 | current_mat = transform_utils.pose2mat( 109 | pose=(current_pos.flatten(), current_quat.flatten()) 110 | ) 111 | 112 | pose_error = transform_utils.get_pose_error( 113 | target_pose=target_mat, current_pose=current_mat 114 | ) 115 | 116 | if np.dot(target_quat, current_quat) < 0.0: 117 | current_quat = -current_quat 118 | 119 | quat_diff = transform_utils.quat_distance(target_quat, current_quat) 120 | axis_angle_diff = transform_utils.quat2axisangle(quat_diff) 121 | 122 | action_pos = pose_error[:3] 123 | action_axis_angle = axis_angle_diff.flatten() 124 | 125 | action = action_pos.tolist() + action_axis_angle.tolist() + [gripper_state] 126 | 127 | self.control( 128 | controller_type="OSC_POSE", 129 | action=action, 130 | controller_cfg=self.velocity_controller_cfg, 131 | ) 132 | 133 | def reset_joints( 134 | self, 135 | timeout=7, 136 | gripper_open=False, 137 | ): 138 | start_joint_pos = [ 139 | 0.09162008114028396, 140 | -0.19826458111314524, 141 | -0.01990020486871322, 142 | -2.4732269941140346, 143 | -0.01307073642274261, 144 | 2.30396583422025, 145 | 0.8480939705504309, 146 | ] 147 | assert type(start_joint_pos) is list or type(start_joint_pos) is np.ndarray 148 | controller_cfg = get_default_controller_config(controller_type="JOINT_POSITION") 149 | 150 | if gripper_open: 151 | gripper_action = -1 152 | else: 153 | gripper_action = 1 154 | 155 | # This is for varying initialization of joints a little bit to 156 | # increase data variation. 157 | # start_joint_pos = [ 158 | # e + np.clip(np.random.randn() * 0.005, -0.005, 0.005) 159 | # for e in start_joint_pos 160 | # ] 161 | if type(start_joint_pos) is list: 162 | action = start_joint_pos + [gripper_action] 163 | else: 164 | action = start_joint_pos.tolist() + [gripper_action] 165 | start_time = time.time() 166 | while True: 167 | if self.received_states and self.check_nonzero_configuration(): 168 | if ( 169 | np.max(np.abs(np.array(self.last_q) - np.array(start_joint_pos))) 170 | < 1e-3 171 | ): 172 | break 173 | self.control( 174 | controller_type="JOINT_POSITION", 175 | action=action, 176 | controller_cfg=controller_cfg, 177 | ) 178 | end_time = time.time() 179 | 180 | # Add timeout 181 | if end_time - start_time > timeout: 182 | break 183 | return True 184 | -------------------------------------------------------------------------------- /frankateach/teleoperator.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pickle 3 | from frankateach.utils import notify_component_start 4 | from frankateach.network import ( 5 | ZMQKeypointSubscriber, 6 | create_request_socket, 7 | ZMQKeypointPublisher, 8 | ) 9 | from frankateach.constants import ( 10 | COMMANDED_STATE_PORT, 11 | CONTROL_PORT, 12 | HOST, 13 | STATE_PORT, 14 | VR_CONTROLLER_STATE_PORT, 15 | H_R_V, 16 | H_R_V_star, 17 | ROBOT_WORKSPACE_MIN, 18 | ROBOT_WORKSPACE_MAX, 19 | GRIPPER_OPEN, 20 | GRIPPER_CLOSE, 21 | ) 22 | from frankateach.messages import FrankaAction, FrankaState 23 | 24 | from deoxys.utils import transform_utils 25 | 26 | import numpy as np 27 | from numpy.linalg import pinv 28 | 29 | 30 | def get_relative_affine(init_affine, current_affine): 31 | H_V_des = pinv(init_affine) @ current_affine 32 | 33 | # Transform to robot frame. 34 | relative_affine_rot = (pinv(H_R_V) @ H_V_des @ H_R_V)[:3, :3] 35 | relative_affine_trans = (pinv(H_R_V_star) @ H_V_des @ H_R_V_star)[:3, 3] 36 | 37 | # Homogeneous coordinates 38 | relative_affine = np.block( 39 | [[relative_affine_rot, relative_affine_trans.reshape(3, 1)], [0, 0, 0, 1]] 40 | ) 41 | 42 | return relative_affine 43 | 44 | 45 | class FrankaOperator: 46 | def __init__( 47 | self, 48 | init_gripper_state="open", 49 | teleop_mode="robot", 50 | home_offset=[0, 0, 0], 51 | ) -> None: 52 | # Subscribe controller state 53 | self._controller_state_subscriber = ZMQKeypointSubscriber( 54 | host=HOST, port=VR_CONTROLLER_STATE_PORT, topic="controller_state" 55 | ) 56 | 57 | self.action_socket = create_request_socket(HOST, CONTROL_PORT) 58 | self.state_socket = ZMQKeypointPublisher(HOST, STATE_PORT) 59 | self.commanded_state_socket = ZMQKeypointPublisher(HOST, COMMANDED_STATE_PORT) 60 | 61 | # Class variables 62 | # self._save_states = save_states 63 | self.is_first_frame = True 64 | self.gripper_state = ( 65 | GRIPPER_OPEN if init_gripper_state == "open" else GRIPPER_CLOSE 66 | ) 67 | self.start_teleop = False 68 | self.init_affine = None 69 | self.teleop_mode = teleop_mode 70 | 71 | if teleop_mode == "human" and home_offset is None: 72 | home_offset = [-0.22, 0.0, 0.1] 73 | self.home_offset = ( 74 | np.array(home_offset) if home_offset is not None else np.zeros(3) 75 | ) 76 | 77 | def _apply_retargeted_angles(self) -> None: 78 | self.controller_state = self._controller_state_subscriber.recv_keypoints() 79 | 80 | if self.is_first_frame: 81 | print("Resetting robot..") 82 | action = FrankaAction( 83 | pos=np.zeros(3), 84 | quat=np.zeros(4), 85 | gripper=self.gripper_state, 86 | reset=True, 87 | timestamp=time.time(), 88 | ) 89 | self.action_socket.send(bytes(pickle.dumps(action, protocol=-1))) 90 | robot_state = pickle.loads(self.action_socket.recv()) 91 | 92 | # Move to offset position 93 | target_pos = robot_state.pos + self.home_offset 94 | target_quat = robot_state.quat 95 | action = FrankaAction( 96 | pos=target_pos.flatten().astype(np.float32), 97 | quat=target_quat.flatten().astype(np.float32), 98 | gripper=self.gripper_state, 99 | reset=False, 100 | timestamp=time.time(), 101 | ) 102 | self.action_socket.send(bytes(pickle.dumps(action, protocol=-1))) 103 | robot_state = pickle.loads(self.action_socket.recv()) 104 | # HOME <- Pos: [0.457632 0.0321814 0.2653815], Quat: [0.9998586 0.00880853 0.01421072 0.00179784] 105 | 106 | print(robot_state) 107 | self.home_rot, self.home_pos = ( 108 | transform_utils.quat2mat(robot_state.quat), 109 | robot_state.pos, 110 | ) 111 | 112 | self.is_first_frame = False 113 | if self.controller_state.right_a: 114 | self.start_teleop = True 115 | self.init_affine = self.controller_state.right_affine 116 | if self.controller_state.right_b: 117 | self.start_teleop = False 118 | self.init_affine = None 119 | # receive the robot state 120 | self.action_socket.send(b"get_state") 121 | robot_state: FrankaState = pickle.loads(self.action_socket.recv()) 122 | if robot_state == b"state_error": 123 | print("Error getting robot state") 124 | return 125 | 126 | self.home_rot, self.home_pos = ( 127 | transform_utils.quat2mat(robot_state.quat), 128 | robot_state.pos, 129 | ) 130 | 131 | if self.start_teleop and self.teleop_mode == "robot": 132 | relative_affine = get_relative_affine( 133 | self.init_affine, self.controller_state.right_affine 134 | ) 135 | else: 136 | relative_affine = np.zeros((4, 4)) 137 | relative_affine[3, 3] = 1 138 | 139 | gripper_action = None 140 | if self.teleop_mode == "robot": 141 | if self.controller_state.right_index_trigger > 0.5: 142 | gripper_action = GRIPPER_CLOSE 143 | elif self.controller_state.right_hand_trigger > 0.5: 144 | gripper_action = GRIPPER_OPEN 145 | 146 | if gripper_action is not None and gripper_action != self.gripper_state: 147 | self.gripper_state = gripper_action 148 | 149 | if self.start_teleop: 150 | relative_pos, relative_rot = ( 151 | relative_affine[:3, 3], 152 | relative_affine[:3, :3], 153 | ) 154 | 155 | target_pos = self.home_pos + relative_pos 156 | target_rot = self.home_rot @ relative_rot 157 | target_quat = transform_utils.mat2quat(target_rot) 158 | 159 | target_pos = np.clip( 160 | target_pos, 161 | a_min=ROBOT_WORKSPACE_MIN, 162 | a_max=ROBOT_WORKSPACE_MAX, 163 | ) 164 | 165 | else: 166 | target_pos, target_quat = ( 167 | self.home_pos + self.home_offset, 168 | transform_utils.mat2quat(self.home_rot), 169 | ) 170 | 171 | action = FrankaAction( 172 | pos=target_pos.flatten().astype(np.float32), 173 | quat=target_quat.flatten().astype(np.float32), 174 | gripper=self.gripper_state, 175 | reset=False, 176 | timestamp=time.time(), 177 | ) 178 | if self.teleop_mode == "robot": 179 | self.action_socket.send(bytes(pickle.dumps(action, protocol=-1))) 180 | else: 181 | self.action_socket.send(b"get_state") 182 | 183 | # self.action_socket.send(bytes(pickle.dumps(action, protocol=-1))) 184 | robot_state = self.action_socket.recv() 185 | 186 | robot_state = pickle.loads(robot_state) 187 | robot_state.start_teleop = self.start_teleop 188 | # self.state_socket.send(bytes(pickle.dumps(robot_state, protocol=-1))) 189 | self.state_socket.pub_keypoints(robot_state, "robot_state") 190 | self.commanded_state_socket.pub_keypoints(action, "commanded_robot_state") 191 | 192 | def stream(self): 193 | notify_component_start("Franka teleoperator control") 194 | print("Start controlling the robot hand using the Oculus Headset.\n") 195 | 196 | try: 197 | while True: 198 | # Retargeting function 199 | self._apply_retargeted_angles() 200 | except KeyboardInterrupt: 201 | pass 202 | finally: 203 | self._controller_state_subscriber.stop() 204 | self.action_socket.close() 205 | 206 | print("Stopping the teleoperator!") 207 | 208 | 209 | def main(): 210 | operator = FrankaOperator(save_states=True) 211 | operator.stream() 212 | 213 | 214 | if __name__ == "__main__": 215 | main() 216 | -------------------------------------------------------------------------------- /frankateach/data_collector.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from pathlib import Path 3 | import pickle 4 | import cv2 5 | import time 6 | import threading 7 | import h5py 8 | import numpy as np 9 | 10 | from frankateach.network import ( 11 | ZMQCameraSubscriber, 12 | ZMQKeypointSubscriber, 13 | ) 14 | from frankateach.sensors.reskin import ReskinSensorSubscriber 15 | from frankateach.utils import notify_component_start 16 | 17 | from frankateach.constants import ( 18 | COMMANDED_STATE_PORT, 19 | HOST, 20 | CAM_PORT, 21 | STATE_PORT, 22 | DEPTH_PORT_OFFSET, 23 | RESKIN_STREAM_PORT, 24 | ) 25 | 26 | 27 | class DataCollector: 28 | def __init__( 29 | self, 30 | storage_path: str, 31 | demo_num: int, 32 | cams=[], # camera info dicts 33 | cam_config=None, # camera configs by type 34 | collect_img=False, 35 | collect_state=False, 36 | collect_depth=False, 37 | collect_reskin=False, 38 | ): 39 | self.image_subscribers = {} 40 | self.depth_subscribers = {} 41 | if collect_img: 42 | for camera in cams: 43 | self.image_subscribers[camera.cam_id] = ZMQCameraSubscriber( 44 | HOST, CAM_PORT + camera.cam_id, "RGB" 45 | ) 46 | 47 | if collect_depth: 48 | for camera in cams: 49 | if camera.type == "realsense": 50 | self.depth_subscribers[camera.cam_id] = ZMQCameraSubscriber( 51 | HOST, CAM_PORT + DEPTH_PORT_OFFSET + camera.cam_id, "Depth" 52 | ) 53 | 54 | if collect_state: 55 | self.state_socket = ZMQKeypointSubscriber( 56 | host=HOST, port=STATE_PORT, topic="robot_state" 57 | ) 58 | # self.state_socket = create_response_socket(HOST, STATE_PORT) 59 | self.commanded_state_socket = ZMQKeypointSubscriber( 60 | host=HOST, port=COMMANDED_STATE_PORT, topic="commanded_robot_state" 61 | ) 62 | 63 | if collect_reskin: 64 | self.reskin_subscriber = ReskinSensorSubscriber() 65 | 66 | # Create the storage directory 67 | self.storage_path = Path(storage_path) / f"demonstration_{demo_num}" 68 | self.storage_path.mkdir(parents=True, exist_ok=True) 69 | print("Storage path: ", self.storage_path) 70 | 71 | self.run_event = threading.Event() 72 | self.run_event.set() 73 | self.threads = [] 74 | 75 | # Set up image subscribers 76 | for camera in cams: 77 | if collect_img: 78 | self.threads.append( 79 | threading.Thread( 80 | target=self.save_rgb, 81 | args=(camera.cam_id, cam_config[camera.type]), 82 | daemon=True, 83 | ) 84 | ) 85 | if collect_depth and camera.type == "realsense": 86 | self.threads.append( 87 | threading.Thread( 88 | target=self.save_depth, 89 | args=(camera.cam_id, cam_config[camera.type]), 90 | daemon=True, 91 | ) 92 | ) 93 | 94 | if collect_state: 95 | self.threads.append(threading.Thread(target=self.save_states, daemon=True)) 96 | 97 | if collect_reskin: 98 | self.threads.append(threading.Thread(target=self.save_reskin, daemon=True)) 99 | 100 | def start(self): 101 | for thread in self.threads: 102 | thread.start() 103 | try: 104 | while True: 105 | pass 106 | except KeyboardInterrupt: 107 | print("Stopping data collection...") 108 | self.run_event.clear() # Ensure this clears the event to stop threads 109 | for thread in self.threads: 110 | thread.join() 111 | 112 | def save_rgb(self, cam_idx, cam_config): 113 | notify_component_start(component_name="RGB Image Collector") 114 | 115 | filename = self.storage_path / f"cam_{cam_idx}_rgb_video.avi" 116 | metadata_filename = self.storage_path / f"cam_{cam_idx}_rgb_video.metadata" 117 | 118 | recorder = cv2.VideoWriter( 119 | str(filename), 120 | cv2.VideoWriter_fourcc(*"XVID"), 121 | cam_config["fps"], 122 | (cam_config["width"], cam_config["height"]), 123 | ) 124 | 125 | timestamps = [] 126 | metadata = dict( 127 | cam_idx=cam_idx, 128 | width=cam_config.width, 129 | height=cam_config.height, 130 | fps=cam_config.fps, 131 | filename=filename, 132 | record_start_time=time.time(), 133 | ) 134 | 135 | try: 136 | # Loop to capture frames until stopped 137 | while self.run_event.is_set(): 138 | rgb_image, timestamp = self.image_subscribers[cam_idx].recv_rgb_image() 139 | recorder.write(rgb_image) 140 | timestamps.append(timestamp) 141 | finally: 142 | # Ensure resources are released regardless of exit conditions 143 | recorder.release() 144 | metadata["record_end_time"] = time.time() 145 | metadata["num_image_frames"] = len(timestamps) 146 | metadata["timestamps"] = timestamps 147 | with open(metadata_filename, "wb") as f: 148 | pickle.dump(metadata, f) 149 | self.image_subscribers[cam_idx].stop() 150 | print(f"Saved video to {filename}") 151 | 152 | # def save_depth(self, cam_idx, cam_config): 153 | # raise NotImplementedError("Depth recording is not yet implemented") 154 | 155 | def save_depth(self, cam_idx, cam_config): 156 | notify_component_start(component_name="Depth Image Collector") 157 | 158 | filename = self.storage_path / f"cam_{cam_idx}_depth.pkl" 159 | metadata_filename = self.storage_path / f"cam_{cam_idx}_depth.metadata" 160 | 161 | depth_frames = [] 162 | 163 | timestamps = [] 164 | metadata = dict( 165 | cam_idx=cam_idx, 166 | width=cam_config.width, 167 | height=cam_config.height, 168 | fps=cam_config.fps, 169 | filename=filename, 170 | record_start_time=time.time(), 171 | ) 172 | 173 | try: 174 | # Loop to capture frames until stopped 175 | while self.run_event.is_set(): 176 | depth_frame, timestamp = self.depth_subscribers[ 177 | cam_idx 178 | ].recv_depth_image() 179 | depth_frames.append(depth_frame) 180 | timestamps.append(timestamp) 181 | finally: 182 | # Ensure resources are released regardless of exit conditions 183 | metadata["record_end_time"] = time.time() 184 | metadata["num_image_frames"] = len(timestamps) 185 | metadata["timestamps"] = timestamps 186 | with open(metadata_filename, "wb") as f: 187 | pickle.dump(metadata, f) 188 | with open(filename, "wb") as f: 189 | pickle.dump(depth_frames, f) 190 | self.depth_subscribers[cam_idx].stop() 191 | print(f"Saved depth to {filename}") 192 | 193 | def save_states(self): 194 | notify_component_start(component_name="State Collector") 195 | 196 | filename = self.storage_path / "states.pkl" 197 | cmd_filename = self.storage_path / "commanded_states.pkl" 198 | states = [] 199 | commanded_states = [] 200 | 201 | while self.run_event.is_set(): 202 | # state = pickle.loads(self.state_socket.recv()) 203 | state = self.state_socket.recv_keypoints() 204 | commanded_state = self.commanded_state_socket.recv_keypoints() 205 | states.append(state) 206 | commanded_states.append(commanded_state) 207 | 208 | with open(filename, "wb") as f: 209 | pickle.dump(states, f) 210 | 211 | with open(cmd_filename, "wb") as f: 212 | pickle.dump(commanded_states, f) 213 | 214 | print("Saved states to ", filename) 215 | # self.state_socket.close() 216 | self.state_socket.stop() 217 | self.commanded_state_socket.stop() 218 | 219 | print( 220 | "Frequency of state savings: ", 221 | (len(states) - 10) / (states[-1].timestamp - states[10].timestamp), 222 | ) 223 | 224 | print( 225 | "Frequency of commanded state savings: ", 226 | (len(commanded_states) - 10) 227 | / (commanded_states[-1].timestamp - commanded_states[10].timestamp), 228 | ) 229 | 230 | def save_reskin(self): 231 | notify_component_start(component_name="Reskin Collector") 232 | 233 | sensor_information = defaultdict(list) 234 | filename = self.storage_path / "reskin_sensor_values.h5" 235 | 236 | print("Starting to record Reskin frames from port:", RESKIN_STREAM_PORT) 237 | 238 | while self.run_event.is_set(): 239 | reskin_state = self.reskin_subscriber.get_sensor_state() 240 | for attr in reskin_state.keys(): 241 | sensor_information[attr].append(reskin_state[attr]) 242 | 243 | print("Finished recording Reskin frames") 244 | 245 | with h5py.File(filename, "w") as hf: 246 | for key in sensor_information.keys(): 247 | sensor_information[key] = np.array( 248 | sensor_information[key], 249 | dtype=np.float32 if key != "timestamp" else np.float64, 250 | ) 251 | hf.create_dataset( 252 | key, 253 | data=sensor_information[key], 254 | compression="gzip", 255 | compression_opts=6, 256 | ) 257 | 258 | print("Saved Reskin sensor data to ", filename) 259 | print( 260 | "ReSkin Data duration: ", 261 | sensor_information["timestamp"][-1] - sensor_information["timestamp"][0], 262 | ) 263 | 264 | print( 265 | "Frequency of ReSkin savings: ", 266 | len(sensor_information["timestamp"]) 267 | / ( 268 | sensor_information["timestamp"][-1] - sensor_information["timestamp"][0] 269 | ), 270 | ) 271 | -------------------------------------------------------------------------------- /frankateach/network.py: -------------------------------------------------------------------------------- 1 | import zmq 2 | import cv2 3 | import base64 4 | import numpy as np 5 | import pickle 6 | import blosc as bl 7 | import threading 8 | 9 | 10 | def flush_socket(socket): 11 | """Flush all messages currently in the socket.""" 12 | while True: 13 | try: 14 | # Check if a message is waiting in the queue 15 | message = socket.recv(zmq.NOBLOCK) 16 | print(message) 17 | except zmq.Again: 18 | # No more messages to flush 19 | break 20 | 21 | 22 | # ZMQ Sockets 23 | def create_push_socket(host, port): 24 | context = zmq.Context() 25 | socket = context.socket(zmq.PUSH) 26 | socket.bind("tcp://{}:{}".format(host, port)) 27 | return socket 28 | 29 | 30 | def create_pull_socket(host, port): 31 | context = zmq.Context() 32 | socket = context.socket(zmq.PULL) 33 | socket.setsockopt(zmq.CONFLATE, 1) 34 | socket.bind("tcp://{}:{}".format(host, port)) 35 | return socket 36 | 37 | 38 | def create_response_socket(host, port): 39 | content = zmq.Context() 40 | socket = content.socket(zmq.REP) 41 | socket.bind("tcp://{}:{}".format(host, port)) 42 | return socket 43 | 44 | 45 | def create_request_socket(host, port): 46 | context = zmq.Context() 47 | socket = context.socket(zmq.REQ) 48 | socket.connect("tcp://{}:{}".format(host, port)) 49 | return socket 50 | 51 | 52 | def create_subscriber_socket(host, port, topic, conflate=False): 53 | context = zmq.Context() 54 | socket = context.socket(zmq.SUB) 55 | socket.setsockopt(zmq.CONFLATE, int(conflate)) 56 | socket.connect("tcp://{}:{}".format(host, port)) 57 | socket.subscribe(topic) 58 | flush_socket(socket) 59 | return socket 60 | 61 | 62 | # Pub/Sub classes for Keypoints 63 | class ZMQKeypointPublisher(object): 64 | def __init__(self, host, port): 65 | self._host, self._port = host, port 66 | self._init_publisher() 67 | 68 | def _init_publisher(self): 69 | self.context = zmq.Context() 70 | self.socket = self.context.socket(zmq.PUB) 71 | self.socket.bind("tcp://{}:{}".format(self._host, self._port)) 72 | 73 | def pub_keypoints(self, keypoint_array, topic_name): 74 | """ 75 | Process the keypoints into a byte stream and input them in this function 76 | """ 77 | buffer = pickle.dumps(keypoint_array, protocol=-1) 78 | self.socket.send(bytes("{} ".format(topic_name), "utf-8") + buffer) 79 | 80 | def stop(self): 81 | print("Closing the publisher socket in {}:{}.".format(self._host, self._port)) 82 | self.socket.close() 83 | self.context.term() 84 | 85 | 86 | class ZMQKeypointSubscriber(threading.Thread): 87 | def __init__(self, host, port, topic): 88 | self._host, self._port, self._topic = host, port, topic 89 | self._init_subscriber() 90 | 91 | # Topic chars to remove 92 | self.strip_value = bytes("{} ".format(self._topic), "utf-8") 93 | 94 | def _init_subscriber(self): 95 | self.context = zmq.Context() 96 | self.socket = self.context.socket(zmq.SUB) 97 | self.socket.setsockopt(zmq.CONFLATE, 1) 98 | self.socket.connect("tcp://{}:{}".format(self._host, self._port)) 99 | self.socket.setsockopt(zmq.SUBSCRIBE, bytes(self._topic, "utf-8")) 100 | 101 | def recv_keypoints(self, flags=None): 102 | if flags is None: 103 | raw_data = self.socket.recv() 104 | raw_array = raw_data.lstrip(self.strip_value) 105 | return pickle.loads(raw_array) 106 | else: # For possible usage of no blocking zmq subscriber 107 | try: 108 | raw_data = self.socket.recv(flags) 109 | raw_array = raw_data.lstrip(self.strip_value) 110 | return pickle.loads(raw_array) 111 | except zmq.Again: 112 | # print('zmq again error') 113 | return None 114 | 115 | def stop(self): 116 | print("Closing the subscriber socket in {}:{}.".format(self._host, self._port)) 117 | self.socket.close() 118 | self.context.term() 119 | 120 | 121 | # Pub/Sub classes for storing data from Realsense Cameras 122 | class ZMQCameraPublisher(object): 123 | def __init__(self, host, port): 124 | self._host, self._port = host, port 125 | self._init_publisher() 126 | 127 | def _init_publisher(self): 128 | self.context = zmq.Context() 129 | self.socket = self.context.socket(zmq.PUB) 130 | print("tcp://{}:{}".format(self._host, self._port)) 131 | self.socket.bind("tcp://{}:{}".format(self._host, self._port)) 132 | 133 | def pub_intrinsics(self, array): 134 | self.socket.send(b"intrinsics " + pickle.dumps(array, protocol=-1)) 135 | 136 | def pub_rgb_image(self, rgb_image, timestamp): 137 | _, buffer = cv2.imencode(".jpg", rgb_image, [int(cv2.IMWRITE_JPEG_QUALITY), 70]) 138 | data = dict(timestamp=timestamp, rgb_image=base64.b64encode(buffer)) 139 | self.socket.send(b"rgb_image " + pickle.dumps(data, protocol=-1)) 140 | 141 | def pub_depth_image(self, depth_image, timestamp): 142 | compressed_depth = bl.pack_array( 143 | depth_image, cname="zstd", clevel=1, shuffle=bl.NOSHUFFLE 144 | ) 145 | data = dict(timestamp=timestamp, depth_image=compressed_depth) 146 | self.socket.send(b"depth_image " + pickle.dumps(data, protocol=-1)) 147 | 148 | def stop(self): 149 | print("Closing the publisher socket in {}:{}.".format(self._host, self._port)) 150 | self.socket.close() 151 | self.context.term() 152 | 153 | 154 | class ZMQCameraSubscriber(threading.Thread): 155 | def __init__(self, host, port, topic_type): 156 | self._host, self._port, self._topic_type = host, port, topic_type 157 | self._init_subscriber() 158 | 159 | def _init_subscriber(self): 160 | self.context = zmq.Context() 161 | self.socket = self.context.socket(zmq.SUB) 162 | self.socket.setsockopt(zmq.CONFLATE, 1) 163 | print("tcp://{}:{}".format(self._host, self._port)) 164 | self.socket.connect("tcp://{}:{}".format(self._host, self._port)) 165 | 166 | if self._topic_type == "Intrinsics": 167 | self.socket.setsockopt(zmq.SUBSCRIBE, b"intrinsics") 168 | elif self._topic_type == "RGB": 169 | self.socket.setsockopt(zmq.SUBSCRIBE, b"rgb_image") 170 | elif self._topic_type == "Depth": 171 | self.socket.setsockopt(zmq.SUBSCRIBE, b"depth_image") 172 | 173 | def recv_intrinsics(self): 174 | raw_data = self.socket.recv() 175 | raw_array = raw_data.lstrip(b"intrinsics ") 176 | return pickle.loads(raw_array) 177 | 178 | def recv_rgb_image(self): 179 | raw_data = self.socket.recv() 180 | data = raw_data.lstrip(b"rgb_image ") 181 | data = pickle.loads(data) 182 | encoded_data = np.fromstring(base64.b64decode(data["rgb_image"]), np.uint8) 183 | return cv2.imdecode(encoded_data, 1), data["timestamp"] 184 | 185 | def recv_depth_image(self): 186 | raw_data = self.socket.recv() 187 | striped_data = raw_data.lstrip(b"depth_image ") 188 | data = pickle.loads(striped_data) 189 | depth_image = bl.unpack_array(data["depth_image"]) 190 | return np.array(depth_image, dtype=np.int16), data["timestamp"] 191 | 192 | def stop(self): 193 | print("Closing the subscriber socket in {}:{}.".format(self._host, self._port)) 194 | self.socket.close() 195 | self.context.term() 196 | 197 | 198 | # Publisher for image visualizers 199 | class ZMQCompressedImageTransmitter(object): 200 | def __init__(self, host, port): 201 | self._host, self._port = host, port 202 | # self._init_push_socket() 203 | self._init_publisher() 204 | 205 | def _init_publisher(self): 206 | self.context = zmq.Context() 207 | self.socket = self.context.socket(zmq.PUB) 208 | self.socket.bind("tcp://{}:{}".format(self._host, self._port)) 209 | 210 | def _init_push_socket(self): 211 | self.context = zmq.Context() 212 | self.socket = self.context.socket(zmq.PUSH) 213 | self.socket.bind("tcp://{}:{}".format(self._host, self._port)) 214 | 215 | def send_image(self, rgb_image): 216 | _, buffer = cv2.imencode(".jpg", rgb_image, [int(cv2.IMWRITE_WEBP_QUALITY), 10]) 217 | self.socket.send(np.array(buffer).tobytes()) 218 | 219 | def stop(self): 220 | print("Closing the publisher in {}:{}.".format(self._host, self._port)) 221 | self.socket.close() 222 | self.context.term() 223 | 224 | 225 | class ZMQCompressedImageReciever(threading.Thread): 226 | def __init__(self, host, port): 227 | self._host, self._port = host, port 228 | # self._init_pull_socket() 229 | self._init_subscriber() 230 | 231 | def _init_subscriber(self): 232 | self.context = zmq.Context() 233 | self.socket = self.context.socket(zmq.SUB) 234 | self.socket.setsockopt(zmq.CONFLATE, 1) 235 | self.socket.connect("tcp://{}:{}".format(self._host, self._port)) 236 | self.socket.subscribe("") 237 | 238 | def _init_pull_socket(self): 239 | self.context = zmq.Context() 240 | self.socket = self.context.socket(zmq.PULL) 241 | self.socket.setsockopt(zmq.CONFLATE, 1) 242 | self.socket.connect("tcp://{}:{}".format(self._host, self._port)) 243 | 244 | def recv_image(self): 245 | raw_data = self.socket.recv() 246 | encoded_data = np.fromstring(raw_data, np.uint8) 247 | decoded_frame = cv2.imdecode(encoded_data, 1) 248 | return decoded_frame 249 | 250 | def stop(self): 251 | print("Closing the subscriber socket in {}:{}.".format(self._host, self._port)) 252 | self.socket.close() 253 | self.context.term() 254 | 255 | 256 | class ZMQButtonFeedbackSubscriber(threading.Thread): 257 | def __init__(self, host, port): 258 | self._host, self._port = host, port 259 | # self._init_pull_socket() 260 | self._init_subscriber() 261 | 262 | def _init_subscriber(self): 263 | self.context = zmq.Context() 264 | self.socket = self.context.socket(zmq.SUB) 265 | self.socket.setsockopt(zmq.CONFLATE, 1) 266 | self.socket.connect("tcp://{}:{}".format(self._host, self._port)) 267 | self.socket.subscribe("") 268 | 269 | def _init_pull_socket(self): 270 | self.context = zmq.Context() 271 | self.socket = self.context.socket(zmq.PULL) 272 | self.socket.setsockopt(zmq.CONFLATE, 1) 273 | self.socket.connect("tcp://{}:{}".format(self._host, self._port)) 274 | 275 | def recv_keypoints(self): 276 | raw_data = self.socket.recv() 277 | return pickle.loads(raw_data) 278 | 279 | def stop(self): 280 | print("Closing the subscriber socket in {}:{}.".format(self._host, self._port)) 281 | self.socket.close() 282 | self.context.term() 283 | -------------------------------------------------------------------------------- /franka-env/franka_env/envs/franka_env.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import gym 3 | import numpy as np 4 | import time 5 | import pickle 6 | 7 | from frankateach.constants import ( 8 | CAM_PORT, 9 | GRIPPER_CLOSE, 10 | GRIPPER_OPEN, 11 | HOST, 12 | CONTROL_PORT, 13 | ) 14 | from frankateach.messages import FrankaAction, FrankaState 15 | from frankateach.network import ( 16 | ZMQCameraSubscriber, 17 | create_request_socket, 18 | ) 19 | 20 | try: 21 | from frankateach.sensors.reskin import ReskinSensorSubscriber 22 | except ImportError: 23 | print("ReskinSensorSubscriber not found") 24 | ReskinSensorSubscriber = None 25 | 26 | 27 | class FrankaEnv(gym.Env): 28 | def __init__( 29 | self, 30 | cam_ids=[1, 2, 3, 4, 51], 31 | width=640, 32 | height=480, 33 | use_robot=True, 34 | sensor_type=None, 35 | sensor_params=None, 36 | ): 37 | super(FrankaEnv, self).__init__() 38 | self.width = width 39 | self.height = height 40 | self.channels = 3 41 | self.feature_dim = 8 42 | self.action_dim = 7 # (pos, axis angle, gripper) 43 | 44 | self.use_robot = use_robot 45 | self.sensor_type = sensor_type 46 | if sensor_type is not None: 47 | assert sensor_type in ["reskin"] 48 | assert ( 49 | ReskinSensorSubscriber is not None 50 | ), "ReskinSensorSubscriber not found" 51 | if sensor_type == "reskin": 52 | self.n_sensors = 2 53 | self.sensor_dim = 15 54 | self.sensor_params = sensor_params 55 | 56 | self.n_channels = 3 57 | self.reward = 0 58 | 59 | self.franka_state = None 60 | self.curr_images = None 61 | 62 | self.action_space = gym.spaces.Box( 63 | low=-float("inf"), high=float("inf"), shape=(self.action_dim,) 64 | ) 65 | obs_space = { 66 | f"pixels{cam_id}": gym.spaces.Box( 67 | low=0, high=255, shape=(height, width, self.n_channels), dtype=np.uint8 68 | ) 69 | for cam_id in cam_ids 70 | } 71 | obs_space["features"] = gym.spaces.Box( 72 | low=-float("inf"), 73 | high=float("inf"), 74 | shape=(self.feature_dim,), 75 | dtype=np.float32, 76 | ) 77 | obs_space["proprioceptive"] = gym.spaces.Box( 78 | low=-float("inf"), 79 | high=float("inf"), 80 | shape=(self.feature_dim,), 81 | dtype=np.float32, 82 | ) 83 | if self.sensor_type == "reskin": 84 | for sensor_idx in range(self.n_sensors): 85 | obs_space[f"sensor{sensor_idx}"] = gym.spaces.Box( 86 | low=-float("inf"), 87 | high=float("inf"), 88 | shape=(self.sensor_dim,), 89 | dtype=np.float32, 90 | ) 91 | obs_space[f"sensor{sensor_idx}_diffs"] = gym.spaces.Box( 92 | low=-float("inf"), 93 | high=float("inf"), 94 | shape=(self.sensor_dim,), 95 | dtype=np.float32, 96 | ) 97 | self.observation_space = gym.spaces.Dict(obs_space) 98 | 99 | if self.use_robot: 100 | self.image_subscribers = {} 101 | for cam_idx in cam_ids: 102 | port = CAM_PORT + cam_idx 103 | self.image_subscribers[cam_idx] = ZMQCameraSubscriber( 104 | host=HOST, 105 | port=port, 106 | topic_type="RGB", 107 | ) 108 | 109 | if self.sensor_type == "reskin": 110 | self.sensor_subscriber = ReskinSensorSubscriber() 111 | 112 | self.sensor_prev_state = None 113 | self.subtract_sensor_baseline = sensor_params[ 114 | "subtract_sensor_baseline" 115 | ] 116 | 117 | # Call once to populate initial baseline 118 | self._get_reskin_state(update_baseline=True) 119 | 120 | self.action_request_socket = create_request_socket(HOST, CONTROL_PORT) 121 | 122 | def get_state(self): 123 | self.action_request_socket.send(b"get_state") 124 | franka_state: FrankaState = pickle.loads(self.action_request_socket.recv()) 125 | self.franka_state = franka_state 126 | return franka_state 127 | 128 | def step(self, abs_action): 129 | pos = abs_action[:3] 130 | quat = abs_action[3:7] 131 | gripper = abs_action[-1] 132 | if gripper < 0.0: 133 | gripper = GRIPPER_OPEN 134 | else: 135 | gripper = GRIPPER_CLOSE 136 | 137 | # Send action to the robot 138 | franka_action = FrankaAction( 139 | pos=pos, 140 | quat=quat, 141 | gripper=gripper, 142 | reset=False, 143 | timestamp=time.time(), 144 | ) 145 | # print("sending action to robot: ", franka_action) 146 | 147 | self.action_request_socket.send(bytes(pickle.dumps(franka_action, protocol=-1))) 148 | franka_state: FrankaState = pickle.loads(self.action_request_socket.recv()) 149 | self.franka_state = franka_state 150 | 151 | image_dict = {} 152 | self.curr_images = [] 153 | for cam_id, subscriber in self.image_subscribers.items(): 154 | image, _ = subscriber.recv_rgb_image() 155 | image_dict[f"pixels{cam_id}"] = cv2.resize(image, (self.width, self.height)) 156 | self.curr_images.append(image) 157 | 158 | obs = { 159 | "features": np.concatenate( 160 | (franka_state.pos, franka_state.quat, [franka_state.gripper]) 161 | ), 162 | "proprioceptive": np.concatenate( 163 | (franka_state.pos, franka_state.quat, [franka_state.gripper]) 164 | ), 165 | } 166 | if self.sensor_type == "reskin": 167 | try: 168 | reskin_state = self._get_reskin_state() 169 | obs.update(reskin_state) 170 | except KeyError: 171 | pass 172 | 173 | obs.update(image_dict) 174 | # for i, image in image_dict.items(): 175 | # obs[f"pixels{i}"] = cv2.resize(image, (self.width, self.height)) 176 | return obs, self.reward, False, False, {} 177 | 178 | def reset(self): 179 | print("resetting") 180 | # TODO: send b"reset" to the robot instead of this action 181 | franka_reset_action = FrankaAction( 182 | pos=np.zeros(3), 183 | quat=np.zeros(4), 184 | gripper=GRIPPER_OPEN, 185 | reset=True, 186 | timestamp=time.time(), 187 | ) 188 | 189 | self.action_request_socket.send( 190 | bytes(pickle.dumps(franka_reset_action, protocol=-1)) 191 | ) 192 | franka_state: FrankaState = pickle.loads(self.action_request_socket.recv()) 193 | self.franka_state = franka_state 194 | print("reset done: ", franka_state) 195 | 196 | image_dict = {} 197 | self.curr_images = [] 198 | for cam_id, subscriber in self.image_subscribers.items(): 199 | image, _ = subscriber.recv_rgb_image() 200 | image_dict[f"pixels{cam_id}"] = cv2.resize(image, (self.width, self.height)) 201 | self.curr_images.append(image) 202 | 203 | obs = { 204 | "features": np.concatenate( 205 | (franka_state.pos, franka_state.quat, [franka_state.gripper]) 206 | ), 207 | "proprioceptive": np.concatenate( 208 | (franka_state.pos, franka_state.quat, [franka_state.gripper]) 209 | ), 210 | } 211 | if self.sensor_type == "reskin": 212 | try: 213 | reskin_state = self._get_reskin_state(update_baseline=True) 214 | obs.update(reskin_state) 215 | except KeyError: 216 | pass 217 | 218 | obs.update(image_dict) 219 | # for i, image in enumerate(image_list): 220 | # obs[f"pixels{i}"] = cv2.resize(image, (self.width, self.height)) 221 | print("returning obs") 222 | return obs 223 | 224 | def _get_reskin_state(self, update_baseline=False): 225 | sensor_state = self.sensor_subscriber.get_sensor_state() 226 | sensor_values = np.array(sensor_state["sensor_values"], dtype=np.float32) 227 | if update_baseline: 228 | baseline_meas = [] 229 | while len(baseline_meas) < 5: 230 | sensor_state = self.sensor_subscriber.get_sensor_state() 231 | sensor_values = np.array( 232 | sensor_state["sensor_values"], dtype=np.float32 233 | ) 234 | baseline_meas.append(sensor_values) 235 | self.sensor_baseline = np.mean(baseline_meas, axis=0) 236 | if self.subtract_sensor_baseline: 237 | self.sensor_prev_state = sensor_values - self.sensor_baseline 238 | else: 239 | self.sensor_prev_state = sensor_values 240 | if self.subtract_sensor_baseline: 241 | sensor_values = sensor_values - self.sensor_baseline 242 | 243 | sensor_diff = sensor_values - self.sensor_prev_state 244 | self.sensor_prev_state = sensor_values 245 | sensor_keys = [f"sensor{sensor_idx}" for sensor_idx in range(self.n_sensors)] 246 | reskin_state = {} 247 | for sidx, sensor_key in enumerate(sensor_keys): 248 | reskin_state[sensor_key] = sensor_values[ 249 | sidx * self.sensor_dim : (sidx + 1) * self.sensor_dim 250 | ] 251 | reskin_state[f"{sensor_key}_diffs"] = sensor_diff[ 252 | sidx * self.sensor_dim : (sidx + 1) * self.sensor_dim 253 | ] 254 | return reskin_state 255 | 256 | def render(self, mode="rgb_array", width=640, height=480): 257 | assert self.curr_images is not None, "Must call reset() before render()" 258 | if mode == "rgb_array": 259 | image_list = [] 260 | for im in self.curr_images: 261 | image_list.append(cv2.resize(im, (width, height))) 262 | 263 | return np.concatenate(image_list, axis=1) 264 | else: 265 | raise NotImplementedError 266 | 267 | 268 | if __name__ == "__main__": 269 | env = FrankaEnv() 270 | images = [] 271 | obs = env.reset() 272 | 273 | apply_deltas = False 274 | if apply_deltas: 275 | delta_pos = 0.03 276 | delta_angle = 0.05 277 | for i in range(100): 278 | obs, reward, done, _ = env.step([delta_pos, 0, 0, 0, 0, 0, GRIPPER_OPEN]) 279 | images.append(obs["pixels_0"]) 280 | 281 | for i in range(100): 282 | obs, reward, done, _ = env.step([0, delta_pos, 0, 0, 0, 0, GRIPPER_OPEN]) 283 | images.append(obs["pixels_0"]) 284 | 285 | for i in range(100): 286 | obs, reward, done, _ = env.step([0, 0, delta_pos, 0, 0, 0, GRIPPER_OPEN]) 287 | images.append(obs["pixels_0"]) 288 | 289 | for i in range(100): 290 | obs, reward, done, _ = env.step([0, 0, 0, delta_angle, 0, 0, GRIPPER_OPEN]) 291 | images.append(obs["pixels_0"]) 292 | 293 | for i in range(100): 294 | obs, reward, done, _ = env.step([0, 0, 0, 0, delta_angle, 0, GRIPPER_OPEN]) 295 | images.append(obs["pixels_0"]) 296 | 297 | for i in range(100): 298 | obs, reward, done, _ = env.step([0, 0, 0, 0, 0, delta_angle, GRIPPER_OPEN]) 299 | images.append(obs["pixels_0"]) 300 | 301 | np.save("images.npy", np.array(images)) 302 | --------------------------------------------------------------------------------