├── .gitignore ├── README.md ├── asset ├── banner.gif ├── banner.mp4 ├── hardware_setup.jpg └── method-overview.gif ├── constraint_solver ├── action_parser.py ├── draw_arrow.py ├── main_infer_trans.py └── spatial_solver │ ├── __init__.py │ ├── constraint.py │ ├── elements.py │ ├── parser.py │ ├── solver.py │ └── transform.py ├── demo.py ├── graspnet ├── .gitignore ├── command_demo.sh ├── command_test.sh ├── command_train.sh ├── dataset │ ├── bullet_dataset.py │ ├── command_generate_tolerance_label.sh │ ├── generate_tolerance_label.py │ └── graspnet_dataset.py ├── demo.py ├── demo_filter_grasp.py ├── doc │ ├── example_data │ │ ├── color.png │ │ ├── demo_result.png │ │ ├── depth.png │ │ ├── meta.mat │ │ └── workspace_mask.png │ └── teaser.png ├── knn │ ├── knn_modules.py │ ├── setup.py │ └── src │ │ ├── cpu │ │ ├── knn_cpu.cpp │ │ └── vision.h │ │ ├── cuda │ │ ├── knn.cu │ │ └── vision.h │ │ ├── knn.h │ │ └── vision.cpp ├── main_to_ros.py ├── models │ ├── backbone.py │ ├── graspnet.py │ ├── loss.py │ └── modules.py ├── pointnet2 │ ├── _ext_src │ │ ├── include │ │ │ ├── ball_query.h │ │ │ ├── cuda_utils.h │ │ │ ├── cylinder_query.h │ │ │ ├── group_points.h │ │ │ ├── interpolate.h │ │ │ ├── sampling.h │ │ │ └── utils.h │ │ └── src │ │ │ ├── ball_query.cpp │ │ │ ├── ball_query_gpu.cu │ │ │ ├── bindings.cpp │ │ │ ├── cylinder_query.cpp │ │ │ ├── cylinder_query_gpu.cu │ │ │ ├── group_points.cpp │ │ │ ├── group_points_gpu.cu │ │ │ ├── interpolate.cpp │ │ │ ├── interpolate_gpu.cu │ │ │ ├── sampling.cpp │ │ │ └── sampling_gpu.cu │ ├── pointnet2_modules.py │ ├── pointnet2_utils.py │ ├── pytorch_utils.py │ └── setup.py ├── requirements.txt ├── test.py ├── train.py ├── utils │ ├── collision_detector.py │ ├── data_utils.py │ ├── grasp_projection.py │ ├── label_generation.py │ ├── loss_utils.py │ ├── motion_planner.py │ └── ros_adapter.py └── visualize.py ├── real_world ├── CMakeLists.txt ├── README.md ├── behavior │ └── behavior.py ├── functional_grasp │ └── functional_grasp.py ├── generate_point_cloud │ ├── CMakeLists.txt │ ├── launch │ │ └── generate_point_cloud.launch │ └── scripts │ │ ├── generate_point_cloud.py │ │ └── visualize.py ├── get_point_cloud │ ├── CMakeLists.txt │ ├── launch │ │ └── get_point_cloud.launch │ └── scripts │ │ └── get_point_cloud.py ├── package.xml ├── point_cloud_snapshot │ └── point_cloud_snapshot.py ├── robot_homing │ ├── CMakeLists.txt │ ├── launch │ │ └── robot_homing.launch │ └── scripts │ │ └── robot_homing.py ├── robot_states_monitor │ ├── CMakeLists.txt │ ├── launch │ │ └── robot_states_monitor.launch │ └── scripts │ │ └── robot_states_monitor.py └── trajectory_recorder │ ├── CMakeLists.txt │ ├── launch │ └── trajectory_recorder.launch │ ├── scripts │ └── trajectory_recorder.py │ └── trajectory_recorder.py └── som_gpt4v ├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── assets ├── method2_xyz.png ├── som_bench_bottom.jpg ├── som_bench_upper.jpg ├── som_gpt4v_demo.mp4 ├── som_logo.png ├── som_toolbox_interface.jpg └── teaser.png ├── benchmark └── README.md ├── camera_extrinsic2.npy ├── configs ├── seem_focall_unicl_lang_v1.yaml └── semantic_sam_only_sa-1b_swinL.yaml ├── dataset_tasks.py ├── demo_functional_grasp.py ├── demo_gpt4v_som.py ├── demo_som.py ├── download_ckpt.sh ├── draw_arrow.py ├── examples ├── gpt-4v-som-example.jpg ├── ironing_man.jpg ├── ironing_man_som.png └── som_logo.png ├── first_prompt ├── prompt1.png ├── prompt1.txt ├── prompt2.png ├── prompt2.txt ├── prompt3.png ├── prompt3.txt └── prompt4.txt ├── gpt4v.py ├── gpt4v_azure.py ├── main_behavior.py ├── main_constraint.py ├── main_grasp.py ├── mask_filters.py ├── ops ├── functions │ ├── __init__.py │ └── ms_deform_attn_func.py ├── make.sh ├── modules │ ├── __init__.py │ └── ms_deform_attn.py ├── setup.py ├── src │ ├── cpu │ │ ├── ms_deform_attn_cpu.cpp │ │ └── ms_deform_attn_cpu.h │ ├── cuda │ │ ├── ms_deform_attn_cuda.cu │ │ ├── ms_deform_attn_cuda.h │ │ └── ms_deform_im2col_cuda.cuh │ ├── ms_deform_attn.h │ └── vision.cpp └── test.py ├── second_prompt ├── prompt1.png ├── prompt1.txt ├── prompt2.png ├── prompt2.txt ├── prompt3.png ├── prompt3.txt └── prompt4.txt └── task_adapter ├── sam ├── __init__.py └── tasks │ ├── __Init__.py │ ├── inference_sam_m2m_auto.py │ └── inference_sam_m2m_interactive.py ├── seem ├── __init__.py └── tasks │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── inference_seem_interactive.py │ ├── inference_seem_pano.py │ └── interactive_seem_m2m_auto.py ├── semantic_sam └── tasks │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── inference_semsam_m2m_auto.py │ ├── interactive_idino_1o1_box.py │ ├── interactive_idino_m2m.py │ └── interactive_predictor.py └── utils ├── lfq_visualizer.py ├── visualizer.py ├── visualizer_method_2.py ├── visualizer_method_2_1.py ├── visualizer_method_2_2.py └── visualizer_method_first_seg.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CoPa: General Robotic Manipulation through Spatial Constraints of Parts with Foundation Models 2 | 3 | [[Project page]](https://copa-2024.github.io/) 4 | [[Paper]](https://arxiv.org/abs/2403.08248) 5 | 6 | [Haoxu Huang*](https://bpsoda.github.io/), [Fanqi Lin*](https://fanqi-lin.github.io/), Yingdong Hu, Shengjie Wang, Yang Gao 7 | 8 | This repository is the official implementation of the paper: CoPa: General Robotic Manipulation through Spatial Constraints of Parts with Foundation Models 9 | 10 | ![](asset/banner.gif) 11 | ![](asset/method-overview.gif) 12 | 13 | ## Get Started 14 | Install SoM following the [instruction](https://github.com/microsoft/SoM#rocket-quick-start). 15 | Install [graspnetAPI](https://github.com/graspnet/graspnetAPI). 16 | Download [examples](https://drive.google.com/drive/folders/1YjDU89-EOh2v5XIJjVr7vA4yATCuYrjM?usp=sharing) and place it under the data directory. 17 | Run the demo 18 | ```console 19 | $ python demo.py 20 | ``` 21 | 22 | ## Real-World Deployment 23 | Please follow the instruction in `real_world/README.md`. 24 | 25 | ## Acknowledgement 26 | - Our grounding module is adapted from [SoM](https://github.com/microsoft/SoM). 27 | - We use [GraspNet](https://graspnet.net/) for grasp candidates generation. 28 | -------------------------------------------------------------------------------- /asset/banner.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/asset/banner.gif -------------------------------------------------------------------------------- /asset/banner.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/asset/banner.mp4 -------------------------------------------------------------------------------- /asset/hardware_setup.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/asset/hardware_setup.jpg -------------------------------------------------------------------------------- /asset/method-overview.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/asset/method-overview.gif -------------------------------------------------------------------------------- /constraint_solver/action_parser.py: -------------------------------------------------------------------------------- 1 | class ActionParser(): 2 | def __init__(self): 3 | self.actions = [] 4 | 5 | def parse(self, text): 6 | lines = text.split("\n") 7 | description = self.locate_actions(lines) 8 | 9 | for line in description: 10 | self.parse_action(line) 11 | 12 | return self.actions 13 | 14 | def locate_actions(self, lines): 15 | for i, line in enumerate(lines): 16 | line = line.strip() 17 | if line == "": 18 | start = i 19 | elif line == "": 20 | end = i 21 | return lines[start+1:end] 22 | 23 | def parse_action(self, line): 24 | line = line.strip().strip('.').lower() 25 | words = line.replace('\'', ' ').split() 26 | if "move" in line: 27 | if "left" in line: 28 | action = "move_left" 29 | elif "right" in line: 30 | action = "move_right" 31 | elif "forward" in line: 32 | action = "move_forward" 33 | elif "backward" in line: 34 | action = "move_backward" 35 | elif "up" in line: 36 | action = "move_up" 37 | elif "down" in line: 38 | action = "move_down" 39 | for i, word in enumerate(words): 40 | if word == 'cm': 41 | distance = float(words[i-1]) / 100 42 | self.actions.append((action, distance)) 43 | elif "rotate" in line: 44 | action = "rotate" 45 | for i, word in enumerate(words): 46 | if word == 'degree': 47 | angle = float(words[i-1]) 48 | self.actions.append((action, angle)) 49 | elif "open" in line: 50 | self.actions.append(("open", 0)) 51 | elif "close" in line: 52 | self.actions.append(("close", 0)) 53 | -------------------------------------------------------------------------------- /constraint_solver/draw_arrow.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import numpy as np 3 | 4 | def draw_geometries(pcds): 5 | """ 6 | Draw Geometries 7 | Args: 8 | - pcds (): [pcd1,pcd2,...] 9 | """ 10 | o3d.visualization.draw_geometries(pcds) 11 | 12 | def get_o3d_FOR(origin=[0, 0, 0],size=10): 13 | """ 14 | Create a FOR that can be added to the open3d point cloud 15 | """ 16 | mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( 17 | size=size) 18 | mesh_frame.translate(origin) 19 | return(mesh_frame) 20 | 21 | def vector_magnitude(vec): 22 | """ 23 | Calculates a vector's magnitude. 24 | Args: 25 | - vec (): 26 | """ 27 | magnitude = np.sqrt(np.sum(vec**2)) 28 | return(magnitude) 29 | 30 | 31 | def calculate_zy_rotation_for_arrow(vec): 32 | """ 33 | Calculates the rotations required to go from the vector vec to the 34 | z axis vector of the original FOR. The first rotation that is 35 | calculated is over the z axis. This will leave the vector vec on the 36 | XZ plane. Then, the rotation over the y axis. 37 | 38 | Returns the angles of rotation over axis z and y required to 39 | get the vector vec into the same orientation as axis z 40 | of the original FOR 41 | 42 | Args: 43 | - vec (): 44 | """ 45 | gamma = np.arctan2(vec[1], vec[0]) 46 | # Check if z component is negative 47 | if vec[2] < 0: 48 | gamma = np.pi + gamma # Add 180 degrees 49 | Rz = np.array([[np.cos(gamma), -np.sin(gamma), 0], 50 | [np.sin(gamma), np.cos(gamma), 0], 51 | [0, 0, 1]]) 52 | # Rotate vec to calculate next rotation 53 | vec = Rz.T @ vec.reshape(-1, 1) 54 | vec = vec.reshape(-1) 55 | # Rotation over y axis of the FOR 56 | beta = np.arctan2(vec[0], vec[2]) 57 | Ry = np.array([[np.cos(beta), 0, np.sin(beta)], 58 | [0, 1, 0], 59 | [-np.sin(beta), 0, np.cos(beta)]]) 60 | return Rz, Ry 61 | 62 | def create_arrow(scale=10): 63 | """ 64 | Create an arrow in for Open3D 65 | """ 66 | cone_height = scale*0.2 67 | cylinder_height = scale*0.8 68 | cone_radius = scale/10 69 | cylinder_radius = scale/20 70 | mesh_frame = o3d.geometry.TriangleMesh.create_arrow(cone_radius=cone_radius, 71 | cone_height=cone_height, 72 | cylinder_radius=cylinder_radius, 73 | cylinder_height=cylinder_height) 74 | return(mesh_frame) 75 | 76 | def get_arrow(origin=[0, 0, 0], end=None, vec=None): 77 | """ 78 | Creates an arrow from an origin point to an end point, 79 | or create an arrow from a vector vec starting from origin. 80 | Args: 81 | - end (): End point. [x,y,z] 82 | - vec (): Vector. [i,j,k] 83 | """ 84 | scale = 0.1 85 | Ry = Rz = np.eye(3) 86 | T = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) 87 | T[:3, -1] = origin 88 | if end is not None: 89 | vec = np.array(end) - np.array(origin) 90 | elif vec is not None: 91 | vec = np.array(vec) 92 | if end is not None or vec is not None: 93 | scale = vector_magnitude(vec) * scale 94 | Rz, Ry = calculate_zy_rotation_for_arrow(vec) 95 | mesh = create_arrow(scale) 96 | # Create the arrow 97 | mesh.rotate(Ry, center=np.array([0, 0, 0])) 98 | mesh.rotate(Rz, center=np.array([0, 0, 0])) 99 | mesh.translate(origin) 100 | return(mesh) 101 | 102 | 103 | if __name__ == '__main__': 104 | 105 | # Create a Cartesian Frame of Reference 106 | FOR = get_o3d_FOR() 107 | # Create an arrow from point (5,5,5) to point (10,10,10) 108 | # arrow = get_arrow([5,5,5],[10,10,10]) 109 | 110 | # Create an arrow representing vector vec, starting at (5,5,5) 111 | # arrow = get_arrow([5,5,5],vec=[5,5,5]) 112 | 113 | # Create an arrow in the same place as the z axis 114 | arrow = get_arrow(origin=[0,1,1], vec=[0, np.sqrt(2)/2, np.sqrt(2)/2]) 115 | arrow2 = get_arrow(origin=[0,1,1], vec=[0, -np.sqrt(2)/2, -np.sqrt(2)/2]) 116 | arrow2.paint_uniform_color([1, 0, 0]) 117 | # Draw everything 118 | draw_geometries([FOR,arrow,arrow2]) -------------------------------------------------------------------------------- /constraint_solver/main_infer_trans.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | from pathlib import Path 4 | import numpy as np 5 | import open3d as o3d 6 | 7 | from spatial_solver import ConstraintsParser, Plane, transform, create_se3, create_x 8 | from draw_arrow import get_arrow 9 | from action_parser import ActionParser 10 | 11 | def create_elements(spatial_data, add_table=False): 12 | elements = {} 13 | for i, (ratio, center, norm) in enumerate(spatial_data): 14 | print('Element {}: ratio {:2f}, d1 {}, d2 {}'.format(i+1, ratio, center, norm)) 15 | if ratio < 3: 16 | elements[str(i+1)] = Plane.from_numpy(np.stack([center, norm])) 17 | else: # line connect start point (center) and end point (norm) 18 | elements[str(i+1)] = Plane.from_numpy(np.stack([norm, 19 | (norm - center) / np.linalg.norm(norm - center)])) 20 | if add_table: 21 | elements['table'] = Plane.from_numpy(np.stack([np.array([0.5, 0, 0.07]), np.array([0, 0, 1])])) 22 | print(elements) 23 | return elements 24 | 25 | def visualize_transform(elements, se3, transform_obj_id=None): 26 | arrows = [] 27 | for k, elem in elements.items(): 28 | arrow = get_arrow(origin=elem.p.to_numpy(), vec=elem.n.to_numpy()) 29 | arrow.paint_uniform_color([1, 0, 0]) 30 | # print(elem) 31 | if transform_obj_id is None or k == transform_obj_id: 32 | arrow.paint_uniform_color([0, 0, 1]) 33 | trans_elem = transform(elem, se3) 34 | # print(trans_elem) 35 | trans_arrow = get_arrow(origin=trans_elem.p.to_numpy(), vec=trans_elem.n.to_numpy()) 36 | trans_arrow.paint_uniform_color([0, 1, 0]) 37 | arrows.append(trans_arrow) 38 | arrows.append(arrow) 39 | frame_axis = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.1, origin=[0,0,0]) 40 | o3d.visualization.draw_geometries([*arrows, frame_axis]) 41 | 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--spatial_data', type=Path, default=Path('spatial_data.pkl')) 47 | parser.add_argument('--constraints', type=Path, default=Path('response.txt')) 48 | parser.add_argument('--output', type=Path, default=Path('transform.npy')) 49 | args = parser.parse_args() 50 | 51 | # read data and constraints 52 | with open(args.spatial_data, 'rb') as f: 53 | spatial_data = pickle.load(f) 54 | with open(args.constraints, 'r') as f: 55 | constraints = f.read() 56 | 57 | elements = create_elements(spatial_data, add_table=True) 58 | constraints_parser = ConstraintsParser(elements) 59 | solver = constraints_parser.parse(constraints) 60 | # print(constraints_parser) 61 | solver.dist_factor = 0.02 # minimize distance between start and end pose 62 | result = solver.solve() 63 | se3 = create_se3(result.x) 64 | np.save(args.output, se3) 65 | 66 | subsequent_actions = ActionParser().parse(constraints) 67 | pickle.dump(subsequent_actions, open(args.output.parent / 'action.pkl', 'wb')) 68 | 69 | visualize_transform(elements, se3, '1') -------------------------------------------------------------------------------- /constraint_solver/spatial_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .constraint import Constraint 2 | from .elements import * 3 | from .solver import SpatialSolver, create_se3, create_x 4 | from .parser import ConstraintsParser 5 | from .transform import transform -------------------------------------------------------------------------------- /constraint_solver/spatial_solver/elements.py: -------------------------------------------------------------------------------- 1 | """Definition of 3D geometry elements: point, line, plane, etc.""" 2 | import dataclasses 3 | import numpy as np 4 | 5 | 6 | @dataclasses.dataclass 7 | class Point: 8 | """A point in 3D space.""" 9 | 10 | x: float 11 | y: float 12 | z: float 13 | 14 | @classmethod 15 | def from_numpy(cls, arr): 16 | """Create a point from a numpy array.""" 17 | assert arr.shape == (3,) or arr.shape == (4,), "Array must have shape (3,) or (4,)" 18 | return cls(*arr[:3]) 19 | 20 | def to_numpy(self): 21 | """Convert the point to a numpy array.""" 22 | return np.array([self.x, self.y, self.z]) 23 | 24 | def to_homogeneous(self): 25 | """Convert the point to homogeneous coordinates. Shape: (4,)""" 26 | return np.array([self.x, self.y, self.z, 1]) 27 | 28 | 29 | @dataclasses.dataclass 30 | class Line: 31 | """A line in 3D space.""" 32 | 33 | p1: Point 34 | p2: Point 35 | 36 | @classmethod 37 | def from_numpy(cls, arr): 38 | """Create a line from a numpy array.""" 39 | assert arr.shape == (2, 3) or arr.shape == (2, 4), "Array must have shape (2, 3) or (2, 4)" 40 | return cls(Point.from_numpy(arr[0]), Point.from_numpy(arr[1])) 41 | 42 | def to_numpy(self): 43 | """Convert the line to a numpy array.""" 44 | return np.array([self.p1.to_numpy(), self.p2.to_numpy()]) 45 | 46 | def to_homogeneous(self): 47 | """Convert the line to homogeneous coordinates. Shape: (3, 2)""" 48 | return np.array([self.p1.to_homogeneous(), self.p2.to_homogeneous()]).T 49 | 50 | def to_vector(self): 51 | """Convert the line to a vector.""" 52 | return Vector(self.p2.x - self.p1.x, self.p2.y - self.p1.y, self.p2.z - self.p1.z) 53 | 54 | 55 | class Vector(Point): 56 | """A vector in 3D space.""" 57 | 58 | pass 59 | 60 | 61 | @dataclasses.dataclass 62 | class Plane: 63 | """A plane in 3D space. 64 | Described by a central point and a normal vector. 65 | """ 66 | p: Point 67 | n: Vector 68 | 69 | @classmethod 70 | def from_numpy(cls, arr): 71 | """Create a plane from a numpy array.""" 72 | assert arr.shape == (2, 3) or arr.shape == (2, 4), "Array must have shape (2, 3) or (2, 4)" 73 | return cls(Point.from_numpy(arr[0,:]), Vector.from_numpy(arr[1,:])) 74 | 75 | def to_numpy(self): 76 | """Convert the plane to a numpy array. Shape: (2, 3)""" 77 | return np.array([self.p.to_numpy(), self.n.to_numpy()]) 78 | 79 | def to_homogeneous(self): 80 | """Convert the plane to homogeneous coordinates. Shape: (4, 2)""" 81 | return np.array([self.p.to_homogeneous(), self.n.to_homogeneous()]).T 82 | -------------------------------------------------------------------------------- /constraint_solver/spatial_solver/parser.py: -------------------------------------------------------------------------------- 1 | ### Constraints parser ### 2 | # Constraints include: 3 | # 1. Vector A and Vector B are on the same line, with the same/opposite direction. 4 | # 2. The target position of Point A is x cm along Vector B from Point C's current position. 5 | # 3. Vector A is parallel to the table surface. 6 | # 4. Vector A is perpendicular to the table surface, pointing downward/upward. 7 | # 5. Point A is x cm above the table surface. 8 | 9 | from .solver import SpatialSolver 10 | from .constraint import Constraint 11 | from .elements import * 12 | 13 | 14 | class ConstraintsParser(): 15 | """Parser for constraints.""" 16 | 17 | def __init__(self, elements): 18 | """Initialize the parser. 19 | elements: a dict of elements (id: element) 20 | """ 21 | self.elements = elements 22 | self.solver = SpatialSolver() 23 | 24 | def parse(self, text): 25 | """Parse the constraints. 26 | text: the paragraph of constraints 27 | """ 28 | lines = text.split("\n") 29 | description = self.locate_constraints(lines) 30 | 31 | for line in description: 32 | self.parse_constraint(line) 33 | 34 | return self.solver 35 | 36 | def locate_constraints(self, lines): 37 | """Locate the description of the constraints. 38 | The description is wrapped in .... 39 | lines: a list of lines 40 | """ 41 | for i, line in enumerate(lines): 42 | line = line.strip() 43 | if line == "": 44 | start = i 45 | elif line == "": 46 | end = i 47 | return lines[start+1:end] 48 | 49 | def parse_constraint(self, line) -> SpatialSolver: 50 | """Parse a line of description. 51 | line: a line of description 52 | """ 53 | if 'Move' in line: 54 | return 55 | line = line.strip().strip('.') 56 | words = line.replace('\'', ' ').split() 57 | elems, distance = self.parse_elements(words) 58 | if "on the same line" in line: # constraint 1 59 | constraints = [Constraint(elems[0], elems[1], "colinear")] 60 | if "same direction" in line: 61 | constraints.append(Constraint(elems[0].n, elems[1].n, "equal")) 62 | else: # opposite direction 63 | constraints.append(Constraint(elems[0].n, elems[1].n, "inverse")) 64 | 65 | elif "cm along" in line: # constraint 2 66 | new_plane = Plane(elems[2].p, elems[1].n) 67 | constraints = [Constraint(elems[0].p, new_plane, "distance", distance=distance), 68 | Constraint(elems[0].p, new_plane, "online")] 69 | 70 | elif "parallel to the table surface" in line: # constraint 3 71 | if "Vector" in line: 72 | constraints = [Constraint(elems[0].n, elems[1].n, "perpendicular")] 73 | else: # Plane 74 | constraints = [Constraint(elems[0].n, elems[1].n, "parallel")] 75 | 76 | elif "perpendicular to the table surface" in line: # constraint 4 77 | if "downward" in line: 78 | constraints = [Constraint(elems[0].n, elems[1].n, "inverse")] 79 | else: # upward 80 | constraints = [Constraint(elems[0].n, elems[1].n, "equal")] 81 | 82 | elif "above the table surface" in line: # constraint 5 83 | constraints = [Constraint(elems[0].p, elems[1], "distance", distance=distance)] 84 | 85 | else: 86 | raise NotImplementedError(f"Constraint type {line} not implemented") 87 | 88 | self.solver.add_constraints(constraints) 89 | 90 | 91 | def parse_elements(self, words): 92 | """Parse the elements in the description. 93 | words: a list of words 94 | """ 95 | elems = [] 96 | distance = None 97 | for i, word in enumerate(words): 98 | if word in ['Point', 'Vector', 'Surface']: 99 | elem_id = words[i+1] 100 | elems.append(self.elements[elem_id]) 101 | if word == 'table': 102 | elems.append(self.elements['table']) 103 | if word == 'cm': 104 | distance = float(words[i-1]) / 100 105 | assert len(elems) > 0, "No element found in the description" 106 | return elems, distance 107 | 108 | def __str__(self) -> str: 109 | return f"ConstraintsParser({self.elements})\n {self.solver}" -------------------------------------------------------------------------------- /constraint_solver/spatial_solver/solver.py: -------------------------------------------------------------------------------- 1 | from scipy.optimize import minimize, Bounds, NonlinearConstraint 2 | from scipy.spatial.transform import Rotation as R 3 | import numpy as np 4 | 5 | from .elements import * 6 | from .constraint import * 7 | 8 | 9 | def quaternion_norm_constraint(x): 10 | # 假设四元数是x中的前四个参数 11 | quaternion = x[:4] 12 | return np.linalg.norm(quaternion) - 1 13 | 14 | def create_se3(x): 15 | se3 = np.zeros([4,4]) 16 | se3[:3,:3] = R.from_quat(x[:4]).as_matrix() 17 | se3[:3,3] = x[4:] 18 | se3[3,3] = 1 19 | return se3 20 | 21 | def create_x(se3): 22 | x = np.zeros(7) 23 | x[:4] = R.from_matrix(se3[:3,:3]).as_quat() 24 | x[4:] = se3[:3,3] 25 | return x 26 | 27 | class SpatialSolver: 28 | """Spatial solver for 3D geometry constraints. 29 | Attributes: 30 | constraints: a list of constraints 31 | """ 32 | 33 | def __init__(self, constraints=None, dist_factor=0): 34 | """Initialize the spatial solver. 35 | constraints: a list of constraints 36 | """ 37 | self.constraints = constraints or [] 38 | self.dist_factor = dist_factor 39 | 40 | def solve(self, max_iter=10000, tol=1e-6, initial_guess=[0,0,0,1,0,0,0], bounds=None): 41 | """Solve the constraints. 42 | max_iter: maximum number of iterations 43 | tol: tolerance for convergence 44 | """ 45 | result = minimize(self.objective, initial_guess, method='trust-constr', # BFGS, trust-constr 46 | bounds=bounds, options={'maxiter': max_iter}, tol=tol) 47 | # report convergence 48 | if result.success: 49 | print("Converged!") 50 | else: 51 | print("Failed to converge! Try increasing max_iter or tol.") 52 | print(f"Current solution: {result.x}") 53 | normalized_result = result.x 54 | normalized_result[:4] /= np.linalg.norm(normalized_result[:4]) 55 | print(f"Normalized current solution: {normalized_result}") 56 | result_se3 = create_se3(result.x) 57 | print(f"Current solution SE3: \n{result_se3}") 58 | for constraint in self.constraints: 59 | constraint.test_objective(result_se3) 60 | solution_dist = np.linalg.norm(result.x - np.array([0,0,0,1,0,0,0])) 61 | print(f"Distance between start and end pose: {solution_dist}") 62 | return result 63 | 64 | def objective(self, x): 65 | """Compute the objective function of the constraints. 66 | x: the current solution 67 | """ 68 | se3 = create_se3(x) 69 | obj = 0 70 | for constraint in self.constraints: 71 | obj += constraint.objective(se3) 72 | 73 | # distance between start and end pose 74 | if self.dist_factor > 0: 75 | obj += np.linalg.norm(x - np.array([0,0,0,1,0,0,0])) * self.dist_factor 76 | return obj 77 | 78 | def add_constraints(self, constraints): 79 | """Add constraints to the solver. 80 | constraints: a list of constraints 81 | """ 82 | self.constraints.extend(constraints) 83 | 84 | def __str__(self) -> str: 85 | """Return a string representation of the solver.""" 86 | return f"SpatialSolver({self.constraints})" 87 | -------------------------------------------------------------------------------- /constraint_solver/spatial_solver/transform.py: -------------------------------------------------------------------------------- 1 | from .elements import * 2 | 3 | 4 | def transform(elem, se3): 5 | """Transform a geometry element by a SE3 transform. 6 | elem: the geometry element 7 | se3: the SE3 transform 8 | """ 9 | if isinstance(elem, Vector): 10 | return Vector.from_numpy(se3 @ elem.to_homogeneous() - se3[:, 3]) 11 | elif isinstance(elem, Line): 12 | return Line.from_numpy((se3 @ elem.to_homogeneous()).T) 13 | elif isinstance(elem, Point): 14 | return Point.from_numpy(se3 @ elem.to_homogeneous()) 15 | elif isinstance(elem, Plane): 16 | return Plane(Point.from_numpy(se3 @ elem.p.to_homogeneous()), Vector.from_numpy(se3 @ elem.n.to_homogeneous() - se3[:, 3])) 17 | else: 18 | raise TypeError(f"Cannot transform {elem} of type {type(elem)}") -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser(description='Quick demo of CoPa.') 5 | parser.add_argument('task', type=str, help='Task to run (e.g. "flower", "button")') 6 | args = parser.parse_args() 7 | task = args.task 8 | 9 | # Task-oriented Grasping 10 | ## Ground grasp part 11 | subprocess.run(['python', 'som_gpt4v/main_grasp.py', task]) 12 | 13 | ## Generate and filter grasp candidates 14 | subprocess.run(['python', 'graspnet/demo_filter_grasp.py', 15 | '--candidates_path', f'data/{task}/grasp_candidates.npy', 16 | '--pointcloud_path', f'data/{task}/pointcloud1.npy', 17 | '--mask_path', f'data/{task}/mask.npy', 18 | '--output_path', f'data/{task}/grasp.npy',]) 19 | ## Uncomment and replace above if you want to generate grasp candidates 20 | # subprocess.run(['python', 'graspnet/main_to_ros.py', 21 | # '--checkpoint_path', 'graspnet/checkpoint-rs.tar', 22 | # '--pointcloud_path', f'data/{task}/pointcloud1.npy', 23 | # '--mask_path', f'data/{task}/mask.npy', 24 | # '--output_path', f'data/{task}/grasp.npy',]) 25 | 26 | # Task-aware Motion Planning 27 | ## Extract geometric elements 28 | subprocess.run(['python', 'som_gpt4v/main_behavior.py', 29 | '--image_base_dir', 'data', 30 | '--task', task, 31 | '--base_output_dir', 'data',]) 32 | ## Generate spatial constraints 33 | subprocess.run(['python', 'som_gpt4v/main_constraint.py', 34 | '--instruction_base_dir', 'data', 35 | '--results_image_base_dir', 'data', 36 | '--task', task,]) 37 | ## Generate target pose 38 | subprocess.run(['python', 'constraint_solver/main_infer_trans.py', 39 | '--spatial_data', f'data/{task}/spatial_data.pkl', 40 | '--constraints', f'data/{task}/constraint_response.txt', 41 | '--output', f'data/{task}/transform.npy',]) -------------------------------------------------------------------------------- /graspnet/.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/** 2 | *.ipynb 3 | **/.ipynb_checkpoints/** 4 | *.npy 5 | *.npz 6 | **/.vscode/** 7 | **/grasp_label*/** 8 | **/log*/** 9 | **/dump*/** 10 | **/build/** 11 | *.o 12 | *.so 13 | *.egg 14 | **/*.egg-info/** 15 | logs 16 | dataset/tolerance 17 | dataset/scenes -------------------------------------------------------------------------------- /graspnet/command_demo.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python demo.py --checkpoint_path logs/log_kn/checkpoint.tar 2 | -------------------------------------------------------------------------------- /graspnet/command_test.sh: -------------------------------------------------------------------------------- 1 | # CUDA_VISIBLE_DEVICES=0 python test.py --dump_dir logs/dump_rs --checkpoint_path logs/log_rs/checkpoint.tar --camera realsense --dataset_root /data/Benchmark/graspnet 2 | # CUDA_VISIBLE_DEVICES=0 python test.py --dump_dir logs/dump_kn --checkpoint_path logs/log_kn/checkpoint.tar --camera kinect --dataset_root /data/Benchmark/graspnet 3 | CUDA_VISIBLE_DEVICES=0 python test.py --dump_dir logs/dump_rs --checkpoint_path logs/log_rs/checkpoint-rs.tar --camera realsense --dataset_root dataset/ 4 | -------------------------------------------------------------------------------- /graspnet/command_train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train.py --camera realsense --log_dir logs/log_rs --batch_size 2 --dataset_root /data/Benchmark/graspnet 2 | # CUDA_VISIBLE_DEVICES=0 python train.py --camera kinect --log_dir logs/log_kn --batch_size 2 --dataset_root /data/Benchmark/graspnet 3 | -------------------------------------------------------------------------------- /graspnet/dataset/command_generate_tolerance_label.sh: -------------------------------------------------------------------------------- 1 | python generate_tolerance_label.py --dataset_root /data/Benchmark/graspnet --num_workers 50 2 | -------------------------------------------------------------------------------- /graspnet/dataset/generate_tolerance_label.py: -------------------------------------------------------------------------------- 1 | """ Tolerance label generation. 2 | Author: chenxi-wang 3 | """ 4 | 5 | import os 6 | import sys 7 | import numpy as np 8 | import time 9 | import argparse 10 | import multiprocessing as mp 11 | 12 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 13 | ROOT_DIR = os.path.dirname(BASE_DIR) 14 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 15 | from data_utils import compute_point_dists 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--dataset_root', required=True, help='Dataset root') 19 | parser.add_argument('--pos_ratio_thresh', type=float, default=0.8, help='Threshold of positive neighbor ratio[default: 0.8]') 20 | parser.add_argument('--mu_thresh', type=float, default=0.55, help='Threshold of friction coefficient[default: 0.55]') 21 | parser.add_argument('--num_workers', type=int, default=50, help='Worker number[default: 50]') 22 | cfgs = parser.parse_args() 23 | 24 | save_path = 'tolerance' 25 | 26 | V = 300 27 | A = 12 28 | D = 4 29 | radius_list = [0.001 * x for x in range(51)] 30 | 31 | def manager(obj_name, pool_size=8): 32 | # load models 33 | label_path = '{}_labels.npz'.format(obj_name) 34 | label = np.load(os.path.join(cfgs.dataset_root, 'grasp_label', label_path)) 35 | points = label['points'] 36 | scores = label['scores'] 37 | 38 | # create dict 39 | tolerance = mp.Manager().dict() 40 | dists = compute_point_dists(points, points) 41 | params = params = (scores, dists) 42 | 43 | # assign works 44 | pool = [] 45 | process_cnt = 0 46 | work_list = [x for x in range(len(points))] 47 | for _ in range(pool_size): 48 | point_ind = work_list.pop(0) 49 | pool.append(mp.Process(target=worker, args=(obj_name, point_ind, params, tolerance))) 50 | [p.start() for p in pool] 51 | 52 | # refill 53 | while len(work_list) > 0: 54 | for ind, p in enumerate(pool): 55 | if not p.is_alive(): 56 | pool.pop(ind) 57 | point_ind = work_list.pop(0) 58 | p = mp.Process(target=worker, args=(obj_name, point_ind, params, tolerance)) 59 | p.start() 60 | pool.append(p) 61 | process_cnt += 1 62 | print('{}/{}'.format(process_cnt, len(points))) 63 | break 64 | while len(pool) > 0: 65 | for ind, p in enumerate(pool): 66 | if not p.is_alive(): 67 | pool.pop(ind) 68 | process_cnt += 1 69 | print('{}/{}'.format(process_cnt, len(points))) 70 | break 71 | 72 | # save tolerance 73 | if not os.path.exists(save_path): 74 | os.mkdir(save_path) 75 | saved_tolerance = [None for _ in range(len(points))] 76 | for i in range(len(points)): 77 | saved_tolerance[i] = tolerance[i] 78 | saved_tolerance = np.array(saved_tolerance) 79 | np.save('{}/{}_tolerance.npy'.format(save_path, obj_name), saved_tolerance) 80 | 81 | def worker(obj_name, point_ind, params, tolerance): 82 | scores, dists = params 83 | tmp_tolerance = np.zeros([V, A, D], dtype=np.float32) 84 | tic = time.time() 85 | for r in radius_list: 86 | dist_mask = (dists[point_ind] <= r) 87 | scores_in_ball = scores[dist_mask] 88 | pos_ratio = ((scores_in_ball > 0) & (scores_in_ball <= cfgs.mu_thresh)).mean(axis=0) 89 | tolerance_mask = (pos_ratio >= cfgs.pos_ratio_thresh) 90 | if tolerance_mask.sum() == 0: 91 | break 92 | tmp_tolerance[tolerance_mask] = r 93 | tolerance[point_ind] = tmp_tolerance 94 | toc = time.time() 95 | print("{}: point {} time".format(obj_name, point_ind), toc - tic) 96 | 97 | if __name__ == '__main__': 98 | obj_list = ['%03d' % x for x in range(88)] 99 | for obj_name in obj_list: 100 | p = mp.Process(target=manager, args=(obj_name, cfgs.num_workers)) 101 | p.start() 102 | p.join() -------------------------------------------------------------------------------- /graspnet/demo.py: -------------------------------------------------------------------------------- 1 | """ Demo to show prediction results. 2 | Author: chenxi-wang 3 | """ 4 | 5 | import os 6 | import sys 7 | import numpy as np 8 | import open3d as o3d 9 | import argparse 10 | import importlib 11 | import scipy.io as scio 12 | from PIL import Image 13 | 14 | import torch 15 | from graspnetAPI import GraspGroup 16 | 17 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 18 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 19 | sys.path.append(os.path.join(ROOT_DIR, 'dataset')) 20 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 21 | 22 | from graspnet import GraspNet, pred_decode 23 | from graspnet_dataset import GraspNetDataset 24 | from collision_detector import ModelFreeCollisionDetector 25 | from data_utils import CameraInfo, create_point_cloud_from_depth_image 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--checkpoint_path', required=True, help='Model checkpoint path') 29 | parser.add_argument('--num_point', type=int, default=20000, help='Point Number [default: 20000]') 30 | parser.add_argument('--num_view', type=int, default=300, help='View Number [default: 300]') 31 | parser.add_argument('--collision_thresh', type=float, default=0.01, help='Collision Threshold in collision detection [default: 0.01]') 32 | parser.add_argument('--voxel_size', type=float, default=0.01, help='Voxel Size to process point clouds before collision detection [default: 0.01]') 33 | cfgs = parser.parse_args() 34 | 35 | 36 | def get_net(): 37 | # Init the model 38 | net = GraspNet(input_feature_dim=0, num_view=cfgs.num_view, num_angle=12, num_depth=4, 39 | cylinder_radius=0.05, hmin=-0.02, hmax_list=[0.01,0.02,0.03,0.04], is_training=False) 40 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 41 | net.to(device) 42 | # Load checkpoint 43 | checkpoint = torch.load(cfgs.checkpoint_path) 44 | net.load_state_dict(checkpoint['model_state_dict']) 45 | start_epoch = checkpoint['epoch'] 46 | print("-> loaded checkpoint %s (epoch: %d)"%(cfgs.checkpoint_path, start_epoch)) 47 | # set model to eval mode 48 | net.eval() 49 | return net 50 | 51 | def get_and_process_data(data_dir): 52 | # load data 53 | color = np.array(Image.open(os.path.join(data_dir, 'color.png')), dtype=np.float32) / 255.0 54 | depth = np.array(Image.open(os.path.join(data_dir, 'depth.png'))) 55 | workspace_mask = np.array(Image.open(os.path.join(data_dir, 'workspace_mask.png'))) 56 | meta = scio.loadmat(os.path.join(data_dir, 'meta.mat')) 57 | intrinsic = meta['intrinsic_matrix'] 58 | factor_depth = meta['factor_depth'] 59 | 60 | # generate cloud 61 | camera = CameraInfo(1280.0, 720.0, intrinsic[0][0], intrinsic[1][1], intrinsic[0][2], intrinsic[1][2], factor_depth) 62 | cloud = create_point_cloud_from_depth_image(depth, camera, organized=True) 63 | 64 | # get valid points 65 | mask = (workspace_mask & (depth > 0)) 66 | cloud_masked = cloud[mask] 67 | color_masked = color[mask] 68 | 69 | # sample points 70 | if len(cloud_masked) >= cfgs.num_point: 71 | idxs = np.random.choice(len(cloud_masked), cfgs.num_point, replace=False) 72 | else: 73 | idxs1 = np.arange(len(cloud_masked)) 74 | idxs2 = np.random.choice(len(cloud_masked), cfgs.num_point-len(cloud_masked), replace=True) 75 | idxs = np.concatenate([idxs1, idxs2], axis=0) 76 | cloud_sampled = cloud_masked[idxs] 77 | color_sampled = color_masked[idxs] 78 | 79 | # convert data 80 | cloud = o3d.geometry.PointCloud() 81 | cloud.points = o3d.utility.Vector3dVector(cloud_masked.astype(np.float32)) 82 | cloud.colors = o3d.utility.Vector3dVector(color_masked.astype(np.float32)) 83 | end_points = dict() 84 | cloud_sampled = torch.from_numpy(cloud_sampled[np.newaxis].astype(np.float32)) 85 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 86 | cloud_sampled = cloud_sampled.to(device) 87 | end_points['point_clouds'] = cloud_sampled 88 | end_points['cloud_colors'] = color_sampled 89 | 90 | return end_points, cloud 91 | 92 | def get_grasps(net, end_points): 93 | # Forward pass 94 | with torch.no_grad(): 95 | end_points = net(end_points) 96 | grasp_preds = pred_decode(end_points) 97 | gg_array = grasp_preds[0].detach().cpu().numpy() 98 | gg = GraspGroup(gg_array) 99 | return gg 100 | 101 | def collision_detection(gg, cloud): 102 | mfcdetector = ModelFreeCollisionDetector(cloud, voxel_size=cfgs.voxel_size) 103 | collision_mask = mfcdetector.detect(gg, approach_dist=0.05, collision_thresh=cfgs.collision_thresh) 104 | gg = gg[~collision_mask] 105 | return gg 106 | 107 | def vis_grasps(gg, cloud): 108 | gg.nms() 109 | gg.sort_by_score() 110 | gg = gg[:50] 111 | grippers = gg.to_open3d_geometry_list() 112 | o3d.visualization.draw_geometries([cloud, *grippers]) 113 | 114 | def demo(data_dir): 115 | net = get_net() 116 | end_points, cloud = get_and_process_data(data_dir) 117 | gg = get_grasps(net, end_points) 118 | if cfgs.collision_thresh > 0: 119 | gg = collision_detection(gg, np.array(cloud.points)) 120 | vis_grasps(gg, cloud) 121 | 122 | if __name__=='__main__': 123 | data_dir = 'doc/example_data' 124 | demo(data_dir) 125 | -------------------------------------------------------------------------------- /graspnet/doc/example_data/color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/graspnet/doc/example_data/color.png -------------------------------------------------------------------------------- /graspnet/doc/example_data/demo_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/graspnet/doc/example_data/demo_result.png -------------------------------------------------------------------------------- /graspnet/doc/example_data/depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/graspnet/doc/example_data/depth.png -------------------------------------------------------------------------------- /graspnet/doc/example_data/meta.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/graspnet/doc/example_data/meta.mat -------------------------------------------------------------------------------- /graspnet/doc/example_data/workspace_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/graspnet/doc/example_data/workspace_mask.png -------------------------------------------------------------------------------- /graspnet/doc/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/graspnet/doc/teaser.png -------------------------------------------------------------------------------- /graspnet/knn/knn_modules.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import gc 3 | import operator as op 4 | import functools 5 | import torch 6 | from torch.autograd import Variable, Function 7 | from knn_pytorch import knn_pytorch 8 | # import knn_pytorch 9 | def knn(ref, query, k=1): 10 | """ Compute k nearest neighbors for each query point. 11 | """ 12 | device = ref.device 13 | ref = ref.float().to(device) 14 | query = query.float().to(device) 15 | inds = torch.empty(query.shape[0], k, query.shape[2]).long().to(device) 16 | knn_pytorch.knn(ref, query, inds) 17 | return inds 18 | -------------------------------------------------------------------------------- /graspnet/knn/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import glob 4 | import os 5 | 6 | import torch 7 | from setuptools import find_packages 8 | from setuptools import setup 9 | from torch.utils.cpp_extension import CUDA_HOME 10 | from torch.utils.cpp_extension import CppExtension 11 | from torch.utils.cpp_extension import CUDAExtension 12 | 13 | requirements = ["torch", "torchvision"] 14 | 15 | 16 | def get_extensions(): 17 | this_dir = os.path.dirname(os.path.abspath(__file__)) 18 | extensions_dir = os.path.join(this_dir, "src") 19 | 20 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 21 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 22 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 23 | 24 | sources = main_file + source_cpu 25 | extension = CppExtension 26 | 27 | extra_compile_args = {"cxx": []} 28 | define_macros = [] 29 | 30 | if torch.cuda.is_available() and CUDA_HOME is not None: 31 | extension = CUDAExtension 32 | sources += source_cuda 33 | define_macros += [("WITH_CUDA", None)] 34 | extra_compile_args["nvcc"] = [ 35 | "-DCUDA_HAS_FP16=1", 36 | "-D__CUDA_NO_HALF_OPERATORS__", 37 | "-D__CUDA_NO_HALF_CONVERSIONS__", 38 | "-D__CUDA_NO_HALF2_OPERATORS__", 39 | ] 40 | 41 | sources = [os.path.join(extensions_dir, s) for s in sources] 42 | 43 | include_dirs = [extensions_dir] 44 | 45 | ext_modules = [ 46 | extension( 47 | "knn_pytorch.knn_pytorch", 48 | sources, 49 | include_dirs=include_dirs, 50 | define_macros=define_macros, 51 | extra_compile_args=extra_compile_args, 52 | ) 53 | ] 54 | 55 | return ext_modules 56 | 57 | 58 | setup( 59 | name="knn_pytorch", 60 | version="0.1", 61 | author="foolyc", 62 | url="https://github.com/foolyc/torchKNN", 63 | description="KNN implement in Pytorch 1.0 including both cpu version and gpu version", 64 | ext_modules=get_extensions(), 65 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 66 | ) 67 | -------------------------------------------------------------------------------- /graspnet/knn/src/cpu/knn_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "cpu/vision.h" 2 | 3 | 4 | void knn_cpu(float* ref_dev, int ref_width, float* query_dev, int query_width, 5 | int height, int k, float* dist_dev, long* ind_dev, long* ind_buf) 6 | { 7 | // Compute all the distances 8 | for(int query_idx = 0;query_idx dist_dev[query_idx * ref_width + j + 1]) 31 | { 32 | temp_value = dist_dev[query_idx * ref_width + j]; 33 | dist_dev[query_idx * ref_width + j] = dist_dev[query_idx * ref_width + j + 1]; 34 | dist_dev[query_idx * ref_width + j + 1] = temp_value; 35 | temp_idx = ind_buf[j]; 36 | ind_buf[j] = ind_buf[j + 1]; 37 | ind_buf[j + 1] = temp_idx; 38 | } 39 | 40 | } 41 | 42 | for(int i = 0;i < k;i++) 43 | ind_dev[query_idx + i * query_width] = ind_buf[i]; 44 | #if DEBUG 45 | for(int i = 0;i < ref_width;i++) 46 | printf("%d, ", ind_buf[i]); 47 | printf("\n"); 48 | #endif 49 | 50 | } 51 | 52 | 53 | 54 | 55 | 56 | } -------------------------------------------------------------------------------- /graspnet/knn/src/cpu/vision.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | void knn_cpu(float* ref_dev, int ref_width, 5 | float* query_dev, int query_width, 6 | int height, int k, float* dist_dev, long* ind_dev, long* ind_buf); -------------------------------------------------------------------------------- /graspnet/knn/src/cuda/vision.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | void knn_device(float* ref_dev, int ref_width, 6 | float* query_dev, int query_width, 7 | int height, int k, float* dist_dev, long* ind_dev, cudaStream_t stream); -------------------------------------------------------------------------------- /graspnet/knn/src/knn.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "cpu/vision.h" 3 | 4 | #ifdef WITH_CUDA 5 | #include "cuda/vision.h" 6 | #include 7 | extern THCState *state; 8 | #endif 9 | 10 | 11 | 12 | int knn(at::Tensor& ref, at::Tensor& query, at::Tensor& idx) 13 | { 14 | 15 | // TODO check dimensions 16 | long batch, ref_nb, query_nb, dim, k; 17 | batch = ref.size(0); 18 | dim = ref.size(1); 19 | k = idx.size(1); 20 | ref_nb = ref.size(2); 21 | query_nb = query.size(2); 22 | 23 | float *ref_dev = ref.data(); 24 | float *query_dev = query.data(); 25 | long *idx_dev = idx.data(); 26 | 27 | 28 | 29 | 30 | if (ref.type().is_cuda()) { 31 | #ifdef WITH_CUDA 32 | // TODO raise error if not compiled with CUDA 33 | float *dist_dev = (float*)THCudaMalloc(state, ref_nb * query_nb * sizeof(float)); 34 | 35 | for (int b = 0; b < batch; b++) 36 | { 37 | // knn_device(ref_dev + b * dim * ref_nb, ref_nb, query_dev + b * dim * query_nb, query_nb, dim, k, 38 | // dist_dev, idx_dev + b * k * query_nb, THCState_getCurrentStream(state)); 39 | knn_device(ref_dev + b * dim * ref_nb, ref_nb, query_dev + b * dim * query_nb, query_nb, dim, k, 40 | dist_dev, idx_dev + b * k * query_nb, c10::cuda::getCurrentCUDAStream()); 41 | } 42 | THCudaFree(state, dist_dev); 43 | cudaError_t err = cudaGetLastError(); 44 | if (err != cudaSuccess) 45 | { 46 | printf("error in knn: %s\n", cudaGetErrorString(err)); 47 | THError("aborting"); 48 | } 49 | return 1; 50 | #else 51 | AT_ERROR("Not compiled with GPU support"); 52 | #endif 53 | } 54 | 55 | 56 | float *dist_dev = (float*)malloc(ref_nb * query_nb * sizeof(float)); 57 | long *ind_buf = (long*)malloc(ref_nb * sizeof(long)); 58 | for (int b = 0; b < batch; b++) { 59 | knn_cpu(ref_dev + b * dim * ref_nb, ref_nb, query_dev + b * dim * query_nb, query_nb, dim, k, 60 | dist_dev, idx_dev + b * k * query_nb, ind_buf); 61 | } 62 | 63 | free(dist_dev); 64 | free(ind_buf); 65 | 66 | return 1; 67 | 68 | } 69 | -------------------------------------------------------------------------------- /graspnet/knn/src/vision.cpp: -------------------------------------------------------------------------------- 1 | #include "knn.h" 2 | 3 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 4 | m.def("knn", &knn, "k-nearest neighbors"); 5 | } 6 | -------------------------------------------------------------------------------- /graspnet/models/backbone.py: -------------------------------------------------------------------------------- 1 | """ PointNet2 backbone for feature learning. 2 | Author: Charles R. Qi 3 | """ 4 | import os 5 | import sys 6 | import torch 7 | import torch.nn as nn 8 | 9 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 10 | ROOT_DIR = os.path.dirname(BASE_DIR) 11 | sys.path.append(ROOT_DIR) 12 | sys.path.append(os.path.join(ROOT_DIR, 'pointnet2')) 13 | 14 | from pointnet2_modules import PointnetSAModuleVotes, PointnetFPModule 15 | 16 | class Pointnet2Backbone(nn.Module): 17 | r""" 18 | Backbone network for point cloud feature learning. 19 | Based on Pointnet++ single-scale grouping network. 20 | 21 | Parameters 22 | ---------- 23 | input_feature_dim: int 24 | Number of input channels in the feature descriptor for each point. 25 | e.g. 3 for RGB. 26 | """ 27 | def __init__(self, input_feature_dim=0): 28 | super().__init__() 29 | 30 | self.sa1 = PointnetSAModuleVotes( 31 | npoint=2048, 32 | radius=0.04, 33 | nsample=64, 34 | mlp=[input_feature_dim, 64, 64, 128], 35 | use_xyz=True, 36 | normalize_xyz=True 37 | ) 38 | 39 | self.sa2 = PointnetSAModuleVotes( 40 | npoint=1024, 41 | radius=0.1, 42 | nsample=32, 43 | mlp=[128, 128, 128, 256], 44 | use_xyz=True, 45 | normalize_xyz=True 46 | ) 47 | 48 | self.sa3 = PointnetSAModuleVotes( 49 | npoint=512, 50 | radius=0.2, 51 | nsample=16, 52 | mlp=[256, 128, 128, 256], 53 | use_xyz=True, 54 | normalize_xyz=True 55 | ) 56 | 57 | self.sa4 = PointnetSAModuleVotes( 58 | npoint=256, 59 | radius=0.3, 60 | nsample=16, 61 | mlp=[256, 128, 128, 256], 62 | use_xyz=True, 63 | normalize_xyz=True 64 | ) 65 | 66 | self.fp1 = PointnetFPModule(mlp=[256+256,256,256]) 67 | self.fp2 = PointnetFPModule(mlp=[256+256,256,256]) 68 | 69 | def _break_up_pc(self, pc): 70 | xyz = pc[..., 0:3].contiguous() 71 | features = ( 72 | pc[..., 3:].transpose(1, 2).contiguous() 73 | if pc.size(-1) > 3 else None 74 | ) 75 | 76 | return xyz, features 77 | 78 | def forward(self, pointcloud: torch.cuda.FloatTensor, end_points=None): 79 | r""" 80 | Forward pass of the network 81 | 82 | Parameters 83 | ---------- 84 | pointcloud: Variable(torch.cuda.FloatTensor) 85 | (B, N, 3 + input_feature_dim) tensor 86 | Point cloud to run predicts on 87 | Each point in the point-cloud MUST 88 | be formated as (x, y, z, features...) 89 | 90 | Returns 91 | ---------- 92 | end_points: {XXX_xyz, XXX_features, XXX_inds} 93 | XXX_xyz: float32 Tensor of shape (B,K,3) 94 | XXX_features: float32 Tensor of shape (B,D,K) 95 | XXX_inds: int64 Tensor of shape (B,K) values in [0,N-1] 96 | """ 97 | if not end_points: end_points = {} 98 | batch_size = pointcloud.shape[0] 99 | 100 | xyz, features = self._break_up_pc(pointcloud) 101 | end_points['input_xyz'] = xyz 102 | end_points['input_features'] = features 103 | 104 | # --------- 4 SET ABSTRACTION LAYERS --------- 105 | xyz, features, fps_inds = self.sa1(xyz, features) 106 | end_points['sa1_inds'] = fps_inds 107 | end_points['sa1_xyz'] = xyz 108 | end_points['sa1_features'] = features 109 | 110 | xyz, features, fps_inds = self.sa2(xyz, features) # this fps_inds is just 0,1,...,1023 111 | end_points['sa2_inds'] = fps_inds 112 | end_points['sa2_xyz'] = xyz 113 | end_points['sa2_features'] = features 114 | 115 | xyz, features, fps_inds = self.sa3(xyz, features) # this fps_inds is just 0,1,...,511 116 | end_points['sa3_xyz'] = xyz 117 | end_points['sa3_features'] = features 118 | 119 | xyz, features, fps_inds = self.sa4(xyz, features) # this fps_inds is just 0,1,...,255 120 | end_points['sa4_xyz'] = xyz 121 | end_points['sa4_features'] = features 122 | 123 | # --------- 2 FEATURE UPSAMPLING LAYERS -------- 124 | features = self.fp1(end_points['sa3_xyz'], end_points['sa4_xyz'], end_points['sa3_features'], end_points['sa4_features']) 125 | features = self.fp2(end_points['sa2_xyz'], end_points['sa3_xyz'], end_points['sa2_features'], features) 126 | end_points['fp2_features'] = features 127 | end_points['fp2_xyz'] = end_points['sa2_xyz'] 128 | num_seed = end_points['fp2_xyz'].shape[1] 129 | end_points['fp2_inds'] = end_points['sa1_inds'][:,0:num_seed] # indices among the entire input point clouds 130 | 131 | return features, end_points['fp2_xyz'], end_points -------------------------------------------------------------------------------- /graspnet/pointnet2/_ext_src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 10 | const int nsample); 11 | -------------------------------------------------------------------------------- /graspnet/pointnet2/_ext_src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #ifndef _CUDA_UTILS_H 7 | #define _CUDA_UTILS_H 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | #define TOTAL_THREADS 512 19 | 20 | inline int opt_n_threads(int work_size) { 21 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 22 | 23 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 24 | } 25 | 26 | inline dim3 opt_block_config(int x, int y) { 27 | const int x_threads = opt_n_threads(x); 28 | const int y_threads = 29 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 30 | dim3 block_config(x_threads, y_threads, 1); 31 | 32 | return block_config; 33 | } 34 | 35 | #define CUDA_CHECK_ERRORS() \ 36 | do { \ 37 | cudaError_t err = cudaGetLastError(); \ 38 | if (cudaSuccess != err) { \ 39 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 40 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 41 | __FILE__); \ 42 | exit(-1); \ 43 | } \ 44 | } while (0) 45 | 46 | #endif 47 | -------------------------------------------------------------------------------- /graspnet/pointnet2/_ext_src/include/cylinder_query.h: -------------------------------------------------------------------------------- 1 | // Author: chenxi-wang 2 | 3 | #pragma once 4 | #include 5 | 6 | at::Tensor cylinder_query(at::Tensor new_xyz, at::Tensor xyz, at::Tensor rot, const float radius, const float hmin, const float hmax, 7 | const int nsample); 8 | -------------------------------------------------------------------------------- /graspnet/pointnet2/_ext_src/include/group_points.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 10 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 11 | -------------------------------------------------------------------------------- /graspnet/pointnet2/_ext_src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 12 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 13 | at::Tensor weight); 14 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 15 | at::Tensor weight, const int m); 16 | -------------------------------------------------------------------------------- /graspnet/pointnet2/_ext_src/include/sampling.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 10 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 11 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 12 | -------------------------------------------------------------------------------- /graspnet/pointnet2/_ext_src/include/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | 10 | #define CHECK_CUDA(x) \ 11 | do { \ 12 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_CONTIGUOUS(x) \ 16 | do { \ 17 | TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \ 18 | } while (0) 19 | 20 | #define CHECK_IS_INT(x) \ 21 | do { \ 22 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ 23 | #x " must be an int tensor"); \ 24 | } while (0) 25 | 26 | #define CHECK_IS_FLOAT(x) \ 27 | do { \ 28 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \ 29 | #x " must be a float tensor"); \ 30 | } while (0) 31 | -------------------------------------------------------------------------------- /graspnet/pointnet2/_ext_src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "ball_query.h" 7 | #include "utils.h" 8 | 9 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 10 | int nsample, const float *new_xyz, 11 | const float *xyz, int *idx); 12 | 13 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 14 | const int nsample) { 15 | CHECK_CONTIGUOUS(new_xyz); 16 | CHECK_CONTIGUOUS(xyz); 17 | CHECK_IS_FLOAT(new_xyz); 18 | CHECK_IS_FLOAT(xyz); 19 | 20 | if (new_xyz.type().is_cuda()) { 21 | CHECK_CUDA(xyz); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 26 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 27 | 28 | if (new_xyz.type().is_cuda()) { 29 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 30 | radius, nsample, new_xyz.data(), 31 | xyz.data(), idx.data()); 32 | } else { 33 | TORCH_CHECK(false, "CPU not supported"); 34 | } 35 | 36 | return idx; 37 | } 38 | -------------------------------------------------------------------------------- /graspnet/pointnet2/_ext_src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_utils.h" 11 | 12 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 13 | // output: idx(b, m, nsample) 14 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 15 | int nsample, 16 | const float *__restrict__ new_xyz, 17 | const float *__restrict__ xyz, 18 | int *__restrict__ idx) { 19 | int batch_index = blockIdx.x; 20 | xyz += batch_index * n * 3; 21 | new_xyz += batch_index * m * 3; 22 | idx += m * nsample * batch_index; 23 | 24 | int index = threadIdx.x; 25 | int stride = blockDim.x; 26 | 27 | float radius2 = radius * radius; 28 | for (int j = index; j < m; j += stride) { 29 | float new_x = new_xyz[j * 3 + 0]; 30 | float new_y = new_xyz[j * 3 + 1]; 31 | float new_z = new_xyz[j * 3 + 2]; 32 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 33 | float x = xyz[k * 3 + 0]; 34 | float y = xyz[k * 3 + 1]; 35 | float z = xyz[k * 3 + 2]; 36 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 37 | (new_z - z) * (new_z - z); 38 | if (d2 < radius2) { 39 | if (cnt == 0) { 40 | for (int l = 0; l < nsample; ++l) { 41 | idx[j * nsample + l] = k; 42 | } 43 | } 44 | idx[j * nsample + cnt] = k; 45 | ++cnt; 46 | } 47 | } 48 | } 49 | } 50 | 51 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 52 | int nsample, const float *new_xyz, 53 | const float *xyz, int *idx) { 54 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 55 | query_ball_point_kernel<<>>( 56 | b, n, m, radius, nsample, new_xyz, xyz, idx); 57 | 58 | CUDA_CHECK_ERRORS(); 59 | } 60 | -------------------------------------------------------------------------------- /graspnet/pointnet2/_ext_src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "ball_query.h" 7 | #include "group_points.h" 8 | #include "interpolate.h" 9 | #include "sampling.h" 10 | #include "cylinder_query.h" 11 | 12 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 13 | m.def("gather_points", &gather_points); 14 | m.def("gather_points_grad", &gather_points_grad); 15 | m.def("furthest_point_sampling", &furthest_point_sampling); 16 | 17 | m.def("three_nn", &three_nn); 18 | m.def("three_interpolate", &three_interpolate); 19 | m.def("three_interpolate_grad", &three_interpolate_grad); 20 | 21 | m.def("ball_query", &ball_query); 22 | 23 | m.def("group_points", &group_points); 24 | m.def("group_points_grad", &group_points_grad); 25 | 26 | m.def("cylinder_query", &cylinder_query); 27 | } 28 | -------------------------------------------------------------------------------- /graspnet/pointnet2/_ext_src/src/cylinder_query.cpp: -------------------------------------------------------------------------------- 1 | // Author: chenxi-wang 2 | 3 | #include "cylinder_query.h" 4 | #include "utils.h" 5 | 6 | void query_cylinder_point_kernel_wrapper(int b, int n, int m, float radius, float hmin, float hmax, 7 | int nsample, const float *new_xyz, 8 | const float *xyz, const float *rot, int *idx); 9 | 10 | at::Tensor cylinder_query(at::Tensor new_xyz, at::Tensor xyz, at::Tensor rot, const float radius, const float hmin, const float hmax, 11 | const int nsample) { 12 | CHECK_CONTIGUOUS(new_xyz); 13 | CHECK_CONTIGUOUS(xyz); 14 | CHECK_CONTIGUOUS(rot); 15 | CHECK_IS_FLOAT(new_xyz); 16 | CHECK_IS_FLOAT(xyz); 17 | CHECK_IS_FLOAT(rot); 18 | 19 | if (new_xyz.type().is_cuda()) { 20 | CHECK_CUDA(xyz); 21 | CHECK_CUDA(rot); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 26 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 27 | 28 | if (new_xyz.type().is_cuda()) { 29 | query_cylinder_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 30 | radius, hmin, hmax, nsample, new_xyz.data(), 31 | xyz.data(), rot.data(), idx.data()); 32 | } else { 33 | TORCH_CHECK(false, "CPU not supported"); 34 | } 35 | 36 | return idx; 37 | } 38 | -------------------------------------------------------------------------------- /graspnet/pointnet2/_ext_src/src/cylinder_query_gpu.cu: -------------------------------------------------------------------------------- 1 | // Author: chenxi-wang 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "cuda_utils.h" 8 | 9 | __global__ void query_cylinder_point_kernel(int b, int n, int m, float radius, float hmin, float hmax, 10 | int nsample, 11 | const float *__restrict__ new_xyz, 12 | const float *__restrict__ xyz, 13 | const float *__restrict__ rot, 14 | int *__restrict__ idx) { 15 | int batch_index = blockIdx.x; 16 | xyz += batch_index * n * 3; 17 | new_xyz += batch_index * m * 3; 18 | rot += batch_index * m * 9; 19 | idx += m * nsample * batch_index; 20 | 21 | int index = threadIdx.x; 22 | int stride = blockDim.x; 23 | 24 | float radius2 = radius * radius; 25 | for (int j = index; j < m; j += stride) { 26 | float new_x = new_xyz[j * 3 + 0]; 27 | float new_y = new_xyz[j * 3 + 1]; 28 | float new_z = new_xyz[j * 3 + 2]; 29 | float r0 = rot[j * 9 + 0]; 30 | float r1 = rot[j * 9 + 1]; 31 | float r2 = rot[j * 9 + 2]; 32 | float r3 = rot[j * 9 + 3]; 33 | float r4 = rot[j * 9 + 4]; 34 | float r5 = rot[j * 9 + 5]; 35 | float r6 = rot[j * 9 + 6]; 36 | float r7 = rot[j * 9 + 7]; 37 | float r8 = rot[j * 9 + 8]; 38 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 39 | float x = xyz[k * 3 + 0] - new_x; 40 | float y = xyz[k * 3 + 1] - new_y; 41 | float z = xyz[k * 3 + 2] - new_z; 42 | float x_rot = r0 * x + r3 * y + r6 * z; 43 | float y_rot = r1 * x + r4 * y + r7 * z; 44 | float z_rot = r2 * x + r5 * y + r8 * z; 45 | float d2 = y_rot * y_rot + z_rot * z_rot; 46 | if (d2 < radius2 && x_rot > hmin && x_rot < hmax) { 47 | if (cnt == 0) { 48 | for (int l = 0; l < nsample; ++l) { 49 | idx[j * nsample + l] = k; 50 | } 51 | } 52 | idx[j * nsample + cnt] = k; 53 | ++cnt; 54 | } 55 | } 56 | } 57 | } 58 | 59 | void query_cylinder_point_kernel_wrapper(int b, int n, int m, float radius, float hmin, float hmax, 60 | int nsample, const float *new_xyz, 61 | const float *xyz, const float *rot, int *idx) { 62 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 63 | query_cylinder_point_kernel<<>>( 64 | b, n, m, radius, hmin, hmax, nsample, new_xyz, xyz, rot, idx); 65 | 66 | CUDA_CHECK_ERRORS(); 67 | } 68 | -------------------------------------------------------------------------------- /graspnet/pointnet2/_ext_src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "group_points.h" 7 | #include "utils.h" 8 | 9 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 10 | const float *points, const int *idx, 11 | float *out); 12 | 13 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 14 | int nsample, const float *grad_out, 15 | const int *idx, float *grad_points); 16 | 17 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 18 | CHECK_CONTIGUOUS(points); 19 | CHECK_CONTIGUOUS(idx); 20 | CHECK_IS_FLOAT(points); 21 | CHECK_IS_INT(idx); 22 | 23 | if (points.type().is_cuda()) { 24 | CHECK_CUDA(idx); 25 | } 26 | 27 | at::Tensor output = 28 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 29 | at::device(points.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (points.type().is_cuda()) { 32 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 33 | idx.size(1), idx.size(2), points.data(), 34 | idx.data(), output.data()); 35 | } else { 36 | TORCH_CHECK(false, "CPU not supported"); 37 | } 38 | 39 | return output; 40 | } 41 | 42 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 43 | CHECK_CONTIGUOUS(grad_out); 44 | CHECK_CONTIGUOUS(idx); 45 | CHECK_IS_FLOAT(grad_out); 46 | CHECK_IS_INT(idx); 47 | 48 | if (grad_out.type().is_cuda()) { 49 | CHECK_CUDA(idx); 50 | } 51 | 52 | at::Tensor output = 53 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 54 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 55 | 56 | if (grad_out.type().is_cuda()) { 57 | group_points_grad_kernel_wrapper( 58 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 59 | grad_out.data(), idx.data(), output.data()); 60 | } else { 61 | TORCH_CHECK(false, "CPU not supported"); 62 | } 63 | 64 | return output; 65 | } 66 | -------------------------------------------------------------------------------- /graspnet/pointnet2/_ext_src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | 9 | #include "cuda_utils.h" 10 | 11 | // input: points(b, c, n) idx(b, npoints, nsample) 12 | // output: out(b, c, npoints, nsample) 13 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 14 | int nsample, 15 | const float *__restrict__ points, 16 | const int *__restrict__ idx, 17 | float *__restrict__ out) { 18 | int batch_index = blockIdx.x; 19 | points += batch_index * n * c; 20 | idx += batch_index * npoints * nsample; 21 | out += batch_index * npoints * nsample * c; 22 | 23 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 24 | const int stride = blockDim.y * blockDim.x; 25 | for (int i = index; i < c * npoints; i += stride) { 26 | const int l = i / npoints; 27 | const int j = i % npoints; 28 | for (int k = 0; k < nsample; ++k) { 29 | int ii = idx[j * nsample + k]; 30 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 31 | } 32 | } 33 | } 34 | 35 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 36 | const float *points, const int *idx, 37 | float *out) { 38 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 39 | 40 | group_points_kernel<<>>( 41 | b, c, n, npoints, nsample, points, idx, out); 42 | 43 | CUDA_CHECK_ERRORS(); 44 | } 45 | 46 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 47 | // output: grad_points(b, c, n) 48 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 49 | int nsample, 50 | const float *__restrict__ grad_out, 51 | const int *__restrict__ idx, 52 | float *__restrict__ grad_points) { 53 | int batch_index = blockIdx.x; 54 | grad_out += batch_index * npoints * nsample * c; 55 | idx += batch_index * npoints * nsample; 56 | grad_points += batch_index * n * c; 57 | 58 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 59 | const int stride = blockDim.y * blockDim.x; 60 | for (int i = index; i < c * npoints; i += stride) { 61 | const int l = i / npoints; 62 | const int j = i % npoints; 63 | for (int k = 0; k < nsample; ++k) { 64 | int ii = idx[j * nsample + k]; 65 | atomicAdd(grad_points + l * n + ii, 66 | grad_out[(l * npoints + j) * nsample + k]); 67 | } 68 | } 69 | } 70 | 71 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 72 | int nsample, const float *grad_out, 73 | const int *idx, float *grad_points) { 74 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 75 | 76 | group_points_grad_kernel<<>>( 77 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 78 | 79 | CUDA_CHECK_ERRORS(); 80 | } 81 | -------------------------------------------------------------------------------- /graspnet/pointnet2/_ext_src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "interpolate.h" 7 | #include "utils.h" 8 | 9 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 10 | const float *known, float *dist2, int *idx); 11 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 12 | const float *points, const int *idx, 13 | const float *weight, float *out); 14 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 15 | const float *grad_out, 16 | const int *idx, const float *weight, 17 | float *grad_points); 18 | 19 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 20 | CHECK_CONTIGUOUS(unknowns); 21 | CHECK_CONTIGUOUS(knows); 22 | CHECK_IS_FLOAT(unknowns); 23 | CHECK_IS_FLOAT(knows); 24 | 25 | if (unknowns.type().is_cuda()) { 26 | CHECK_CUDA(knows); 27 | } 28 | 29 | at::Tensor idx = 30 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 31 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 32 | at::Tensor dist2 = 33 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 34 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 35 | 36 | if (unknowns.type().is_cuda()) { 37 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 38 | unknowns.data(), knows.data(), 39 | dist2.data(), idx.data()); 40 | } else { 41 | TORCH_CHECK(false, "CPU not supported"); 42 | } 43 | 44 | return {dist2, idx}; 45 | } 46 | 47 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 48 | at::Tensor weight) { 49 | CHECK_CONTIGUOUS(points); 50 | CHECK_CONTIGUOUS(idx); 51 | CHECK_CONTIGUOUS(weight); 52 | CHECK_IS_FLOAT(points); 53 | CHECK_IS_INT(idx); 54 | CHECK_IS_FLOAT(weight); 55 | 56 | if (points.type().is_cuda()) { 57 | CHECK_CUDA(idx); 58 | CHECK_CUDA(weight); 59 | } 60 | 61 | at::Tensor output = 62 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 63 | at::device(points.device()).dtype(at::ScalarType::Float)); 64 | 65 | if (points.type().is_cuda()) { 66 | three_interpolate_kernel_wrapper( 67 | points.size(0), points.size(1), points.size(2), idx.size(1), 68 | points.data(), idx.data(), weight.data(), 69 | output.data()); 70 | } else { 71 | TORCH_CHECK(false, "CPU not supported"); 72 | } 73 | 74 | return output; 75 | } 76 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 77 | at::Tensor weight, const int m) { 78 | CHECK_CONTIGUOUS(grad_out); 79 | CHECK_CONTIGUOUS(idx); 80 | CHECK_CONTIGUOUS(weight); 81 | CHECK_IS_FLOAT(grad_out); 82 | CHECK_IS_INT(idx); 83 | CHECK_IS_FLOAT(weight); 84 | 85 | if (grad_out.type().is_cuda()) { 86 | CHECK_CUDA(idx); 87 | CHECK_CUDA(weight); 88 | } 89 | 90 | at::Tensor output = 91 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 92 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 93 | 94 | if (grad_out.type().is_cuda()) { 95 | three_interpolate_grad_kernel_wrapper( 96 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 97 | grad_out.data(), idx.data(), weight.data(), 98 | output.data()); 99 | } else { 100 | TORCH_CHECK(false, "CPU not supported"); 101 | } 102 | 103 | return output; 104 | } 105 | -------------------------------------------------------------------------------- /graspnet/pointnet2/_ext_src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_utils.h" 11 | 12 | // input: unknown(b, n, 3) known(b, m, 3) 13 | // output: dist2(b, n, 3), idx(b, n, 3) 14 | __global__ void three_nn_kernel(int b, int n, int m, 15 | const float *__restrict__ unknown, 16 | const float *__restrict__ known, 17 | float *__restrict__ dist2, 18 | int *__restrict__ idx) { 19 | int batch_index = blockIdx.x; 20 | unknown += batch_index * n * 3; 21 | known += batch_index * m * 3; 22 | dist2 += batch_index * n * 3; 23 | idx += batch_index * n * 3; 24 | 25 | int index = threadIdx.x; 26 | int stride = blockDim.x; 27 | for (int j = index; j < n; j += stride) { 28 | float ux = unknown[j * 3 + 0]; 29 | float uy = unknown[j * 3 + 1]; 30 | float uz = unknown[j * 3 + 2]; 31 | 32 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 33 | int besti1 = 0, besti2 = 0, besti3 = 0; 34 | for (int k = 0; k < m; ++k) { 35 | float x = known[k * 3 + 0]; 36 | float y = known[k * 3 + 1]; 37 | float z = known[k * 3 + 2]; 38 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 39 | if (d < best1) { 40 | best3 = best2; 41 | besti3 = besti2; 42 | best2 = best1; 43 | besti2 = besti1; 44 | best1 = d; 45 | besti1 = k; 46 | } else if (d < best2) { 47 | best3 = best2; 48 | besti3 = besti2; 49 | best2 = d; 50 | besti2 = k; 51 | } else if (d < best3) { 52 | best3 = d; 53 | besti3 = k; 54 | } 55 | } 56 | dist2[j * 3 + 0] = best1; 57 | dist2[j * 3 + 1] = best2; 58 | dist2[j * 3 + 2] = best3; 59 | 60 | idx[j * 3 + 0] = besti1; 61 | idx[j * 3 + 1] = besti2; 62 | idx[j * 3 + 2] = besti3; 63 | } 64 | } 65 | 66 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 67 | const float *known, float *dist2, int *idx) { 68 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 69 | three_nn_kernel<<>>(b, n, m, unknown, known, 70 | dist2, idx); 71 | 72 | CUDA_CHECK_ERRORS(); 73 | } 74 | 75 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 76 | // output: out(b, c, n) 77 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 78 | const float *__restrict__ points, 79 | const int *__restrict__ idx, 80 | const float *__restrict__ weight, 81 | float *__restrict__ out) { 82 | int batch_index = blockIdx.x; 83 | points += batch_index * m * c; 84 | 85 | idx += batch_index * n * 3; 86 | weight += batch_index * n * 3; 87 | 88 | out += batch_index * n * c; 89 | 90 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 91 | const int stride = blockDim.y * blockDim.x; 92 | for (int i = index; i < c * n; i += stride) { 93 | const int l = i / n; 94 | const int j = i % n; 95 | float w1 = weight[j * 3 + 0]; 96 | float w2 = weight[j * 3 + 1]; 97 | float w3 = weight[j * 3 + 2]; 98 | 99 | int i1 = idx[j * 3 + 0]; 100 | int i2 = idx[j * 3 + 1]; 101 | int i3 = idx[j * 3 + 2]; 102 | 103 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 104 | points[l * m + i3] * w3; 105 | } 106 | } 107 | 108 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 109 | const float *points, const int *idx, 110 | const float *weight, float *out) { 111 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 112 | three_interpolate_kernel<<>>( 113 | b, c, m, n, points, idx, weight, out); 114 | 115 | CUDA_CHECK_ERRORS(); 116 | } 117 | 118 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 119 | // output: grad_points(b, c, m) 120 | 121 | __global__ void three_interpolate_grad_kernel( 122 | int b, int c, int n, int m, const float *__restrict__ grad_out, 123 | const int *__restrict__ idx, const float *__restrict__ weight, 124 | float *__restrict__ grad_points) { 125 | int batch_index = blockIdx.x; 126 | grad_out += batch_index * n * c; 127 | idx += batch_index * n * 3; 128 | weight += batch_index * n * 3; 129 | grad_points += batch_index * m * c; 130 | 131 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 132 | const int stride = blockDim.y * blockDim.x; 133 | for (int i = index; i < c * n; i += stride) { 134 | const int l = i / n; 135 | const int j = i % n; 136 | float w1 = weight[j * 3 + 0]; 137 | float w2 = weight[j * 3 + 1]; 138 | float w3 = weight[j * 3 + 2]; 139 | 140 | int i1 = idx[j * 3 + 0]; 141 | int i2 = idx[j * 3 + 1]; 142 | int i3 = idx[j * 3 + 2]; 143 | 144 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 145 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 146 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 147 | } 148 | } 149 | 150 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 151 | const float *grad_out, 152 | const int *idx, const float *weight, 153 | float *grad_points) { 154 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 155 | three_interpolate_grad_kernel<<>>( 156 | b, c, n, m, grad_out, idx, weight, grad_points); 157 | 158 | CUDA_CHECK_ERRORS(); 159 | } 160 | -------------------------------------------------------------------------------- /graspnet/pointnet2/_ext_src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "sampling.h" 7 | #include "utils.h" 8 | 9 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 10 | const float *points, const int *idx, 11 | float *out); 12 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 13 | const float *grad_out, const int *idx, 14 | float *grad_points); 15 | 16 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 17 | const float *dataset, float *temp, 18 | int *idxs); 19 | 20 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 21 | CHECK_CONTIGUOUS(points); 22 | CHECK_CONTIGUOUS(idx); 23 | CHECK_IS_FLOAT(points); 24 | CHECK_IS_INT(idx); 25 | 26 | if (points.type().is_cuda()) { 27 | CHECK_CUDA(idx); 28 | } 29 | 30 | at::Tensor output = 31 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 32 | at::device(points.device()).dtype(at::ScalarType::Float)); 33 | 34 | if (points.type().is_cuda()) { 35 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 36 | idx.size(1), points.data(), 37 | idx.data(), output.data()); 38 | } else { 39 | TORCH_CHECK(false, "CPU not supported"); 40 | } 41 | 42 | return output; 43 | } 44 | 45 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 46 | const int n) { 47 | CHECK_CONTIGUOUS(grad_out); 48 | CHECK_CONTIGUOUS(idx); 49 | CHECK_IS_FLOAT(grad_out); 50 | CHECK_IS_INT(idx); 51 | 52 | if (grad_out.type().is_cuda()) { 53 | CHECK_CUDA(idx); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 58 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (grad_out.type().is_cuda()) { 61 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 62 | idx.size(1), grad_out.data(), 63 | idx.data(), output.data()); 64 | } else { 65 | TORCH_CHECK(false, "CPU not supported"); 66 | } 67 | 68 | return output; 69 | } 70 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 71 | CHECK_CONTIGUOUS(points); 72 | CHECK_IS_FLOAT(points); 73 | 74 | at::Tensor output = 75 | torch::zeros({points.size(0), nsamples}, 76 | at::device(points.device()).dtype(at::ScalarType::Int)); 77 | 78 | at::Tensor tmp = 79 | torch::full({points.size(0), points.size(1)}, 1e10, 80 | at::device(points.device()).dtype(at::ScalarType::Float)); 81 | 82 | if (points.type().is_cuda()) { 83 | furthest_point_sampling_kernel_wrapper( 84 | points.size(0), points.size(1), nsamples, points.data(), 85 | tmp.data(), output.data()); 86 | } else { 87 | TORCH_CHECK(false, "CPU not supported"); 88 | } 89 | 90 | return output; 91 | } 92 | -------------------------------------------------------------------------------- /graspnet/pointnet2/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from setuptools import setup 7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 8 | import glob 9 | import os 10 | ROOT = os.path.dirname(os.path.abspath(__file__)) 11 | 12 | _ext_src_root = "_ext_src" 13 | _ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob( 14 | "{}/src/*.cu".format(_ext_src_root) 15 | ) 16 | _ext_headers = glob.glob("{}/include/*".format(_ext_src_root)) 17 | 18 | setup( 19 | name='pointnet2', 20 | ext_modules=[ 21 | CUDAExtension( 22 | name='pointnet2._ext', 23 | sources=_ext_sources, 24 | extra_compile_args={ 25 | "cxx": ["-O2", "-I{}".format("{}/{}/include".format(ROOT, _ext_src_root))], 26 | "nvcc": ["-O2", "-I{}".format("{}/{}/include".format(ROOT, _ext_src_root))], 27 | }, 28 | ) 29 | ], 30 | cmdclass={ 31 | 'build_ext': BuildExtension 32 | } 33 | ) 34 | -------------------------------------------------------------------------------- /graspnet/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.6 2 | tensorboard==2.3 3 | numpy 4 | scipy 5 | open3d>=0.8 6 | Pillow 7 | tqdm 8 | -------------------------------------------------------------------------------- /graspnet/test.py: -------------------------------------------------------------------------------- 1 | """ Testing for GraspNet baseline model. """ 2 | 3 | import os 4 | import sys 5 | import numpy as np 6 | import argparse 7 | import time 8 | 9 | import torch 10 | from torch.utils.data import DataLoader 11 | from graspnetAPI import GraspGroup, GraspNetEval 12 | 13 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 14 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 15 | sys.path.append(os.path.join(ROOT_DIR, 'dataset')) 16 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 17 | 18 | from graspnet import GraspNet, pred_decode 19 | from graspnet_dataset import GraspNetDataset, collate_fn 20 | from bullet_dataset import BulletDataset 21 | from collision_detector import ModelFreeCollisionDetector 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--dataset_root', required=True, help='Dataset root') 25 | parser.add_argument('--checkpoint_path', required=True, help='Model checkpoint path') 26 | parser.add_argument('--dump_dir', required=True, help='Dump dir to save outputs') 27 | parser.add_argument('--camera', required=True, help='Camera split [realsense/kinect]') 28 | parser.add_argument('--num_point', type=int, default=20000, help='Point Number [default: 20000]') 29 | parser.add_argument('--num_view', type=int, default=300, help='View Number [default: 300]') 30 | parser.add_argument('--batch_size', type=int, default=1, help='Batch Size during inference [default: 1]') 31 | parser.add_argument('--collision_thresh', type=float, default=0.01, help='Collision Threshold in collision detection [default: 0.01]') 32 | parser.add_argument('--voxel_size', type=float, default=0.01, help='Voxel Size to process point clouds before collision detection [default: 0.01]') 33 | parser.add_argument('--num_workers', type=int, default=30, help='Number of workers used in evaluation [default: 30]') 34 | cfgs = parser.parse_args() 35 | 36 | # ------------------------------------------------------------------------- GLOBAL CONFIG BEG 37 | if not os.path.exists(cfgs.dump_dir): os.mkdir(cfgs.dump_dir) 38 | 39 | # Init datasets and dataloaders 40 | def my_worker_init_fn(worker_id): 41 | np.random.seed(np.random.get_state()[1][0] + worker_id) 42 | pass 43 | 44 | # Create Dataset and Dataloader 45 | TEST_DATASET = BulletDataset(cfgs.dataset_root, valid_obj_idxs=None, grasp_labels=None, split='test', camera=cfgs.camera, num_points=cfgs.num_point, remove_outlier=True, augment=False, load_label=False) 46 | 47 | print(len(TEST_DATASET)) 48 | SCENE_LIST = TEST_DATASET.scene_list() 49 | TEST_DATALOADER = DataLoader(TEST_DATASET, batch_size=cfgs.batch_size, shuffle=False, 50 | num_workers=4, worker_init_fn=my_worker_init_fn, collate_fn=collate_fn) 51 | print(len(TEST_DATALOADER)) 52 | # Init the model 53 | net = GraspNet(input_feature_dim=0, num_view=cfgs.num_view, num_angle=12, num_depth=4, 54 | cylinder_radius=0.05, hmin=-0.02, hmax_list=[0.01,0.02,0.03,0.04], is_training=False) 55 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 56 | net.to(device) 57 | # Load checkpoint 58 | checkpoint = torch.load(cfgs.checkpoint_path) 59 | net.load_state_dict(checkpoint['model_state_dict']) 60 | start_epoch = checkpoint['epoch'] 61 | print("-> loaded checkpoint %s (epoch: %d)"%(cfgs.checkpoint_path, start_epoch)) 62 | 63 | 64 | # ------------------------------------------------------------------------- GLOBAL CONFIG END 65 | 66 | def inference(): 67 | batch_interval = 100 68 | stat_dict = {} # collect statistics 69 | # set model to eval mode (for bn and dp) 70 | net.eval() 71 | tic = time.time() 72 | for batch_idx, batch_data in enumerate(TEST_DATALOADER): 73 | for key in batch_data: 74 | if 'list' in key: 75 | for i in range(len(batch_data[key])): 76 | for j in range(len(batch_data[key][i])): 77 | batch_data[key][i][j] = batch_data[key][i][j].to(device) 78 | else: 79 | batch_data[key] = batch_data[key].to(device) 80 | 81 | # Forward pass 82 | with torch.no_grad(): 83 | end_points = net(batch_data) 84 | grasp_preds = pred_decode(end_points) 85 | 86 | # Dump results for evaluation 87 | for i in range(cfgs.batch_size): 88 | data_idx = batch_idx * cfgs.batch_size + i 89 | preds = grasp_preds[i].detach().cpu().numpy() 90 | gg = GraspGroup(preds) 91 | 92 | # collision detection 93 | if cfgs.collision_thresh > 0: 94 | cloud, _ = TEST_DATASET.get_data(data_idx, return_raw_cloud=True) 95 | mfcdetector = ModelFreeCollisionDetector(cloud, voxel_size=cfgs.voxel_size) 96 | collision_mask = mfcdetector.detect(gg, approach_dist=0.05, collision_thresh=cfgs.collision_thresh) 97 | gg = gg[~collision_mask] 98 | 99 | # save grasps 100 | save_dir = os.path.join(cfgs.dump_dir, SCENE_LIST[data_idx], cfgs.camera) 101 | save_path = os.path.join(save_dir, str(data_idx%256).zfill(4)+'.npy') 102 | if not os.path.exists(save_dir): 103 | os.makedirs(save_dir) 104 | gg.save_npy(save_path) 105 | 106 | if batch_idx % batch_interval == 0: 107 | toc = time.time() 108 | print('Eval batch: %d, time: %fs'%(batch_idx, (toc-tic)/batch_interval)) 109 | tic = time.time() 110 | 111 | def evaluate(): 112 | ge = GraspNetEval(root=cfgs.dataset_root, camera=cfgs.camera, split='test') 113 | res, ap = ge.eval_all(cfgs.dump_dir, proc=cfgs.num_workers) 114 | save_dir = os.path.join(cfgs.dump_dir, 'ap_{}.npy'.format(cfgs.camera)) 115 | np.save(save_dir, res) 116 | 117 | if __name__=='__main__': 118 | inference() 119 | # evaluate() 120 | -------------------------------------------------------------------------------- /graspnet/utils/grasp_projection.py: -------------------------------------------------------------------------------- 1 | from graspnetAPI.grasp import GraspGroup, Grasp 2 | import numpy as np 3 | import copy 4 | import cv2 5 | 6 | import pdb 7 | 8 | 9 | def filter_grasp_in_mask(grasp_group, mask, camera, extrinsics=None, camera_s=1000.0, expansion=5): 10 | '''Project the grasp pose group to the image plane and filter the grasps in the mask. 11 | Args: 12 | grasp_group (GraspGroup): the grasp pose group. 13 | mask (np.ndarray): the mask of the image. 14 | camera (np.ndarray or str): the intrinsic matrix or 'kinect' or 'realsense'. 15 | camera_s (float): the camera scale. 16 | ''' 17 | # Project the grasp pose group to the image plane 18 | rect_grasp_group = project_to_image_plane(grasp_group, camera, extrinsics, camera_s) 19 | # Filter the grasps in the mask 20 | mask = mask.astype(np.uint8) 21 | filtered_grasp_group = GraspGroup() 22 | # pdb.set_trace() 23 | for i in range(rect_grasp_group.shape[0]): 24 | rect_grasp_pt = rect_grasp_group[i,:] # (x, y) 25 | if (0 < rect_grasp_pt[1] < mask.shape[0] and 26 | 0 < rect_grasp_pt[0] < mask.shape[1] and 27 | mask[int(rect_grasp_pt[1]), int(rect_grasp_pt[0])] > 0): 28 | filtered_grasp_group.add(grasp_group[i]) 29 | if len(filtered_grasp_group) < 5: 30 | # binary expansion 31 | mask = cv2.dilate(mask, np.ones((5, 5), np.uint8), iterations=1) 32 | # mask = cv2.dilate(mask, np.ones((50, 50), np.uint8), iterations=1) 33 | # pdb.set_trace() 34 | for i in range(rect_grasp_group.shape[0]): 35 | rect_grasp_pt = rect_grasp_group[i,:] # (x, y) 36 | if (0 < rect_grasp_pt[1] < mask.shape[0] and 37 | 0 < rect_grasp_pt[0] < mask.shape[1] and 38 | mask[int(rect_grasp_pt[1]), int(rect_grasp_pt[0])] > 0): 39 | filtered_grasp_group.add(grasp_group[i]) 40 | return filtered_grasp_group 41 | 42 | def project_to_image_plane(grasp_group, camera, extrinsics=None, camera_s=1000.0): 43 | """Project the 6DoF grasp pose group to the image plane.""" 44 | if extrinsics is not None: 45 | grasp_group = copy.deepcopy(grasp_group) 46 | grasp_group = grasp_group.transform(extrinsics) 47 | if isinstance(camera, str): 48 | return grasp_group.to_rect_grasp_group(camera).center_points 49 | elif isinstance(camera, np.ndarray): 50 | translations = grasp_group.translations 51 | if len(camera.shape) == 1: 52 | fx = camera[0] 53 | fy = camera[4] 54 | cx = camera[2] 55 | cy = camera[5] 56 | else: 57 | fx, fy = camera[0, 0], camera[1, 1] 58 | cx, cy = camera[0, 2], camera[1, 2] 59 | # z = translations[:, 2] * camera_s 60 | coords_x = translations[:, 0] / translations[:, 2] * fx + cx 61 | coords_y = translations[:, 1] / translations[:, 2] * fy + cy 62 | return np.stack([coords_x, coords_y], axis=-1) 63 | 64 | def look_at(camera_center, point, upward): 65 | def normalize(v): 66 | return v / np.linalg.norm(v) 67 | forward = normalize(point - camera_center) 68 | right = normalize(np.cross(upward, forward)) 69 | up = np.cross(forward, right) 70 | rotation = np.array([right, up, forward]).T 71 | translation = -rotation @ camera_center 72 | extrinsics = np.zeros((4, 4)) 73 | extrinsics[:3, :3] = rotation 74 | extrinsics[:3, 3] = translation 75 | extrinsics[3, 3] = 1 76 | return extrinsics 77 | 78 | 79 | if __name__ == '__main__': 80 | # Create a dummy grasp group 81 | grasp_group = GraspGroup() 82 | grasp_group.add(Grasp()) 83 | grasp_group.add(Grasp()) 84 | grasp_group.add(Grasp()) 85 | 86 | # Create a dummy mask 87 | mask = np.zeros((10, 10)) 88 | mask[0:5, 0:5] = 1 89 | 90 | intrinsics = np.array([[1000, 0, 5], [0, 1000, 5], [0, 0, 1]]) 91 | extrinsics = np.eye(4) 92 | 93 | # Call the function under test 94 | filtered_grasp_group = filter_grasp_in_mask(grasp_group, mask, intrinsics, extrinsics, camera_s=1000.0) 95 | 96 | # Check the filtered grasp group 97 | print(len(filtered_grasp_group)) -------------------------------------------------------------------------------- /graspnet/utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | """ Tools for loss computation. 2 | Author: chenxi-wang 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | 8 | GRASP_MAX_WIDTH = 0.1 9 | GRASP_MAX_TOLERANCE = 0.05 10 | THRESH_GOOD = 0.7 11 | THRESH_BAD = 0.1 12 | 13 | def transform_point_cloud(cloud, transform, format='4x4'): 14 | """ Transform points to new coordinates with transformation matrix. 15 | 16 | Input: 17 | cloud: [torch.FloatTensor, (N,3)] 18 | points in original coordinates 19 | transform: [torch.FloatTensor, (3,3)/(3,4)/(4,4)] 20 | transformation matrix, could be rotation only or rotation+translation 21 | format: [string, '3x3'/'3x4'/'4x4'] 22 | the shape of transformation matrix 23 | '3x3' --> rotation matrix 24 | '3x4'/'4x4' --> rotation matrix + translation matrix 25 | 26 | Output: 27 | cloud_transformed: [torch.FloatTensor, (N,3)] 28 | points in new coordinates 29 | """ 30 | if not (format == '3x3' or format == '4x4' or format == '3x4'): 31 | raise ValueError('Unknown transformation format, only support \'3x3\' or \'4x4\' or \'3x4\'.') 32 | if format == '3x3': 33 | cloud_transformed = torch.matmul(transform, cloud.T).T 34 | elif format == '4x4' or format == '3x4': 35 | ones = cloud.new_ones(cloud.size(0), device=cloud.device).unsqueeze(-1) 36 | cloud_ = torch.cat([cloud, ones], dim=1) 37 | cloud_transformed = torch.matmul(transform, cloud_.T).T 38 | cloud_transformed = cloud_transformed[:, :3] 39 | return cloud_transformed 40 | 41 | def generate_grasp_views(N=300, phi=(np.sqrt(5)-1)/2, center=np.zeros(3), r=1): 42 | """ View sampling on a unit sphere using Fibonacci lattices. 43 | Ref: https://arxiv.org/abs/0912.4540 44 | 45 | Input: 46 | N: [int] 47 | number of sampled views 48 | phi: [float] 49 | constant for view coordinate calculation, different phi's bring different distributions, default: (sqrt(5)-1)/2 50 | center: [np.ndarray, (3,), np.float32] 51 | sphere center 52 | r: [float] 53 | sphere radius 54 | 55 | Output: 56 | views: [torch.FloatTensor, (N,3)] 57 | sampled view coordinates 58 | """ 59 | views = [] 60 | for i in range(N): 61 | zi = (2 * i + 1) / N - 1 62 | xi = np.sqrt(1 - zi**2) * np.cos(2 * i * np.pi * phi) 63 | yi = np.sqrt(1 - zi**2) * np.sin(2 * i * np.pi * phi) 64 | views.append([xi, yi, zi]) 65 | views = r * np.array(views) + center 66 | return torch.from_numpy(views.astype(np.float32)) 67 | 68 | def batch_viewpoint_params_to_matrix(batch_towards, batch_angle): 69 | """ Transform approach vectors and in-plane rotation angles to rotation matrices. 70 | 71 | Input: 72 | batch_towards: [torch.FloatTensor, (N,3)] 73 | approach vectors in batch 74 | batch_angle: [torch.floatTensor, (N,)] 75 | in-plane rotation angles in batch 76 | 77 | Output: 78 | batch_matrix: [torch.floatTensor, (N,3,3)] 79 | rotation matrices in batch 80 | """ 81 | axis_x = batch_towards 82 | ones = torch.ones(axis_x.shape[0], dtype=axis_x.dtype, device=axis_x.device) 83 | zeros = torch.zeros(axis_x.shape[0], dtype=axis_x.dtype, device=axis_x.device) 84 | axis_y = torch.stack([-axis_x[:,1], axis_x[:,0], zeros], dim=-1) 85 | mask_y = (torch.norm(axis_y, dim=-1) == 0) 86 | axis_y[mask_y,1] = 1 87 | axis_x = axis_x / torch.norm(axis_x, dim=-1, keepdim=True) 88 | axis_y = axis_y / torch.norm(axis_y, dim=-1, keepdim=True) 89 | axis_z = torch.cross(axis_x, axis_y) 90 | sin = torch.sin(batch_angle) 91 | cos = torch.cos(batch_angle) 92 | R1 = torch.stack([ones, zeros, zeros, zeros, cos, -sin, zeros, sin, cos], dim=-1) 93 | R1 = R1.reshape([-1,3,3]) 94 | R2 = torch.stack([axis_x, axis_y, axis_z], dim=-1) 95 | batch_matrix = torch.matmul(R2, R1) 96 | return batch_matrix 97 | 98 | def huber_loss(error, delta=1.0): 99 | """ 100 | Args: 101 | error: Torch tensor (d1,d2,...,dk) 102 | Returns: 103 | loss: Torch tensor (d1,d2,...,dk) 104 | 105 | x = error = pred - gt or dist(pred,gt) 106 | 0.5 * |x|^2 if |x|<=d 107 | 0.5 * d^2 + d * (|x|-d) if |x|>d 108 | Author: Charles R. Qi 109 | Ref: https://github.com/charlesq34/frustum-pointnets/blob/master/models/model_util.py 110 | """ 111 | abs_error = torch.abs(error) 112 | quadratic = torch.clamp(abs_error, max=delta) 113 | linear = (abs_error - quadratic) 114 | loss = 0.5 * quadratic**2 + delta * linear 115 | return loss -------------------------------------------------------------------------------- /graspnet/utils/motion_planner.py: -------------------------------------------------------------------------------- 1 | # Collistion free motion planning for the robot 2 | # powered by pybullet_planning 3 | # 4 | from pybullet_planning.pybullet_tools.utils import get_movable_joints, get_joint_positions, set_joint_positions, \ 5 | plan_joint_motion, link_from_name, INF 6 | from pybullet_planning.pybullet_tools.ikfast.ikfast import get_ik_joints, either_inverse_kinematics, check_ik_solver 7 | from pybullet_planning.pybullet_tools.ikfast.franka_panda.ik import PANDA_INFO, FRANKA_URDF 8 | 9 | def collistion_free_motion_planning(robot, joints, end_pos, obstacles, self_collisions=False, **kwargs): 10 | """ 11 | Using pybullet_planning to plan a collision free motion for the robot 12 | Alogorithm is birrt: Bi-directional Rapidly-exploring Random Tree""" 13 | tool_link = link_from_name(robot, 'panda_hand') 14 | # inverse kinematics 15 | final_conf = next(either_inverse_kinematics(robot, PANDA_INFO, tool_link, end_pos, use_pybullet=False, 16 | max_distance=INF, max_time=10, max_candidates=INF, **kwargs), None) 17 | print("final_conf: ", final_conf, flush=True) 18 | if final_conf is None: 19 | print("No IK solution found!", flush=True) 20 | return None 21 | rtn = plan_joint_motion(robot, joints, final_conf, obstacles=obstacles, 22 | self_collisions=self_collisions) 23 | if rtn is None: 24 | print("No path found!", flush=True) 25 | return None 26 | return rtn 27 | 28 | 29 | # test 30 | if __name__ == '__main__': 31 | from pybullet_planning.pybullet_tools.utils import connect, load_model, disconnect, wait_if_gui, create_box, set_point, dump_body, \ 32 | HideOutput, LockRenderer, joint_from_name, set_euler, get_euler, get_point, \ 33 | set_joint_position, get_joint_positions, pairwise_collision, stable_z, wait_for_duration, get_link_pose, \ 34 | link_from_name, get_pose, euler_from_quat, multiply, invert, draw_pose, unit_point, unit_quat, \ 35 | remove_debug, get_aabb, draw_aabb, get_subtree_aabb, ROOMBA_URDF, set_all_static, assign_link_colors, \ 36 | set_camera_pose, RGBA, draw_point, Pose, Point, Euler 37 | from pybullet_planning.pybullet_tools.ikfast.ikfast import get_ik_joints, either_inverse_kinematics, check_ik_solver 38 | from pybullet_planning.pybullet_tools.ikfast.franka_panda.ik import PANDA_INFO, FRANKA_URDF 39 | from pybullet_planning.pybullet_tools.utils import get_movable_joints, get_joint_positions, set_joint_positions, \ 40 | plan_joint_motion, link_from_name 41 | import numpy as np 42 | 43 | connect(use_gui=True) 44 | robot = load_model(FRANKA_URDF) 45 | joints = get_movable_joints(robot) 46 | tool_link = link_from_name(robot, 'panda_hand') 47 | print("joints: ", joints, flush=True) 48 | # end_pos = get_link_pose(robot, tool_link) 49 | sample_joint = [0.0, -0.785, 0.0, -2.356, 0.0, 1.571, 0.785, 0.0, 0.0] 50 | # set initial configuration 51 | set_joint_positions(robot, joints, sample_joint) 52 | end_pos = Pose(Point(0.5, 0.5, 0.5), Euler(0, 0, 0)) 53 | obstacles = [] 54 | # draw boxes as obstacles 55 | for i in range(5): 56 | box = create_box(0.1, 0.1, 0.1, color=(1, 0, 0, 0.5)) 57 | set_point(box, [0.5, 0.2, 0.3 * i]) 58 | set_euler(box, [0, 0, np.pi / 4 * i]) 59 | obstacles.append(box) 60 | # draw aabb 61 | for obstacle in obstacles: 62 | obstacle_aabb = get_subtree_aabb(obstacle) 63 | draw_aabb(obstacle_aabb) 64 | # planning 65 | ik_joints = get_ik_joints(robot, PANDA_INFO, tool_link) 66 | path = collistion_free_motion_planning(robot, ik_joints, end_pos, obstacles) 67 | print(path) 68 | # draw path 69 | for i in range(len(path) - 1): 70 | set_joint_positions(robot, ik_joints, path[i]) 71 | wait_for_duration(0.1) 72 | wait_if_gui() 73 | disconnect() -------------------------------------------------------------------------------- /graspnet/utils/ros_adapter.py: -------------------------------------------------------------------------------- 1 | # convert Grasp from graspnet to grasp message in ROS 2 | import moveit_msgs.msg 3 | import geometry_msgs.msg 4 | import rospy 5 | import copy 6 | import numpy as np 7 | from scipy.spatial.transform import Rotation as R 8 | from graspnetAPI.grasp import Grasp 9 | 10 | def graspnet_to_ros(grasp, extrinsic_mat): 11 | grasp_w = grasp_cam2world(grasp, extrinsic_mat) 12 | rot_mat = grasp_w.rotation_matrix 13 | rot_quat = R.from_matrix(rot_mat).as_quat() 14 | grasp_msg = moveit_msgs.msg.Grasp() 15 | # grasp pose 16 | grasp_msg.grasp_pose.header.frame_id = "panda_link0" 17 | orientation = geometry_msgs.msg.Quaternion() 18 | orientation.x = rot_quat[0] 19 | orientation.y = rot_quat[1] 20 | orientation.z = rot_quat[2] 21 | orientation.w = rot_quat[3] 22 | grasp_msg.grasp_pose.pose.orientation = orientation 23 | 24 | translation = grasp_w.translation # FIXME: need to compensate for the gripper height 25 | grasp_msg.grasp_pose.pose.position.x = translation[0] 26 | grasp_msg.grasp_pose.pose.position.y = translation[1] 27 | grasp_msg.grasp_pose.pose.position.z = translation[2] 28 | # pre-grasp approach 29 | grasp_msg.pre_grasp_approach.direction.header.frame_id = "panda_link0" 30 | # direction is the same as gripper orientation 31 | direction_vec = np.dot(rot_mat, np.array([1, 0, 0])) 32 | grasp_msg.pre_grasp_approach.direction.vector.x = direction_vec[0] 33 | grasp_msg.pre_grasp_approach.direction.vector.y = direction_vec[1] 34 | grasp_msg.pre_grasp_approach.direction.vector.z = direction_vec[2] 35 | grasp_msg.pre_grasp_approach.min_distance = 0.095 36 | grasp_msg.pre_grasp_approach.desired_distance = 0.115 37 | # post-grasp retreat 38 | grasp_msg.post_grasp_retreat.direction.header.frame_id = "panda_link0" 39 | # direction is positive z axis 40 | grasp_msg.post_grasp_retreat.direction.vector.z = 1.0 41 | grasp_msg.post_grasp_retreat.min_distance = 0.05 42 | grasp_msg.post_grasp_retreat.desired_distance = 0.10 43 | # open gripper 44 | open_gripper(grasp_msg.pre_grasp_posture) 45 | # close gripper 46 | close_gripper(grasp_msg.grasp_posture) 47 | 48 | return grasp_msg 49 | 50 | 51 | def open_gripper(pre_grasp_posture): 52 | joint_names = ['panda_finger_joint1', 'panda_finger_joint2'] 53 | joint_values = [0.04, 0.04] 54 | pre_grasp_posture.joint_names = joint_names 55 | pre_grasp_posture.points = [moveit_msgs.msg.JointTrajectoryPoint( 56 | positions=joint_values, time_from_start=rospy.Duration(0.5))] 57 | 58 | def close_gripper(grasp_posture): 59 | joint_names = ['panda_finger_joint1', 'panda_finger_joint2'] 60 | joint_values = [0.0, 0.0] 61 | grasp_posture.joint_names = joint_names 62 | grasp_posture.points = [moveit_msgs.msg.JointTrajectoryPoint( 63 | positions=joint_values, time_from_start=rospy.Duration(0.5), 64 | effort=[1.0, 1.0])] 65 | 66 | 67 | def grasp_cam2world(grasp, extrinsic_mat): 68 | # convert grasp pose from camera frame to world frame 69 | # grasp: Grasp object from graspnet 70 | # extrinsic_mat: 4x4 extrinsic matrix from camera to world 71 | # return: Grasp object in world frame 72 | rot_mat = grasp.rotation_matrix 73 | translation = grasp.translation 74 | # convert translation 75 | translation = np.dot(extrinsic_mat, np.append(translation, 1.0))[:3] 76 | # convert rotation 77 | rot_mat = np.dot(extrinsic_mat[:3, :3], rot_mat) 78 | # swap axises 79 | # rot_mat = np.dot(np.array([[0, 0, 1], [0, -1, 0], [1, 0, 0]]), rot_mat) 80 | grasp_w = copy.deepcopy(grasp) 81 | grasp_w.translation = translation 82 | grasp_w.rotation_matrix = rot_mat 83 | return grasp_w 84 | -------------------------------------------------------------------------------- /graspnet/visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import open3d as o3d 5 | import argparse 6 | import importlib 7 | import scipy.io as scio 8 | from PIL import Image 9 | 10 | import torch 11 | from graspnetAPI import GraspGroup 12 | 13 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 14 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 15 | sys.path.append(os.path.join(ROOT_DIR, 'dataset')) 16 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 17 | 18 | from graspnet import GraspNet, pred_decode 19 | from graspnet_dataset import GraspNetDataset 20 | from collision_detector import ModelFreeCollisionDetector 21 | from data_utils import CameraInfo, create_point_cloud_from_depth_image 22 | 23 | 24 | DATA_ROOT = 'dataset/scenes' 25 | WORKSPACE_MASK = True 26 | 27 | def get_net(cfgs): 28 | # Init the model 29 | net = GraspNet(input_feature_dim=0, num_view=cfgs.num_view, num_angle=12, num_depth=4, 30 | cylinder_radius=0.05, hmin=-0.02, hmax_list=[0.01,0.02,0.03,0.04], is_training=False) 31 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 32 | net.to(device) 33 | # Load checkpoint 34 | checkpoint = torch.load(cfgs.checkpoint_path) 35 | net.load_state_dict(checkpoint['model_state_dict']) 36 | start_epoch = checkpoint['epoch'] 37 | print("-> loaded checkpoint %s (epoch: %d)"%(cfgs.checkpoint_path, start_epoch)) 38 | # set model to eval mode 39 | net.eval() 40 | return net 41 | 42 | def get_and_process_data(index): 43 | # load data 44 | index = str(index).zfill(4) 45 | color = np.array(Image.open(os.path.join(DATA_ROOT, 'rgb', index+'.png')), dtype=np.float32) / 255.0 46 | depth = scio.loadmat(os.path.join(DATA_ROOT, 'depth', index+'.mat'))['A'] 47 | meta = scio.loadmat(os.path.join(DATA_ROOT, 'meta', index+'.mat')) 48 | if WORKSPACE_MASK: 49 | workspace_mask = scio.loadmat(os.path.join(DATA_ROOT, 'mask', index+'.mat'))['A'] 50 | else: 51 | workspace_mask = None 52 | intrinsic = meta['intrinsic_matrix'] 53 | factor_depth = meta['factor_depth'] 54 | return process_data(cfgs, color, depth, intrinsic, factor_depth, workspace_mask) 55 | 56 | def process_data(cfgs, color, depth, intrinsic, factor_depth, workspace_mask=None): 57 | # generate cloud 58 | camera = CameraInfo(640.0, 480.0, intrinsic[0][0], intrinsic[1][1], intrinsic[0][2], intrinsic[1][2], factor_depth) 59 | cloud = create_point_cloud_from_depth_image(depth, camera, organized=True) 60 | 61 | # get valid points 62 | if WORKSPACE_MASK: 63 | mask = ((workspace_mask > 1) & (depth > 0)) 64 | else: 65 | mask = depth > 0 66 | cloud_masked = cloud[mask] 67 | color_masked = color[mask] 68 | 69 | # sample points 70 | if len(cloud_masked) >= cfgs.num_point: 71 | idxs = np.random.choice(len(cloud_masked), cfgs.num_point, replace=False) 72 | else: 73 | idxs1 = np.arange(len(cloud_masked)) 74 | idxs2 = np.random.choice(len(cloud_masked), cfgs.num_point-len(cloud_masked), replace=True) 75 | idxs = np.concatenate([idxs1, idxs2], axis=0) 76 | cloud_sampled = cloud_masked[idxs] 77 | color_sampled = color_masked[idxs] 78 | 79 | # convert data 80 | cloud_viz = o3d.geometry.PointCloud() 81 | cloud_viz.points = o3d.utility.Vector3dVector(cloud[depth>0].astype(np.float32)) 82 | cloud_viz.colors = o3d.utility.Vector3dVector(color[depth>0].astype(np.float32)) 83 | end_points = dict() 84 | cloud_sampled = torch.from_numpy(cloud_sampled[np.newaxis].astype(np.float32)) 85 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 86 | cloud_sampled = cloud_sampled.to(device) 87 | end_points['point_clouds'] = cloud_sampled 88 | end_points['cloud_colors'] = color_sampled 89 | 90 | return end_points, cloud_viz 91 | 92 | def get_grasps(net, end_points) -> GraspGroup: 93 | # Forward pass 94 | with torch.no_grad(): 95 | end_points = net(end_points) 96 | grasp_preds = pred_decode(end_points) 97 | gg_array = grasp_preds[0].detach().cpu().numpy() 98 | gg = GraspGroup(gg_array) 99 | return gg 100 | 101 | def collision_detection(gg, cloud, cfgs): 102 | mfcdetector = ModelFreeCollisionDetector(cloud, voxel_size=cfgs.voxel_size) 103 | collision_mask = mfcdetector.detect(gg, approach_dist=0.05, collision_thresh=cfgs.collision_thresh) 104 | gg = gg[~collision_mask] 105 | return gg 106 | 107 | def vis_grasps(gg, cloud): 108 | gg.nms() 109 | gg.sort_by_score() 110 | gg = gg[:50] 111 | grippers = gg.to_open3d_geometry_list() 112 | o3d.visualization.draw_geometries([cloud, *grippers]) 113 | 114 | def demo(index, cfgs): 115 | net = get_net(cfgs) 116 | end_points, cloud = get_and_process_data(index) 117 | gg = get_grasps(net, end_points) 118 | if cfgs.collision_thresh > 0: 119 | gg = collision_detection(gg, np.array(cloud.points), cfgs) 120 | vis_grasps(gg, cloud) 121 | 122 | if __name__=='__main__': 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument('--checkpoint_path', required=True, help='Model checkpoint path') 125 | parser.add_argument('--num_point', type=int, default=20000, help='Point Number [default: 20000]') 126 | parser.add_argument('--num_view', type=int, default=300, help='View Number [default: 300]') 127 | parser.add_argument('--collision_thresh', type=float, default=0.01, help='Collision Threshold in collision detection [default: 0.01]') 128 | parser.add_argument('--voxel_size', type=float, default=0.01, help='Voxel Size to process point clouds before collision detection [default: 0.01]') 129 | cfgs = parser.parse_args() 130 | 131 | index = 1 132 | demo(index, cfgs) 133 | -------------------------------------------------------------------------------- /real_world/README.md: -------------------------------------------------------------------------------- 1 | # Real-World Deployment Instruction 2 | 3 | ## Hardware Setup 4 | - Franka Emika Panda robot with parallel jaw gripper. 5 | - Two Intel RealSense D435 RGB-D camera. 6 | - PC with Nvidia RTX3090. 7 | 8 | ![](../asset/hardware_setup.jpg) 9 | 10 | ## Installation 11 | We use [Franka ROS](https://frankaemika.github.io/docs/franka_ros.html) and [MoveIt 1](https://moveit.ai/) to control the robot, which by default uses an RRT-Connect planner for motion planning. 12 | 13 | *Installation steps to be updated.* 14 | 15 | ## Launch 16 | 17 | A launch script looks like below. Please modify according to your network configuration and file structure and fill in {camera1_serial_no}, {camera2_serial_no}. 18 | ```xml 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | ``` 48 | 49 | After all nodes launched, run grasp or motion script: 50 | ```console 51 | $ python functional_grasp/functional_grasp.py # For grasping 52 | $ python behavior/behavior.py # For post-grasp motion 53 | ``` -------------------------------------------------------------------------------- /real_world/generate_point_cloud/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | catkin_install_python(PROGRAMS 2 | scripts/generate_point_cloud.py 3 | DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} 4 | ) 5 | 6 | install(DIRECTORY launch DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}) 7 | -------------------------------------------------------------------------------- /real_world/generate_point_cloud/launch/generate_point_cloud.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /real_world/generate_point_cloud/scripts/visualize.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import numpy as np 3 | 4 | 5 | # pointcloud = np.load('/root/data/pointcloud/point_cloud.npy') 6 | pointcloud = np.load('/root/data/pointcloud/pointcloud.npy') 7 | print(pointcloud.shape) 8 | pointcloud = pointcloud.reshape(-1, 6) 9 | pcd = o3d.geometry.PointCloud() 10 | pcd.points = o3d.utility.Vector3dVector(pointcloud[:, :3]) 11 | pcd.colors = o3d.utility.Vector3dVector(pointcloud[:, 3:] / 255.0) 12 | o3d.visualization.draw_geometries_with_editing([pcd]) -------------------------------------------------------------------------------- /real_world/get_point_cloud/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | catkin_install_python(PROGRAMS 2 | scripts/get_point_cloud.py 3 | DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} 4 | ) 5 | 6 | install(DIRECTORY launch DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}) 7 | -------------------------------------------------------------------------------- /real_world/get_point_cloud/launch/get_point_cloud.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /real_world/get_point_cloud/scripts/get_point_cloud.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import rospy 4 | import numpy as np 5 | from sensor_msgs.msg import Image, CameraInfo 6 | from cv_bridge import CvBridge 7 | from pathlib import Path 8 | import cv2 9 | 10 | class PointCloudGenerator: 11 | def __init__(self): 12 | rospy.init_node('point_cloud_generator', anonymous=True) 13 | 14 | # 订阅所需的主题 15 | rospy.Subscriber('/cam_2/aligned_depth_to_color/image_raw', Image, self.depth_callback) 16 | rospy.Subscriber('/cam_2/color/image_raw', Image, self.color_callback) 17 | rospy.Subscriber('/cam_2/color/camera_info', CameraInfo, self.camera_info_callback) 18 | 19 | self.bridge = CvBridge() 20 | self.depth_image = None 21 | self.color_image = None 22 | self.camera_info = None 23 | # import pdb;pdb.set_trace() 24 | rospy.sleep(1) 25 | self.generate_point_cloud() 26 | # 将RGB保存为图片,转化为RGB 27 | # cv2.imwrite('/root/ws_moveit/src/franka_scripts/get_point_cloud/scripts/color.png', cv2.cvtColor(self.color_image, cv2.COLOR_RGB2BGR)) 28 | # cv2.imwrite('color.png', ) 29 | 30 | def depth_callback(self, data): 31 | self.depth_image = self.bridge.imgmsg_to_cv2(data) 32 | 33 | def color_callback(self, data): 34 | self.color_image = self.bridge.imgmsg_to_cv2(data, desired_encoding='rgb8') 35 | 36 | def camera_info_callback(self, data): 37 | self.camera_info = data 38 | 39 | def hole_filling(self, depth_image: np.ndarray): 40 | import matplotlib.pyplot as plt 41 | plt.imshow(depth_image) 42 | plt.show() 43 | depth_image_shape = depth_image.shape 44 | depth_image = depth_image.copy().reshape(-1) 45 | outliners = depth_image < 0.05 46 | print('outliners: ', np.shape(outliners)) 47 | outliner_idx = np.where(outliners) 48 | outliner_idx = np.squeeze(outliner_idx) 49 | inliner_idx = np.where(~outliners) 50 | inliner_idx = np.squeeze(inliner_idx) 51 | print(np.shape(outliner_idx)) 52 | depth_image[outliners] = np.interp(outliner_idx, inliner_idx, depth_image[~outliners]) 53 | depth_image = depth_image.reshape(depth_image_shape) 54 | plt.imshow(depth_image) 55 | plt.show() 56 | return depth_image 57 | 58 | def generate_point_cloud(self): 59 | if self.depth_image is None or self.color_image is None or self.camera_info is None: 60 | return None 61 | 62 | fx = self.camera_info.K[0] 63 | fy = self.camera_info.K[4] 64 | cx = self.camera_info.K[2] 65 | cy = self.camera_info.K[5] 66 | 67 | depth_image = self.hole_filling(self.depth_image) 68 | h, w = depth_image.shape 69 | point_cloud_with_color = np.zeros((h, w, 6), dtype=np.float32) 70 | 71 | for i in range(h): 72 | for j in range(w): 73 | color = self.color_image[i, j] 74 | depth = depth_image[i, j] / 1000.0 75 | # x = (i - cx) * depth / fx 76 | # y = (j - cy) * depth / fy 77 | x = (j - cx) * depth / fx 78 | y = (i - cy) * depth / fy 79 | z = depth 80 | point_cloud_with_color[i, j] = [x, y, z, color[0], color[1], color[2]] 81 | # visualize point_cloud_with_color 82 | # import pdb;pdb.set_trace() 83 | import open3d as o3d 84 | print(point_cloud_with_color[:, :, :3].shape) 85 | pcd = o3d.geometry.PointCloud() 86 | pcd.points = o3d.utility.Vector3dVector(point_cloud_with_color[:, :, :3].reshape(-1, 3)) 87 | pcd.colors = o3d.utility.Vector3dVector(point_cloud_with_color[:, :, 3:].reshape(-1, 3) / 255.0) 88 | task_name = 'pour_water' 89 | data_dir = Path('/root/ws_moveit/src/env_pics') / task_name 90 | data_dir.mkdir(parents=True, exist_ok=True) 91 | np.save(data_dir / f'{task_name}.npy', point_cloud_with_color) 92 | cv2.imwrite(str(data_dir / f'{task_name}.png'), cv2.cvtColor(self.color_image, cv2.COLOR_RGB2BGR)) 93 | o3d.visualization.draw_geometries([pcd]) 94 | 95 | 96 | 97 | return point_cloud_with_color 98 | 99 | if __name__ == '__main__': 100 | generator = PointCloudGenerator() 101 | rospy.spin() -------------------------------------------------------------------------------- /real_world/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | franka_scripts 4 | 0.0.0 5 | The franka_scripts package 6 | 7 | 8 | 9 | 10 | Haoxu Huang 11 | 12 | 13 | 14 | 15 | 16 | TODO 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | catkin 52 | roscpp 53 | rospy 54 | std_msgs 55 | pluginlib 56 | eigen 57 | moveit_core 58 | moveit_ros_planning 59 | moveit_ros_planning_interface 60 | moveit_ros_perception 61 | interactive_markers 62 | geometric_shapes 63 | moveit_visual_tools 64 | rviz_visual_tools 65 | pcl_ros 66 | pcl_conversions 67 | rosbag 68 | tf2_ros 69 | tf2_eigen 70 | tf2_geometry_msgs 71 | 72 | roscpp 73 | rospy 74 | std_msgs 75 | 76 | panda_moveit_config 77 | franka_description 78 | pluginlib 79 | moveit_core 80 | moveit_commander 81 | moveit_fake_controller_manager 82 | moveit_ros_planning_interface 83 | moveit_ros_perception 84 | interactive_markers 85 | moveit_visual_tools 86 | rviz_visual_tools 87 | joint_state_publisher 88 | robot_state_publisher 89 | joy 90 | pcl_ros 91 | pcl_conversions 92 | rosbag 93 | rviz 94 | tf2_ros 95 | tf2_eigen 96 | tf2_geometry_msgs 97 | xacro 98 | nodelet 99 | gazebo_ros 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /real_world/robot_homing/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | catkin_install_python(PROGRAMS 2 | scripts/robot_homing.py 3 | DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} 4 | ) 5 | 6 | install(DIRECTORY launch DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}) 7 | -------------------------------------------------------------------------------- /real_world/robot_homing/launch/robot_homing.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /real_world/robot_homing/scripts/robot_homing.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import rospy 4 | import moveit_commander 5 | import moveit_msgs.msg 6 | import geometry_msgs.msg 7 | from math import pi, tau, dist, fabs, sin, cos 8 | 9 | from std_msgs.msg import String 10 | from moveit_commander.conversions import pose_to_list 11 | 12 | 13 | def all_close(goal, actual, tolerance=1e-3): 14 | for i in range(len(goal)): 15 | if fabs(goal[i] - actual[i]) > tolerance: 16 | return False 17 | return True 18 | 19 | class RobotHoming(object): 20 | def __init__(self) -> None: 21 | super().__init__() 22 | # First initialize `moveit_commander`_ and a `rospy`_ node: 23 | moveit_commander.roscpp_initialize(sys.argv) 24 | rospy.init_node('robot_homing', anonymous=True) 25 | 26 | # Instantiate a `RobotCommander`_ object. This object is an interface to 27 | # kinematic model and the current state of the robot: 28 | robot = moveit_commander.RobotCommander() 29 | 30 | # Instantiate a `PlanningSceneInterface`_ object. This object is an interface 31 | # to the world surrounding the robot: 32 | scene = moveit_commander.PlanningSceneInterface() 33 | 34 | # Instantiate a `MoveGroupCommander`_ object. This object is an interface 35 | # to one group of joints. In this case the group is the joints in the left 36 | # arm. This interface can be used to plan and execute motions on the left 37 | # arm: 38 | group_name = "panda_arm" 39 | move_group = moveit_commander.MoveGroupCommander(group_name) 40 | 41 | # Create a `DisplayTrajectory`_ ROS publisher which is used to display 42 | # trajectories in Rviz: 43 | display_trajectory_publisher = rospy.Publisher( 44 | '/move_group/display_planned_path', 45 | moveit_msgs.msg.DisplayTrajectory, 46 | queue_size=20) 47 | 48 | # Getting Basic Information 49 | # We can get the name of the reference frame for this robot: 50 | planning_frame = move_group.get_planning_frame() 51 | print("============ Planning frame: %s" % planning_frame) 52 | # end effector link 53 | eef_link = move_group.get_end_effector_link() 54 | print("============ End effector link: %s" % eef_link) 55 | # We can also print the name of the end-effector link for this group: 56 | group_names = robot.get_group_names() 57 | print("============ Available Planning Groups:", robot.get_group_names()) 58 | 59 | self.robot = robot 60 | self.scene = scene 61 | self.move_group = move_group 62 | self.display_trajectory_publisher = display_trajectory_publisher 63 | self.eef_link = eef_link 64 | self.group_names = group_names 65 | # home joint angles 66 | self.home_pose = [0.000,-0.785,0.0,-1.90,0.0,1.37,0.785] 67 | 68 | def get_robot_states(self): 69 | joint_state = self.move_group.get_current_joint_values() 70 | eef_pose = self.move_group.get_current_pose().pose 71 | return joint_state, eef_pose 72 | 73 | def go_home(self): 74 | # get current state 75 | joint_state, eef_pose = self.get_robot_states() 76 | # set home pose 77 | joint_state[0:7] = self.home_pose 78 | self.move_group.go(joint_state, wait=True) 79 | self.move_group.stop() 80 | # check if reached home pose 81 | current_joints = self.move_group.get_current_joint_values() 82 | return all_close(joint_state, current_joints, 0.01) 83 | 84 | 85 | if __name__ == '__main__': 86 | robot_homing = RobotHoming() 87 | success = robot_homing.go_home() 88 | if success: 89 | rospy.loginfo("Homing successful!") 90 | else: 91 | rospy.loginfo("Homing failed!") 92 | # print current state 93 | joint_state, eef_pose = robot_homing.get_robot_states() 94 | rospy.loginfo("Joint states: {}".format(joint_state)) 95 | rospy.loginfo("End effector pose: {}".format(eef_pose)) 96 | # wait for user input 97 | input("Press enter to exit ;") -------------------------------------------------------------------------------- /real_world/robot_states_monitor/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | catkin_install_python(PROGRAMS 2 | scripts/robot_states_monitor.py 3 | DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} 4 | ) 5 | 6 | install(DIRECTORY launch DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}) 7 | -------------------------------------------------------------------------------- /real_world/robot_states_monitor/launch/robot_states_monitor.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /real_world/robot_states_monitor/scripts/robot_states_monitor.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 4 | import copy 5 | import rospy 6 | import moveit_commander 7 | import moveit_msgs.msg 8 | import geometry_msgs.msg 9 | from math import pi, tau, dist, fabs, sin, cos 10 | 11 | from std_msgs.msg import String 12 | from moveit_commander.conversions import pose_to_list 13 | 14 | 15 | 16 | class RobotStatesMonitor(object): 17 | def __init__(self) -> None: 18 | super().__init__() 19 | # First initialize `moveit_commander`_ and a `rospy`_ node: 20 | moveit_commander.roscpp_initialize(sys.argv) 21 | rospy.init_node('robot_states_monitor', anonymous=True) 22 | 23 | 24 | # Instantiate a `RobotCommander`_ object. This object is an interface to 25 | # kinematic model and the current state of the robot: 26 | robot = moveit_commander.RobotCommander() 27 | 28 | # Instantiate a `PlanningSceneInterface`_ object. This object is an interface 29 | # to the world surrounding the robot: 30 | scene = moveit_commander.PlanningSceneInterface() 31 | 32 | # Instantiate a `MoveGroupCommander`_ object. This object is an interface 33 | # to one group of joints. In this case the group is the joints in the left 34 | # arm. This interface can be used to plan and execute motions on the left 35 | # arm: 36 | group_name = "panda_arm" 37 | move_group = moveit_commander.MoveGroupCommander(group_name) 38 | 39 | # Create a `DisplayTrajectory`_ ROS publisher which is used to display 40 | # trajectories in Rviz: 41 | display_trajectory_publisher = rospy.Publisher( 42 | '/move_group/display_planned_path', 43 | moveit_msgs.msg.DisplayTrajectory, 44 | queue_size=20) 45 | 46 | # Getting Basic Information 47 | # We can get the name of the reference frame for this robot: 48 | planning_frame = move_group.get_planning_frame() 49 | print("============ Planning frame: %s" % planning_frame) 50 | # end effector link 51 | eef_link = move_group.get_end_effector_link() 52 | print("============ End effector link: %s" % eef_link) 53 | # We can also print the name of the end-effector link for this group: 54 | group_names = robot.get_group_names() 55 | print("============ Available Planning Groups:", robot.get_group_names()) 56 | 57 | self.robot = robot 58 | self.scene = scene 59 | self.move_group = move_group 60 | self.display_trajectory_publisher = display_trajectory_publisher 61 | self.eef_link = eef_link 62 | self.group_names = group_names 63 | 64 | def get_robot_states(self): 65 | joint_state = self.move_group.get_current_joint_values() 66 | eef_pose = self.move_group.get_current_pose().pose 67 | return joint_state, eef_pose 68 | 69 | def monitor(self, duration=0.1): 70 | while not rospy.is_shutdown(): 71 | joint_state, eef_pose = self.get_robot_states() 72 | print("Joint states: ", joint_state) 73 | print("End effector pose: ", eef_pose) 74 | rospy.sleep(duration) 75 | 76 | 77 | if __name__ == '__main__': 78 | try: 79 | robot_states_monitor = RobotStatesMonitor() 80 | robot_states_monitor.monitor() 81 | except rospy.ROSInterruptException: 82 | pass -------------------------------------------------------------------------------- /real_world/trajectory_recorder/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | catkin_install_python(PROGRAMS 2 | scripts/trajectory_recorder.py 3 | DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} 4 | ) 5 | 6 | install(DIRECTORY launch DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION}) 7 | -------------------------------------------------------------------------------- /real_world/trajectory_recorder/launch/trajectory_recorder.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /real_world/trajectory_recorder/scripts/trajectory_recorder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import copy 4 | import time 5 | import rospy 6 | import atexit 7 | import moveit_commander 8 | import moveit_msgs.msg 9 | import geometry_msgs.msg 10 | import pickle 11 | import numpy as np 12 | import dataclasses 13 | from scipy.spatial.transform import Rotation as R 14 | from math import pi, tau, dist, fabs, sin, cos 15 | 16 | from std_msgs.msg import String 17 | from moveit_commander.conversions import pose_to_list 18 | 19 | def pose_msg_to_matrix(pose_msg): 20 | # convert pose msg to matrix 21 | pose_matrix = np.zeros((4, 4)) 22 | pose_matrix[0:3, 0:3] = R.from_quat( 23 | [pose_msg.orientation.x, 24 | pose_msg.orientation.y, 25 | pose_msg.orientation.z, 26 | pose_msg.orientation.w]).as_matrix() 27 | pose_matrix[0:3, 3] = [pose_msg.position.x, 28 | pose_msg.position.y, 29 | pose_msg.position.z] 30 | pose_matrix[3, 3] = 1.0 31 | return pose_matrix 32 | 33 | 34 | @dataclasses.dataclass 35 | class RobotStateFrame: 36 | timestamp: int # in milliseconds 37 | joint_state: list 38 | eef_pose: np.ndarray 39 | 40 | def asdict(self): 41 | return dataclasses.asdict(self) 42 | 43 | 44 | class RobotStatesRecorder(object): 45 | def __init__(self, frame_rate=10, max_frames=36000) -> None: 46 | super().__init__() 47 | # First initialize `moveit_commander`_ and a `rospy`_ node: 48 | moveit_commander.roscpp_initialize(sys.argv) 49 | rospy.init_node('robot_traj_recorder', anonymous=True) 50 | 51 | 52 | # Instantiate a `RobotCommander`_ object. This object is an interface to 53 | # kinematic model and the current state of the robot: 54 | robot = moveit_commander.RobotCommander() 55 | 56 | # Instantiate a `PlanningSceneInterface`_ object. This object is an interface 57 | # to the world surrounding the robot: 58 | scene = moveit_commander.PlanningSceneInterface() 59 | 60 | # Instantiate a `MoveGroupCommander`_ object. This object is an interface 61 | # to one group of joints. In this case the group is the joints in the left 62 | # arm. This interface can be used to plan and execute motions on the left 63 | # arm: 64 | group_name = "panda_arm" 65 | move_group = moveit_commander.MoveGroupCommander(group_name) 66 | 67 | # record command subscriber 68 | rospy.Subscriber('/robot_traj_recorder/command', String, self.command_callback) 69 | 70 | # Getting Basic Information 71 | # We can get the name of the reference frame for this robot: 72 | planning_frame = move_group.get_planning_frame() 73 | print("============ Planning frame: %s" % planning_frame) 74 | # end effector link 75 | eef_link = move_group.get_end_effector_link() 76 | print("============ End effector link: %s" % eef_link) 77 | # We can also print the name of the end-effector link for this group: 78 | group_names = robot.get_group_names() 79 | print("============ Available Planning Groups:", robot.get_group_names()) 80 | 81 | self.robot = robot 82 | self.scene = scene 83 | self.move_group = move_group 84 | self.eef_link = eef_link 85 | self.group_names = group_names 86 | self.recording = False 87 | self.recording_frames = [] 88 | self.frame_rate = rospy.Rate(frame_rate) 89 | self.max_frames = max_frames 90 | 91 | def get_robot_states(self): 92 | joint_state = self.move_group.get_current_joint_values() 93 | eef_pose = self.move_group.get_current_pose() 94 | return joint_state, eef_pose 95 | 96 | def command_callback(self, data): 97 | command = data.data 98 | if command == 'start': 99 | self.start_recording() 100 | elif command == 'stop': 101 | self.stop_recording() 102 | elif command == 'clear': 103 | self.clear_recording() 104 | 105 | def start_recording(self): 106 | if self.recording: 107 | rospy.logwarn("Already recording...") 108 | self.recording = True 109 | self.recording_frames = [] 110 | rospy.loginfo("Start recording...") 111 | 112 | def stop_recording(self): 113 | if not self.recording: 114 | rospy.logwarn("Not recording...") 115 | return 116 | self.recording = False 117 | rospy.loginfo("Stop recording...") 118 | # save to file 119 | pickle.dump(self.recording_frames, open('/root/data/trajectory_record_{}.pkl'.format(time.strftime("%H-%M-%S")), 'wb')) 120 | rospy.loginfo("{} frames have been saved to file...".format(len(self.recording_frames))) 121 | 122 | def clear_recording(self): 123 | self.recording_frames = [] 124 | rospy.loginfo("Clear recording...") 125 | 126 | def record(self): 127 | if self.recording: 128 | if len(self.recording_frames) >= self.max_frames: 129 | rospy.logwarn("Recording frames exceed max frames, stop recording...") 130 | self.stop_recording() 131 | return 132 | joint_state, eef_pose = self.get_robot_states() 133 | timestamp = eef_pose.header.stamp.to_nsec() // 1000000 134 | eef_pose_mat = pose_msg_to_matrix(eef_pose.pose) 135 | robot_state_frame = RobotStateFrame(timestamp, joint_state, eef_pose_mat) 136 | self.recording_frames.append(robot_state_frame.asdict()) 137 | self.frame_rate.sleep() 138 | 139 | 140 | 141 | if __name__ == '__main__': 142 | recorder = RobotStatesRecorder() 143 | @atexit.register 144 | def shutdown(): 145 | rospy.loginfo("Shutting down...") 146 | recorder.stop_recording() 147 | rospy.loginfo("Shutdown complete.") 148 | 149 | while not rospy.is_shutdown(): 150 | recorder.record() -------------------------------------------------------------------------------- /real_world/trajectory_recorder/trajectory_recorder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import copy 4 | import time 5 | import rospy 6 | import moveit_commander 7 | import moveit_msgs.msg 8 | import geometry_msgs.msg 9 | import pickle 10 | import numpy as np 11 | import dataclasses 12 | from scipy.spatial.transform import Rotation as R 13 | from math import pi, tau, dist, fabs, sin, cos 14 | 15 | from std_msgs.msg import String 16 | from moveit_commander.conversions import pose_to_list 17 | 18 | def pose_msg_to_matrix(pose_msg): 19 | # convert pose msg to matrix 20 | pose_matrix = np.zeros((4, 4)) 21 | pose_matrix[0:3, 0:3] = R.from_quat( 22 | [pose_msg.orientation.x, 23 | pose_msg.orientation.y, 24 | pose_msg.orientation.z, 25 | pose_msg.orientation.w]).as_matrix() 26 | pose_matrix[0:3, 3] = [pose_msg.position.x, 27 | pose_msg.position.y, 28 | pose_msg.position.z] 29 | pose_matrix[3, 3] = 1.0 30 | return pose_matrix 31 | 32 | 33 | @dataclasses.dataclass 34 | class RobotStateFrame: 35 | timestamp: int # in milliseconds 36 | joint_state: list 37 | eef_pose: np.ndarray 38 | 39 | class RobotStatesRecorder(object): 40 | def __init__(self, frame_rate=10, max_frames=36000) -> None: 41 | super().__init__() 42 | # First initialize `moveit_commander`_ and a `rospy`_ node: 43 | moveit_commander.roscpp_initialize(sys.argv) 44 | rospy.init_node('robot_traj_recorder', anonymous=True) 45 | 46 | 47 | # Instantiate a `RobotCommander`_ object. This object is an interface to 48 | # kinematic model and the current state of the robot: 49 | robot = moveit_commander.RobotCommander() 50 | 51 | # Instantiate a `PlanningSceneInterface`_ object. This object is an interface 52 | # to the world surrounding the robot: 53 | scene = moveit_commander.PlanningSceneInterface() 54 | 55 | # Instantiate a `MoveGroupCommander`_ object. This object is an interface 56 | # to one group of joints. In this case the group is the joints in the left 57 | # arm. This interface can be used to plan and execute motions on the left 58 | # arm: 59 | group_name = "panda_arm" 60 | move_group = moveit_commander.MoveGroupCommander(group_name) 61 | 62 | # record command subscriber 63 | rospy.Subscriber('/robot_traj_recorder/command', String, self.command_callback) 64 | 65 | # Getting Basic Information 66 | # We can get the name of the reference frame for this robot: 67 | planning_frame = move_group.get_planning_frame() 68 | print("============ Planning frame: %s" % planning_frame) 69 | # end effector link 70 | eef_link = move_group.get_end_effector_link() 71 | print("============ End effector link: %s" % eef_link) 72 | # We can also print the name of the end-effector link for this group: 73 | group_names = robot.get_group_names() 74 | print("============ Available Planning Groups:", robot.get_group_names()) 75 | 76 | self.robot = robot 77 | self.scene = scene 78 | self.move_group = move_group 79 | self.eef_link = eef_link 80 | self.group_names = group_names 81 | self.recording = False 82 | self.frame_rate = rospy.Rate(frame_rate) 83 | self.max_frames = max_frames 84 | 85 | def get_robot_states(self): 86 | joint_state = self.move_group.get_current_joint_values() 87 | eef_pose = self.move_group.get_current_pose() 88 | return joint_state, eef_pose 89 | 90 | def command_callback(self, data): 91 | command = data.data 92 | if command == 'start': 93 | self.start_recording() 94 | elif command == 'stop': 95 | self.stop_recording() 96 | elif command == 'clear': 97 | self.clear_recording() 98 | 99 | def start_recording(self): 100 | if self.recording: 101 | rospy.logwarn("Already recording...") 102 | self.recording = True 103 | self.recording_frames = [] 104 | rospy.loginfo("Start recording...") 105 | 106 | def stop_recording(self): 107 | if not self.recording: 108 | rospy.logwarn("Not recording...") 109 | self.recording = False 110 | rospy.loginfo("Stop recording...") 111 | # save to file 112 | pickle.dump(self.recording_frames, open('/root/data/trajectory_record_{}.pkl'.format(time.strftime("%H-%M-%S")), 'wb')) 113 | rospy.loginfo("{} frames have been saved to file...".format(len(self.recording_frames))) 114 | 115 | def clear_recording(self): 116 | self.recording_frames = [] 117 | rospy.loginfo("Clear recording...") 118 | 119 | def record(self): 120 | if self.recording: 121 | if len(self.recording_frames) >= self.max_frames: 122 | rospy.logwarn("Recording frames exceed max frames, stop recording...") 123 | self.stop_recording() 124 | return 125 | joint_state, eef_pose = self.get_robot_states() 126 | timestamp = eef_pose.header.stamp.to_nsec() // 1000000 127 | eef_pose_mat = pose_msg_to_matrix(eef_pose.pose) 128 | robot_state_frame = RobotStateFrame(timestamp, joint_state, eef_pose_mat) 129 | self.recording_frames.append(robot_state_frame) 130 | self.frame_rate.sleep() 131 | 132 | 133 | 134 | if __name__ == '__main__': 135 | recorder = RobotStatesRecorder() 136 | while not rospy.is_shutdown(): 137 | recorder.record() -------------------------------------------------------------------------------- /som_gpt4v/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /som_gpt4v/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /som_gpt4v/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /som_gpt4v/SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /som_gpt4v/SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /som_gpt4v/assets/method2_xyz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/som_gpt4v/assets/method2_xyz.png -------------------------------------------------------------------------------- /som_gpt4v/assets/som_bench_bottom.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/som_gpt4v/assets/som_bench_bottom.jpg -------------------------------------------------------------------------------- /som_gpt4v/assets/som_bench_upper.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/som_gpt4v/assets/som_bench_upper.jpg -------------------------------------------------------------------------------- /som_gpt4v/assets/som_gpt4v_demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/som_gpt4v/assets/som_gpt4v_demo.mp4 -------------------------------------------------------------------------------- /som_gpt4v/assets/som_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/som_gpt4v/assets/som_logo.png -------------------------------------------------------------------------------- /som_gpt4v/assets/som_toolbox_interface.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/som_gpt4v/assets/som_toolbox_interface.jpg -------------------------------------------------------------------------------- /som_gpt4v/assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/som_gpt4v/assets/teaser.png -------------------------------------------------------------------------------- /som_gpt4v/benchmark/README.md: -------------------------------------------------------------------------------- 1 | # SoM-Bench: Evaluating Visual Grounding with Visual Prompting 2 | 3 | We build a new benchmark called SoM-Bench to evaluate the visual grounding capability of LLMs with visual prompting. 4 | 5 | ## Dataset 6 | 7 | | Vision Taks | Source | #Images | #Instances | Marks | Metric | Data 8 | | -------- | -------- | -------- | -------- | -------- | -------- | -------- | 9 | | Open-Vocab Segmentation | [COCO](https://cocodataset.org/#home) | 100 | 567 | Numeric IDs and Masks | Precision | [Download](https://github.com/microsoft/SoM/releases/download/v1.0/coco_ovseg.zip) 10 | | Open-Vocab Segmentation | [ADE20K](https://groups.csail.mit.edu/vision/datasets/ADE20K/) | 100 | 488 | Numeric IDs and Masks | Precision | [Download](https://github.com/microsoft/SoM/releases/download/v1.0/ade20k_ovseg.zip) 11 | | Phrase Grounding | [Flickr30K](https://shannon.cs.illinois.edu/DenotationGraph/) | 100 | 274 | Numeric IDs and Masks and Boxes | Recall @ 1 | [Download](https://github.com/microsoft/SoM/releases/download/v1.0/flickr30k_grounding.zip) 12 | | Referring Comprehension | [RefCOCO](https://github.com/lichengunc/refer) | 100 | 177 | Numeric IDs and Masks | ACC @ 0.5 | [Download](https://github.com/microsoft/SoM/releases/download/v1.0/refcocog_refseg.zip) 13 | | Referring Segmentation | [RefCOCO](https://github.com/lichengunc/refer) | 100 | 177 | Numeric IDs and Masks | mIoU | [Download](https://github.com/microsoft/SoM/releases/download/v1.0/refcocog_refseg.zip) 14 | 15 | ## Dataset Structure 16 | 17 | ### Open-Vocab Segmentation on COCO 18 | 19 | We provide COCO in the following structure: 20 | 21 | ``` 22 | coco_ovseg 23 | ├── som_images 24 | ├── 000000000285_0.jpg 25 | ├── 000000000872_0.jpg 26 | |── 000000000872_5.jpg 27 | ├── ... 28 | ├── 000000002153_5.jpg 29 | └── 000000002261_0.jpg 30 | ``` 31 | 32 | For some of the samples, the regions are very dense, so we split the regions into multiple groups of size 5,. For example, `000000000872_0.jpg` has 5 regions, and `000000000872_5.jpg` has the other 5 regions. Note that you can use the image_id to track the original image. 33 | 34 | We used the following language prompt for the task: 35 | ``` 36 | I have labeled a bright numeric ID at the center for each visual object in the image. Please enumerate their names. You must answer by selecting from the following names: [COCO Vocabulary] 37 | ``` 38 | 39 | ### Open-Vocab Segmentation on ADE20K 40 | 41 | ``` 42 | ade20k_ovseg 43 | ├── som_images 44 | ├── ADE_val_00000001_0.jpg 45 | ├── ADE_val_00000001_5.jpg 46 | |── ADE_val_00000011_5.jpg 47 | ├── ... 48 | ├── ADE_val_00000039_5.jpg 49 | └── ADE_val_00000040_0.jpg 50 | ``` 51 | Similar to COCO, the regions in ADE20K are also very dense, so we split the regions into multiple groups of size 5,. For example, `ADE_val_00000001_0.jpg` has 5 regions, and `ADE_val_00000001_5.jpg` has the other 5 regions. Note that you can use the image_id to track the original image. 52 | 53 | We used the following language prompt for the task: 54 | ``` 55 | I have labeled a bright numeric ID at the center for each visual object in the image. Please enumerate their names. You must answer by selecting from the following names: [ADE20K Vocabulary] 56 | ``` 57 | 58 | ### Phrase Grounding on Flickr30K 59 | 60 | ``` 61 | flickr30k_grounding 62 | ├── som_images 63 | ├── 14868339.jpg 64 | ├── 14868339_wbox.jpg 65 | |── 14868339.json 66 | ├── ... 67 | ├── 302740416.jpg 68 | |── 319185571_wbox.jpg 69 | └── 302740416.json 70 | ``` 71 | 72 | For Flickr30K, we provide the image with numeric IDs and masks, and also the image with additional bounding boxes. The json file containing the ground truth bounding boxes and the corresponding phrases. Note that the bounding boxes are in the format of [x1, y1, x2, y2]. 73 | 74 | We used the following language prompt for the task: 75 | ``` 76 | I have labeled a bright numeric ID at the center for each visual object in the image. Given the image showing a man in glasses holding a piece of paper, find the corresponding regions for a man in glasses, a piece of paper. 77 | ``` 78 | 79 | ### Referring Expression Comprehension and Segmentation on RefCOCOg 80 | 81 | ``` 82 | refcocog_refseg 83 | ├── som_images 84 | ├── 000000000795.jpg 85 | |── 000000000795.json 86 | ├── ... 87 | |── 000000007852.jpg 88 | └── 000000007852.json 89 | ``` 90 | 91 | For RefCOCOg, we provide the image with numeric IDs and masks, and also the json file containing the referring expressions and the corresponding referring ids. 92 | 93 | We used the following language prompt for the task: 94 | ``` 95 | I have labeled a bright numeric ID at the center for each visual object in the image. Please tell me the IDs for: The laptop behind the beer bottle; Laptop turned on. 96 | ``` 97 | -------------------------------------------------------------------------------- /som_gpt4v/camera_extrinsic2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/som_gpt4v/camera_extrinsic2.npy -------------------------------------------------------------------------------- /som_gpt4v/dataset_tasks.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from PIL import Image 3 | import numpy as np 4 | 5 | # Path to the dataset 6 | dataset_path = Path('dataset') 7 | 8 | class DatasetTasks: 9 | def __init__(self, task_name) -> None: 10 | self.task_name = task_name 11 | self.dataset_path = dataset_path / task_name 12 | if not self.dataset_path.exists(): 13 | raise ValueError(f"Dataset {task_name} not found at {self.dataset_path}") 14 | self.data_list = list((self.dataset_path / 'pointcloud').glob('*.npy')) 15 | 16 | def __getitem__(self, idx): 17 | data_name = self.data_list[idx].stem 18 | color1 = Image.open(self.dataset_path / 'color' / f'{data_name}_color1.png') 19 | color2 = Image.open(self.dataset_path / 'color' / f'{data_name}_color2.png') 20 | depth1 = np.load(self.dataset_path / 'depth' / f'{data_name}_depth1.npy') 21 | depth2 = np.load(self.dataset_path / 'depth' / f'{data_name}_depth2.npy') 22 | pointcloud = np.load(self.dataset_path / 'pointcloud' / f'{data_name}.npy') 23 | return color1, color2, depth1, depth2, pointcloud 24 | 25 | def __len__(self): 26 | return len(self.data_list) 27 | 28 | 29 | # Example usage 30 | if __name__ == '__main__': 31 | dataset = DatasetTasks('dataset_button') 32 | print(len(dataset)) 33 | for i in range(len(dataset)): 34 | color1, color2, depth1, depth2, pointcloud = dataset[i] 35 | print(color1.size, color2.size, depth1.shape, depth2.shape, pointcloud.shape) 36 | break -------------------------------------------------------------------------------- /som_gpt4v/download_ckpt.sh: -------------------------------------------------------------------------------- 1 | # wget https://github.com/UX-Decoder/Semantic-SAM/releases/download/checkpoint/swinl_only_sam_many2many.pth 2 | wget https://huggingface.co/xdecoder/SEEM/resolve/main/seem_focall_v1.pt 3 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 4 | -------------------------------------------------------------------------------- /som_gpt4v/draw_arrow.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import numpy as np 3 | 4 | def draw_geometries(pcds): 5 | """ 6 | Draw Geometries 7 | Args: 8 | - pcds (): [pcd1,pcd2,...] 9 | """ 10 | o3d.visualization.draw_geometries(pcds) 11 | 12 | def get_o3d_FOR(origin=[0, 0, 0],size=10): 13 | """ 14 | Create a FOR that can be added to the open3d point cloud 15 | """ 16 | mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame( 17 | size=size) 18 | mesh_frame.translate(origin) 19 | return(mesh_frame) 20 | 21 | def vector_magnitude(vec): 22 | """ 23 | Calculates a vector's magnitude. 24 | Args: 25 | - vec (): 26 | """ 27 | magnitude = np.sqrt(np.sum(vec**2)) 28 | return(magnitude) 29 | 30 | 31 | def calculate_zy_rotation_for_arrow(vec): 32 | """ 33 | Calculates the rotations required to go from the vector vec to the 34 | z axis vector of the original FOR. The first rotation that is 35 | calculated is over the z axis. This will leave the vector vec on the 36 | XZ plane. Then, the rotation over the y axis. 37 | 38 | Returns the angles of rotation over axis z and y required to 39 | get the vector vec into the same orientation as axis z 40 | of the original FOR 41 | 42 | Args: 43 | - vec (): 44 | """ 45 | # Rotation over z axis of the FOR 46 | gamma = np.arctan(vec[1]/vec[0]) 47 | Rz = np.array([[np.cos(gamma),-np.sin(gamma),0], 48 | [np.sin(gamma),np.cos(gamma),0], 49 | [0,0,1]]) 50 | # Rotate vec to calculate next rotation 51 | vec = Rz.T@vec.reshape(-1,1) 52 | vec = vec.reshape(-1) 53 | # Rotation over y axis of the FOR 54 | beta = np.arctan(vec[0]/vec[2]) 55 | Ry = np.array([[np.cos(beta),0,np.sin(beta)], 56 | [0,1,0], 57 | [-np.sin(beta),0,np.cos(beta)]]) 58 | return(Rz, Ry) 59 | 60 | def create_arrow(scale=10): 61 | """ 62 | Create an arrow in for Open3D 63 | """ 64 | cone_height = scale*0.2 65 | cylinder_height = scale*0.8 66 | cone_radius = scale/10 67 | cylinder_radius = scale/20 68 | mesh_frame = o3d.geometry.TriangleMesh.create_arrow(cone_radius=cone_radius, 69 | cone_height=cone_height, 70 | cylinder_radius=cylinder_radius, 71 | cylinder_height=cylinder_height) 72 | return(mesh_frame) 73 | 74 | def get_arrow(origin=[0, 0, 0], end=None, vec=None): 75 | """ 76 | Creates an arrow from an origin point to an end point, 77 | or create an arrow from a vector vec starting from origin. 78 | Args: 79 | - end (): End point. [x,y,z] 80 | - vec (): Vector. [i,j,k] 81 | """ 82 | scale = 1 83 | Ry = Rz = np.eye(3) 84 | T = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) 85 | T[:3, -1] = origin 86 | if end is not None: 87 | vec = np.array(end) - np.array(origin) 88 | elif vec is not None: 89 | vec = np.array(vec) 90 | if end is not None or vec is not None: 91 | scale = vector_magnitude(vec) 92 | Rz, Ry = calculate_zy_rotation_for_arrow(vec) 93 | mesh = create_arrow(scale) 94 | # Create the arrow 95 | mesh.rotate(Ry, center=np.array([0, 0, 0])) 96 | mesh.rotate(Rz, center=np.array([0, 0, 0])) 97 | mesh.translate(origin) 98 | return(mesh) 99 | 100 | 101 | if __name__ == '__main__': 102 | 103 | # Create a Cartesian Frame of Reference 104 | FOR = get_o3d_FOR() 105 | # Create an arrow from point (5,5,5) to point (10,10,10) 106 | # arrow = get_arrow([5,5,5],[10,10,10]) 107 | 108 | # Create an arrow representing vector vec, starting at (5,5,5) 109 | # arrow = get_arrow([5,5,5],vec=[5,5,5]) 110 | 111 | # Create an arrow in the same place as the z axis 112 | arrow = get_arrow(origin=[0,1,1], vec=[0, np.sqrt(2)/2, np.sqrt(2)/2]) 113 | 114 | # Draw everything 115 | draw_geometries([FOR,arrow]) -------------------------------------------------------------------------------- /som_gpt4v/examples/gpt-4v-som-example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/som_gpt4v/examples/gpt-4v-som-example.jpg -------------------------------------------------------------------------------- /som_gpt4v/examples/ironing_man.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/som_gpt4v/examples/ironing_man.jpg -------------------------------------------------------------------------------- /som_gpt4v/examples/ironing_man_som.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/som_gpt4v/examples/ironing_man_som.png -------------------------------------------------------------------------------- /som_gpt4v/examples/som_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/som_gpt4v/examples/som_logo.png -------------------------------------------------------------------------------- /som_gpt4v/first_prompt/prompt1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/som_gpt4v/first_prompt/prompt1.png -------------------------------------------------------------------------------- /som_gpt4v/first_prompt/prompt1.txt: -------------------------------------------------------------------------------- 1 | Given a segmented photo, answer the question as if you are a robot with a parallel jaw gripper (as shown in the image). Some objects in the segmented photo are labeled with numbers. You need to find all objects related to the task and refer to them with the corresponding numbers. At the same time, you need to ignore distracting objects that are irrelevant to the task. 2 | You can first analyze the task, point out the objects related to the task, and finally give the label corresponding to the object. 3 | To help you understand, I will show you several examples. 4 | 5 | Example 1: 6 | Instruction: Sweep the paper ball into the dustpan with a brush. -------------------------------------------------------------------------------- /som_gpt4v/first_prompt/prompt2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/som_gpt4v/first_prompt/prompt2.png -------------------------------------------------------------------------------- /som_gpt4v/first_prompt/prompt2.txt: -------------------------------------------------------------------------------- 1 | 2 | In this task, I need to sweep paper balls into a dustpan using a brush. Since the paper balls simply follow the movement of the brush, the object to be identified is the brush (Label 7). 3 | All other objects, including the dustpan, water cup, and parts of the robotic arm, are distractions. 4 | 5 | Object Label: [7] 6 | 7 | Example 2: 8 | Instruction: Put the book into the white shelf. -------------------------------------------------------------------------------- /som_gpt4v/first_prompt/prompt3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/som_gpt4v/first_prompt/prompt3.png -------------------------------------------------------------------------------- /som_gpt4v/first_prompt/prompt3.txt: -------------------------------------------------------------------------------- 1 | 2 | In this task, I need to identify the book and the white bookshelf. The book is labeled as Label 5, and the bookshelf as Label 3. 3 | Other items such as cups and the desk are considered distractions. 4 | 5 | Object Label: [3, 5] 6 | 7 | Example 3: 8 | Instruction: Pound garlic in wooden jar with stick. -------------------------------------------------------------------------------- /som_gpt4v/first_prompt/prompt4.txt: -------------------------------------------------------------------------------- 1 | 2 | In this task, I need to identify the wooden stick and the wooden jar, labeled as 8 and 4 respectively. 3 | Other items, such as parts of the robotic arm, blue blocks, dolls, desks, etc., are considered distractions. 4 | 5 | Object Label: [4, 8] 6 | 7 | This is the new task. 8 | Instruction: -------------------------------------------------------------------------------- /som_gpt4v/gpt4v_azure.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import os 4 | import base64 5 | import traceback 6 | from io import BytesIO 7 | 8 | api_base = ... 9 | deployment_name = ... 10 | API_KEY = os.environ.get('AZURE_API_KEY') 11 | 12 | base_url = f"{api_base}openai/deployments/{deployment_name}" 13 | headers = { 14 | "Content-Type": "application/json", 15 | "api-key": API_KEY 16 | } 17 | endpoint = f"{base_url}/chat/completions?api-version=2023-12-01-preview" 18 | 19 | def encode_image_from_file(image_path): 20 | with open(image_path, "rb") as image_file: 21 | return base64.b64encode(image_file.read()).decode('ascii') 22 | 23 | def encode_image_from_pil(image): 24 | buffered = BytesIO() 25 | image.save(buffered, format="JPEG") 26 | return base64.b64encode(buffered.getvalue()).decode('ascii') 27 | 28 | def encode_image_base64(img): 29 | with open(img, "rb") as image_file: 30 | encoded_string = base64.b64encode(image_file.read()) 31 | return encoded_string.decode('ascii') 32 | 33 | chat_history = None 34 | 35 | def clear_history(): 36 | global chat_history 37 | chat_history = None 38 | 39 | def prepare_inputs_multi_image(messages, images): 40 | content = [] 41 | for message, image in zip(messages, images): 42 | content.append({"type": "text", "text": message}) 43 | encode_function = encode_image_from_file if isinstance(image, str) else encode_image_from_pil 44 | content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_function(image)}"}}) 45 | 46 | if chat_history is None: 47 | payload = { 48 | "messages": [ 49 | { 50 | "role": "system", 51 | "content": '- For any marks mentioned in your answer, please highlight them with [].' 52 | }, 53 | { 54 | "role": "user", 55 | "content": content 56 | } 57 | ], 58 | "max_tokens": 800 59 | } 60 | else: 61 | payload = chat_history 62 | payload['messages'].append({ 63 | "role": "user", 64 | "content": content 65 | }) 66 | 67 | return payload 68 | 69 | def request_gpt4v_multi_image_azure(messages, images): 70 | global chat_history 71 | payload = prepare_inputs_multi_image(messages, images) 72 | response = requests.post(endpoint, headers=headers, data=json.dumps(payload)) 73 | response = json.loads(response.text) 74 | try: 75 | res = response['choices'][0]["message"]["content"] 76 | except: 77 | print(response) 78 | traceback.print_exc() 79 | return None 80 | 81 | chat_history = payload 82 | chat_history['messages'].append({ 83 | "role": "assistant", 84 | "content": res, 85 | }) 86 | return res 87 | 88 | def prepare_inputs_multi_image_behavior(messages, images): 89 | content = [] 90 | for message, image in zip(messages, images): 91 | content.append({"type": "text", "text": message}) 92 | encode_function = encode_image_from_file if isinstance(image, str) else encode_image_from_pil 93 | content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_function(image)}"}}) 94 | 95 | payload = { 96 | "messages": [ 97 | { 98 | "role": "system", 99 | "content": 'Your response must strictly adhere to the specified format.' 100 | }, 101 | { 102 | "role": "user", 103 | "content": content 104 | } 105 | ], 106 | "max_tokens": 800 107 | } 108 | 109 | return payload 110 | 111 | def request_gpt4v_multi_image_behavior_azure(messages, images): 112 | payload = prepare_inputs_multi_image_behavior(messages, images) 113 | response = requests.post(endpoint, headers=headers, data=json.dumps(payload)) 114 | response = json.loads(response.text) 115 | res = response['choices'][0]["message"]["content"] 116 | return res -------------------------------------------------------------------------------- /som_gpt4v/main_constraint.py: -------------------------------------------------------------------------------- 1 | from gpt4v import request_gpt4v_multi_image_behavior 2 | import traceback 3 | import os 4 | import argparse 5 | 6 | def gpt4v_response(instruction, image, prompt_path): 7 | for i in range(3): 8 | try: 9 | prompts, visions = [], [] 10 | for j in range(4): 11 | with open(f'{prompt_path}/prompt{j + 1}.txt', 'r') as f: 12 | prompts.append(f.read()) 13 | if j < 3: 14 | visions.append(f'{prompt_path}/prompt{j + 1}.png') 15 | prompts[-1] += instruction 16 | visions.append(image) 17 | res = request_gpt4v_multi_image_behavior(prompts, visions) 18 | return res 19 | except Exception as e: 20 | traceback.print_exc() 21 | continue 22 | 23 | parser = argparse.ArgumentParser(description='Process some instructions and images.') 24 | parser.add_argument('--instruction_base_dir', type=str, required=True, help='Base directory for instructions') 25 | parser.add_argument('--results_image_base_dir', type=str, required=True, help='Base directory for result images') 26 | parser.add_argument('--task', type=str, required=True, help='Task name') 27 | 28 | args = parser.parse_args() 29 | 30 | instruction_base_dir = args.instruction_base_dir 31 | results_image_base_dir = args.results_image_base_dir 32 | task = args.task 33 | 34 | with open(os.path.join(instruction_base_dir, task, 'instruction.txt'), 'r') as f: 35 | instruction = f.read() 36 | response = gpt4v_response(instruction, os.path.join(results_image_base_dir, task, 'final_output_add_mask.png'), 'behavior') 37 | 38 | print(response) 39 | with open(os.path.join(results_image_base_dir, task, 'constraint_response.txt'), 'w') as f: 40 | f.write(response) -------------------------------------------------------------------------------- /som_gpt4v/mask_filters.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | def corner_filter(mask): 5 | '''If the mask contains a corner, return False''' 6 | return not(mask[0, 0] or mask[0, -1] or mask[-1, 0] or mask[-1, -1]) 7 | 8 | def area_filter(mask, area_thresh): 9 | '''If the mask area is smaller than area_thresh, return False''' 10 | return mask.sum() > area_thresh 11 | 12 | def intersection_filter(mask1, mask2, intersection_thresh): 13 | '''If the intersection of mask1 and mask2 is smaller than intersection_thresh, return False''' 14 | # resize mask2 to the same size as mask1 15 | mask2 = np.array(Image.fromarray(mask2).resize(mask1.shape[1::-1])) 16 | # calculate the intersection area 17 | intersection = np.logical_and(mask1, mask2).sum() 18 | return intersection / mask1.sum() > intersection_thresh 19 | 20 | def get_mask_filter(corner, area, intersection, area_thresh=100, mask2=None, intersection_thresh=0.1): 21 | def mask_filter(output): 22 | mask = output['segmentation'] 23 | if corner and not corner_filter(mask): 24 | return False 25 | if area and not area_filter(mask, area_thresh): 26 | return False 27 | if intersection and not intersection_filter(mask, mask2, intersection_thresh): 28 | return False 29 | return True 30 | return mask_filter 31 | 32 | 33 | if __name__ == '__main__': 34 | mask = np.zeros((10, 20), dtype=bool) 35 | mask[2:8, 5:15] = True 36 | output = {'segmentation': mask} 37 | mask2 = np.zeros((20, 40), dtype=bool) 38 | mask2[4:16, 20:40] = True 39 | flt = get_mask_filter(True, True, True, area_thresh=10, mask2=mask2) 40 | print(flt(output)) -------------------------------------------------------------------------------- /som_gpt4v/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn_func import MSDeformAttnFunction 13 | 14 | -------------------------------------------------------------------------------- /som_gpt4v/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from torch.autograd import Function 19 | from torch.autograd.function import once_differentiable 20 | 21 | try: 22 | import MultiScaleDeformableAttention as MSDA 23 | except ModuleNotFoundError as e: 24 | info_string = ( 25 | "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n" 26 | "\t`cd mask2former/modeling/pixel_decoder/ops`\n" 27 | "\t`sh make.sh`\n" 28 | ) 29 | raise ModuleNotFoundError(info_string) 30 | 31 | 32 | class MSDeformAttnFunction(Function): 33 | @staticmethod 34 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 35 | ctx.im2col_step = im2col_step 36 | output = MSDA.ms_deform_attn_forward( 37 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 38 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 39 | return output 40 | 41 | @staticmethod 42 | @once_differentiable 43 | def backward(ctx, grad_output): 44 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 45 | grad_value, grad_sampling_loc, grad_attn_weight = \ 46 | MSDA.ms_deform_attn_backward( 47 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 48 | 49 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 50 | 51 | 52 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 53 | # for debug and test only, 54 | # need to use cuda version instead 55 | N_, S_, M_, D_ = value.shape 56 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 57 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 58 | sampling_grids = 2 * sampling_locations - 1 59 | sampling_value_list = [] 60 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 61 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 62 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 63 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 64 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 65 | # N_*M_, D_, Lq_, P_ 66 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 67 | mode='bilinear', padding_mode='zeros', align_corners=False) 68 | sampling_value_list.append(sampling_value_l_) 69 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 70 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 71 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 72 | return output.transpose(1, 2).contiguous() 73 | -------------------------------------------------------------------------------- /som_gpt4v/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | # Copyright (c) Facebook, Inc. and its affiliates. 11 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 12 | 13 | python setup.py build install 14 | -------------------------------------------------------------------------------- /som_gpt4v/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn import MSDeformAttn 13 | -------------------------------------------------------------------------------- /som_gpt4v/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | import os 13 | import glob 14 | 15 | import torch 16 | 17 | from torch.utils.cpp_extension import CUDA_HOME 18 | from torch.utils.cpp_extension import CppExtension 19 | from torch.utils.cpp_extension import CUDAExtension 20 | 21 | from setuptools import find_packages 22 | from setuptools import setup 23 | 24 | requirements = ["torch", "torchvision"] 25 | 26 | def get_extensions(): 27 | this_dir = os.path.dirname(os.path.abspath(__file__)) 28 | extensions_dir = os.path.join(this_dir, "src") 29 | 30 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 31 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 32 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 33 | 34 | sources = main_file + source_cpu 35 | extension = CppExtension 36 | extra_compile_args = {"cxx": []} 37 | define_macros = [] 38 | 39 | # Force cuda since torch ask for a device, not if cuda is in fact available. 40 | if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None: 41 | extension = CUDAExtension 42 | sources += source_cuda 43 | define_macros += [("WITH_CUDA", None)] 44 | extra_compile_args["nvcc"] = [ 45 | "-DCUDA_HAS_FP16=1", 46 | "-D__CUDA_NO_HALF_OPERATORS__", 47 | "-D__CUDA_NO_HALF_CONVERSIONS__", 48 | "-D__CUDA_NO_HALF2_OPERATORS__", 49 | ] 50 | else: 51 | if CUDA_HOME is None: 52 | raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.') 53 | else: 54 | raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().') 55 | 56 | sources = [os.path.join(extensions_dir, s) for s in sources] 57 | include_dirs = [extensions_dir] 58 | ext_modules = [ 59 | extension( 60 | "MultiScaleDeformableAttention", 61 | sources, 62 | include_dirs=include_dirs, 63 | define_macros=define_macros, 64 | extra_compile_args=extra_compile_args, 65 | ) 66 | ] 67 | return ext_modules 68 | 69 | setup( 70 | name="MultiScaleDeformableAttention", 71 | version="1.0", 72 | author="Weijie Su", 73 | url="https://github.com/fundamentalvision/Deformable-DETR", 74 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 75 | packages=find_packages(exclude=("configs", "tests",)), 76 | ext_modules=get_extensions(), 77 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 78 | ) 79 | -------------------------------------------------------------------------------- /som_gpt4v/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include 17 | 18 | #include 19 | #include 20 | 21 | 22 | at::Tensor 23 | ms_deform_attn_cpu_forward( 24 | const at::Tensor &value, 25 | const at::Tensor &spatial_shapes, 26 | const at::Tensor &level_start_index, 27 | const at::Tensor &sampling_loc, 28 | const at::Tensor &attn_weight, 29 | const int im2col_step) 30 | { 31 | AT_ERROR("Not implement on cpu"); 32 | } 33 | 34 | std::vector 35 | ms_deform_attn_cpu_backward( 36 | const at::Tensor &value, 37 | const at::Tensor &spatial_shapes, 38 | const at::Tensor &level_start_index, 39 | const at::Tensor &sampling_loc, 40 | const at::Tensor &attn_weight, 41 | const at::Tensor &grad_output, 42 | const int im2col_step) 43 | { 44 | AT_ERROR("Not implement on cpu"); 45 | } 46 | 47 | -------------------------------------------------------------------------------- /som_gpt4v/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor 20 | ms_deform_attn_cpu_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step); 27 | 28 | std::vector 29 | ms_deform_attn_cpu_backward( 30 | const at::Tensor &value, 31 | const at::Tensor &spatial_shapes, 32 | const at::Tensor &level_start_index, 33 | const at::Tensor &sampling_loc, 34 | const at::Tensor &attn_weight, 35 | const at::Tensor &grad_output, 36 | const int im2col_step); 37 | 38 | 39 | -------------------------------------------------------------------------------- /som_gpt4v/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor ms_deform_attn_cuda_forward( 20 | const at::Tensor &value, 21 | const at::Tensor &spatial_shapes, 22 | const at::Tensor &level_start_index, 23 | const at::Tensor &sampling_loc, 24 | const at::Tensor &attn_weight, 25 | const int im2col_step); 26 | 27 | std::vector ms_deform_attn_cuda_backward( 28 | const at::Tensor &value, 29 | const at::Tensor &spatial_shapes, 30 | const at::Tensor &level_start_index, 31 | const at::Tensor &sampling_loc, 32 | const at::Tensor &attn_weight, 33 | const at::Tensor &grad_output, 34 | const int im2col_step); 35 | 36 | -------------------------------------------------------------------------------- /som_gpt4v/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | 18 | #include "cpu/ms_deform_attn_cpu.h" 19 | 20 | #ifdef WITH_CUDA 21 | #include "cuda/ms_deform_attn_cuda.h" 22 | #endif 23 | 24 | 25 | at::Tensor 26 | ms_deform_attn_forward( 27 | const at::Tensor &value, 28 | const at::Tensor &spatial_shapes, 29 | const at::Tensor &level_start_index, 30 | const at::Tensor &sampling_loc, 31 | const at::Tensor &attn_weight, 32 | const int im2col_step) 33 | { 34 | if (value.type().is_cuda()) 35 | { 36 | #ifdef WITH_CUDA 37 | return ms_deform_attn_cuda_forward( 38 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 39 | #else 40 | AT_ERROR("Not compiled with GPU support"); 41 | #endif 42 | } 43 | AT_ERROR("Not implemented on the CPU"); 44 | } 45 | 46 | std::vector 47 | ms_deform_attn_backward( 48 | const at::Tensor &value, 49 | const at::Tensor &spatial_shapes, 50 | const at::Tensor &level_start_index, 51 | const at::Tensor &sampling_loc, 52 | const at::Tensor &attn_weight, 53 | const at::Tensor &grad_output, 54 | const int im2col_step) 55 | { 56 | if (value.type().is_cuda()) 57 | { 58 | #ifdef WITH_CUDA 59 | return ms_deform_attn_cuda_backward( 60 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 61 | #else 62 | AT_ERROR("Not compiled with GPU support"); 63 | #endif 64 | } 65 | AT_ERROR("Not implemented on the CPU"); 66 | } 67 | 68 | -------------------------------------------------------------------------------- /som_gpt4v/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include "ms_deform_attn.h" 17 | 18 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 19 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 20 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 21 | } 22 | -------------------------------------------------------------------------------- /som_gpt4v/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import time 17 | import torch 18 | import torch.nn as nn 19 | from torch.autograd import gradcheck 20 | 21 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 22 | 23 | 24 | N, M, D = 1, 2, 2 25 | Lq, L, P = 2, 2, 2 26 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 27 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 28 | S = sum([(H*W).item() for H, W in shapes]) 29 | 30 | 31 | torch.manual_seed(3) 32 | 33 | 34 | @torch.no_grad() 35 | def check_forward_equal_with_pytorch_double(): 36 | value = torch.rand(N, S, M, D).cuda() * 0.01 37 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 38 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 39 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 40 | im2col_step = 2 41 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 42 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 43 | fwdok = torch.allclose(output_cuda, output_pytorch) 44 | max_abs_err = (output_cuda - output_pytorch).abs().max() 45 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 46 | 47 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 48 | 49 | 50 | @torch.no_grad() 51 | def check_forward_equal_with_pytorch_float(): 52 | value = torch.rand(N, S, M, D).cuda() * 0.01 53 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 54 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 55 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 56 | im2col_step = 2 57 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 58 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 59 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 60 | max_abs_err = (output_cuda - output_pytorch).abs().max() 61 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 62 | 63 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 64 | 65 | 66 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 67 | 68 | value = torch.rand(N, S, M, channels).cuda() * 0.01 69 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 70 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 71 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 72 | im2col_step = 2 73 | func = MSDeformAttnFunction.apply 74 | 75 | value.requires_grad = grad_value 76 | sampling_locations.requires_grad = grad_sampling_loc 77 | attention_weights.requires_grad = grad_attn_weight 78 | 79 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 80 | 81 | print(f'* {gradok} check_gradient_numerical(D={channels})') 82 | 83 | 84 | if __name__ == '__main__': 85 | check_forward_equal_with_pytorch_double() 86 | check_forward_equal_with_pytorch_float() 87 | 88 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 89 | check_gradient_numerical(channels, True, True, True) 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /som_gpt4v/second_prompt/prompt1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/som_gpt4v/second_prompt/prompt1.png -------------------------------------------------------------------------------- /som_gpt4v/second_prompt/prompt1.txt: -------------------------------------------------------------------------------- 1 | I have selected the objects you just chose, performed an additional segmentation on them, and marked them on the photo. Your task now is to identify the numbers corresponding to the lines or surfaces needed to complete the task from these numbers. 2 | To help you understand, I will show you several examples. 3 | 4 | Example 1: 5 | Instruction: Sweep the paper ball into the dustpan with a brush. -------------------------------------------------------------------------------- /som_gpt4v/second_prompt/prompt2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/som_gpt4v/second_prompt/prompt2.png -------------------------------------------------------------------------------- /som_gpt4v/second_prompt/prompt2.txt: -------------------------------------------------------------------------------- 1 | 2 | In this task, we are required to sweep paper balls into a dustpan using a brush. 3 | Since the only object manipulated by the robotic arm is the brush, with its most important part being the bristles (marked as number 1), it's essential to retain number 1. 4 | Numbers 5 and 7 represent the handle of the brush, which is smaller and does not need to interact with the paper balls. 5 | Number 10 is a minor corner of the brush and is not significant. 6 | Numbers 6, 8, and 9 belong to the paper balls and only need to move with the brush, so they are not necessary to keep. 7 | 8 | Object Label: [1] 9 | 10 | Example 2: 11 | Instruction: Put the book into the white shelf. -------------------------------------------------------------------------------- /som_gpt4v/second_prompt/prompt3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/som_gpt4v/second_prompt/prompt3.png -------------------------------------------------------------------------------- /som_gpt4v/second_prompt/prompt3.txt: -------------------------------------------------------------------------------- 1 | 2 | In this task, we are required to insert books into a white bookshelf, necessitating the retention of some key numbers on both the books and the shelf. 3 | Numbers 11, 12, and 13 are parts of a book, with number 11 being the book's large, prominent side, which we will retain. 4 | Numbers 9 and 14, representing the table surface, will be discarded. 5 | Numbers 5 and 8 are parts of a cup, irrelevant to the task, and thus discarded. 6 | Numbers 1, 2, 4, 6, and 7 are on the white bookshelf, where we will retain the clearer and more noticeable bottom surface 4 and side surface 1. 7 | 8 | Object Label: [1, 4, 11] 9 | 10 | Example 3: 11 | Instruction: Pound garlic in wooden jar with stick. 12 | -------------------------------------------------------------------------------- /som_gpt4v/second_prompt/prompt4.txt: -------------------------------------------------------------------------------- 1 | 2 | In this task, we are to pound garlic in a wooden jar using a stick, similar to inserting the stick into the jar. It's necessary to retain important numbers on both the stick and the jar. 3 | Number 3 represents the stick, which we will keep. 4 | Number 2 represents the mouth of the jar, and number 1 represents the body of the jar, with the mouth (number 2) being more important. 5 | 6 | Object Label: [2, 3] 7 | 8 | This is the new task. 9 | Instruction: -------------------------------------------------------------------------------- /som_gpt4v/task_adapter/sam/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/som_gpt4v/task_adapter/sam/__init__.py -------------------------------------------------------------------------------- /som_gpt4v/task_adapter/sam/tasks/__Init__.py: -------------------------------------------------------------------------------- 1 | from .inference_sam_m2m_auto import * 2 | from .inference_sam_m2m_interactive import * -------------------------------------------------------------------------------- /som_gpt4v/task_adapter/seem/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoxuHuang/copa/cb688e42c7a95e5310a034b5baecbf6476d791f8/som_gpt4v/task_adapter/seem/__init__.py -------------------------------------------------------------------------------- /som_gpt4v/task_adapter/seem/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .interactive_seem_m2m_auto import * 2 | from .inference_seem_pano import * 3 | from .inference_seem_interactive import * -------------------------------------------------------------------------------- /som_gpt4v/task_adapter/seem/tasks/interactive_seem_m2m_auto.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Semantic-SAM: Segment and Recognize Anything at Any Granularity 3 | # Copyright (c) 2023 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Hao Zhang (hzhangcx@connect.ust.hk) 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import numpy as np 10 | from torchvision import transforms 11 | from task_adapter.utils.visualizer import Visualizer 12 | from typing import Tuple 13 | from PIL import Image 14 | from detectron2.data import MetadataCatalog 15 | import matplotlib.pyplot as plt 16 | import cv2 17 | import io 18 | from .automatic_mask_generator import SeemAutomaticMaskGenerator 19 | metadata = MetadataCatalog.get('coco_2017_train_panoptic') 20 | 21 | def interactive_seem_m2m_auto(model, image, text_size, label_mode='1', alpha=0.1, anno_mode=['Mask']): 22 | t = [] 23 | t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC)) 24 | transform1 = transforms.Compose(t) 25 | image_ori = transform1(image) 26 | 27 | image_ori = np.asarray(image_ori) 28 | images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda() 29 | 30 | mask_generator = SeemAutomaticMaskGenerator(model) 31 | outputs = mask_generator.generate(images) 32 | 33 | from task_adapter.utils.visualizer import Visualizer 34 | visual = Visualizer(image_ori, metadata=metadata) 35 | sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True) 36 | label = 1 37 | for ann in sorted_anns: 38 | mask = ann['segmentation'] 39 | color_mask = np.random.random((1, 3)).tolist()[0] 40 | # color_mask = [int(c*255) for c in color_mask] 41 | demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode) 42 | label += 1 43 | im = demo.get_image() 44 | 45 | # fig=plt.figure(figsize=(10, 10)) 46 | # plt.imshow(image_ori) 47 | # show_anns(outputs) 48 | # fig.canvas.draw() 49 | # im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) 50 | return im 51 | 52 | 53 | def remove_small_regions( 54 | mask: np.ndarray, area_thresh: float, mode: str 55 | ) -> Tuple[np.ndarray, bool]: 56 | """ 57 | Removes small disconnected regions and holes in a mask. Returns the 58 | mask and an indicator of if the mask has been modified. 59 | """ 60 | import cv2 # type: ignore 61 | 62 | assert mode in ["holes", "islands"] 63 | correct_holes = mode == "holes" 64 | working_mask = (correct_holes ^ mask).astype(np.uint8) 65 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 66 | sizes = stats[:, -1][1:] # Row 0 is background label 67 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 68 | if len(small_regions) == 0: 69 | return mask, False 70 | fill_labels = [0] + small_regions 71 | if not correct_holes: 72 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 73 | # If every region is below threshold, keep largest 74 | if len(fill_labels) == 0: 75 | fill_labels = [int(np.argmax(sizes)) + 1] 76 | mask = np.isin(regions, fill_labels) 77 | return mask, True 78 | 79 | def show_anns(anns): 80 | if len(anns) == 0: 81 | return 82 | sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) 83 | ax = plt.gca() 84 | ax.set_autoscale_on(False) 85 | polygons = [] 86 | color = [] 87 | for ann in sorted_anns: 88 | m = ann['segmentation'] 89 | img = np.ones((m.shape[0], m.shape[1], 3)) 90 | color_mask = np.random.random((1, 3)).tolist()[0] 91 | for i in range(3): 92 | img[:,:,i] = color_mask[i] 93 | ax.imshow(np.dstack((img, m*0.35))) -------------------------------------------------------------------------------- /som_gpt4v/task_adapter/semantic_sam/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .interactive_idino_m2m import interactive_infer_image as interactive_infer_image_idino_m2m 2 | from .interactive_idino_m2m import interactive_infer_image_semantic, interactive_infer_image_3l 3 | from .inference_semsam_m2m_auto import inference_semsam_m2m_auto, inference_semsam_m2m_auto_remove_robot_arm 4 | from .interactive_idino_1o1_box import interactive_infer_image_box as interactive_infer_image_idino_m2m_box 5 | from .automatic_mask_generator import prompt_switch 6 | from .interactive_predictor import SemanticSAMPredictor -------------------------------------------------------------------------------- /som_gpt4v/task_adapter/semantic_sam/tasks/interactive_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torchvision import transforms 4 | from task_adapter.utils.visualizer import Visualizer 5 | from typing import Tuple 6 | from PIL import Image 7 | from detectron2.data import MetadataCatalog 8 | metadata = MetadataCatalog.get('coco_2017_train_panoptic') 9 | 10 | 11 | class SemanticSAMPredictor: 12 | def __init__(self, model, thresh=0.5, text_size=640, hole_scale=100, island_scale=100): 13 | """ 14 | thresh: iou thresh to filter low confidence objects 15 | text_size: resize the input image short edge for the model to process 16 | hole_scale: fill in small holes as in SAM 17 | island_scale: remove small regions as in SAM 18 | """ 19 | self.model = model 20 | self.thresh = thresh 21 | self.text_size = hole_scale 22 | self.hole_scale = hole_scale 23 | self.island_scale = island_scale 24 | self.point = None 25 | 26 | def predict(self, image_ori, image, point=None): 27 | """ 28 | produce up to 6 prediction results for each click 29 | """ 30 | width = image_ori.shape[0] 31 | height = image_ori.shape[1] 32 | 33 | data = {"image": image, "height": height, "width": width} 34 | # import ipdb; ipdb.set_trace() 35 | if point is None: 36 | point = torch.tensor([[0.5, 0.5, 0.006, 0.006]]).cuda() 37 | else: 38 | point = torch.tensor(point).cuda() 39 | point_ = point 40 | point = point_.clone() 41 | point[0, 0] = point_[0, 0] 42 | point[0, 1] = point_[0, 1] 43 | # point = point[:, [1, 0]] 44 | point = torch.cat([point, point.new_tensor([[0.005, 0.005]])], dim=-1) 45 | 46 | self.point = point[:, :2].clone()*(torch.tensor([width, height]).to(point)) 47 | 48 | data['targets'] = [dict()] 49 | data['targets'][0]['points'] = point 50 | data['targets'][0]['pb'] = point.new_tensor([0.]) 51 | 52 | batch_inputs = [data] 53 | masks, ious = self.model.model.evaluate_demo(batch_inputs) 54 | 55 | return masks, ious 56 | 57 | def process_multi_mask(self, masks, ious, image_ori): 58 | pred_masks_poses = masks 59 | reses = [] 60 | ious = ious[0, 0] 61 | ids = torch.argsort(ious, descending=True) 62 | 63 | text_res = '' 64 | mask_ls = [] 65 | ious_res = [] 66 | areas = [] 67 | for i, (pred_masks_pos, iou) in enumerate(zip(pred_masks_poses[ids], ious[ids])): 68 | iou = round(float(iou), 2) 69 | texts = f'{iou}' 70 | mask = (pred_masks_pos > 0.0).cpu().numpy() 71 | area = mask.sum() 72 | conti = False 73 | if iou < self.thresh: 74 | conti = True 75 | for m in mask_ls: 76 | if np.logical_and(mask, m).sum() / np.logical_or(mask, m).sum() > 0.95: 77 | conti = True 78 | break 79 | if i == len(pred_masks_poses[ids]) - 1 and mask_ls == []: 80 | conti = False 81 | if conti: 82 | continue 83 | ious_res.append(iou) 84 | mask_ls.append(mask) 85 | areas.append(area) 86 | mask, _ = self.remove_small_regions(mask, int(self.hole_scale), mode="holes") 87 | mask, _ = self.remove_small_regions(mask, int(self.island_scale), mode="islands") 88 | mask = (mask).astype(np.float) 89 | out_txt = texts 90 | visual = Visualizer(image_ori, metadata=metadata) 91 | color = [0., 0., 1.0] 92 | demo = visual.draw_binary_mask(mask, color=color, text=texts) 93 | res = demo.get_image() 94 | point_x0 = max(0, int(self.point[0, 0]) - 3) 95 | point_x1 = min(image_ori.shape[1], int(self.point[0, 0]) + 3) 96 | point_y0 = max(0, int(self.point[0, 1]) - 3) 97 | point_y1 = min(image_ori.shape[0], int(self.point[0, 1]) + 3) 98 | res[point_y0:point_y1, point_x0:point_x1, 0] = 255 99 | res[point_y0:point_y1, point_x0:point_x1, 1] = 0 100 | res[point_y0:point_y1, point_x0:point_x1, 2] = 0 101 | reses.append(Image.fromarray(res)) 102 | text_res = text_res + ';' + out_txt 103 | ids = list(torch.argsort(torch.tensor(areas), descending=False)) 104 | ids = [int(i) for i in ids] 105 | 106 | torch.cuda.empty_cache() 107 | 108 | return reses, [reses[i] for i in ids] 109 | 110 | def predict_masks(self, image_ori, image, point=None): 111 | masks, ious = self.predict(image_ori, image, point) 112 | return self.process_multi_mask(masks, ious, image_ori) 113 | 114 | @staticmethod 115 | def remove_small_regions( 116 | mask: np.ndarray, area_thresh: float, mode: str 117 | ) -> Tuple[np.ndarray, bool]: 118 | """ 119 | Removes small disconnected regions and holes in a mask. Returns the 120 | mask and an indicator of if the mask has been modified. 121 | """ 122 | import cv2 # type: ignore 123 | 124 | assert mode in ["holes", "islands"] 125 | correct_holes = mode == "holes" 126 | working_mask = (correct_holes ^ mask).astype(np.uint8) 127 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 128 | sizes = stats[:, -1][1:] # Row 0 is background label 129 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 130 | if len(small_regions) == 0: 131 | return mask, False 132 | fill_labels = [0] + small_regions 133 | if not correct_holes: 134 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 135 | # If every region is below threshold, keep largest 136 | if len(fill_labels) == 0: 137 | fill_labels = [int(np.argmax(sizes)) + 1] 138 | mask = np.isin(regions, fill_labels) 139 | return mask, True 140 | --------------------------------------------------------------------------------