├── .gitignore ├── 3d_eval.py ├── README.md ├── grid_script ├── mellow.py ├── modeltests.py ├── old_stack ├── crnn.py ├── imitate_msgs │ ├── msg │ │ └── GripperStamped.msg │ └── package.xml ├── models │ ├── crnn.py │ └── imitation_net.py ├── publishers │ ├── rightGripperPublisher.py │ ├── rightPosePublisher.py │ ├── rightVelPublisher.py │ └── toggletest.py ├── simulation │ └── gym-imitate │ │ ├── README.md │ │ ├── imitate_env.egg-info │ │ └── PKG-INFO │ │ ├── imitate_gym │ │ ├── __init__.py │ │ ├── buttons_data.py │ │ ├── crnn.py │ │ ├── envs │ │ │ ├── __init__.py │ │ │ ├── assets │ │ │ │ ├── buttons │ │ │ │ │ ├── buttons.xml │ │ │ │ │ ├── buttons_3x3.xml │ │ │ │ │ ├── robot.xml │ │ │ │ │ └── shared.xml │ │ │ │ ├── stls │ │ │ │ │ ├── .get │ │ │ │ │ ├── fetch │ │ │ │ │ │ ├── base_link_collision.stl │ │ │ │ │ │ ├── bellows_link_collision.stl │ │ │ │ │ │ ├── elbow_flex_link_collision.stl │ │ │ │ │ │ ├── estop_link.stl │ │ │ │ │ │ ├── forearm_roll_link_collision.stl │ │ │ │ │ │ ├── gripper_link.stl │ │ │ │ │ │ ├── head_pan_link_collision.stl │ │ │ │ │ │ ├── head_tilt_link_collision.stl │ │ │ │ │ │ ├── l_wheel_link_collision.stl │ │ │ │ │ │ ├── laser_link.stl │ │ │ │ │ │ ├── r_wheel_link_collision.stl │ │ │ │ │ │ ├── shoulder_lift_link_collision.stl │ │ │ │ │ │ ├── shoulder_pan_link_collision.stl │ │ │ │ │ │ ├── torso_fixed_link.stl │ │ │ │ │ │ ├── torso_lift_link_collision.stl │ │ │ │ │ │ ├── upperarm_roll_link_collision.stl │ │ │ │ │ │ ├── wrist_flex_link_collision.stl │ │ │ │ │ │ └── wrist_roll_link_collision.stl │ │ │ │ │ └── hand │ │ │ │ │ │ ├── F1.stl │ │ │ │ │ │ ├── F2.stl │ │ │ │ │ │ ├── F3.stl │ │ │ │ │ │ ├── TH1_z.stl │ │ │ │ │ │ ├── TH2_z.stl │ │ │ │ │ │ ├── TH3_z.stl │ │ │ │ │ │ ├── forearm_electric.stl │ │ │ │ │ │ ├── forearm_electric_cvx.stl │ │ │ │ │ │ ├── knuckle.stl │ │ │ │ │ │ ├── lfmetacarpal.stl │ │ │ │ │ │ ├── palm.stl │ │ │ │ │ │ └── wrist.stl │ │ │ │ └── textures │ │ │ │ │ ├── block.png │ │ │ │ │ └── block_hidden.png │ │ │ ├── buttons.py │ │ │ └── imitate_env.py │ │ └── gym_eval.py │ │ └── setup.py ├── test_baxter.py └── util │ ├── old │ ├── crnn.py │ ├── imitate_msgs │ │ ├── msg │ │ │ └── GripperStamped.msg │ │ └── package.xml │ ├── publishers │ │ ├── rightGripperPublisher.py │ │ ├── rightPosePublisher.py │ │ ├── rightVelPublisher.py │ │ └── toggletest.py │ ├── simulation │ │ ├── 2dsim.py │ │ └── gym-imitate │ │ │ ├── README.md │ │ │ ├── imitate_env.egg-info │ │ │ └── PKG-INFO │ │ │ ├── imitate_gym │ │ │ ├── __init__.py │ │ │ ├── buttons_data.py │ │ │ ├── crnn.py │ │ │ ├── envs │ │ │ │ ├── __init__.py │ │ │ │ ├── assets │ │ │ │ │ ├── buttons │ │ │ │ │ │ ├── buttons.xml │ │ │ │ │ │ ├── buttons_3x3.xml │ │ │ │ │ │ ├── robot.xml │ │ │ │ │ │ └── shared.xml │ │ │ │ │ ├── stls │ │ │ │ │ │ ├── .get │ │ │ │ │ │ ├── fetch │ │ │ │ │ │ │ ├── base_link_collision.stl │ │ │ │ │ │ │ ├── bellows_link_collision.stl │ │ │ │ │ │ │ ├── elbow_flex_link_collision.stl │ │ │ │ │ │ │ ├── estop_link.stl │ │ │ │ │ │ │ ├── forearm_roll_link_collision.stl │ │ │ │ │ │ │ ├── gripper_link.stl │ │ │ │ │ │ │ ├── head_pan_link_collision.stl │ │ │ │ │ │ │ ├── head_tilt_link_collision.stl │ │ │ │ │ │ │ ├── l_wheel_link_collision.stl │ │ │ │ │ │ │ ├── laser_link.stl │ │ │ │ │ │ │ ├── r_wheel_link_collision.stl │ │ │ │ │ │ │ ├── shoulder_lift_link_collision.stl │ │ │ │ │ │ │ ├── shoulder_pan_link_collision.stl │ │ │ │ │ │ │ ├── torso_fixed_link.stl │ │ │ │ │ │ │ ├── torso_lift_link_collision.stl │ │ │ │ │ │ │ ├── upperarm_roll_link_collision.stl │ │ │ │ │ │ │ ├── wrist_flex_link_collision.stl │ │ │ │ │ │ │ └── wrist_roll_link_collision.stl │ │ │ │ │ │ └── hand │ │ │ │ │ │ │ ├── F1.stl │ │ │ │ │ │ │ ├── F2.stl │ │ │ │ │ │ │ ├── F3.stl │ │ │ │ │ │ │ ├── TH1_z.stl │ │ │ │ │ │ │ ├── TH2_z.stl │ │ │ │ │ │ │ ├── TH3_z.stl │ │ │ │ │ │ │ ├── forearm_electric.stl │ │ │ │ │ │ │ ├── forearm_electric_cvx.stl │ │ │ │ │ │ │ ├── knuckle.stl │ │ │ │ │ │ │ ├── lfmetacarpal.stl │ │ │ │ │ │ │ ├── palm.stl │ │ │ │ │ │ │ └── wrist.stl │ │ │ │ │ └── textures │ │ │ │ │ │ ├── block.png │ │ │ │ │ │ └── block_hidden.png │ │ │ │ ├── buttons.py │ │ │ │ └── imitate_env.py │ │ │ └── gym_eval.py │ │ │ └── setup.py │ └── test_baxter.py │ ├── parse_data.py │ ├── record.py │ └── toggler.py ├── sim_eval.py ├── sim_eval0.py ├── simulation └── sim.py ├── src ├── datasets.py ├── loss_func.py ├── model.py └── model0.py ├── superval.py ├── superval_3dval.py ├── superval_results_compile.py ├── temprun.sh ├── train.py ├── util ├── parse_data.py └── plot_loss.py ├── venv_tool └── verify_data.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Miscellaneous 2 | *.swp 3 | 4 | # Compiled data 5 | *.csv 6 | *.lmdb 7 | *.lmdb-lock 8 | 9 | # Generated data 10 | *.png 11 | *.txt 12 | *.pt 13 | 14 | # Model checkpoints 15 | *.tar 16 | 17 | # Pycache 18 | *.pyc 19 | 20 | # Auto-generated grid files 21 | *.e1* 22 | *.o1* 23 | -------------------------------------------------------------------------------- /3d_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import cv2 3 | import time 4 | import csv 5 | import os 6 | import sys 7 | import rospy 8 | import itertools 9 | import numpy as np 10 | #from tf.transformations import euler_from_quaternion 11 | from cv_bridge import CvBridge 12 | from namedlist import namedlist 13 | from std_msgs.msg import Int64, String 14 | from sensor_msgs.msg import CompressedImage, Image, JointState 15 | from geometry_msgs.msg import Twist, Pose, TwistStamped, PoseStamped, Vector3 16 | import torch 17 | from model import Model 18 | import argparse 19 | import copy 20 | from matplotlib import pyplot as plt 21 | import signal 22 | 23 | class ImitateEval: 24 | def __init__(self, weights): 25 | self.bridge = CvBridge() 26 | self.Data = namedlist('Data', ['pose', 'rgb', 'depth']) 27 | self.data = self.Data(pose=None, rgb=None, depth=None) 28 | self.is_start = True 29 | 30 | checkpoint = torch.load(weights, map_location="cpu") 31 | self.model = Model(**checkpoint['kwargs']) 32 | self.model.load_state_dict(checkpoint["model_state_dict"]) 33 | self.model.eval() 34 | 35 | def change_start(self): 36 | radius = 0.07 37 | # Publisher for the movement and the starting pose 38 | self.movement_publisher = rospy.Publisher('/iiwa/CollisionAwareMotion', Pose, queue_size=10) 39 | self.target_start = Pose() 40 | self.target_start.position.x = -0.15 + np.random.rand()*2*radius - radius # -0.10757 41 | self.target_start.position.y = 0.455 + np.random.rand()*2*radius - radius # 0.4103 42 | self.target_start.position.z = 1.015 43 | self.target_start.orientation.x = 0.0 44 | self.target_start.orientation.y = 0.0 45 | self.target_start.orientation.z = 0.7071068 46 | self.target_start.orientation.w = 0.7071068 47 | 48 | 49 | def move_to_button(self, tau, tolerance): 50 | self.init_listeners() 51 | rospy.Rate(5).sleep() 52 | rate = rospy.Rate(15) 53 | 54 | pose_to_move = copy.deepcopy(self.target_start) 55 | eof = [] 56 | while not rospy.is_shutdown(): 57 | if None not in self.data: 58 | # Position from the CartesianPose Topic!! 59 | pos = self.data.pose.position 60 | pos = [pos.x, pos.y, pos.z] 61 | if self.is_start: 62 | for _ in range(5): 63 | eof += pos 64 | self.is_start = False 65 | else: 66 | eof = pos + eof[:-3] 67 | 68 | eof_input = torch.from_numpy(np.array(eof)).type(torch.FloatTensor) 69 | eof_input = eof_input.unsqueeze(0)#.zero_() 70 | 71 | tau_input = torch.Tensor(tau).to(eof_input).view(1, -1) 72 | 73 | rgb = self.process_images(self.data.rgb, True) 74 | depth = self.process_images(self.data.depth, False) 75 | 76 | # print("RGB min: {}, RGB max: {}".format(np.amin(rgb), np.amax(rgb))) 77 | # print("Depth min: {}, Depth max: {}".format(np.amin(depth), np.amax(depth))) 78 | print("EOF: {}".format(eof_input)) 79 | print("Tau: {}".format(tau_input)) 80 | 81 | torch.save(rgb, "/home/amazon/Desktop/rgb_tensor.pt") 82 | torch.save(depth, "/home/amazon/Desktop/depth_tensor.pt") 83 | torch.save(eof, "/home/amazon/Desktop/eof_tensor.pt") 84 | torch.save(tau_input, "/home/amazon/Desktop/tau.pt") 85 | 86 | with torch.no_grad(): 87 | out, aux = self.model(rgb, depth, eof_input, tau_input) 88 | torch.save(out, "/home/amazon/Desktop/out.pt") 89 | torch.save(aux, "/home/amazon/Desktop/aux.pt") 90 | out = out.squeeze() 91 | x_cartesian = out[0].item() 92 | y_cartesian = out[1].item() 93 | z_cartesian = out[2].item() 94 | print("X:{}, Y:{}, Z:{}".format(x_cartesian, y_cartesian, z_cartesian)) 95 | print("Aux: {}".format(aux)) 96 | # This new pose is the previous pose + the deltas output by the net, adjusted for discrepancy in frame 97 | # It used to be: 98 | # pose_to_move.position.x += -y_cartesian 99 | # pose_to_move.position.y += x_cartesian 100 | # pose_to_move.position.z += z_cartesian 101 | 102 | pose_to_move.position.x -= y_cartesian 103 | pose_to_move.position.y += x_cartesian 104 | pose_to_move.position.z += z_cartesian 105 | #print(pose_to_move) 106 | 107 | # Publish to Kuka!!!! 108 | for i in range(10): 109 | self.movement_publisher.publish(pose_to_move) 110 | rospy.Rate(10).sleep() 111 | 112 | rospy.wait_for_message("/iiwa/CollisionAwareExecutionStatus", String) 113 | # End publisher 114 | 115 | self.data = self.Data(pose=None,rgb=None,depth=None) 116 | rate.sleep() 117 | 118 | #print(pose_to_move.position.y, -pose_to_move.position.x, pose_to_move.position.z) 119 | #print(distance(tau, (pose_to_move.position.y, -pose_to_move.position.x, pose_to_move.position.z))) 120 | if distance(tau, (pose_to_move.position.y, -pose_to_move.position.x, pose_to_move.position.z)) < tolerance: 121 | break 122 | 123 | def process_images(self, img_msg, is_it_rgb): 124 | crop_right=586 125 | crop_lower=386 126 | img = self.bridge.compressed_imgmsg_to_cv2(img_msg, desired_encoding="passthrough") 127 | if(is_it_rgb): 128 | img = img[:,:,::-1] 129 | # Does this crop work? 130 | #rgb = img[0:386, 0:586] 131 | #rgb = img.crop((0, 0, crop_right, crop_lower)) 132 | rgb = cv2.resize(img, (160,120)) 133 | rgb = np.array(rgb).astype(np.float32) 134 | 135 | rgb = 2*((rgb - np.amin(rgb))/(np.amax(rgb)-np.amin(rgb)))-1 136 | 137 | 138 | rgb = torch.from_numpy(rgb).type(torch.FloatTensor) 139 | if is_it_rgb: 140 | rgb = rgb.view(1, rgb.shape[0], rgb.shape[1], rgb.shape[2]).permute(0, 3, 1, 2) 141 | else: 142 | rgb = rgb.view(1, 1, rgb.shape[0], rgb.shape[1]) 143 | #plt.imshow(rgb[0,0] / 2 + .5) 144 | # plt.show() 145 | 146 | return rgb 147 | 148 | def move_to_start(self): 149 | # Publish starting position to Kuka!!!! 150 | for i in range(10): 151 | self.movement_publisher.publish(self.target_start) 152 | rospy.Rate(10).sleep() 153 | 154 | rospy.wait_for_message("/iiwa/CollisionAwareExecutionStatus", String) 155 | # End publisher 156 | 157 | def init_listeners(self): 158 | # The Topics we are Subscribing to for data 159 | self.right_arm_pose = rospy.Subscriber('/iiwa/state/CartesianPose', PoseStamped, self.pose_callback) 160 | self.rgb_state_sub = rospy.Subscriber('/camera3/camera/color/image_rect_color/compressed', CompressedImage, self.rgb_callback) 161 | self.depth_state_sub = rospy.Subscriber('/camera3/camera/depth/image_rect_raw/compressed', CompressedImage, self.depth_callback) 162 | 163 | def unsubscribe(self): 164 | self.right_arm_pose.unregister() 165 | self.rgb_state_sub.unregister() 166 | self.depth_state_sub.unregister() 167 | 168 | def pose_callback(self, pose): 169 | if None in self.data: 170 | self.data.pose = pose.pose 171 | 172 | def rgb_callback(self, rgb): 173 | if None in self.data: 174 | self.data.rgb = rgb 175 | 176 | def depth_callback(self, depth): 177 | if None in self.data: 178 | self.data.depth = depth 179 | 180 | 181 | def translate_tau(button): 182 | b_0 = int(button[0]) 183 | b_1 = int(button[1]) 184 | tau = [.558-.069*b_0, .22-.063*b_1, .22] 185 | return tau 186 | 187 | 188 | def distance(a, b): 189 | a = [a[0]-.035, a[1], .94] 190 | print(a) 191 | return np.sqrt(np.sum([np.abs(aa - bb) for aa, bb in zip(a,b)])) 192 | 193 | 194 | def get_tau(): 195 | button = input('Please enter a button for the robot to try to press (e.g. "0,0", "1,2"): ') 196 | return button + (0,) 197 | tau = translate_tau(button) 198 | return tau 199 | 200 | 201 | def main(weights, tolerance): 202 | agent = ImitateEval(weights) 203 | agent.change_start() 204 | agent.move_to_start() 205 | tau = get_tau() 206 | agent.move_to_button(tau, tolerance) 207 | 208 | 209 | def sighandler(signal, frame): 210 | raise Exception('op') 211 | 212 | 213 | if __name__ == "__main__": 214 | parser = argparse.ArgumentParser(description="Arguments for evaluating imitation net") 215 | parser.add_argument('-w', '--weights', required=True, help='Filepath for model weightsself.') 216 | parser.add_argument('-t', '--tolerance', default=0, type=float, help='Tolerance for button presses.') 217 | args = parser.parse_args() 218 | rospy.init_node('eval_imitation', log_level=rospy.DEBUG) 219 | 220 | signal.signal(signal.SIGINT, sighandler) 221 | 222 | cont = False 223 | while True: 224 | try: 225 | main(args.weights, args.tolerance) 226 | except Exception: 227 | if cont: 228 | break 229 | cont = True 230 | continue 231 | cont = False 232 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Targetable Visuomotor Imitation Learning 2 | 3 | Code repository for the paper "Learning Deep Parameterized Skills for Re-Targetable Visuomotor Control" by the H2R Lab in collaboration with MERL. 4 | 5 | Read the paper here: https://arxiv.org/abs/1910.10628 6 | -------------------------------------------------------------------------------- /grid_script: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source /data/people/shastin1/venv/bin/activate 4 | python train.py "$@" 5 | -------------------------------------------------------------------------------- /mellow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | def mellowmax2(self ,beta): 5 | c = np.max(self.Q[0]) 6 | mm = c + np.log((1/self.a_size)*np.sum(np.exp(self.omega * (self.Q[0] - c ))))/self.omega 7 | b = 0 8 | for a in self.Q[0]: 9 | b+=np.exp(beta * (a-mm))*(a-mm) 10 | return b 11 | 12 | def mellowmax(x, dim=0, beta=6, omega=3): 13 | c = torch.max(x, dim=dim, keepdim=True)[0] 14 | mm = c + torch.log((1/x.size(dim))*torch.sum(torch.exp(omega * (x - c)))) / omega 15 | return torch.exp(beta * (x - mm)) * (x - mm) 16 | -------------------------------------------------------------------------------- /old_stack/imitate_msgs/msg/GripperStamped.msg: -------------------------------------------------------------------------------- 1 | std_msgs/Header header 2 | std_msgs/Int64 data 3 | -------------------------------------------------------------------------------- /old_stack/imitate_msgs/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | imitate_msgs 4 | 0.0.0 5 | The imitate_msgs package 6 | 7 | 8 | 9 | 10 | jonathanchang 11 | 12 | 13 | 14 | 15 | 16 | TODO 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | catkin 52 | std_msgs 53 | std_msgs 54 | std_msgs 55 | message_generation 56 | message_runtime 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /old_stack/publishers/rightGripperPublisher.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import roslib 3 | import rospy 4 | import tf 5 | from tf.transformations import euler_from_quaternion 6 | import numpy as np 7 | import message_filters 8 | from std_msgs.msg import Int64, Header 9 | from sensor_msgs.msg import JointState 10 | from imitate_msgs.msg import GripperStamped 11 | 12 | 13 | is_open = None 14 | 15 | def callback(data): 16 | global is_open 17 | position = np.array(data.position) - 0.0016998404159380444 18 | if position.sum() < 0.05: 19 | is_open = 1 20 | else: 21 | is_open = 0 22 | 23 | # Publisher From Here 24 | rospy.init_node('movo_right_gripper') 25 | listener = rospy.Subscriber("/movo/right_gripper/joint_states", JointState, callback) 26 | 27 | # Give time for initialization 28 | rospy.Rate(1).sleep() 29 | 30 | robot_in_map = rospy.Publisher('/movo/right_gripper/gripper_is_open', GripperStamped, queue_size=1) 31 | 32 | rate = rospy.Rate(100) 33 | while not rospy.is_shutdown(): 34 | h = Header() 35 | h.stamp = rospy.Time.now() 36 | msg = GripperStamped() 37 | msg.header = h 38 | msg.data = Int64(is_open) 39 | robot_in_map.publish(msg) 40 | rate.sleep() 41 | -------------------------------------------------------------------------------- /old_stack/publishers/rightPosePublisher.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import roslib 3 | import rospy 4 | import tf 5 | from tf.transformations import euler_from_quaternion 6 | from geometry_msgs.msg import PoseStamped, Pose, Point, Quaternion 7 | from std_msgs.msg import Header 8 | 9 | rospy.init_node('movo_right_tf_listener_pose') 10 | 11 | listener = tf.TransformListener() 12 | 13 | robot_in_map = rospy.Publisher('tf/right_arm_pose', PoseStamped, queue_size=1) 14 | 15 | rate = rospy.Rate(100) 16 | while not rospy.is_shutdown(): 17 | try: 18 | (trans, rot) = listener.lookupTransform('/base_link', '/right_gripper_base_link', rospy.Time(0)) 19 | h = Header() 20 | h.stamp = rospy.Time.now() 21 | msg = PoseStamped() 22 | msg.header = h 23 | msg.pose = Pose(Point(*trans), Quaternion(*rot)) 24 | 25 | except (tf.LookupException, tf.ConnectivityException, tf.ExtrapolationException): 26 | continue 27 | # Publish the message 28 | robot_in_map.publish(msg) 29 | rate.sleep() 30 | -------------------------------------------------------------------------------- /old_stack/publishers/rightVelPublisher.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import numpy as np 3 | import roslib 4 | import rospy 5 | import tf 6 | from tf.transformations import euler_from_quaternion 7 | from geometry_msgs.msg import Twist, TwistStamped, Vector3 8 | from std_msgs.msg import Header 9 | 10 | rospy.init_node('movo_right_tf_listener') 11 | 12 | listener = tf.TransformListener() 13 | 14 | robot_in_map = rospy.Publisher('tf/right_arm_vels', TwistStamped, queue_size=1) 15 | 16 | prev_time = rospy.get_time() 17 | curr_time = rospy.get_time() 18 | prev_rot, prev_trans = None, None 19 | 20 | rate = rospy.Rate(100) 21 | while not rospy.is_shutdown(): 22 | try: 23 | curr_time = rospy.get_time() 24 | (curr_trans, curr_rot) = listener.lookupTransform('/base_link', '/right_gripper_base_link', rospy.Time(0)) 25 | # First Iteration 26 | if not prev_rot and not prev_trans: 27 | prev_trans, prev_rot = curr_trans, curr_rot 28 | continue 29 | delta = float(curr_time)-float(prev_time) 30 | lin_vel = (np.array(curr_trans) - np.array(prev_trans))/delta 31 | ang_vel = (np.array(euler_from_quaternion(curr_rot)) - np.array(euler_from_quaternion(prev_rot)))/delta 32 | h = Header() 33 | h.stamp = rospy.Time.now() 34 | msg = TwistStamped() 35 | msg.header = h 36 | msg.twist = Twist(Vector3(*lin_vel), Vector3(*ang_vel)) 37 | 38 | # Update 39 | prev_time = curr_time 40 | prev_trans, prev_rot = curr_trans, curr_rot 41 | 42 | except (tf.LookupException, tf.ConnectivityException, tf.ExtrapolationException): 43 | continue 44 | # Publish the message 45 | robot_in_map.publish(msg) 46 | rate.sleep() 47 | -------------------------------------------------------------------------------- /old_stack/publishers/toggletest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import roslib 3 | import rospy 4 | from std_msgs.msg import Int64 5 | 6 | # Publisher From Here 7 | rospy.init_node('toggle_test') 8 | 9 | robot_in_map = rospy.Publisher('/toggle', Int64, queue_size=1) 10 | 11 | rate = rospy.Rate(100) 12 | while not rospy.is_shutdown(): 13 | rate.sleep() -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/README.md -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_env.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: imitate-env 3 | Version: 0.0.1 4 | Summary: UNKNOWN 5 | Home-page: UNKNOWN 6 | Author: UNKNOWN 7 | Author-email: UNKNOWN 8 | License: UNKNOWN 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id='buttons-v0', 5 | entry_point='imitate_gym.envs:ButtonsEnv', 6 | ) 7 | -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/buttons_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import gym 4 | import argparse 5 | import imitate_gym 6 | import numpy as np 7 | from PIL import Image 8 | 9 | # GOAL_00 = np.array([1.2,0.66,0.45]) 10 | # GOAL_01 = np.array([1.2,0.76,0.45]) 11 | # GOAL_02 = np.array([1.2,0.86,0.45]) 12 | # GOAL_10 = np.array([1.3,0.66,0.45]) 13 | # GOAL_11 = np.array([1.3,0.76,0.45]) 14 | # GOAL_12 = np.array([1.3,0.86,0.45]) 15 | # GOAL_20 = np.array([1.4,0.66,0.45]) 16 | # GOAL_21 = np.array([1.4,0.76,0.45]) 17 | # GOAL_22 = np.array([1.4,0.86,0.45]) 18 | # GOALS = np.array([GOAL_00,GOAL_01,GOAL_02,GOAL_10,GOAL_11,GOAL_12,GOAL_20,GOAL_21,GOAL_22]) 19 | # NAMES = {0:'goal_00',1:'goal_01',2:'goal_02',3:'goal_10',4:'goal_11',5:'goal_12',6:'goal_20',7:'goal_21',8:'goal_22'} 20 | GOAL_00 = np.array([1.2,0.56,0.45]) 21 | GOAL_01 = np.array([1.2,0.66,0.45]) 22 | GOAL_02 = np.array([1.2,0.76,0.45]) 23 | GOAL_03 = np.array([1.2,0.86,0.45]) 24 | GOAL_04 = np.array([1.2,0.96,0.45]) 25 | GOAL_10 = np.array([1.3,0.56,0.45]) 26 | GOAL_11 = np.array([1.3,0.66,0.45]) 27 | GOAL_12 = np.array([1.3,0.76,0.45]) 28 | GOAL_13 = np.array([1.3,0.86,0.45]) 29 | GOAL_14 = np.array([1.3,0.96,0.45]) 30 | GOAL_20 = np.array([1.4,0.56,0.45]) 31 | GOAL_21 = np.array([1.4,0.66,0.45]) 32 | GOAL_22 = np.array([1.4,0.76,0.45]) 33 | GOAL_23 = np.array([1.4,0.86,0.45]) 34 | GOAL_24 = np.array([1.4,0.96,0.45]) 35 | GOALS = np.array([GOAL_00, 36 | GOAL_01, 37 | GOAL_02, 38 | GOAL_03, 39 | GOAL_04, 40 | GOAL_10, 41 | GOAL_11, 42 | GOAL_12, 43 | GOAL_13, 44 | GOAL_14, 45 | GOAL_20, 46 | GOAL_21, 47 | GOAL_22, 48 | GOAL_23, 49 | GOAL_24]) 50 | NAMES = {0:'goal_00', 51 | 1:'goal_01', 52 | 2:'goal_02', 53 | 3:'goal_03', 54 | 4:'goal_04', 55 | 5:'goal_10', 56 | 6:'goal_11', 57 | 7:'goal_12', 58 | 8:'goal_13', 59 | 9:'goal_14', 60 | 10:'goal_20', 61 | 11:'goal_21', 62 | 12:'goal_22', 63 | 13:'goal_23', 64 | 14:'goal_24'} 65 | test=np.array([GOAL_24]) 66 | 67 | def calc_changes(start_pos, goal_pos, num_steps=50.): 68 | return (goal_pos-start_pos)/num_steps 69 | 70 | def next_move(curr_pos, goal_pos, slope): 71 | diffs = np.absolute(curr_pos - goal_pos) 72 | if diffs[0] <= 0.001: 73 | slope[0] = 0. 74 | if diffs[1] <= 0.001: 75 | slope[1] = 0. 76 | if diffs[2] <= 0.001: 77 | slope[2] = 0. 78 | return slope 79 | 80 | def run(env, goal, trial_dir, viz=False): 81 | if viz == False: 82 | with open(trial_dir+'vectors.txt', 'w') as f: 83 | writer = csv.writer(f) 84 | curr_pos = None 85 | slope = None 86 | goal_reached = False 87 | counter = 0 88 | traj_counter = 0 89 | other = np.array([0.,0.,0.,0.,0.]) 90 | while True: 91 | rgb = env.render('rgb_array') 92 | if curr_pos is None and goal_reached is False: 93 | obs, _, _, _ = env.step([0.,0.,0.,0.,0.,0.,0.,0.]) 94 | curr_pos = obs['achieved_goal'] 95 | slope = calc_changes(curr_pos, goal) 96 | elif curr_pos is not None and goal_reached is False: 97 | slope = next_move(curr_pos, goal, slope) 98 | if sum(np.absolute(slope)) < 0.008 and curr_pos[2] < 0.455: 99 | goal_reached = True 100 | action = np.concatenate([slope,other]) 101 | obs, _, _, _ = env.step(action) 102 | img_rgb = Image.fromarray(rgb, 'RGB') 103 | img_rgb = img_rgb.resize((160,120)) 104 | img_rgb.save(trial_dir+'{}.png'.format(traj_counter)) 105 | arr = [traj_counter] 106 | arr += [x for x in obs['observation']] 107 | arr += [0] 108 | curr_pos = obs['achieved_goal'] 109 | traj_counter += 1 110 | writer.writerow(arr) 111 | elif curr_pos is not None and goal_reached is True: 112 | # This is to simulate the terminal state where the user would stop near the end 113 | counter += 1 114 | obs, _, _, _ = env.step([0.,0.,0.,0.,0.,0.,0.,0.]) 115 | img_rgb = Image.fromarray(rgb, 'RGB') 116 | img_rgb = img_rgb.resize((160,120)) 117 | img_rgb.save(trial_dir+'{}.png'.format(traj_counter)) 118 | arr = [traj_counter] 119 | arr += [x for x in obs['observation']] 120 | arr += [0] 121 | traj_counter += 1 122 | writer.writerow(arr) 123 | if counter == 20: 124 | break 125 | else: 126 | curr_pos = None 127 | slope = None 128 | goal_reached = False 129 | counter = 0 130 | other = np.array([0.,0.,0.,0.,0.]) 131 | while True: 132 | env.render() 133 | if curr_pos is None and goal_reached is False: 134 | obs, _, _, _ = env.step([0.,0.,0.,0.,0.,0.,0.,0.]) 135 | curr_pos = obs['achieved_goal'] 136 | slope = calc_changes(curr_pos, goal) 137 | elif curr_pos is not None and goal_reached is False: 138 | slope = next_move(curr_pos, goal, slope) 139 | if sum(np.absolute(slope)) < 0.008 and curr_pos[2] < 0.455: 140 | goal_reached = True 141 | action = np.concatenate([slope,other]) 142 | obs, _, _, _ = env.step(action) 143 | print("==================") 144 | print(slope) 145 | print(obs['observation'][7:10]) 146 | curr_pos = obs['achieved_goal'] 147 | print(curr_pos) 148 | elif curr_pos is not None and goal_reached is True: 149 | # This is to simulate the terminal state where the user would stop near the end 150 | counter += 1 151 | obs, _, _, _ = env.step([0.,0.,0.,0.,0.,0.,0.,0.]) 152 | print("==================") 153 | print(slope) 154 | print(obs['observation'][7:10]) 155 | if counter == 20: 156 | break 157 | 158 | 159 | 160 | 161 | def main(task_name, num_trials=300, viz=False): 162 | if viz == False: 163 | if not os.path.exists('../../../datas/' + task_name + '/'): 164 | os.mkdir('../../../datas/' + task_name + '/') 165 | trial_counter = 0 166 | for trial in range(num_trials): 167 | env = gym.make('buttons-v0') 168 | for i, goal in enumerate(GOALS): 169 | save_folder = '../../../datas/'+task_name+'/'+NAMES[i]+'/' 170 | if viz == False: 171 | if not os.path.exists(save_folder): 172 | os.mkdir(save_folder) 173 | trial_dir = save_folder+str(trial_counter)+'/' 174 | if viz == False: 175 | os.mkdir(trial_dir) 176 | env.reset() 177 | run(env, goal, trial_dir, viz=viz) 178 | trial_counter+=1 179 | 180 | if __name__ == '__main__': 181 | main("buttons3x5", num_trials=1, viz=True) 182 | -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from imitate_gym.envs.imitate_env import ImitateEnv 2 | from imitate_gym.envs.buttons import ButtonsEnv 3 | -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/buttons/buttons.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/buttons/buttons_3x3.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/buttons/robot.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/buttons/shared.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/.get: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/.get -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/base_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/base_link_collision.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/bellows_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/bellows_link_collision.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/elbow_flex_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/elbow_flex_link_collision.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/estop_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/estop_link.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/forearm_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/forearm_roll_link_collision.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/gripper_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/gripper_link.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/head_pan_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/head_pan_link_collision.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/head_tilt_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/head_tilt_link_collision.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/l_wheel_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/l_wheel_link_collision.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/laser_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/laser_link.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/r_wheel_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/r_wheel_link_collision.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/shoulder_lift_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/shoulder_lift_link_collision.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/shoulder_pan_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/shoulder_pan_link_collision.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/torso_fixed_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/torso_fixed_link.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/torso_lift_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/torso_lift_link_collision.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/upperarm_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/upperarm_roll_link_collision.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/wrist_flex_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/wrist_flex_link_collision.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/wrist_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/wrist_roll_link_collision.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F1.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F2.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F3.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH1_z.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH1_z.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH2_z.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH2_z.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH3_z.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH3_z.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/forearm_electric.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/forearm_electric.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/forearm_electric_cvx.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/forearm_electric_cvx.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/knuckle.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/knuckle.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/lfmetacarpal.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/lfmetacarpal.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/palm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/palm.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/wrist.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/wrist.stl -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/textures/block.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/textures/block.png -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/assets/textures/block_hidden.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/textures/block_hidden.png -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/buttons.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from gym import utils 4 | from imitate_gym.envs import imitate_env 5 | 6 | #MODEL_XML_PATH = os.path.join('fetch', 'pick_and_place.xml') 7 | HOME = '/home/jonathanchang/parameterized-imitation-learning/simulation/gym-imitate/imitate_gym/envs' 8 | MODEL_XML_PATH = HOME+'/assets/buttons/buttons_3x3.xml' 9 | class ButtonsEnv(imitate_env.ImitateEnv, utils.EzPickle): 10 | def __init__(self, reward_type='dense'): 11 | initial_qpos = { 12 | 'robot0:slide0': 0.405, 13 | 'robot0:slide1': 0.48, 14 | 'robot0:slide2': 0.0, 15 | 'robot0:shoulder_pan_joint': 0.0, 16 | 'robot0:shoulder_lift_joint': -0.8, 17 | 'robot0:elbow_flex_joint': 1.0 18 | } 19 | target = np.array([1.25, 0.53, 0.4]) 20 | imitate_env.ImitateEnv.__init__( 21 | self, model_path=MODEL_XML_PATH, n_substeps=25, initial_qpos=initial_qpos, reward_type=reward_type, 22 | distance_threshold=0.05, gripper_extra_height=0.2, target=target) 23 | utils.EzPickle.__init__(self) 24 | -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/envs/imitate_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import gym 4 | import mujoco_py 5 | from gym import error, spaces, utils 6 | from gym.envs.robotics import rotations, robot_env, utils 7 | 8 | # Just to make sure that input shape is correct 9 | def goal_distance(goal_a, goal_b): 10 | assert goal_a.shape == goal_b.shape 11 | return np.linalg.norm(goal_a - goal_b, axis=-1) 12 | 13 | class ImitateEnv(robot_env.RobotEnv): 14 | """ 15 | Inheritance Chain: Env -> GoalEnv -> RobotEnv -> ImitateEnv 16 | 17 | Methods from RobotEnv 18 | seed(), step(), reset(), close(), render(), _reset_sim(), 19 | _get_viewer(): This method will be important for retrieving rgb and depth from the sim 20 | """ 21 | def __init__( 22 | self, model_path, n_substeps, initial_qpos, reward_type, 23 | distance_threshold, target, gripper_extra_height 24 | ): 25 | """ 26 | Args: model_path (string): path to the environments XML file 27 | n_substeps (int): number of substeps the simulation runs on every call to step 28 | initial_qpos (dict): a dictionary of joint names and values that define the initial configuration 29 | reward_type ('sparse' or 'dense'): the reward type, i.e. sparse or dense 30 | distance_threshold (float): the threshold after which a goal is considered achieved 31 | target (list): the target position that we are aiming for 32 | gripper_extra_height (float): the gripper offset in position 33 | """ 34 | # n_actions = 8 for [pos_x, pos_y, pos_z, rot_x, rot_y, rot_z, rot_w, gripper] 35 | 36 | self.reward_type = reward_type 37 | self.distance_threshold = distance_threshold 38 | self.target = target 39 | self.gripper_extra_height = gripper_extra_height 40 | self._viewers = {} 41 | self.initial_qpos = initial_qpos 42 | super(ImitateEnv, self).__init__(model_path=model_path, 43 | n_substeps=n_substeps, 44 | n_actions=8, 45 | initial_qpos=initial_qpos) 46 | # Env Method 47 | # -------------------------------------------------- 48 | def render(self, mode='human'): 49 | self._render_callback() 50 | if mode == 'rgb_array': 51 | self._get_viewer(mode).render(1000,1000) 52 | rgb = self._get_viewer(mode).read_pixels(1000, 1000, depth=False) 53 | #rgbd = self._get_viewer(mode).read_pixels(1000, 1000, depth=True) 54 | #rgb = rgbd[0][::-1,:,:] 55 | #depth = rgbd[1] 56 | #depth = (depth-np.amin(depth))*(255/(np.amax(depth)-np.amin(depth))) 57 | return rgb[::-1,:,:] 58 | elif mode == 'human': 59 | self._get_viewer(mode).render() 60 | 61 | def _get_viewer(self, mode): 62 | self.viewer = self._viewers.get(mode) 63 | if self.viewer is None: 64 | if mode == 'human': 65 | self.viewer = mujoco_py.MjViewer(self.sim) 66 | elif mode == 'rgb_array': 67 | self.viewer = mujoco_py.MjRenderContextOffscreen(self.sim, device_id=-1) 68 | self._viewer_setup() 69 | self._viewers[mode] = self.viewer 70 | return self.viewer 71 | # GoalEnv Method 72 | # -------------------------------------------------- 73 | def compute_reward(self, achieved_goal, goal, info): 74 | d = goal_distance(achieved_goal, goal) 75 | if self.reward_type == 'sparse': 76 | return -(d > self.distance_threshold).astype(np.float32) 77 | else: 78 | return -d 79 | 80 | # RobotEnv Methods 81 | # -------------------------------------------------- 82 | 83 | def _step_callback(self): 84 | """ 85 | In the step function in the RobotEnv it does: 1) self._set_action(action) 86 | 2) self.sim.step() 87 | 3) self._step_callback() 88 | This method can be used to provide additional constraints to the actions that happen every step. 89 | Could be used to fix the gripper to be a certain orientation for example 90 | """ 91 | pass 92 | 93 | def _set_action(self, action): 94 | """ 95 | Currently, I am just returning 1 number for the gripper control because there are 2 fingers in the 96 | robot that is used in the OpenAI gym library. If I start using movo, I would need 3 fingers so 97 | perhaps increase the gripper control from 2 to 3. 98 | """ 99 | assert action.shape == (8,) 100 | action = action.copy() # ensures that we don't change the action outside of this scope 101 | pos_ctrl, rot_ctrl, gripper_ctrl = action[:3], action[3:7], action[-1] 102 | 103 | # This is where I match the gripper control to number of fingers 104 | gripper_ctrl = np.array([gripper_ctrl, gripper_ctrl]) 105 | assert gripper_ctrl.shape == (2,) 106 | 107 | # Create the action that we want to pass into the simulation 108 | action = np.concatenate([pos_ctrl, rot_ctrl, gripper_ctrl]) 109 | 110 | # Now apply the action to the simulation 111 | utils.ctrl_set_action(self.sim, action) # note self.sim is an inherited variable 112 | utils.mocap_set_action(self.sim, action) 113 | 114 | def _get_obs(self): 115 | """ 116 | This is where I make sure to grab the observations for the position and the velocities. Potential 117 | area to grab the rgb and depth images. 118 | """ 119 | ee = 'robot0:grip' 120 | dt = self.sim.nsubsteps * self.sim.model.opt.timestep 121 | ee_pos = self.sim.data.get_site_xpos(ee) 122 | ee_quat = rotations.mat2quat(self.sim.data.get_site_xmat(ee)) 123 | # Position and angular velocity respectively 124 | ee_velp = self.sim.data.get_site_xvelp(ee) * dt # to remove time dependency 125 | ee_velr = self.sim.data.get_site_xvelr(ee) * dt 126 | 127 | obs = np.concatenate([ee_pos, ee_quat, ee_velp, ee_velr]) 128 | 129 | return { 130 | 'observation': obs.copy(), 131 | 'achieved_goal': ee_pos.copy(), 132 | 'desired_goal': self.goal.copy() 133 | } 134 | 135 | def _viewer_setup(self): 136 | """ 137 | This could be used to reorient the starting view frame. Will have to mess around 138 | """ 139 | body_id = self.sim.model.body_name2id('robot0:gripper_link') 140 | lookat = self.sim.data.body_xpos[body_id] 141 | for idx, value in enumerate(lookat): 142 | self.viewer.cam.lookat[idx] = value 143 | self.viewer.cam.distance = 1.7 144 | self.viewer.cam.azimuth = 180. 145 | self.viewer.cam.elevation = -30. 146 | 147 | def _sample_goal(self): 148 | """ 149 | Instead of sampling I am currently just setting our defined goal as the "sampled" goal 150 | """ 151 | return self.target 152 | 153 | def _is_success(self, achieved_goal, desired_goal): 154 | d = goal_distance(achieved_goal, desired_goal) 155 | return (d < self.distance_threshold).astype(np.float32) 156 | 157 | def _env_setup(self, initial_qpos): 158 | # Randomize the starting position 159 | shoulder_pan_val = -0.1 + (random.random()*0.2) 160 | initial_qpos['robot0:shoulder_pan_joint'] = shoulder_pan_val 161 | for name, value in initial_qpos.items(): 162 | self.sim.data.set_joint_qpos(name, value) 163 | utils.reset_mocap_welds(self.sim) 164 | self.sim.forward() 165 | 166 | # Move end effector into position 167 | gripper_target = np.array([-0.498, 0.005, -0.431 + self.gripper_extra_height]) + self.sim.data.get_site_xpos('robot0:grip') 168 | gripper_rotation = np.array([1.,0.,1.,0.]) 169 | self.sim.data.set_mocap_pos('robot0:mocap', gripper_target) 170 | self.sim.data.set_mocap_quat('robot0:mocap', gripper_rotation) 171 | for _ in range(10): 172 | self.sim.step() 173 | 174 | 175 | 176 | 177 | 178 | 179 | -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/imitate_gym/gym_eval.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import imitate_gym 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | from crnn import SpatialCRNN 7 | import sys 8 | seq_len = 10 9 | weights = sys.argv[1] 10 | tau = [[int(sys.argv[2]), int(sys.argv[3])] for _ in range(seq_len)] 11 | tau = torch.from_numpy(np.array(tau)).type(torch.FloatTensor) 12 | tau = torch.unsqueeze(tau, dim=0) 13 | env = gym.make('buttons-v0') 14 | env.reset() 15 | model = SpatialCRNN(rtype="GRU", num_layers=2, device="cpu") 16 | model.load_state_dict(torch.load(weights, map_location='cpu'), strict=False) 17 | model.eval() 18 | eof = None 19 | rgbs = [] 20 | while True: 21 | env.render() 22 | rgb = env.render('rgb_array') 23 | if eof is None: 24 | obs, _,_,_ = env.step([0.,0.,0.,0.,0.,0.,0.,0.]) 25 | pos = obs['achieved_goal'] 26 | eof = [pos for _ in range(seq_len)] 27 | img_rgb = Image.fromarray(rgb, 'RGB') 28 | img_rgb = np.array(img_rgb.resize((160,120))) 29 | img_rgb = np.reshape(img_rgb,(3,120,160)) 30 | rgbs = [img_rgb for _ in range(seq_len)] 31 | else: 32 | img_rgb = Image.fromarray(rgb, 'RGB') 33 | img_rgb = np.array(img_rgb.resize((160,120))) 34 | img_rgb = np.reshape(img_rgb,(3,120,160)) 35 | rgbs.pop(0) 36 | rgbs.append(img_rgb) 37 | input_rgb = torch.from_numpy(np.array(rgbs)).type(torch.FloatTensor) 38 | input_rgb = torch.unsqueeze(input_rgb, dim=0) 39 | input_eof = torch.from_numpy(np.array(eof)).type(torch.FloatTensor) 40 | input_eof = torch.unsqueeze(input_eof, dim=0) 41 | vels, _ = model([input_rgb, input_eof, tau]) 42 | action = [vels[0][0].item(), vels[0][1].item(), vels[0][2].item(), 0., 0., 0., 0., 0.] 43 | print(action) 44 | obs, _, _, _ = env.step(action) 45 | eof.pop(0) 46 | eof.append(obs['achieved_goal']) 47 | 48 | -------------------------------------------------------------------------------- /old_stack/simulation/gym-imitate/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='imitate_env', 4 | version='0.0.1', 5 | install_requires=['gym', 'mujoco-py'] 6 | ) 7 | -------------------------------------------------------------------------------- /old_stack/test_baxter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import sys 3 | import rospy 4 | import itertools 5 | import message_filters 6 | import numpy as np 7 | import time 8 | import cv2 9 | import moveit_commander 10 | from cv_bridge import CvBridge 11 | from std_msgs.msg import Int64 12 | from geometry_msgs.msg import Point, Quaternion, Pose, Twist 13 | from sensor_msgs.msg import CompressedImage, JointState 14 | 15 | class ApproxTimeSync(message_filters.ApproximateTimeSynchronizer): 16 | def add(self, msg, my_queue, my_queue_index=None): 17 | self.allow_headerless = True 18 | if hasattr(msg, 'timestamp'): 19 | stamp = msg.timestamp 20 | elif not hasattr(msg, 'header') or not hasattr(msg.header, 'stamp'): 21 | if not self.allow_headerless: 22 | rospy.logwarn("Cannot use message filters with non-stamped messages. " 23 | "Use the 'allow_headerless' constructor option to " 24 | "auto-assign ROS time to headerless messages.") 25 | return 26 | stamp = rospy.Time.now() 27 | else: 28 | stamp = msg.header.stamp 29 | 30 | #TODO ADD HEADER TO ALLOW HEADERLESS 31 | # http://book2code.com/ros_kinetic/source/ros_comm/message_filters/src/message_filters/__init__.y 32 | #setattr(msg, 'header', a) 33 | #msg.header.stamp = stamp 34 | #super(message_filters.ApproximateTimeSynchronizer, self).add(msg, my_queue) 35 | self.lock.acquire() 36 | my_queue[stamp] = msg 37 | while len(my_queue) > self.queue_size: 38 | del my_queue[min(my_queue)] 39 | # self.queues = [topic_0 {stamp: msg}, topic_1 {stamp: msg}, ...] 40 | if my_queue_index is None: 41 | search_queues = self.queues 42 | else: 43 | search_queues = self.queues[:my_queue_index] + \ 44 | self.queues[my_queue_index+1:] 45 | # sort and leave only reasonable stamps for synchronization 46 | stamps = [] 47 | for queue in search_queues: 48 | topic_stamps = [] 49 | for s in queue: 50 | stamp_delta = abs(s - stamp) 51 | if stamp_delta > self.slop: 52 | continue # far over the slop 53 | topic_stamps.append((s, stamp_delta)) 54 | if not topic_stamps: 55 | self.lock.release() 56 | return 57 | topic_stamps = sorted(topic_stamps, key=lambda x: x[1]) 58 | stamps.append(topic_stamps) 59 | for vv in itertools.product(*[zip(*s)[0] for s in stamps]): 60 | vv = list(vv) 61 | # insert the new message 62 | if my_queue_index is not None: 63 | vv.insert(my_queue_index, stamp) 64 | qt = list(zip(self.queues, vv)) 65 | if ( ((max(vv) - min(vv)) < self.slop) and 66 | (len([1 for q,t in qt if t not in q]) == 0) ): 67 | msgs = [q[t] for q,t in qt] 68 | self.signalMessage(*msgs) 69 | for q,t in qt: 70 | del q[t] 71 | break # fast finish after the synchronization 72 | self.lock.release() 73 | 74 | class ImitateLearner(): 75 | """ 76 | This class will evaluate the outputs of the net. 77 | """ 78 | def __init__(self, arm='right'): 79 | rospy.init_node("{}_arm_eval".format(arm)) 80 | self.queue = [] 81 | self.rgb = None 82 | self.depth = None 83 | self.pos = None 84 | self.orient = None 85 | self.prevTime = None 86 | self.time = None 87 | self.arm = arm 88 | 89 | # Initialize Subscribers 90 | self.listener() 91 | 92 | # Enable Robot 93 | moveit_commander.roscpp_initialize(sys.argv) 94 | robot = moveit_commander.RobotCommander() 95 | self.group_arms = moveit_commander.MoveGroupCommander('upper_body') 96 | self.group_arms.set_pose_reference_frame('/base_link') 97 | self.left_ee_link = 'left_ee_link' 98 | self.right_ee_link = 'right_ee_link' 99 | 100 | # Set the rate of our evaluation 101 | rate = rospy.Rate(0.5) 102 | 103 | # Give time for initialization 104 | rospy.Rate(1).sleep() 105 | 106 | # This is to the get the time delta 107 | first_time = True 108 | while not rospy.is_shutdown(): 109 | # Todo: Connect Net 110 | if first_time: 111 | self.prevTime = time.time() 112 | first_time = False 113 | # Given the output, solve for limb joints 114 | #limb_joints = self.get_limb_joints(output) 115 | 116 | # If valid joints then move to joint 117 | #if limb_joints is not -1: 118 | # right.move_to_joint_positions(limb_joints) 119 | # self.prevTime = self.time 120 | #else: 121 | # print 'ERROR: IK solver returned -1' 122 | print(self.pos) 123 | rate.sleep() 124 | 125 | def listener(self): 126 | """ 127 | Listener for all of the topics 128 | """ 129 | print("Listener Initialized") 130 | pose_sub = message_filters.Subscriber('/tf/{}_arm_pose'.format(self.arm), Pose) 131 | twist_sub = message_filters.Subscriber('tf/{}_arm_vels'.format(self.arm), Twist) 132 | rgb_sub = message_filters.Subscriber('/kinect2/sd/image_color_rect/compressed', CompressedImage) 133 | depth_sub = message_filters.Subscriber('/kinect2/sd/image_depth_rect/compressed', CompressedImage) 134 | gripper_sub = message_filters.Subscriber('/movo/{}_gripper/gripper_is_open'.format(self.arm), Int64) 135 | ts = ApproxTimeSync([pose_sub, twist_sub, rgb_sub, depth_sub, gripper_sub], 1, 0.1) 136 | ts.registerCallback(self.listener_callback) 137 | 138 | def listener_callback(self, pose, twist, rgb, depth, gripper): 139 | """ 140 | This method updates the variables. 141 | """ 142 | bridge = CvBridge() 143 | self.time = time.time() 144 | self.rgb = bridge.compressed_imgmsg_to_cv2(rgb) 145 | self.depth = bridge.compressed_imgmsg_to_cv2(depth) 146 | self.pos = pose.position 147 | self.orient = pose.orientation 148 | # Create input for net. x, y, z 149 | queue_input = np.array([self.pos.x, self.pos.y, self.pos.z]) 150 | if len(self.queue) == 0: 151 | self.queue = [queue_input for i in range(5)] 152 | else: 153 | self.queue.pop(0) 154 | self.queue.append(queue_input) 155 | 156 | def get_next_pose(self, output): 157 | """ 158 | This method gets the ik_solver solution for the arm joints. 159 | """ 160 | [goal_pos, goal_orient] = self.calculate_move(np.reshape(output[0, :3], (3,)), np.reshape(output[0, 3:], (3,))) 161 | return Point(*goal_pos), Quaternion(*goal_orient) 162 | 163 | def calculate_move(self, lin, ang): 164 | """ 165 | This calculates the position and orientation (in quaterion) of the next pose given 166 | the linear and angular velocities outputted by the net. 167 | """ 168 | delta = self.time - self.prevTime 169 | print("------------") 170 | print(delta) 171 | print("------------") 172 | #delta = 1/30 173 | # Position Update 174 | curr_pos = np.array([self.pos.x, self.pos.y, self.pos.z]) 175 | goal_pos = np.add(curr_pos, delta*np.array(lin)) 176 | # Orientation Update 177 | curr_orient = np.array([self.orient.x, self.orient.y, self.orient.z, self.orient.w]) 178 | w_ang = np.concatenate([[0], ang]) 179 | goal_orient = np.add(curr_orient, 0.5*delta*np.matmul(w_ang, np.transpose(curr_orient))) 180 | # Update the prevTime 181 | return goal_pos, goal_orient 182 | 183 | if __name__ == '__main__': 184 | learner = ImitateLearner() -------------------------------------------------------------------------------- /old_stack/util/old/imitate_msgs/msg/GripperStamped.msg: -------------------------------------------------------------------------------- 1 | std_msgs/Header header 2 | std_msgs/Int64 data 3 | -------------------------------------------------------------------------------- /old_stack/util/old/imitate_msgs/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | imitate_msgs 4 | 0.0.0 5 | The imitate_msgs package 6 | 7 | 8 | 9 | 10 | jonathanchang 11 | 12 | 13 | 14 | 15 | 16 | TODO 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | catkin 52 | std_msgs 53 | std_msgs 54 | std_msgs 55 | message_generation 56 | message_runtime 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /old_stack/util/old/publishers/rightGripperPublisher.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import roslib 3 | import rospy 4 | import tf 5 | from tf.transformations import euler_from_quaternion 6 | import numpy as np 7 | import message_filters 8 | from std_msgs.msg import Int64, Header 9 | from sensor_msgs.msg import JointState 10 | from imitate_msgs.msg import GripperStamped 11 | 12 | 13 | is_open = None 14 | 15 | def callback(data): 16 | global is_open 17 | position = np.array(data.position) - 0.0016998404159380444 18 | if position.sum() < 0.05: 19 | is_open = 1 20 | else: 21 | is_open = 0 22 | 23 | # Publisher From Here 24 | rospy.init_node('movo_right_gripper') 25 | listener = rospy.Subscriber("/movo/right_gripper/joint_states", JointState, callback) 26 | 27 | # Give time for initialization 28 | rospy.Rate(1).sleep() 29 | 30 | robot_in_map = rospy.Publisher('/movo/right_gripper/gripper_is_open', GripperStamped, queue_size=1) 31 | 32 | rate = rospy.Rate(100) 33 | while not rospy.is_shutdown(): 34 | h = Header() 35 | h.stamp = rospy.Time.now() 36 | msg = GripperStamped() 37 | msg.header = h 38 | msg.data = Int64(is_open) 39 | robot_in_map.publish(msg) 40 | rate.sleep() 41 | -------------------------------------------------------------------------------- /old_stack/util/old/publishers/rightPosePublisher.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import roslib 3 | import rospy 4 | import tf 5 | from tf.transformations import euler_from_quaternion 6 | from geometry_msgs.msg import PoseStamped, Pose, Point, Quaternion 7 | from std_msgs.msg import Header 8 | 9 | rospy.init_node('movo_right_tf_listener_pose') 10 | 11 | listener = tf.TransformListener() 12 | 13 | robot_in_map = rospy.Publisher('tf/right_arm_pose', PoseStamped, queue_size=1) 14 | 15 | rate = rospy.Rate(100) 16 | while not rospy.is_shutdown(): 17 | try: 18 | (trans, rot) = listener.lookupTransform('/base_link', '/right_gripper_base_link', rospy.Time(0)) 19 | h = Header() 20 | h.stamp = rospy.Time.now() 21 | msg = PoseStamped() 22 | msg.header = h 23 | msg.pose = Pose(Point(*trans), Quaternion(*rot)) 24 | 25 | except (tf.LookupException, tf.ConnectivityException, tf.ExtrapolationException): 26 | continue 27 | # Publish the message 28 | robot_in_map.publish(msg) 29 | rate.sleep() 30 | -------------------------------------------------------------------------------- /old_stack/util/old/publishers/rightVelPublisher.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import numpy as np 3 | import roslib 4 | import rospy 5 | import tf 6 | from tf.transformations import euler_from_quaternion 7 | from geometry_msgs.msg import Twist, TwistStamped, Vector3 8 | from std_msgs.msg import Header 9 | 10 | rospy.init_node('movo_right_tf_listener') 11 | 12 | listener = tf.TransformListener() 13 | 14 | robot_in_map = rospy.Publisher('tf/right_arm_vels', TwistStamped, queue_size=1) 15 | 16 | prev_time = rospy.get_time() 17 | curr_time = rospy.get_time() 18 | prev_rot, prev_trans = None, None 19 | 20 | rate = rospy.Rate(100) 21 | while not rospy.is_shutdown(): 22 | try: 23 | curr_time = rospy.get_time() 24 | (curr_trans, curr_rot) = listener.lookupTransform('/base_link', '/right_gripper_base_link', rospy.Time(0)) 25 | # First Iteration 26 | if not prev_rot and not prev_trans: 27 | prev_trans, prev_rot = curr_trans, curr_rot 28 | continue 29 | delta = float(curr_time)-float(prev_time) 30 | lin_vel = (np.array(curr_trans) - np.array(prev_trans))/delta 31 | ang_vel = (np.array(euler_from_quaternion(curr_rot)) - np.array(euler_from_quaternion(prev_rot)))/delta 32 | h = Header() 33 | h.stamp = rospy.Time.now() 34 | msg = TwistStamped() 35 | msg.header = h 36 | msg.twist = Twist(Vector3(*lin_vel), Vector3(*ang_vel)) 37 | 38 | # Update 39 | prev_time = curr_time 40 | prev_trans, prev_rot = curr_trans, curr_rot 41 | 42 | except (tf.LookupException, tf.ConnectivityException, tf.ExtrapolationException): 43 | continue 44 | # Publish the message 45 | robot_in_map.publish(msg) 46 | rate.sleep() 47 | -------------------------------------------------------------------------------- /old_stack/util/old/publishers/toggletest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import roslib 3 | import rospy 4 | from std_msgs.msg import Int64 5 | 6 | # Publisher From Here 7 | rospy.init_node('toggle_test') 8 | 9 | robot_in_map = rospy.Publisher('/toggle', Int64, queue_size=1) 10 | 11 | rate = rospy.Rate(100) 12 | while not rospy.is_shutdown(): 13 | rate.sleep() -------------------------------------------------------------------------------- /old_stack/util/old/simulation/2dsim.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import os 3 | import sys 4 | import csv 5 | import time 6 | import pygame 7 | from pygame.locals import * 8 | from PIL import Image 9 | import numpy as np 10 | 11 | task = sys.argv[1] 12 | if not os.path.exists('datas/' + task + '/'): 13 | os.mkdir('datas/' + task + '/') 14 | save_folder = None 15 | writer = None 16 | text_file = None 17 | 18 | """ 19 | Goal Positions: (200, 150), (400, 150), (600, 150) 20 | (200, 300), (400, 300), (600, 300) 21 | (200, 450), (400, 450), (600, 450) 22 | """ 23 | # Note that we have the goal positions listed above. The ones that are listed are the current ones that we are using 24 | GOAL_X = 200 25 | GOAL_Y = 150 26 | 27 | # These are magic numbers 28 | RECT_X = 60 29 | RECT_Y = 60 30 | SPACEBAR_KEY = 32 # pygame logic 31 | S_KEY = 115 32 | 33 | pygame.init() 34 | 35 | #Window 36 | screen = pygame.display.set_mode((800, 600)) 37 | pygame.display.set_caption("2D Simulation") 38 | pygame.mouse.set_visible(1) 39 | 40 | #Background 41 | background = pygame.Surface(screen.get_size()) 42 | background = background.convert() 43 | background.fill((211, 211, 211)) 44 | screen.blit(background, (0, 0)) 45 | pygame.display.flip() 46 | 47 | clock = pygame.time.Clock() 48 | 49 | run = True 50 | recording = False 51 | save_counter = 0 52 | idx = 0 53 | last_pos = None 54 | last_gripper = None 55 | prev_pos = None 56 | 57 | while run: 58 | # Note that this is the data collection speed 59 | clock.tick(60) 60 | for event in pygame.event.get(): 61 | if event.type == pygame.QUIT: 62 | run = False 63 | if event.type == pygame.KEYUP: 64 | # Note that since we are only recording when the mouse is moving, there is never a case when the velocity is 0 65 | # This may cause problems because the net cannot learn to stop when the data suggests that it never does. 66 | # To simulate the fact that there will be a start and an end with no movement, we will save 5 instances at the beginning 67 | # and at the end 68 | if event.key == SPACEBAR_KEY: # recordings the recording 69 | recording = not recording 70 | vel = (0,0,0) 71 | if recording: 72 | folder = 'datas/'+task+'/'+str(time.time())+'/' 73 | os.mkdir(folder) 74 | save_folder = folder 75 | text_file = open(save_folder + 'vector.txt', 'w') 76 | writer = csv.writer(text_file) 77 | print("===Start Recording===") 78 | position = pygame.mouse.get_pos() 79 | buttons = pygame.mouse.get_pressed() 80 | gripper = buttons[0] 81 | for _ in range(5): 82 | now = time.time() 83 | pygame.image.save(screen, save_folder + str(idx) + "_rgb.png") 84 | depth = Image.fromarray(np.uint8(np.zeros((600,800)))) 85 | depth.save(save_folder + str(idx) + "_depth.png") 86 | # Record data 87 | writer.writerow([idx, position[0], position[1], 0, 0, 0, 0, 0, vel[0], vel[1], vel[2], 0, 0, 0, gripper, now]) 88 | idx += 1 89 | print(position, vel, gripper) 90 | if not recording: 91 | for _ in range(5): 92 | now = time.time() 93 | pygame.image.save(screen, save_folder + str(idx) + "_rgb.png") 94 | depth = Image.fromarray(np.uint8(np.zeros((600,800)))) 95 | depth.save(save_folder + str(idx) + "_depth.png") 96 | # Record data 97 | writer.writerow([idx, last_pos[0], last_pos[1], 0, 0, 0, 0, 0, vel[0], vel[1], vel[2], 0, 0, 0, last_gripper, now]) 98 | idx += 1 99 | print(last_pos, vel, last_gripper) 100 | print("---Stop Recording---") 101 | if text_file != None: 102 | text_file.close() 103 | text_file = None 104 | prev_pos = None 105 | save_counter = 0 106 | idx = 0 107 | if event.key == S_KEY: # sets the cursor postion near the relative start position 108 | print("Cursor set to position (760, 370)") 109 | pygame.mouse.set_pos([760, 370]) 110 | if event.type == pygame.MOUSEMOTION: # This is the simulation of the arm. Note that left click simulates the gripper status 111 | if recording: 112 | if save_counter % 5 == 0: 113 | now = time.time() 114 | pygame.image.save(screen, save_folder+str(idx)+"_rgb.png") 115 | depth = Image.fromarray(np.uint8(np.zeros((600,800)))) 116 | depth.save(save_folder + str(idx) + "_depth.png") 117 | position = event.pos 118 | print(idx) 119 | if idx == 5: 120 | vel = event.rel 121 | else: 122 | vel = np.array(position)-prev_pos 123 | gripper = event.buttons[0] 124 | writer.writerow([idx, position[0], position[1], 0, 0, 0, 0, 0, vel[0], vel[1], 0, 0, 0, 0, gripper, now]) 125 | idx += 1 126 | prev_pos = np.array(position) 127 | last_pos = position 128 | last_gripper = gripper 129 | print(now, position, vel, gripper) 130 | save_counter += 1 131 | screen.fill((211,211,211)) 132 | pygame.draw.rect(screen, (0,0,255), pygame.Rect(GOAL_X-RECT_X/2, GOAL_Y-RECT_Y/2, RECT_X, RECT_Y)) 133 | pygame.draw.circle(screen, (255,0,0), pygame.mouse.get_pos(), 20, 1) 134 | pygame.display.update() 135 | 136 | pygame.quit() -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/README.md -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_env.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: imitate-env 3 | Version: 0.0.1 4 | Summary: UNKNOWN 5 | Home-page: UNKNOWN 6 | Author: UNKNOWN 7 | Author-email: UNKNOWN 8 | License: UNKNOWN 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id='buttons-v0', 5 | entry_point='imitate_gym.envs:ButtonsEnv', 6 | ) 7 | -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/buttons_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import gym 4 | import argparse 5 | import imitate_gym 6 | import numpy as np 7 | from PIL import Image 8 | 9 | # GOAL_00 = np.array([1.2,0.66,0.45]) 10 | # GOAL_01 = np.array([1.2,0.76,0.45]) 11 | # GOAL_02 = np.array([1.2,0.86,0.45]) 12 | # GOAL_10 = np.array([1.3,0.66,0.45]) 13 | # GOAL_11 = np.array([1.3,0.76,0.45]) 14 | # GOAL_12 = np.array([1.3,0.86,0.45]) 15 | # GOAL_20 = np.array([1.4,0.66,0.45]) 16 | # GOAL_21 = np.array([1.4,0.76,0.45]) 17 | # GOAL_22 = np.array([1.4,0.86,0.45]) 18 | # GOALS = np.array([GOAL_00,GOAL_01,GOAL_02,GOAL_10,GOAL_11,GOAL_12,GOAL_20,GOAL_21,GOAL_22]) 19 | # NAMES = {0:'goal_00',1:'goal_01',2:'goal_02',3:'goal_10',4:'goal_11',5:'goal_12',6:'goal_20',7:'goal_21',8:'goal_22'} 20 | GOAL_00 = np.array([1.2,0.56,0.45]) 21 | GOAL_01 = np.array([1.2,0.66,0.45]) 22 | GOAL_02 = np.array([1.2,0.76,0.45]) 23 | GOAL_03 = np.array([1.2,0.86,0.45]) 24 | GOAL_04 = np.array([1.2,0.96,0.45]) 25 | GOAL_10 = np.array([1.3,0.56,0.45]) 26 | GOAL_11 = np.array([1.3,0.66,0.45]) 27 | GOAL_12 = np.array([1.3,0.76,0.45]) 28 | GOAL_13 = np.array([1.3,0.86,0.45]) 29 | GOAL_14 = np.array([1.3,0.96,0.45]) 30 | GOAL_20 = np.array([1.4,0.56,0.45]) 31 | GOAL_21 = np.array([1.4,0.66,0.45]) 32 | GOAL_22 = np.array([1.4,0.76,0.45]) 33 | GOAL_23 = np.array([1.4,0.86,0.45]) 34 | GOAL_24 = np.array([1.4,0.96,0.45]) 35 | GOALS = np.array([GOAL_00, 36 | GOAL_01, 37 | GOAL_02, 38 | GOAL_03, 39 | GOAL_04, 40 | GOAL_10, 41 | GOAL_11, 42 | GOAL_12, 43 | GOAL_13, 44 | GOAL_14, 45 | GOAL_20, 46 | GOAL_21, 47 | GOAL_22, 48 | GOAL_23, 49 | GOAL_24]) 50 | NAMES = {0:'goal_00', 51 | 1:'goal_01', 52 | 2:'goal_02', 53 | 3:'goal_03', 54 | 4:'goal_04', 55 | 5:'goal_10', 56 | 6:'goal_11', 57 | 7:'goal_12', 58 | 8:'goal_13', 59 | 9:'goal_14', 60 | 10:'goal_20', 61 | 11:'goal_21', 62 | 12:'goal_22', 63 | 13:'goal_23', 64 | 14:'goal_24'} 65 | test=np.array([GOAL_24]) 66 | 67 | def calc_changes(start_pos, goal_pos, num_steps=50.): 68 | return (goal_pos-start_pos)/num_steps 69 | 70 | def next_move(curr_pos, goal_pos, slope): 71 | diffs = np.absolute(curr_pos - goal_pos) 72 | if diffs[0] <= 0.001: 73 | slope[0] = 0. 74 | if diffs[1] <= 0.001: 75 | slope[1] = 0. 76 | if diffs[2] <= 0.001: 77 | slope[2] = 0. 78 | return slope 79 | 80 | def run(env, goal, trial_dir, viz=False): 81 | if viz == False: 82 | with open(trial_dir+'vectors.txt', 'w') as f: 83 | writer = csv.writer(f) 84 | curr_pos = None 85 | slope = None 86 | goal_reached = False 87 | counter = 0 88 | traj_counter = 0 89 | other = np.array([0.,0.,0.,0.,0.]) 90 | while True: 91 | rgb = env.render('rgb_array') 92 | if curr_pos is None and goal_reached is False: 93 | obs, _, _, _ = env.step([0.,0.,0.,0.,0.,0.,0.,0.]) 94 | curr_pos = obs['achieved_goal'] 95 | slope = calc_changes(curr_pos, goal) 96 | elif curr_pos is not None and goal_reached is False: 97 | slope = next_move(curr_pos, goal, slope) 98 | if sum(np.absolute(slope)) < 0.008 and curr_pos[2] < 0.455: 99 | goal_reached = True 100 | action = np.concatenate([slope,other]) 101 | obs, _, _, _ = env.step(action) 102 | img_rgb = Image.fromarray(rgb, 'RGB') 103 | img_rgb = img_rgb.resize((160,120)) 104 | img_rgb.save(trial_dir+'{}.png'.format(traj_counter)) 105 | arr = [traj_counter] 106 | arr += [x for x in obs['observation']] 107 | arr += [0] 108 | curr_pos = obs['achieved_goal'] 109 | traj_counter += 1 110 | writer.writerow(arr) 111 | elif curr_pos is not None and goal_reached is True: 112 | # This is to simulate the terminal state where the user would stop near the end 113 | counter += 1 114 | obs, _, _, _ = env.step([0.,0.,0.,0.,0.,0.,0.,0.]) 115 | img_rgb = Image.fromarray(rgb, 'RGB') 116 | img_rgb = img_rgb.resize((160,120)) 117 | img_rgb.save(trial_dir+'{}.png'.format(traj_counter)) 118 | arr = [traj_counter] 119 | arr += [x for x in obs['observation']] 120 | arr += [0] 121 | traj_counter += 1 122 | writer.writerow(arr) 123 | if counter == 20: 124 | break 125 | else: 126 | curr_pos = None 127 | slope = None 128 | goal_reached = False 129 | counter = 0 130 | other = np.array([0.,0.,0.,0.,0.]) 131 | while True: 132 | env.render() 133 | if curr_pos is None and goal_reached is False: 134 | obs, _, _, _ = env.step([0.,0.,0.,0.,0.,0.,0.,0.]) 135 | curr_pos = obs['achieved_goal'] 136 | slope = calc_changes(curr_pos, goal) 137 | elif curr_pos is not None and goal_reached is False: 138 | slope = next_move(curr_pos, goal, slope) 139 | if sum(np.absolute(slope)) < 0.008 and curr_pos[2] < 0.455: 140 | goal_reached = True 141 | action = np.concatenate([slope,other]) 142 | obs, _, _, _ = env.step(action) 143 | print("==================") 144 | print(slope) 145 | print(obs['observation'][7:10]) 146 | curr_pos = obs['achieved_goal'] 147 | print(curr_pos) 148 | elif curr_pos is not None and goal_reached is True: 149 | # This is to simulate the terminal state where the user would stop near the end 150 | counter += 1 151 | obs, _, _, _ = env.step([0.,0.,0.,0.,0.,0.,0.,0.]) 152 | print("==================") 153 | print(slope) 154 | print(obs['observation'][7:10]) 155 | if counter == 20: 156 | break 157 | 158 | 159 | 160 | 161 | def main(task_name, num_trials=300, viz=False): 162 | if viz == False: 163 | if not os.path.exists('../../../datas/' + task_name + '/'): 164 | os.mkdir('../../../datas/' + task_name + '/') 165 | trial_counter = 0 166 | for trial in range(num_trials): 167 | env = gym.make('buttons-v0') 168 | for i, goal in enumerate(GOALS): 169 | save_folder = '../../../datas/'+task_name+'/'+NAMES[i]+'/' 170 | if viz == False: 171 | if not os.path.exists(save_folder): 172 | os.mkdir(save_folder) 173 | trial_dir = save_folder+str(trial_counter)+'/' 174 | if viz == False: 175 | os.mkdir(trial_dir) 176 | env.reset() 177 | run(env, goal, trial_dir, viz=viz) 178 | trial_counter+=1 179 | 180 | if __name__ == '__main__': 181 | main("buttons3x5", num_trials=1, viz=True) 182 | -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from imitate_gym.envs.imitate_env import ImitateEnv 2 | from imitate_gym.envs.buttons import ButtonsEnv 3 | -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/buttons/buttons.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/buttons/buttons_3x3.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/buttons/robot.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/buttons/shared.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/.get: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/.get -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/base_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/base_link_collision.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/bellows_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/bellows_link_collision.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/elbow_flex_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/elbow_flex_link_collision.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/estop_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/estop_link.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/forearm_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/forearm_roll_link_collision.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/gripper_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/gripper_link.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/head_pan_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/head_pan_link_collision.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/head_tilt_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/head_tilt_link_collision.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/l_wheel_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/l_wheel_link_collision.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/laser_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/laser_link.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/r_wheel_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/r_wheel_link_collision.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/shoulder_lift_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/shoulder_lift_link_collision.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/shoulder_pan_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/shoulder_pan_link_collision.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/torso_fixed_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/torso_fixed_link.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/torso_lift_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/torso_lift_link_collision.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/upperarm_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/upperarm_roll_link_collision.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/wrist_flex_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/wrist_flex_link_collision.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/wrist_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/wrist_roll_link_collision.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F1.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F2.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F3.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH1_z.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH1_z.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH2_z.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH2_z.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH3_z.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH3_z.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/forearm_electric.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/forearm_electric.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/forearm_electric_cvx.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/forearm_electric_cvx.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/knuckle.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/knuckle.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/lfmetacarpal.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/lfmetacarpal.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/palm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/palm.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/wrist.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/wrist.stl -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/textures/block.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/textures/block.png -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/textures/block_hidden.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/textures/block_hidden.png -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/buttons.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from gym import utils 4 | from imitate_gym.envs import imitate_env 5 | 6 | #MODEL_XML_PATH = os.path.join('fetch', 'pick_and_place.xml') 7 | HOME = '/home/jonathanchang/parameterized-imitation-learning/simulation/gym-imitate/imitate_gym/envs' 8 | MODEL_XML_PATH = HOME+'/assets/buttons/buttons_3x3.xml' 9 | class ButtonsEnv(imitate_env.ImitateEnv, utils.EzPickle): 10 | def __init__(self, reward_type='dense'): 11 | initial_qpos = { 12 | 'robot0:slide0': 0.405, 13 | 'robot0:slide1': 0.48, 14 | 'robot0:slide2': 0.0, 15 | 'robot0:shoulder_pan_joint': 0.0, 16 | 'robot0:shoulder_lift_joint': -0.8, 17 | 'robot0:elbow_flex_joint': 1.0 18 | } 19 | target = np.array([1.25, 0.53, 0.4]) 20 | imitate_env.ImitateEnv.__init__( 21 | self, model_path=MODEL_XML_PATH, n_substeps=25, initial_qpos=initial_qpos, reward_type=reward_type, 22 | distance_threshold=0.05, gripper_extra_height=0.2, target=target) 23 | utils.EzPickle.__init__(self) 24 | -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/imitate_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import gym 4 | import mujoco_py 5 | from gym import error, spaces, utils 6 | from gym.envs.robotics import rotations, robot_env, utils 7 | 8 | # Just to make sure that input shape is correct 9 | def goal_distance(goal_a, goal_b): 10 | assert goal_a.shape == goal_b.shape 11 | return np.linalg.norm(goal_a - goal_b, axis=-1) 12 | 13 | class ImitateEnv(robot_env.RobotEnv): 14 | """ 15 | Inheritance Chain: Env -> GoalEnv -> RobotEnv -> ImitateEnv 16 | 17 | Methods from RobotEnv 18 | seed(), step(), reset(), close(), render(), _reset_sim(), 19 | _get_viewer(): This method will be important for retrieving rgb and depth from the sim 20 | """ 21 | def __init__( 22 | self, model_path, n_substeps, initial_qpos, reward_type, 23 | distance_threshold, target, gripper_extra_height 24 | ): 25 | """ 26 | Args: model_path (string): path to the environments XML file 27 | n_substeps (int): number of substeps the simulation runs on every call to step 28 | initial_qpos (dict): a dictionary of joint names and values that define the initial configuration 29 | reward_type ('sparse' or 'dense'): the reward type, i.e. sparse or dense 30 | distance_threshold (float): the threshold after which a goal is considered achieved 31 | target (list): the target position that we are aiming for 32 | gripper_extra_height (float): the gripper offset in position 33 | """ 34 | # n_actions = 8 for [pos_x, pos_y, pos_z, rot_x, rot_y, rot_z, rot_w, gripper] 35 | 36 | self.reward_type = reward_type 37 | self.distance_threshold = distance_threshold 38 | self.target = target 39 | self.gripper_extra_height = gripper_extra_height 40 | self._viewers = {} 41 | self.initial_qpos = initial_qpos 42 | super(ImitateEnv, self).__init__(model_path=model_path, 43 | n_substeps=n_substeps, 44 | n_actions=8, 45 | initial_qpos=initial_qpos) 46 | # Env Method 47 | # -------------------------------------------------- 48 | def render(self, mode='human'): 49 | self._render_callback() 50 | if mode == 'rgb_array': 51 | self._get_viewer(mode).render(1000,1000) 52 | rgb = self._get_viewer(mode).read_pixels(1000, 1000, depth=False) 53 | #rgbd = self._get_viewer(mode).read_pixels(1000, 1000, depth=True) 54 | #rgb = rgbd[0][::-1,:,:] 55 | #depth = rgbd[1] 56 | #depth = (depth-np.amin(depth))*(255/(np.amax(depth)-np.amin(depth))) 57 | return rgb[::-1,:,:] 58 | elif mode == 'human': 59 | self._get_viewer(mode).render() 60 | 61 | def _get_viewer(self, mode): 62 | self.viewer = self._viewers.get(mode) 63 | if self.viewer is None: 64 | if mode == 'human': 65 | self.viewer = mujoco_py.MjViewer(self.sim) 66 | elif mode == 'rgb_array': 67 | self.viewer = mujoco_py.MjRenderContextOffscreen(self.sim, device_id=-1) 68 | self._viewer_setup() 69 | self._viewers[mode] = self.viewer 70 | return self.viewer 71 | # GoalEnv Method 72 | # -------------------------------------------------- 73 | def compute_reward(self, achieved_goal, goal, info): 74 | d = goal_distance(achieved_goal, goal) 75 | if self.reward_type == 'sparse': 76 | return -(d > self.distance_threshold).astype(np.float32) 77 | else: 78 | return -d 79 | 80 | # RobotEnv Methods 81 | # -------------------------------------------------- 82 | 83 | def _step_callback(self): 84 | """ 85 | In the step function in the RobotEnv it does: 1) self._set_action(action) 86 | 2) self.sim.step() 87 | 3) self._step_callback() 88 | This method can be used to provide additional constraints to the actions that happen every step. 89 | Could be used to fix the gripper to be a certain orientation for example 90 | """ 91 | pass 92 | 93 | def _set_action(self, action): 94 | """ 95 | Currently, I am just returning 1 number for the gripper control because there are 2 fingers in the 96 | robot that is used in the OpenAI gym library. If I start using movo, I would need 3 fingers so 97 | perhaps increase the gripper control from 2 to 3. 98 | """ 99 | assert action.shape == (8,) 100 | action = action.copy() # ensures that we don't change the action outside of this scope 101 | pos_ctrl, rot_ctrl, gripper_ctrl = action[:3], action[3:7], action[-1] 102 | 103 | # This is where I match the gripper control to number of fingers 104 | gripper_ctrl = np.array([gripper_ctrl, gripper_ctrl]) 105 | assert gripper_ctrl.shape == (2,) 106 | 107 | # Create the action that we want to pass into the simulation 108 | action = np.concatenate([pos_ctrl, rot_ctrl, gripper_ctrl]) 109 | 110 | # Now apply the action to the simulation 111 | utils.ctrl_set_action(self.sim, action) # note self.sim is an inherited variable 112 | utils.mocap_set_action(self.sim, action) 113 | 114 | def _get_obs(self): 115 | """ 116 | This is where I make sure to grab the observations for the position and the velocities. Potential 117 | area to grab the rgb and depth images. 118 | """ 119 | ee = 'robot0:grip' 120 | dt = self.sim.nsubsteps * self.sim.model.opt.timestep 121 | ee_pos = self.sim.data.get_site_xpos(ee) 122 | ee_quat = rotations.mat2quat(self.sim.data.get_site_xmat(ee)) 123 | # Position and angular velocity respectively 124 | ee_velp = self.sim.data.get_site_xvelp(ee) * dt # to remove time dependency 125 | ee_velr = self.sim.data.get_site_xvelr(ee) * dt 126 | 127 | obs = np.concatenate([ee_pos, ee_quat, ee_velp, ee_velr]) 128 | 129 | return { 130 | 'observation': obs.copy(), 131 | 'achieved_goal': ee_pos.copy(), 132 | 'desired_goal': self.goal.copy() 133 | } 134 | 135 | def _viewer_setup(self): 136 | """ 137 | This could be used to reorient the starting view frame. Will have to mess around 138 | """ 139 | body_id = self.sim.model.body_name2id('robot0:gripper_link') 140 | lookat = self.sim.data.body_xpos[body_id] 141 | for idx, value in enumerate(lookat): 142 | self.viewer.cam.lookat[idx] = value 143 | self.viewer.cam.distance = 1.7 144 | self.viewer.cam.azimuth = 180. 145 | self.viewer.cam.elevation = -30. 146 | 147 | def _sample_goal(self): 148 | """ 149 | Instead of sampling I am currently just setting our defined goal as the "sampled" goal 150 | """ 151 | return self.target 152 | 153 | def _is_success(self, achieved_goal, desired_goal): 154 | d = goal_distance(achieved_goal, desired_goal) 155 | return (d < self.distance_threshold).astype(np.float32) 156 | 157 | def _env_setup(self, initial_qpos): 158 | # Randomize the starting position 159 | shoulder_pan_val = -0.1 + (random.random()*0.2) 160 | initial_qpos['robot0:shoulder_pan_joint'] = shoulder_pan_val 161 | for name, value in initial_qpos.items(): 162 | self.sim.data.set_joint_qpos(name, value) 163 | utils.reset_mocap_welds(self.sim) 164 | self.sim.forward() 165 | 166 | # Move end effector into position 167 | gripper_target = np.array([-0.498, 0.005, -0.431 + self.gripper_extra_height]) + self.sim.data.get_site_xpos('robot0:grip') 168 | gripper_rotation = np.array([1.,0.,1.,0.]) 169 | self.sim.data.set_mocap_pos('robot0:mocap', gripper_target) 170 | self.sim.data.set_mocap_quat('robot0:mocap', gripper_rotation) 171 | for _ in range(10): 172 | self.sim.step() 173 | 174 | 175 | 176 | 177 | 178 | 179 | -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/imitate_gym/gym_eval.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import imitate_gym 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | from crnn import SpatialCRNN 7 | import sys 8 | seq_len = 10 9 | weights = sys.argv[1] 10 | tau = [[int(sys.argv[2]), int(sys.argv[3])] for _ in range(seq_len)] 11 | tau = torch.from_numpy(np.array(tau)).type(torch.FloatTensor) 12 | tau = torch.unsqueeze(tau, dim=0) 13 | env = gym.make('buttons-v0') 14 | env.reset() 15 | model = SpatialCRNN(rtype="GRU", num_layers=2, device="cpu") 16 | model.load_state_dict(torch.load(weights, map_location='cpu'), strict=False) 17 | model.eval() 18 | eof = None 19 | rgbs = [] 20 | while True: 21 | env.render() 22 | rgb = env.render('rgb_array') 23 | if eof is None: 24 | obs, _,_,_ = env.step([0.,0.,0.,0.,0.,0.,0.,0.]) 25 | pos = obs['achieved_goal'] 26 | eof = [pos for _ in range(seq_len)] 27 | img_rgb = Image.fromarray(rgb, 'RGB') 28 | img_rgb = np.array(img_rgb.resize((160,120))) 29 | img_rgb = np.reshape(img_rgb,(3,120,160)) 30 | rgbs = [img_rgb for _ in range(seq_len)] 31 | else: 32 | img_rgb = Image.fromarray(rgb, 'RGB') 33 | img_rgb = np.array(img_rgb.resize((160,120))) 34 | img_rgb = np.reshape(img_rgb,(3,120,160)) 35 | rgbs.pop(0) 36 | rgbs.append(img_rgb) 37 | input_rgb = torch.from_numpy(np.array(rgbs)).type(torch.FloatTensor) 38 | input_rgb = torch.unsqueeze(input_rgb, dim=0) 39 | input_eof = torch.from_numpy(np.array(eof)).type(torch.FloatTensor) 40 | input_eof = torch.unsqueeze(input_eof, dim=0) 41 | vels, _ = model([input_rgb, input_eof, tau]) 42 | action = [vels[0][0].item(), vels[0][1].item(), vels[0][2].item(), 0., 0., 0., 0., 0.] 43 | print(action) 44 | obs, _, _, _ = env.step(action) 45 | eof.pop(0) 46 | eof.append(obs['achieved_goal']) 47 | 48 | -------------------------------------------------------------------------------- /old_stack/util/old/simulation/gym-imitate/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='imitate_env', 4 | version='0.0.1', 5 | install_requires=['gym', 'mujoco-py'] 6 | ) 7 | -------------------------------------------------------------------------------- /old_stack/util/old/test_baxter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import sys 3 | import rospy 4 | import itertools 5 | import message_filters 6 | import numpy as np 7 | import time 8 | import cv2 9 | import moveit_commander 10 | from cv_bridge import CvBridge 11 | from std_msgs.msg import Int64 12 | from geometry_msgs.msg import Point, Quaternion, Pose, Twist 13 | from sensor_msgs.msg import CompressedImage, JointState 14 | 15 | class ApproxTimeSync(message_filters.ApproximateTimeSynchronizer): 16 | def add(self, msg, my_queue, my_queue_index=None): 17 | self.allow_headerless = True 18 | if hasattr(msg, 'timestamp'): 19 | stamp = msg.timestamp 20 | elif not hasattr(msg, 'header') or not hasattr(msg.header, 'stamp'): 21 | if not self.allow_headerless: 22 | rospy.logwarn("Cannot use message filters with non-stamped messages. " 23 | "Use the 'allow_headerless' constructor option to " 24 | "auto-assign ROS time to headerless messages.") 25 | return 26 | stamp = rospy.Time.now() 27 | else: 28 | stamp = msg.header.stamp 29 | 30 | #TODO ADD HEADER TO ALLOW HEADERLESS 31 | # http://book2code.com/ros_kinetic/source/ros_comm/message_filters/src/message_filters/__init__.y 32 | #setattr(msg, 'header', a) 33 | #msg.header.stamp = stamp 34 | #super(message_filters.ApproximateTimeSynchronizer, self).add(msg, my_queue) 35 | self.lock.acquire() 36 | my_queue[stamp] = msg 37 | while len(my_queue) > self.queue_size: 38 | del my_queue[min(my_queue)] 39 | # self.queues = [topic_0 {stamp: msg}, topic_1 {stamp: msg}, ...] 40 | if my_queue_index is None: 41 | search_queues = self.queues 42 | else: 43 | search_queues = self.queues[:my_queue_index] + \ 44 | self.queues[my_queue_index+1:] 45 | # sort and leave only reasonable stamps for synchronization 46 | stamps = [] 47 | for queue in search_queues: 48 | topic_stamps = [] 49 | for s in queue: 50 | stamp_delta = abs(s - stamp) 51 | if stamp_delta > self.slop: 52 | continue # far over the slop 53 | topic_stamps.append((s, stamp_delta)) 54 | if not topic_stamps: 55 | self.lock.release() 56 | return 57 | topic_stamps = sorted(topic_stamps, key=lambda x: x[1]) 58 | stamps.append(topic_stamps) 59 | for vv in itertools.product(*[zip(*s)[0] for s in stamps]): 60 | vv = list(vv) 61 | # insert the new message 62 | if my_queue_index is not None: 63 | vv.insert(my_queue_index, stamp) 64 | qt = list(zip(self.queues, vv)) 65 | if ( ((max(vv) - min(vv)) < self.slop) and 66 | (len([1 for q,t in qt if t not in q]) == 0) ): 67 | msgs = [q[t] for q,t in qt] 68 | self.signalMessage(*msgs) 69 | for q,t in qt: 70 | del q[t] 71 | break # fast finish after the synchronization 72 | self.lock.release() 73 | 74 | class ImitateLearner(): 75 | """ 76 | This class will evaluate the outputs of the net. 77 | """ 78 | def __init__(self, arm='right'): 79 | rospy.init_node("{}_arm_eval".format(arm)) 80 | self.queue = [] 81 | self.rgb = None 82 | self.depth = None 83 | self.pos = None 84 | self.orient = None 85 | self.prevTime = None 86 | self.time = None 87 | self.arm = arm 88 | 89 | # Initialize Subscribers 90 | self.listener() 91 | 92 | # Enable Robot 93 | moveit_commander.roscpp_initialize(sys.argv) 94 | robot = moveit_commander.RobotCommander() 95 | self.group_arms = moveit_commander.MoveGroupCommander('upper_body') 96 | self.group_arms.set_pose_reference_frame('/base_link') 97 | self.left_ee_link = 'left_ee_link' 98 | self.right_ee_link = 'right_ee_link' 99 | 100 | # Set the rate of our evaluation 101 | rate = rospy.Rate(0.5) 102 | 103 | # Give time for initialization 104 | rospy.Rate(1).sleep() 105 | 106 | # This is to the get the time delta 107 | first_time = True 108 | while not rospy.is_shutdown(): 109 | # Todo: Connect Net 110 | if first_time: 111 | self.prevTime = time.time() 112 | first_time = False 113 | # Given the output, solve for limb joints 114 | #limb_joints = self.get_limb_joints(output) 115 | 116 | # If valid joints then move to joint 117 | #if limb_joints is not -1: 118 | # right.move_to_joint_positions(limb_joints) 119 | # self.prevTime = self.time 120 | #else: 121 | # print 'ERROR: IK solver returned -1' 122 | print(self.pos) 123 | rate.sleep() 124 | 125 | def listener(self): 126 | """ 127 | Listener for all of the topics 128 | """ 129 | print("Listener Initialized") 130 | pose_sub = message_filters.Subscriber('/tf/{}_arm_pose'.format(self.arm), Pose) 131 | twist_sub = message_filters.Subscriber('tf/{}_arm_vels'.format(self.arm), Twist) 132 | rgb_sub = message_filters.Subscriber('/kinect2/sd/image_color_rect/compressed', CompressedImage) 133 | depth_sub = message_filters.Subscriber('/kinect2/sd/image_depth_rect/compressed', CompressedImage) 134 | gripper_sub = message_filters.Subscriber('/movo/{}_gripper/gripper_is_open'.format(self.arm), Int64) 135 | ts = ApproxTimeSync([pose_sub, twist_sub, rgb_sub, depth_sub, gripper_sub], 1, 0.1) 136 | ts.registerCallback(self.listener_callback) 137 | 138 | def listener_callback(self, pose, twist, rgb, depth, gripper): 139 | """ 140 | This method updates the variables. 141 | """ 142 | bridge = CvBridge() 143 | self.time = time.time() 144 | self.rgb = bridge.compressed_imgmsg_to_cv2(rgb) 145 | self.depth = bridge.compressed_imgmsg_to_cv2(depth) 146 | self.pos = pose.position 147 | self.orient = pose.orientation 148 | # Create input for net. x, y, z 149 | queue_input = np.array([self.pos.x, self.pos.y, self.pos.z]) 150 | if len(self.queue) == 0: 151 | self.queue = [queue_input for i in range(5)] 152 | else: 153 | self.queue.pop(0) 154 | self.queue.append(queue_input) 155 | 156 | def get_next_pose(self, output): 157 | """ 158 | This method gets the ik_solver solution for the arm joints. 159 | """ 160 | [goal_pos, goal_orient] = self.calculate_move(np.reshape(output[0, :3], (3,)), np.reshape(output[0, 3:], (3,))) 161 | return Point(*goal_pos), Quaternion(*goal_orient) 162 | 163 | def calculate_move(self, lin, ang): 164 | """ 165 | This calculates the position and orientation (in quaterion) of the next pose given 166 | the linear and angular velocities outputted by the net. 167 | """ 168 | delta = self.time - self.prevTime 169 | print("------------") 170 | print(delta) 171 | print("------------") 172 | #delta = 1/30 173 | # Position Update 174 | curr_pos = np.array([self.pos.x, self.pos.y, self.pos.z]) 175 | goal_pos = np.add(curr_pos, delta*np.array(lin)) 176 | # Orientation Update 177 | curr_orient = np.array([self.orient.x, self.orient.y, self.orient.z, self.orient.w]) 178 | w_ang = np.concatenate([[0], ang]) 179 | goal_orient = np.add(curr_orient, 0.5*delta*np.matmul(w_ang, np.transpose(curr_orient))) 180 | # Update the prevTime 181 | return goal_pos, goal_orient 182 | 183 | if __name__ == '__main__': 184 | learner = ImitateLearner() -------------------------------------------------------------------------------- /old_stack/util/parse_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | #import cv2 3 | import numpy as np 4 | import os 5 | import pandas as pd 6 | import csv 7 | 8 | from PIL import Image 9 | from random import shuffle 10 | 11 | import re 12 | import time 13 | import math 14 | import sys 15 | import copy 16 | import shutil 17 | from functools import cmp_to_key 18 | from tqdm import tqdm 19 | 20 | splits = {} 21 | 22 | def preprocess_images(root_diri, cases): 23 | for case in cases: 24 | dirs = [x[0] for x in os.walk(root_dir + case)] 25 | for sub_dir in dirs: 26 | for root, _, files in os.walk(sub_dir): 27 | files = sorted(files) 28 | files = files[:-1] 29 | for i in range(0, len(files), 2): 30 | depth = Image.open(root+"/"+files[i]) 31 | rgb = Image.open(root+"/"+files[i+1]) 32 | rgb = rgb.resize((160,120)) 33 | depth = depth.resize((160,120)) 34 | depth.save(root+"/"+files[i]) 35 | rgb.save(root+"/"+files[i+1]) 36 | 37 | 38 | def parse_param_2(root_dir, mode, cases, tau, seq_len, split_percen, dest, directions=None): 39 | global splits 40 | file = open(dest+"/"+ mode + "_data.csv", "w+") 41 | writer = csv.writer(file, delimiter=',') 42 | # For every folder, tau, and direction 43 | for case, tau in tqdm(zip(cases, tau)): 44 | #print(case) 45 | #print(tau) 46 | # Create the splits 47 | if case not in splits.keys(): 48 | dirs = [x[0] for x in os.walk(root_dir + case)][1:] 49 | shuffle(dirs) 50 | split_idx = int(math.ceil(len(dirs)*float(split_percen))) 51 | splits[case] = {"train": dirs[:split_idx], "test": dirs[split_idx:]} 52 | # Go into every subdirectory 53 | sub_dirs = splits[case] 54 | #dirs = [x[0] for x in os.walk(root_dir + case)][1:] 55 | for sub_dir in sub_dirs[mode]: 56 | for root, _, file in os.walk(sub_dir): 57 | file = sorted(file) 58 | pics = file[:-1] 59 | pics = sorted(pics, key=cmp_to_key(compare_names)) 60 | vectors = pd.read_csv(root+"/"+file[-1], header=-1) 61 | # We will start from 10 to start creating a sequential dataset (rgb, depth) 62 | # The length of the history will be 5 63 | #for i in range((seq_len*2)-2, len(pics), 2): 64 | for i in range(0, len(pics)): 65 | # First get the current directory 66 | row = [root, int(pics[i][:-4]), tau[0], tau[1]] 67 | #row = [root, int(pics[i][:-10]), tau[0], tau[1]] 68 | #depth = [root+"/"+pics[i-j] for j in range(seq_len,-1,-2)] 69 | #rgb = [root+"/"+pics[i-j+1] for j in range(seq_len,-1,-2)] 70 | #tau = [tau for _ in range(seq_len)] 71 | #print(pics) 72 | #print(i) 73 | #print([pics[i-j] for j in range(seq_len*2, -1, -2)]) 74 | if i == 0: 75 | prevs = [int(pics[i][:-4]) for _ in range(seq_len)] 76 | elif i != 0 and i < seq_len: 77 | prevs.pop(0) 78 | prevs.append(int(pics[i][:-4])) 79 | else: 80 | prevs = [int(pics[i-j][:-4]) for j in range(seq_len-1, -1, -1)] 81 | #prevs = [int(pics[i-j][:-10]) for j in range((seq_len*2)-2, -1, -2)] 82 | eof = [] 83 | for prev in prevs: 84 | pos = [float(vectors[vectors[0]==prev][j]) for j in range(1,4)] 85 | eof += pos 86 | ## Label and Gripper still stay the same 87 | label = [float(vectors[vectors[0]==float(pics[i][:-4])][j]) for j in range(8,14)] 88 | aux_label = [float(vectors[vectors[0]==float(pics[-1][:-4])][j]) for j in range(1,8)] 89 | #label = [float(vectors[vectors[0]==float(pics[i][:-10])][j]) for j in range(8,14)] 90 | #aux_label = [float(vectors[vectors[0]==float(pics[-2][:-10])][j]) for j in range(1,8)] 91 | #print(label) 92 | row += prevs 93 | row += label 94 | row += aux_label 95 | row += eof 96 | #gripper = repr(int(vectors[vectors[0]==float(pics[i][:-8])][14])) 97 | ## This is how we will represent the data 98 | #writer.writerow([depth, rgb, eof, label, gripper, tau]) 99 | writer.writerow(row) 100 | #print(prevs) 101 | print("{} Dta Creation Done".format(mode)) 102 | 103 | def compare_names(name1, name2): 104 | num1 = extract_number(name1) 105 | num2 = extract_number(name2) 106 | if num1 == num2: 107 | if name1 > name2: 108 | return 1 109 | else: 110 | return -1 111 | else: 112 | return num1 - num2 113 | 114 | def extract_number(name): 115 | numbers = re.findall('\d+', name) 116 | return int(numbers[0]) 117 | 118 | def clean_data(root_dir, cases): 119 | for case in cases: 120 | # Create the splits 121 | dirs = [x[0] for x in os.walk(root_dir + case)][1:] 122 | # Go into every subdirectory 123 | for sub_dir in dirs: 124 | for root, _, file in os.walk(sub_dir): 125 | print(root) 126 | file = sorted(file) 127 | vector_file = root+"/"+file[-1] 128 | vectors = csv.read_csv(vector_file, header=-1) 129 | delete_rows = [] 130 | for i in range(1, len(file), 2): 131 | label = [round(float(vectors[vectors[0]==float(file[i][:-8])][j])*10) for j in range(8,14)] 132 | if sum(label) == 0: 133 | #print(float(file[i][:-8])) 134 | #print(root+"/"+file[i]) #rgb 135 | #print(root+"/"+file[i-1]) #depth 136 | delete_rws.append(file[i][:-8]) 137 | os.remove(root+"/"+file[i]) 138 | os.remove(root+"/"+file[i-1]) 139 | if (len(file)-1)/2 != len(delete_rows): 140 | last = None 141 | last_num = 0 142 | clean_file = root+"/vector2.txt" 143 | with open(vector_file, "rb") as input, open(clean_file, "wb") as out: 144 | writer = csv.writer(out) 145 | for row in csv.reader(input): 146 | if row[0] not in delete_rows: 147 | writer.writerow(row) 148 | with open(clean_file, "r") as fd: 149 | last = [l for l in fd][-1] 150 | last = last.strip().split(',') 151 | last_num = int(last[0]) 152 | last[8:14] = ['0.0' for _ in range(8,14)] 153 | counter = last_num + 1 154 | with open(clean_file, "a") as fd: 155 | for _ in range(10): 156 | shutil.copy(root+"/"+str(last_num) + "_depth.png", root+ "/" + str(counter) + "_depth.png") 157 | shutil.copy(root+"/"+str(last_num) + "_rgb.png", root+"/" + str(counter) + "_rgb.png") 158 | row = copy.deepcopy(last) 159 | row[0] = str(counter) 160 | row = ",".join(row) + "\n" 161 | counter += 1 162 | fd.write(row) 163 | os.remove(vector_file) 164 | if len(os.listdir(sub_dir)) == 0: 165 | os.rmdir(sub_dir) 166 | 167 | if __name__ == '__main__': 168 | tau = [[71,77],[91,77],[81,83],[70,89],[92,89]] 169 | cases = ['/goal_00', '/goal_02', '/goal_11', '/goal_20', '/goal_22'] 170 | #tau = [[81,77], [71,83], [91,83], [81,89]] 171 | #cases = ['/goal_01', '/goal_10', '/goal_12', '/goal_21'] 172 | modes = ["train"] 173 | seq_len = 10 174 | root_dir = sys.argv[1] 175 | split_percen = sys.argv[2] 176 | dest = sys.argv[3] 177 | #clean_data(root_dir, cases) 178 | #preprocess_images(root_dir, cases) 179 | datasets = {mode: parse_param_2(root_dir, mode, cases, tau, seq_len, split_percen, dest=dest) for mode in modes} 180 | -------------------------------------------------------------------------------- /old_stack/util/record.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import cv2 3 | import time 4 | import csv 5 | import os 6 | import sys 7 | import rospy 8 | import itertools 9 | import numpy as np 10 | from cv_bridge import CvBridge 11 | from namedlist import namedlist 12 | from std_msgs.msg import Int64, String 13 | from sensor_msgs.msg import CompressedImage, Image, JointState 14 | from geometry_msgs.msg import Twist, Pose, TwistStamped, PoseStamped 15 | from imitate_msgs.msg import GripperStamped 16 | 17 | class ImitateRecorder(): 18 | def __init__(self, task): 19 | rospy.init_node('imitate_recorder2', log_level=rospy.DEBUG) 20 | # Create the appropriate directory in the datas for the task we are training 21 | if not os.path.exists('../datas/' + task + '/'): 22 | os.mkdir('../datas/' + task + '/') 23 | self.save_folder = None # The specific folder 24 | self.writer = None # The writer to create our txt files 25 | self.text_file = None # The file that we are writing to currently 26 | self.is_recording = False # Toggling recording 27 | self.counter = 0 # The unique key for each datapoint we save 28 | self.bridge = CvBridge() 29 | # Initialize current values 30 | self.Data = namedlist('Data', ['pose', 'twist', 'grip', 'rgb', 'depth']) 31 | self.data = self.Data(pose=None, twist=None, grip=None, rgb=None, depth=None) 32 | 33 | def toggle_collection(self, toggle): 34 | if toggle is "0": 35 | self.counter = 0 36 | self.is_recording = False 37 | self.unsubscribe() 38 | time.sleep(1) 39 | if self.text_file != None: 40 | self.text_file.close() 41 | self.text_file = None 42 | print("-----Stop Recording-----") 43 | else: 44 | save_folder = '../datas/' + task + '/' + str(time.time()) + '/' 45 | os.mkdir(save_folder) 46 | self.save_folder = save_folder 47 | self.text_file = open(save_folder + 'vectors.txt', 'w') 48 | self.writer = csv.writer(self.text_file) 49 | self.is_recording = True 50 | print("=====Start Recording=====") 51 | self.collect_data() 52 | 53 | def collect_data(self): 54 | # Initialize Listeners 55 | self.init_listeners() 56 | rospy.Rate(5).sleep() 57 | # Define the rate at which we will collect data 58 | rate = rospy.Rate(15) 59 | while not rospy.is_shutdown(): 60 | if None not in self.data: 61 | print("Data Collected!!") 62 | rgb_image = self.bridge.imgmsg_to_cv2(self.data.rgb, desired_encoding="passthrough") 63 | depth_image = self.bridge.imgmsg_to_cv2(self.data.depth, desired_encoding="passthrough") 64 | cv2.imwrite(self.save_folder + str(self.counter) + '_rgb.png', rgb_image) 65 | cv2.imwrite(self.save_folder + str(self.counter) + '_depth.png', depth_image) 66 | posit = self.data.pose.position 67 | orient = self.data.pose.orientation 68 | lin = self.data.twist.linear 69 | ang = self.data.twist.angular 70 | arr = [self.counter, posit.x, posit.y, posit.z, orient.w, orient.x, orient.y, orient.z, lin.x, lin.y, lin.z, ang.x, ang.y, ang.z, self.data.grip, time.time()] 71 | self.writer.writerow(arr) 72 | self.data = self.Data(pose=None, twist=None, grip=None, rgb=None, depth=None) 73 | self.counter += 1 74 | rate.sleep() 75 | 76 | def init_listeners(self): 77 | # The Topics we are Subscribing to for data 78 | self.right_arm_pose = rospy.Subscriber('/tf/right_arm_pose', PoseStamped, self.pose_callback) 79 | self.right_arm_vel = rospy.Subscriber('/tf/right_arm_vels', TwistStamped, self.vel_callback) 80 | self.rgb_state_sub = rospy.Subscriber('/kinect2/qhd/image_color_rect', Image, self.rgb_callback) 81 | self.depth_state_sub = rospy.Subscriber('/kinect2/qhd/image_depth_rect', Image, self.depth_callback) 82 | self.gripper_state_sub = rospy.Subscriber('/movo/right_gripper/gripper_is_open', GripperStamped, self.gripper_callback) 83 | 84 | def unsubscribe(self): 85 | self.right_arm_pose.unregister() 86 | self.right_arm_vel.unregister() 87 | self.rgb_state_sub.unregister() 88 | self.depth_state_sub.unregister() 89 | self.gripper_state_sub.unregister() 90 | 91 | def pose_callback(self, pose): 92 | if None in self.data: 93 | self.data.pose = pose.pose 94 | 95 | def vel_callback(self, twist): 96 | if None in self.data: 97 | self.data.twist = twist.twist 98 | 99 | def rgb_callback(self, rgb): 100 | if None in self.data: 101 | self.data.rgb = rgb 102 | 103 | def depth_callback(self, depth): 104 | if None in self.data: 105 | self.data.depth = depth 106 | 107 | def gripper_callback(self, gripper): 108 | if None in self.data: 109 | self.data.grip = gripper.data.data 110 | 111 | if __name__ == '__main__': 112 | task = sys.argv[1] 113 | recorder = ImitateRecorder(task) 114 | recorder.toggle_collection("1") 115 | recorder.toggle_collection("0") 116 | -------------------------------------------------------------------------------- /old_stack/util/toggler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import sys 3 | import rospy 4 | import subprocess 5 | from std_msgs.msg import String 6 | 7 | p = None 8 | is_recording = False 9 | task = sys.argv[1] 10 | 11 | def toggle(value): 12 | global p 13 | global is_recording 14 | global task 15 | 16 | if value.data is "0": 17 | if is_recording: 18 | p.terminate() 19 | is_recording = False 20 | elif value.data is "1": 21 | if not is_recording: 22 | p = subprocess.Popen(["./record.py", task]) 23 | is_recording = True 24 | 25 | rospy.init_node('imitate_toggler', log_level=rospy.DEBUG) 26 | toggle_sub = rospy.Subscriber('/unity_learning_record', String, toggle) 27 | rospy.spin() 28 | -------------------------------------------------------------------------------- /sim_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import os 3 | import sys 4 | import csv 5 | import time 6 | import pygame 7 | import argparse 8 | from pygame.locals import * 9 | from PIL import Image 10 | import numpy as np 11 | import itertools 12 | from src.model import Model 13 | import torch 14 | from simulation.sim import get_start 15 | 16 | # note we want tau to be col, row == x, y 17 | # This is like the get goal method 18 | 19 | def get_tau_(): 20 | button = input('Please enter a button for the robot to try to press (e.g. "00", "12"): ') 21 | return [int(b) for b in button] 22 | 23 | 24 | def distance(a, b): 25 | dist = [np.abs(a[i] - b[i]) for i in range(len(a))] 26 | dist = [e**2 for e in dist] 27 | dist = sum(dist) 28 | return np.sqrt(dist) 29 | 30 | 31 | def get_tau(goal_x, goal_y, options): 32 | return options[goal_x][goal_y] 33 | 34 | 35 | def process_images(np_array_img, is_it_rgb): 36 | img = 2*((np_array_img - np.amin(np_array_img))/(np.amax(np_array_img)-np.amin(np_array_img))) - 1 37 | img = torch.from_numpy(img).type(torch.FloatTensor).squeeze() 38 | if(is_it_rgb): 39 | img = img.permute(2, 0, 1).unsqueeze(0) 40 | else: 41 | img = img.view(1,1,img.shape[0], img.shape[1]) 42 | return img 43 | 44 | def sim(model, config): 45 | goals_x = list(range(3)) 46 | goals_y = list(range(3)) 47 | 48 | goal_pos = [[(200, 150), (200, 300), (200, 450)], 49 | [(400, 150), (400, 300), (400, 450)], 50 | [(600, 150), (600, 300), (600, 450)]] 51 | #goal_pos = [[(int(np.random.rand()*720)+40, int(np.random.rand()*520)+40) for j in goals_y] for i in goals_x] 52 | #goal_pos = [[(int((i+1)*800/(len(goals_x)+1)) - 100, int((j+1)*600/(len(goals_y)+1)) + 35) for j in goals_y] for i in goals_x] 53 | 54 | tau_opts = [[(0, 0), (0, 1), (0, 2)], 55 | [(1, 0), (1, 1), (1, 2)], 56 | [(2, 0), (2, 1), (2, 2)]] 57 | 58 | tau_opts = goal_pos 59 | 60 | 61 | # These are magic numbers 62 | RECT_X = 60 63 | RECT_Y = 60 64 | SPACEBAR_KEY = 32 # pygame logic 65 | S_KEY = pygame.K_s 66 | R_KEY = pygame.K_r 67 | ESCAPE_KEY = pygame.K_ESCAPE 68 | 69 | pygame.init() 70 | 71 | #Window 72 | screen = pygame.display.set_mode((800, 600)) 73 | pygame.display.set_caption("2D Simulation") 74 | pygame.mouse.set_visible(1) 75 | # Set the cursor 76 | curr_pos = get_start() 77 | 78 | #Background 79 | background = pygame.Surface(screen.get_size()) 80 | background = background.convert() 81 | background.fill((211, 211, 211)) 82 | screen.blit(background, (0, 0)) 83 | pygame.display.flip() 84 | 85 | clock = pygame.time.Clock() 86 | 87 | run = True 88 | eof = None 89 | rgb = None 90 | depth = None 91 | 92 | gx, gy = get_tau_()#np.random.randint(0, 3, (2,)) 93 | #tau_opts = np.random.randint(0, 255, (3,3,3)) if config.color else goal_pos 94 | tau = get_tau(gx, gy, goal_pos) 95 | a = 0 96 | while run: 97 | a += 1 98 | clock.tick(config.framerate) 99 | for event in pygame.event.get(): 100 | if event.type == pygame.QUIT: 101 | run = False 102 | break 103 | if event.type == pygame.KEYUP: 104 | if event.key == S_KEY: 105 | curr_pos = get_start() 106 | if event.key == R_KEY: 107 | gx, gy = get_tau_()#np.random.randint(0, 3, (2,)) 108 | #tau_opts = np.random.randint(0, 255, (3,3,3)) if config.color else goal_pos 109 | tau = (gx, gy)#get_tau(gx, gy, tau_opts) 110 | if event.key == ESCAPE_KEY: 111 | run = False 112 | break 113 | 114 | vanilla_rgb_string = pygame.image.tostring(screen,"RGBA",False) 115 | vanilla_rgb_pil = Image.frombytes("RGBA",(800,600),vanilla_rgb_string) 116 | resized_rgb = vanilla_rgb_pil.resize((160,120)) 117 | rgb = np.array(resized_rgb)[:,:,:3] 118 | rgb = process_images(rgb, True) 119 | if torch.any(torch.isnan(rgb)): 120 | rgb.zero_() 121 | 122 | vanilla_depth = Image.fromarray(np.uint8(np.zeros((120,160)))) 123 | depth = process_images(vanilla_depth, False).zero_() 124 | 125 | div = [200, 175] if config.normalize else [1, 1] 126 | sub = [400, 325] if config.normalize else [0, 0] 127 | norm_pos = [(curr_pos[0] - sub[0]) / div[0], (curr_pos[1] - sub[1]) / div[1]] 128 | if eof is None: 129 | eof = torch.FloatTensor([norm_pos[0], norm_pos[1], 0.0] * 5) 130 | else: 131 | eof = torch.cat([torch.FloatTensor([norm_pos[0], norm_pos[1], 0.0]), eof[0:12]]) 132 | 133 | # Calculate the trajectory 134 | in_tau = torch.FloatTensor(tau) 135 | #in_tau = torch.zeros(1) 136 | #in_tau[0] = tau[0]*3 + tau[1] 137 | out, aux = model(rgb, depth, eof.view(1, -1), in_tau.view(1, -1).to(eof)) 138 | out = out.squeeze() 139 | delta_x = out[0].item() 140 | delta_y = out[1].item() 141 | new_pos = [curr_pos[0] + delta_x, curr_pos[1] + delta_y] 142 | print(eof) 143 | print(tau) 144 | print(aux) 145 | #print(get_tau(gx, gy, goal_pos)) 146 | print(out) 147 | print(new_pos) 148 | print(distance(curr_pos, new_pos)) 149 | #if (distance(curr_pos, new_pos)) < 1.5: 150 | # time.sleep(5) 151 | print('========') 152 | curr_pos = new_pos 153 | 154 | screen.fill((211,211,211)) 155 | for x, y in list(itertools.product(goals_x, goals_y)): 156 | color = tau_opts[x,y] if config.color else (0, 0, 255) 157 | #if x == gx and y == gy: 158 | # continue 159 | pygame.draw.rect(screen, color, pygame.Rect(goal_pos[x][y][0]-RECT_X/2, goal_pos[x][y][1]-RECT_Y/2, RECT_X, RECT_Y)) 160 | pygame.draw.circle(screen, (0,0,0), [int(v) for v in curr_pos], 20, 0) 161 | pygame.display.update() 162 | 163 | #if a == 200: 164 | # break 165 | pygame.quit() 166 | return 0 167 | 168 | if __name__ == '__main__': 169 | parser = argparse.ArgumentParser(description='Input to 2d simulation.') 170 | parser.add_argument('-w', '--weights', required=True, help='The path to the weights to load.') 171 | parser.add_argument('-c', '--color', dest='color', default=False, action='store_true', help='Used to activate color simulation.') 172 | parser.add_argument('-no', '--normalize', dest='normalize', default=False, action='store_true', help='Used to activate position normalization.') 173 | parser.add_argument('-f', '--framerate', default=300, type=int, help='Framerate of simulation.') 174 | args = parser.parse_args() 175 | 176 | checkpoint = torch.load(args.weights, map_location='cpu') 177 | model = Model(**checkpoint['kwargs']) 178 | model.load_state_dict(checkpoint['model_state_dict']) 179 | model.eval() 180 | 181 | sim(model, args) 182 | -------------------------------------------------------------------------------- /sim_eval0.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | import os 3 | import sys 4 | import csv 5 | import time 6 | import pygame 7 | import argparse 8 | from pygame.locals import * 9 | from PIL import Image 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | import itertools 13 | from src.model import Model 14 | from src.model0 import Model0 15 | from src.loss_func import fix_rot 16 | import torch 17 | from simulation.sim import get_start, get_tau, RECT_X, RECT_Y, goals_x, goals_y, goal_pos 18 | 19 | # note we want tau to be col, row == x, y 20 | # This is like the get goal method 21 | def get_tau_(): 22 | button = input('Please enter a button for the robot to try to press (e.g. "00", "12"): ') 23 | return [int(b) for b in button] 24 | 25 | 26 | def distance(a, b): 27 | dist = [np.abs(a[i] - b[i]) for i in range(len(a))] 28 | dist = [e**2 for e in dist] 29 | dist = sum(dist) 30 | return np.sqrt(dist) 31 | 32 | 33 | def process_images(np_array_img, is_it_rgb): 34 | #try: 35 | img = 2*((np_array_img - np.amin(np_array_img))/(np.amax(np_array_img)-np.amin(np_array_img))) - 1 36 | #except: 37 | # img = np.zeros_like(np_array_img) 38 | img = torch.from_numpy(img).type(torch.FloatTensor).squeeze() 39 | if(is_it_rgb): 40 | img = img.permute(2, 0, 1) 41 | else: 42 | img = img.view(1, img.shape[0], img.shape[1]) 43 | return img.unsqueeze(0) 44 | 45 | def sim(model, config): 46 | 47 | tau_opts = [[(0, 0), (0, 1), (0, 2)], 48 | [(1, 0), (1, 1), (1, 2)], 49 | [(2, 0), (2, 1), (2, 2)]] 50 | 51 | #tau_opts = [[(1,0), (2,0), (3,0)], File "/home/nishanth/miniconda3/envs/py3_pytorch_cuda10/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__ 52 | 53 | # [(3,1), (1,1), (2,1)], 54 | # [(2,2), (3,2), (1,2)]] 55 | 56 | 57 | # These are magic numbers 58 | SPACEBAR_KEY = 32 # pygame logic 59 | S_KEY = pygame.K_s 60 | R_KEY = pygame.K_r 61 | ESCAPE_KEY = pygame.K_ESCAPE 62 | 63 | pygame.init() 64 | 65 | #Window 66 | screen = pygame.display.set_mode((800, 600)) 67 | pygame.display.set_caption("2D Simulation") 68 | pygame.mouse.set_visible(1) 69 | # Set the cursor 70 | curr_pos = get_start() 71 | 72 | #Background 73 | background = pygame.Surface(screen.get_size()) 74 | background = background.convert() 75 | background.fill((211, 211, 211)) 76 | screen.blit(background, (0, 0)) 77 | pygame.display.flip() 78 | 79 | clock = pygame.time.Clock() 80 | 81 | run = True 82 | eof = None 83 | rgb = None 84 | depth = None 85 | 86 | gx, gy = get_tau_()#np.random.randint(0, 3, (2,)) 87 | #tau_opts = np.random.randint(0, 255, (3,3,3)) if config.color else goal_pos 88 | tau = get_tau(gx, gy, tau_opts)#(gx, gy) 89 | if args.rotation: 90 | # rect_rot = np.ones(9) * np.random.randint(0,360) 91 | rect_rot = np.random.randint(0,360, (9,)) 92 | else: 93 | # rect_rot = np.ones(9) * np.random.randint(0,360) 94 | # rect_rot = np.ones(9) * 35 95 | rect_rot = np.random.randint(0,360, (9,)) 96 | 97 | x_offset = np.random.randint(0, 200) 98 | y_offset = np.random.randint(0, 200) 99 | 100 | # This is the values used for the only_rot experiments 101 | # x_offset = 150 102 | # y_offset = 150 103 | 104 | 105 | _gx, _gy = get_tau(gx, gy, goal_pos) 106 | _gx += x_offset 107 | _gy += y_offset 108 | a = 0 109 | while run: 110 | a += 1 111 | clock.tick(config.framerate) 112 | for event in pygame.event.get(): 113 | if event.type == pygame.QUIT: 114 | run = False 115 | break 116 | if event.type == pygame.KEYUP: 117 | if event.key == S_KEY: 118 | curr_pos = get_start() 119 | #curr_pos[2] = np.random.randint(0, 360) 120 | if event.key == R_KEY: 121 | x_offset = np.random.randint(0, 200) 122 | y_offset = np.random.randint(0, 200) 123 | rect_rot = np.ones(9) * np.random.randint(0,360) 124 | #rect_rot = np.random.randint(0,360, (9,)) 125 | 126 | # gx, gy = get_tau_()#np.random.randint(0, 3, (2,)) 127 | #tau_opts = np.random.randint(0, 255, (3,3,3)) if config.color else goal_pos 128 | tau = (gx, gy)#get_tau(gx, gy, tau_opts) 129 | if event.key == ESCAPE_KEY: 130 | run = False 131 | break 132 | 133 | 134 | 135 | 136 | screen.fill((211,211,211)) 137 | 138 | surf2 = pygame.Surface((RECT_Y-10, RECT_Y-10), pygame.SRCALPHA) 139 | surf2.fill((0, 255, 0)) 140 | 141 | #for obstacle in obstacles: 142 | # pygame.draw.rect(screen, (255, 0, 0), pygame.Rect(*obstacle)) 143 | for x, y in list(itertools.product(goals_x, goals_y)): 144 | color = tau_opts[x, y] if config.color else (0, 0, 255) 145 | #surf = pygame.Surface((RECT_X, RECT_Y), pygame.SRCALPHA) 146 | #surf.fill(color) 147 | #surf = pygame.transform.rotate(surf, rect_rot[3*x+y]) 148 | #surf.convert() 149 | #screen.blit(surf, (goal_pos[x][y][0] + x_offset, goal_pos[x][y][1] + y_offset)) 150 | 151 | color = tau_opts[x, y] if config.color else (0, 0, 255) 152 | #surf = pygame.Surface() 153 | surf = pygame.Surface((RECT_X, RECT_Y), pygame.SRCALPHA) 154 | surf.fill(color) 155 | surf.blit(surf2, (5, 5)) 156 | surf = pygame.transform.rotate(surf, rect_rot[3*x+y]) 157 | print(rect_rot[3*x+y]) 158 | surf.convert() 159 | screen.blit(surf, (goal_pos[x][y][0] + x_offset, goal_pos[x][y][1] + y_offset)) 160 | 161 | # THIS IS THE OLD CODE FOR MULTICOLOR SQUARES 162 | # pygame.draw.rect(screen, color, pygame.Rect(goal_pos[x][y][0]-RECT_X/2 + x_offset, goal_pos[x][y][1]-RECT_Y/2 + y_offset, RECT_X, RECT_Y)) 163 | # pygame.draw.rect(screen, (255, 0, 0), pygame.Rect(goal_pos[x][y][0]-RECT_X/4 + x_offset, goal_pos[x][y][1]-RECT_Y/4 + y_offset, RECT_X/2, RECT_Y/2)) 164 | # pygame.draw.rect(screen, (0, 255, 0), pygame.Rect(goal_pos[x][y][0]-RECT_X/8 + x_offset, goal_pos[x][y][1]-RECT_Y/8 + y_offset, RECT_X/4, RECT_Y/4)) 165 | 166 | 167 | #surf = pygame.Surface((RECT_X, RECT_Y), pygame.SRCALPHA) 168 | #surf.fill((0,0,0)) 169 | #surf = pygame.transform.rotate(surf, int(curr_pos[2])) 170 | #surf = pygame.transform.rotate(surf, 90) 171 | #surf.convert() 172 | #screen.blit(surf, curr_pos[:2]) 173 | 174 | surf = pygame.Surface((RECT_X, RECT_Y), pygame.SRCALPHA) 175 | surf.fill((0,0,0)) 176 | surf2.fill((255, 0, 0)) 177 | surf.blit(surf2, (5, 5)) 178 | surf = pygame.transform.rotate(surf, int(curr_pos[2])) 179 | surf.convert() 180 | screen.blit(surf, curr_pos[:2]) 181 | 182 | # OLD CIRCULAR AGENT 183 | # pygame.draw.circle(screen, (0,0,0), [int(v) for v in curr_pos[:2]], 20, 0) 184 | pygame.display.update() 185 | 186 | 187 | 188 | 189 | vanilla_rgb_string = pygame.image.tostring(screen,"RGBA",False) 190 | vanilla_rgb_pil = Image.frombytes("RGBA",(800,600),vanilla_rgb_string) 191 | resized_rgb = vanilla_rgb_pil.resize((160,120)) 192 | rgb = np.array(resized_rgb)[:,:,:3] 193 | rgb = process_images(rgb, True) 194 | if torch.any(torch.isnan(rgb)): 195 | rgb.zero_() 196 | 197 | vanilla_depth = Image.fromarray(np.uint8(np.zeros((120,160)))) 198 | depth = process_images(vanilla_depth, False).zero_() 199 | 200 | div = [400, 300, 180] if config.normalize else [1, 1, 1] 201 | sub = [400, 300, 180] if config.normalize else [0, 0, 0] 202 | norm_pos = [(curr_pos[i] - sub[i]) / div[i] for i in range(3)] 203 | norm_pos = norm_pos[:2] + [np.sin(np.pi * norm_pos[2]), np.cos(np.pi * norm_pos[2])] 204 | if eof is None: 205 | eof = torch.FloatTensor(norm_pos * 5) 206 | else: 207 | eof = torch.cat([torch.FloatTensor(norm_pos), eof[0:16] * 0]) 208 | 209 | # Calculate the trajectory 210 | in_tau = torch.FloatTensor(tau) 211 | print_loc = 'eval_print/' + str(a) 212 | out = None 213 | aux = None 214 | with torch.no_grad(): 215 | out, aux = model(rgb, depth, eof.view(1, -1), in_tau.view(1, -1).to(eof), b_print=config.print, print_path=print_loc)#, aux_in = torch.rand(1,4)) 216 | out = out.squeeze() 217 | delta_x = out[0].item() 218 | delta_y = out[1].item() 219 | delta_rot = out[2].item() 220 | 221 | # This block is a relic from when the network output sin and cos values 222 | # sin_cos = out[2:4] 223 | # mag = torch.sqrt(sin_cos[0]**2 + sin_cos[1]**2) 224 | # sin_cos = sin_cos / mag 225 | # delta_rot = (torch.atan2(sin_cos[0], sin_cos[1]).item() / 3.14159) * 180 226 | new_pos = [curr_pos[0] + delta_x, curr_pos[1] + delta_y, (curr_pos[2] + delta_rot) % 360] 227 | # sin_cos = aux.squeeze()[2:4] 228 | # mag = torch.sqrt(sin_cos[0]**2 + sin_cos[1]**2) 229 | # sin_cos = sin_cos / mag 230 | # aux_rot = (torch.atan2(sin_cos[0] , sin_cos[1]).view(-1, 1) / 3.14159) - 1 231 | 232 | print(eof.numpy()) 233 | print(out.numpy()) 234 | print([-1*np.sin(new_pos[2] / 180 * np.pi), -1*np.cos(new_pos[2] / 180 * np.pi)]) 235 | print([-1*np.sin((rect_rot[3*gx+gy] / 180) * np.pi), -1*np.cos((rect_rot[3*gx+gy] / 180) * np.pi)]) 236 | print((goal_pos[gx][gy][0] + x_offset - 400) / 400, (goal_pos[gx][gy][1] + y_offset - 300) / 300) 237 | print(aux.numpy()) 238 | # print([-1*np.sin(rect_rot[3*gx+gy] / 180 * 3.14159), -1*np.cos(rect_rot[3*gx+gy] / 180 * 3.14159)]) 239 | # print(fix_rot(torch.FloatTensor([rect_rot[3*gx+gy] / 180 - 1]).view(-1, 1), aux_rot)) 240 | # print(fix_rot(torch.FloatTensor([rect_rot[3*gx+gy] / 180 - 1]).view(-1, 1), torch.FloatTensor([new_pos[2]/180 - 1]).view(-1, 1))) 241 | # print(fix_rot(torch.FloatTensor([new_pos[2]/180 - 1]).view(-1, 1), aux_rot)) 242 | # print(eof) 243 | # print(tau) 244 | # print(aux) 245 | # print((_gx, _gy)) 246 | # print(out) 247 | # print(new_pos) 248 | # print(distance(curr_pos, new_pos)) 249 | #if (distance(curr_pos, new_pos)) < 1.5: 250 | # time.sleep(5) 251 | print('========') 252 | curr_pos = new_pos 253 | 254 | #if a == 200: 255 | # break 256 | pygame.quit() 257 | return 0 258 | 259 | if __name__ == '__main__': 260 | parser = argparse.ArgumentParser(description='Input to 2d simulation.') 261 | parser.add_argument('-w', '--weights', required=True, help='The path to the weights to load.') 262 | parser.add_argument('-c', '--color', dest='color', default=False, action='store_true', help='Used to activate color simulation.') 263 | parser.add_argument('-no', '--normalize', dest='normalize', default=False, action='store_true', help='Used to activate position normalization.') 264 | parser.add_argument('-f', '--framerate', default=300, type=int, help='Framerate of simulation.') 265 | parser.add_argument('-r', '--rotation', default=True, dest='rotation', action='store_false', help='Used to eval rotation.') 266 | parser.add_argument('-p', '--print', default=False, dest='print', action='store_true', help='Flag to print activations.') 267 | parser.add_argument('-att', '--attention', default=False, dest='attention', action='store_true', help='Flag indicating to use attention') 268 | args = parser.parse_args() 269 | 270 | checkpoint = torch.load(args.weights, map_location='cpu') 271 | if args.attention: 272 | model = Model(**checkpoint['kwargs']) 273 | else: 274 | model = Model0(**checkpoint['kwargs']) 275 | model.load_state_dict(checkpoint['model_state_dict']) 276 | model.eval() 277 | 278 | sim(model, args) 279 | -------------------------------------------------------------------------------- /src/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | import lmdb 5 | from multiprocessing import Pool 6 | 7 | import os.path as osp 8 | import pyarrow as pa 9 | import random 10 | 11 | class ImitationLMDB(Dataset): 12 | def __init__(self, dest, mode): 13 | super(ImitationLMDB, self).__init__() 14 | lmdb_file = osp.join(dest, mode+".lmdb") 15 | # Open the LMDB file 16 | self.env = lmdb.open(lmdb_file, subdir=osp.isdir(lmdb_file), 17 | readonly=True, lock=False, 18 | readahead=False, meminit=False) 19 | 20 | with self.env.begin(write=False) as txn: 21 | self.length = self.loads_pyarrow(txn.get(b'__len__')) 22 | self.keys = self.loads_pyarrow(txn.get(b'__keys__')) 23 | 24 | self.shuffled = [i for i in range(self.length)] 25 | #random.shuffle(self.shuffled) 26 | 27 | def loads_pyarrow(self, buf): 28 | return pa.deserialize(buf) 29 | 30 | def __getitem__(self, idx): 31 | rgb, depth, eof, tau, aux, target = None, None, None, None, None, None 32 | index = self.shuffled[idx] 33 | env = self.env 34 | with env.begin(write=False) as txn: 35 | byteflow = txn.get(self.keys[index]) 36 | 37 | # RGB, Depth, EOF, Tau, Aux, Target 38 | unpacked = self.loads_pyarrow(byteflow) 39 | 40 | # load data 41 | rgb = torch.from_numpy(unpacked[0]).type(torch.FloatTensor) 42 | depth = torch.from_numpy(unpacked[1]).type(torch.FloatTensor) 43 | eof = torch.from_numpy(unpacked[2]).type(torch.FloatTensor) 44 | tau = torch.from_numpy(unpacked[3]).type(torch.FloatTensor) 45 | aux = torch.from_numpy(unpacked[4]).type(torch.FloatTensor) 46 | target = torch.from_numpy(unpacked[5]).type(torch.FloatTensor) 47 | 48 | return [rgb, depth, eof, tau, target, aux] 49 | 50 | def __len__(self): 51 | return self.length 52 | 53 | def close(self): 54 | self.env.close() 55 | -------------------------------------------------------------------------------- /src/loss_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from math import pi 4 | 5 | class LossException(Exception): 6 | def __init__(self, *args, **kwargs): 7 | super(LossException, self).__init__(*args, **kwargs) 8 | 9 | 10 | def fix_rot(goal_rot, cur_rot): 11 | true_rot = torch.abs(torch.remainder(goal_rot, 2) - torch.remainder(cur_rot, 2)) 12 | mask = true_rot > 1 13 | true_rot[mask] = 2 - true_rot[mask] 14 | 15 | return true_rot 16 | 17 | class BehaviorCloneLoss(nn.Module): 18 | """ 19 | Behavior Clone Loss 20 | """ 21 | def __init__(self, lamb_l2=0.01, lamb_l1=1.0, lamb_c=0.005, lamb_aux=0.0001, eps=1e-7): 22 | super(BehaviorCloneLoss, self).__init__() 23 | self.lamb_l2 = lamb_l2 24 | self.lamb_l1 = lamb_l1 25 | self.lamb_c = lamb_c 26 | self.lamb_aux = lamb_aux 27 | self.l2 = nn.MSELoss() 28 | self.l1 = nn.L1Loss() 29 | self.aux = nn.MSELoss() 30 | 31 | self.eps = eps 32 | 33 | def forward(self, out, aux_out, target, aux_target, flag=False): 34 | if torch.any(torch.isnan(out)): 35 | print(out) 36 | raise LossException('nan in model outputs!') 37 | 38 | ''' 39 | x = out[:, 0:1] * torch.sin(out[:, 1:2]) * torch.cos(out[:, 2:3]) 40 | y = out[:, 0:1] * torch.sin(out[:, 1:2]) * torch.sin(out[:, 2:3]) 41 | z = out[:, 0:1] * torch.cos(out[:, 1:2]) 42 | 43 | out = torch.cat([x, y, z, out[:, 3:]], dim=1) 44 | ''' 45 | l2_loss = self.l2(out[:, :2], target[:, :2]) * 2 / 3 + self.l2(out[:, 2], target[:, 2]) / 3 46 | l1_loss = self.l1(out, target) 47 | 48 | # For the arccos loss 49 | bs, n = out.shape 50 | num = torch.bmm(target.view(bs,1,n), out.view(bs,n,1)).squeeze() 51 | den = torch.norm(target,p=2,dim=1) * torch.norm(out,p=2,dim=1) + self.eps 52 | div = num / den 53 | a_cos = torch.acos(torch.clamp(div, -1 + self.eps, 1 - self.eps)) 54 | c_loss = torch.mean(a_cos) 55 | # For the aux loss 56 | aux_loss = self.aux(aux_out, aux_target) 57 | 58 | weighted_loss = self.lamb_l2*l2_loss + self.lamb_l1*l1_loss + self.lamb_c*c_loss + self.lamb_aux*aux_loss 59 | 60 | if flag:#torch.isnan(weighted_loss): 61 | #print(out) 62 | print('===============') 63 | print('===============') 64 | #print(target) 65 | 66 | print(' ') 67 | print(' ') 68 | print(' ') 69 | 70 | print('weighted loss: %.2f' % weighted_loss) 71 | print('l2 loss: %.2f' % l2_loss) 72 | print('l1 loss: %.2f' % l1_loss) 73 | print('c loss: %.2f' % c_loss) 74 | print('aux loss: %.2f' % aux_loss) 75 | 76 | print(' ') 77 | 78 | print('L2 x: %.2f' % self.l2(out[:, 0], target[:, 0])) 79 | print('L2 y: %.2f' % self.l2(out[:, 1], target[:, 1])) 80 | print('L2 theta: %.2f' % self.l2(out[:, 2], target[:, 2])) 81 | 82 | if torch.isnan(c_loss): 83 | print('num: %s' % str(num)) 84 | print('den: %s' % str(den)) 85 | print('div: %s' % str(div)) 86 | print('acos: %s' % str(a_cos)) 87 | 88 | #raise LossException('Loss is nan!') 89 | 90 | return weighted_loss 91 | -------------------------------------------------------------------------------- /superval_3dval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | import cv2 3 | import time 4 | import csv 5 | import os 6 | import sys 7 | import rospy 8 | import itertools 9 | import numpy as np 10 | #from tf.transformations import euler_from_quaternion 11 | from cv_bridge import CvBridge 12 | from namedlist import namedlist 13 | from std_msgs.msg import Int64, String 14 | from sensor_msgs.msg import CompressedImage, Image, JointState 15 | from geometry_msgs.msg import Twist, Pose, TwistStamped, PoseStamped, Vector3 16 | import torch 17 | from model import Model 18 | import argparse 19 | import copy 20 | from matplotlib import pyplot as plt 21 | from blob import blob 22 | 23 | 24 | class ImitateEval: 25 | def __init__(self, weights): 26 | self.bridge = CvBridge() 27 | self.Data = namedlist('Data', ['pose', 'rgb', 'depth']) 28 | self.data = self.Data(pose=None, rgb=None, depth=None) 29 | self.is_start = True 30 | 31 | checkpoint = torch.load(weights, map_location="cpu") 32 | self.model = Model(**checkpoint['kwargs']) 33 | self.model.load_state_dict(checkpoint["model_state_dict"]) 34 | self.model.eval() 35 | 36 | def change_start(self): 37 | radius = 0.07 38 | # Publisher for the movement and the starting pose 39 | self.movement_publisher = rospy.Publisher('/iiwa/CollisionAwareMotion', Pose, queue_size=10) 40 | self.target_start = Pose() 41 | self.target_start.position.x = -0.15 + np.random.rand()*2*radius - radius # -0.10757 42 | self.target_start.position.y = 0.455 + np.random.rand()*2*radius - radius # 0.4103 43 | self.target_start.position.z = 1.015 44 | self.target_start.orientation.x = 0.0 45 | self.target_start.orientation.y = 0.0 46 | self.target_start.orientation.z = 0.7071068 47 | self.target_start.orientation.w = 0.7071068 48 | 49 | 50 | def move_to_button(self, tau, tolerance): 51 | self.init_listeners() 52 | rospy.Rate(5).sleep() 53 | rate = rospy.Rate(15) 54 | 55 | pose_to_move = copy.deepcopy(self.target_start) 56 | eof = [] 57 | while not rospy.is_shutdown(): 58 | if None not in self.data: 59 | # Position from the CartesianPose Topic!! 60 | pos = self.data.pose.position 61 | pos = [pos.x, pos.y, pos.z] 62 | if self.is_start: 63 | for _ in range(5): 64 | eof += pos 65 | self.is_start = False 66 | else: 67 | eof = pos + eof[:-3] 68 | 69 | eof_input = torch.from_numpy(np.array(eof)).type(torch.FloatTensor) 70 | eof_input = eof_input.unsqueeze(0) 71 | 72 | 73 | 74 | rgb = self.process_images(self.data.rgb, True) 75 | depth = self.process_images(self.data.depth, False) 76 | 77 | # print("RGB min: {}, RGB max: {}".format(np.amin(rgb), np.amax(rgb))) 78 | # print("Depth min: {}, Depth max: {}".format(np.amin(depth), np.amax(depth))) 79 | print("EOF: {}".format(eof_input)) 80 | print("Tau: {}".format(tau)) 81 | 82 | torch.save(rgb, "/home/amazon/Desktop/rgb_tensor.pt") 83 | torch.save(depth, "/home/amazon/Desktop/depth_tensor.pt") 84 | torch.save(eof, "/home/amazon/Desktop/eof_tensor.pt") 85 | torch.save(tau, "/home/amazon/Desktop/tau.pt") 86 | 87 | with torch.no_grad(): 88 | out, aux = self.model(rgb, depth, eof_input, tau) 89 | torch.save(out, "/home/amazon/Desktop/out.pt") 90 | torch.save(aux, "/home/amazon/Desktop/aux.pt") 91 | out = out.squeeze() 92 | x_cartesian = out[0].item() 93 | y_cartesian = out[1].item() 94 | z_cartesian = out[2].item() 95 | print("X:{}, Y:{}, Z:{}".format(x_cartesian, y_cartesian, z_cartesian)) 96 | print("Aux: {}".format(aux)) 97 | # This new pose is the previous pose + the deltas output by the net, adjusted for discrepancy in frame 98 | # It used to be: 99 | # pose_to_move.position.x += -y_cartesian 100 | # pose_to_move.position.y += x_cartesian 101 | # pose_to_move.position.z += z_cartesian 102 | 103 | pose_to_move.position.x -= y_cartesian 104 | pose_to_move.position.y += x_cartesian 105 | pose_to_move.position.z += z_cartesian 106 | #print(pose_to_move) 107 | 108 | # Publish to Kuka!!!! 109 | for i in range(10): 110 | self.movement_publisher.publish(pose_to_move) 111 | rospy.Rate(10).sleep() 112 | 113 | rospy.wait_for_message("/iiwa/CollisionAwareExecutionStatus", String) 114 | # End publisher 115 | 116 | self.data = self.Data(pose=None,rgb=None,depth=None) 117 | rate.sleep() 118 | 119 | if distance(tau, (pose_to_move.position.x, pose_to_move.position.y, pose_to_move.position.z)) < tolerance: 120 | break 121 | 122 | def process_images(self, img_msg, is_it_rgb): 123 | crop_right=586 124 | crop_lower=386 125 | img = self.bridge.compressed_imgmsg_to_cv2(img_msg, desired_encoding="passthrough") 126 | if(is_it_rgb): 127 | img = img[:,:,::-1] 128 | # Does this crop work? 129 | #rgb = img[0:386, 0:586] 130 | #rgb = img.crop((0, 0, crop_right, crop_lower)) 131 | rgb = cv2.resize(img, (160,120)) 132 | rgb = np.array(rgb).astype(np.float32) 133 | 134 | rgb = 2*((rgb - np.amin(rgb))/(np.amax(rgb)-np.amin(rgb)))-1 135 | 136 | 137 | rgb = torch.from_numpy(rgb).type(torch.FloatTensor) 138 | if is_it_rgb: 139 | rgb = rgb.view(1, rgb.shape[0], rgb.shape[1], rgb.shape[2]).permute(0, 3, 1, 2) 140 | else: 141 | rgb = rgb.view(1, 1, rgb.shape[0], rgb.shape[1]) 142 | #plt.imshow(rgb[0,0] / 2 + .5) 143 | # plt.show() 144 | 145 | return rgb 146 | 147 | def move_to_start(self): 148 | # Publish starting position to Kuka!!!! 149 | for i in range(10): 150 | self.movement_publisher.publish(self.target_start) 151 | rospy.Rate(10).sleep() 152 | 153 | rospy.wait_for_message("/iiwa/CollisionAwareExecutionStatus", String) 154 | # End publisher 155 | 156 | def init_listeners(self): 157 | # The Topics we are Subscribing to for data 158 | self.right_arm_pose = rospy.Subscriber('/iiwa/state/CartesianPose', PoseStamped, self.pose_callback) 159 | self.rgb_state_sub = rospy.Subscriber('/camera3/camera/color/image_rect_color/compressed', CompressedImage, self.rgb_callback) 160 | self.depth_state_sub = rospy.Subscriber('/camera3/camera/depth/image_rect_raw/compressed', CompressedImage, self.depth_callback) 161 | 162 | def unsubscribe(self): 163 | self.right_arm_pose.unregister() 164 | self.rgb_state_sub.unregister() 165 | self.depth_state_sub.unregister() 166 | 167 | def pose_callback(self, pose): 168 | if None in self.data: 169 | self.data.pose = pose.pose 170 | 171 | def rgb_callback(self, rgb): 172 | if None in self.data: 173 | self.data.rgb = rgb 174 | 175 | def depth_callback(self, depth): 176 | if None in self.data: 177 | self.data.depth = depth 178 | 179 | 180 | def translate_tau(button): 181 | b_0 = int(button[0]) 182 | b_1 = int(button[1]) 183 | 184 | tau = (-.22+.07*b_0, .56-.07*b_1, .94-.0025*b_y) 185 | return tau 186 | 187 | 188 | def distance(a, b): 189 | return np.sqrt(np.sum([np.abs(aa - bb) for aa, bb in zip(a,b)])) 190 | 191 | 192 | def get_tau(r, c): 193 | tau = translate_tau([r, c]) 194 | 195 | 196 | def main(config): 197 | for weights in list(blob(config.weights + '/*/best_checkpoint.tar', recursive=True)): 198 | agent = Agent(config.weights) 199 | agent.change_start() 200 | agent.move_to_start() 201 | rates = torch.zeros(3, 3, config.num_traj) 202 | for r in range(3): 203 | for c in range(3): 204 | for i in range(config.num_traj): 205 | tau = get_tau(r, c) 206 | rates[3, c, i] = agent.move_to_button(tau, config.tolerance) 207 | agent.change_start() 208 | agent.move_to_start() 209 | rates = torch.sum(rates, dim=2) / config.num_traj 210 | torch.save(rates, config.weights[:config.weights.rfind('/')] + '/button_eval_percentages.pt') 211 | 212 | 213 | if __name__ == "__main__": 214 | parser = argparse.ArgumentParser(description="Arguments for evaluating imitation net") 215 | parser.add_argument('-w', '--weights', required=True, help='Path to folder containing checkpoint directories/files.') 216 | parser.add_argument('-t', '--tolerance', default=.01, type=float, help='Tolerance for button presses.') 217 | parser.add_argument('-n', '--num_traj', default=10, type=int, help='Presses per button per arrangement.') 218 | args = parser.parse_args() 219 | rospy.init_node('eval_imitation', log_level=rospy.DEBUG) 220 | try: 221 | main(args) 222 | except KeyboardInterrupt: 223 | pass 224 | cept KeyboardInterrupt: 225 | pass 226 | -------------------------------------------------------------------------------- /superval_results_compile.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from glob import glob 3 | import argparse 4 | from matplotlib import pyplot as plt 5 | import numpy as np 6 | 7 | 8 | def parse_path(path): 9 | buttons = path[path.rfind('(')+1:path.rfind(')')].split(',') 10 | buttons = [(int(button[-2]), int(button[-3])) for button in buttons if len(button) > 0] 11 | return buttons 12 | 13 | 14 | def get_pcts(button_pcts, train_btns): 15 | train_inds = list(zip(*train_btns)) 16 | train_pct = torch.mean(button_pcts[train_inds]) 17 | total_pct = torch.mean(button_pcts) 18 | test_pct = (9 * total_pct - len(train_btns) * train_pct) / (9 - len(train_btns)) 19 | return torch.stack([train_pct, test_pct, total_pct]) 20 | 21 | def errorfill(x, y, yerr, color=None, alpha_fill=0.3, ax=None, plt_label=None, linestyle='-'): 22 | ax = ax if ax is not None else plt.gca() 23 | # if color is None: 24 | # color = ax._get_lines.color_cycle.next() 25 | if np.isscalar(yerr) or len(yerr) == len(y): 26 | ymin = y - yerr 27 | ymax = y + yerr 28 | elif len(yerr) == 2: 29 | ymin, ymax = yerr 30 | ax.plot(x, y, color=color, label=plt_label, linestyle=linestyle) 31 | ax.fill_between(x, ymax, ymin, color=color, alpha=alpha_fill) 32 | 33 | 34 | def plot(percentages, versions): 35 | colors = ['red', 'blue', 'green', 'teal'] 36 | linestyles = ['-', '--', '-.', ':'] 37 | plt.style.use('ggplot') 38 | 39 | # MERL_percs = [11.0, 33.0, 100.0, 78.0, 100.0, 100.0, 100.0, 100.0, 100.0] 40 | # MERL_percs_errors = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] 41 | # 42 | # plt.figure(0) 43 | # errorfill(range(1, 10), MERL_percs, yerr=MERL_percs_errors, color='red', plt_label='tau') 44 | # plt.xlabel('Number of holes in training set') 45 | # plt.ylabel('Successful insertion percentage') 46 | # plt.title('Average Accuracy across all insertions') 47 | # plt.legend() 48 | 49 | plt.figure(1) 50 | # plt.errorbar(range(1, 9), percentages[0, :-1, 1, 0], yerr=percentages[0, :-1, 1, 2], label='tau') 51 | # plt.errorbar(range(1, 9), percentages[1, :-1, 1, 0], yerr=percentages[1, :-1, 1, 2], label='onehot') 52 | tauerr = [percentages[0, :-1, 1, 5], percentages[0, :-1, 1, 6]] 53 | tauerr_avg = [percentages[0, :, 2, 5], percentages[0, :, 2, 6]] 54 | #errorfill(range(1, 9), percentages[0, :-1, 1, 0], yerr=tauerr, color='red', plt_label='tau for unseen') 55 | for i, name in enumerate(versions): 56 | errorfill(range(1, 10), percentages[i, :, 2, 4], yerr=[percentages[i, :, 2, 5], percentages[i, :, 2, 6]], color=colors[i], linestyle=linestyles[i], plt_label=name) 57 | #errorfill(range(1, 10), percentages[1, :, 2, 0], yerr=percentages[1, :, 2, 2], color='green', plt_label='onehot') 58 | plt.xlabel('Number of goals in the training set') 59 | plt.ylabel('Success percentage') 60 | plt.title('Median Success Percentage in Simulation') 61 | plt.legend() 62 | ''' 63 | plt.figure(2) 64 | plt.plot(range(1, 9), percentages[0, :-1, 1, 0], label='All Arrangements') 65 | #plt.plot(range(1, 9), percentages[1, :-1, 1, 0], label='onehot') 66 | plt.plot(range(1, 9), percentages[0, :-1, 1, 3], label='Median-centered 50%'+' of Arrangements') 67 | #plt.plot(range(1, 9), percentages[1, :-1, 1, 3], label='onehot without outliers') 68 | plt.xlabel('Number of Buttons in Train Set') 69 | plt.ylabel('Successful Press Percentage') 70 | plt.title('Average Accuracy on Untrained Buttons With Tau') 71 | plt.legend() 72 | 73 | plt.figure(3) 74 | plt.plot(range(1, 9), percentages[0, :-1, 1, 7], label='100th') 75 | #plt.bar(range(1, 9), percentages[0, :-1, 1, 6], label='75th') 76 | plt.plot(range(1, 9), percentages[0, :-1, 1, 5], label='50th') 77 | #plt.bar(range(1, 9), percentages[0, :-1, 1, 4], label='25th') 78 | plt.plot(range(1, 9), percentages[0, :-1, 1, 3], label='0th') 79 | # plt.plot(range(1, 9), percentages[0, :-1, 1, 4], label='tau without outliers') 80 | # plt.plot(range(1, 9), percentages[1, :-1, 1, 4], label='onehot without outliers') 81 | plt.xlabel('Number of Buttons in Train Set') 82 | plt.ylabel('Successful Press Percentage') 83 | plt.title('Accuracy on Untrained Buttons by Percentile') 84 | plt.legend() 85 | 86 | plt.figure(4) 87 | plt.errorbar(range(1, 10), percentages[0, :, 2, 0], yerr=percentages[0, :, 2, 2], label='tau') 88 | #plt.errorbar(range(1, 10), percentages[1, :, 2, 0], yerr=percentages[1, :, 2, 2], label='onehot') 89 | plt.xlabel('Number of Buttons in Train Set') 90 | plt.ylabel('Successful Press Percentage') 91 | plt.title('Average Accuracy on All Buttons') 92 | plt.legend() 93 | 94 | plt.figure(5) 95 | plt.plot(range(1, 10), percentages[0, :, 2, 0], label='tau') 96 | #plt.plot(range(1, 10), percentages[1, :, 2, 0], label='onehot') 97 | plt.plot(range(1, 10), percentages[0, :, 2, 3], label='tau without outliers') 98 | # plt.plot(range(1, 10), percentages[1, :, 2, 3], label='onehot without outliers') 99 | plt.xlabel('Number of Buttons in Train Set') 100 | plt.ylabel('Successful Press Percentage') 101 | plt.title('Average Accuracy on All Buttons') 102 | plt.legend() 103 | 104 | plt.figure(6) 105 | plt.plot(range(1, 10), percentages[0, :, 2, 1], label='tau') 106 | #plt.plot(range(1, 10), percentages[1, :, 2, 1], label='onehot') 107 | plt.plot(range(1, 10), percentages[0, :, 2, 4], label='tau without outliers') 108 | #plt.plot(range(1, 10), percentages[1, :, 2, 4], label='onehot without outliers') 109 | plt.xlabel('Number of Buttons in Train Set') 110 | plt.ylabel('Successful Press Percentage') 111 | plt.title('Maximum Accuracy on All Buttons') 112 | plt.legend() 113 | ''' 114 | plt.show() 115 | 116 | 117 | if __name__ == '__main__': 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument('root') 120 | parser.add_argument('versions', nargs='+') 121 | parser.add_argument('-e', '--error_range', default=.25, type=float) 122 | args = parser.parse_args() 123 | 124 | percentages = torch.zeros(len(args.versions), 9, 3, 8) 125 | for i_version, version in enumerate(args.versions): 126 | version_root = args.root + '/' + version 127 | for count in range(1,4): 128 | count_root = version_root + '/' + str(count) 129 | sample_paths = glob(count_root + '/**/button_eval_percentages.pt', recursive=True) 130 | samples = [(torch.load(path), parse_path(path)) for path in sample_paths] 131 | sample_pcts = torch.stack([get_pcts(*sample) for sample in samples], dim=0) 132 | sorted_test = torch.stack([sample_pcts[:, i].sort()[0] for i in range(sample_pcts.size(1))], dim=1) 133 | error_low = int(sample_pcts.size(0) * args.error_range) 134 | error_high = int(sample_pcts.size(0) * (1 - args.error_range)) 135 | percentages[i_version, count - 1, :] = torch.stack([torch.mean(sample_pcts, dim=0), 136 | torch.max(sample_pcts, dim=0)[0], 137 | torch.std(sample_pcts, dim=0), 138 | sorted_test[int(sample_pcts.size(0) * 0)], 139 | sorted_test[int(sample_pcts.size(0) * .49)], 140 | sorted_test[error_low], 141 | sorted_test[error_high], 142 | sorted_test[int(sample_pcts.size(0) * .99)], 143 | ], dim=1) 144 | 145 | #print((percentages*100).int()) 146 | #print(percentages[1, :, 2, 0]*100) 147 | plot(percentages.numpy()*100, args.versions) 148 | -------------------------------------------------------------------------------- /temprun.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in 5 6 7 8 9 4 | do 5 | cp -r 3d_$i"buttons" ~/ 6 | python train.py -d ~/3d_$i"buttons" -s 3d_$i"buttons_out" -ne 5 -lr .0005 -ub -sr 51 -opt novograd -at -device cuda:1 7 | rm -r ~/3d_$i"buttons" 8 | done 9 | -------------------------------------------------------------------------------- /util/plot_loss.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import numpy as np 3 | import pandas as pd 4 | 5 | import argparse 6 | 7 | def load_file(filename): 8 | data = pd.read_csv(filename, sep=',', header=None) 9 | data.columns = ["epoch", "loss_category", "loss"] 10 | 11 | categories = data['loss_category'].unique() 12 | 13 | losses = [(category, filter_loss(data, category)) for category in categories] 14 | 15 | return losses, filename[:-4] 16 | 17 | 18 | def filter_loss(data, category): 19 | cat_inds = data['loss_category'].values == category 20 | 21 | losses = np.stack([data['epoch'], data['loss']]) 22 | cat_losses = losses[(slice(None),) + np.where(cat_inds)] 23 | 24 | return cat_losses 25 | 26 | 27 | def plot_file(losses, identifier=None): 28 | if identifier is None: 29 | label = 'Loss' 30 | else: 31 | label = 'Loss(%s)' % identifier 32 | 33 | for cat, loss in losses: 34 | plt.plot(loss[0], loss[1], label='%s %s'%(cat, label)) 35 | 36 | 37 | if __name__ == '__main__': 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('loss_files', 40 | help='The path(s) to the loss file(s) you want to plot.', 41 | nargs='*') 42 | args = parser.parse_args() 43 | 44 | print('Loading data...') 45 | losses = [load_file(loss) for loss in args.loss_files] 46 | print('Data loaded.') 47 | 48 | for loss in losses: 49 | plot_file(*loss) 50 | 51 | plt.xlabel('Epoch') 52 | plt.ylabel('Loss') 53 | plt.title('Train/Test Loss') 54 | plt.legend() 55 | plt.show() 56 | -------------------------------------------------------------------------------- /venv_tool: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source /data/people/shastin1/venv/bin/activate 4 | pip uninstall pyarrow 5 | pip install pyarrow==0.11.1 6 | -------------------------------------------------------------------------------- /verify_data.py: -------------------------------------------------------------------------------- 1 | from src.datasets import ImitationLMDB 2 | from src.model import Model 3 | from matplotlib import pyplot as plt 4 | import torch 5 | import argparse 6 | 7 | 8 | def verify_data(config): 9 | model = None 10 | if config.weights is not None: 11 | checkpoint = torch.load(config.weights, map_location='cpu') 12 | model = Model(**checkpoint['kwargs']) 13 | model.load_state_dict(checkpoint['model_state_dict']) 14 | model.eval() 15 | 16 | dataset = ImitationLMDB(config.data_dir, config.mode) 17 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=config.shuffle) 18 | 19 | for data in dataloader: 20 | rgb = data[0].squeeze().permute(1, 2, 0) 21 | print(rgb.min(), rgb.max()) 22 | 23 | depth = data[1].squeeze() 24 | print(depth.min(), depth.max()) 25 | 26 | print('EOF: %s' % data[2].squeeze()) 27 | print('TAU: %s' % data[3].squeeze()) 28 | 29 | print('Target: %s' % data[4].squeeze()) 30 | print('Aux: %s' % data[5].squeeze()) 31 | 32 | if model is not None: 33 | out, aux = model(data[0], data[1], data[2], data[3]) 34 | print('Model out: %s' % out.squeeze()) 35 | print('Model aux: %s' % aux.squeeze()) 36 | 37 | if config.show: 38 | plt.imshow((rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-6)) 39 | plt.show() 40 | 41 | plt.imshow((depth - depth.min()) / (depth.max() - depth.min() + 1e-6)) 42 | plt.show() 43 | 44 | print('==========================') 45 | 46 | 47 | 48 | 49 | 50 | if __name__ == '__main__': 51 | parser = argparse.ArgumentParser(description='Qualitative data verification') 52 | parser.add_argument('-d', '--data_dir', required=True, help='Location of lmdb to pull data from.') 53 | parser.add_argument('-w', '--weights', help='Weights for model to evaluate on data.') 54 | parser.add_argument('-m', '--mode', default='test', help='Mode to evaluate data of.') 55 | parser.add_argument('-s', '--shuffle', default=False, dest='shuffle', action='store_true', help='Weights for model to evaluate on data.') 56 | parser.add_argument('-sh', '--show', default=False, dest='show', action='store_true', help='Flag to show visual inputs.') 57 | args = parser.parse_args() 58 | 59 | verify_data(args) 60 | --------------------------------------------------------------------------------