├── .gitignore
├── 3d_eval.py
├── README.md
├── grid_script
├── mellow.py
├── modeltests.py
├── old_stack
├── crnn.py
├── imitate_msgs
│ ├── msg
│ │ └── GripperStamped.msg
│ └── package.xml
├── models
│ ├── crnn.py
│ └── imitation_net.py
├── publishers
│ ├── rightGripperPublisher.py
│ ├── rightPosePublisher.py
│ ├── rightVelPublisher.py
│ └── toggletest.py
├── simulation
│ └── gym-imitate
│ │ ├── README.md
│ │ ├── imitate_env.egg-info
│ │ └── PKG-INFO
│ │ ├── imitate_gym
│ │ ├── __init__.py
│ │ ├── buttons_data.py
│ │ ├── crnn.py
│ │ ├── envs
│ │ │ ├── __init__.py
│ │ │ ├── assets
│ │ │ │ ├── buttons
│ │ │ │ │ ├── buttons.xml
│ │ │ │ │ ├── buttons_3x3.xml
│ │ │ │ │ ├── robot.xml
│ │ │ │ │ └── shared.xml
│ │ │ │ ├── stls
│ │ │ │ │ ├── .get
│ │ │ │ │ ├── fetch
│ │ │ │ │ │ ├── base_link_collision.stl
│ │ │ │ │ │ ├── bellows_link_collision.stl
│ │ │ │ │ │ ├── elbow_flex_link_collision.stl
│ │ │ │ │ │ ├── estop_link.stl
│ │ │ │ │ │ ├── forearm_roll_link_collision.stl
│ │ │ │ │ │ ├── gripper_link.stl
│ │ │ │ │ │ ├── head_pan_link_collision.stl
│ │ │ │ │ │ ├── head_tilt_link_collision.stl
│ │ │ │ │ │ ├── l_wheel_link_collision.stl
│ │ │ │ │ │ ├── laser_link.stl
│ │ │ │ │ │ ├── r_wheel_link_collision.stl
│ │ │ │ │ │ ├── shoulder_lift_link_collision.stl
│ │ │ │ │ │ ├── shoulder_pan_link_collision.stl
│ │ │ │ │ │ ├── torso_fixed_link.stl
│ │ │ │ │ │ ├── torso_lift_link_collision.stl
│ │ │ │ │ │ ├── upperarm_roll_link_collision.stl
│ │ │ │ │ │ ├── wrist_flex_link_collision.stl
│ │ │ │ │ │ └── wrist_roll_link_collision.stl
│ │ │ │ │ └── hand
│ │ │ │ │ │ ├── F1.stl
│ │ │ │ │ │ ├── F2.stl
│ │ │ │ │ │ ├── F3.stl
│ │ │ │ │ │ ├── TH1_z.stl
│ │ │ │ │ │ ├── TH2_z.stl
│ │ │ │ │ │ ├── TH3_z.stl
│ │ │ │ │ │ ├── forearm_electric.stl
│ │ │ │ │ │ ├── forearm_electric_cvx.stl
│ │ │ │ │ │ ├── knuckle.stl
│ │ │ │ │ │ ├── lfmetacarpal.stl
│ │ │ │ │ │ ├── palm.stl
│ │ │ │ │ │ └── wrist.stl
│ │ │ │ └── textures
│ │ │ │ │ ├── block.png
│ │ │ │ │ └── block_hidden.png
│ │ │ ├── buttons.py
│ │ │ └── imitate_env.py
│ │ └── gym_eval.py
│ │ └── setup.py
├── test_baxter.py
└── util
│ ├── old
│ ├── crnn.py
│ ├── imitate_msgs
│ │ ├── msg
│ │ │ └── GripperStamped.msg
│ │ └── package.xml
│ ├── publishers
│ │ ├── rightGripperPublisher.py
│ │ ├── rightPosePublisher.py
│ │ ├── rightVelPublisher.py
│ │ └── toggletest.py
│ ├── simulation
│ │ ├── 2dsim.py
│ │ └── gym-imitate
│ │ │ ├── README.md
│ │ │ ├── imitate_env.egg-info
│ │ │ └── PKG-INFO
│ │ │ ├── imitate_gym
│ │ │ ├── __init__.py
│ │ │ ├── buttons_data.py
│ │ │ ├── crnn.py
│ │ │ ├── envs
│ │ │ │ ├── __init__.py
│ │ │ │ ├── assets
│ │ │ │ │ ├── buttons
│ │ │ │ │ │ ├── buttons.xml
│ │ │ │ │ │ ├── buttons_3x3.xml
│ │ │ │ │ │ ├── robot.xml
│ │ │ │ │ │ └── shared.xml
│ │ │ │ │ ├── stls
│ │ │ │ │ │ ├── .get
│ │ │ │ │ │ ├── fetch
│ │ │ │ │ │ │ ├── base_link_collision.stl
│ │ │ │ │ │ │ ├── bellows_link_collision.stl
│ │ │ │ │ │ │ ├── elbow_flex_link_collision.stl
│ │ │ │ │ │ │ ├── estop_link.stl
│ │ │ │ │ │ │ ├── forearm_roll_link_collision.stl
│ │ │ │ │ │ │ ├── gripper_link.stl
│ │ │ │ │ │ │ ├── head_pan_link_collision.stl
│ │ │ │ │ │ │ ├── head_tilt_link_collision.stl
│ │ │ │ │ │ │ ├── l_wheel_link_collision.stl
│ │ │ │ │ │ │ ├── laser_link.stl
│ │ │ │ │ │ │ ├── r_wheel_link_collision.stl
│ │ │ │ │ │ │ ├── shoulder_lift_link_collision.stl
│ │ │ │ │ │ │ ├── shoulder_pan_link_collision.stl
│ │ │ │ │ │ │ ├── torso_fixed_link.stl
│ │ │ │ │ │ │ ├── torso_lift_link_collision.stl
│ │ │ │ │ │ │ ├── upperarm_roll_link_collision.stl
│ │ │ │ │ │ │ ├── wrist_flex_link_collision.stl
│ │ │ │ │ │ │ └── wrist_roll_link_collision.stl
│ │ │ │ │ │ └── hand
│ │ │ │ │ │ │ ├── F1.stl
│ │ │ │ │ │ │ ├── F2.stl
│ │ │ │ │ │ │ ├── F3.stl
│ │ │ │ │ │ │ ├── TH1_z.stl
│ │ │ │ │ │ │ ├── TH2_z.stl
│ │ │ │ │ │ │ ├── TH3_z.stl
│ │ │ │ │ │ │ ├── forearm_electric.stl
│ │ │ │ │ │ │ ├── forearm_electric_cvx.stl
│ │ │ │ │ │ │ ├── knuckle.stl
│ │ │ │ │ │ │ ├── lfmetacarpal.stl
│ │ │ │ │ │ │ ├── palm.stl
│ │ │ │ │ │ │ └── wrist.stl
│ │ │ │ │ └── textures
│ │ │ │ │ │ ├── block.png
│ │ │ │ │ │ └── block_hidden.png
│ │ │ │ ├── buttons.py
│ │ │ │ └── imitate_env.py
│ │ │ └── gym_eval.py
│ │ │ └── setup.py
│ └── test_baxter.py
│ ├── parse_data.py
│ ├── record.py
│ └── toggler.py
├── sim_eval.py
├── sim_eval0.py
├── simulation
└── sim.py
├── src
├── datasets.py
├── loss_func.py
├── model.py
└── model0.py
├── superval.py
├── superval_3dval.py
├── superval_results_compile.py
├── temprun.sh
├── train.py
├── util
├── parse_data.py
└── plot_loss.py
├── venv_tool
└── verify_data.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Miscellaneous
2 | *.swp
3 |
4 | # Compiled data
5 | *.csv
6 | *.lmdb
7 | *.lmdb-lock
8 |
9 | # Generated data
10 | *.png
11 | *.txt
12 | *.pt
13 |
14 | # Model checkpoints
15 | *.tar
16 |
17 | # Pycache
18 | *.pyc
19 |
20 | # Auto-generated grid files
21 | *.e1*
22 | *.o1*
23 |
--------------------------------------------------------------------------------
/3d_eval.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | import cv2
3 | import time
4 | import csv
5 | import os
6 | import sys
7 | import rospy
8 | import itertools
9 | import numpy as np
10 | #from tf.transformations import euler_from_quaternion
11 | from cv_bridge import CvBridge
12 | from namedlist import namedlist
13 | from std_msgs.msg import Int64, String
14 | from sensor_msgs.msg import CompressedImage, Image, JointState
15 | from geometry_msgs.msg import Twist, Pose, TwistStamped, PoseStamped, Vector3
16 | import torch
17 | from model import Model
18 | import argparse
19 | import copy
20 | from matplotlib import pyplot as plt
21 | import signal
22 |
23 | class ImitateEval:
24 | def __init__(self, weights):
25 | self.bridge = CvBridge()
26 | self.Data = namedlist('Data', ['pose', 'rgb', 'depth'])
27 | self.data = self.Data(pose=None, rgb=None, depth=None)
28 | self.is_start = True
29 |
30 | checkpoint = torch.load(weights, map_location="cpu")
31 | self.model = Model(**checkpoint['kwargs'])
32 | self.model.load_state_dict(checkpoint["model_state_dict"])
33 | self.model.eval()
34 |
35 | def change_start(self):
36 | radius = 0.07
37 | # Publisher for the movement and the starting pose
38 | self.movement_publisher = rospy.Publisher('/iiwa/CollisionAwareMotion', Pose, queue_size=10)
39 | self.target_start = Pose()
40 | self.target_start.position.x = -0.15 + np.random.rand()*2*radius - radius # -0.10757
41 | self.target_start.position.y = 0.455 + np.random.rand()*2*radius - radius # 0.4103
42 | self.target_start.position.z = 1.015
43 | self.target_start.orientation.x = 0.0
44 | self.target_start.orientation.y = 0.0
45 | self.target_start.orientation.z = 0.7071068
46 | self.target_start.orientation.w = 0.7071068
47 |
48 |
49 | def move_to_button(self, tau, tolerance):
50 | self.init_listeners()
51 | rospy.Rate(5).sleep()
52 | rate = rospy.Rate(15)
53 |
54 | pose_to_move = copy.deepcopy(self.target_start)
55 | eof = []
56 | while not rospy.is_shutdown():
57 | if None not in self.data:
58 | # Position from the CartesianPose Topic!!
59 | pos = self.data.pose.position
60 | pos = [pos.x, pos.y, pos.z]
61 | if self.is_start:
62 | for _ in range(5):
63 | eof += pos
64 | self.is_start = False
65 | else:
66 | eof = pos + eof[:-3]
67 |
68 | eof_input = torch.from_numpy(np.array(eof)).type(torch.FloatTensor)
69 | eof_input = eof_input.unsqueeze(0)#.zero_()
70 |
71 | tau_input = torch.Tensor(tau).to(eof_input).view(1, -1)
72 |
73 | rgb = self.process_images(self.data.rgb, True)
74 | depth = self.process_images(self.data.depth, False)
75 |
76 | # print("RGB min: {}, RGB max: {}".format(np.amin(rgb), np.amax(rgb)))
77 | # print("Depth min: {}, Depth max: {}".format(np.amin(depth), np.amax(depth)))
78 | print("EOF: {}".format(eof_input))
79 | print("Tau: {}".format(tau_input))
80 |
81 | torch.save(rgb, "/home/amazon/Desktop/rgb_tensor.pt")
82 | torch.save(depth, "/home/amazon/Desktop/depth_tensor.pt")
83 | torch.save(eof, "/home/amazon/Desktop/eof_tensor.pt")
84 | torch.save(tau_input, "/home/amazon/Desktop/tau.pt")
85 |
86 | with torch.no_grad():
87 | out, aux = self.model(rgb, depth, eof_input, tau_input)
88 | torch.save(out, "/home/amazon/Desktop/out.pt")
89 | torch.save(aux, "/home/amazon/Desktop/aux.pt")
90 | out = out.squeeze()
91 | x_cartesian = out[0].item()
92 | y_cartesian = out[1].item()
93 | z_cartesian = out[2].item()
94 | print("X:{}, Y:{}, Z:{}".format(x_cartesian, y_cartesian, z_cartesian))
95 | print("Aux: {}".format(aux))
96 | # This new pose is the previous pose + the deltas output by the net, adjusted for discrepancy in frame
97 | # It used to be:
98 | # pose_to_move.position.x += -y_cartesian
99 | # pose_to_move.position.y += x_cartesian
100 | # pose_to_move.position.z += z_cartesian
101 |
102 | pose_to_move.position.x -= y_cartesian
103 | pose_to_move.position.y += x_cartesian
104 | pose_to_move.position.z += z_cartesian
105 | #print(pose_to_move)
106 |
107 | # Publish to Kuka!!!!
108 | for i in range(10):
109 | self.movement_publisher.publish(pose_to_move)
110 | rospy.Rate(10).sleep()
111 |
112 | rospy.wait_for_message("/iiwa/CollisionAwareExecutionStatus", String)
113 | # End publisher
114 |
115 | self.data = self.Data(pose=None,rgb=None,depth=None)
116 | rate.sleep()
117 |
118 | #print(pose_to_move.position.y, -pose_to_move.position.x, pose_to_move.position.z)
119 | #print(distance(tau, (pose_to_move.position.y, -pose_to_move.position.x, pose_to_move.position.z)))
120 | if distance(tau, (pose_to_move.position.y, -pose_to_move.position.x, pose_to_move.position.z)) < tolerance:
121 | break
122 |
123 | def process_images(self, img_msg, is_it_rgb):
124 | crop_right=586
125 | crop_lower=386
126 | img = self.bridge.compressed_imgmsg_to_cv2(img_msg, desired_encoding="passthrough")
127 | if(is_it_rgb):
128 | img = img[:,:,::-1]
129 | # Does this crop work?
130 | #rgb = img[0:386, 0:586]
131 | #rgb = img.crop((0, 0, crop_right, crop_lower))
132 | rgb = cv2.resize(img, (160,120))
133 | rgb = np.array(rgb).astype(np.float32)
134 |
135 | rgb = 2*((rgb - np.amin(rgb))/(np.amax(rgb)-np.amin(rgb)))-1
136 |
137 |
138 | rgb = torch.from_numpy(rgb).type(torch.FloatTensor)
139 | if is_it_rgb:
140 | rgb = rgb.view(1, rgb.shape[0], rgb.shape[1], rgb.shape[2]).permute(0, 3, 1, 2)
141 | else:
142 | rgb = rgb.view(1, 1, rgb.shape[0], rgb.shape[1])
143 | #plt.imshow(rgb[0,0] / 2 + .5)
144 | # plt.show()
145 |
146 | return rgb
147 |
148 | def move_to_start(self):
149 | # Publish starting position to Kuka!!!!
150 | for i in range(10):
151 | self.movement_publisher.publish(self.target_start)
152 | rospy.Rate(10).sleep()
153 |
154 | rospy.wait_for_message("/iiwa/CollisionAwareExecutionStatus", String)
155 | # End publisher
156 |
157 | def init_listeners(self):
158 | # The Topics we are Subscribing to for data
159 | self.right_arm_pose = rospy.Subscriber('/iiwa/state/CartesianPose', PoseStamped, self.pose_callback)
160 | self.rgb_state_sub = rospy.Subscriber('/camera3/camera/color/image_rect_color/compressed', CompressedImage, self.rgb_callback)
161 | self.depth_state_sub = rospy.Subscriber('/camera3/camera/depth/image_rect_raw/compressed', CompressedImage, self.depth_callback)
162 |
163 | def unsubscribe(self):
164 | self.right_arm_pose.unregister()
165 | self.rgb_state_sub.unregister()
166 | self.depth_state_sub.unregister()
167 |
168 | def pose_callback(self, pose):
169 | if None in self.data:
170 | self.data.pose = pose.pose
171 |
172 | def rgb_callback(self, rgb):
173 | if None in self.data:
174 | self.data.rgb = rgb
175 |
176 | def depth_callback(self, depth):
177 | if None in self.data:
178 | self.data.depth = depth
179 |
180 |
181 | def translate_tau(button):
182 | b_0 = int(button[0])
183 | b_1 = int(button[1])
184 | tau = [.558-.069*b_0, .22-.063*b_1, .22]
185 | return tau
186 |
187 |
188 | def distance(a, b):
189 | a = [a[0]-.035, a[1], .94]
190 | print(a)
191 | return np.sqrt(np.sum([np.abs(aa - bb) for aa, bb in zip(a,b)]))
192 |
193 |
194 | def get_tau():
195 | button = input('Please enter a button for the robot to try to press (e.g. "0,0", "1,2"): ')
196 | return button + (0,)
197 | tau = translate_tau(button)
198 | return tau
199 |
200 |
201 | def main(weights, tolerance):
202 | agent = ImitateEval(weights)
203 | agent.change_start()
204 | agent.move_to_start()
205 | tau = get_tau()
206 | agent.move_to_button(tau, tolerance)
207 |
208 |
209 | def sighandler(signal, frame):
210 | raise Exception('op')
211 |
212 |
213 | if __name__ == "__main__":
214 | parser = argparse.ArgumentParser(description="Arguments for evaluating imitation net")
215 | parser.add_argument('-w', '--weights', required=True, help='Filepath for model weightsself.')
216 | parser.add_argument('-t', '--tolerance', default=0, type=float, help='Tolerance for button presses.')
217 | args = parser.parse_args()
218 | rospy.init_node('eval_imitation', log_level=rospy.DEBUG)
219 |
220 | signal.signal(signal.SIGINT, sighandler)
221 |
222 | cont = False
223 | while True:
224 | try:
225 | main(args.weights, args.tolerance)
226 | except Exception:
227 | if cont:
228 | break
229 | cont = True
230 | continue
231 | cont = False
232 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Targetable Visuomotor Imitation Learning
2 |
3 | Code repository for the paper "Learning Deep Parameterized Skills for Re-Targetable Visuomotor Control" by the H2R Lab in collaboration with MERL.
4 |
5 | Read the paper here: https://arxiv.org/abs/1910.10628
6 |
--------------------------------------------------------------------------------
/grid_script:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source /data/people/shastin1/venv/bin/activate
4 | python train.py "$@"
5 |
--------------------------------------------------------------------------------
/mellow.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | def mellowmax2(self ,beta):
5 | c = np.max(self.Q[0])
6 | mm = c + np.log((1/self.a_size)*np.sum(np.exp(self.omega * (self.Q[0] - c ))))/self.omega
7 | b = 0
8 | for a in self.Q[0]:
9 | b+=np.exp(beta * (a-mm))*(a-mm)
10 | return b
11 |
12 | def mellowmax(x, dim=0, beta=6, omega=3):
13 | c = torch.max(x, dim=dim, keepdim=True)[0]
14 | mm = c + torch.log((1/x.size(dim))*torch.sum(torch.exp(omega * (x - c)))) / omega
15 | return torch.exp(beta * (x - mm)) * (x - mm)
16 |
--------------------------------------------------------------------------------
/old_stack/imitate_msgs/msg/GripperStamped.msg:
--------------------------------------------------------------------------------
1 | std_msgs/Header header
2 | std_msgs/Int64 data
3 |
--------------------------------------------------------------------------------
/old_stack/imitate_msgs/package.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | imitate_msgs
4 | 0.0.0
5 | The imitate_msgs package
6 |
7 |
8 |
9 |
10 | jonathanchang
11 |
12 |
13 |
14 |
15 |
16 | TODO
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 | catkin
52 | std_msgs
53 | std_msgs
54 | std_msgs
55 | message_generation
56 | message_runtime
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
--------------------------------------------------------------------------------
/old_stack/publishers/rightGripperPublisher.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import roslib
3 | import rospy
4 | import tf
5 | from tf.transformations import euler_from_quaternion
6 | import numpy as np
7 | import message_filters
8 | from std_msgs.msg import Int64, Header
9 | from sensor_msgs.msg import JointState
10 | from imitate_msgs.msg import GripperStamped
11 |
12 |
13 | is_open = None
14 |
15 | def callback(data):
16 | global is_open
17 | position = np.array(data.position) - 0.0016998404159380444
18 | if position.sum() < 0.05:
19 | is_open = 1
20 | else:
21 | is_open = 0
22 |
23 | # Publisher From Here
24 | rospy.init_node('movo_right_gripper')
25 | listener = rospy.Subscriber("/movo/right_gripper/joint_states", JointState, callback)
26 |
27 | # Give time for initialization
28 | rospy.Rate(1).sleep()
29 |
30 | robot_in_map = rospy.Publisher('/movo/right_gripper/gripper_is_open', GripperStamped, queue_size=1)
31 |
32 | rate = rospy.Rate(100)
33 | while not rospy.is_shutdown():
34 | h = Header()
35 | h.stamp = rospy.Time.now()
36 | msg = GripperStamped()
37 | msg.header = h
38 | msg.data = Int64(is_open)
39 | robot_in_map.publish(msg)
40 | rate.sleep()
41 |
--------------------------------------------------------------------------------
/old_stack/publishers/rightPosePublisher.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import roslib
3 | import rospy
4 | import tf
5 | from tf.transformations import euler_from_quaternion
6 | from geometry_msgs.msg import PoseStamped, Pose, Point, Quaternion
7 | from std_msgs.msg import Header
8 |
9 | rospy.init_node('movo_right_tf_listener_pose')
10 |
11 | listener = tf.TransformListener()
12 |
13 | robot_in_map = rospy.Publisher('tf/right_arm_pose', PoseStamped, queue_size=1)
14 |
15 | rate = rospy.Rate(100)
16 | while not rospy.is_shutdown():
17 | try:
18 | (trans, rot) = listener.lookupTransform('/base_link', '/right_gripper_base_link', rospy.Time(0))
19 | h = Header()
20 | h.stamp = rospy.Time.now()
21 | msg = PoseStamped()
22 | msg.header = h
23 | msg.pose = Pose(Point(*trans), Quaternion(*rot))
24 |
25 | except (tf.LookupException, tf.ConnectivityException, tf.ExtrapolationException):
26 | continue
27 | # Publish the message
28 | robot_in_map.publish(msg)
29 | rate.sleep()
30 |
--------------------------------------------------------------------------------
/old_stack/publishers/rightVelPublisher.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import numpy as np
3 | import roslib
4 | import rospy
5 | import tf
6 | from tf.transformations import euler_from_quaternion
7 | from geometry_msgs.msg import Twist, TwistStamped, Vector3
8 | from std_msgs.msg import Header
9 |
10 | rospy.init_node('movo_right_tf_listener')
11 |
12 | listener = tf.TransformListener()
13 |
14 | robot_in_map = rospy.Publisher('tf/right_arm_vels', TwistStamped, queue_size=1)
15 |
16 | prev_time = rospy.get_time()
17 | curr_time = rospy.get_time()
18 | prev_rot, prev_trans = None, None
19 |
20 | rate = rospy.Rate(100)
21 | while not rospy.is_shutdown():
22 | try:
23 | curr_time = rospy.get_time()
24 | (curr_trans, curr_rot) = listener.lookupTransform('/base_link', '/right_gripper_base_link', rospy.Time(0))
25 | # First Iteration
26 | if not prev_rot and not prev_trans:
27 | prev_trans, prev_rot = curr_trans, curr_rot
28 | continue
29 | delta = float(curr_time)-float(prev_time)
30 | lin_vel = (np.array(curr_trans) - np.array(prev_trans))/delta
31 | ang_vel = (np.array(euler_from_quaternion(curr_rot)) - np.array(euler_from_quaternion(prev_rot)))/delta
32 | h = Header()
33 | h.stamp = rospy.Time.now()
34 | msg = TwistStamped()
35 | msg.header = h
36 | msg.twist = Twist(Vector3(*lin_vel), Vector3(*ang_vel))
37 |
38 | # Update
39 | prev_time = curr_time
40 | prev_trans, prev_rot = curr_trans, curr_rot
41 |
42 | except (tf.LookupException, tf.ConnectivityException, tf.ExtrapolationException):
43 | continue
44 | # Publish the message
45 | robot_in_map.publish(msg)
46 | rate.sleep()
47 |
--------------------------------------------------------------------------------
/old_stack/publishers/toggletest.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import roslib
3 | import rospy
4 | from std_msgs.msg import Int64
5 |
6 | # Publisher From Here
7 | rospy.init_node('toggle_test')
8 |
9 | robot_in_map = rospy.Publisher('/toggle', Int64, queue_size=1)
10 |
11 | rate = rospy.Rate(100)
12 | while not rospy.is_shutdown():
13 | rate.sleep()
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/README.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/README.md
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_env.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 1.0
2 | Name: imitate-env
3 | Version: 0.0.1
4 | Summary: UNKNOWN
5 | Home-page: UNKNOWN
6 | Author: UNKNOWN
7 | Author-email: UNKNOWN
8 | License: UNKNOWN
9 | Description: UNKNOWN
10 | Platform: UNKNOWN
11 |
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.registration import register
2 |
3 | register(
4 | id='buttons-v0',
5 | entry_point='imitate_gym.envs:ButtonsEnv',
6 | )
7 |
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/buttons_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import gym
4 | import argparse
5 | import imitate_gym
6 | import numpy as np
7 | from PIL import Image
8 |
9 | # GOAL_00 = np.array([1.2,0.66,0.45])
10 | # GOAL_01 = np.array([1.2,0.76,0.45])
11 | # GOAL_02 = np.array([1.2,0.86,0.45])
12 | # GOAL_10 = np.array([1.3,0.66,0.45])
13 | # GOAL_11 = np.array([1.3,0.76,0.45])
14 | # GOAL_12 = np.array([1.3,0.86,0.45])
15 | # GOAL_20 = np.array([1.4,0.66,0.45])
16 | # GOAL_21 = np.array([1.4,0.76,0.45])
17 | # GOAL_22 = np.array([1.4,0.86,0.45])
18 | # GOALS = np.array([GOAL_00,GOAL_01,GOAL_02,GOAL_10,GOAL_11,GOAL_12,GOAL_20,GOAL_21,GOAL_22])
19 | # NAMES = {0:'goal_00',1:'goal_01',2:'goal_02',3:'goal_10',4:'goal_11',5:'goal_12',6:'goal_20',7:'goal_21',8:'goal_22'}
20 | GOAL_00 = np.array([1.2,0.56,0.45])
21 | GOAL_01 = np.array([1.2,0.66,0.45])
22 | GOAL_02 = np.array([1.2,0.76,0.45])
23 | GOAL_03 = np.array([1.2,0.86,0.45])
24 | GOAL_04 = np.array([1.2,0.96,0.45])
25 | GOAL_10 = np.array([1.3,0.56,0.45])
26 | GOAL_11 = np.array([1.3,0.66,0.45])
27 | GOAL_12 = np.array([1.3,0.76,0.45])
28 | GOAL_13 = np.array([1.3,0.86,0.45])
29 | GOAL_14 = np.array([1.3,0.96,0.45])
30 | GOAL_20 = np.array([1.4,0.56,0.45])
31 | GOAL_21 = np.array([1.4,0.66,0.45])
32 | GOAL_22 = np.array([1.4,0.76,0.45])
33 | GOAL_23 = np.array([1.4,0.86,0.45])
34 | GOAL_24 = np.array([1.4,0.96,0.45])
35 | GOALS = np.array([GOAL_00,
36 | GOAL_01,
37 | GOAL_02,
38 | GOAL_03,
39 | GOAL_04,
40 | GOAL_10,
41 | GOAL_11,
42 | GOAL_12,
43 | GOAL_13,
44 | GOAL_14,
45 | GOAL_20,
46 | GOAL_21,
47 | GOAL_22,
48 | GOAL_23,
49 | GOAL_24])
50 | NAMES = {0:'goal_00',
51 | 1:'goal_01',
52 | 2:'goal_02',
53 | 3:'goal_03',
54 | 4:'goal_04',
55 | 5:'goal_10',
56 | 6:'goal_11',
57 | 7:'goal_12',
58 | 8:'goal_13',
59 | 9:'goal_14',
60 | 10:'goal_20',
61 | 11:'goal_21',
62 | 12:'goal_22',
63 | 13:'goal_23',
64 | 14:'goal_24'}
65 | test=np.array([GOAL_24])
66 |
67 | def calc_changes(start_pos, goal_pos, num_steps=50.):
68 | return (goal_pos-start_pos)/num_steps
69 |
70 | def next_move(curr_pos, goal_pos, slope):
71 | diffs = np.absolute(curr_pos - goal_pos)
72 | if diffs[0] <= 0.001:
73 | slope[0] = 0.
74 | if diffs[1] <= 0.001:
75 | slope[1] = 0.
76 | if diffs[2] <= 0.001:
77 | slope[2] = 0.
78 | return slope
79 |
80 | def run(env, goal, trial_dir, viz=False):
81 | if viz == False:
82 | with open(trial_dir+'vectors.txt', 'w') as f:
83 | writer = csv.writer(f)
84 | curr_pos = None
85 | slope = None
86 | goal_reached = False
87 | counter = 0
88 | traj_counter = 0
89 | other = np.array([0.,0.,0.,0.,0.])
90 | while True:
91 | rgb = env.render('rgb_array')
92 | if curr_pos is None and goal_reached is False:
93 | obs, _, _, _ = env.step([0.,0.,0.,0.,0.,0.,0.,0.])
94 | curr_pos = obs['achieved_goal']
95 | slope = calc_changes(curr_pos, goal)
96 | elif curr_pos is not None and goal_reached is False:
97 | slope = next_move(curr_pos, goal, slope)
98 | if sum(np.absolute(slope)) < 0.008 and curr_pos[2] < 0.455:
99 | goal_reached = True
100 | action = np.concatenate([slope,other])
101 | obs, _, _, _ = env.step(action)
102 | img_rgb = Image.fromarray(rgb, 'RGB')
103 | img_rgb = img_rgb.resize((160,120))
104 | img_rgb.save(trial_dir+'{}.png'.format(traj_counter))
105 | arr = [traj_counter]
106 | arr += [x for x in obs['observation']]
107 | arr += [0]
108 | curr_pos = obs['achieved_goal']
109 | traj_counter += 1
110 | writer.writerow(arr)
111 | elif curr_pos is not None and goal_reached is True:
112 | # This is to simulate the terminal state where the user would stop near the end
113 | counter += 1
114 | obs, _, _, _ = env.step([0.,0.,0.,0.,0.,0.,0.,0.])
115 | img_rgb = Image.fromarray(rgb, 'RGB')
116 | img_rgb = img_rgb.resize((160,120))
117 | img_rgb.save(trial_dir+'{}.png'.format(traj_counter))
118 | arr = [traj_counter]
119 | arr += [x for x in obs['observation']]
120 | arr += [0]
121 | traj_counter += 1
122 | writer.writerow(arr)
123 | if counter == 20:
124 | break
125 | else:
126 | curr_pos = None
127 | slope = None
128 | goal_reached = False
129 | counter = 0
130 | other = np.array([0.,0.,0.,0.,0.])
131 | while True:
132 | env.render()
133 | if curr_pos is None and goal_reached is False:
134 | obs, _, _, _ = env.step([0.,0.,0.,0.,0.,0.,0.,0.])
135 | curr_pos = obs['achieved_goal']
136 | slope = calc_changes(curr_pos, goal)
137 | elif curr_pos is not None and goal_reached is False:
138 | slope = next_move(curr_pos, goal, slope)
139 | if sum(np.absolute(slope)) < 0.008 and curr_pos[2] < 0.455:
140 | goal_reached = True
141 | action = np.concatenate([slope,other])
142 | obs, _, _, _ = env.step(action)
143 | print("==================")
144 | print(slope)
145 | print(obs['observation'][7:10])
146 | curr_pos = obs['achieved_goal']
147 | print(curr_pos)
148 | elif curr_pos is not None and goal_reached is True:
149 | # This is to simulate the terminal state where the user would stop near the end
150 | counter += 1
151 | obs, _, _, _ = env.step([0.,0.,0.,0.,0.,0.,0.,0.])
152 | print("==================")
153 | print(slope)
154 | print(obs['observation'][7:10])
155 | if counter == 20:
156 | break
157 |
158 |
159 |
160 |
161 | def main(task_name, num_trials=300, viz=False):
162 | if viz == False:
163 | if not os.path.exists('../../../datas/' + task_name + '/'):
164 | os.mkdir('../../../datas/' + task_name + '/')
165 | trial_counter = 0
166 | for trial in range(num_trials):
167 | env = gym.make('buttons-v0')
168 | for i, goal in enumerate(GOALS):
169 | save_folder = '../../../datas/'+task_name+'/'+NAMES[i]+'/'
170 | if viz == False:
171 | if not os.path.exists(save_folder):
172 | os.mkdir(save_folder)
173 | trial_dir = save_folder+str(trial_counter)+'/'
174 | if viz == False:
175 | os.mkdir(trial_dir)
176 | env.reset()
177 | run(env, goal, trial_dir, viz=viz)
178 | trial_counter+=1
179 |
180 | if __name__ == '__main__':
181 | main("buttons3x5", num_trials=1, viz=True)
182 |
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from imitate_gym.envs.imitate_env import ImitateEnv
2 | from imitate_gym.envs.buttons import ButtonsEnv
3 |
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/buttons/buttons.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/buttons/buttons_3x3.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/buttons/robot.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/buttons/shared.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/.get:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/.get
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/base_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/base_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/bellows_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/bellows_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/elbow_flex_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/elbow_flex_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/estop_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/estop_link.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/forearm_roll_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/forearm_roll_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/gripper_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/gripper_link.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/head_pan_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/head_pan_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/head_tilt_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/head_tilt_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/l_wheel_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/l_wheel_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/laser_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/laser_link.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/r_wheel_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/r_wheel_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/shoulder_lift_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/shoulder_lift_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/shoulder_pan_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/shoulder_pan_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/torso_fixed_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/torso_fixed_link.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/torso_lift_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/torso_lift_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/upperarm_roll_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/upperarm_roll_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/wrist_flex_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/wrist_flex_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/wrist_roll_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/wrist_roll_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F1.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F2.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F3.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F3.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH1_z.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH1_z.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH2_z.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH2_z.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH3_z.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH3_z.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/forearm_electric.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/forearm_electric.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/forearm_electric_cvx.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/forearm_electric_cvx.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/knuckle.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/knuckle.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/lfmetacarpal.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/lfmetacarpal.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/palm.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/palm.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/wrist.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/wrist.stl
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/textures/block.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/textures/block.png
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/textures/block_hidden.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/simulation/gym-imitate/imitate_gym/envs/assets/textures/block_hidden.png
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/buttons.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from gym import utils
4 | from imitate_gym.envs import imitate_env
5 |
6 | #MODEL_XML_PATH = os.path.join('fetch', 'pick_and_place.xml')
7 | HOME = '/home/jonathanchang/parameterized-imitation-learning/simulation/gym-imitate/imitate_gym/envs'
8 | MODEL_XML_PATH = HOME+'/assets/buttons/buttons_3x3.xml'
9 | class ButtonsEnv(imitate_env.ImitateEnv, utils.EzPickle):
10 | def __init__(self, reward_type='dense'):
11 | initial_qpos = {
12 | 'robot0:slide0': 0.405,
13 | 'robot0:slide1': 0.48,
14 | 'robot0:slide2': 0.0,
15 | 'robot0:shoulder_pan_joint': 0.0,
16 | 'robot0:shoulder_lift_joint': -0.8,
17 | 'robot0:elbow_flex_joint': 1.0
18 | }
19 | target = np.array([1.25, 0.53, 0.4])
20 | imitate_env.ImitateEnv.__init__(
21 | self, model_path=MODEL_XML_PATH, n_substeps=25, initial_qpos=initial_qpos, reward_type=reward_type,
22 | distance_threshold=0.05, gripper_extra_height=0.2, target=target)
23 | utils.EzPickle.__init__(self)
24 |
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/envs/imitate_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import gym
4 | import mujoco_py
5 | from gym import error, spaces, utils
6 | from gym.envs.robotics import rotations, robot_env, utils
7 |
8 | # Just to make sure that input shape is correct
9 | def goal_distance(goal_a, goal_b):
10 | assert goal_a.shape == goal_b.shape
11 | return np.linalg.norm(goal_a - goal_b, axis=-1)
12 |
13 | class ImitateEnv(robot_env.RobotEnv):
14 | """
15 | Inheritance Chain: Env -> GoalEnv -> RobotEnv -> ImitateEnv
16 |
17 | Methods from RobotEnv
18 | seed(), step(), reset(), close(), render(), _reset_sim(),
19 | _get_viewer(): This method will be important for retrieving rgb and depth from the sim
20 | """
21 | def __init__(
22 | self, model_path, n_substeps, initial_qpos, reward_type,
23 | distance_threshold, target, gripper_extra_height
24 | ):
25 | """
26 | Args: model_path (string): path to the environments XML file
27 | n_substeps (int): number of substeps the simulation runs on every call to step
28 | initial_qpos (dict): a dictionary of joint names and values that define the initial configuration
29 | reward_type ('sparse' or 'dense'): the reward type, i.e. sparse or dense
30 | distance_threshold (float): the threshold after which a goal is considered achieved
31 | target (list): the target position that we are aiming for
32 | gripper_extra_height (float): the gripper offset in position
33 | """
34 | # n_actions = 8 for [pos_x, pos_y, pos_z, rot_x, rot_y, rot_z, rot_w, gripper]
35 |
36 | self.reward_type = reward_type
37 | self.distance_threshold = distance_threshold
38 | self.target = target
39 | self.gripper_extra_height = gripper_extra_height
40 | self._viewers = {}
41 | self.initial_qpos = initial_qpos
42 | super(ImitateEnv, self).__init__(model_path=model_path,
43 | n_substeps=n_substeps,
44 | n_actions=8,
45 | initial_qpos=initial_qpos)
46 | # Env Method
47 | # --------------------------------------------------
48 | def render(self, mode='human'):
49 | self._render_callback()
50 | if mode == 'rgb_array':
51 | self._get_viewer(mode).render(1000,1000)
52 | rgb = self._get_viewer(mode).read_pixels(1000, 1000, depth=False)
53 | #rgbd = self._get_viewer(mode).read_pixels(1000, 1000, depth=True)
54 | #rgb = rgbd[0][::-1,:,:]
55 | #depth = rgbd[1]
56 | #depth = (depth-np.amin(depth))*(255/(np.amax(depth)-np.amin(depth)))
57 | return rgb[::-1,:,:]
58 | elif mode == 'human':
59 | self._get_viewer(mode).render()
60 |
61 | def _get_viewer(self, mode):
62 | self.viewer = self._viewers.get(mode)
63 | if self.viewer is None:
64 | if mode == 'human':
65 | self.viewer = mujoco_py.MjViewer(self.sim)
66 | elif mode == 'rgb_array':
67 | self.viewer = mujoco_py.MjRenderContextOffscreen(self.sim, device_id=-1)
68 | self._viewer_setup()
69 | self._viewers[mode] = self.viewer
70 | return self.viewer
71 | # GoalEnv Method
72 | # --------------------------------------------------
73 | def compute_reward(self, achieved_goal, goal, info):
74 | d = goal_distance(achieved_goal, goal)
75 | if self.reward_type == 'sparse':
76 | return -(d > self.distance_threshold).astype(np.float32)
77 | else:
78 | return -d
79 |
80 | # RobotEnv Methods
81 | # --------------------------------------------------
82 |
83 | def _step_callback(self):
84 | """
85 | In the step function in the RobotEnv it does: 1) self._set_action(action)
86 | 2) self.sim.step()
87 | 3) self._step_callback()
88 | This method can be used to provide additional constraints to the actions that happen every step.
89 | Could be used to fix the gripper to be a certain orientation for example
90 | """
91 | pass
92 |
93 | def _set_action(self, action):
94 | """
95 | Currently, I am just returning 1 number for the gripper control because there are 2 fingers in the
96 | robot that is used in the OpenAI gym library. If I start using movo, I would need 3 fingers so
97 | perhaps increase the gripper control from 2 to 3.
98 | """
99 | assert action.shape == (8,)
100 | action = action.copy() # ensures that we don't change the action outside of this scope
101 | pos_ctrl, rot_ctrl, gripper_ctrl = action[:3], action[3:7], action[-1]
102 |
103 | # This is where I match the gripper control to number of fingers
104 | gripper_ctrl = np.array([gripper_ctrl, gripper_ctrl])
105 | assert gripper_ctrl.shape == (2,)
106 |
107 | # Create the action that we want to pass into the simulation
108 | action = np.concatenate([pos_ctrl, rot_ctrl, gripper_ctrl])
109 |
110 | # Now apply the action to the simulation
111 | utils.ctrl_set_action(self.sim, action) # note self.sim is an inherited variable
112 | utils.mocap_set_action(self.sim, action)
113 |
114 | def _get_obs(self):
115 | """
116 | This is where I make sure to grab the observations for the position and the velocities. Potential
117 | area to grab the rgb and depth images.
118 | """
119 | ee = 'robot0:grip'
120 | dt = self.sim.nsubsteps * self.sim.model.opt.timestep
121 | ee_pos = self.sim.data.get_site_xpos(ee)
122 | ee_quat = rotations.mat2quat(self.sim.data.get_site_xmat(ee))
123 | # Position and angular velocity respectively
124 | ee_velp = self.sim.data.get_site_xvelp(ee) * dt # to remove time dependency
125 | ee_velr = self.sim.data.get_site_xvelr(ee) * dt
126 |
127 | obs = np.concatenate([ee_pos, ee_quat, ee_velp, ee_velr])
128 |
129 | return {
130 | 'observation': obs.copy(),
131 | 'achieved_goal': ee_pos.copy(),
132 | 'desired_goal': self.goal.copy()
133 | }
134 |
135 | def _viewer_setup(self):
136 | """
137 | This could be used to reorient the starting view frame. Will have to mess around
138 | """
139 | body_id = self.sim.model.body_name2id('robot0:gripper_link')
140 | lookat = self.sim.data.body_xpos[body_id]
141 | for idx, value in enumerate(lookat):
142 | self.viewer.cam.lookat[idx] = value
143 | self.viewer.cam.distance = 1.7
144 | self.viewer.cam.azimuth = 180.
145 | self.viewer.cam.elevation = -30.
146 |
147 | def _sample_goal(self):
148 | """
149 | Instead of sampling I am currently just setting our defined goal as the "sampled" goal
150 | """
151 | return self.target
152 |
153 | def _is_success(self, achieved_goal, desired_goal):
154 | d = goal_distance(achieved_goal, desired_goal)
155 | return (d < self.distance_threshold).astype(np.float32)
156 |
157 | def _env_setup(self, initial_qpos):
158 | # Randomize the starting position
159 | shoulder_pan_val = -0.1 + (random.random()*0.2)
160 | initial_qpos['robot0:shoulder_pan_joint'] = shoulder_pan_val
161 | for name, value in initial_qpos.items():
162 | self.sim.data.set_joint_qpos(name, value)
163 | utils.reset_mocap_welds(self.sim)
164 | self.sim.forward()
165 |
166 | # Move end effector into position
167 | gripper_target = np.array([-0.498, 0.005, -0.431 + self.gripper_extra_height]) + self.sim.data.get_site_xpos('robot0:grip')
168 | gripper_rotation = np.array([1.,0.,1.,0.])
169 | self.sim.data.set_mocap_pos('robot0:mocap', gripper_target)
170 | self.sim.data.set_mocap_quat('robot0:mocap', gripper_rotation)
171 | for _ in range(10):
172 | self.sim.step()
173 |
174 |
175 |
176 |
177 |
178 |
179 |
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/imitate_gym/gym_eval.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import imitate_gym
3 | from PIL import Image
4 | import numpy as np
5 | import torch
6 | from crnn import SpatialCRNN
7 | import sys
8 | seq_len = 10
9 | weights = sys.argv[1]
10 | tau = [[int(sys.argv[2]), int(sys.argv[3])] for _ in range(seq_len)]
11 | tau = torch.from_numpy(np.array(tau)).type(torch.FloatTensor)
12 | tau = torch.unsqueeze(tau, dim=0)
13 | env = gym.make('buttons-v0')
14 | env.reset()
15 | model = SpatialCRNN(rtype="GRU", num_layers=2, device="cpu")
16 | model.load_state_dict(torch.load(weights, map_location='cpu'), strict=False)
17 | model.eval()
18 | eof = None
19 | rgbs = []
20 | while True:
21 | env.render()
22 | rgb = env.render('rgb_array')
23 | if eof is None:
24 | obs, _,_,_ = env.step([0.,0.,0.,0.,0.,0.,0.,0.])
25 | pos = obs['achieved_goal']
26 | eof = [pos for _ in range(seq_len)]
27 | img_rgb = Image.fromarray(rgb, 'RGB')
28 | img_rgb = np.array(img_rgb.resize((160,120)))
29 | img_rgb = np.reshape(img_rgb,(3,120,160))
30 | rgbs = [img_rgb for _ in range(seq_len)]
31 | else:
32 | img_rgb = Image.fromarray(rgb, 'RGB')
33 | img_rgb = np.array(img_rgb.resize((160,120)))
34 | img_rgb = np.reshape(img_rgb,(3,120,160))
35 | rgbs.pop(0)
36 | rgbs.append(img_rgb)
37 | input_rgb = torch.from_numpy(np.array(rgbs)).type(torch.FloatTensor)
38 | input_rgb = torch.unsqueeze(input_rgb, dim=0)
39 | input_eof = torch.from_numpy(np.array(eof)).type(torch.FloatTensor)
40 | input_eof = torch.unsqueeze(input_eof, dim=0)
41 | vels, _ = model([input_rgb, input_eof, tau])
42 | action = [vels[0][0].item(), vels[0][1].item(), vels[0][2].item(), 0., 0., 0., 0., 0.]
43 | print(action)
44 | obs, _, _, _ = env.step(action)
45 | eof.pop(0)
46 | eof.append(obs['achieved_goal'])
47 |
48 |
--------------------------------------------------------------------------------
/old_stack/simulation/gym-imitate/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(name='imitate_env',
4 | version='0.0.1',
5 | install_requires=['gym', 'mujoco-py']
6 | )
7 |
--------------------------------------------------------------------------------
/old_stack/test_baxter.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | import sys
3 | import rospy
4 | import itertools
5 | import message_filters
6 | import numpy as np
7 | import time
8 | import cv2
9 | import moveit_commander
10 | from cv_bridge import CvBridge
11 | from std_msgs.msg import Int64
12 | from geometry_msgs.msg import Point, Quaternion, Pose, Twist
13 | from sensor_msgs.msg import CompressedImage, JointState
14 |
15 | class ApproxTimeSync(message_filters.ApproximateTimeSynchronizer):
16 | def add(self, msg, my_queue, my_queue_index=None):
17 | self.allow_headerless = True
18 | if hasattr(msg, 'timestamp'):
19 | stamp = msg.timestamp
20 | elif not hasattr(msg, 'header') or not hasattr(msg.header, 'stamp'):
21 | if not self.allow_headerless:
22 | rospy.logwarn("Cannot use message filters with non-stamped messages. "
23 | "Use the 'allow_headerless' constructor option to "
24 | "auto-assign ROS time to headerless messages.")
25 | return
26 | stamp = rospy.Time.now()
27 | else:
28 | stamp = msg.header.stamp
29 |
30 | #TODO ADD HEADER TO ALLOW HEADERLESS
31 | # http://book2code.com/ros_kinetic/source/ros_comm/message_filters/src/message_filters/__init__.y
32 | #setattr(msg, 'header', a)
33 | #msg.header.stamp = stamp
34 | #super(message_filters.ApproximateTimeSynchronizer, self).add(msg, my_queue)
35 | self.lock.acquire()
36 | my_queue[stamp] = msg
37 | while len(my_queue) > self.queue_size:
38 | del my_queue[min(my_queue)]
39 | # self.queues = [topic_0 {stamp: msg}, topic_1 {stamp: msg}, ...]
40 | if my_queue_index is None:
41 | search_queues = self.queues
42 | else:
43 | search_queues = self.queues[:my_queue_index] + \
44 | self.queues[my_queue_index+1:]
45 | # sort and leave only reasonable stamps for synchronization
46 | stamps = []
47 | for queue in search_queues:
48 | topic_stamps = []
49 | for s in queue:
50 | stamp_delta = abs(s - stamp)
51 | if stamp_delta > self.slop:
52 | continue # far over the slop
53 | topic_stamps.append((s, stamp_delta))
54 | if not topic_stamps:
55 | self.lock.release()
56 | return
57 | topic_stamps = sorted(topic_stamps, key=lambda x: x[1])
58 | stamps.append(topic_stamps)
59 | for vv in itertools.product(*[zip(*s)[0] for s in stamps]):
60 | vv = list(vv)
61 | # insert the new message
62 | if my_queue_index is not None:
63 | vv.insert(my_queue_index, stamp)
64 | qt = list(zip(self.queues, vv))
65 | if ( ((max(vv) - min(vv)) < self.slop) and
66 | (len([1 for q,t in qt if t not in q]) == 0) ):
67 | msgs = [q[t] for q,t in qt]
68 | self.signalMessage(*msgs)
69 | for q,t in qt:
70 | del q[t]
71 | break # fast finish after the synchronization
72 | self.lock.release()
73 |
74 | class ImitateLearner():
75 | """
76 | This class will evaluate the outputs of the net.
77 | """
78 | def __init__(self, arm='right'):
79 | rospy.init_node("{}_arm_eval".format(arm))
80 | self.queue = []
81 | self.rgb = None
82 | self.depth = None
83 | self.pos = None
84 | self.orient = None
85 | self.prevTime = None
86 | self.time = None
87 | self.arm = arm
88 |
89 | # Initialize Subscribers
90 | self.listener()
91 |
92 | # Enable Robot
93 | moveit_commander.roscpp_initialize(sys.argv)
94 | robot = moveit_commander.RobotCommander()
95 | self.group_arms = moveit_commander.MoveGroupCommander('upper_body')
96 | self.group_arms.set_pose_reference_frame('/base_link')
97 | self.left_ee_link = 'left_ee_link'
98 | self.right_ee_link = 'right_ee_link'
99 |
100 | # Set the rate of our evaluation
101 | rate = rospy.Rate(0.5)
102 |
103 | # Give time for initialization
104 | rospy.Rate(1).sleep()
105 |
106 | # This is to the get the time delta
107 | first_time = True
108 | while not rospy.is_shutdown():
109 | # Todo: Connect Net
110 | if first_time:
111 | self.prevTime = time.time()
112 | first_time = False
113 | # Given the output, solve for limb joints
114 | #limb_joints = self.get_limb_joints(output)
115 |
116 | # If valid joints then move to joint
117 | #if limb_joints is not -1:
118 | # right.move_to_joint_positions(limb_joints)
119 | # self.prevTime = self.time
120 | #else:
121 | # print 'ERROR: IK solver returned -1'
122 | print(self.pos)
123 | rate.sleep()
124 |
125 | def listener(self):
126 | """
127 | Listener for all of the topics
128 | """
129 | print("Listener Initialized")
130 | pose_sub = message_filters.Subscriber('/tf/{}_arm_pose'.format(self.arm), Pose)
131 | twist_sub = message_filters.Subscriber('tf/{}_arm_vels'.format(self.arm), Twist)
132 | rgb_sub = message_filters.Subscriber('/kinect2/sd/image_color_rect/compressed', CompressedImage)
133 | depth_sub = message_filters.Subscriber('/kinect2/sd/image_depth_rect/compressed', CompressedImage)
134 | gripper_sub = message_filters.Subscriber('/movo/{}_gripper/gripper_is_open'.format(self.arm), Int64)
135 | ts = ApproxTimeSync([pose_sub, twist_sub, rgb_sub, depth_sub, gripper_sub], 1, 0.1)
136 | ts.registerCallback(self.listener_callback)
137 |
138 | def listener_callback(self, pose, twist, rgb, depth, gripper):
139 | """
140 | This method updates the variables.
141 | """
142 | bridge = CvBridge()
143 | self.time = time.time()
144 | self.rgb = bridge.compressed_imgmsg_to_cv2(rgb)
145 | self.depth = bridge.compressed_imgmsg_to_cv2(depth)
146 | self.pos = pose.position
147 | self.orient = pose.orientation
148 | # Create input for net. x, y, z
149 | queue_input = np.array([self.pos.x, self.pos.y, self.pos.z])
150 | if len(self.queue) == 0:
151 | self.queue = [queue_input for i in range(5)]
152 | else:
153 | self.queue.pop(0)
154 | self.queue.append(queue_input)
155 |
156 | def get_next_pose(self, output):
157 | """
158 | This method gets the ik_solver solution for the arm joints.
159 | """
160 | [goal_pos, goal_orient] = self.calculate_move(np.reshape(output[0, :3], (3,)), np.reshape(output[0, 3:], (3,)))
161 | return Point(*goal_pos), Quaternion(*goal_orient)
162 |
163 | def calculate_move(self, lin, ang):
164 | """
165 | This calculates the position and orientation (in quaterion) of the next pose given
166 | the linear and angular velocities outputted by the net.
167 | """
168 | delta = self.time - self.prevTime
169 | print("------------")
170 | print(delta)
171 | print("------------")
172 | #delta = 1/30
173 | # Position Update
174 | curr_pos = np.array([self.pos.x, self.pos.y, self.pos.z])
175 | goal_pos = np.add(curr_pos, delta*np.array(lin))
176 | # Orientation Update
177 | curr_orient = np.array([self.orient.x, self.orient.y, self.orient.z, self.orient.w])
178 | w_ang = np.concatenate([[0], ang])
179 | goal_orient = np.add(curr_orient, 0.5*delta*np.matmul(w_ang, np.transpose(curr_orient)))
180 | # Update the prevTime
181 | return goal_pos, goal_orient
182 |
183 | if __name__ == '__main__':
184 | learner = ImitateLearner()
--------------------------------------------------------------------------------
/old_stack/util/old/imitate_msgs/msg/GripperStamped.msg:
--------------------------------------------------------------------------------
1 | std_msgs/Header header
2 | std_msgs/Int64 data
3 |
--------------------------------------------------------------------------------
/old_stack/util/old/imitate_msgs/package.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | imitate_msgs
4 | 0.0.0
5 | The imitate_msgs package
6 |
7 |
8 |
9 |
10 | jonathanchang
11 |
12 |
13 |
14 |
15 |
16 | TODO
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 | catkin
52 | std_msgs
53 | std_msgs
54 | std_msgs
55 | message_generation
56 | message_runtime
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
--------------------------------------------------------------------------------
/old_stack/util/old/publishers/rightGripperPublisher.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import roslib
3 | import rospy
4 | import tf
5 | from tf.transformations import euler_from_quaternion
6 | import numpy as np
7 | import message_filters
8 | from std_msgs.msg import Int64, Header
9 | from sensor_msgs.msg import JointState
10 | from imitate_msgs.msg import GripperStamped
11 |
12 |
13 | is_open = None
14 |
15 | def callback(data):
16 | global is_open
17 | position = np.array(data.position) - 0.0016998404159380444
18 | if position.sum() < 0.05:
19 | is_open = 1
20 | else:
21 | is_open = 0
22 |
23 | # Publisher From Here
24 | rospy.init_node('movo_right_gripper')
25 | listener = rospy.Subscriber("/movo/right_gripper/joint_states", JointState, callback)
26 |
27 | # Give time for initialization
28 | rospy.Rate(1).sleep()
29 |
30 | robot_in_map = rospy.Publisher('/movo/right_gripper/gripper_is_open', GripperStamped, queue_size=1)
31 |
32 | rate = rospy.Rate(100)
33 | while not rospy.is_shutdown():
34 | h = Header()
35 | h.stamp = rospy.Time.now()
36 | msg = GripperStamped()
37 | msg.header = h
38 | msg.data = Int64(is_open)
39 | robot_in_map.publish(msg)
40 | rate.sleep()
41 |
--------------------------------------------------------------------------------
/old_stack/util/old/publishers/rightPosePublisher.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import roslib
3 | import rospy
4 | import tf
5 | from tf.transformations import euler_from_quaternion
6 | from geometry_msgs.msg import PoseStamped, Pose, Point, Quaternion
7 | from std_msgs.msg import Header
8 |
9 | rospy.init_node('movo_right_tf_listener_pose')
10 |
11 | listener = tf.TransformListener()
12 |
13 | robot_in_map = rospy.Publisher('tf/right_arm_pose', PoseStamped, queue_size=1)
14 |
15 | rate = rospy.Rate(100)
16 | while not rospy.is_shutdown():
17 | try:
18 | (trans, rot) = listener.lookupTransform('/base_link', '/right_gripper_base_link', rospy.Time(0))
19 | h = Header()
20 | h.stamp = rospy.Time.now()
21 | msg = PoseStamped()
22 | msg.header = h
23 | msg.pose = Pose(Point(*trans), Quaternion(*rot))
24 |
25 | except (tf.LookupException, tf.ConnectivityException, tf.ExtrapolationException):
26 | continue
27 | # Publish the message
28 | robot_in_map.publish(msg)
29 | rate.sleep()
30 |
--------------------------------------------------------------------------------
/old_stack/util/old/publishers/rightVelPublisher.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import numpy as np
3 | import roslib
4 | import rospy
5 | import tf
6 | from tf.transformations import euler_from_quaternion
7 | from geometry_msgs.msg import Twist, TwistStamped, Vector3
8 | from std_msgs.msg import Header
9 |
10 | rospy.init_node('movo_right_tf_listener')
11 |
12 | listener = tf.TransformListener()
13 |
14 | robot_in_map = rospy.Publisher('tf/right_arm_vels', TwistStamped, queue_size=1)
15 |
16 | prev_time = rospy.get_time()
17 | curr_time = rospy.get_time()
18 | prev_rot, prev_trans = None, None
19 |
20 | rate = rospy.Rate(100)
21 | while not rospy.is_shutdown():
22 | try:
23 | curr_time = rospy.get_time()
24 | (curr_trans, curr_rot) = listener.lookupTransform('/base_link', '/right_gripper_base_link', rospy.Time(0))
25 | # First Iteration
26 | if not prev_rot and not prev_trans:
27 | prev_trans, prev_rot = curr_trans, curr_rot
28 | continue
29 | delta = float(curr_time)-float(prev_time)
30 | lin_vel = (np.array(curr_trans) - np.array(prev_trans))/delta
31 | ang_vel = (np.array(euler_from_quaternion(curr_rot)) - np.array(euler_from_quaternion(prev_rot)))/delta
32 | h = Header()
33 | h.stamp = rospy.Time.now()
34 | msg = TwistStamped()
35 | msg.header = h
36 | msg.twist = Twist(Vector3(*lin_vel), Vector3(*ang_vel))
37 |
38 | # Update
39 | prev_time = curr_time
40 | prev_trans, prev_rot = curr_trans, curr_rot
41 |
42 | except (tf.LookupException, tf.ConnectivityException, tf.ExtrapolationException):
43 | continue
44 | # Publish the message
45 | robot_in_map.publish(msg)
46 | rate.sleep()
47 |
--------------------------------------------------------------------------------
/old_stack/util/old/publishers/toggletest.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import roslib
3 | import rospy
4 | from std_msgs.msg import Int64
5 |
6 | # Publisher From Here
7 | rospy.init_node('toggle_test')
8 |
9 | robot_in_map = rospy.Publisher('/toggle', Int64, queue_size=1)
10 |
11 | rate = rospy.Rate(100)
12 | while not rospy.is_shutdown():
13 | rate.sleep()
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/2dsim.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | import os
3 | import sys
4 | import csv
5 | import time
6 | import pygame
7 | from pygame.locals import *
8 | from PIL import Image
9 | import numpy as np
10 |
11 | task = sys.argv[1]
12 | if not os.path.exists('datas/' + task + '/'):
13 | os.mkdir('datas/' + task + '/')
14 | save_folder = None
15 | writer = None
16 | text_file = None
17 |
18 | """
19 | Goal Positions: (200, 150), (400, 150), (600, 150)
20 | (200, 300), (400, 300), (600, 300)
21 | (200, 450), (400, 450), (600, 450)
22 | """
23 | # Note that we have the goal positions listed above. The ones that are listed are the current ones that we are using
24 | GOAL_X = 200
25 | GOAL_Y = 150
26 |
27 | # These are magic numbers
28 | RECT_X = 60
29 | RECT_Y = 60
30 | SPACEBAR_KEY = 32 # pygame logic
31 | S_KEY = 115
32 |
33 | pygame.init()
34 |
35 | #Window
36 | screen = pygame.display.set_mode((800, 600))
37 | pygame.display.set_caption("2D Simulation")
38 | pygame.mouse.set_visible(1)
39 |
40 | #Background
41 | background = pygame.Surface(screen.get_size())
42 | background = background.convert()
43 | background.fill((211, 211, 211))
44 | screen.blit(background, (0, 0))
45 | pygame.display.flip()
46 |
47 | clock = pygame.time.Clock()
48 |
49 | run = True
50 | recording = False
51 | save_counter = 0
52 | idx = 0
53 | last_pos = None
54 | last_gripper = None
55 | prev_pos = None
56 |
57 | while run:
58 | # Note that this is the data collection speed
59 | clock.tick(60)
60 | for event in pygame.event.get():
61 | if event.type == pygame.QUIT:
62 | run = False
63 | if event.type == pygame.KEYUP:
64 | # Note that since we are only recording when the mouse is moving, there is never a case when the velocity is 0
65 | # This may cause problems because the net cannot learn to stop when the data suggests that it never does.
66 | # To simulate the fact that there will be a start and an end with no movement, we will save 5 instances at the beginning
67 | # and at the end
68 | if event.key == SPACEBAR_KEY: # recordings the recording
69 | recording = not recording
70 | vel = (0,0,0)
71 | if recording:
72 | folder = 'datas/'+task+'/'+str(time.time())+'/'
73 | os.mkdir(folder)
74 | save_folder = folder
75 | text_file = open(save_folder + 'vector.txt', 'w')
76 | writer = csv.writer(text_file)
77 | print("===Start Recording===")
78 | position = pygame.mouse.get_pos()
79 | buttons = pygame.mouse.get_pressed()
80 | gripper = buttons[0]
81 | for _ in range(5):
82 | now = time.time()
83 | pygame.image.save(screen, save_folder + str(idx) + "_rgb.png")
84 | depth = Image.fromarray(np.uint8(np.zeros((600,800))))
85 | depth.save(save_folder + str(idx) + "_depth.png")
86 | # Record data
87 | writer.writerow([idx, position[0], position[1], 0, 0, 0, 0, 0, vel[0], vel[1], vel[2], 0, 0, 0, gripper, now])
88 | idx += 1
89 | print(position, vel, gripper)
90 | if not recording:
91 | for _ in range(5):
92 | now = time.time()
93 | pygame.image.save(screen, save_folder + str(idx) + "_rgb.png")
94 | depth = Image.fromarray(np.uint8(np.zeros((600,800))))
95 | depth.save(save_folder + str(idx) + "_depth.png")
96 | # Record data
97 | writer.writerow([idx, last_pos[0], last_pos[1], 0, 0, 0, 0, 0, vel[0], vel[1], vel[2], 0, 0, 0, last_gripper, now])
98 | idx += 1
99 | print(last_pos, vel, last_gripper)
100 | print("---Stop Recording---")
101 | if text_file != None:
102 | text_file.close()
103 | text_file = None
104 | prev_pos = None
105 | save_counter = 0
106 | idx = 0
107 | if event.key == S_KEY: # sets the cursor postion near the relative start position
108 | print("Cursor set to position (760, 370)")
109 | pygame.mouse.set_pos([760, 370])
110 | if event.type == pygame.MOUSEMOTION: # This is the simulation of the arm. Note that left click simulates the gripper status
111 | if recording:
112 | if save_counter % 5 == 0:
113 | now = time.time()
114 | pygame.image.save(screen, save_folder+str(idx)+"_rgb.png")
115 | depth = Image.fromarray(np.uint8(np.zeros((600,800))))
116 | depth.save(save_folder + str(idx) + "_depth.png")
117 | position = event.pos
118 | print(idx)
119 | if idx == 5:
120 | vel = event.rel
121 | else:
122 | vel = np.array(position)-prev_pos
123 | gripper = event.buttons[0]
124 | writer.writerow([idx, position[0], position[1], 0, 0, 0, 0, 0, vel[0], vel[1], 0, 0, 0, 0, gripper, now])
125 | idx += 1
126 | prev_pos = np.array(position)
127 | last_pos = position
128 | last_gripper = gripper
129 | print(now, position, vel, gripper)
130 | save_counter += 1
131 | screen.fill((211,211,211))
132 | pygame.draw.rect(screen, (0,0,255), pygame.Rect(GOAL_X-RECT_X/2, GOAL_Y-RECT_Y/2, RECT_X, RECT_Y))
133 | pygame.draw.circle(screen, (255,0,0), pygame.mouse.get_pos(), 20, 1)
134 | pygame.display.update()
135 |
136 | pygame.quit()
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/README.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/README.md
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_env.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 1.0
2 | Name: imitate-env
3 | Version: 0.0.1
4 | Summary: UNKNOWN
5 | Home-page: UNKNOWN
6 | Author: UNKNOWN
7 | Author-email: UNKNOWN
8 | License: UNKNOWN
9 | Description: UNKNOWN
10 | Platform: UNKNOWN
11 |
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.registration import register
2 |
3 | register(
4 | id='buttons-v0',
5 | entry_point='imitate_gym.envs:ButtonsEnv',
6 | )
7 |
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/buttons_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import gym
4 | import argparse
5 | import imitate_gym
6 | import numpy as np
7 | from PIL import Image
8 |
9 | # GOAL_00 = np.array([1.2,0.66,0.45])
10 | # GOAL_01 = np.array([1.2,0.76,0.45])
11 | # GOAL_02 = np.array([1.2,0.86,0.45])
12 | # GOAL_10 = np.array([1.3,0.66,0.45])
13 | # GOAL_11 = np.array([1.3,0.76,0.45])
14 | # GOAL_12 = np.array([1.3,0.86,0.45])
15 | # GOAL_20 = np.array([1.4,0.66,0.45])
16 | # GOAL_21 = np.array([1.4,0.76,0.45])
17 | # GOAL_22 = np.array([1.4,0.86,0.45])
18 | # GOALS = np.array([GOAL_00,GOAL_01,GOAL_02,GOAL_10,GOAL_11,GOAL_12,GOAL_20,GOAL_21,GOAL_22])
19 | # NAMES = {0:'goal_00',1:'goal_01',2:'goal_02',3:'goal_10',4:'goal_11',5:'goal_12',6:'goal_20',7:'goal_21',8:'goal_22'}
20 | GOAL_00 = np.array([1.2,0.56,0.45])
21 | GOAL_01 = np.array([1.2,0.66,0.45])
22 | GOAL_02 = np.array([1.2,0.76,0.45])
23 | GOAL_03 = np.array([1.2,0.86,0.45])
24 | GOAL_04 = np.array([1.2,0.96,0.45])
25 | GOAL_10 = np.array([1.3,0.56,0.45])
26 | GOAL_11 = np.array([1.3,0.66,0.45])
27 | GOAL_12 = np.array([1.3,0.76,0.45])
28 | GOAL_13 = np.array([1.3,0.86,0.45])
29 | GOAL_14 = np.array([1.3,0.96,0.45])
30 | GOAL_20 = np.array([1.4,0.56,0.45])
31 | GOAL_21 = np.array([1.4,0.66,0.45])
32 | GOAL_22 = np.array([1.4,0.76,0.45])
33 | GOAL_23 = np.array([1.4,0.86,0.45])
34 | GOAL_24 = np.array([1.4,0.96,0.45])
35 | GOALS = np.array([GOAL_00,
36 | GOAL_01,
37 | GOAL_02,
38 | GOAL_03,
39 | GOAL_04,
40 | GOAL_10,
41 | GOAL_11,
42 | GOAL_12,
43 | GOAL_13,
44 | GOAL_14,
45 | GOAL_20,
46 | GOAL_21,
47 | GOAL_22,
48 | GOAL_23,
49 | GOAL_24])
50 | NAMES = {0:'goal_00',
51 | 1:'goal_01',
52 | 2:'goal_02',
53 | 3:'goal_03',
54 | 4:'goal_04',
55 | 5:'goal_10',
56 | 6:'goal_11',
57 | 7:'goal_12',
58 | 8:'goal_13',
59 | 9:'goal_14',
60 | 10:'goal_20',
61 | 11:'goal_21',
62 | 12:'goal_22',
63 | 13:'goal_23',
64 | 14:'goal_24'}
65 | test=np.array([GOAL_24])
66 |
67 | def calc_changes(start_pos, goal_pos, num_steps=50.):
68 | return (goal_pos-start_pos)/num_steps
69 |
70 | def next_move(curr_pos, goal_pos, slope):
71 | diffs = np.absolute(curr_pos - goal_pos)
72 | if diffs[0] <= 0.001:
73 | slope[0] = 0.
74 | if diffs[1] <= 0.001:
75 | slope[1] = 0.
76 | if diffs[2] <= 0.001:
77 | slope[2] = 0.
78 | return slope
79 |
80 | def run(env, goal, trial_dir, viz=False):
81 | if viz == False:
82 | with open(trial_dir+'vectors.txt', 'w') as f:
83 | writer = csv.writer(f)
84 | curr_pos = None
85 | slope = None
86 | goal_reached = False
87 | counter = 0
88 | traj_counter = 0
89 | other = np.array([0.,0.,0.,0.,0.])
90 | while True:
91 | rgb = env.render('rgb_array')
92 | if curr_pos is None and goal_reached is False:
93 | obs, _, _, _ = env.step([0.,0.,0.,0.,0.,0.,0.,0.])
94 | curr_pos = obs['achieved_goal']
95 | slope = calc_changes(curr_pos, goal)
96 | elif curr_pos is not None and goal_reached is False:
97 | slope = next_move(curr_pos, goal, slope)
98 | if sum(np.absolute(slope)) < 0.008 and curr_pos[2] < 0.455:
99 | goal_reached = True
100 | action = np.concatenate([slope,other])
101 | obs, _, _, _ = env.step(action)
102 | img_rgb = Image.fromarray(rgb, 'RGB')
103 | img_rgb = img_rgb.resize((160,120))
104 | img_rgb.save(trial_dir+'{}.png'.format(traj_counter))
105 | arr = [traj_counter]
106 | arr += [x for x in obs['observation']]
107 | arr += [0]
108 | curr_pos = obs['achieved_goal']
109 | traj_counter += 1
110 | writer.writerow(arr)
111 | elif curr_pos is not None and goal_reached is True:
112 | # This is to simulate the terminal state where the user would stop near the end
113 | counter += 1
114 | obs, _, _, _ = env.step([0.,0.,0.,0.,0.,0.,0.,0.])
115 | img_rgb = Image.fromarray(rgb, 'RGB')
116 | img_rgb = img_rgb.resize((160,120))
117 | img_rgb.save(trial_dir+'{}.png'.format(traj_counter))
118 | arr = [traj_counter]
119 | arr += [x for x in obs['observation']]
120 | arr += [0]
121 | traj_counter += 1
122 | writer.writerow(arr)
123 | if counter == 20:
124 | break
125 | else:
126 | curr_pos = None
127 | slope = None
128 | goal_reached = False
129 | counter = 0
130 | other = np.array([0.,0.,0.,0.,0.])
131 | while True:
132 | env.render()
133 | if curr_pos is None and goal_reached is False:
134 | obs, _, _, _ = env.step([0.,0.,0.,0.,0.,0.,0.,0.])
135 | curr_pos = obs['achieved_goal']
136 | slope = calc_changes(curr_pos, goal)
137 | elif curr_pos is not None and goal_reached is False:
138 | slope = next_move(curr_pos, goal, slope)
139 | if sum(np.absolute(slope)) < 0.008 and curr_pos[2] < 0.455:
140 | goal_reached = True
141 | action = np.concatenate([slope,other])
142 | obs, _, _, _ = env.step(action)
143 | print("==================")
144 | print(slope)
145 | print(obs['observation'][7:10])
146 | curr_pos = obs['achieved_goal']
147 | print(curr_pos)
148 | elif curr_pos is not None and goal_reached is True:
149 | # This is to simulate the terminal state where the user would stop near the end
150 | counter += 1
151 | obs, _, _, _ = env.step([0.,0.,0.,0.,0.,0.,0.,0.])
152 | print("==================")
153 | print(slope)
154 | print(obs['observation'][7:10])
155 | if counter == 20:
156 | break
157 |
158 |
159 |
160 |
161 | def main(task_name, num_trials=300, viz=False):
162 | if viz == False:
163 | if not os.path.exists('../../../datas/' + task_name + '/'):
164 | os.mkdir('../../../datas/' + task_name + '/')
165 | trial_counter = 0
166 | for trial in range(num_trials):
167 | env = gym.make('buttons-v0')
168 | for i, goal in enumerate(GOALS):
169 | save_folder = '../../../datas/'+task_name+'/'+NAMES[i]+'/'
170 | if viz == False:
171 | if not os.path.exists(save_folder):
172 | os.mkdir(save_folder)
173 | trial_dir = save_folder+str(trial_counter)+'/'
174 | if viz == False:
175 | os.mkdir(trial_dir)
176 | env.reset()
177 | run(env, goal, trial_dir, viz=viz)
178 | trial_counter+=1
179 |
180 | if __name__ == '__main__':
181 | main("buttons3x5", num_trials=1, viz=True)
182 |
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from imitate_gym.envs.imitate_env import ImitateEnv
2 | from imitate_gym.envs.buttons import ButtonsEnv
3 |
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/buttons/buttons.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/buttons/buttons_3x3.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/buttons/robot.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/buttons/shared.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/.get:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/.get
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/base_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/base_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/bellows_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/bellows_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/elbow_flex_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/elbow_flex_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/estop_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/estop_link.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/forearm_roll_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/forearm_roll_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/gripper_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/gripper_link.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/head_pan_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/head_pan_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/head_tilt_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/head_tilt_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/l_wheel_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/l_wheel_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/laser_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/laser_link.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/r_wheel_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/r_wheel_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/shoulder_lift_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/shoulder_lift_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/shoulder_pan_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/shoulder_pan_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/torso_fixed_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/torso_fixed_link.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/torso_lift_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/torso_lift_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/upperarm_roll_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/upperarm_roll_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/wrist_flex_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/wrist_flex_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/wrist_roll_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/fetch/wrist_roll_link_collision.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F1.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F2.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F3.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/F3.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH1_z.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH1_z.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH2_z.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH2_z.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH3_z.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/TH3_z.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/forearm_electric.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/forearm_electric.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/forearm_electric_cvx.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/forearm_electric_cvx.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/knuckle.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/knuckle.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/lfmetacarpal.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/lfmetacarpal.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/palm.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/palm.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/wrist.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/stls/hand/wrist.stl
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/textures/block.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/textures/block.png
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/textures/block_hidden.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/h2r/parameterized-imitation-learning/b73d50f7824305c0bf35526c4271a02f5c813f2f/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/assets/textures/block_hidden.png
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/buttons.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from gym import utils
4 | from imitate_gym.envs import imitate_env
5 |
6 | #MODEL_XML_PATH = os.path.join('fetch', 'pick_and_place.xml')
7 | HOME = '/home/jonathanchang/parameterized-imitation-learning/simulation/gym-imitate/imitate_gym/envs'
8 | MODEL_XML_PATH = HOME+'/assets/buttons/buttons_3x3.xml'
9 | class ButtonsEnv(imitate_env.ImitateEnv, utils.EzPickle):
10 | def __init__(self, reward_type='dense'):
11 | initial_qpos = {
12 | 'robot0:slide0': 0.405,
13 | 'robot0:slide1': 0.48,
14 | 'robot0:slide2': 0.0,
15 | 'robot0:shoulder_pan_joint': 0.0,
16 | 'robot0:shoulder_lift_joint': -0.8,
17 | 'robot0:elbow_flex_joint': 1.0
18 | }
19 | target = np.array([1.25, 0.53, 0.4])
20 | imitate_env.ImitateEnv.__init__(
21 | self, model_path=MODEL_XML_PATH, n_substeps=25, initial_qpos=initial_qpos, reward_type=reward_type,
22 | distance_threshold=0.05, gripper_extra_height=0.2, target=target)
23 | utils.EzPickle.__init__(self)
24 |
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/envs/imitate_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import gym
4 | import mujoco_py
5 | from gym import error, spaces, utils
6 | from gym.envs.robotics import rotations, robot_env, utils
7 |
8 | # Just to make sure that input shape is correct
9 | def goal_distance(goal_a, goal_b):
10 | assert goal_a.shape == goal_b.shape
11 | return np.linalg.norm(goal_a - goal_b, axis=-1)
12 |
13 | class ImitateEnv(robot_env.RobotEnv):
14 | """
15 | Inheritance Chain: Env -> GoalEnv -> RobotEnv -> ImitateEnv
16 |
17 | Methods from RobotEnv
18 | seed(), step(), reset(), close(), render(), _reset_sim(),
19 | _get_viewer(): This method will be important for retrieving rgb and depth from the sim
20 | """
21 | def __init__(
22 | self, model_path, n_substeps, initial_qpos, reward_type,
23 | distance_threshold, target, gripper_extra_height
24 | ):
25 | """
26 | Args: model_path (string): path to the environments XML file
27 | n_substeps (int): number of substeps the simulation runs on every call to step
28 | initial_qpos (dict): a dictionary of joint names and values that define the initial configuration
29 | reward_type ('sparse' or 'dense'): the reward type, i.e. sparse or dense
30 | distance_threshold (float): the threshold after which a goal is considered achieved
31 | target (list): the target position that we are aiming for
32 | gripper_extra_height (float): the gripper offset in position
33 | """
34 | # n_actions = 8 for [pos_x, pos_y, pos_z, rot_x, rot_y, rot_z, rot_w, gripper]
35 |
36 | self.reward_type = reward_type
37 | self.distance_threshold = distance_threshold
38 | self.target = target
39 | self.gripper_extra_height = gripper_extra_height
40 | self._viewers = {}
41 | self.initial_qpos = initial_qpos
42 | super(ImitateEnv, self).__init__(model_path=model_path,
43 | n_substeps=n_substeps,
44 | n_actions=8,
45 | initial_qpos=initial_qpos)
46 | # Env Method
47 | # --------------------------------------------------
48 | def render(self, mode='human'):
49 | self._render_callback()
50 | if mode == 'rgb_array':
51 | self._get_viewer(mode).render(1000,1000)
52 | rgb = self._get_viewer(mode).read_pixels(1000, 1000, depth=False)
53 | #rgbd = self._get_viewer(mode).read_pixels(1000, 1000, depth=True)
54 | #rgb = rgbd[0][::-1,:,:]
55 | #depth = rgbd[1]
56 | #depth = (depth-np.amin(depth))*(255/(np.amax(depth)-np.amin(depth)))
57 | return rgb[::-1,:,:]
58 | elif mode == 'human':
59 | self._get_viewer(mode).render()
60 |
61 | def _get_viewer(self, mode):
62 | self.viewer = self._viewers.get(mode)
63 | if self.viewer is None:
64 | if mode == 'human':
65 | self.viewer = mujoco_py.MjViewer(self.sim)
66 | elif mode == 'rgb_array':
67 | self.viewer = mujoco_py.MjRenderContextOffscreen(self.sim, device_id=-1)
68 | self._viewer_setup()
69 | self._viewers[mode] = self.viewer
70 | return self.viewer
71 | # GoalEnv Method
72 | # --------------------------------------------------
73 | def compute_reward(self, achieved_goal, goal, info):
74 | d = goal_distance(achieved_goal, goal)
75 | if self.reward_type == 'sparse':
76 | return -(d > self.distance_threshold).astype(np.float32)
77 | else:
78 | return -d
79 |
80 | # RobotEnv Methods
81 | # --------------------------------------------------
82 |
83 | def _step_callback(self):
84 | """
85 | In the step function in the RobotEnv it does: 1) self._set_action(action)
86 | 2) self.sim.step()
87 | 3) self._step_callback()
88 | This method can be used to provide additional constraints to the actions that happen every step.
89 | Could be used to fix the gripper to be a certain orientation for example
90 | """
91 | pass
92 |
93 | def _set_action(self, action):
94 | """
95 | Currently, I am just returning 1 number for the gripper control because there are 2 fingers in the
96 | robot that is used in the OpenAI gym library. If I start using movo, I would need 3 fingers so
97 | perhaps increase the gripper control from 2 to 3.
98 | """
99 | assert action.shape == (8,)
100 | action = action.copy() # ensures that we don't change the action outside of this scope
101 | pos_ctrl, rot_ctrl, gripper_ctrl = action[:3], action[3:7], action[-1]
102 |
103 | # This is where I match the gripper control to number of fingers
104 | gripper_ctrl = np.array([gripper_ctrl, gripper_ctrl])
105 | assert gripper_ctrl.shape == (2,)
106 |
107 | # Create the action that we want to pass into the simulation
108 | action = np.concatenate([pos_ctrl, rot_ctrl, gripper_ctrl])
109 |
110 | # Now apply the action to the simulation
111 | utils.ctrl_set_action(self.sim, action) # note self.sim is an inherited variable
112 | utils.mocap_set_action(self.sim, action)
113 |
114 | def _get_obs(self):
115 | """
116 | This is where I make sure to grab the observations for the position and the velocities. Potential
117 | area to grab the rgb and depth images.
118 | """
119 | ee = 'robot0:grip'
120 | dt = self.sim.nsubsteps * self.sim.model.opt.timestep
121 | ee_pos = self.sim.data.get_site_xpos(ee)
122 | ee_quat = rotations.mat2quat(self.sim.data.get_site_xmat(ee))
123 | # Position and angular velocity respectively
124 | ee_velp = self.sim.data.get_site_xvelp(ee) * dt # to remove time dependency
125 | ee_velr = self.sim.data.get_site_xvelr(ee) * dt
126 |
127 | obs = np.concatenate([ee_pos, ee_quat, ee_velp, ee_velr])
128 |
129 | return {
130 | 'observation': obs.copy(),
131 | 'achieved_goal': ee_pos.copy(),
132 | 'desired_goal': self.goal.copy()
133 | }
134 |
135 | def _viewer_setup(self):
136 | """
137 | This could be used to reorient the starting view frame. Will have to mess around
138 | """
139 | body_id = self.sim.model.body_name2id('robot0:gripper_link')
140 | lookat = self.sim.data.body_xpos[body_id]
141 | for idx, value in enumerate(lookat):
142 | self.viewer.cam.lookat[idx] = value
143 | self.viewer.cam.distance = 1.7
144 | self.viewer.cam.azimuth = 180.
145 | self.viewer.cam.elevation = -30.
146 |
147 | def _sample_goal(self):
148 | """
149 | Instead of sampling I am currently just setting our defined goal as the "sampled" goal
150 | """
151 | return self.target
152 |
153 | def _is_success(self, achieved_goal, desired_goal):
154 | d = goal_distance(achieved_goal, desired_goal)
155 | return (d < self.distance_threshold).astype(np.float32)
156 |
157 | def _env_setup(self, initial_qpos):
158 | # Randomize the starting position
159 | shoulder_pan_val = -0.1 + (random.random()*0.2)
160 | initial_qpos['robot0:shoulder_pan_joint'] = shoulder_pan_val
161 | for name, value in initial_qpos.items():
162 | self.sim.data.set_joint_qpos(name, value)
163 | utils.reset_mocap_welds(self.sim)
164 | self.sim.forward()
165 |
166 | # Move end effector into position
167 | gripper_target = np.array([-0.498, 0.005, -0.431 + self.gripper_extra_height]) + self.sim.data.get_site_xpos('robot0:grip')
168 | gripper_rotation = np.array([1.,0.,1.,0.])
169 | self.sim.data.set_mocap_pos('robot0:mocap', gripper_target)
170 | self.sim.data.set_mocap_quat('robot0:mocap', gripper_rotation)
171 | for _ in range(10):
172 | self.sim.step()
173 |
174 |
175 |
176 |
177 |
178 |
179 |
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/imitate_gym/gym_eval.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import imitate_gym
3 | from PIL import Image
4 | import numpy as np
5 | import torch
6 | from crnn import SpatialCRNN
7 | import sys
8 | seq_len = 10
9 | weights = sys.argv[1]
10 | tau = [[int(sys.argv[2]), int(sys.argv[3])] for _ in range(seq_len)]
11 | tau = torch.from_numpy(np.array(tau)).type(torch.FloatTensor)
12 | tau = torch.unsqueeze(tau, dim=0)
13 | env = gym.make('buttons-v0')
14 | env.reset()
15 | model = SpatialCRNN(rtype="GRU", num_layers=2, device="cpu")
16 | model.load_state_dict(torch.load(weights, map_location='cpu'), strict=False)
17 | model.eval()
18 | eof = None
19 | rgbs = []
20 | while True:
21 | env.render()
22 | rgb = env.render('rgb_array')
23 | if eof is None:
24 | obs, _,_,_ = env.step([0.,0.,0.,0.,0.,0.,0.,0.])
25 | pos = obs['achieved_goal']
26 | eof = [pos for _ in range(seq_len)]
27 | img_rgb = Image.fromarray(rgb, 'RGB')
28 | img_rgb = np.array(img_rgb.resize((160,120)))
29 | img_rgb = np.reshape(img_rgb,(3,120,160))
30 | rgbs = [img_rgb for _ in range(seq_len)]
31 | else:
32 | img_rgb = Image.fromarray(rgb, 'RGB')
33 | img_rgb = np.array(img_rgb.resize((160,120)))
34 | img_rgb = np.reshape(img_rgb,(3,120,160))
35 | rgbs.pop(0)
36 | rgbs.append(img_rgb)
37 | input_rgb = torch.from_numpy(np.array(rgbs)).type(torch.FloatTensor)
38 | input_rgb = torch.unsqueeze(input_rgb, dim=0)
39 | input_eof = torch.from_numpy(np.array(eof)).type(torch.FloatTensor)
40 | input_eof = torch.unsqueeze(input_eof, dim=0)
41 | vels, _ = model([input_rgb, input_eof, tau])
42 | action = [vels[0][0].item(), vels[0][1].item(), vels[0][2].item(), 0., 0., 0., 0., 0.]
43 | print(action)
44 | obs, _, _, _ = env.step(action)
45 | eof.pop(0)
46 | eof.append(obs['achieved_goal'])
47 |
48 |
--------------------------------------------------------------------------------
/old_stack/util/old/simulation/gym-imitate/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(name='imitate_env',
4 | version='0.0.1',
5 | install_requires=['gym', 'mujoco-py']
6 | )
7 |
--------------------------------------------------------------------------------
/old_stack/util/old/test_baxter.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | import sys
3 | import rospy
4 | import itertools
5 | import message_filters
6 | import numpy as np
7 | import time
8 | import cv2
9 | import moveit_commander
10 | from cv_bridge import CvBridge
11 | from std_msgs.msg import Int64
12 | from geometry_msgs.msg import Point, Quaternion, Pose, Twist
13 | from sensor_msgs.msg import CompressedImage, JointState
14 |
15 | class ApproxTimeSync(message_filters.ApproximateTimeSynchronizer):
16 | def add(self, msg, my_queue, my_queue_index=None):
17 | self.allow_headerless = True
18 | if hasattr(msg, 'timestamp'):
19 | stamp = msg.timestamp
20 | elif not hasattr(msg, 'header') or not hasattr(msg.header, 'stamp'):
21 | if not self.allow_headerless:
22 | rospy.logwarn("Cannot use message filters with non-stamped messages. "
23 | "Use the 'allow_headerless' constructor option to "
24 | "auto-assign ROS time to headerless messages.")
25 | return
26 | stamp = rospy.Time.now()
27 | else:
28 | stamp = msg.header.stamp
29 |
30 | #TODO ADD HEADER TO ALLOW HEADERLESS
31 | # http://book2code.com/ros_kinetic/source/ros_comm/message_filters/src/message_filters/__init__.y
32 | #setattr(msg, 'header', a)
33 | #msg.header.stamp = stamp
34 | #super(message_filters.ApproximateTimeSynchronizer, self).add(msg, my_queue)
35 | self.lock.acquire()
36 | my_queue[stamp] = msg
37 | while len(my_queue) > self.queue_size:
38 | del my_queue[min(my_queue)]
39 | # self.queues = [topic_0 {stamp: msg}, topic_1 {stamp: msg}, ...]
40 | if my_queue_index is None:
41 | search_queues = self.queues
42 | else:
43 | search_queues = self.queues[:my_queue_index] + \
44 | self.queues[my_queue_index+1:]
45 | # sort and leave only reasonable stamps for synchronization
46 | stamps = []
47 | for queue in search_queues:
48 | topic_stamps = []
49 | for s in queue:
50 | stamp_delta = abs(s - stamp)
51 | if stamp_delta > self.slop:
52 | continue # far over the slop
53 | topic_stamps.append((s, stamp_delta))
54 | if not topic_stamps:
55 | self.lock.release()
56 | return
57 | topic_stamps = sorted(topic_stamps, key=lambda x: x[1])
58 | stamps.append(topic_stamps)
59 | for vv in itertools.product(*[zip(*s)[0] for s in stamps]):
60 | vv = list(vv)
61 | # insert the new message
62 | if my_queue_index is not None:
63 | vv.insert(my_queue_index, stamp)
64 | qt = list(zip(self.queues, vv))
65 | if ( ((max(vv) - min(vv)) < self.slop) and
66 | (len([1 for q,t in qt if t not in q]) == 0) ):
67 | msgs = [q[t] for q,t in qt]
68 | self.signalMessage(*msgs)
69 | for q,t in qt:
70 | del q[t]
71 | break # fast finish after the synchronization
72 | self.lock.release()
73 |
74 | class ImitateLearner():
75 | """
76 | This class will evaluate the outputs of the net.
77 | """
78 | def __init__(self, arm='right'):
79 | rospy.init_node("{}_arm_eval".format(arm))
80 | self.queue = []
81 | self.rgb = None
82 | self.depth = None
83 | self.pos = None
84 | self.orient = None
85 | self.prevTime = None
86 | self.time = None
87 | self.arm = arm
88 |
89 | # Initialize Subscribers
90 | self.listener()
91 |
92 | # Enable Robot
93 | moveit_commander.roscpp_initialize(sys.argv)
94 | robot = moveit_commander.RobotCommander()
95 | self.group_arms = moveit_commander.MoveGroupCommander('upper_body')
96 | self.group_arms.set_pose_reference_frame('/base_link')
97 | self.left_ee_link = 'left_ee_link'
98 | self.right_ee_link = 'right_ee_link'
99 |
100 | # Set the rate of our evaluation
101 | rate = rospy.Rate(0.5)
102 |
103 | # Give time for initialization
104 | rospy.Rate(1).sleep()
105 |
106 | # This is to the get the time delta
107 | first_time = True
108 | while not rospy.is_shutdown():
109 | # Todo: Connect Net
110 | if first_time:
111 | self.prevTime = time.time()
112 | first_time = False
113 | # Given the output, solve for limb joints
114 | #limb_joints = self.get_limb_joints(output)
115 |
116 | # If valid joints then move to joint
117 | #if limb_joints is not -1:
118 | # right.move_to_joint_positions(limb_joints)
119 | # self.prevTime = self.time
120 | #else:
121 | # print 'ERROR: IK solver returned -1'
122 | print(self.pos)
123 | rate.sleep()
124 |
125 | def listener(self):
126 | """
127 | Listener for all of the topics
128 | """
129 | print("Listener Initialized")
130 | pose_sub = message_filters.Subscriber('/tf/{}_arm_pose'.format(self.arm), Pose)
131 | twist_sub = message_filters.Subscriber('tf/{}_arm_vels'.format(self.arm), Twist)
132 | rgb_sub = message_filters.Subscriber('/kinect2/sd/image_color_rect/compressed', CompressedImage)
133 | depth_sub = message_filters.Subscriber('/kinect2/sd/image_depth_rect/compressed', CompressedImage)
134 | gripper_sub = message_filters.Subscriber('/movo/{}_gripper/gripper_is_open'.format(self.arm), Int64)
135 | ts = ApproxTimeSync([pose_sub, twist_sub, rgb_sub, depth_sub, gripper_sub], 1, 0.1)
136 | ts.registerCallback(self.listener_callback)
137 |
138 | def listener_callback(self, pose, twist, rgb, depth, gripper):
139 | """
140 | This method updates the variables.
141 | """
142 | bridge = CvBridge()
143 | self.time = time.time()
144 | self.rgb = bridge.compressed_imgmsg_to_cv2(rgb)
145 | self.depth = bridge.compressed_imgmsg_to_cv2(depth)
146 | self.pos = pose.position
147 | self.orient = pose.orientation
148 | # Create input for net. x, y, z
149 | queue_input = np.array([self.pos.x, self.pos.y, self.pos.z])
150 | if len(self.queue) == 0:
151 | self.queue = [queue_input for i in range(5)]
152 | else:
153 | self.queue.pop(0)
154 | self.queue.append(queue_input)
155 |
156 | def get_next_pose(self, output):
157 | """
158 | This method gets the ik_solver solution for the arm joints.
159 | """
160 | [goal_pos, goal_orient] = self.calculate_move(np.reshape(output[0, :3], (3,)), np.reshape(output[0, 3:], (3,)))
161 | return Point(*goal_pos), Quaternion(*goal_orient)
162 |
163 | def calculate_move(self, lin, ang):
164 | """
165 | This calculates the position and orientation (in quaterion) of the next pose given
166 | the linear and angular velocities outputted by the net.
167 | """
168 | delta = self.time - self.prevTime
169 | print("------------")
170 | print(delta)
171 | print("------------")
172 | #delta = 1/30
173 | # Position Update
174 | curr_pos = np.array([self.pos.x, self.pos.y, self.pos.z])
175 | goal_pos = np.add(curr_pos, delta*np.array(lin))
176 | # Orientation Update
177 | curr_orient = np.array([self.orient.x, self.orient.y, self.orient.z, self.orient.w])
178 | w_ang = np.concatenate([[0], ang])
179 | goal_orient = np.add(curr_orient, 0.5*delta*np.matmul(w_ang, np.transpose(curr_orient)))
180 | # Update the prevTime
181 | return goal_pos, goal_orient
182 |
183 | if __name__ == '__main__':
184 | learner = ImitateLearner()
--------------------------------------------------------------------------------
/old_stack/util/parse_data.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | #import cv2
3 | import numpy as np
4 | import os
5 | import pandas as pd
6 | import csv
7 |
8 | from PIL import Image
9 | from random import shuffle
10 |
11 | import re
12 | import time
13 | import math
14 | import sys
15 | import copy
16 | import shutil
17 | from functools import cmp_to_key
18 | from tqdm import tqdm
19 |
20 | splits = {}
21 |
22 | def preprocess_images(root_diri, cases):
23 | for case in cases:
24 | dirs = [x[0] for x in os.walk(root_dir + case)]
25 | for sub_dir in dirs:
26 | for root, _, files in os.walk(sub_dir):
27 | files = sorted(files)
28 | files = files[:-1]
29 | for i in range(0, len(files), 2):
30 | depth = Image.open(root+"/"+files[i])
31 | rgb = Image.open(root+"/"+files[i+1])
32 | rgb = rgb.resize((160,120))
33 | depth = depth.resize((160,120))
34 | depth.save(root+"/"+files[i])
35 | rgb.save(root+"/"+files[i+1])
36 |
37 |
38 | def parse_param_2(root_dir, mode, cases, tau, seq_len, split_percen, dest, directions=None):
39 | global splits
40 | file = open(dest+"/"+ mode + "_data.csv", "w+")
41 | writer = csv.writer(file, delimiter=',')
42 | # For every folder, tau, and direction
43 | for case, tau in tqdm(zip(cases, tau)):
44 | #print(case)
45 | #print(tau)
46 | # Create the splits
47 | if case not in splits.keys():
48 | dirs = [x[0] for x in os.walk(root_dir + case)][1:]
49 | shuffle(dirs)
50 | split_idx = int(math.ceil(len(dirs)*float(split_percen)))
51 | splits[case] = {"train": dirs[:split_idx], "test": dirs[split_idx:]}
52 | # Go into every subdirectory
53 | sub_dirs = splits[case]
54 | #dirs = [x[0] for x in os.walk(root_dir + case)][1:]
55 | for sub_dir in sub_dirs[mode]:
56 | for root, _, file in os.walk(sub_dir):
57 | file = sorted(file)
58 | pics = file[:-1]
59 | pics = sorted(pics, key=cmp_to_key(compare_names))
60 | vectors = pd.read_csv(root+"/"+file[-1], header=-1)
61 | # We will start from 10 to start creating a sequential dataset (rgb, depth)
62 | # The length of the history will be 5
63 | #for i in range((seq_len*2)-2, len(pics), 2):
64 | for i in range(0, len(pics)):
65 | # First get the current directory
66 | row = [root, int(pics[i][:-4]), tau[0], tau[1]]
67 | #row = [root, int(pics[i][:-10]), tau[0], tau[1]]
68 | #depth = [root+"/"+pics[i-j] for j in range(seq_len,-1,-2)]
69 | #rgb = [root+"/"+pics[i-j+1] for j in range(seq_len,-1,-2)]
70 | #tau = [tau for _ in range(seq_len)]
71 | #print(pics)
72 | #print(i)
73 | #print([pics[i-j] for j in range(seq_len*2, -1, -2)])
74 | if i == 0:
75 | prevs = [int(pics[i][:-4]) for _ in range(seq_len)]
76 | elif i != 0 and i < seq_len:
77 | prevs.pop(0)
78 | prevs.append(int(pics[i][:-4]))
79 | else:
80 | prevs = [int(pics[i-j][:-4]) for j in range(seq_len-1, -1, -1)]
81 | #prevs = [int(pics[i-j][:-10]) for j in range((seq_len*2)-2, -1, -2)]
82 | eof = []
83 | for prev in prevs:
84 | pos = [float(vectors[vectors[0]==prev][j]) for j in range(1,4)]
85 | eof += pos
86 | ## Label and Gripper still stay the same
87 | label = [float(vectors[vectors[0]==float(pics[i][:-4])][j]) for j in range(8,14)]
88 | aux_label = [float(vectors[vectors[0]==float(pics[-1][:-4])][j]) for j in range(1,8)]
89 | #label = [float(vectors[vectors[0]==float(pics[i][:-10])][j]) for j in range(8,14)]
90 | #aux_label = [float(vectors[vectors[0]==float(pics[-2][:-10])][j]) for j in range(1,8)]
91 | #print(label)
92 | row += prevs
93 | row += label
94 | row += aux_label
95 | row += eof
96 | #gripper = repr(int(vectors[vectors[0]==float(pics[i][:-8])][14]))
97 | ## This is how we will represent the data
98 | #writer.writerow([depth, rgb, eof, label, gripper, tau])
99 | writer.writerow(row)
100 | #print(prevs)
101 | print("{} Dta Creation Done".format(mode))
102 |
103 | def compare_names(name1, name2):
104 | num1 = extract_number(name1)
105 | num2 = extract_number(name2)
106 | if num1 == num2:
107 | if name1 > name2:
108 | return 1
109 | else:
110 | return -1
111 | else:
112 | return num1 - num2
113 |
114 | def extract_number(name):
115 | numbers = re.findall('\d+', name)
116 | return int(numbers[0])
117 |
118 | def clean_data(root_dir, cases):
119 | for case in cases:
120 | # Create the splits
121 | dirs = [x[0] for x in os.walk(root_dir + case)][1:]
122 | # Go into every subdirectory
123 | for sub_dir in dirs:
124 | for root, _, file in os.walk(sub_dir):
125 | print(root)
126 | file = sorted(file)
127 | vector_file = root+"/"+file[-1]
128 | vectors = csv.read_csv(vector_file, header=-1)
129 | delete_rows = []
130 | for i in range(1, len(file), 2):
131 | label = [round(float(vectors[vectors[0]==float(file[i][:-8])][j])*10) for j in range(8,14)]
132 | if sum(label) == 0:
133 | #print(float(file[i][:-8]))
134 | #print(root+"/"+file[i]) #rgb
135 | #print(root+"/"+file[i-1]) #depth
136 | delete_rws.append(file[i][:-8])
137 | os.remove(root+"/"+file[i])
138 | os.remove(root+"/"+file[i-1])
139 | if (len(file)-1)/2 != len(delete_rows):
140 | last = None
141 | last_num = 0
142 | clean_file = root+"/vector2.txt"
143 | with open(vector_file, "rb") as input, open(clean_file, "wb") as out:
144 | writer = csv.writer(out)
145 | for row in csv.reader(input):
146 | if row[0] not in delete_rows:
147 | writer.writerow(row)
148 | with open(clean_file, "r") as fd:
149 | last = [l for l in fd][-1]
150 | last = last.strip().split(',')
151 | last_num = int(last[0])
152 | last[8:14] = ['0.0' for _ in range(8,14)]
153 | counter = last_num + 1
154 | with open(clean_file, "a") as fd:
155 | for _ in range(10):
156 | shutil.copy(root+"/"+str(last_num) + "_depth.png", root+ "/" + str(counter) + "_depth.png")
157 | shutil.copy(root+"/"+str(last_num) + "_rgb.png", root+"/" + str(counter) + "_rgb.png")
158 | row = copy.deepcopy(last)
159 | row[0] = str(counter)
160 | row = ",".join(row) + "\n"
161 | counter += 1
162 | fd.write(row)
163 | os.remove(vector_file)
164 | if len(os.listdir(sub_dir)) == 0:
165 | os.rmdir(sub_dir)
166 |
167 | if __name__ == '__main__':
168 | tau = [[71,77],[91,77],[81,83],[70,89],[92,89]]
169 | cases = ['/goal_00', '/goal_02', '/goal_11', '/goal_20', '/goal_22']
170 | #tau = [[81,77], [71,83], [91,83], [81,89]]
171 | #cases = ['/goal_01', '/goal_10', '/goal_12', '/goal_21']
172 | modes = ["train"]
173 | seq_len = 10
174 | root_dir = sys.argv[1]
175 | split_percen = sys.argv[2]
176 | dest = sys.argv[3]
177 | #clean_data(root_dir, cases)
178 | #preprocess_images(root_dir, cases)
179 | datasets = {mode: parse_param_2(root_dir, mode, cases, tau, seq_len, split_percen, dest=dest) for mode in modes}
180 |
--------------------------------------------------------------------------------
/old_stack/util/record.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | import cv2
3 | import time
4 | import csv
5 | import os
6 | import sys
7 | import rospy
8 | import itertools
9 | import numpy as np
10 | from cv_bridge import CvBridge
11 | from namedlist import namedlist
12 | from std_msgs.msg import Int64, String
13 | from sensor_msgs.msg import CompressedImage, Image, JointState
14 | from geometry_msgs.msg import Twist, Pose, TwistStamped, PoseStamped
15 | from imitate_msgs.msg import GripperStamped
16 |
17 | class ImitateRecorder():
18 | def __init__(self, task):
19 | rospy.init_node('imitate_recorder2', log_level=rospy.DEBUG)
20 | # Create the appropriate directory in the datas for the task we are training
21 | if not os.path.exists('../datas/' + task + '/'):
22 | os.mkdir('../datas/' + task + '/')
23 | self.save_folder = None # The specific folder
24 | self.writer = None # The writer to create our txt files
25 | self.text_file = None # The file that we are writing to currently
26 | self.is_recording = False # Toggling recording
27 | self.counter = 0 # The unique key for each datapoint we save
28 | self.bridge = CvBridge()
29 | # Initialize current values
30 | self.Data = namedlist('Data', ['pose', 'twist', 'grip', 'rgb', 'depth'])
31 | self.data = self.Data(pose=None, twist=None, grip=None, rgb=None, depth=None)
32 |
33 | def toggle_collection(self, toggle):
34 | if toggle is "0":
35 | self.counter = 0
36 | self.is_recording = False
37 | self.unsubscribe()
38 | time.sleep(1)
39 | if self.text_file != None:
40 | self.text_file.close()
41 | self.text_file = None
42 | print("-----Stop Recording-----")
43 | else:
44 | save_folder = '../datas/' + task + '/' + str(time.time()) + '/'
45 | os.mkdir(save_folder)
46 | self.save_folder = save_folder
47 | self.text_file = open(save_folder + 'vectors.txt', 'w')
48 | self.writer = csv.writer(self.text_file)
49 | self.is_recording = True
50 | print("=====Start Recording=====")
51 | self.collect_data()
52 |
53 | def collect_data(self):
54 | # Initialize Listeners
55 | self.init_listeners()
56 | rospy.Rate(5).sleep()
57 | # Define the rate at which we will collect data
58 | rate = rospy.Rate(15)
59 | while not rospy.is_shutdown():
60 | if None not in self.data:
61 | print("Data Collected!!")
62 | rgb_image = self.bridge.imgmsg_to_cv2(self.data.rgb, desired_encoding="passthrough")
63 | depth_image = self.bridge.imgmsg_to_cv2(self.data.depth, desired_encoding="passthrough")
64 | cv2.imwrite(self.save_folder + str(self.counter) + '_rgb.png', rgb_image)
65 | cv2.imwrite(self.save_folder + str(self.counter) + '_depth.png', depth_image)
66 | posit = self.data.pose.position
67 | orient = self.data.pose.orientation
68 | lin = self.data.twist.linear
69 | ang = self.data.twist.angular
70 | arr = [self.counter, posit.x, posit.y, posit.z, orient.w, orient.x, orient.y, orient.z, lin.x, lin.y, lin.z, ang.x, ang.y, ang.z, self.data.grip, time.time()]
71 | self.writer.writerow(arr)
72 | self.data = self.Data(pose=None, twist=None, grip=None, rgb=None, depth=None)
73 | self.counter += 1
74 | rate.sleep()
75 |
76 | def init_listeners(self):
77 | # The Topics we are Subscribing to for data
78 | self.right_arm_pose = rospy.Subscriber('/tf/right_arm_pose', PoseStamped, self.pose_callback)
79 | self.right_arm_vel = rospy.Subscriber('/tf/right_arm_vels', TwistStamped, self.vel_callback)
80 | self.rgb_state_sub = rospy.Subscriber('/kinect2/qhd/image_color_rect', Image, self.rgb_callback)
81 | self.depth_state_sub = rospy.Subscriber('/kinect2/qhd/image_depth_rect', Image, self.depth_callback)
82 | self.gripper_state_sub = rospy.Subscriber('/movo/right_gripper/gripper_is_open', GripperStamped, self.gripper_callback)
83 |
84 | def unsubscribe(self):
85 | self.right_arm_pose.unregister()
86 | self.right_arm_vel.unregister()
87 | self.rgb_state_sub.unregister()
88 | self.depth_state_sub.unregister()
89 | self.gripper_state_sub.unregister()
90 |
91 | def pose_callback(self, pose):
92 | if None in self.data:
93 | self.data.pose = pose.pose
94 |
95 | def vel_callback(self, twist):
96 | if None in self.data:
97 | self.data.twist = twist.twist
98 |
99 | def rgb_callback(self, rgb):
100 | if None in self.data:
101 | self.data.rgb = rgb
102 |
103 | def depth_callback(self, depth):
104 | if None in self.data:
105 | self.data.depth = depth
106 |
107 | def gripper_callback(self, gripper):
108 | if None in self.data:
109 | self.data.grip = gripper.data.data
110 |
111 | if __name__ == '__main__':
112 | task = sys.argv[1]
113 | recorder = ImitateRecorder(task)
114 | recorder.toggle_collection("1")
115 | recorder.toggle_collection("0")
116 |
--------------------------------------------------------------------------------
/old_stack/util/toggler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | import sys
3 | import rospy
4 | import subprocess
5 | from std_msgs.msg import String
6 |
7 | p = None
8 | is_recording = False
9 | task = sys.argv[1]
10 |
11 | def toggle(value):
12 | global p
13 | global is_recording
14 | global task
15 |
16 | if value.data is "0":
17 | if is_recording:
18 | p.terminate()
19 | is_recording = False
20 | elif value.data is "1":
21 | if not is_recording:
22 | p = subprocess.Popen(["./record.py", task])
23 | is_recording = True
24 |
25 | rospy.init_node('imitate_toggler', log_level=rospy.DEBUG)
26 | toggle_sub = rospy.Subscriber('/unity_learning_record', String, toggle)
27 | rospy.spin()
28 |
--------------------------------------------------------------------------------
/sim_eval.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | import os
3 | import sys
4 | import csv
5 | import time
6 | import pygame
7 | import argparse
8 | from pygame.locals import *
9 | from PIL import Image
10 | import numpy as np
11 | import itertools
12 | from src.model import Model
13 | import torch
14 | from simulation.sim import get_start
15 |
16 | # note we want tau to be col, row == x, y
17 | # This is like the get goal method
18 |
19 | def get_tau_():
20 | button = input('Please enter a button for the robot to try to press (e.g. "00", "12"): ')
21 | return [int(b) for b in button]
22 |
23 |
24 | def distance(a, b):
25 | dist = [np.abs(a[i] - b[i]) for i in range(len(a))]
26 | dist = [e**2 for e in dist]
27 | dist = sum(dist)
28 | return np.sqrt(dist)
29 |
30 |
31 | def get_tau(goal_x, goal_y, options):
32 | return options[goal_x][goal_y]
33 |
34 |
35 | def process_images(np_array_img, is_it_rgb):
36 | img = 2*((np_array_img - np.amin(np_array_img))/(np.amax(np_array_img)-np.amin(np_array_img))) - 1
37 | img = torch.from_numpy(img).type(torch.FloatTensor).squeeze()
38 | if(is_it_rgb):
39 | img = img.permute(2, 0, 1).unsqueeze(0)
40 | else:
41 | img = img.view(1,1,img.shape[0], img.shape[1])
42 | return img
43 |
44 | def sim(model, config):
45 | goals_x = list(range(3))
46 | goals_y = list(range(3))
47 |
48 | goal_pos = [[(200, 150), (200, 300), (200, 450)],
49 | [(400, 150), (400, 300), (400, 450)],
50 | [(600, 150), (600, 300), (600, 450)]]
51 | #goal_pos = [[(int(np.random.rand()*720)+40, int(np.random.rand()*520)+40) for j in goals_y] for i in goals_x]
52 | #goal_pos = [[(int((i+1)*800/(len(goals_x)+1)) - 100, int((j+1)*600/(len(goals_y)+1)) + 35) for j in goals_y] for i in goals_x]
53 |
54 | tau_opts = [[(0, 0), (0, 1), (0, 2)],
55 | [(1, 0), (1, 1), (1, 2)],
56 | [(2, 0), (2, 1), (2, 2)]]
57 |
58 | tau_opts = goal_pos
59 |
60 |
61 | # These are magic numbers
62 | RECT_X = 60
63 | RECT_Y = 60
64 | SPACEBAR_KEY = 32 # pygame logic
65 | S_KEY = pygame.K_s
66 | R_KEY = pygame.K_r
67 | ESCAPE_KEY = pygame.K_ESCAPE
68 |
69 | pygame.init()
70 |
71 | #Window
72 | screen = pygame.display.set_mode((800, 600))
73 | pygame.display.set_caption("2D Simulation")
74 | pygame.mouse.set_visible(1)
75 | # Set the cursor
76 | curr_pos = get_start()
77 |
78 | #Background
79 | background = pygame.Surface(screen.get_size())
80 | background = background.convert()
81 | background.fill((211, 211, 211))
82 | screen.blit(background, (0, 0))
83 | pygame.display.flip()
84 |
85 | clock = pygame.time.Clock()
86 |
87 | run = True
88 | eof = None
89 | rgb = None
90 | depth = None
91 |
92 | gx, gy = get_tau_()#np.random.randint(0, 3, (2,))
93 | #tau_opts = np.random.randint(0, 255, (3,3,3)) if config.color else goal_pos
94 | tau = get_tau(gx, gy, goal_pos)
95 | a = 0
96 | while run:
97 | a += 1
98 | clock.tick(config.framerate)
99 | for event in pygame.event.get():
100 | if event.type == pygame.QUIT:
101 | run = False
102 | break
103 | if event.type == pygame.KEYUP:
104 | if event.key == S_KEY:
105 | curr_pos = get_start()
106 | if event.key == R_KEY:
107 | gx, gy = get_tau_()#np.random.randint(0, 3, (2,))
108 | #tau_opts = np.random.randint(0, 255, (3,3,3)) if config.color else goal_pos
109 | tau = (gx, gy)#get_tau(gx, gy, tau_opts)
110 | if event.key == ESCAPE_KEY:
111 | run = False
112 | break
113 |
114 | vanilla_rgb_string = pygame.image.tostring(screen,"RGBA",False)
115 | vanilla_rgb_pil = Image.frombytes("RGBA",(800,600),vanilla_rgb_string)
116 | resized_rgb = vanilla_rgb_pil.resize((160,120))
117 | rgb = np.array(resized_rgb)[:,:,:3]
118 | rgb = process_images(rgb, True)
119 | if torch.any(torch.isnan(rgb)):
120 | rgb.zero_()
121 |
122 | vanilla_depth = Image.fromarray(np.uint8(np.zeros((120,160))))
123 | depth = process_images(vanilla_depth, False).zero_()
124 |
125 | div = [200, 175] if config.normalize else [1, 1]
126 | sub = [400, 325] if config.normalize else [0, 0]
127 | norm_pos = [(curr_pos[0] - sub[0]) / div[0], (curr_pos[1] - sub[1]) / div[1]]
128 | if eof is None:
129 | eof = torch.FloatTensor([norm_pos[0], norm_pos[1], 0.0] * 5)
130 | else:
131 | eof = torch.cat([torch.FloatTensor([norm_pos[0], norm_pos[1], 0.0]), eof[0:12]])
132 |
133 | # Calculate the trajectory
134 | in_tau = torch.FloatTensor(tau)
135 | #in_tau = torch.zeros(1)
136 | #in_tau[0] = tau[0]*3 + tau[1]
137 | out, aux = model(rgb, depth, eof.view(1, -1), in_tau.view(1, -1).to(eof))
138 | out = out.squeeze()
139 | delta_x = out[0].item()
140 | delta_y = out[1].item()
141 | new_pos = [curr_pos[0] + delta_x, curr_pos[1] + delta_y]
142 | print(eof)
143 | print(tau)
144 | print(aux)
145 | #print(get_tau(gx, gy, goal_pos))
146 | print(out)
147 | print(new_pos)
148 | print(distance(curr_pos, new_pos))
149 | #if (distance(curr_pos, new_pos)) < 1.5:
150 | # time.sleep(5)
151 | print('========')
152 | curr_pos = new_pos
153 |
154 | screen.fill((211,211,211))
155 | for x, y in list(itertools.product(goals_x, goals_y)):
156 | color = tau_opts[x,y] if config.color else (0, 0, 255)
157 | #if x == gx and y == gy:
158 | # continue
159 | pygame.draw.rect(screen, color, pygame.Rect(goal_pos[x][y][0]-RECT_X/2, goal_pos[x][y][1]-RECT_Y/2, RECT_X, RECT_Y))
160 | pygame.draw.circle(screen, (0,0,0), [int(v) for v in curr_pos], 20, 0)
161 | pygame.display.update()
162 |
163 | #if a == 200:
164 | # break
165 | pygame.quit()
166 | return 0
167 |
168 | if __name__ == '__main__':
169 | parser = argparse.ArgumentParser(description='Input to 2d simulation.')
170 | parser.add_argument('-w', '--weights', required=True, help='The path to the weights to load.')
171 | parser.add_argument('-c', '--color', dest='color', default=False, action='store_true', help='Used to activate color simulation.')
172 | parser.add_argument('-no', '--normalize', dest='normalize', default=False, action='store_true', help='Used to activate position normalization.')
173 | parser.add_argument('-f', '--framerate', default=300, type=int, help='Framerate of simulation.')
174 | args = parser.parse_args()
175 |
176 | checkpoint = torch.load(args.weights, map_location='cpu')
177 | model = Model(**checkpoint['kwargs'])
178 | model.load_state_dict(checkpoint['model_state_dict'])
179 | model.eval()
180 |
181 | sim(model, args)
182 |
--------------------------------------------------------------------------------
/sim_eval0.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | import os
3 | import sys
4 | import csv
5 | import time
6 | import pygame
7 | import argparse
8 | from pygame.locals import *
9 | from PIL import Image
10 | import numpy as np
11 | import matplotlib.pyplot as plt
12 | import itertools
13 | from src.model import Model
14 | from src.model0 import Model0
15 | from src.loss_func import fix_rot
16 | import torch
17 | from simulation.sim import get_start, get_tau, RECT_X, RECT_Y, goals_x, goals_y, goal_pos
18 |
19 | # note we want tau to be col, row == x, y
20 | # This is like the get goal method
21 | def get_tau_():
22 | button = input('Please enter a button for the robot to try to press (e.g. "00", "12"): ')
23 | return [int(b) for b in button]
24 |
25 |
26 | def distance(a, b):
27 | dist = [np.abs(a[i] - b[i]) for i in range(len(a))]
28 | dist = [e**2 for e in dist]
29 | dist = sum(dist)
30 | return np.sqrt(dist)
31 |
32 |
33 | def process_images(np_array_img, is_it_rgb):
34 | #try:
35 | img = 2*((np_array_img - np.amin(np_array_img))/(np.amax(np_array_img)-np.amin(np_array_img))) - 1
36 | #except:
37 | # img = np.zeros_like(np_array_img)
38 | img = torch.from_numpy(img).type(torch.FloatTensor).squeeze()
39 | if(is_it_rgb):
40 | img = img.permute(2, 0, 1)
41 | else:
42 | img = img.view(1, img.shape[0], img.shape[1])
43 | return img.unsqueeze(0)
44 |
45 | def sim(model, config):
46 |
47 | tau_opts = [[(0, 0), (0, 1), (0, 2)],
48 | [(1, 0), (1, 1), (1, 2)],
49 | [(2, 0), (2, 1), (2, 2)]]
50 |
51 | #tau_opts = [[(1,0), (2,0), (3,0)], File "/home/nishanth/miniconda3/envs/py3_pytorch_cuda10/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
52 |
53 | # [(3,1), (1,1), (2,1)],
54 | # [(2,2), (3,2), (1,2)]]
55 |
56 |
57 | # These are magic numbers
58 | SPACEBAR_KEY = 32 # pygame logic
59 | S_KEY = pygame.K_s
60 | R_KEY = pygame.K_r
61 | ESCAPE_KEY = pygame.K_ESCAPE
62 |
63 | pygame.init()
64 |
65 | #Window
66 | screen = pygame.display.set_mode((800, 600))
67 | pygame.display.set_caption("2D Simulation")
68 | pygame.mouse.set_visible(1)
69 | # Set the cursor
70 | curr_pos = get_start()
71 |
72 | #Background
73 | background = pygame.Surface(screen.get_size())
74 | background = background.convert()
75 | background.fill((211, 211, 211))
76 | screen.blit(background, (0, 0))
77 | pygame.display.flip()
78 |
79 | clock = pygame.time.Clock()
80 |
81 | run = True
82 | eof = None
83 | rgb = None
84 | depth = None
85 |
86 | gx, gy = get_tau_()#np.random.randint(0, 3, (2,))
87 | #tau_opts = np.random.randint(0, 255, (3,3,3)) if config.color else goal_pos
88 | tau = get_tau(gx, gy, tau_opts)#(gx, gy)
89 | if args.rotation:
90 | # rect_rot = np.ones(9) * np.random.randint(0,360)
91 | rect_rot = np.random.randint(0,360, (9,))
92 | else:
93 | # rect_rot = np.ones(9) * np.random.randint(0,360)
94 | # rect_rot = np.ones(9) * 35
95 | rect_rot = np.random.randint(0,360, (9,))
96 |
97 | x_offset = np.random.randint(0, 200)
98 | y_offset = np.random.randint(0, 200)
99 |
100 | # This is the values used for the only_rot experiments
101 | # x_offset = 150
102 | # y_offset = 150
103 |
104 |
105 | _gx, _gy = get_tau(gx, gy, goal_pos)
106 | _gx += x_offset
107 | _gy += y_offset
108 | a = 0
109 | while run:
110 | a += 1
111 | clock.tick(config.framerate)
112 | for event in pygame.event.get():
113 | if event.type == pygame.QUIT:
114 | run = False
115 | break
116 | if event.type == pygame.KEYUP:
117 | if event.key == S_KEY:
118 | curr_pos = get_start()
119 | #curr_pos[2] = np.random.randint(0, 360)
120 | if event.key == R_KEY:
121 | x_offset = np.random.randint(0, 200)
122 | y_offset = np.random.randint(0, 200)
123 | rect_rot = np.ones(9) * np.random.randint(0,360)
124 | #rect_rot = np.random.randint(0,360, (9,))
125 |
126 | # gx, gy = get_tau_()#np.random.randint(0, 3, (2,))
127 | #tau_opts = np.random.randint(0, 255, (3,3,3)) if config.color else goal_pos
128 | tau = (gx, gy)#get_tau(gx, gy, tau_opts)
129 | if event.key == ESCAPE_KEY:
130 | run = False
131 | break
132 |
133 |
134 |
135 |
136 | screen.fill((211,211,211))
137 |
138 | surf2 = pygame.Surface((RECT_Y-10, RECT_Y-10), pygame.SRCALPHA)
139 | surf2.fill((0, 255, 0))
140 |
141 | #for obstacle in obstacles:
142 | # pygame.draw.rect(screen, (255, 0, 0), pygame.Rect(*obstacle))
143 | for x, y in list(itertools.product(goals_x, goals_y)):
144 | color = tau_opts[x, y] if config.color else (0, 0, 255)
145 | #surf = pygame.Surface((RECT_X, RECT_Y), pygame.SRCALPHA)
146 | #surf.fill(color)
147 | #surf = pygame.transform.rotate(surf, rect_rot[3*x+y])
148 | #surf.convert()
149 | #screen.blit(surf, (goal_pos[x][y][0] + x_offset, goal_pos[x][y][1] + y_offset))
150 |
151 | color = tau_opts[x, y] if config.color else (0, 0, 255)
152 | #surf = pygame.Surface()
153 | surf = pygame.Surface((RECT_X, RECT_Y), pygame.SRCALPHA)
154 | surf.fill(color)
155 | surf.blit(surf2, (5, 5))
156 | surf = pygame.transform.rotate(surf, rect_rot[3*x+y])
157 | print(rect_rot[3*x+y])
158 | surf.convert()
159 | screen.blit(surf, (goal_pos[x][y][0] + x_offset, goal_pos[x][y][1] + y_offset))
160 |
161 | # THIS IS THE OLD CODE FOR MULTICOLOR SQUARES
162 | # pygame.draw.rect(screen, color, pygame.Rect(goal_pos[x][y][0]-RECT_X/2 + x_offset, goal_pos[x][y][1]-RECT_Y/2 + y_offset, RECT_X, RECT_Y))
163 | # pygame.draw.rect(screen, (255, 0, 0), pygame.Rect(goal_pos[x][y][0]-RECT_X/4 + x_offset, goal_pos[x][y][1]-RECT_Y/4 + y_offset, RECT_X/2, RECT_Y/2))
164 | # pygame.draw.rect(screen, (0, 255, 0), pygame.Rect(goal_pos[x][y][0]-RECT_X/8 + x_offset, goal_pos[x][y][1]-RECT_Y/8 + y_offset, RECT_X/4, RECT_Y/4))
165 |
166 |
167 | #surf = pygame.Surface((RECT_X, RECT_Y), pygame.SRCALPHA)
168 | #surf.fill((0,0,0))
169 | #surf = pygame.transform.rotate(surf, int(curr_pos[2]))
170 | #surf = pygame.transform.rotate(surf, 90)
171 | #surf.convert()
172 | #screen.blit(surf, curr_pos[:2])
173 |
174 | surf = pygame.Surface((RECT_X, RECT_Y), pygame.SRCALPHA)
175 | surf.fill((0,0,0))
176 | surf2.fill((255, 0, 0))
177 | surf.blit(surf2, (5, 5))
178 | surf = pygame.transform.rotate(surf, int(curr_pos[2]))
179 | surf.convert()
180 | screen.blit(surf, curr_pos[:2])
181 |
182 | # OLD CIRCULAR AGENT
183 | # pygame.draw.circle(screen, (0,0,0), [int(v) for v in curr_pos[:2]], 20, 0)
184 | pygame.display.update()
185 |
186 |
187 |
188 |
189 | vanilla_rgb_string = pygame.image.tostring(screen,"RGBA",False)
190 | vanilla_rgb_pil = Image.frombytes("RGBA",(800,600),vanilla_rgb_string)
191 | resized_rgb = vanilla_rgb_pil.resize((160,120))
192 | rgb = np.array(resized_rgb)[:,:,:3]
193 | rgb = process_images(rgb, True)
194 | if torch.any(torch.isnan(rgb)):
195 | rgb.zero_()
196 |
197 | vanilla_depth = Image.fromarray(np.uint8(np.zeros((120,160))))
198 | depth = process_images(vanilla_depth, False).zero_()
199 |
200 | div = [400, 300, 180] if config.normalize else [1, 1, 1]
201 | sub = [400, 300, 180] if config.normalize else [0, 0, 0]
202 | norm_pos = [(curr_pos[i] - sub[i]) / div[i] for i in range(3)]
203 | norm_pos = norm_pos[:2] + [np.sin(np.pi * norm_pos[2]), np.cos(np.pi * norm_pos[2])]
204 | if eof is None:
205 | eof = torch.FloatTensor(norm_pos * 5)
206 | else:
207 | eof = torch.cat([torch.FloatTensor(norm_pos), eof[0:16] * 0])
208 |
209 | # Calculate the trajectory
210 | in_tau = torch.FloatTensor(tau)
211 | print_loc = 'eval_print/' + str(a)
212 | out = None
213 | aux = None
214 | with torch.no_grad():
215 | out, aux = model(rgb, depth, eof.view(1, -1), in_tau.view(1, -1).to(eof), b_print=config.print, print_path=print_loc)#, aux_in = torch.rand(1,4))
216 | out = out.squeeze()
217 | delta_x = out[0].item()
218 | delta_y = out[1].item()
219 | delta_rot = out[2].item()
220 |
221 | # This block is a relic from when the network output sin and cos values
222 | # sin_cos = out[2:4]
223 | # mag = torch.sqrt(sin_cos[0]**2 + sin_cos[1]**2)
224 | # sin_cos = sin_cos / mag
225 | # delta_rot = (torch.atan2(sin_cos[0], sin_cos[1]).item() / 3.14159) * 180
226 | new_pos = [curr_pos[0] + delta_x, curr_pos[1] + delta_y, (curr_pos[2] + delta_rot) % 360]
227 | # sin_cos = aux.squeeze()[2:4]
228 | # mag = torch.sqrt(sin_cos[0]**2 + sin_cos[1]**2)
229 | # sin_cos = sin_cos / mag
230 | # aux_rot = (torch.atan2(sin_cos[0] , sin_cos[1]).view(-1, 1) / 3.14159) - 1
231 |
232 | print(eof.numpy())
233 | print(out.numpy())
234 | print([-1*np.sin(new_pos[2] / 180 * np.pi), -1*np.cos(new_pos[2] / 180 * np.pi)])
235 | print([-1*np.sin((rect_rot[3*gx+gy] / 180) * np.pi), -1*np.cos((rect_rot[3*gx+gy] / 180) * np.pi)])
236 | print((goal_pos[gx][gy][0] + x_offset - 400) / 400, (goal_pos[gx][gy][1] + y_offset - 300) / 300)
237 | print(aux.numpy())
238 | # print([-1*np.sin(rect_rot[3*gx+gy] / 180 * 3.14159), -1*np.cos(rect_rot[3*gx+gy] / 180 * 3.14159)])
239 | # print(fix_rot(torch.FloatTensor([rect_rot[3*gx+gy] / 180 - 1]).view(-1, 1), aux_rot))
240 | # print(fix_rot(torch.FloatTensor([rect_rot[3*gx+gy] / 180 - 1]).view(-1, 1), torch.FloatTensor([new_pos[2]/180 - 1]).view(-1, 1)))
241 | # print(fix_rot(torch.FloatTensor([new_pos[2]/180 - 1]).view(-1, 1), aux_rot))
242 | # print(eof)
243 | # print(tau)
244 | # print(aux)
245 | # print((_gx, _gy))
246 | # print(out)
247 | # print(new_pos)
248 | # print(distance(curr_pos, new_pos))
249 | #if (distance(curr_pos, new_pos)) < 1.5:
250 | # time.sleep(5)
251 | print('========')
252 | curr_pos = new_pos
253 |
254 | #if a == 200:
255 | # break
256 | pygame.quit()
257 | return 0
258 |
259 | if __name__ == '__main__':
260 | parser = argparse.ArgumentParser(description='Input to 2d simulation.')
261 | parser.add_argument('-w', '--weights', required=True, help='The path to the weights to load.')
262 | parser.add_argument('-c', '--color', dest='color', default=False, action='store_true', help='Used to activate color simulation.')
263 | parser.add_argument('-no', '--normalize', dest='normalize', default=False, action='store_true', help='Used to activate position normalization.')
264 | parser.add_argument('-f', '--framerate', default=300, type=int, help='Framerate of simulation.')
265 | parser.add_argument('-r', '--rotation', default=True, dest='rotation', action='store_false', help='Used to eval rotation.')
266 | parser.add_argument('-p', '--print', default=False, dest='print', action='store_true', help='Flag to print activations.')
267 | parser.add_argument('-att', '--attention', default=False, dest='attention', action='store_true', help='Flag indicating to use attention')
268 | args = parser.parse_args()
269 |
270 | checkpoint = torch.load(args.weights, map_location='cpu')
271 | if args.attention:
272 | model = Model(**checkpoint['kwargs'])
273 | else:
274 | model = Model0(**checkpoint['kwargs'])
275 | model.load_state_dict(checkpoint['model_state_dict'])
276 | model.eval()
277 |
278 | sim(model, args)
279 |
--------------------------------------------------------------------------------
/src/datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 |
4 | import lmdb
5 | from multiprocessing import Pool
6 |
7 | import os.path as osp
8 | import pyarrow as pa
9 | import random
10 |
11 | class ImitationLMDB(Dataset):
12 | def __init__(self, dest, mode):
13 | super(ImitationLMDB, self).__init__()
14 | lmdb_file = osp.join(dest, mode+".lmdb")
15 | # Open the LMDB file
16 | self.env = lmdb.open(lmdb_file, subdir=osp.isdir(lmdb_file),
17 | readonly=True, lock=False,
18 | readahead=False, meminit=False)
19 |
20 | with self.env.begin(write=False) as txn:
21 | self.length = self.loads_pyarrow(txn.get(b'__len__'))
22 | self.keys = self.loads_pyarrow(txn.get(b'__keys__'))
23 |
24 | self.shuffled = [i for i in range(self.length)]
25 | #random.shuffle(self.shuffled)
26 |
27 | def loads_pyarrow(self, buf):
28 | return pa.deserialize(buf)
29 |
30 | def __getitem__(self, idx):
31 | rgb, depth, eof, tau, aux, target = None, None, None, None, None, None
32 | index = self.shuffled[idx]
33 | env = self.env
34 | with env.begin(write=False) as txn:
35 | byteflow = txn.get(self.keys[index])
36 |
37 | # RGB, Depth, EOF, Tau, Aux, Target
38 | unpacked = self.loads_pyarrow(byteflow)
39 |
40 | # load data
41 | rgb = torch.from_numpy(unpacked[0]).type(torch.FloatTensor)
42 | depth = torch.from_numpy(unpacked[1]).type(torch.FloatTensor)
43 | eof = torch.from_numpy(unpacked[2]).type(torch.FloatTensor)
44 | tau = torch.from_numpy(unpacked[3]).type(torch.FloatTensor)
45 | aux = torch.from_numpy(unpacked[4]).type(torch.FloatTensor)
46 | target = torch.from_numpy(unpacked[5]).type(torch.FloatTensor)
47 |
48 | return [rgb, depth, eof, tau, target, aux]
49 |
50 | def __len__(self):
51 | return self.length
52 |
53 | def close(self):
54 | self.env.close()
55 |
--------------------------------------------------------------------------------
/src/loss_func.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from math import pi
4 |
5 | class LossException(Exception):
6 | def __init__(self, *args, **kwargs):
7 | super(LossException, self).__init__(*args, **kwargs)
8 |
9 |
10 | def fix_rot(goal_rot, cur_rot):
11 | true_rot = torch.abs(torch.remainder(goal_rot, 2) - torch.remainder(cur_rot, 2))
12 | mask = true_rot > 1
13 | true_rot[mask] = 2 - true_rot[mask]
14 |
15 | return true_rot
16 |
17 | class BehaviorCloneLoss(nn.Module):
18 | """
19 | Behavior Clone Loss
20 | """
21 | def __init__(self, lamb_l2=0.01, lamb_l1=1.0, lamb_c=0.005, lamb_aux=0.0001, eps=1e-7):
22 | super(BehaviorCloneLoss, self).__init__()
23 | self.lamb_l2 = lamb_l2
24 | self.lamb_l1 = lamb_l1
25 | self.lamb_c = lamb_c
26 | self.lamb_aux = lamb_aux
27 | self.l2 = nn.MSELoss()
28 | self.l1 = nn.L1Loss()
29 | self.aux = nn.MSELoss()
30 |
31 | self.eps = eps
32 |
33 | def forward(self, out, aux_out, target, aux_target, flag=False):
34 | if torch.any(torch.isnan(out)):
35 | print(out)
36 | raise LossException('nan in model outputs!')
37 |
38 | '''
39 | x = out[:, 0:1] * torch.sin(out[:, 1:2]) * torch.cos(out[:, 2:3])
40 | y = out[:, 0:1] * torch.sin(out[:, 1:2]) * torch.sin(out[:, 2:3])
41 | z = out[:, 0:1] * torch.cos(out[:, 1:2])
42 |
43 | out = torch.cat([x, y, z, out[:, 3:]], dim=1)
44 | '''
45 | l2_loss = self.l2(out[:, :2], target[:, :2]) * 2 / 3 + self.l2(out[:, 2], target[:, 2]) / 3
46 | l1_loss = self.l1(out, target)
47 |
48 | # For the arccos loss
49 | bs, n = out.shape
50 | num = torch.bmm(target.view(bs,1,n), out.view(bs,n,1)).squeeze()
51 | den = torch.norm(target,p=2,dim=1) * torch.norm(out,p=2,dim=1) + self.eps
52 | div = num / den
53 | a_cos = torch.acos(torch.clamp(div, -1 + self.eps, 1 - self.eps))
54 | c_loss = torch.mean(a_cos)
55 | # For the aux loss
56 | aux_loss = self.aux(aux_out, aux_target)
57 |
58 | weighted_loss = self.lamb_l2*l2_loss + self.lamb_l1*l1_loss + self.lamb_c*c_loss + self.lamb_aux*aux_loss
59 |
60 | if flag:#torch.isnan(weighted_loss):
61 | #print(out)
62 | print('===============')
63 | print('===============')
64 | #print(target)
65 |
66 | print(' ')
67 | print(' ')
68 | print(' ')
69 |
70 | print('weighted loss: %.2f' % weighted_loss)
71 | print('l2 loss: %.2f' % l2_loss)
72 | print('l1 loss: %.2f' % l1_loss)
73 | print('c loss: %.2f' % c_loss)
74 | print('aux loss: %.2f' % aux_loss)
75 |
76 | print(' ')
77 |
78 | print('L2 x: %.2f' % self.l2(out[:, 0], target[:, 0]))
79 | print('L2 y: %.2f' % self.l2(out[:, 1], target[:, 1]))
80 | print('L2 theta: %.2f' % self.l2(out[:, 2], target[:, 2]))
81 |
82 | if torch.isnan(c_loss):
83 | print('num: %s' % str(num))
84 | print('den: %s' % str(den))
85 | print('div: %s' % str(div))
86 | print('acos: %s' % str(a_cos))
87 |
88 | #raise LossException('Loss is nan!')
89 |
90 | return weighted_loss
91 |
--------------------------------------------------------------------------------
/superval_3dval.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | import cv2
3 | import time
4 | import csv
5 | import os
6 | import sys
7 | import rospy
8 | import itertools
9 | import numpy as np
10 | #from tf.transformations import euler_from_quaternion
11 | from cv_bridge import CvBridge
12 | from namedlist import namedlist
13 | from std_msgs.msg import Int64, String
14 | from sensor_msgs.msg import CompressedImage, Image, JointState
15 | from geometry_msgs.msg import Twist, Pose, TwistStamped, PoseStamped, Vector3
16 | import torch
17 | from model import Model
18 | import argparse
19 | import copy
20 | from matplotlib import pyplot as plt
21 | from blob import blob
22 |
23 |
24 | class ImitateEval:
25 | def __init__(self, weights):
26 | self.bridge = CvBridge()
27 | self.Data = namedlist('Data', ['pose', 'rgb', 'depth'])
28 | self.data = self.Data(pose=None, rgb=None, depth=None)
29 | self.is_start = True
30 |
31 | checkpoint = torch.load(weights, map_location="cpu")
32 | self.model = Model(**checkpoint['kwargs'])
33 | self.model.load_state_dict(checkpoint["model_state_dict"])
34 | self.model.eval()
35 |
36 | def change_start(self):
37 | radius = 0.07
38 | # Publisher for the movement and the starting pose
39 | self.movement_publisher = rospy.Publisher('/iiwa/CollisionAwareMotion', Pose, queue_size=10)
40 | self.target_start = Pose()
41 | self.target_start.position.x = -0.15 + np.random.rand()*2*radius - radius # -0.10757
42 | self.target_start.position.y = 0.455 + np.random.rand()*2*radius - radius # 0.4103
43 | self.target_start.position.z = 1.015
44 | self.target_start.orientation.x = 0.0
45 | self.target_start.orientation.y = 0.0
46 | self.target_start.orientation.z = 0.7071068
47 | self.target_start.orientation.w = 0.7071068
48 |
49 |
50 | def move_to_button(self, tau, tolerance):
51 | self.init_listeners()
52 | rospy.Rate(5).sleep()
53 | rate = rospy.Rate(15)
54 |
55 | pose_to_move = copy.deepcopy(self.target_start)
56 | eof = []
57 | while not rospy.is_shutdown():
58 | if None not in self.data:
59 | # Position from the CartesianPose Topic!!
60 | pos = self.data.pose.position
61 | pos = [pos.x, pos.y, pos.z]
62 | if self.is_start:
63 | for _ in range(5):
64 | eof += pos
65 | self.is_start = False
66 | else:
67 | eof = pos + eof[:-3]
68 |
69 | eof_input = torch.from_numpy(np.array(eof)).type(torch.FloatTensor)
70 | eof_input = eof_input.unsqueeze(0)
71 |
72 |
73 |
74 | rgb = self.process_images(self.data.rgb, True)
75 | depth = self.process_images(self.data.depth, False)
76 |
77 | # print("RGB min: {}, RGB max: {}".format(np.amin(rgb), np.amax(rgb)))
78 | # print("Depth min: {}, Depth max: {}".format(np.amin(depth), np.amax(depth)))
79 | print("EOF: {}".format(eof_input))
80 | print("Tau: {}".format(tau))
81 |
82 | torch.save(rgb, "/home/amazon/Desktop/rgb_tensor.pt")
83 | torch.save(depth, "/home/amazon/Desktop/depth_tensor.pt")
84 | torch.save(eof, "/home/amazon/Desktop/eof_tensor.pt")
85 | torch.save(tau, "/home/amazon/Desktop/tau.pt")
86 |
87 | with torch.no_grad():
88 | out, aux = self.model(rgb, depth, eof_input, tau)
89 | torch.save(out, "/home/amazon/Desktop/out.pt")
90 | torch.save(aux, "/home/amazon/Desktop/aux.pt")
91 | out = out.squeeze()
92 | x_cartesian = out[0].item()
93 | y_cartesian = out[1].item()
94 | z_cartesian = out[2].item()
95 | print("X:{}, Y:{}, Z:{}".format(x_cartesian, y_cartesian, z_cartesian))
96 | print("Aux: {}".format(aux))
97 | # This new pose is the previous pose + the deltas output by the net, adjusted for discrepancy in frame
98 | # It used to be:
99 | # pose_to_move.position.x += -y_cartesian
100 | # pose_to_move.position.y += x_cartesian
101 | # pose_to_move.position.z += z_cartesian
102 |
103 | pose_to_move.position.x -= y_cartesian
104 | pose_to_move.position.y += x_cartesian
105 | pose_to_move.position.z += z_cartesian
106 | #print(pose_to_move)
107 |
108 | # Publish to Kuka!!!!
109 | for i in range(10):
110 | self.movement_publisher.publish(pose_to_move)
111 | rospy.Rate(10).sleep()
112 |
113 | rospy.wait_for_message("/iiwa/CollisionAwareExecutionStatus", String)
114 | # End publisher
115 |
116 | self.data = self.Data(pose=None,rgb=None,depth=None)
117 | rate.sleep()
118 |
119 | if distance(tau, (pose_to_move.position.x, pose_to_move.position.y, pose_to_move.position.z)) < tolerance:
120 | break
121 |
122 | def process_images(self, img_msg, is_it_rgb):
123 | crop_right=586
124 | crop_lower=386
125 | img = self.bridge.compressed_imgmsg_to_cv2(img_msg, desired_encoding="passthrough")
126 | if(is_it_rgb):
127 | img = img[:,:,::-1]
128 | # Does this crop work?
129 | #rgb = img[0:386, 0:586]
130 | #rgb = img.crop((0, 0, crop_right, crop_lower))
131 | rgb = cv2.resize(img, (160,120))
132 | rgb = np.array(rgb).astype(np.float32)
133 |
134 | rgb = 2*((rgb - np.amin(rgb))/(np.amax(rgb)-np.amin(rgb)))-1
135 |
136 |
137 | rgb = torch.from_numpy(rgb).type(torch.FloatTensor)
138 | if is_it_rgb:
139 | rgb = rgb.view(1, rgb.shape[0], rgb.shape[1], rgb.shape[2]).permute(0, 3, 1, 2)
140 | else:
141 | rgb = rgb.view(1, 1, rgb.shape[0], rgb.shape[1])
142 | #plt.imshow(rgb[0,0] / 2 + .5)
143 | # plt.show()
144 |
145 | return rgb
146 |
147 | def move_to_start(self):
148 | # Publish starting position to Kuka!!!!
149 | for i in range(10):
150 | self.movement_publisher.publish(self.target_start)
151 | rospy.Rate(10).sleep()
152 |
153 | rospy.wait_for_message("/iiwa/CollisionAwareExecutionStatus", String)
154 | # End publisher
155 |
156 | def init_listeners(self):
157 | # The Topics we are Subscribing to for data
158 | self.right_arm_pose = rospy.Subscriber('/iiwa/state/CartesianPose', PoseStamped, self.pose_callback)
159 | self.rgb_state_sub = rospy.Subscriber('/camera3/camera/color/image_rect_color/compressed', CompressedImage, self.rgb_callback)
160 | self.depth_state_sub = rospy.Subscriber('/camera3/camera/depth/image_rect_raw/compressed', CompressedImage, self.depth_callback)
161 |
162 | def unsubscribe(self):
163 | self.right_arm_pose.unregister()
164 | self.rgb_state_sub.unregister()
165 | self.depth_state_sub.unregister()
166 |
167 | def pose_callback(self, pose):
168 | if None in self.data:
169 | self.data.pose = pose.pose
170 |
171 | def rgb_callback(self, rgb):
172 | if None in self.data:
173 | self.data.rgb = rgb
174 |
175 | def depth_callback(self, depth):
176 | if None in self.data:
177 | self.data.depth = depth
178 |
179 |
180 | def translate_tau(button):
181 | b_0 = int(button[0])
182 | b_1 = int(button[1])
183 |
184 | tau = (-.22+.07*b_0, .56-.07*b_1, .94-.0025*b_y)
185 | return tau
186 |
187 |
188 | def distance(a, b):
189 | return np.sqrt(np.sum([np.abs(aa - bb) for aa, bb in zip(a,b)]))
190 |
191 |
192 | def get_tau(r, c):
193 | tau = translate_tau([r, c])
194 |
195 |
196 | def main(config):
197 | for weights in list(blob(config.weights + '/*/best_checkpoint.tar', recursive=True)):
198 | agent = Agent(config.weights)
199 | agent.change_start()
200 | agent.move_to_start()
201 | rates = torch.zeros(3, 3, config.num_traj)
202 | for r in range(3):
203 | for c in range(3):
204 | for i in range(config.num_traj):
205 | tau = get_tau(r, c)
206 | rates[3, c, i] = agent.move_to_button(tau, config.tolerance)
207 | agent.change_start()
208 | agent.move_to_start()
209 | rates = torch.sum(rates, dim=2) / config.num_traj
210 | torch.save(rates, config.weights[:config.weights.rfind('/')] + '/button_eval_percentages.pt')
211 |
212 |
213 | if __name__ == "__main__":
214 | parser = argparse.ArgumentParser(description="Arguments for evaluating imitation net")
215 | parser.add_argument('-w', '--weights', required=True, help='Path to folder containing checkpoint directories/files.')
216 | parser.add_argument('-t', '--tolerance', default=.01, type=float, help='Tolerance for button presses.')
217 | parser.add_argument('-n', '--num_traj', default=10, type=int, help='Presses per button per arrangement.')
218 | args = parser.parse_args()
219 | rospy.init_node('eval_imitation', log_level=rospy.DEBUG)
220 | try:
221 | main(args)
222 | except KeyboardInterrupt:
223 | pass
224 | cept KeyboardInterrupt:
225 | pass
226 |
--------------------------------------------------------------------------------
/superval_results_compile.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from glob import glob
3 | import argparse
4 | from matplotlib import pyplot as plt
5 | import numpy as np
6 |
7 |
8 | def parse_path(path):
9 | buttons = path[path.rfind('(')+1:path.rfind(')')].split(',')
10 | buttons = [(int(button[-2]), int(button[-3])) for button in buttons if len(button) > 0]
11 | return buttons
12 |
13 |
14 | def get_pcts(button_pcts, train_btns):
15 | train_inds = list(zip(*train_btns))
16 | train_pct = torch.mean(button_pcts[train_inds])
17 | total_pct = torch.mean(button_pcts)
18 | test_pct = (9 * total_pct - len(train_btns) * train_pct) / (9 - len(train_btns))
19 | return torch.stack([train_pct, test_pct, total_pct])
20 |
21 | def errorfill(x, y, yerr, color=None, alpha_fill=0.3, ax=None, plt_label=None, linestyle='-'):
22 | ax = ax if ax is not None else plt.gca()
23 | # if color is None:
24 | # color = ax._get_lines.color_cycle.next()
25 | if np.isscalar(yerr) or len(yerr) == len(y):
26 | ymin = y - yerr
27 | ymax = y + yerr
28 | elif len(yerr) == 2:
29 | ymin, ymax = yerr
30 | ax.plot(x, y, color=color, label=plt_label, linestyle=linestyle)
31 | ax.fill_between(x, ymax, ymin, color=color, alpha=alpha_fill)
32 |
33 |
34 | def plot(percentages, versions):
35 | colors = ['red', 'blue', 'green', 'teal']
36 | linestyles = ['-', '--', '-.', ':']
37 | plt.style.use('ggplot')
38 |
39 | # MERL_percs = [11.0, 33.0, 100.0, 78.0, 100.0, 100.0, 100.0, 100.0, 100.0]
40 | # MERL_percs_errors = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
41 | #
42 | # plt.figure(0)
43 | # errorfill(range(1, 10), MERL_percs, yerr=MERL_percs_errors, color='red', plt_label='tau')
44 | # plt.xlabel('Number of holes in training set')
45 | # plt.ylabel('Successful insertion percentage')
46 | # plt.title('Average Accuracy across all insertions')
47 | # plt.legend()
48 |
49 | plt.figure(1)
50 | # plt.errorbar(range(1, 9), percentages[0, :-1, 1, 0], yerr=percentages[0, :-1, 1, 2], label='tau')
51 | # plt.errorbar(range(1, 9), percentages[1, :-1, 1, 0], yerr=percentages[1, :-1, 1, 2], label='onehot')
52 | tauerr = [percentages[0, :-1, 1, 5], percentages[0, :-1, 1, 6]]
53 | tauerr_avg = [percentages[0, :, 2, 5], percentages[0, :, 2, 6]]
54 | #errorfill(range(1, 9), percentages[0, :-1, 1, 0], yerr=tauerr, color='red', plt_label='tau for unseen')
55 | for i, name in enumerate(versions):
56 | errorfill(range(1, 10), percentages[i, :, 2, 4], yerr=[percentages[i, :, 2, 5], percentages[i, :, 2, 6]], color=colors[i], linestyle=linestyles[i], plt_label=name)
57 | #errorfill(range(1, 10), percentages[1, :, 2, 0], yerr=percentages[1, :, 2, 2], color='green', plt_label='onehot')
58 | plt.xlabel('Number of goals in the training set')
59 | plt.ylabel('Success percentage')
60 | plt.title('Median Success Percentage in Simulation')
61 | plt.legend()
62 | '''
63 | plt.figure(2)
64 | plt.plot(range(1, 9), percentages[0, :-1, 1, 0], label='All Arrangements')
65 | #plt.plot(range(1, 9), percentages[1, :-1, 1, 0], label='onehot')
66 | plt.plot(range(1, 9), percentages[0, :-1, 1, 3], label='Median-centered 50%'+' of Arrangements')
67 | #plt.plot(range(1, 9), percentages[1, :-1, 1, 3], label='onehot without outliers')
68 | plt.xlabel('Number of Buttons in Train Set')
69 | plt.ylabel('Successful Press Percentage')
70 | plt.title('Average Accuracy on Untrained Buttons With Tau')
71 | plt.legend()
72 |
73 | plt.figure(3)
74 | plt.plot(range(1, 9), percentages[0, :-1, 1, 7], label='100th')
75 | #plt.bar(range(1, 9), percentages[0, :-1, 1, 6], label='75th')
76 | plt.plot(range(1, 9), percentages[0, :-1, 1, 5], label='50th')
77 | #plt.bar(range(1, 9), percentages[0, :-1, 1, 4], label='25th')
78 | plt.plot(range(1, 9), percentages[0, :-1, 1, 3], label='0th')
79 | # plt.plot(range(1, 9), percentages[0, :-1, 1, 4], label='tau without outliers')
80 | # plt.plot(range(1, 9), percentages[1, :-1, 1, 4], label='onehot without outliers')
81 | plt.xlabel('Number of Buttons in Train Set')
82 | plt.ylabel('Successful Press Percentage')
83 | plt.title('Accuracy on Untrained Buttons by Percentile')
84 | plt.legend()
85 |
86 | plt.figure(4)
87 | plt.errorbar(range(1, 10), percentages[0, :, 2, 0], yerr=percentages[0, :, 2, 2], label='tau')
88 | #plt.errorbar(range(1, 10), percentages[1, :, 2, 0], yerr=percentages[1, :, 2, 2], label='onehot')
89 | plt.xlabel('Number of Buttons in Train Set')
90 | plt.ylabel('Successful Press Percentage')
91 | plt.title('Average Accuracy on All Buttons')
92 | plt.legend()
93 |
94 | plt.figure(5)
95 | plt.plot(range(1, 10), percentages[0, :, 2, 0], label='tau')
96 | #plt.plot(range(1, 10), percentages[1, :, 2, 0], label='onehot')
97 | plt.plot(range(1, 10), percentages[0, :, 2, 3], label='tau without outliers')
98 | # plt.plot(range(1, 10), percentages[1, :, 2, 3], label='onehot without outliers')
99 | plt.xlabel('Number of Buttons in Train Set')
100 | plt.ylabel('Successful Press Percentage')
101 | plt.title('Average Accuracy on All Buttons')
102 | plt.legend()
103 |
104 | plt.figure(6)
105 | plt.plot(range(1, 10), percentages[0, :, 2, 1], label='tau')
106 | #plt.plot(range(1, 10), percentages[1, :, 2, 1], label='onehot')
107 | plt.plot(range(1, 10), percentages[0, :, 2, 4], label='tau without outliers')
108 | #plt.plot(range(1, 10), percentages[1, :, 2, 4], label='onehot without outliers')
109 | plt.xlabel('Number of Buttons in Train Set')
110 | plt.ylabel('Successful Press Percentage')
111 | plt.title('Maximum Accuracy on All Buttons')
112 | plt.legend()
113 | '''
114 | plt.show()
115 |
116 |
117 | if __name__ == '__main__':
118 | parser = argparse.ArgumentParser()
119 | parser.add_argument('root')
120 | parser.add_argument('versions', nargs='+')
121 | parser.add_argument('-e', '--error_range', default=.25, type=float)
122 | args = parser.parse_args()
123 |
124 | percentages = torch.zeros(len(args.versions), 9, 3, 8)
125 | for i_version, version in enumerate(args.versions):
126 | version_root = args.root + '/' + version
127 | for count in range(1,4):
128 | count_root = version_root + '/' + str(count)
129 | sample_paths = glob(count_root + '/**/button_eval_percentages.pt', recursive=True)
130 | samples = [(torch.load(path), parse_path(path)) for path in sample_paths]
131 | sample_pcts = torch.stack([get_pcts(*sample) for sample in samples], dim=0)
132 | sorted_test = torch.stack([sample_pcts[:, i].sort()[0] for i in range(sample_pcts.size(1))], dim=1)
133 | error_low = int(sample_pcts.size(0) * args.error_range)
134 | error_high = int(sample_pcts.size(0) * (1 - args.error_range))
135 | percentages[i_version, count - 1, :] = torch.stack([torch.mean(sample_pcts, dim=0),
136 | torch.max(sample_pcts, dim=0)[0],
137 | torch.std(sample_pcts, dim=0),
138 | sorted_test[int(sample_pcts.size(0) * 0)],
139 | sorted_test[int(sample_pcts.size(0) * .49)],
140 | sorted_test[error_low],
141 | sorted_test[error_high],
142 | sorted_test[int(sample_pcts.size(0) * .99)],
143 | ], dim=1)
144 |
145 | #print((percentages*100).int())
146 | #print(percentages[1, :, 2, 0]*100)
147 | plot(percentages.numpy()*100, args.versions)
148 |
--------------------------------------------------------------------------------
/temprun.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | for i in 5 6 7 8 9
4 | do
5 | cp -r 3d_$i"buttons" ~/
6 | python train.py -d ~/3d_$i"buttons" -s 3d_$i"buttons_out" -ne 5 -lr .0005 -ub -sr 51 -opt novograd -at -device cuda:1
7 | rm -r ~/3d_$i"buttons"
8 | done
9 |
--------------------------------------------------------------------------------
/util/plot_loss.py:
--------------------------------------------------------------------------------
1 | from matplotlib import pyplot as plt
2 | import numpy as np
3 | import pandas as pd
4 |
5 | import argparse
6 |
7 | def load_file(filename):
8 | data = pd.read_csv(filename, sep=',', header=None)
9 | data.columns = ["epoch", "loss_category", "loss"]
10 |
11 | categories = data['loss_category'].unique()
12 |
13 | losses = [(category, filter_loss(data, category)) for category in categories]
14 |
15 | return losses, filename[:-4]
16 |
17 |
18 | def filter_loss(data, category):
19 | cat_inds = data['loss_category'].values == category
20 |
21 | losses = np.stack([data['epoch'], data['loss']])
22 | cat_losses = losses[(slice(None),) + np.where(cat_inds)]
23 |
24 | return cat_losses
25 |
26 |
27 | def plot_file(losses, identifier=None):
28 | if identifier is None:
29 | label = 'Loss'
30 | else:
31 | label = 'Loss(%s)' % identifier
32 |
33 | for cat, loss in losses:
34 | plt.plot(loss[0], loss[1], label='%s %s'%(cat, label))
35 |
36 |
37 | if __name__ == '__main__':
38 | parser = argparse.ArgumentParser()
39 | parser.add_argument('loss_files',
40 | help='The path(s) to the loss file(s) you want to plot.',
41 | nargs='*')
42 | args = parser.parse_args()
43 |
44 | print('Loading data...')
45 | losses = [load_file(loss) for loss in args.loss_files]
46 | print('Data loaded.')
47 |
48 | for loss in losses:
49 | plot_file(*loss)
50 |
51 | plt.xlabel('Epoch')
52 | plt.ylabel('Loss')
53 | plt.title('Train/Test Loss')
54 | plt.legend()
55 | plt.show()
56 |
--------------------------------------------------------------------------------
/venv_tool:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source /data/people/shastin1/venv/bin/activate
4 | pip uninstall pyarrow
5 | pip install pyarrow==0.11.1
6 |
--------------------------------------------------------------------------------
/verify_data.py:
--------------------------------------------------------------------------------
1 | from src.datasets import ImitationLMDB
2 | from src.model import Model
3 | from matplotlib import pyplot as plt
4 | import torch
5 | import argparse
6 |
7 |
8 | def verify_data(config):
9 | model = None
10 | if config.weights is not None:
11 | checkpoint = torch.load(config.weights, map_location='cpu')
12 | model = Model(**checkpoint['kwargs'])
13 | model.load_state_dict(checkpoint['model_state_dict'])
14 | model.eval()
15 |
16 | dataset = ImitationLMDB(config.data_dir, config.mode)
17 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=config.shuffle)
18 |
19 | for data in dataloader:
20 | rgb = data[0].squeeze().permute(1, 2, 0)
21 | print(rgb.min(), rgb.max())
22 |
23 | depth = data[1].squeeze()
24 | print(depth.min(), depth.max())
25 |
26 | print('EOF: %s' % data[2].squeeze())
27 | print('TAU: %s' % data[3].squeeze())
28 |
29 | print('Target: %s' % data[4].squeeze())
30 | print('Aux: %s' % data[5].squeeze())
31 |
32 | if model is not None:
33 | out, aux = model(data[0], data[1], data[2], data[3])
34 | print('Model out: %s' % out.squeeze())
35 | print('Model aux: %s' % aux.squeeze())
36 |
37 | if config.show:
38 | plt.imshow((rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-6))
39 | plt.show()
40 |
41 | plt.imshow((depth - depth.min()) / (depth.max() - depth.min() + 1e-6))
42 | plt.show()
43 |
44 | print('==========================')
45 |
46 |
47 |
48 |
49 |
50 | if __name__ == '__main__':
51 | parser = argparse.ArgumentParser(description='Qualitative data verification')
52 | parser.add_argument('-d', '--data_dir', required=True, help='Location of lmdb to pull data from.')
53 | parser.add_argument('-w', '--weights', help='Weights for model to evaluate on data.')
54 | parser.add_argument('-m', '--mode', default='test', help='Mode to evaluate data of.')
55 | parser.add_argument('-s', '--shuffle', default=False, dest='shuffle', action='store_true', help='Weights for model to evaluate on data.')
56 | parser.add_argument('-sh', '--show', default=False, dest='show', action='store_true', help='Flag to show visual inputs.')
57 | args = parser.parse_args()
58 |
59 | verify_data(args)
60 |
--------------------------------------------------------------------------------