├── .gitignore ├── LICENSE ├── README.md ├── _config.yml ├── cleanup.sh ├── evaluate.py ├── inference ├── __init__.py ├── grasp_generator.py ├── models │ ├── __init__.py │ ├── ggcnn.py │ ├── grasp_model.py │ ├── grconvnet.py │ ├── grconvnet2.py │ ├── grconvnet3.py │ ├── grconvnet4.py │ └── ragt │ │ ├── Anchor.py │ │ ├── mobile_vit.py │ │ ├── model_config.py │ │ ├── ragt.py │ │ └── transformer.py └── post_process.py ├── models └── grasp_det_seg │ ├── __init__.py │ ├── _version.py │ ├── algos │ ├── __init__.py │ ├── detection.py │ ├── fpn.py │ ├── rpn.py │ └── semantic_seg.py │ ├── config │ ├── __init__.py │ ├── config.py │ └── defaults │ │ └── det_seg_OCID.ini │ ├── data_OCID │ ├── OCID_class_dict.py │ ├── __init__.py │ ├── dataset.py │ ├── misc.py │ ├── sampler.py │ └── transform.py │ ├── models │ ├── __init__.py │ ├── det_seg.py │ └── resnet.py │ ├── modules │ ├── __init__.py │ ├── fpn.py │ ├── heads │ │ ├── __init__.py │ │ ├── fpn.py │ │ └── rpn.py │ ├── losses.py │ ├── misc.py │ └── residual.py │ └── utils │ ├── __init__.py │ ├── bbx │ ├── __init__.py │ ├── _backend.pyi │ └── bbx.py │ ├── logging.py │ ├── meters.py │ ├── misc.py │ ├── nms │ ├── __init__.py │ ├── _backend.pyi │ └── nms.py │ ├── parallel │ ├── __init__.py │ ├── data_parallel.py │ ├── packed_sequence.py │ └── scatter_gather.py │ ├── roi_sampling │ ├── __init__.py │ ├── _backend.pyi │ └── functions.py │ ├── scheduler.py │ ├── sequence.py │ └── snapshot.py ├── requirements.txt ├── script ├── eval_cornell_seen.sh ├── eval_cornell_unseen.sh ├── eval_grasp_anything_seen.sh ├── eval_grasp_anything_unseen.sh ├── eval_jacquard_seen.sh └── eval_jacquard_unseen.sh ├── split ├── control_manage_cornell.py ├── control_manage_jacquard.py ├── cornell │ ├── seen.obj │ └── unseen.obj ├── grasp-anything │ ├── seen.obj │ └── unseen.obj └── jacquard │ ├── seen.obj │ └── unseen.obj ├── train_network.py ├── train_network_grasp_det_seg.py ├── utils ├── data │ ├── __init__.py │ ├── camera_data.py │ ├── cornell_data.py │ ├── grasp_anything_data.py │ ├── grasp_data.py │ ├── jacquard_data.py │ ├── ocid_grasp_data.py │ └── vmrd_data.py ├── dataset_processing │ ├── evaluation.py │ ├── generate_cornell_depth.py │ ├── grasp.py │ ├── image.py │ └── mask.py ├── get_cornell.sh ├── get_jacquard.sh ├── timeit.py └── visualisation │ ├── gridshow.py │ └── plot.py └── weights ├── model_cornell ├── model_grasp_anything ├── model_jacquard ├── model_ocid └── model_vmrd /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | /docker/ 106 | 107 | results/ 108 | 109 | saved_data/ 110 | 111 | *.pyc 112 | 113 | data/* 114 | logs/ 115 | trained-models/grasp_det_seg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023, An Dinh Vuong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Grasp-Anything 2 | This is the repository of the paper "Grasp-Anything: Large-scale Grasp Dataset from Foundation Models" 3 | ## Table of contents 4 | 1. [Installation](#installation) 5 | 1. [Datasets](#datasets) 6 | 1. [Training](#training) 7 | 1. [Testing](#testing) 8 | 9 | ## Installation 10 | - Create a virtual environment 11 | ```bash 12 | $ conda create -n granything python=3.9 13 | $ conda activate granything 14 | ``` 15 | 16 | - Install pytorch 17 | ```bash 18 | $ conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch 19 | $ pip install -r requirements.txt 20 | ``` 21 | 22 | ## Datasets 23 | Our dataset can be accessed via [this link](https://airvlab.github.io/grasp-anything/docs/download/). 24 | 25 | ## Training 26 | We use GR-ConvNet as our default deep network. To train GR-ConvNet on different datasets, you can use the following command: 27 | ```bash 28 | $ python train_network.py --dataset --dataset-path --description --use-depth 0 29 | ``` 30 | For example, if you want to train a GR-ConvNet on Cornell, use the following command: 31 | ```bash 32 | $ python train_network.py --dataset cornell --dataset-path data/cornell --description training_cornell --use-depth 0 33 | ``` 34 | We also provide training for other baselines, you can use the following command: 35 | ```bash 36 | $ python train_network.py --dataset --dataset-path --description --use-depth 0 --network 37 | ``` 38 | For instance, if you want to train GG-CNN on Cornell, use the following command: 39 | ```bash 40 | python train_network.py --dataset cornell --dataset-path data/cornell/ --description training_ggcnn_on_cornell --use-depth 0 --network ggcnn 41 | ``` 42 | 43 | ## Testing 44 | For testing procedure, we can apply the similar commands to test different baselines on different datasets: 45 | ```bash 46 | python evaluate.py --network --dataset --dataset-path data/ --iou-eval 47 | ``` 48 | Important note: `` is the path to the pretrained model obtained by training procedure. Usually, the pretrained models obtained by training are stored at `logs/_`. You can select the desired pretrained model to evaluate. We do not have to specify neural architecture as the codebase will automatically detect the neural architecture. Pretrained weights are available at [this link](https://drive.google.com/file/d/1OXVFXqv0rgxiVLz89tnSj0Xb-20ZJ4fH/view?usp=sharing). 49 | 50 | 51 | ## Acknowledgement 52 | Our codebase is developed based on [Kumra et al.](https://github.com/skumra/robotic-grasping). 53 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /cleanup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | find logs/ -maxdepth 1 -type d | grep -v ^\\.$ | xargs -n 1 du -s | while read size name ; do if [ $size -le 10485 ] ; then echo rm -rf $name ; fi done 4 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import time 4 | 5 | import numpy as np 6 | import torch.utils.data 7 | 8 | from hardware.device import get_device 9 | from inference.post_process import post_process_output 10 | from utils.data import get_dataset 11 | from utils.dataset_processing import evaluation, grasp 12 | from utils.visualisation.plot import save_results 13 | 14 | logging.basicConfig(level=logging.INFO) 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser(description='Evaluate networks') 19 | 20 | # Network 21 | parser.add_argument('--network', metavar='N', type=str, nargs='+', 22 | help='Path to saved networks to evaluate') 23 | parser.add_argument('--input-size', type=int, default=224, 24 | help='Input image size for the network') 25 | 26 | # Dataset 27 | parser.add_argument('--dataset', type=str, 28 | help='Dataset Name ("cornell" or "jaquard")') 29 | parser.add_argument('--dataset-path', type=str, 30 | help='Path to dataset') 31 | parser.add_argument('--use-depth', type=int, default=1, 32 | help='Use Depth image for evaluation (1/0)') 33 | parser.add_argument('--use-rgb', type=int, default=1, 34 | help='Use RGB image for evaluation (1/0)') 35 | parser.add_argument('--augment', action='store_true', 36 | help='Whether data augmentation should be applied') 37 | parser.add_argument('--split', type=float, default=0.01, 38 | help='Fraction of data for training (remainder is validation)') 39 | parser.add_argument('--ds-shuffle', action='store_true', default=False, 40 | help='Shuffle the dataset') 41 | parser.add_argument('--ds-rotate', type=float, default=0.0, 42 | help='Shift the start point of the dataset to use a different test/train split') 43 | parser.add_argument('--num-workers', type=int, default=8, 44 | help='Dataset workers') 45 | 46 | # Evaluation 47 | parser.add_argument('--n-grasps', type=int, default=1, 48 | help='Number of grasps to consider per image') 49 | parser.add_argument('--iou-threshold', type=float, default=0.25, 50 | help='Threshold for IOU matching') 51 | parser.add_argument('--iou-eval', action='store_true', 52 | help='Compute success based on IoU metric.') 53 | parser.add_argument('--jacquard-output', action='store_true', 54 | help='Jacquard-dataset style output') 55 | 56 | # Misc. 57 | parser.add_argument('--vis', action='store_true', 58 | help='Visualise the network output') 59 | parser.add_argument('--cpu', dest='force_cpu', action='store_true', default=False, 60 | help='Force code to run in CPU mode') 61 | parser.add_argument('--random-seed', type=int, default=123, 62 | help='Random seed for numpy') 63 | parser.add_argument('--seen', type=int, default=1, 64 | help='Flag for using seen classes, only work for Grasp-Anything dataset') 65 | 66 | args = parser.parse_args() 67 | 68 | if args.jacquard_output and args.dataset != 'jacquard': 69 | raise ValueError('--jacquard-output can only be used with the --dataset jacquard option.') 70 | if args.jacquard_output and args.augment: 71 | raise ValueError('--jacquard-output can not be used with data augmentation.') 72 | 73 | return args 74 | 75 | 76 | if __name__ == '__main__': 77 | args = parse_args() 78 | 79 | # Get the compute device 80 | device = get_device(args.force_cpu) 81 | 82 | # Load Dataset 83 | logging.info('Loading {} Dataset...'.format(args.dataset.title())) 84 | Dataset = get_dataset(args.dataset) 85 | test_dataset = Dataset(args.dataset_path, 86 | output_size=args.input_size, 87 | ds_rotate=args.ds_rotate, 88 | random_rotate=args.augment, 89 | random_zoom=args.augment, 90 | include_depth=args.use_depth, 91 | include_rgb=args.use_rgb, 92 | seen=args.seen) 93 | 94 | indices = list(range(test_dataset.length)) 95 | split = int(np.floor(args.split * test_dataset.length)) 96 | if args.ds_shuffle: 97 | np.random.seed(args.random_seed) 98 | np.random.shuffle(indices) 99 | val_indices = indices[split:] 100 | val_sampler = torch.utils.data.sampler.SubsetRandomSampler(val_indices) 101 | logging.info('Validation size: {}'.format(len(val_indices))) 102 | 103 | test_data = torch.utils.data.DataLoader( 104 | test_dataset, 105 | batch_size=1, 106 | num_workers=args.num_workers, 107 | sampler=val_sampler 108 | ) 109 | logging.info('Done') 110 | 111 | for network in args.network: 112 | logging.info('\nEvaluating model {}'.format(network)) 113 | 114 | # Load Network 115 | net = torch.load(network) 116 | 117 | results = {'correct': 0, 'failed': 0} 118 | 119 | if args.jacquard_output: 120 | jo_fn = network + '_jacquard_output.txt' 121 | with open(jo_fn, 'w') as f: 122 | pass 123 | 124 | start_time = time.time() 125 | 126 | with torch.no_grad(): 127 | for idx, (x, y, didx, rot, zoom) in enumerate(test_data): 128 | xc = x.to(device) 129 | yc = [yi.to(device) for yi in y] 130 | lossd = net.compute_loss(xc, yc) 131 | 132 | q_img, ang_img, width_img = post_process_output(lossd['pred']['pos'], lossd['pred']['cos'], 133 | lossd['pred']['sin'], lossd['pred']['width']) 134 | 135 | if args.iou_eval: 136 | s = evaluation.calculate_iou_match(q_img, ang_img, test_data.dataset.get_gtbb(didx, rot, zoom), 137 | no_grasps=args.n_grasps, 138 | grasp_width=width_img, 139 | threshold=args.iou_threshold 140 | ) 141 | if s: 142 | results['correct'] += 1 143 | else: 144 | results['failed'] += 1 145 | 146 | if args.jacquard_output: 147 | grasps = grasp.detect_grasps(q_img, ang_img, width_img=width_img, no_grasps=1) 148 | with open(jo_fn, 'a') as f: 149 | for g in grasps: 150 | f.write(test_data.dataset.get_jname(didx) + '\n') 151 | f.write(g.to_jacquard(scale=1024 / 300) + '\n') 152 | 153 | if args.vis: 154 | save_results( 155 | rgb_img=test_data.dataset.get_rgb(didx, rot, zoom, normalise=False), 156 | depth_img=test_data.dataset.get_depth(didx, rot, zoom), 157 | grasp_q_img=q_img, 158 | grasp_angle_img=ang_img, 159 | no_grasps=args.n_grasps, 160 | grasp_width_img=width_img 161 | ) 162 | 163 | avg_time = (time.time() - start_time) / len(test_data) 164 | logging.info('Average evaluation time per image: {}ms'.format(avg_time * 1000)) 165 | 166 | if args.iou_eval: 167 | logging.info('IOU Results: %d/%d = %f' % (results['correct'], 168 | results['correct'] + results['failed'], 169 | results['correct'] / (results['correct'] + results['failed']))) 170 | 171 | if args.jacquard_output: 172 | logging.info('Jacquard output saved to {}'.format(jo_fn)) 173 | 174 | del net 175 | torch.cuda.empty_cache() 176 | -------------------------------------------------------------------------------- /inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Grasp-Anything/d7755f43c5518bd6590b25021054f862e65bddd5/inference/__init__.py -------------------------------------------------------------------------------- /inference/grasp_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | 8 | from hardware.camera import RealSenseCamera 9 | from hardware.device import get_device 10 | from inference.post_process import post_process_output 11 | from utils.data.camera_data import CameraData 12 | from utils.dataset_processing.grasp import detect_grasps 13 | from utils.visualisation.plot import plot_grasp 14 | 15 | 16 | class GraspGenerator: 17 | def __init__(self, saved_model_path, cam_id, visualize=False): 18 | self.saved_model_path = saved_model_path 19 | self.camera = RealSenseCamera(device_id=cam_id) 20 | 21 | self.saved_model_path = saved_model_path 22 | self.model = None 23 | self.device = None 24 | 25 | self.cam_data = CameraData(include_depth=True, include_rgb=True) 26 | 27 | # Connect to camera 28 | self.camera.connect() 29 | 30 | # Load camera pose and depth scale (from running calibration) 31 | self.cam_pose = np.loadtxt('saved_data/camera_pose.txt', delimiter=' ') 32 | self.cam_depth_scale = np.loadtxt('saved_data/camera_depth_scale.txt', delimiter=' ') 33 | 34 | homedir = os.path.join(os.path.expanduser('~'), "grasp-comms") 35 | self.grasp_request = os.path.join(homedir, "grasp_request.npy") 36 | self.grasp_available = os.path.join(homedir, "grasp_available.npy") 37 | self.grasp_pose = os.path.join(homedir, "grasp_pose.npy") 38 | 39 | if visualize: 40 | self.fig = plt.figure(figsize=(10, 10)) 41 | else: 42 | self.fig = None 43 | 44 | def load_model(self): 45 | print('Loading model... ') 46 | self.model = torch.load(self.saved_model_path) 47 | # Get the compute device 48 | self.device = get_device(force_cpu=False) 49 | 50 | def generate(self): 51 | # Get RGB-D image from camera 52 | image_bundle = self.camera.get_image_bundle() 53 | rgb = image_bundle['rgb'] 54 | depth = image_bundle['aligned_depth'] 55 | x, depth_img, rgb_img = self.cam_data.get_data(rgb=rgb, depth=depth) 56 | 57 | # Predict the grasp pose using the saved model 58 | with torch.no_grad(): 59 | xc = x.to(self.device) 60 | pred = self.model.predict(xc) 61 | 62 | q_img, ang_img, width_img = post_process_output(pred['pos'], pred['cos'], pred['sin'], pred['width']) 63 | grasps = detect_grasps(q_img, ang_img, width_img) 64 | 65 | # Get grasp position from model output 66 | pos_z = depth[grasps[0].center[0] + self.cam_data.top_left[0], grasps[0].center[1] + self.cam_data.top_left[1]] * self.cam_depth_scale - 0.04 67 | pos_x = np.multiply(grasps[0].center[1] + self.cam_data.top_left[1] - self.camera.intrinsics.ppx, 68 | pos_z / self.camera.intrinsics.fx) 69 | pos_y = np.multiply(grasps[0].center[0] + self.cam_data.top_left[0] - self.camera.intrinsics.ppy, 70 | pos_z / self.camera.intrinsics.fy) 71 | 72 | if pos_z == 0: 73 | return 74 | 75 | target = np.asarray([pos_x, pos_y, pos_z]) 76 | target.shape = (3, 1) 77 | print('target: ', target) 78 | 79 | # Convert camera to robot coordinates 80 | camera2robot = self.cam_pose 81 | target_position = np.dot(camera2robot[0:3, 0:3], target) + camera2robot[0:3, 3:] 82 | target_position = target_position[0:3, 0] 83 | 84 | # Convert camera to robot angle 85 | angle = np.asarray([0, 0, grasps[0].angle]) 86 | angle.shape = (3, 1) 87 | target_angle = np.dot(camera2robot[0:3, 0:3], angle) 88 | 89 | # Concatenate grasp pose with grasp angle 90 | grasp_pose = np.append(target_position, target_angle[2]) 91 | 92 | print('grasp_pose: ', grasp_pose) 93 | 94 | np.save(self.grasp_pose, grasp_pose) 95 | 96 | if self.fig: 97 | plot_grasp(fig=self.fig, rgb_img=self.cam_data.get_rgb(rgb, False), grasps=grasps, save=True) 98 | 99 | def run(self): 100 | while True: 101 | if np.load(self.grasp_request): 102 | self.generate() 103 | np.save(self.grasp_request, 0) 104 | np.save(self.grasp_available, 1) 105 | else: 106 | time.sleep(0.1) 107 | -------------------------------------------------------------------------------- /inference/models/__init__.py: -------------------------------------------------------------------------------- 1 | def get_network(network_name): 2 | network_name = network_name.lower() 3 | # Original GR-ConvNet 4 | if network_name == 'grconvnet': 5 | from .grconvnet import GenerativeResnet 6 | return GenerativeResnet 7 | # Configurable GR-ConvNet with multiple dropouts 8 | elif network_name == 'grconvnet2': 9 | from .grconvnet2 import GenerativeResnet 10 | return GenerativeResnet 11 | # Configurable GR-ConvNet with dropout at the end 12 | elif network_name == 'grconvnet3': 13 | from .grconvnet3 import GenerativeResnet 14 | return GenerativeResnet 15 | # Inverted GR-ConvNet 16 | elif network_name == 'grconvnet4': 17 | from .grconvnet4 import GenerativeResnet 18 | return GenerativeResnet 19 | elif network_name == 'ragt': 20 | from .ragt.ragt import RAGT 21 | return RAGT 22 | elif network_name == 'ggcnn': 23 | from .ggcnn import GGCNN 24 | return GGCNN 25 | else: 26 | raise NotImplementedError('Network {} is not implemented'.format(network_name)) 27 | -------------------------------------------------------------------------------- /inference/models/ggcnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from inference.models.grasp_model import GraspModel 5 | 6 | filter_sizes = [32, 16, 8, 8, 16, 32] 7 | kernel_sizes = [9, 5, 3, 3, 5, 9] 8 | strides = [3, 2, 2, 2, 2, 3] 9 | 10 | 11 | class GGCNN(GraspModel): 12 | """ 13 | GG-CNN 14 | Equivalient to the Keras Model used in the RSS Paper (https://arxiv.org/abs/1804.05172) 15 | """ 16 | def __init__(self, input_channels=4, output_channels=1, channel_size=32, dropout=False, prob=0.0): 17 | super().__init__() 18 | self.conv1 = nn.Conv2d(input_channels, filter_sizes[0], kernel_sizes[0], stride=strides[0], padding=3) 19 | self.conv2 = nn.Conv2d(filter_sizes[0], filter_sizes[1], kernel_sizes[1], stride=strides[1], padding=2) 20 | self.conv3 = nn.Conv2d(filter_sizes[1], filter_sizes[2], kernel_sizes[2], stride=strides[2], padding=1) 21 | self.convt1 = nn.ConvTranspose2d(filter_sizes[2], filter_sizes[3], kernel_sizes[3], stride=strides[3], padding=1, output_padding=1) 22 | self.convt2 = nn.ConvTranspose2d(filter_sizes[3], filter_sizes[4], kernel_sizes[4], stride=strides[4], padding=2, output_padding=1) 23 | self.convt3 = nn.ConvTranspose2d(filter_sizes[4], filter_sizes[5], kernel_sizes[5], stride=strides[5], padding=5, output_padding=1) 24 | 25 | self.pos_output = nn.Conv2d(filter_sizes[5], 1, kernel_size=2) 26 | self.cos_output = nn.Conv2d(filter_sizes[5], 1, kernel_size=2) 27 | self.sin_output = nn.Conv2d(filter_sizes[5], 1, kernel_size=2) 28 | self.width_output = nn.Conv2d(filter_sizes[5], 1, kernel_size=2) 29 | 30 | for m in self.modules(): 31 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 32 | nn.init.xavier_uniform_(m.weight, gain=1) 33 | 34 | def forward(self, x): 35 | x = F.relu(self.conv1(x)) 36 | x = F.relu(self.conv2(x)) 37 | x = F.relu(self.conv3(x)) 38 | x = F.relu(self.convt1(x)) 39 | x = F.relu(self.convt2(x)) 40 | x = F.relu(self.convt3(x)) 41 | 42 | pos_output = self.pos_output(x) 43 | cos_output = self.cos_output(x) 44 | sin_output = self.sin_output(x) 45 | width_output = self.width_output(x) 46 | 47 | return pos_output, cos_output, sin_output, width_output 48 | 49 | def compute_loss(self, xc, yc): 50 | y_pos, y_cos, y_sin, y_width = yc 51 | pos_pred, cos_pred, sin_pred, width_pred = self(xc) 52 | 53 | p_loss = F.mse_loss(pos_pred, y_pos) 54 | cos_loss = F.mse_loss(cos_pred, y_cos) 55 | sin_loss = F.mse_loss(sin_pred, y_sin) 56 | width_loss = F.mse_loss(width_pred, y_width) 57 | 58 | return { 59 | 'loss': p_loss + cos_loss + sin_loss + width_loss, 60 | 'losses': { 61 | 'p_loss': p_loss, 62 | 'cos_loss': cos_loss, 63 | 'sin_loss': sin_loss, 64 | 'width_loss': width_loss 65 | }, 66 | 'pred': { 67 | 'pos': pos_pred, 68 | 'cos': cos_pred, 69 | 'sin': sin_pred, 70 | 'width': width_pred 71 | } 72 | } -------------------------------------------------------------------------------- /inference/models/grasp_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class GraspModel(nn.Module): 6 | """ 7 | An abstract model for grasp network in a common format. 8 | """ 9 | 10 | def __init__(self): 11 | super(GraspModel, self).__init__() 12 | 13 | def forward(self, x_in): 14 | raise NotImplementedError() 15 | 16 | def compute_loss(self, xc, yc): 17 | y_pos, y_cos, y_sin, y_width = yc 18 | pos_pred, cos_pred, sin_pred, width_pred = self(xc) 19 | 20 | p_loss = F.smooth_l1_loss(pos_pred, y_pos) 21 | cos_loss = F.smooth_l1_loss(cos_pred, y_cos) 22 | sin_loss = F.smooth_l1_loss(sin_pred, y_sin) 23 | width_loss = F.smooth_l1_loss(width_pred, y_width) 24 | 25 | return { 26 | 'loss': p_loss + cos_loss + sin_loss + width_loss, 27 | 'losses': { 28 | 'p_loss': p_loss, 29 | 'cos_loss': cos_loss, 30 | 'sin_loss': sin_loss, 31 | 'width_loss': width_loss 32 | }, 33 | 'pred': { 34 | 'pos': pos_pred, 35 | 'cos': cos_pred, 36 | 'sin': sin_pred, 37 | 'width': width_pred 38 | } 39 | } 40 | 41 | def predict(self, xc): 42 | pos_pred, cos_pred, sin_pred, width_pred = self(xc) 43 | return { 44 | 'pos': pos_pred, 45 | 'cos': cos_pred, 46 | 'sin': sin_pred, 47 | 'width': width_pred 48 | } 49 | 50 | 51 | class ResidualBlock(nn.Module): 52 | """ 53 | A residual block with dropout option 54 | """ 55 | 56 | def __init__(self, in_channels, out_channels, kernel_size=3): 57 | super(ResidualBlock, self).__init__() 58 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1) 59 | self.bn1 = nn.BatchNorm2d(in_channels) 60 | self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1) 61 | self.bn2 = nn.BatchNorm2d(in_channels) 62 | 63 | def forward(self, x_in): 64 | x = self.bn1(self.conv1(x_in)) 65 | x = F.relu(x) 66 | x = self.bn2(self.conv2(x)) 67 | return x + x_in 68 | -------------------------------------------------------------------------------- /inference/models/grconvnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from inference.models.grasp_model import GraspModel, ResidualBlock 5 | 6 | 7 | class GenerativeResnet(GraspModel): 8 | 9 | def __init__(self, input_channels=1, dropout=False, prob=0.0, channel_size=32): 10 | super(GenerativeResnet, self).__init__() 11 | self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=9, stride=1, padding=4) 12 | self.bn1 = nn.BatchNorm2d(32) 13 | 14 | self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1) 15 | self.bn2 = nn.BatchNorm2d(64) 16 | 17 | self.conv3 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1) 18 | self.bn3 = nn.BatchNorm2d(128) 19 | 20 | self.res1 = ResidualBlock(128, 128) 21 | self.res2 = ResidualBlock(128, 128) 22 | self.res3 = ResidualBlock(128, 128) 23 | self.res4 = ResidualBlock(128, 128) 24 | self.res5 = ResidualBlock(128, 128) 25 | 26 | self.conv4 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, output_padding=1) 27 | self.bn4 = nn.BatchNorm2d(64) 28 | 29 | self.conv5 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=2, output_padding=1) 30 | self.bn5 = nn.BatchNorm2d(32) 31 | 32 | self.conv6 = nn.ConvTranspose2d(32, 32, kernel_size=9, stride=1, padding=4) 33 | 34 | self.pos_output = nn.Conv2d(32, 1, kernel_size=2) 35 | self.cos_output = nn.Conv2d(32, 1, kernel_size=2) 36 | self.sin_output = nn.Conv2d(32, 1, kernel_size=2) 37 | self.width_output = nn.Conv2d(32, 1, kernel_size=2) 38 | 39 | self.dropout1 = nn.Dropout(p=prob) 40 | 41 | for m in self.modules(): 42 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 43 | nn.init.xavier_uniform_(m.weight, gain=1) 44 | 45 | def forward(self, x_in): 46 | x = F.relu(self.bn1(self.conv1(x_in))) 47 | x = F.relu(self.bn2(self.conv2(x))) 48 | x = F.relu(self.bn3(self.conv3(x))) 49 | x = self.res1(x) 50 | x = self.res2(x) 51 | x = self.res3(x) 52 | x = self.res4(x) 53 | x = self.res5(x) 54 | x = F.relu(self.bn4(self.conv4(x))) 55 | x = F.relu(self.bn5(self.conv5(x))) 56 | x = self.conv6(x) 57 | 58 | pos_output = self.pos_output(self.dropout1(x)) 59 | cos_output = self.cos_output(self.dropout1(x)) 60 | sin_output = self.sin_output(self.dropout1(x)) 61 | width_output = self.width_output(self.dropout1(x)) 62 | 63 | return pos_output, cos_output, sin_output, width_output 64 | -------------------------------------------------------------------------------- /inference/models/grconvnet2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from inference.models.grasp_model import GraspModel, ResidualBlock 5 | 6 | 7 | class GenerativeResnet(GraspModel): 8 | 9 | def __init__(self, input_channels=4, output_channels=1, channel_size=32, dropout=False, prob=0.0): 10 | super(GenerativeResnet, self).__init__() 11 | self.conv1 = nn.Conv2d(input_channels, channel_size, kernel_size=9, stride=1, padding=4) 12 | self.bn1 = nn.BatchNorm2d(channel_size) 13 | 14 | self.conv2 = nn.Conv2d(channel_size, channel_size * 2, kernel_size=4, stride=2, padding=1) 15 | self.bn2 = nn.BatchNorm2d(channel_size * 2) 16 | 17 | self.conv3 = nn.Conv2d(channel_size * 2, channel_size * 4, kernel_size=4, stride=2, padding=1) 18 | self.bn3 = nn.BatchNorm2d(channel_size * 4) 19 | 20 | self.res1 = ResidualBlock(channel_size * 4, channel_size * 4) 21 | self.res2 = ResidualBlock(channel_size * 4, channel_size * 4) 22 | self.res3 = ResidualBlock(channel_size * 4, channel_size * 4) 23 | self.res4 = ResidualBlock(channel_size * 4, channel_size * 4) 24 | self.res5 = ResidualBlock(channel_size * 4, channel_size * 4) 25 | 26 | self.conv4 = nn.ConvTranspose2d(channel_size * 4, channel_size * 2, kernel_size=4, stride=2, padding=1, 27 | output_padding=1) 28 | self.bn4 = nn.BatchNorm2d(channel_size * 2) 29 | 30 | self.conv5 = nn.ConvTranspose2d(channel_size * 2, channel_size, kernel_size=4, stride=2, padding=2, 31 | output_padding=1) 32 | self.bn5 = nn.BatchNorm2d(channel_size) 33 | 34 | self.conv6 = nn.ConvTranspose2d(channel_size, channel_size, kernel_size=9, stride=1, padding=4) 35 | 36 | self.pos_output = nn.Conv2d(in_channels=channel_size, out_channels=output_channels, kernel_size=2) 37 | self.cos_output = nn.Conv2d(in_channels=channel_size, out_channels=output_channels, kernel_size=2) 38 | self.sin_output = nn.Conv2d(in_channels=channel_size, out_channels=output_channels, kernel_size=2) 39 | self.width_output = nn.Conv2d(in_channels=channel_size, out_channels=output_channels, kernel_size=2) 40 | 41 | self.dropout = dropout 42 | self.dropout_pos = nn.Dropout(p=prob) 43 | self.dropout_cos = nn.Dropout(p=prob) 44 | self.dropout_sin = nn.Dropout(p=prob) 45 | self.dropout_wid = nn.Dropout(p=prob) 46 | 47 | for m in self.modules(): 48 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 49 | nn.init.xavier_uniform_(m.weight, gain=1) 50 | 51 | def forward(self, x_in): 52 | x = F.relu(self.bn1(self.conv1(x_in))) 53 | x = F.relu(self.bn2(self.conv2(x))) 54 | x = F.relu(self.bn3(self.conv3(x))) 55 | x = self.res1(x) 56 | x = self.res2(x) 57 | x = self.res3(x) 58 | x = self.res4(x) 59 | x = self.res5(x) 60 | x = F.relu(self.bn4(self.conv4(x))) 61 | x = F.relu(self.bn5(self.conv5(x))) 62 | x = self.conv6(x) 63 | 64 | if self.dropout: 65 | pos_output = self.pos_output(self.dropout_pos(x)) 66 | cos_output = self.cos_output(self.dropout_cos(x)) 67 | sin_output = self.sin_output(self.dropout_sin(x)) 68 | width_output = self.width_output(self.dropout_wid(x)) 69 | else: 70 | pos_output = self.pos_output(x) 71 | cos_output = self.cos_output(x) 72 | sin_output = self.sin_output(x) 73 | width_output = self.width_output(x) 74 | 75 | return pos_output, cos_output, sin_output, width_output 76 | -------------------------------------------------------------------------------- /inference/models/grconvnet3.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from inference.models.grasp_model import GraspModel, ResidualBlock 5 | 6 | 7 | class GenerativeResnet(GraspModel): 8 | 9 | def __init__(self, input_channels=4, output_channels=1, channel_size=32, dropout=False, prob=0.0): 10 | super(GenerativeResnet, self).__init__() 11 | self.conv1 = nn.Conv2d(input_channels, channel_size, kernel_size=9, stride=1, padding=4) 12 | self.bn1 = nn.BatchNorm2d(channel_size) 13 | 14 | self.conv2 = nn.Conv2d(channel_size, channel_size * 2, kernel_size=4, stride=2, padding=1) 15 | self.bn2 = nn.BatchNorm2d(channel_size * 2) 16 | 17 | self.conv3 = nn.Conv2d(channel_size * 2, channel_size * 4, kernel_size=4, stride=2, padding=1) 18 | self.bn3 = nn.BatchNorm2d(channel_size * 4) 19 | 20 | self.res1 = ResidualBlock(channel_size * 4, channel_size * 4) 21 | self.res2 = ResidualBlock(channel_size * 4, channel_size * 4) 22 | self.res3 = ResidualBlock(channel_size * 4, channel_size * 4) 23 | self.res4 = ResidualBlock(channel_size * 4, channel_size * 4) 24 | self.res5 = ResidualBlock(channel_size * 4, channel_size * 4) 25 | 26 | self.conv4 = nn.ConvTranspose2d(channel_size * 4, channel_size * 2, kernel_size=4, stride=2, padding=1, 27 | output_padding=1) 28 | self.bn4 = nn.BatchNorm2d(channel_size * 2) 29 | 30 | self.conv5 = nn.ConvTranspose2d(channel_size * 2, channel_size, kernel_size=4, stride=2, padding=2, 31 | output_padding=1) 32 | self.bn5 = nn.BatchNorm2d(channel_size) 33 | 34 | self.conv6 = nn.ConvTranspose2d(channel_size, channel_size, kernel_size=9, stride=1, padding=4) 35 | 36 | self.pos_output = nn.Conv2d(in_channels=channel_size, out_channels=output_channels, kernel_size=2) 37 | self.cos_output = nn.Conv2d(in_channels=channel_size, out_channels=output_channels, kernel_size=2) 38 | self.sin_output = nn.Conv2d(in_channels=channel_size, out_channels=output_channels, kernel_size=2) 39 | self.width_output = nn.Conv2d(in_channels=channel_size, out_channels=output_channels, kernel_size=2) 40 | 41 | self.dropout = dropout 42 | self.dropout_pos = nn.Dropout(p=prob) 43 | self.dropout_cos = nn.Dropout(p=prob) 44 | self.dropout_sin = nn.Dropout(p=prob) 45 | self.dropout_wid = nn.Dropout(p=prob) 46 | 47 | for m in self.modules(): 48 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 49 | nn.init.xavier_uniform_(m.weight, gain=1) 50 | 51 | def forward(self, x_in): 52 | x = F.relu(self.bn1(self.conv1(x_in))) 53 | x = F.relu(self.bn2(self.conv2(x))) 54 | x = F.relu(self.bn3(self.conv3(x))) 55 | x = self.res1(x) 56 | x = self.res2(x) 57 | x = self.res3(x) 58 | x = self.res4(x) 59 | x = self.res5(x) 60 | x = F.relu(self.bn4(self.conv4(x))) 61 | x = F.relu(self.bn5(self.conv5(x))) 62 | x = self.conv6(x) 63 | 64 | if self.dropout: 65 | pos_output = self.pos_output(self.dropout_pos(x)) 66 | cos_output = self.cos_output(self.dropout_cos(x)) 67 | sin_output = self.sin_output(self.dropout_sin(x)) 68 | width_output = self.width_output(self.dropout_wid(x)) 69 | else: 70 | pos_output = self.pos_output(x) 71 | cos_output = self.cos_output(x) 72 | sin_output = self.sin_output(x) 73 | width_output = self.width_output(x) 74 | 75 | return pos_output, cos_output, sin_output, width_output 76 | -------------------------------------------------------------------------------- /inference/models/grconvnet4.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from inference.models.grasp_model import GraspModel, ResidualBlock 5 | 6 | 7 | class GenerativeResnet(GraspModel): 8 | 9 | def __init__(self, input_channels=4, output_channels=1, channel_size=32, dropout=False, prob=0.0): 10 | super(GenerativeResnet, self).__init__() 11 | self.conv1 = nn.Conv2d(input_channels, channel_size, kernel_size=9, stride=1, padding=4) 12 | self.bn1 = nn.BatchNorm2d(channel_size) 13 | 14 | self.conv2 = nn.Conv2d(channel_size, channel_size // 2, kernel_size=4, stride=2, padding=1) 15 | self.bn2 = nn.BatchNorm2d(channel_size // 2) 16 | 17 | self.conv3 = nn.Conv2d(channel_size // 2, channel_size // 4, kernel_size=4, stride=2, padding=1) 18 | self.bn3 = nn.BatchNorm2d(channel_size // 4) 19 | 20 | self.res1 = ResidualBlock(channel_size // 4, channel_size // 4) 21 | self.res2 = ResidualBlock(channel_size // 4, channel_size // 4) 22 | self.res3 = ResidualBlock(channel_size // 4, channel_size // 4) 23 | self.res4 = ResidualBlock(channel_size // 4, channel_size // 4) 24 | self.res5 = ResidualBlock(channel_size // 4, channel_size // 4) 25 | 26 | self.conv4 = nn.ConvTranspose2d(channel_size // 4, channel_size // 2, kernel_size=4, stride=2, padding=1, 27 | output_padding=1) 28 | self.bn4 = nn.BatchNorm2d(channel_size // 2) 29 | 30 | self.conv5 = nn.ConvTranspose2d(channel_size // 2, channel_size, kernel_size=4, stride=2, padding=2, 31 | output_padding=1) 32 | self.bn5 = nn.BatchNorm2d(channel_size) 33 | 34 | self.conv6 = nn.ConvTranspose2d(channel_size, channel_size, kernel_size=9, stride=1, padding=4) 35 | 36 | self.pos_output = nn.Conv2d(in_channels=channel_size, out_channels=output_channels, kernel_size=2) 37 | self.cos_output = nn.Conv2d(in_channels=channel_size, out_channels=output_channels, kernel_size=2) 38 | self.sin_output = nn.Conv2d(in_channels=channel_size, out_channels=output_channels, kernel_size=2) 39 | self.width_output = nn.Conv2d(in_channels=channel_size, out_channels=output_channels, kernel_size=2) 40 | 41 | self.dropout = dropout 42 | self.dropout_pos = nn.Dropout(p=prob) 43 | self.dropout_cos = nn.Dropout(p=prob) 44 | self.dropout_sin = nn.Dropout(p=prob) 45 | self.dropout_wid = nn.Dropout(p=prob) 46 | 47 | for m in self.modules(): 48 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 49 | nn.init.xavier_uniform_(m.weight, gain=1) 50 | 51 | def forward(self, x_in): 52 | x = F.relu(self.bn1(self.conv1(x_in))) 53 | x = F.relu(self.bn2(self.conv2(x))) 54 | x = F.relu(self.bn3(self.conv3(x))) 55 | x = self.res1(x) 56 | x = self.res2(x) 57 | x = self.res3(x) 58 | x = self.res4(x) 59 | x = self.res5(x) 60 | x = F.relu(self.bn4(self.conv4(x))) 61 | x = F.relu(self.bn5(self.conv5(x))) 62 | x = self.conv6(x) 63 | 64 | if self.dropout: 65 | pos_output = self.pos_output(self.dropout_pos(x)) 66 | cos_output = self.cos_output(self.dropout_cos(x)) 67 | sin_output = self.sin_output(self.dropout_sin(x)) 68 | width_output = self.width_output(self.dropout_wid(x)) 69 | else: 70 | pos_output = self.pos_output(x) 71 | cos_output = self.cos_output(x) 72 | sin_output = self.sin_output(x) 73 | width_output = self.width_output(x) 74 | 75 | return pos_output, cos_output, sin_output, width_output 76 | -------------------------------------------------------------------------------- /inference/models/ragt/Anchor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | anchor_thetas = [x*0.2094 for x in range(15)] 5 | # anchor宽 6 | anchor_w = 85.72 7 | # anchor高 8 | anchor_h = 19.15 9 | # 每个grid cell 的anchor数 10 | num_anchors = 3 11 | # 输出层下采样次数 12 | times_of_down_sampling = 5 13 | # 输入图像尺寸 14 | img_size = 416 15 | # 防止角度偏移量为0 16 | Anchor_eps = 0.000001 17 | 18 | 19 | field_of_grid_cell = 2 ** times_of_down_sampling 20 | num_grid_cell = int(img_size / field_of_grid_cell) 21 | theta_margin = 180 / num_anchors 22 | 23 | 24 | # if __name__ == '__main__': 25 | # print(field_of_grid_cell) 26 | # print(anchor_thetas) 27 | # print(3//0.2094) 28 | # a = np.arange(16).reshape((4, 4)) 29 | # print(33.2%16.4) 30 | # a = np.arange(10*26*26*15*6).reshape((10, 26, 26, 15, 6)) 31 | # b = [] 32 | # for i in a: 33 | # b.append(i) 34 | # # b = np.array(b) 35 | # print(type(b[0])) 36 | # # print(b.shape) 37 | # c = np.arange(16).reshape((4, 4)) 38 | # d = [] 39 | # for i in c: 40 | # d.append(i) 41 | # print(d) 42 | # print(np.array(d)) 43 | -------------------------------------------------------------------------------- /inference/models/ragt/model_config.py: -------------------------------------------------------------------------------- 1 | def get_config(mode: str = "xxs") -> dict: 2 | if mode == "xx_small": 3 | mv2_exp_mult = 2 4 | config = { 5 | "layer1": { 6 | "out_channels": 16, 7 | "expand_ratio": mv2_exp_mult, 8 | "num_blocks": 1, 9 | "stride": 1, 10 | "block_type": "mv2", 11 | }, 12 | "layer2": { 13 | "out_channels": 24, 14 | "expand_ratio": mv2_exp_mult, 15 | "num_blocks": 3, 16 | "stride": 2, 17 | "block_type": "mv2", 18 | }, 19 | "layer3": { # 28x28 20 | "out_channels": 48, 21 | "transformer_channels": 64, 22 | "ffn_dim": 128, 23 | "transformer_blocks": 2, 24 | "patch_h": 2, # 8, 25 | "patch_w": 2, # 8, 26 | "stride": 2, 27 | "mv_expand_ratio": mv2_exp_mult, 28 | "num_heads": 4, 29 | "block_type": "mobilevit", 30 | }, 31 | "layer4": { # 14x14 32 | "out_channels": 64, 33 | "transformer_channels": 80, 34 | "ffn_dim": 160, 35 | "transformer_blocks": 4, 36 | "patch_h": 2, # 4, 37 | "patch_w": 2, # 4, 38 | "stride": 2, 39 | "mv_expand_ratio": mv2_exp_mult, 40 | "num_heads": 4, 41 | "block_type": "mobilevit", 42 | }, 43 | "layer5": { # 7x7 44 | "out_channels": 80, 45 | "transformer_channels": 96, 46 | "ffn_dim": 192, 47 | "transformer_blocks": 3, 48 | "patch_h": 2, 49 | "patch_w": 2, 50 | "stride": 2, 51 | "mv_expand_ratio": mv2_exp_mult, 52 | "num_heads": 4, 53 | "block_type": "mobilevit", 54 | }, 55 | "last_layer_exp_factor": 4, 56 | "cls_dropout": 0.1 57 | } 58 | elif mode == "x_small": 59 | mv2_exp_mult = 4 60 | config = { 61 | "layer1": { 62 | "out_channels": 32, 63 | "expand_ratio": mv2_exp_mult, 64 | "num_blocks": 1, 65 | "stride": 1, 66 | "block_type": "mv2", 67 | }, 68 | "layer2": { 69 | "out_channels": 48, 70 | "expand_ratio": mv2_exp_mult, 71 | "num_blocks": 3, 72 | "stride": 2, 73 | "block_type": "mv2", 74 | }, 75 | "layer3": { # 28x28 76 | "out_channels": 64, 77 | "transformer_channels": 96, 78 | "ffn_dim": 192, 79 | "transformer_blocks": 2, 80 | "patch_h": 2, 81 | "patch_w": 2, 82 | "stride": 2, 83 | "mv_expand_ratio": mv2_exp_mult, 84 | "num_heads": 4, 85 | "block_type": "mobilevit", 86 | }, 87 | "layer4": { # 14x14 88 | "out_channels": 80, 89 | "transformer_channels": 120, 90 | "ffn_dim": 240, 91 | "transformer_blocks": 4, 92 | "patch_h": 2, 93 | "patch_w": 2, 94 | "stride": 2, 95 | "mv_expand_ratio": mv2_exp_mult, 96 | "num_heads": 4, 97 | "block_type": "mobilevit", 98 | }, 99 | "layer5": { # 7x7 100 | "out_channels": 96, 101 | "transformer_channels": 144, 102 | "ffn_dim": 288, 103 | "transformer_blocks": 3, 104 | "patch_h": 2, 105 | "patch_w": 2, 106 | "stride": 2, 107 | "mv_expand_ratio": mv2_exp_mult, 108 | "num_heads": 4, 109 | "block_type": "mobilevit", 110 | }, 111 | "last_layer_exp_factor": 4, 112 | "cls_dropout": 0.1 113 | } 114 | elif mode == "small": 115 | mv2_exp_mult = 4 116 | config = { 117 | "layer1": { 118 | "out_channels": 32, 119 | "expand_ratio": mv2_exp_mult, 120 | "num_blocks": 1, 121 | "stride": 1, 122 | "block_type": "mv2", 123 | }, 124 | "layer2": { 125 | "out_channels": 64, 126 | "expand_ratio": mv2_exp_mult, 127 | "num_blocks": 3, 128 | "stride": 2, 129 | "block_type": "mv2", 130 | }, 131 | "layer3": { # 28x28 132 | "out_channels": 96, 133 | "transformer_channels": 144, 134 | "ffn_dim": 288, 135 | "transformer_blocks": 2, 136 | "patch_h": 2, 137 | "patch_w": 2, 138 | "stride": 2, 139 | "mv_expand_ratio": mv2_exp_mult, 140 | "num_heads": 4, 141 | "block_type": "mobilevit", 142 | }, 143 | "layer4": { # 14x14 144 | "out_channels": 128, 145 | "transformer_channels": 192, 146 | "ffn_dim": 384, 147 | "transformer_blocks": 4, 148 | "patch_h": 2, 149 | "patch_w": 2, 150 | "stride": 2, 151 | "mv_expand_ratio": mv2_exp_mult, 152 | "num_heads": 4, 153 | "block_type": "mobilevit", 154 | }, 155 | "layer5": { # 7x7 156 | "out_channels": 160, 157 | "transformer_channels": 240, 158 | "ffn_dim": 480, 159 | "transformer_blocks": 3, 160 | "patch_h": 2, 161 | "patch_w": 2, 162 | "stride": 2, 163 | "mv_expand_ratio": mv2_exp_mult, 164 | "num_heads": 4, 165 | "block_type": "mobilevit", 166 | }, 167 | "last_layer_exp_factor": 4, 168 | "cls_dropout": 0.1 169 | } 170 | else: 171 | raise NotImplementedError 172 | 173 | for k in ["layer1", "layer2", "layer3", "layer4", "layer5"]: 174 | config[k].update({"dropout": 0.1, "ffn_dropout": 0.0, "attn_dropout": 0.0}) 175 | 176 | return config 177 | -------------------------------------------------------------------------------- /inference/models/ragt/ragt.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from inference.models.grasp_model import GraspModel, ResidualBlock 5 | from .mobile_vit import get_model 6 | 7 | 8 | class RAGT(GraspModel): 9 | 10 | def __init__(self, input_channels=4, output_channels=1, channel_size=18, dropout=False, prob=0.0): 11 | super(RAGT, self).__init__() 12 | self.mobile_vit = get_model() 13 | 14 | # Upsampling layers to increase spatial dimensions 15 | self.upsample_layers = nn.Sequential( 16 | nn.Upsample(scale_factor=33, mode='bilinear', align_corners=False), 17 | nn.ReLU() 18 | ) 19 | 20 | self.pos_output = nn.Conv2d(in_channels=channel_size, out_channels=output_channels, kernel_size=2) 21 | self.cos_output = nn.Conv2d(in_channels=channel_size, out_channels=output_channels, kernel_size=2) 22 | self.sin_output = nn.Conv2d(in_channels=channel_size, out_channels=output_channels, kernel_size=2) 23 | self.width_output = nn.Conv2d(in_channels=channel_size, out_channels=output_channels, kernel_size=2) 24 | 25 | self.dropout = dropout 26 | self.dropout_pos = nn.Dropout(p=prob) 27 | self.dropout_cos = nn.Dropout(p=prob) 28 | self.dropout_sin = nn.Dropout(p=prob) 29 | self.dropout_wid = nn.Dropout(p=prob) 30 | 31 | for m in self.modules(): 32 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 33 | nn.init.xavier_uniform_(m.weight, gain=1) 34 | 35 | def forward(self, x_in): 36 | x = self.mobile_vit(x_in) 37 | x = self.upsample_layers(x) 38 | x = x[:,:,:225, :225] 39 | 40 | if self.dropout: 41 | pos_output = self.pos_output(self.dropout_pos(x)) 42 | cos_output = self.cos_output(self.dropout_cos(x)) 43 | sin_output = self.sin_output(self.dropout_sin(x)) 44 | width_output = self.width_output(self.dropout_wid(x)) 45 | else: 46 | pos_output = self.pos_output(x) 47 | cos_output = self.cos_output(x) 48 | sin_output = self.sin_output(x) 49 | width_output = self.width_output(x) 50 | 51 | return pos_output, cos_output, sin_output, width_output 52 | -------------------------------------------------------------------------------- /inference/models/ragt/transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | 7 | 8 | class MultiHeadAttention(nn.Module): 9 | """ 10 | This layer applies a multi-head self- or cross-attention as described in 11 | `Attention is all you need `_ paper 12 | 13 | Args: 14 | embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})` 15 | num_heads (int): Number of heads in multi-head attention 16 | attn_dropout (float): Attention dropout. Default: 0.0 17 | bias (bool): Use bias or not. Default: ``True`` 18 | 19 | Shape: 20 | - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches, 21 | and :math:`C_{in}` is input embedding dim 22 | - Output: same shape as the input 23 | 24 | """ 25 | 26 | def __init__( 27 | self, 28 | embed_dim: int, 29 | num_heads: int, 30 | attn_dropout: float = 0.0, 31 | bias: bool = True, 32 | *args, 33 | **kwargs 34 | ) -> None: 35 | super().__init__() 36 | if embed_dim % num_heads != 0: 37 | raise ValueError( 38 | "Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format( 39 | self.__class__.__name__, embed_dim, num_heads 40 | ) 41 | ) 42 | 43 | self.qkv_proj = nn.Linear(in_features=embed_dim, out_features=3 * embed_dim, bias=bias) 44 | 45 | self.attn_dropout = nn.Dropout(p=attn_dropout) 46 | self.out_proj = nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=bias) 47 | 48 | self.head_dim = embed_dim // num_heads 49 | self.scaling = self.head_dim ** -0.5 50 | self.softmax = nn.Softmax(dim=-1) 51 | self.num_heads = num_heads 52 | self.embed_dim = embed_dim 53 | 54 | def forward(self, x_q: Tensor) -> Tensor: 55 | # [N, P, C] 56 | b_sz, n_patches, in_channels = x_q.shape 57 | 58 | # self-attention 59 | # [N, P, C] -> [N, P, 3C] -> [N, P, 3, h, c] where C = hc 60 | qkv = self.qkv_proj(x_q).reshape(b_sz, n_patches, 3, self.num_heads, -1) 61 | 62 | # [N, P, 3, h, c] -> [N, h, 3, P, C] 63 | qkv = qkv.transpose(1, 3).contiguous() 64 | 65 | # [N, h, 3, P, C] -> [N, h, P, C] x 3 66 | query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] 67 | 68 | query = query * self.scaling 69 | 70 | # [N h, P, c] -> [N, h, c, P] 71 | key = key.transpose(-1, -2) 72 | 73 | # QK^T 74 | # [N, h, P, c] x [N, h, c, P] -> [N, h, P, P] 75 | attn = torch.matmul(query, key) 76 | attn = self.softmax(attn) 77 | attn = self.attn_dropout(attn) 78 | 79 | # weighted sum 80 | # [N, h, P, P] x [N, h, P, c] -> [N, h, P, c] 81 | out = torch.matmul(attn, value) 82 | 83 | # [N, h, P, c] -> [N, P, h, c] -> [N, P, C] 84 | out = out.transpose(1, 2).reshape(b_sz, n_patches, -1) 85 | out = self.out_proj(out) 86 | 87 | return out 88 | 89 | 90 | class TransformerEncoder(nn.Module): 91 | """ 92 | This class defines the pre-norm `Transformer encoder `_ 93 | Args: 94 | embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})` 95 | ffn_latent_dim (int): Inner dimension of the FFN 96 | num_heads (int) : Number of heads in multi-head attention. Default: 8 97 | attn_dropout (float): Dropout rate for attention in multi-head attention. Default: 0.0 98 | dropout (float): Dropout rate. Default: 0.0 99 | ffn_dropout (float): Dropout between FFN layers. Default: 0.0 100 | 101 | Shape: 102 | - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches, 103 | and :math:`C_{in}` is input embedding dim 104 | - Output: same shape as the input 105 | """ 106 | 107 | def __init__( 108 | self, 109 | embed_dim: int, 110 | ffn_latent_dim: int, 111 | num_heads: Optional[int] = 8, 112 | attn_dropout: Optional[float] = 0.0, 113 | dropout: Optional[float] = 0.0, 114 | ffn_dropout: Optional[float] = 0.0, 115 | *args, 116 | **kwargs 117 | ) -> None: 118 | 119 | super().__init__() 120 | 121 | attn_unit = MultiHeadAttention( 122 | embed_dim, 123 | num_heads, 124 | attn_dropout=attn_dropout, 125 | bias=True 126 | ) 127 | 128 | self.pre_norm_mha = nn.Sequential( 129 | nn.LayerNorm(embed_dim), 130 | attn_unit, 131 | nn.Dropout(p=dropout) 132 | ) 133 | 134 | self.pre_norm_ffn = nn.Sequential( 135 | nn.LayerNorm(embed_dim), 136 | nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True), 137 | nn.SiLU(), 138 | nn.Dropout(p=ffn_dropout), 139 | nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True), 140 | nn.Dropout(p=dropout) 141 | ) 142 | self.embed_dim = embed_dim 143 | self.ffn_dim = ffn_latent_dim 144 | self.ffn_dropout = ffn_dropout 145 | self.std_dropout = dropout 146 | 147 | def forward(self, x: Tensor) -> Tensor: 148 | # multi-head attention 149 | res = x 150 | x = self.pre_norm_mha(x) 151 | x = x + res 152 | 153 | # feed forward network 154 | x = x + self.pre_norm_ffn(x) 155 | return x 156 | -------------------------------------------------------------------------------- /inference/post_process.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from skimage.filters import gaussian 3 | 4 | 5 | def post_process_output(q_img, cos_img, sin_img, width_img): 6 | """ 7 | Post-process the raw output of the network, convert to numpy arrays, apply filtering. 8 | :param q_img: Q output of network (as torch Tensors) 9 | :param cos_img: cos output of network 10 | :param sin_img: sin output of network 11 | :param width_img: Width output of network 12 | :return: Filtered Q output, Filtered Angle output, Filtered Width output 13 | """ 14 | q_img = q_img.cpu().numpy().squeeze() 15 | ang_img = (torch.atan2(sin_img, cos_img) / 2.0).cpu().numpy().squeeze() 16 | width_img = width_img.cpu().numpy().squeeze() * 150.0 17 | 18 | q_img = gaussian(q_img, 2.0, preserve_range=True) 19 | ang_img = gaussian(ang_img, 2.0, preserve_range=True) 20 | width_img = gaussian(width_img, 1.0, preserve_range=True) 21 | 22 | return q_img, ang_img, width_img 23 | -------------------------------------------------------------------------------- /models/grasp_det_seg/__init__.py: -------------------------------------------------------------------------------- 1 | from ._version import version as __version__ 2 | from models.grasp_det_seg.config import load_config 3 | import torch 4 | import models.grasp_det_seg.models as models 5 | from models.grasp_det_seg.algos.detection import PredictionGenerator, ProposalMatcher, DetectionLoss 6 | from models.grasp_det_seg.algos.fpn import RPNAlgoFPN, DetectionAlgoFPN 7 | from models.grasp_det_seg.algos.rpn import AnchorMatcher, ProposalGenerator, RPNLoss 8 | from models.grasp_det_seg.algos.semantic_seg import SemanticSegAlgo, SemanticSegLoss 9 | from models.grasp_det_seg.config import load_config 10 | from models.grasp_det_seg.models.det_seg import DetSegNet, NETWORK_INPUTS 11 | from models.grasp_det_seg.modules.fpn import FPN, FPNBody 12 | from models.grasp_det_seg.modules.heads import RPNHead, FPNSemanticHeadDeeplab, FPNROIHead 13 | from models.grasp_det_seg.utils.misc import norm_act_from_config, freeze_params, NORM_LAYERS, OTHER_LAYERS 14 | 15 | 16 | def make_config(args): 17 | args.config = 'models/grasp_det_seg/config/defaults/det_seg_OCID.ini' 18 | print("Loading configuration from %s", args.config) 19 | 20 | conf = load_config(args.config, args.config) 21 | 22 | return conf 23 | 24 | def make_model(config): 25 | body_config = config["body"] 26 | fpn_config = config["fpn"] 27 | rpn_config = config["rpn"] 28 | roi_config = config["roi"] 29 | sem_config = config["sem"] 30 | general_config = config["general"] 31 | classes = {"total": int(general_config["num_things"]) + int(general_config["num_stuff"]), "stuff": 32 | int(general_config["num_stuff"]), "thing": int(general_config["num_things"]), 33 | "semantic": int(general_config["num_semantic"])} 34 | # BN + activation 35 | norm_act_static, norm_act_dynamic = norm_act_from_config(body_config) 36 | 37 | # Create backbone 38 | print("Creating backbone model %s", body_config["body"]) 39 | body_fn = models.__dict__["net_" + body_config["body"]] 40 | body_params = body_config.getstruct("body_params") if body_config.get("body_params") else {} 41 | body = body_fn(norm_act=norm_act_static, **body_params) 42 | if body_config.get("weights"): 43 | body.load_state_dict(torch.load(body_config["weights"], map_location="cpu")) 44 | 45 | # Freeze parameters 46 | for n, m in body.named_modules(): 47 | for mod_id in range(1, body_config.getint("num_frozen") + 1): 48 | if ("mod%d" % mod_id) in n: 49 | freeze_params(m) 50 | 51 | body_channels = body_config.getstruct("out_channels") 52 | 53 | # Create FPN 54 | fpn_inputs = fpn_config.getstruct("inputs") 55 | fpn = FPN([body_channels[inp] for inp in fpn_inputs], 56 | fpn_config.getint("out_channels"), 57 | fpn_config.getint("extra_scales"), 58 | norm_act_static, 59 | fpn_config["interpolation"]) 60 | body = FPNBody(body, fpn, fpn_inputs) 61 | 62 | # Create RPN 63 | proposal_generator = ProposalGenerator(rpn_config.getfloat("nms_threshold"), 64 | rpn_config.getint("num_pre_nms_train"), 65 | rpn_config.getint("num_post_nms_train"), 66 | rpn_config.getint("num_pre_nms_val"), 67 | rpn_config.getint("num_post_nms_val"), 68 | rpn_config.getint("min_size")) 69 | anchor_matcher = AnchorMatcher(rpn_config.getint("num_samples"), 70 | rpn_config.getfloat("pos_ratio"), 71 | rpn_config.getfloat("pos_threshold"), 72 | rpn_config.getfloat("neg_threshold"), 73 | rpn_config.getfloat("void_threshold")) 74 | rpn_loss = RPNLoss(rpn_config.getfloat("sigma")) 75 | rpn_algo = RPNAlgoFPN( 76 | proposal_generator, anchor_matcher, rpn_loss, 77 | rpn_config.getint("anchor_scale"), rpn_config.getstruct("anchor_ratios"), 78 | fpn_config.getstruct("out_strides"), rpn_config.getint("fpn_min_level"), rpn_config.getint("fpn_levels")) 79 | rpn_head = RPNHead( 80 | fpn_config.getint("out_channels"), len(rpn_config.getstruct("anchor_ratios")), 1, 81 | rpn_config.getint("hidden_channels"), norm_act_dynamic) 82 | 83 | # Create detection network 84 | prediction_generator = PredictionGenerator(roi_config.getfloat("nms_threshold"), 85 | roi_config.getfloat("score_threshold"), 86 | roi_config.getint("max_predictions")) 87 | proposal_matcher = ProposalMatcher(classes, 88 | roi_config.getint("num_samples"), 89 | roi_config.getfloat("pos_ratio"), 90 | roi_config.getfloat("pos_threshold"), 91 | roi_config.getfloat("neg_threshold_hi"), 92 | roi_config.getfloat("neg_threshold_lo"), 93 | roi_config.getfloat("void_threshold")) 94 | roi_loss = DetectionLoss(roi_config.getfloat("sigma")) 95 | roi_size = roi_config.getstruct("roi_size") 96 | roi_algo = DetectionAlgoFPN( 97 | prediction_generator, proposal_matcher, roi_loss, classes, roi_config.getstruct("bbx_reg_weights"), 98 | roi_config.getint("fpn_canonical_scale"), roi_config.getint("fpn_canonical_level"), roi_size, 99 | roi_config.getint("fpn_min_level"), roi_config.getint("fpn_levels")) 100 | roi_head = FPNROIHead(fpn_config.getint("out_channels"), classes, roi_size, norm_act=norm_act_dynamic) 101 | 102 | # Create semantic segmentation network 103 | sem_loss = SemanticSegLoss(ohem=sem_config.getfloat("ohem")) 104 | sem_algo = SemanticSegAlgo(sem_loss, classes["semantic"]) 105 | sem_head = FPNSemanticHeadDeeplab(fpn_config.getint("out_channels"), 106 | sem_config.getint("fpn_min_level"), 107 | sem_config.getint("fpn_levels"), 108 | classes["semantic"], 109 | pooling_size=sem_config.getstruct("pooling_size"), 110 | norm_act=norm_act_static) 111 | 112 | return DetSegNet(body, rpn_head, roi_head, sem_head, rpn_algo, roi_algo, sem_algo, classes) -------------------------------------------------------------------------------- /models/grasp_det_seg/_version.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # file generated by setuptools_scm 3 | # don't change, don't track in version control 4 | version = '0.1.dev0' 5 | version_tuple = (0, 1, 'dev0') 6 | -------------------------------------------------------------------------------- /models/grasp_det_seg/algos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Grasp-Anything/d7755f43c5518bd6590b25021054f862e65bddd5/models/grasp_det_seg/algos/__init__.py -------------------------------------------------------------------------------- /models/grasp_det_seg/algos/semantic_seg.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | 3 | import torch 4 | import torch.nn.functional as functional 5 | 6 | from models.grasp_det_seg.utils.parallel import PackedSequence 7 | from models.grasp_det_seg.utils.sequence import pack_padded_images 8 | 9 | 10 | class SemanticSegLoss: 11 | """Semantic segmentation loss 12 | 13 | Parameters 14 | ---------- 15 | ohem : float or None 16 | Online hard example mining fraction, or `None` to disable OHEM 17 | ignore_index : int 18 | Index of the void class 19 | """ 20 | 21 | def __init__(self, ohem=None, ignore_index=255): 22 | if ohem is not None and (ohem <= 0 or ohem > 1): 23 | raise ValueError("ohem should be in (0, 1]") 24 | self.ohem = ohem 25 | self.ignore_index = ignore_index 26 | 27 | def __call__(self, sem_logits, sem): 28 | """Compute the semantic segmentation loss 29 | """ 30 | sem_loss = [] 31 | for sem_logits_i, sem_i in zip(sem_logits, sem): 32 | sem_loss_i = functional.cross_entropy( 33 | sem_logits_i.unsqueeze(0), sem_i.unsqueeze(0), ignore_index=self.ignore_index, reduction="none") 34 | sem_loss_i = sem_loss_i.view(-1) 35 | 36 | if self.ohem is not None and self.ohem != 1: 37 | top_k = int(ceil(sem_loss_i.numel() * self.ohem)) 38 | if top_k != sem_loss_i.numel(): 39 | sem_loss_i, _ = sem_loss_i.topk(top_k) 40 | 41 | sem_loss.append(sem_loss_i.mean()) 42 | 43 | return sum(sem_loss) / len(sem_logits) 44 | 45 | 46 | class SemanticSegAlgo: 47 | """Semantic segmentation algorithm 48 | """ 49 | 50 | def __init__(self, loss, num_classes, ignore_index=255): 51 | self.loss = loss 52 | self.num_classes = num_classes 53 | self.ignore_index = ignore_index 54 | 55 | @staticmethod 56 | def _pack_logits(sem_logits, valid_size, img_size): 57 | sem_logits = functional.interpolate(sem_logits, size=img_size, mode="bilinear", align_corners=False) 58 | return pack_padded_images(sem_logits, valid_size) 59 | 60 | def _confusion_matrix(self, sem_pred, sem): 61 | confmat = sem[0].new_zeros(self.num_classes * self.num_classes, dtype=torch.float) 62 | 63 | for sem_pred_i, sem_i in zip(sem_pred, sem): 64 | valid = sem_i != self.ignore_index 65 | if valid.any(): 66 | sem_pred_i = sem_pred_i[valid] 67 | sem_i = sem_i[valid] 68 | 69 | confmat.index_add_( 70 | 0, sem_i.view(-1) * self.num_classes + sem_pred_i.view(-1), confmat.new_ones(sem_i.numel())) 71 | 72 | return confmat.view(self.num_classes, self.num_classes) 73 | 74 | @staticmethod 75 | def _logits(head, x, valid_size, img_size): 76 | sem_logits, sem_feats = head(x) 77 | return sem_logits,SemanticSegAlgo._pack_logits(sem_logits, valid_size, img_size), sem_feats 78 | 79 | def training(self, head, x, sem, valid_size, img_size): 80 | """Given input features and ground truth compute semantic segmentation loss, confusion matrix and prediction 81 | """ 82 | # Compute logits and prediction 83 | sem_logits_low_res, sem_logits, sem_feats = self._logits(head, x, valid_size, img_size) 84 | sem_pred = PackedSequence([sem_logits_i.max(dim=0)[1] for sem_logits_i in sem_logits]) 85 | sem_pred_low_res = PackedSequence([sem_logits_low_res_i.max(dim=0)[1].float() for sem_logits_low_res_i in sem_logits_low_res]) 86 | 87 | # Compute loss and confusion matrix 88 | sem_loss = self.loss(sem_logits, sem) 89 | conf_mat = self._confusion_matrix(sem_pred, sem) 90 | 91 | return sem_loss, conf_mat, sem_pred,sem_logits,sem_logits_low_res,sem_pred_low_res,sem_feats 92 | 93 | def inference(self, head, x, valid_size, img_size): 94 | """Given input features compute semantic segmentation prediction 95 | """ 96 | sem_logits_low_res, sem_logits, sem_feats = self._logits(head, x, valid_size, img_size) 97 | sem_pred = PackedSequence([sem_logits_i.max(dim=0)[1] for sem_logits_i in sem_logits]) 98 | sem_pred_low_res = PackedSequence([sem_logits_low_res_i.max(dim=0)[1].float() for sem_logits_low_res_i in sem_logits_low_res]) 99 | 100 | return sem_pred, sem_feats, sem_pred_low_res 101 | 102 | 103 | def confusion_matrix(sem_pred, sem, num_classes, ignore_index=255): 104 | confmat = sem_pred.new_zeros(num_classes * num_classes, dtype=torch.float) 105 | 106 | valid = sem != ignore_index 107 | if valid.any(): 108 | sem_pred = sem_pred[valid] 109 | sem = sem[valid] 110 | 111 | confmat.index_add_(0, sem.view(-1) * num_classes + sem_pred.view(-1), confmat.new_ones(sem.numel())) 112 | 113 | return confmat.view(num_classes, num_classes) 114 | -------------------------------------------------------------------------------- /models/grasp_det_seg/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import load_config 2 | -------------------------------------------------------------------------------- /models/grasp_det_seg/config/config.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import configparser 3 | from os import path, listdir 4 | 5 | _CONVERTERS = { 6 | "struct": ast.literal_eval 7 | } 8 | 9 | def load_config(config_file, defaults_file): 10 | parser = configparser.ConfigParser(allow_no_value=True, converters=_CONVERTERS) 11 | parser.read([defaults_file, config_file]) 12 | return parser 13 | -------------------------------------------------------------------------------- /models/grasp_det_seg/config/defaults/det_seg_OCID.ini: -------------------------------------------------------------------------------- 1 | # GENERAL NOTE: the fields denoted as meta-info are not actual configuration parameters. Instead, they are used to 2 | # describe some characteristic of a network module that needs to be accessible from some other module but is hard to 3 | # determine in a generic way from within the code. A typical example is the total output stride of the network body. 4 | # These should be properly configured by the user to match the actual properties of the network. 5 | 6 | [general] 7 | # Number of epochs between validations 8 | val_interval = 25 9 | # Number of steps before outputting a log entry 10 | log_interval = 10 11 | cudnn_benchmark = no 12 | num_classes = 4 13 | num_stuff = 0 14 | num_things = 4 15 | # 0 - 31 16 | num_semantic = 4 17 | 18 | 19 | [body] 20 | # Architecture for the body 21 | body = resnet101 22 | # Path to pre-trained ImageNet weights 23 | weights = trained-models/grasp_det_seg/resnet101 24 | # Normalization mode: 25 | # -- bn: in-place batch norm everywhere 26 | # -- syncbn: synchronized in-place batch norm everywhere 27 | # -- syncbn+bn: synchronized in-place batch norm in the static part of the network, in-place batch norm everywhere else 28 | # -- gn: group norm everywhere 29 | # -- syncbn+gn: synchronized in-place batch norm in the static part of the network, group norm everywhere else 30 | # -- off: do not normalize activations (scale and bias are kept) 31 | normalization_mode = syncbn 32 | # Activation: 'leaky_relu' or 'elu' 33 | activation = leaky_relu 34 | activation_slope = 0.01 35 | # Group norm parameters 36 | gn_groups = 16 37 | # Additional parameters for the body 38 | body_params = {} 39 | # Number of frozen modules: in [1, 5] 40 | num_frozen = 2 41 | # Wether to freeze BN modules 42 | bn_frozen = yes 43 | # Meta-info 44 | out_channels = {"mod1": 64, "mod2": 256, "mod3": 512, "mod4": 1024, "mod5": 2048} 45 | out_strides = {"mod1": 4, "mod2": 4, "mod3": 8, "mod4": 16, "mod5": 32} 46 | 47 | [fpn] 48 | out_channels = 256 49 | extra_scales = 0 50 | interpolation = nearest 51 | # Input settings 52 | inputs = ["mod2", "mod3", "mod4", "mod5"] 53 | # Meta-info 54 | out_strides = (4, 8, 16, 32) 55 | 56 | [rpn] 57 | hidden_channels = 256 58 | stride = 1 59 | # Anchor settings 60 | anchor_ratios = (1., 0.1, 0.4, 0.7, 1.2) 61 | anchor_scale = 2 62 | # Proposal settings 63 | nms_threshold = 0.7 64 | num_pre_nms_train = 12000 65 | num_post_nms_train = 2000 66 | num_pre_nms_val = 6000 67 | num_post_nms_val = 300 68 | min_size = 16 69 | # Anchor matcher settings 70 | num_samples = 256 71 | pos_ratio = .5 72 | pos_threshold = .7 73 | neg_threshold = .3 74 | void_threshold = 0.7 75 | # FPN-specific settings 76 | fpn_min_level = 0 77 | fpn_levels = 3 78 | # Loss settings 79 | sigma = 3. 80 | 81 | [roi] 82 | roi_size = (14, 14) 83 | # Matcher settings 84 | num_samples = 128 85 | pos_ratio = .25 86 | pos_threshold = .5 87 | neg_threshold_hi = .5 88 | neg_threshold_lo = 0. 89 | void_threshold = 0.7 90 | void_is_background = no 91 | # Prediction generator settings 92 | nms_threshold = 0.3 93 | score_threshold = 0.05 94 | max_predictions = 100 95 | # FPN-specific settings 96 | fpn_min_level = 0 97 | fpn_levels = 4 98 | fpn_canonical_scale = 224 99 | fpn_canonical_level = 2 100 | # Loss settings 101 | sigma = 1. 102 | bbx_reg_weights = (10., 10., 5., 5.) 103 | 104 | [sem] 105 | fpn_min_level = 0 106 | fpn_levels = 4 107 | pooling_size = (64, 64) 108 | # Loss settings 109 | ohem = .25 110 | 111 | [optimizer] 112 | lr = 0.03 113 | weight_decay = 0.0001 114 | weight_decay_norm = yes 115 | momentum = 0.9 116 | nesterov = yes 117 | # obj, bbx, roi_cls, roi_bbx, sem 118 | loss_weights = (1., 1., 1., 1.,.75) 119 | 120 | [scheduler] 121 | epochs = 800 122 | # Scheduler type: 'linear', 'step', 'poly' or 'multistep' 123 | type = poly 124 | # When to update the learning rate: 'batch', 'epoch' 125 | update_mode = batch 126 | # Additional parameters for the scheduler 127 | # -- linear 128 | # from: initial lr multiplier 129 | # to: final lr multiplier 130 | # -- step 131 | # step_size: number of steps between lr decreases 132 | # gamma: multiplicative factor 133 | # -- poly 134 | # gamma: exponent of the polynomial 135 | # -- multistep 136 | # milestones: step indicies where the lr decreases will be triggered 137 | params = {"gamma": 0.9} 138 | burn_in_steps = 500 139 | burn_in_start = 0.333 140 | 141 | [dataloader] 142 | # Absolute path to the project 143 | root_path = ./GraspDetSeg_CNN 144 | # Image size parameters 145 | shortest_size = 224 146 | longest_max_size = 224 147 | # Batch size 148 | train_batch_size = 10 149 | val_batch_size = 1 150 | # Augmentation parameters 151 | rgb_mean = (0.485, 0.456, 0.406) 152 | rgb_std = (0.229, 0.224, 0.225) 153 | random_flip = no 154 | random_scale = None 155 | rotate_and_scale = True 156 | # Number of worker threads 157 | num_workers = 6 158 | # Subsets 159 | train_set = training_0 160 | val_set = validation_0 161 | test_set = validation_0 -------------------------------------------------------------------------------- /models/grasp_det_seg/data_OCID/OCID_class_dict.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | cls_names = { 4 | 'background' : '0', 5 | 'apple' : '1', 6 | 'ball' : '2', 7 | 'banana' : '3', 8 | 'bell_pepper' : '4', 9 | 'binder' : '5', 10 | 'bowl' : '6', 11 | 'cereal_box' : '7', 12 | 'coffee_mug' : '8', 13 | 'flashlight' : '9', 14 | 'food_bag' : '10', 15 | 'food_box' : '11', 16 | 'food_can' : '12', 17 | 'glue_stick' : '13', 18 | 'hand_towel' : '14', 19 | 'instant_noodles' : '15', 20 | 'keyboard' : '16', 21 | 'kleenex' : '17', 22 | 'lemon' : '18', 23 | 'lime' : '19', 24 | 'marker' : '20', 25 | 'orange' : '21', 26 | 'peach' : '22', 27 | 'pear' : '23', 28 | 'potato' : '24', 29 | 'shampoo' : '25', 30 | 'soda_can' : '26', 31 | 'sponge' : '27', 32 | 'stapler' : '28', 33 | 'tomato' : '29', 34 | 'toothpaste' : '30', 35 | 'unknown' : '31' 36 | } 37 | 38 | colors = { 39 | '0': np.array([0, 0, 0]), 40 | '1': np.array([ 211, 47, 47 ]), 41 | '2': np.array([ 0, 255, 0]), 42 | '3': np.array([123, 31, 162]), 43 | '4': np.array([ 81, 45, 168 ]), 44 | '5': np.array([ 48, 63, 159 ]), 45 | '6': np.array([25, 118, 210]), 46 | '7': np.array([ 2, 136, 209 ]), 47 | '8': np.array([ 153, 51, 102 ]), 48 | '9': np.array([ 0, 121, 107 ]), 49 | '10': np.array([ 56, 142, 60 ]), 50 | '11': np.array([ 104, 159, 56 ]), 51 | '12': np.array([ 175, 180, 43 ]), 52 | '13': np.array([ 251, 192, 45 ]), 53 | '14': np.array([ 255, 160, 0 ]), 54 | '15': np.array([ 245, 124, 0 ]), 55 | '16': np.array([ 230, 74, 25 ]), 56 | '17': np.array([ 93, 64, 55 ]), 57 | '18': np.array([ 97, 97, 97 ]), 58 | '19': np.array([ 84, 110, 122 ]), 59 | '20': np.array([ 255, 255, 102]), 60 | '21': np.array([ 0, 151, 167 ]), 61 | '22': np.array([ 153, 255, 102 ]), 62 | '23': np.array([ 51, 255, 102 ]), 63 | '24': np.array([ 0, 255, 255 ]), 64 | '25': np.array([ 255, 255, 255 ]), 65 | '26': np.array([ 255, 204, 204 ]), 66 | '27': np.array([ 153, 102, 0 ]), 67 | '28': np.array([ 204, 255, 204 ]), 68 | '29': np.array([ 204, 255, 0 ]), 69 | '30': np.array([ 255, 0, 255 ]), 70 | '31': np.array([ 194, 24, 91 ]), 71 | } 72 | 73 | colors_list = list(colors.values()) 74 | cls_list = list(cls_names.keys()) -------------------------------------------------------------------------------- /models/grasp_det_seg/data_OCID/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import OCIDDataset, OCIDTestDataset 2 | from .misc import iss_collate_fn, read_boxes_from_file, prepare_frcnn_format 3 | from .transform import OCIDTransform, OCIDTestTransform -------------------------------------------------------------------------------- /models/grasp_det_seg/data_OCID/dataset.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import cv2 3 | import numpy as np 4 | import torch.utils.data as data 5 | import os 6 | from PIL import Image 7 | 8 | 9 | class OCIDDataset(data.Dataset): 10 | """OCID_grasp dataset for grasp detection and semantic segmentation 11 | """ 12 | 13 | def __init__(self, data_path, root_dir, split_name, transform): 14 | super(OCIDDataset, self).__init__() 15 | self.data_path = data_path 16 | self.root_dir = root_dir 17 | self.split_name = split_name 18 | self.transform = transform 19 | 20 | self._images = self._load_split() 21 | 22 | def _load_split(self): 23 | with open(path.join(self.data_path, self.split_name + ".txt"), "r") as fid: 24 | images = [x.strip() for x in fid.readlines()] 25 | 26 | return images 27 | 28 | def _load_item(self, item): 29 | seq_path, im_name = item.split(',') 30 | sample_path = os.path.join(self.root_dir, seq_path) 31 | img_path = os.path.join(sample_path, 'rgb', im_name) 32 | mask_path = os.path.join(sample_path, 'seg_mask_labeled_combi', im_name) 33 | anno_path = os.path.join(sample_path, 'Annotations', im_name[:-4] + '.txt') 34 | img_bgr = cv2.imread(img_path) 35 | img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) 36 | 37 | with open(anno_path, "r") as f: 38 | points_list = [] 39 | boxes_list = [] 40 | for count, line in enumerate(f): 41 | line = line.rstrip() 42 | [x, y] = line.split(' ') 43 | 44 | x = float(x) 45 | y = float(y) 46 | 47 | pt = (x, y) 48 | points_list.append(pt) 49 | 50 | if len(points_list) == 4: 51 | boxes_list.append(points_list) 52 | points_list = [] 53 | 54 | msk = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED) 55 | box_arry = np.asarray(boxes_list) 56 | return img, msk, box_arry 57 | 58 | @property 59 | def categories(self): 60 | """Category names""" 61 | return self._meta["categories"] 62 | 63 | @property 64 | def num_categories(self): 65 | """Number of categories""" 66 | return len(self.categories) 67 | 68 | @property 69 | def num_stuff(self): 70 | """Number of "stuff" categories""" 71 | return self._meta["num_stuff"] 72 | 73 | @property 74 | def num_thing(self): 75 | """Number of "thing" categories""" 76 | return self.num_categories - self.num_stuff 77 | 78 | @property 79 | def original_ids(self): 80 | """Original class id of each category""" 81 | return self._meta["original_ids"] 82 | 83 | @property 84 | def palette(self): 85 | """Default palette to be used when color-coding semantic labels""" 86 | return np.array(self._meta["palette"], dtype=np.uint8) 87 | 88 | @property 89 | def img_sizes(self): 90 | """Size of each image of the dataset""" 91 | return [img_desc["size"] for img_desc in self._images] 92 | 93 | @property 94 | def img_categories(self): 95 | """Categories present in each image of the dataset""" 96 | return [img_desc["cat"] for img_desc in self._images] 97 | 98 | @property 99 | def get_images(self): 100 | """Categories present in each image of the dataset""" 101 | return self._images 102 | 103 | def __len__(self): 104 | return len(self._images) 105 | 106 | def __getitem__(self, item): 107 | im_rgb, msk, bbox_infos = self._load_item(item) 108 | 109 | rec, im_size = self.transform(im_rgb, msk, bbox_infos) 110 | 111 | rec["abs_path"] = item 112 | rec["root_path"] = self.root_dir 113 | rec["im_size"] = im_size 114 | return rec 115 | 116 | def get_raw_image(self, idx): 117 | """Load a single, unmodified image with given id from the dataset""" 118 | img_file = path.join(self._img_dir, idx) 119 | if path.exists(img_file + ".png"): 120 | img_file = img_file + ".png" 121 | elif path.exists(img_file + ".jpg"): 122 | img_file = img_file + ".jpg" 123 | else: 124 | raise IOError("Cannot find any image for id {} in {}".format(idx, self._img_dir)) 125 | 126 | return Image.open(img_file) 127 | 128 | def get_image_desc(self, idx): 129 | """Look up an image descriptor given the id""" 130 | matching = [img_desc for img_desc in self._images if img_desc["id"] == idx] 131 | if len(matching) == 1: 132 | return matching[0] 133 | else: 134 | raise ValueError("No image found with id %s" % idx) 135 | 136 | 137 | class OCIDTestDataset(data.Dataset): 138 | 139 | def __init__(self, data_path, root_dir, split_name, transform): 140 | super(OCIDTestDataset, self).__init__() 141 | self.data_path = data_path 142 | self.root_dir = root_dir 143 | self.split_name = split_name 144 | self.transform = transform 145 | 146 | self._images = self._load_split() 147 | 148 | def _load_split(self): 149 | with open(path.join(self.data_path, self.split_name + ".txt"), "r") as fid: 150 | images = [x.strip() for x in fid.readlines()] 151 | return images 152 | 153 | @property 154 | def img_sizes(self): 155 | """Size of each image of the dataset""" 156 | return [img_desc["size"] for img_desc in self._images] 157 | 158 | @property 159 | def get_images(self): 160 | """Categories present in each image of the dataset""" 161 | return self._images 162 | 163 | def __len__(self): 164 | return len(self._images) 165 | 166 | def __getitem__(self, item): 167 | seq_path, im_name = item.split(',') 168 | sample_path = os.path.join(self.root_dir, seq_path) 169 | img_path = os.path.join(sample_path, 'rgb', im_name) 170 | img_bgr = cv2.imread(img_path) 171 | im_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) 172 | 173 | img_, im_size = self.transform(im_rgb) 174 | 175 | return {"img": img_, 176 | "root_path": self.root_dir, 177 | "abs_path": item, 178 | "im_size": im_size 179 | } 180 | -------------------------------------------------------------------------------- /models/grasp_det_seg/data_OCID/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from grasp_det_seg.utils.parallel import PackedSequence 5 | 6 | 7 | def iss_collate_fn(items): 8 | """Collate function for ISS batches""" 9 | out = {} 10 | if len(items) > 0: 11 | for key in items[0]: 12 | out[key] = [item[key] for item in items] 13 | if isinstance(items[0][key], torch.Tensor): 14 | out[key] = PackedSequence(out[key]) 15 | return out 16 | 17 | def prepare_frcnn_format(boxes,im_size): 18 | boxes_ary = np.asarray(boxes) 19 | 20 | boxes_ary = np.swapaxes(boxes_ary, 1, 2) 21 | xy_ctr = np.sum(boxes_ary, axis=2) / 4 22 | x_ctr = xy_ctr[:, 0] 23 | y_ctr = xy_ctr[:, 1] 24 | width = np.sqrt(np.sum((boxes_ary[:, :, 0] - boxes_ary[:, :, 1]) ** 2, axis=1)) 25 | height = np.sqrt(np.sum((boxes_ary[:, :, 1] - boxes_ary[:, :, 2]) ** 2, axis=1)) 26 | 27 | theta = np.zeros((boxes_ary.shape[0]), dtype=np.int) 28 | theta = np.arctan((boxes_ary[:, 1, 1] - boxes_ary[:, 1, 0]) / (boxes_ary[:, 0, 0] - boxes_ary[:, 0, 1])) 29 | b = np.arctan((boxes_ary[:, 1, 0] - boxes_ary[:, 1, 1]) / (boxes_ary[:, 0, 1] - boxes_ary[:, 0, 0])) 30 | theta[np.where(boxes_ary[:, 0, 0] <= boxes_ary[:, 0, 1])] = b[np.where(boxes_ary[:, 0, 0] <= boxes_ary[:, 0, 1])] 31 | 32 | # used for fasterrcnn loss 33 | x_min = x_ctr - width / 2 34 | x_max = x_ctr + width / 2 35 | y_min = y_ctr - height / 2 36 | y_max = y_ctr + height / 2 37 | 38 | x_coords = np.vstack((x_min, x_max)) 39 | y_coords = np.vstack((y_min, y_max)) 40 | 41 | mat = np.asarray((np.all(x_coords > im_size[1], axis=0), np.all(x_coords < 0, axis=0), 42 | np.all(y_coords > im_size[0], axis=0), np.all(y_coords < 0, axis=0))) 43 | 44 | fail = np.any(mat, axis=0) 45 | correct_idx = np.where(fail == False) 46 | theta_deg = np.rad2deg(theta) + 90 47 | cls = (np.round((theta_deg) / (180 / 18))).astype(int) 48 | cls[np.where(cls == 18)] = 0 49 | 50 | ret_value = (boxes_ary[correct_idx], theta_deg[correct_idx],cls[correct_idx]) 51 | return ret_value 52 | 53 | def read_boxes_from_file(gt_path,delta_xy): 54 | with open(gt_path)as f: 55 | points_list = [] 56 | box_list = [] 57 | for count, line in enumerate(f): 58 | line = line.rstrip() 59 | [x, y] = line.split(' ') 60 | x = float(x) - int(delta_xy[0]) 61 | y = float(y) - int(delta_xy[1]) 62 | 63 | pt = (x, y) 64 | points_list.append(pt) 65 | 66 | if len(points_list) == 4: 67 | box_list.append(points_list) 68 | points_list = [] 69 | return box_list 70 | -------------------------------------------------------------------------------- /models/grasp_det_seg/data_OCID/sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import distributed 5 | from torch.utils.data.sampler import Sampler 6 | 7 | 8 | class ARBatchSampler(Sampler): 9 | def __init__(self, data_source, batch_size, drop_last=False, epoch=0): 10 | super(ARBatchSampler, self).__init__(data_source) 11 | self.data_source = data_source 12 | self.batch_size = batch_size 13 | self.drop_last = drop_last 14 | self._epoch = epoch 15 | 16 | # Split images by orientation 17 | self.img_sets = self.data_source.get_images 18 | 19 | def _split_images(self, indices): 20 | # returns lists of [im_id, aspect_ratio] 21 | 22 | img_sizes = self.data_source.img_sizes 23 | img_sets = [[], []] 24 | for img_id in indices: 25 | aspect_ratio = img_sizes[img_id][0] / img_sizes[img_id][1] 26 | if aspect_ratio < 1: 27 | img_sets[0].append({"id": img_id, "ar": aspect_ratio}) 28 | else: 29 | img_sets[1].append({"id": img_id, "ar": aspect_ratio}) 30 | 31 | return img_sets 32 | 33 | def _generate_batches(self): 34 | g = torch.Generator() 35 | g.manual_seed(self._epoch) 36 | 37 | self.img_sets = [self.img_sets[i] for i in torch.randperm(len(self.img_sets), generator=g)] 38 | 39 | batches = [] 40 | leftover = [] 41 | batch = [] 42 | for img in self.img_sets: 43 | batch.append(img) 44 | if len(batch) == self.batch_size: 45 | batches.append(batch) 46 | batch = [] 47 | leftover += batch 48 | 49 | if not self.drop_last: 50 | batch = [] 51 | for img in leftover: 52 | batch.append(img) 53 | if len(batch) == self.batch_size: 54 | batches.append(batch) 55 | batch = [] 56 | 57 | if len(batch) != 0: 58 | batches.append(batch) 59 | 60 | return batches 61 | 62 | def set_epoch(self, epoch): 63 | self._epoch = epoch 64 | 65 | def __len__(self): 66 | if self.drop_last: 67 | return len(self.img_sets) // self.batch_size 68 | else: 69 | return (len(self.img_sets) + self.batch_size - 1) // self.batch_size 70 | 71 | 72 | def __iter__(self): 73 | batches = self._generate_batches() 74 | for batch in batches: 75 | batch = sorted(batch, key=lambda i: i["ar"]) 76 | batch = [i["id"] for i in batch] 77 | yield batch 78 | 79 | 80 | class DistributedARBatchSampler(ARBatchSampler): 81 | def __init__(self, data_source, batch_size, num_replicas=None, rank=None, drop_last=False, epoch=0): 82 | super(DistributedARBatchSampler, self).__init__(data_source, batch_size, drop_last, epoch) 83 | 84 | # Automatically get world size and rank if not provided 85 | if num_replicas is None: 86 | num_replicas = distributed.get_world_size() 87 | if rank is None: 88 | rank = distributed.get_rank() 89 | 90 | self.num_replicas = num_replicas 91 | self.rank = rank 92 | 93 | tot_batches = super(DistributedARBatchSampler, self).__len__() 94 | self.num_batches = int(math.ceil(tot_batches / self.num_replicas)) 95 | 96 | def __len__(self): 97 | return self.num_batches 98 | 99 | def __iter__(self): 100 | batches = self._generate_batches() 101 | 102 | g = torch.Generator() 103 | g.manual_seed(self._epoch) 104 | indices = list(torch.randperm(len(batches), generator=g)) 105 | 106 | # add extra samples to make it evenly divisible 107 | indices += indices[:(self.num_batches * self.num_replicas - len(indices))] 108 | assert len(indices) == self.num_batches * self.num_replicas 109 | 110 | # subsample 111 | offset = self.num_batches * self.rank 112 | indices = indices[offset:offset + self.num_batches] 113 | assert len(indices) == self.num_batches 114 | 115 | for idx in indices: 116 | yield batches[idx] 117 | -------------------------------------------------------------------------------- /models/grasp_det_seg/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | -------------------------------------------------------------------------------- /models/grasp_det_seg/models/det_seg.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | from models.grasp_det_seg.utils.sequence import pad_packed_images 8 | from inference.models.grasp_model import GraspModel 9 | from utils.dataset_processing.grasp import Grasp 10 | from utils.dataset_processing.grasp import GraspRectangles 11 | 12 | def _gr_text_to_no(l, offset=(0, 0)): 13 | """ 14 | Transform a single point from a Cornell file line to a pair of ints. 15 | :param l: Line from Cornell grasp file (str) 16 | :param offset: Offset to apply to point positions 17 | :return: Point [y, x] 18 | """ 19 | x, y = l.split() 20 | return [int(round(float(y))) - offset[0], int(round(float(x))) - offset[1]] 21 | 22 | NETWORK_INPUTS = ["img", "msk", "bbx"] 23 | 24 | class DetSegNet(GraspModel): 25 | def __init__(self, 26 | body, 27 | rpn_head, 28 | roi_head, 29 | sem_head, 30 | rpn_algo, 31 | detection_algo, 32 | semantic_seg_algo, 33 | classes): 34 | super(DetSegNet, self).__init__() 35 | self.num_stuff = classes["stuff"] 36 | 37 | # Modules 38 | self.body = body 39 | self.rpn_head = rpn_head 40 | self.roi_head = roi_head 41 | self.sem_head = sem_head 42 | 43 | # Algorithms 44 | self.rpn_algo = rpn_algo 45 | self.detection_algo = detection_algo 46 | self.semantic_seg_algo = semantic_seg_algo 47 | 48 | def _prepare_inputs(self, msk, cat, iscrowd, bbx): 49 | cat_out, iscrowd_out, bbx_out, ids_out, sem_out = [], [], [], [], [] 50 | for msk_i, cat_i, iscrowd_i, bbx_i in zip(msk, cat, iscrowd, bbx): 51 | msk_i = msk_i.squeeze(0) 52 | thing = (cat_i >= self.num_stuff) & (cat_i != 255) 53 | valid = thing & ~iscrowd_i 54 | 55 | if valid.any().item(): 56 | cat_out.append(cat_i[valid]) 57 | bbx_out.append(bbx_i[valid]) 58 | ids_out.append(torch.nonzero(valid)) 59 | else: 60 | cat_out.append(None) 61 | bbx_out.append(None) 62 | ids_out.append(None) 63 | 64 | if iscrowd_i.any().item(): 65 | iscrowd_i = iscrowd_i & thing 66 | iscrowd_out.append(iscrowd_i[msk_i]) 67 | else: 68 | iscrowd_out.append(None) 69 | 70 | sem_out.append(cat_i[msk_i]) 71 | 72 | return cat_out, iscrowd_out, bbx_out, ids_out, sem_out 73 | 74 | def _convert_to_bb(self, bbx_pred, angle_pred): 75 | grs = [] 76 | # bbxs = torch.cat(tuple(bbx_pred), dim=0) 77 | # angle_preds = torch.cat(tuple(angle_pred), dim=0) 78 | bbxs = bbx_pred[0] 79 | angle_preds = angle_pred[0] 80 | for i, bbx in enumerate(bbxs): 81 | x, y, w, h = bbx.tolist() 82 | theta = angle_preds[i].item() 83 | grs.append(Grasp(np.array([y, x]), -theta / 180.0 * np.pi, w, h).as_gr) 84 | grs = GraspRectangles(grs) 85 | # grs = grs.scale(scale) 86 | return grs 87 | 88 | def _numpy_to_torch(self, s): 89 | if len(s.shape) == 2: 90 | return torch.from_numpy(np.expand_dims(s, 0).astype(np.float32)) 91 | else: 92 | return torch.from_numpy(s.astype(np.float32)) 93 | 94 | 95 | def forward(self, img, msk=None, cat=None, iscrowd=None, bbx=None, do_loss=False, do_prediction=True): 96 | # Pad the input images 97 | output_size = img.shape[-1] 98 | device = img.device 99 | img.requires_grad_(True) 100 | img, valid_size = pad_packed_images(img) 101 | img_size = img.shape[-2:] 102 | 103 | # Convert ground truth to the internal format 104 | if do_loss: 105 | sem, _ = pad_packed_images(msk) 106 | msk, _ = pad_packed_images(msk) 107 | 108 | # Run network body 109 | x = self.body(img) 110 | 111 | # RPN part 112 | if do_loss: 113 | obj_loss, bbx_loss, proposals = self.rpn_algo.training( 114 | self.rpn_head, x, bbx, iscrowd, valid_size, training=self.training, do_inference=True) 115 | elif do_prediction: 116 | proposals = self.rpn_algo.inference(self.rpn_head, x, valid_size, self.training) 117 | obj_loss, bbx_loss = None, None 118 | else: 119 | obj_loss, bbx_loss, proposals = None, None, None 120 | 121 | # ROI part 122 | if do_loss: 123 | roi_cls_loss, roi_bbx_loss = self.detection_algo.training( 124 | self.roi_head, x, proposals, bbx, cat, iscrowd, img_size) 125 | else: 126 | roi_cls_loss, roi_bbx_loss = None, None 127 | if do_prediction: 128 | bbx_pred, cls_pred, obj_pred = self.detection_algo.inference( 129 | self.roi_head, x, proposals, valid_size, img_size) 130 | else: 131 | bbx_pred, cls_pred, obj_pred = None, None, None 132 | 133 | # Segmentation part 134 | # if do_loss: 135 | # sem_loss, conf_mat, sem_pred,sem_logits,sem_logits_low_res, sem_pred_low_res, sem_feats =\ 136 | # self.semantic_seg_algo.training(self.sem_head, x, sem, valid_size, img_size) 137 | # elif do_prediction: 138 | # sem_pred,sem_feats,_ = self.semantic_seg_algo.inference(self.sem_head, x, valid_size, img_size) 139 | # sem_loss, conf_mat = None, None 140 | # else: 141 | # sem_loss, conf_mat, sem_pred, sem_feats = None, None, None, None 142 | 143 | grs = self._convert_to_bb(bbx_pred, angle_pred=obj_pred) 144 | pos_img, ang_img, width_img = grs.draw((output_size, output_size)) 145 | pos = self._numpy_to_torch(pos_img).to(device) 146 | cos = self._numpy_to_torch(np.cos(2 * ang_img)).to(device) 147 | sin = self._numpy_to_torch(np.sin(2 * ang_img)).to(device) 148 | width = self._numpy_to_torch(width_img).to(device) 149 | 150 | pos.requires_grad_(True) 151 | cos.requires_grad_(True) 152 | sin.requires_grad_(True) 153 | width.requires_grad_(True) 154 | return pos, cos, sin, width 155 | # Prepare outputs 156 | # loss = OrderedDict([ 157 | # ("obj_loss", obj_loss), 158 | # ("bbx_loss", bbx_loss), 159 | # ("roi_cls_loss", roi_cls_loss), 160 | # ("roi_bbx_loss", roi_bbx_loss), 161 | # ("sem_loss", sem_loss) 162 | # ]) 163 | # pred = OrderedDict([ 164 | # ("bbx_pred", bbx_pred), 165 | # ("cls_pred", cls_pred), 166 | # ("obj_pred", obj_pred), 167 | # ("sem_pred", sem_pred) 168 | # ]) 169 | # conf = OrderedDict([ 170 | # ("sem_conf", conf_mat) 171 | # ]) 172 | # return loss, pred, conf 173 | -------------------------------------------------------------------------------- /models/grasp_det_seg/models/resnet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import OrderedDict 3 | from functools import partial 4 | 5 | import torch.nn as nn 6 | # from inplace_abn import ABN 7 | 8 | from models.grasp_det_seg.modules.misc import GlobalAvgPool2d 9 | from models.grasp_det_seg.modules.residual import ResidualBlock 10 | from models.grasp_det_seg.utils.misc import try_index 11 | 12 | 13 | class ResNet(nn.Module): 14 | """Standard residual network 15 | 16 | Parameters 17 | ---------- 18 | structure : list of int 19 | Number of residual blocks in each of the four modules of the network 20 | bottleneck : bool 21 | If `True` use "bottleneck" residual blocks with 3 convolutions, otherwise use standard blocks 22 | norm_act : callable or list of callable 23 | Function to create normalization / activation Module. If a list is passed it should have four elements, one for 24 | each module of the network 25 | classes : int 26 | If not `0` also include global average pooling and a fully-connected layer with `classes` outputs at the end 27 | of the network 28 | dilation : int or list of int 29 | List of dilation factors for the four modules of the network, or `1` to ignore dilation 30 | dropout : list of float or None 31 | If present, specifies the amount of dropout to apply in the blocks of each of the four modules of the network 32 | caffe_mode : bool 33 | If `True`, use bias in the first convolution for compatibility with the Caffe pretrained models 34 | """ 35 | 36 | def __init__(self, 37 | structure, 38 | bottleneck, 39 | norm_act=nn.BatchNorm2d, 40 | classes=0, 41 | dilation=1, 42 | dropout=None, 43 | caffe_mode=False): 44 | super(ResNet, self).__init__() 45 | self.structure = structure 46 | self.bottleneck = bottleneck 47 | self.dilation = dilation 48 | self.dropout = dropout 49 | self.caffe_mode = caffe_mode 50 | 51 | if len(structure) != 4: 52 | raise ValueError("Expected a structure with four values") 53 | if dilation != 1 and len(dilation) != 4: 54 | raise ValueError("If dilation is not 1 it must contain four values") 55 | 56 | # Initial layers 57 | layers = [ 58 | ("conv1", nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=caffe_mode)), 59 | ("bn1", nn.BatchNorm2d(64)) 60 | ] 61 | if try_index(dilation, 0) == 1: 62 | layers.append(("pool1", nn.MaxPool2d(3, stride=2, padding=1))) 63 | self.mod1 = nn.Sequential(OrderedDict(layers)) 64 | 65 | # Groups of residual blocks 66 | in_channels = 64 67 | if self.bottleneck: 68 | channels = (64, 64, 256) 69 | else: 70 | channels = (64, 64) 71 | for mod_id, num in enumerate(structure): 72 | mod_dropout = None 73 | if self.dropout is not None: 74 | if self.dropout[mod_id] is not None: 75 | mod_dropout = partial(nn.Dropout, p=self.dropout[mod_id]) 76 | 77 | # Create blocks for module 78 | blocks = [] 79 | for block_id in range(num): 80 | stride, dil = self._stride_dilation(dilation, mod_id, block_id) 81 | blocks.append(( 82 | "block%d" % (block_id + 1), 83 | ResidualBlock(in_channels, channels, norm_act=norm_act, 84 | stride=stride, dilation=dil, dropout=mod_dropout) 85 | )) 86 | 87 | # Update channels and p_keep 88 | in_channels = channels[-1] 89 | 90 | # Create module 91 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) 92 | 93 | # Double the number of channels for the next module 94 | channels = [c * 2 for c in channels] 95 | 96 | # Pooling and predictor 97 | if classes != 0: 98 | self.classifier = nn.Sequential(OrderedDict([ 99 | ("avg_pool", GlobalAvgPool2d()), 100 | ("fc", nn.Linear(in_channels, classes)) 101 | ])) 102 | 103 | @staticmethod 104 | def _stride_dilation(dilation, mod_id, block_id): 105 | d = try_index(dilation, mod_id) 106 | s = 2 if d == 1 and block_id == 0 and mod_id > 0 else 1 107 | return s, d 108 | 109 | def forward(self, x): 110 | outs = OrderedDict() 111 | 112 | outs["mod1"] = self.mod1(x) 113 | outs["mod2"] = self.mod2(outs["mod1"]) 114 | outs["mod3"] = self.mod3(outs["mod2"]) 115 | outs["mod4"] = self.mod4(outs["mod3"]) 116 | outs["mod5"] = self.mod5(outs["mod4"]) 117 | 118 | if hasattr(self, "classifier"): 119 | outs["classifier"] = self.classifier(outs["mod5"]) 120 | 121 | return outs 122 | 123 | 124 | _NETS = { 125 | "18": {"structure": [2, 2, 2, 2], "bottleneck": False}, 126 | "34": {"structure": [3, 4, 6, 3], "bottleneck": False}, 127 | "50": {"structure": [3, 4, 6, 3], "bottleneck": True}, 128 | "101": {"structure": [3, 4, 23, 3], "bottleneck": True}, 129 | "152": {"structure": [3, 8, 36, 3], "bottleneck": True}, 130 | } 131 | 132 | __all__ = [] 133 | for name, params in _NETS.items(): 134 | net_name = "net_resnet" + name 135 | setattr(sys.modules[__name__], net_name, partial(ResNet, **params)) 136 | __all__.append(net_name) 137 | -------------------------------------------------------------------------------- /models/grasp_det_seg/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Grasp-Anything/d7755f43c5518bd6590b25021054f862e65bddd5/models/grasp_det_seg/modules/__init__.py -------------------------------------------------------------------------------- /models/grasp_det_seg/modules/fpn.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as functional 5 | # from inplace_abn import ABN 6 | 7 | 8 | class FPN(nn.Module): 9 | """Feature Pyramid Network module 10 | 11 | Parameters 12 | ---------- 13 | in_channels : sequence of int 14 | Number of feature channels in each of the input feature levels 15 | out_channels : int 16 | Number of output feature channels (same for each level) 17 | extra_scales : int 18 | Number of extra low-resolution scales 19 | norm_act : callable 20 | Function to create normalization + activation modules 21 | interpolation : str 22 | Interpolation mode to use when up-sampling, see `torch.nn.functional.interpolate` 23 | """ 24 | 25 | def __init__(self, in_channels, out_channels=256, extra_scales=0, norm_act=nn.BatchNorm2d, interpolation="nearest"): 26 | super(FPN, self).__init__() 27 | self.interpolation = interpolation 28 | 29 | # Lateral connections and output convolutions 30 | self.lateral = nn.ModuleList([ 31 | self._make_lateral(channels, out_channels, norm_act) for channels in in_channels 32 | ]) 33 | self.output = nn.ModuleList([ 34 | self._make_output(out_channels, norm_act) for _ in in_channels 35 | ]) 36 | 37 | if extra_scales > 0: 38 | self.extra = nn.ModuleList([ 39 | self._make_extra(in_channels[-1] if i == 0 else out_channels, out_channels, norm_act) 40 | for i in range(extra_scales) 41 | ]) 42 | 43 | self.reset_parameters() 44 | 45 | def reset_parameters(self): 46 | gain = nn.init.calculate_gain("relu", self.lateral[0].bn.weight) 47 | for mod in self.modules(): 48 | if isinstance(mod, nn.Conv2d): 49 | nn.init.xavier_normal_(mod.weight, gain) 50 | elif isinstance(mod, nn.BatchNorm2d): 51 | nn.init.constant_(mod.weight, 1.) 52 | if hasattr(mod, "bias") and mod.bias is not None: 53 | nn.init.constant_(mod.bias, 0.) 54 | 55 | @staticmethod 56 | def _make_lateral(input_channels, hidden_channels, norm_act): 57 | return nn.Sequential(OrderedDict([ 58 | ("conv", nn.Conv2d(input_channels, hidden_channels, 1, bias=False)), 59 | ("bn", nn.BatchNorm2d(hidden_channels)) 60 | ])) 61 | 62 | @staticmethod 63 | def _make_output(channels, norm_act): 64 | return nn.Sequential(OrderedDict([ 65 | ("conv", nn.Conv2d(channels, channels, 3, padding=1, bias=False)), 66 | ("bn", nn.BatchNorm2d(channels)) 67 | ])) 68 | 69 | @staticmethod 70 | def _make_extra(input_channels, out_channels, norm_act): 71 | return nn.Sequential(OrderedDict([ 72 | ("conv", nn.Conv2d(input_channels, out_channels, 3, stride=2, padding=1, bias=False)), 73 | ("bn", norm_act(out_channels)) 74 | ])) 75 | 76 | def forward(self, xs): 77 | """Feature Pyramid Network module 78 | 79 | Parameters 80 | ---------- 81 | xs : sequence of torch.Tensor 82 | The input feature maps, tensors with shapes N x C_i x H_i x W_i 83 | 84 | Returns 85 | ------- 86 | ys : sequence of torch.Tensor 87 | The output feature maps, tensors with shapes N x K x H_i x W_i 88 | """ 89 | ys = [] 90 | interp_params = {"mode": self.interpolation} 91 | if self.interpolation == "bilinear": 92 | interp_params["align_corners"] = False 93 | 94 | # Build pyramid 95 | for x_i, lateral_i in zip(xs[::-1], self.lateral[::-1]): 96 | x_i = lateral_i(x_i) 97 | if len(ys) > 0: 98 | x_i = x_i + functional.interpolate(ys[0], size=x_i.shape[-2:], **interp_params) 99 | ys.insert(0, x_i) 100 | 101 | # Compute outputs 102 | ys = [output_i(y_i) for y_i, output_i in zip(ys, self.output)] 103 | 104 | # Compute extra outputs if necessary 105 | if hasattr(self, "extra"): 106 | y = xs[-1] 107 | for extra_i in self.extra: 108 | y = extra_i(y) 109 | ys.append(y) 110 | 111 | return ys 112 | 113 | 114 | class FPNBody(nn.Module): 115 | """Wrapper for a backbone network and an FPN module 116 | 117 | Parameters 118 | ---------- 119 | backbone : torch.nn.Module 120 | Backbone network, which takes a batch of images and produces a dictionary of intermediate features 121 | fpn : torch.nn.Module 122 | FPN module, which takes a list of intermediate features and produces a list of outputs 123 | fpn_inputs : iterable 124 | An iterable producing the names of the intermediate features to take from the backbone's output and pass 125 | to the FPN 126 | """ 127 | 128 | def __init__(self, backbone, fpn, fpn_inputs=()): 129 | super(FPNBody, self).__init__() 130 | self.fpn_inputs = fpn_inputs 131 | 132 | self.backbone = backbone 133 | self.fpn = fpn 134 | 135 | def forward(self, x): 136 | x = self.backbone(x) 137 | xs = [x[fpn_input] for fpn_input in self.fpn_inputs] 138 | return self.fpn(xs) 139 | -------------------------------------------------------------------------------- /models/grasp_det_seg/modules/heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .fpn import FPNROIHead, FPNSemanticHeadDeeplab 2 | from .rpn import RPNHead 3 | -------------------------------------------------------------------------------- /models/grasp_det_seg/modules/heads/fpn.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as functional 6 | # from inplace_abn import ABN 7 | from models.grasp_det_seg.utils.misc import try_index 8 | 9 | class FPNROIHead(nn.Module): 10 | """ROI head module for FPN 11 | """ 12 | 13 | def __init__(self, in_channels, classes, roi_size, hidden_channels=1024, norm_act=nn.BatchNorm2d): 14 | super(FPNROIHead, self).__init__() 15 | 16 | self.fc = nn.Sequential(OrderedDict([ 17 | ("fc1", nn.Linear(int(roi_size[0] * roi_size[1] * in_channels / 4), hidden_channels, bias=False)), 18 | ("bn1",nn.BatchNorm1d(hidden_channels)), 19 | ("fc2", nn.Linear(hidden_channels, hidden_channels, bias=False)), 20 | ("bn2", nn.BatchNorm1d(hidden_channels)) 21 | ])) 22 | self.roi_cls = nn.Linear(hidden_channels, classes["thing"] + 1) 23 | self.roi_bbx = nn.Linear(hidden_channels, classes["thing"] * 4) 24 | 25 | self.reset_parameters() 26 | 27 | def reset_parameters(self): 28 | gain = nn.init.calculate_gain("relu", self.fc.bn1.weight) 29 | 30 | for name, mod in self.named_modules(): 31 | if isinstance(mod, nn.Linear): 32 | if "roi_cls" in name: 33 | nn.init.xavier_normal_(mod.weight, .01) 34 | elif "roi_bbx" in name: 35 | nn.init.xavier_normal_(mod.weight, .001) 36 | else: 37 | nn.init.xavier_normal_(mod.weight, gain) 38 | elif isinstance(mod, nn.BatchNorm2d): 39 | nn.init.constant_(mod.weight, 1.) 40 | 41 | if hasattr(mod, "bias") and mod.bias is not None: 42 | nn.init.constant_(mod.bias, 0.) 43 | 44 | def forward(self, x): 45 | """ROI head module for FPN 46 | """ 47 | x = functional.avg_pool2d(x, 2) 48 | 49 | # Run head 50 | x = self.fc(x.view(x.size(0), -1)) 51 | return self.roi_cls(x), self.roi_bbx(x).view(x.size(0), -1, 4) 52 | 53 | class FPNSemanticHeadDeeplab(nn.Module): 54 | """Semantic segmentation head for FPN-style networks, extending Deeplab v3 for FPN bodies""" 55 | 56 | class _MiniDL(nn.Module): 57 | def __init__(self, in_channels, out_channels, dilation, pooling_size, norm_act): 58 | super(FPNSemanticHeadDeeplab._MiniDL, self).__init__() 59 | self.pooling_size = pooling_size 60 | 61 | self.conv1_3x3 = nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False) 62 | self.conv1_dil = nn.Conv2d(in_channels, out_channels, 3, dilation=dilation, padding=dilation, bias=False) 63 | self.conv1_glb = nn.Conv2d(in_channels, out_channels, 1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(out_channels * 3) 65 | 66 | self.conv2 = nn.Conv2d(out_channels * 3, out_channels, 1, bias=False) 67 | self.bn2 = nn.BatchNorm2d(out_channels) 68 | 69 | def _global_pooling(self, x): 70 | pooling_size = (min(try_index(self.pooling_size, 0), x.shape[2]), 71 | min(try_index(self.pooling_size, 1), x.shape[3])) 72 | padding = ( 73 | (pooling_size[1] - 1) // 2, 74 | (pooling_size[1] - 1) // 2 if pooling_size[1] % 2 == 1 else (pooling_size[1] - 1) // 2 + 1, 75 | (pooling_size[0] - 1) // 2, 76 | (pooling_size[0] - 1) // 2 if pooling_size[0] % 2 == 1 else (pooling_size[0] - 1) // 2 + 1 77 | ) 78 | 79 | pool = functional.avg_pool2d(x, pooling_size, stride=1) 80 | pool = functional.pad(pool, pad=padding, mode="replicate") 81 | return pool 82 | 83 | def forward(self, x): 84 | x = torch.cat([ 85 | self.conv1_3x3(x), 86 | self.conv1_dil(x), 87 | self.conv1_glb(self._global_pooling(x)), 88 | ], dim=1) 89 | x = self.bn1(x) 90 | x = self.conv2(x) 91 | x = self.bn2(x) 92 | return x 93 | 94 | def __init__(self, 95 | in_channels, 96 | min_level, 97 | levels, 98 | num_classes, 99 | hidden_channels=128, 100 | dilation=6, 101 | pooling_size=(64, 64), 102 | norm_act=nn.BatchNorm2d, 103 | interpolation="bilinear"): 104 | super(FPNSemanticHeadDeeplab, self).__init__() 105 | self.min_level = min_level 106 | self.levels = levels 107 | self.interpolation = interpolation 108 | 109 | self.output = nn.ModuleList([ 110 | self._MiniDL(in_channels, hidden_channels, dilation, pooling_size, norm_act) for _ in range(levels) 111 | ]) 112 | self.conv_sem = nn.Conv2d(hidden_channels * levels, num_classes, 1) 113 | 114 | self.reset_parameters() 115 | 116 | def reset_parameters(self): 117 | gain = nn.init.calculate_gain("relu", self.output[0].bn1.weight) 118 | for name, mod in self.named_modules(): 119 | if isinstance(mod, nn.Conv2d): 120 | if "conv_sem" not in name: 121 | nn.init.xavier_normal_(mod.weight, gain) 122 | else: 123 | nn.init.xavier_normal_(mod.weight, .1) 124 | elif isinstance(mod, nn.BatchNorm2d): 125 | nn.init.constant_(mod.weight, 1.) 126 | if hasattr(mod, "bias") and mod.bias is not None: 127 | nn.init.constant_(mod.bias, 0.) 128 | 129 | def forward(self, xs): 130 | xs = xs[self.min_level:self.min_level + self.levels] 131 | 132 | ref_size = xs[0].shape[-2:] 133 | interp_params = {"mode": self.interpolation} 134 | if self.interpolation == "bilinear": 135 | interp_params["align_corners"] = False 136 | 137 | for i, output in enumerate(self.output): 138 | xs[i] = output(xs[i]) 139 | if i > 0: 140 | xs[i] = functional.interpolate(xs[i], size=ref_size, **interp_params) 141 | 142 | xs_feats = torch.cat(xs, dim=1) 143 | xs = self.conv_sem(xs_feats) 144 | 145 | return xs,xs_feats 146 | -------------------------------------------------------------------------------- /models/grasp_det_seg/modules/heads/rpn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | # from inplace_abn import ABN 4 | 5 | class RPNHead(nn.Module): 6 | """RPN head module 7 | 8 | Parameters 9 | ---------- 10 | in_channels : int 11 | Number of channels in the input feature map 12 | num_anchors : int 13 | Number of anchors predicted at each spatial location 14 | stride : int 15 | Stride of the internal convolutions 16 | hidden_channels : int 17 | Number of channels in the internal intermediate feature map 18 | norm_act : callable 19 | Function to create normalization + activation modules 20 | """ 21 | 22 | def __init__(self, in_channels, num_anchors, stride=1, hidden_channels=255, norm_act=nn.BatchNorm2d): 23 | super(RPNHead, self).__init__() 24 | 25 | self.conv1 = nn.Conv2d(in_channels, hidden_channels, 3, padding=1, stride=stride, bias=False) 26 | self.bn1 = nn.BatchNorm2d(hidden_channels) 27 | self.conv_obj = nn.Conv2d(hidden_channels, num_anchors, 1) 28 | self.conv_bbx = nn.Conv2d(hidden_channels, num_anchors * 4, 1) 29 | 30 | self.reset_parameters() 31 | 32 | def reset_parameters(self): 33 | activation = "relu" 34 | activation_param = self.bn1.weight 35 | 36 | # Hidden convolution 37 | gain = nn.init.calculate_gain(activation, activation_param) 38 | nn.init.xavier_normal_(self.conv1.weight, gain) 39 | self.bn1.reset_parameters() 40 | 41 | # Classifiers 42 | for m in [self.conv_obj, self.conv_bbx]: 43 | nn.init.xavier_normal_(m.weight, .01) 44 | nn.init.constant_(m.bias, 0) 45 | 46 | def forward(self, x): 47 | """RPN head module 48 | """ 49 | x = self.conv1(x) 50 | x = self.bn1(x) 51 | return self.conv_obj(x), self.conv_bbx(x) 52 | -------------------------------------------------------------------------------- /models/grasp_det_seg/modules/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from models.grasp_det_seg.utils.parallel import PackedSequence 4 | 5 | 6 | def smooth_l1(x1, x2, sigma): 7 | """Smooth L1 loss""" 8 | sigma2 = sigma ** 2 9 | 10 | diff = x1 - x2 11 | abs_diff = diff.abs() 12 | 13 | mask = (abs_diff.detach() < (1. / sigma2)).float() 14 | return mask * (sigma2 / 2.) * diff ** 2 + (1 - mask) * (abs_diff - 0.5 / sigma2) 15 | 16 | 17 | def ohem_loss(loss, ohem=None): 18 | if isinstance(loss, torch.Tensor): 19 | loss = loss.view(loss.size(0), -1) 20 | if ohem is None: 21 | return loss.mean() 22 | 23 | top_k = min(max(int(ohem * loss.size(1)), 1), loss.size(1)) 24 | if top_k != loss.size(1): 25 | loss, _ = loss.topk(top_k, dim=1) 26 | 27 | return loss.mean() 28 | elif isinstance(loss, PackedSequence): 29 | if ohem is None: 30 | return sum(loss_i.mean() for loss_i in loss) / len(loss) 31 | 32 | loss_out = loss.data.new_zeros(()) 33 | for loss_i in loss: 34 | loss_i = loss_i.view(-1) 35 | 36 | top_k = min(max(int(ohem * loss_i.numel()), 1), loss_i.numel()) 37 | if top_k != loss_i.numel(): 38 | loss_i, _ = loss_i.topk(top_k, dim=0) 39 | 40 | loss_out += loss_i.mean() 41 | 42 | return loss_out / len(loss) 43 | -------------------------------------------------------------------------------- /models/grasp_det_seg/modules/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as functional 4 | 5 | # from inplace_abn import ABN 6 | 7 | 8 | class GlobalAvgPool2d(nn.Module): 9 | """Global average pooling over the input's spatial dimensions""" 10 | 11 | def __init__(self): 12 | super(GlobalAvgPool2d, self).__init__() 13 | 14 | def forward(self, inputs): 15 | in_size = inputs.size() 16 | return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2) 17 | 18 | 19 | class Interpolate(nn.Module): 20 | """nn.Module wrapper to nn.functional.interpolate""" 21 | 22 | def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None): 23 | super(Interpolate, self).__init__() 24 | self.size = size 25 | self.scale_factor = scale_factor 26 | self.mode = mode 27 | self.align_corners = align_corners 28 | 29 | def forward(self, x): 30 | return functional.interpolate(x, self.size, self.scale_factor, self.mode, self.align_corners) 31 | 32 | 33 | class ActivatedAffine(nn.BatchNorm2d): 34 | """Drop-in replacement for ABN which performs inference-mode BN + activation""" 35 | 36 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", 37 | activation_param=0.01): 38 | super(ActivatedAffine, self).__init__(num_features, eps, momentum, affine, activation, activation_param) 39 | 40 | @staticmethod 41 | def _broadcast_shape(x): 42 | out_size = [] 43 | for i, s in enumerate(x.size()): 44 | if i != 1: 45 | out_size.append(1) 46 | else: 47 | out_size.append(s) 48 | return out_size 49 | 50 | def forward(self, x): 51 | inv_var = torch.rsqrt(self.running_var + self.eps) 52 | if self.affine: 53 | alpha = self.weight * inv_var 54 | beta = self.bias - self.running_mean * alpha 55 | else: 56 | alpha = inv_var 57 | beta = - self.running_mean * alpha 58 | 59 | x.mul_(alpha.view(self._broadcast_shape(x))) 60 | x.add_(beta.view(self._broadcast_shape(x))) 61 | 62 | if self.activation == "relu": 63 | return functional.relu(x, inplace=True) 64 | elif self.activation == "leaky_relu": 65 | return functional.leaky_relu(x, negative_slope=self.activation_param, inplace=True) 66 | elif self.activation == "elu": 67 | return functional.elu(x, alpha=self.activation_param, inplace=True) 68 | elif self.activation == "identity": 69 | return x 70 | else: 71 | raise RuntimeError("Unknown activation function {}".format(self.activation)) 72 | 73 | 74 | class ActivatedGroupNorm(nn.BatchNorm2d): 75 | """GroupNorm + activation function compatible with the ABN interface""" 76 | 77 | def __init__(self, num_channels, num_groups, eps=1e-5, affine=True, activation="leaky_relu", activation_param=0.01): 78 | super(ActivatedGroupNorm, self).__init__(num_channels, eps, affine=affine, activation=activation, 79 | activation_param=activation_param) 80 | self.num_groups = num_groups 81 | 82 | # Delete running mean and var since they are not used here 83 | delattr(self, "running_mean") 84 | delattr(self, "running_var") 85 | 86 | def reset_parameters(self): 87 | if self.affine: 88 | nn.init.constant_(self.weight, 1) 89 | nn.init.constant_(self.bias, 0) 90 | 91 | def forward(self, x): 92 | x = functional.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 93 | 94 | if self.activation == "relu": 95 | return functional.relu(x, inplace=True) 96 | elif self.activation == "leaky_relu": 97 | return functional.leaky_relu(x, negative_slope=self.activation_param, inplace=True) 98 | elif self.activation == "elu": 99 | return functional.elu(x, alpha=self.activation_param, inplace=True) 100 | elif self.activation == "identity": 101 | return x 102 | else: 103 | raise RuntimeError("Unknown activation function {}".format(self.activation)) 104 | -------------------------------------------------------------------------------- /models/grasp_det_seg/modules/residual.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as functional 5 | # from inplace_abn import ABN 6 | 7 | 8 | class ResidualBlock(nn.Module): 9 | """Configurable residual block 10 | 11 | Parameters 12 | ---------- 13 | in_channels : int 14 | Number of input channels. 15 | channels : list of int 16 | Number of channels in the internal feature maps. Can either have two or three elements: if three construct 17 | a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then 18 | `3 x 3` then `1 x 1` convolutions. 19 | stride : int 20 | Stride of the first `3 x 3` convolution 21 | dilation : int 22 | Dilation to apply to the `3 x 3` convolutions. 23 | groups : int 24 | Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with 25 | bottleneck blocks. 26 | norm_act : callable 27 | Function to create normalization / activation Module. 28 | dropout: callable 29 | Function to create Dropout Module. 30 | """ 31 | 32 | def __init__(self, 33 | in_channels, 34 | channels, 35 | stride=1, 36 | dilation=1, 37 | groups=1, 38 | norm_act=nn.BatchNorm2d, 39 | dropout=None): 40 | super(ResidualBlock, self).__init__() 41 | 42 | # Check parameters for inconsistencies 43 | if len(channels) != 2 and len(channels) != 3: 44 | raise ValueError("channels must contain either two or three values") 45 | if len(channels) == 2 and groups != 1: 46 | raise ValueError("groups > 1 are only valid if len(channels) == 3") 47 | 48 | is_bottleneck = len(channels) == 3 49 | need_proj_conv = stride != 1 or in_channels != channels[-1] 50 | 51 | if not is_bottleneck: 52 | bn2 = norm_act(channels[1]) 53 | bn2.activation = "identity" 54 | layers = [ 55 | ("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False, 56 | dilation=dilation)), 57 | ("bn1", norm_act(channels[0])), 58 | ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False, 59 | dilation=dilation)), 60 | ("bn2", bn2) 61 | ] 62 | if dropout is not None: 63 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:] 64 | else: 65 | bn3 = nn.BatchNorm2d(channels[2]) 66 | # bn3.activation = "identity" 67 | layers = [ 68 | ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=1, padding=0, bias=False)), 69 | ("bn1", nn.BatchNorm2d(channels[0])), 70 | ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=stride, padding=dilation, bias=False, 71 | groups=groups, dilation=dilation)), 72 | ("bn2", nn.BatchNorm2d(channels[1])), 73 | ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False)), 74 | ("bn3", bn3) 75 | ] 76 | if dropout is not None: 77 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:] 78 | self.convs = nn.Sequential(OrderedDict(layers)) 79 | 80 | if need_proj_conv: 81 | self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False) 82 | self.proj_bn = nn.BatchNorm2d(channels[-1]) 83 | self.proj_bn.activation = "identity" 84 | 85 | def forward(self, x): 86 | if hasattr(self, "proj_conv"): 87 | residual = self.proj_conv(x) 88 | residual = self.proj_bn(residual) 89 | else: 90 | residual = x 91 | 92 | x = self.convs(x) + residual 93 | 94 | return x 95 | 96 | # if self.convs.bn1.activation == "relu": 97 | # return functional.relu(x, inplace=True) 98 | # elif self.convs.bn1.activation == "leaky_relu": 99 | # return functional.leaky_relu(x, negative_slope=self.convs.bn1.activation_param, inplace=True) 100 | # elif self.convs.bn1.activation == "elu": 101 | # return functional.elu(x, alpha=self.convs.bn1.activation_param, inplace=True) 102 | # elif self.convs.bn1.activation == "identity": 103 | # return x 104 | # else: 105 | # raise RuntimeError("Unknown activation function {}".format(self.activation)) 106 | -------------------------------------------------------------------------------- /models/grasp_det_seg/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Grasp-Anything/d7755f43c5518bd6590b25021054f862e65bddd5/models/grasp_det_seg/utils/__init__.py -------------------------------------------------------------------------------- /models/grasp_det_seg/utils/bbx/__init__.py: -------------------------------------------------------------------------------- 1 | from .bbx import * 2 | -------------------------------------------------------------------------------- /models/grasp_det_seg/utils/bbx/_backend.pyi: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def extract_boxes(mask: torch.Tensor, n_instances: int) -> torch.Tensor: ... 5 | 6 | 7 | def mask_count(bbx: torch.Tensor, int_mask: torch.Tensor) -> torch.Tensor: ... 8 | -------------------------------------------------------------------------------- /models/grasp_det_seg/utils/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from math import log10 3 | from os import path 4 | 5 | from .meters import AverageMeter 6 | 7 | _NAME = "GraspDetSeg_CNN" 8 | 9 | 10 | def _current_total_formatter(current, total): 11 | width = int(log10(total)) + 1 12 | return ("[{:" + str(width) + "}/{:" + str(width) + "}]").format(current, total) 13 | 14 | 15 | def init(log_dir, name): 16 | logger = logging.getLogger(_NAME) 17 | logger.setLevel(logging.DEBUG) 18 | 19 | # Set console logging 20 | console_handler = logging.StreamHandler() 21 | console_formatter = logging.Formatter(fmt="%(asctime)s %(message)s", datefmt="%H:%M:%S") 22 | console_handler.setFormatter(console_formatter) 23 | console_handler.setLevel(logging.DEBUG) 24 | logger.addHandler(console_handler) 25 | 26 | # Setup file logging 27 | file_handler = logging.FileHandler(path.join(log_dir, name + ".log"), mode="w") 28 | file_formatter = logging.Formatter(fmt="%(levelname).1s %(asctime)s %(message)s", datefmt="%y-%m-%d %H:%M:%S") 29 | file_handler.setFormatter(file_formatter) 30 | file_handler.setLevel(logging.INFO) 31 | logger.addHandler(file_handler) 32 | 33 | 34 | def get_logger(): 35 | return logging.getLogger(_NAME) 36 | 37 | 38 | def iteration(summary, phase, global_step, epoch, num_epochs, step, num_steps, values, multiple_lines=False): 39 | logger = get_logger() 40 | 41 | # Build message and write summary 42 | msg = _current_total_formatter(epoch, num_epochs) + " " + _current_total_formatter(step, num_steps) 43 | for k, v in values.items(): 44 | if isinstance(v, AverageMeter): 45 | msg += "\n" if multiple_lines else "" + "\t{}={:.3f} ({:.3f})".format(k, v.value.item(), v.mean.item()) 46 | if summary is not None: 47 | summary.add_scalar("{}/{}".format(phase, k), v.value.item(), global_step) 48 | else: 49 | msg += "\n" if multiple_lines else "" + "\t{}={:.3f}".format(k, v) 50 | if summary is not None: 51 | summary.add_scalar("{}/{}".format(phase, k), v, global_step) 52 | 53 | # Write log 54 | logger.info(msg) 55 | -------------------------------------------------------------------------------- /models/grasp_det_seg/utils/meters.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | class Meter: 7 | def __init__(self): 8 | self._states = OrderedDict() 9 | 10 | def register_state(self, name, tensor): 11 | if name not in self._states and isinstance(tensor, torch.Tensor): 12 | self._states[name] = tensor 13 | 14 | def __getattr__(self, item): 15 | if "_states" in self.__dict__: 16 | _states = self.__dict__["_states"] 17 | if item in _states: 18 | return _states[item] 19 | return self.__dict__[item] 20 | 21 | def reset(self): 22 | for state in self._states.values(): 23 | state.zero_() 24 | 25 | def state_dict(self): 26 | return dict(self._states) 27 | 28 | def load_state_dict(self, state_dict): 29 | for k, v in state_dict.items(): 30 | if k in self._states: 31 | self._states[k].copy_(v) 32 | else: 33 | raise KeyError("Unexpected key {} in state dict when loading {} from state dict" 34 | .format(k, self.__class__.__name__)) 35 | 36 | 37 | class ConstantMeter(Meter): 38 | def __init__(self, shape): 39 | super(ConstantMeter, self).__init__() 40 | self.register_state("last", torch.zeros(shape, dtype=torch.float32)) 41 | 42 | def update(self, value): 43 | self.last.copy_(value) 44 | 45 | @property 46 | def value(self): 47 | return self.last 48 | 49 | 50 | class AverageMeter(ConstantMeter): 51 | def __init__(self, shape, momentum=1.): 52 | super(AverageMeter, self).__init__(shape) 53 | self.register_state("sum", torch.zeros(shape, dtype=torch.float32)) 54 | self.register_state("count", torch.tensor(0, dtype=torch.float32)) 55 | self.momentum = momentum 56 | 57 | def update(self, value): 58 | super(AverageMeter, self).update(value) 59 | self.sum.mul_(self.momentum).add_(value) 60 | self.count.mul_(self.momentum).add_(1.) 61 | 62 | @property 63 | def mean(self): 64 | if self.count.item() == 0: 65 | return torch.tensor(0.) 66 | else: 67 | return self.sum / self.count.clamp(min=1) 68 | 69 | 70 | class ConfusionMatrixMeter(AverageMeter): 71 | def __init__(self, num_classes, momentum=1.): 72 | super(ConfusionMatrixMeter, self).__init__((num_classes, num_classes), momentum) 73 | 74 | @property 75 | def iou(self): 76 | mean_conf = self.mean 77 | return mean_conf.diag() / (mean_conf.sum(dim=0) + mean_conf.sum(dim=1) - mean_conf.diag()) 78 | 79 | @property 80 | def precision(self): 81 | return self.mean.diag() * torch.clamp(1. / self.mean.sum(dim=0), max=1.) 82 | 83 | @property 84 | def recall(self): 85 | return self.mean.diag() * torch.clamp(1. / self.mean.sum(dim=1), max=1.) 86 | 87 | 88 | class PanopticMeter(AverageMeter): 89 | def panoptic(self): 90 | return None if self.sum is None else \ 91 | self.sum[0] / (self.sum[1] + 0.5 * self.sum[2] + 0.5 * self.sum[3]) 92 | 93 | @property 94 | def avg(self): 95 | panoptic = self.panoptic() 96 | return 0 if panoptic is None else panoptic.mean() 97 | -------------------------------------------------------------------------------- /models/grasp_det_seg/utils/misc.py: -------------------------------------------------------------------------------- 1 | import io 2 | from collections import OrderedDict 3 | from functools import partial 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import torch.nn as nn 8 | # from inplace_abn import InPlaceABN, InPlaceABNSync, ABN 9 | 10 | from models.grasp_det_seg.modules.misc import ActivatedAffine, ActivatedGroupNorm 11 | from . import scheduler as lr_scheduler 12 | 13 | NORM_LAYERS = [nn.BatchNorm2d, nn.BatchNorm2d, nn.BatchNorm2d, nn.BatchNorm2d, nn.GroupNorm] 14 | OTHER_LAYERS = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d] 15 | 16 | 17 | class Empty(Exception): 18 | """Exception to facilitate handling of empty predictions, annotations etc.""" 19 | pass 20 | 21 | 22 | def try_index(scalar_or_list, i): 23 | try: 24 | return scalar_or_list[i] 25 | except TypeError: 26 | return scalar_or_list 27 | 28 | 29 | def config_to_string(config): 30 | with io.StringIO() as sio: 31 | config.write(sio) 32 | config_str = sio.getvalue() 33 | return config_str 34 | 35 | 36 | def scheduler_from_config(scheduler_config, optimizer, epoch_length): 37 | assert scheduler_config["type"] in ("linear", "step", "poly", "multistep") 38 | 39 | params = scheduler_config.getstruct("params") 40 | 41 | if scheduler_config["type"] == "linear": 42 | if scheduler_config["update_mode"] == "batch": 43 | count = epoch_length * scheduler_config.getint("epochs") 44 | else: 45 | count = scheduler_config.getint("epochs") 46 | 47 | beta = float(params["from"]) 48 | alpha = float(params["to"] - beta) / count 49 | 50 | scheduler = lr_scheduler.LambdaLR(optimizer, lambda it: it * alpha + beta) 51 | elif scheduler_config["type"] == "step": 52 | scheduler = lr_scheduler.StepLR(optimizer, params["step_size"], params["gamma"]) 53 | elif scheduler_config["type"] == "poly": 54 | if scheduler_config["update_mode"] == "batch": 55 | count = epoch_length * scheduler_config.getint("epochs") 56 | else: 57 | count = scheduler_config.getint("epochs") 58 | scheduler = lr_scheduler.LambdaLR(optimizer, lambda it: (1 - float(it) / count) ** params["gamma"]) 59 | elif scheduler_config["type"] == "multistep": 60 | scheduler = lr_scheduler.MultiStepLR(optimizer, params["milestones"], params["gamma"]) 61 | else: 62 | raise ValueError("Unrecognized scheduler type {}, valid options: 'linear', 'step', 'poly', 'multistep'" 63 | .format(scheduler_config["type"])) 64 | 65 | if scheduler_config.getint("burn_in_steps") != 0: 66 | scheduler = lr_scheduler.BurnInLR(scheduler, 67 | scheduler_config.getint("burn_in_steps"), 68 | scheduler_config.getfloat("burn_in_start")) 69 | 70 | return scheduler 71 | 72 | 73 | def norm_act_from_config(body_config): 74 | """Make normalization + activation function from configuration 75 | 76 | Available normalization modes are: 77 | - `bn`: Standard In-Place Batch Normalization 78 | - `syncbn`: Synchronized In-Place Batch Normalization 79 | - `syncbn+bn`: Synchronized In-Place Batch Normalization in the "static" part of the network, Standard In-Place 80 | Batch Normalization in the "dynamic" parts 81 | - `gn`: Group Normalization 82 | - `syncbn+gn`: Synchronized In-Place Batch Normalization in the "static" part of the network, Group Normalization 83 | in the "dynamic" parts 84 | - `off`: No normalization (preserve scale and bias parameters) 85 | 86 | The "static" part of the network includes the backbone, FPN and semantic segmentation components, while the 87 | "dynamic" part of the network includes the RPN, detection and instance segmentation components. Note that this 88 | distinction is due to historical reasons and for back-compatibility with the CVPR2019 pre-trained models. 89 | 90 | Parameters 91 | ---------- 92 | body_config 93 | Configuration object containing the following fields: `normalization_mode`, `activation`, `activation_slope` 94 | and `gn_groups` 95 | 96 | Returns 97 | ------- 98 | norm_act_static : callable 99 | Function that returns norm_act modules for the static parts of the network 100 | norm_act_dynamic : callable 101 | Function that returns norm_act modules for the dynamic parts of the network 102 | """ 103 | mode = body_config["normalization_mode"] 104 | activation = body_config["activation"] 105 | slope = body_config.getfloat("activation_slope") 106 | groups = body_config.getint("gn_groups") 107 | 108 | if mode == "bn": 109 | norm_act_static = norm_act_dynamic = partial(nn.BatchNorm2d, activation=activation, activation_param=slope) 110 | elif mode == "syncbn": 111 | norm_act_static = norm_act_dynamic = partial(nn.BatchNorm2d, activation=activation, activation_param=slope) 112 | elif mode == "syncbn+bn": 113 | norm_act_static = partial(nn.BatchNorm2d, activation=activation, activation_param=slope) 114 | norm_act_dynamic = partial(nn.BatchNorm2d, activation=activation, activation_param=slope) 115 | elif mode == "gn": 116 | norm_act_static = norm_act_dynamic = partial( 117 | ActivatedGroupNorm, num_groups=groups, activation=activation, activation_param=slope) 118 | elif mode == "syncbn+gn": 119 | norm_act_static = partial(nn.BatchNorm2d, activation=activation, activation_param=slope) 120 | norm_act_dynamic = partial(ActivatedGroupNorm, num_groups=groups, activation=activation, activation_param=slope) 121 | elif mode == "off": 122 | norm_act_static = norm_act_dynamic = partial(ActivatedAffine, activation=activation, activation_param=slope) 123 | else: 124 | raise ValueError("Unrecognized normalization_mode {}, valid options: 'bn', 'syncbn', 'syncbn+bn', 'gn', " 125 | "'syncbn+gn', 'off'".format(mode)) 126 | 127 | return norm_act_static, norm_act_dynamic 128 | 129 | 130 | def freeze_params(module): 131 | """Freeze all parameters of the given module""" 132 | for p in module.parameters(): 133 | p.requires_grad_(False) 134 | 135 | 136 | def all_reduce_losses(losses): 137 | """Coalesced mean all reduce over a dictionary of 0-dimensional tensors""" 138 | names, values = [], [] 139 | for k, v in losses.items(): 140 | names.append(k) 141 | values.append(v) 142 | 143 | # Peform the actual coalesced all_reduce 144 | values = torch.cat([v.view(1) for v in values], dim=0) 145 | dist.all_reduce(values, dist.ReduceOp.SUM) 146 | values.div_(dist.get_world_size()) 147 | values = torch.chunk(values, values.size(0), dim=0) 148 | 149 | # Reconstruct the dictionary 150 | return OrderedDict((k, v.view(())) for k, v in zip(names, values)) 151 | -------------------------------------------------------------------------------- /models/grasp_det_seg/utils/nms/__init__.py: -------------------------------------------------------------------------------- 1 | from .nms import nms 2 | -------------------------------------------------------------------------------- /models/grasp_det_seg/utils/nms/_backend.pyi: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def nms(bbx: torch.Tensor, scores: torch.Tensor, threshold: float, n_max: int) -> torch.Tensor: ... 5 | -------------------------------------------------------------------------------- /models/grasp_det_seg/utils/nms/nms.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torchvision 5 | from torch import Tensor 6 | from torchvision.extension import _assert_has_ops 7 | 8 | # from ..utils import _log_api_usage_once 9 | # from ._box_convert import _box_cxcywh_to_xyxy, _box_xywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xyxy_to_xywh 10 | # from ._utils import _upcast 11 | 12 | from torchvision.ops.boxes import box_iou 13 | 14 | def nms(bboxes: torch.Tensor, scores: torch.Tensor, threshold: float=0.5, n_max=-1) -> torch.Tensor: 15 | order = torch.argsort(-scores) 16 | indices = torch.arange(bboxes.shape[0], device=bboxes.device) 17 | keep = torch.ones_like(indices, dtype=torch.bool, device=bboxes.device) 18 | for i in indices: 19 | if keep[i]: 20 | bbox = bboxes[order[i]] 21 | iou = box_iou(bbox[None,...],(bboxes[order[i + 1:]]) * keep[i + 1:][...,None]) 22 | overlapped = torch.nonzero(iou > threshold) 23 | keep[overlapped + i + 1] = 0 24 | return order[keep] 25 | 26 | # def nms(bbx, scores, threshold=0.5, n_max=-1): 27 | # """Perform non-maxima suppression 28 | 29 | # Select up to n_max bounding boxes from bbx, giving priorities to bounding boxes with greater scores. Each selected 30 | # bounding box suppresses all other not yet selected boxes that intersect it by more than the given threshold. 31 | 32 | # Parameters 33 | # ---------- 34 | # bbx : torch.Tensor 35 | # A tensor of bounding boxes with shape N x 4 36 | # scores : torch.Tensor 37 | # A tensor of bounding box scores with shape N 38 | # threshold : float 39 | # The minimum iou value for a pair of bounding boxes to be considered a match 40 | # n_max : int 41 | # Maximum number of bounding boxes to select. If n_max <= 0, keep all surviving boxes 42 | 43 | # Returns 44 | # ------- 45 | # selection : torch.Tensor 46 | # A tensor with the indices of the selected boxes 47 | 48 | # """ 49 | # selection = _backend.nms(bbx, scores, threshold, n_max) 50 | # return selection.to(device=bbx.device) 51 | -------------------------------------------------------------------------------- /models/grasp_det_seg/utils/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_parallel import DistributedDataParallel 2 | from .packed_sequence import PackedSequence 3 | -------------------------------------------------------------------------------- /models/grasp_det_seg/utils/parallel/data_parallel.py: -------------------------------------------------------------------------------- 1 | from torch.nn.parallel import DistributedDataParallel as TorchDistributedDataParallel 2 | 3 | from .scatter_gather import scatter_kwargs, gather 4 | 5 | 6 | class DistributedDataParallel(TorchDistributedDataParallel): 7 | """`nn.parallel.DistributedDataParallel` extension which can handle `PackedSequence`s""" 8 | 9 | def scatter(self, inputs, kwargs, device_ids): 10 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 11 | 12 | def gather(self, outputs, output_device): 13 | return gather(outputs, output_device, dim=self.dim) 14 | -------------------------------------------------------------------------------- /models/grasp_det_seg/utils/parallel/packed_sequence.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _all_same(lst): 5 | return not lst or lst.count(lst[0]) == len(lst) 6 | 7 | 8 | class PackedSequence: 9 | def __init__(self, *args): 10 | if len(args) == 1 and isinstance(args[0], list): 11 | tensors = args[0] 12 | else: 13 | tensors = args 14 | 15 | # Check if all input are tensors of the same type and device 16 | for tensor in tensors: 17 | if tensor is not None and not isinstance(tensor, torch.Tensor): 18 | raise TypeError("All args must be tensors") 19 | if not _all_same([tensor.dtype for tensor in tensors if tensor is not None]): 20 | raise TypeError("All tensors must have the same type") 21 | if not _all_same([tensor.device for tensor in tensors if tensor is not None]): 22 | raise TypeError("All tensors must reside on the same device") 23 | self._tensors = tensors 24 | 25 | # Check useful properties of the sequence 26 | self._compatible = _all_same([tensor.shape[1:] for tensor in self._tensors if tensor is not None]) 27 | self._all_none = all([tensor is None for tensor in self._tensors]) 28 | 29 | def __add__(self, other): 30 | if not isinstance(other, PackedSequence): 31 | raise TypeError("other must be a PackedSequence") 32 | return PackedSequence(self._tensors + other._tensors) 33 | 34 | def __iadd__(self, other): 35 | if not isinstance(other, PackedSequence): 36 | raise TypeError("other must be a PackedSequence") 37 | self._tensors += other._tensors 38 | return self 39 | 40 | def __len__(self): 41 | return self._tensors.__len__() 42 | 43 | def __getitem__(self, item): 44 | if isinstance(item, slice): 45 | return PackedSequence(*self._tensors.__getitem__(item)) 46 | else: 47 | return self._tensors.__getitem__(item) 48 | 49 | def __iter__(self): 50 | return self._tensors.__iter__() 51 | 52 | def cuda(self, device=None, non_blocking=False): 53 | self._tensors = [ 54 | tensor.cuda(device, non_blocking) if tensor is not None else None 55 | for tensor in self._tensors 56 | ] 57 | return self 58 | 59 | def cpu(self): 60 | self._tensors = [ 61 | tensor.cpu() if tensor is not None else None 62 | for tensor in self._tensors 63 | ] 64 | return self 65 | 66 | @property 67 | def all_none(self): 68 | return self._all_none 69 | 70 | @property 71 | def dtype(self): 72 | if self.all_none: 73 | return None 74 | return next(tensor.dtype for tensor in self._tensors if tensor is not None) 75 | 76 | @property 77 | def device(self): 78 | if self.all_none: 79 | return None 80 | return next(tensor.device for tensor in self._tensors if tensor is not None) 81 | 82 | @property 83 | def contiguous(self): 84 | if not self._compatible: 85 | raise ValueError("The tensors in the sequence are not compatible for contiguous view") 86 | if self.all_none: 87 | return None, None 88 | 89 | packed_tensors = [] 90 | packed_idx = [] 91 | for i, tensor in enumerate(self._tensors): 92 | if tensor is not None: 93 | packed_tensors.append(tensor) 94 | packed_idx.append(tensor.new_full((tensor.size(0),), i, dtype=torch.long)) 95 | 96 | return torch.cat(packed_tensors, dim=0), torch.cat(packed_idx, dim=0) 97 | -------------------------------------------------------------------------------- /models/grasp_det_seg/utils/parallel/scatter_gather.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.parallel._functions import Scatter, Gather 3 | 4 | from .packed_sequence import PackedSequence 5 | 6 | 7 | def scatter(inputs, target_gpus, dim=0): 8 | r""" 9 | Slices tensors into approximately equal chunks and 10 | distributes them across given GPUs. Duplicates 11 | references to objects that are not tensors. 12 | """ 13 | 14 | def scatter_map(obj): 15 | if isinstance(obj, torch.Tensor): 16 | return Scatter.apply(target_gpus, None, dim, obj) 17 | if isinstance(obj, tuple) and len(obj) > 0: 18 | return list(zip(*map(scatter_map, obj))) 19 | if isinstance(obj, list) and len(obj) > 0: 20 | return list(map(list, zip(*map(scatter_map, obj)))) 21 | if isinstance(obj, dict) and len(obj) > 0: 22 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 23 | if isinstance(obj, PackedSequence): 24 | return packed_sequence_scatter(obj, target_gpus) 25 | return [obj for _ in target_gpus] 26 | 27 | # After scatter_map is called, a scatter_map cell will exist. This cell 28 | # has a reference to the actual function scatter_map, which has references 29 | # to a closure that has a reference to the scatter_map cell (because the 30 | # fn is recursive). To avoid this reference cycle, we set the function to 31 | # None, clearing the cell 32 | try: 33 | return scatter_map(inputs) 34 | finally: 35 | scatter_map = None 36 | 37 | 38 | def scatter_kwargs(inputs, kwargs, target_gpus, dim=0): 39 | r"""Scatter with support for kwargs dictionary""" 40 | inputs = scatter(inputs, target_gpus, dim) if inputs else [] 41 | kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] 42 | if len(inputs) < len(kwargs): 43 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 44 | elif len(kwargs) < len(inputs): 45 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 46 | inputs = tuple(inputs) 47 | kwargs = tuple(kwargs) 48 | return inputs, kwargs 49 | 50 | 51 | def gather(outputs, target_device, dim=0): 52 | r""" 53 | Gathers tensors from different GPUs on a specified device 54 | (-1 means the CPU). 55 | """ 56 | 57 | def gather_map(outputs): 58 | out = outputs[0] 59 | if isinstance(out, torch.Tensor): 60 | return Gather.apply(target_device, dim, *outputs) 61 | if out is None: 62 | return None 63 | if isinstance(out, dict): 64 | if not all((len(out) == len(d) for d in outputs)): 65 | raise ValueError('All dicts must have the same number of keys') 66 | return type(out)(((k, gather_map([d[k] for d in outputs])) 67 | for k in out)) 68 | if isinstance(out, PackedSequence): 69 | return packed_sequence_gather(outputs, target_device) 70 | return type(out)(map(gather_map, zip(*outputs))) 71 | 72 | # Recursive function calls like this create reference cycles. 73 | # Setting the function to None clears the refcycle. 74 | try: 75 | return gather_map(outputs) 76 | finally: 77 | gather_map = None 78 | 79 | 80 | def packed_sequence_scatter(seq, target_gpus): 81 | # Find chunks 82 | k, m = divmod(len(seq), len(target_gpus)) 83 | limits = [(i * k + min(i, m), (i + 1) * k + min(i + 1, m)) for i in range(len(target_gpus))] 84 | outs = [] 85 | for device, (i, j) in zip(target_gpus, limits): 86 | outs.append(seq[i:j].cuda(device)) 87 | return outs 88 | 89 | 90 | def packed_sequence_gather(seqs, target_device): 91 | out = seqs[0].cuda(target_device) 92 | for i in range(1, len(seqs)): 93 | out += seqs[i].cuda(target_device) 94 | return out 95 | -------------------------------------------------------------------------------- /models/grasp_det_seg/utils/roi_sampling/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import roi_sampling 2 | -------------------------------------------------------------------------------- /models/grasp_det_seg/utils/roi_sampling/_backend.pyi: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | 6 | class PaddingMode: 7 | Zero = ... 8 | Border = ... 9 | 10 | 11 | class Interpolation: 12 | Bilinear = ... 13 | Nearest = ... 14 | 15 | 16 | def roi_sampling_forward( 17 | x: torch.Tensor, bbx: torch.Tensor, idx: torch.Tensor, out_size: Tuple[int, int], 18 | interpolation: Interpolation, padding: PaddingMode, valid_mask: bool) -> Tuple[torch.Tensor, torch.Tensor]: ... 19 | 20 | 21 | def roi_sampling_backward( 22 | dy: torch.Tensor, bbx: torch.Tensor, idx: torch.Tensor, in_size: Tuple[int, int, int], 23 | interpolation: Interpolation, padding: PaddingMode) -> torch.Tensor: ... 24 | -------------------------------------------------------------------------------- /models/grasp_det_seg/utils/roi_sampling/functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.autograd as autograd 4 | from torch.autograd.function import once_differentiable 5 | 6 | # from . import _backend 7 | 8 | # _INTERPOLATION = {"bilinear": _backend.Interpolation.Bilinear, "nearest": _backend.Interpolation.Nearest} 9 | # _PADDING = {"zero": _backend.PaddingMode.Zero, "border": _backend.PaddingMode.Border} 10 | 11 | 12 | class ROISampling(autograd.Function): 13 | @staticmethod 14 | def forward(ctx, x, bbx, idx, roi_size, interpolation, padding, valid_mask): 15 | ctx.save_for_backward(bbx, idx) 16 | ctx.input_shape = (x.size(0), x.size(2), x.size(3)) 17 | ctx.valid_mask = valid_mask 18 | 19 | try: 20 | ctx.interpolation = _INTERPOLATION[interpolation] 21 | except KeyError: 22 | raise ValueError("Unknown interpolation {}".format(interpolation)) 23 | try: 24 | ctx.padding = _PADDING[padding] 25 | except KeyError: 26 | raise ValueError("Unknown padding {}".format(padding)) 27 | 28 | y, mask = _backend.roi_sampling_forward(x, bbx, idx, roi_size, ctx.interpolation, ctx.padding, valid_mask) 29 | 30 | if not torch.is_floating_point(x): 31 | ctx.mark_non_differentiable(y) 32 | if valid_mask: 33 | ctx.mark_non_differentiable(mask) 34 | return y, mask 35 | else: 36 | return y 37 | 38 | @staticmethod 39 | @once_differentiable 40 | def backward(ctx, *args): 41 | if ctx.valid_mask: 42 | dy, _ = args 43 | else: 44 | dy = args[0] 45 | 46 | assert torch.is_floating_point(dy), "ROISampling.backward is only defined for floating point types" 47 | bbx, idx = ctx.saved_tensors 48 | 49 | dx = _backend.roi_sampling_backward(dy, bbx, idx, ctx.input_shape, ctx.interpolation, ctx.padding) 50 | return dx, None, None, None, None, None, None 51 | 52 | 53 | def roi_sampling(x, bbx, idx, roi_size, interpolation='bilinear', padding_mode='zeros'): 54 | """ 55 | Performs RoI sampling on the input feature map `x` based on the bounding boxes `bbx`. 56 | 57 | Arguments: 58 | x (torch.Tensor): Input feature map with shape (N, C, H, W). 59 | bbx (torch.Tensor): RoI bounding boxes with shape (M, 4), where M is the number of RoIs, 60 | and each RoI is represented by (xmin, ymin, xmax, ymax). 61 | idx (torch.Tensor): The corresponding indices of the RoIs in the input feature map `x`. 62 | Should be a 1D tensor of integers with shape (M,). 63 | roi_size (int): The output size of the RoI. 64 | interpolation (str): The interpolation method to use for resizing. Default is 'bilinear'. 65 | padding_mode (str): The padding mode to use for resizing. Default is 'zeros'. 66 | 67 | Returns: 68 | torch.Tensor: Output RoI feature map with shape (M, C, roi_size, roi_size). 69 | """ 70 | num_rois = bbx.size(0) 71 | output = [] 72 | 73 | for i in range(num_rois): 74 | roi_idx = int(idx[i]) 75 | roi = bbx[i] 76 | xmin, ymin, xmax, ymax = roi 77 | 78 | # Calculate RoI coordinates in integer format 79 | roi_xmin = int(torch.round(xmin).item()) 80 | roi_ymin = int(torch.round(ymin).item()) 81 | roi_xmax = int(torch.round(xmax).item()) 82 | roi_ymax = int(torch.round(ymax).item()) 83 | 84 | # Crop the RoI region from the input feature map 85 | roi_feature = x[roi_idx:roi_idx + 1, :, roi_ymin:roi_ymax, roi_xmin:roi_xmax] 86 | 87 | # Resize the RoI to the desired output size using RoIAlign 88 | roi_feature = F.adaptive_avg_pool2d(roi_feature, roi_size) 89 | 90 | output.append(roi_feature) 91 | 92 | return torch.cat(output, dim=0) 93 | 94 | # def roi_sampling(x, bbx, idx, roi_size, interpolation="bilinear", padding="border", valid_mask=False): 95 | # """Sample ROIs from a batch of images using bi-linear interpolation 96 | 97 | # ROIs are sampled from the input by bi-linear interpolation, using the following equations to transform from 98 | # ROI coordinates to image coordinates: 99 | 100 | # y_img = y0 + y_roi / h_roi * (y1 - y0), for y_roi in range(0, h_roi) 101 | # x_img = x0 + x_roi / w_roi * (x1 - x0), for x_roi in range(0, w_roi) 102 | 103 | # where `(h_roi, w_roi)` is the shape of the ROI and `(y0, x0, y1, x1)` are its bounding box coordinates on the image 104 | 105 | # Parameters 106 | # ---------- 107 | # x : torch.Tensor 108 | # A tensor with shape N x C x H x W containing a batch of images to sample from 109 | # bbx : torch.Tensor 110 | # A tensor with shape K x 4 containing the bounding box coordinates of the ROIs in "corners" format 111 | # idx : torch.Tensor 112 | # A tensor with shape K containing the batch indices of the image each ROI should be sampled from 113 | # roi_size : tuple of int 114 | # The size `(h_roi, w_roi)` of the output ROIs 115 | # interpolation : str 116 | # Sampling mode, one of "bilinear" or "nearest" 117 | # padding : str 118 | # Padding mode, one of "border" or "zero" 119 | # valid_mask : bool 120 | # If `True` also return a mask tensor that indicates which points of the outputs where sampled from within the 121 | # valid region of the input 122 | 123 | # Returns 124 | # ------- 125 | # y : torch.Tensor 126 | # A tensor with shape K x C x h_roi x w_roi containing the sampled ROIs 127 | # mask : torch.Tensor 128 | # Optional output returned only when valid_mask is `True`: a mask tensor with shape K x h_roi x w_roi, whose 129 | # entries are `!= 0` where the corresponding location in `y` was sampled from within the limits of the input image 130 | # """ 131 | # return ROISampling.apply(x, bbx, idx, roi_size, interpolation, padding, valid_mask) 132 | 133 | 134 | __all__ = ["roi_sampling"] 135 | -------------------------------------------------------------------------------- /models/grasp_det_seg/utils/sequence.py: -------------------------------------------------------------------------------- 1 | from .parallel import PackedSequence 2 | 3 | 4 | def pad_packed_images(packed_images, pad_value=0., snap_size_to=None): 5 | """Assemble a padded tensor for a `PackedSequence` of images with different spatial sizes 6 | 7 | This method allows any standard convnet to operate on a `PackedSequence` of images as a batch 8 | 9 | Parameters 10 | ---------- 11 | packed_images : PackedSequence 12 | A PackedSequence containing N tensors with different spatial sizes H_i, W_i. The tensors can be either 2D or 3D. 13 | If they are 3D, they must all have the same number of channels C. 14 | pad_value : float or int 15 | Value used to fill the padded areas 16 | snap_size_to : int or None 17 | If not None, chose the spatial sizes of the padded tensor to be multiples of this 18 | 19 | Returns 20 | ------- 21 | padded_images : torch.Tensor 22 | A tensor with shape N x C x H x W or N x H x W, where `H = max_i H_i` and `W = max_i W_i` containing the images 23 | of the sequence aligned to the top left corner and padded with `pad_value` 24 | sizes : list of tuple of int 25 | A list with the original spatial sizes of the input images 26 | """ 27 | # if packed_images.all_none: 28 | # raise ValueError("at least one image in packed_images should be non-None") 29 | 30 | reference_img = next(img for img in packed_images if img is not None) 31 | max_size = reference_img.shape[-2:] 32 | ndims = len(reference_img.shape) 33 | chn = reference_img.shape[0] if ndims == 3 else 0 34 | 35 | # Check the shapes and find maximum spatial size 36 | for img in packed_images: 37 | if img is not None: 38 | if len(img.shape) != 3 and len(img.shape) != 2: 39 | raise ValueError("The input sequence must contain 2D or 3D tensors") 40 | if len(img.shape) != ndims: 41 | raise ValueError("All tensors in the input sequence must have the same number of dimensions") 42 | if ndims == 3 and img.shape[0] != chn: 43 | raise ValueError("3D tensors must all have the same number of channels") 44 | max_size = [max(s1, s2) for s1, s2 in zip(max_size, img.shape[-2:])] 45 | 46 | # Optional size snapping 47 | if snap_size_to is not None: 48 | max_size = [(s + snap_size_to - 1) // snap_size_to * snap_size_to for s in max_size] 49 | 50 | if ndims == 3: 51 | padded_images = reference_img.new_full([len(packed_images), chn] + max_size, pad_value) 52 | else: 53 | padded_images = reference_img.new_full([len(packed_images)] + max_size, pad_value) 54 | 55 | sizes = [] 56 | for i, tensor in enumerate(packed_images): 57 | if tensor is not None: 58 | if ndims == 3: 59 | padded_images[i, :, :tensor.shape[1], :tensor.shape[2]] = tensor 60 | sizes.append(tensor.shape[1:]) 61 | else: 62 | padded_images[i, :tensor.shape[0], :tensor.shape[1]] = tensor 63 | sizes.append(tensor.shape) 64 | else: 65 | sizes.append((0, 0)) 66 | 67 | return padded_images, sizes 68 | 69 | 70 | def pack_padded_images(padded_images, sizes): 71 | """Inverse function of `pad_packed_images`, refer to that for details""" 72 | images = [] 73 | for img, size in zip(padded_images, sizes): 74 | if img.dim() == 2: 75 | images.append(img[:int(size[0]), :int(size[1])]) 76 | else: 77 | images.append(img[:, :int(size[0]), :int(size[1])]) 78 | 79 | return PackedSequence([img.contiguous() for img in images]) 80 | -------------------------------------------------------------------------------- /models/grasp_det_seg/utils/snapshot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .misc import config_to_string 4 | 5 | 6 | def save_snapshot(file, config, epoch, last_score, best_score, global_step, **kwargs): 7 | data = { 8 | "config": config_to_string(config), 9 | "state_dict": dict(kwargs), 10 | "training_meta": { 11 | "epoch": epoch, 12 | "last_score": last_score, 13 | "best_score": best_score, 14 | "global_step": global_step 15 | } 16 | } 17 | torch.save(data, file) 18 | 19 | 20 | def pre_train_from_snapshots(model, snapshots, modules): 21 | for snapshot in snapshots: 22 | if ":" in snapshot: 23 | module_name, snapshot = snapshot.split(":") 24 | else: 25 | module_name = None 26 | 27 | snapshot = torch.load(snapshot, map_location="cpu") 28 | state_dict = snapshot["state_dict"] 29 | 30 | if module_name is None: 31 | for module_name in modules: 32 | if module_name in state_dict: 33 | _load_pretraining_dict(getattr(model, module_name), state_dict[module_name]) 34 | else: 35 | if module_name in modules: 36 | _load_pretraining_dict(getattr(model, module_name), state_dict[module_name]) 37 | else: 38 | raise ValueError("Unrecognized network module {}".format(module_name)) 39 | 40 | 41 | def resume_from_snapshot(model, snapshot, modules): 42 | snapshot = torch.load(snapshot, map_location="cpu") 43 | state_dict = snapshot["state_dict"] 44 | 45 | for module in modules: 46 | if module in state_dict: 47 | _load_pretraining_dict(getattr(model, module), state_dict[module]) 48 | else: 49 | raise KeyError("The given snapshot does not contain a state_dict for module '{}'".format(module)) 50 | 51 | return snapshot 52 | 53 | 54 | def _load_pretraining_dict(model, state_dict): 55 | """Load state dictionary from a pre-training snapshot 56 | 57 | This is an even less strict version of `model.load_state_dict(..., False)`, which also ignores parameters from 58 | `state_dict` that don't have the same shapes as the corresponding ones in `model`. This is useful when loading 59 | from pre-trained models that are trained on different datasets. 60 | 61 | Parameters 62 | ---------- 63 | model : torch.nn.Model 64 | Target model 65 | state_dict : dict 66 | Dictionary of model parameters 67 | """ 68 | model_sd = model.state_dict() 69 | 70 | for k, v in model_sd.items(): 71 | if k in state_dict: 72 | if v.shape != state_dict[k].shape: 73 | del state_dict[k] 74 | 75 | model.load_state_dict(state_dict, False) 76 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv-python 3 | matplotlib 4 | scikit-image 5 | imageio 6 | torch 7 | torchvision 8 | torchsummary 9 | tensorboardX 10 | pyrealsense2 11 | Pillow -------------------------------------------------------------------------------- /script/eval_cornell_seen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | folder_path="$1" # Replace '/path/to/your/folder' with the actual folder path 4 | 5 | if [ ! -d "$folder_path" ]; then 6 | echo "Folder $folder_path not found." 7 | exit 1 8 | fi 9 | 10 | pattern="epoch_*" 11 | for file in "$folder_path"/$pattern; do 12 | if [ -f "$file" ]; then 13 | echo "Running command with file: $file" 14 | python evaluate.py --dataset cornell --dataset-path data/cornell_seen/archive/ --iou-eval --use-depth 0 --seen 0 --split 0.9 --network "$file" # Execute the command with the file as a parameter 15 | fi 16 | done 17 | 18 | echo "All files processed." 19 | -------------------------------------------------------------------------------- /script/eval_cornell_unseen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | folder_path="$1" # Replace '/path/to/your/folder' with the actual folder path 4 | 5 | if [ ! -d "$folder_path" ]; then 6 | echo "Folder $folder_path not found." 7 | exit 1 8 | fi 9 | 10 | pattern="epoch_*" 11 | for file in "$folder_path"/$pattern; do 12 | if [ -f "$file" ]; then 13 | echo "Running command with file: $file" 14 | python evaluate.py --dataset cornell --dataset-path data/cornell_unseen/archive/ --iou-eval --use-depth 0 --seen 0 --split 0.01 --network "$file" # Execute the command with the file as a parameter 15 | fi 16 | done 17 | 18 | echo "All files processed." 19 | -------------------------------------------------------------------------------- /script/eval_grasp_anything_seen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | folder_path="$1" # Replace '/path/to/your/folder' with the actual folder path 4 | 5 | if [ ! -d "$folder_path" ]; then 6 | echo "Folder $folder_path not found." 7 | exit 1 8 | fi 9 | 10 | pattern="epoch_*" 11 | for file in "$folder_path"/$pattern; do 12 | if [ -f "$file" ]; then 13 | echo "Running command with file: $file" 14 | python evaluate.py --dataset grasp-anything --dataset-path data/grasp-anything/ --iou-eval --use-depth 0 --seen 1 --split 0.99 --network "$file" # Execute the command with the file as a parameter 15 | fi 16 | done 17 | 18 | echo "All files processed." 19 | -------------------------------------------------------------------------------- /script/eval_grasp_anything_unseen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | folder_path="$1" # Replace '/path/to/your/folder' with the actual folder path 4 | 5 | if [ ! -d "$folder_path" ]; then 6 | echo "Folder $folder_path not found." 7 | exit 1 8 | fi 9 | 10 | pattern="epoch_*" 11 | for file in "$folder_path"/$pattern; do 12 | if [ -f "$file" ]; then 13 | echo "Running command with file: $file" 14 | python evaluate.py --dataset grasp-anything --dataset-path data/grasp-anything/ --iou-eval --use-depth 0 --seen 0 --split 0.01 --network "$file" # Execute the command with the file as a parameter 15 | fi 16 | done 17 | 18 | echo "All files processed." 19 | -------------------------------------------------------------------------------- /script/eval_jacquard_seen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | folder_path="$1" # Replace '/path/to/your/folder' with the actual folder path 4 | 5 | if [ ! -d "$folder_path" ]; then 6 | echo "Folder $folder_path not found." 7 | exit 1 8 | fi 9 | 10 | pattern="epoch_*" 11 | for file in "$folder_path"/$pattern; do 12 | if [ -f "$file" ]; then 13 | echo "Running command with file: $file" 14 | python evaluate.py --dataset jacquard --dataset-path data/jacquard_seen/0/ --iou-eval --use-depth 0 --seen 0 --split 0.9 --network "$file" # Execute the command with the file as a parameter 15 | fi 16 | done 17 | 18 | echo "All files processed." 19 | -------------------------------------------------------------------------------- /script/eval_jacquard_unseen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | folder_path="$1" # Replace '/path/to/your/folder' with the actual folder path 4 | 5 | if [ ! -d "$folder_path" ]; then 6 | echo "Folder $folder_path not found." 7 | exit 1 8 | fi 9 | 10 | pattern="epoch_*" 11 | for file in "$folder_path"/$pattern; do 12 | if [ -f "$file" ]; then 13 | echo "Running command with file: $file" 14 | python evaluate.py --dataset jacquard --dataset-path data/jacquard_unseen/0/ --iou-eval --use-depth 0 --seen 0 --split 0.01 --network "$file" # Execute the command with the file as a parameter 15 | fi 16 | done 17 | 18 | echo "All files processed." 19 | -------------------------------------------------------------------------------- /split/control_manage_cornell.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import shutil 4 | import pickle 5 | 6 | def copy_files_with_structure(source_dir, destination_dir, filter_func=None): 7 | # Create the destination directory if it doesn't exist 8 | if not os.path.exists(destination_dir): 9 | os.makedirs(destination_dir) 10 | 11 | for root, _, files in os.walk(source_dir): 12 | # Get the relative path of the current directory within the source directory 13 | relative_path = os.path.relpath(root, source_dir) 14 | # Create the corresponding directory structure in the destination directory 15 | dest_dir = os.path.join(destination_dir, relative_path) 16 | 17 | # Create the directory in the destination directory 18 | if not os.path.exists(dest_dir): 19 | os.makedirs(dest_dir) 20 | 21 | # Copy each file in the current directory to the corresponding directory in the destination directory 22 | for file in files: 23 | source_file_path = os.path.join(root, file) 24 | dest_file_path = os.path.join(dest_dir, file) 25 | 26 | # Use the filter function if provided to check if the file should be copied 27 | if filter_func is None or filter_func(source_file_path): 28 | shutil.copy2(source_file_path, dest_file_path) # Use shutil.copy2 to preserve metadata 29 | 30 | # Example usage: 31 | def seen_filter_cornell(file_path): 32 | idx = file_path.split('/')[-1].split('.')[0] 33 | idx = idx[:7] 34 | return idx in seen_filters 35 | 36 | def unseen_filter_cornell(file_path): 37 | idx = file_path.split('/')[-1].split('.')[0] 38 | idx = idx[:7] 39 | return idx in unseen_filters 40 | 41 | cornell_path = 'data/cornell' 42 | cornell_filter_path = 'split/cornell' 43 | source_directory = "data/cornell" 44 | seen_destination_directory = "data/cornell_seen" 45 | unseen_destination_directory = "data/cornell_unseen" 46 | 47 | with open(os.path.join(cornell_filter_path, 'seen.obj'), 'rb') as f: 48 | seen_filters = pickle.load(f) 49 | 50 | with open(os.path.join(cornell_filter_path, 'unseen.obj'), 'rb') as f: 51 | unseen_filters = pickle.load(f) 52 | 53 | copy_files_with_structure(source_directory, seen_destination_directory, filter_func=seen_filter_cornell) 54 | copy_files_with_structure(source_directory, unseen_destination_directory, filter_func=unseen_filter_cornell) 55 | -------------------------------------------------------------------------------- /split/control_manage_jacquard.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import shutil 4 | import pickle 5 | 6 | def copy_files_with_structure(source_dir, destination_dir, filter_func=None): 7 | # Create the destination directory if it doesn't exist 8 | if not os.path.exists(destination_dir): 9 | os.makedirs(destination_dir) 10 | 11 | for root, _, files in os.walk(source_dir): 12 | # Get the relative path of the current directory within the source directory 13 | relative_path = os.path.relpath(root, source_dir) 14 | # Create the corresponding directory structure in the destination directory 15 | dest_dir = os.path.join(destination_dir, relative_path) 16 | 17 | # Create the directory in the destination directory 18 | if not os.path.exists(dest_dir): 19 | os.makedirs(dest_dir) 20 | 21 | # Copy each file in the current directory to the corresponding directory in the destination directory 22 | for file in files: 23 | source_file_path = os.path.join(root, file) 24 | dest_file_path = os.path.join(dest_dir, file) 25 | 26 | # Use the filter function if provided to check if the file should be copied 27 | if filter_func is None or filter_func(source_file_path): 28 | shutil.copy2(source_file_path, dest_file_path) # Use shutil.copy2 to preserve metadata 29 | 30 | # Example usage: 31 | def seen_filter_jacquard(file_path): 32 | idx = file_path.split('/')[-1].split('.')[0] 33 | idx = "_".join(idx.split('_')[:2]) 34 | return idx in seen_filters 35 | 36 | def unseen_filter_jacquard(file_path): 37 | idx = file_path.split('/')[-1].split('.')[0] 38 | idx = "_".join(idx.split('_')[:2]) 39 | return idx in unseen_filters 40 | 41 | jacquard_path = 'data/jacquard' 42 | jacquard_filter_path = 'split/jacquard' 43 | source_directory = "data/jacquard" 44 | seen_destination_directory = "data/jacquard_seen" 45 | unseen_destination_directory = "data/jacquard_unseen" 46 | 47 | with open(os.path.join(jacquard_filter_path, 'seen.obj'), 'rb') as f: 48 | seen_filters = pickle.load(f) 49 | 50 | with open(os.path.join(jacquard_filter_path, 'unseen.obj'), 'rb') as f: 51 | unseen_filters = pickle.load(f) 52 | 53 | copy_files_with_structure(source_directory, seen_destination_directory, filter_func=seen_filter_jacquard) 54 | copy_files_with_structure(source_directory, unseen_destination_directory, filter_func=unseen_filter_jacquard) 55 | -------------------------------------------------------------------------------- /split/cornell/seen.obj: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Grasp-Anything/d7755f43c5518bd6590b25021054f862e65bddd5/split/cornell/seen.obj -------------------------------------------------------------------------------- /split/cornell/unseen.obj: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Grasp-Anything/d7755f43c5518bd6590b25021054f862e65bddd5/split/cornell/unseen.obj -------------------------------------------------------------------------------- /split/grasp-anything/seen.obj: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Grasp-Anything/d7755f43c5518bd6590b25021054f862e65bddd5/split/grasp-anything/seen.obj -------------------------------------------------------------------------------- /split/grasp-anything/unseen.obj: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Grasp-Anything/d7755f43c5518bd6590b25021054f862e65bddd5/split/grasp-anything/unseen.obj -------------------------------------------------------------------------------- /split/jacquard/seen.obj: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Grasp-Anything/d7755f43c5518bd6590b25021054f862e65bddd5/split/jacquard/seen.obj -------------------------------------------------------------------------------- /split/jacquard/unseen.obj: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Grasp-Anything/d7755f43c5518bd6590b25021054f862e65bddd5/split/jacquard/unseen.obj -------------------------------------------------------------------------------- /utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | def get_dataset(dataset_name): 2 | if dataset_name == 'cornell': 3 | from .cornell_data import CornellDataset 4 | return CornellDataset 5 | elif dataset_name == 'jacquard': 6 | from .jacquard_data import JacquardDataset 7 | return JacquardDataset 8 | elif dataset_name == 'grasp-anything': 9 | from .grasp_anything_data import GraspAnythingDataset 10 | return GraspAnythingDataset 11 | elif dataset_name == 'vmrd': 12 | from .vmrd_data import VMRDDataset 13 | return VMRDDataset 14 | elif dataset_name == 'ocid': 15 | from .ocid_grasp_data import OCIDGraspDataset 16 | return OCIDGraspDataset 17 | else: 18 | raise NotImplementedError('Dataset Type {} is Not implemented'.format(dataset_name)) 19 | -------------------------------------------------------------------------------- /utils/data/camera_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from utils.dataset_processing import image 5 | 6 | 7 | class CameraData: 8 | """ 9 | Dataset wrapper for the camera data. 10 | """ 11 | def __init__(self, 12 | width=640, 13 | height=480, 14 | output_size=224, 15 | include_depth=True, 16 | include_rgb=True 17 | ): 18 | """ 19 | :param output_size: Image output size in pixels (square) 20 | :param include_depth: Whether depth image is included 21 | :param include_rgb: Whether RGB image is included 22 | """ 23 | self.output_size = output_size 24 | self.include_depth = include_depth 25 | self.include_rgb = include_rgb 26 | 27 | if include_depth is False and include_rgb is False: 28 | raise ValueError('At least one of Depth or RGB must be specified.') 29 | 30 | left = (width - output_size) // 2 31 | top = (height - output_size) // 2 32 | right = (width + output_size) // 2 33 | bottom = (height + output_size) // 2 34 | 35 | self.bottom_right = (bottom, right) 36 | self.top_left = (top, left) 37 | 38 | @staticmethod 39 | def numpy_to_torch(s): 40 | if len(s.shape) == 2: 41 | return torch.from_numpy(np.expand_dims(s, 0).astype(np.float32)) 42 | else: 43 | return torch.from_numpy(s.astype(np.float32)) 44 | 45 | def get_depth(self, img): 46 | depth_img = image.Image(img) 47 | depth_img.crop(bottom_right=self.bottom_right, top_left=self.top_left) 48 | depth_img.normalise() 49 | # depth_img.resize((self.output_size, self.output_size)) 50 | depth_img.img = depth_img.img.transpose((2, 0, 1)) 51 | return depth_img.img 52 | 53 | def get_rgb(self, img, norm=True): 54 | rgb_img = image.Image(img) 55 | rgb_img.crop(bottom_right=self.bottom_right, top_left=self.top_left) 56 | # rgb_img.resize((self.output_size, self.output_size)) 57 | if norm: 58 | rgb_img.normalise() 59 | rgb_img.img = rgb_img.img.transpose((2, 0, 1)) 60 | return rgb_img.img 61 | 62 | def get_data(self, rgb=None, depth=None): 63 | depth_img = None 64 | rgb_img = None 65 | # Load the depth image 66 | if self.include_depth: 67 | depth_img = self.get_depth(img=depth) 68 | 69 | # Load the RGB image 70 | if self.include_rgb: 71 | rgb_img = self.get_rgb(img=rgb) 72 | 73 | if self.include_depth and self.include_rgb: 74 | x = self.numpy_to_torch( 75 | np.concatenate( 76 | (np.expand_dims(depth_img, 0), 77 | np.expand_dims(rgb_img, 0)), 78 | 1 79 | ) 80 | ) 81 | elif self.include_depth: 82 | x = self.numpy_to_torch(depth_img) 83 | elif self.include_rgb: 84 | x = self.numpy_to_torch(np.expand_dims(rgb_img, 0)) 85 | 86 | return x, depth_img, rgb_img 87 | -------------------------------------------------------------------------------- /utils/data/cornell_data.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | from utils.dataset_processing import grasp, image 5 | from .grasp_data import GraspDatasetBase 6 | 7 | 8 | class CornellDataset(GraspDatasetBase): 9 | """ 10 | Dataset wrapper for the Cornell dataset. 11 | """ 12 | 13 | def __init__(self, file_path, ds_rotate=0, **kwargs): 14 | """ 15 | :param file_path: Cornell Dataset directory. 16 | :param ds_rotate: If splitting the dataset, rotate the list of items by this fraction first 17 | :param kwargs: kwargs for GraspDatasetBase 18 | """ 19 | super(CornellDataset, self).__init__(**kwargs) 20 | 21 | self.grasp_files = glob.glob(os.path.join(file_path, '*', 'pcd*cpos.txt')) 22 | self.grasp_files.sort() 23 | self.length = len(self.grasp_files) 24 | 25 | if self.length == 0: 26 | raise FileNotFoundError('No dataset files found. Check path: {}'.format(file_path)) 27 | 28 | if ds_rotate: 29 | self.grasp_files = self.grasp_files[int(self.length * ds_rotate):] + self.grasp_files[ 30 | :int(self.length * ds_rotate)] 31 | 32 | self.depth_files = [f.replace('cpos.txt', 'd.tiff') for f in self.grasp_files] 33 | self.rgb_files = [f.replace('d.tiff', 'r.png') for f in self.depth_files] 34 | 35 | def _get_crop_attrs(self, idx): 36 | gtbbs = grasp.GraspRectangles.load_from_cornell_file(self.grasp_files[idx]) 37 | center = gtbbs.center 38 | left = max(0, min(center[1] - self.output_size // 2, 640 - self.output_size)) 39 | top = max(0, min(center[0] - self.output_size // 2, 480 - self.output_size)) 40 | return center, left, top 41 | 42 | def get_gtbb(self, idx, rot=0, zoom=1.0): 43 | gtbbs = grasp.GraspRectangles.load_from_cornell_file(self.grasp_files[idx]) 44 | center, left, top = self._get_crop_attrs(idx) 45 | gtbbs.rotate(rot, center) 46 | gtbbs.offset((-top, -left)) 47 | gtbbs.zoom(zoom, (self.output_size // 2, self.output_size // 2)) 48 | return gtbbs 49 | 50 | def get_depth(self, idx, rot=0, zoom=1.0): 51 | depth_img = image.DepthImage.from_tiff(self.depth_files[idx]) 52 | center, left, top = self._get_crop_attrs(idx) 53 | depth_img.rotate(rot, center) 54 | depth_img.crop((top, left), (min(480, top + self.output_size), min(640, left + self.output_size))) 55 | depth_img.normalise() 56 | depth_img.zoom(zoom) 57 | depth_img.resize((self.output_size, self.output_size)) 58 | return depth_img.img 59 | 60 | def get_rgb(self, idx, rot=0, zoom=1.0, normalise=True): 61 | rgb_img = image.Image.from_file(self.rgb_files[idx]) 62 | center, left, top = self._get_crop_attrs(idx) 63 | rgb_img.rotate(rot, center) 64 | rgb_img.crop((top, left), (min(480, top + self.output_size), min(640, left + self.output_size))) 65 | rgb_img.zoom(zoom) 66 | rgb_img.resize((self.output_size, self.output_size)) 67 | if normalise: 68 | rgb_img.normalise() 69 | rgb_img.img = rgb_img.img.transpose((2, 0, 1)) 70 | return rgb_img.img 71 | -------------------------------------------------------------------------------- /utils/data/grasp_anything_data.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import re 4 | 5 | import pickle 6 | import torch 7 | 8 | from utils.dataset_processing import grasp, image, mask 9 | from .grasp_data import GraspDatasetBase 10 | 11 | 12 | class GraspAnythingDataset(GraspDatasetBase): 13 | """ 14 | Dataset wrapper for the Grasp-Anything dataset. 15 | """ 16 | 17 | def __init__(self, file_path, ds_rotate=0, **kwargs): 18 | """ 19 | :param file_path: Grasp-Anything Dataset directory. 20 | :param ds_rotate: If splitting the dataset, rotate the list of items by this fraction first 21 | :param kwargs: kwargs for GraspDatasetBase 22 | """ 23 | super(GraspAnythingDataset, self).__init__(**kwargs) 24 | 25 | self.grasp_files = glob.glob(os.path.join(file_path, 'grasp_label_positive', '*.pt')) 26 | self.prompt_files = glob.glob(os.path.join(file_path, 'scene_description', '*.pkl')) 27 | self.rgb_files = glob.glob(os.path.join(file_path, 'image', '*.jpg')) 28 | # self.mask_files = glob.glob(os.path.join(file_path, 'mask', '*.npy')) 29 | 30 | if kwargs["seen"]: 31 | with open(os.path.join('split/grasp-anything/seen.obj'), 'rb') as f: 32 | idxs = pickle.load(f) 33 | 34 | self.grasp_files = list(filter(lambda x: x.split('/')[-1].split('.')[0] in idxs, self.grasp_files)) 35 | else: 36 | with open(os.path.join('split/grasp-anything/unseen.obj'), 'rb') as f: 37 | idxs = pickle.load(f) 38 | 39 | self.grasp_files = list(filter(lambda x: x.split('/')[-1].split('.')[0] in idxs, self.grasp_files)) 40 | 41 | self.grasp_files.sort() 42 | self.prompt_files.sort() 43 | self.rgb_files.sort() 44 | # self.mask_files.sort() 45 | 46 | self.length = len(self.grasp_files) 47 | 48 | if self.length == 0: 49 | raise FileNotFoundError('No dataset files found. Check path: {}'.format(file_path)) 50 | 51 | if ds_rotate: 52 | self.grasp_files = self.grasp_files[int(self.length * ds_rotate):] + self.grasp_files[ 53 | :int(self.length * ds_rotate)] 54 | 55 | 56 | def _get_crop_attrs(self, idx): 57 | gtbbs = grasp.GraspRectangles.load_from_grasp_anything_file(self.grasp_files[idx]) 58 | center = gtbbs.center 59 | left = max(0, min(center[1] - self.output_size // 2, 416 - self.output_size)) 60 | top = max(0, min(center[0] - self.output_size // 2, 416 - self.output_size)) 61 | return center, left, top 62 | 63 | def get_gtbb(self, idx, rot=0, zoom=1.0): 64 | # Jacquard try 65 | gtbbs = grasp.GraspRectangles.load_from_grasp_anything_file(self.grasp_files[idx], scale=self.output_size / 416.0) 66 | 67 | c = self.output_size // 2 68 | gtbbs.rotate(rot, (c, c)) 69 | gtbbs.zoom(zoom, (c, c)) 70 | 71 | # Cornell try 72 | # gtbbs = grasp.GraspRectangles.load_from_grasp_anything_file(self.grasp_files[idx]) 73 | # center, left, top = self._get_crop_attrs(idx) 74 | # gtbbs.rotate(rot, center) 75 | # gtbbs.offset((-top, -left)) 76 | # gtbbs.zoom(zoom, (self.output_size // 2, self.output_size // 2)) 77 | return gtbbs 78 | 79 | def get_depth(self, idx, rot=0, zoom=1.0): 80 | depth_img = image.DepthImage.from_tiff(self.depth_files[idx]) 81 | center, left, top = self._get_crop_attrs(idx) 82 | depth_img.rotate(rot, center) 83 | depth_img.crop((top, left), (min(480, top + self.output_size), min(640, left + self.output_size))) 84 | depth_img.normalise() 85 | depth_img.zoom(zoom) 86 | depth_img.resize((self.output_size, self.output_size)) 87 | return depth_img.img 88 | 89 | def get_rgb(self, idx, rot=0, zoom=1.0, normalise=True): 90 | # mask_file = self.grasp_files[idx].replace("positive_grasp", "mask").replace(".pt", ".npy") 91 | # mask_img = mask.Mask.from_file(mask_file) 92 | rgb_file = re.sub(r"_\d{1}\.pt", ".jpg", self.grasp_files[idx]) 93 | rgb_file = rgb_file.replace("grasp_label_positive", "image") 94 | rgb_img = image.Image.from_file(rgb_file) 95 | # rgb_img = image.Image.mask_out_image(rgb_img, mask_img) 96 | 97 | # Jacquard try 98 | rgb_img.rotate(rot) 99 | rgb_img.zoom(zoom) 100 | rgb_img.resize((self.output_size, self.output_size)) 101 | if normalise: 102 | rgb_img.normalise() 103 | rgb_img.img = rgb_img.img.transpose((2, 0, 1)) 104 | return rgb_img.img 105 | 106 | # Cornell try 107 | # center, left, top = self._get_crop_attrs(idx) 108 | # rgb_img.rotate(rot, center) 109 | # rgb_img.crop((top, left), (min(480, top + self.output_size), min(640, left + self.output_size))) 110 | # rgb_img.zoom(zoom) 111 | # rgb_img.resize((self.output_size, self.output_size)) 112 | # if normalise: 113 | # rgb_img.normalise() 114 | # rgb_img.img = rgb_img.img.transpose((2, 0, 1)) 115 | # return rgb_img.img 116 | -------------------------------------------------------------------------------- /utils/data/grasp_data.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data 6 | 7 | 8 | class GraspDatasetBase(torch.utils.data.Dataset): 9 | """ 10 | An abstract dataset for training networks in a common format. 11 | """ 12 | 13 | def __init__(self, output_size=224, include_depth=True, include_rgb=False, random_rotate=False, 14 | random_zoom=False, input_only=False, seen=True): 15 | """ 16 | :param output_size: Image output size in pixels (square) 17 | :param include_depth: Whether depth image is included 18 | :param include_rgb: Whether RGB image is included 19 | :param random_rotate: Whether random rotations are applied 20 | :param random_zoom: Whether random zooms are applied 21 | :param input_only: Whether to return only the network input (no labels) 22 | """ 23 | self.output_size = output_size 24 | self.random_rotate = random_rotate 25 | self.random_zoom = random_zoom 26 | self.input_only = input_only 27 | self.include_depth = include_depth 28 | self.include_rgb = include_rgb 29 | 30 | self.grasp_files = [] 31 | 32 | if include_depth is False and include_rgb is False: 33 | raise ValueError('At least one of Depth or RGB must be specified.') 34 | 35 | @staticmethod 36 | def numpy_to_torch(s): 37 | if len(s.shape) == 2: 38 | return torch.from_numpy(np.expand_dims(s, 0).astype(np.float32)) 39 | else: 40 | return torch.from_numpy(s.astype(np.float32)) 41 | 42 | def get_gtbb(self, idx, rot=0, zoom=1.0): 43 | raise NotImplementedError() 44 | 45 | def get_depth(self, idx, rot=0, zoom=1.0): 46 | raise NotImplementedError() 47 | 48 | def get_rgb(self, idx, rot=0, zoom=1.0): 49 | raise NotImplementedError() 50 | 51 | def __getitem__(self, idx): 52 | if self.random_rotate: 53 | rotations = [0, np.pi / 2, 2 * np.pi / 2, 3 * np.pi / 2] 54 | rot = random.choice(rotations) 55 | else: 56 | rot = 0.0 57 | 58 | if self.random_zoom: 59 | zoom_factor = np.random.uniform(0.5, 1.0) 60 | else: 61 | zoom_factor = 1.0 62 | 63 | # Load the depth image 64 | if self.include_depth: 65 | depth_img = self.get_depth(idx, rot, zoom_factor) 66 | 67 | # Load the RGB image 68 | if self.include_rgb: 69 | rgb_img = self.get_rgb(idx, rot, zoom_factor) 70 | 71 | # Load the grasps 72 | bbs = self.get_gtbb(idx, rot, zoom_factor) 73 | 74 | pos_img, ang_img, width_img = bbs.draw((self.output_size, self.output_size)) 75 | width_img = np.clip(width_img, 0.0, self.output_size / 2) / (self.output_size / 2) 76 | 77 | if self.include_depth and self.include_rgb: 78 | x = self.numpy_to_torch( 79 | np.concatenate( 80 | (np.expand_dims(depth_img, 0), 81 | rgb_img), 82 | 0 83 | ) 84 | ) 85 | elif self.include_depth: 86 | x = self.numpy_to_torch(depth_img) 87 | elif self.include_rgb: 88 | x = self.numpy_to_torch(rgb_img) 89 | 90 | pos = self.numpy_to_torch(pos_img) 91 | cos = self.numpy_to_torch(np.cos(2 * ang_img)) 92 | sin = self.numpy_to_torch(np.sin(2 * ang_img)) 93 | width = self.numpy_to_torch(width_img) 94 | 95 | return x, (pos, cos, sin, width), idx, rot, zoom_factor 96 | 97 | def __len__(self): 98 | return len(self.grasp_files) 99 | -------------------------------------------------------------------------------- /utils/data/jacquard_data.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | from utils.dataset_processing import grasp, image 5 | from .grasp_data import GraspDatasetBase 6 | 7 | 8 | class JacquardDataset(GraspDatasetBase): 9 | """ 10 | Dataset wrapper for the Jacquard dataset. 11 | """ 12 | 13 | def __init__(self, file_path, ds_rotate=0, **kwargs): 14 | """ 15 | :param file_path: Jacquard Dataset directory. 16 | :param ds_rotate: If splitting the dataset, rotate the list of items by this fraction first 17 | :param kwargs: kwargs for GraspDatasetBase 18 | """ 19 | super(JacquardDataset, self).__init__(**kwargs) 20 | 21 | self.grasp_files = glob.glob(os.path.join(file_path, '*', '*_grasps.txt')) 22 | self.grasp_files.sort() 23 | self.length = len(self.grasp_files) 24 | 25 | if self.length == 0: 26 | raise FileNotFoundError('No dataset files found. Check path: {}'.format(file_path)) 27 | 28 | if ds_rotate: 29 | self.grasp_files = self.grasp_files[int(self.length * ds_rotate):] + self.grasp_files[ 30 | :int(self.length * ds_rotate)] 31 | 32 | self.depth_files = [f.replace('grasps.txt', 'perfect_depth.tiff') for f in self.grasp_files] 33 | self.rgb_files = [f.replace('perfect_depth.tiff', 'RGB.png') for f in self.depth_files] 34 | 35 | def get_gtbb(self, idx, rot=0, zoom=1.0): 36 | gtbbs = grasp.GraspRectangles.load_from_jacquard_file(self.grasp_files[idx], scale=self.output_size / 1024.0) 37 | c = self.output_size // 2 38 | gtbbs.rotate(rot, (c, c)) 39 | gtbbs.zoom(zoom, (c, c)) 40 | return gtbbs 41 | 42 | def get_depth(self, idx, rot=0, zoom=1.0): 43 | depth_img = image.DepthImage.from_tiff(self.depth_files[idx]) 44 | depth_img.rotate(rot) 45 | depth_img.normalise() 46 | depth_img.zoom(zoom) 47 | depth_img.resize((self.output_size, self.output_size)) 48 | return depth_img.img 49 | 50 | def get_rgb(self, idx, rot=0, zoom=1.0, normalise=True): 51 | rgb_img = image.Image.from_file(self.rgb_files[idx]) 52 | rgb_img.rotate(rot) 53 | rgb_img.zoom(zoom) 54 | rgb_img.resize((self.output_size, self.output_size)) 55 | if normalise: 56 | rgb_img.normalise() 57 | rgb_img.img = rgb_img.img.transpose((2, 0, 1)) 58 | return rgb_img.img 59 | 60 | def get_jname(self, idx): 61 | return '_'.join(self.grasp_files[idx].split(os.sep)[-1].split('_')[:-1]) 62 | -------------------------------------------------------------------------------- /utils/data/ocid_grasp_data.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | from utils.dataset_processing import grasp, image 5 | from .grasp_data import GraspDatasetBase 6 | 7 | 8 | class OCIDGraspDataset(GraspDatasetBase): 9 | """ 10 | Dataset wrapper for the Cornell dataset. 11 | """ 12 | 13 | def __init__(self, file_path, ds_rotate=0, **kwargs): 14 | """ 15 | :param file_path: Cornell Dataset directory. 16 | :param ds_rotate: If splitting the dataset, rotate the list of items by this fraction first 17 | :param kwargs: kwargs for GraspDatasetBase 18 | """ 19 | super(OCIDGraspDataset, self).__init__(**kwargs) 20 | 21 | self.grasp_files = glob.glob(os.path.join(file_path, '*', '*', '*', '*', '*', 'Annotations', '*.txt')) 22 | self.grasp_files.sort() 23 | self.length = len(self.grasp_files) 24 | 25 | if self.length == 0: 26 | raise FileNotFoundError('No dataset files found. Check path: {}'.format(file_path)) 27 | 28 | if ds_rotate: 29 | self.grasp_files = self.grasp_files[int(self.length * ds_rotate):] + self.grasp_files[ 30 | :int(self.length * ds_rotate)] 31 | 32 | def _get_crop_attrs(self, idx): 33 | gtbbs = grasp.GraspRectangles.load_from_ocid_grasp_file(self.grasp_files[idx]) 34 | center = gtbbs.center 35 | left = max(0, min(center[1] - self.output_size // 2, 640 - self.output_size)) 36 | top = max(0, min(center[0] - self.output_size // 2, 480 - self.output_size)) 37 | return center, left, top 38 | 39 | def get_gtbb(self, idx, rot=0, zoom=1.0): 40 | gtbbs = grasp.GraspRectangles.load_from_cornell_file(self.grasp_files[idx]) 41 | center, left, top = self._get_crop_attrs(idx) 42 | gtbbs.rotate(rot, center) 43 | gtbbs.offset((-top, -left)) 44 | gtbbs.zoom(zoom, (self.output_size // 2, self.output_size // 2)) 45 | return gtbbs 46 | 47 | def get_depth(self, idx, rot=0, zoom=1.0): 48 | depth_img = image.DepthImage.from_tiff(self.depth_files[idx]) 49 | center, left, top = self._get_crop_attrs(idx) 50 | depth_img.rotate(rot, center) 51 | depth_img.crop((top, left), (min(480, top + self.output_size), min(640, left + self.output_size))) 52 | depth_img.normalise() 53 | depth_img.zoom(zoom) 54 | depth_img.resize((self.output_size, self.output_size)) 55 | return depth_img.img 56 | 57 | def get_rgb(self, idx, rot=0, zoom=1.0, normalise=True): 58 | rgb_file = self.grasp_files[idx].replace("Annotations", "rgb").replace("txt", "png") 59 | rgb_img = image.Image.from_file(rgb_file) 60 | center, left, top = self._get_crop_attrs(idx) 61 | rgb_img.rotate(rot, center) 62 | rgb_img.crop((top, left), (min(480, top + self.output_size), min(640, left + self.output_size))) 63 | rgb_img.zoom(zoom) 64 | rgb_img.resize((self.output_size, self.output_size)) 65 | if normalise: 66 | rgb_img.normalise() 67 | rgb_img.img = rgb_img.img.transpose((2, 0, 1)) 68 | return rgb_img.img 69 | -------------------------------------------------------------------------------- /utils/data/vmrd_data.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import re 4 | 5 | import pickle 6 | import torch 7 | 8 | from utils.dataset_processing import grasp, image, mask 9 | from .grasp_data import GraspDatasetBase 10 | 11 | 12 | class VMRDDataset(GraspDatasetBase): 13 | """ 14 | Dataset wrapper for the Grasp-Anything dataset. 15 | """ 16 | 17 | def __init__(self, file_path, ds_rotate=0, **kwargs): 18 | """ 19 | :param file_path: Grasp-Anything Dataset directory. 20 | :param ds_rotate: If splitting the dataset, rotate the list of items by this fraction first 21 | :param kwargs: kwargs for GraspDatasetBase 22 | """ 23 | super(VMRDDataset, self).__init__(**kwargs) 24 | 25 | self.grasp_files = glob.glob(os.path.join(file_path, 'Grasps', '*.txt')) 26 | self.rgb_files = glob.glob(os.path.join(file_path, 'JPEGImages', '*.jpg')) 27 | self.length = len(self.grasp_files) 28 | 29 | if self.length == 0: 30 | raise FileNotFoundError('No dataset files found. Check path: {}'.format(file_path)) 31 | 32 | if ds_rotate: 33 | self.grasp_files = self.grasp_files[int(self.length * ds_rotate):] + self.grasp_files[ 34 | :int(self.length * ds_rotate)] 35 | 36 | 37 | def _get_crop_attrs(self, idx): 38 | gtbbs = grasp.GraspRectangles.load_from_vmrd_file(self.grasp_files[idx]) 39 | center = gtbbs.center 40 | left = max(0, min(center[1] - self.output_size // 2, 1008 - self.output_size)) 41 | top = max(0, min(center[0] - self.output_size // 2, 756 - self.output_size)) 42 | return center, left, top 43 | 44 | def get_gtbb(self, idx, rot=0, zoom=1.0): 45 | # Jacquard try 46 | gtbbs = grasp.GraspRectangles.load_from_vmrd_file(self.grasp_files[idx]) 47 | center, left, top = self._get_crop_attrs(idx) 48 | gtbbs.rotate(rot, center) 49 | gtbbs.offset((-top, -left)) 50 | gtbbs.zoom(zoom, (self.output_size // 2, self.output_size // 2)) 51 | return gtbbs 52 | 53 | def get_depth(self, idx, rot=0, zoom=1.0): 54 | depth_img = image.DepthImage.from_tiff(self.depth_files[idx]) 55 | center, left, top = self._get_crop_attrs(idx) 56 | depth_img.rotate(rot, center) 57 | depth_img.crop((top, left), (min(480, top + self.output_size), min(640, left + self.output_size))) 58 | depth_img.normalise() 59 | depth_img.zoom(zoom) 60 | depth_img.resize((self.output_size, self.output_size)) 61 | return depth_img.img 62 | 63 | def get_rgb(self, idx, rot=0, zoom=1.0, normalise=True): 64 | rgb_file = self.grasp_files[idx].replace("Grasps", "JPEGImages").replace("txt", "jpg") 65 | rgb_img = image.Image.from_file(rgb_file) 66 | 67 | # Cornell try 68 | center, left, top = self._get_crop_attrs(idx) 69 | rgb_img.rotate(rot, center) 70 | rgb_img.crop((top, left), (min(756, top + self.output_size), min(1008, left + self.output_size))) 71 | rgb_img.zoom(zoom) 72 | rgb_img.resize((self.output_size, self.output_size)) 73 | if normalise: 74 | rgb_img.normalise() 75 | rgb_img.img = rgb_img.img.transpose((2, 0, 1)) 76 | 77 | return rgb_img.img 78 | -------------------------------------------------------------------------------- /utils/dataset_processing/evaluation.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | warnings.filterwarnings("ignore") 7 | 8 | from .grasp import GraspRectangles, detect_grasps 9 | 10 | 11 | def plot_output(fig, rgb_img, grasp_q_img, grasp_angle_img, depth_img=None, no_grasps=1, grasp_width_img=None): 12 | """ 13 | Plot the output of a network 14 | :param fig: Figure to plot the output 15 | :param rgb_img: RGB Image 16 | :param depth_img: Depth Image 17 | :param grasp_q_img: Q output of network 18 | :param grasp_angle_img: Angle output of network 19 | :param no_grasps: Maximum number of grasps to plot 20 | :param grasp_width_img: (optional) Width output of network 21 | :return: 22 | """ 23 | gs = detect_grasps(grasp_q_img, grasp_angle_img, width_img=grasp_width_img, no_grasps=no_grasps) 24 | 25 | plt.ion() 26 | plt.clf() 27 | ax = fig.add_subplot(2, 2, 1) 28 | ax.imshow(rgb_img) 29 | for g in gs: 30 | g.plot(ax) 31 | ax.set_title('RGB') 32 | ax.axis('off') 33 | 34 | if depth_img: 35 | ax = fig.add_subplot(2, 2, 2) 36 | ax.imshow(depth_img, cmap='gray') 37 | for g in gs: 38 | g.plot(ax) 39 | ax.set_title('Depth') 40 | ax.axis('off') 41 | 42 | ax = fig.add_subplot(2, 2, 3) 43 | plot = ax.imshow(grasp_q_img, cmap='jet', vmin=0, vmax=1) 44 | ax.set_title('Q') 45 | ax.axis('off') 46 | plt.colorbar(plot) 47 | 48 | ax = fig.add_subplot(2, 2, 4) 49 | plot = ax.imshow(grasp_angle_img, cmap='hsv', vmin=-np.pi / 2, vmax=np.pi / 2) 50 | ax.set_title('Angle') 51 | ax.axis('off') 52 | plt.colorbar(plot) 53 | plt.pause(0.1) 54 | fig.canvas.draw() 55 | 56 | 57 | def calculate_iou_match(grasp_q, grasp_angle, ground_truth_bbs, no_grasps=1, grasp_width=None, threshold=0.25): 58 | """ 59 | Calculate grasp success using the IoU (Jacquard) metric (e.g. in https://arxiv.org/abs/1301.3592) 60 | A success is counted if grasp rectangle has a 25% IoU with a ground truth, and is withing 30 degrees. 61 | :param grasp_q: Q outputs of network (Nx300x300x3) 62 | :param grasp_angle: Angle outputs of network 63 | :param ground_truth_bbs: Corresponding ground-truth BoundingBoxes 64 | :param no_grasps: Maximum number of grasps to consider per image. 65 | :param grasp_width: (optional) Width output from network 66 | :param threshold: Threshold for IOU matching. Detect with IOU ≥ threshold 67 | :return: success 68 | """ 69 | 70 | if not isinstance(ground_truth_bbs, GraspRectangles): 71 | gt_bbs = GraspRectangles.load_from_array(ground_truth_bbs) 72 | else: 73 | gt_bbs = ground_truth_bbs 74 | gs = detect_grasps(grasp_q, grasp_angle, width_img=grasp_width, no_grasps=no_grasps) 75 | for g in gs: 76 | if g.max_iou(gt_bbs) > threshold: 77 | return True 78 | else: 79 | return False 80 | -------------------------------------------------------------------------------- /utils/dataset_processing/generate_cornell_depth.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | 5 | import numpy as np 6 | from imageio import imsave 7 | 8 | from utils.dataset_processing.image import DepthImage 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser(description='Generate depth images from Cornell PCD files.') 12 | parser.add_argument('path', type=str, help='Path to Cornell Grasping Dataset') 13 | args = parser.parse_args() 14 | 15 | pcds = glob.glob(os.path.join(args.path, '*', 'pcd*[0-9].txt')) 16 | pcds.sort() 17 | 18 | for pcd in pcds: 19 | di = DepthImage.from_pcd(pcd, (480, 640)) 20 | di.inpaint() 21 | 22 | of_name = pcd.replace('.txt', 'd.tiff') 23 | print(of_name) 24 | imsave(of_name, di.img.astype(np.float32)) 25 | -------------------------------------------------------------------------------- /utils/dataset_processing/image.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import cv2 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from imageio import imread 7 | import imageio 8 | from skimage.transform import rotate, resize 9 | 10 | warnings.filterwarnings("ignore", category=UserWarning) 11 | 12 | 13 | class Image: 14 | """ 15 | Wrapper around an image with some convenient functions. 16 | """ 17 | 18 | def __init__(self, img): 19 | self.img = img 20 | 21 | def __getattr__(self, attr): 22 | # Pass along any other methods to the underlying ndarray 23 | return getattr(self.img, attr) 24 | 25 | @classmethod 26 | def from_file(cls, fname): 27 | return cls(imread(fname)) 28 | 29 | def copy(self): 30 | """ 31 | :return: Copy of self. 32 | """ 33 | return self.__class__(self.img.copy()) 34 | 35 | @classmethod 36 | def mask_out_image(cls, image, mask): 37 | # Apply the mask to the image 38 | masked_image = np.array(image) 39 | masked_image[:, :, 0] = masked_image[:, :, 0] * mask + 255 * (1-mask) 40 | masked_image[:, :, 1] = masked_image[:, :, 1] * mask + 255 * (1-mask) 41 | masked_image[:, :, 2] = masked_image[:, :, 2] * mask + 255 * (1-mask) 42 | 43 | return cls(imageio.core.util.Array(masked_image)) 44 | 45 | def crop(self, top_left, bottom_right, resize=None): 46 | """ 47 | Crop the image to a bounding box given by top left and bottom right pixels. 48 | :param top_left: tuple, top left pixel. 49 | :param bottom_right: tuple, bottom right pixel 50 | :param resize: If specified, resize the cropped image to this size 51 | """ 52 | self.img = self.img[top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]] 53 | if resize is not None: 54 | self.resize(resize) 55 | 56 | def cropped(self, *args, **kwargs): 57 | """ 58 | :return: Cropped copy of the image. 59 | """ 60 | i = self.copy() 61 | i.crop(*args, **kwargs) 62 | return i 63 | 64 | def normalise(self): 65 | """ 66 | Normalise the image by converting to float [0,1] and zero-centering 67 | """ 68 | self.img = self.img.astype(np.float32) / 255.0 69 | self.img -= self.img.mean() 70 | 71 | def resize(self, shape): 72 | """ 73 | Resize image to shape. 74 | :param shape: New shape. 75 | """ 76 | if self.img.shape == shape: 77 | return 78 | self.img = resize(self.img, shape, preserve_range=True).astype(self.img.dtype) 79 | 80 | def resized(self, *args, **kwargs): 81 | """ 82 | :return: Resized copy of the image. 83 | """ 84 | i = self.copy() 85 | i.resize(*args, **kwargs) 86 | return i 87 | 88 | def rotate(self, angle, center=None): 89 | """ 90 | Rotate the image. 91 | :param angle: Angle (in radians) to rotate by. 92 | :param center: Center pixel to rotate if specified, otherwise image center is used. 93 | """ 94 | if center is not None: 95 | center = (center[1], center[0]) 96 | self.img = rotate(self.img, angle / np.pi * 180, center=center, mode='symmetric', preserve_range=True).astype( 97 | self.img.dtype) 98 | 99 | def rotated(self, *args, **kwargs): 100 | """ 101 | :return: Rotated copy of image. 102 | """ 103 | i = self.copy() 104 | i.rotate(*args, **kwargs) 105 | return i 106 | 107 | def show(self, ax=None, **kwargs): 108 | """ 109 | Plot the image 110 | :param ax: Existing matplotlib axis (optional) 111 | :param kwargs: kwargs to imshow 112 | """ 113 | if ax: 114 | ax.imshow(self.img, **kwargs) 115 | else: 116 | plt.imshow(self.img, **kwargs) 117 | plt.show() 118 | 119 | def zoom(self, factor): 120 | """ 121 | "Zoom" the image by cropping and resizing. 122 | :param factor: Factor to zoom by. e.g. 0.5 will keep the center 50% of the image. 123 | """ 124 | sr = int(self.img.shape[0] * (1 - factor)) // 2 125 | sc = int(self.img.shape[1] * (1 - factor)) // 2 126 | orig_shape = self.img.shape 127 | self.img = self.img[sr:self.img.shape[0] - sr, sc: self.img.shape[1] - sc].copy() 128 | self.img = resize(self.img, orig_shape, mode='symmetric', preserve_range=True).astype(self.img.dtype) 129 | 130 | def zoomed(self, *args, **kwargs): 131 | """ 132 | :return: Zoomed copy of the image. 133 | """ 134 | i = self.copy() 135 | i.zoom(*args, **kwargs) 136 | return i 137 | 138 | 139 | class DepthImage(Image): 140 | def __init__(self, img): 141 | super().__init__(img) 142 | 143 | @classmethod 144 | def from_pcd(cls, pcd_filename, shape, default_filler=0, index=None): 145 | """ 146 | Create a depth image from an unstructured PCD file. 147 | If index isn't specified, use euclidean distance, otherwise choose x/y/z=0/1/2 148 | """ 149 | img = np.zeros(shape) 150 | if default_filler != 0: 151 | img += default_filler 152 | 153 | with open(pcd_filename) as f: 154 | for l in f.readlines(): 155 | ls = l.split() 156 | 157 | if len(ls) != 5: 158 | # Not a point line in the file. 159 | continue 160 | try: 161 | # Not a number, carry on. 162 | float(ls[0]) 163 | except ValueError: 164 | continue 165 | 166 | i = int(ls[4]) 167 | r = i // shape[1] 168 | c = i % shape[1] 169 | 170 | if index is None: 171 | x = float(ls[0]) 172 | y = float(ls[1]) 173 | z = float(ls[2]) 174 | 175 | img[r, c] = np.sqrt(x ** 2 + y ** 2 + z ** 2) 176 | 177 | else: 178 | img[r, c] = float(ls[index]) 179 | 180 | return cls(img / 1000.0) 181 | 182 | @classmethod 183 | def from_tiff(cls, fname): 184 | return cls(imread(fname)) 185 | 186 | def inpaint(self, missing_value=0): 187 | """ 188 | Inpaint missing values in depth image. 189 | :param missing_value: Value to fill in teh depth image. 190 | """ 191 | # cv2 inpainting doesn't handle the border properly 192 | # https://stackoverflow.com/questions/25974033/inpainting-depth-map-still-a-black-image-border 193 | self.img = cv2.copyMakeBorder(self.img, 1, 1, 1, 1, cv2.BORDER_DEFAULT) 194 | mask = (self.img == missing_value).astype(np.uint8) 195 | 196 | # Scale to keep as float, but has to be in bounds -1:1 to keep opencv happy. 197 | scale = np.abs(self.img).max() 198 | self.img = self.img.astype(np.float32) / scale # Has to be float32, 64 not supported. 199 | self.img = cv2.inpaint(self.img, mask, 1, cv2.INPAINT_NS) 200 | 201 | # Back to original size and value range. 202 | self.img = self.img[1:-1, 1:-1] 203 | self.img = self.img * scale 204 | 205 | def gradients(self): 206 | """ 207 | Compute gradients of the depth image using Sobel filtesr. 208 | :return: Gradients in X direction, Gradients in Y diretion, Magnitude of XY gradients. 209 | """ 210 | grad_x = cv2.Sobel(self.img, cv2.CV_64F, 1, 0, borderType=cv2.BORDER_DEFAULT) 211 | grad_y = cv2.Sobel(self.img, cv2.CV_64F, 0, 1, borderType=cv2.BORDER_DEFAULT) 212 | grad = np.sqrt(grad_x ** 2 + grad_y ** 2) 213 | 214 | return DepthImage(grad_x), DepthImage(grad_y), DepthImage(grad) 215 | 216 | def normalise(self): 217 | """ 218 | Normalise by subtracting the mean and clippint [-1, 1] 219 | """ 220 | self.img = np.clip((self.img - self.img.mean()), -1, 1) 221 | 222 | 223 | class WidthImage(Image): 224 | """ 225 | A width image is one that describes the desired gripper width at each pixel. 226 | """ 227 | 228 | def zoom(self, factor): 229 | """ 230 | "Zoom" the image by cropping and resizing. Also scales the width accordingly. 231 | :param factor: Factor to zoom by. e.g. 0.5 will keep the center 50% of the image. 232 | """ 233 | super().zoom(factor) 234 | self.img = self.img / factor 235 | 236 | def normalise(self): 237 | """ 238 | Normalise by mapping [0, 150] -> [0, 1] 239 | """ 240 | self.img = np.clip(self.img, 0, 150.0) / 150.0 241 | -------------------------------------------------------------------------------- /utils/dataset_processing/mask.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | 5 | warnings.filterwarnings("ignore", category=UserWarning) 6 | 7 | class Mask: 8 | """ 9 | Wrapper around an mask with some convenient functions. 10 | """ 11 | 12 | def __init__(self, mask): 13 | self.mask = mask 14 | 15 | def __getattr__(self, attr): 16 | # Pass along any other methods to the underlying ndarray 17 | return getattr(self.mask, attr) 18 | 19 | @classmethod 20 | def from_file(cls, fname): 21 | return np.load(fname, mmap_mode='r') -------------------------------------------------------------------------------- /utils/get_cornell.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | END=9 4 | for ((i=1;i<=END;i++)); do 5 | wget --retry 5 http://pr.cs.cornell.edu/grasping/rect_data/temp/data0$i.tar.gz 6 | tar xvzf data0$i.tar.gz 7 | rm data0$i.tar.gz 8 | done 9 | 10 | wget --retry 5 http://pr.cs.cornell.edu/grasping/rect_data/temp/data10.tar.gz 11 | tar xvzf data10.tar.gz 12 | rm data10.tar.gz 13 | -------------------------------------------------------------------------------- /utils/get_jacquard.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | END=11 4 | for ((i=0;i<=END;i++)); do 5 | wget --retry 5 https://jacquard.liris.cnrs.fr/data/Download/Jacquard_Dataset_$i.zip 6 | unzip Jacquard_Dataset_$i.zip 7 | rm Jacquard_Dataset_$i.zip 8 | done -------------------------------------------------------------------------------- /utils/timeit.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class TimeIt: 5 | """ 6 | Print nested timing information. 7 | """ 8 | print_output = True 9 | last_parent = None 10 | level = -1 11 | 12 | def __init__(self, s): 13 | self.s = s 14 | self.t0 = None 15 | self.t1 = None 16 | self.outputs = [] 17 | self.parent = None 18 | 19 | def __enter__(self): 20 | self.t0 = time.time() 21 | self.parent = TimeIt.last_parent 22 | TimeIt.last_parent = self 23 | TimeIt.level += 1 24 | 25 | def __exit__(self, t, value, traceback): 26 | self.t1 = time.time() 27 | st = '%s%s: %0.1fms' % (' ' * TimeIt.level, self.s, (self.t1 - self.t0) * 1000) 28 | TimeIt.level -= 1 29 | 30 | if self.parent: 31 | self.parent.outputs.append(st) 32 | self.parent.outputs += self.outputs 33 | else: 34 | if TimeIt.print_output: 35 | print(st) 36 | for o in self.outputs: 37 | print(o) 38 | self.outputs = [] 39 | 40 | TimeIt.last_parent = self.parent 41 | -------------------------------------------------------------------------------- /utils/visualisation/gridshow.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def gridshow(name, imgs, scales, cmaps, width, border=10): 6 | """ 7 | Display images in a grid. 8 | :param name: cv2 Window Name to update 9 | :param imgs: List of Images (np.ndarrays) 10 | :param scales: The min/max scale of images to properly scale the colormaps 11 | :param cmaps: List of cv2 Colormaps to apply 12 | :param width: Number of images in a row 13 | :param border: Border (pixels) between images. 14 | """ 15 | imgrows = [] 16 | imgcols = [] 17 | 18 | maxh = 0 19 | for i, (img, cmap, scale) in enumerate(zip(imgs, cmaps, scales)): 20 | 21 | # Scale images into range 0-1 22 | if scale is not None: 23 | img = (np.clip(img, scale[0], scale[1]) - scale[0]) / (scale[1] - scale[0]) 24 | elif img.dtype == np.float: 25 | img = (img - img.min()) / (img.max() - img.min() + 1e-6) 26 | 27 | # Apply colormap (if applicable) and convert to uint8 28 | if cmap is not None: 29 | try: 30 | imgc = cv2.applyColorMap((img * 255).astype(np.uint8), cmap) 31 | except: 32 | imgc = (img * 255.0).astype(np.uint8) 33 | else: 34 | imgc = img 35 | 36 | if imgc.shape[0] == 3: 37 | imgc = imgc.transpose((1, 2, 0)) 38 | elif imgc.shape[0] == 4: 39 | imgc = imgc[1:, :, :].transpose((1, 2, 0)) 40 | 41 | # Arrange row of images. 42 | maxh = max(maxh, imgc.shape[0]) 43 | imgcols.append(imgc) 44 | if i > 0 and i % width == (width - 1): 45 | imgrows.append(np.hstack( 46 | [np.pad(c, ((0, maxh - c.shape[0]), (border // 2, border // 2), (0, 0)), mode='constant') for c in 47 | imgcols])) 48 | imgcols = [] 49 | maxh = 0 50 | 51 | # Unfinished row 52 | if imgcols: 53 | imgrows.append(np.hstack( 54 | [np.pad(c, ((0, maxh - c.shape[0]), (border // 2, border // 2), (0, 0)), mode='constant') for c in 55 | imgcols])) 56 | 57 | maxw = max([c.shape[1] for c in imgrows]) 58 | 59 | cv2.imshow(name, np.vstack( 60 | [np.pad(r, ((border // 2, border // 2), (0, maxw - r.shape[1]), (0, 0)), mode='constant') for r in imgrows])) 61 | -------------------------------------------------------------------------------- /utils/visualisation/plot.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from datetime import datetime 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | from utils.dataset_processing.grasp import detect_grasps 8 | 9 | warnings.filterwarnings("ignore") 10 | 11 | 12 | def plot_results( 13 | fig, 14 | rgb_img, 15 | grasp_q_img, 16 | grasp_angle_img, 17 | depth_img=None, 18 | no_grasps=1, 19 | grasp_width_img=None 20 | ): 21 | """ 22 | Plot the output of a network 23 | :param fig: Figure to plot the output 24 | :param rgb_img: RGB Image 25 | :param depth_img: Depth Image 26 | :param grasp_q_img: Q output of network 27 | :param grasp_angle_img: Angle output of network 28 | :param no_grasps: Maximum number of grasps to plot 29 | :param grasp_width_img: (optional) Width output of network 30 | :return: 31 | """ 32 | gs = detect_grasps(grasp_q_img, grasp_angle_img, width_img=grasp_width_img, no_grasps=no_grasps) 33 | 34 | plt.ion() 35 | plt.clf() 36 | ax = fig.add_subplot(2, 3, 1) 37 | ax.imshow(rgb_img) 38 | ax.set_title('RGB') 39 | ax.axis('off') 40 | 41 | if depth_img is not None: 42 | ax = fig.add_subplot(2, 3, 2) 43 | ax.imshow(depth_img, cmap='gray') 44 | ax.set_title('Depth') 45 | ax.axis('off') 46 | 47 | ax = fig.add_subplot(2, 3, 3) 48 | ax.imshow(rgb_img) 49 | for g in gs: 50 | g.plot(ax) 51 | ax.set_title('Grasp') 52 | ax.axis('off') 53 | 54 | ax = fig.add_subplot(2, 3, 4) 55 | plot = ax.imshow(grasp_q_img, cmap='jet', vmin=0, vmax=1) 56 | ax.set_title('Q') 57 | ax.axis('off') 58 | plt.colorbar(plot) 59 | 60 | ax = fig.add_subplot(2, 3, 5) 61 | plot = ax.imshow(grasp_angle_img, cmap='hsv', vmin=-np.pi / 2, vmax=np.pi / 2) 62 | ax.set_title('Angle') 63 | ax.axis('off') 64 | plt.colorbar(plot) 65 | 66 | ax = fig.add_subplot(2, 3, 6) 67 | plot = ax.imshow(grasp_width_img, cmap='jet', vmin=0, vmax=100) 68 | ax.set_title('Width') 69 | ax.axis('off') 70 | plt.colorbar(plot) 71 | 72 | plt.pause(0.1) 73 | fig.canvas.draw() 74 | 75 | 76 | def plot_grasp( 77 | fig, 78 | grasps=None, 79 | save=False, 80 | rgb_img=None, 81 | grasp_q_img=None, 82 | grasp_angle_img=None, 83 | no_grasps=1, 84 | grasp_width_img=None 85 | ): 86 | """ 87 | Plot the output grasp of a network 88 | :param fig: Figure to plot the output 89 | :param grasps: grasp pose(s) 90 | :param save: Bool for saving the plot 91 | :param rgb_img: RGB Image 92 | :param grasp_q_img: Q output of network 93 | :param grasp_angle_img: Angle output of network 94 | :param no_grasps: Maximum number of grasps to plot 95 | :param grasp_width_img: (optional) Width output of network 96 | :return: 97 | """ 98 | if grasps is None: 99 | grasps = detect_grasps(grasp_q_img, grasp_angle_img, width_img=grasp_width_img, no_grasps=no_grasps) 100 | 101 | plt.ion() 102 | plt.clf() 103 | 104 | ax = plt.subplot(111) 105 | ax.imshow(rgb_img) 106 | for g in grasps: 107 | g.plot(ax) 108 | ax.set_title('Grasp') 109 | ax.axis('off') 110 | 111 | plt.pause(0.1) 112 | fig.canvas.draw() 113 | 114 | if save: 115 | time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') 116 | fig.savefig('results/{}.png'.format(time)) 117 | 118 | 119 | def save_results(rgb_img, grasp_q_img, grasp_angle_img, depth_img=None, no_grasps=1, grasp_width_img=None): 120 | """ 121 | Plot the output of a network 122 | :param rgb_img: RGB Image 123 | :param depth_img: Depth Image 124 | :param grasp_q_img: Q output of network 125 | :param grasp_angle_img: Angle output of network 126 | :param no_grasps: Maximum number of grasps to plot 127 | :param grasp_width_img: (optional) Width output of network 128 | :return: 129 | """ 130 | gs = detect_grasps(grasp_q_img, grasp_angle_img, width_img=grasp_width_img, no_grasps=no_grasps) 131 | 132 | fig = plt.figure(figsize=(10, 10)) 133 | plt.ion() 134 | plt.clf() 135 | ax = plt.subplot(111) 136 | ax.imshow(rgb_img) 137 | ax.set_title('RGB') 138 | ax.axis('off') 139 | fig.savefig('results/rgb.png') 140 | 141 | if depth_img.any(): 142 | fig = plt.figure(figsize=(10, 10)) 143 | plt.ion() 144 | plt.clf() 145 | ax = plt.subplot(111) 146 | ax.imshow(depth_img, cmap='gray') 147 | for g in gs: 148 | g.plot(ax) 149 | ax.set_title('Depth') 150 | ax.axis('off') 151 | fig.savefig('results/depth.png') 152 | 153 | fig = plt.figure(figsize=(10, 10)) 154 | plt.ion() 155 | plt.clf() 156 | ax = plt.subplot(111) 157 | ax.imshow(rgb_img) 158 | for g in gs: 159 | g.plot(ax) 160 | ax.set_title('Grasp') 161 | ax.axis('off') 162 | fig.savefig('results/grasp.png') 163 | 164 | fig = plt.figure(figsize=(10, 10)) 165 | plt.ion() 166 | plt.clf() 167 | ax = plt.subplot(111) 168 | plot = ax.imshow(grasp_q_img, cmap='jet', vmin=0, vmax=1) 169 | ax.set_title('Q') 170 | ax.axis('off') 171 | plt.colorbar(plot) 172 | fig.savefig('results/quality.png') 173 | 174 | fig = plt.figure(figsize=(10, 10)) 175 | plt.ion() 176 | plt.clf() 177 | ax = plt.subplot(111) 178 | plot = ax.imshow(grasp_angle_img, cmap='hsv', vmin=-np.pi / 2, vmax=np.pi / 2) 179 | ax.set_title('Angle') 180 | ax.axis('off') 181 | plt.colorbar(plot) 182 | fig.savefig('results/angle.png') 183 | 184 | fig = plt.figure(figsize=(10, 10)) 185 | plt.ion() 186 | plt.clf() 187 | ax = plt.subplot(111) 188 | plot = ax.imshow(grasp_width_img, cmap='jet', vmin=0, vmax=100) 189 | ax.set_title('Width') 190 | ax.axis('off') 191 | plt.colorbar(plot) 192 | fig.savefig('results/width.png') 193 | 194 | fig.canvas.draw() 195 | plt.close(fig) 196 | -------------------------------------------------------------------------------- /weights/model_cornell: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Grasp-Anything/d7755f43c5518bd6590b25021054f862e65bddd5/weights/model_cornell -------------------------------------------------------------------------------- /weights/model_grasp_anything: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Grasp-Anything/d7755f43c5518bd6590b25021054f862e65bddd5/weights/model_grasp_anything -------------------------------------------------------------------------------- /weights/model_jacquard: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Grasp-Anything/d7755f43c5518bd6590b25021054f862e65bddd5/weights/model_jacquard -------------------------------------------------------------------------------- /weights/model_ocid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Grasp-Anything/d7755f43c5518bd6590b25021054f862e65bddd5/weights/model_ocid -------------------------------------------------------------------------------- /weights/model_vmrd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fsoft-AIC/Grasp-Anything/d7755f43c5518bd6590b25021054f862e65bddd5/weights/model_vmrd --------------------------------------------------------------------------------