├── .gitignore ├── README.md ├── VLM_CaP ├── VideoPerception │ ├── src │ │ ├── __pycache__ │ │ │ └── vlm_video.cpython-39.pyc │ │ ├── video_cropper.py │ │ ├── vlm_perception.py │ │ └── vlm_video.py │ ├── video_crop.ipynb │ └── vlm_pic_set.ipynb ├── simulation.py └── src │ ├── LMP.py │ ├── __init__.py │ ├── configs.py │ ├── env.py │ ├── grippers.py │ ├── key.py │ ├── prompts.py │ ├── setup.py │ └── vlm_video.py ├── convert_video.py ├── get_frame_by_hands.py ├── hand_landmarker.task ├── media └── main.jpg ├── requirements.txt ├── track_anything.py ├── track_objects.py └── vlm.py /.gitignore: -------------------------------------------------------------------------------- 1 | ############################models############################# 2 | # GroundingDINO/ 3 | GroundingDINO/ 4 | 5 | # inpainter/ 6 | inpainter/ 7 | 8 | # segment_anything 9 | segment_anything/ 10 | segment-anything-2/ 11 | 12 | # tools 13 | tools/ 14 | 15 | # Track-Anything 16 | Track-Anything/ 17 | 18 | # tracker 19 | tracker/ 20 | 21 | # pth model 22 | *.pth 23 | 24 | # csv files 25 | *.csv 26 | 27 | # pycache 28 | __pycache__/ 29 | 30 | # results 31 | results/ 32 | 33 | # bash code 34 | *.sh 35 | 36 | run.py 37 | 38 | ############################media############################### 39 | # media/ 40 | media/ 41 | test_media/ 42 | VLM_CaP/bowl/ 43 | VLM_CaP/robotiq_2f_85/ 44 | VLM_CaP/ur5e/ 45 | VLM_CaP/src/key.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SeeDo: Human Demo Video to Robot Action Plan via Vision Language Model 2 | 3 | **VLM See, Robot Do (SeeDo)** is a method that uses large vision models, tracking models and vision-language models to extract robot action plan from human demonstration video, specifically focusing on long horizon pick-and-place tasks. The action plan is then implemented in realworld and PyBullet simulation environment. 4 | 5 | ![main](https://github.com/ai4ce/SeeDo/blob/main/media/main.jpg) 6 | 7 | ## Setup Instructions 8 | 9 | Note that SeeDo relies on GroundingDINO, SAM and SAM2. The code has only been tested on Ubuntu 20.04. The version of CUDA tested is 11.8, the Pytorch version is 2.3.1+cu118. 10 | 11 | - Install SeeDo and create a new environment 12 | 13 | ```python 14 | git clone https://github.com/ai4ce/SeeDo 15 | conda create --name seedo python=3.10.14 16 | conda activate seedo 17 | cd SeeDo 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | - Install Pytorch (Only for CUDA 11.8 user) 22 | 23 | ```python 24 | pip install torch==2.3.1+cu118 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 25 | ``` 26 | 27 | - Install GroundingDINO, SAM and SAM2 in the same environment 28 | 29 | ```python 30 | git clone https://github.com/IDEA-Research/GroundingDINO 31 | git clone https://github.com/facebookresearch/segment-anything.git 32 | git clone https://github.com/facebookresearch/segment-anything-2.git 33 | ``` 34 | 35 | - Make sure these models are installed in editable packages 36 | 37 | ```python 38 | cd GroundingDINO 39 | pip install -e . 40 | ``` 41 | And do the same with segment-anything, segment-anything-2 42 | 43 | - We have slightly modified the GroundingDINO 44 | 45 | In `GroundingDINO/groundingdino/util/inference.py`, we add a function to help inference on an array of images. Please paste the following function into `inference.py`. 46 | 47 | ```python 48 | def load_image_from_array(image_array: np.array) -> Tuple[np.array, torch.Tensor]: 49 | transform = T.Compose( 50 | [ 51 | T.RandomResize([800], max_size=1333), 52 | T.ToTensor(), 53 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 54 | ] 55 | ) 56 | image_source = Image.fromarray(image_array) 57 | image_transformed, _ = transform(image_source, None) 58 | return image_array, image_transformed 59 | ``` 60 | 61 | - The code still uses one checkpoint from segment-anything. 62 | 63 | Make sure you download it in the SeeDo folder. 64 | **`default` or `vit_h`: [ViT-H SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth)** 65 | 66 | - Obtain an OpenAI API key and create a `key.py` file under `VLM_CaP/src` 67 | 68 | ```python 69 | cd VLM_CaP/src 70 | touch key.py 71 | echo 'projectkey = "YOUR_OPENAI_API_KEY"' > key.py 72 | ``` 73 | 74 | ## Pipeline 75 | 76 | There are mainly four parts of SeeDo. To ensure the video is successfully processed in subsequent steps, use `convert_video.py` to convert the video to the appropriate encoding before inputting it. The `convert_video.py` script accepts two parameters: `--input` and `--output`, which specify the path of your original video and the path of the converted video, respectively. 77 | 78 | 1. **Keyframe Selection Module** 79 | 80 | `get_frame_by_hands.py`: The `get_frame_by_hands.py` script allows selecting key frames by tracking hand movements. It accepts two parameters. 81 | 82 | `--video_path`, which specifies the path of the input video. 83 | 84 | `--output_dir`, which designates the directory where the key frames will be saved. If `output_dir` is not specified, the keyframes will be saved to `./output` by default. For debugging purpose, the hand image and hand speed plot will also be saved in this directory. 85 | 86 | 2. **Visual Perception Module** 87 | 88 | `track_objects.py`: The `track_objects.py` script is used to track each object and add a visual prompt for the objects. It also returns a string containing the center coordinates of each object in the key frames. The script accepts three parameters. 89 | 90 | `--input` is the video converted to the appropriate format. 91 | 92 | `--output` specifies the output path for the video with the visual prompts. 93 | 94 | `--key_frames` is the list of key frame indices obtained from `get_frames_by_hands.py`. 95 | 96 | This module will return a `box_list` string stored for useage in VLM Reasoning Module 97 | 98 | 3. **VLM Reasoning Module** 99 | 100 | `vlm.py`: The `vlm.py` script performs reasoning on the key frames and generates an action list for the video. It accepts three parameters. 101 | 102 | `--input` is the video with visual prompts added by the Visual Perception Module. 103 | 104 | `--list` is the keyframe index list obtained from the Keyframe Selection Module. 105 | 106 | `--bbx_list` is the `box_list` string obtained from the Visual Perception Module. 107 | 108 | This module will return two strings: `obj_list` representing for the objects in the environment; `action_list` representing for the action list performed on these objects. 109 | 110 | 4. **Robot Manipulation Module** 111 | 112 | `simulation.py`: The `simulation.py` script accepts three parameters: `obj_list`, `action_list`, `output`. It first initializes a random simulation scene based on the `obj_list`, and then executes pick-and-place tasks according to the `action_list`, and finally write the video to output. 113 | 114 | Example usage: `python simulation.py --action_list "put chili on bowl and then put eggplant on glass" --obj_list chili carrot eggplant bowl glass --output demo2.mp4` 115 | 116 | Note that this part uses a modified version of the Code as Policies framework, and its successful execution depends heavily on whether the objects are already modeled and whether the corresponding execution functions for actions are present in the prompt. We provide a series of new object models and prompts that are compatible with our defined action list. If you want to operate on unseen objects, you will need to provide the corresponding object modeling, and modify the LMP and prompt file accordingly. 117 | 118 | We provide some simple object modelings of vegetables on hugging face. Download from https://huggingface.co/datasets/ai4ce/SeeDo/tree/main/SeeDo 119 | 120 | There will be an `assets.zip` file, extract that file into `assets` and make sure this folder is under the path of VLM_CaP. `VLM_CaP/assets` will then be used by `simulation.py` for simulation. 121 | 122 | It will write out a video of robot manipulation of a series of pick-and-place tasks in simulation. 123 | -------------------------------------------------------------------------------- /VLM_CaP/VideoPerception/src/__pycache__/vlm_video.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4ce/SeeDo/9b44abbca767218ab2b452ef5b8ac0958623def0/VLM_CaP/VideoPerception/src/__pycache__/vlm_video.cpython-39.pyc -------------------------------------------------------------------------------- /VLM_CaP/VideoPerception/src/video_cropper.py: -------------------------------------------------------------------------------- 1 | # video_cropper.py 2 | import cv2 3 | 4 | def crop_video(video_path, x, y, w, h): 5 | cap = cv2.VideoCapture(video_path) 6 | 7 | if not cap.isOpened(): 8 | print("Error: Could not open video.") 9 | return 10 | 11 | while True: 12 | ret, frame = cap.read() 13 | if not ret: 14 | break 15 | 16 | # 绘制矩形框以显示裁剪区域 17 | frame_with_rect = frame.copy() 18 | cv2.rectangle(frame_with_rect, (x, y), (x + w, y + h), (0, 255, 0), 2) 19 | 20 | # 显示带有裁剪区域的帧 21 | cv2.imshow('Frame with proposed ROI', frame_with_rect) 22 | key = cv2.waitKey(1) 23 | if key == ord('q'): # 按 'q' 键退出 24 | break 25 | elif key == ord('c'): # 按 'c' 键裁剪并显示裁剪区域 26 | roi = frame[y:y+h, x:x+w] 27 | cv2.imshow('Cropped ROI', roi) 28 | 29 | cap.release() 30 | cv2.destroyAllWindows() 31 | 32 | def crop_video_to_video(video_path, x, y, w, h, output_path): 33 | cap = cv2.VideoCapture(video_path) 34 | fps = cap.get(cv2.CAP_PROP_FPS) # 获取原视频的帧率 35 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用mp4v编码器 36 | out = cv2.VideoWriter(output_path, fourcc, fps, (w, h)) # 创建VideoWriter对象 37 | 38 | if not cap.isOpened(): 39 | print("Error: Could not open video.") 40 | return 41 | 42 | while True: 43 | ret, frame = cap.read() 44 | if not ret: 45 | break 46 | 47 | # 裁剪帧的感兴趣区域 48 | cropped_roi = frame[y:y+h, x:x+w] 49 | out.write(cropped_roi) # 写入帧 50 | 51 | cap.release() 52 | out.release() # 释放VideoWriter对象 -------------------------------------------------------------------------------- /VLM_CaP/VideoPerception/src/vlm_perception.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import requests 3 | import os 4 | 5 | # OpenAI API Key 6 | api_key = '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' 7 | 8 | # Function to encode the image 9 | def encode_image(image_path): 10 | with open(image_path, "rb") as image_file: 11 | return base64.b64encode(image_file.read()).decode('utf-8') 12 | 13 | # Directory containing the images 14 | directory_path = "media/charge laptop" 15 | 16 | # Get all image paths 17 | image_paths = [os.path.join(directory_path, f) for f in os.listdir(directory_path) if f.endswith(('.png', '.jpg', '.jpeg'))] 18 | 19 | # Encode all images to base64 20 | base64_images = [encode_image(path) for path in image_paths] 21 | 22 | 23 | headers = { 24 | "Content-Type": "application/json", 25 | "Authorization": f"Bearer {api_key}" 26 | } 27 | 28 | question = '''This picture set is a human demonstration of a task. organize your answer in the following format. 29 | Description of task: description 30 | Plan decomposition: decompose the task into a few executable sub-plans 31 | Object inhand: the object to hold in hand 32 | Object unattached: the object to touch 33 | Pre-touch point: the touching part of object inhand 34 | Post-touch point: the touching part of object unattached''' 35 | 36 | # Construct messages payload with all images 37 | messages_payload = [ 38 | { 39 | "role": "user", 40 | "content": [ 41 | { 42 | "type": "text", 43 | "text": f"I want you to answer question: {question}" 44 | } 45 | ] 46 | } 47 | ] 48 | 49 | # Add all images to the payload 50 | for b64_img in base64_images: 51 | image_message = { 52 | "role": "system", 53 | "content": { 54 | "type": "image", 55 | "data": f"data:image/png;base64,{b64_img}" 56 | } 57 | } 58 | messages_payload.append(image_message) 59 | 60 | # Final payload 61 | payload = { 62 | "model": "gpt-4-1106-vision-preview", 63 | "messages": messages_payload, 64 | "max_tokens": 300 65 | } 66 | 67 | response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) 68 | 69 | print(response.json()) -------------------------------------------------------------------------------- /VLM_CaP/VideoPerception/src/vlm_video.py: -------------------------------------------------------------------------------- 1 | from IPython.display import display, Image, Audio 2 | 3 | import cv2 # We're using OpenCV to read video, to install !pip install opencv-python 4 | import base64 5 | import time 6 | from openai import OpenAI 7 | import os 8 | import requests 9 | 10 | def extract_frames(video_path): 11 | # 使用 OpenCV 从视频文件中提取帧 12 | video = cv2.VideoCapture(video_path) 13 | 14 | base64Frames = [] 15 | while video.isOpened(): 16 | success, frame = video.read() 17 | if not success: 18 | break 19 | _, buffer = cv2.imencode(".jpg", frame) 20 | base64Frames.append(base64.b64encode(buffer).decode("utf-8")) 21 | 22 | video.release() 23 | print(len(base64Frames), "frames read.") 24 | 25 | # 展示帧,用于调试 26 | display_handle = display(None, display_id=True) 27 | for img in base64Frames: 28 | display_handle.update(Image(data=base64.b64decode(img.encode("utf-8")))) 29 | # time.sleep(0.025) # 如果需要以动画形式展示每帧,可以取消此行注释 30 | 31 | return base64Frames # 返回包含所有帧的 base64 编码的列表 -------------------------------------------------------------------------------- /VLM_CaP/VideoPerception/video_crop.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 10, 6 | "id": "6a98dde8", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# 在 Jupyter Notebook 中\n", 11 | "import sys\n", 12 | "sys.path.append('src') # 确保可以从 src 目录导入\n", 13 | "\n", 14 | "from video_cropper import crop_video_to_video\n", 15 | "\n", 16 | "# 视频路径和裁剪参数\n", 17 | "video_path = 'media/HumanDemo/RGYPC_Ca2.mp4'\n", 18 | "x, y, w, h = 400, 200, 400, 400 # 调整裁剪区域的参数\n", 19 | "output_path = 'media/HumanDemo/RGYPC_Cropped_Ca2.mp4'\n", 20 | "\n", 21 | "crop_video_to_video(video_path, x, y, w, h, output_path)\n", 22 | "\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "id": "55bcec7b", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [] 32 | } 33 | ], 34 | "metadata": { 35 | "kernelspec": { 36 | "display_name": "Python 3 (ipykernel)", 37 | "language": "python", 38 | "name": "python3" 39 | }, 40 | "language_info": { 41 | "codemirror_mode": { 42 | "name": "ipython", 43 | "version": 3 44 | }, 45 | "file_extension": ".py", 46 | "mimetype": "text/x-python", 47 | "name": "python", 48 | "nbconvert_exporter": "python", 49 | "pygments_lexer": "ipython3", 50 | "version": "3.9.13" 51 | } 52 | }, 53 | "nbformat": 4, 54 | "nbformat_minor": 5 55 | } 56 | -------------------------------------------------------------------------------- /VLM_CaP/VideoPerception/vlm_pic_set.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "f39c33f9", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "{'error': {'message': \"Invalid chat format. Expected 'content' field in all messages to be either str or list.\", 'type': 'invalid_request_error', 'param': None, 'code': None}}\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "%run src/vlm_perception.py" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "id": "8dee0e0e", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "id": "b67b3c28", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [] 36 | } 37 | ], 38 | "metadata": { 39 | "kernelspec": { 40 | "display_name": "Python 3 (ipykernel)", 41 | "language": "python", 42 | "name": "python3" 43 | }, 44 | "language_info": { 45 | "codemirror_mode": { 46 | "name": "ipython", 47 | "version": 3 48 | }, 49 | "file_extension": ".py", 50 | "mimetype": "text/x-python", 51 | "name": "python", 52 | "nbconvert_exporter": "python", 53 | "pygments_lexer": "ipython3", 54 | "version": "3.9.13" 55 | } 56 | }, 57 | "nbformat": 4, 58 | "nbformat_minor": 5 59 | } 60 | -------------------------------------------------------------------------------- /VLM_CaP/simulation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import numpy as np 4 | from src.env import PickPlaceEnv 5 | from src.LMP import LMP, LMP_wrapper, LMPFGen 6 | from src.configs import cfg_tabletop, lmp_tabletop_coords 7 | from src.key import projectkey 8 | from openai import OpenAI 9 | import shapely 10 | from shapely.geometry import * 11 | from shapely.affinity import * 12 | from moviepy.editor import ImageSequenceClip, concatenate_videoclips 13 | 14 | def setup_LMP(env, cfg_tabletop, openai_client): 15 | # LMP env wrapper 16 | cfg_tabletop = copy.deepcopy(cfg_tabletop) 17 | cfg_tabletop["env"] = dict() 18 | cfg_tabletop["env"]["init_objs"] = list(env.obj_name_to_id.keys()) 19 | cfg_tabletop["env"]["coords"] = lmp_tabletop_coords 20 | LMP_env = LMP_wrapper(env, cfg_tabletop) 21 | # creating APIs that the LMPs can interact with 22 | fixed_vars = {"np": np} 23 | fixed_vars.update( 24 | { 25 | name: eval(name) 26 | for name in shapely.geometry.__all__ + shapely.affinity.__all__ 27 | } 28 | ) 29 | variable_vars = { 30 | k: getattr(LMP_env, k) 31 | for k in [ 32 | "get_bbox", 33 | "get_obj_pos", 34 | "get_color", 35 | "is_obj_visible", 36 | "denormalize_xy", 37 | "put_first_on_second", 38 | "get_obj_names", 39 | "get_corner_name", 40 | "get_side_name", 41 | ] 42 | } 43 | variable_vars["say"] = lambda msg: print(f"robot says: {msg}") 44 | 45 | # creating the function-generating LMP 46 | lmp_fgen = LMPFGen(openai_client, cfg_tabletop["lmps"]["fgen"], fixed_vars, variable_vars) 47 | 48 | # creating other low-level LMPs 49 | variable_vars.update( 50 | { 51 | k: LMP(openai_client, k, cfg_tabletop["lmps"][k], lmp_fgen, fixed_vars, variable_vars) 52 | for k in [ 53 | "parse_obj_name", 54 | "parse_position", 55 | "parse_question", 56 | "transform_shape_pts", 57 | ] 58 | } 59 | ) 60 | 61 | # creating the LMP that deals w/ high-level language commands 62 | lmp_tabletop_ui = LMP( 63 | openai_client, 64 | "tabletop_ui", 65 | cfg_tabletop["lmps"]["tabletop_ui"], 66 | lmp_fgen, 67 | fixed_vars, 68 | variable_vars, 69 | ) 70 | 71 | return lmp_tabletop_ui 72 | 73 | def execute_actions(action_list, obj_list, env, lmp_tabletop_ui, output_path): 74 | # Split action_list into individual tasks 75 | tasks = action_list.split("and then") 76 | 77 | # List to hold all video clips 78 | video_clips = [] 79 | 80 | # Process each task separately 81 | for task in tasks: 82 | env.cache_video = [] # Clear the cache for the new task 83 | print(f"Running task: {task.strip()} and recording video...") 84 | lmp_tabletop_ui(task.strip(), f'objects = {env.object_list}') 85 | 86 | # Render the video for the task 87 | if env.cache_video: 88 | task_clip = ImageSequenceClip(env.cache_video, fps=30) 89 | video_clips.append(task_clip) 90 | 91 | # Concatenate all the task videos into one final video 92 | if video_clips: 93 | final_clip = concatenate_videoclips(video_clips, method="compose") 94 | final_clip.write_videofile(output_path, codec='libx264', bitrate="5000k", fps=30) 95 | print(f"Final video saved at {output_path}") 96 | 97 | def main(args): 98 | client = OpenAI(api_key=projectkey) 99 | # Initialize environment and LMP with passed arguments 100 | obj_list = args.obj_list 101 | action_list = args.action_list 102 | output_path = args.output # Output path for final video 103 | 104 | # Initialize environment 105 | env = PickPlaceEnv(render=True, high_res=True, high_frame_rate=False) 106 | _ = env.reset(obj_list) 107 | lmp_tabletop_ui = setup_LMP(env, cfg_tabletop, client) 108 | 109 | # Execute actions and save video 110 | execute_actions(action_list, obj_list, env, lmp_tabletop_ui, output_path) 111 | 112 | if __name__ == "__main__": 113 | # Parse arguments 114 | parser = argparse.ArgumentParser(description="Run PickPlaceEnv with LMP based on action list.") 115 | parser.add_argument('--action_list', type=str, required=True, help='String of actions separated by "and then"') 116 | parser.add_argument('--obj_list', nargs='+', required=True, help='List of object names in the environment') 117 | parser.add_argument('--output', type=str, required=True, help='Path to save the final video') 118 | 119 | args = parser.parse_args() 120 | 121 | main(args) -------------------------------------------------------------------------------- /VLM_CaP/src/LMP.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import openai 3 | import shapely 4 | import ast 5 | import astunparse 6 | from time import sleep 7 | import numpy as np 8 | from shapely.geometry import * 9 | from shapely.affinity import * 10 | 11 | # from openai.error import RateLimitError, APIConnectionError 12 | from pygments import highlight 13 | from pygments.lexers import PythonLexer 14 | from pygments.formatters import TerminalFormatter 15 | from src.env import COLORS 16 | 17 | # openai keys 18 | from src.key import mykey, projectkey 19 | 20 | 21 | class LMP: 22 | 23 | def __init__(self, client, name, cfg, lmp_fgen, fixed_vars, variable_vars): 24 | self._name = name 25 | self._cfg = cfg 26 | self._client = client 27 | 28 | self._base_prompt = self._cfg["prompt_text"] 29 | 30 | self._stop_tokens = list(self._cfg["stop"]) 31 | 32 | self._lmp_fgen = lmp_fgen 33 | 34 | self._fixed_vars = fixed_vars 35 | self._variable_vars = variable_vars 36 | self.exec_hist = "" 37 | 38 | def clear_exec_hist(self): 39 | self.exec_hist = "" 40 | 41 | def build_prompt(self, query, context=""): 42 | if len(self._variable_vars) > 0: 43 | variable_vars_imports_str = ( 44 | f"from utils import {', '.join(self._variable_vars.keys())}" 45 | ) 46 | else: 47 | variable_vars_imports_str = "" 48 | prompt = self._base_prompt.replace( 49 | "{variable_vars_imports}", variable_vars_imports_str 50 | ) 51 | 52 | if self._cfg["maintain_session"]: 53 | prompt += f"\n{self.exec_hist}" 54 | 55 | if context != "": 56 | prompt += f"\n{context}" 57 | 58 | use_query = f'{self._cfg["query_prefix"]}{query}{self._cfg["query_suffix"]}' 59 | prompt += f"\n{use_query}" 60 | 61 | return prompt, use_query 62 | 63 | def __call__(self, query, context="", **kwargs): 64 | prompt, use_query = self.build_prompt(query, context=context) 65 | 66 | message = [ 67 | {"role": "user", "content": prompt}, 68 | ] 69 | 70 | while True: 71 | try: 72 | code_str = self._client.chat.completions.create( 73 | messages = message, 74 | stop=self._stop_tokens, 75 | temperature=self._cfg["temperature"], 76 | model=self._cfg["engine"], 77 | max_tokens=self._cfg["max_tokens"], 78 | ).choices[0].message.content.strip() 79 | break 80 | except (openai.RateLimitError, openai.APIConnectionError) as e: 81 | print(f"OpenAI API got err {e}") 82 | print("Retrying after 10s.") 83 | sleep(10) 84 | 85 | if self._cfg["include_context"] and context != "": 86 | to_exec = f"{context}\n{code_str}" 87 | to_log = f"{context}\n{use_query}\n{code_str}" 88 | else: 89 | to_exec = code_str 90 | to_log = f"{use_query}\n{to_exec}" 91 | 92 | to_log_pretty = highlight(to_log, PythonLexer(), TerminalFormatter()) 93 | print(f"LMP {self._name} exec:\n\n{to_log_pretty}\n") 94 | 95 | new_fs = self._lmp_fgen.create_new_fs_from_code(code_str) 96 | self._variable_vars.update(new_fs) 97 | 98 | gvars = merge_dicts([self._fixed_vars, self._variable_vars]) 99 | lvars = kwargs 100 | 101 | if not self._cfg["debug_mode"]: 102 | print("to_exec: ", to_exec) 103 | exec_safe(to_exec, gvars, lvars) 104 | 105 | self.exec_hist += f"\n{to_exec}" 106 | 107 | if self._cfg["maintain_session"]: 108 | self._variable_vars.update(lvars) 109 | 110 | if self._cfg["has_return"]: 111 | return lvars[self._cfg["return_val_name"]] 112 | 113 | 114 | class LMPFGen: 115 | 116 | def __init__(self, client, cfg, fixed_vars, variable_vars): 117 | self._cfg = cfg 118 | self._client = client 119 | 120 | self._stop_tokens = list(self._cfg["stop"]) 121 | self._fixed_vars = fixed_vars 122 | self._variable_vars = variable_vars 123 | 124 | self._base_prompt = self._cfg["prompt_text"] 125 | 126 | def create_f_from_sig( 127 | self, f_name, f_sig, other_vars=None, fix_bugs=False, return_src=False 128 | ): 129 | print(f"Creating function: {f_sig}") 130 | 131 | use_query = f'{self._cfg["query_prefix"]}{f_sig}{self._cfg["query_suffix"]}' 132 | # prompt = f"{self._base_prompt}\n{use_query}" 133 | message = [ 134 | {"role": "user", "content": f"{self._base_prompt}\n{use_query}"} 135 | ] 136 | 137 | while True: 138 | try: 139 | f_src = self._client.chat.completions.create( 140 | messages = message, 141 | stop=self._stop_tokens, 142 | temperature=self._cfg["temperature"], 143 | model=self._cfg["engine"], 144 | max_tokens=self._cfg["max_tokens"], 145 | ).choices[0].message.content.strip() 146 | break 147 | except (openai.RateLimitError, openai.APIConnectionError) as e: 148 | print(f"OpenAI API got err {e}") 149 | print("Retrying after 10s.") 150 | sleep(10) 151 | 152 | if fix_bugs: 153 | f_src = openai.CodeEdit.create( 154 | model="code-davinci-002", 155 | input="# " + f_src, 156 | temperature=0, 157 | instruction="Fix the bug if there is one. Improve readability. Keep same inputs and outputs. Only small changes. No comments.", 158 | )["choices"][0]["text"].strip() 159 | 160 | if other_vars is None: 161 | other_vars = {} 162 | gvars = merge_dicts([self._fixed_vars, self._variable_vars, other_vars]) 163 | lvars = {} 164 | 165 | exec_safe(f_src, gvars, lvars) 166 | 167 | f = lvars[f_name] 168 | 169 | to_print = highlight( 170 | f"{use_query}\n{f_src}", PythonLexer(), TerminalFormatter() 171 | ) 172 | print(f"LMP FGEN created:\n\n{to_print}\n") 173 | 174 | if return_src: 175 | return f, f_src 176 | return f 177 | 178 | def create_new_fs_from_code( 179 | self, code_str, other_vars=None, fix_bugs=False, return_src=False 180 | ): 181 | fs, f_assigns = {}, {} 182 | f_parser = FunctionParser(fs, f_assigns) 183 | f_parser.visit(ast.parse(code_str)) 184 | for f_name, f_assign in f_assigns.items(): 185 | if f_name in fs: 186 | fs[f_name] = f_assign 187 | 188 | if other_vars is None: 189 | other_vars = {} 190 | 191 | new_fs = {} 192 | srcs = {} 193 | for f_name, f_sig in fs.items(): 194 | all_vars = merge_dicts( 195 | [self._fixed_vars, self._variable_vars, new_fs, other_vars] 196 | ) 197 | if not var_exists(f_name, all_vars): 198 | f, f_src = self.create_f_from_sig( 199 | f_name, f_sig, new_fs, fix_bugs=fix_bugs, return_src=True 200 | ) 201 | 202 | # recursively define child_fs in the function body if needed 203 | f_def_body = astunparse.unparse(ast.parse(f_src).body[0].body) 204 | child_fs, child_f_srcs = self.create_new_fs_from_code( 205 | f_def_body, other_vars=all_vars, fix_bugs=fix_bugs, return_src=True 206 | ) 207 | 208 | if len(child_fs) > 0: 209 | new_fs.update(child_fs) 210 | srcs.update(child_f_srcs) 211 | 212 | # redefine parent f so newly created child_fs are in scope 213 | gvars = merge_dicts( 214 | [self._fixed_vars, self._variable_vars, new_fs, other_vars] 215 | ) 216 | lvars = {} 217 | 218 | exec_safe(f_src, gvars, lvars) 219 | 220 | f = lvars[f_name] 221 | 222 | new_fs[f_name], srcs[f_name] = f, f_src 223 | 224 | if return_src: 225 | return new_fs, srcs 226 | return new_fs 227 | 228 | 229 | class FunctionParser(ast.NodeTransformer): 230 | 231 | def __init__(self, fs, f_assigns): 232 | super().__init__() 233 | self._fs = fs 234 | self._f_assigns = f_assigns 235 | 236 | def visit_Call(self, node): 237 | self.generic_visit(node) 238 | if isinstance(node.func, ast.Name): 239 | f_sig = astunparse.unparse(node).strip() 240 | f_name = astunparse.unparse(node.func).strip() 241 | self._fs[f_name] = f_sig 242 | return node 243 | 244 | def visit_Assign(self, node): 245 | self.generic_visit(node) 246 | if isinstance(node.value, ast.Call): 247 | assign_str = astunparse.unparse(node).strip() 248 | f_name = astunparse.unparse(node.value.func).strip() 249 | self._f_assigns[f_name] = assign_str 250 | return node 251 | 252 | 253 | def var_exists(name, all_vars): 254 | try: 255 | eval(name, all_vars) 256 | except: 257 | exists = False 258 | else: 259 | exists = True 260 | return exists 261 | 262 | 263 | def merge_dicts(dicts): 264 | return {k: v for d in dicts for k, v in d.items()} 265 | 266 | 267 | def exec_safe(code_str, gvars=None, lvars=None): 268 | print("savely executing code: ", code_str) 269 | banned_phrases = ["import", "__"] 270 | for phrase in banned_phrases: 271 | assert phrase not in code_str 272 | 273 | if gvars is None: 274 | gvars = {} 275 | if lvars is None: 276 | lvars = {} 277 | empty_fn = lambda *args, **kwargs: None 278 | custom_gvars = merge_dicts([gvars, {"exec": empty_fn, "eval": empty_fn}]) 279 | # print("custom_gvars: ", custom_gvars) 280 | # print("------") 281 | # print("local vars: ", lvars) 282 | # print("------") 283 | exec(code_str, custom_gvars, lvars) 284 | 285 | 286 | class LMP_wrapper: 287 | 288 | def __init__(self, env, cfg, render=False): 289 | self.env = env 290 | self._cfg = cfg 291 | self.object_names = list(self._cfg["env"]["init_objs"]) 292 | 293 | self._min_xy = np.array(self._cfg["env"]["coords"]["bottom_left"]) 294 | self._max_xy = np.array(self._cfg["env"]["coords"]["top_right"]) 295 | self._range_xy = self._max_xy - self._min_xy 296 | 297 | self._table_z = self._cfg["env"]["coords"]["table_z"] 298 | self.render = render 299 | 300 | def is_obj_visible(self, obj_name): 301 | return obj_name in self.object_names 302 | 303 | def get_obj_names(self): 304 | return self.object_names[::] 305 | 306 | def denormalize_xy(self, pos_normalized): 307 | return pos_normalized * self._range_xy + self._min_xy 308 | 309 | def get_corner_positions(self): 310 | unit_square = box(0, 0, 1, 1) 311 | normalized_corners = np.array(list(unit_square.exterior.coords))[:4] 312 | corners = np.array( 313 | ([self.denormalize_xy(corner) for corner in normalized_corners]) 314 | ) 315 | return corners 316 | 317 | def get_side_positions(self): 318 | side_xs = np.array([0, 0.5, 0.5, 1]) 319 | side_ys = np.array([0.5, 0, 1, 0.5]) 320 | normalized_side_positions = np.c_[side_xs, side_ys] 321 | side_positions = np.array( 322 | ([self.denormalize_xy(corner) for corner in normalized_side_positions]) 323 | ) 324 | return side_positions 325 | 326 | def get_obj_pos(self, obj_name): 327 | # return the xy position of the object in robot base frame 328 | return self.env.get_obj_pos(obj_name)[:2] 329 | 330 | def get_obj_position_np(self, obj_name): 331 | return self.get_pos(obj_name) 332 | 333 | def get_bbox(self, obj_name): 334 | # return the axis-aligned object bounding box in robot base frame (not in pixels) 335 | # the format is (min_x, min_y, max_x, max_y) 336 | bbox = self.env.get_bounding_box(obj_name) 337 | return bbox 338 | 339 | def get_color(self, obj_name): 340 | for color, rgb in COLORS.items(): 341 | if color in obj_name: 342 | return rgb 343 | 344 | def pick_place(self, pick_pos, place_pos): 345 | pick_pos_xyz = np.r_[pick_pos, [self._table_z]] 346 | place_pos_xyz = np.r_[place_pos, [self._table_z]] 347 | pass 348 | 349 | def put_first_on_second(self, arg1, arg2): 350 | # put the object with obj_name on top of target 351 | # target can either be another object name, or it can be an x-y position in robot base frame 352 | pick_pos = self.get_obj_pos(arg1) if isinstance(arg1, str) else arg1 353 | place_pos = self.get_obj_pos(arg2) if isinstance(arg2, str) else arg2 354 | self.env.step(action={"pick": pick_pos, "place": place_pos}) 355 | 356 | def get_robot_pos(self): 357 | # return robot end-effector xy position in robot base frame 358 | return self.env.get_ee_pos() 359 | 360 | def goto_pos(self, position_xy): 361 | # move the robot end-effector to the desired xy position while maintaining same z 362 | ee_xyz = self.env.get_ee_pos() 363 | position_xyz = np.concatenate([position_xy, ee_xyz[-1]]) 364 | while np.linalg.norm(position_xyz - ee_xyz) > 0.01: 365 | self.env.movep(position_xyz) 366 | self.env.step_sim_and_render() 367 | ee_xyz = self.env.get_ee_pos() 368 | 369 | def follow_traj(self, traj): 370 | for pos in traj: 371 | self.goto_pos(pos) 372 | 373 | def get_corner_positions(self): 374 | normalized_corners = np.array([[0, 1], [1, 1], [0, 0], [1, 0]]) 375 | return np.array( 376 | ([self.denormalize_xy(corner) for corner in normalized_corners]) 377 | ) 378 | 379 | def get_side_positions(self): 380 | normalized_sides = np.array([[0.5, 1], [1, 0.5], [0.5, 0], [0, 0.5]]) 381 | return np.array(([self.denormalize_xy(side) for side in normalized_sides])) 382 | 383 | def get_corner_name(self, pos): 384 | corner_positions = self.get_corner_positions() 385 | corner_idx = np.argmin(np.linalg.norm(corner_positions - pos, axis=1)) 386 | return [ 387 | "top left corner", 388 | "top right corner", 389 | "bottom left corner", 390 | "botom right corner", 391 | ][corner_idx] 392 | 393 | def get_side_name(self, pos): 394 | side_positions = self.get_side_positions() 395 | side_idx = np.argmin(np.linalg.norm(side_positions - pos, axis=1)) 396 | return ["top side", "right side", "bottom side", "left side"][side_idx] 397 | -------------------------------------------------------------------------------- /VLM_CaP/src/__init__.py: -------------------------------------------------------------------------------- 1 | # src/__init__.py 2 | -------------------------------------------------------------------------------- /VLM_CaP/src/configs.py: -------------------------------------------------------------------------------- 1 | # LMP Configs 2 | from src.prompts import * 3 | 4 | cfg_tabletop = { 5 | "lmps": { 6 | "tabletop_ui": { 7 | "prompt_text": prompt_tabletop_ui, 8 | "engine": "gpt-3.5-turbo", # for "text-davinci-003" has been deprecated, 9 | "max_tokens": 512, 10 | "temperature": 0, 11 | "query_prefix": "# ", 12 | "query_suffix": ".", 13 | "stop": ["#", "objects = ["], 14 | "maintain_session": True, 15 | "debug_mode": False, 16 | "include_context": True, 17 | "has_return": False, 18 | "return_val_name": "ret_val", 19 | }, 20 | "parse_obj_name": { 21 | "prompt_text": prompt_parse_obj_name, 22 | "engine": "gpt-3.5-turbo", # for "text-davinci-003" has been deprecated, 23 | "max_tokens": 512, 24 | "temperature": 0, 25 | "query_prefix": "# ", 26 | "query_suffix": ".", 27 | "stop": ["#", "objects = ["], 28 | "maintain_session": False, 29 | "debug_mode": False, 30 | "include_context": True, 31 | "has_return": True, 32 | "return_val_name": "ret_val", 33 | }, 34 | "parse_position": { 35 | "prompt_text": prompt_parse_position, 36 | "engine": "gpt-3.5-turbo", # for "text-davinci-003" has been deprecated, 37 | "max_tokens": 512, 38 | "temperature": 0, 39 | "query_prefix": "# ", 40 | "query_suffix": ".", 41 | "stop": ["#"], 42 | "maintain_session": False, 43 | "debug_mode": False, 44 | "include_context": True, 45 | "has_return": True, 46 | "return_val_name": "ret_val", 47 | }, 48 | "parse_question": { 49 | "prompt_text": prompt_parse_question, 50 | "engine": "gpt-3.5-turbo", # for "text-davinci-003" has been deprecated, 51 | "max_tokens": 512, 52 | "temperature": 0, 53 | "query_prefix": "# ", 54 | "query_suffix": ".", 55 | "stop": ["#", "objects = ["], 56 | "maintain_session": False, 57 | "debug_mode": False, 58 | "include_context": True, 59 | "has_return": True, 60 | "return_val_name": "ret_val", 61 | }, 62 | "transform_shape_pts": { 63 | "prompt_text": prompt_transform_shape_pts, 64 | "engine": "gpt-3.5-turbo", # for "text-davinci-003" has been deprecated, 65 | "max_tokens": 512, 66 | "temperature": 0, 67 | "query_prefix": "# ", 68 | "query_suffix": ".", 69 | "stop": ["#"], 70 | "maintain_session": False, 71 | "debug_mode": False, 72 | "include_context": True, 73 | "has_return": True, 74 | "return_val_name": "new_shape_pts", 75 | }, 76 | "fgen": { 77 | "prompt_text": prompt_fgen, 78 | "engine": "gpt-3.5-turbo", # for "text-davinci-003" has been deprecated, 79 | "max_tokens": 512, 80 | "temperature": 0, 81 | "query_prefix": "# define function: ", 82 | "query_suffix": ".", 83 | "stop": ["# define", "# example"], 84 | "maintain_session": False, 85 | "debug_mode": False, 86 | "include_context": True, 87 | }, 88 | } 89 | } 90 | 91 | lmp_tabletop_coords = { 92 | "top_left": (-0.3 + 0.05, -0.2 - 0.05), 93 | "top_side": (0, -0.2 - 0.05), 94 | "top_right": (0.3 - 0.05, -0.2 - 0.05), 95 | "left_side": ( 96 | -0.3 + 0.05, 97 | -0.5, 98 | ), 99 | "middle": ( 100 | 0, 101 | -0.5, 102 | ), 103 | "right_side": ( 104 | 0.3 - 0.05, 105 | -0.5, 106 | ), 107 | "bottom_left": (-0.3 + 0.05, -0.8 + 0.05), 108 | "bottom_side": (0, -0.8 + 0.05), 109 | "bottom_right": (0.3 - 0.05, -0.8 + 0.05), 110 | "table_z": 0.0, 111 | } 112 | -------------------------------------------------------------------------------- /VLM_CaP/src/env.py: -------------------------------------------------------------------------------- 1 | # pick and place environment 2 | # # Global constants: pick and place objects, colors, workspace bounds 3 | import os 4 | import pybullet 5 | import pybullet_data 6 | import numpy as np 7 | import threading 8 | import copy 9 | import openai 10 | import cv2 11 | 12 | from src.grippers import Robotiq2F85 13 | 14 | COLORS = { 15 | "blue": (78 / 255, 121 / 255, 167 / 255, 255 / 255), 16 | "red": (255 / 255, 87 / 255, 89 / 255, 255 / 255), 17 | "green": (89 / 255, 169 / 255, 79 / 255, 255 / 255), 18 | "orange": (242 / 255, 142 / 255, 43 / 255, 255 / 255), 19 | "yellow": (237 / 255, 201 / 255, 72 / 255, 255 / 255), 20 | "purple": (176 / 255, 122 / 255, 161 / 255, 255 / 255), 21 | "pink": (255 / 255, 157 / 255, 167 / 255, 255 / 255), 22 | "cyan": (118 / 255, 183 / 255, 178 / 255, 255 / 255), 23 | "brown": (156 / 255, 117 / 255, 95 / 255, 255 / 255), 24 | "gray": (186 / 255, 176 / 255, 172 / 255, 255 / 255), 25 | "white": (255 / 255, 255 / 255, 255 / 255, 255 / 255), 26 | "wooden": (255 / 255, 255 / 255, 255 / 255, 255 / 255), 27 | } 28 | 29 | CORNER_POS = { 30 | "top left corner": (-0.3 + 0.05, -0.2 - 0.05, 0), 31 | "top side": (0, -0.2 - 0.05, 0), 32 | "top right corner": (0.3 - 0.05, -0.2 - 0.05, 0), 33 | "left side": (-0.3 + 0.05, -0.5, 0), 34 | "middle": (0, -0.5, 0), 35 | "right side": (0.3 - 0.05, -0.5, 0), 36 | "bottom left corner": (-0.3 + 0.05, -0.8 + 0.05, 0), 37 | "bottom side": (0, -0.8 + 0.05, 0), 38 | "bottom right corner": (0.3 - 0.05, -0.8 + 0.05, 0), 39 | } 40 | 41 | ALL_BLOCKS = [ 42 | "blue block", 43 | "red block", 44 | "green block", 45 | "orange block", 46 | "yellow block", 47 | "purple block", 48 | "pink block", 49 | "cyan block", 50 | "brown block", 51 | "gray block", 52 | ] 53 | ALL_BOWLS = [ 54 | "blue bowl", 55 | "red bowl", 56 | "green bowl", 57 | "orange bowl", 58 | "yellow bowl", 59 | "purple bowl", 60 | "pink bowl", 61 | "cyan bowl", 62 | "brown bowl", 63 | "gray bowl", 64 | "white bowl", 65 | ] 66 | 67 | ALL_VEGGIES = [ 68 | "carrot", 69 | "tomato", 70 | "chili", 71 | "eggplant", 72 | "potato", 73 | "corn", 74 | "glass", 75 | "wooden block1", 76 | "wooden block2", 77 | "wooden block3", 78 | "wooden block4" 79 | ] 80 | 81 | PIXEL_SIZE = 0.00267857 82 | BOUNDS = np.float32([[-0.3, 0.3], [-0.8, -0.2], [0, 0.15]]) # X Y Z 83 | 84 | # Gym-style environment code 85 | 86 | 87 | class PickPlaceEnv: 88 | 89 | def __init__(self, render=False, high_res=False, high_frame_rate=False): 90 | self.dt = 1 / 480 91 | self.sim_step = 0 92 | 93 | # Configure and start PyBullet. 94 | # python3 -m pybullet_utils.runServer 95 | # pybullet.connect(pybullet.SHARED_MEMORY) # pybullet.GUI for local GUI. 96 | pybullet.connect(pybullet.DIRECT) # pybullet.GUI for local GUI. 97 | # pybullet.connect(pybullet.GUI) # pybullet.GUI for local GUI. 98 | pybullet.configureDebugVisualizer(pybullet.COV_ENABLE_GUI, 0) 99 | pybullet.setPhysicsEngineParameter(enableFileCaching=0) 100 | assets_path = os.path.dirname(os.path.abspath("")) 101 | pybullet.setAdditionalSearchPath(assets_path) 102 | pybullet.setAdditionalSearchPath(pybullet_data.getDataPath()) 103 | pybullet.setTimeStep(self.dt) 104 | 105 | self.home_joints = ( 106 | np.pi / 2, 107 | -np.pi / 2, 108 | np.pi / 2, 109 | -np.pi / 2, 110 | 3 * np.pi / 2, 111 | 0, 112 | ) # Joint angles: (J0, J1, J2, J3, J4, J5). 113 | self.home_ee_euler = (np.pi, 0, np.pi) # (RX, RY, RZ) rotation in Euler angles. 114 | self.ee_link_id = 9 # Link ID of UR5 end effector. 115 | self.tip_link_id = 10 # Link ID of gripper finger tips. 116 | self.gripper = None 117 | 118 | self.render = render 119 | self.high_res = high_res 120 | self.high_frame_rate = high_frame_rate 121 | 122 | def reset(self, object_list): 123 | pybullet.resetSimulation(pybullet.RESET_USE_DEFORMABLE_WORLD) 124 | pybullet.setGravity(0, 0, -9.8) 125 | self.cache_video = [] 126 | 127 | # Temporarily disable rendering to load URDFs faster. 128 | pybullet.configureDebugVisualizer(pybullet.COV_ENABLE_RENDERING, 0) 129 | 130 | # Add robot. 131 | pybullet.loadURDF("plane.urdf", [0, 0, -0.001]) 132 | self.robot_id = pybullet.loadURDF( 133 | "assets/ur5e/ur5e.urdf", 134 | [0, 0, 0], 135 | flags=pybullet.URDF_USE_MATERIAL_COLORS_FROM_MTL, 136 | ) 137 | self.ghost_id = pybullet.loadURDF( 138 | "assets/ur5e/ur5e.urdf", [0, 0, -10] 139 | ) # For forward kinematics. 140 | self.joint_ids = [ 141 | pybullet.getJointInfo(self.robot_id, i) 142 | for i in range(pybullet.getNumJoints(self.robot_id)) 143 | ] 144 | self.joint_ids = [ 145 | j[0] for j in self.joint_ids if j[2] == pybullet.JOINT_REVOLUTE 146 | ] 147 | 148 | # Move robot to home configuration. 149 | for i in range(len(self.joint_ids)): 150 | pybullet.resetJointState( 151 | self.robot_id, self.joint_ids[i], self.home_joints[i] 152 | ) 153 | 154 | # Add gripper. 155 | if self.gripper is not None: 156 | while self.gripper.constraints_thread.is_alive(): 157 | self.constraints_thread_active = False 158 | self.gripper = Robotiq2F85(self.robot_id, self.ee_link_id) 159 | self.gripper.release() 160 | 161 | # # Add inner white workspace. 162 | # inner_plane_shape = pybullet.createCollisionShape( 163 | # pybullet.GEOM_BOX, halfExtents=[0.3, 0.3, 0.001] 164 | # ) 165 | # inner_plane_visual = pybullet.createVisualShape( 166 | # pybullet.GEOM_BOX, halfExtents=[0.3, 0.3, 0.001] 167 | # ) 168 | # inner_plane_id = pybullet.createMultiBody( 169 | # 0, inner_plane_shape, inner_plane_visual, basePosition=[0, -0.5, 0] 170 | # ) 171 | # pybullet.changeVisualShape(inner_plane_id, -1, rgbaColor=[1.0, 1.0, 1.0, 1.0]) # White color 172 | 173 | # Add outer black workspace as a border, placed below the inner workspace. 174 | outer_plane_shape = pybullet.createCollisionShape( 175 | pybullet.GEOM_BOX, halfExtents=[0.35, 0.35, 0.001] # Slightly larger than the inner workspace 176 | ) 177 | outer_plane_visual = pybullet.createVisualShape( 178 | pybullet.GEOM_BOX, halfExtents=[0.35, 0.35, 0.001] 179 | ) 180 | outer_plane_id = pybullet.createMultiBody( 181 | 0, outer_plane_shape, outer_plane_visual, basePosition=[0, -0.5, -0.001] # Lowered z-axis 182 | ) 183 | pybullet.changeVisualShape(outer_plane_id, -1, rgbaColor=[0.0, 0.0, 0.0, 1.0]) # Black color 184 | 185 | 186 | 187 | # Load objects according to config. 188 | self.object_list = object_list 189 | self.obj_name_to_id = {} 190 | obj_xyz = np.zeros((0, 3)) 191 | for obj_name in object_list: 192 | if("block1" in obj_name) or ("block2" in obj_name) or ("block3" in obj_name) or ("block4" in obj_name): 193 | object_type = "vegetable" 194 | # Get random position 15cm+ from other objects. 195 | while True: 196 | rand_x = np.random.uniform(BOUNDS[0, 0] + 0.1, BOUNDS[0, 1] - 0.1) 197 | rand_y = np.random.uniform(BOUNDS[1, 0] + 0.1, BOUNDS[1, 1] - 0.1) 198 | rand_xyz = np.float32([rand_x, rand_y, 0.03]).reshape(1, 3) 199 | if len(obj_xyz) == 0: 200 | obj_xyz = np.concatenate((obj_xyz, rand_xyz), axis=0) 201 | break 202 | else: 203 | nn_dist = np.min( 204 | np.linalg.norm(obj_xyz - rand_xyz, axis=1) 205 | ).squeeze() 206 | if nn_dist > 0.15: 207 | obj_xyz = np.concatenate((obj_xyz, rand_xyz), axis=0) 208 | break 209 | 210 | object_position = rand_xyz.squeeze() 211 | 212 | # load object urdf 213 | if "glass" in obj_name: 214 | object_position[2] = 0 # following the bowls configuration 215 | object_id = pybullet.loadURDF(f"assets/{obj_name}/{obj_name}.urdf", object_position, useFixedBase=1) 216 | else: 217 | object_id = pybullet.loadURDF(f"assets/{obj_name}/{obj_name}.urdf", object_position) 218 | self.obj_name_to_id[obj_name] = object_id 219 | elif ("block" in obj_name) or ("bowl" in obj_name): 220 | 221 | # Get random position 15cm+ from other objects. 222 | while True: 223 | rand_x = np.random.uniform(BOUNDS[0, 0] + 0.1, BOUNDS[0, 1] - 0.1) 224 | rand_y = np.random.uniform(BOUNDS[1, 0] + 0.1, BOUNDS[1, 1] - 0.1) 225 | rand_xyz = np.float32([rand_x, rand_y, 0.03]).reshape(1, 3) 226 | if len(obj_xyz) == 0: 227 | obj_xyz = np.concatenate((obj_xyz, rand_xyz), axis=0) 228 | break 229 | else: 230 | nn_dist = np.min( 231 | np.linalg.norm(obj_xyz - rand_xyz, axis=1) 232 | ).squeeze() 233 | if nn_dist > 0.20: 234 | obj_xyz = np.concatenate((obj_xyz, rand_xyz), axis=0) 235 | break 236 | 237 | object_color = COLORS[obj_name.split(" ")[0]] 238 | object_type = obj_name.split(" ")[1] 239 | object_position = rand_xyz.squeeze() 240 | if object_type == "block": 241 | object_shape = pybullet.createCollisionShape( 242 | pybullet.GEOM_BOX, halfExtents=[0.02, 0.02, 0.02] # (half x, half y , half z) 243 | ) 244 | object_visual = pybullet.createVisualShape( 245 | pybullet.GEOM_BOX, halfExtents=[0.02, 0.02, 0.02] 246 | ) 247 | object_id = pybullet.createMultiBody( 248 | 0.01, object_shape, object_visual, basePosition=object_position 249 | ) 250 | elif object_type == "bowl": 251 | object_position[2] = 0 252 | object_id = pybullet.loadURDF( 253 | "assets/bowl/bowl.urdf", object_position, useFixedBase=1 254 | ) 255 | pybullet.changeVisualShape(object_id, -1, rgbaColor=object_color) 256 | self.obj_name_to_id[obj_name] = object_id 257 | 258 | elif obj_name in ALL_VEGGIES: 259 | object_type = "vegetable" 260 | 261 | # Get random position 15cm+ from other objects. 262 | while True: 263 | rand_x = np.random.uniform(BOUNDS[0, 0] + 0.1, BOUNDS[0, 1] - 0.1) 264 | rand_y = np.random.uniform(BOUNDS[1, 0] + 0.1, BOUNDS[1, 1] - 0.1) 265 | rand_xyz = np.float32([rand_x, rand_y, 0.03]).reshape(1, 3) 266 | if len(obj_xyz) == 0: 267 | obj_xyz = np.concatenate((obj_xyz, rand_xyz), axis=0) 268 | break 269 | else: 270 | nn_dist = np.min( 271 | np.linalg.norm(obj_xyz - rand_xyz, axis=1) 272 | ).squeeze() 273 | if nn_dist > 0.15: 274 | obj_xyz = np.concatenate((obj_xyz, rand_xyz), axis=0) 275 | break 276 | 277 | object_position = rand_xyz.squeeze() 278 | 279 | # load object urdf 280 | if "glass" in obj_name: 281 | object_position[2] = 0 # following the bowls configuration 282 | object_id = pybullet.loadURDF(f"assets/{obj_name}/{obj_name}.urdf", object_position, useFixedBase=1) 283 | else: 284 | object_id = pybullet.loadURDF(f"assets/{obj_name}/{obj_name}.urdf", object_position) 285 | self.obj_name_to_id[obj_name] = object_id 286 | 287 | # Re-enable rendering. 288 | pybullet.configureDebugVisualizer(pybullet.COV_ENABLE_RENDERING, 1) 289 | 290 | for _ in range(200): 291 | pybullet.stepSimulation() 292 | 293 | # record object positions at reset 294 | self.init_pos = {name: self.get_obj_pos(name) for name in object_list} 295 | 296 | return self.get_observation() 297 | 298 | def servoj(self, joints): 299 | """Move to target joint positions with position control.""" 300 | pybullet.setJointMotorControlArray( 301 | bodyIndex=self.robot_id, 302 | jointIndices=self.joint_ids, 303 | controlMode=pybullet.POSITION_CONTROL, 304 | targetPositions=joints, 305 | positionGains=[0.01] * 6, 306 | ) 307 | 308 | def movep(self, position): 309 | """Move to target end effector position.""" 310 | joints = pybullet.calculateInverseKinematics( 311 | bodyUniqueId=self.robot_id, 312 | endEffectorLinkIndex=self.tip_link_id, 313 | targetPosition=position, 314 | targetOrientation=pybullet.getQuaternionFromEuler(self.home_ee_euler), 315 | maxNumIterations=100, 316 | ) 317 | self.servoj(joints) 318 | 319 | def get_ee_pos(self): 320 | ee_xyz = np.float32(pybullet.getLinkState(self.robot_id, self.tip_link_id)[0]) 321 | return ee_xyz 322 | 323 | def step(self, action=None): 324 | """Do pick and place motion primitive.""" 325 | print("action", action) 326 | pick_pos, place_pos = action["pick"].copy(), action["place"].copy() 327 | 328 | # Set fixed primitive z-heights. 329 | hover_xyz = np.float32([pick_pos[0], pick_pos[1], 0.2]) 330 | if pick_pos.shape[-1] == 2: 331 | pick_xyz = np.append(pick_pos, 0.025) 332 | else: 333 | pick_xyz = pick_pos 334 | pick_xyz[2] = 0.025 335 | if place_pos.shape[-1] == 2: 336 | place_xyz = np.append(place_pos, 0.15) 337 | else: 338 | place_xyz = place_pos 339 | place_xyz[2] = 0.15 340 | 341 | # Move to object. 342 | ee_xyz = self.get_ee_pos() 343 | while np.linalg.norm(hover_xyz - ee_xyz) > 0.01: 344 | self.movep(hover_xyz) 345 | self.step_sim_and_render() 346 | ee_xyz = self.get_ee_pos() 347 | 348 | while np.linalg.norm(pick_xyz - ee_xyz) > 0.01: 349 | self.movep(pick_xyz) 350 | self.step_sim_and_render() 351 | ee_xyz = self.get_ee_pos() 352 | 353 | # Pick up object. 354 | self.gripper.activate() 355 | for _ in range(240): 356 | self.step_sim_and_render() 357 | while np.linalg.norm(hover_xyz - ee_xyz) > 0.01: 358 | self.movep(hover_xyz) 359 | self.step_sim_and_render() 360 | ee_xyz = self.get_ee_pos() 361 | 362 | for _ in range(50): 363 | self.step_sim_and_render() 364 | 365 | # Move to place location. 366 | while np.linalg.norm(place_xyz - ee_xyz) > 0.01: 367 | self.movep(place_xyz) 368 | self.step_sim_and_render() 369 | ee_xyz = self.get_ee_pos() 370 | 371 | # Place down object. 372 | while (not self.gripper.detect_contact()) and (place_xyz[2] > 0.03): 373 | place_xyz[2] -= 0.001 374 | self.movep(place_xyz) 375 | for _ in range(3): 376 | self.step_sim_and_render() 377 | self.gripper.release() 378 | for _ in range(240): 379 | self.step_sim_and_render() 380 | place_xyz[2] = 0.2 381 | ee_xyz = self.get_ee_pos() 382 | while np.linalg.norm(place_xyz - ee_xyz) > 0.01: 383 | self.movep(place_xyz) 384 | self.step_sim_and_render() 385 | ee_xyz = self.get_ee_pos() 386 | place_xyz = np.float32([0, -0.5, 0.2]) 387 | while np.linalg.norm(place_xyz - ee_xyz) > 0.01: 388 | self.movep(place_xyz) 389 | self.step_sim_and_render() 390 | ee_xyz = self.get_ee_pos() 391 | 392 | observation = self.get_observation() 393 | reward = self.get_reward() 394 | done = False 395 | info = {} 396 | return observation, reward, done, info 397 | 398 | def set_alpha_transparency(self, alpha: float) -> None: 399 | for id in range(20): 400 | visual_shape_data = pybullet.getVisualShapeData(id) 401 | for i in range(len(visual_shape_data)): 402 | object_id, link_index, _, _, _, _, _, rgba_color = visual_shape_data[i] 403 | rgba_color = list(rgba_color[0:3]) + [alpha] 404 | pybullet.changeVisualShape( 405 | self.robot_id, linkIndex=i, rgbaColor=rgba_color 406 | ) 407 | pybullet.changeVisualShape( 408 | self.gripper.body, linkIndex=i, rgbaColor=rgba_color 409 | ) 410 | 411 | def step_sim_and_render(self): 412 | pybullet.stepSimulation() 413 | self.sim_step += 1 414 | 415 | interval = 40 if self.high_frame_rate else 60 416 | # Render current image at 8 FPS. 417 | if self.sim_step % interval == 0 and self.render: 418 | self.cache_video.append(self.get_camera_image()) 419 | 420 | def get_camera_image(self): 421 | if not self.high_res: 422 | image_size = (240, 240) 423 | intrinsics = (120.0, 0, 120.0, 0, 120.0, 120.0, 0, 0, 1) 424 | else: 425 | image_size = (360, 360) 426 | intrinsics = (180.0, 0, 180.0, 0, 180.0, 180.0, 0, 0, 1) 427 | # color, _, _, _, _ = env.render_image(image_size, intrinsics) 428 | color, _, _, _, _ = self.render_image(image_size, intrinsics) # why env? 429 | return color 430 | 431 | def get_reward(self): 432 | return None 433 | 434 | def get_observation(self): 435 | observation = {} 436 | 437 | # Render current image. 438 | color, depth, position, orientation, intrinsics = self.render_image() 439 | 440 | # Get heightmaps and colormaps. 441 | points = self.get_pointcloud(depth, intrinsics) 442 | position = np.float32(position).reshape(3, 1) 443 | rotation = pybullet.getMatrixFromQuaternion(orientation) 444 | rotation = np.float32(rotation).reshape(3, 3) 445 | transform = np.eye(4) 446 | transform[:3, :] = np.hstack((rotation, position)) 447 | points = self.transform_pointcloud(points, transform) 448 | heightmap, colormap, xyzmap = self.get_heightmap( 449 | points, color, BOUNDS, PIXEL_SIZE 450 | ) 451 | 452 | observation["image"] = colormap 453 | observation["xyzmap"] = xyzmap 454 | 455 | return observation 456 | 457 | def render_image( 458 | self, 459 | image_size=(720, 720), 460 | intrinsics=(360.0, 0, 360.0, 0, 360.0, 360.0, 0, 0, 1), 461 | ): 462 | 463 | # Camera parameters. 464 | position = (0, -0.85, 0.4) 465 | orientation = (np.pi / 4 + np.pi / 48, np.pi, np.pi) 466 | orientation = pybullet.getQuaternionFromEuler(orientation) 467 | zrange = (0.01, 10.0) 468 | noise = True 469 | 470 | # OpenGL camera settings. 471 | lookdir = np.float32([0, 0, 1]).reshape(3, 1) 472 | updir = np.float32([0, -1, 0]).reshape(3, 1) 473 | rotation = pybullet.getMatrixFromQuaternion(orientation) 474 | rotm = np.float32(rotation).reshape(3, 3) 475 | lookdir = (rotm @ lookdir).reshape(-1) 476 | updir = (rotm @ updir).reshape(-1) 477 | lookat = position + lookdir 478 | focal_len = intrinsics[0] 479 | znear, zfar = (0.01, 10.0) 480 | viewm = pybullet.computeViewMatrix(position, lookat, updir) 481 | fovh = (image_size[0] / 2) / focal_len 482 | fovh = 180 * np.arctan(fovh) * 2 / np.pi 483 | 484 | # Notes: 1) FOV is vertical FOV 2) aspect must be float 485 | aspect_ratio = image_size[1] / image_size[0] 486 | projm = pybullet.computeProjectionMatrixFOV(fovh, aspect_ratio, znear, zfar) 487 | 488 | # Render with OpenGL camera settings. 489 | _, _, color, depth, segm = pybullet.getCameraImage( 490 | width=image_size[1], 491 | height=image_size[0], 492 | viewMatrix=viewm, 493 | projectionMatrix=projm, 494 | shadow=1, 495 | flags=pybullet.ER_SEGMENTATION_MASK_OBJECT_AND_LINKINDEX, 496 | renderer=pybullet.ER_BULLET_HARDWARE_OPENGL, 497 | ) 498 | 499 | # Get color image. 500 | color_image_size = (image_size[0], image_size[1], 4) 501 | color = np.array(color, dtype=np.uint8).reshape(color_image_size) 502 | color = color[:, :, :3] # remove alpha channel 503 | if noise: 504 | color = np.int32(color) 505 | color += np.int32(np.random.normal(0, 3, color.shape)) 506 | color = np.uint8(np.clip(color, 0, 255)) 507 | 508 | # Get depth image. 509 | depth_image_size = (image_size[0], image_size[1]) 510 | zbuffer = np.float32(depth).reshape(depth_image_size) 511 | depth = zfar + znear - (2 * zbuffer - 1) * (zfar - znear) 512 | depth = (2 * znear * zfar) / depth 513 | if noise: 514 | depth += np.random.normal(0, 0.003, depth.shape) 515 | 516 | intrinsics = np.float32(intrinsics).reshape(3, 3) 517 | return color, depth, position, orientation, intrinsics 518 | 519 | def get_pointcloud(self, depth, intrinsics): 520 | """Get 3D pointcloud from perspective depth image. 521 | Args: 522 | depth: HxW float array of perspective depth in meters. 523 | intrinsics: 3x3 float array of camera intrinsics matrix. 524 | Returns: 525 | points: HxWx3 float array of 3D points in camera coordinates. 526 | """ 527 | height, width = depth.shape 528 | xlin = np.linspace(0, width - 1, width) 529 | ylin = np.linspace(0, height - 1, height) 530 | px, py = np.meshgrid(xlin, ylin) 531 | px = (px - intrinsics[0, 2]) * (depth / intrinsics[0, 0]) 532 | py = (py - intrinsics[1, 2]) * (depth / intrinsics[1, 1]) 533 | points = np.float32([px, py, depth]).transpose(1, 2, 0) 534 | return points 535 | 536 | def transform_pointcloud(self, points, transform): 537 | """Apply rigid transformation to 3D pointcloud. 538 | Args: 539 | points: HxWx3 float array of 3D points in camera coordinates. 540 | transform: 4x4 float array representing a rigid transformation matrix. 541 | Returns: 542 | points: HxWx3 float array of transformed 3D points. 543 | """ 544 | padding = ((0, 0), (0, 0), (0, 1)) 545 | homogen_points = np.pad(points.copy(), padding, "constant", constant_values=1) 546 | for i in range(3): 547 | points[Ellipsis, i] = np.sum(transform[i, :] * homogen_points, axis=-1) 548 | return points 549 | 550 | def get_heightmap(self, points, colors, bounds, pixel_size): 551 | """Get top-down (z-axis) orthographic heightmap image from 3D pointcloud. 552 | Args: 553 | points: HxWx3 float array of 3D points in world coordinates. 554 | colors: HxWx3 uint8 array of values in range 0-255 aligned with points. 555 | bounds: 3x2 float array of values (rows: X,Y,Z; columns: min,max) defining 556 | region in 3D space to generate heightmap in world coordinates. 557 | pixel_size: float defining size of each pixel in meters. 558 | Returns: 559 | heightmap: HxW float array of height (from lower z-bound) in meters. 560 | colormap: HxWx3 uint8 array of backprojected color aligned with heightmap. 561 | xyzmap: HxWx3 float array of XYZ points in world coordinates. 562 | """ 563 | width = int(np.round((bounds[0, 1] - bounds[0, 0]) / pixel_size)) 564 | height = int(np.round((bounds[1, 1] - bounds[1, 0]) / pixel_size)) 565 | heightmap = np.zeros((height, width), dtype=np.float32) 566 | colormap = np.zeros((height, width, colors.shape[-1]), dtype=np.uint8) 567 | xyzmap = np.zeros((height, width, 3), dtype=np.float32) 568 | 569 | # Filter out 3D points that are outside of the predefined bounds. 570 | ix = (points[Ellipsis, 0] >= bounds[0, 0]) & ( 571 | points[Ellipsis, 0] < bounds[0, 1] 572 | ) 573 | iy = (points[Ellipsis, 1] >= bounds[1, 0]) & ( 574 | points[Ellipsis, 1] < bounds[1, 1] 575 | ) 576 | iz = (points[Ellipsis, 2] >= bounds[2, 0]) & ( 577 | points[Ellipsis, 2] < bounds[2, 1] 578 | ) 579 | valid = ix & iy & iz 580 | points = points[valid] 581 | colors = colors[valid] 582 | 583 | # Sort 3D points by z-value, which works with array assignment to simulate 584 | # z-buffering for rendering the heightmap image. 585 | iz = np.argsort(points[:, -1]) 586 | points, colors = points[iz], colors[iz] 587 | px = np.int32(np.floor((points[:, 0] - bounds[0, 0]) / pixel_size)) 588 | py = np.int32(np.floor((points[:, 1] - bounds[1, 0]) / pixel_size)) 589 | px = np.clip(px, 0, width - 1) 590 | py = np.clip(py, 0, height - 1) 591 | heightmap[py, px] = points[:, 2] - bounds[2, 0] 592 | for c in range(colors.shape[-1]): 593 | colormap[py, px, c] = colors[:, c] 594 | xyzmap[py, px, c] = points[:, c] 595 | colormap = colormap[::-1, :, :] # Flip up-down. 596 | xv, yv = np.meshgrid( 597 | np.linspace(BOUNDS[0, 0], BOUNDS[0, 1], height), 598 | np.linspace(BOUNDS[1, 0], BOUNDS[1, 1], width), 599 | ) 600 | xyzmap[:, :, 0] = xv 601 | xyzmap[:, :, 1] = yv 602 | xyzmap = xyzmap[::-1, :, :] # Flip up-down. 603 | heightmap = heightmap[::-1, :] # Flip up-down. 604 | return heightmap, colormap, xyzmap 605 | 606 | def on_top_of(self, obj_a, obj_b): 607 | """ 608 | check if obj_a is on top of obj_b 609 | condition 1: l2 distance on xy plane is less than a threshold 610 | condition 2: obj_a is higher than obj_b 611 | """ 612 | obj_a_pos = self.get_obj_pos(obj_a) 613 | obj_b_pos = self.get_obj_pos(obj_b) 614 | xy_dist = np.linalg.norm(obj_a_pos[:2] - obj_b_pos[:2]) 615 | if obj_b in CORNER_POS: 616 | is_near = xy_dist < 0.06 617 | return is_near 618 | elif "bowl" in obj_b: 619 | is_near = xy_dist < 0.06 620 | is_higher = obj_a_pos[2] > obj_b_pos[2] 621 | return is_near and is_higher 622 | else: 623 | is_near = xy_dist < 0.04 624 | is_higher = obj_a_pos[2] > obj_b_pos[2] 625 | return is_near and is_higher 626 | 627 | def get_obj_id(self, obj_name): 628 | try: 629 | if obj_name in self.obj_name_to_id: 630 | obj_id = self.obj_name_to_id[obj_name] 631 | else: 632 | obj_name = ( 633 | obj_name.replace("circle", "bowl") 634 | .replace("square", "block") 635 | .replace("small", "") 636 | .strip() 637 | ) 638 | obj_id = self.obj_name_to_id[obj_name] 639 | except: 640 | print(f'requested_name="{obj_name}"') 641 | print(f'available_objects_and_id="{self.obj_name_to_id}') 642 | return obj_id 643 | 644 | def get_obj_pos(self, obj_name): 645 | obj_name = obj_name.replace("the", "").replace("_", " ").strip() 646 | if obj_name in CORNER_POS: 647 | position = np.float32(np.array(CORNER_POS[obj_name])) 648 | else: 649 | pick_id = self.get_obj_id(obj_name) 650 | pose = pybullet.getBasePositionAndOrientation(pick_id) 651 | position = np.float32(pose[0]) 652 | return position 653 | 654 | def get_bounding_box(self, obj_name): 655 | obj_id = self.get_obj_id(obj_name) 656 | return pybullet.getAABB(obj_id) 657 | -------------------------------------------------------------------------------- /VLM_CaP/src/grippers.py: -------------------------------------------------------------------------------- 1 | # Gripper (Robotiq 2F85) code 2 | import os 3 | from time import sleep 4 | import numpy as np 5 | import pybullet 6 | import threading 7 | 8 | class Robotiq2F85: 9 | """Gripper handling for Robotiq 2F85.""" 10 | 11 | def __init__(self, robot, tool): 12 | self.robot = robot 13 | self.tool = tool 14 | pos = [0.1339999999999999, -0.49199999999872496, 0.5] 15 | rot = pybullet.getQuaternionFromEuler([np.pi, 0, np.pi]) 16 | urdf = 'assets/robotiq_2f_85/robotiq_2f_85.urdf' 17 | self.body = pybullet.loadURDF(urdf, pos, rot) 18 | self.n_joints = pybullet.getNumJoints(self.body) 19 | self.activated = False 20 | 21 | # Connect gripper base to robot tool. 22 | pybullet.createConstraint(self.robot, tool, self.body, 0, jointType=pybullet.JOINT_FIXED, jointAxis=[0, 0, 0], parentFramePosition=[0, 0, 0], childFramePosition=[0, 0, -0.07], childFrameOrientation=pybullet.getQuaternionFromEuler([0, 0, np.pi / 2])) 23 | 24 | # Set friction coefficients for gripper fingers. 25 | for i in range(pybullet.getNumJoints(self.body)): 26 | pybullet.changeDynamics(self.body, i, lateralFriction=10.0, spinningFriction=1.0, rollingFriction=1.0, frictionAnchor=True) 27 | 28 | # Start thread to handle additional gripper constraints. 29 | self.motor_joint = 1 30 | self.constraints_thread = threading.Thread(target=self.step) 31 | self.constraints_thread.daemon = True 32 | self.constraints_thread.start() 33 | 34 | # Control joint positions by enforcing hard contraints on gripper behavior. 35 | # Set one joint as the open/close motor joint (other joints should mimic). 36 | def step(self): 37 | while True: 38 | try: 39 | currj = [pybullet.getJointState(self.body, i)[0] for i in range(self.n_joints)] 40 | indj = [6, 3, 8, 5, 10] 41 | targj = [currj[1], -currj[1], -currj[1], currj[1], currj[1]] 42 | pybullet.setJointMotorControlArray(self.body, indj, pybullet.POSITION_CONTROL, targj, positionGains=np.ones(5)) 43 | except: 44 | return 45 | sleep(0.001) 46 | 47 | # Close gripper fingers. 48 | def activate(self): 49 | pybullet.setJointMotorControl2(self.body, self.motor_joint, pybullet.VELOCITY_CONTROL, targetVelocity=1, force=10) 50 | self.activated = True 51 | 52 | # Open gripper fingers. 53 | def release(self): 54 | pybullet.setJointMotorControl2(self.body, self.motor_joint, pybullet.VELOCITY_CONTROL, targetVelocity=-1, force=10) 55 | self.activated = False 56 | 57 | # If activated and object in gripper: check object contact. 58 | # If activated and nothing in gripper: check gripper contact. 59 | # If released: check proximity to surface (disabled). 60 | def detect_contact(self): 61 | obj, _, ray_frac = self.check_proximity() 62 | if self.activated: 63 | empty = self.grasp_width() < 0.01 64 | cbody = self.body if empty else obj 65 | if obj == self.body or obj == 0: 66 | return False 67 | return self.external_contact(cbody) 68 | # else: 69 | # return ray_frac < 0.14 or self.external_contact() 70 | 71 | # Return if body is in contact with something other than gripper 72 | def external_contact(self, body=None): 73 | if body is None: 74 | body = self.body 75 | pts = pybullet.getContactPoints(bodyA=body) 76 | pts = [pt for pt in pts if pt[2] != self.body] 77 | return len(pts) > 0 # pylint: disable=g-explicit-length-test 78 | 79 | def check_grasp(self): 80 | while self.moving(): 81 | sleep(0.001) 82 | success = self.grasp_width() > 0.01 83 | return success 84 | 85 | def grasp_width(self): 86 | lpad = np.array(pybullet.getLinkState(self.body, 4)[0]) 87 | rpad = np.array(pybullet.getLinkState(self.body, 9)[0]) 88 | dist = np.linalg.norm(lpad - rpad) - 0.047813 89 | return dist 90 | 91 | def check_proximity(self): 92 | ee_pos = np.array(pybullet.getLinkState(self.robot, self.tool)[0]) 93 | tool_pos = np.array(pybullet.getLinkState(self.body, 0)[0]) 94 | vec = (tool_pos - ee_pos) / np.linalg.norm((tool_pos - ee_pos)) 95 | ee_targ = ee_pos + vec 96 | ray_data = pybullet.rayTest(ee_pos, ee_targ)[0] 97 | obj, link, ray_frac = ray_data[0], ray_data[1], ray_data[2] 98 | return obj, link, ray_frac -------------------------------------------------------------------------------- /VLM_CaP/src/key.py: -------------------------------------------------------------------------------- 1 | mykey = 'YOUR_API_KEY' 2 | projectkey = 'YOUR_API_KEY' -------------------------------------------------------------------------------- /VLM_CaP/src/prompts.py: -------------------------------------------------------------------------------- 1 | # prompts 2 | prompt_tabletop_ui = ''' 3 | # Python 2D robot control script 4 | import numpy as np 5 | from env_utils import put_first_on_second, get_obj_pos, get_obj_names, say, get_corner_name, get_side_name, is_obj_visible, stack_objects_in_order 6 | from plan_utils import parse_obj_name, parse_position, parse_question, transform_shape_pts 7 | 8 | objects = ['yellow block', 'green block', 'yellow bowl', 'blue block', 'blue bowl', 'green bowl', 'wooden block1', 'wooden block2'] 9 | # place the yellow block on the red block. 10 | say('Ok - putting the yellow block on the red block') 11 | put_first_on_second('yellow block', 'red block') 12 | objects = ['yellow block', 'green block', 'yellow bowl', 'blue block', 'blue bowl', 'green bowl'] 13 | # place the wooden block1 on top of wooden block2 14 | wooden_block2_pos = get_obj_pos('wooden block2') 15 | put_first_on_second('wooden block1', wooden_block2_pos) 16 | # place the wooden block1 to the left of wooden block2 17 | wooden_block2_pos = get_obj_pos('wooden block2') 18 | wooden_block1_pos = np.array([wooden_block1_pos[0] - 0.15, wooden_block2_pos[1]]) 19 | put_first_on_second('wooden block1', wooden_block1_pos) 20 | # place the yellow block to the right of red block 21 | red_block_pos = get_obj_pos('red block') 22 | yellow_block_pos = np.array([red_block_pos[0] + 0.15, red_block_pos[1]]) 23 | put_first_on_second('yellow block', yellow_block_pos) 24 | # which block did you move. 25 | say('I moved the yellow block') 26 | objects = ['yellow block', 'green block', 'yellow bowl', 'blue block', 'blue bowl', 'green bowl'] 27 | # move the green block to the top right corner. 28 | say('Got it - putting the green block on the top right corner') 29 | corner_pos = parse_position('top right corner') 30 | put_first_on_second('green block', corner_pos) 31 | objects = ['yellow block', 'green block', 'yellow bowl', 'blue block', 'blue bowl', 'green bowl'] 32 | # stack the blue bowl on the yellow bowl on the green block. 33 | order_bottom_to_top = ['green block', 'yellow block', 'blue bowl'] 34 | say(f'Sure - stacking from top to bottom: {", ".join(order_bottom_to_top)}') 35 | stack_objects_in_order(object_names=order_bottom_to_top) 36 | objects = ['cyan block', 'white block', 'cyan bowl', 'blue block', 'blue bowl', 'white bowl'] 37 | # move the cyan block into its corresponding bowl. 38 | matches = {'cyan block': 'cyan bowl'} 39 | say('Got it - placing the cyan block on the cyan bowl') 40 | for first, second in matches.items(): 41 | put_first_on_second(first, get_obj_pos(second)) 42 | objects = ['cyan block', 'white block', 'cyan bowl', 'blue block', 'blue bowl', 'white bowl'] 43 | # place the green block to the right of the bowl that has the blue block. 44 | bowl_name = parse_obj_name('the bowl that has the blue block', f'objects = {get_obj_names()}') 45 | if bowl_name: 46 | target_pos = parse_position(f'a point 10cm to the right of the {bowl_name}') 47 | say(f'No problem - placing the green block to the right of the {bowl_name}') 48 | put_first_on_second('green block', target_pos) 49 | else: 50 | say('There are no bowls that has the blue block') 51 | objects = ['brown bowl', 'green block', 'brown block', 'green bowl', 'blue bowl', 'blue block'] 52 | # is the blue block to the right of the yellow bowl? 53 | if parse_question('is the blue block to the right of the yellow bowl?', f'objects = {get_obj_names()}'): 54 | say('yes, there is a blue block to the right of the yellow bow') 55 | else: 56 | say('no, there is\'t a blue block to the right of the yellow bow') 57 | objects = ['yellow bowl', 'blue block', 'yellow block', 'blue bowl'] 58 | # how many yellow objects are there? 59 | n_yellow_objs = parse_question('how many yellow objects are there', f'objects = {get_obj_names()}') 60 | say(f'there are {n_yellow_objs} yellow object') 61 | objects = ['pink block', 'green block', 'pink bowl', 'blue block', 'blue bowl', 'green bowl'] 62 | # move the left most block to the green bowl. 63 | left_block_name = parse_obj_name('left most block', f'objects = {get_obj_names()}') 64 | say(f'Moving the {left_block_name} on the green bowl') 65 | put_first_on_second(left_block_name, 'green bowl') 66 | objects = ['pink block', 'green block', 'pink bowl', 'blue block', 'blue bowl', 'green bowl'] 67 | # move the other blocks to different corners. 68 | block_names = parse_obj_name(f'blocks other than the {left_block_name}', f'objects = {get_obj_names()}') 69 | corners = parse_position('the corners') 70 | say(f'Ok - moving the other {len(block_names)} blocks to different corners') 71 | for block_name, pos in zip(block_names, corners): 72 | put_first_on_second(block_name, pos) 73 | objects = ['pink block', 'green block', 'pink bowl', 'blue block', 'blue bowl', 'green bowl'] 74 | # is the pink block on the green bowl. 75 | if parse_question('is the pink block on the green bowl', f'objects = {get_obj_names()}'): 76 | say('Yes - the pink block is on the green bowl.') 77 | else: 78 | say('No - the pink block is not on the green bowl.') 79 | objects = ['pink block', 'green block', 'pink bowl', 'blue block', 'blue bowl', 'green bowl'] 80 | # what are the blocks left of the green bowl. 81 | left_block_names = parse_question('what are the blocks left of the green bowl', f'objects = {get_obj_names()}') 82 | if len(left_block_names) > 0: 83 | say(f'These blocks are left of the green bowl: {", ".join(left_block_names)}') 84 | else: 85 | say('There are no blocks left of the green bowl') 86 | objects = ['yellow block', 'green block', 'yellow bowl', 'blue block', 'blue bowl', 'green bowl'] 87 | # imagine that the bowls are different biomes on earth and imagine that the blocks are parts of a building. 88 | say('ok') 89 | objects = ['yellow block', 'green block', 'yellow bowl', 'blue block', 'blue bowl', 'green bowl'] 90 | # now build a tower in the grasslands. 91 | order_bottom_to_top = ['green bowl', 'blue block', 'green block', 'yellow block'] 92 | say('stacking the blocks on the green bowl') 93 | stack_objects_in_order(object_names=order_bottom_to_top) 94 | objects = ['yellow block', 'green block', 'yellow bowl', 'gray block', 'gray bowl', 'green bowl'] 95 | # show me what happens when the desert gets flooded by the ocean. 96 | say('putting the yellow bowl on the blue bowl') 97 | put_first_on_second('yellow bowl', 'blue bowl') 98 | objects = ['pink block', 'gray block', 'orange block'] 99 | # move all blocks 5cm toward the top. 100 | say('Ok - moving all blocks 5cm toward the top') 101 | block_names = parse_obj_name('the blocks', f'objects = {get_obj_names()}') 102 | for block_name in block_names: 103 | target_pos = parse_position(f'a point 5cm above the {block_name}') 104 | put_first_on_second(block_name, target_pos) 105 | objects = ['cyan block', 'white block', 'purple bowl', 'blue block', 'blue bowl', 'white bowl'] 106 | # make a triangle of blocks in the middle. 107 | block_names = parse_obj_name('the blocks', f'objects = {get_obj_names()}') 108 | triangle_pts = parse_position(f'a triangle with size 10cm around the middle with {len(block_names)} points') 109 | say('Making a triangle of blocks around the middle of the workspace') 110 | for block_name, pt in zip(block_names, triangle_pts): 111 | put_first_on_second(block_name, pt) 112 | objects = ['cyan block', 'white block', 'purple bowl', 'blue block', 'blue bowl', 'white bowl'] 113 | # make the triangle smaller. 114 | triangle_pts = transform_shape_pts('scale it by 0.5x', shape_pts=triangle_pts) 115 | say('Making the triangle smaller') 116 | block_names = parse_obj_name('the blocks', f'objects = {get_obj_names()}') 117 | for block_name, pt in zip(block_names, triangle_pts): 118 | put_first_on_second(block_name, pt) 119 | objects = ['brown bowl', 'red block', 'brown block', 'red bowl', 'pink bowl', 'pink block'] 120 | # put the red block on the farthest bowl. 121 | farthest_bowl_name = parse_obj_name('the bowl farthest from the red block', f'objects = {get_obj_names()}') 122 | say(f'Putting the red block on the {farthest_bowl_name}') 123 | put_first_on_second('red block', farthest_bowl_name) 124 | '''.strip() 125 | 126 | 127 | prompt_parse_obj_name = ''' 128 | import numpy as np 129 | from env_utils import get_obj_pos, parse_position 130 | from utils import get_obj_positions_np 131 | 132 | objects = ['blue block', 'cyan block', 'purple bowl', 'gray bowl', 'brown bowl', 'pink block', 'purple block'] 133 | # the block closest to the purple bowl. 134 | block_names = ['blue block', 'cyan block', 'purple block'] 135 | block_positions = get_obj_positions_np(block_names) 136 | closest_block_idx = get_closest_idx(points=block_positions, point=get_obj_pos('purple bowl')) 137 | closest_block_name = block_names[closest_block_idx] 138 | ret_val = closest_block_name 139 | objects = ['brown bowl', 'banana', 'brown block', 'apple', 'blue bowl', 'blue block'] 140 | # the blocks. 141 | ret_val = ['brown block', 'blue block'] 142 | objects = ['brown bowl', 'banana', 'brown block', 'apple', 'blue bowl', 'blue block'] 143 | # the brown objects. 144 | ret_val = ['brown bowl', 'brown block'] 145 | objects = ['brown bowl', 'banana', 'brown block', 'apple', 'blue bowl', 'blue block'] 146 | # a fruit that's not the apple 147 | fruit_names = ['banana', 'apple'] 148 | for fruit_name in fruit_names: 149 | if fruit_name != 'apple': 150 | ret_val = fruit_name 151 | objects = ['blue block', 'cyan block', 'purple bowl', 'brown bowl', 'purple block'] 152 | # blocks above the brown bowl. 153 | block_names = ['blue block', 'cyan block', 'purple block'] 154 | brown_bowl_pos = get_obj_pos('brown bowl') 155 | use_block_names = [] 156 | for block_name in block_names: 157 | if get_obj_pos(block_name)[1] > brown_bowl_pos[1]: 158 | use_block_names.append(block_name) 159 | ret_val = use_block_names 160 | objects = ['blue block', 'cyan block', 'purple bowl', 'brown bowl', 'purple block'] 161 | # the blue block. 162 | ret_val = 'blue block' 163 | objects = ['blue block', 'cyan block', 'purple bowl', 'brown bowl', 'purple block'] 164 | # the block closest to the bottom right corner. 165 | corner_pos = parse_position('bottom right corner') 166 | block_names = ['blue block', 'cyan block', 'purple block'] 167 | block_positions = get_obj_positions_np(block_names) 168 | closest_block_idx = get_closest_idx(points=block_positions, point=corner_pos) 169 | closest_block_name = block_names[closest_block_idx] 170 | ret_val = closest_block_name 171 | objects = ['brown bowl', 'green block', 'brown block', 'green bowl', 'blue bowl', 'blue block'] 172 | # the left most block. 173 | block_names = ['green block', 'brown block', 'blue block'] 174 | block_positions = get_obj_positions_np(block_names) 175 | left_block_idx = np.argsort(block_positions[:, 0])[0] 176 | left_block_name = block_names[left_block_idx] 177 | ret_val = left_block_name 178 | objects = ['brown bowl', 'green block', 'brown block', 'green bowl', 'blue bowl', 'blue block'] 179 | # the bowl on near the top. 180 | bowl_names = ['brown bowl', 'green bowl', 'blue bowl'] 181 | bowl_positions = get_obj_positions_np(bowl_names) 182 | top_bowl_idx = np.argsort(bowl_positions[:, 1])[-1] 183 | top_bowl_name = bowl_names[top_bowl_idx] 184 | ret_val = top_bowl_name 185 | objects = ['yellow bowl', 'purple block', 'yellow block', 'purple bowl', 'pink bowl', 'pink block'] 186 | # the third bowl from the right. 187 | bowl_names = ['yellow bowl', 'purple bowl', 'pink bowl'] 188 | bowl_positions = get_obj_positions_np(bowl_names) 189 | bowl_idx = np.argsort(bowl_positions[:, 0])[-3] 190 | bowl_name = bowl_names[bowl_idx] 191 | ret_val = bowl_name 192 | '''.strip() 193 | 194 | 195 | prompt_parse_position = ''' 196 | import numpy as np 197 | from shapely.geometry import * 198 | from shapely.affinity import * 199 | from env_utils import denormalize_xy, parse_obj_name, get_obj_names, get_obj_pos 200 | 201 | # the side farthest from the right most bowl. 202 | bowl_name = parse_obj_name('the right most bowl', f'objects = {get_obj_names()}') 203 | side_positions = np.array([denormalize_xy(pos) for pos in [[0.5, 0], [0.5, 1], [1, 0.5], [0, 0.5]]]) 204 | farthest_side_pos = get_farthest_point(points=side_positions, point=get_obj_pos(bowl_name)) 205 | ret_val = farthest_side_pos 206 | # a point above the third block from the bottom. 207 | block_name = parse_obj_name('the third block from the bottom', f'objects = {get_obj_names()}') 208 | ret_val = get_obj_pos(block_name) + [0.1, 0] 209 | # a point adjacent to the right of red block. 210 | block_name = parse_obj_name('red block', f'objects = {get_obj_names()}') 211 | block_pos = get_obj_pos(block_name) 212 | adjacent_pos = block_pos + [0.02, 0] 213 | ret_val = adjacent_pos 214 | # a point faraway to the right of red block. 215 | block_name = parse_obj_name('red block', f'objects = {get_obj_names()}') 216 | block_pos = get_obj_pos(block_name) 217 | adjacent_pos = block_pos + [0.1, 0] 218 | ret_val = adjacent_pos 219 | # a point adjacent above red block. 220 | block_name = parse_obj_name('red block', f'objects = {get_obj_names()}') 221 | block_pos = get_obj_pos(block_name) 222 | adjacent_pos = block_pos + [0, 0.02] 223 | ret_val = adjacent_pos 224 | # the bottom side. 225 | bottom_pos = denormalize_xy([0.5, 0]) 226 | ret_val = bottom_pos 227 | # the top corners. 228 | top_left_pos = denormalize_xy([0, 1]) 229 | top_right_pos = denormalize_xy([1, 1]) 230 | ret_val = [top_left_pos, top_right_pos] 231 | '''.strip() 232 | 233 | prompt_parse_question = ''' 234 | from utils import get_obj_pos, get_obj_names, parse_obj_name, bbox_contains_pt, is_obj_visible 235 | 236 | objects = ['yellow bowl', 'blue block', 'yellow block', 'blue bowl', 'fruit', 'green block', 'black bowl'] 237 | # is the blue block to the right of the yellow bowl? 238 | ret_val = get_obj_pos('blue block')[0] > get_obj_pos('yellow bowl')[0] 239 | objects = ['yellow bowl', 'blue block', 'yellow block', 'blue bowl', 'fruit', 'green block', 'black bowl'] 240 | # how many yellow objects are there? 241 | yellow_object_names = parse_obj_name('the yellow objects', f'objects = {get_obj_names()}') 242 | ret_val = len(yellow_object_names) 243 | objects = ['pink block', 'green block', 'pink bowl', 'blue block', 'blue bowl', 'green bowl'] 244 | # is the pink block on the green bowl? 245 | ret_val = bbox_contains_pt(container_name='green bowl', obj_name='pink block') 246 | objects = ['pink block', 'green block', 'pink bowl', 'blue block', 'blue bowl', 'green bowl'] 247 | # what are the blocks left of the green bowl? 248 | block_names = parse_obj_name('the blocks', f'objects = {get_obj_names()}') 249 | green_bowl_pos = get_obj_pos('green bowl') 250 | left_block_names = [] 251 | for block_name in block_names: 252 | if get_obj_pos(block_name)[0] < green_bowl_pos[0]: 253 | left_block_names.append(block_name) 254 | ret_val = left_block_names 255 | objects = ['pink block', 'yellow block', 'pink bowl', 'blue block', 'blue bowl', 'yellow bowl'] 256 | # is the sun colored block above the blue bowl? 257 | sun_block_name = parse_obj_name('sun colored block', f'objects = {get_obj_names()}') 258 | sun_block_pos = get_obj_pos(sun_block_name) 259 | blue_bowl_pos = get_obj_pos('blue bowl') 260 | ret_val = sun_block_pos[1] > blue_bowl_pos[1] 261 | objects = ['pink block', 'yellow block', 'pink bowl', 'blue block', 'blue bowl', 'yellow bowl'] 262 | # is the green block below the blue bowl? 263 | ret_val = get_obj_pos('green block')[1] < get_obj_pos('blue bowl')[1] 264 | '''.strip() 265 | 266 | prompt_transform_shape_pts = ''' 267 | import numpy as np 268 | from utils import get_obj_pos, get_obj_names, parse_position, parse_obj_name 269 | 270 | # make it bigger by 1.5. 271 | new_shape_pts = scale_pts_around_centroid_np(shape_pts, scale_x=1.5, scale_y=1.5) 272 | # move it to the right by 10cm. 273 | new_shape_pts = translate_pts_np(shape_pts, delta=[0.1, 0]) 274 | # move it to the top by 20cm. 275 | new_shape_pts = translate_pts_np(shape_pts, delta=[0, 0.2]) 276 | # rotate it clockwise by 40 degrees. 277 | new_shape_pts = rotate_pts_around_centroid_np(shape_pts, angle=-np.deg2rad(40)) 278 | # rotate by 30 degrees and make it slightly smaller 279 | new_shape_pts = rotate_pts_around_centroid_np(shape_pts, angle=np.deg2rad(30)) 280 | new_shape_pts = scale_pts_around_centroid_np(new_shape_pts, scale_x=0.7, scale_y=0.7) 281 | # move it toward the blue block. 282 | block_name = parse_obj_name('the blue block', f'objects = {get_obj_names()}') 283 | block_pos = get_obj_pos(block_name) 284 | mean_delta = np.mean(block_pos - shape_pts, axis=1) 285 | new_shape_pts = translate_pts_np(shape_pts, mean_delta) 286 | '''.strip() 287 | 288 | prompt_fgen = ''' 289 | import numpy as np 290 | from shapely.geometry import * 291 | from shapely.affinity import * 292 | 293 | from env_utils import get_obj_pos, get_obj_names 294 | from ctrl_utils import put_first_on_second 295 | 296 | # define function: total = get_total(xs=numbers). 297 | def get_total(xs): 298 | return np.sum(xs) 299 | 300 | # define function: y = eval_line(x, slope, y_intercept=0). 301 | def eval_line(x, slope, y_intercept): 302 | return x * slope + y_intercept 303 | 304 | # define function: pt = get_pt_to_the_left(pt, dist). 305 | def get_pt_to_the_left(pt, dist): 306 | return pt + [-dist, 0] 307 | 308 | # define function: pt = get_pt_to_the_top(pt, dist). 309 | def get_pt_to_the_top(pt, dist): 310 | return pt + [0, dist] 311 | 312 | # define function line = make_line_by_length(length=x). 313 | def make_line_by_length(length): 314 | line = LineString([[0, 0], [length, 0]]) 315 | return line 316 | 317 | # define function: line = make_vertical_line_by_length(length=x). 318 | def make_vertical_line_by_length(length): 319 | line = make_line_by_length(length) 320 | vertical_line = rotate(line, 90) 321 | return vertical_line 322 | 323 | # define function: pt = interpolate_line(line, t=0.5). 324 | def interpolate_line(line, t): 325 | pt = line.interpolate(t, normalized=True) 326 | return np.array(pt.coords[0]) 327 | 328 | # example: scale a line by 2. 329 | line = make_line_by_length(1) 330 | new_shape = scale(line, xfact=2, yfact=2) 331 | 332 | # example: put object1 on top of object0. 333 | put_first_on_second('object1', 'object0') 334 | 335 | # example: get the position of the first object. 336 | obj_names = get_obj_names() 337 | pos_2d = get_obj_pos(obj_names[0]) 338 | '''.strip() -------------------------------------------------------------------------------- /VLM_CaP/src/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="VLMTutor", 5 | version="0.1", 6 | packages=find_packages(), 7 | install_requires=[ 8 | 9 | ], 10 | ) 11 | -------------------------------------------------------------------------------- /VLM_CaP/src/vlm_video.py: -------------------------------------------------------------------------------- 1 | from IPython.display import display, Image, Audio 2 | 3 | import cv2 # We're using OpenCV to read video, to install !pip install opencv-python 4 | import base64 5 | import time 6 | from openai import OpenAI 7 | import os 8 | import requests 9 | 10 | def extract_frames(video_path): 11 | # 使用 OpenCV 从视频文件中提取帧 12 | video = cv2.VideoCapture(video_path) 13 | 14 | base64Frames = [] 15 | while video.isOpened(): 16 | success, frame = video.read() 17 | if not success: 18 | break 19 | _, buffer = cv2.imencode(".jpg", frame) 20 | base64Frames.append(base64.b64encode(buffer).decode("utf-8")) 21 | 22 | video.release() 23 | print(len(base64Frames), "frames read.") 24 | 25 | # 展示帧,用于调试 26 | # display_handle = display(None, display_id=True) 27 | # for img in base64Frames: 28 | # display_handle.update(Image(data=base64.b64decode(img.encode("utf-8")))) 29 | # time.sleep(0.025) # 如果需要以动画形式展示每帧,可以取消此行注释 30 | 31 | return base64Frames # 返回包含所有帧的 base64 编码的列表 32 | 33 | def extract_frame_list(frames_list): 34 | """ 35 | 将图像帧列表转换为 Base64 编码的列表。 36 | 37 | :param frames_list: 图像帧列表,每一帧是一个 NumPy 数组(从 OpenCV 读取的图像)。 38 | :return: 包含所有帧的 Base64 编码字符串列表。 39 | """ 40 | base64Frames = [] 41 | 42 | for frame in frames_list: 43 | if frame is None: 44 | continue 45 | 46 | # 编码帧为 JPEG 格式 47 | _, buffer = cv2.imencode(".jpg", frame) 48 | 49 | # 将 JPEG 编码的帧转换为 Base64 字符串 50 | base64Frames.append(base64.b64encode(buffer).decode("utf-8")) 51 | 52 | print(len(base64Frames), "frames processed.") 53 | 54 | return base64Frames # 返回包含所有帧的 Base64 编码的列表 -------------------------------------------------------------------------------- /convert_video.py: -------------------------------------------------------------------------------- 1 | import ffmpy 2 | import argparse 3 | 4 | def convert_video_to_30fps(input_path, output_path): 5 | """ 6 | Converts the input video file to H.264 encoded .mp4 format with 30 FPS using ffmpy. 7 | """ 8 | ff = ffmpy.FFmpeg( 9 | inputs={input_path: None}, 10 | outputs={output_path: '-c:v libx264 -r 30 -crf 23 -preset fast'} 11 | ) 12 | ff.run() 13 | print(f"Video converted successfully to 30 FPS: {output_path}") 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser(description="Convert a video to 30 FPS with H.264 encoding.") 17 | parser.add_argument('--input', type=str, required=True, help="Path to the input video file.") 18 | parser.add_argument('--output', type=str, required=True, help="Path to the output video file.") 19 | 20 | args = parser.parse_args() 21 | 22 | # Convert the single video 23 | convert_video_to_30fps(args.input, args.output) 24 | 25 | if __name__ == '__main__': 26 | main() -------------------------------------------------------------------------------- /get_frame_by_hands.py: -------------------------------------------------------------------------------- 1 | import mediapipe as mp 2 | from mediapipe import solutions 3 | from mediapipe.framework.formats import landmark_pb2 4 | 5 | import numpy as np 6 | from scipy.ndimage import gaussian_filter 7 | from scipy.signal import find_peaks 8 | 9 | import cv2 10 | import matplotlib.pyplot as plt 11 | import os 12 | import time 13 | from pathlib import Path 14 | from argparse import ArgumentParser 15 | import csv 16 | 17 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 18 | 19 | class FrameExtractor: 20 | def __init__(self, video_path, output_dir, gaussian_sigma=5, prominence=0.8, csv_file='new_wooden_block_selected_valleys.csv'): 21 | self._folder_init(video_path, output_dir) 22 | self._mediapipe_init() 23 | self._visualization_init() 24 | 25 | self.cap = cv2.VideoCapture(self.video_path) 26 | self.num_frame = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) 27 | 28 | ''' 29 | Hardcoded possible handedness is a bad practice because we do not know how the Google API may change in the future. 30 | But because how hard it is to actually query category_module.Category and get all possible handedness, we will just hardcode it. 31 | ''' 32 | self.all_landmark_pos = {"Right": np.zeros((self.num_frame, 2)), "Left": np.zeros((self.num_frame, 2))} 33 | 34 | # signal processing parameters 35 | self.gaussian_sigma = gaussian_sigma 36 | self.prominence = prominence 37 | 38 | self.csv_file = csv_file 39 | 40 | def extract_frames(self): 41 | ''' 42 | Extract the frames from the video. 43 | ''' 44 | self.analyze_video() 45 | 46 | self.all_possible_handedness = set(self.all_landmark_pos.keys()) # all possible handedness detected in the video 47 | 48 | self.all_speeds = {} 49 | 50 | for handedness in self.all_possible_handedness: 51 | print(f"Calculating the speed curve of the {handedness} hand.") 52 | self.all_speeds[handedness] = self.get_speed(self.all_landmark_pos[handedness]) 53 | print(f"Plotting the speed curve of the {handedness} hand.") 54 | self.plot_speed(self.all_speeds[handedness], handedness) 55 | print(f"Making the video of the {handedness} hand.") 56 | self.make_video(handedness) 57 | 58 | print("Deciding which hand to focus on.") 59 | self.handedness = "Right" 60 | print(f"The decided handedness is {self.handedness}.") 61 | 62 | print(f"Processing the speed curve of {self.handedness} hand.") 63 | smoothed_curve = self.process_speed_curve() 64 | 65 | print(f"Getting the peaks and valleys of the speed curve of {self.handedness} hand.") 66 | peaks, valleys = self.get_peaks_valleys(smoothed_curve) 67 | 68 | # Filter valleys based on the index difference 69 | selected_valleys = [] 70 | for i in range(len(valleys)): 71 | if i == 0 or (valleys[i] - selected_valleys[-1]) >= 15: 72 | selected_valleys.append(valleys[i]) 73 | 74 | print(f"Plotting and making videos with smoothed {self.handedness} hand speed curve.") 75 | self.plot_speed(speeds=smoothed_curve, 76 | handedness=f'Smoothed {self.handedness}', 77 | selected_frame=selected_valleys) 78 | 79 | self.make_video(f'Smoothed {self.handedness}') 80 | 81 | first_frame = self.get_frame(0) 82 | cv2.imwrite(f'{str(self.selected_folder)}/{0}.jpg', first_frame) 83 | for valley in selected_valleys: 84 | frame = self.get_frame(valley) 85 | cv2.imwrite(f'{str(self.selected_folder)}/{valley}.jpg', frame) 86 | print(f"The selected valley frames are: {selected_valleys}") 87 | 88 | for valley in valleys: 89 | frame = self.get_frame(valley) 90 | cv2.imwrite(f'{str(self.all_valleys_folder)}/{valley}.jpg', frame) 91 | print(f"All valley frames are: {valleys}") 92 | 93 | # Save selected valleys to a unified CSV 94 | self.save_selected_valleys_to_csv(selected_valleys) 95 | 96 | self.cap.release() 97 | print("All done!") 98 | 99 | def save_selected_valleys_to_csv(self, selected_valleys): 100 | ''' 101 | Save the selected valleys to a unified CSV file with the video name as the identifier. 102 | This ensures all videos append to the same CSV file. 103 | ''' 104 | with open(self.csv_file, mode='a', newline='') as file: 105 | writer = csv.writer(file) 106 | writer.writerow([self.video_name, selected_valleys]) 107 | print(f"Selected valleys for {self.video_name} saved to {self.csv_file}") 108 | 109 | def analyze_video(self): 110 | ''' 111 | Analyze the video frame by frame. 112 | ''' 113 | 114 | frame_counter = 0 115 | while self.cap.isOpened(): 116 | ret, frame = self.cap.read() 117 | if not ret: 118 | print(f"End of video at frame {frame_counter}") 119 | break 120 | 121 | # convert the BGR image to RGB because of OpenCV uses BGR while MediaPipe uses RGB 122 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 123 | 124 | with self.HandLandmarker.create_from_options(self.mp_hand_options) as landmarker: 125 | mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame_rgb) # convert the image to MediaPipe format 126 | mp_timestamp = int(round(time.time()*1000)) # get the current timestamp in milliseconds 127 | 128 | results = landmarker.detect_for_video(mp_image, mp_timestamp) # detect the hands in the frame 129 | 130 | # Draw the landmarks on the image and get the average x and y coordinates of the hand landmarks in this frame. 131 | annotated_image, avg_landmark_x, avg_landmark_y, all_handedness = self.process_frame_results(rgb_image=frame, 132 | detection_result=results) 133 | 134 | if len(all_handedness) == 0: 135 | # no hands detected in this frame 136 | self.all_landmark_pos['Right'][frame_counter] = None 137 | self.all_landmark_pos['Left'][frame_counter] = None 138 | else: 139 | for handedness in all_handedness: 140 | # collet both left and right hand landmarks 141 | self.all_landmark_pos[handedness][frame_counter] = (avg_landmark_x[handedness], avg_landmark_y[handedness]) 142 | 143 | # save the image to a folder, not as a video 144 | cv2.imwrite(f'{str(self.hand_images_folder)}/{frame_counter}.jpg', annotated_image) 145 | 146 | 147 | frame_counter += 1 148 | 149 | ''' 150 | Replace the 0 in the landmark_pos with np.nan 151 | When one hand is detected, the other hand will have 0 as the x and y coordinates. 152 | ''' 153 | for handedness in self.all_landmark_pos.keys(): 154 | self.all_landmark_pos[handedness] = np.where(self.all_landmark_pos[handedness] == 0, np.nan, self.all_landmark_pos[handedness]) 155 | 156 | def process_frame_results(self, rgb_image, detection_result): 157 | ''' 158 | Annotate the image with the hand landmarks and handedness. 159 | Also process the results to get the average x and y coordinates of the hand landmarks in this frame. 160 | 161 | Mostly copied from the MediaPipe example. Kinda messy to be honest. 162 | 163 | Input: 164 | rgb_image: np.array. The image to be annotated. 165 | detection_result: mp.tasks.vision.HandLandmarkerResult. The detection result. 166 | 167 | Return: 168 | annotated_image: np.array. The annotated image. 169 | avg_landmark_x: dict. The average x coordinate of the hand landmarks in this frame. Both left and right hand. 170 | avg_landmark_y: dict. The average y coordinate of the hand landmarks in this frame. Both left and right hand. 171 | all_handedness: set. All handedness detected in this frame. 172 | ''' 173 | 174 | hand_landmarks_list = detection_result.hand_landmarks 175 | handedness_list = detection_result.handedness 176 | annotated_image = np.copy(rgb_image) 177 | 178 | 179 | avg_landmark_x = {} # average x coordinate of the hand landmarks in this frame 180 | avg_landmark_y = {} # average y coordinate of the hand landmarks in this frame 181 | all_handedness = set() # all handedness detected in this frame 182 | 183 | # Loop through the detected hands to visualize. 184 | for idx in range(len(hand_landmarks_list)): 185 | hand_landmarks = hand_landmarks_list[idx] 186 | handedness = handedness_list[idx] 187 | 188 | # Draw the hand landmarks. 189 | hand_landmarks_proto = landmark_pb2.NormalizedLandmarkList() # type: ignore 190 | hand_landmarks_proto.landmark.extend([ 191 | landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in hand_landmarks # type: ignore 192 | ]) 193 | solutions.drawing_utils.draw_landmarks( # type: ignore 194 | annotated_image, 195 | hand_landmarks_proto, 196 | solutions.hands.HAND_CONNECTIONS, # type: ignore 197 | solutions.drawing_styles.get_default_hand_landmarks_style(), # type: ignore 198 | solutions.drawing_styles.get_default_hand_connections_style()) # type: ignore 199 | 200 | # Get the top left corner of the detected hand's bounding box. 201 | height, width, _ = annotated_image.shape 202 | x_coordinates = [landmark.x for landmark in hand_landmarks] 203 | y_coordinates = [landmark.y for landmark in hand_landmarks] 204 | text_x = int(min(x_coordinates) * width) 205 | text_y = int(min(y_coordinates) * height) - self.MARGIN 206 | 207 | avg_landmark_x[handedness[0].category_name] = np.average(x_coordinates) * width 208 | avg_landmark_y[handedness[0].category_name] = np.average(y_coordinates) * height 209 | all_handedness.add(handedness[0].category_name) 210 | 211 | # Draw handedness (left or right hand) on the image. 212 | cv2.putText(annotated_image, f"{handedness[0].category_name}", 213 | (text_x, text_y), cv2.FONT_HERSHEY_DUPLEX, 214 | self.FONT_SIZE, self.HANDEDNESS_TEXT_COLOR, self.FONT_THICKNESS, cv2.LINE_AA) 215 | 216 | return annotated_image, avg_landmark_x, avg_landmark_y, all_handedness 217 | 218 | def decide_handedness(self): 219 | ''' 220 | Decide which hand is the one we are interested in. 221 | 222 | Return: 223 | str. The decided handedness. Either "Right" or "Left". 224 | ''' 225 | 226 | if len(self.all_possible_handedness) == 1: 227 | # only one hand detected in the video 228 | return list(self.all_possible_handedness)[0] 229 | if len(self.all_possible_handedness) == 0: 230 | # no hand detected in the video 231 | return None 232 | 233 | scoreboard = {} 234 | for handedness in self.all_possible_handedness: 235 | scoreboard[handedness] = 0 236 | 237 | # count which hand has the most non-zero landmarks 238 | max_nonzero_num = 0 239 | max_nonzero_handedness = None 240 | 241 | # count which hand has the most speed range 242 | max_speed_range = 0 243 | max_speed_range_handedness = None 244 | 245 | # count which hand has the most x and y range 246 | max_range = 0 247 | max_range_handedness = None 248 | 249 | for handedness in self.all_possible_handedness: 250 | # decide which hand has the most non-zero landmarks 251 | non_zero_count = np.count_nonzero(~np.isnan(self.all_landmark_pos[handedness])) 252 | if non_zero_count > max_nonzero_num: 253 | max_nonzero_num = non_zero_count 254 | max_nonzero_handedness = handedness 255 | 256 | # decide which hand has the most speed range 257 | speed_max = np.nanmax(self.all_speeds[handedness]) 258 | speed_min = np.nanmin(self.all_speeds[handedness]) 259 | speed_range = speed_max - speed_min 260 | if speed_range > max_speed_range: 261 | max_speed_range = speed_range 262 | max_speed_range_handedness = handedness 263 | 264 | pos_max = np.nanmax(self.all_landmark_pos[handedness]) 265 | pos_min = np.nanmin(self.all_landmark_pos[handedness]) 266 | pos_range = pos_max - pos_min 267 | if pos_range > max_range: 268 | max_range = pos_range 269 | max_range_handedness = handedness 270 | 271 | scoreboard[max_nonzero_handedness] += 1 272 | scoreboard[max_speed_range_handedness] += 1 273 | scoreboard[max_range_handedness] += 1 274 | 275 | # return the hand with the most votes 276 | return max(scoreboard, key=scoreboard.get) # type: ignore 277 | 278 | def process_speed_curve(self): 279 | ''' 280 | Process the speed curve of the hand so we can find peaks and valley more robustly. 281 | 282 | 1. Linearly interpolate the nan values in the speed curve. 283 | 2. Use Gaussian filter to smooth the speed curve. 284 | 285 | Return: 286 | np.array. The smoothed speed curve. 287 | ''' 288 | 289 | # linear interpolation 290 | y = self.all_speeds[self.handedness] 291 | nans, x_nans = np.isnan(y), lambda z: z.nonzero()[0] 292 | y_interpolated = y.copy() 293 | y_interpolated[nans] = np.interp(x_nans(nans), x_nans(~nans), y[~nans]) 294 | 295 | # Gaussian filter smoothing 296 | y_smoothed = gaussian_filter(y_interpolated, sigma=self.gaussian_sigma) 297 | 298 | return y_smoothed 299 | 300 | def get_peaks_valleys(self, smoothed_curve): 301 | ''' 302 | Get the peaks and valleys of the speed curve. 303 | 304 | Input: np.array. The smoothed speed curve. 305 | 306 | Return: 307 | peaks: np.array. The indices of the peaks. 308 | valleys: np.array. The indices of the valleys. 309 | ''' 310 | peaks, _ = find_peaks(smoothed_curve) 311 | valleys, _ = find_peaks(-smoothed_curve) 312 | x = np.arange(len(smoothed_curve)) 313 | 314 | plt.figure() 315 | plt.plot(x, smoothed_curve, label='Smoothed Data'); 316 | plt.plot(x[peaks], smoothed_curve[peaks], 'rx', label='Peaks'); 317 | plt.plot(x[valleys], smoothed_curve[valleys], 'go', label='Valleys'); 318 | plt.title(f'Smoothed {self.handedness} Hand Speed Peaks and Valleys') 319 | plt.savefig(f'{str(self.base_folder)}/{self.handedness}_speed_smoothed.jpg') 320 | plt.close() 321 | 322 | return peaks, valleys 323 | 324 | def get_frame(self, frame_number): 325 | ''' 326 | Get the frame of the video at the specified frame number. 327 | 328 | Input: 329 | frame_number: int. The frame number. 330 | 331 | Return: 332 | np.array. The frame of the video. 333 | ''' 334 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) 335 | ret, frame = self.cap.read() 336 | 337 | if not ret: 338 | print(f"Something went wrong when reading frame {frame_number}.") 339 | return np.zeros((self.VIDEO_HEIGHT, self.VIDEO_WIDTH, 3)) 340 | return frame 341 | 342 | def get_speed(self, coordinates): 343 | ''' 344 | Get the speed curve of the hand. 345 | ''' 346 | length = len(coordinates) 347 | speed = np.zeros(length-1) 348 | for i in range(length-1): 349 | if coordinates[i] is not None and coordinates[i+1] is not None: 350 | speed[i] = np.linalg.norm(coordinates[i] - coordinates[i+1]) 351 | else: 352 | speed[i] = None 353 | return speed 354 | 355 | def plot_speed(self, speeds, handedness, selected_frame=None): 356 | ''' 357 | Plot the speed curve of the hand. Overlay the current hand speed of i-th frame on the entire speed curve. 358 | 359 | Input: 360 | speeds: np.array. The speed curve of the hand. 361 | handedness: str. The handedness of the hand. Either "Right" or "Left". 362 | ''' 363 | for i in range(self.num_frame-1): 364 | ''' 365 | For every frame, we are plotting the entire speed curve again and again for the purpose of making a video at the end. 366 | Again, can be better if I am more fluent in matplotlib. 367 | ''' 368 | plt.figure(); 369 | plt.plot(speeds, label=f'{handedness} Hand Speed Distribution'); 370 | if selected_frame is not None: 371 | plt.scatter(selected_frame, speeds[selected_frame], color='blue', label=f'Selected Frame {selected_frame}'); 372 | plt.scatter(i, speeds[i], color='red', label=f'Current {handedness} Hand'); 373 | plt.legend(); 374 | plt.savefig(f'{str(self.plot_folder)}/{i}_{handedness}.jpg') 375 | plt.close() 376 | 377 | def make_video(self, handedness): 378 | ''' 379 | Make a video of the hand images and the speed plot combined. 380 | 381 | Input: 382 | handedness: str. The handedness of the hand. Either "Right" or "Left". 383 | ''' 384 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') # type: ignore 385 | out = cv2.VideoWriter(f'{str(self.base_folder)}/{handedness}_combined.mp4', fourcc, 30.0, (self.VIDEO_WIDTH*2, self.VIDEO_HEIGHT)) 386 | for i in range(self.num_frame-1): 387 | frame = cv2.imread(f'{self.hand_images_folder}/{i+1}.jpg') 388 | frame = cv2.resize(frame, (self.VIDEO_WIDTH, self.VIDEO_HEIGHT)) 389 | plot = cv2.imread(f'{str(self.plot_folder)}/{i}_{handedness}.jpg') 390 | combined = cv2.hconcat([frame, plot]) 391 | 392 | out.write(combined) 393 | out.release() 394 | 395 | def _mediapipe_init(self): 396 | ''' 397 | Initialize the MediaPipe HandLandmarker. 398 | ''' 399 | 400 | BaseOptions = mp.tasks.BaseOptions 401 | self.HandLandmarker = mp.tasks.vision.HandLandmarker 402 | HandLandmarkerOptions = mp.tasks.vision.HandLandmarkerOptions 403 | VisionRunningMode = mp.tasks.vision.RunningMode 404 | 405 | # Create a hand landmarker instance with the video mode: 406 | self.mp_hand_options = HandLandmarkerOptions( 407 | base_options=BaseOptions(model_asset_path='./hand_landmarker.task'), 408 | running_mode=VisionRunningMode.VIDEO, 409 | num_hands=2, 410 | min_hand_detection_confidence=0.2, 411 | min_hand_presence_confidence=0.2, 412 | min_tracking_confidence=0.2, 413 | ) 414 | 415 | def _visualization_init(self): 416 | ''' 417 | Initialize the visualization parameters. 418 | ''' 419 | self.MARGIN = 10 # pixels 420 | self.FONT_SIZE = 1 421 | self.FONT_THICKNESS = 1 422 | self.HANDEDNESS_TEXT_COLOR = (88, 205, 54) # vibrant green 423 | self.VIDEO_WIDTH = 640 424 | self.VIDEO_HEIGHT = 480 425 | 426 | def _folder_init(self, video_path, output_dir): 427 | ''' 428 | Initialize all folder-related stuff. 429 | ''' 430 | output_dir = Path(output_dir) 431 | self.video_path = video_path 432 | self.video_name = Path(video_path).stem 433 | 434 | # This folder will hold all output files related to this SPECIFIC video. 435 | self.base_folder = output_dir / self.video_name 436 | 437 | # This folder will hold all annotated hand images extracted from the video. 438 | self.hand_images_folder = self.base_folder / 'hand_images' 439 | os.makedirs(self.hand_images_folder, exist_ok=True) 440 | 441 | # This folder will hold all matplotlib plots related to the hand landmarks. 442 | self.plot_folder = self.base_folder / 'plots' 443 | os.makedirs(self.plot_folder, exist_ok=True) 444 | 445 | # This folder will hold the selected frames extracted from the video. 446 | self.selected_folder = self.base_folder / 'selected_frames' 447 | os.makedirs(self.selected_folder, exist_ok=True) 448 | 449 | # This folder will hold ALL valled frames extracted from the video. 450 | self.all_valleys_folder = self.base_folder / 'all_valleys' 451 | os.makedirs(self.all_valleys_folder, exist_ok=True) 452 | 453 | 454 | def main(args): 455 | args_parsed = args.parse_args() 456 | video_path = args_parsed.video_path 457 | output_dir = args_parsed.output_dir 458 | gaussian_sigma = args_parsed.gaussian_sigma 459 | prominence = args_parsed.prominence 460 | 461 | frame_extractor = FrameExtractor(video_path, output_dir, gaussian_sigma, prominence) 462 | frame_extractor.extract_frames() 463 | 464 | if __name__ == '__main__': 465 | args = ArgumentParser() 466 | args.add_argument('--video_path', type=str, help='The path to the video file.') 467 | args.add_argument('--output_dir', type=str, help='The path to the output directory.', default='./output') 468 | args.add_argument('--gaussian_sigma', type=int, help='The sigma value for the Gaussian filter.', default=5) 469 | args.add_argument('--prominence', type=float, help='The prominence value for the find_peaks function.', default=0.8) 470 | main(args) -------------------------------------------------------------------------------- /hand_landmarker.task: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4ce/SeeDo/9b44abbca767218ab2b452ef5b8ac0958623def0/hand_landmarker.task -------------------------------------------------------------------------------- /media/main.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai4ce/SeeDo/9b44abbca767218ab2b452ef5b8ac0958623def0/media/main.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | addict==2.4.0 2 | astunparse==1.6.3 3 | cython==3.0.10 4 | diffusers==0.29.2 5 | ffmpy==0.4.0 6 | Flask==3.0.3 7 | gdown==5.2.0 8 | huggingface_hub>=0.25.1 9 | hydra-core==1.3.2 10 | iopath==0.1.10 11 | ipdb==0.13.13 12 | ipython==8.12.3 13 | matplotlib==3.9.1 14 | mediapipe==0.10.15 15 | mmcv==2.2.0 16 | mmengine==0.10.5 17 | numpy==1.26.4 18 | omegaconf==2.3.0 19 | onnxruntime==1.20.0 20 | openai==1.54.3 21 | opencv_contrib_python==4.10.0.84 22 | opencv_python==4.10.0.84 23 | opencv_python_headless==4.10.0.84 24 | Pillow==10.4.0 25 | progressbar33==2.4 26 | psutil==6.0.0 27 | pybullet==3.2.6 28 | pycocotools==2.0.8 29 | Pygments==2.15.1 30 | PyYAML==6.0.1 31 | Requests==2.32.3 32 | scipy==1.14.1 33 | setuptools==60.2.0 34 | Shapely==2.0.6 35 | scikit-image==0.24.0 36 | supervision==0.24.0 37 | termcolor==2.5.0 38 | timm==1.0.11 39 | tqdm==4.66.6 40 | transformers==4.43.3 41 | yapf==0.40.2 -------------------------------------------------------------------------------- /track_anything.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | from tqdm import tqdm 3 | from tools.interact_tools import SamControler 4 | from tracker.base_tracker import BaseTracker 5 | from inpainter.base_inpainter import BaseInpainter 6 | import numpy as np 7 | import argparse 8 | 9 | class TrackingAnything(): 10 | def __init__(self, sam_checkpoint, xmem_checkpoint, e2fgvi_checkpoint, args): 11 | self.args = args 12 | self.sam_checkpoint = sam_checkpoint 13 | self.xmem_checkpoint = xmem_checkpoint 14 | self.e2fgvi_checkpoint = e2fgvi_checkpoint 15 | self.samcontroler = SamControler(self.sam_checkpoint, args.sam_model_type, args.device) 16 | self.xmem = BaseTracker(self.xmem_checkpoint, device=args.device) 17 | self.baseinpainter = BaseInpainter(self.e2fgvi_checkpoint, args.device) 18 | 19 | def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True): 20 | mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask) 21 | return mask, logit, painted_image 22 | 23 | def generator(self, images: list, template_mask:np.ndarray): 24 | 25 | masks = [] 26 | logits = [] 27 | painted_images = [] 28 | for i in tqdm(range(len(images)), desc="Tracking image"): 29 | if i ==0: 30 | mask, logit, painted_image = self.xmem.track(images[i], template_mask) 31 | masks.append(mask) 32 | logits.append(logit) 33 | painted_images.append(painted_image) 34 | 35 | else: 36 | mask, logit, painted_image = self.xmem.track(images[i]) 37 | masks.append(mask) 38 | logits.append(logit) 39 | painted_images.append(painted_image) 40 | return masks, logits, painted_images 41 | 42 | def parse_augment(): 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('--device', type=str, default="cuda:0") 45 | parser.add_argument('--sam_model_type', type=str, default="vit_h") 46 | parser.add_argument('--port', type=int, default=6080, help="only useful when running gradio applications") 47 | parser.add_argument('--debug', action="store_true") 48 | parser.add_argument('--mask_save', default=False) 49 | args = parser.parse_args() 50 | 51 | if args.debug: 52 | print(args) 53 | return args 54 | 55 | if __name__ == "__main__": 56 | masks = None 57 | logits = None 58 | painted_images = None 59 | images = [] 60 | image = np.array(PIL.Image.open('/hhd3/gaoshang/truck.jpg')) 61 | args = parse_augment() 62 | images.append(image) 63 | images.append(image) 64 | 65 | mask = np.zeros_like(image)[:,:,0] 66 | mask[0,0]= 1 67 | trackany = TrackingAnything('/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth','/ssd1/gaomingqi/checkpoints/XMem-s012.pth', args) 68 | masks, logits ,painted_images= trackany.generator(images, mask) -------------------------------------------------------------------------------- /track_objects.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import copy 5 | import gc 6 | import json 7 | import re 8 | import csv 9 | import base64 10 | from io import BytesIO 11 | from collections import Counter 12 | import ast 13 | 14 | import numpy as np 15 | import torch 16 | from PIL import Image 17 | from torchvision.ops import box_convert 18 | from tqdm import tqdm 19 | import cv2 20 | import scipy.signal 21 | import matplotlib.pyplot as plt 22 | 23 | from openai import OpenAI 24 | from VLM_CaP.src.key import mykey, projectkey 25 | from diffusers import StableDiffusionInpaintPipeline 26 | from sam2.build_sam import build_sam2_video_predictor 27 | 28 | # Grounding DINO 29 | from GroundingDINO.groundingdino.models import build_model 30 | from GroundingDINO.groundingdino.util import box_ops 31 | from GroundingDINO.groundingdino.util.slconfig import SLConfig 32 | from GroundingDINO.groundingdino.util.utils import clean_state_dict 33 | from GroundingDINO.groundingdino.util.inference import ( 34 | annotate, 35 | load_image, 36 | predict, 37 | load_image_from_array, 38 | ) 39 | 40 | # Segment Anything 41 | from segment_anything import build_sam, SamPredictor 42 | 43 | # Hugging Face Hub 44 | from huggingface_hub import hf_hub_download 45 | 46 | sys.path.append(os.path.join(os.getcwd(), "GroundingDINO")) 47 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 48 | 49 | 50 | def image_to_base64(image): 51 | buffered = BytesIO() 52 | image.save(buffered, format="PNG") 53 | img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") 54 | return img_str 55 | 56 | 57 | def call_openai_api(prompt_messages, client): 58 | params = { 59 | "model": "gpt-4o", 60 | "messages": prompt_messages, 61 | "max_tokens": 400, 62 | "temperature": 0, 63 | } 64 | result = client.chat.completions.create(**params) 65 | return result.choices[0].message.content 66 | 67 | 68 | def get_object_list(video_path, client): 69 | # Use the first frame for encoding 70 | video = cv2.VideoCapture(video_path) 71 | 72 | base64Frames = [] 73 | frame_count = 0 74 | max_frames = 2 # Only process the first 2 frames 75 | 76 | while video.isOpened() and frame_count < max_frames: 77 | success, frame = video.read() 78 | if not success: 79 | break 80 | _, buffer = cv2.imencode(".jpg", frame) 81 | base64Frames.append(base64.b64encode(buffer).decode("utf-8")) 82 | frame_count += 1 83 | 84 | video.release() 85 | print(len(base64Frames), "frames read.") 86 | 87 | prompt_messages_state = [ 88 | { 89 | "role": "system", 90 | "content": [ 91 | "You are a visual object detector. Your task is to count and identify the objects in the provided image that are on the desk. Focus on objects classified as grasped_objects and containers.", 92 | "Do not include hand or gripper in your answer", 93 | ], 94 | }, 95 | { 96 | "role": "user", 97 | "content": [ 98 | "There are two kinds of objects, grasped_objects and containers in the environment. We only care about objects on the desk.", 99 | "You must strictly follow the rules below: Even if there are multiple objects that appear identical, you must repeat their names in your answer according to their quantity. For example, if there are three wooden blocks, you must mention 'wooden block' three times in your answer." 100 | "Be careful and accurate with the number. Do not miss or add additional object in your answer." 101 | "Based on the input picture, answer:", 102 | "1. How many objects are there in the environment?", 103 | "2. What are these objects?", 104 | "You should respond in the format of the following example:", 105 | "Number: 3", 106 | "Objects: red pepper, red tomato, white bowl", 107 | "Number: 4", 108 | "Objects: wooden block, wooden block, wooden block, wooden block", 109 | *map(lambda x: {"image": x, "resize": 768}, base64Frames[0:1]), 110 | ], 111 | }, 112 | ] 113 | 114 | response_state = call_openai_api(prompt_messages_state, client) 115 | return response_state 116 | 117 | 118 | def extract_num_object(response_state): 119 | # Extract number of objects 120 | num_match = re.search(r"Number: (\d+)", response_state) 121 | num = int(num_match.group(1)) if num_match else 0 122 | 123 | # Extract objects 124 | objects_match = re.search(r"Objects: (.+)", response_state) 125 | objects_list = objects_match.group(1).split(", ") if objects_match else [] 126 | 127 | # Construct object list 128 | objects = [obj for obj in objects_list] 129 | 130 | return num, objects 131 | 132 | 133 | def load_model_hf(repo_id, filename, ckpt_config_filename, device="cpu"): 134 | cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename) 135 | 136 | args = SLConfig.fromfile(cache_config_file) 137 | model = build_model(args) 138 | args.device = device 139 | 140 | cache_file = hf_hub_download(repo_id=repo_id, filename=filename) 141 | checkpoint = torch.load(cache_file, map_location="cpu") 142 | log = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) 143 | print("Model loaded from {} \n => {}".format(cache_file, log)) 144 | _ = model.eval() 145 | return model 146 | 147 | 148 | def read_video(video_path): 149 | video_capture = cv2.VideoCapture(video_path) 150 | 151 | if not video_capture.isOpened(): 152 | print("Error: Could not open video.") 153 | exit() 154 | 155 | frames = [] 156 | 157 | while True: 158 | ret, frame = video_capture.read() 159 | 160 | if not ret: 161 | break 162 | 163 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 164 | frames.append(frame) 165 | return frames 166 | 167 | 168 | def my_annotate( 169 | image_source: np.ndarray, 170 | boxes: torch.Tensor, 171 | logits: torch.Tensor, 172 | phrases, 173 | ) -> np.ndarray: 174 | h, w, _ = image_source.shape 175 | boxes = boxes * torch.Tensor([w, h, w, h]) 176 | xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() 177 | 178 | annotated_frame = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR) 179 | 180 | for box, logit, phrase in zip(xyxy, logits, phrases): 181 | x1, y1, x2, y2 = map(int, box) 182 | label = f"{phrase} {logit:.2f}" 183 | 184 | # Draw bounding box 185 | cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), (0, 255, 0), 2) 186 | 187 | # Draw label background box 188 | (text_width, text_height), _ = cv2.getTextSize( 189 | label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1 190 | ) 191 | cv2.rectangle( 192 | annotated_frame, 193 | (x1, y1 - text_height - 4), 194 | (x1 + text_width, y1), 195 | (0, 255, 0), 196 | -1, 197 | ) 198 | 199 | # Draw label text 200 | cv2.putText( 201 | annotated_frame, 202 | label, 203 | (x1, y1 - 2), 204 | cv2.FONT_HERSHEY_SIMPLEX, 205 | 0.5, 206 | (0, 0, 0), 207 | 1, 208 | ) 209 | 210 | return annotated_frame 211 | 212 | 213 | def video2jpg(video_path, output_folder, sample_freq=1): 214 | os.makedirs(output_folder, exist_ok=True) 215 | 216 | cap = cv2.VideoCapture(video_path) 217 | 218 | if not cap.isOpened(): 219 | print("Error: Could not open video.") 220 | else: 221 | frame_index = 0 222 | save_index = 0 223 | while True: 224 | ret, frame = cap.read() 225 | if not ret: 226 | break 227 | 228 | if frame_index % sample_freq == 0: 229 | frame_filename = os.path.join(output_folder, f"{save_index:04d}.jpg") 230 | cv2.imwrite(frame_filename, frame) 231 | save_index += 1 232 | 233 | frame_index += 1 234 | 235 | cap.release() 236 | print(f"All frames have been saved to {output_folder}.") 237 | 238 | 239 | color_list = { 240 | 0: np.array([255, 0, 0]), # Red 241 | 1: np.array([0, 255, 0]), # Green 242 | 2: np.array([0, 0, 255]), # Blue 243 | 3: np.array([0, 125, 125]), # Teal 244 | 4: np.array([125, 0, 125]), # Purple 245 | 5: np.array([125, 125, 0]), # Yellow 246 | 6: np.array([255, 165, 0]), # Orange 247 | 7: np.array([255, 105, 180]), # Pink 248 | } 249 | 250 | 251 | def contour_painter( 252 | input_image, 253 | input_mask, 254 | mask_color=5, 255 | mask_alpha=0.7, 256 | contour_color=1, 257 | contour_width=3, 258 | ann_obj_id=None, 259 | ): 260 | assert ( 261 | input_image.shape[:2] == input_mask.shape 262 | ), "Different shape between image and mask" 263 | # 0: background, 1: foreground 264 | mask = np.clip(input_mask, 0, 1).astype(np.uint8) 265 | contour_radius = (contour_width - 1) // 2 266 | 267 | dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3) 268 | dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3) 269 | dist_map = dist_transform_fore - dist_transform_back 270 | contour_radius += 2 271 | contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius)) 272 | contour_mask = contour_mask / np.max(contour_mask) 273 | contour_mask[contour_mask > 0.5] = 1.0 274 | 275 | # Paint contour 276 | painted_image = input_image.copy() 277 | color = color_list[contour_color] 278 | mask = 1 - contour_mask 279 | painted_image[mask.astype(bool)] = ( 280 | painted_image[mask.astype(bool)] * (1 - 1) + color * 1 281 | ).astype("uint8") 282 | 283 | # Find the center position of the mask 284 | moments = cv2.moments(mask) 285 | if moments["m00"] != 0 and ann_obj_id is not None: 286 | cX = int(moments["m10"] / moments["m00"]) 287 | cY = int(moments["m01"] / moments["m00"]) 288 | 289 | font = cv2.FONT_HERSHEY_SIMPLEX 290 | font_scale = 1.2 291 | font_color = (0, 0, 0) 292 | font_thickness = 3 293 | cv2.putText( 294 | painted_image, 295 | str(ann_obj_id), 296 | (cX, cY), 297 | font, 298 | font_scale, 299 | font_color, 300 | font_thickness, 301 | ) 302 | 303 | return painted_image 304 | 305 | 306 | def write_video(frames, output_path, fps): 307 | if not frames: 308 | print("Error: No frames to write.") 309 | return 310 | 311 | height, width, _ = frames[0].shape 312 | 313 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") 314 | video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) 315 | 316 | for frame in frames: 317 | frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 318 | video_writer.write(frame_bgr) 319 | print("Video writing completed.") 320 | video_writer.release() 321 | 322 | 323 | def process_mask_signal(mask_add, mask_min): 324 | kernel_size = 3 325 | 326 | n = len(mask_add) 327 | 328 | fig, axes = plt.subplots(n * 2, 1, figsize=(10, 5 * n * 2), sharex=True, sharey=True) 329 | axes = axes.flatten() 330 | 331 | filtered_mask_add = {} 332 | filtered_mask_min = {} 333 | min_num = 99999 334 | max_num = 0 335 | 336 | index = 0 337 | for k in mask_add.keys(): 338 | mask1 = np.array(mask_add[k]) 339 | mask2 = np.array(mask_min[k]) 340 | 341 | # Median filter 342 | filtered_data1 = scipy.signal.medfilt(mask1, kernel_size=kernel_size) 343 | filtered_data2 = scipy.signal.medfilt(mask2, kernel_size=kernel_size) 344 | 345 | filtered_mask_add[k] = filtered_data1 346 | filtered_mask_min[k] = filtered_data2 347 | 348 | max_num = max(max_num, max(filtered_mask_add[k])) 349 | max_num = max(max_num, max(filtered_mask_min[k])) 350 | min_num = min(min_num, min(filtered_mask_add[k])) 351 | min_num = min(min_num, min(filtered_mask_min[k])) 352 | 353 | axes[index].plot(filtered_data1, linestyle="-", color="b") 354 | axes[index + 1].plot(filtered_data2, linestyle="-", color="b") 355 | index += 2 356 | 357 | plt.show() 358 | 359 | def sigmoid(x): 360 | return 1 / (1 + np.exp(-x)) 361 | 362 | fig, axes = plt.subplots(n * 2, 1, figsize=(10, 5 * n * 2), sharex=True, sharey=True) 363 | axes = axes.flatten() 364 | index = 0 365 | 366 | for k in filtered_mask_add.keys(): 367 | filtered_mask_add[k] = ((filtered_mask_add[k] - min_num) / max_num) * 2 - 1 368 | filtered_mask_min[k] = ((filtered_mask_min[k] - min_num) / max_num) * 2 - 1 369 | 370 | filtered_mask_add[k] = sigmoid(filtered_mask_add[k] * 5) 371 | filtered_mask_min[k] = sigmoid(filtered_mask_min[k] * 5) 372 | 373 | axes[index].plot(filtered_mask_add[k], linestyle="-", color="b") 374 | axes[index + 1].plot(filtered_mask_min[k], linestyle="-", color="b") 375 | index += 2 376 | 377 | plt.show() 378 | 379 | fig, axes = plt.subplots(n, 1, figsize=(10, 5 * n), sharex=True, sharey=True) 380 | axes = axes.flatten() 381 | index = 0 382 | 383 | final_result = {} 384 | for k in filtered_mask_add.keys(): 385 | final_result[k] = filtered_mask_add[k] * filtered_mask_min[k] 386 | 387 | axes[index].plot(final_result[k], linestyle="-", color="b") 388 | index += 1 389 | 390 | plt.show() 391 | 392 | 393 | def main(input_video_path, output_video_path, key_frames): 394 | key_frames = ast.literal_eval(key_frames) 395 | 396 | # First Part: Get object list from first key_frame using VLM 397 | client = OpenAI(api_key=projectkey) 398 | 399 | ckpt_repo_id = "ShilongLiu/GroundingDINO" 400 | ckpt_filenmae = "groundingdino_swinb_cogcoor.pth" 401 | ckpt_config_filename = "GroundingDINO_SwinB.cfg.py" 402 | 403 | groundingdino_model = load_model_hf( 404 | ckpt_repo_id, ckpt_filenmae, ckpt_config_filename 405 | ) 406 | 407 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 408 | 409 | sam_checkpoint = "sam_vit_h_4b8939.pth" 410 | sam = build_sam(checkpoint=sam_checkpoint) 411 | sam.to(device=DEVICE) 412 | sam_predictor = SamPredictor(sam) 413 | 414 | if DEVICE.type == "cpu": 415 | float_type = torch.float32 416 | else: 417 | float_type = torch.float16 418 | 419 | pipe = StableDiffusionInpaintPipeline.from_pretrained( 420 | "stabilityai/stable-diffusion-2-inpainting", 421 | torch_dtype=float_type, 422 | ) 423 | 424 | if DEVICE.type != "cpu": 425 | pipe = pipe.to("cuda") 426 | 427 | video_path = input_video_path 428 | sample_freq = 16 429 | output_video_path = output_video_path 430 | 431 | frames = read_video(video_path) 432 | 433 | object_list_response = get_object_list(video_path, client) 434 | 435 | num, obj_list = extract_num_object(object_list_response) 436 | print(f"Generated prompt: {obj_list}") 437 | 438 | # Second Part: Use GroundedSAM2 to track the objects 439 | # Parameters for GroundingDINO 440 | BOX_TRESHOLD = 0.3 441 | TEXT_TRESHOLD = 0.25 442 | object_counts = Counter(obj_list) 443 | 444 | image_source, image = load_image_from_array(frames[0]) 445 | 446 | best_boxes = [] 447 | best_phrases = [] 448 | best_logits = [] 449 | 450 | # Iterate over each object and select the box with highest confidence 451 | for obj, count in object_counts.items(): 452 | boxes, logits, phrases = predict( 453 | model=groundingdino_model, 454 | image=image, 455 | caption=obj, 456 | box_threshold=BOX_TRESHOLD, 457 | text_threshold=TEXT_TRESHOLD, 458 | device=DEVICE, 459 | ) 460 | 461 | if boxes.shape[0] > 0: 462 | selected_count = min( 463 | count, boxes.shape[0] 464 | ) # If returned boxes are fewer than object count 465 | for i in range(selected_count): 466 | best_boxes.append(boxes[i].unsqueeze(0)) 467 | best_phrases.append(phrases[i]) 468 | best_logits.append(logits[i]) 469 | 470 | if best_boxes: 471 | best_boxes = torch.cat(best_boxes) 472 | best_logits = torch.stack(best_logits) 473 | 474 | annotated_frame = my_annotate( 475 | image_source=image_source, 476 | boxes=best_boxes, 477 | logits=best_logits, 478 | phrases=best_phrases, 479 | ) 480 | annotated_frame = annotated_frame[..., ::-1] # BGR to RGB 481 | 482 | sam_predictor.set_image(image_source) 483 | H, W, _ = image_source.shape 484 | boxes_xyxy = box_ops.box_cxcywh_to_xyxy(best_boxes) * torch.Tensor([W, H, W, H]) 485 | 486 | transformed_boxes = sam_predictor.transform.apply_boxes_torch( 487 | boxes_xyxy, image_source.shape[:2] 488 | ).to(DEVICE) 489 | masks, _, _ = sam_predictor.predict_torch( 490 | point_coords=None, 491 | point_labels=None, 492 | boxes=transformed_boxes, 493 | multimask_output=False, 494 | ) 495 | 496 | masks = masks.cpu() 497 | masks_np = masks.numpy() 498 | 499 | h, w = masks_np[0][0].shape 500 | pixel_cnt = h * w 501 | indices_to_keep = np.ones(len(masks_np), dtype=bool) 502 | for i in range(len(masks_np)): 503 | if np.sum(masks_np[i][0]) > pixel_cnt * 0.3: 504 | indices_to_keep[i] = False 505 | masks_np = masks_np[indices_to_keep] 506 | 507 | del groundingdino_model 508 | del sam 509 | del sam_predictor 510 | del pipe 511 | 512 | torch.cuda.empty_cache() 513 | gc.collect() 514 | 515 | # Use bfloat16 for the entire notebook 516 | torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() 517 | 518 | if torch.cuda.get_device_properties(0).major >= 8: 519 | torch.backends.cuda.matmul.allow_tf32 = True 520 | torch.backends.cudnn.allow_tf32 = True 521 | 522 | sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_large.pt" 523 | model_cfg = "sam2_hiera_l.yaml" 524 | 525 | predictor = build_sam2_video_predictor( 526 | model_cfg, sam2_checkpoint, device="cuda:0" 527 | ) 528 | 529 | # First Round for sampling 530 | video_dir = ( 531 | os.path.dirname(video_path) 532 | + f"/sample_freq_{sample_freq}_" 533 | + video_path.split("/")[-1].split(".")[0] 534 | ) 535 | if not os.path.exists(video_dir): 536 | video2jpg(video_path, video_dir, sample_freq) 537 | 538 | frame_names = [ 539 | p 540 | for p in os.listdir(video_dir) 541 | if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] 542 | ] 543 | frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) 544 | 545 | inference_state = predictor.init_state(video_path=video_dir) 546 | predictor.reset_state(inference_state) 547 | 548 | prompts = {} # Hold all the clicks we add for visualization 549 | 550 | ann_frame_idx = 0 # The frame index we interact with 551 | ann_obj_id = 1 # Give a unique id to each object we interact with 552 | 553 | for i in range(len(masks_np)): 554 | _, out_obj_ids, out_mask_logits = predictor.add_new_mask( 555 | inference_state=inference_state, 556 | frame_idx=ann_frame_idx, 557 | obj_id=i, 558 | mask=masks_np[i][0], 559 | ) 560 | 561 | # Run propagation throughout the video and collect the results in a dict 562 | video_segments = {} # Contains the per-frame segmentation results 563 | for ( 564 | out_frame_idx, 565 | out_obj_ids, 566 | out_mask_logits, 567 | ) in predictor.propagate_in_video(inference_state): 568 | video_segments[out_frame_idx] = { 569 | out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() 570 | for i, out_obj_id in enumerate(out_obj_ids) 571 | } 572 | 573 | # Second Round for processing whole video 574 | del inference_state 575 | del predictor 576 | torch.cuda.empty_cache() 577 | torch.cuda.set_device(1) 578 | 579 | predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint) 580 | 581 | video_dir = os.path.dirname(video_path) + "/" + video_path.split("/")[-1].split(".")[0] 582 | if not os.path.exists(video_dir): 583 | video2jpg(video_path, video_dir, 1) 584 | 585 | frame_names = [ 586 | p 587 | for p in os.listdir(video_dir) 588 | if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] 589 | ] 590 | frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) 591 | 592 | inference_state = predictor.init_state(video_path=video_dir) 593 | predictor.reset_state(inference_state) 594 | 595 | prompts = {} 596 | 597 | ann_frame_idx = 0 598 | ann_obj_id = 1 599 | 600 | for frame_idx in range(0, len(frame_names), sample_freq): 601 | for k in video_segments[frame_idx // sample_freq].keys(): 602 | _, out_obj_ids, out_mask_logits = predictor.add_new_mask( 603 | inference_state=inference_state, 604 | frame_idx=frame_idx, 605 | obj_id=k, 606 | mask=video_segments[frame_idx // sample_freq][k][0], 607 | ) 608 | 609 | video_segments = {} 610 | for ( 611 | out_frame_idx, 612 | out_obj_ids, 613 | out_mask_logits, 614 | ) in predictor.propagate_in_video(inference_state): 615 | video_segments[out_frame_idx] = { 616 | out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() 617 | for i, out_obj_id in enumerate(out_obj_ids) 618 | } 619 | 620 | # Third Part: Select key frames and compute center coordinates of masks 621 | key_frame_coordinates = {} 622 | 623 | # Iterate through the key frames provided 624 | for frame_idx in key_frames: 625 | current_frame_coords = [] 626 | 627 | # Check if the frame exists in the video segments (contains mask data) 628 | if frame_idx in video_segments: 629 | # For each object in the frame, retrieve the mask 630 | for obj_id, mask in video_segments[frame_idx].items(): 631 | mask_data = mask[0] 632 | mask_indices = np.argwhere(mask_data > 0) # Get non-zero pixel indices 633 | 634 | if len(mask_indices) > 0: 635 | # Calculate the average x and y coordinates of the mask to get the center 636 | avg_y, avg_x = np.mean(mask_indices, axis=0) 637 | current_frame_coords.append( 638 | f"Object {obj_id}: ({int(avg_x)}, {int(avg_y)})" 639 | ) 640 | else: 641 | # Print a warning if the mask is empty 642 | print( 643 | f"Warning: Empty mask for object {obj_id} in frame {frame_idx}" 644 | ) 645 | 646 | # Store the coordinates for the current frame 647 | key_frame_coordinates[f"key_frame{frame_idx}"] = current_frame_coords 648 | 649 | # Initialize an empty string to store the result 650 | bbx_string = "" 651 | 652 | # Iterate through the key_frame_coordinates and generate the string 653 | for key_frame, coordinates in key_frame_coordinates.items(): 654 | coordinates_str = "\n".join(coordinates) # Join the coordinates into a single string 655 | bbx_string += f"{key_frame}\n{coordinates_str}\n\n" # Append the key frame and coordinates 656 | 657 | # Print the final bounding box string 658 | print(f"Bounding box extraction completed. Result:\n{bbx_string}") 659 | 660 | # Fourth Part: Append all the painted frames into a video 661 | painted_frames = [] 662 | for i in range(len(frame_names)): 663 | img = cv2.imread(os.path.join(video_dir, frame_names[i])) 664 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 665 | for k in video_segments[i].keys(): 666 | img = contour_painter( 667 | img, video_segments[i][k][0], contour_color=1, ann_obj_id=k 668 | ) 669 | painted_frames.append(img) 670 | 671 | mask_add = {} 672 | mask_min = {} 673 | for k in video_segments[i].keys(): 674 | mask_add[k] = [] 675 | mask_min[k] = [] 676 | 677 | write_video(painted_frames, output_video_path, fps=30) 678 | 679 | for i in range(len(frame_names) - 1): 680 | for k in video_segments[i].keys(): 681 | mask_before = video_segments[i][k][0].copy() 682 | mask_after = video_segments[i + 1][k][0].copy() 683 | mask_after[mask_before] = False 684 | add_cnt = np.sum(mask_after) 685 | 686 | mask_before = video_segments[i][k][0].copy() 687 | mask_after = video_segments[i + 1][k][0].copy() 688 | mask_before[mask_after] = False 689 | min_cnt = np.sum(mask_before) 690 | 691 | mask_add[k].append(add_cnt.item()) 692 | mask_min[k].append(min_cnt.item()) 693 | 694 | process_mask_signal(mask_add, mask_min) 695 | # Return the final bounding box string 696 | print(bbx_string) 697 | return bbx_string 698 | 699 | 700 | if __name__ == "__main__": 701 | parser = argparse.ArgumentParser( 702 | description="Process video with SAM and GroundingDINO." 703 | ) 704 | parser.add_argument("--input", type=str, help="Path to the input video") 705 | parser.add_argument("--output", type=str, help="Path to the output video") 706 | parser.add_argument("--key_frames", type=str, help="List of key frame indices as a string") 707 | 708 | args = parser.parse_args() 709 | 710 | main(args.input, args.output, args.key_frames) -------------------------------------------------------------------------------- /vlm.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from shapely.geometry import * 3 | from shapely.affinity import * 4 | from openai import OpenAI 5 | from VLM_CaP.src.key import projectkey 6 | import os 7 | import re 8 | import argparse 9 | import csv 10 | import ffmpy 11 | import ast 12 | from VLM_CaP.src.vlm_video import extract_frame_list # import extract_frames 13 | # set up your openai api key 14 | client = OpenAI(api_key=projectkey) 15 | # def for calling openai api with different prompts 16 | def call_openai_api(prompt_messages): 17 | params = { 18 | "model": "gpt-4o", 19 | "messages": prompt_messages, 20 | "max_tokens": 400, 21 | "temperature": 0 22 | } 23 | result = client.chat.completions.create(**params) 24 | return result.choices[0].message.content 25 | # "Notice that there might be similar objects. You are supposed to use the index annotated on the objects to distinguish between only similar objects that is hard to distinguish with language." 26 | def get_object_list(selected_frames): 27 | # first prompt to get objects in the environment 28 | prompt_messages_state = [ 29 | { 30 | "role": "system", 31 | "content": [ 32 | "You are a visual object detector. Your task is to count and identify the objects in the provided image that are on the desk. Focus on objects classified as grasped_objects and containers.", 33 | ], 34 | }, 35 | { 36 | "role": "user", 37 | "content": [ 38 | "There are two kinds of objects, grasped_objects and containers in the environment. We only care about objects on the desk. Do not count in hand or person as objects.", 39 | "Based on the input picture, answer:", 40 | "1. How many objects are there in the environment?", 41 | "2. What are these objects?", 42 | "You should respond in the format of the following example:", 43 | "Number: 1", 44 | "Objects: purple eggplant, red tomato, white bowl, white bowl", 45 | *map(lambda x: {"image": x, "resize": 768}, selected_frames[0:1]), # use first picture for environment objects 46 | ], 47 | }, 48 | ] 49 | response_state = call_openai_api(prompt_messages_state) 50 | return response_state 51 | def extract_num_object(response_state): 52 | # extract number of objects 53 | num_match = re.search(r"Number: (\d+)", response_state) 54 | num = int(num_match.group(1)) if num_match else 0 55 | 56 | # extract objects 57 | objects_match = re.search(r"Objects: (.+)", response_state) 58 | objects_list = objects_match.group(1).split(", ") if objects_match else [] 59 | 60 | # construct object list 61 | objects = [obj for obj in objects_list] 62 | 63 | return num, objects 64 | def extract_keywords_pick(response): 65 | try: 66 | return response.split(': ')[1] 67 | except IndexError: 68 | print("Error extracting pick keyword from response:", response) 69 | return None 70 | def extract_keywords_drop(response): 71 | try: 72 | return response.split(': ')[1] 73 | except IndexError: 74 | print("Error extracting drop keyword from response:", response) 75 | return None 76 | def extract_keywords_reference(response): 77 | try: 78 | return response.split(': ')[1] 79 | except IndexError: 80 | print("Error extracting reference object from response:", response) 81 | return None 82 | def is_frame_relevant(response): 83 | return "hand is manipulating an object" in response.lower() 84 | def parse_closest_object_and_relationship(response): 85 | pattern = r"Closest Object: ([^,]+), (.+)" 86 | match = re.search(pattern, response) 87 | if match: 88 | return match.group(1), match.group(2) 89 | print("Error parsing reference object and relationship from response:", response) 90 | return None, None 91 | def process_images(selected_frames, obj_list): 92 | string_cache = "" # cache for CaP operations 93 | i = 1 94 | while i < len(selected_frames): 95 | input_frame_pick = selected_frames[i:i+1] 96 | prompt_messages_relevance_pick = [ 97 | { 98 | "role": "system", 99 | "content": [ 100 | "You are an operations inspector. You need to check whether the hand in operation is holding an object. The objects have been outlined with contours of different colors and labeled with indexes for easier distinction." 101 | ], 102 | }, 103 | { 104 | "role": "user", 105 | "content": [ 106 | "This is a picture from a pick-and-drop task. Please determine if the hand is manipulating an object.", 107 | "Respond with 'Hand is manipulating an object' or 'Hand is not manipulating an object'.", 108 | *map(lambda x: {"image": x, "resize": 768}, input_frame_pick), 109 | ], 110 | }, 111 | ] 112 | response_relevance_pick = call_openai_api(prompt_messages_relevance_pick) 113 | print(response_relevance_pick) 114 | if not is_frame_relevant(response_relevance_pick): 115 | i += 1 116 | continue 117 | # which to pick 118 | prompt_messages_pick = [ 119 | { 120 | "role": "system", 121 | "content": [ 122 | "You are an operation inspector. You need to check which object is being picked in a pick-and-drop task. Some of the objects have been outlined with contours of different colors and labeled with indexes for easier distinction.", 123 | "The contour and index is only used to help. Due to limitation of vision models, the contours and index labels might not cover every objects in the environment. If you notice any unannotated objects in the demo or in the object list, make sure you name it and handle them properly.", 124 | ], 125 | }, 126 | { 127 | "role": "user", 128 | "content": [ 129 | f"This is a picture describing the pick state of a pick-and-drop task. The objects in the environment are {obj_list}. One of the objects is being picked by a human hand or robot gripper now. The objects have been outlined with contours of different colors and labeled with indexes for easier distinction.", 130 | "Based on the input picture and object list, answer:", 131 | "1. Which object is being picked", 132 | "You should respond in the format of the following example:", 133 | "Object Picked: red block", 134 | *map(lambda x: {"image": x, "resize": 768}, input_frame_pick), 135 | ], 136 | }, 137 | ] 138 | response_pick = call_openai_api(prompt_messages_pick) 139 | print(response_pick) 140 | object_picked = extract_keywords_pick(response_pick) 141 | i += 1 142 | # Ensure there is another frame for drop and relative position reasoning 143 | if i >= len(selected_frames): 144 | break 145 | # Check if the second frame (i) is relevant (i.e., hand is holding an object) 146 | input_frame_drop = selected_frames[i:i+1] 147 | # reference object 148 | prompt_messages_reference = [ 149 | { 150 | "role": "system", 151 | "content": [ 152 | "You are an operation inspector. You need to find the reference object for the placement location of the picked object in the pick-and-place process. Notice that the reference object can vary based on the task. If this is a storage task, the reference object should be the container into which the items are stored. If this is a stacking task, the reference object should be the object that best expresses the orientation of the arrangement." 153 | ], 154 | }, 155 | { 156 | "role": "user", 157 | "content": [ 158 | f"This is a picture describing the drop state of a pick-and-place task. The objects in the environment are {obj_list}. {object_picked} is being dropped by a human hand or robot gripper now.", 159 | "Based on the input picture and object list, answer:", 160 | f"1. Which object in the rest of object list do you choose as a reference object to {object_picked}", 161 | "You should respond in the format of the following example without any additional information or reason steps:", 162 | "Reference Object: red block", 163 | *map(lambda x: {"image": x, "resize": 768}, input_frame_drop), 164 | ], 165 | }, 166 | ] 167 | response_reference = call_openai_api(prompt_messages_reference) 168 | print(response_reference) 169 | object_reference = extract_keywords_reference(response_reference) 170 | # current_bbx = bbx_list[i] if i < len(bbx_list) else {} 171 | 172 | # "Due to limitation of vision models, the contours and index labels might not cover every objects in the environment. If you notice any unannotated objects in the demo or in the object list, make sure you handle them properly.", 173 | prompt_messages_relationship = [ 174 | { 175 | "role": "system", 176 | "content": [ 177 | "You are a VLMTutor. You will describe the drop state of a pick-and-drop task from a demo picture. You must pay specific attention to the spatial relationship between picked object and reference object in the picture and be correct and accurate with directions.", 178 | ], 179 | }, 180 | { 181 | "role": "user", 182 | "content": [ 183 | f"This is a picture describing the drop state of a pick-and-drop task. The objects in the environment are object list: {obj_list}. {object_picked} is said to be being dropped by a human hand or robot gripper now.", 184 | f"However, the object being dropped might be wrong due to bad visual prompt. If you feel that object being picked is not {object_picked} but some other object, red chili is said to be the object picked but you feel it is an orange carrot, you MUST modify it and change the name!" 185 | # "But notice that due to limitation of vision models, the contours and index labels might not cover every objects in the environment. If you notice any unannotated objects in the demo or in the object list, make sure you mention their name and handle their spatial relationships." 186 | # "The ID is only used to help with your reasoning. You should only mention them when the objects are the same in language description. For example, when there are two white bowls, you must specify white bowl (ID:1), white bowl (ID:2) in your answer. But for different objects like vegetables, you do not need to specify their IDs." 187 | # f"To help you better understand the spatial relationship, a bounding box list is given to you. Notice that the bounding boxes of objects in the bounding box list are distinguished by labels. These labels correspond one-to-one with the labels of the objects in the image. The bounding box list is: {bbx_list}", 188 | # "The coordinates of the bounding box represent the center point of the object. The format is two coordinates (x,y). The origin of the coordinates is at the top-left corner of the image. If there are two objects A(x1, y1) and B(x2, y2), a significantly smaller x2 compared to x1 indicates that B is to the left of A; a significantly greater x2 compared to x1 indicates that B is to the right of A; a significantly smaller y2 compared to y1 indicates that B is at the back of A; a significantly greater y2 compared to y1 indicates that B is in front of A." 189 | # "Pay attention to distinguish between at the back of and on top of. If B and A has a visual gap, they are not in touch. Thus B is at the back of A. However, if they are very close, this means B and A are in contact, thus B is on top of A." 190 | # "Notice that the largest difference in corresponding coordinates often represents the most significant feature. If you have coordinates with small difference in x but large difference in y, then coordinates y will represent most significant feature. Make sure to use the picture together with coordinates." 191 | f"The object picked is being dropped somewhere near {object_reference}. Based on the input picture, object list answer:", 192 | f"Drop object picked to which relative position to the {object_reference}? You need to mention the name of objects in your answer.", 193 | f"There are totally six kinds of relative position, and the direction means the visual direction of the picture.", 194 | f"1. In ((object picked is contained in the {object_reference})", 195 | f"2. On top of (object picked is stacked on the {object_reference}, {object_reference} supports object picked)", 196 | f"3. At the back of (in demo it means object picked is positioned farther to the viewer relative to the {object_reference})", 197 | f"4. In front of (in demo it means object picked is positioned closer to the viewer or relative to the {object_reference})", 198 | "5. to the left", 199 | "6. to the right", 200 | f"You must choose one relative position." 201 | "You should respond in the format of the following example without any additional information or reason steps, be sure to mention the object picked and reference object.", 202 | f"Drop yellow corn to the left of the red chili", 203 | f"Drop red chili in the white bowl", 204 | f"Drop wooden block (ID:1) to the right of the wooden block (ID:0)", 205 | *map(lambda x: {"image": x, "resize": 768}, input_frame_drop), 206 | ], 207 | }, 208 | ] 209 | response_relationship = call_openai_api(prompt_messages_relationship) 210 | print(response_relationship) 211 | string_cache += response_relationship + " and then " 212 | 213 | i += 1 214 | 215 | return string_cache 216 | def save_results_to_csv(demo_name, num, obj_list, string_cache, output_file): 217 | file_exists = os.path.exists(output_file) 218 | with open(output_file, mode='a', newline='') as file: 219 | writer = csv.writer(file) 220 | if not file_exists: 221 | writer.writerow(["demo", "object", "action list"]) 222 | 223 | writer.writerow([f"{demo_name}", f"{num} objects: {', '.join(obj_list)}", string_cache]) 224 | print(f"Results appended to {output_file}") 225 | def convert_video_to_mp4(input_path): 226 | """ 227 | Converts the input video file to H.264 encoded .mp4 format using ffmpy. 228 | The output path will be the same as the input path with '_converted' appended before the extension. 229 | """ 230 | # Get the file name without extension and append '_converted' 231 | base_name, ext = os.path.splitext(input_path) 232 | output_path = f"{base_name}_converted.mp4" 233 | # Run FFmpeg command to convert the video 234 | ff = ffmpy.FFmpeg( 235 | inputs={input_path: None}, 236 | outputs={output_path: '-c:v libx264 -crf 23 -preset fast -r 30'} 237 | ) 238 | ff.run() 239 | print(f"Video converted successfully: {output_path}") 240 | return output_path 241 | 242 | def main(input_video_path, frame_index_list, bbx_list): 243 | video_path = input_video_path 244 | # list to store key frames 245 | selected_raw_frames1 = [] 246 | # list to store key frame indexes 247 | frame_index_list = ast.literal_eval(frame_index_list) 248 | selected_frame_index = frame_index_list 249 | # Convert the video to H.264 encoded .mp4 format 250 | # converted_video_path = convert_video_to_mp4(video_path) 251 | # Open the converted video 252 | cap = cv2.VideoCapture(video_path) 253 | # Manually calculate total number of frames 254 | actual_frame_count = 0 255 | while cap.isOpened(): 256 | ret, frame = cap.read() 257 | if not ret: 258 | break 259 | actual_frame_count += 1 260 | # Reset the capture to the beginning of the video 261 | cap.set(cv2.CAP_PROP_POS_FRAMES, 0) 262 | print(f"Actual frame count: {actual_frame_count}") 263 | # Iterate through index list and get the frame list 264 | for index in selected_frame_index: 265 | if index < actual_frame_count: 266 | cap.set(cv2.CAP_PROP_POS_FRAMES, index) 267 | ret, cv2_image = cap.read() 268 | if ret: 269 | selected_raw_frames1.append(cv2_image) 270 | else: 271 | print(f"Failed to retrieve frame at index {index}") 272 | else: 273 | print(f"Frame index {index} is out of range for this video.") 274 | # Release video capture object 275 | cap.release() 276 | selected_frames1 = extract_frame_list(selected_raw_frames1) 277 | response_state = get_object_list(selected_frames1) 278 | num, obj_list = extract_num_object(response_state) 279 | print("Number of objects:", num) 280 | print("available objects:", obj_list) 281 | # obj_list = "green corn, orange carrot, red pepper, white bowl, glass container" 282 | # process the key frames 283 | string_cache = process_images(selected_frames1, obj_list) 284 | if string_cache.endswith(" and then "): 285 | my_string = string_cache.removesuffix(" and then ") 286 | print(my_string) 287 | return my_string 288 | 289 | if __name__ == "__main__": 290 | parser = argparse.ArgumentParser(description="Process video and key frame extraction.") 291 | parser.add_argument('--input', type=str, required=True, help='Input video path') 292 | parser.add_argument('--list', type=str, required=True, help='List of key frame indexes') 293 | parser.add_argument('--bbx_list', type=str, required=True, help='Bbx of key frames') 294 | args = parser.parse_args() 295 | # Call the main function with arguments 296 | main(args.input, args.list, args.bbx_list) --------------------------------------------------------------------------------