├── .gitignore ├── peg_in_hole_visual_servoing ├── __init__.py ├── annotator.py ├── crop.py ├── servo.py ├── servo_configure.py ├── unet.py └── utils.py ├── readme.md ├── ros ├── .catkin_workspace └── src │ ├── CMakeLists.txt │ └── peg_in_hole_visual_servoing_api │ ├── CMakeLists.txt │ ├── package.xml │ └── srv │ └── SetString.srv └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | ros/build 2 | ros/devel 3 | .idea 4 | **/__pycache__ 5 | *.egg-info 6 | -------------------------------------------------------------------------------- /peg_in_hole_visual_servoing/__init__.py: -------------------------------------------------------------------------------- 1 | from .servo import servo 2 | from .servo_configure import config_from_demonstration 3 | -------------------------------------------------------------------------------- /peg_in_hole_visual_servoing/annotator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from functools import partial 3 | from threading import Event 4 | from typing import List, Union 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | import torchvision 10 | from PIL import Image 11 | 12 | import rospy 13 | from std_msgs.msg import Float64MultiArray 14 | from sensor_msgs.msg import Image as ImageMsg 15 | from ros_numpy.image import numpy_to_image, image_to_numpy 16 | 17 | from .utils import draw_points 18 | from .unet import ResNetUNet 19 | 20 | size = 224 21 | imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--device', default='cuda:0') 25 | parser.add_argument('--model', default='models/synth-e1+-lr1e-3-wd1e-4-fn.pth') 26 | parser.add_argument('--max-crops', type=int, default=2) 27 | args = parser.parse_args() 28 | 29 | # TODO: potentially set max crops dynamically 30 | 31 | device = torch.device(args.device) 32 | model = ResNetUNet(2, pretrained=False).half().to(device) 33 | model.load_state_dict( 34 | torch.load(args.model, map_location=lambda s, l: s)['model'] 35 | ) 36 | 37 | 38 | def infer(model, img: Image) -> np.ndarray: 39 | img = torchvision.transforms.ToTensor()(img) 40 | img = torchvision.transforms.Normalize(*imagenet_stats)(img) 41 | with torch.no_grad(): 42 | result = model(img.half().unsqueeze(0).to(device)).squeeze(0) 43 | return result.detach().cpu().numpy() 44 | 45 | 46 | def main(): 47 | rospy.init_node('annotator') 48 | 49 | annotation_pubs = [] 50 | annotated_img_pubs = [] 51 | images = [None for _ in range(args.max_crops)] # type: List[Union[None, ImageMsg]] 52 | new_data_event = Event() 53 | 54 | def cb(img_msg: ImageMsg, img_idx: int): 55 | images[img_idx] = img_msg 56 | new_data_event.set() 57 | 58 | for i in range(args.max_crops): 59 | annotated_img_pubs.append( 60 | rospy.Publisher('/servo/crop_{}/annotated'.format(i), ImageMsg, queue_size=1) 61 | ) 62 | annotation_pubs.append( 63 | rospy.Publisher('/servo/crop_{}/points'.format(i), Float64MultiArray, queue_size=1) 64 | ) 65 | rospy.Subscriber('/servo/crop_{}'.format(i), ImageMsg, partial(cb, img_idx=i), queue_size=1, 66 | buff_size=size * size * 8 * 3 + 2 ** 8) 67 | 68 | while not rospy.is_shutdown(): 69 | new_data_event.wait(1) 70 | new_data_event.clear() 71 | for i, annotation_pub, annotated_img_pub in zip(range(args.max_crops), annotation_pubs, annotated_img_pubs): 72 | # sequential inference in the main thread, interlaced between image sources 73 | # TODO: could be optimized by collating input from multiple queues 74 | # and feeding the batch to the model instead 75 | img_msg = images[i] 76 | if img_msg is None: 77 | continue 78 | images[i] = None 79 | timestamp = img_msg.header.stamp.to_sec() 80 | img = image_to_numpy(img_msg) 81 | hms = infer(model, Image.fromarray(img)) 82 | points = [] 83 | for hm in hms: 84 | points.append(np.unravel_index(np.argmax(hm), hm.shape)) 85 | points_flat = np.array(points)[:, ::-1].reshape(-1) 86 | points_flat = np.concatenate(([timestamp], points_flat)) 87 | annotation_pub.publish(Float64MultiArray(data=points_flat.astype(np.float64))) 88 | for p, c in zip(points, 'br'): 89 | draw_points(img, [p], c=c) 90 | annotated_img_pub.publish(numpy_to_image(img, 'rgb8')) 91 | 92 | 93 | if __name__ == '__main__': 94 | main() 95 | -------------------------------------------------------------------------------- /peg_in_hole_visual_servoing/crop.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from typing import List 4 | from functools import partial 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | import rospy 10 | import sensor_msgs.msg 11 | from ros_numpy.image import numpy_to_image, image_to_numpy 12 | 13 | import peg_in_hole_visual_servoing_api.srv 14 | from . import servo_configure 15 | 16 | 17 | def configure(servo_config): 18 | _configure = rospy.ServiceProxy('/servo/configure', peg_in_hole_visual_servoing_api.srv.SetString) 19 | _configure.wait_for_service() 20 | return _configure(json.dumps(servo_config)) 21 | 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--config', default='plastic') 26 | parser.add_argument('--crop-size', type=int, default=224) 27 | parser.add_argument('--crop-hole-scale', type=float, default=5.) 28 | 29 | args = parser.parse_args() 30 | crop_size = args.crop_size 31 | crop_hole_scale = args.crop_hole_scale 32 | 33 | rospy.init_node('servo_crop_node') 34 | subs = [] # type: List[rospy.Subscriber] 35 | pubs = [] # type: List[rospy.Publisher] 36 | topics = [] 37 | rect_maps = [] 38 | 39 | # Reusing pub-subs is a little messy. 40 | # I tried a cleaner approach that would always unregister and register pubsubs on configure, 41 | # but there was a delay of approx 2 seconds. 42 | 43 | def _configure(msg: peg_in_hole_visual_servoing_api.srv.SetStringRequest): 44 | crop_configs = json.loads(msg.str)['crop_configs'] 45 | n = len(crop_configs) 46 | if n != len(subs): 47 | for sub in subs: 48 | sub.unregister() 49 | for pub in pubs: 50 | pub.unregister() 51 | subs.clear() 52 | pubs.clear() 53 | topics.clear() 54 | rect_maps.clear() 55 | 56 | for i in range(n): 57 | pubs.append(rospy.Publisher('/servo/crop_{}'.format(i), sensor_msgs.msg.Image, queue_size=1)) 58 | 59 | for i, crop_config in enumerate(crop_configs): 60 | image_topic, K, dist_coeffs, hole_center, hole_size, hole_normal = (crop_config[key] for key in ( 61 | 'image_topic', 'K', 'dist_coeffs', 'hole_center', 'hole_size', 'hole_normal' 62 | )) 63 | 64 | K, dist_coeffs = np.array(K), np.array(dist_coeffs) 65 | roi_transform = servo_configure.get_roi_transform(crop_config, crop_hole_scale, crop_size) 66 | roi_K = roi_transform @ K 67 | rect_map = cv2.initUndistortRectifyMap(K, dist_coeffs, np.eye(3), roi_K, 68 | (crop_size, crop_size), cv2.CV_32FC1) 69 | if i < len(rect_maps): 70 | rect_maps[i] = rect_map 71 | else: 72 | rect_maps.append(rect_map) 73 | 74 | if i >= len(subs) or topics[i] != image_topic: 75 | def cb(img_msg: sensor_msgs.msg.Image, i): 76 | img = image_to_numpy(img_msg) 77 | crop = cv2.remap(img, *rect_maps[i], cv2.INTER_LINEAR) 78 | crop_msg = numpy_to_image(crop, 'rgb8') 79 | crop_msg.header.stamp = img_msg.header.stamp 80 | pubs[i].publish(crop_msg) 81 | 82 | sub = rospy.Subscriber( 83 | image_topic, sensor_msgs.msg.Image, partial(cb, i=i), 84 | queue_size=1, buff_size=1920 * 1080 * 3 * 8 + 2 ** 16 85 | ) 86 | if i < len(subs): 87 | subs[i].unregister() 88 | subs[i] = sub 89 | topics[i] = image_topic 90 | else: 91 | subs.append(sub) 92 | topics.append(image_topic) 93 | 94 | return True 95 | 96 | rospy.Service('/servo/configure', peg_in_hole_visual_servoing_api.srv.SetString, _configure) 97 | rospy.spin() 98 | 99 | 100 | if __name__ == '__main__': 101 | main() 102 | -------------------------------------------------------------------------------- /peg_in_hole_visual_servoing/servo.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import List, Sequence, Union 3 | from functools import partial 4 | 5 | import numpy as np 6 | from transform3d import Transform, SceneNode, SceneState 7 | from ur_control import Robot, DeviatingMotionError 8 | 9 | import rospy 10 | from std_msgs.msg import Float64MultiArray 11 | 12 | from .servo_configure import get_roi_transform 13 | from . import utils 14 | from . import crop 15 | 16 | 17 | def servo(peg_robot: Robot, peg_tcp_node: SceneNode, scene_state: SceneState, 18 | servo_config: dict, camera_nodes: Sequence[SceneNode], 19 | aux_robots: Sequence[Robot] = (), aux_tcp_nodes: Sequence[SceneNode] = (), 20 | insertion_direction_tcp=np.array((0, 0, 1)), 21 | err_tolerance_scale=0.05, timeout=5., 22 | max_travel: float = None, max_travel_scale=3., 23 | alpha_target=0.9, alpha_err=0.9): 24 | state = scene_state.copy() 25 | crop_configs = servo_config['crop_configs'] 26 | insertion_direction_tcp = np.asarray(insertion_direction_tcp) / np.linalg.norm(insertion_direction_tcp) 27 | assert len(aux_robots) == len(aux_tcp_nodes) 28 | tcp_nodes = (peg_tcp_node, *aux_tcp_nodes) 29 | robots = (peg_robot, *aux_robots) 30 | assert len(servo_config['robots_q_init']) == len(robots) 31 | assert len(aux_robots) 32 | n_cams = len(crop_configs) 33 | z_ests, diameter_est = servo_config['z_ests'], servo_config['diameter_est'] 34 | assert n_cams == len(z_ests) == len(camera_nodes) 35 | err_tolerance = diameter_est * err_tolerance_scale 36 | if max_travel is None: 37 | max_travel = diameter_est * max_travel_scale 38 | assert crop.configure(servo_config) 39 | 40 | crop_K_invs = [np.linalg.inv(get_roi_transform(crop_config) @ crop_config['K']) for crop_config in crop_configs] 41 | cams_points = [None for _ in range(n_cams)] # type: List[Union[None, (float, np.ndarray)]] 42 | new_data = [False] 43 | 44 | def handle_points(msg: Float64MultiArray, cam_idx: int): 45 | timestamp, points = msg.data[0], np.array(msg.data[1:]).reshape(2, 2) 46 | cams_points[cam_idx] = timestamp, points 47 | new_data[0] = True 48 | 49 | subs = [] 50 | for cam_idx in range(n_cams): 51 | subs.append(rospy.Subscriber( 52 | '/servo/crop_{}/points'.format(cam_idx), Float64MultiArray, 53 | partial(handle_points, cam_idx=cam_idx), queue_size=1 54 | )) 55 | 56 | scene_configs_times = [] # type: List[float] 57 | scene_configs = [] # type: List[Sequence[Transform]] 58 | 59 | def add_current_scene_config(): 60 | scene_configs.append([r.base_t_tcp() for r in robots]) 61 | scene_configs_times.append(time.time()) 62 | 63 | def update_scene_state(timestamp): 64 | transforms = scene_configs[utils.bisect_closest(scene_configs_times, timestamp)] 65 | for tcp_node, transform in zip(tcp_nodes, transforms): 66 | state[tcp_node] = transform 67 | 68 | add_current_scene_config() 69 | update_scene_state(0) 70 | 71 | peg_tcp_init_node = SceneNode(parent=peg_tcp_node.parent) 72 | state[peg_tcp_init_node] = scene_configs[-1][0] 73 | peg_tcp_cur_node = SceneNode(parent=peg_tcp_node.parent) 74 | 75 | base_p_tcp_rolling = state[peg_tcp_init_node].p 76 | err_rolling, err_size_rolling = None, err_tolerance * 10 77 | start = time.time() 78 | try: 79 | while err_size_rolling > err_tolerance: 80 | loop_start = time.time() 81 | add_current_scene_config() 82 | if new_data[0]: 83 | new_data[0] = False 84 | state[peg_tcp_cur_node] = scene_configs[-1][0] 85 | 86 | peg_tcp_init_t_peg_tcp = peg_tcp_init_node.t(peg_tcp_node, state) 87 | if np.linalg.norm(peg_tcp_init_t_peg_tcp.p) > max_travel: 88 | raise DeviatingMotionError() 89 | if rospy.is_shutdown(): 90 | raise RuntimeError() 91 | if loop_start - start > timeout: 92 | raise TimeoutError() 93 | 94 | # TODO: check for age of points 95 | # TODO: handle no camera inputs (raise appropriate error) 96 | move_dirs = [] 97 | move_errs = [] 98 | for cam_points, K_inv, cam_node, z_est in zip(cams_points, crop_K_invs, camera_nodes, z_ests): 99 | if cam_points is None: 100 | continue 101 | timestamp, cam_points = cam_points 102 | update_scene_state(timestamp) 103 | peg_tcp_t_cam = peg_tcp_node.t(cam_node, state) 104 | 105 | pts_peg_tcp = [] 106 | for p_img in cam_points: 107 | p_cam = K_inv @ (*p_img, 1) 108 | p_cam *= z_est / p_cam[2] 109 | pts_peg_tcp.append(peg_tcp_t_cam @ p_cam) 110 | hole_peg_tcp, peg_peg_tcp = pts_peg_tcp 111 | 112 | view_dir = (peg_peg_tcp + hole_peg_tcp) / 2 - peg_tcp_t_cam.p 113 | move_dir = np.cross(view_dir, insertion_direction_tcp) 114 | move_dir /= np.linalg.norm(move_dir) 115 | 116 | already_moved = state[peg_tcp_init_node].inv.rotate( 117 | state[peg_tcp_cur_node].p - state[peg_tcp_node].p 118 | ) 119 | move_err = np.dot(move_dir, (hole_peg_tcp - peg_peg_tcp) - already_moved) 120 | 121 | move_dirs.append(move_dir) 122 | move_errs.append(move_err) 123 | 124 | move_dirs = np.array(move_dirs) 125 | move_errs = np.array(move_errs) 126 | if len(move_dirs) > 0: 127 | err_tcp, *_ = np.linalg.lstsq(move_dirs, move_errs, rcond=None) 128 | if err_rolling is None: 129 | err_rolling = err_tcp 130 | err_rolling = alpha_err * err_rolling + (1 - alpha_err) * err_tcp 131 | err_size_rolling = alpha_err * err_size_rolling + (1 - alpha_err) * np.linalg.norm(err_rolling) 132 | base_t_tcp_target = state[peg_tcp_cur_node] @ Transform(p=err_tcp) 133 | base_p_tcp_rolling = alpha_target * base_p_tcp_rolling + (1 - alpha_target) * base_t_tcp_target.p 134 | peg_robot.ctrl.servoL( 135 | (*base_p_tcp_rolling, *state[peg_tcp_init_node].rotvec), 136 | 0.5, 0.25, peg_robot.dt, 0.2, 300 137 | ) 138 | loop_duration = time.time() - loop_start 139 | time.sleep(max(0., peg_robot.dt - loop_duration)) 140 | finally: 141 | peg_robot.ctrl.servoStop() 142 | for sub in subs: 143 | sub.unregister() 144 | -------------------------------------------------------------------------------- /peg_in_hole_visual_servoing/servo_configure.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import numpy as np 4 | import cv2 5 | from transform3d import SceneNode, SceneState 6 | from ur_control import Robot 7 | 8 | import rospy 9 | import sensor_msgs.msg 10 | from ros_numpy.image import image_to_numpy 11 | 12 | from . import utils 13 | 14 | 15 | def get_z_ests(servo_config, cam_nodes: Sequence[SceneNode], state: SceneState, resolution=224): 16 | crop_Ks = [get_roi_transform(crop_config) @ crop_config['K'] for crop_config in servo_config['crop_configs']] 17 | 18 | frame_node = cam_nodes[0] 19 | line_points = [] 20 | line_directions = [] 21 | 22 | for K, cam_node in zip(crop_Ks, cam_nodes): 23 | K_inv = np.linalg.inv(K) 24 | dir_cam = K_inv @ (resolution / 2, resolution / 3, 1) 25 | frame_t_cam = frame_node.t(cam_node, state) 26 | line_directions.append(frame_t_cam.rotate(dir_cam)) 27 | line_points.append(frame_t_cam.p) 28 | 29 | p = utils.closest_point_to_lines(np.array(line_points), np.array(line_directions)) 30 | z_ests = [(cam_node.t(frame_node, state) @ p)[2] for cam_node in cam_nodes] 31 | return z_ests 32 | 33 | 34 | def get_diameter_est(servo_config, zs): 35 | diameter_ests = [] 36 | for crop_config, z in zip(servo_config['crop_configs'], zs): 37 | hole_size_px = crop_config['hole_size'] 38 | K = np.array(crop_config['K']) 39 | f = np.sqrt(np.linalg.det(K[:2, :2])) 40 | diameter_ests.append(hole_size_px * z / f) 41 | return np.mean(diameter_ests) 42 | 43 | 44 | def get_roi_transform(crop_config, crop_hole_scale=5., crop_size=224): 45 | hole_center, hole_size, hole_normal = (crop_config[key] for key in ('hole_center', 'hole_size', 'hole_normal')) 46 | angle = -np.arctan2(hole_normal[1], hole_normal[0]) + np.pi / 2 47 | S, C = np.sin(angle), np.cos(angle) 48 | M = np.eye(3) 49 | M[:2, :2] = np.array(((C, -S), (S, C))) 50 | size = hole_size * crop_hole_scale 51 | M[:2, 2] = (M[:2, :2] @ -np.array(hole_center)) + (size / 2, size / 3) 52 | M[:2] *= crop_size / size 53 | return M 54 | 55 | 56 | # TODO: from known peg-tcp position 57 | 58 | def config_from_demonstration( 59 | peg_robot: Robot, aux_robots: Sequence[Robot], 60 | peg_tcp_node: SceneNode, aux_tcp_nodes: Sequence[SceneNode], 61 | scene_state: SceneState, 62 | image_topics: Sequence[str], camera_nodes: Sequence[SceneNode], 63 | Ks: Sequence[np.ndarray], dist_coeffs: Sequence[np.ndarray], 64 | diameter_est: float = None, 65 | ): 66 | n_cams = len(camera_nodes) 67 | assert n_cams == len(Ks) == len(dist_coeffs) == len(image_topics) 68 | robots = (peg_robot, *aux_robots) 69 | tcp_nodes = (peg_tcp_node, *aux_tcp_nodes) 70 | assert len(robots) == len(tcp_nodes) 71 | for r in robots: 72 | r.ctrl.teachMode() 73 | input('move robots into start position for visual servoing and press enter') 74 | for r in robots: 75 | r.ctrl.endTeachMode() 76 | 77 | robots_q_init = [r.recv.getActualQ() for r in robots] 78 | robots_t_init = [r.base_t_tcp() for r in robots] 79 | state = scene_state.copy() 80 | 81 | for base_t_tcp, tcp_node in zip(robots_t_init, tcp_nodes): 82 | state[tcp_node] = base_t_tcp 83 | 84 | crop_configs = [] 85 | for image_topic, K, dist_coeff in zip(image_topics, Ks, dist_coeffs): 86 | K, dist_coeff = np.array(K), np.array(dist_coeff) 87 | rect_maps = cv2.initUndistortRectifyMap(K, dist_coeff, np.eye(3), K, (1920, 1080), cv2.CV_32FC1) 88 | img = rospy.wait_for_message(image_topic, sensor_msgs.msg.Image, timeout=3) 89 | img = image_to_numpy(img) 90 | img = cv2.remap(img, *rect_maps, cv2.INTER_LINEAR) 91 | hole_points = utils.gui_select_vector('mark hole (longest line within the hole)', lambda: img) 92 | hole_center = np.mean(hole_points, axis=0).round().astype(int) 93 | hole_size = int(np.round(np.linalg.norm(hole_points[1] - hole_points[0]))) 94 | hole_points = utils.gui_select_vector('draw vector from hole towards peg along insertion direction', 95 | lambda: img, arrow=True, first_point=hole_center) 96 | hole_normal = hole_points[1] - hole_points[0] 97 | crop_configs.append({ 98 | 'image_topic': image_topic, 99 | 'K': K.tolist(), 100 | 'dist_coeffs': dist_coeff.tolist(), 101 | 'hole_center': hole_center.tolist(), 102 | 'hole_size': hole_size, 103 | 'hole_normal': hole_normal.tolist() 104 | }) 105 | 106 | servo_config = { 107 | 'robots_q_init': [list(q) for q in robots_q_init], 108 | 'robots_t_init': [list(t.xyz_rotvec) for t in robots_t_init], 109 | 'crop_configs': crop_configs, 110 | } 111 | 112 | z_ests = get_z_ests(servo_config, camera_nodes, state) 113 | diameter_est = get_diameter_est(servo_config, z_ests) if diameter_est is None else diameter_est 114 | servo_config['z_ests'] = z_ests 115 | servo_config['diameter_est'] = diameter_est 116 | 117 | return servo_config 118 | -------------------------------------------------------------------------------- /peg_in_hole_visual_servoing/unet.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/usuyama/pytorch-unet 3 | 4 | MIT License 5 | 6 | Copyright (c) 2018 Naoto Usuyama 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | """ 26 | 27 | import torch 28 | import torch.nn as nn 29 | from torchvision import models 30 | 31 | 32 | def convrelu(in_channels, out_channels, kernel, padding): 33 | return nn.Sequential( 34 | nn.Conv2d(in_channels, out_channels, kernel, padding=padding), 35 | nn.ReLU(inplace=True), 36 | ) 37 | 38 | 39 | def projection(c, n, k): 40 | return nn.Sequential( 41 | *[convrelu(c, c, k, k // 2) for _ in range(n)] 42 | ) 43 | 44 | 45 | class ResNetUNet(nn.Module): 46 | def __init__(self, n_class, pretrained=True, freeze=False, projection_n=1, projection_k=1): 47 | super().__init__() 48 | 49 | self.base_model = models.resnet18(pretrained=pretrained) 50 | if freeze: 51 | for child in list(self.base_model.children())[:6]: 52 | for p in child.parameters(): 53 | p.requires_grad = False 54 | self.base_layers = list(self.base_model.children()) 55 | 56 | self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2) 57 | self.layer0_1x1 = projection(64, projection_n, projection_k) 58 | self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4) 59 | self.layer1_1x1 = projection(64, projection_n, projection_k) 60 | self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8) 61 | self.layer2_1x1 = projection(128, projection_n, projection_k) 62 | self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16) 63 | self.layer3_1x1 = projection(256, projection_n, projection_k) 64 | self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32) 65 | self.layer4_1x1 = projection(512, projection_n, projection_k) 66 | 67 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 68 | 69 | self.conv_up3 = convrelu(256 + 512, 512, 3, 1) 70 | self.conv_up2 = convrelu(128 + 512, 256, 3, 1) 71 | self.conv_up1 = convrelu(64 + 256, 256, 3, 1) 72 | self.conv_up0 = convrelu(64 + 256, 128, 3, 1) 73 | 74 | self.conv_original_size0 = convrelu(3, 64, 3, 1) 75 | self.conv_original_size1 = convrelu(64, 64, 3, 1) 76 | self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1) 77 | 78 | self.conv_last = nn.Conv2d(64, n_class, 1) 79 | 80 | def forward(self, input): 81 | x_original = self.conv_original_size0(input) 82 | x_original = self.conv_original_size1(x_original) 83 | 84 | layer0 = self.layer0(input) 85 | layer1 = self.layer1(layer0) 86 | layer2 = self.layer2(layer1) 87 | layer3 = self.layer3(layer2) 88 | layer4 = self.layer4(layer3) 89 | 90 | layer4 = self.layer4_1x1(layer4) 91 | x = self.upsample(layer4) 92 | layer3 = self.layer3_1x1(layer3) 93 | x = torch.cat([x, layer3], dim=1) 94 | x = self.conv_up3(x) 95 | 96 | x = self.upsample(x) 97 | layer2 = self.layer2_1x1(layer2) 98 | x = torch.cat([x, layer2], dim=1) 99 | x = self.conv_up2(x) 100 | 101 | x = self.upsample(x) 102 | layer1 = self.layer1_1x1(layer1) 103 | x = torch.cat([x, layer1], dim=1) 104 | x = self.conv_up1(x) 105 | 106 | x = self.upsample(x) 107 | layer0 = self.layer0_1x1(layer0) 108 | x = torch.cat([x, layer0], dim=1) 109 | x = self.conv_up0(x) 110 | 111 | x = self.upsample(x) 112 | x = torch.cat([x, x_original], dim=1) 113 | x = self.conv_original_size2(x) 114 | 115 | out = self.conv_last(x) 116 | 117 | return out 118 | -------------------------------------------------------------------------------- /peg_in_hole_visual_servoing/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Union 2 | import bisect 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | COLORS = { 8 | 'r': (255, 0, 0, 255), 9 | 'g': (0, 255, 0, 255), 10 | 'b': (0, 0, 255, 255), 11 | 'k': (0, 0, 0, 255), 12 | 'w': (255, 255, 255, 255), 13 | } 14 | 15 | 16 | def draw_points(img, points, c: Union[str, tuple] = 'r'): 17 | if isinstance(c, str): 18 | c = COLORS[c] 19 | for i, p in enumerate(points): 20 | cv2.drawMarker(img, tuple(p[::-1]), c, cv2.MARKER_TILTED_CROSS, 10, 1, cv2.LINE_AA) 21 | 22 | 23 | def gui_select_vector(window_title: str, grab_frame_cb: Callable[[], np.ndarray], first_point=None, 24 | roi=None, arrow=False, destroy_window_on_end=True) -> np.ndarray: 25 | points = [] if first_point is None else [tuple(first_point)] 26 | mouse_pos = [None] 27 | 28 | def mouse_callback(event, x, y, flags, param): 29 | if event == cv2.EVENT_LBUTTONDOWN: 30 | points.append((x, y)) 31 | elif event == cv2.EVENT_MOUSEMOVE and len(points) == 1: 32 | mouse_pos[0] = (x, y) 33 | elif event == cv2.EVENT_LBUTTONUP: 34 | points.append((x, y)) 35 | 36 | cv2.namedWindow(window_title) 37 | cv2.setMouseCallback(window_title, mouse_callback) 38 | left, upper, right, lower = roi or (0, 0, None, None) 39 | while True: 40 | cv2.waitKey(16) 41 | img = grab_frame_cb()[upper:lower, left:right].copy() 42 | if mouse_pos[0]: 43 | cv2.arrowedLine(img, points[0], mouse_pos[0], (255, 0, 0), 2, tipLength=.2 if arrow else 0.) 44 | cv2.imshow(window_title, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 45 | if len(points) > 1: 46 | break 47 | if destroy_window_on_end: 48 | cv2.destroyWindow(window_title) 49 | return np.array(points) + (left, upper) 50 | 51 | 52 | def bisect_closest(a, x): 53 | idx = bisect.bisect(a, x) 54 | if idx == len(a): 55 | return idx - 1 56 | if idx > 0: 57 | if x - a[idx - 1] < a[idx] - x: 58 | idx = idx - 1 59 | return idx 60 | 61 | 62 | def closest_point_to_lines(line_points: np.ndarray, line_directions: np.ndarray, 63 | direction_is_unit=False): 64 | assert line_points.shape == line_directions.shape 65 | *pre_shape, N, d = line_points.shape 66 | 67 | if not direction_is_unit: 68 | line_directions = line_directions / np.linalg.norm(line_directions, axis=-1, keepdims=True) 69 | 70 | A = np.empty((*pre_shape, d, d)) 71 | for k in range(d): 72 | uk = np.zeros(d) 73 | uk[k] = 1. 74 | A[..., k, :] = (line_directions[..., k:k + 1] * line_directions - uk).sum(axis=-2) 75 | b = line_directions * np.sum(line_points * line_directions, axis=-1, keepdims=True) - line_points 76 | b = b.sum(axis=-2) 77 | return np.linalg.solve(A, b) 78 | 79 | 80 | def _gui_selector_test(): 81 | print(gui_select_vector( 82 | 'hey', 83 | lambda: np.zeros((500, 500, 3), dtype=np.uint8), 84 | roi=(300, 0, 500, 200), 85 | arrow=True, 86 | )) 87 | 88 | 89 | def _bisect_closest_test(): 90 | for x, idx in [(-0.1, 0), (0.49, 0), (1.51, 2), (10, 2)]: 91 | idx_ = bisect_closest([0, 1, 2], x) 92 | assert idx_ == idx, '{}, expexted {}, but got {}'.format(x, idx, idx_) 93 | 94 | 95 | def _closest_point_to_lines_test(): 96 | import matplotlib.pyplot as plt 97 | colors = 'rgbcmyk'[:3] 98 | n, N, d = len(colors), 3, 2 99 | line_points, line_directions = np.random.uniform(-1, 1, (2, n, N, d)) 100 | 101 | pts = closest_point_to_lines(line_points, line_directions) 102 | 103 | for line_points_, line_directions_, p, c in zip(line_points, line_directions, pts, colors): 104 | for line_point, line_direction in zip(line_points_, line_directions_): 105 | x0, y0 = line_point[:2] 106 | dx, dy = line_direction[:2] * 100 107 | plt.plot([x0 - dx, x0 + dx], [y0 - dy, y0 + dy], c=c) 108 | 109 | plt.scatter([p[0]], [p[1]], c=c, zorder=3) 110 | plt.xlim(-2, 2) 111 | plt.ylim(-2, 2) 112 | plt.gca().set_aspect(1) 113 | plt.show() 114 | 115 | 116 | if __name__ == '__main__': 117 | _gui_selector_test() 118 | _bisect_closest_test() 119 | _closest_point_to_lines_test() 120 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # peg-in-hole-visual-servoing 2 | 3 | Visual servoing for peg-in-hole. 4 | 5 | The servoing consists of three nodes. 6 | 1) A crop node that captures and crops images from ros image topics. 7 | 2) An annotator node that processes the cropped images. 8 | 3) A client node that configures the crop node, subscribes to the annotations, and controls the robot(s). 9 | 10 | The images can be captured and cropped by one computer, 11 | processed by a second computer (with a GPU) and 12 | the robot control can happen on a third computer. 13 | If the computer that captures images is not the same as the computer that controls the robots, 14 | then make sure, the computers clocks are synchronized, eg. with [chrony](https://chrony.tuxfamily.org/). 15 | 16 | Requires a trained model, see 17 | [peg-in-hole-visual-servoing-model](https://github.com/RasmusHaugaard/peg-in-hole-visual-servoing-model). 18 | 19 | 20 | #### install 21 | install opencv 22 | ``$ pip3 install opencv-python`` 23 | 24 | On the annotator node, *torch* and *torchvision* should also be installed. 25 | 26 | install visual servoing module 27 | ``$ pip3 install -e .`` 28 | 29 | ROS is used for communication between the crop and client node. 30 | ``$ catkin_make --directory ros`` 31 | ``$ source ros/devel/setup.bash`` 32 | 33 | #### on the computer that is connected to the cameras 34 | ``python3 -m peg_in_hole_visual_servoing.crop`` 35 | 36 | #### on a computer with GPU 37 | ``python3 -m peg_in_hole_visual_servoing.annotator --model [model path]`` 38 | 39 | #### on the computer connected to the robots 40 | ```python 41 | import json 42 | from ur_control import Robot 43 | import peg_in_hole_visual_servoing 44 | from transform3d import SceneNode, SceneState 45 | 46 | peg_robot = Robot.from_ip('192.168.1.123') 47 | aux_robots = [Robot.from_ip('192.168.1.124')] 48 | image_topics = '/camera_a/color/image_raw', '/camera_b/color/image_raw' 49 | 50 | # build the scene structure 51 | peg_robot_tcp, cams_robot_tcp, peg_robot_base, cams_robot_base, \ 52 | cam_a, cam_b, table = SceneNode.n(7) 53 | table.adopt( 54 | peg_robot_base.adopt(peg_robot_tcp), 55 | cams_robot_base.adopt(cams_robot_tcp.adopt(cam_a, cam_b)) 56 | ) 57 | # insert necessary transforms from calibrations 58 | state = SceneState() 59 | state[peg_robot_base] = get_table_peg_base_calibration() 60 | state[cams_robot_base] = get_table_cams_base_calibration() 61 | state[cam_a] = get_tcp_cam_a_calibration() 62 | state[cam_b] = get_tcp_cam_b_calibration() 63 | 64 | Ks, dist_coeffs = get_camera_intrinsic_calibrations() 65 | 66 | ### Once, create a servo configuration: 67 | # config_from_demonstration will let you move the robots in place for insertion 68 | # and mark the holes in the images. 69 | config = peg_in_hole_visual_servoing.config_from_demonstration( 70 | peg_robot=peg_robot, aux_robots=aux_robots, 71 | peg_tcp_node=peg_robot_tcp, aux_tcp_nodes=[cams_robot_tcp], 72 | scene_state=state, 73 | image_topics=image_topics, camera_nodes=[cam_a, cam_b], 74 | Ks=Ks, dist_coeffs=dist_coeffs 75 | ) 76 | 77 | # the configuration is json serializable 78 | json.dump(open('servo_config.json', 'w'), config) 79 | 80 | 81 | ### When servoing is needed 82 | config = json.load(open('servo_config.json')) 83 | peg_in_hole_visual_servoing.servo( 84 | peg_robot=peg_robot, aux_robots=aux_robots, 85 | peg_tcp_node=peg_robot_tcp, aux_tcp_nodes=[cams_robot_tcp], 86 | scene_state=state, camera_nodes=[cam_a, cam_b], 87 | servo_config=config, insertion_direction_tcp=(0, 0, 1) 88 | ) 89 | ``` 90 | 91 | 92 | -------------------------------------------------------------------------------- /ros/.catkin_workspace: -------------------------------------------------------------------------------- 1 | # This file currently only serves to mark the location of a catkin workspace for tool integration 2 | -------------------------------------------------------------------------------- /ros/src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | /opt/ros/melodic/share/catkin/cmake/toplevel.cmake -------------------------------------------------------------------------------- /ros/src/peg_in_hole_visual_servoing_api/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.3) 2 | project(peg_in_hole_visual_servoing_api) 3 | 4 | find_package(catkin REQUIRED COMPONENTS genmsg message_generation std_msgs sensor_msgs) 5 | 6 | add_service_files( 7 | FILES 8 | SetString.srv 9 | ) 10 | 11 | generate_messages( 12 | DEPENDENCIES 13 | std_msgs 14 | ) 15 | 16 | catkin_package( 17 | CATKIN_DEPENDS message_runtime std_msgs sensor_msgs 18 | ) -------------------------------------------------------------------------------- /ros/src/peg_in_hole_visual_servoing_api/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | peg_in_hole_visual_servoing_api 4 | 0.0.0 5 | The peg_in_hole_visual_servoing_api package 6 | Rasmus Laurvig Haugaard 7 | TODO 8 | message_generation 9 | std_msgs 10 | sensor_msgs 11 | message_runtime 12 | std_msgs 13 | sensor_msgs 14 | catkin 15 | 16 | -------------------------------------------------------------------------------- /ros/src/peg_in_hole_visual_servoing_api/srv/SetString.srv: -------------------------------------------------------------------------------- 1 | # request 2 | string str 3 | --- 4 | # response 5 | bool success -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from setuptools import setup 3 | 4 | setup( 5 | name='peg_in_hole_visual_servoing', 6 | version='0.0.0', 7 | install_requires=[ 8 | 'numpy', 9 | 'pillow', 10 | 'typing', 11 | 'transform3d', 12 | 'ur_control' 13 | ] 14 | ) 15 | --------------------------------------------------------------------------------