├── .gitignore ├── LICENSE ├── README.md ├── assets ├── Architecture.png ├── Headliner.png ├── gripper_fingertip.STL ├── model-archi.gif └── view_preds.png ├── config ├── augmentation_ablation.yml ├── depth_ablation.yml ├── forcesight.yml ├── grip_force_dist_pos_effort_5_25.yml ├── pretrain_ablation.yml ├── text_cond_ablation.yml ├── vit_base_ablation.yml └── vit_small_ablation.yml ├── environment.yml ├── eval_models.sh ├── grip_force_checkpoints └── grip_force_dist_pos_effort_5_25_4 │ └── model_best.pth ├── netft ├── package-list.txt ├── prediction ├── classifier_free_guidance.py ├── deep_fusion.py ├── fuse_helper.py ├── grip_force_model.py ├── grip_force_trainer.py ├── live_model.py ├── loader.py ├── mesh_from_rgbd.py ├── models.py ├── owlvit_pytorch.py ├── owlvit_seg.py ├── realsense_owlvit.py ├── trainer.py └── view_preds.py ├── recording ├── __pycache__ │ └── ft.cpython-38.pyc ├── capture_data.py ├── capture_grip_data.py └── ft.py ├── requirements.txt ├── robot ├── kdl_client.py ├── kdl_server.py ├── kdl_wrapper.py ├── kinpy_wrapper.py ├── robot_utils.py ├── stretch_robot.urdf └── visual_servo.py ├── ros_scripts ├── joint_state_pub.py ├── ros_aruco_detect.py ├── ros_viz.py ├── ros_viz.rviz ├── urdf_viewer.launch └── urdf_viewer.rviz ├── train_models.sh └── utils ├── aruco_detect.py ├── calibration_file.xml ├── config_utils.py ├── data_aug.py ├── data_pipeline.py ├── ft_utils.py ├── pred_utils.py ├── realsense_utils.py ├── render_videos.py ├── test_aug.py ├── transform.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | data 3 | data/ 4 | checkpoints 5 | checkpoints/ 6 | librealsense/ 7 | stretch_remote/ 8 | wandb/ 9 | videos/ 10 | figures/ 11 | results/ 12 | results 13 | ft_calibration.npy 14 | makefile -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 force-sight 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ForceSight 2 | 3 | Given an RGBD image and a text prompt, ForceSight produces visual-force goals for a robot, enabling mobile manipulation in unseen environments with unseen object instances. 4 | 5 | [Project Page](https://force-sight.github.io/) | [Paper](https://arxiv.org/abs/2309.12312) 6 | 7 | ![headliner](assets/Headliner.png) 8 | ![architecture](assets/Architecture.png) 9 | 10 | ## Installation 11 | 12 | Install the conda environment `forcesight` 13 | 14 | ```bash 15 | conda env create -f environment.yml 16 | conda activate forcesight 17 | ```` 18 | 19 | OR 20 | 21 | (Optional) Manually install the dependencies: 22 | 23 | ```bash 24 | # First create a conda environment 25 | conda create -n fs python=3.8 26 | conda activate fs 27 | ``` 28 | If manually installing dependencies, install PyTorch from [here](https://pytorch.org/get-started/locally/), then: 29 | ```bash 30 | conda install libffi 31 | pip3 install -r requirements.txt 32 | ``` 33 | 34 | ## Quick Start 35 | 36 | The following is a quick start guide for the project. The robot is not required for this part. 37 | 38 | 1. Download the dataset, model, and hardware [here](https://1drv.ms/f/s!AjebifpxoPl5hO5bu91QCJSDizws9g?e=h9AlnZ). Place the model in `checkpoints/forcesight_0/` and place the dataset in `data/`. 39 | 40 | 2. **Train a model** 41 | 42 | Skip this if using a trained checkpoint 43 | 44 | ```bash 45 | python -m prediction.trainer --config forcesight 46 | ``` 47 | 48 | 3. **Evaluate the prediction** 49 | ```bash 50 | python -m prediction.view_preds \ 51 | --config forcesight \ 52 | --folder data/test_new_objects_and_env \ 53 | --index 0 --epoch best --ignore_prefilter 54 | # --ignore_prefilter is used to ignore the prefiltering step, for faster init 55 | ``` 56 | 57 | You will seen the output plot like this: 58 | 59 | ![output](assets/view_preds.png) 60 | 61 | 4. **Show live view** 62 | 63 | This requires a [realsense d405](https://www.intelrealsense.com/depth-camera-d405/) camera. 64 | 65 | ```bash 66 | python -m prediction.live_model --config forcesight --index 0 --epoch best --prompt "pick up the keys" 67 | ``` 68 | 69 | Press "p" to change the prompt. For more info about the key control, please refer to [keyboard_teleop](https://github.com/force-sight/forcesight/blob/5e2720016f31da6823b3eadfaaeaa7105803b588/robot/robot_utils.py#L140) 70 | 71 | --- 72 | 73 | Beyond this point, the Documentation contains more detailed information about the project. This will involve the usage of the [Stretch](https://hello-robot.com/stretch-2) robot and the [Realsense D405](https://www.intelrealsense.com/depth-camera-d405/) camera. 74 | 75 | ## Data collection 76 | 77 | We assume that you have a Stretch Robot and a force/torque sensor mounted on the wrist of the robot. Hardware to mount an [ATI Mini45](https://www.ati-ia.com/products/ft/ft_models.aspx?id=mini45) force/torque sensor to the Stretch can be found [here](https://1drv.ms/f/s!AjebifpxoPl5hO5bu91QCJSDizws9g?e=h9AlnZ). 78 | 79 | - Requires installation of stretch_remote: 80 | ```bash 81 | git clone https://github.com/Healthcare-Robotics/stretch_remote.git 82 | cd stretch_remote 83 | pip install -e . 84 | ``` 85 | 86 | Run the stretch remote server on the robot: 87 | 1. `python3 stretch_remote/stretch_remote/robot_server.py` 88 | 2. `conda activate forcesight` 89 | 3. test data collection, `cd ~/forcesight` 90 | 91 | ```bash 92 | # first task 93 | # OUTPUT Folder format: _frame__ 94 | python -m recording.capture_data --config --stage train --folder --prompt "pick up the apple" --realsense_id 95 | 96 | # stage 1-> 2 97 | python -m recording.capture_data --config --stage train --folder pick_up_the_apple_frame_1_2 --prompt "pick up the apple" --realsense_id 98 | 99 | # stage 2 -> 3 100 | python -m recording.capture_data --config --stage train --folder pick_up_the_apple_frame_2_3 --prompt "pick up the apple" --realsense_id 101 | 102 | # stage 3 -> 4 103 | python -m recording.capture_data --config --stage train --folder pick_up_the_apple_frame_3_4 --prompt "pick up the apple" --realsense_id 104 | ``` 105 | 106 | **Key control**: 107 | - `wasd` key: up down front back 108 | - `[]` key: robot base 109 | - `ijkl` keys: wrist 110 | - `h`: home 111 | - `enter`: switch step 112 | - `space`: save frame 113 | - `backspace`: delete 114 | - `/`: randomize the position of the end effector 115 | 116 | We use a randomizer to quickly obtain varied data, in `robot/robot_utils.py`, `if keycode == ord('/')`. This speeds up the data collection process. 117 | 118 | Data collection for grip force model 119 | ```bash 120 | python3 -m recording.capture_grip_data --bipartite 0 --config grip_force_5_21 --folder grip_force_5_25_frame_0_0 --stage train --ip 100.99.105.59 121 | ``` 122 | 123 | ### Load the new data 124 | 125 | We will try to load the data with a loader to check the newly collected raw data. 126 | 127 | ```bash 128 | python -m prediction.loader --config --folder data/raw 129 | ``` 130 | 131 | ## Train a model 132 | 133 | Set up a config for each model. The config used for ForceSight is provided in [config/forcesight.yaml](https://github.com/force-sight/forcesight/blob/main/config/forcesight.yml). For more details, please refer to the config files in `configs/` directory. 134 | 135 | Start the training: 136 | ```bash 137 | python -m prediction.trainer --config 138 | ``` 139 | 140 | ## Train grip force model (OPTIONAL) 141 | 142 | Since grip force measurement is not available from the robot, we train a grip force model to predict the grip force, given fingertip locations, motor effort, and motor position. A default model is provided in `grip_force_checkpoints/` directory. 143 | 144 | **Grip force data collection** 145 | 146 | `python -m recording.capture_data --config --stage raw --folder grip_force_5_25 --realsense_id --bipartite 0` 147 | 148 | **Train the grip force prediction model** 149 | 150 | `python -m prediction.grip_force_trainer --config --bipartite 0` 151 | 152 | ## Running ForceSight on a real robot 153 | 154 | After training, we can run the model on the robot. We will use ForceSight to generate kinematic and force goals for the robot, and the low-level controller will then control the robot to reach the goals. 155 | 156 | To run the robot, we will need to run `stretch_remote/stretch_remote/robot_server.py` on the robot, and then run the `visual_servo.py`. The `visual_servo.py` can be run on a different computer with a GPU, and communication is specified by the `--ip` argument. 157 | 158 | Test model with live view and visual servoing 159 | 160 | ```bash 161 | # Visual Servo: Press 'p' to insert prompt, 162 | # Press 't' to switch between view model and visual servoing mode 163 | # add --ros_viz arg to visualize the 3d scene on rviz 164 | python -m robot.visual_servo --config forcesight --index 0 --epoch best --prompt "pick up the keys" --ip 165 | ``` 166 | 167 | - `t` key to switch between view model and visual servoing mode 168 | - `p` key to insert prompt 169 | - `wasd[]ijkl`` to control the robot (refer above) 170 | - `h` home 171 | - `c` switch between publish or not publish point cloud (if using --ros_viz) 172 | 173 | If you do not have a force/torque sensor, you can still run ForceSight on the robot by passing `--use_ft 0` as an arg. If running `visual_servo.py`, set `USE_FORCE_OBJECTIVE` to `False` to ignore forces. Note that performance will suffer without the use of force goals, however. 174 | 175 | --- 176 | 177 | ## Others 178 | 179 | ### Run Realsense camera 180 | 181 | Util scripts to run aruco detection and visualize the point cloud. 182 | 183 | ```bash 184 | # Run realsense camera 185 | python utils/realsense_utils.py 186 | python utils/realsense_utils.py --cloud 187 | 188 | # Run aruco deteciont with realsense 189 | python -m utils.aruco_detect --rs 190 | ``` 191 | 192 | ### Run with ROS 193 | 194 | To install rospy in conda env, run `conda install -c conda-forge ros-rospy`, ***make sure you are using Python 3.8 or follow this: https://robostack.github.io/GettingStarted.html 195 | 196 | *Note: ROS tends to be unfriendly with conda env, so this installation will not be seamless.* 197 | 198 | ```bash 199 | roslaunch realsense2_camera rs_camera.launch enable_pointcloud:=1 infra_width:=640 200 | 201 | # view the point cloud 202 | rviz -d ros_scripts/ros_viz.rviz 203 | 204 | # then run the marker pose estimation script 205 | python3 -m ros_scripts.ros_aruco_detect 206 | 207 | # ros viz to visualize the pointcloud and contact in rviz. --rs to use realsense --ft to use ft sensor 208 | python -m ros_scripts.ros_viz --rs 209 | ``` 210 | 211 | **Others** 212 | 213 | ROS to visualize the urdf. URDF describes the robot model, and it is helpful to calculate the forward and inverse kinematics of the robot. 214 | 215 | ```bash 216 | roslaunch ros_scripts/urdf_viewer.launch model:=robot/stretch_robot.urdf 217 | ``` 218 | 219 | Test joints of the robot 220 | 221 | ```bash 222 | roslaunch ros_scripts/urdf_viewer.launch model:=robot/stretch_robot.urdf joints_pub:=false 223 | python3 joint_state_pub.py --joints 0 0 0 0 0 0 224 | ``` 225 | 226 | Extra, to convert xacro to urdf from [stretch_ros](https://github.com/hello-robot/stretch_ros/tree/master/stretch_description/urdf), run: 227 | 228 | ```bash 229 | rosrun xacro xacro src/stretch_ros/stretch_description/urdf/stretch_description.xacro -o output.urdf 230 | ``` 231 | 232 | ### Test data augmentation 233 | We tested various methods of data augmentation during pilot experiments. 234 | 235 | ```bash 236 | python -m utils.test_aug --no_gripper --data 237 | python -m utils.test_aug --translate_pic --data 238 | ``` 239 | 240 | --- 241 | 242 | ## Notes 243 | 244 | 1. There are some caveats when using D405 Camera with ROS. The current realsense driver doesnt support the D405 version, since the devel effort are in ros2. This fork is used: https://github.com/rjwb1/realsense-ros 245 | 2. Make sure that the image resolution is corresponding to the one its intrinsic parameters. Different image res for the same camera will have different ppx, ppy, fx, fy values. `rs-enumerate-devices -c` 246 | 3. To teleop the Stretch robot, use: https://github.com/Healthcare-Robotics/stretch_remote 247 | 4. if getting MESA driver error when running open3d viz in conda env, try `conda install -c conda-forge libstdcxx-ng` 248 | 5. There are 2 IK/FK solvers being used here: `kdl` and `kinpy`. Kinpy is a recent migration from KDL since kdl is dependent on ROS, which is a headache for conda installation. 249 | 250 | ## Bibliography 251 | 252 | ```bibtex 253 | @misc{collins2023forcesight, 254 | title={ForceSight: Text-Guided Mobile Manipulation with Visual-Force Goals}, 255 | author={Jeremy A. Collins and Cody Houff and You Liang Tan and Charles C. Kemp}, 256 | year={2023}, 257 | eprint={2309.12312}, 258 | archivePrefix={arXiv}, 259 | primaryClass={cs.RO} 260 | } 261 | ``` 262 | -------------------------------------------------------------------------------- /assets/Architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/force-sight/forcesight/671234ca9e902c138072946e8cd0f67bf01f1eba/assets/Architecture.png -------------------------------------------------------------------------------- /assets/Headliner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/force-sight/forcesight/671234ca9e902c138072946e8cd0f67bf01f1eba/assets/Headliner.png -------------------------------------------------------------------------------- /assets/gripper_fingertip.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/force-sight/forcesight/671234ca9e902c138072946e8cd0f67bf01f1eba/assets/gripper_fingertip.STL -------------------------------------------------------------------------------- /assets/model-archi.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/force-sight/forcesight/671234ca9e902c138072946e8cd0f67bf01f1eba/assets/model-archi.gif -------------------------------------------------------------------------------- /assets/view_preds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/force-sight/forcesight/671234ca9e902c138072946e8cd0f67bf01f1eba/assets/view_preds.png -------------------------------------------------------------------------------- /config/augmentation_ablation.yml: -------------------------------------------------------------------------------- 1 | TRAIN_FOLDER: [data/multiprompt/train] 2 | TEST_FOLDER: [data/multiprompt/test] 3 | 4 | MODEL_DIR: checkpoints 5 | 6 | ROBOT_STATES: [x, y, z, roll, pitch, yaw] 7 | 8 | LEARNING_RATE: 0.00005 9 | BATCH_SIZE: 8 10 | NUM_WORKERS: 12 11 | NUM_EPOCHS: 20 12 | 13 | TRANSFORM: [] 14 | 15 | LOSS_RATIO: 5.6 16 | 17 | ACTION_DELTA_DICT: {x: 0.03, y: 0.03, z: 0.03, roll: 0.1, pitch: 0.1, yaw: 0.05, gripper: 5, theta: 0.1} 18 | 19 | IMAGE_MODEL: vit-large # vit-tiny, vit-small, vit-base, vit-large, clip-base, clip-large, dinov2-small, dinov2-base, dinov2-large 20 | FREEZE_IMAGE_MODEL: False 21 | IMAGE_SIZE: 224 22 | USE_PATCH_FEATURES: False 23 | 24 | TEXT_MODEL: t5-large # t5-small, t5-base, t5-large, bert, clip-base, clip-large 25 | FREEZE_TEXT_MODEL: True 26 | 27 | MULTIMODAL_HEAD: classifier-free-guidance # vision-only-linear, vision-only-mlp, concat-linear-attn-mlp, concat-linear, concat-mlp 28 | CFG_COND_METHOD: xattn 29 | 30 | # Loss functions 31 | FINGERTIP_LOSS: L1 32 | FORCE_LOSS: L2 33 | PITCH_LOSS: L2 34 | GRIP_LOSS: L2 35 | WIDTH_LOSS: L2 36 | YAW_LOSS: L2 37 | 38 | # Loss weights 39 | LAMBDA_FINGERTIPS: 0.0 40 | LAMBDA_FORCE: 0.2 # 0.025 41 | LAMBDA_PITCH: 0.0 # 1.0 42 | LAMBDA_GRIP: 0.2 43 | LAMBDA_TIMESTEP: 0.1 44 | LAMBDA_PIXEL: 1 45 | LAMBDA_DEPTH: 50000 46 | LAMBDA_WIDTH: 0.2 47 | LAMBDA_YAW: 0.2 48 | 49 | USE_RGBD: True 50 | PRETRAINED: True 51 | 52 | CLASSIFY_TIMESTEP: True # predict current timestep 53 | CLASSIFICATION_LOSS: focal 54 | ALPHA: 1 # for focal loss 55 | NUM_TIMESTEPS: 4 56 | 57 | # new representation 58 | CLS_POINT_RADIUS: 10 59 | PIXEL_SPACE_OUTPUT: True 60 | PIXEL_SPACE_CENTROID: True 61 | PIXEL_LABEL_WEIGHT: 100 # how much we care about 1s over 0s in pixel space 62 | 63 | REMOVE_NON_VIEWABLE_TARGET: 1 # removes when all 1 point of target contacts is not viewable 64 | 65 | SUBGOAL_TEXT: named_action 66 | -------------------------------------------------------------------------------- /config/depth_ablation.yml: -------------------------------------------------------------------------------- 1 | TRAIN_FOLDER: [data/multiprompt/train] 2 | TEST_FOLDER: [data/multiprompt/test] 3 | 4 | MODEL_DIR: checkpoints 5 | 6 | ROBOT_STATES: [x, y, z, roll, pitch, yaw] 7 | 8 | LEARNING_RATE: 0.00005 9 | BATCH_SIZE: 8 10 | NUM_WORKERS: 12 11 | NUM_EPOCHS: 5 12 | 13 | TRANSFORM: [jitter] 14 | 15 | LOSS_RATIO: 5.6 16 | 17 | ACTION_DELTA_DICT: {x: 0.03, y: 0.03, z: 0.03, roll: 0.1, pitch: 0.1, yaw: 0.05, gripper: 5, theta: 0.1} 18 | 19 | IMAGE_MODEL: vit-large # vit-tiny, vit-small, vit-base, vit-large, clip-base, clip-large, dinov2-small, dinov2-base, dinov2-large 20 | FREEZE_IMAGE_MODEL: False 21 | IMAGE_SIZE: 224 22 | USE_PATCH_FEATURES: False 23 | 24 | TEXT_MODEL: t5-large # t5-small, t5-base, t5-large, bert, clip-base, clip-large 25 | FREEZE_TEXT_MODEL: True 26 | 27 | MULTIMODAL_HEAD: classifier-free-guidance # vision-only-linear, vision-only-mlp, concat-linear-attn-mlp, concat-linear, concat-mlp 28 | CFG_COND_METHOD: xattn 29 | 30 | # Loss functions 31 | FINGERTIP_LOSS: L1 32 | FORCE_LOSS: L2 33 | PITCH_LOSS: L2 34 | GRIP_LOSS: L2 35 | WIDTH_LOSS: L2 36 | YAW_LOSS: L2 37 | 38 | # Loss weights 39 | LAMBDA_FINGERTIPS: 0.0 40 | LAMBDA_FORCE: 0.2 # 0.025 41 | LAMBDA_PITCH: 0.0 # 1.0 42 | LAMBDA_GRIP: 0.2 43 | LAMBDA_TIMESTEP: 0.1 44 | LAMBDA_PIXEL: 1 45 | LAMBDA_DEPTH: 50000 46 | LAMBDA_WIDTH: 0.2 47 | LAMBDA_YAW: 0.2 48 | 49 | USE_RGBD: False 50 | PRETRAINED: True 51 | 52 | CLASSIFY_TIMESTEP: True # predict current timestep 53 | CLASSIFICATION_LOSS: focal 54 | ALPHA: 1 # for focal loss 55 | NUM_TIMESTEPS: 4 56 | 57 | # new representation 58 | CLS_POINT_RADIUS: 10 59 | PIXEL_SPACE_OUTPUT: True 60 | PIXEL_SPACE_CENTROID: True 61 | PIXEL_LABEL_WEIGHT: 100 # how much we care about 1s over 0s in pixel space 62 | 63 | REMOVE_NON_VIEWABLE_TARGET: 1 # removes when all 1 point of target contacts is not viewable 64 | 65 | SUBGOAL_TEXT: named_action 66 | -------------------------------------------------------------------------------- /config/forcesight.yml: -------------------------------------------------------------------------------- 1 | # This is the configuration for the final model: final_tests_jitter_default_6_6 2 | # Trained on 6th June 2023 3 | 4 | TRAIN_FOLDER: [data/multiprompt/train] 5 | TEST_FOLDER: [data/multiprompt/test] 6 | 7 | MODEL_DIR: checkpoints 8 | 9 | ROBOT_STATES: [x, y, z, roll, pitch, yaw] 10 | 11 | LEARNING_RATE: 0.00005 12 | BATCH_SIZE: 8 13 | NUM_WORKERS: 12 14 | NUM_EPOCHS: 20 15 | 16 | TRANSFORM: [jitter] 17 | 18 | LOSS_RATIO: 5.6 19 | 20 | ACTION_DELTA_DICT: {x: 0.03, y: 0.03, z: 0.03, roll: 0.1, pitch: 0.1, yaw: 0.05, gripper: 5, theta: 0.1} 21 | 22 | IMAGE_MODEL: vit-large # vit-tiny, vit-small, vit-base, vit-large, clip-base, clip-large, dinov2-small, dinov2-base, dinov2-large 23 | FREEZE_IMAGE_MODEL: False 24 | IMAGE_SIZE: 224 25 | USE_PATCH_FEATURES: False 26 | 27 | TEXT_MODEL: t5-large # t5-small, t5-base, t5-large, bert, clip-base, clip-large 28 | FREEZE_TEXT_MODEL: True 29 | 30 | MULTIMODAL_HEAD: classifier-free-guidance # vision-only-linear, vision-only-mlp, concat-linear-attn-mlp, concat-linear, concat-mlp 31 | CFG_COND_METHOD: xattn 32 | 33 | # Loss functions 34 | FINGERTIP_LOSS: L1 35 | FORCE_LOSS: L2 36 | PITCH_LOSS: L2 37 | GRIP_LOSS: L2 38 | WIDTH_LOSS: L2 39 | YAW_LOSS: L2 40 | 41 | # Loss weights 42 | LAMBDA_FINGERTIPS: 0.0 43 | LAMBDA_FORCE: 0.2 # 0.025 44 | LAMBDA_PITCH: 0.0 # 1.0 45 | LAMBDA_GRIP: 0.2 46 | LAMBDA_TIMESTEP: 0.1 47 | LAMBDA_PIXEL: 1 48 | LAMBDA_DEPTH: 50000 49 | LAMBDA_WIDTH: 0.2 50 | LAMBDA_YAW: 0.2 51 | 52 | USE_RGBD: True 53 | PRETRAINED: True 54 | 55 | CLASSIFY_TIMESTEP: True # predict current timestep 56 | CLASSIFICATION_LOSS: focal 57 | ALPHA: 1 # for focal loss 58 | NUM_TIMESTEPS: 4 59 | 60 | # new representation 61 | CLS_POINT_RADIUS: 10 62 | PIXEL_SPACE_OUTPUT: True 63 | PIXEL_SPACE_CENTROID: True 64 | PIXEL_LABEL_WEIGHT: 100 # how much we care about 1s over 0s in pixel space 65 | 66 | REMOVE_NON_VIEWABLE_TARGET: 1 # removes when all 1 point of target contacts is not viewable 67 | 68 | SUBGOAL_TEXT: named_action 69 | -------------------------------------------------------------------------------- /config/grip_force_dist_pos_effort_5_25.yml: -------------------------------------------------------------------------------- 1 | TRAIN_FOLDER: [data/grip_force/train] 2 | TEST_FOLDER: [data/grip_force/test] 3 | 4 | MODEL_DIR: checkpoints 5 | 6 | ROBOT_STATES: [x, y, z, roll, pitch, yaw, gripper, gripper_effort] 7 | 8 | LEARNING_RATE: 0.001 9 | BATCH_SIZE: 4 10 | NUM_WORKERS: 12 11 | NUM_EPOCHS: 100 12 | 13 | TRANSFORM: [] 14 | 15 | LOSS_RATIO: 5.6 16 | 17 | ACTION_DELTA_DICT: {x: 0.02, y: 0.02, z: 0.005, roll: 0.1, pitch: 0.1, yaw: 0.05, gripper: 5, theta: 0.1} 18 | 19 | IMAGE_MODEL: None # vit-tiny, vit-small, vit-base, vit-large, clip-base, clip-large, dinov2-small, dinov2-base, dinov2-large 20 | # FREEZE_IMAGE_MODEL: False 21 | IMAGE_SIZE: 224 22 | # USE_PATCH_FEATURES: False 23 | 24 | # TEXT_MODEL: None # t5-small, t5-base, t5-large, bert, clip-base, clip-large 25 | # FREEZE_TEXT_MODEL: True 26 | 27 | # MULTIMODAL_HEAD: classifier-free-guidance # vision-only-linear, vision-only-mlp, concat-linear-attn-mlp, concat-linear, concat-mlp 28 | # CFG_COND_METHOD: xattn 29 | 30 | # Losses 31 | # FINGERTIP_LOSS: L1 32 | # FORCE_LOSS: L2 33 | # PITCH_LOSS: L2 34 | 35 | # Loss weights 36 | # LAMBDA_FINGERTIPS: 1.0 37 | # LAMBDA_FORCE: 0.01 # 0.025 38 | # LAMBDA_PITCH: 0.0 # 1.0 39 | 40 | # USE_RGBD: True 41 | # PRETRAINED: True 42 | 43 | CLASSIFY_TIMESTEP: False 44 | # CLASSIFICATION_LOSS: focal 45 | # ALPHA: 1 # for focal loss 46 | NUM_TIMESTEPS: 1 47 | -------------------------------------------------------------------------------- /config/pretrain_ablation.yml: -------------------------------------------------------------------------------- 1 | TRAIN_FOLDER: [data/multiprompt/train] 2 | TEST_FOLDER: [data/multiprompt/test] 3 | 4 | MODEL_DIR: checkpoints 5 | 6 | ROBOT_STATES: [x, y, z, roll, pitch, yaw] 7 | 8 | LEARNING_RATE: 0.00005 9 | BATCH_SIZE: 8 10 | NUM_WORKERS: 12 11 | NUM_EPOCHS: 5 12 | 13 | TRANSFORM: [] 14 | 15 | LOSS_RATIO: 5.6 16 | 17 | ACTION_DELTA_DICT: {x: 0.03, y: 0.03, z: 0.03, roll: 0.1, pitch: 0.1, yaw: 0.05, gripper: 5, theta: 0.1} 18 | 19 | IMAGE_MODEL: vit-large # vit-tiny, vit-small, vit-base, vit-large, clip-base, clip-large, dinov2-small, dinov2-base, dinov2-large 20 | FREEZE_IMAGE_MODEL: False 21 | IMAGE_SIZE: 224 22 | USE_PATCH_FEATURES: False 23 | 24 | TEXT_MODEL: t5-large # t5-small, t5-base, t5-large, bert, clip-base, clip-large 25 | FREEZE_TEXT_MODEL: True 26 | 27 | MULTIMODAL_HEAD: classifier-free-guidance # vision-only-linear, vision-only-mlp, concat-linear-attn-mlp, concat-linear, concat-mlp 28 | CFG_COND_METHOD: xattn 29 | 30 | # Loss functions 31 | FINGERTIP_LOSS: L1 32 | FORCE_LOSS: L2 33 | PITCH_LOSS: L2 34 | GRIP_LOSS: L2 35 | WIDTH_LOSS: L2 36 | YAW_LOSS: L2 37 | 38 | # Loss weights 39 | LAMBDA_FINGERTIPS: 0.0 40 | LAMBDA_FORCE: 0.2 # 0.025 41 | LAMBDA_PITCH: 0.0 # 1.0 42 | LAMBDA_GRIP: 0.2 43 | LAMBDA_TIMESTEP: 0.1 44 | LAMBDA_PIXEL: 1 45 | LAMBDA_DEPTH: 50000 46 | LAMBDA_WIDTH: 0.2 47 | LAMBDA_YAW: 0.2 48 | 49 | USE_RGBD: True 50 | PRETRAINED: False 51 | 52 | CLASSIFY_TIMESTEP: True # predict current timestep 53 | CLASSIFICATION_LOSS: focal 54 | ALPHA: 1 # for focal loss 55 | NUM_TIMESTEPS: 4 56 | 57 | # new representation 58 | CLS_POINT_RADIUS: 10 59 | PIXEL_SPACE_OUTPUT: True 60 | PIXEL_SPACE_CENTROID: True 61 | PIXEL_LABEL_WEIGHT: 100 # how much we care about 1s over 0s in pixel space 62 | 63 | REMOVE_NON_VIEWABLE_TARGET: 1 # removes when all 1 point of target contacts is not viewable 64 | 65 | SUBGOAL_TEXT: named_action 66 | -------------------------------------------------------------------------------- /config/text_cond_ablation.yml: -------------------------------------------------------------------------------- 1 | TRAIN_FOLDER: [data/multiprompt/train] 2 | TEST_FOLDER: [data/multiprompt/test] 3 | 4 | MODEL_DIR: checkpoints 5 | 6 | ROBOT_STATES: [x, y, z, roll, pitch, yaw] 7 | 8 | LEARNING_RATE: 0.00005 9 | BATCH_SIZE: 8 10 | NUM_WORKERS: 12 11 | NUM_EPOCHS: 5 12 | 13 | TRANSFORM: [jitter] 14 | 15 | LOSS_RATIO: 5.6 16 | 17 | ACTION_DELTA_DICT: {x: 0.03, y: 0.03, z: 0.03, roll: 0.1, pitch: 0.1, yaw: 0.05, gripper: 5, theta: 0.1} 18 | 19 | IMAGE_MODEL: vit-large # vit-tiny, vit-small, vit-base, vit-large, clip-base, clip-large, dinov2-small, dinov2-base, dinov2-large 20 | FREEZE_IMAGE_MODEL: False 21 | IMAGE_SIZE: 224 22 | USE_PATCH_FEATURES: False 23 | 24 | TEXT_MODEL: t5-large # t5-small, t5-base, t5-large, bert, clip-base, clip-large 25 | FREEZE_TEXT_MODEL: True 26 | 27 | MULTIMODAL_HEAD: classifier-free-guidance # vision-only-linear, vision-only-mlp, concat-linear-attn-mlp, concat-linear, concat-mlp 28 | CFG_COND_METHOD: xattn 29 | 30 | # Loss functions 31 | FINGERTIP_LOSS: L1 32 | FORCE_LOSS: L2 33 | PITCH_LOSS: L2 34 | GRIP_LOSS: L2 35 | WIDTH_LOSS: L2 36 | YAW_LOSS: L2 37 | 38 | # Loss weights 39 | LAMBDA_FINGERTIPS: 0.0 40 | LAMBDA_FORCE: 0.2 # 0.025 41 | LAMBDA_PITCH: 0.0 # 1.0 42 | LAMBDA_GRIP: 0.2 43 | LAMBDA_TIMESTEP: 0.1 44 | LAMBDA_PIXEL: 1 45 | LAMBDA_DEPTH: 50000 46 | LAMBDA_WIDTH: 0.2 47 | LAMBDA_YAW: 0.2 48 | 49 | USE_RGBD: True 50 | PRETRAINED: True 51 | 52 | CLASSIFY_TIMESTEP: True # predict current timestep 53 | CLASSIFICATION_LOSS: focal 54 | ALPHA: 1 # for focal loss 55 | NUM_TIMESTEPS: 4 56 | 57 | # new representation 58 | CLS_POINT_RADIUS: 10 59 | PIXEL_SPACE_OUTPUT: True 60 | PIXEL_SPACE_CENTROID: True 61 | PIXEL_LABEL_WEIGHT: 100 # how much we care about 1s over 0s in pixel space 62 | 63 | REMOVE_NON_VIEWABLE_TARGET: 1 # removes when all 1 point of target contacts is not viewable 64 | 65 | SUBGOAL_TEXT: named_action 66 | -------------------------------------------------------------------------------- /config/vit_base_ablation.yml: -------------------------------------------------------------------------------- 1 | TRAIN_FOLDER: [data/multiprompt/train] 2 | TEST_FOLDER: [data/multiprompt/test] 3 | 4 | MODEL_DIR: checkpoints 5 | 6 | ROBOT_STATES: [x, y, z, roll, pitch, yaw] 7 | 8 | LEARNING_RATE: 0.00005 9 | BATCH_SIZE: 8 10 | NUM_WORKERS: 12 11 | NUM_EPOCHS: 20 12 | 13 | TRANSFORM: [jitter] 14 | 15 | LOSS_RATIO: 5.6 16 | 17 | ACTION_DELTA_DICT: {x: 0.03, y: 0.03, z: 0.03, roll: 0.1, pitch: 0.1, yaw: 0.05, gripper: 5, theta: 0.1} 18 | 19 | IMAGE_MODEL: vit-base # vit-tiny, vit-small, vit-base, vit-large, clip-base, clip-large, dinov2-small, dinov2-base, dinov2-large 20 | FREEZE_IMAGE_MODEL: False 21 | IMAGE_SIZE: 224 22 | USE_PATCH_FEATURES: False 23 | 24 | TEXT_MODEL: t5-large # t5-small, t5-base, t5-large, bert, clip-base, clip-large 25 | FREEZE_TEXT_MODEL: True 26 | 27 | MULTIMODAL_HEAD: classifier-free-guidance # vision-only-linear, vision-only-mlp, concat-linear-attn-mlp, concat-linear, concat-mlp 28 | CFG_COND_METHOD: xattn 29 | 30 | # Loss functions 31 | FINGERTIP_LOSS: L1 32 | FORCE_LOSS: L2 33 | PITCH_LOSS: L2 34 | GRIP_LOSS: L2 35 | WIDTH_LOSS: L2 36 | YAW_LOSS: L2 37 | 38 | # Loss weights 39 | LAMBDA_FINGERTIPS: 0.0 40 | LAMBDA_FORCE: 0.2 # 0.025 41 | LAMBDA_PITCH: 0.0 # 1.0 42 | LAMBDA_GRIP: 0.2 43 | LAMBDA_TIMESTEP: 0.1 44 | LAMBDA_PIXEL: 1 45 | LAMBDA_DEPTH: 50000 46 | LAMBDA_WIDTH: 0.2 47 | LAMBDA_YAW: 0.2 48 | 49 | USE_RGBD: True 50 | PRETRAINED: True 51 | 52 | CLASSIFY_TIMESTEP: True # predict current timestep 53 | CLASSIFICATION_LOSS: focal 54 | ALPHA: 1 # for focal loss 55 | NUM_TIMESTEPS: 4 56 | 57 | # new representation 58 | CLS_POINT_RADIUS: 10 59 | PIXEL_SPACE_OUTPUT: True 60 | PIXEL_SPACE_CENTROID: True 61 | PIXEL_LABEL_WEIGHT: 100 # how much we care about 1s over 0s in pixel space 62 | 63 | REMOVE_NON_VIEWABLE_TARGET: 1 # removes when all 1 point of target contacts is not viewable 64 | 65 | SUBGOAL_TEXT: named_action 66 | -------------------------------------------------------------------------------- /config/vit_small_ablation.yml: -------------------------------------------------------------------------------- 1 | TRAIN_FOLDER: [data/multiprompt/train] 2 | TEST_FOLDER: [data/multiprompt/test] 3 | 4 | MODEL_DIR: checkpoints 5 | 6 | ROBOT_STATES: [x, y, z, roll, pitch, yaw] 7 | 8 | LEARNING_RATE: 0.00005 9 | BATCH_SIZE: 8 10 | NUM_WORKERS: 12 11 | NUM_EPOCHS: 20 12 | 13 | TRANSFORM: [jitter] 14 | 15 | LOSS_RATIO: 5.6 16 | 17 | ACTION_DELTA_DICT: {x: 0.03, y: 0.03, z: 0.03, roll: 0.1, pitch: 0.1, yaw: 0.05, gripper: 5, theta: 0.1} 18 | 19 | IMAGE_MODEL: vit-small # vit-tiny, vit-small, vit-base, vit-large, clip-base, clip-large, dinov2-small, dinov2-base, dinov2-large 20 | FREEZE_IMAGE_MODEL: False 21 | IMAGE_SIZE: 224 22 | USE_PATCH_FEATURES: False 23 | 24 | TEXT_MODEL: t5-large # t5-small, t5-base, t5-large, bert, clip-base, clip-large 25 | FREEZE_TEXT_MODEL: True 26 | 27 | MULTIMODAL_HEAD: classifier-free-guidance # vision-only-linear, vision-only-mlp, concat-linear-attn-mlp, concat-linear, concat-mlp 28 | CFG_COND_METHOD: xattn 29 | 30 | # Loss functions 31 | FINGERTIP_LOSS: L1 32 | FORCE_LOSS: L2 33 | PITCH_LOSS: L2 34 | GRIP_LOSS: L2 35 | WIDTH_LOSS: L2 36 | YAW_LOSS: L2 37 | 38 | # Loss weights 39 | LAMBDA_FINGERTIPS: 0.0 40 | LAMBDA_FORCE: 0.2 # 0.025 41 | LAMBDA_PITCH: 0.0 # 1.0 42 | LAMBDA_GRIP: 0.2 43 | LAMBDA_TIMESTEP: 0.1 44 | LAMBDA_PIXEL: 1 45 | LAMBDA_DEPTH: 50000 46 | LAMBDA_WIDTH: 0.2 47 | LAMBDA_YAW: 0.2 48 | 49 | USE_RGBD: True 50 | PRETRAINED: True 51 | 52 | CLASSIFY_TIMESTEP: True # predict current timestep 53 | CLASSIFICATION_LOSS: focal 54 | ALPHA: 1 # for focal loss 55 | NUM_TIMESTEPS: 4 56 | 57 | # new representation 58 | CLS_POINT_RADIUS: 10 59 | PIXEL_SPACE_OUTPUT: True 60 | PIXEL_SPACE_CENTROID: True 61 | PIXEL_LABEL_WEIGHT: 100 # how much we care about 1s over 0s in pixel space 62 | 63 | REMOVE_NON_VIEWABLE_TARGET: 1 # removes when all 1 point of target contacts is not viewable 64 | 65 | SUBGOAL_TEXT: named_action 66 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: forcesight 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - blas=1.0=mkl 11 | - brotlipy=0.7.0=py310h7f8727e_1002 12 | - bzip2=1.0.8=h7b6447c_0 13 | - ca-certificates=2023.5.7=hbcca054_0 14 | - certifi=2023.5.7=pyhd8ed1ab_0 15 | - cffi=1.15.1=py310h5eee18b_3 16 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 17 | - cryptography=39.0.1=py310h9ce1e76_0 18 | - cuda-cudart=11.8.89=0 19 | - cuda-cupti=11.8.87=0 20 | - cuda-libraries=11.8.0=0 21 | - cuda-nvrtc=11.8.89=0 22 | - cuda-nvtx=11.8.86=0 23 | - cuda-runtime=11.8.0=0 24 | - eigen=3.4.0=h4bd325d_0 25 | - ffmpeg=4.3=hf484d3e_0 26 | - freetype=2.12.1=h4a9f257_0 27 | - giflib=5.2.1=h5eee18b_3 28 | - gmp=6.2.1=h295c915_3 29 | - gmpy2=2.1.2=py310heeb90bb_0 30 | - gnutls=3.6.15=he1e5248_0 31 | - idna=3.4=py310h06a4308_0 32 | - intel-openmp=2023.1.0=hdb19cb5_46305 33 | - jinja2=3.1.2=py310h06a4308_0 34 | - jpeg=9e=h5eee18b_1 35 | - lame=3.100=h7b6447c_0 36 | - lcms2=2.12=h3be6417_0 37 | - ld_impl_linux-64=2.38=h1181459_1 38 | - lerc=3.0=h295c915_0 39 | - libcublas=11.11.3.6=0 40 | - libcufft=10.9.0.58=0 41 | - libcufile=1.6.1.9=0 42 | - libcurand=10.3.2.106=0 43 | - libcusolver=11.4.1.48=0 44 | - libcusparse=11.7.5.86=0 45 | - libdeflate=1.17=h5eee18b_0 46 | - libffi=3.4.4=h6a678d5_0 47 | - libgcc-ng=11.2.0=h1234567_1 48 | - libgfortran-ng=11.2.0=h00389a5_1 49 | - libgfortran5=11.2.0=h1234567_1 50 | - libgomp=11.2.0=h1234567_1 51 | - libiconv=1.16=h7f8727e_2 52 | - libidn2=2.3.2=h7f8727e_0 53 | - libnpp=11.8.0.86=0 54 | - libnvjpeg=11.9.0.86=0 55 | - libopenblas=0.3.21=h043d6bf_0 56 | - libpng=1.6.39=h5eee18b_0 57 | - libstdcxx-ng=11.2.0=h1234567_1 58 | - libtasn1=4.19.0=h5eee18b_0 59 | - libtiff=4.5.0=h6a678d5_2 60 | - libunistring=0.9.10=h27cfd23_0 61 | - libuuid=1.41.5=h5eee18b_0 62 | - libwebp=1.2.4=h11a3e52_1 63 | - libwebp-base=1.2.4=h5eee18b_1 64 | - lz4-c=1.9.4=h6a678d5_0 65 | - mkl=2023.1.0=h6d00ec8_46342 66 | - mkl-service=2.4.0=py310h5eee18b_1 67 | - mkl_fft=1.3.6=py310h1128e8f_1 68 | - mkl_random=1.2.2=py310h1128e8f_1 69 | - mpc=1.1.0=h10f8cd9_1 70 | - mpfr=4.0.2=hb69a4c5_1 71 | - ncurses=6.4=h6a678d5_0 72 | - nettle=3.7.3=hbbd107a_1 73 | - networkx=2.8.4=py310h06a4308_1 74 | - numpy=1.24.3=py310h5f9d8c6_1 75 | - numpy-base=1.24.3=py310hb5e798b_1 76 | - openh264=2.1.1=h4ff587b_0 77 | - openssl=1.1.1t=h7f8727e_0 78 | - orocos-kdl=1.5.1=h122e73d_2 79 | - pip=23.0.1=py310h06a4308_0 80 | - pycparser=2.21=pyhd3eb1b0_0 81 | - pyopenssl=23.0.0=py310h06a4308_0 82 | - pysocks=1.7.1=py310h06a4308_0 83 | - python=3.10.11=h7a1cb2a_2 84 | - python_abi=3.10=2_cp310 85 | - pytorch=2.0.1=py3.10_cuda11.8_cudnn8.7.0_0 86 | - pytorch-cuda=11.8=h7e8668a_5 87 | - pytorch-mutex=1.0=cuda 88 | - readline=8.2=h5eee18b_0 89 | - requests=2.29.0=py310h06a4308_0 90 | - setuptools=66.0.0=py310h06a4308_0 91 | - sqlite=3.41.2=h5eee18b_0 92 | - sympy=1.11.1=py310h06a4308_0 93 | - tbb=2021.8.0=hdb19cb5_0 94 | - tk=8.6.12=h1ccaba5_0 95 | - torchaudio=2.0.2=py310_cu118 96 | - torchtriton=2.0.0=py310 97 | - torchvision=0.15.2=py310_cu118 98 | - typing_extensions=4.5.0=py310h06a4308_0 99 | - urllib3=1.26.15=py310h06a4308_0 100 | - wheel=0.38.4=py310h06a4308_0 101 | - xz=5.4.2=h5eee18b_0 102 | - zlib=1.2.13=h5eee18b_0 103 | - zstd=1.5.5=hc292b87_0 104 | - pip: 105 | - addict==2.4.0 106 | - asttokens==2.2.1 107 | - attrs==23.1.0 108 | - backcall==0.2.0 109 | - blinker==1.6.2 110 | - click==8.1.3 111 | - comm==0.1.3 112 | - configargparse==1.5.3 113 | - contourpy==1.0.7 114 | - cycler==0.11.0 115 | - dash==2.9.3 116 | - dash-core-components==2.0.0 117 | - dash-html-components==2.0.0 118 | - dash-table==5.0.0 119 | - debugpy==1.6.7 120 | - decorator==5.1.1 121 | - executing==1.2.0 122 | - fastjsonschema==2.16.3 123 | - filelock==3.12.0 124 | - flask==2.3.1 125 | - fonttools==4.39.3 126 | - fsspec==2023.4.0 127 | - huggingface-hub==0.14.1 128 | - ipykernel==6.22.0 129 | - ipython==8.12.0 130 | - ipywidgets==8.0.6 131 | - itsdangerous==2.1.2 132 | - jedi==0.18.2 133 | - joblib==1.2.0 134 | - jsonschema==4.17.3 135 | - jupyter-client==8.2.0 136 | - jupyter-core==5.3.0 137 | - jupyterlab-widgets==3.0.7 138 | - kdl-py==1.1.0 139 | - kiwisolver==1.4.4 140 | - markupsafe==2.1.2 141 | - matplotlib==3.7.1 142 | - matplotlib-inline==0.1.6 143 | - mpmath==1.2.1 144 | - nbformat==5.7.0 145 | - nest-asyncio==1.5.6 146 | - open3d==0.17.0 147 | - opencv-contrib-python==4.6.0.66 148 | - opencv-python==4.6.0.66 149 | - opencv-python-headless==4.6.0.66 150 | - packaging==23.1 151 | - pandas==2.0.1 152 | - parso==0.8.3 153 | - pexpect==4.8.0 154 | - pickleshare==0.7.5 155 | - pillow==9.5.0 156 | - platformdirs==3.4.0 157 | - plotly==5.14.1 158 | - prompt-toolkit==3.0.38 159 | - psutil==5.9.5 160 | - ptyprocess==0.7.0 161 | - pure-eval==0.2.2 162 | - pygments==2.15.1 163 | - pyparsing==3.0.9 164 | - pyquaternion==0.9.9 165 | - pyrealsense2==2.53.1.4623 166 | - pyrsistent==0.19.3 167 | - python-dateutil==2.8.2 168 | - pytz==2023.3 169 | - pyyaml==6.0 170 | - pyzmq==25.0.2 171 | - regex==2023.5.5 172 | - scikit-learn==1.2.2 173 | - scipy==1.10.1 174 | - segment-anything==1.0 175 | - sentencepiece==0.1.91 176 | - six==1.16.0 177 | - stack-data==0.6.2 178 | - tenacity==8.2.2 179 | - threadpoolctl==3.1.0 180 | - timm==0.6.13 181 | - tokenizers==0.13.3 182 | - tornado==6.3.1 183 | - tqdm==4.65.0 184 | - traitlets==5.9.0 185 | - transformers==4.28.1 186 | - typing==3.7.4.3 187 | - tzdata==2023.3 188 | - wcwidth==0.2.6 189 | - werkzeug==2.3.0 190 | - widgetsnbextension==4.0.7 191 | - zmq==0.0.0 192 | - transformations 193 | - kinpy 194 | - timm 195 | - wandb 196 | - classifier_free_guidance_pytorch 197 | - opencv-contrib-python==4.6.0.66 198 | - git+https://github.com/openai/CLIP.git 199 | -------------------------------------------------------------------------------- /eval_models.sh: -------------------------------------------------------------------------------- 1 | python -m prediction.trainer --bipartite 1 --config final_tests_vit_small_ablation_6_7 --pretrained_model_path checkpoints/final_tests_vit_small_ablation_6_7_0/model_latest.pth --evaluate --eval_folder data/multiprompt/test_new_objects_and_env 2 | python -m prediction.trainer --bipartite 1 --config final_tests_vit_base_ablation_6_7 --pretrained_model_path checkpoints/final_tests_vit_base_ablation_6_7_0/model_latest.pth --evaluate --eval_folder data/multiprompt/test_new_objects_and_env -------------------------------------------------------------------------------- /grip_force_checkpoints/grip_force_dist_pos_effort_5_25_4/model_best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/force-sight/forcesight/671234ca9e902c138072946e8cd0f67bf01f1eba/grip_force_checkpoints/grip_force_dist_pos_effort_5_25_4/model_best.pth -------------------------------------------------------------------------------- /netft: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/force-sight/forcesight/671234ca9e902c138072946e8cd0f67bf01f1eba/netft -------------------------------------------------------------------------------- /package-list.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _openmp_mutex=5.1=1_gnu 6 | asttokens=2.2.1=pypi_0 7 | attrs=23.1.0=pypi_0 8 | beartype=0.14.0=pypi_0 9 | blas=1.0=mkl 10 | boost-cpp=1.72.0=he72f1d9_7 11 | brotlipy=0.7.0=py38h27cfd23_1003 12 | bzip2=1.0.8=h7b6447c_0 13 | c-ares=1.19.0=h5eee18b_0 14 | ca-certificates=2023.5.7=hbcca054_0 15 | catkin_pkg=0.5.2=pyhd8ed1ab_0 16 | certifi=2023.5.7=pyhd8ed1ab_0 17 | cffi=1.15.1=py38h5eee18b_3 18 | charset-normalizer=2.0.4=pyhd3eb1b0_0 19 | classifier-free-guidance-pytorch=0.2.2=pypi_0 20 | cmake=3.22.1=h1fce559_0 21 | comm=0.1.3=pypi_0 22 | console_bridge=1.0.2=h924138e_1 23 | contourpy=1.0.7=pypi_0 24 | cryptography=39.0.1=py38h9ce1e76_0 25 | cuda-cudart=11.8.89=0 26 | cuda-cupti=11.8.87=0 27 | cuda-libraries=11.8.0=0 28 | cuda-nvrtc=11.8.89=0 29 | cuda-nvtx=11.8.86=0 30 | cuda-runtime=11.8.0=0 31 | cycler=0.11.0=pypi_0 32 | debugpy=1.6.7=pypi_0 33 | decorator=5.1.1=pypi_0 34 | distro=1.8.0=pyhd8ed1ab_0 35 | docutils=0.20.1=py38h578d9bd_0 36 | einops=0.6.1=pypi_0 37 | empy=3.3.4=pyh9f0ad1d_1 38 | entrypoints=0.4=pypi_0 39 | executing=1.2.0=pypi_0 40 | expat=2.2.10=h9c3ff4c_0 41 | ffmpeg=4.3=hf484d3e_0 42 | filelock=3.12.0=pypi_0 43 | flask=2.3.1=pypi_0 44 | fonttools=4.39.3=pypi_0 45 | freetype=2.12.1=h4a9f257_0 46 | fsspec=2023.4.0=pypi_0 47 | giflib=5.2.1=h5eee18b_3 48 | gmock=1.11.0=hdb19cb5_0 49 | gmp=6.2.1=h295c915_3 50 | gmpy2=2.1.2=py38heeb90bb_0 51 | gnutls=3.6.15=he1e5248_0 52 | gtest=1.11.0=hdb19cb5_0 53 | huggingface-hub=0.14.1=pypi_0 54 | icu=70.1=h27087fc_0 55 | idna=3.4=py38h06a4308_0 56 | intel-openmp=2023.1.0=hdb19cb5_46305 57 | ipykernel=6.22.0=pypi_0 58 | ipython=8.12.0=pypi_0 59 | jedi=0.18.2=pypi_0 60 | jinja2=3.1.2=py38h06a4308_0 61 | jpeg=9e=h5eee18b_1 62 | jsonschema=4.17.3=pypi_0 63 | jupyter-client=8.2.0=pypi_0 64 | jupyter-core=5.3.0=pypi_0 65 | kdl-py=1.1.0=pypi_0 66 | keyutils=1.6.1=h166bdaf_0 67 | kiwisolver=1.4.4=pypi_0 68 | krb5=1.19.3=h3790be6_0 69 | lame=3.100=h7b6447c_0 70 | lcms2=2.12=h3be6417_0 71 | ld_impl_linux-64=2.38=h1181459_1 72 | lerc=3.0=h295c915_0 73 | libcublas=11.11.3.6=0 74 | libcufft=10.9.0.58=0 75 | libcufile=1.6.1.9=0 76 | libcurand=10.3.2.106=0 77 | libcurl=7.87.0=h91b91d3_0 78 | libcusolver=11.4.1.48=0 79 | libcusparse=11.7.5.86=0 80 | libdeflate=1.17=h5eee18b_0 81 | libedit=3.1.20191231=he28a2e2_2 82 | libev=4.33=h516909a_1 83 | libffi=3.4.4=h6a678d5_0 84 | libgcc-ng=11.2.0=h1234567_1 85 | libgomp=11.2.0=h1234567_1 86 | libiconv=1.16=h7f8727e_2 87 | libidn2=2.3.4=h5eee18b_0 88 | libnghttp2=1.46.0=hce63b2e_0 89 | libnpp=11.8.0.86=0 90 | libnvjpeg=11.9.0.86=0 91 | libpng=1.6.39=h5eee18b_0 92 | libssh2=1.10.0=h8f2d780_0 93 | libstdcxx-ng=13.1.0=hfd8a6a1_0 94 | libtasn1=4.19.0=h5eee18b_0 95 | libtiff=4.5.0=h6a678d5_2 96 | libunistring=0.9.10=h27cfd23_0 97 | libuv=1.44.2=h5eee18b_0 98 | libwebp=1.2.4=h11a3e52_1 99 | libwebp-base=1.2.4=h5eee18b_1 100 | lz4-c=1.9.4=h6a678d5_0 101 | markupsafe=2.1.2=pypi_0 102 | matplotlib=3.7.1=pypi_0 103 | matplotlib-inline=0.1.6=pypi_0 104 | mkl=2023.1.0=h6d00ec8_46342 105 | mkl-service=2.4.0=py38h5eee18b_1 106 | mkl_fft=1.3.6=py38h417a72b_1 107 | mkl_random=1.2.2=py38h417a72b_1 108 | mpc=1.1.0=h10f8cd9_1 109 | mpfr=4.0.2=hb69a4c5_1 110 | mpmath=1.2.1=py38h06a4308_0 111 | ncurses=6.4=h6a678d5_0 112 | nest-asyncio=1.5.6=pypi_0 113 | nettle=3.7.3=hbbd107a_1 114 | networkx=2.8.4=py38h06a4308_1 115 | nose=1.3.7=py_1006 116 | numpy=1.24.3=py38hf6e8229_1 117 | numpy-base=1.24.3=py38h060ed82_1 118 | open-clip-torch=2.20.0=pypi_0 119 | opencv-contrib-python=4.6.0.66=pypi_0 120 | opencv-python=4.6.0.66=pypi_0 121 | opencv-python-headless=4.6.0.66=pypi_0 122 | openh264=2.1.1=h4ff587b_0 123 | openssl=1.1.1t=h7f8727e_0 124 | packaging=23.1=pypi_0 125 | pandas=2.0.1=pypi_0 126 | parso=0.8.3=pypi_0 127 | pexpect=4.8.0=pypi_0 128 | pillow=9.5.0=pypi_0 129 | pip=23.1.2=pypi_0 130 | pkg-config=0.29.2=h36c2ea0_1008 131 | pkgutil-resolve-name=1.3.10=pypi_0 132 | platformdirs=3.4.0=pypi_0 133 | prompt-toolkit=3.0.38=pypi_0 134 | psutil=5.9.5=pypi_0 135 | pure-eval=0.2.2=pypi_0 136 | pycparser=2.21=pyhd3eb1b0_0 137 | pygments=2.15.1=pypi_0 138 | pyopenssl=23.0.0=py38h06a4308_0 139 | pyparsing=3.0.9=pyhd8ed1ab_0 140 | pyrsistent=0.19.3=pypi_0 141 | pysocks=1.7.1=py38h06a4308_0 142 | python=3.8.16=h7a1cb2a_3 143 | python-dateutil=2.8.2=pyhd8ed1ab_0 144 | python_abi=3.8=2_cp38 145 | pytorch=2.0.1=py3.8_cuda11.8_cudnn8.7.0_0 146 | pytorch-cuda=11.8=h7e8668a_5 147 | pytorch-mutex=1.0=cuda 148 | pytz=2023.3=pypi_0 149 | pyyaml=6.0=py38h0a891b7_4 150 | pyzmq=25.0.2=pypi_0 151 | readline=8.2=h5eee18b_0 152 | regex=2023.5.5=pypi_0 153 | requests=2.29.0=py38h06a4308_0 154 | rhash=1.4.1=h3c74f83_1 155 | ros-catkin=0.7.17=py38h950e882_5 156 | ros-conda-base=0.0.2=hcb32578_2 157 | ros-conda-mutex=1.0=melodic 158 | ros-cpp-common=0.6.12=py38h794f011_5 159 | ros-environment=1.2.1=py38h950e882_2 160 | ros-gencpp=0.6.2=py38h950e882_1 161 | ros-geneus=2.2.6=py38h950e882_1 162 | ros-genlisp=0.4.16=py38h950e882_1 163 | ros-genmsg=0.5.12=py38h950e882_1 164 | ros-gennodejs=2.0.1=py38h950e882_1 165 | ros-genpy=0.6.8=py38h950e882_1 166 | ros-message-generation=0.4.0=h950e882_1 167 | ros-message-runtime=0.4.12=he1b5a44_0 168 | ros-mk=1.14.6=he1b5a44_0 169 | ros-ros=1.14.6=he1b5a44_0 170 | ros-rosbash=1.14.6=he1b5a44_0 171 | ros-rosboost-cfg=1.14.6=py38h950e882_1 172 | ros-rosbuild=1.14.6=he1b5a44_0 173 | ros-rosclean=1.14.6=py38h950e882_1 174 | ros-roscpp-serialization=0.6.12=he1b5a44_0 175 | ros-roscpp-traits=0.6.12=he1b5a44_0 176 | ros-roscreate=1.14.6=py38h950e882_1 177 | ros-roslang=1.14.6=he1b5a44_0 178 | ros-roslib=1.14.6=py38h794f011_4 179 | ros-rosmake=1.14.6=py38h950e882_1 180 | ros-rospack=2.5.3=py38hd02d5f2_1 181 | ros-rostime=0.6.12=h794f011_3 182 | ros-rosunit=1.14.6=py38h950e882_1 183 | rosdep=0.22.2=pyhd8ed1ab_1 184 | rosdistro=0.9.0=py38h578d9bd_0 185 | rospkg=1.5.0=pyhd8ed1ab_0 186 | scikit-learn=1.2.2=pypi_0 187 | scipy=1.10.1=pypi_0 188 | segment-anything=1.0=pypi_0 189 | sentencepiece=0.1.91=pypi_0 190 | setuptools=67.8.0=py38h06a4308_0 191 | six=1.16.0=pyh6c4a22f_0 192 | sqlite=3.41.2=h5eee18b_0 193 | stack-data=0.6.2=pypi_0 194 | stretch-remote=0.1=dev_0 195 | sympy=1.11.1=py38h06a4308_0 196 | tbb=2021.8.0=hdb19cb5_0 197 | timm=0.6.13=pypi_0 198 | tinyxml2=9.0.0=h9c3ff4c_2 199 | tk=8.6.12=h1ccaba5_0 200 | tokenizers=0.13.3=pypi_0 201 | torchaudio=2.0.2=py38_cu118 202 | torchtriton=2.0.0=py38 203 | torchvision=0.15.2=py38_cu118 204 | tornado=6.3.1=pypi_0 205 | traitlets=5.9.0=pypi_0 206 | transformers=4.28.1=pypi_0 207 | typing=3.7.4.3=pypi_0 208 | typing_extensions=4.5.0=py38h06a4308_0 209 | tzdata=2023.3=pypi_0 210 | urllib3=1.26.15=py38h06a4308_0 211 | wcwidth=0.2.6=pypi_0 212 | werkzeug=2.3.0=pypi_0 213 | wheel=0.38.4=py38h06a4308_0 214 | xz=5.2.10=h5eee18b_1 215 | yaml=0.2.5=h7f98852_2 216 | zlib=1.2.13=h5eee18b_0 217 | zstd=1.5.5=hc292b87_0 218 | -------------------------------------------------------------------------------- /prediction/deep_fusion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | import torch.nn as nn 5 | import timm 6 | from transformers import T5Model, T5Tokenizer, BertModel, BertTokenizer 7 | import clip 8 | from utils.config_utils import * 9 | import numpy as np 10 | from prediction.fuse_helper import * 11 | 12 | class DeepFusion(nn.Module): 13 | def __init__(self, image_model, text_model, hidden_dim=256, num_outputs=10): 14 | super(DeepFusion, self).__init__() 15 | 16 | self.image_model = image_model 17 | 18 | # t5 19 | # self.text_model = text_model.encoder 20 | 21 | # bert 22 | self.text_model = text_model 23 | 24 | # Dimensions for image and text features 25 | image_feature_dim = self.image_model.embed_dim 26 | text_feature_dim = self.text_model.embed_dim 27 | 28 | # Multi-modal fusion layer 29 | # text -> image 30 | self.t2i_attn = AttentionT2I(q_dim=image_feature_dim, # self.joint_embedding_size, 31 | k_dim=text_feature_dim, 32 | embed_dim=2048, # self.embed_dim, 33 | num_heads=8, # self.n_head, 34 | hidden_dim=1024, # self.t2i_hidden_dim, 35 | dropout=0.1, 36 | drop_path=.0, 37 | init_values=1.0 / 6, # cfg.MODEL.DYHEAD.NUM_CONVS, 38 | mode="t2i", 39 | use_layer_scale=True, 40 | clamp_min_for_underflow=True, 41 | clamp_max_for_overflow=True 42 | ) 43 | 44 | # MLP 45 | self.mlp = nn.Sequential( 46 | nn.Linear(image_feature_dim, hidden_dim), 47 | nn.ReLU(), 48 | nn.Linear(hidden_dim, num_outputs) 49 | ) 50 | 51 | def forward(self, image, text): 52 | # Image feature extraction 53 | visual_features = self.image_model(image) 54 | 55 | print('image_features', visual_features.shape) 56 | 57 | # Text feature extraction 58 | language_feature = self.text_model(text).last_hidden_state[:, 0, :] # [CLS] token 59 | 60 | print('text_features', language_feature.shape) 61 | 62 | # # Multi-modal fusion 63 | # q0, q1, q2, q3, q4 = self.t2i_attn( 64 | # visual_features[0], visual_features[1], 65 | # visual_features[2], visual_features[3], 66 | # visual_features[4], 67 | # language_feature, language_feature, 68 | # attention_mask=text["attention_mask"] 69 | # ) 70 | 71 | # fused_visual_features = [q0, q1, q2, q3, q4] 72 | 73 | # Multi-modal fusion 74 | q = self.t2i_attn( 75 | [visual_features], 76 | language_feature, language_feature, 77 | # attention_mask=text["attention_mask"] 78 | attention_mask=None 79 | ) 80 | 81 | print('q[0]', q[0].shape) 82 | 83 | fused_visual_features = q[0] + visual_features 84 | 85 | print('fused_visual_features', fused_visual_features.shape) 86 | 87 | # MLP 88 | x = self.mlp(fused_visual_features) 89 | 90 | output = { 91 | 'left_fingertip': x[:, 0:3], 92 | 'right_fingertip': x[:, 3:6], 93 | 'force': x[:, 6:9], 94 | 'pitch': x[:, 9] 95 | } 96 | 97 | return output -------------------------------------------------------------------------------- /prediction/grip_force_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | import torch.nn as nn 5 | from utils.config_utils import * 6 | import numpy as np 7 | 8 | class GripForceMLP(nn.Module): 9 | def __init__(self, num_inputs, hidden_dim=128, num_outputs=1): 10 | super(GripForceMLP, self).__init__() 11 | self.mlp_force = nn.Sequential( 12 | nn.Linear(num_inputs, hidden_dim), 13 | nn.ReLU(), 14 | nn.Linear(hidden_dim, num_outputs) 15 | ) 16 | 17 | def forward(self, x): 18 | x = x.view(x.size(0), -1) 19 | force = self.mlp_force(x) 20 | return force -------------------------------------------------------------------------------- /prediction/grip_force_trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from prediction.loader import ActAffData 6 | from tqdm import tqdm 7 | from utils.config_utils import * 8 | from utils.pred_utils import * 9 | from utils.data_pipeline import * 10 | import wandb 11 | from prediction.grip_force_model import GripForceMLP 12 | 13 | def train_epoch(model, optimizer, train_loader, criterion): 14 | model.train() 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | loss_sum = 0 17 | finger_dist_sum_old = 0 18 | finger_dist_sum = 0 19 | force_rmse_sum = 0 20 | correct_class = 0 21 | 22 | for prompt, initial_data, final_data, rgb_paths in tqdm(train_loader): 23 | # initial_data['rgb'] = initial_data['rgb'].to(device).float() 24 | # initial_data['depth'] = initial_data['depth'].to(device).float() 25 | initial_data['state'] = initial_data['state'].to(device) 26 | initial_data['left_fingertip'] = initial_data['left_fingertip'].to(device) 27 | initial_data['right_fingertip'] = initial_data['right_fingertip'].to(device) 28 | initial_data['ft'] = initial_data['ft'].to(device) 29 | 30 | # final_data['rgb'] = final_data['rgb'].to(device) 31 | # final_data['depth'] = final_data['depth'].to(device) 32 | # final_data['state'] = final_data['state'].to(device) 33 | # final_data['left_fingertip'] = final_data['left_fingertip'].to(device).float() 34 | # final_data['right_fingertip'] = final_data['right_fingertip'].to(device).float() 35 | # final_data['ft'] = final_data['ft'].to(device).float() 36 | 37 | # self.config.ROBOT_STATES= [x, y, z, roll, pitch, yaw, gripper] 38 | # initial state: torch.Size([batch_size, 7]) 39 | # print('initial state: ', initial_data['state'].shape) 40 | gripper_pos = initial_data['state'][:, config.ROBOT_STATES.index('gripper')].to(device).float() 41 | # print('gripper_pos: ', gripper_pos) 42 | gripper_pos = normalize_gripper_pos(gripper_pos) 43 | # print('normalized gripper_pos: ', gripper_pos) 44 | 45 | 46 | gripper_effort = initial_data['state'][:, config.ROBOT_STATES.index('gripper_effort')].to(device).float() 47 | # print('gripper_effort: ', gripper_effort) 48 | gripper_effort = normalize_gripper_effort(gripper_effort) 49 | # print('normalized gripper_effort: ', gripper_effort) 50 | 51 | fingertip_dist = torch.norm(initial_data['left_fingertip'] - initial_data['right_fingertip'], dim=1).to(device).float() 52 | # print ('fingertip_dist: ', fingertip_dist) 53 | 54 | force_norm = torch.norm(initial_data['ft'], dim=1).unsqueeze(1).to(device).float() 55 | 56 | model_input = torch.cat(( 57 | gripper_pos.unsqueeze(1), 58 | gripper_effort.unsqueeze(1), 59 | fingertip_dist.unsqueeze(1)), dim=1) 60 | # model_input = torch.cat((gripper_pos.unsqueeze(1), fingertip_dist.unsqueeze(1)), dim=1) 61 | # model_input = gripper_pos.unsqueeze(1) 62 | 63 | optimizer.zero_grad() 64 | 65 | output = model(model_input) 66 | loss = criterion(output, force_norm) 67 | loss.backward() 68 | optimizer.step() 69 | 70 | loss_sum += loss.item() 71 | 72 | return loss_sum / len(train_loader) 73 | 74 | def val_epoch(model, val_loader, criterion): 75 | model.eval() 76 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 77 | loss_sum = 0 78 | finger_dist_sum_old = 0 79 | finger_dist_sum = 0 80 | force_rmse_sum = 0 81 | correct_class = 0 82 | 83 | with torch.no_grad(): 84 | for prompt, initial_data, final_data, rgb_paths in tqdm(val_loader): 85 | # initial_data['rgb'] = initial_data['rgb'].to(device).float() 86 | # initial_data['depth'] = initial_data['depth'].to(device).float() 87 | initial_data['state'] = initial_data['state'].to(device) 88 | initial_data['left_fingertip'] = initial_data['left_fingertip'].to(device) 89 | initial_data['right_fingertip'] = initial_data['right_fingertip'].to(device) 90 | initial_data['ft'] = initial_data['ft'].to(device) 91 | 92 | # final_data['rgb'] = final_data['rgb'].to(device) 93 | # final_data['depth'] = final_data['depth'].to(device) 94 | # final_data['state'] = final_data['state'].to(device) 95 | # final_data['left_fingertip'] = final_data['left_fingertip'].to(device).float() 96 | # final_data['right_fingertip'] = final_data['right_fingertip'].to(device).float() 97 | # final_data['ft'] = final_data['ft'].to(device).float() 98 | 99 | # self.config.ROBOT_STATES= [x, y, z, roll, pitch, yaw, gripper] 100 | # initial state: torch.Size([batch_size, 7]) 101 | # print('initial state: ', initial_data['state'].shape) 102 | gripper_pos = initial_data['state'][:, config.ROBOT_STATES.index('gripper')].to(device).float() 103 | # print('gripper_pos: ', gripper_pos) 104 | gripper_pos = normalize_gripper_pos(gripper_pos) 105 | # print('normalized gripper_pos: ', gripper_pos) 106 | 107 | gripper_effort = initial_data['state'][:, config.ROBOT_STATES.index('gripper_effort')].to(device).float() 108 | # print('gripper_effort: ', gripper_effort) 109 | gripper_effort = normalize_gripper_effort(gripper_effort) 110 | # print('normalized gripper_effort: ', gripper_effort) 111 | 112 | fingertip_dist = torch.norm(initial_data['left_fingertip'] - initial_data['right_fingertip'], dim=1).to(device).float() 113 | # print ('fingertip_dist: ', fingertip_dist) 114 | 115 | force_norm = torch.norm(initial_data['ft'][:3], dim=1).unsqueeze(1).to(device).float() 116 | 117 | model_input = torch.cat((gripper_pos.unsqueeze(1), gripper_effort.unsqueeze(1), fingertip_dist.unsqueeze(1)), dim=1) 118 | # model_input = torch.cat((gripper_pos.unsqueeze(1), fingertip_dist.unsqueeze(1)), dim=1) 119 | # model_input = gripper_pos.unsqueeze(1) 120 | 121 | # print('model_input: ', model_input.shape) 122 | 123 | output = model(model_input) 124 | loss = criterion(output, force_norm) 125 | 126 | loss_sum += loss.item() 127 | 128 | return loss_sum / len(val_loader) 129 | 130 | def criterion(output, force_norm): 131 | # mse 132 | # print('output: ', output) 133 | # print('output shape: ', output.shape) 134 | # print('force_norm: ', force_norm) 135 | # print('force_norm shape: ', force_norm.shape) 136 | return torch.nn.functional.mse_loss(output, force_norm, reduction='mean') 137 | 138 | if __name__ == '__main__': 139 | config, args = parse_config_args() 140 | wandb.init(project='action-affordances') 141 | 142 | wandb.config.update(config) 143 | wandb.config.update(args) 144 | 145 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 146 | 147 | # model, tokenizer = load_model(config) 148 | model = GripForceMLP(num_inputs=3, num_outputs=1) 149 | model.to(device) 150 | 151 | print('model: ', model) 152 | print('number of parameters: ', sum(p.numel() for p in model.parameters())) 153 | print('number of trainable parameters: ', sum(p.numel() for p in model.parameters() if p.requires_grad)) 154 | 155 | optimizer = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE) 156 | 157 | train_dataset = ActAffData(config.TRAIN_FOLDER, stage='train') 158 | val_dataset = ActAffData(config.TEST_FOLDER, stage='test') 159 | train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS) 160 | val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS) 161 | 162 | # number of files in ./checkpoints that contain args.config 163 | folder_index = len([f for f in os.listdir(config.MODEL_DIR) if f.startswith(args.config)]) 164 | 165 | # creating the checkpoint folder structure 166 | if not os.path.exists(config.MODEL_DIR): 167 | os.makedirs(config.MODEL_DIR) 168 | 169 | if not os.path.exists(os.path.join(config.MODEL_DIR, '{}_{}'.format(args.config, folder_index))): 170 | os.makedirs(os.path.join(config.MODEL_DIR, '{}_{}'.format(args.config, folder_index))) 171 | 172 | wandb.run.name = '{}_{}'.format(args.config, folder_index) 173 | 174 | best_train_loss = float('inf') 175 | best_val_loss = float('inf') 176 | 177 | for epoch in range(config.NUM_EPOCHS): 178 | train_loss = train_epoch(model, optimizer, train_loader, criterion) 179 | val_loss = val_epoch(model, val_loader, criterion) 180 | 181 | wandb.log({'train_loss': train_loss, 'val_loss': val_loss}, step=epoch) 182 | 183 | print(f'Epoch {epoch} - Train Loss: {train_loss} - Val Loss: {val_loss}') 184 | 185 | # model_name = '{}_{}/model_{}'.format(args.config, folder_index, epoch) 186 | model_name = '{}_{}/model'.format(args.config, folder_index, epoch) 187 | model_path = os.path.join(config.MODEL_DIR, model_name) 188 | 189 | if val_loss < best_val_loss: 190 | 191 | torch.save(model.state_dict(), model_path + '_best.pth') 192 | print('Model saved to {}'.format(model_path + '_best.pth')) 193 | best_val_loss = val_loss 194 | 195 | torch.save(model.state_dict(), model_path + '_latest.pth') 196 | print('Model saved to {}'.format(model_path + '_latest.pth')) 197 | -------------------------------------------------------------------------------- /prediction/mesh_from_rgbd.py: -------------------------------------------------------------------------------- 1 | #!pip install torch opencv-python Pillowimport open3d as o3d 2 | 3 | import numpy as np 4 | import open3d as o3d 5 | 6 | def fit_mesh(color_image, depth_image): 7 | # Create color and depth Open3D images 8 | color = o3d.geometry.Image(color_image) 9 | depth = o3d.geometry.Image(depth_image) 10 | 11 | # Create an RGBD image 12 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( 13 | color, depth, depth_scale=1000.0, depth_trunc=3.0, convert_rgb_to_intensity=False 14 | ) 15 | 16 | # Define camera intrinsic parameters 17 | fx, fy, cx, cy = 525.0, 525.0, 319.5, 239.5 18 | intrinsic = o3d.camera.PinholeCameraIntrinsic(640, 480, fx, fy, cx, cy) 19 | 20 | # Generate point cloud from RGBD image 21 | pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd, intrinsic) 22 | 23 | # Load the STL mesh 24 | mesh = o3d.io.read_triangle_mesh("assets/gripper_fingertip.STL") 25 | mesh.translate(-mesh.get_center()) # Center the mesh 26 | 27 | # Calculate the average depth value from the depth image 28 | depth_array = np.asarray(depth_image) 29 | avg_depth = np.mean(depth_array[depth_array > 0]) # Exclude zero values 30 | 31 | # Compute the dynamic scale factor 32 | # You can adjust the scaling_constant based on your object size and camera setup 33 | scaling_constant = 0.001 34 | scale_factor = scaling_constant * avg_depth 35 | 36 | # Scale the mesh based on the scale factor 37 | mesh.scale(scale_factor, mesh.get_center()) 38 | 39 | # Create a PointCloud object from the mesh vertices 40 | mesh_pcd = o3d.geometry.PointCloud() 41 | mesh_pcd.points = o3d.utility.Vector3dVector(np.asarray(mesh.vertices)) 42 | 43 | # Initial transformation can be an identity matrix if you don't have an initial estimate 44 | initial_transform = np.identity(4) 45 | 46 | # Apply ICP to refine the transformation 47 | threshold = 0.01 48 | trans_icp = o3d.pipelines.registration.registration_icp( 49 | mesh_pcd, pcd, threshold, initial_transform, 50 | o3d.pipelines.registration.TransformationEstimationPointToPoint() 51 | ).transformation 52 | 53 | # Apply the transformation to the mesh 54 | mesh.transform(trans_icp) 55 | 56 | # Turning the mesh orange 57 | mesh.paint_uniform_color([1, 0.706, 0]) 58 | 59 | # Visualize the aligned mesh and point cloud 60 | o3d.visualization.draw_geometries([pcd, mesh]) 61 | 62 | -------------------------------------------------------------------------------- /prediction/owlvit_pytorch.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from PIL import Image, ImageDraw, ImageFont 3 | import torch 4 | 5 | from transformers import OwlViTProcessor, OwlViTForObjectDetection 6 | 7 | # everything should be on gpu 8 | 9 | processor = OwlViTProcessor.from_pretrained("google/owlvit-large-patch14") 10 | model = OwlViTForObjectDetection.from_pretrained("google/owlvit-large-patch14") 11 | 12 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 13 | model.to(device) 14 | 15 | # url = "http://images.cocodataset.org/val2017/000000039769.jpg" 16 | # image = Image.open(requests.get(url, stream=True).raw) 17 | image = Image.open("table1.jpg") #Image.open(requests.get(url, stream=True).raw) 18 | # texts = [["a photo of a cat", "a photo of a dog"]] 19 | texts = [["marker", "pencil", "mouse", "keyboard", "earphones", "sunglasses", "xbox controller"]] 20 | inputs = processor(text=texts, images=image, return_tensors="pt").to(device) 21 | outputs = model(**inputs) 22 | 23 | # Target image sizes (height, width) to rescale box predictions [batch_size, 2] 24 | target_sizes = torch.Tensor([image.size[::-1]]) 25 | # Convert outputs (bounding boxes and class logits) to COCO API 26 | results = processor.post_process(outputs=outputs, target_sizes=target_sizes) 27 | 28 | # Retrieve predictions for the first image for the corresponding text queries 29 | i = 0 30 | text = texts[i] 31 | boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"] 32 | 33 | score_threshold = 0.3 34 | #font = ImageFont.truetype("arial.ttf", 12) 35 | draw = ImageDraw.Draw(image) 36 | 37 | print("for loop") 38 | for box, score, label in zip(boxes, scores, labels): 39 | box = [round(i, 2) for i in box.tolist()] 40 | if score >= score_threshold: 41 | draw.rectangle(box, outline="red", width=4) 42 | label_text = f"{text[label]} ({round(score.item(), 3)})" 43 | draw.text([box[0], box[1]-15], label_text, fill='black') 44 | print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}") 45 | 46 | image.show() 47 | -------------------------------------------------------------------------------- /prediction/owlvit_seg.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('big_vision/') 3 | 4 | import os 5 | import jax 6 | from matplotlib import pyplot as plt 7 | import numpy as np 8 | from scenic.projects.owl_vit import models 9 | from scenic.projects.owl_vit.configs import clip_b32 10 | from scipy.special import expit as sigmoid 11 | import skimage 12 | from skimage import io as skimage_io 13 | from skimage import transform as skimage_transform 14 | from scenic.projects.owl_vit.configs import clip_l14_with_masks 15 | 16 | 17 | 18 | text_queries = ['human face'] 19 | 20 | config = clip_l14_with_masks.get_config(init_mode='canonical_checkpoint') 21 | 22 | module = models.TextZeroShotDetectionModule( 23 | body_configs=config.model.body, 24 | mask_head_configs=config.model.mask_head, 25 | normalize=config.model.normalize, 26 | box_bias=config.model.box_bias) 27 | 28 | variables = module.load_variables(config.init_from.checkpoint_path) 29 | 30 | # Load example image: 31 | filename = os.path.join(skimage.data_dir, 'IMG_5637.PNG') 32 | image_uint8 = skimage_io.imread(filename) 33 | image = image_uint8.astype(np.float32) / 255.0 34 | 35 | # Pad to square with gray pixels on bottom and right: 36 | h, w, _ = image.shape 37 | size = max(h, w) 38 | image_padded = np.pad( 39 | image, ((0, size - h), (0, size - w), (0, 0)), constant_values=0.5) 40 | 41 | # Resize to model input size: 42 | input_image = skimage.transform.resize( 43 | image_padded, 44 | (config.dataset_configs.input_size, config.dataset_configs.input_size), 45 | anti_aliasing=True) 46 | 47 | text_queries = ['human face'] 48 | tokenized_queries = np.array([ 49 | module.tokenize(q, config.dataset_configs.max_query_length) 50 | for q in text_queries 51 | ]) 52 | 53 | # Pad tokenized queries to avoid recompilation if number of queries changes: 54 | tokenized_queries = np.pad( 55 | tokenized_queries, 56 | pad_width=((0, 100 - len(text_queries)), (0, 0)), 57 | constant_values=0) 58 | 59 | jitted = jax.jit(module.apply, static_argnames=('train',)) 60 | 61 | # Resize to model input size: 62 | input_image = skimage.transform.resize( 63 | image_padded, 64 | (config.dataset_configs.input_size, config.dataset_configs.input_size), 65 | anti_aliasing=True) 66 | 67 | # Note: The model expects a batch dimension. 68 | predictions = jitted( 69 | variables, 70 | input_image[None, ...], 71 | tokenized_queries[None, ...], 72 | train=False) 73 | 74 | # Remove batch dimension and convert to numpy: 75 | predictions = jax.tree_util.tree_map(lambda x: np.array(x[0]), predictions ) 76 | 77 | score_threshold = 0.3 78 | 79 | logits = predictions['pred_logits'][..., :len(text_queries)] # Remove padding. 80 | scores = sigmoid(np.max(logits, axis=-1)) 81 | labels = np.argmax(predictions['pred_logits'], axis=-1) 82 | boxes = predictions['pred_boxes'] 83 | 84 | masks = [None] * len(boxes) 85 | if 'pred_masks' in predictions: 86 | masks = sigmoid(predictions['pred_masks']) 87 | 88 | fig, ax = plt.subplots(1, 1, figsize=(8, 8)) 89 | ax.imshow(input_image, extent=(0, 1, 1, 0)) 90 | ax.set_axis_off() 91 | 92 | for score, box, label, mask in zip(scores, boxes, labels, masks): 93 | if score < score_threshold: 94 | continue 95 | cx, cy, w, h = box 96 | ax.plot([cx - w / 2, cx + w / 2, cx + w / 2, cx - w / 2, cx - w / 2], 97 | [cy - h / 2, cy - h / 2, cy + h / 2, cy + h / 2, cy - h / 2], 'r') 98 | 99 | if mask is not None: 100 | mask_img = plt.cm.viridis(mask) 101 | mask_img[..., -1] = (mask > 0.5) * 0.8 102 | extent = np.array((cx - w / 2, cx + w / 2, cy + h / 2, cy - h / 2)) 103 | ax.imshow(mask_img, extent=np.clip(extent, 0, 1)) 104 | 105 | ax.text( 106 | cx - w / 2, 107 | cy + h / 2 + 0.015, 108 | f'{text_queries[label]}: {score:1.2f}', 109 | ha='left', 110 | va='top', 111 | color='red', 112 | bbox={ 113 | 'facecolor': 'white', 114 | 'edgecolor': 'red', 115 | 'boxstyle': 'square,pad=.3' 116 | }) 117 | 118 | ax.set_xlim(0, 1) 119 | ax.set_ylim(1, 0) 120 | 121 | plt.show() -------------------------------------------------------------------------------- /prediction/realsense_owlvit.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import pyrealsense2 as rs 4 | from prediction.owlvit_seg import OwlViTSeg 5 | 6 | # Configure depth and color streams 7 | pipeline = rs.pipeline() 8 | config = rs.config() 9 | 10 | config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30) 11 | config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30) 12 | 13 | # Start streaming 14 | pipeline.start(config) 15 | 16 | owlvit = OwlViTSeg() 17 | 18 | try: 19 | while True: 20 | # Wait for a coherent pair of frames: depth and color 21 | frames = pipeline.wait_for_frames() 22 | depth_frame = frames.get_depth_frame() 23 | color_frame = frames.get_color_frame() 24 | 25 | 26 | if not depth_frame or not color_frame: 27 | continue 28 | 29 | # Convert images to numpy arrays 30 | depth_image = np.asanyarray(depth_frame.get_data()) # depth in mm 31 | color_image = np.asanyarray(color_frame.get_data()) 32 | 33 | print('depth_min', depth_image.min()) 34 | print('depth_mean', depth_image.mean()) 35 | print('depth_max', depth_image.max()) 36 | 37 | # Apply colormap on depth image (image must be converted to 8-bit per pixel first) 38 | depth_colormap = cv2.applyColorMap(cv2.convertScaleAbs(depth_image, alpha=0.03), cv2.COLORMAP_JET) 39 | 40 | # rotating images 90 degrees ccw 41 | depth_colormap = np.rot90(depth_colormap) 42 | color_image = np.rot90(color_image) 43 | 44 | # Stack both images horizontally 45 | images = np.hstack((color_image, depth_colormap)) 46 | 47 | print('color_image dtype', color_image.dtype) 48 | owlvit.segment(color_image, ['human face']) 49 | 50 | # Show images 51 | cv2.namedWindow('RealSense', cv2.WINDOW_AUTOSIZE) 52 | cv2.imshow('RealSense', images) 53 | cv2.waitKey(1) 54 | 55 | # # creating point cloud 56 | # points = rs.pointcloud() 57 | # points.map_to(color_frame) 58 | # pointcloud = points.calculate(depth_frame) 59 | 60 | # # displaying point cloud 61 | # pc = np.asanyarray(pointcloud.get_vertices()) 62 | # pc = pc.view(np.float32).reshape(pc.shape + (-1,)) 63 | # print(pc.shape) 64 | 65 | rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth( 66 | color_image, 67 | depth_image, 68 | depth_scale=1000.0, # Set the depth scale according to your depth image format (e.g., 1000.0 for millimeters) 69 | depth_trunc=3.0, # Set the depth truncation distance (in meters) for points that are too far away 70 | convert_rgb_to_intensity=False # Set to True if you want to convert the RGB values to intensity values 71 | ) 72 | 73 | # Default pinhole camera model: 74 | camera_intrinsics = o3d.camera.PinholeCameraIntrinsic( 75 | o3d.camera.PinholeCameraIntrinsicParameters.PrimeSenseDefault 76 | ) 77 | 78 | # Or, specify your own camera parameters: 79 | fx, fy, cx, cy = 525.0, 525.0, 319.5, 239.5 # Replace these with your actual camera parameters 80 | width, height = 640, 480 # Replace these with your actual image dimensions 81 | camera_intrinsics = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy) 82 | 83 | point_cloud = o3d.geometry.PointCloud.create_from_rgbd_image( 84 | rgbd_image, 85 | camera_intrinsics 86 | ) 87 | 88 | o3d.visualization.draw_geometries([point_cloud]) 89 | o3d.io.write_point_cloud('output_point_cloud.pcd', point_cloud) 90 | 91 | 92 | finally: 93 | # Stop streaming 94 | pipeline.stop() 95 | -------------------------------------------------------------------------------- /prediction/view_preds.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from prediction.models import * 4 | from prediction.loader import ActAffData 5 | from tqdm import tqdm 6 | from utils.config_utils import * 7 | from utils.pred_utils import * 8 | from utils.data_pipeline import * 9 | from utils.visualizer import * 10 | from utils.realsense_utils import PCDViewer 11 | import timm 12 | from transformers import T5Model, T5Tokenizer, BertModel, BertTokenizer 13 | import numpy as np 14 | import cv2 15 | 16 | def view_preds(folder, stage='test'): 17 | config, args = parse_config_args() 18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 19 | 20 | dataset = ActAffData(folder=folder, stage=stage) 21 | # loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0) 22 | loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) 23 | 24 | model_path = os.path.join(os.getcwd(), 'checkpoints/{}_{}/model_{}.pth'.format(args.config, args.index, args.epoch)) 25 | model, tokenizer = load_model(config, model_path) 26 | model = model.to(device) 27 | model.eval() 28 | 29 | save_folder = os.path.join(os.getcwd(), 'results/{}_{}_{}/'.format(args.config, args.index, args.epoch)) 30 | if not os.path.exists(save_folder): 31 | os.makedirs(save_folder) 32 | 33 | # pcd_view = PCDViewer(blocking=True) 34 | pcd_view = None 35 | if args.ros_viz: 36 | from ros_scripts.ros_viz import RosVizInterface 37 | pcd_view = RosVizInterface() 38 | 39 | with torch.no_grad(): 40 | for prompt, initial_data, final_data, rgb_paths in tqdm(loader): 41 | prompt = prompt 42 | initial_data['rgb'] = initial_data['rgb'].to(device) 43 | initial_data['depth'] = initial_data['depth'].to(device) 44 | 45 | final_data['state'] = final_data['state'].to(device) 46 | final_data['left_fingertip'] = final_data['left_fingertip'].to(device) 47 | final_data['right_fingertip'] = final_data['right_fingertip'].to(device) 48 | final_data['ft'] = final_data['ft'].to(device) 49 | 50 | # Add padding to the text input 51 | prompt_input = preprocess_prompt(config, prompt, tokenizer) 52 | 53 | # stacking rgb and depth on the channel dimension 54 | rgbd_input = torch.cat((initial_data['rgb'], initial_data['depth']), dim=1) 55 | 56 | output = model(rgbd_input, texts=prompt_input) #, cond_scale=1.) # output is has keys ['left_fingertip'], ['right_fingertip'], ['force'], and ['pitch'] 57 | output = postprocess_output(config, output, stage='metrics') 58 | visualize_datapoint(prompt, initial_data, final_data, 59 | config=config, pred=output, save_folder=save_folder, 60 | viz_3d=pcd_view) 61 | 62 | if __name__ == '__main__': 63 | config, args = parse_config_args() 64 | view_preds(folder=args.folder, stage=args.stage) 65 | -------------------------------------------------------------------------------- /recording/__pycache__/ft.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/force-sight/forcesight/671234ca9e902c138072946e8cd0f67bf01f1eba/recording/__pycache__/ft.cpython-38.pyc -------------------------------------------------------------------------------- /recording/capture_grip_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import cv2 4 | from utils.realsense_utils import RealSense, CameraType, PCDViewer 5 | 6 | from utils.config_utils import * 7 | from utils.ft_utils import * 8 | from robot.robot_utils import * 9 | from stretch_remote.robot_utils import * 10 | from recording.ft import FTCapture 11 | from stretch_remote.remote_client import RemoteClient 12 | from utils.aruco_detect import ArucoPoseEstimator, find_contact_markers 13 | import sys 14 | import time 15 | import os 16 | import json 17 | import re 18 | 19 | ''' 20 | python -m recording.capture_data --config data_collection_5_16 --stage raw --folder test_5_16 --view True --prompt "pick up the mouse" --bipartite True 21 | python -m prediction.loader --config data_collection_5_16 --stage raw --folder data/raw/mouse_5_16_1 --view True --bipartite True 22 | ''' 23 | 24 | HOME_JOINTS = {'y':0.0, 'pitch': -0.2853, 'gripper': 50, 'roll': -0.0, 'yaw': 0.0} 25 | 26 | ############################################################################## 27 | 28 | class DataCapture: 29 | def __init__(self): 30 | self.config, self.args = parse_config_args() 31 | self.data_folder = 'data' 32 | 33 | if not self.check_filename(self.args.folder): # TODO: check this 34 | print("Please enter a valid folder name ending with frame_x_y_z, when x y and z are ints") 35 | sys.exit(0) 36 | 37 | self.realsense = RealSense(select_device=self.args.realsense_id, view=self.args.view) 38 | self.ft = FTCapture() 39 | self.fingertips = None 40 | 41 | self.rc = RemoteClient(ip=self.args.ip, home_dict = HOME_JOINTS) 42 | if self.rc.get_status() is None: 43 | raise Exception('Remote client not connected') 44 | 45 | self.enable_moving = True 46 | self.stop = False 47 | 48 | self.manage_folders() 49 | 50 | cam_mat, cam_dist = self.realsense.get_camera_intrinsics(CameraType.COLOR) 51 | self.aruco_pose_estimator = ArucoPoseEstimator(cam_mat, cam_dist) 52 | 53 | self.rc.move({'roll': 0.0, 'pitch': 0.0, 'yaw': 0.0, 'gripper':0}) # leveling robot 54 | time.sleep(1) 55 | self.save_ft_calibration() 56 | # self.rc.home() 57 | self.status = self.rc.get_status() 58 | 59 | # saving prompt to a text file 60 | with open(os.path.join(self.data_folder, self.args.stage, self.args.folder, 'prompt.txt'), 'w') as f: 61 | f.write(self.args.prompt) 62 | 63 | self.data_index = 0 64 | self.keyframe_index_list = [] 65 | self.keyframe_step = 0 66 | self.keyframe_step_list = [] 67 | self.keyframe = False 68 | self.delete_last_keyframe = False 69 | 70 | def check_filename(self, filename): 71 | pattern = r"_frame_\d+_\d+$" 72 | return re.search(pattern, filename) 73 | 74 | def manage_folders(self): 75 | # making directories for data if they doesn't exist 76 | for folder in ['data', 'data/raw', 'data/train', 'data/test']: 77 | if not os.path.exists(folder): 78 | os.makedirs(folder) 79 | 80 | # counting the number of folders in the stage folder beginning with args.folder 81 | # folders = os.listdir(os.path.join(self.data_folder, self.args.stage)) 82 | folders = [f for f in os.listdir(os.path.join(self.data_folder, self.args.stage)) if f.startswith(self.args.folder)] 83 | 84 | if len(folders) == 0: 85 | folder_count = 0 86 | else: 87 | # folder_count = len([f for f in folders if re.match(self.args.folder, f)]) 88 | # folder_count = len([f for f in folders if f.split('/')[-1].startswith(self.args.folder)]) 89 | 90 | # using max of folder names.split('_')[-1] to get the highest folder number 91 | folder_count = max([int(f.split('_')[-1]) for f in folders if f.split('/')[-1].startswith(self.args.folder)]) + 1 92 | print('FOLDER COUNT!!!: ', folder_count) 93 | 94 | self.args.folder = self.args.folder + '_' + str(folder_count) 95 | 96 | # setting folder names as class attributes (self._folder) 97 | for name in ['rgb', 'depth', 'prompt', 'state', 'fingertips', 'ft']: 98 | folder = os.path.join(self.data_folder, self.args.stage, self.args.folder, name) 99 | setattr(self, name + '_folder', folder) 100 | 101 | # making directories for data if they doesn't exist 102 | for folder in [self.rgb_folder, self.depth_folder, self.state_folder, self.fingertips_folder, self.ft_folder]: 103 | if not os.path.exists(folder): 104 | os.makedirs(folder) 105 | 106 | def save_ft_calibration(self): 107 | self.ft_offset = self.ft.get_ft() 108 | 109 | if np.abs(self.ft_offset.mean()) < 1e-3: 110 | print('FT NOT CONNECTED') 111 | sys.exit() 112 | 113 | print('CALIBRATING FT: ', self.ft_offset) 114 | time.sleep(0.5) 115 | 116 | np.save(os.path.join(self.data_folder, self.args.stage, self.args.folder, 'ft_calibration.npy'), self.ft_offset) 117 | print('saving to ', os.path.join(self.data_folder, self.args.stage, self.args.folder, 'ft_calibration.npy')) 118 | 119 | def capture_data(self, viz_3d=None): 120 | # get data snapshots 121 | rgb_image, depth_image = self.realsense.get_rgbd_image() 122 | 123 | if viz_3d is not None: 124 | pcd = self.realsense.get_point_cloud(rgb_image, depth_image) 125 | viz_3d.display(pcd) 126 | 127 | ft_data = self.ft.get_ft() 128 | 129 | print('ft: ', ft_data - self.ft_offset) 130 | 131 | self.status = self.rc.get_status() 132 | 133 | # # get fingertip poses with aruco, return as np array 134 | # arucos = self.aruco_pose_estimator.detect(rgb_image) # 11 is left, 12 is right 135 | 136 | # ids = [x.id for x in arucos] 137 | 138 | # if 11 in ids and 12 in ids: 139 | # # left_fingertip, right_fingertip = find_contact_markers(arucos[ids.index(11)], arucos[ids.index(12)]) # TODO: move find_contact_markers to __getitem__ in loader 140 | # left_fingertip = np.array(arucos[ids.index(11)].trans) 141 | # right_fingertip = np.array(arucos[ids.index(12)].trans) 142 | # else: 143 | # left_fingertip = -np.ones(3) 144 | # right_fingertip = -np.ones(3) 145 | 146 | detected_fingertips = self.aruco_pose_estimator.get_fingertip_poses(rgb_image) 147 | 148 | # print('left translation:', left_fingertip) 149 | # print('right translation:', right_fingertip) 150 | 151 | # need to access robot state to control with keyboard 152 | keycode = cv2.waitKey(1) & 0xFF 153 | keyboard_teleop(self.rc, self.config.ACTION_DELTA_DICT, keycode, self) 154 | 155 | image_name = str(self.data_index) + '.png' 156 | depth_name = image_name 157 | prompt_name = str(self.data_index) + '.txt' 158 | state_name = str(self.data_index) + '.txt' 159 | fingertips_name = str(self.data_index) 160 | ft_name = str(self.data_index) 161 | keyframe_name = 'keyframe_list' 162 | 163 | if self.realsense.view: 164 | disp_time = str(round(self.realsense.current_frame_time - self.realsense.first_frame_time, 3)) 165 | disp_img = self.realsense.display_rgbd_image(rgb_image, depth_image) 166 | 167 | # display time 168 | cv2.putText(disp_img, disp_time, (50,50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 3, cv2.LINE_AA) 169 | 170 | # display keyframe list length 171 | cv2.putText(disp_img, 'keyframe list length: ' + str(len(self.keyframe_index_list)), (50,100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 3, cv2.LINE_AA) 172 | 173 | # display keyframe step 174 | cv2.putText(disp_img, 'step: ' + str(self.keyframe_step + 1), (50,150), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 3, cv2.LINE_AA) 175 | 176 | cv2.imshow("frames", disp_img) 177 | 178 | if detected_fingertips is None: 179 | print('NO FINGERTIPS DETECTED!!! IGNORE This frame \n\n') 180 | self.keyframe = False 181 | 182 | if self.keyframe: 183 | self.keyframe_index_list.append(self.data_index) 184 | self.keyframe_step_list.append(self.keyframe_step) 185 | 186 | if (self.keyframe or self.args.save_all_frames) and detected_fingertips is not None: 187 | # save data to machine 188 | if self.args.stage in ['train', 'test', 'raw']: 189 | cv2.imwrite(os.path.join(self.rgb_folder, image_name), rgb_image) # save rgb image 190 | cv2.imwrite(os.path.join(self.depth_folder, depth_name), depth_image) # save depth image 191 | 192 | with open(os.path.join(self.state_folder, state_name), 'w') as file: # save robot state 193 | file.write(str(self.status)) 194 | 195 | print('pos_dict: ', get_pos_dict(self.rc)) 196 | left_fingertip, right_fingertip = detected_fingertips 197 | 198 | # np.save(os.path.join(self.fingertips_folder, fingertips_name), self.fingertips) # save fingertip poses 199 | np.savez(os.path.join(self.fingertips_folder, fingertips_name), left=left_fingertip, right=right_fingertip) 200 | np.save(os.path.join(self.ft_folder, ft_name), ft_data) # save ft data 201 | 202 | # save keyframe list 203 | # np.save(os.path.join(self.data_folder, self.args.stage, self.args.folder, keyframe_name), self.keyframe_index_list) 204 | np.savez(os.path.join(self.data_folder, self.args.stage, self.args.folder, keyframe_name), keyframe_index_list=self.keyframe_index_list, keyframe_step_list=self.keyframe_step_list) 205 | 206 | else: 207 | print('Invalid stage argument. Please choose train, test, or raw') 208 | sys.exit(1) 209 | 210 | if self.delete_last_keyframe: 211 | self.keyframe_index_list.pop() 212 | self.keyframe_step_list.pop() 213 | self.delete_last_keyframe = False 214 | 215 | print('keyframe_index_list: ', self.keyframe_index_list) 216 | print('keyframe_step: ', self.keyframe_step) 217 | 218 | self.data_index += 1 219 | 220 | result = { 221 | 'rgb':rgb_image, 222 | 'depth':depth_image, 223 | 'prompt':self.args.prompt, 224 | 'state':self.status, 225 | 'fingertips':self.fingertips, 226 | 'ft_frame':ft_data, 227 | 'ft_frame_time':self.ft.current_frame_time, 228 | 'frame_time':self.realsense.current_frame_time, 229 | } 230 | 231 | print('Average FPS', self.realsense.frame_count / (time.time() - self.realsense.first_frame_time)) 232 | 233 | return result 234 | 235 | 236 | ############################################################################## 237 | 238 | if __name__ == '__main__': 239 | dc = DataCapture() 240 | delay = [] 241 | 242 | # pcd_vis = PCDViewer() 243 | while not dc.stop: 244 | data = dc.capture_data() 245 | delay.append(data['ft_frame_time'] - data['frame_time']) 246 | 247 | folder_sizes = [len(files) for r, d, files in os.walk( 248 | os.path.join(dc.data_folder, dc.args.stage, dc.args.folder))][1:] 249 | folder_names = [r.split('/')[-1] for r, d, files in os.walk( 250 | os.path.join(dc.data_folder, dc.args.stage, dc.args.folder))][1:] 251 | folder_dict = dict(zip(folder_names, folder_sizes)) 252 | 253 | print('folder sizes: ', folder_dict) 254 | 255 | if len(set(folder_sizes)) > 1: 256 | print('ERROR: not all folders have the same number of files') 257 | print('missing files in: ', [k for k, v in folder_dict.items() if v != max(folder_sizes)]) 258 | 259 | print('saved results to {}'.format(os.path.join(dc.data_folder, dc.args.stage, dc.args.folder))) 260 | print("delay avg:", np.mean(delay)) 261 | print("delay std:", np.std(delay)) 262 | print("delay max:", np.max(delay)) 263 | 264 | # depth_example = cv2.imread('data/raw/test_5_8_21/depth/1683597964_278.png', cv2.IMREAD_ANYDEPTH) 265 | # print(depth_example) 266 | # print('average depth: ', depth_example.mean()) 267 | # print('min depth: ', depth_example.min()) 268 | # print('max depth: ', depth_example.max()) 269 | # cv2.imshow('depth', depth_example) 270 | # cv2.waitKey(0) 271 | -------------------------------------------------------------------------------- /recording/ft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import subprocess 3 | import time 4 | 5 | # when the gripper is -0.3 orientated downwards 6 | EEF_PITCH_WEIGHT_OFFSET = np.array([-0.22010994, -0.02645493, 0.3561182, 0, 0, 0]) 7 | 8 | class FTCapture: 9 | def __init__(self, delay=0, ip="192.168.1.1"): 10 | print('Initializing force/torque sensor...') 11 | self.counts_per_force = 1e6 12 | self.counts_per_torque = 1e6 13 | self.first_frame_time = 0 14 | self.current_frame_time = 0 15 | self.frame_count = 0 16 | self.delay = delay # number of frames to delay before returning ft data to sync with camera in live model 17 | self.history = [] 18 | self.ip = ip 19 | 20 | def get_ft(self): 21 | (stat, output) = subprocess.getstatusoutput(f"./netft {self.ip}") 22 | output_list = output.splitlines() 23 | 24 | # resetting ft data to zero 25 | ft_values = 6*[0] 26 | 27 | # parsing output strings from network call 28 | for i in range(6): 29 | if len(output_list[i + 1][4:]) > 1: 30 | if i <= 3: 31 | ft_values[i] = int(output_list[i + 1][4:])/self.counts_per_force 32 | else: 33 | ft_values[i] = int(output_list[i + 1][4:])/self.counts_per_torque 34 | else: 35 | ft_values[i] = 0 36 | 37 | self.current_frame_time = time.time() 38 | if self.first_frame_time == 0: 39 | self.first_frame_time = self.current_frame_time 40 | self.frame_count += 1 41 | 42 | if len(ft_values) != 6: 43 | print('Error: receiving invalid force/torque data') 44 | return None 45 | 46 | # shifting 47 | if self.delay > 0: 48 | self.history.append(ft_values) 49 | if len(self.history) > self.delay: 50 | self.history.pop(0) 51 | 52 | ft = self.history[0] 53 | ft = np.array(ft, dtype='float32') 54 | return ft 55 | 56 | else: 57 | ft = np.array(ft_values, dtype='float32') 58 | 59 | if np.all(np.abs(ft) < 1e-5): 60 | print('Error: receiving invalid force/torque data') 61 | exit() 62 | 63 | return ft 64 | 65 | class MockFTCapture: 66 | """This is used to mock the FT sensor""" 67 | def __init__(self, delay=0): 68 | pass 69 | 70 | def get_ft(self): 71 | return EEF_PITCH_WEIGHT_OFFSET 72 | 73 | if __name__ == "__main__": 74 | ft = FTCapture() 75 | start_time = time.time() 76 | 77 | while True: 78 | ft_data = ft.get_ft() 79 | current_time = time.time() - start_time 80 | print(np.round(ft_data, 4)) 81 | print('Average FPS', ft.frame_count / (time.time() - ft.first_frame_time)) 82 | print(ft.frame_count, ' frames captured') 83 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | addict==2.4.0 2 | asttokens==2.2.1 3 | attrs==23.1.0 4 | backcall==0.2.0 5 | blinker==1.6.2 6 | click==8.1.3 7 | comm==0.1.3 8 | configargparse==1.5.3 9 | contourpy==1.0.7 10 | cycler==0.11.0 11 | dash==2.9.3 12 | dash-core-components==2.0.0 13 | dash-html-components==2.0.0 14 | dash-table==5.0.0 15 | debugpy==1.6.7 16 | decorator==5.1.1 17 | executing==1.2.0 18 | fastjsonschema==2.16.3 19 | filelock==3.12.0 20 | flask==2.3.1 21 | fonttools==4.39.3 22 | fsspec==2023.4.0 23 | huggingface-hub==0.14.1 24 | ipykernel==6.22.0 25 | ipython==8.12.0 26 | ipywidgets==8.0.6 27 | itsdangerous==2.1.2 28 | jedi==0.18.2 29 | joblib==1.2.0 30 | jsonschema==4.17.3 31 | jupyter-client==8.2.0 32 | jupyter-core==5.3.0 33 | jupyterlab-widgets==3.0.7 34 | kdl-py==1.1.0 35 | kiwisolver==1.4.4 36 | markupsafe==2.1.2 37 | matplotlib==3.7.1 38 | matplotlib-inline==0.1.6 39 | mpmath==1.2.1 40 | nbformat==5.7.0 41 | nest-asyncio==1.5.6 42 | open3d==0.17.0 43 | opencv-contrib-python==4.6.0.66 44 | opencv-python==4.6.0.66 45 | opencv-python-headless==4.6.0.66 46 | packaging==23.1 47 | pandas==2.0.1 48 | parso==0.8.3 49 | pexpect==4.8.0 50 | pickleshare==0.7.5 51 | pillow==9.5.0 52 | platformdirs==3.4.0 53 | plotly==5.14.1 54 | prompt-toolkit==3.0.38 55 | psutil==5.9.5 56 | ptyprocess==0.7.0 57 | pure-eval==0.2.2 58 | pygments==2.15.1 59 | pyparsing==3.0.9 60 | pyquaternion==0.9.9 61 | pyrealsense2==2.53.1.4623 62 | pyrsistent==0.19.3 63 | python-dateutil==2.8.2 64 | pytz==2023.3 65 | pyyaml==6.0 66 | pyzmq==25.0.2 67 | regex==2023.5.5 68 | scikit-learn==1.2.2 69 | scipy==1.10.1 70 | segment-anything==1.0 71 | sentencepiece==0.1.91 72 | six==1.16.0 73 | stack-data==0.6.2 74 | tenacity==8.2.2 75 | threadpoolctl==3.1.0 76 | timm==0.6.13 77 | tokenizers==0.13.3 78 | tornado==6.3.1 79 | tqdm==4.65.0 80 | traitlets==5.9.0 81 | transformers==4.28.1 82 | typing==3.7.4.3 83 | tzdata==2023.3 84 | wcwidth==0.2.6 85 | werkzeug==2.3.0 86 | widgetsnbextension==4.0.7 87 | zmq==0.0.0 88 | transformations 89 | kinpy>=0.2.2 90 | timm>=0.6.13 91 | wandb 92 | classifier_free_guidance_pytorch 93 | opencv-contrib-python==4.6.0.66 94 | git+https://github.com/openai/CLIP.git 95 | -------------------------------------------------------------------------------- /robot/kdl_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import json 4 | import requests 5 | 6 | ############################################################################## 7 | 8 | def get_forward_kinematics(joint_list, url='http://localhost:8000/fk_calc'): 9 | """Return an optional list of target pose for the camera egocam_link""" 10 | if len(joint_list) != 6: 11 | print("Error! Expect 6 joints") 12 | return None 13 | 14 | headers = {'Content-Type': 'application/json'} 15 | l = [] 16 | for j in joint_list: 17 | l.append(float(j)) 18 | data = json.dumps(l) 19 | response = requests.post(url, headers=headers, data=data) 20 | 21 | if response.status_code == 200: 22 | result = response.json() 23 | return result['res'] 24 | else: 25 | print('Error:', response.status_code) 26 | return None 27 | 28 | def get_inverse_kinematics(joint_list, url='http://localhost:8000/ik_calc'): 29 | """ 30 | Return an optional list of joint angles to reach the link_grasp_center 31 | """ 32 | if len(joint_list) != 6: 33 | print("Error! Expect 6 values, xyzrpy") 34 | return None 35 | 36 | headers = {'Content-Type': 'application/json'} 37 | l = [] 38 | for j in joint_list: 39 | l.append(float(j)) 40 | data = json.dumps(l) 41 | response = requests.post(url, headers=headers, data=data) 42 | 43 | if response.status_code == 200: 44 | result = response.json() 45 | if result is not None: 46 | return result['res'] 47 | else: 48 | return None 49 | else: 50 | print('Error:', response.status_code) 51 | return None 52 | 53 | ############################################################################## 54 | if __name__ == "__main__": 55 | r = get_forward_kinematics([0]*6) 56 | print("r", r) 57 | r = get_inverse_kinematics(r) 58 | print("r", r) 59 | r = get_inverse_kinematics([1.0]*6) 60 | print("r", r) # this will fail 61 | -------------------------------------------------------------------------------- /robot/kdl_server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import json 4 | from aiohttp import web 5 | from robot.kdl_wrapper import KdlWrapper, KdlType 6 | from typing import Optional, Dict 7 | 8 | ############################################################################## 9 | 10 | kdl_wrapper_cam = KdlWrapper.make( 11 | KdlType.FULL_ROBOT_6DOF, 12 | "robot/stretch_robot.urdf", 13 | "egocam_link", 14 | ) 15 | 16 | kdl_wrapper_gripper = KdlWrapper.make( 17 | KdlType.FULL_ROBOT_6DOF, 18 | "robot/stretch_robot.urdf", 19 | "link_grasp_center", 20 | ) 21 | 22 | ############################################################################## 23 | 24 | async def handle_fk_calc(request): 25 | try: 26 | data = await request.json() 27 | 28 | print("Calc FK, data", data) 29 | arm_seg = data[2]/4 30 | joints = [ 31 | data[0], data[1], 32 | arm_seg, arm_seg, arm_seg, arm_seg, 33 | data[3], data[4], data[5], 34 | ] 35 | origin_pose = kdl_wrapper_cam.forward_kinematics(joints) 36 | response = {'res': origin_pose} 37 | return web.json_response(response) 38 | 39 | except json.JSONDecodeError: 40 | return web.Response(status=400, text='Bad Request: Invalid JSON data.') 41 | 42 | 43 | async def handle_ik_calc(request): 44 | try: 45 | data = await request.json() 46 | 47 | print("Calc IK, data", data) 48 | target_joints = kdl_wrapper_gripper.inverse_kinematics( 49 | data, combine_arm_extension=True) 50 | response = {'res': target_joints} 51 | return web.json_response(response) 52 | 53 | except json.JSONDecodeError: 54 | return web.Response(status=400, text='Bad Request: Invalid JSON data.') 55 | 56 | 57 | app = web.Application() 58 | app.router.add_post('/ik_calc', handle_ik_calc) 59 | app.router.add_post('/fk_calc', handle_fk_calc) 60 | 61 | 62 | if __name__ == "__main__": 63 | web.run_app(app, host='localhost', port=8000) 64 | -------------------------------------------------------------------------------- /robot/kdl_wrapper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from __future__ import annotations 4 | 5 | import kdl_parser_py.urdf 6 | import PyKDL as pykdl 7 | from math import pi as PI 8 | from typing import List, Optional 9 | from tf import transformations as ts 10 | from scipy.spatial.transform import Rotation as R 11 | from utils.transform import mat_to_pose, pose_to_mat 12 | import numpy as np 13 | 14 | ######################################################################## 15 | def list_to_jnt_array(l: List): 16 | """Convert a list to a KDL JntArray""" 17 | jnt_array = pykdl.JntArray(len(l)) 18 | for i in range(len(l)): 19 | jnt_array[i] = l[i] 20 | return jnt_array 21 | 22 | ######################################################################## 23 | class KdlType: 24 | WRIST_3DOF = 0 25 | WRIST_6DOF = 1 26 | FULL_ROBOT_5DOF = 2 27 | FULL_ROBOT_6DOF = 3 # this has a fake x joint with total 9 joints 28 | 29 | ######################################################################## 30 | class KdlWrapper: 31 | @staticmethod 32 | def make( 33 | type: KdlType, urdf_file: str, target_frame: str 34 | ) -> Optional[KdlWrapper]: 35 | """ 36 | Create a KdlWrapper object for the given sws type. 37 | @return KdlWrapper object or None if failed. 38 | """ 39 | kdl_wrapper = KdlWrapper() 40 | if kdl_wrapper.__create_chain__(type, urdf_file, target_frame): 41 | print("< Done init KDLWrapper >") 42 | return kdl_wrapper 43 | return None 44 | 45 | def number_of_joints(self) -> int: 46 | """ 47 | Return the number of joints in the chain. 48 | """ 49 | return self.total_joints 50 | 51 | def forward_kinematics( 52 | self, joint_values: List[float] 53 | )-> Optional[List]: 54 | """ 55 | Compute the forward kinematics for the given joint values. 56 | @joint_values: [j0, j1, ... jN] 57 | @return: [x, y, z, roll, pitch, yaw] 58 | """ 59 | if len(joint_values) != self.total_joints: 60 | print(f"Error! Expect {self.total_joints} of joints") 61 | return None 62 | joints = list_to_jnt_array(joint_values) 63 | frame_out = pykdl.Frame() 64 | result = self.fk_solver.JntToCart(joints, frame_out, -1) 65 | if result < 0: 66 | print(f"Error! Failed to compute FK [{result}]") 67 | return None 68 | 69 | pos = frame_out.p 70 | rot = frame_out.M.GetRPY() 71 | return [pos[0], pos[1], pos[2], rot[0], rot[1], rot[2]] 72 | 73 | def inverse_kinematics( 74 | self, 75 | target_position: List[float], 76 | combine_arm_extension: bool = False 77 | ) -> Optional[List]: 78 | """ 79 | Compute the inverse kinematics for the given target position. 80 | @target_position: [x, y, z, roll, pitch, yaw] 81 | @combine_arm_extension: If True, all 4 arm translation joints 82 | will be combined into a single y trans 83 | @return: [j0, j1, ... jN], or 84 | [fake_x, lift, arm, wrist_yaw, wrist_pitch, wrist_roll] 85 | """ 86 | if len(target_position) != 6: 87 | print("Error! Expect 6 values for target position") 88 | return None 89 | 90 | x, y, z, roll, pitch, yaw = target_position 91 | q = ts.quaternion_from_euler(roll, pitch, yaw, 'rxyz') 92 | rot = ts.quaternion_matrix(q) 93 | pos_kdl = pykdl.Vector(x, y, z) 94 | rot_kdl = pykdl.Rotation(rot[0, 0], rot[0, 1], rot[0, 2], 95 | rot[1, 0], rot[1, 1], rot[1, 2], 96 | rot[2, 0], rot[2, 1], rot[2, 2]) 97 | frame_kdl = pykdl.Frame(rot_kdl, pos_kdl) 98 | 99 | # Initial guess 100 | if self.type == KdlType.WRIST_6DOF: 101 | q_kdl = list_to_jnt_array([x, y, z, roll, pitch, yaw]) 102 | else: 103 | q_kdl = pykdl.JntArray(self.total_joints) 104 | 105 | q_kdl_out = pykdl.JntArray(self.total_joints) 106 | result = self.ik_solver.CartToJnt(q_kdl, frame_kdl, q_kdl_out) 107 | 108 | if result < 0: 109 | print(f"Error! Failed to compute IK [{result}]") 110 | return None 111 | 112 | joints = [q_kdl_out[i] for i in range(self.total_joints)] 113 | 114 | if combine_arm_extension: 115 | print("WARNING! Combining arm translation joints!") 116 | # Combine the 4 arm translation joints into a single y-trans 117 | if self.type == KdlType.FULL_ROBOT_5DOF: 118 | # return 5dof (z, y, wrist: yaw, pitch, roll) 119 | return joints[:1] + [sum(joints[1:5])] + joints[5:] 120 | elif self.type == KdlType.FULL_ROBOT_6DOF: 121 | # return 6dof (x, z, y, wrist: yaw, pitch, roll) 122 | return joints[:2] + [sum(joints[2:6])] + joints[6:] 123 | 124 | return joints 125 | 126 | ######################################################################## 127 | # Private methods 128 | ######################################################################## 129 | 130 | def __init__(self): 131 | pass 132 | 133 | def __create_chain__(self, type: KdlType, urdf_file: str, target_frame: str): 134 | self.type = type 135 | 136 | (ok, tree) = kdl_parser_py.urdf.treeFromFile(urdf_file) 137 | if not ok: 138 | print("Failed to parse urdf file") 139 | return False 140 | 141 | if type == KdlType.WRIST_3DOF or type == KdlType.WRIST_6DOF: 142 | parent_frame = "link_arm_l0" 143 | elif type == KdlType.FULL_ROBOT_5DOF or type == KdlType.FULL_ROBOT_6DOF: 144 | parent_frame = "base_link" 145 | 146 | ee_chain = tree.getChain(parent_frame, target_frame) 147 | # Add new fake joints to the end effector chain 148 | self.new_ee_chain = pykdl.Chain() 149 | 150 | # Add x-y-z translation joints 151 | if type == KdlType.WRIST_6DOF: 152 | self.new_ee_chain.addSegment( 153 | pykdl.Segment("link_fake_x", 154 | pykdl.Joint("joint_fake_x", pykdl.Joint.TransX), 155 | pykdl.Frame(pykdl.Vector(0, 0, 0))) 156 | ) 157 | self.new_ee_chain.addSegment( 158 | pykdl.Segment("link_fake_y", 159 | pykdl.Joint("joint_fake_y", pykdl.Joint.TransY), 160 | pykdl.Frame(pykdl.Vector(0, 0, 0))) 161 | ) 162 | self.new_ee_chain.addSegment( 163 | pykdl.Segment("link_fake_z", 164 | pykdl.Joint("joint_fake_z", pykdl.Joint.TransZ), 165 | pykdl.Frame(pykdl.Vector(0, 0, 0))) 166 | ) 167 | # Add x translation joint as the base 168 | elif type == KdlType.FULL_ROBOT_6DOF: 169 | self.new_ee_chain.addSegment( 170 | pykdl.Segment("link_fake_x", 171 | pykdl.Joint("joint_fake_x", pykdl.Joint.TransX), 172 | pykdl.Frame(pykdl.Vector(0, 0, 0))) 173 | ) 174 | self.new_ee_chain.addChain(ee_chain) 175 | print("New chain:") 176 | print(self.new_ee_chain.getNrOfSegments()) 177 | self.total_joints = self.new_ee_chain.getNrOfJoints() 178 | 179 | if type == KdlType.WRIST_3DOF: 180 | assert self.total_joints == 3, "Error in chain creation" 181 | elif type == KdlType.WRIST_6DOF: 182 | assert self.total_joints == 6, "Error in chain creation" 183 | elif type == KdlType.FULL_ROBOT_5DOF: 184 | assert self.total_joints == 8, "Error in chain creation" 185 | elif type == KdlType.FULL_ROBOT_6DOF: 186 | assert self.total_joints == 9, "Error in chain creation" 187 | 188 | self.fk_solver = pykdl.ChainFkSolverPos_recursive(self.new_ee_chain) 189 | 190 | # create joint limits 191 | if type == KdlType.WRIST_3DOF: 192 | min_joints_list = [-PI*2]*3 193 | max_joints_list = [PI*2]*3 194 | elif type == KdlType.WRIST_6DOF: 195 | min_joints_list = [-2.0]*3 + [-PI*2]*3 196 | max_joints_list = [2.0]*3 + [PI*2]*3 197 | elif type == KdlType.FULL_ROBOT_5DOF: 198 | min_joints_list = [-2.0]*5 + [-PI*2]*3 199 | max_joints_list = [2.0]*5 + [PI*2]*3 200 | elif type == KdlType.FULL_ROBOT_6DOF: 201 | min_joints_list = [-2.0]*6 + [-PI*2]*3 202 | max_joints_list = [2.0]*6 + [PI*2]*3 203 | # min_joints_list[1] = 0.22 # lift min limit of the robot 204 | 205 | assert len(min_joints_list) == self.total_joints, "Error create joint limits" 206 | assert len(max_joints_list) == self.total_joints, "Error create joint limits" 207 | 208 | min_joints = list_to_jnt_array(min_joints_list) 209 | max_joints = list_to_jnt_array(max_joints_list) 210 | 211 | self.ik_v_kdl = pykdl.ChainIkSolverVel_pinv(self.new_ee_chain) 212 | self.ik_solver = pykdl.ChainIkSolverPos_NR_JL( 213 | self.new_ee_chain, min_joints, max_joints, 214 | self.fk_solver, self.ik_v_kdl, 200, 1e-5) 215 | 216 | return True 217 | 218 | 219 | ############################################################################## 220 | 221 | if __name__ == "__main__": 222 | kdl = KdlWrapper.make( 223 | KdlType.FULL_ROBOT_6DOF, 224 | "stretch_robot.urdf", 225 | "egocam_link", 226 | ) 227 | 228 | # Joints representation (1x9): 229 | # [ 230 | # fake-x, z-lift, 231 | # y-arm1, y-arm2, y-arm3, y-arm4, 232 | # wrist_yaw, wrist_pitch, wrist_roll 233 | # ] 234 | joints = [0.3]*9 235 | origin_pose = kdl.forward_kinematics(joints) 236 | print(origin_pose) 237 | mat = pose_to_mat(origin_pose) 238 | pose = mat_to_pose(mat) 239 | print("scipy: \n", pose) 240 | -------------------------------------------------------------------------------- /robot/kinpy_wrapper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import kinpy as kp 4 | from kinpy.frame import Frame, Link, Joint 5 | from kinpy.chain import Chain, SerialChain 6 | from kinpy.transform import Transform 7 | import transformations as tf 8 | 9 | 10 | class KinpyWrapper: 11 | def __init__(self, urdf_file) -> None: 12 | 13 | chain = kp.build_chain_from_urdf(open(urdf_file).read()) 14 | 15 | fake_x_frame = Frame( 16 | name="link_fake_x_frame", 17 | link=Link("link_fake_x"), 18 | joint=Joint("joint_fake_x", axis=[1, 0, 0], joint_type="prismatic"), 19 | children=[chain._root] 20 | ) 21 | 22 | root_base = Frame( 23 | name="link_new_base_frame", 24 | link=Link("link_new_base"), 25 | joint=Joint("joint_new_base", joint_type="fixed", offset=Transform([0, 0, 0])), 26 | children=[fake_x_frame] 27 | ) 28 | 29 | new_chain = Chain(root_base) 30 | self.sc = SerialChain(new_chain, 31 | end_frame_name="egocam_link_frame", 32 | root_frame_name="link_new_base_frame" 33 | ) 34 | 35 | def forward_kinematics(self, joints) -> list: 36 | """ 37 | compute forward kinematics 38 | :param joints: list of joint values [j0, j1, ... jN] 39 | :return: [x, y, z, roll, pitch, yaw] 40 | """ 41 | joint_names = self.sc.get_joint_parameter_names() 42 | assert len(joints) == len(joint_names) 43 | joint_dict = {} 44 | for i in range(len(joints)): 45 | joint_dict[joint_names[i]] = joints[i] 46 | solution = self.sc.forward_kinematics(joint_dict) 47 | euler = tf.euler_from_quaternion(solution.rot) 48 | return list(solution.pos) + list(euler) 49 | 50 | ################################################################################# 51 | """ 52 | TODO: this is a bad implementation for init in a global space 53 | the reason of doing this is maintain api consistency with kdl_client.py 54 | """ 55 | __urdf_file = "robot/stretch_robot.urdf" 56 | __kinpy_wrapper = KinpyWrapper(__urdf_file) 57 | 58 | def get_forward_kinematics(joint_6dofs): 59 | """ 60 | Simple wrapped function to get forward kinematics 61 | """ 62 | assert len(joint_6dofs) == 6 63 | arm_seg = joint_6dofs[2]/4.0 64 | joints = [ 65 | joint_6dofs[0], joint_6dofs[1], 66 | arm_seg, arm_seg, arm_seg, arm_seg, 67 | joint_6dofs[3], joint_6dofs[4], joint_6dofs[5], 68 | ] 69 | return __kinpy_wrapper.forward_kinematics(joints) 70 | 71 | ################################################################################# 72 | 73 | def __test_ik(): 74 | """ 75 | Simple test to compare the forward kinematics of KDL and Kinpy 76 | """ 77 | from robot.kdl_wrapper import KdlWrapper, KdlType 78 | import numpy as np 79 | 80 | joints = [0.2]*9 81 | 82 | 83 | kdl = KdlWrapper.make(KdlType.FULL_ROBOT_6DOF, __urdf_file, "egocam_link") 84 | origin_pose = kdl.forward_kinematics(joints) 85 | 86 | pose = __kinpy_wrapper.forward_kinematics(joints) 87 | print(" ----------------------------------------------------") 88 | print(origin_pose) 89 | print(pose) 90 | assert np.allclose(origin_pose, pose), "Incorrect forward kinematics!" 91 | 92 | 93 | if __name__ == "__main__": 94 | __test_ik() 95 | -------------------------------------------------------------------------------- /robot/robot_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import time 4 | import sys 5 | import cv2 6 | import numpy as np 7 | import math 8 | import threading 9 | 10 | from stretch_remote.remote_client import RemoteClient 11 | from stretch_remote.robot_utils import read_robot_status 12 | 13 | from robot.kdl_client import get_inverse_kinematics 14 | import random 15 | from utils.transform import transform_coord, pose_to_mat, mat_to_pose 16 | from tkinter import Tk, simpledialog, font 17 | 18 | ############################################################################## 19 | 20 | 21 | def set_angle(angle: float) -> float: 22 | """ 23 | Ensure the angle is within the range of [-pi, pi] radian convention 24 | """ 25 | return math.atan2(math.sin(angle), math.cos(angle)) 26 | 27 | 28 | def move_to_target(target_fingertips: np.ndarray, 29 | remote_control: RemoteClient, 30 | target_pitch=None, 31 | move_factor=1.0, # 0-1.0 32 | retries=1 # hack to make sure the robot is in the joint state 33 | ) -> bool: 34 | """ 35 | Provide the target fingertip positions and current joints to move the robot 36 | to the the target fingertip positions 37 | This is a non blocking request 38 | : target_fingertips: np.ndarray of shape (3), reference to cam frame 39 | : return: True if the action is successful, False otherwise 40 | """ 41 | s = remote_control.get_status(compact=True) 42 | current_joints = get_joints_from_robot_status(s) 43 | current_yaw_joint = current_joints[5] 44 | # find center of fingertips 45 | grasp_center = np.mean(target_fingertips, axis=0) 46 | l_contact, r_contact = target_fingertips[0], target_fingertips[1] 47 | contact_yaw = math.atan2(r_contact[2] - l_contact[2], r_contact[0] - l_contact[0]) 48 | diff_contact_yaw = contact_yaw + current_yaw_joint 49 | 50 | if diff_contact_yaw < -0.95: 51 | print("[WARNINIG] YAW is", diff_contact_yaw, " clip back to -0.95 to make it more stable") 52 | diff_contact_yaw = -0.95 53 | elif diff_contact_yaw > 1.58: 54 | print("[WARNINIG] YAW is", diff_contact_yaw, " clip back to 1.58 to make it more stable") 55 | diff_contact_yaw = 1.58 56 | 57 | transformed_coord = transform_coord(grasp_center, current_joints) 58 | # print(" - grasp center: ", grasp_center) 59 | 60 | # pitch = current_joints[4] 61 | # yaw = current_joints[3] 62 | fix_pitch = 0.3 63 | joints = get_inverse_kinematics(np.append(transformed_coord, [fix_pitch, 0, -1.57 + diff_contact_yaw])) 64 | 65 | # TODO: need to find the exact roll pitch yaw of the target from grasp center 66 | # r1 = pose_to_mat(np.array([0, 0, 0, 0, 0, -1.57 - 0.8])) 67 | # r2 = pose_to_mat(np.array([0, 0, 0, 0.2, 0, 0])) 68 | # pose = mat_to_pose(r1@r2) 69 | # joints = get_inverse_kinematics(np.append(transformed_coord, pose[3:])) 70 | if joints is None: 71 | return False 72 | 73 | if target_pitch is not None: 74 | # vector from fingertip 0 to fingertip 1 75 | vec = target_fingertips[1] - target_fingertips[0] 76 | 77 | # NOTE: We are currently not using roll 78 | target_joints = {} 79 | target_joints["x"] = (joints[0] - current_joints[0])*move_factor + current_joints[0] 80 | target_joints["y"] = (joints[2] - current_joints[1])*move_factor + current_joints[1] 81 | target_joints["z"] = (joints[1] - current_joints[2])*move_factor + current_joints[2] - \ 82 | 0.01 # TODO: compensation, might need to fix this 83 | target_joints["roll"] = 0 # (joints[5] - current_joints[3])*move_factor + current_joints[3] 84 | target_joints["pitch"] = (joints[4] - current_joints[4])*move_factor + current_joints[4] 85 | target_joints["yaw"] = (joints[3] - current_joints[5])*move_factor + current_joints[5] 86 | print(f"[DEBUG TARGET JOINTS]: {joints[0]} {joints[1]} {joints[2]} {joints[3]} {joints[4]} {joints[5]}") 87 | 88 | # calculate the distance between the two fingertips to determine 89 | # the gripper action 90 | contact_dist = np.linalg.norm(target_fingertips[0] - target_fingertips[1]) 91 | if contact_dist < 0.04: 92 | print("[DEBUG] contact_dist is less than 0.04, close gripper") 93 | target_joints["gripper"] = -20 94 | elif contact_dist < 0.08: 95 | target_joints["gripper"] = 5 96 | else: 97 | target_joints["gripper"] = 100 98 | 99 | def move_robot_async(target_joints, remote_control, retries): 100 | print("Run move robot thread...") 101 | # TODO: hack since the robot is not able to reach the desired joint state 102 | target_joints_without_grasp = target_joints.copy() 103 | target_joints_without_grasp.pop("gripper") 104 | 105 | for i in range(retries): 106 | remote_control.move(target_joints_without_grasp) 107 | time.sleep(1.3) 108 | print("run grasp") 109 | remote_control.move(target_joints) 110 | 111 | s = remote_control.get_status(compact=True) 112 | joints = get_joints_from_robot_status(s) 113 | print("----------------------\n After IK action, joints: ", joints) 114 | 115 | thread = threading.Thread(target=move_robot_async, args=(target_joints, remote_control, retries)) 116 | thread.start() 117 | return True 118 | 119 | ############################################################################## 120 | 121 | 122 | def get_pos_dict(rc: RemoteClient): 123 | """ 124 | Get the current status of the robot from RemoteClient class 125 | """ 126 | robot_status = rc.get_status(compact=True) 127 | # print('robot_status: ', robot_status) 128 | # print('robot_status is not None', robot_status is not None) 129 | # while True: 130 | if robot_status is not None: 131 | pos_dict = robot_status 132 | else: 133 | print('cannot read robot status') 134 | pos_dict = None 135 | 136 | return pos_dict 137 | 138 | ############################################################################## 139 | 140 | 141 | def keyboard_teleop(rc, deltas, keycode, self=None): # enable_moving=True, stop=False): 142 | if keycode == ord('q') and hasattr(self, 'stop'): # stop 143 | self.stop = True 144 | # if keycode == ord(' ') and hasattr(self, 'enable_moving'): # toggle moving 145 | # self.enable_moving = not self.enable_moving 146 | 147 | if keycode == ord(' ') and hasattr(self, 'keyframe'): # label as keyframe 148 | self.keyframe = True 149 | else: 150 | self.keyframe = False 151 | 152 | # if backspace is pressed, remove last keyframe 153 | if keycode == 8 and len(self.keyframe_index_list) > 0: 154 | self.delete_last_keyframe = True 155 | else: 156 | self.delete_last_keyframe = False 157 | 158 | # if enter is pressed, toggle the keyframe step (0 or 1) 159 | if keycode == 13 and hasattr(self, 'keyframe_step'): 160 | self.keyframe_step = int(not self.keyframe_step) 161 | 162 | # set the prompt 163 | if keycode == ord('p') and hasattr(self, 'prompt'): 164 | # self.prompt = input("Enter new prompt: ") 165 | # Create a Tkinter root widget 166 | root = Tk() 167 | root.withdraw() # We don't want a full GUI, so keep the root window from appearing 168 | 169 | # Customize the font size 170 | myFont = font.Font(family='Times New Roman', size=20, weight='bold') 171 | root.option_add("*Font", myFont) 172 | 173 | # Show an input box and wait for text 174 | self.prompt = simpledialog.askstring("Input", "Enter new prompt:") 175 | 176 | # self.prompt = self.prompt # NOTE: This might break live model 177 | 178 | move_ok = (self is None or (hasattr(self, 'enable_moving') and self.enable_moving)) 179 | 180 | if move_ok and rc: 181 | if keycode == ord('h'): # drive home 182 | rc.home() 183 | elif keycode == ord(']'): # drive X 184 | rc.move({'delta_x': -deltas['x']}) 185 | elif keycode == ord('['): # drive X 186 | rc.move({'delta_x': deltas['x']}) 187 | elif keycode == ord('a'): # drive Y 188 | pos_dict = get_pos_dict(rc) 189 | rc.move({'y': pos_dict['y'] - deltas['y']}) 190 | elif keycode == ord('d'): # drive Y 191 | pos_dict = get_pos_dict(rc) 192 | rc.move({'y': pos_dict['y'] + deltas['y']}) 193 | elif keycode == ord('s'): # drive Z 194 | pos_dict = get_pos_dict(rc) 195 | rc.move({'z': pos_dict['z'] - deltas['z']}) 196 | elif keycode == ord('w'): # drive Z 197 | pos_dict = get_pos_dict(rc) 198 | rc.move({'z': pos_dict['z'] + deltas['z']}) 199 | # elif keycode == ord('u'): # drive roll 200 | # rc.move({'roll':pos_dict['roll'] - deltas['roll']}) 201 | # elif keycode == ord('o'): # drive roll 202 | # rc.move({'roll':pos_dict['roll'] + deltas['roll']}) 203 | elif keycode == ord('k'): # drive pitch 204 | pos_dict = get_pos_dict(rc) 205 | rc.move({'pitch': pos_dict['pitch'] - deltas['pitch']}) 206 | elif keycode == ord('i'): # drive pitch 207 | pos_dict = get_pos_dict(rc) 208 | rc.move({'pitch': pos_dict['pitch'] + deltas['pitch']}) 209 | elif keycode == ord('l'): # drive yaw 210 | pos_dict = get_pos_dict(rc) 211 | rc.move({'yaw': pos_dict['yaw'] - deltas['yaw']}) 212 | elif keycode == ord('j'): # drive yaw 213 | pos_dict = get_pos_dict(rc) 214 | rc.move({'yaw': pos_dict['yaw'] + deltas['yaw']}) 215 | elif keycode == ord('b'): # drive gripper 216 | pos_dict = get_pos_dict(rc) 217 | rc.move({'gripper': pos_dict['gripper'] - deltas['gripper']}) 218 | elif keycode == ord('n'): # drive gripper 219 | pos_dict = get_pos_dict(rc) 220 | rc.move({'gripper': pos_dict['gripper'] + deltas['gripper']}) 221 | # elif keycode == ord('1'): # drive theta 222 | # rc.move({'theta':deltas['theta']}) 223 | # elif keycode == ord('2'): # drive theta 224 | # rc.move({'theta':-deltas['theta']}) 225 | 226 | if keycode == ord('\\') and rc: 227 | pos_dict = get_pos_dict(rc) 228 | rc.move({'z': pos_dict['z'] + deltas['z'] * 10}) 229 | if keycode == ord('=') and rc: 230 | pos_dict = get_pos_dict(rc) 231 | rc.move({'y': pos_dict['y'] - deltas['y'] * 10}) 232 | 233 | # randomize the robot 234 | if keycode == ord('/') and rc: # randomize the robot 235 | arm_random = round(random.uniform(.1, .2), 3) # .016 236 | lift_random = round(random.uniform(.8, 1.2), 3) 237 | pitch_random = round(random.uniform(-.75, .25), 3) 238 | yaw_random = round(random.uniform(.4, -.4), 3) 239 | gripper_random = round(random.uniform(-98.0, 60.0), 1) 240 | 241 | print("arm_random: ", arm_random, " lift_random: ", lift_random, " pitch_random: ", 242 | pitch_random, " yaw_random: ", yaw_random, "gripper_random: ", gripper_random) 243 | rc.move({'y': arm_random, 'z': lift_random, 244 | 'pitch': pitch_random, 'yaw': yaw_random 245 | }) 246 | # rc.move({'gripper':gripper_random}) 247 | return keycode 248 | 249 | 250 | ############################################################################## 251 | 252 | def level_robot(rc, rpy_eps=0.1, grip_eps=0.5): 253 | rc.move({'roll': 0.0, 'pitch': 0.0, 'yaw': 0.0, 'gripper': 0}) 254 | 255 | print("LEVELING ROBOT") 256 | time.sleep(1) 257 | 258 | robot_ok, pos_dict = rc.get_status() 259 | 260 | if not robot_ok: 261 | print("ROBOT NOT CONNECTED") 262 | return False 263 | elif abs(pos_dict['roll']) > rpy_eps or abs(pos_dict['pitch']) > rpy_eps or abs(pos_dict['yaw']) > rpy_eps or abs(pos_dict['gripper']) > grip_eps: 264 | print("ROBOT NOT LEVEL") 265 | return False 266 | else: 267 | print("ROBOT LEVELED") 268 | return True 269 | 270 | 271 | def get_joints_from_robot_status(status, joint_sequence=False): 272 | """ 273 | Utility function to convert rc robot status to joint list 274 | NOTE: this sequence is not correct, the convertion is in the transform.py 275 | """ 276 | j = [status['x'], status['z'], status['y'], status['yaw'], status['pitch'], status['roll']] 277 | print(f"[DEBUG] Proper joint angle seq: {j[0]} {j[1]} {j[2]} {j[3]} {j[4]} {j[5]}") 278 | if joint_sequence: 279 | return j 280 | return [status['x'], status['y'], status['z'], 281 | status['roll'], status['pitch'], status['yaw']] 282 | 283 | ############################################################################## 284 | 285 | 286 | if __name__ == "__main__": 287 | # TODO: make this configurable 288 | _d = { 289 | "x": 0, 290 | "y": 0.1, 291 | "z": 0.9067479377560522, 292 | "roll": 0.0, 293 | "pitch": 0.0, 294 | "yaw": 0.0 295 | } 296 | 297 | # NOTE: Use your own IP address 298 | # rc = RemoteClient("100.99.105.59", home_dict = _d) # RE2 299 | rc = RemoteClient("100.124.244.50", home_dict=_d) # RE1 300 | s = rc.get_status(compact=True) 301 | joints = get_joints_from_robot_status(s) 302 | print(" current joints: ", joints) 303 | # rc.home() 304 | 305 | # time.sleep(3) 306 | # ## testing target location 307 | l_fingertip = np.array([0, 0.1, 0.4]) 308 | r_fingertip = np.array([0.05, 0.1, 0.42]) 309 | 310 | result = move_to_target([l_fingertip, r_fingertip], rc) 311 | time.sleep(3) 312 | print(result) 313 | -------------------------------------------------------------------------------- /robot/visual_servo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from prediction.models import * 6 | from prediction.loader import ActAffData 7 | from prediction.live_model import LiveModel 8 | from utils.config_utils import * 9 | from utils.pred_utils import * 10 | from utils.data_pipeline import * 11 | from robot.robot_utils import * 12 | from utils.transform import * 13 | from utils.aruco_detect import ArucoPoseEstimator, find_contact_markers 14 | from utils.ft_utils import ft_to_cam_rotation 15 | 16 | import time 17 | import numpy as np 18 | import cv2 19 | from utils.realsense_utils import RealSense, CameraType 20 | 21 | from stretch_remote.robot_utils import * 22 | from stretch_remote.remote_client import RemoteClient 23 | from typing import Dict, Optional, Tuple 24 | 25 | """ 26 | Note: Robot Joint directions (remote client) and coordinates 27 | 28 | ___ z +ve 29 | ||0 30 | ||------C z +ve ^ 31 | || | x +ve 32 | || /x +ve | / 33 | [__] ----> y +ve y +ve <---/ Robot frame 34 | / x 35 | """ 36 | 37 | ##################################################################################### 38 | # CONFIGS 39 | 40 | USE_FORCE_OBJECTIVE = True 41 | 42 | TEMPORAL_FILTERING_WINDOW = 4 43 | CONTROL_FREQ = 16 # Hz 44 | 45 | # HAND_MOVEMENT_ERR_THRESH = 0.023 46 | HAND_MOVEMENT_ERR_THRESH = 0.015 47 | GRIPPER_MOVEMENT_ERR_THRESH = 0.02 48 | 49 | WRIST_FORCE_LAMBDA = 0.011 50 | GRIP_FORCE_LAMBDA = 0.013 51 | 52 | GRIPPER_EPS = 0.005 53 | GRIPPER_STATIC_OFFSET = -0.005 54 | 55 | ##################################################################################### 56 | 57 | def print_red(text): 58 | print("\033[91m {}\033[00m".format(text)) 59 | 60 | def print_green(text): 61 | print('\033[92m {}\033[00m'.format(text)) 62 | 63 | ##################################################################################### 64 | 65 | class StateError: 66 | trans = np.array([0.0, 0.0, 0.0]) 67 | wrist_force = np.array([0.0, 0.0, 0.0]) 68 | width = 0.0 69 | grip_force = 0.0 70 | 71 | @staticmethod 72 | def from_np(err): 73 | """convert from numpy array""" 74 | assert np.shape(err) == (8,) 75 | se = StateError() 76 | se.trans = err[0:3] 77 | se.wrist_force = err[3:6] 78 | se.width = err[6] 79 | se.grip_force = err[7] 80 | return se 81 | 82 | def to_np(self): 83 | """convert to numpy array""" 84 | np_arr = np.zeros(8) 85 | np_arr[:3] = self.trans 86 | np_arr[3:6] = self.wrist_force 87 | np_arr[6] = self.width 88 | np_arr[7] = self.grip_force 89 | return np_arr 90 | 91 | 92 | ##################################################################################### 93 | 94 | class VisualServo(LiveModel): 95 | def __init__(self): 96 | super().__init__() 97 | self.pos_dict = get_pos_dict(self.rc) 98 | self.rc.home() 99 | time.sleep(1) 100 | 101 | # CONFIGS 102 | self.servo_delta = np.array([0.01, 0.01, 0.01]) # x, y, z 103 | # self.force_delta = 0.015 # should be similar to above 104 | # self.grip_servo_factor = 5 105 | 106 | self.servo_min = -0.04 107 | self.servo_max = 0.04 108 | self.yaw_limit = [-1.1, 2.8] # abs yaw limit to prevent camera from hitting the wrist 109 | self.start_grasp_err_norm = 0.075 # 0.06 110 | 111 | self.prev_move_time = time.time() 112 | self.state_err_history = [] 113 | 114 | if self.args.ros_viz: 115 | from ros_scripts.ros_viz import RosVizInterface 116 | self.ros_viz = RosVizInterface() 117 | self.publish_to_rviz = True 118 | 119 | def get_temporal_error(self, err: StateError) -> StateError: 120 | self.state_err_history.append(err.to_np()) 121 | if len(self.state_err_history) > TEMPORAL_FILTERING_WINDOW: 122 | self.state_err_history.pop(0) 123 | err_np = np.array(self.state_err_history) 124 | err_np = np.mean(err_np, axis=0) 125 | return StateError.from_np(err_np) 126 | 127 | def proportional_hand_servo_delta(self, distance): 128 | if distance > 0.2: 129 | servo_d = self.servo_delta*4 130 | elif distance > 0.12: 131 | servo_d = self.servo_delta*2 132 | else: 133 | servo_d = self.servo_delta 134 | return servo_d 135 | 136 | def proportional_gripper_servo_delta(self, distance): 137 | gripper_servo = 8 138 | if distance > 0.04: 139 | servo_d = gripper_servo*3 140 | elif distance > 0.02: 141 | servo_d = gripper_servo*2 142 | else: 143 | servo_d = gripper_servo 144 | return servo_d 145 | 146 | def control_robot(self) -> Tuple[Optional[Dict], bool]: 147 | """ 148 | Return an action when the robot is in the current state, 149 | and the error norm 150 | """ 151 | next_subgoal = False 152 | # if (time.time() - self.prev_move_time) < 1.0 / CONTROL_FREQ: 153 | # return None, next_subgoal 154 | 155 | self.prev_move_time = time.time() 156 | 157 | ############################################################################ 158 | pred_l_contact, pred_r_contact = self.pred_left_fingertip, self.pred_right_fingertip 159 | curr_l_contact, curr_r_contact = self.curr_left_fingertip, self.curr_right_fingertip 160 | curr_grip_force = self.grip_force 161 | pred_grip_force = self.pred_grip_force 162 | curr_force = self.curr_force[:3] # these forces has been offset-ed 163 | pred_force = self.pred_force[:3] # these forces has been offset-ed 164 | 165 | if self.pred_left_fingertip is None or self.pred_right_fingertip is None: 166 | print("Detect Nothing, Do Nothing!") 167 | return None, next_subgoal 168 | 169 | pred_centroid, pred_width, pred_yaw = fingertips_to_centroid_width_yaw( 170 | pred_l_contact, pred_r_contact) 171 | 172 | curr_centroid, curr_width, curr_yaw = fingertips_to_centroid_width_yaw( 173 | curr_l_contact, curr_r_contact) 174 | 175 | print(f"[VS] current pred timestep: {self.pred_timestep_index}") 176 | print(' [VS]curr_centroid: ', curr_centroid) 177 | print(' [VS]pred_centroid: ', pred_centroid) 178 | print(f' [VS]curr vs pred grip width: {curr_width} \t | {pred_width}') 179 | print(f' [VS]curr vs pred grip force: {curr_grip_force} \t | {pred_grip_force}') 180 | print(' [VS]pred_force: ', pred_force) 181 | print(' [VS]curr force: ', curr_force) 182 | print(' [VS]curr force mag: ', np.linalg.norm(curr_force)) 183 | print('ablate_force: ', self.args.ablate_force) 184 | print('binary_grip: ', self.args.binary_grip) 185 | if np.linalg.norm(curr_force) > 30: 186 | print(' [VS] MAX FORCE REACHED!!!!') 187 | 188 | curr_cam_force = curr_force@ft_to_cam_rotation() 189 | curr_bot_force = camera_frame_to_robot_frame(self.pos_dict, curr_cam_force) 190 | pred_cam_force = pred_force@ft_to_cam_rotation() 191 | pred_bot_force = camera_frame_to_robot_frame(self.pos_dict, pred_cam_force) 192 | print(' [VS]curr robot frame force: ', curr_bot_force) 193 | 194 | ############################################################################ 195 | 196 | trans_err = camera_frame_to_robot_frame(self.pos_dict, 197 | pred_centroid - curr_centroid) 198 | state_err = StateError() 199 | state_err.trans = trans_err 200 | state_err.width = pred_width - curr_width 201 | state_err.wrist_force = curr_bot_force - pred_bot_force # opposing force direction to translation 202 | state_err.grip_force = curr_grip_force - pred_grip_force # opposing force direction to translation 203 | 204 | avg_state_err = self.get_temporal_error(state_err) 205 | # print(' [VS]delta robot frame force: ', avg_state_err.trans) 206 | # print(' [VS]delta robot frame force: ', avg_state_err.wrist_force) 207 | 208 | ############################################################################ 209 | 210 | # Move the robot, Kinemtic Objective + Force Objective 211 | if self.args.ablate_force: 212 | hand_movement = avg_state_err.trans 213 | else: 214 | hand_movement = avg_state_err.trans + WRIST_FORCE_LAMBDA*avg_state_err.wrist_force 215 | 216 | # End Effector Control 217 | if abs(hand_movement[0]) > HAND_MOVEMENT_ERR_THRESH*3 or abs(hand_movement[2]) > HAND_MOVEMENT_ERR_THRESH*3: 218 | print(" -------- first line up the robot in x and z-dir -------") 219 | hand_movement[1] = 0. 220 | 221 | servo_delta = self.proportional_hand_servo_delta(np.linalg.norm(hand_movement)) 222 | hand_movement_delta = servo_delta * hand_movement / np.linalg.norm(hand_movement) 223 | # print(' @@ hand movement: {hand_movement} | ', np.linalg.norm(hand_movement)) 224 | # print(' @@ movement trans delta: ', hand_movement_delta) 225 | 226 | # set a limit to trans_delta and check if there is NaN # make sure!! 227 | hand_movement_delta[np.isnan(hand_movement_delta)] = 0 228 | hand_movement_delta = np.clip(hand_movement_delta, self.servo_min, self.servo_max) 229 | 230 | control_request ={ 231 | 'x': self.pos_dict['x'] + hand_movement_delta[0], 232 | 'y': self.pos_dict['y'] - hand_movement_delta[1], # joint space != cartesian space 233 | 'z': self.pos_dict['z'] + hand_movement_delta[2], 234 | 'pitch': -0.3 235 | # 'pitch': -0.15 236 | # 'pitch': -0.0 237 | } 238 | 239 | ###################################################################### 240 | # Gripper control 241 | if self.args.ablate_force and not self.args.binary_grip: # continuous grip position with no force 242 | gripper_movement = avg_state_err.width + GRIPPER_STATIC_OFFSET 243 | elif self.args.binary_grip: # binary grip, no grip force, could still have wrist force 244 | print('gripper position: ', self.pos_dict['gripper']) 245 | WIDTH_THRESH = 0.085 246 | # CLOSED_POS = -25.0 247 | CLOSED_POS = -95.0 248 | # OPEN_POS = 50.0 249 | OPEN_POS = 50.0 250 | 251 | if pred_width < WIDTH_THRESH and self.pos_dict['gripper'] > CLOSED_POS: # if pred width is small and not closed, close it 252 | gripper_movement = -1 253 | elif pred_width > WIDTH_THRESH and self.pos_dict['gripper'] < OPEN_POS: # if pred width is large and not open, open it 254 | gripper_movement = 1 255 | else: 256 | gripper_movement = 0 257 | else: # continuous grip position with force 258 | gripper_movement = avg_state_err.width + GRIP_FORCE_LAMBDA*avg_state_err.grip_force + GRIPPER_STATIC_OFFSET 259 | 260 | servo_delta = int(self.proportional_gripper_servo_delta(gripper_movement)) 261 | if gripper_movement < -GRIPPER_EPS: 262 | gripper_control_request = self.pos_dict['gripper'] - servo_delta 263 | elif gripper_movement > GRIPPER_EPS: 264 | gripper_control_request = self.pos_dict['gripper'] + servo_delta 265 | else: 266 | gripper_control_request = self.pos_dict['gripper'] # do nothing 267 | 268 | if abs(gripper_movement) < GRIPPER_EPS*4: 269 | control_request['gripper'] = gripper_control_request 270 | else: 271 | control_request = {'gripper': gripper_control_request} # only control gripper 272 | 273 | ###################################################################### 274 | 275 | if self.args.ros_viz and self.publish_to_rviz: 276 | """This visualizes the points in 3D space""" 277 | pcd = get_point_cloud(self.rgb_image, self.depth_image, Intrinsic640()) 278 | 279 | # viz_3d.display(pcd, markers=finger_tips) 280 | self.ros_viz.publish_pcd(pcd, False) 281 | pred_finger_tips = [np.squeeze(a) for a in [pred_l_contact, pred_r_contact]] 282 | curr_finger_tips = [np.squeeze(a) for a in [curr_l_contact, curr_r_contact]] 283 | 284 | # self.ros_viz.publish_grip_force(pred_grip_force, centroid) 285 | # self.ros_viz.publish_fingertips(pred_finger_tips) 286 | self.ros_viz.publish_wrist_force(pred_cam_force, pred_centroid) 287 | self.ros_viz.publish_grip_force_with_fingertips( 288 | pred_grip_force, pred_finger_tips) 289 | 290 | self.ros_viz.publish_wrist_force(curr_cam_force, curr_centroid, is_curr=True) 291 | self.ros_viz.publish_grip_force_with_fingertips( 292 | curr_grip_force, curr_finger_tips, is_curr=True) 293 | 294 | ###################################################################### 295 | # Compute error 296 | # avg_trans_norm = np.linalg.norm(avg_state_err.trans) 297 | # avg_wrist_force_norm = np.linalg.norm(avg_state_err.wrist_force) 298 | 299 | next_subgoal = True 300 | hand_movement_norm = np.linalg.norm(hand_movement) 301 | if abs(hand_movement_norm) < HAND_MOVEMENT_ERR_THRESH: 302 | print_green(f" @hand movement: {hand_movement} | {hand_movement_norm}") 303 | else: 304 | print_red(f" @hand movement: {hand_movement} | {hand_movement_norm}") 305 | next_subgoal = False 306 | 307 | if abs(gripper_movement) < GRIPPER_MOVEMENT_ERR_THRESH: 308 | print_green(f" @gripper movement {gripper_movement}") 309 | else: 310 | print_red(f" @gripper movement {gripper_movement}") 311 | next_subgoal = False 312 | 313 | if next_subgoal: 314 | print_green("Switching to next subgoal!") 315 | self.state_err_history = [] # reset error 316 | time.sleep(0.5) 317 | return control_request, next_subgoal 318 | 319 | ################################################################################## 320 | 321 | if __name__ == '__main__': 322 | vs = VisualServo() 323 | vs.run_model(control_func=vs.control_robot) 324 | -------------------------------------------------------------------------------- /ros_scripts/joint_state_pub.py: -------------------------------------------------------------------------------- 1 | 2 | import rospy 3 | import tf2_ros as tf 4 | import threading 5 | import json 6 | import numpy as np 7 | 8 | from math import pi, atan2 9 | from typing import List, Optional, Tuple 10 | 11 | from sensor_msgs.msg import JointState 12 | from geometry_msgs.msg import Transform, TransformStamped 13 | from std_msgs.msg import ColorRGBA, Header 14 | from geometry_msgs.msg import Pose, Vector3, Point, Quaternion 15 | import geometry_msgs.msg as gm 16 | 17 | from tf import transformations as ts 18 | from visualization_msgs.msg import Marker 19 | import json 20 | import argparse 21 | 22 | ############################################################################## 23 | 24 | def transform_stamped_msg(parent_frame, child_frame, transform): 25 | t = TransformStamped() 26 | t.transform = transform 27 | t.header.frame_id = parent_frame 28 | t.header.stamp = rospy.Time.now() 29 | t.child_frame_id = child_frame 30 | return t 31 | 32 | class Pose3D(): 33 | x: float = 0.0 34 | y: float = 0.0 35 | z: float = 0.0 36 | roll: float = 0.0 37 | pitch: float = 0.0 38 | yaw: float = 0.0 39 | 40 | def get_transform(self) -> gm.Transform: 41 | """convert pose3d to tf2 geometry msg""" 42 | t = gm.Transform() 43 | t.translation = gm.Vector3(x=self.x, y=self.y, z=self.z) 44 | q = ts.quaternion_from_euler(self.roll, self.pitch, self.yaw, 'rxyz') 45 | t.rotation = gm.Quaternion(x=q[0], y=q[1], z=q[2], w=q[3]) 46 | return t 47 | 48 | def to_list(self) -> list: 49 | """convert pose3d to list""" 50 | return [self.x, self.y, self.z, self.roll, self.pitch, self.yaw] 51 | 52 | def to_matrix(self) -> np.ndarray: 53 | """convert pose3d to transformation matrix""" 54 | t = self.get_transform() 55 | rot = t.rotation 56 | trans = t.translation 57 | return ts.translation_matrix([trans.x, trans.y, trans.z]) \ 58 | @ ts.quaternion_matrix([rot.x, rot.y, rot.z, rot.w]) 59 | 60 | ############################################################################## 61 | 62 | class JointStatePublisher(): 63 | def __init__(self, target_joints): 64 | rospy.init_node("joint_state_publisher") 65 | self.tf_buffer = tf.Buffer() 66 | self.tf_listener = tf.TransformListener(self.tf_buffer) 67 | 68 | self.tf_buffer = tf.Buffer() 69 | self.tf_listener = tf.TransformListener(self.tf_buffer) 70 | 71 | self.br = tf.TransformBroadcaster() 72 | self.broadcast_tf_lock = threading.Lock() # TODO: impl this 73 | self.human_pose_transform = None 74 | self.exercise_start_pose = None 75 | self.exercise_end_pose = None 76 | 77 | self.joint_states_publisher = rospy.Publisher( 78 | "/joint_states", JointState, queue_size=1) 79 | 80 | assert len(target_joints) == 6, "joints should be a list of 6 floats" 81 | self.target_joints = target_joints 82 | self.timer = rospy.Timer(rospy.Duration(1), self._joint_publisher_callback) 83 | print("joint_state_publisher initialized") 84 | 85 | def _joint_publisher_callback(self, event): 86 | """ 87 | publish robot joints for visalization test on rviz 88 | """ 89 | fake_x, lift, arm, wrist_yaw, wrist_pitch, wrist_roll = self.target_joints 90 | # since the robot is moving on x-axis, we need to adjust the yaw 91 | # to show the human skeleton correctly 92 | 93 | msg = JointState() 94 | msg.header.stamp = rospy.Time.now() 95 | msg.name = [ 96 | "joint_wrist_pitch", 97 | "joint_wrist_yaw", 98 | "joint_wrist_roll", 99 | "joint_lift", 100 | "joint_arm_l3", 101 | "joint_arm_l2", 102 | "joint_arm_l1", 103 | "joint_arm_l0", 104 | ] 105 | msg.position = [ 106 | wrist_pitch, wrist_yaw, wrist_roll, lift, arm/4, arm/4, arm/4, arm/4, 107 | ] 108 | self.joint_states_publisher.publish(msg) 109 | 110 | tf = Pose3D() 111 | tf.x = fake_x 112 | odom_to_baselink = transform_stamped_msg( 113 | "odom", "base_link", tf.get_transform()) 114 | self.br.sendTransform(odom_to_baselink) 115 | 116 | 117 | ############################################################################## 118 | 119 | if __name__ == "__main__": 120 | # provide example usage in help 121 | example = "python joint_state_pub.py --joints 0.0 0.0 0.0 0.0 0.0 0.0" 122 | print("Example usage: ", example) 123 | 124 | parser = argparse.ArgumentParser() 125 | parser.add_argument("--joints", nargs='+', type=float, help='List of floats joints') 126 | args = parser.parse_args() 127 | 128 | print("joints: ", args.joints) 129 | jsp = JointStatePublisher(args.joints) 130 | rospy.spin() 131 | -------------------------------------------------------------------------------- /ros_scripts/ros_aruco_detect.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import cv2 4 | import numpy as np 5 | from aruco_detect import ArucoPoseEstimator, find_contact_markers 6 | from realsense_utils import DefaultIntrinsic 7 | 8 | import rospy 9 | import tf 10 | from sensor_msgs.msg import Image 11 | from visualization_msgs.msg import Marker 12 | from geometry_msgs.msg import Point, Vector3 13 | 14 | from geometry_msgs.msg import TransformStamped 15 | from cv_bridge import CvBridge 16 | 17 | 18 | ############################################################################## 19 | 20 | class ImageSubscriber: 21 | def __init__( 22 | self, 23 | aruco_pose_est, 24 | frame_ref="camera_depth_optical_frame", 25 | contact_callback=None 26 | ): 27 | """Image subscriber for aruco pose estimation 28 | 29 | Args: 30 | aruco_pose_est (_type_): _description_ 31 | frame_ref (str, optional): _description_. Defaults to "camera_depth_optical_frame". 32 | contact_callback (_type_, optional): _description_. Defaults to None. 33 | """ 34 | self.bridge = CvBridge() 35 | self.image_sub = \ 36 | rospy.Subscriber("/camera/color/image_raw", Image, self.image_callback) 37 | self.marker_pub = \ 38 | rospy.Publisher("/contact_marker", Marker, queue_size=10) 39 | self.tf_broadcaster = tf.TransformBroadcaster() 40 | self.aruco_pose_est = aruco_pose_est 41 | self.ref_frame = frame_ref 42 | self.contact_callback = contact_callback 43 | 44 | def image_callback(self, data): 45 | try: 46 | cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8") 47 | m_list = self.aruco_pose_est.detect(cv_image, viz=True) 48 | 49 | for m in m_list: 50 | # print(f"m{m.id}", m.trans, m.rot) 51 | 52 | # Publish the transform to /tf 53 | self.tf_broadcaster.sendTransform( 54 | (m.trans[0], m.trans[1], m.trans[2]), 55 | tf.transformations.quaternion_from_euler( 56 | m.rot[0], m.rot[1], m.rot[2]), 57 | rospy.Time.now(), 58 | f"m{m.id}", 59 | self.ref_frame 60 | ) 61 | 62 | if len(m_list) == 2: 63 | coor1, coor2 = find_contact_markers(m_list[0], m_list[1]) 64 | self.publish_marker_msg([coor1, coor2]) 65 | if self.contact_callback: 66 | self.contact_callback(coor1, coor2) 67 | 68 | except Exception as e: 69 | print(e) 70 | # else: 71 | # Process the cv_image here 72 | # cv2.imshow("Image", cv_image) 73 | cv2.waitKey(1) 74 | 75 | def publish_marker_msg(self, coors, size=0.01, is_point=True): 76 | marker = Marker() 77 | marker.header.frame_id = self.ref_frame 78 | marker.action = Marker.ADD # set the marker action to ADD 79 | marker.color.a = 0.8 # set the alpha 80 | marker.color.r = 1.0 81 | marker.color.b = 1.0 82 | 83 | if is_point: 84 | marker.type = Marker.POINTS # set the marker type to POINTS 85 | marker.scale.x = size*2 86 | marker.scale.y = size*2 87 | marker.scale.z = size*2 88 | 89 | for c in coors: 90 | point = Point() 91 | point.x = c[0] 92 | point.y = c[1] 93 | point.z = c[2] 94 | marker.points.append(point) 95 | else: 96 | # TODO: use cylinder to represent the contact 97 | marker.type = Marker.CYLINDER # set the marker type to CYLINDER 98 | marker.scale.x = size # set the radius of the cylinder 99 | marker.scale.y = size 100 | # use the dist between two markers as the height of the cylinder 101 | marker.scale.z = np.linalg.norm(coors[0] - coors[1]) 102 | 103 | start = coors[0] 104 | end = coors[1] 105 | center = Point() 106 | center.x = (start[0] + end[0])/2 107 | center.y = (start[1] + end[1])/2 108 | center.z = (start[2] + end[2])/2 109 | # covert vector to quaternion 110 | vector = end - start 111 | quat = tf.transformations.quaternion_about_axis(0, vector) 112 | marker.pose.position = center 113 | marker.pose.orientation.x = quat[0] 114 | marker.pose.orientation.y = quat[1] 115 | marker.pose.orientation.z = quat[2] 116 | marker.pose.orientation.w = quat[3] 117 | 118 | self.marker_pub.publish(marker) 119 | 120 | ############################################################################## 121 | 122 | if __name__ == "__main__": 123 | default_intr = DefaultIntrinsic() 124 | cam_mat = default_intr.cam_mat() 125 | cam_dist = default_intr.cam_dist() 126 | 127 | def callback(p1, p2): 128 | print("received callback", p1, p2) 129 | 130 | aruco_pose_est = ArucoPoseEstimator( 131 | cam_mat, cam_dist, marker_size=0.0155, valid_ids=[11, 12]) 132 | 133 | rospy.init_node('gripper_pose_estimator', anonymous=True) 134 | image_subscriber = ImageSubscriber(aruco_pose_est, contact_callback=callback) 135 | rospy.spin() 136 | 137 | # Release the video capture object and close all windows 138 | cv2.destroyAllWindows() 139 | -------------------------------------------------------------------------------- /ros_scripts/ros_viz.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | from scipy.spatial.transform import Rotation as R 5 | import math 6 | import numpy as np 7 | from ctypes import * # convert float to uint32 8 | 9 | import rospy 10 | import tf2_ros 11 | from sensor_msgs.msg import PointCloud2, PointField 12 | from std_msgs.msg import Header 13 | from std_msgs.msg import ColorRGBA 14 | from geometry_msgs.msg import Point, Vector3 15 | from geometry_msgs.msg import TransformStamped 16 | from visualization_msgs.msg import Marker, MarkerArray 17 | import sensor_msgs.point_cloud2 as pc2 18 | 19 | from recording.ft import FTCapture 20 | from utils.ft_utils import * 21 | from utils.transform import * 22 | from utils.realsense_utils import * 23 | 24 | # The data structure of each point in ros PointCloud2: 16 bits = x + y + z + rgb 25 | FIELDS_XYZ = [ 26 | PointField(name='x', offset=0, datatype=PointField.FLOAT32, count=1), 27 | PointField(name='y', offset=4, datatype=PointField.FLOAT32, count=1), 28 | PointField(name='z', offset=8, datatype=PointField.FLOAT32, count=1), 29 | ] 30 | FIELDS_XYZRGB = FIELDS_XYZ + \ 31 | [PointField(name='rgb', offset=12, datatype=PointField.FLOAT32, count=1)] 32 | 33 | # Bit operations 34 | BIT_MOVE_16 = 2**16 35 | BIT_MOVE_8 = 2**8 36 | 37 | ############################################################################## 38 | 39 | # Convert the datatype of point cloud from Open3D to ROS PointCloud2 (XYZRGB only) 40 | def convertCloudFromOpen3dToRos(open3d_cloud, frame_id="map", invert_color=False): 41 | """ 42 | Refer to here: 43 | https://github.com/felixchenfy/open3d_ros_pointcloud_conversion/issues/6 44 | 45 | # TODO: enable down sampling 46 | """ 47 | header = Header() 48 | header.stamp = rospy.Time.now() 49 | header.frame_id = frame_id 50 | 51 | # Set "fields" and "cloud_data" 52 | points=np.asarray(open3d_cloud.points) 53 | if not open3d_cloud.colors: # XYZ only 54 | fields=FIELDS_XYZ 55 | cloud_data=points 56 | else: # XYZ + RGB 57 | fields=FIELDS_XYZRGB 58 | # -- Change rgb color from "three float" to "one 24-byte int" 59 | # 0x00FFFFFF is white, 0x00000000 is black. 60 | colors = np.floor(np.asarray(open3d_cloud.colors)*255) 61 | colors = colors.astype(np.uint32) 62 | if invert_color: 63 | colors = colors[:,0] * BIT_MOVE_16 +colors[:,1] * BIT_MOVE_8 + colors[:,2] 64 | else: 65 | colors = colors[:,2] * BIT_MOVE_16 +colors[:,1] * BIT_MOVE_8 + colors[:,1] 66 | colors = colors.view(np.float32) 67 | cloud_data = np.column_stack((points, colors)) 68 | # create ros_cloud 69 | return pc2.create_cloud(header, fields, cloud_data) 70 | 71 | ############################################################################## 72 | 73 | class RosVizInterface: 74 | def __init__(self): 75 | # Create a ROS node 76 | # Run rosviz, make sure roscore is running 77 | rospy.init_node('force_sight_viz_publisher', anonymous=True) 78 | self.cloud_pub = rospy.Publisher('point_cloud_topic', PointCloud2, queue_size=10) 79 | self.fingertips_pub = rospy.Publisher("/contact_marker", Marker, queue_size=10) 80 | self.grip_force_pub = rospy.Publisher("/grip_force_marker", Marker, queue_size=10) 81 | 82 | self.wrist_force_pub = rospy.Publisher("/wrist_force_marker", Marker, queue_size=10) 83 | self.grip_force_fingers_pub = rospy.Publisher( 84 | "/grip_force_fingers_markers", MarkerArray, queue_size=10) 85 | 86 | self.curr_wrist_force_pub = rospy.Publisher("/curr_wrist_force_marker", Marker, queue_size=10) 87 | self.curr_grip_force_fingers_pub = rospy.Publisher( 88 | "/curr_grip_force_fingers_markers", MarkerArray, queue_size=10) 89 | 90 | self.ref_frame = "camera" 91 | 92 | static_broadcaster = tf2_ros.StaticTransformBroadcaster() 93 | # Create a TransformStamped message 94 | transform = TransformStamped() 95 | transform.header.stamp = rospy.Time.now() 96 | transform.header.frame_id = 'map' 97 | transform.child_frame_id = self.ref_frame 98 | 99 | # Set translation (assuming no translation) 100 | transform.transform.translation.x = 0.0 101 | transform.transform.translation.y = 0.0 102 | transform.transform.translation.z = 0.1 103 | 104 | # Set rotation 105 | x_angle = math.pi/2 106 | y_angle = math.pi 107 | z_angle = 0 108 | 109 | # from euler to rotation matrix 110 | rot1 = R.from_euler('xyz', [x_angle, y_angle, z_angle]) # arm 111 | rot2 = R.from_euler('xyz', [-math.pi/6, 0, 0]) # camera 112 | quaternion = (rot1*rot2).as_quat() 113 | 114 | transform.transform.rotation.x = quaternion[0] 115 | transform.transform.rotation.y = quaternion[1] 116 | transform.transform.rotation.z = quaternion[2] 117 | transform.transform.rotation.w = quaternion[3] 118 | 119 | # Publish the transform 120 | static_broadcaster.sendTransform(transform) 121 | 122 | 123 | def publish_pcd(self, pcd, invert_color=False): 124 | cloud_msg = convertCloudFromOpen3dToRos(pcd, self.ref_frame, invert_color) 125 | self.cloud_pub.publish(cloud_msg) 126 | 127 | 128 | def publish_fingertips(self, coors, size=0.005): 129 | marker = Marker() 130 | marker.header.frame_id = self.ref_frame 131 | marker.action = Marker.ADD # set the marker action to ADD 132 | marker.color.a = 0.8 # set the alpha 133 | marker.color.r = 1.0 134 | marker.color.b = 1.0 135 | 136 | marker.type = Marker.POINTS # set the marker type to POINTS 137 | marker.scale.x = size*2 138 | marker.scale.y = size*2 139 | marker.scale.z = size*2 140 | 141 | for c in coors: 142 | point = Point() 143 | point.x = c[0] 144 | point.y = c[1] 145 | point.z = c[2] 146 | marker.points.append(point) 147 | self.fingertips_pub.publish(marker) 148 | 149 | 150 | def publish_wrist_force(self, force_vec, origin, scale=0.05, is_curr=False): 151 | # Invert the force vector 152 | force_vec = -np.array(force_vec) 153 | 154 | # Convert force_vec to a unit vector 155 | force_unit_vec = force_vec / np.linalg.norm(force_vec) 156 | 157 | # Default arrow direction is along x-axis 158 | arrow_dir = np.array([1, 0, 0]) 159 | # Compute rotation axis (unit vector) 160 | rotation_axis = np.cross(arrow_dir, force_unit_vec) 161 | rotation_axis = rotation_axis / np.linalg.norm(rotation_axis) 162 | 163 | # Compute rotation angle 164 | rotation_angle = np.arccos(np.dot(arrow_dir, force_unit_vec)) 165 | 166 | # Compute quaternion from axis-angle 167 | rotation_quat = np.zeros(4) 168 | rotation_quat[0] = rotation_axis[0] * np.sin(rotation_angle / 2) 169 | rotation_quat[1] = rotation_axis[1] * np.sin(rotation_angle / 2) 170 | rotation_quat[2] = rotation_axis[2] * np.sin(rotation_angle / 2) 171 | rotation_quat[3] = np.cos(rotation_angle / 2) 172 | 173 | # Compute arrow's position at the head 174 | arrow_length = np.linalg.norm(force_vec)*scale 175 | arrow_pos = origin - arrow_length*force_unit_vec 176 | 177 | # Create a Marker message 178 | marker = Marker() 179 | marker.header.frame_id = self.ref_frame 180 | marker.type = Marker.ARROW 181 | marker.pose.position.x = arrow_pos[0] 182 | marker.pose.position.y = arrow_pos[1] 183 | marker.pose.position.z = arrow_pos[2] 184 | marker.pose.orientation.x = rotation_quat[0] 185 | marker.pose.orientation.y = rotation_quat[1] 186 | marker.pose.orientation.z = rotation_quat[2] 187 | marker.pose.orientation.w = rotation_quat[3] 188 | marker.scale.x = arrow_length # Length of the arrow 189 | marker.scale.y = 0.2*scale # Width of the arrow 190 | marker.scale.z = 0.2*scale # Height of the arrow 191 | 192 | if is_curr: 193 | marker.color.r = 1.0 194 | marker.color.g = 1.0 195 | marker.color.b = 0.0 196 | marker.color.a = 0.5 # Yellow 197 | else: 198 | marker.color.r = 1.0 199 | marker.color.g = 1.0 200 | marker.color.b = 0.0 201 | marker.color.a = 1.0 # Yellow 202 | 203 | if is_curr: 204 | self.curr_wrist_force_pub.publish(marker) 205 | else: 206 | self.wrist_force_pub.publish(marker) 207 | 208 | 209 | def publish_grip_force(self, force_magnitude, origin, force_scale=0.2, sphere_scale=.04): 210 | marker = Marker() 211 | marker.header.frame_id = self.ref_frame 212 | marker.type = Marker.SPHERE 213 | marker.action = Marker.ADD # set the marker action to ADD 214 | marker.pose.position.x = origin[0] 215 | marker.pose.position.y = origin[1] 216 | marker.pose.position.z = origin[2] 217 | marker.pose.orientation.w = 1.0 218 | marker.scale.x = sphere_scale 219 | marker.scale.y = sphere_scale 220 | marker.scale.z = sphere_scale 221 | 222 | # Calculate the color based on magnitude 223 | color = ColorRGBA() 224 | color.r = min(1.0, force_magnitude * force_scale) # Green to red transition 225 | color.g = 1.0 - color.r 226 | color.b = 0.0 227 | color.a = 0.8 228 | marker.color = color 229 | self.grip_force_pub.publish(marker) 230 | 231 | 232 | def publish_grip_force_with_fingertips(self, 233 | force_magnitude, 234 | fingertips, 235 | force_scale=0.01, 236 | is_curr=False): 237 | """ 238 | Publishes the grip force as a arrow marker with the fingertips as sphere markers 239 | """ 240 | # create 2 arrow markers from the fingertips (as origin) 241 | # to the center, which the arrow is horizontal placed on the x-axis 242 | # and the length is the force magnitude 243 | markers_msg = MarkerArray() 244 | for i in range(2): 245 | marker = Marker() 246 | marker.id = i 247 | marker.header.frame_id = self.ref_frame 248 | marker.type = Marker.ARROW 249 | marker.action = Marker.ADD 250 | marker.scale.x = force_magnitude*force_scale 251 | marker.scale.y = 0.01 # Adjust as needed for the arrow's width 252 | marker.scale.z = 0.015 # Adjust as needed for the arrow's height 253 | marker.pose 254 | 255 | if is_curr: 256 | # marker.color = ColorRGBA(0, 0.6, 1.0, 0.8) # Light blue opaque 257 | marker.color = ColorRGBA(0.4, 1.0, 0.1, 0.5) # Light green 258 | else: 259 | # marker.color = ColorRGBA(1.0, 0.2, 1.0, 0.8) 260 | marker.color = ColorRGBA(0.4, 1.0, 0.1, 1.0) # Light green 261 | if i == 0: 262 | marker.pose.position.x = fingertips[i][0] - force_magnitude*force_scale 263 | else: 264 | marker.pose.position.x = fingertips[i][0] + force_magnitude*force_scale 265 | marker.pose.position.y = fingertips[i][1] 266 | marker.pose.position.z = fingertips[i][2] 267 | # marker.pose.orientation.w = 1.0 268 | 269 | # flip the arrow direction 270 | if i == 0: 271 | marker.pose.orientation.w = 1.0 272 | else: 273 | quat = R.from_euler('xyz', [0, 0, np.pi]).as_quat() 274 | marker.pose.orientation.x = quat[0] 275 | marker.pose.orientation.y = quat[1] 276 | marker.pose.orientation.z = quat[2] 277 | marker.pose.orientation.w = quat[3] 278 | markers_msg.markers.append(marker) 279 | 280 | for i in range(2): 281 | marker = Marker() 282 | marker.id = i + 2 283 | marker.header.frame_id = self.ref_frame 284 | marker.type = Marker.SPHERE 285 | marker.action = Marker.ADD # set the marker action to ADD 286 | marker.pose.position.x = fingertips[i][0] 287 | marker.pose.position.y = fingertips[i][1] 288 | marker.pose.position.z = fingertips[i][2] 289 | marker.pose.orientation.w = 1.0 290 | marker.scale.x = 0.01 291 | marker.scale.y = 0.01 292 | marker.scale.z = 0.01 293 | 294 | color = ColorRGBA() 295 | if is_curr: 296 | color.r = 0.0 297 | color.g = 0.5 # Dark Green 298 | color.b = 0.0 299 | color.a = 0.5 300 | else: 301 | color.r = 0.0 302 | color.g = 0.5 # Dark Green 303 | color.b = 0.0 304 | color.a = 1.0 305 | 306 | marker.color = color 307 | markers_msg.markers.append(marker) 308 | 309 | # Publish the markers 310 | if is_curr: 311 | self.curr_grip_force_fingers_pub.publish(markers_msg) 312 | else: 313 | self.grip_force_fingers_pub.publish(markers_msg) 314 | 315 | 316 | ############################################################################## 317 | 318 | if __name__ == '__main__': 319 | parser = argparse.ArgumentParser() 320 | ros_viz = RosVizInterface() 321 | parser.add_argument('--rs', action='store_true', help='use realsense camera') 322 | parser.add_argument('--ft', action='store_true', help='use force torque sensor') 323 | args = parser.parse_args() 324 | 325 | if args.rs: 326 | rs = RealSense(select_device=None) 327 | intr = rs.get_camera_intrinsics(CameraType.COLOR) 328 | print(intr) 329 | 330 | if args.ft: 331 | ft_obj = FTCapture() 332 | calibrate_ft(ft_obj) 333 | offset = get_ft_calibration() 334 | cam_rot = ft_to_cam_rotation() 335 | 336 | while True: 337 | if args.rs: 338 | color_image, depth_image = rs.get_rgbd_image() 339 | disp_image = rs.display_rgbd_image(color_image, depth_image) 340 | pcd = rs.get_point_cloud(color_image, depth_image) 341 | ros_viz.publish_pcd(pcd) 342 | 343 | if args.ft: 344 | ft = ft_obj.get_ft() - offset 345 | force_vec = ft[:3]@cam_rot 346 | ros_viz.publish_wrist_force(force_vec, [0,0,0.2], scale=0.1) 347 | print(force_vec) 348 | 349 | ros_viz.publish_wrist_force([3,0,1], [0,0,0.2]) 350 | ros_viz.publish_fingertips([[1,0,0], [0,0,0.1]]) 351 | # Show images 352 | ros_viz.publish_grip_force(3.5, [0,0,0.2]) 353 | ros_viz.publish_grip_force_with_fingertips( 354 | 10.5, 355 | [[-0.3, 0, 0.2], [0.3, 0, 0.2]] 356 | ) 357 | time.sleep(1) 358 | 359 | # cv2.namedWindow('RealSense', cv2.WINDOW_AUTOSIZE) 360 | # cv2.imshow('RealSense', disp_image) 361 | # cv2.waitKey(1) 362 | -------------------------------------------------------------------------------- /ros_scripts/ros_viz.rviz: -------------------------------------------------------------------------------- 1 | Panels: 2 | - Class: rviz/Displays 3 | Help Height: 78 4 | Name: Displays 5 | Property Tree Widget: 6 | Expanded: 7 | - /Global Options1 8 | - /Status1 9 | - /PointCloud21 10 | - /Marker1/Status1 11 | - /Marker1/Namespaces1 12 | - /MarkerArray2 13 | - /Marker4 14 | Splitter Ratio: 0.5 15 | Tree Height: 514 16 | - Class: rviz/Selection 17 | Name: Selection 18 | - Class: rviz/Tool Properties 19 | Expanded: 20 | - /2D Pose Estimate1 21 | - /2D Nav Goal1 22 | - /Publish Point1 23 | Name: Tool Properties 24 | Splitter Ratio: 0.5886790156364441 25 | - Class: rviz/Views 26 | Expanded: 27 | - /Current View1 28 | Name: Views 29 | Splitter Ratio: 0.5 30 | - Class: rviz/Time 31 | Name: Time 32 | SyncMode: 0 33 | SyncSource: PointCloud2 34 | Preferences: 35 | PromptSaveOnExit: true 36 | Toolbars: 37 | toolButtonStyle: 2 38 | Visualization Manager: 39 | Class: "" 40 | Displays: 41 | - Alpha: 0.5 42 | Cell Size: 1 43 | Class: rviz/Grid 44 | Color: 160; 160; 164 45 | Enabled: true 46 | Line Style: 47 | Line Width: 0.029999999329447746 48 | Value: Lines 49 | Name: Grid 50 | Normal Cell Count: 0 51 | Offset: 52 | X: 0 53 | Y: 0 54 | Z: 0 55 | Plane: XY 56 | Plane Cell Count: 10 57 | Reference Frame: 58 | Value: true 59 | - Alpha: 1 60 | Autocompute Intensity Bounds: true 61 | Autocompute Value Bounds: 62 | Max Value: 10 63 | Min Value: -10 64 | Value: true 65 | Axis: Z 66 | Channel Name: intensity 67 | Class: rviz/PointCloud2 68 | Color: 255; 255; 255 69 | Color Transformer: RGB8 70 | Decay Time: 0 71 | Enabled: true 72 | Invert Rainbow: true 73 | Max Color: 255; 255; 255 74 | Min Color: 0; 0; 0 75 | Name: PointCloud2 76 | Position Transformer: XYZ 77 | Queue Size: 10 78 | Selectable: true 79 | Size (Pixels): 3 80 | Size (m): 0.0020000000949949026 81 | Style: Flat Squares 82 | Topic: /point_cloud_topic 83 | Unreliable: false 84 | Use Fixed Frame: true 85 | Use rainbow: true 86 | Value: true 87 | - Alpha: 1 88 | Class: rviz/Axes 89 | Enabled: true 90 | Length: 0.05000000074505806 91 | Name: Axes 92 | Radius: 0.009999999776482582 93 | Reference Frame: 94 | Show Trail: false 95 | Value: true 96 | - Class: rviz/Marker 97 | Enabled: true 98 | Marker Topic: /contact_marker 99 | Name: Marker 100 | Namespaces: 101 | {} 102 | Queue Size: 100 103 | Value: true 104 | - Class: rviz/TF 105 | Enabled: true 106 | Filter (blacklist): "" 107 | Filter (whitelist): "" 108 | Frame Timeout: 15 109 | Frames: 110 | All Enabled: true 111 | camera: 112 | Value: true 113 | map: 114 | Value: true 115 | Marker Alpha: 0.5 116 | Marker Scale: 0.5 117 | Name: TF 118 | Show Arrows: true 119 | Show Axes: true 120 | Show Names: true 121 | Tree: 122 | map: 123 | camera: 124 | {} 125 | Update Interval: 0 126 | Value: true 127 | - Class: rviz/Marker 128 | Enabled: true 129 | Marker Topic: /wrist_force_marker 130 | Name: Marker 131 | Namespaces: 132 | "": true 133 | Queue Size: 100 134 | Value: true 135 | - Class: rviz/Marker 136 | Enabled: true 137 | Marker Topic: /grip_force_marker 138 | Name: Marker 139 | Namespaces: 140 | {} 141 | Queue Size: 100 142 | Value: true 143 | - Class: rviz/MarkerArray 144 | Enabled: true 145 | Marker Topic: /grip_force_fingers_markers 146 | Name: MarkerArray 147 | Namespaces: 148 | "": true 149 | Queue Size: 100 150 | Value: true 151 | - Class: rviz/MarkerArray 152 | Enabled: false 153 | Marker Topic: /curr_grip_force_fingers_markers 154 | Name: MarkerArray 155 | Namespaces: 156 | {} 157 | Queue Size: 100 158 | Value: false 159 | - Class: rviz/Marker 160 | Enabled: false 161 | Marker Topic: /curr_wrist_force_marker 162 | Name: Marker 163 | Namespaces: 164 | {} 165 | Queue Size: 100 166 | Value: false 167 | Enabled: true 168 | Global Options: 169 | Background Color: 255; 255; 255 170 | Default Light: true 171 | Fixed Frame: map 172 | Frame Rate: 30 173 | Name: root 174 | Tools: 175 | - Class: rviz/Interact 176 | Hide Inactive Objects: true 177 | - Class: rviz/MoveCamera 178 | - Class: rviz/Select 179 | - Class: rviz/FocusCamera 180 | - Class: rviz/Measure 181 | - Class: rviz/SetInitialPose 182 | Theta std deviation: 0.2617993950843811 183 | Topic: /initialpose 184 | X std deviation: 0.5 185 | Y std deviation: 0.5 186 | - Class: rviz/SetGoal 187 | Topic: /move_base_simple/goal 188 | - Class: rviz/PublishPoint 189 | Single click: true 190 | Topic: /clicked_point 191 | Value: true 192 | Views: 193 | Current: 194 | Class: rviz/XYOrbit 195 | Distance: 0.5225079655647278 196 | Enable Stereo Rendering: 197 | Stereo Eye Separation: 0.05999999865889549 198 | Stereo Focal Distance: 1 199 | Swap Stereo Eyes: false 200 | Value: false 201 | Field of View: 0.7853981852531433 202 | Focal Point: 203 | X: 0.10746105015277863 204 | Y: -0.10062887519598007 205 | Z: 4.470348358154297e-08 206 | Focal Shape Fixed Size: false 207 | Focal Shape Size: 0.05000000074505806 208 | Invert Z Axis: false 209 | Name: Current View 210 | Near Clip Distance: 0.009999999776482582 211 | Pitch: 0.6603981852531433 212 | Target Frame: 213 | Yaw: 0.8753992319107056 214 | Saved: ~ 215 | Window Geometry: 216 | Displays: 217 | collapsed: false 218 | Height: 811 219 | Hide Left Dock: false 220 | Hide Right Dock: false 221 | QMainWindow State: 000000ff00000000fd0000000400000000000001560000028dfc0200000008fb0000001200530065006c0065006300740069006f006e00000001e10000009b0000005c00fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000a3fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c006100790073010000003d0000028d000000c900fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000186000001060000030c00000261000000010000010f0000028dfc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073010000003d0000028d000000a400fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e10000019700000003000005400000003efc0100000002fb0000000800540069006d0065010000000000000540000003bc00fffffffb0000000800540069006d00650100000000000004500000000000000000000002cf0000028d00000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000 222 | Selection: 223 | collapsed: false 224 | Time: 225 | collapsed: false 226 | Tool Properties: 227 | collapsed: false 228 | Views: 229 | collapsed: false 230 | Width: 1344 231 | X: 96 232 | Y: 585 233 | -------------------------------------------------------------------------------- /ros_scripts/urdf_viewer.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | 15 | 19 | 20 | 23 | 24 | 27 | 28 | 32 | -------------------------------------------------------------------------------- /train_models.sh: -------------------------------------------------------------------------------- 1 | python -m prediction.trainer --bipartite 1 --config final_tests_lr1e-4_jitter_6_6 2 | python -m prediction.trainer --bipartite 1 --config final_tests_lr1e-5_jitter_6_6 3 | -------------------------------------------------------------------------------- /utils/aruco_detect.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | import cv2 5 | import cv2.aruco as aruco 6 | from typing import Any, Tuple, List 7 | import pyrealsense2 as rs 8 | import numpy as np 9 | import argparse 10 | from utils.realsense_utils import RealSense, CameraType 11 | 12 | ############################################################################## 13 | 14 | def find_contact_markers(v1, v2, offset_from_marker=0.025): 15 | """ 16 | Find the contact point of the two markers on the stretch gripper 17 | """ 18 | plane_normal = np.cross(v1, v2) 19 | contact_vec1 = np.cross(v1, plane_normal) 20 | contact_vec2 = np.cross(v2, plane_normal) 21 | 22 | # convert to unit vector then scale by offset_from_marker 23 | contact_vec1 = contact_vec1*offset_from_marker/np.linalg.norm(contact_vec1) 24 | contact_vec2 = contact_vec2*offset_from_marker/np.linalg.norm(contact_vec2) 25 | 26 | # calculate contact points (left and right) 27 | contact_p1 = v1 - contact_vec1 28 | contact_p2 = v2 + contact_vec2 29 | # print("contact_p1", contact_p1) 30 | # print("contact_p2", contact_p2) 31 | return contact_p1, contact_p2 32 | 33 | 34 | ############################################################################## 35 | 36 | def get_camera_calib(calibration_file: str) -> Tuple[np.ndarray, np.ndarray]: 37 | """ 38 | Get the camera calibration parameters from the calibration file 39 | :param calibration_file: path to the calibration file 40 | :return: camera matrix and distortion coefficients 41 | """ 42 | calibration_params = cv2.FileStorage(calibration_file, cv2.FILE_STORAGE_READ) 43 | cam_mat = calibration_params.getNode("Camera_Matrix").mat() 44 | cam_dist = calibration_params.getNode("Distortion_Coefficients").mat() 45 | print(cam_mat, cam_dist) 46 | assert cam_mat is not None and cam_dist is not None, \ 47 | "Camera calibration file is not valid" 48 | return cam_mat, cam_dist 49 | 50 | ############################################################################## 51 | 52 | class MarkerPose: 53 | id: int 54 | trans: np.ndarray # xyz 55 | rot: np.ndarray # rpy 56 | 57 | class ArucoPoseEstimator: 58 | def __init__( 59 | self, 60 | camera_matrix, 61 | dist_coeffs, 62 | marker_size=0.016, 63 | valid_ids=None, 64 | ) -> None: 65 | """ 66 | This class is used to detect aruco markers from a frame and 67 | return the pose of the marker 68 | :valid_ids: a list of valid ids to detect, if None, detect all ids 69 | """ 70 | self.aruco_dict = aruco.Dictionary_get(aruco.DICT_4X4_250) 71 | self.aruco_params = aruco.DetectorParameters_create() # TODO: tune these params 72 | self.camera_matrix = camera_matrix 73 | self.dist_coeffs = dist_coeffs 74 | self.marker_size = marker_size # size of the aruco marker in meters 75 | self.valid_ids = valid_ids 76 | 77 | def detect(self, frame, viz=False) -> List[MarkerPose]: 78 | """ 79 | detect aruco marker from frame and return the pose of the marker 80 | """ 81 | corners, ids, rejected = aruco.detectMarkers( 82 | frame, self.aruco_dict, parameters=self.aruco_params) 83 | 84 | if viz: 85 | if ids is not None: 86 | aruco.drawDetectedMarkers(frame, corners, ids) 87 | cv2.imshow("frame", frame) 88 | 89 | if self.valid_ids is not None: 90 | aruco_mask = np.isin(ids, self.valid_ids) 91 | aruco_mask = np.squeeze(aruco_mask) 92 | if aruco_mask.size > 0: 93 | ids = ids[aruco_mask] 94 | corners = np.array(corners)[aruco_mask] 95 | assert len(ids) == len(corners) 96 | 97 | m_list = [] 98 | if ids is not None: 99 | rvecs, tvecs, _ = aruco.estimatePoseSingleMarkers( 100 | corners, self.marker_size, 101 | self.camera_matrix, self.dist_coeffs) 102 | for i in range(len(ids)): 103 | m = MarkerPose() 104 | m.id = ids[i][0] 105 | m.trans = tvecs[i][0] 106 | m.rot = rvecs[i][0] 107 | m_list.append(m) 108 | return m_list 109 | 110 | def get_fingertip_poses(self, rgb_image): 111 | """ 112 | get fingertip poses with aruco, return as np array 113 | For our system, ID 11 is left fingertip, ID 12 is right fingertip 114 | """ 115 | arucos = self.detect(rgb_image) 116 | 117 | ids = [x.id for x in arucos] 118 | 119 | if 11 in ids and 12 in ids: 120 | left_fingertip = np.array(arucos[ids.index(11)].trans) 121 | right_fingertip = np.array(arucos[ids.index(12)].trans) 122 | return left_fingertip, right_fingertip 123 | return None 124 | 125 | ############################################################################## 126 | 127 | 128 | if __name__ == "__main__": 129 | 130 | # init argparser 131 | parser = argparse.ArgumentParser() 132 | parser.add_argument("-l", "--load", 133 | help="load calibration file", action="store_true") 134 | parser.add_argument("--rs", help="use realsense camera", action="store_true") 135 | parser.add_argument('--realsense_id', '-rs', type=str, default=None, help='realsense id') 136 | args = parser.parse_args() 137 | 138 | # init aruco pose estimator 139 | 140 | if args.rs: 141 | device = RealSense(select_device=args.realsense_id) 142 | cam_mat, cam_dist = device.get_camera_intrinsics(CameraType.COLOR) 143 | print(cam_mat, cam_dist) 144 | else: 145 | # Initialize the video capture object 146 | cap = cv2.VideoCapture(2) 147 | calibration_file = "calibration_file.xml" 148 | cam_mat, cam_dist = get_camera_calib(calibration_file) 149 | 150 | aruco_pose_estimator = ArucoPoseEstimator(cam_mat, cam_dist) 151 | 152 | counts = {0: 0, 1: 0, 2: 0} 153 | 154 | while True: 155 | # Capture a frame from the camera 156 | if args.rs: 157 | frame = device.get_frame() 158 | ret = True 159 | else: 160 | ret, frame = cap.read() 161 | 162 | if not ret: 163 | break 164 | 165 | # # Convert the frame to grayscale 166 | # gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 167 | m_list = aruco_pose_estimator.detect(frame, viz=True) 168 | if len(m_list) == 2 and m_list[0].id == 11 and m_list[1].id == 12: 169 | coors = find_contact_markers(m_list[0].trans, m_list[1].trans) 170 | print("fingertips detected: ", coors) 171 | 172 | # Exit if the 'q' key is pressed 173 | key = cv2.waitKey(1) & 0xFF 174 | 175 | # save frame locally 176 | if key == ord('q'): 177 | # cv2.imwrite("last_frame.jpg", frame) 178 | break 179 | if key == ord('0'): 180 | print("save frame 0") 181 | cv2.imwrite(f"frame_0_{counts[0]}.jpg", frame) 182 | # TODO: save eef transformation 183 | elif key == ord('1'): 184 | print("save frame 1") 185 | cv2.imwrite(f"frame_1_{counts[1]}.jpg", frame) 186 | # TODO: save eef transformation and action target transformation 187 | elif key == ord('2'): 188 | print("save frame 2") 189 | cv2.imwrite(f"frame_2_{counts[2]}.jpg", frame) 190 | # TODO: save eef transformation and action target transformation 191 | else: 192 | pass 193 | 194 | cv2.destroyAllWindows() 195 | -------------------------------------------------------------------------------- /utils/config_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import yaml 4 | import argparse 5 | from types import SimpleNamespace 6 | 7 | def load_config(config_name): 8 | config_path = os.path.join('./config', config_name + '.yml') 9 | 10 | with open(config_path, 'r') as stream: 11 | data = yaml.safe_load(stream) 12 | 13 | data_obj = SimpleNamespace(**data) 14 | data_obj.CONFIG_NAME = config_name 15 | return data_obj 16 | 17 | def parse_config_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('-cfg', '--config', type=str, default='default') 20 | parser.add_argument('--epoch', '-e', type=str, default='best', help='model epoch to load') 21 | parser.add_argument('--index', '-i', type=str, default='0', help='keeps track of training sessions using the same config') 22 | parser.add_argument('--folder', '-f', type=str, default=None, help='folder for data_capture or folder to pull data from if not live') 23 | parser.add_argument('--stage', '-s', type=str, default=None, help='train, test, or raw') 24 | parser.add_argument('--video_name', '-vname', type=str, default=False, help='video name') 25 | parser.add_argument('--speed', '-sp', type=int, default=1, help='general speed multiplier') 26 | parser.add_argument('--prompt', '-p', type=str, default='', help='text prompt for the model') 27 | parser.add_argument('--bipartite', '-bp', type=int, default=1, help='bipartite data capture') 28 | parser.add_argument('--save_all_frames', type=int, default=0, help='save non-keyframes') 29 | parser.add_argument('--ip', type=str, default='localhost', help='robot ip') 30 | parser.add_argument('--num_folders', type=int, default=0, help='number of folders to select for data experiments') 31 | parser.add_argument('--keypoints_per_folder', type=int, default=0, help='number of keypoints per folder for data experiments') 32 | 33 | # these are all technically boolean flags 34 | parser.add_argument('--record_video', '-rec', type=int, default=0, help='record video') 35 | parser.add_argument('--use_ft', type=int, default=1, help='use force torque sensor') 36 | parser.add_argument('--filter', type=int, default=0, help='delete bad data points') 37 | parser.add_argument('--view', '-v', type=int, default=1, help='view camera and graphs') 38 | parser.add_argument('--xbox', '-x', type=int, default=0, help='use xbox controller') 39 | parser.add_argument('--realsense_id', '-rs', type=str, default=None, help='realsense id') 40 | parser.add_argument('--live', '-lv', type=int, default=1, help='use camera feed instead of args.folder') 41 | 42 | # --flag that is default false 43 | parser.add_argument('--random_trans', action='store_true', help='apply data aug with random translation') 44 | parser.add_argument('--use_mock_ft', action='store_true', help='use mock ft sensor') 45 | parser.add_argument('--ignore_prefilter', action='store_true', help='ignore prefiltering') 46 | parser.add_argument('--pretrained_model_path', type=str, default=None, help='path to pretrained model') 47 | parser.add_argument('--ros_viz', action='store_true', help='use ros viz') 48 | parser.add_argument('--evaluate', action='store_true', help='evaluate a trained model') 49 | parser.add_argument('--disable_auto_expose', action='store_true', help='disable realsense auto expose') 50 | parser.add_argument('--eval_folder', type=str, default=None, help='files to evaluate on when args.evaluate is True') 51 | parser.add_argument('--ablate_prompt', action='store_true', help='don\'t condition the model on text') 52 | parser.add_argument('--save_every_epoch', action='store_true', help='save a checkpoint every epoch') 53 | parser.add_argument('--binary_grip', action='store_true', help='binary gripper state') 54 | parser.add_argument('--ablate_force', action='store_true', help='dont use force goals') 55 | parser.add_argument('--ignore_robot', action='store_true', help='dont connect to the robot') 56 | 57 | args = parser.parse_args() 58 | return load_config(args.config), args 59 | -------------------------------------------------------------------------------- /utils/data_aug.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import ast 3 | import numpy as np 4 | import open3d as o3d 5 | 6 | from stretch_remote.robot_utils import read_robot_status 7 | from utils.transform import calc_tf_delta, pose_to_mat 8 | from utils.realsense_utils import pcd_from_rgbd, Intrinsic640 9 | 10 | import argparse 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | 14 | ################################################################################ 15 | 16 | def load_robot_state(data_point): 17 | with open(data_point, 'r') as f: 18 | robot_state = read_robot_status(ast.literal_eval(f.read())) 19 | return robot_state 20 | 21 | def get_joints(data_point_state): 22 | s = load_robot_state(data_point_state["state"]) 23 | return [s["x"], s["y"], s["z"], s["roll"], s["pitch"], s["yaw"]] 24 | 25 | def get_transform(data_point, ref_joint): 26 | i_joints = get_joints(data_point[1]) 27 | f_joints = get_joints(data_point[2]) 28 | i_delta_mat = calc_tf_delta(ref_joint, i_joints) 29 | f_delta_mat = calc_tf_delta(ref_joint, f_joints) 30 | return i_delta_mat, f_delta_mat 31 | 32 | def crop_imgs(image, image2, x1, y1, x2, y2, intr): 33 | """ 34 | Return cropped image and the transformation matrix 35 | :return: cropped image, transformation matrix 36 | """ 37 | # get size of image 38 | h, w, _ = image.shape 39 | # TODO: use shrink_ratio, x_movement_ratio, y_move_ratio, intr): 40 | # new_h = int(h * shrink_ratio) 41 | # new_w = int(w * shrink_ratio) 42 | 43 | # # potential movement 44 | # delta_h = h - new_h 45 | # delta_w = w - new_w 46 | 47 | rx_1 = np.tan((x1-w/2)/intr.fx) 48 | rx_2 = np.tan((x2-w/2)/intr.fx) 49 | ry_1 = np.tan((y1-h/2)/intr.fy) 50 | ry_2 = np.tan((y2-h/2)/intr.fy) 51 | # print("rx_1: {}, rx_2: {}, ry_1: {}, ry_2: {}".format(rx_1, rx_2, ry_1, ry_2)) 52 | 53 | # TODO: make it more better, convert the direction 54 | # TODO: include translation from ppx ppy 55 | avg_pitch = (rx_1 + rx_2) / 2 56 | avg_yaw = (ry_1 + ry_2) / 2 57 | mat = pose_to_mat([0, 0, 0, avg_pitch, avg_yaw, 0]) 58 | # print("avg_pitch: {}, avg_yaw: {}".format(avg_pitch, avg_yaw)) 59 | return image[y1:y2, x1:x2], image2[y1:y2, x1:x2], mat 60 | 61 | def add_noise_to_borders(image, x, y, is_rgb=False): 62 | # Add noise only to the borders of the image 63 | if is_rgb: 64 | noise_img = np.random.random(image.shape).astype(np.float32) 65 | noise_img = (noise_img * 255).astype(np.uint8) 66 | else: 67 | # this is a depth image, get the max depth value and apply a small noise to the noise img 68 | max_depth = float(np.max(image)) 69 | noise_img = np.random.random(image.shape).astype(np.float32) 70 | noise_img = (-noise_img*20000 + max_depth).astype(np.uint16) # the current depth image is uint16 71 | 72 | image = image.copy() 73 | # apply change the border pixels to noise according to the x, y translation 74 | if y > 0: 75 | image[:y, :] = noise_img[:y, :] 76 | elif y < 0: 77 | image[y:, :] = noise_img[y:, :] 78 | if x > 0: 79 | image[:, :x] = noise_img[:, :x] 80 | elif x < 0: 81 | image[:, x:] = noise_img[:, x:] 82 | return image 83 | 84 | def translate_imgs(rgb_img, depth_img, x, y, intr, 85 | rgb_noise=True, depth_noise=True): 86 | """ 87 | :arg x: x translation in pixel +ve is right -ve is left 88 | :arg y: y translation in pixel +ve is down -ve is up 89 | :return: translated rgb and depth images, transformation matrix 90 | """ 91 | rows, cols = rgb_img.shape[:2] 92 | affine_mat = np.float32([[1, 0, x], 93 | [0, 1, y]]) 94 | # Apply the translation matrix to shift the image right 95 | translated_rgb_img = cv2.warpAffine( 96 | rgb_img, affine_mat, (cols, rows), borderValue=(255, 255, 255)) 97 | translated_depth_img = cv2.warpAffine( 98 | depth_img, affine_mat, (cols, rows), borderValue=float(np.max(depth_img))) 99 | 100 | if rgb_noise: 101 | translated_rgb_img = add_noise_to_borders( 102 | translated_rgb_img, x, y, is_rgb=True) 103 | if depth_noise: 104 | translated_depth_img = add_noise_to_borders( 105 | translated_depth_img, x, y, is_rgb=False) 106 | 107 | h, w, _ = rgb_img.shape 108 | ry = -np.tan(x/intr.fx) # x trans is changing the y-axis rotation of the cam 109 | rx = np.tan(y/intr.fy) # y trans is changing the x-axis rotation of the cam 110 | # print("rx_1: {}, rx_2: {}, ry_1: {}, ry_2: {}".format(rx_1, rx_2, ry_1, ry_2)) 111 | 112 | # TODO: include translation from ppx ppy? or undistort the image 113 | mat = pose_to_mat([0, 0, 0, rx, ry, 0]) 114 | # print("cam roll: {}, pitch: {}".format(rx, ry)) 115 | return translated_rgb_img, translated_depth_img, mat 116 | 117 | ################################################################################ 118 | 119 | def get_pcd(data_point, intr): 120 | i_pcd = pcd_from_rgbd( 121 | data_point[1]["rgb"], data_point[1]["depth"], intr) 122 | f_pcd = pcd_from_rgbd( 123 | data_point[2]["rgb"], data_point[2]["depth"], intr) 124 | return i_pcd, f_pcd 125 | 126 | def remove_gripper(pcd): 127 | """remove a retangular region of the point cloud 128 | this is to remove the robot gripper""" 129 | invert = True # Invert the mask will then show the points to be removed 130 | pcd_np = np.asarray(pcd.points) 131 | xmin, xmax = -0.1, 0.1 132 | ymin, ymax = -0.1, 0.1 133 | zmin, zmax = 0., 0.22 134 | # Apply the condition to the points 135 | mask = np.logical_or( 136 | np.logical_or( 137 | np.logical_or(pcd_np[:, 0] < xmin, pcd_np[:, 0] > xmax), 138 | np.logical_or(pcd_np[:, 1] < ymin, pcd_np[:, 1] > ymax), 139 | ), 140 | np.logical_or(pcd_np[:, 2] < zmin, pcd_np[:, 2] > zmax), 141 | ) 142 | # Apply the mask to remove points 143 | return pcd.select_by_index(np.where(mask == invert)[0]) 144 | -------------------------------------------------------------------------------- /utils/ft_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | from recording.ft import FTCapture 4 | import math 5 | 6 | def calibrate_ft(ft_obj): 7 | ft = ft_obj.get_ft() 8 | np.save('ft_calibration.npy', ft) 9 | 10 | 11 | def get_ft_calibration(): 12 | return np.load('ft_calibration.npy') 13 | 14 | 15 | def ft_to_cam_rotation(custom_pitch=(10/90)*math.pi/2): 16 | """ 17 | custom_pitch: camera pitch angle, down is -ve and up is +ve 18 | """ 19 | custom_pitch = np.array([[1, 0, 0], 20 | [0, np.cos(custom_pitch), -np.sin(custom_pitch)], 21 | [0, np.sin(custom_pitch), np.cos(custom_pitch)]]) 22 | return np.array([[0, 1, 0], 23 | [-1, 0, 0], 24 | [0, 0, 1]])@custom_pitch 25 | 26 | ############################################################################## 27 | 28 | if __name__ == '__main__': 29 | robot = None 30 | ft_obj = FTCapture() 31 | calibrate_ft(ft_obj) 32 | offset = get_ft_calibration() 33 | frame_rotation = ft_to_cam_rotation() 34 | 35 | while True: 36 | ft = ft_obj.get_ft() 37 | ft = ft - offset 38 | # print('calibrated FT: ', ft) 39 | # rounding to 3 decimal places 40 | print('calibrated FT: ', np.round(ft, 3)[:3]) 41 | print('calibrated FT in camera frame: ', np.round(ft[:3] @ frame_rotation, 3)) 42 | time.sleep(0.4) 43 | -------------------------------------------------------------------------------- /utils/pred_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from utils.realsense_utils import Intrinsic640, get_point_cloud, display_point_cloud 4 | import os 5 | import cv2 6 | import torch 7 | import numpy as np 8 | from prediction.models import * 9 | from prediction.deep_fusion import DeepFusion 10 | from prediction.classifier_free_guidance import * 11 | from prediction.grip_force_model import GripForceMLP 12 | from robot.robot_utils import * 13 | from utils.transform import * 14 | from utils.data_pipeline import normalize_gripper_pos, normalize_gripper_effort 15 | 16 | from transformers import T5Tokenizer, BertTokenizer 17 | 18 | import os 19 | import shutil 20 | from pathlib import Path 21 | 22 | 23 | def save_img(image, img_name): 24 | # Create a new directory if it doesn't exist 25 | # Normalize float img to [0,255] and convert to uint8 26 | normalized_image = cv2.normalize(image, None, 0, 255, cv2.NORM_MINMAX) 27 | uint8_image = normalized_image.astype(np.uint8) 28 | directory = f'{os.path.expanduser("~")}/debug_imgs' 29 | if not os.path.exists(directory): 30 | os.makedirs(directory) 31 | 32 | image_path = os.path.join(directory, f'{img_name}.png') 33 | cv2.imwrite(image_path, uint8_image) 34 | 35 | 36 | def t2np(tensor): 37 | return tensor.detach().cpu().numpy()[0] 38 | 39 | 40 | def t2float(tensor): 41 | return tensor.detach().cpu().numpy().item() 42 | 43 | 44 | def create_file_index(folder): 45 | # find the index of the next file to be saved 46 | if len(os.listdir(folder)) == 0: 47 | file_index = 0 48 | else: 49 | # e.g. if folder is 'data/rgb/rgb_1_2_0.png', then the new index will be 0 + 1 = 1 50 | # file_index = max(int(name.split('_')[-1]) for name in os.listdir(folder)) + 1 51 | file_index = max(int(name.split('_')[-1].split('.')[0]) 52 | for name in os.listdir(folder)) + 1 53 | return file_index 54 | 55 | 56 | def load_model(config, checkpoint_path=None): 57 | print('loading model...') 58 | print('USE_RGBD: ', config.USE_RGBD) 59 | # load model from config 60 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 61 | 62 | if config.IMAGE_MODEL.split('-')[0] == 'vit': 63 | 64 | if hasattr(config, 'PRETRAINED'): 65 | image_model = RGBDViT(size=config.IMAGE_MODEL.split('-')[1], 66 | config=config, 67 | num_classes=0, 68 | pretrained=config.PRETRAINED).to(device) # tiny, small, base, large 69 | else: 70 | image_model = RGBDViT(size=config.IMAGE_MODEL.split('-')[1], 71 | config=config, 72 | num_classes=0, 73 | ).to(device) 74 | 75 | if config.FREEZE_IMAGE_MODEL: 76 | for param in image_model.parameters(): 77 | param.requires_grad = False 78 | if config.USE_RGBD: 79 | # train the patch embedding layer 80 | image_model.model.patch_embed.proj.requires_grad = True 81 | 82 | elif config.IMAGE_MODEL.split('-')[0] == 'clip': 83 | clip_model = RGBDCLIP( 84 | size=config.IMAGE_MODEL.split('-')[1], 85 | config=config, 86 | ).to(device) 87 | image_model, text_model = clip_model, clip_model.model.encode_text 88 | tokenizer = clip.tokenize 89 | 90 | if config.FREEZE_IMAGE_MODEL: 91 | for param in clip_model.visual.parameters(): 92 | param.requires_grad = False 93 | if config.USE_RGBD: 94 | # train the patch embedding layer 95 | clip_model.model.visual.conv1.requires_grad = True 96 | 97 | elif config.IMAGE_MODEL.split('-')[0] == 'dinov2': 98 | image_model = RGBDDinov2( 99 | size=config.IMAGE_MODEL.split('-')[1], 100 | config=config, 101 | ).to(device) 102 | 103 | if config.FREEZE_IMAGE_MODEL: 104 | for param in image_model.parameters(): 105 | param.requires_grad = False 106 | if config.USE_RGBD: 107 | # train the patch embedding layer 108 | image_model.model.patch_embed.proj.requires_grad = True 109 | else: 110 | print('Image model in config not recognized') 111 | 112 | # we don't need to load separate text models for this option 113 | if config.MULTIMODAL_HEAD != 'classifier-free-guidance': 114 | if config.TEXT_MODEL == 't5-small': 115 | # text_model = T5Model.from_pretrained("t5-small").encoder.to(device) 116 | text_model = T5(size='small').to(device) 117 | tokenizer = T5Tokenizer.from_pretrained("t5-small") 118 | elif config.TEXT_MODEL == 't5-base': 119 | # text_model = T5Model.from_pretrained("t5-base").encoder.to(device) 120 | text_model = T5(size='base').to(device) 121 | tokenizer = T5Tokenizer.from_pretrained("t5-base") 122 | elif config.TEXT_MODEL == 't5-large': 123 | # text_model = T5Model.from_pretrained("t5-large").encoder.to(device) 124 | text_model = T5(size='large').to(device) 125 | tokenizer = T5Tokenizer.from_pretrained("t5-large") 126 | 127 | elif config.TEXT_MODEL == 'bert-base': 128 | text_model = Bert(size='base').to(device) 129 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 130 | elif config.TEXT_MODEL == 'bert-large': 131 | text_model = Bert(size='large').to(device) 132 | tokenizer = BertTokenizer.from_pretrained('bert-large-uncased') 133 | print('text model: ', text_model) 134 | else: 135 | print('Text model in config not recognized') 136 | 137 | if config.IMAGE_MODEL.split('-')[0] == 'clip' and config.FREEZE_TEXT_MODEL: 138 | for param in clip_model.model.transformer.parameters(): # freeze the text model 139 | param.requires_grad = False 140 | for param in clip_model.model.ln_final.parameters(): 141 | param.requires_grad = False 142 | for param in clip_model.model.token_embedding.parameters(): 143 | param.requires_grad = False 144 | elif config.IMAGE_MODEL.split('-')[0] == 'clip' and not config.FREEZE_TEXT_MODEL: 145 | for param in clip_model.model.transformer.parameters(): # don't freeze the text model 146 | param.requires_grad = True 147 | for param in clip_model.model.ln_final.parameters(): 148 | param.requires_grad = True 149 | for param in clip_model.model.token_embedding.parameters(): 150 | param.requires_grad = True 151 | elif config.FREEZE_TEXT_MODEL: 152 | for param in text_model.parameters(): 153 | param.requires_grad = False 154 | else: 155 | for param in text_model.parameters(): 156 | param.requires_grad = True 157 | 158 | if config.MULTIMODAL_HEAD == 'concat-linear-attn-mlp': 159 | model = ConcatLinearAttnMLP(image_model, text_model) 160 | elif config.MULTIMODAL_HEAD == 'vision-only-linear': 161 | model = VisionOnlyLinear(image_model) 162 | elif config.MULTIMODAL_HEAD == 'vision-only-mlp': 163 | model = VisionOnlyMLP(image_model) 164 | elif config.MULTIMODAL_HEAD == 'vision-only-threeheads': 165 | model = ThreeHeadMLP(image_model) 166 | elif config.MULTIMODAL_HEAD == 'deep-fusion': 167 | model = DeepFusion(image_model, text_model) 168 | elif config.MULTIMODAL_HEAD == 'classifier-free-guidance' and config.IMAGE_MODEL.split('-')[0] == 'vit': 169 | # model = ClassifierFreeGuidance(image_model, text_model) 170 | # model = ConditionedVisionTransformer(image_model, text_model='t5') # text model can only be t5 or clip for now 171 | model = ConditionedVisionTransformer( 172 | image_model, 173 | text_model='t5', # text model can only be t5 or clip for now 174 | config=config, 175 | hidden_dim=256, 176 | ) 177 | tokenizer = T5Tokenizer.from_pretrained("t5-large") 178 | elif config.MULTIMODAL_HEAD == 'classifier-free-guidance' and config.IMAGE_MODEL.split('-')[0] == 'clip': 179 | # modified cfg repo to use openai ViT-L/14 model 180 | model = ConditionedCLIP( 181 | image_model, text_model='clip', config=config, hidden_dim=256) 182 | elif config.MULTIMODAL_HEAD == 'classifier-free-guidance' and config.IMAGE_MODEL.split('-')[0] == 'dinov2': 183 | model = ConditionedDinov2( 184 | image_model, text_model='t5', config=config, hidden_dim=256) 185 | if checkpoint_path is not None: 186 | model.load_state_dict(torch.load(checkpoint_path, map_location=device)) 187 | model = model.to(device) 188 | return model, tokenizer 189 | 190 | 191 | class MovieWriter: 192 | def __init__(self, path, fps=30): 193 | self.writer = None 194 | self.path = path 195 | self.fps = fps 196 | 197 | def write_frame(self, frame): 198 | if self.writer is None: 199 | self.mkdir(self.path, cut_filename=True) 200 | self.writer = cv2.VideoWriter(self.path, cv2.VideoWriter_fourcc( 201 | 'M', 'J', 'P', 'G'), self.fps, (frame.shape[1], frame.shape[0])) 202 | self.writer.write(frame) 203 | 204 | def mkdir(self, path, cut_filename=False): 205 | if cut_filename: 206 | path = os.path.dirname(os.path.abspath(path)) 207 | Path(path).mkdir(parents=True, exist_ok=True) 208 | 209 | def close(self): 210 | self.writer.release() 211 | 212 | 213 | def load_grip_force_model(): 214 | # to avoid multiprocessing issues since grip_force_model is tiny and running in the data loader 215 | device = torch.device('cpu') 216 | 217 | grip_force_model = GripForceMLP(num_inputs=3, num_outputs=1) 218 | grip_force_model.load_state_dict( 219 | torch.load( 220 | 'grip_force_checkpoints/grip_force_dist_pos_effort_5_25_4/model_best.pth', 221 | map_location=device)) 222 | grip_force_model.eval() 223 | return grip_force_model 224 | 225 | 226 | def run_grip_force_model(grip_force_model, pos_dict, curr_left_fingertip, curr_right_fingertip): 227 | with torch.no_grad(): 228 | curr_left_fingertip = torch.tensor( 229 | curr_left_fingertip).float().unsqueeze(0) 230 | curr_right_fingertip = torch.tensor( 231 | curr_right_fingertip).float().unsqueeze(0) 232 | gripper_pos = torch.tensor( 233 | pos_dict['gripper']).float().unsqueeze(0).unsqueeze(0) 234 | gripper_effort = torch.tensor( 235 | pos_dict['gripper_effort']).float().unsqueeze(0).unsqueeze(0) 236 | 237 | fingertip_dist = torch.norm( 238 | curr_left_fingertip - curr_right_fingertip, dim=1).float().unsqueeze(0) 239 | gripper_pos = normalize_gripper_pos(gripper_pos) 240 | gripper_effort = normalize_gripper_effort(gripper_effort) 241 | 242 | model_input = torch.cat((gripper_pos.unsqueeze( 243 | 1), gripper_effort.unsqueeze(1), fingertip_dist.unsqueeze(1)), dim=1) 244 | # model_input = torch.cat((gripper_pos.unsqueeze(1), fingertip_dist.unsqueeze(1)), dim=1) 245 | # model_input = gripper_pos.unsqueeze(1) 246 | grip_force = grip_force_model(model_input) 247 | 248 | grip_force = grip_force.reshape((-1, 1)) 249 | 250 | return grip_force 251 | 252 | 253 | def copy_folder_contents(src_folder, dest_folder): 254 | # check if destination folder exists 255 | if not os.path.exists(dest_folder): 256 | os.makedirs(dest_folder) # if not, create it 257 | 258 | for filename in os.listdir(src_folder): 259 | file_path = os.path.join(src_folder, filename) 260 | 261 | if os.path.isfile(file_path): 262 | shutil.copy(file_path, dest_folder) # copy files 263 | else: 264 | shutil.copytree(file_path, os.path.join( 265 | dest_folder, filename)) # copy directories 266 | -------------------------------------------------------------------------------- /utils/realsense_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | import cv2 5 | import pyrealsense2 as rs 6 | import open3d as o3d 7 | import time 8 | from typing import Any, Tuple, List 9 | import argparse 10 | # from mesh_from_rgbd import * 11 | 12 | ############################################################################## 13 | 14 | class DefaultIntrinsic: 15 | """ 16 | This is the default intrinsic from the D405 realseense rgb camera with 848 17 | """ 18 | ppx = 415.507537841797 19 | ppy = 237.871643066406 20 | fx = 431.125 21 | fy = 430.667 22 | coeffs = [-0.053341, 0.0545209, 0.000824648, 0.000749805, -0.0171459] 23 | 24 | # https://github.com/IntelRealSense/librealsense/issues/3473#issuecomment-474637827 25 | depth_scale = 9.9999e-05 26 | 27 | def cam_mat(self): 28 | return camera_matrix(self) 29 | 30 | def cam_dist(self): 31 | return fisheye_distortion(self) 32 | 33 | class Intrinsic640(DefaultIntrinsic): 34 | ppx = 311.508 35 | ppy = 237.872 36 | fx = 431.125 37 | fy = 430.667 38 | coeffs = [-0.053341, 0.0545209, 0.000824648, 0.000749805, -0.0171459] 39 | 40 | def camera_matrix(intrinsics): 41 | return np.array([[intrinsics.fx, 0, intrinsics.ppx], 42 | [0, intrinsics.fy, intrinsics.ppy], 43 | [0, 0, 1]]) 44 | 45 | 46 | def fisheye_distortion(intrinsics): 47 | return np.array(intrinsics.coeffs[:4]) 48 | 49 | ############################################################################## 50 | 51 | def get_point_cloud(color_image, depth_image, intrinsics): 52 | """creating point cloud in open3d""" 53 | fx, fy = intrinsics.fx, intrinsics.fy 54 | cx, cy = intrinsics.ppx, intrinsics.ppy 55 | 56 | # Create an Open3D camera intrinsic object 57 | intrinsic = o3d.camera.PinholeCameraIntrinsic() 58 | intrinsic.set_intrinsics(640, 480, fx, fy, cx, cy) 59 | 60 | # print('color_image dtype', color_image.dtype) 61 | # print('depth_image dtype', depth_image.dtype) 62 | 63 | # check if rgb imge is dtype uint8 else convert 64 | if color_image.dtype == np.float32: 65 | print("color image is not uint8") 66 | color_image = color_image*255 67 | color_image = color_image.astype(np.uint8) 68 | 69 | # check if rgb imge is dtype uint16 else convert 70 | if depth_image.dtype == np.float32: 71 | # convert to uint16 72 | print("depth image is not uint16") 73 | depth_image = depth_image*65535.0 74 | depth_image = depth_image.astype(np.uint16) 75 | 76 | color_image = o3d.geometry.Image(color_image) 77 | depth_image = o3d.geometry.Image(depth_image) 78 | 79 | # Create an Open3D RGBDImage from the color and depth images 80 | rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth( 81 | color_image, depth_image, depth_scale=10000.0, 82 | depth_trunc=2.0, convert_rgb_to_intensity=False) 83 | 84 | # Create a point cloud from the RGBD image and camera intrinsic parameters 85 | return o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, intrinsic) 86 | 87 | def pcd_from_rgbd(rgb, depth, intr): 88 | color_image = cv2.imread(rgb) 89 | depth_image = cv2.imread(depth, cv2.IMREAD_ANYDEPTH) 90 | 91 | # show both rgb and depth images 92 | # img = np.hstack((color_image, depth_image) 93 | # cv2.imshow('image', img) 94 | depth_image = depth_image.astype(np.float32) 95 | depth_image = depth_image.astype(np.int16) 96 | 97 | # Test resize, INTER_NEAREST is important to keep the depth values 98 | depth_image = cv2.resize(depth_image, (320, 240), interpolation=cv2.INTER_NEAREST) 99 | depth_image = cv2.resize(depth_image, (640, 480), interpolation=cv2.INTER_NEAREST) 100 | 101 | return get_point_cloud(color_image, depth_image, intr) 102 | 103 | ############################################################################## 104 | 105 | class PCDViewer: 106 | def __init__(self, skip_frames=10, blocking=False): 107 | """ 108 | pcd viewer to display point cloud 109 | :arg skip_frames: number of frames to skip, for live display 110 | :arg blocking: if true, will block until window is closed, this will 111 | enable mouse control when blocking 112 | """ 113 | self.vis = o3d.visualization.Visualizer() 114 | self.vis.create_window() 115 | # Set a custom viewpoint: TODO: edit this, viewing angle 116 | # rotate around z axis 117 | ctr = self.vis.get_view_control() 118 | ctr.set_lookat([0, 0, 0.2]) 119 | ctr.set_front([0, -1, -1]) 120 | ctr.set_up([0, -1, 0]) 121 | self.skip = skip_frames 122 | self.counter = 0 123 | self.blocking = blocking 124 | # self.pcd = o3d.geometry.PointCloud() 125 | # self.vis.add_geometry(self.pcd) 126 | 127 | def display(self, points, markers=[], ignore_skip=False): 128 | """ 129 | display point cloud and fingertip markers 130 | ignore skip will display the frame regardless of skip 131 | """ 132 | if ignore_skip or self.blocking: 133 | self.counter = 0 134 | 135 | _markers = [] 136 | for pos in markers: 137 | marker_pos = pos.reshape(3) # np.array([0.5, 0.5, 0.5]) 138 | marker = o3d.geometry.TriangleMesh.create_sphere(radius=0.005) 139 | marker.paint_uniform_color([1, 0, 0]) # Set the color to red 140 | marker.translate(marker_pos) 141 | _markers.append(marker) 142 | 143 | if self.counter%self.skip == 0: 144 | # To get viz camera control of the previous frame, then reset the view 145 | # since add geo will reset it: 146 | # https://github.com/isl-org/Open3D/issues/2264 147 | ctr = self.vis.get_view_control() 148 | view_param = ctr.convert_to_pinhole_camera_parameters() 149 | self.vis.clear_geometries() 150 | self.vis.add_geometry(points) 151 | for m in _markers: 152 | self.vis.add_geometry(m) 153 | 154 | ctr.convert_from_pinhole_camera_parameters(view_param) 155 | self.counter = 0 156 | self.counter += 1 157 | 158 | # print number of points 159 | print("Number of points: {}".format(len(points.points))) 160 | 161 | if self.blocking: 162 | self.vis.run() 163 | else: 164 | self.vis.poll_events() 165 | self.vis.update_renderer() 166 | 167 | def __del__(self): 168 | self.vis.destroy_window() 169 | 170 | 171 | def display_point_cloud(points, markers=[]): 172 | # point cloud visualization 173 | vis = o3d.visualization.Visualizer() 174 | vis.create_window() 175 | 176 | vis.clear_geometries() 177 | vis.add_geometry(points) 178 | print("Number of points: {}".format(len(points.points))) 179 | 180 | for pos in markers: 181 | marker_pos = pos.reshape(3) # np.array([0.5, 0.5, 0.5]) 182 | marker = o3d.geometry.TriangleMesh.create_sphere(radius=0.005) 183 | marker.paint_uniform_color([1, 0, 0]) # Set the color to red 184 | marker.translate(marker_pos) 185 | vis.add_geometry(marker) 186 | 187 | # Set a custom viewpoint: TODO: edit this, viewing angle 188 | # rotate around z axis 189 | ctr = vis.get_view_control() 190 | ctr.set_lookat([0, 0, 0.2]) 191 | ctr.set_front([0, -0.5, -0.5]) 192 | ctr.set_up([0, -1, 0]) 193 | 194 | # vis.run() 195 | while vis.poll_events(): 196 | vis.update_renderer() 197 | vis.destroy_window() 198 | 199 | ############################################################################## 200 | 201 | class CameraType: 202 | COLOR = rs.stream.color 203 | DEPTH = rs.stream.depth 204 | 205 | 206 | ############################################################################## 207 | class RealSense: 208 | def __init__(self, select_device="127122270519", view=False, auto_expose=True): #127122270519 209 | # Configure depth and color streams 210 | self.pipeline = rs.pipeline() 211 | self.config = rs.config() 212 | print("selecting device: {}".format(select_device)) 213 | if select_device is not None: 214 | self.config.enable_device(select_device) 215 | 216 | self.config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30) 217 | self.config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30) 218 | # Start streaming 219 | cfg = self.pipeline.start(self.config) 220 | 221 | color_sensor = cfg.get_device().query_sensors()[0] 222 | color_sensor.set_option(rs.option.enable_auto_exposure, auto_expose) 223 | 224 | # getting camera intrinsics 225 | _depth_profile = cfg.get_stream(rs.stream.depth) 226 | self.depth_intr = \ 227 | _depth_profile.as_video_stream_profile().get_intrinsics() 228 | _rgb_profile = cfg.get_stream(rs.stream.color) 229 | self.rgb_intr = \ 230 | _rgb_profile.as_video_stream_profile().get_intrinsics() 231 | 232 | self.first_frame_time = 0 233 | self.current_frame_time = 0 234 | self.frame_count = 0 235 | self.view = view 236 | 237 | def get_rgbd_image(self): 238 | """Return a pair of color and depth frame from the realsense camera.""" 239 | frames = self.pipeline.wait_for_frames() 240 | depth_frame = frames.get_depth_frame() 241 | color_frame = frames.get_color_frame() 242 | self.current_frame_time = time.time() 243 | 244 | if self.first_frame_time == 0: 245 | self.first_frame_time = self.current_frame_time 246 | 247 | if not depth_frame or not color_frame: 248 | print("frame {} was bad".format(self.frame_count)) 249 | return None 250 | 251 | self.frame_count += 1 252 | 253 | # Convert images to numpy arrays 254 | depth_image = np.asanyarray(depth_frame.get_data()) # depth in mm 255 | color_image = np.asanyarray(color_frame.get_data()) 256 | 257 | # print('depth_min', depth_image.min()) 258 | # print('depth_mean', depth_image.mean()) 259 | # print('depth_max', depth_image.max()) 260 | 261 | return color_image, depth_image 262 | 263 | def display_rgbd_image(self, color_image, depth_image): 264 | """Display color and depth images.""" 265 | 266 | # Apply colormap on depth image (image must be converted to 8-bit per pixel first) 267 | depth_colormap = cv2.applyColorMap( 268 | cv2.convertScaleAbs(depth_image, alpha=0.03), cv2.COLORMAP_JET) 269 | 270 | # Stack both images horizontally 271 | images = np.hstack((color_image, depth_colormap)) 272 | return images 273 | 274 | def get_point_cloud(self, color_image, depth_image): 275 | """creating point cloud in open3d""" 276 | # TODO: check if intrinsics should be from depth or color camera 277 | return get_point_cloud(color_image, depth_image, self.depth_intr) 278 | 279 | def get_camera_intrinsics( 280 | self, type: CameraType 281 | ) -> Tuple[np.ndarray, np.ndarray]: 282 | """ 283 | Get the camera calibration parameters from the realsense camera 284 | Ref: https://github.com/IntelRealSense/librealsense/blob/master/wrappers/python/examples/t265_stereo.py 285 | :return: camera matrix and distortion coefficients 286 | """ 287 | if type == CameraType.COLOR: 288 | mat = camera_matrix(self.rgb_intr) 289 | dist = fisheye_distortion(self.rgb_intr) 290 | elif type == CameraType.DEPTH: 291 | mat = camera_matrix(self.depth_intr) 292 | dist = fisheye_distortion(self.depth_intr) 293 | return mat, dist 294 | 295 | def get_frame(self, type: CameraType = CameraType.COLOR): 296 | """ 297 | Get a frame from the realsense camera 298 | """ 299 | frames = self.pipeline.wait_for_frames() 300 | if type == CameraType.COLOR: 301 | color_frame = frames.get_color_frame() 302 | return np.asanyarray(color_frame.get_data()) 303 | else: 304 | depth_frame = frames.get_depth_frame() 305 | depth_image = np.asanyarray(depth_frame.get_data()) 306 | depth_colormap = cv2.applyColorMap( 307 | cv2.convertScaleAbs(depth_image, alpha=0.03), cv2.COLORMAP_JET) 308 | return depth_colormap 309 | 310 | # detruct the wrapper 311 | def __del__(self): 312 | self.pipeline.stop() 313 | 314 | ############################################################################## 315 | 316 | if __name__ == '__main__': 317 | parser = argparse.ArgumentParser() 318 | parser.add_argument('--cloud', action="store_true", help='view pointcloud') 319 | parser.add_argument('-d', '--device', type=str, default=None, help='device to use') 320 | parser.add_argument('--rgb', type=str, default=None, help='path to rgb image') 321 | parser.add_argument('--depth', type=str, default=None, help='path to depth image') 322 | args = parser.parse_args() 323 | 324 | if args.rgb is not None and args.depth is not None: 325 | print('viewing pointcloud from images') 326 | 327 | pcd = pcd_from_rgbd(args.rgb, args.depth, Intrinsic640()) 328 | pcd_vis = PCDViewer(blocking=True) 329 | pcd_vis.display(pcd) 330 | 331 | # save the point cloud 332 | o3d.io.write_point_cloud('viz_pcd.ply', pcd) 333 | cv2.waitKey(0) 334 | cv2.destroyAllWindows() 335 | exit() 336 | 337 | rs = RealSense(select_device=args.device) 338 | intr = rs.get_camera_intrinsics(CameraType.COLOR) 339 | print(intr) 340 | 341 | pcd_vis = PCDViewer() 342 | while True: 343 | color_image, depth_image = rs.get_rgbd_image() 344 | disp_image = rs.display_rgbd_image(color_image, depth_image) 345 | if args.cloud: 346 | print('depth_image dtype', depth_image.dtype) 347 | 348 | pcd = rs.get_point_cloud(color_image, depth_image) 349 | pcd_vis.display(pcd) 350 | 351 | # save depth and color images 352 | # cv2.imwrite('color.png', color_image) 353 | depth_image = depth_image.astype('uint16') 354 | # cv2.imwrite('depth.png', depth_image) 355 | 356 | # Show images 357 | cv2.namedWindow('RealSense', cv2.WINDOW_AUTOSIZE) 358 | cv2.imshow('RealSense', disp_image) 359 | cv2.waitKey(1) 360 | -------------------------------------------------------------------------------- /utils/render_videos.py: -------------------------------------------------------------------------------- 1 | # python script that takes all folders in 'results/experiments_8_14/" and passes them into the following function: 2 | # python3 -m prediction.live_model --config final_tests_jitter_default_6_6 --index 0 \ 3 | # --epoch best --ip 192.168.0.148 --record_video 1 --video_name folder_name --live 0 --folder 4 | 5 | import os 6 | import subprocess 7 | 8 | root_folder = "results/experiments_8_14/" 9 | folders = os.listdir(root_folder) 10 | 11 | for folder in folders: 12 | if folder == ".DS_Store": 13 | continue 14 | folder = root_folder + folder 15 | print(folder) 16 | command = f"python3 -m prediction.live_model --config final_tests_jitter_default_6_6\ 17 | --index 0 --epoch best --ip 192.168.0.148 --record_video 1 \ 18 | --video_name " + folder + " --live 0 --folder " + folder 19 | print(command) 20 | subprocess.call(command, shell=True) 21 | print("done") 22 | -------------------------------------------------------------------------------- /utils/test_aug.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import cv2 4 | import ast 5 | import numpy as np 6 | import open3d as o3d 7 | 8 | from prediction.loader import get_data 9 | from utils.transform import calc_tf_delta, pose_to_mat 10 | from utils.realsense_utils import pcd_from_rgbd, Intrinsic640 11 | from utils.data_aug import * 12 | 13 | import argparse 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | 17 | ################################################################################ 18 | 19 | if __name__ == "__main__": 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--data", type=str, default="data/raw/cube_data_5_13_frame_1_2_7") 23 | parser.add_argument("--no_gripper", action="store_true") 24 | parser.add_argument("--crop_pic", action="store_true") 25 | parser.add_argument("--translate_pic", action="store_true") 26 | args = parser.parse_args() 27 | 28 | # data = get_data(["data/raw/cube_data_5_13_frame_1_2_7"], shuffle=False) 29 | # data = get_data(["data/raw/mouse_data_5_18_frame_1_2_6"], shuffle=False) 30 | # data = get_data(["data/raw/mouse_data_5_18_frame_1_2_3"], shuffle=False) 31 | 32 | intr = Intrinsic640() 33 | data = get_data([args.data], shuffle=False) 34 | print("Number of samples", len(data)) 35 | print("Sample keys", data[0]) 36 | 37 | index = 1 38 | 39 | if args.crop_pic: 40 | # load image 41 | rgb_img = cv2.imread(data[index][1]["rgb"], cv2.IMREAD_COLOR) 42 | depth_img = cv2.imread(data[index][1]["depth"], cv2.IMREAD_ANYDEPTH) 43 | 44 | # crop image 45 | cropped_img, _, mat = crop_imgs(rgb_img, depth_img, 220, 140, 650, 540, intr) 46 | 47 | # show concatenated image 48 | cv2.imshow("ori", rgb_img) 49 | cv2.imshow("cropped", cropped_img) 50 | cv2.waitKey(0) 51 | cv2.destroyAllWindows() 52 | exit() 53 | 54 | if args.translate_pic: 55 | # load image 56 | rgb_img = cv2.imread(data[index][1]["rgb"], cv2.IMREAD_COLOR) 57 | depth_img = cv2.imread(data[index][1]["depth"], cv2.IMREAD_ANYDEPTH) 58 | 59 | trans_x = 100 60 | trans_y = -100 61 | rgb_img, depth_img, _ = translate_imgs(rgb_img, depth_img, trans_x, trans_y, intr) 62 | 63 | # Assuming depth_img is the depth image you want to visualize 64 | # Normalize the depth values between 0 and 1 65 | normalized_depth_img = (depth_img - np.min(depth_img)) / (np.max(depth_img) - np.min(depth_img)) 66 | 67 | # Apply a colormap to the normalized depth image 68 | colormap = plt.get_cmap('viridis') # You can choose a different colormap if desired 69 | depth_img_color = colormap(normalized_depth_img) 70 | 71 | # Display the depth image 72 | # cv2.imshow("rgb img", rgb_img) 73 | # Plot the RGB image in the first subplot 74 | fig, axs = plt.subplots(1, 2) 75 | axs[0].imshow(rgb_img) 76 | axs[0].axis('off') 77 | axs[0].set_title('RGB Image') 78 | # Plot the colorized depth image in the second subplot 79 | axs[1].imshow(depth_img_color) 80 | axs[1].axis('off') 81 | axs[1].set_title('Colorized Depth Image') 82 | plt.tight_layout() 83 | plt.show() 84 | exit() 85 | 86 | ref_joint = get_joints(data[index][1]) # initial 87 | 88 | i_delta, f_delta = get_transform(data[index], ref_joint) 89 | i_pcd, f_pcd = get_pcd(data[index], intr) 90 | i_pcd.transform(i_delta) 91 | # f_pcd.transform(f_delta) 92 | # Load point cloud 93 | # pcd = o3d.io.read_point_cloud("viz_pcd.ply") 94 | # Convert point cloud to voxel 95 | combined_pcd = i_pcd 96 | if args.no_gripper: 97 | combined_pcd = remove_gripper(combined_pcd) 98 | 99 | # Convert Open3D.o3d.geometry.PointCloud to numpy array 100 | print("number of points", np.asarray(combined_pcd.points).shape) 101 | # cropped_pcd = combined_pcd.crop(bounding_box.inversed()) 102 | # print("number of points", np.asarray(cropped_pcd.points).shape) 103 | 104 | for i in [8, 15, 20, 4, 28, 4]: 105 | i_delta, f_delta = get_transform(data[i], ref_joint) 106 | i_pcd, f_pcd = get_pcd(data[i], intr) 107 | if args.no_gripper: 108 | i_pcd = remove_gripper(i_pcd) 109 | i_pcd.transform(i_delta) 110 | # f_pcd.transform(f_delta) 111 | # Load point cloud 112 | # pcd = o3d.io.read_point_cloud("viz_pcd.ply") 113 | # Convert point cloud to voxel 114 | combined_pcd += i_pcd 115 | 116 | # print number of points 117 | print("number of points", np.asarray(combined_pcd.points).shape) 118 | combined_pcd = combined_pcd.voxel_down_sample(voxel_size=0.001) 119 | print("number of points", np.asarray(combined_pcd.points).shape) 120 | # TODO: remove outliers 121 | cl , ind = combined_pcd.remove_statistical_outlier(nb_neighbors=5, std_ratio=0.5) 122 | combined_pcd = combined_pcd.select_by_index(ind) 123 | print("number of points", np.asarray(combined_pcd.points).shape) 124 | 125 | # visaualize combined point cloud 126 | o3d.visualization.draw_geometries([combined_pcd]) 127 | 128 | # # Visualize voxel grid 129 | # voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(combined_pcd, voxel_size=0.001) 130 | # o3d.visualization.draw_geometries([voxel_grid]) 131 | # o3d.visualization.draw_geometries([pcd]) 132 | -------------------------------------------------------------------------------- /utils/transform.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from scipy.spatial.transform import Rotation as R 4 | import numpy as np 5 | from typing import List, Optional, Tuple 6 | 7 | # NOTE: migrated from kdl to utils/kinpy_wrapper.py 8 | from robot.kinpy_wrapper import get_forward_kinematics 9 | 10 | ######################################################################## 11 | 12 | def pose_to_mat(pose: List[float]) -> np.ndarray: 13 | """Convert a xyzrpy list to a 4x4 matrix""" 14 | assert len(pose) == 6 15 | r = R.from_euler('xyz', [pose[3], pose[4], pose[5]], degrees=False) 16 | rot = r.as_matrix() 17 | mat = np.eye(4) 18 | mat[:3, :3] = rot 19 | mat[0, 3] = pose[0] 20 | mat[1, 3] = pose[1] 21 | mat[2, 3] = pose[2] 22 | return mat 23 | 24 | def inverse_mat(mat): 25 | return np.linalg.inv(mat) 26 | 27 | def matmul_mat(mat1, mat2): 28 | return np.matmul(mat1, mat2) 29 | 30 | def mat_to_pose(mat): 31 | "4x4 matrix to xyzrpy list, not using tf" 32 | translation = mat[:3, 3] 33 | rotation = R.from_matrix(mat[:3, :3]) 34 | return np.concatenate((translation, rotation.as_euler('xyz', degrees=False))) 35 | 36 | def tf_from_mats(mat1, mat2): 37 | """ 38 | Compute the transform from mat1 to mat2 39 | """ 40 | return matmul_mat(inverse_mat(mat1), mat2) 41 | 42 | def stretch_joint_sequence(joints): 43 | # convert robot xyzrpy to stretch custom joints sequence 44 | # NOTE: this is a hack, need a better function design 45 | return [joints[0], joints[2], joints[1], joints[5], joints[4], joints[3]] 46 | 47 | ######################################################################## 48 | 49 | def calc_tf_delta(initial_joint, final_joint): 50 | """ 51 | Calculate the transform from initial to final joint 52 | """ 53 | # this requires the kdl server to be running 54 | init_cam_pose = get_forward_kinematics( 55 | stretch_joint_sequence(final_joint)) 56 | final_cam_pose = get_forward_kinematics( 57 | stretch_joint_sequence(initial_joint)) 58 | init_cam_pose = pose_to_mat(init_cam_pose) 59 | final_cam_pose = pose_to_mat(final_cam_pose) 60 | return tf_from_mats(final_cam_pose, init_cam_pose) 61 | 62 | def get_transformed_fingertips( 63 | init_state, final_state, final_left_fingertip, final_right_fingertip, 64 | optional_camera_tf: Optional[np.array] = None 65 | ) -> Tuple[np.array, np.array]: 66 | """ 67 | Generate the transformed fingertip positions in the initial frame 68 | NOTE: this requires the kdl server to be running 69 | NOTE: the state is in the order of robot's [x, y, z, roll, pitch, yaw] 70 | :arg optional_camera_tf: the transform of the new camera after custom 71 | augmenting the image (e.g. translation and cropping) 72 | :return fl_coor, fr_coor: fingertip positions in the initial frame 73 | """ 74 | _f_s = final_state.cpu().numpy() 75 | _i_s = init_state.cpu().numpy() 76 | 77 | # _delta_tf = calc_tf_delta(_f_s, _i_s) 78 | init_cam_pose = get_forward_kinematics( 79 | stretch_joint_sequence(_i_s)) 80 | final_cam_pose = get_forward_kinematics( 81 | stretch_joint_sequence(_f_s)) 82 | init_cam_pose = pose_to_mat(init_cam_pose) 83 | 84 | if optional_camera_tf is not None: 85 | init_cam_pose = init_cam_pose@optional_camera_tf 86 | 87 | final_cam_pose = pose_to_mat(final_cam_pose) 88 | _delta_tf = tf_from_mats(final_cam_pose, init_cam_pose) 89 | 90 | # np add 3 more 0s to the end of the pose 91 | final_left_fingertip = np.append(final_left_fingertip, [0, 0, 0]) 92 | fl_mat = pose_to_mat(final_left_fingertip) 93 | final_right_fingertip = np.append(final_right_fingertip, [0, 0, 0]) 94 | fr_mat = pose_to_mat(final_right_fingertip) 95 | 96 | fl_diff_mat = tf_from_mats(_delta_tf, fl_mat) 97 | fr_diff_mat = tf_from_mats(_delta_tf, fr_mat) 98 | fl_coor = np.array(mat_to_pose(fl_diff_mat)[:3]) 99 | fr_coor = np.array(mat_to_pose(fr_diff_mat)[:3]) 100 | return fl_coor, fr_coor 101 | 102 | def transform_coord(coord: np.array, curr_joints, from_cam_to_world=True): 103 | """ 104 | This function is used to transform the target (e.g. fingertip) 105 | coordinates from the camera frame to the world frame or vice versa 106 | :arg coord: target coordinates [x, y, z] in cam frame if True 107 | :arg curr_joints: current joint angles [x, lift, arm, yaw, pitch, roll] 108 | """ 109 | target_pose = np.append(coord, [0, 0, 0]) 110 | target_pose = pose_to_mat(target_pose) 111 | 112 | # print(" - target_pose", target_pose) 113 | cam_pose = get_forward_kinematics(stretch_joint_sequence(curr_joints)) 114 | print(" - cam_pose", cam_pose) 115 | cam_pose = pose_to_mat(cam_pose) 116 | 117 | if from_cam_to_world: 118 | new_pose = cam_pose@target_pose 119 | else: 120 | new_pose = inverse_mat(target_pose)@cam_pose 121 | coord = np.array(mat_to_pose(new_pose)[:3]) 122 | print(" - target coord", coord) 123 | return coord 124 | 125 | def camera_frame_to_robot_frame(pos_dict, vec): 126 | # rotating -100 around robot x axis, 10 degree is camera angle 127 | rad_100 = -100 * np.pi / 180 128 | rad_100 += pos_dict['pitch'] 129 | rot_x_100 = np.array([[1, 0, 0], 130 | [0, np.cos(rad_100), -np.sin(rad_100)], 131 | [0, np.sin(rad_100), np.cos(rad_100)]]) 132 | 133 | # rotating 180 around robot z axis 134 | rot_z_180 = np.array([[np.cos(np.pi), -np.sin(np.pi), 0], 135 | [np.sin(np.pi), np.cos(np.pi), 0], 136 | [0, 0, 1]]) 137 | 138 | return rot_z_180 @ rot_x_100 @ vec 139 | 140 | def robot_frame_to_camera_frame(pos_dict, vec): 141 | # rotating 180 around robot z axis 142 | rot_z_180 = np.array([[np.cos(np.pi), -np.sin(np.pi), 0], 143 | [np.sin(np.pi), np.cos(np.pi), 0], 144 | [0, 0, 1]]) 145 | 146 | # rotating 100 around robot x axis, 10 degree is camera angle 147 | rad_100 = 100 * np.pi / 180 148 | rad_100 += pos_dict['pitch'] 149 | rot_x_100 = np.array([[1, 0, 0], 150 | [0, np.cos(rad_100), -np.sin(rad_100)], 151 | [0, np.sin(rad_100), np.cos(rad_100)]]) 152 | return rot_x_100 @ rot_z_180 @ vec 153 | --------------------------------------------------------------------------------