├── README.md ├── img ├── grasp-transformer.png └── test ├── main.py ├── main_grasp_1b.py ├── main_k_fold.py ├── models ├── __init__.py ├── common.py └── swin.py ├── requirements.txt ├── traning.py ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ └── __init__.cpython-38.pyc ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── cornell_data.cpython-36.pyc │ │ ├── cornell_data.cpython-37.pyc │ │ ├── grasp_data.cpython-36.pyc │ │ ├── grasp_data.cpython-37.pyc │ │ ├── jacquard_data.cpython-36.pyc │ │ └── multi_object.cpython-36.pyc │ ├── cornell_data.py │ ├── gn1b_data.py │ ├── grasp_data.py │ ├── jacquard_data.py │ ├── multi_object.py │ └── multigrasp_object.py ├── dataset_processing │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── evaluation.cpython-36.pyc │ │ ├── evaluation.cpython-37.pyc │ │ ├── grasp.cpython-36.pyc │ │ ├── grasp.cpython-37.pyc │ │ ├── image.cpython-36.pyc │ │ └── image.cpython-37.pyc │ ├── evaluation.py │ ├── generate_cornell_depth.py │ ├── grasp.py │ ├── image.py │ └── multigrasp_object.py ├── timeit.py └── visualisation │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── gridshow.cpython-36.pyc │ ├── gridshow.cpython-37.pyc │ └── gridshow.cpython-38.pyc │ └── gridshow.py ├── visualise_grasp_rectangle.py └── visulaize_heatmaps.py /README.md: -------------------------------------------------------------------------------- 1 | ## When Transformer Meets Robotic Grasping: Exploits Context for Efficient Grasping Detection 2 | 3 | PyTorch implementation of paper "When Transformer Meets Robotic Grasping: 4 | Exploits Context for Efficient Grasping Detection" 5 | 6 | ## Visualization of the architecture 7 | 8 |
9 | 10 | This code was developed with Python 3.6 on Ubuntu 16.04. Python requirements can installed by: 11 | 12 | ```bash 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | ## Datasets 17 | 18 | Currently, both the [Cornell Grasping Dataset](http://pr.cs.cornell.edu/grasping/rect_data/data.php), 19 | [Jacquard Dataset](https://jacquard.liris.cnrs.fr/) , and [GraspNet 1Billion](https://graspnet.net/datasets.html) are supported. 20 | 21 | ### Cornell Grasping Dataset 22 | 1. Download the and extract [Cornell Grasping Dataset](http://pr.cs.cornell.edu/grasping/rect_data/data.php). 23 | 24 | ### Jacquard Dataset 25 | 26 | 1. Download and extract the [Jacquard Dataset](https://jacquard.liris.cnrs.fr/). 27 | 28 | ### GraspNet 1Billion dataset 29 | 30 | 1. The dataset can be downloaded [here](https://graspnet.net/datasets.html). 31 | 2. Install graspnetAPI following [here](https://graspnetapi.readthedocs.io/en/latest/install.html#install-api). 32 | 33 | ```bash 34 | pip install graspnetAPI 35 | ``` 36 | 3. We use the setting in [here](https://github.com/ryanreadbooks/Modified-GGCNN) 37 | 38 | 39 | ## Training 40 | 41 | Training is done by the `main.py` script. 42 | 43 | Some basic examples: 44 | 45 | ```bash 46 | # Train on Cornell Dataset 47 | python main.py --dataset cornell 48 | 49 | # k-fold training 50 | python main_k_fold.py --dataset cornell 51 | 52 | # GraspNet 1 53 | python main_grasp_1b.py 54 | ``` 55 | 56 | Trained models are saved in `output/models` by default, with the validation score appended. 57 | 58 | ## Visualize 59 | Some basic examples: 60 | ```bash 61 | # visulaize grasp rectangles 62 | python visualise_grasp_rectangle.py --network your network address 63 | 64 | # visulaize heatmaps 65 | python visulaize_heatmaps.py --network your network address 66 | 67 | ``` 68 | 69 | 70 | 71 | ## Running on a Robot 72 | 73 | Our ROS implementation for running the grasping system see [https://github.com/USTC-ICR/SimGrasp/tree/main/SimGrasp](https://github.com/USTC-ICR/SimGrasp/tree/main/SimGrasp). 74 | 75 | The original implementation for running experiments on a Kinva Mico arm can be found in the repository [https://github.com/dougsm/ggcnn_kinova_grasping](https://github.com/dougsm/ggcnn_kinova_grasping). 76 | 77 | ## Acknowledgement 78 | Code heavily inspired and modified from https://github.com/dougsm/ggcnn 79 | 80 | If you find this helpful, please cite 81 | ```bash 82 | @ARTICLE{9810182, 83 | author={Wang, Shaochen and Zhou, Zhangli and Kan, Zhen}, 84 | journal={IEEE Robotics and Automation Letters}, 85 | title={When Transformer Meets Robotic Grasping: Exploits Context for Efficient Grasp Detection}, 86 | year={2022}, 87 | volume={}, 88 | number={}, 89 | pages={1-8}, 90 | doi={10.1109/LRA.2022.3187261}} 91 | 92 | ``` 93 | -------------------------------------------------------------------------------- /img/grasp-transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/img/grasp-transformer.png -------------------------------------------------------------------------------- /img/test: -------------------------------------------------------------------------------- 1 | test 2 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import sys 4 | import argparse 5 | import logging 6 | import torch 7 | import torch.utils.data 8 | import torch.optim as optim 9 | from torchsummary import summary 10 | from traning import train, validate 11 | from utils.data import get_dataset 12 | from models.swin import SwinTransformerSys 13 | logging.basicConfig(level=logging.INFO) 14 | def parse_args(): 15 | parser = argparse.ArgumentParser(description='TF-Grasp') 16 | 17 | # Network 18 | # Dataset & Data & Training 19 | parser.add_argument('--dataset', type=str,default="jacquard", help='Dataset Name ("cornell" or "jacquard")') 20 | parser.add_argument('--dataset-path', type=str,default="/home/sam/Desktop/cornell" ,help='Path to dataset') 21 | parser.add_argument('--use-depth', type=int, default=0, help='Use Depth image for training (1/0)') 22 | parser.add_argument('--use-rgb', type=int, default=1, help='Use RGB image for training (0/1)') 23 | parser.add_argument('--split', type=float, default=0.9, help='Fraction of data for training (remainder is validation)') 24 | parser.add_argument('--ds-rotate', type=float, default=0.0, 25 | help='Shift the start point of the dataset to use a different test/train split for cross validation.') 26 | parser.add_argument('--num-workers', type=int, default=8, help='Dataset workers') 27 | 28 | parser.add_argument('--batch-size', type=int, default=32, help='Batch size') 29 | parser.add_argument('--vis', type=bool, default=False, help='vis') 30 | parser.add_argument('--epochs', type=int, default=2000, help='Training epochs') 31 | parser.add_argument('--batches-per-epoch', type=int, default=200, help='Batches per Epoch') 32 | parser.add_argument('--val-batches', type=int, default=32, help='Validation Batches') 33 | # Logging etc. 34 | parser.add_argument('--description', type=str, default='', help='Training description') 35 | parser.add_argument('--outdir', type=str, default='output/models/', help='Training Output Directory') 36 | 37 | args = parser.parse_args() 38 | return args 39 | def run(): 40 | args = parse_args() 41 | dt = datetime.datetime.now().strftime('%y%m%d_%H%M') 42 | net_desc = '{}_{}'.format(dt, '_'.join(args.description.split())) 43 | 44 | save_folder = os.path.join(args.outdir, net_desc) 45 | if not os.path.exists(save_folder): 46 | os.makedirs(save_folder) 47 | # tb = tensorboardX.SummaryWriter(os.path.join(args.logdir, net_desc)) 48 | 49 | # Load Dataset 50 | logging.info('Loading {} Dataset...'.format(args.dataset.title())) 51 | Dataset = get_dataset(args.dataset) 52 | 53 | train_dataset = Dataset(args.dataset_path, start=0.0, end=args.split, ds_rotate=args.ds_rotate, 54 | random_rotate=True, random_zoom=False, 55 | include_depth=args.use_depth, include_rgb=args.use_rgb) 56 | train_data = torch.utils.data.DataLoader( 57 | train_dataset, 58 | batch_size=args.batch_size, 59 | shuffle=True, 60 | num_workers=args.num_workers 61 | ) 62 | val_dataset = Dataset(args.dataset_path, start=args.split, end=1.0, ds_rotate=args.ds_rotate, 63 | random_rotate=True, random_zoom=False, 64 | include_depth=args.use_depth, include_rgb=args.use_rgb) 65 | val_data = torch.utils.data.DataLoader( 66 | val_dataset, 67 | batch_size=1, 68 | shuffle=False, 69 | num_workers=args.num_workers 70 | ) 71 | logging.info('Done') 72 | 73 | logging.info('Loading Network...') 74 | input_channels = 1*args.use_depth + 3*args.use_rgb 75 | net = SwinTransformerSys(in_chans=input_channels, embed_dim=48, num_heads=[1, 2, 4, 8]) 76 | device = torch.device("cuda:0") 77 | net = net.to(device) 78 | optimizer = optim.AdamW(net.parameters(), lr=1e-4) 79 | listy = [x * 2 for x in range(1,1000,5)] 80 | schedule=torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=listy,gamma=0.5) 81 | logging.info('Done') 82 | summary(net, (input_channels, 224, 224)) 83 | f = open(os.path.join(save_folder, 'net.txt'), 'w') 84 | sys.stdout = f 85 | summary(net, (input_channels, 224, 224)) 86 | sys.stdout = sys.__stdout__ 87 | f.close() 88 | best_iou = 0.0 89 | for epoch in range(args.epochs): 90 | logging.info('Beginning Epoch {:02d}'.format(epoch)) 91 | print("current lr:",optimizer.state_dict()['param_groups'][0]['lr']) 92 | # for i in range(5000): 93 | train_results = train(epoch, net, device, train_data, optimizer, args.batches_per_epoch, vis=args.vis) 94 | 95 | # schedule.step() 96 | # Run Validation 97 | logging.info('Validating...') 98 | test_results = validate(net, device, val_data, args.val_batches) 99 | logging.info('%d/%d = %f' % (test_results['correct'], test_results['correct'] + test_results['failed'], 100 | test_results['correct']/(test_results['correct']+test_results['failed']))) 101 | 102 | iou = test_results['correct'] / (test_results['correct'] + test_results['failed']) 103 | if epoch%1==0 or iou>best_iou: 104 | torch.save(net, os.path.join(save_folder, 'epoch_%02d_iou_%0.2f' % (epoch, iou))) 105 | torch.save(net.state_dict(), os.path.join(save_folder, 'epoch_%02d_iou_%0.2f_statedict.pt' % (epoch, iou))) 106 | best_iou = iou 107 | 108 | 109 | if __name__ == '__main__': 110 | run() 111 | -------------------------------------------------------------------------------- /main_grasp_1b.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from doctest import FAIL_FAST 3 | import os 4 | import sys 5 | import argparse 6 | import logging 7 | from tqdm import tqdm 8 | import cv2 9 | 10 | import torch 11 | import torch.utils.data 12 | import torch.optim as optim 13 | 14 | from torchsummary import summary 15 | 16 | 17 | from traning import train, validate 18 | from utils.visualisation.gridshow import gridshow 19 | 20 | from utils.dataset_processing import evaluation 21 | from utils.data import get_dataset 22 | from models.common import post_process_output 23 | from models.swin import SwinTransformerSys 24 | logging.basicConfig(level=logging.INFO) 25 | 26 | def parse_args(): 27 | parser = argparse.ArgumentParser(description='TF-Grasp') 28 | 29 | # Network 30 | # Dataset & Data & Training 31 | parser.add_argument('--dataset', type=str,default="graspnet1b", help='Dataset Name ("cornell" or "jaquard")') 32 | parser.add_argument('--use-depth', type=int, default=1, help='Use Depth image for training (1/0)') 33 | parser.add_argument('--use-rgb', type=int, default=0, help='Use RGB image for training (0/1)') 34 | parser.add_argument('--split', type=float, default=1., help='Fraction of data for training (remainder is validation)') 35 | parser.add_argument('--ds-rotate', type=float, default=0.0, 36 | help='Shift the start point of the dataset to use a different test/train split for cross validation.') 37 | parser.add_argument('--num-workers', type=int, default=32, help='Dataset workers') 38 | 39 | parser.add_argument('--batch-size', type=int, default=32, help='Batch size') 40 | parser.add_argument('--epochs', type=int, default=500, help='Training epochs') 41 | parser.add_argument('--batches-per-epoch', type=int, default=500, help='Batches per Epoch') 42 | parser.add_argument('--val-batches', type=int, default=100, help='Validation Batches') 43 | parser.add_argument('--output-size', type=int, default=224, 44 | help='the output size of the network, determining the cropped size of dataset images') 45 | 46 | parser.add_argument('--camera', type=str, default='realsense', 47 | help='Which camera\'s data to use, only effective when using graspnet1b dataset') 48 | parser.add_argument('--scale', type=int, default=2, 49 | help='the scale factor for the original images, only effective when using graspnet1b dataset') 50 | # Logging etc. 51 | parser.add_argument('--description', type=str, default='', help='Training description') 52 | parser.add_argument('--outdir', type=str, default='output/models/', help='Training Output Directory') 53 | parser.add_argument('--logdir', type=str, default='tensorboard/', help='Log directory') 54 | parser.add_argument('--vis', default=False,help='Visualise the training process') 55 | 56 | args = parser.parse_args() 57 | return args 58 | 59 | 60 | def validate(net, device, val_data, batches_per_epoch,no_grasps=1): 61 | """ 62 | Run validation. 63 | :param net: Network 64 | :param device: Torch device 65 | :param val_data: Validation Dataset 66 | :param batches_per_epoch: Number of batches to run 67 | :return: Successes, Failures and Losses 68 | """ 69 | net.eval() 70 | 71 | results = { 72 | 'correct': 0, 73 | 'failed': 0, 74 | 'loss': 0, 75 | 'losses': { 76 | 77 | } 78 | } 79 | 80 | ld = len(val_data) 81 | 82 | with torch.no_grad(): 83 | batch_idx = 0 84 | while batch_idx < batches_per_epoch: 85 | for x, y, didx, rot, zoom_factor in tqdm(val_data): 86 | batch_idx += 1 87 | if batches_per_epoch is not None and batch_idx >= batches_per_epoch: 88 | break 89 | 90 | xc = x.to(device) 91 | yc = [yy.to(device) for yy in y] 92 | lossd = net.compute_loss(xc, yc) 93 | 94 | loss = lossd['loss'] 95 | 96 | results['loss'] += loss.item()/ld 97 | for ln, l in lossd['losses'].items(): 98 | if ln not in results['losses']: 99 | results['losses'][ln] = 0 100 | results['losses'][ln] += l.item()/ld 101 | 102 | q_out, ang_out, w_out = post_process_output(lossd['pred']['pos'], lossd['pred']['cos'], 103 | lossd['pred']['sin'], lossd['pred']['width']) 104 | # print("inde:",didx) 105 | 106 | s = evaluation.calculate_iou_match(q_out, ang_out, 107 | val_data.dataset.get_gtbb(didx, rot, zoom_factor), 108 | no_grasps=no_grasps, 109 | grasp_width=w_out, 110 | ) 111 | 112 | if s: 113 | results['correct'] += 1 114 | else: 115 | results['failed'] += 1 116 | 117 | 118 | return results 119 | 120 | 121 | def run(): 122 | args = parse_args() 123 | 124 | dt = datetime.datetime.now().strftime('%y%m%d_%H%M') 125 | net_desc = '{}_{}'.format(dt, '_'.join(args.description.split())) 126 | 127 | save_folder = os.path.join(args.outdir, net_desc+"_d="+str(args.use_depth+args.use_rgb)+"_scale=3") 128 | if not os.path.exists(save_folder): 129 | os.makedirs(save_folder) 130 | 131 | # Load Dataset 132 | logging.info('Loading {} Dataset...'.format(args.dataset.title())) 133 | Dataset = get_dataset(args.dataset) 134 | 135 | if args.dataset == 'graspnet1b': 136 | print("1 billion") 137 | train_dataset = Dataset( args.dataset_path, ds_rotate=args.ds_rotate, 138 | output_size=args.output_size, 139 | random_rotate=False, random_zoom=False, 140 | include_depth=args.use_depth, 141 | include_rgb=args.use_rgb, 142 | camera=args.camera, 143 | scale=args.scale, 144 | split='train') 145 | else: 146 | train_dataset = Dataset(file_path=args.dataset_path, start=0.0, end=args.split, ds_rotate=args.ds_rotate, 147 | output_size=args.output_size, 148 | random_rotate=True, random_zoom=True, 149 | include_depth=args.use_depth, 150 | include_rgb=args.use_rgb) 151 | 152 | train_data = torch.utils.data.DataLoader( 153 | train_dataset, 154 | batch_size=args.batch_size, 155 | shuffle=True, 156 | num_workers=args.num_workers, 157 | pin_memory=False 158 | ) 159 | train_validate_data = torch.utils.data.DataLoader( 160 | train_dataset, 161 | batch_size=1, 162 | shuffle=True, 163 | num_workers=args.num_workers//4, 164 | pin_memory=False 165 | ) 166 | if args.dataset == 'graspnet1b': 167 | val_dataset = Dataset(args.dataset_path, ds_rotate=args.ds_rotate, 168 | output_size=args.output_size, 169 | random_rotate=False, random_zoom=False, 170 | include_depth=args.use_depth, 171 | include_rgb=args.use_rgb, 172 | camera=args.camera, 173 | scale=args.scale, 174 | split='test_seen') 175 | val_dataset_1 = Dataset(args.dataset_path, ds_rotate=False, 176 | output_size=args.output_size, 177 | random_rotate=False, random_zoom=False, 178 | include_depth=args.use_depth, 179 | include_rgb=args.use_rgb, 180 | camera=args.camera, 181 | scale=args.scale, 182 | split='test_similar') 183 | val_dataset_2 = Dataset(args.dataset_path, ds_rotate=False, 184 | output_size=args.output_size, 185 | random_rotate=False, random_zoom=False, 186 | include_depth=args.use_depth, 187 | include_rgb=args.use_rgb, 188 | camera=args.camera, 189 | scale=args.scale, 190 | split='test_novel') 191 | 192 | else: 193 | val_dataset = Dataset(args.dataset_path, start=args.split, end=1.0, ds_rotate=args.ds_rotate, 194 | output_size=args.output_size, 195 | random_rotate=True, random_zoom=True, 196 | include_depth=args.use_depth, 197 | include_rgb=args.use_rgb) 198 | val_data = torch.utils.data.DataLoader( 199 | val_dataset, 200 | batch_size=1, # do not modify 201 | shuffle=True, 202 | num_workers=args.num_workers // 4, 203 | pin_memory=False, 204 | ) 205 | val_data_1 = torch.utils.data.DataLoader( 206 | val_dataset_1, 207 | batch_size=1, # do not modify 208 | shuffle=True, 209 | num_workers=args.num_workers // 4, 210 | pin_memory=False 211 | ) 212 | val_data_2 = torch.utils.data.DataLoader( 213 | val_dataset_2, 214 | batch_size=1, # do not modify 215 | shuffle=True, 216 | num_workers=args.num_workers // 4, 217 | pin_memory=False 218 | ) 219 | logging.info('Done') 220 | 221 | # Load the network 222 | logging.info('Loading Network...') 223 | input_channels = 1*args.use_depth + 3*args.use_rgb 224 | print("channels:",input_channels) 225 | 226 | net=SwinTransformerSys(in_chans=input_channels,embed_dim=48,num_heads=[1, 2, 4, 8]) 227 | device = torch.device("cuda:0") 228 | net = net.to(device) 229 | optimizer = optim.AdamW(net.parameters(),lr=1e-4) 230 | logging.info('Done') 231 | summary(net, (input_channels, 224, 224)) 232 | f = open(os.path.join(save_folder, 'arch.txt'), 'w') 233 | sys.stdout = f 234 | summary(net, (input_channels, 224, 224)) 235 | sys.stdout = sys.__stdout__ 236 | f.close() 237 | 238 | best_iou = 0.0 239 | for epoch in range(args.epochs): 240 | logging.info('Beginning Epoch {:02d}'.format(epoch)) 241 | train_results = train(epoch, net, device, train_data, optimizer, args.batches_per_epoch, vis=args.vis) 242 | 243 | 244 | test_results = validate(net, device, train_validate_data, args.val_batches) 245 | logging.info(' traning %d/%d = %f' % (test_results['correct'], test_results['correct'] + test_results['failed'], 246 | test_results['correct'] / ( 247 | test_results['correct'] + test_results['failed']))) 248 | logging.info('loss/train_loss: %f'%test_results['loss']) 249 | 250 | test_results = validate(net, device, val_data, args.val_batches) 251 | logging.info(' seen %d/%d = %f' % (test_results['correct'], test_results['correct'] + test_results['failed'], 252 | test_results['correct']/(test_results['correct']+test_results['failed']))) 253 | logging.info('loss/seen_loss: %f'%test_results['loss']) 254 | 255 | test_results = validate(net, device, val_data_1, args.val_batches) 256 | logging.info('similar %d/%d = %f' % (test_results['correct'], test_results['correct'] + test_results['failed'], 257 | test_results['correct'] / (test_results['correct'] + test_results['failed']))) 258 | logging.info('loss/similar_loss: %f'%test_results['loss']) 259 | 260 | test_results = validate(net, device, val_data_2, args.val_batches,no_grasps=2) 261 | logging.info('novel %d/%d = %f' % (test_results['correct'], test_results['correct'] + test_results['failed'], 262 | test_results['correct'] / (test_results['correct'] + test_results['failed']))) 263 | logging.info('loss/novel_loss: %f'%test_results['loss']) 264 | 265 | 266 | # Save best performing network 267 | iou = test_results['correct'] / (test_results['correct'] + test_results['failed']) 268 | if iou > best_iou or epoch == 0 or (epoch % 10) == 0: 269 | torch.save(net, os.path.join(save_folder, 'epoch_%02d_iou_%0.2f' % (epoch, iou))) 270 | # torch.save(net.state_dict(), os.path.join(save_folder, 'epoch_%02d_iou_%0.2f_statedict.pt' % (epoch, iou))) 271 | best_iou = iou 272 | 273 | 274 | if __name__ == '__main__': 275 | run() 276 | -------------------------------------------------------------------------------- /main_k_fold.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import sys 4 | import argparse 5 | import logging 6 | import cv2 7 | import torch 8 | import torch.utils.data 9 | import torch.optim as optim 10 | from torchsummary import summary 11 | from sklearn.model_selection import KFold 12 | from traning import train, validate 13 | from utils.data import get_dataset 14 | from models.swin import SwinTransformerSys 15 | # from models.Swin_without_skipconcetion import SwinTransformerSys 16 | logging.basicConfig(level=logging.INFO) 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description='TF-Grasp') 20 | 21 | # Network 22 | 23 | 24 | # Dataset & Data & Training 25 | parser.add_argument('--dataset', type=str,default="cornell", help='Dataset Name ("cornell" or "jaquard or multi")') 26 | parser.add_argument('--dataset-path', type=str,default="/home/sam/Desktop/archive111" ,help='Path to dataset') 27 | parser.add_argument('--use-depth', type=int, default=1, help='Use Depth image for training (1/0)') 28 | parser.add_argument('--use-rgb', type=int, default=0, help='Use RGB image for training (0/1)') 29 | parser.add_argument('--split', type=float, default=0.9, help='Fraction of data for training (remainder is validation)') 30 | parser.add_argument('--ds-rotate', type=float, default=0.0, 31 | help='Shift the start point of the dataset to use a different test/train split for cross validation.') 32 | parser.add_argument('--num-workers', type=int, default=8, help='Dataset workers') 33 | 34 | parser.add_argument('--batch-size', type=int, default=32, help='Batch size') 35 | parser.add_argument('--epochs', type=int, default=2000, help='Training epochs') 36 | parser.add_argument('--batches-per-epoch', type=int, default=500, help='Batches per Epoch') 37 | parser.add_argument('--val-batches', type=int, default=50, help='Validation Batches') 38 | # Logging etc. 39 | parser.add_argument('--description', type=str, default='', help='Training description') 40 | parser.add_argument('--outdir', type=str, default='output/models/', help='Training Output Directory') 41 | 42 | args = parser.parse_args() 43 | return args 44 | 45 | 46 | def run(): 47 | args = parse_args() 48 | # Set-up output directories 49 | dt = datetime.datetime.now().strftime('%y%m%d_%H%M') 50 | net_desc = '{}_{}'.format(dt, '_'.join(args.description.split())) 51 | 52 | save_folder = os.path.join(args.outdir, net_desc) 53 | if not os.path.exists(save_folder): 54 | os.makedirs(save_folder) 55 | 56 | # Load Dataset 57 | logging.info('Loading {} Dataset...'.format(args.dataset.title())) 58 | Dataset = get_dataset(args.dataset) 59 | 60 | 61 | dataset = Dataset(args.dataset_path, start=0.0, end=1.0, ds_rotate=args.ds_rotate, 62 | random_rotate=True, random_zoom=True, 63 | include_depth=args.use_depth, include_rgb=args.use_rgb) 64 | k_folds = 5 65 | kfold = KFold(n_splits=k_folds, shuffle=True) 66 | logging.info('Done') 67 | logging.info('Loading Network...') 68 | input_channels = 1*args.use_depth + 3*args.use_rgb 69 | net = SwinTransformerSys(in_chans=input_channels,embed_dim=48,num_heads=[1,2,4,8]) 70 | device = torch.device("cuda:0") 71 | net = net.to(device) 72 | optimizer = optim.AdamW(net.parameters(), lr=5e-4) 73 | listy = [x *7 for x in range(1,1000,3)] 74 | schedule=torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=listy,gamma=0.6) 75 | logging.info('Done') 76 | # Print model architecture. 77 | summary(net, (input_channels, 224, 224)) 78 | f = open(os.path.join(save_folder, 'arch.txt'), 'w') 79 | sys.stdout = f 80 | summary(net, (input_channels, 224, 224)) 81 | sys.stdout = sys.__stdout__ 82 | f.close() 83 | 84 | best_iou = 0.0 85 | for epoch in range(args.epochs): 86 | accuracy=0. 87 | for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)): 88 | 89 | train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids) 90 | test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids) 91 | trainloader = torch.utils.data.DataLoader( 92 | dataset, 93 | batch_size=args.batch_size,num_workers=args.num_workers, sampler=train_subsampler) 94 | testloader = torch.utils.data.DataLoader( 95 | dataset, 96 | batch_size=1,num_workers=args.num_workers, sampler=test_subsampler) 97 | 98 | 99 | logging.info('Beginning Epoch {:02d}'.format(epoch)) 100 | print("lr:",optimizer.state_dict()['param_groups'][0]['lr']) 101 | train_results = train(epoch, net, device, trainloader, optimizer, args.batches_per_epoch, ) 102 | schedule.step() 103 | 104 | # Run Validation 105 | logging.info('Validating...') 106 | test_results = validate(net, device, testloader, args.val_batches) 107 | logging.info('%d/%d = %f' % (test_results['correct'], test_results['correct'] + test_results['failed'], 108 | test_results['correct']/(test_results['correct']+test_results['failed']))) 109 | 110 | 111 | iou = test_results['correct'] / (test_results['correct'] + test_results['failed']) 112 | accuracy+=iou 113 | if iou > best_iou or epoch == 0 or (epoch % 50) == 0: 114 | torch.save(net, os.path.join(save_folder, 'epoch_%02d_iou_%0.2f' % (epoch, iou))) 115 | # torch.save(net.state_dict(), os.path.join(save_folder, 'epoch_%02d_iou_%0.2f_statedict.pt' % (epoch, iou))) 116 | best_iou = iou 117 | schedule.step() 118 | print("the accuracy:",accuracy/k_folds) 119 | 120 | 121 | if __name__ == '__main__': 122 | run() 123 | 124 | 125 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from skimage.filters import gaussian 3 | 4 | 5 | def post_process_output(q_img, cos_img, sin_img, width_img): 6 | """ 7 | Post-process the raw output of the GG-CNN, convert to numpy arrays, apply filtering. 8 | :param q_img: Q output of GG-CNN (as torch Tensors) 9 | :param cos_img: cos output of GG-CNN 10 | :param sin_img: sin output of GG-CNN 11 | :param width_img: Width output of GG-CNN 12 | :return: Filtered Q output, Filtered Angle output, Filtered Width output 13 | """ 14 | q_img = q_img.data.detach().cpu().numpy().squeeze() 15 | ang_img = (torch.atan2(sin_img, cos_img) / 2.0).data.detach().cpu().numpy().squeeze() 16 | width_img = width_img.data.detach().cpu().numpy().squeeze() * 150.0 17 | 18 | q_img = gaussian(q_img, 2.0, preserve_range=True) 19 | ang_img = gaussian(ang_img, 2.0, preserve_range=True) 20 | width_img = gaussian(width_img, 1.0, preserve_range=True) 21 | 22 | return q_img, ang_img, width_img 23 | -------------------------------------------------------------------------------- /models/swin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.checkpoint as checkpoint 4 | #from einops import rearrange 5 | #from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 6 | import torch.nn.functional as F 7 | def drop_path(x, drop_prob: float = 0., training: bool = False): 8 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 9 | 10 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 11 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 12 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 13 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 14 | 'survival rate' as the argument. 15 | 16 | """ 17 | if drop_prob == 0. or not training: 18 | return x 19 | keep_prob = 1 - drop_prob 20 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 21 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 22 | random_tensor.floor_() # binarize 23 | output = x.div(keep_prob) * random_tensor 24 | return output 25 | 26 | 27 | class DropPath(nn.Module): 28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 29 | """ 30 | def __init__(self, drop_prob=None): 31 | super(DropPath, self).__init__() 32 | self.drop_prob = drop_prob 33 | 34 | def forward(self, x): 35 | return drop_path(x, self.drop_prob, self.training) 36 | 37 | from itertools import repeat 38 | import collections.abc 39 | 40 | 41 | # From PyTorch internals 42 | def _ntuple(n): 43 | def parse(x): 44 | # if isinstance(x, collections.abc.Iterable): 45 | # return x 46 | return tuple(repeat(x, n)) 47 | return parse 48 | 49 | 50 | to_1tuple = _ntuple(1) 51 | to_2tuple = _ntuple(2) 52 | to_3tuple = _ntuple(3) 53 | to_4tuple = _ntuple(4) 54 | 55 | 56 | import torch 57 | import math 58 | import warnings 59 | 60 | from torch.nn.init import _calculate_fan_in_and_fan_out 61 | 62 | 63 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 64 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 65 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 66 | def norm_cdf(x): 67 | # Computes standard normal cumulative distribution function 68 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 69 | 70 | if (mean < a - 2 * std) or (mean > b + 2 * std): 71 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 72 | "The distribution of values may be incorrect.", 73 | stacklevel=2) 74 | 75 | with torch.no_grad(): 76 | # Values are generated by using a truncated uniform distribution and 77 | # then using the inverse CDF for the normal distribution. 78 | # Get upper and lower cdf values 79 | l = norm_cdf((a - mean) / std) 80 | u = norm_cdf((b - mean) / std) 81 | 82 | # Uniformly fill tensor with values from [l, u], then translate to 83 | # [2l-1, 2u-1]. 84 | tensor.uniform_(2 * l - 1, 2 * u - 1) 85 | 86 | # Use inverse cdf transform for normal distribution to get truncated 87 | # standard normal 88 | tensor.erfinv_() 89 | 90 | # Transform to proper mean, std 91 | tensor.mul_(std * math.sqrt(2.)) 92 | tensor.add_(mean) 93 | 94 | # Clamp to ensure it's in the proper range 95 | tensor.clamp_(min=a, max=b) 96 | return tensor 97 | 98 | 99 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 100 | # type: (Tensor, float, float, float, float) -> Tensor 101 | r"""Fills the input Tensor with values drawn from a truncated 102 | normal distribution. The values are effectively drawn from the 103 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 104 | with values outside :math:`[a, b]` redrawn until they are within 105 | the bounds. The method used for generating the random values works 106 | best when :math:`a \leq \text{mean} \leq b`. 107 | Args: 108 | tensor: an n-dimensional `torch.Tensor` 109 | mean: the mean of the normal distribution 110 | std: the standard deviation of the normal distribution 111 | a: the minimum cutoff value 112 | b: the maximum cutoff value 113 | Examples: 114 | >>> w = torch.empty(3, 5) 115 | >>> nn.init.trunc_normal_(w) 116 | """ 117 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 118 | class Mlp(nn.Module): 119 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, drop=0.): #nn.ReLU nn.GELU 120 | super().__init__() 121 | out_features = out_features or in_features 122 | hidden_features = hidden_features or in_features 123 | self.fc1 = nn.Linear(in_features, hidden_features) 124 | self.act = act_layer() 125 | self.fc2 = nn.Linear(hidden_features, out_features) 126 | self.drop = nn.Dropout(drop) 127 | 128 | def forward(self, x): 129 | x = self.fc1(x) 130 | x = self.act(x) 131 | x = self.drop(x) 132 | x = self.fc2(x) 133 | x = self.drop(x) 134 | return x 135 | def window_partition(x, window_size): 136 | """ 137 | Args: 138 | x: (B, H, W, C) 139 | window_size (int): window size 140 | Returns: 141 | windows: (num_windows*B, window_size, window_size, C) 142 | """ 143 | B, H, W, C = x.shape 144 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 145 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 146 | return windows 147 | def window_reverse(windows, window_size, H, W): 148 | """ 149 | Args: 150 | windows: (num_windows*B, window_size, window_size, C) 151 | window_size (int): Window size 152 | H (int): Height of image 153 | W (int): Width of image 154 | Returns: 155 | x: (B, H, W, C) 156 | """ 157 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 158 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 159 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 160 | return x 161 | 162 | class WindowAttention(nn.Module): 163 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 164 | It supports both of shifted and non-shifted window. 165 | Args: 166 | dim (int): Number of input channels. 167 | window_size (tuple[int]): The height and width of the window. 168 | num_heads (int): Number of attention heads. 169 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 170 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 171 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 172 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 173 | """ 174 | 175 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 176 | 177 | super().__init__() 178 | self.dim = dim 179 | self.window_size = window_size # Wh, Ww 180 | self.num_heads = num_heads 181 | head_dim = dim // num_heads 182 | self.scale = qk_scale or head_dim ** -0.5 183 | 184 | # define a parameter table of relative position bias 185 | self.relative_position_bias_table = nn.Parameter( 186 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 187 | 188 | # get pair-wise relative position index for each token inside the window 189 | coords_h = torch.arange(self.window_size[0]) 190 | coords_w = torch.arange(self.window_size[1]) 191 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 192 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 193 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 194 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 195 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 196 | relative_coords[:, :, 1] += self.window_size[1] - 1 197 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 198 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 199 | self.register_buffer("relative_position_index", relative_position_index) 200 | 201 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 202 | self.attn_drop = nn.Dropout(attn_drop) 203 | self.proj = nn.Linear(dim, dim) 204 | self.proj_drop = nn.Dropout(proj_drop) 205 | 206 | trunc_normal_(self.relative_position_bias_table, std=.02) 207 | self.softmax = nn.Softmax(dim=-1) 208 | 209 | def forward(self, x, mask=None): 210 | """ 211 | Args: 212 | x: input features with shape of (num_windows*B, N, C) 213 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 214 | """ 215 | B_, N, C = x.shape 216 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 217 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 218 | 219 | q = q * self.scale 220 | #attn = (q @ k.transpose(-2, -1)) 221 | attn = torch.matmul(q,k.transpose(-2, -1)) 222 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 223 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 224 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 225 | attn = attn + relative_position_bias.unsqueeze(0) 226 | 227 | if mask is not None: 228 | nW = mask.shape[0] 229 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 230 | attn = attn.view(-1, self.num_heads, N, N) 231 | attn = self.softmax(attn) 232 | else: 233 | attn = self.softmax(attn) 234 | 235 | attn = self.attn_drop(attn) 236 | 237 | # x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 238 | x = torch.matmul(attn,v).transpose(1, 2).reshape(B_, N, C) 239 | x = self.proj(x) 240 | x = self.proj_drop(x) 241 | return x 242 | 243 | def extra_repr(self) -> str: 244 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 245 | 246 | def flops(self, N): 247 | # calculate flops for 1 window with token length of N 248 | flops = 0 249 | # qkv = self.qkv(x) 250 | flops += N * self.dim * 3 * self.dim 251 | # attn = (q @ k.transpose(-2, -1)) 252 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 253 | # x = (attn @ v) 254 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 255 | # x = self.proj(x) 256 | flops += N * self.dim * self.dim 257 | return flops 258 | 259 | class SwinTransformerBlock(nn.Module): 260 | r""" Swin Transformer Block. 261 | Args: 262 | dim (int): Number of input channels. 263 | input_resolution (tuple[int]): Input resulotion. 264 | num_heads (int): Number of attention heads. 265 | window_size (int): Window size. 266 | shift_size (int): Shift size for SW-MSA. 267 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 268 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 269 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 270 | drop (float, optional): Dropout rate. Default: 0.0 271 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 272 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 273 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 274 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 275 | """ 276 | 277 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 278 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 279 | act_layer=nn.ReLU, norm_layer=nn.LayerNorm): 280 | super().__init__() 281 | self.dim = dim 282 | self.input_resolution = input_resolution 283 | self.num_heads = num_heads 284 | self.window_size = window_size 285 | self.shift_size = shift_size 286 | self.mlp_ratio = mlp_ratio 287 | if min(self.input_resolution) <= self.window_size: 288 | # if window size is larger than input resolution, we don't partition windows 289 | self.shift_size = 0 290 | self.window_size = min(self.input_resolution) 291 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 292 | 293 | self.norm1 = norm_layer(dim) 294 | self.attn = WindowAttention( 295 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 296 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 297 | 298 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 299 | self.norm2 = norm_layer(dim) 300 | mlp_hidden_dim = int(dim * mlp_ratio) 301 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 302 | 303 | if self.shift_size > 0: 304 | # calculate attention mask for SW-MSA 305 | H, W = self.input_resolution 306 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 307 | h_slices = (slice(0, -self.window_size), 308 | slice(-self.window_size, -self.shift_size), 309 | slice(-self.shift_size, None)) 310 | w_slices = (slice(0, -self.window_size), 311 | slice(-self.window_size, -self.shift_size), 312 | slice(-self.shift_size, None)) 313 | cnt = 0 314 | for h in h_slices: 315 | for w in w_slices: 316 | img_mask[:, h, w, :] = cnt 317 | cnt += 1 318 | 319 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 320 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 321 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 322 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 323 | else: 324 | attn_mask = None 325 | 326 | self.register_buffer("attn_mask", attn_mask) 327 | 328 | def forward(self, x): 329 | H, W = self.input_resolution 330 | B, L, C = x.shape 331 | assert L == H * W, "input feature has wrong size" 332 | 333 | shortcut = x 334 | x = self.norm1(x) 335 | x = x.view(B, H, W, C) 336 | 337 | pad_l = pad_t = 0 338 | pad_r = (self.window_size - W % self.window_size) % self.window_size 339 | pad_b = (self.window_size - H % self.window_size) % self.window_size 340 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) 341 | _, Hp, Wp, _ = x.shape 342 | # cyclic shift 343 | 344 | if self.shift_size > 0: 345 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 346 | else: 347 | shifted_x = x 348 | 349 | # partition windows 350 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 351 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 352 | 353 | # W-MSA/SW-MSA 354 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 355 | 356 | # merge windows 357 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 358 | shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C 359 | 360 | # reverse cyclic shift 361 | if self.shift_size > 0: 362 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 363 | else: 364 | x = shifted_x 365 | x = x.view(B, H * W, C) 366 | 367 | # FFN 368 | x = shortcut + self.drop_path(x) 369 | x = x + self.drop_path(self.mlp(self.norm2(x))) 370 | 371 | return x 372 | 373 | def extra_repr(self) -> str: 374 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 375 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 376 | 377 | def flops(self): 378 | flops = 0 379 | H, W = self.input_resolution 380 | # norm1 381 | flops += self.dim * H * W 382 | # W-MSA/SW-MSA 383 | nW = H * W / self.window_size / self.window_size 384 | flops += nW * self.attn.flops(self.window_size * self.window_size) 385 | # mlp 386 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 387 | # norm2 388 | flops += self.dim * H * W 389 | return flops 390 | 391 | class PatchMerging(nn.Module): 392 | r""" Patch Merging Layer. 393 | Args: 394 | input_resolution (tuple[int]): Resolution of input feature. 395 | dim (int): Number of input channels. 396 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 397 | """ 398 | 399 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 400 | super().__init__() 401 | self.input_resolution = input_resolution 402 | self.dim = dim 403 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 404 | self.norm = norm_layer(4 * dim) 405 | 406 | def forward(self, x): 407 | """ 408 | x: B, H*W, C 409 | """ 410 | H, W = self.input_resolution 411 | B, L, C = x.shape 412 | assert L == H * W, "input feature has wrong size" 413 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 414 | 415 | x = x.view(B, H, W, C) 416 | pad_input = (H % 2 == 1) or (W % 2 == 1) 417 | if pad_input: 418 | # to pad the last 3 dimensions, starting from the last dimension and moving forward. 419 | # (C_front, C_back, W_left, W_right, H_top, H_bottom) 420 | # 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同 421 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) 422 | 423 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 424 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 425 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 426 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 427 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 428 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 429 | 430 | x = self.norm(x) 431 | x = self.reduction(x) 432 | 433 | return x 434 | 435 | def extra_repr(self) -> str: 436 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 437 | 438 | def flops(self): 439 | H, W = self.input_resolution 440 | flops = H * W * self.dim 441 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 442 | return flops 443 | # class PatchExpand(nn.Module): 444 | # def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): 445 | # super().__init__() 446 | # self.input_resolution = input_resolution 447 | # self.dim = dim 448 | # self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity() 449 | # self.norm = norm_layer(dim // dim_scale) 450 | # 451 | # def forward(self, x): 452 | # """ 453 | # x: B, H*W, C 454 | # """ 455 | # H, W = self.input_resolution 456 | # x = self.expand(x) 457 | # B, L, C = x.shape 458 | # assert L == H * W, "input feature has wrong size" 459 | # 460 | # x = x.view(B, H, W, C) 461 | # x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) 462 | # x = x.view(B,-1,C//4) 463 | # x= self.norm(x) 464 | # 465 | # return x 466 | class PatchExpand(nn.Module): 467 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): 468 | super().__init__() 469 | self.input_resolution = input_resolution 470 | self.dim = dim 471 | self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity() 472 | self.norm = norm_layer(dim // dim_scale) 473 | 474 | def forward(self, x): 475 | """ 476 | x: B, H*W, C 477 | """ 478 | H, W = self.input_resolution 479 | x = self.expand(x) 480 | B, L, C = x.shape 481 | assert L == H * W, "input feature has wrong size" 482 | x = x.view(B, H, W, C) 483 | #print("x:",x.shape) 484 | x=x.reshape(B,H*2,W*2, C//4) 485 | #x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) 486 | x = x.view(B,-1,C//4) 487 | x= self.norm(x) 488 | 489 | return x 490 | 491 | class FinalPatchExpand_X4(nn.Module): 492 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): 493 | super().__init__() 494 | self.input_resolution = input_resolution 495 | self.dim = dim 496 | self.dim_scale = dim_scale 497 | self.expand = nn.Linear(dim, 16*dim, bias=False) 498 | self.output_dim = dim 499 | self.norm = norm_layer(self.output_dim) 500 | 501 | def forward(self, x): 502 | """ 503 | x: B, H*W, C 504 | """ 505 | H, W = self.input_resolution 506 | x = self.expand(x) 507 | B, L, C = x.shape 508 | assert L == H * W, "input feature has wrong size" 509 | 510 | x = x.view(B, H, W, C) 511 | x = x.reshape(B,H*self.dim_scale,W*self.dim_scale,C//(self.dim_scale**2)) 512 | #x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2)) 513 | #print(x.shape) 514 | x = x.view(B,-1,self.output_dim) 515 | x= self.norm(x) 516 | 517 | return x 518 | 519 | class BasicLayer(nn.Module): 520 | """ A basic Swin Transformer layer for one stage. 521 | Args: 522 | dim (int): Number of input channels. 523 | input_resolution (tuple[int]): Input resolution. 524 | depth (int): Number of blocks. 525 | num_heads (int): Number of attention heads. 526 | window_size (int): Local window size. 527 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 528 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 529 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 530 | drop (float, optional): Dropout rate. Default: 0.0 531 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 532 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 533 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 534 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 535 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 536 | """ 537 | 538 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 539 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 540 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 541 | 542 | super().__init__() 543 | self.dim = dim 544 | self.input_resolution = input_resolution 545 | self.depth = depth 546 | self.use_checkpoint = use_checkpoint 547 | 548 | # build blocks 549 | self.blocks = nn.ModuleList([ 550 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 551 | num_heads=num_heads, window_size=window_size, 552 | shift_size=0 if (i % 2 == 0) else window_size // 2, 553 | mlp_ratio=mlp_ratio, 554 | qkv_bias=qkv_bias, qk_scale=qk_scale, 555 | drop=drop, attn_drop=attn_drop, 556 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 557 | norm_layer=norm_layer) 558 | for i in range(depth)]) 559 | 560 | # patch merging layer 561 | if downsample is not None: 562 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 563 | else: 564 | self.downsample = None 565 | 566 | def forward(self, x): 567 | for blk in self.blocks: 568 | if self.use_checkpoint: 569 | x = checkpoint.checkpoint(blk, x) 570 | else: 571 | x = blk(x) 572 | if self.downsample is not None: 573 | x = self.downsample(x) 574 | return x 575 | def extra_repr(self) -> str: 576 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 577 | 578 | def flops(self): 579 | flops = 0 580 | for blk in self.blocks: 581 | flops += blk.flops() 582 | if self.downsample is not None: 583 | flops += self.downsample.flops() 584 | return flops 585 | 586 | class BasicLayer_up(nn.Module): 587 | """ A basic Swin Transformer layer for one stage. 588 | Args: 589 | dim (int): Number of input channels. 590 | input_resolution (tuple[int]): Input resolution. 591 | depth (int): Number of blocks. 592 | num_heads (int): Number of attention heads. 593 | window_size (int): Local window size. 594 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 595 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 596 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 597 | drop (float, optional): Dropout rate. Default: 0.0 598 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 599 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 600 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 601 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 602 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 603 | """ 604 | 605 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 606 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 607 | drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): 608 | 609 | super().__init__() 610 | self.dim = dim 611 | self.input_resolution = input_resolution 612 | self.depth = depth 613 | self.use_checkpoint = use_checkpoint 614 | 615 | # build blocks 616 | self.blocks = nn.ModuleList([ 617 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 618 | num_heads=num_heads, window_size=window_size, 619 | shift_size=0 if (i % 2 == 0) else window_size // 2, 620 | mlp_ratio=mlp_ratio, 621 | qkv_bias=qkv_bias, qk_scale=qk_scale, 622 | drop=drop, attn_drop=attn_drop, 623 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 624 | norm_layer=norm_layer) 625 | for i in range(depth)]) 626 | 627 | # patch merging layer 628 | if upsample is not None: 629 | self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) 630 | else: 631 | self.upsample = None 632 | 633 | def forward(self, x): 634 | for blk in self.blocks: 635 | if self.use_checkpoint: 636 | x = checkpoint.checkpoint(blk, x) 637 | else: 638 | x = blk(x) 639 | if self.upsample is not None: 640 | x = self.upsample(x) 641 | return x 642 | 643 | class PatchEmbed(nn.Module): 644 | r""" Image to Patch Embedding 645 | Args: 646 | img_size (int): Image size. Default: 224. 647 | patch_size (int): Patch token size. Default: 4. 648 | in_chans (int): Number of input image channels. Default: 3. 649 | embed_dim (int): Number of linear projection output channels. Default: 96. 650 | norm_layer (nn.Module, optional): Normalization layer. Default: None 651 | """ 652 | 653 | def __init__(self, img_size=300, patch_size=4, in_chans=1, embed_dim=96, norm_layer=None): 654 | super().__init__() 655 | img_size = to_2tuple(img_size) 656 | patch_size = to_2tuple(patch_size) 657 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 658 | self.img_size = img_size 659 | self.patch_size = patch_size 660 | self.patches_resolution = patches_resolution 661 | self.num_patches = patches_resolution[0] * patches_resolution[1] 662 | 663 | self.in_chans = in_chans 664 | self.embed_dim = embed_dim 665 | 666 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 667 | if norm_layer is not None: 668 | self.norm = norm_layer(embed_dim) 669 | else: 670 | self.norm = None 671 | 672 | def forward(self, x): 673 | B, C, H, W = x.shape 674 | # FIXME look at relaxing size constraints 675 | assert H == self.img_size[0] and W == self.img_size[1], \ 676 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 677 | 678 | # pad_input = (H % self.patch_size != 0) or (W % self.patch_size != 0) 679 | # if pad_input: 680 | # # to pad the last 3 dimensions, 681 | # # (W_left, W_right, H_top,H_bottom, C_front, C_back) 682 | # x = F.pad(x, (0, self.patch_size - W % self.patch_size, 683 | # 0, self.patch_size - H % self.patch_size, 684 | # 0, 0)) 685 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 686 | if self.norm is not None: 687 | x = self.norm(x) 688 | return x 689 | 690 | def flops(self): 691 | Ho, Wo = self.patches_resolution 692 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 693 | if self.norm is not None: 694 | flops += Ho * Wo * self.embed_dim 695 | return flops 696 | 697 | 698 | class SwinTransformerSys(nn.Module): 699 | r""" Swin Transformer 700 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 701 | https://arxiv.org/pdf/2103.14030 702 | Args: 703 | img_size (int | tuple(int)): Input image size. Default 224 704 | patch_size (int | tuple(int)): Patch size. Default: 4 705 | in_chans (int): Number of input image channels. Default: 3 706 | num_classes (int): Number of classes for classification head. Default: 1000 707 | embed_dim (int): Patch embedding dimension. Default: 96 708 | depths (tuple(int)): Depth of each Swin Transformer layer. 709 | num_heads (tuple(int)): Number of attention heads in different layers. 710 | window_size (int): Window size. Default: 7 711 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 712 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 713 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 714 | drop_rate (float): Dropout rate. Default: 0 715 | attn_drop_rate (float): Attention dropout rate. Default: 0 716 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 717 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 718 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 719 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 720 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 721 | """ 722 | 723 | def __init__(self, img_size=224, patch_size=4, in_chans=1, num_classes=1, 724 | embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24], 725 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 726 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 727 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 728 | use_checkpoint=False, final_upsample="expand_first", **kwargs): 729 | super().__init__() 730 | 731 | print( 732 | "SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format( 733 | depths, 734 | depths_decoder, drop_path_rate, num_classes)) 735 | 736 | self.num_classes = num_classes 737 | self.num_layers = len(depths) 738 | self.embed_dim = embed_dim 739 | self.ape = ape 740 | self.patch_norm = patch_norm 741 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 742 | self.num_features_up = int(embed_dim * 2) 743 | self.mlp_ratio = mlp_ratio 744 | self.final_upsample = final_upsample 745 | 746 | # split image into non-overlapping patches 747 | self.patch_embed = PatchEmbed( 748 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 749 | norm_layer=norm_layer if self.patch_norm else None) 750 | num_patches = self.patch_embed.num_patches 751 | patches_resolution = self.patch_embed.patches_resolution 752 | self.patches_resolution = patches_resolution 753 | 754 | # absolute position embedding 755 | if self.ape: 756 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 757 | trunc_normal_(self.absolute_pos_embed, std=.02) 758 | 759 | self.pos_drop = nn.Dropout(p=drop_rate) 760 | 761 | # stochastic depth 762 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 763 | 764 | # build encoder and bottleneck layers 765 | self.layers = nn.ModuleList() 766 | for i_layer in range(self.num_layers): 767 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 768 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 769 | patches_resolution[1] // (2 ** i_layer)), 770 | depth=depths[i_layer], 771 | num_heads=num_heads[i_layer], 772 | window_size=window_size, 773 | mlp_ratio=self.mlp_ratio, 774 | qkv_bias=qkv_bias, qk_scale=qk_scale, 775 | drop=drop_rate, attn_drop=attn_drop_rate, 776 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 777 | norm_layer=norm_layer, 778 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 779 | use_checkpoint=use_checkpoint) 780 | self.layers.append(layer) 781 | 782 | # build decoder layers 783 | self.layers_up = nn.ModuleList() 784 | self.concat_back_dim = nn.ModuleList() 785 | for i_layer in range(self.num_layers): 786 | concat_linear = nn.Linear(2 * int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), 787 | int(embed_dim * 2 ** ( 788 | self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity() 789 | if i_layer == 0: 790 | layer_up = PatchExpand( 791 | input_resolution=(patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)), 792 | patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))), 793 | dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), dim_scale=2, norm_layer=norm_layer) 794 | else: 795 | layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), 796 | input_resolution=( 797 | patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)), 798 | patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))), 799 | depth=depths[(self.num_layers - 1 - i_layer)], 800 | num_heads=num_heads[(self.num_layers - 1 - i_layer)], 801 | window_size=window_size, 802 | mlp_ratio=self.mlp_ratio, 803 | qkv_bias=qkv_bias, qk_scale=qk_scale, 804 | drop=drop_rate, attn_drop=attn_drop_rate, 805 | drop_path=dpr[sum(depths[:(self.num_layers - 1 - i_layer)]):sum( 806 | depths[:(self.num_layers - 1 - i_layer) + 1])], 807 | norm_layer=norm_layer, 808 | upsample=PatchExpand if (i_layer < self.num_layers - 1) else None, 809 | use_checkpoint=use_checkpoint) 810 | self.layers_up.append(layer_up) 811 | self.concat_back_dim.append(concat_linear) 812 | 813 | self.norm = norm_layer(self.num_features) 814 | self.norm_up = norm_layer(self.embed_dim) 815 | 816 | if self.final_upsample == "expand_first": 817 | print("---final upsample expand_first---") 818 | self.up = FinalPatchExpand_X4(input_resolution=(img_size // patch_size, img_size // patch_size), 819 | dim_scale=4, dim=embed_dim) 820 | self.pos_output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False) 821 | self.cos_output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False) 822 | self.sin_output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False) 823 | self.width_output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False) 824 | self.apply(self._init_weights) 825 | 826 | def _init_weights(self, m): 827 | if isinstance(m, nn.Linear): 828 | trunc_normal_(m.weight, std=.02) 829 | if isinstance(m, nn.Linear) and m.bias is not None: 830 | nn.init.constant_(m.bias, 0) 831 | elif isinstance(m, nn.LayerNorm): 832 | nn.init.constant_(m.bias, 0) 833 | nn.init.constant_(m.weight, 1.0) 834 | 835 | @torch.jit.ignore 836 | def no_weight_decay(self): 837 | return {'absolute_pos_embed'} 838 | 839 | @torch.jit.ignore 840 | def no_weight_decay_keywords(self): 841 | return {'relative_position_bias_table'} 842 | 843 | # Encoder and Bottleneck 844 | def forward_features(self, x): 845 | x = self.patch_embed(x) 846 | if self.ape: 847 | x = x + self.absolute_pos_embed 848 | x = self.pos_drop(x) 849 | x_downsample = [] 850 | 851 | for layer in self.layers: 852 | x_downsample.append(x) 853 | x = layer(x) 854 | 855 | x = self.norm(x) # B L C 856 | 857 | return x, x_downsample 858 | 859 | # Dencoder and Skip connection 860 | def forward_up_features(self, x, x_downsample): 861 | for inx, layer_up in enumerate(self.layers_up): 862 | if inx == 0: 863 | x = layer_up(x) 864 | else: 865 | x = torch.cat([x, x_downsample[3 - inx]], -1) 866 | # print(x.shape) 867 | x = self.concat_back_dim[inx](x) 868 | x = layer_up(x) 869 | 870 | x = self.norm_up(x) # B L C 871 | 872 | return x 873 | 874 | def up_x4(self, x): 875 | H, W = self.patches_resolution 876 | B, L, C = x.shape 877 | assert L == H * W, "input features has wrong size" 878 | 879 | if self.final_upsample == "expand_first": 880 | x = self.up(x) 881 | x = x.view(B, 4 * H, 4 * W, -1) 882 | x = x.permute(0, 3, 1, 2) # B,C,H,W 883 | pos_output = self.pos_output(x) 884 | cos_output = self.cos_output(x) 885 | sin_output = self.sin_output(x) 886 | width_output = self.width_output(x) 887 | 888 | 889 | return pos_output, cos_output, sin_output, width_output 890 | 891 | def forward(self, x): 892 | x, x_downsample = self.forward_features(x) 893 | x = self.forward_up_features(x, x_downsample) 894 | pos_output, cos_output, sin_output, width_output = self.up_x4(x) 895 | 896 | return pos_output, cos_output, sin_output, width_output 897 | def compute_loss(self, xc, yc): 898 | y_pos, y_cos, y_sin, y_width = yc 899 | pos_pred, cos_pred, sin_pred, width_pred = self(xc) 900 | # print("pos shape:",pos_pred.shape) 901 | p_loss = F.mse_loss(pos_pred, y_pos) 902 | cos_loss = F.mse_loss(cos_pred, y_cos) 903 | sin_loss = F.mse_loss(sin_pred, y_sin) 904 | width_loss = F.mse_loss(width_pred, y_width) 905 | 906 | return { 907 | 'loss': p_loss + cos_loss + sin_loss + width_loss, 908 | 'losses': { 909 | 'p_loss': p_loss, 910 | 'cos_loss': cos_loss, 911 | 'sin_loss': sin_loss, 912 | 'width_loss': width_loss 913 | }, 914 | 'pred': { 915 | 'pos': pos_pred, 916 | 'cos': cos_pred, 917 | 'sin': sin_pred, 918 | 'width': width_pred 919 | } 920 | } 921 | def flops(self): 922 | flops = 0 923 | flops += self.patch_embed.flops() 924 | for i, layer in enumerate(self.layers): 925 | flops += layer.flops() 926 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 927 | flops += self.num_features * self.num_classes 928 | return flops 929 | if __name__ == '__main__': 930 | from torchsummary import summary 931 | # model = SwinTransformerSys(in_chans=1) 932 | # summary(model, (1, 224, 224)) 933 | # model = SwinTransformerSys(in_chans=3).cuda() 934 | # 935 | # 936 | import time 937 | 938 | # from torchstat import stat 939 | from ptflops import get_model_complexity_info 940 | from thop import profile 941 | model = SwinTransformerSys(in_chans=1,embed_dim=24,num_heads=[3, 6, 12, 24]).cuda() 942 | # model = SwinTransformerSys(in_chans=1, embed_dim=12, num_heads=[1, 2, 2, 4]).cuda() 943 | # flops, params = profile(model, inputs=(input, )) 944 | # macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True, 945 | # print_per_layer_stat=True, verbose=True) 946 | # print('{:<30} {:<8}'.format('Computational complexity: ', macs)) 947 | # print('{:<30} {:<8}'.format('Number of parameters: ', params)) 948 | # model = SwinTransformerSys(in_chans=3, embed_dim=24, num_heads=[1, 2, 4, 8]).cuda() 949 | # stat(model, (3, 224, 224)) 950 | summary(model, (1, 224, 224)) 951 | # model = SwinTransformerSys(in_chans=1, ) 952 | # model = SwinTransformerSys(in_chans=1, embed_dim=96, num_heads=[3, 6, 12, 24]) 953 | # model = SwinTransformerSys(in_chans=1, embed_dim=96, num_heads=[1, 2, 4, 8]) 954 | # model = SwinTransformerSys(in_chans=1, embed_dim=48, num_heads=[1, 2, 4, 8]) 955 | model = SwinTransformerSys(in_chans=1, embed_dim=48, num_heads=[1, 2, 4, 8]) 956 | # summary(model, (1, 224, 224)) 957 | model.eval() 958 | sum=0. 959 | # imge = torch.rand(1, 1, 224, 224) 960 | len=20 961 | # imge = torch.rand(1, 3, 224, 224) 962 | for i in range(len): 963 | imge = torch.rand(1, 1, 224, 224) 964 | start = time.perf_counter() 965 | a = model(imge) 966 | # time.sleep(3) 967 | end = time.perf_counter() 968 | dur = end - start 969 | sum+=dur 970 | print(sum/len) 971 | # imge = torch.rand(3, 1, 224, 224) 972 | # a = model(imge) 973 | # print(a[0].shape) 974 | # print(model) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv-python 3 | matplotlib 4 | scikit-image 5 | imageio 6 | torch 7 | torchvision 8 | torchsummary 9 | tensorboardX 10 | sklearn 11 | -------------------------------------------------------------------------------- /traning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import torch.optim as optim 4 | from utils.dataset_processing import evaluation 5 | from models.common import post_process_output 6 | import logging 7 | def validate(net, device, val_data, batches_per_epoch): 8 | """ 9 | Run validation. 10 | :param net: Network 11 | :param device: Torch device 12 | :param val_data: Validation Dataset 13 | :param batches_per_epoch: Number of batches to run 14 | :return: Successes, Failures and Losses 15 | """ 16 | net.eval() 17 | 18 | results = { 19 | 'correct': 0, 20 | 'failed': 0, 21 | 'loss': 0, 22 | 'losses': { 23 | 24 | } 25 | } 26 | 27 | ld = len(val_data) 28 | 29 | with torch.no_grad(): 30 | batch_idx = 0 31 | while batch_idx < batches_per_epoch: 32 | for x, y, didx, rot, zoom_factor in val_data: 33 | batch_idx += 1 34 | if batches_per_epoch is not None and batch_idx >= batches_per_epoch: 35 | break 36 | 37 | xc = x.to(device) 38 | yc = [yy.to(device) for yy in y] 39 | lossd = net.compute_loss(xc, yc) 40 | 41 | loss = lossd['loss'] 42 | 43 | results['loss'] += loss.item()/ld 44 | for ln, l in lossd['losses'].items(): 45 | if ln not in results['losses']: 46 | results['losses'][ln] = 0 47 | results['losses'][ln] += l.item()/ld 48 | 49 | q_out, ang_out, w_out = post_process_output(lossd['pred']['pos'], lossd['pred']['cos'], 50 | lossd['pred']['sin'], lossd['pred']['width']) 51 | 52 | s = evaluation.calculate_iou_match(q_out, ang_out, 53 | val_data.dataset.get_gtbb(didx, rot, zoom_factor), 54 | no_grasps=2, 55 | grasp_width=w_out, 56 | ) 57 | 58 | 59 | 60 | 61 | if s: 62 | results['correct'] += 1 63 | else: 64 | results['failed'] += 1 65 | 66 | return results 67 | 68 | 69 | def train(epoch, net, device, train_data, optimizer, batches_per_epoch, vis=False): 70 | """ 71 | Run one training epoch 72 | :param epoch: Current epoch 73 | :param net: Network 74 | :param device: Torch device 75 | :param train_data: Training Dataset 76 | :param optimizer: Optimizer 77 | :param batches_per_epoch: Data batches to train on 78 | :param vis: Visualise training progress 79 | :return: Average Losses for Epoch 80 | """ 81 | results = { 82 | 'loss': 0, 83 | 'losses': { 84 | } 85 | } 86 | 87 | net.train() 88 | 89 | batch_idx = 0 90 | # Use batches per epoch to make training on different sized datasets (cornell/jacquard) more equivalent. 91 | while batch_idx < batches_per_epoch: 92 | for x, y, _, _, _ in train_data: 93 | # print("shape:",x.shape) 94 | batch_idx += 1 95 | # if batch_idx >= batches_per_epoch: 96 | # break 97 | # print("x_0:",x[0].shape,y[0][0].shape) 98 | # plt.imshow(x[0].permute(1,2,0).numpy()) 99 | # plt.show() 100 | # plt.imshow(y[0][0][0].numpy()) 101 | # plt.show() 102 | xc = x.to(device) 103 | yc = [yy.to(device) for yy in y] 104 | # print("xc shape:",xc.shape) 105 | lossd = net.compute_loss(xc, yc) 106 | 107 | loss = lossd['loss'] 108 | 109 | if batch_idx % 10 == 0: 110 | logging.info('Epoch: {}, Batch: {}, Loss: {:0.4f}'.format(epoch, batch_idx, loss.item())) 111 | 112 | results['loss'] += loss.item() 113 | for ln, l in lossd['losses'].items(): 114 | if ln not in results['losses']: 115 | results['losses'][ln] = 0 116 | results['losses'][ln] += l.item() 117 | 118 | optimizer.zero_grad() 119 | loss.backward() 120 | optimizer.step() 121 | 122 | # Display the images 123 | if vis: 124 | imgs = [] 125 | n_img = min(4, x.shape[0]) 126 | for idx in range(n_img): 127 | imgs.extend([x[idx,].numpy().squeeze()] + [yi[idx,].numpy().squeeze() for yi in y] + [ 128 | x[idx,].numpy().squeeze()] + [pc[idx,].detach().cpu().numpy().squeeze() for pc in lossd['pred'].values()]) 129 | # gridshow('Display', imgs, 130 | # [(xc.min().item(), xc.max().item()), (0.0, 1.0), (0.0, 1.0), (-1.0, 1.0), (0.0, 1.0)] * 2 * n_img, 131 | # [cv2.COLORMAP_BONE] * 10 * n_img, 10) 132 | # cv2.waitKey(2) 133 | 134 | results['loss'] /= batch_idx 135 | for l in results['losses']: 136 | results['losses'][l] /= batch_idx 137 | 138 | return results 139 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | def get_dataset(dataset_name): 2 | if dataset_name == 'cornell': 3 | from .cornell_data import CornellDataset 4 | return CornellDataset 5 | elif dataset_name == 'jacquard': 6 | from .jacquard_data import JacquardDataset 7 | return JacquardDataset 8 | elif dataset_name == 'multi': 9 | from .multi_object import CornellDataset 10 | return CornellDataset 11 | elif dataset_name == 'graspnet1b': 12 | from .gn1b_data import GraspNet1BDataset 13 | return GraspNet1BDataset 14 | else: 15 | raise NotImplementedError('Dataset Type {} is Not implemented'.format(dataset_name)) -------------------------------------------------------------------------------- /utils/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/data/__pycache__/cornell_data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/data/__pycache__/cornell_data.cpython-36.pyc -------------------------------------------------------------------------------- /utils/data/__pycache__/cornell_data.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/data/__pycache__/cornell_data.cpython-37.pyc -------------------------------------------------------------------------------- /utils/data/__pycache__/grasp_data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/data/__pycache__/grasp_data.cpython-36.pyc -------------------------------------------------------------------------------- /utils/data/__pycache__/grasp_data.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/data/__pycache__/grasp_data.cpython-37.pyc -------------------------------------------------------------------------------- /utils/data/__pycache__/jacquard_data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/data/__pycache__/jacquard_data.cpython-36.pyc -------------------------------------------------------------------------------- /utils/data/__pycache__/multi_object.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/data/__pycache__/multi_object.cpython-36.pyc -------------------------------------------------------------------------------- /utils/data/cornell_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | from .grasp_data import GraspDatasetBase 5 | from utils.dataset_processing import grasp, image 6 | 7 | 8 | class CornellDataset(GraspDatasetBase): 9 | """ 10 | Dataset wrapper for the Cornell dataset. 11 | """ 12 | def __init__(self, file_path, start=0.0, end=1.0, ds_rotate=0, **kwargs): 13 | """ 14 | :param file_path: Cornell Dataset directory. 15 | :param start: If splitting the dataset, start at this fraction [0,1] 16 | :param end: If splitting the dataset, finish at this fraction 17 | :param ds_rotate: If splitting the dataset, rotate the list of items by this fraction first 18 | :param kwargs: kwargs for GraspDatasetBase 19 | """ 20 | super(CornellDataset, self).__init__(**kwargs) 21 | 22 | graspf = glob.glob(os.path.join(file_path, '*', 'pcd*cpos.txt')) 23 | graspf.sort() 24 | l = len(graspf) 25 | if l == 0: 26 | raise FileNotFoundError('No dataset files found. Check path: {}'.format(file_path)) 27 | 28 | if ds_rotate: 29 | graspf = graspf[int(l*ds_rotate):] + graspf[:int(l*ds_rotate)] 30 | 31 | depthf = [f.replace('cpos.txt', 'd.tiff') for f in graspf] 32 | rgbf = [f.replace('d.tiff', 'r.png') for f in depthf] 33 | 34 | self.grasp_files = graspf[int(l*start):int(l*end)] 35 | self.depth_files = depthf[int(l*start):int(l*end)] 36 | self.rgb_files = rgbf[int(l*start):int(l*end)] 37 | 38 | def _get_crop_attrs(self, idx): 39 | gtbbs = grasp.GraspRectangles.load_from_cornell_file(self.grasp_files[idx]) 40 | center = gtbbs.center 41 | left = max(0, min(center[1] - self.output_size // 2, 640 - self.output_size)) 42 | top = max(0, min(center[0] - self.output_size // 2, 480 - self.output_size)) 43 | return center, left, top 44 | 45 | def get_gtbb(self, idx, rot=0, zoom=1.0): 46 | gtbbs = grasp.GraspRectangles.load_from_cornell_file(self.grasp_files[idx]) 47 | center, left, top = self._get_crop_attrs(idx) 48 | gtbbs.rotate(rot, center) 49 | gtbbs.offset((-top, -left)) 50 | gtbbs.zoom(zoom, (self.output_size//2, self.output_size//2)) 51 | return gtbbs 52 | 53 | def get_depth(self, idx, rot=0, zoom=1.0): 54 | depth_img = image.DepthImage.from_tiff(self.depth_files[idx]) 55 | center, left, top = self._get_crop_attrs(idx) 56 | depth_img.rotate(rot, center) 57 | depth_img.crop((top, left), (min(480, top + self.output_size), min(640, left + self.output_size))) 58 | depth_img.normalise() 59 | depth_img.zoom(zoom) 60 | depth_img.resize((self.output_size, self.output_size)) 61 | return depth_img.img 62 | 63 | def get_rgb(self, idx, rot=0, zoom=1.0, normalise=True): 64 | rgb_img = image.Image.from_file(self.rgb_files[idx]) 65 | center, left, top = self._get_crop_attrs(idx) 66 | rgb_img.rotate(rot, center) 67 | rgb_img.crop((top, left), (min(480, top + self.output_size), min(640, left + self.output_size))) 68 | rgb_img.zoom(zoom) 69 | rgb_img.resize((self.output_size, self.output_size)) 70 | if normalise: 71 | rgb_img.normalise() 72 | rgb_img.img = rgb_img.img.transpose((2, 0, 1)) 73 | return rgb_img.img 74 | -------------------------------------------------------------------------------- /utils/data/gn1b_data.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import copy 4 | import glob 5 | 6 | import torch 7 | import torch.utils.data 8 | from graspnetAPI import GraspNet 9 | from .grasp_data import GraspDatasetBase 10 | from .cornell_data import CornellDataset 11 | 12 | from utils.dataset_processing import grasp, image 13 | 14 | graspnet_root = "/home/zzl/Pictures/graspnet" 15 | 16 | 17 | class GraspNet1BDataset(GraspDatasetBase): 18 | 19 | def __init__(self, file_path, camera='realsense', split='train', scale=2.0, ds_rotate=True, 20 | output_size=224, 21 | random_rotate=True, random_zoom=True, 22 | include_depth=True, 23 | include_rgb=True, 24 | ): 25 | super(GraspNet1BDataset, self).__init__(output_size=output_size, include_depth=include_depth, include_rgb=include_rgb, random_rotate=random_rotate, 26 | random_zoom=random_zoom, input_only=False) 27 | logging.info('Graspnet root = {}'.format(graspnet_root)) 28 | logging.info('Using data from camera {}'.format(camera)) 29 | self.graspnet_root = graspnet_root 30 | self.camera = camera 31 | self.split = split 32 | self.scale = scale # 原图是hxw=720x1280,scale将原图缩小scale倍 33 | 34 | self._graspnet_instance = GraspNet(graspnet_root, camera, split) 35 | 36 | self.g_rgb_files = self._graspnet_instance.rgbPath # 存放rgb的路径 37 | self.g_depth_files = self._graspnet_instance.depthPath # 存放深度图的路径 38 | self.g_rect_files = [] # 存放抓取标签的路径 39 | 40 | for original_rect_grasp_file in self._graspnet_instance.rectLabelPath: 41 | self.g_rect_files.append( 42 | original_rect_grasp_file 43 | .replace('rect', 'rect_cornell') 44 | .replace('.npy', '.txt') 45 | ) 46 | 47 | logging.info('Graspnet 1Billion dataset created!!') 48 | 49 | def _get_crop_attrs(self, idx, return_gtbbs=False): 50 | gtbbs = grasp.GraspRectangles.load_from_cornell_file(self.g_rect_files[idx], scale=self.scale) 51 | center = gtbbs.center 52 | left = max(0, min(center[1] - self.output_size // 2, int(1280 // self.scale) - self.output_size)) 53 | top = max(0, min(center[0] - self.output_size // 2, int(720 // self.scale) - self.output_size)) 54 | if not return_gtbbs: 55 | return center, left, top 56 | else: 57 | return center, left, top, gtbbs 58 | 59 | def get_gtbb(self, idx, rot=0, zoom=1): 60 | # gtbbs = grasp.GraspRectangles.load_from_cornell_file(self.g_rect_files[idx], scale=self.scale) 61 | center, left, top, gtbbs = self._get_crop_attrs(idx, return_gtbbs=True) 62 | gtbbs.rotate(rot, center) 63 | gtbbs.offset((-top, -left)) 64 | gtbbs.zoom(zoom, (self.output_size // 2, self.output_size // 2)) 65 | return gtbbs 66 | 67 | def get_depth(self, idx, rot=0, zoom=1.0): 68 | # graspnet 1b中的深度图单位转换成m 69 | depth_img = image.DepthImage.from_tiff(self.g_depth_files[idx], depth_scale=1000.0) 70 | rh, rw = int(720 // self.scale), int(1280 // self.scale) 71 | # 读入的是wxh=1280x720 resize成目标尺寸 72 | depth_img.resize((rh, rw)) 73 | center, left, top = self._get_crop_attrs(idx) 74 | depth_img.rotate(rot, center) 75 | depth_img.crop((top, left), (min(rh, top + self.output_size), min(rw, left + self.output_size))) 76 | depth_img.normalise() 77 | depth_img.zoom(zoom) 78 | depth_img.resize((self.output_size, self.output_size)) 79 | return depth_img.img 80 | 81 | def get_rgb(self, idx, rot=0, zoom=1.0, normalise=True): 82 | rgb_img = image.Image.from_file(self.g_rgb_files[idx]) 83 | rh, rw = int(720 // self.scale), int(1280 // self.scale) 84 | # 读入的是wxh=1280x720 resize成目标尺寸 85 | rgb_img.resize((rh, rw)) 86 | center, left, top = self._get_crop_attrs(idx) 87 | rgb_img.rotate(rot, center) 88 | rgb_img.crop((top, left), (min(rh, top + self.output_size), min(rw, left + self.output_size))) 89 | rgb_img.zoom(zoom) 90 | rgb_img.resize((self.output_size, self.output_size)) 91 | if normalise: 92 | rgb_img.normalise() 93 | rgb_img.img = rgb_img.img.transpose((2, 0, 1)) 94 | return rgb_img.img 95 | 96 | def __len__(self): 97 | return len(self.g_rect_files) -------------------------------------------------------------------------------- /utils/data/grasp_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.utils.data 5 | 6 | import random 7 | 8 | 9 | class GraspDatasetBase(torch.utils.data.Dataset): 10 | """ 11 | An abstract dataset for training GG-CNNs in a common format. 12 | """ 13 | def __init__(self, output_size=224, include_depth=True, include_rgb=False, random_rotate=False, 14 | random_zoom=False, input_only=False): 15 | """ 16 | :param output_size: Image output size in pixels (square) 17 | :param include_depth: Whether depth image is included 18 | :param include_rgb: Whether RGB image is included 19 | :param random_rotate: Whether random rotations are applied 20 | :param random_zoom: Whether random zooms are applied 21 | :param input_only: Whether to return only the network input (no labels) 22 | """ 23 | self.output_size = output_size 24 | self.random_rotate = random_rotate 25 | self.random_zoom = random_zoom 26 | self.input_only = input_only 27 | self.include_depth = include_depth 28 | self.include_rgb = include_rgb 29 | 30 | self.grasp_files = [] 31 | 32 | if include_depth is False and include_rgb is False: 33 | raise ValueError('At least one of Depth or RGB must be specified.') 34 | 35 | @staticmethod 36 | def numpy_to_torch(s): 37 | if len(s.shape) == 2: 38 | return torch.from_numpy(np.expand_dims(s, 0).astype(np.float32)) 39 | else: 40 | return torch.from_numpy(s.astype(np.float32)) 41 | 42 | def get_gtbb(self, idx, rot=0, zoom=1.0): 43 | raise NotImplementedError() 44 | 45 | def get_depth(self, idx, rot=0, zoom=1.0): 46 | raise NotImplementedError() 47 | 48 | def get_rgb(self, idx, rot=0, zoom=1.0): 49 | raise NotImplementedError() 50 | 51 | def __getitem__(self, idx): 52 | if self.random_rotate: 53 | rotations = [0, np.pi/2, 2*np.pi/2, 3*np.pi/2] 54 | rot = random.choice(rotations) 55 | else: 56 | rot = 0.0 57 | 58 | if self.random_zoom: 59 | zoom_factor = np.random.uniform(0.5, 1.0) 60 | else: 61 | zoom_factor = 1.0 62 | 63 | # Load the depth image 64 | if self.include_depth: 65 | depth_img = self.get_depth(idx, rot, zoom_factor) 66 | 67 | # Load the RGB image 68 | if self.include_rgb: 69 | rgb_img = self.get_rgb(idx, rot, zoom_factor) 70 | 71 | # Load the grasps 72 | bbs = self.get_gtbb(idx, rot, zoom_factor) 73 | 74 | pos_img, ang_img, width_img = bbs.draw((self.output_size, self.output_size)) 75 | width_img = np.clip(width_img, 0.0, 150.0)/150.0 76 | 77 | if self.include_depth and self.include_rgb: 78 | x = self.numpy_to_torch( 79 | np.concatenate( 80 | (np.expand_dims(depth_img, 0), 81 | rgb_img), 82 | 0 83 | ) 84 | ) 85 | elif self.include_depth: 86 | x = self.numpy_to_torch(depth_img) 87 | elif self.include_rgb: 88 | x = self.numpy_to_torch(rgb_img) 89 | 90 | pos = self.numpy_to_torch(pos_img) 91 | cos = self.numpy_to_torch(np.cos(2*ang_img)) 92 | sin = self.numpy_to_torch(np.sin(2*ang_img)) 93 | width = self.numpy_to_torch(width_img) 94 | 95 | return x, (pos, cos, sin, width), idx, rot, zoom_factor 96 | 97 | def __len__(self): 98 | return len(self.grasp_files) 99 | -------------------------------------------------------------------------------- /utils/data/jacquard_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | from .grasp_data import GraspDatasetBase 5 | from utils.dataset_processing import grasp, image 6 | 7 | 8 | class JacquardDataset(GraspDatasetBase): 9 | """ 10 | Dataset wrapper for the Jacquard dataset. 11 | """ 12 | def __init__(self, file_path, start=0.0, end=1.0, ds_rotate=0, **kwargs): 13 | """ 14 | :param file_path: Jacquard Dataset directory. 15 | :param start: If splitting the dataset, start at this fraction [0,1] 16 | :param end: If splitting the dataset, finish at this fraction 17 | :param ds_rotate: If splitting the dataset, rotate the list of items by this fraction first 18 | :param kwargs: kwargs for GraspDatasetBase 19 | """ 20 | super(JacquardDataset, self).__init__(**kwargs) 21 | 22 | # graspf = glob.glob(os.path.join(file_path, '*', '*_grasps.txt')) 23 | graspf = glob.glob('/home/sam/Desktop/jacouard' + '/*/*/' + '*_grasps.txt') 24 | graspf.sort() 25 | l = len(graspf) 26 | print("len jaccquard:", l) 27 | 28 | if l == 0: 29 | raise FileNotFoundError('No dataset files found. Check path: {}'.format(file_path)) 30 | 31 | if ds_rotate: 32 | graspf = graspf[int(l*ds_rotate):] + graspf[:int(l*ds_rotate)] 33 | 34 | depthf = [f.replace('grasps.txt', 'perfect_depth.tiff') for f in graspf] 35 | rgbf = [f.replace('perfect_depth.tiff', 'RGB.png') for f in depthf] 36 | 37 | self.grasp_files = graspf[int(l*start):int(l*end)] 38 | self.depth_files = depthf[int(l*start):int(l*end)] 39 | self.rgb_files = rgbf[int(l*start):int(l*end)] 40 | 41 | def get_gtbb(self, idx, rot=0, zoom=1.0): 42 | gtbbs = grasp.GraspRectangles.load_from_jacquard_file(self.grasp_files[idx], scale=self.output_size / 1024.0) 43 | c = self.output_size//2 44 | gtbbs.rotate(rot, (c, c)) 45 | gtbbs.zoom(zoom, (c, c)) 46 | return gtbbs 47 | 48 | def get_depth(self, idx, rot=0, zoom=1.0): 49 | depth_img = image.DepthImage.from_tiff(self.depth_files[idx]) 50 | depth_img.rotate(rot) 51 | depth_img.normalise() 52 | depth_img.zoom(zoom) 53 | depth_img.resize((self.output_size, self.output_size)) 54 | return depth_img.img 55 | 56 | def get_rgb(self, idx, rot=0, zoom=1.0, normalise=True): 57 | rgb_img = image.Image.from_file(self.rgb_files[idx]) 58 | rgb_img.rotate(rot) 59 | rgb_img.zoom(zoom) 60 | rgb_img.resize((self.output_size, self.output_size)) 61 | if normalise: 62 | rgb_img.normalise() 63 | rgb_img.img = rgb_img.img.transpose((2, 0, 1)) 64 | return rgb_img.img 65 | 66 | def get_jname(self, idx): 67 | return '_'.join(self.grasp_files[idx].split(os.sep)[-1].split('_')[:-1]) -------------------------------------------------------------------------------- /utils/data/multi_object.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | from .grasp_data import GraspDatasetBase 5 | from utils.dataset_processing import grasp, image 6 | 7 | class CornellDataset(GraspDatasetBase): 8 | """ 9 | Dataset wrapper for the Cornell dataset. 10 | """ 11 | def __init__(self, file_path, start=0.0, end=1.0, ds_rotate=0, **kwargs): 12 | """ 13 | :param file_path: Cornell Dataset directory. 14 | :param start: If splitting the dataset, start at this fraction [0,1] 15 | :param end: If splitting the dataset, finish at this fraction 16 | :param ds_rotate: If splitting the dataset, rotate the list of items by this fraction first 17 | :param kwargs: kwargs for GraspDatasetBase 18 | """ 19 | super(CornellDataset, self).__init__(**kwargs) 20 | file_path="/home/sam/Desktop/multiobject/rgbd" 21 | graspf = glob.glob(os.path.join(file_path,'rgb_*.txt')) 22 | graspf.sort() 23 | l = 95 24 | if l == 0: 25 | raise FileNotFoundError('No dataset files found. Check path: {}'.format(file_path)) 26 | 27 | if ds_rotate: 28 | graspf = graspf[int(l*ds_rotate):] + graspf[:int(l*ds_rotate)] 29 | 30 | depthf =glob.glob(os.path.join(file_path,'depth_*.png')) 31 | depthf.sort() 32 | rgbf =glob.glob(os.path.join(file_path,'rgb*.jpg')) 33 | rgbf.sort() 34 | 35 | self.grasp_files = graspf[int(l*start):int(l*end)] 36 | self.depth_files = depthf[int(l*start):int(l*end)] 37 | self.rgb_files = rgbf[int(l*start):int(l*end)] 38 | 39 | def _get_crop_attrs(self, idx): 40 | gtbbs = grasp.GraspRectangles.load_from_cornell_file(self.grasp_files[idx]) 41 | center = gtbbs.center 42 | left = max(0, min(center[1] - self.output_size // 2, 640 - self.output_size)) 43 | top = max(0, min(center[0] - self.output_size // 2, 480 - self.output_size)) 44 | return center, left, top 45 | 46 | def get_gtbb(self, idx, rot=0, zoom=1.0): 47 | gtbbs = grasp.GraspRectangles.load_from_cornell_file(self.grasp_files[idx]) 48 | center, left, top = self._get_crop_attrs(idx) 49 | gtbbs.rotate(rot, center) 50 | gtbbs.offset((-top, -left)) 51 | gtbbs.zoom(zoom, (self.output_size//2, self.output_size//2)) 52 | return gtbbs 53 | 54 | def get_depth(self, idx, rot=0, zoom=1.0): 55 | depth_img = image.DepthImage.from_tiff(self.depth_files[idx]) 56 | center, left, top = self._get_crop_attrs(idx) 57 | depth_img.rotate(rot, center) 58 | depth_img.crop((top, left), (min(480, top + self.output_size), min(640, left + self.output_size))) 59 | depth_img.normalise() 60 | depth_img.zoom(zoom) 61 | depth_img.resize((self.output_size, self.output_size)) 62 | return depth_img.img 63 | 64 | def get_rgb(self, idx, rot=0, zoom=1.0, normalise=True): 65 | rgb_img = image.Image.from_file(self.rgb_files[idx]) 66 | center, left, top = self._get_crop_attrs(idx) 67 | rgb_img.rotate(rot, center) 68 | rgb_img.crop((top, left), (min(480, top + self.output_size), min(640, left + self.output_size))) 69 | rgb_img.zoom(zoom) 70 | rgb_img.resize((self.output_size, self.output_size)) 71 | if normalise: 72 | rgb_img.normalise() 73 | rgb_img.img = rgb_img.img.transpose((2, 0, 1)) 74 | return rgb_img.img 75 | -------------------------------------------------------------------------------- /utils/data/multigrasp_object.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | origin_path = os.getcwd()#记录一下原始的路径 4 | os.chdir('/home/sam/Desktop/multiobject/rgbd')#这是我的路径 5 | print(os.listdir()) 6 | path='/home/sam/Desktop/multiobject/rgbd' 7 | depth = glob.glob(os.path.join(path,'depth_*.png')) 8 | depth.sort() 9 | print(depth[0:10]) 10 | 11 | rgb=glob.glob(os.path.join(path,'rgb*.jpg')) 12 | rgb.sort() 13 | print(rgb[0:10]) 14 | 15 | label=glob.glob(os.path.join(path,'rgb_*.txt')) 16 | label.sort() 17 | print(label[0:10]) 18 | 19 | from PIL import Image 20 | import matplotlib.pyplot as plt 21 | # plt.figure(figsize=(15,15)) 22 | # for i in range(9): 23 | # img = Image.open(rgb[i]) 24 | # plt.subplot(331+i) 25 | # plt.imshow(img) 26 | # plt.show() 27 | # plt.figure(figsize=(15,15)) 28 | # for i in range(9): 29 | # img = Image.open(depth[i]) 30 | # plt.subplot(331+i) 31 | # plt.imshow(img) 32 | # plt.show() 33 | 34 | # with open(label[0],'r') as f: 35 | # grasp_data = f.read() 36 | # print(grasp_data) 37 | # grasp_data = [grasp.strip() for grasp in grasp_data]#去除末尾换行符 38 | 39 | def str2num(point): 40 | ''' 41 | :参数 :point,字符串,以字符串形式存储的一个点的坐标 42 | :返回值 :列表,包含int型抓取点数据的列表[x,y] 43 | 44 | ''' 45 | x, y = point.split() 46 | x, y = int(round(float(x))), int(round(float(y))) 47 | 48 | return (x, y) 49 | 50 | 51 | def get_rectangle(cornell_grasp_file): 52 | ''' 53 | :参数 :cornell_grap_file:字符串,指向某个抓取文件的路径 54 | :返回值 :列表,包含各个抓取矩形数据的列表 55 | 56 | ''' 57 | grasp_rectangles = [] 58 | with open(cornell_grasp_file, 'r') as f: 59 | while True: 60 | grasp_rectangle = [] 61 | point0 = f.readline().strip() 62 | if not point0: 63 | break 64 | point1, point2, point3 = f.readline().strip(), f.readline().strip(), f.readline().strip() 65 | grasp_rectangle = [str2num(point0), 66 | str2num(point1), 67 | str2num(point2), 68 | str2num(point3)] 69 | grasp_rectangles.append(grasp_rectangle) 70 | 71 | return grasp_rectangles 72 | i=90 73 | grs = get_rectangle(label[i]) 74 | print(grs) 75 | import cv2 76 | import random 77 | import cv2 78 | img = cv2.imread(rgb[i]) 79 | for gr in grs: 80 | # 产生随机颜色 81 | color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) 82 | # 绘制添加矩形框 83 | for i in range(3): # 因为一个框只有四条线,所以这里是3 84 | img = cv2.line(img, gr[i], gr[i + 1], color, 3) 85 | img = cv2.line(img, gr[3], gr[0], color, 2) # 补上最后一条封闭的线 86 | 87 | plt.figure(figsize=(10, 10)) 88 | plt.imshow(img) # 之前用cv2.imshow,显示倒是能显示,就是服务老是挂掉,现在索性换成这个 89 | plt.show() 90 | plt.imshow(img) 91 | plt.show() -------------------------------------------------------------------------------- /utils/dataset_processing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/dataset_processing/__init__.py -------------------------------------------------------------------------------- /utils/dataset_processing/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/dataset_processing/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/dataset_processing/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/dataset_processing/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/dataset_processing/__pycache__/evaluation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/dataset_processing/__pycache__/evaluation.cpython-36.pyc -------------------------------------------------------------------------------- /utils/dataset_processing/__pycache__/evaluation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/dataset_processing/__pycache__/evaluation.cpython-37.pyc -------------------------------------------------------------------------------- /utils/dataset_processing/__pycache__/grasp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/dataset_processing/__pycache__/grasp.cpython-36.pyc -------------------------------------------------------------------------------- /utils/dataset_processing/__pycache__/grasp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/dataset_processing/__pycache__/grasp.cpython-37.pyc -------------------------------------------------------------------------------- /utils/dataset_processing/__pycache__/image.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/dataset_processing/__pycache__/image.cpython-36.pyc -------------------------------------------------------------------------------- /utils/dataset_processing/__pycache__/image.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/dataset_processing/__pycache__/image.cpython-37.pyc -------------------------------------------------------------------------------- /utils/dataset_processing/evaluation.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | from .grasp import GraspRectangles, detect_grasps 6 | plt.rcParams.update({ 7 | "text.usetex": True, 8 | "font.family": "sans-serif", 9 | "font.sans-serif": ["Helvetica"]}) 10 | matplotlib.use("TkAgg") 11 | counter=100 12 | def plot_output(rgb_img,rgb_img_1, depth_img, grasp_q_img, grasp_angle_img, no_grasps=1, grasp_width_img=None, 13 | grasp_q_img_ggcnn=None,grasp_angle_img_ggcnn=None,grasp_width_img_ggcnn=None): 14 | """ 15 | Plot the output of a GG-CNN 16 | :param rgb_img: RGB Image 17 | :param depth_img: Depth Image 18 | :param grasp_q_img: Q output of GG-CNN 19 | :param grasp_angle_img: Angle output of GG-CNN 20 | :param no_grasps: Maximum number of grasps to plot 21 | :param grasp_width_img: (optional) Width output of GG-CNN 22 | :return: 23 | """ 24 | global counter 25 | 26 | gs_1 = detect_grasps(grasp_q_img, grasp_angle_img, width_img=grasp_width_img, no_grasps=5) 27 | # print(len(gs_1)) 28 | fig = plt.figure(figsize=(10, 10)) 29 | ax = fig.add_subplot(3, 3, 1) 30 | ax.imshow(rgb_img) 31 | for g in gs_1: 32 | g.plot(ax) 33 | ax.set_title('RGB') 34 | ax.axis('off') 35 | # ax.savefig('/home/sam/compare_conv/Q_%d.pdf' % counter, bbox_inches='tight') 36 | 37 | 38 | gs_2 = detect_grasps(grasp_q_img_ggcnn, grasp_angle_img_ggcnn, width_img=grasp_width_img_ggcnn, no_grasps=5) 39 | print(len(gs_2)) 40 | # fig = plt.figure(figsize=(10, 10)) 41 | ax = fig.add_subplot(3, 3, 2) 42 | ax.imshow(rgb_img) 43 | for g in gs_2: 44 | g.plot(ax) 45 | ax.set_title('RGB') 46 | ax.axis('off') 47 | 48 | # ax.imshow(depth_img, cmap='gist_rainbow') 49 | # ax.imshow(depth_img,) 50 | # for g in gs: 51 | # g.plot(ax) 52 | # ax.set_title('Depth') 53 | # ax.axis('off') 54 | 55 | ax = fig.add_subplot(3, 3, 3) 56 | # plot = ax.imshow(grasp_width_img, cmap='prism', vmin=-0, vmax=150) 57 | # ax.set_title('q image') 58 | 59 | ax = fig.add_subplot(3, 3, 4) 60 | plot = ax.imshow(grasp_q_img, cmap="jet", vmin=0, vmax=1) #?terrain 61 | plt.colorbar(plot) 62 | ax.axis('off') 63 | ax.set_title('q image') 64 | 65 | 66 | ax = fig.add_subplot(3, 3, 5) #flag prism jet 67 | plot = ax.imshow(grasp_angle_img, cmap="hsv", vmin=-np.pi / 2, vmax=np.pi / 2) 68 | plt.colorbar(plot) 69 | ax.axis('off') 70 | ax.set_title('angle') 71 | 72 | ax = fig.add_subplot(3, 3, 6,) 73 | plot = ax.imshow(grasp_width_img, cmap='jet', vmin=-0, vmax=150) 74 | plt.colorbar(plot) 75 | ax.set_title('width') 76 | ax.axis('off') 77 | 78 | 79 | 80 | ax = fig.add_subplot(3, 3, 7, ) 81 | plot = ax.imshow(grasp_q_img_ggcnn, cmap='jet', vmin=0, vmax=1) 82 | ax.set_title('q image') 83 | ax.axis('off') 84 | 85 | ax = fig.add_subplot(3, 3, 8) # flag prism jet 86 | plot = ax.imshow(grasp_angle_img_ggcnn, cmap="hsv", vmin=-np.pi / 2, vmax=np.pi / 2) 87 | ax.axis('off') 88 | ax.set_title('angle') 89 | 90 | ax = fig.add_subplot(3, 3, 9, ) 91 | plot = ax.imshow(grasp_width_img_ggcnn, cmap='jet', vmin=-0, vmax=150) 92 | ax.set_title('width') 93 | ax.axis('off') 94 | # for g in gs: 95 | # g.plot(ax) 96 | # ax.set_title('Angle') 97 | 98 | # plt.colorbar(plot) 99 | 100 | # plt.imshow(rgb_img) 101 | plt.show() 102 | if input("input") == "1": 103 | print("333") 104 | # matplotlib.use("Agg") 105 | # plt.margins(0, 0) 106 | plt.imshow(rgb_img) 107 | for g in gs_1: 108 | g.plot(plt) 109 | # plt.axis("off") 110 | # plot=plt.imshow(rgb_img) 111 | # for g in gs: 112 | # g.plot(plot) 113 | plt.axis("off") 114 | plt.savefig('/home/sam/compare_conv/RGB_1_%d.pdf'%counter, bbox_inches='tight') 115 | plt.show() 116 | plt.imshow(rgb_img) 117 | for g in gs_2: 118 | g.plot(plt) 119 | # plt.axis("off") 120 | # plot=plt.imshow(rgb_img) 121 | # for g in gs: 122 | # g.plot(plot) 123 | plt.axis("off") 124 | plt.savefig('/home/sam/compare_conv/RGB_2_%d.pdf' % counter, bbox_inches='tight') 125 | plt.show() 126 | 127 | plot1 = plt.imshow(grasp_q_img, cmap="jet", vmin=0, vmax=1) 128 | plt.axis("off") 129 | plt.colorbar(plot1) 130 | plt.savefig('/home/sam/compare_conv/Q_1_%d.pdf'%counter, bbox_inches='tight') 131 | plt.show() 132 | plot1 = plt.imshow(grasp_q_img_ggcnn, cmap="jet", vmin=0, vmax=1) 133 | plt.axis("off") 134 | plt.colorbar(plot1) 135 | plt.savefig('/home/sam/compare_conv/Q_2_%d.pdf' % counter, bbox_inches='tight') 136 | plt.show() 137 | # 138 | counter=counter+1 139 | # if input("input") == "1": 140 | # print("333") 141 | # # matplotlib.use("Agg") 142 | # # plt.margins(0, 0) 143 | # plt.imshow(rgb_img) 144 | # plt.axis("off") 145 | # plt.savefig('RGB_%d.pdf'%counter, bbox_inches='tight') 146 | # # plt.show() 147 | # 148 | # plot1=plt.imshow(grasp_q_img, cmap="jet", vmin=0, vmax=1) 149 | # plt.axis("off") 150 | # plt.colorbar(plot1) 151 | # plt.savefig('Q_%d.pdf'%counter, bbox_inches='tight') 152 | # plt.show() 153 | # 154 | # plt.imshow(grasp_q_img, cmap="jet", vmin=0, vmax=1) 155 | # plt.axis("off") 156 | # plt.savefig('Q_1_%d.pdf' % counter, bbox_inches='tight') 157 | # 158 | # 159 | # plot2=plt.imshow(grasp_angle_img, cmap="hsv", vmin=-np.pi / 2, vmax=np.pi / 2) 160 | # plt.axis("off") 161 | # plt.colorbar(plot2) 162 | # plt.savefig('Angle_%d.pdf'%counter, bbox_inches='tight') 163 | # plt.show() 164 | # 165 | # plt.imshow(grasp_angle_img, cmap="hsv", vmin=-np.pi / 2, vmax=np.pi / 2) 166 | # plt.axis("off") 167 | # plt.savefig('Angle_1_%d.pdf' % counter, bbox_inches='tight') 168 | # 169 | # plot3=plt.imshow(grasp_width_img_ggcnn, cmap='jet', vmin=-0, vmax=150) 170 | # plt.axis("off") 171 | # plt.colorbar(plot3) 172 | # plt.savefig('Width_%d.pdf'%counter, bbox_inches='tight') 173 | # plt.show() 174 | # 175 | # plt.imshow(grasp_width_img_ggcnn, cmap='jet', vmin=-0, vmax=150) 176 | # plt.axis("off") 177 | # plt.savefig('Width_1_%d.pdf' % counter, bbox_inches='tight') 178 | # counter=counter+1 179 | # matplotlib.use("TkAgg") 180 | # if (input())==str(1): 181 | # plt.figure(figsize=(5,5)) 182 | # plt.imshow(rgb_img) 183 | # for g in gs: 184 | # g.plot(plt) 185 | # plt.axis("off") 186 | # plt.tight_layout() 187 | # plt.show() 188 | # 189 | # plt.figure(figsize=(5, 5)) 190 | # # plt.imshow(depth_img, cmap='gist_gray') 191 | # plt.imshow(depth_img, ) 192 | # plt.axis("off") 193 | # for g in gs: 194 | # g.plot(plt) 195 | # plt.tight_layout() 196 | # plt.show() 197 | # 198 | # plt.figure(figsize=(5, 5)) 199 | # plt.imshow(grasp_q_img, cmap="terrain", vmin=0, vmax=1) 200 | # plt.axis("off") 201 | # plt.tight_layout() 202 | # plt.show() 203 | # 204 | # plt.figure(figsize=(5, 5)) 205 | # plt.imshow(grasp_angle_img, cmap="prism", vmin=-np.pi / 2, vmax=np.pi / 2) 206 | # plt.axis("off") 207 | # plt.tight_layout() 208 | # plt.show() 209 | # 210 | # plt.figure(figsize=(5, 5)) 211 | # plt.imshow(grasp_width_img, cmap='hsv', vmin=-0, vmax=150) 212 | # plt.axis("off") 213 | # plt.tight_layout() 214 | # plt.show() 215 | # 216 | # 217 | # plt.figure(figsize=(5, 5)) 218 | # plt.imshow(grasp_q_img_ggcnn, cmap="terrain", vmin=0, vmax=1) 219 | # plt.axis("off") 220 | # plt.tight_layout() 221 | # plt.show() 222 | # 223 | # plt.figure(figsize=(5, 5)) 224 | # plt.imshow(grasp_angle_img_ggcnn, cmap="hsv", vmin=-np.pi / 2, vmax=np.pi / 2) 225 | # plt.axis("off") 226 | # plt.tight_layout() 227 | # plt.show() 228 | # 229 | # plt.figure(figsize=(5, 5)) 230 | # plt.imshow(grasp_width_img_ggcnn, cmap='hsv', vmin=-0, vmax=150) 231 | # plt.axis("off") 232 | # plt.tight_layout() 233 | # plt.show() 234 | def calculate_iou_match(grasp_q, grasp_angle, ground_truth_bbs, no_grasps=1, grasp_width=None): 235 | """ 236 | Calculate grasp success using the IoU (Jacquard) metric (e.g. in https://arxiv.org/abs/1301.3592) 237 | A success is counted if grasp rectangle has a 25% IoU with a ground truth, and is withing 30 degrees. 238 | :param grasp_q: Q outputs of GG-CNN (Nx300x300x3) 239 | :param grasp_angle: Angle outputs of GG-CNN 240 | :param ground_truth_bbs: Corresponding ground-truth BoundingBoxes 241 | :param no_grasps: Maximum number of grasps to consider per image. 242 | :param grasp_width: (optional) Width output from GG-CNN 243 | :return: success 244 | """ 245 | 246 | if not isinstance(ground_truth_bbs, GraspRectangles): 247 | gt_bbs = GraspRectangles.load_from_array(ground_truth_bbs) 248 | else: 249 | gt_bbs = ground_truth_bbs 250 | gs = detect_grasps(grasp_q, grasp_angle, width_img=grasp_width, no_grasps=no_grasps) 251 | for g in gs: 252 | if g.max_iou(gt_bbs) > 0.25: 253 | return True 254 | else: 255 | return False -------------------------------------------------------------------------------- /utils/dataset_processing/generate_cornell_depth.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import numpy as np 4 | from imageio import imsave 5 | import argparse 6 | from utils.dataset_processing.image import DepthImage 7 | 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser(description='Generate depth images from Cornell PCD files.') 11 | parser.add_argument('path', type=str, help='Path to Cornell Grasping Dataset') 12 | args = parser.parse_args() 13 | 14 | pcds = glob.glob(os.path.join(args.path, '*', 'pcd*[0-9].txt')) 15 | pcds.sort() 16 | 17 | for pcd in pcds: 18 | di = DepthImage.from_pcd(pcd, (480, 640)) 19 | di.inpaint() 20 | 21 | of_name = pcd.replace('.txt', 'd.tiff') 22 | print(of_name) 23 | imsave(of_name, di.img.astype(np.float32)) -------------------------------------------------------------------------------- /utils/dataset_processing/grasp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import matplotlib.pyplot as plt 4 | 5 | from skimage.draw import polygon 6 | from skimage.feature import peak_local_max 7 | 8 | 9 | def _gr_text_to_no(l, offset=(0, 0)): 10 | """ 11 | Transform a single point from a Cornell file line to a pair of ints. 12 | :param l: Line from Cornell grasp file (str) 13 | :param offset: Offset to apply to point positions 14 | :return: Point [y, x] 15 | """ 16 | x, y = l.split() 17 | return [int(round(float(y))) - offset[0], int(round(float(x))) - offset[1]] 18 | 19 | 20 | class GraspRectangles: 21 | """ 22 | Convenience class for loading and operating on sets of Grasp Rectangles. 23 | """ 24 | def __init__(self, grs=None): 25 | if grs: 26 | self.grs = grs 27 | else: 28 | self.grs = [] 29 | 30 | def __getitem__(self, item): 31 | return self.grs[item] 32 | 33 | def __iter__(self): 34 | return self.grs.__iter__() 35 | 36 | def __getattr__(self, attr): 37 | """ 38 | Test if GraspRectangle has the desired attr as a function and call it. 39 | """ 40 | # Fuck yeah python. 41 | if hasattr(GraspRectangle, attr) and callable(getattr(GraspRectangle, attr)): 42 | return lambda *args, **kwargs: list(map(lambda gr: getattr(gr, attr)(*args, **kwargs), self.grs)) 43 | else: 44 | raise AttributeError("Couldn't find function %s in BoundingBoxes or BoundingBox" % attr) 45 | 46 | @classmethod 47 | def load_from_array(cls, arr): 48 | """ 49 | Load grasp rectangles from numpy array. 50 | :param arr: Nx4x2 array, where each 4x2 array is the 4 corner pixels of a grasp rectangle. 51 | :return: GraspRectangles() 52 | """ 53 | grs = [] 54 | for i in range(arr.shape[0]): 55 | grp = arr[i, :, :].squeeze() 56 | if grp.max() == 0: 57 | break 58 | else: 59 | grs.append(GraspRectangle(grp)) 60 | return cls(grs) 61 | 62 | @classmethod 63 | def load_from_cornell_file(cls, fname): 64 | """ 65 | Load grasp rectangles from a Cornell dataset grasp file. 66 | :param fname: Path to text file. 67 | :return: GraspRectangles() 68 | """ 69 | grs = [] 70 | with open(fname) as f: 71 | while True: 72 | # Load 4 lines at a time, corners of bounding box. 73 | p0 = f.readline() 74 | if not p0: 75 | break # EOF 76 | p1, p2, p3 = f.readline(), f.readline(), f.readline() 77 | try: 78 | gr = np.array([ 79 | _gr_text_to_no(p0), 80 | _gr_text_to_no(p1), 81 | _gr_text_to_no(p2), 82 | _gr_text_to_no(p3) 83 | ]) 84 | 85 | grs.append(GraspRectangle(gr)) 86 | 87 | except ValueError: 88 | # Some files contain weird values. 89 | continue 90 | return cls(grs) 91 | 92 | @classmethod 93 | def load_from_jacquard_file(cls, fname, scale=1.0): 94 | """ 95 | Load grasp rectangles from a Jacquard dataset file. 96 | :param fname: Path to file. 97 | :param scale: Scale to apply (e.g. if resizing images) 98 | :return: GraspRectangles() 99 | """ 100 | grs = [] 101 | with open(fname) as f: 102 | for l in f: 103 | x, y, theta, w, h = [float(v) for v in l[:-1].split(';')] 104 | # index based on row, column (y,x), and the Jacquard dataset's angles are flipped around an axis. 105 | grs.append(Grasp(np.array([y, x]), -theta/180.0*np.pi, w, h).as_gr) 106 | grs = cls(grs) 107 | grs.scale(scale) 108 | return grs 109 | 110 | def append(self, gr): 111 | """ 112 | Add a grasp rectangle to this GraspRectangles object 113 | :param gr: GraspRectangle 114 | """ 115 | self.grs.append(gr) 116 | 117 | def copy(self): 118 | """ 119 | :return: A deep copy of this object and all of its GraspRectangles. 120 | """ 121 | new_grs = GraspRectangles() 122 | for gr in self.grs: 123 | new_grs.append(gr.copy()) 124 | return new_grs 125 | 126 | def show(self, ax=None, shape=None): 127 | """ 128 | Draw all GraspRectangles on a matplotlib plot. 129 | :param ax: (optional) existing axis 130 | :param shape: (optional) Plot shape if no existing axis 131 | """ 132 | if ax is None: 133 | f = plt.figure() 134 | ax = f.add_subplot(1, 1, 1) 135 | ax.imshow(np.zeros(shape)) 136 | ax.axis([0, shape[1], shape[0], 0]) 137 | self.plot(ax) 138 | plt.show() 139 | else: 140 | self.plot(ax) 141 | 142 | def draw(self, shape, position=True, angle=True, width=True): 143 | """ 144 | Plot all GraspRectangles as solid rectangles in a numpy array, e.g. as network training data. 145 | :param shape: output shape 146 | :param position: If True, Q output will be produced 147 | :param angle: If True, Angle output will be produced 148 | :param width: If True, Width output will be produced 149 | :return: Q, Angle, Width outputs (or None) 150 | """ 151 | if position: 152 | pos_out = np.zeros(shape) 153 | else: 154 | pos_out = None 155 | if angle: 156 | ang_out = np.zeros(shape) 157 | else: 158 | ang_out = None 159 | if width: 160 | width_out = np.zeros(shape) 161 | else: 162 | width_out = None 163 | 164 | for gr in self.grs: 165 | rr, cc = gr.compact_polygon_coords(shape) 166 | if position: 167 | pos_out[rr, cc] = 1.0 168 | if angle: 169 | ang_out[rr, cc] = gr.angle 170 | if width: 171 | width_out[rr, cc] = gr.length 172 | 173 | return pos_out, ang_out, width_out 174 | 175 | def to_array(self, pad_to=0): 176 | """ 177 | Convert all GraspRectangles to a single array. 178 | :param pad_to: Length to 0-pad the array along the first dimension 179 | :return: Nx4x2 numpy array 180 | """ 181 | a = np.stack([gr.points for gr in self.grs]) 182 | if pad_to: 183 | if pad_to > len(self.grs): 184 | a = np.concatenate((a, np.zeros((pad_to - len(self.grs), 4, 2)))) 185 | return a.astype(np.int) 186 | 187 | @property 188 | def center(self): 189 | """ 190 | Compute mean center of all GraspRectangles 191 | :return: float, mean centre of all GraspRectangles 192 | """ 193 | points = [gr.points for gr in self.grs] 194 | return np.mean(np.vstack(points), axis=0).astype(np.int) 195 | 196 | 197 | class GraspRectangle: 198 | """ 199 | Representation of a grasp in the common "Grasp Rectangle" format. 200 | """ 201 | def __init__(self, points): 202 | self.points = points 203 | 204 | def __str__(self): 205 | return str(self.points) 206 | 207 | @property 208 | def angle(self): 209 | """ 210 | :return: Angle of the grasp to the horizontal. 211 | """ 212 | dx = self.points[1, 1] - self.points[0, 1] 213 | dy = self.points[1, 0] - self.points[0, 0] 214 | return (np.arctan2(-dy, dx) + np.pi/2) % np.pi - np.pi/2 215 | 216 | @property 217 | def as_grasp(self): 218 | """ 219 | :return: GraspRectangle converted to a Grasp 220 | """ 221 | return Grasp(self.center, self.angle, self.length, self.width) 222 | 223 | @property 224 | def center(self): 225 | """ 226 | :return: Rectangle center point 227 | """ 228 | return self.points.mean(axis=0).astype(np.int) 229 | 230 | @property 231 | def length(self): 232 | """ 233 | :return: Rectangle length (i.e. along the axis of the grasp) 234 | """ 235 | dx = self.points[1, 1] - self.points[0, 1] 236 | dy = self.points[1, 0] - self.points[0, 0] 237 | return np.sqrt(dx ** 2 + dy ** 2) 238 | 239 | @property 240 | def width(self): 241 | """ 242 | :return: Rectangle width (i.e. perpendicular to the axis of the grasp) 243 | """ 244 | dy = self.points[2, 1] - self.points[1, 1] 245 | dx = self.points[2, 0] - self.points[1, 0] 246 | return np.sqrt(dx ** 2 + dy ** 2) 247 | 248 | def polygon_coords(self, shape=None): 249 | """ 250 | :param shape: Output Shape 251 | :return: Indices of pixels within the grasp rectangle polygon. 252 | """ 253 | return polygon(self.points[:, 0], self.points[:, 1], shape) 254 | 255 | def compact_polygon_coords(self, shape=None): 256 | """ 257 | :param shape: Output shape 258 | :return: Indices of pixels within the centre thrid of the grasp rectangle. 259 | """ 260 | return Grasp(self.center, self.angle, self.length/3, self.width).as_gr.polygon_coords(shape) 261 | 262 | def iou(self, gr, angle_threshold=np.pi/6): 263 | """ 264 | Compute IoU with another grasping rectangle 265 | :param gr: GraspingRectangle to compare 266 | :param angle_threshold: Maximum angle difference between GraspRectangles 267 | :return: IoU between Grasp Rectangles 268 | """ 269 | if abs((self.angle - gr.angle + np.pi/2) % np.pi - np.pi/2) > angle_threshold: 270 | return 0 271 | 272 | rr1, cc1 = self.polygon_coords() 273 | rr2, cc2 = polygon(gr.points[:, 0], gr.points[:, 1]) 274 | 275 | try: 276 | r_max = max(rr1.max(), rr2.max()) + 1 277 | c_max = max(cc1.max(), cc2.max()) + 1 278 | except: 279 | return 0 280 | 281 | canvas = np.zeros((r_max, c_max)) 282 | canvas[rr1, cc1] += 1 283 | canvas[rr2, cc2] += 1 284 | union = np.sum(canvas > 0) 285 | if union == 0: 286 | return 0 287 | intersection = np.sum(canvas == 2) 288 | return intersection/union 289 | 290 | def copy(self): 291 | """ 292 | :return: Copy of self. 293 | """ 294 | return GraspRectangle(self.points.copy()) 295 | 296 | def offset(self, offset): 297 | """ 298 | Offset grasp rectangle 299 | :param offset: array [y, x] distance to offset 300 | """ 301 | self.points += np.array(offset).reshape((1, 2)) 302 | 303 | def rotate(self, angle, center): 304 | """ 305 | Rotate grasp rectangle 306 | :param angle: Angle to rotate (in radians) 307 | :param center: Point to rotate around (e.g. image center) 308 | """ 309 | R = np.array( 310 | [ 311 | [np.cos(-angle), np.sin(-angle)], 312 | [-1 * np.sin(-angle), np.cos(-angle)], 313 | ] 314 | ) 315 | c = np.array(center).reshape((1, 2)) 316 | self.points = ((np.dot(R, (self.points - c).T)).T + c).astype(np.int) 317 | 318 | def scale(self, factor): 319 | """ 320 | :param factor: Scale grasp rectangle by factor 321 | """ 322 | if factor == 1.0: 323 | return 324 | self.points *= factor 325 | 326 | def plot(self, ax, color=None): 327 | """ 328 | Plot grasping rectangle. 329 | :param ax: Existing matplotlib axis 330 | :param color: matplotlib color code (optional) 331 | """ 332 | points = np.vstack((self.points, self.points[0])) 333 | ax.plot(points[:, 1], points[:, 0], color=color) 334 | 335 | def zoom(self, factor, center): 336 | """ 337 | Zoom grasp rectangle by given factor. 338 | :param factor: Zoom factor 339 | :param center: Zoom zenter (focus point, e.g. image center) 340 | """ 341 | T = np.array( 342 | [ 343 | [1/factor, 0], 344 | [0, 1/factor] 345 | ] 346 | ) 347 | c = np.array(center).reshape((1, 2)) 348 | self.points = ((np.dot(T, (self.points - c).T)).T + c).astype(np.int) 349 | 350 | 351 | class Grasp: 352 | """ 353 | A Grasp represented by a center pixel, rotation angle and gripper width (length) 354 | """ 355 | def __init__(self, center, angle, length=60, width=30): 356 | self.center = center 357 | self.angle = angle # Positive angle means rotate anti-clockwise from horizontal. 358 | self.length = length 359 | self.width = width 360 | 361 | @property 362 | def as_gr(self): 363 | """ 364 | Convert to GraspRectangle 365 | :return: GraspRectangle representation of grasp. 366 | """ 367 | xo = np.cos(self.angle) 368 | yo = np.sin(self.angle) 369 | 370 | y1 = self.center[0] + self.length / 2 * yo 371 | x1 = self.center[1] - self.length / 2 * xo 372 | y2 = self.center[0] - self.length / 2 * yo 373 | x2 = self.center[1] + self.length / 2 * xo 374 | 375 | return GraspRectangle(np.array( 376 | [ 377 | [y1 - self.width/2 * xo, x1 - self.width/2 * yo], 378 | [y2 - self.width/2 * xo, x2 - self.width/2 * yo], 379 | [y2 + self.width/2 * xo, x2 + self.width/2 * yo], 380 | [y1 + self.width/2 * xo, x1 + self.width/2 * yo], 381 | ] 382 | ).astype(np.float)) 383 | 384 | def max_iou(self, grs): 385 | """ 386 | Return maximum IoU between self and a list of GraspRectangles 387 | :param grs: List of GraspRectangles 388 | :return: Maximum IoU with any of the GraspRectangles 389 | """ 390 | self_gr = self.as_gr 391 | max_iou = 0 392 | for gr in grs: 393 | iou = self_gr.iou(gr) 394 | max_iou = max(max_iou, iou) 395 | return max_iou 396 | 397 | def plot(self, ax, color=None): 398 | """ 399 | Plot Grasp 400 | :param ax: Existing matplotlib axis 401 | :param color: (optional) color 402 | """ 403 | self.as_gr.plot(ax, color) 404 | 405 | def to_jacquard(self, scale=1): 406 | """ 407 | Output grasp in "Jacquard Dataset Format" (https://jacquard.liris.cnrs.fr/database.php) 408 | :param scale: (optional) scale to apply to grasp 409 | :return: string in Jacquard format 410 | """ 411 | # Output in jacquard format. 412 | return '%0.2f;%0.2f;%0.2f;%0.2f;%0.2f' % (self.center[1]*scale, self.center[0]*scale, -1*self.angle*180/np.pi, self.length*scale, self.width*scale) 413 | 414 | 415 | def detect_grasps(q_img, ang_img, width_img=None, no_grasps=1): 416 | """ 417 | Detect grasps in a GG-CNN output. 418 | :param q_img: Q image network output 419 | :param ang_img: Angle image network output 420 | :param width_img: (optional) Width image network output 421 | :param no_grasps: Max number of grasps to return 422 | :return: list of Grasps 423 | """ 424 | local_max = peak_local_max(q_img, min_distance=20, threshold_abs=0.2, num_peaks=no_grasps) 425 | 426 | grasps = [] 427 | for grasp_point_array in local_max: 428 | grasp_point = tuple(grasp_point_array) 429 | 430 | grasp_angle = ang_img[grasp_point] 431 | 432 | g = Grasp(grasp_point, grasp_angle) 433 | if width_img is not None: 434 | g.length = width_img[grasp_point] 435 | g.width = g.length/2 436 | 437 | grasps.append(g) 438 | 439 | return grasps 440 | -------------------------------------------------------------------------------- /utils/dataset_processing/image.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from imageio import imread 5 | from skimage.transform import rotate, resize 6 | 7 | import warnings 8 | warnings.filterwarnings("ignore", category=UserWarning) 9 | 10 | 11 | class Image: 12 | """ 13 | Wrapper around an image with some convenient functions. 14 | """ 15 | def __init__(self, img): 16 | self.img = img 17 | 18 | def __getattr__(self, attr): 19 | # Pass along any other methods to the underlying ndarray 20 | return getattr(self.img, attr) 21 | 22 | @classmethod 23 | def from_file(cls, fname): 24 | return cls(imread(fname)) 25 | 26 | def copy(self): 27 | """ 28 | :return: Copy of self. 29 | """ 30 | return self.__class__(self.img.copy()) 31 | 32 | def crop(self, top_left, bottom_right, resize=None): 33 | """ 34 | Crop the image to a bounding box given by top left and bottom right pixels. 35 | :param top_left: tuple, top left pixel. 36 | :param bottom_right: tuple, bottom right pixel 37 | :param resize: If specified, resize the cropped image to this size 38 | """ 39 | self.img = self.img[top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]] 40 | if resize is not None: 41 | self.resize(resize) 42 | 43 | def cropped(self, *args, **kwargs): 44 | """ 45 | :return: Cropped copy of the image. 46 | """ 47 | i = self.copy() 48 | i.crop(*args, **kwargs) 49 | return i 50 | 51 | def normalise(self): 52 | """ 53 | Normalise the image by converting to float [0,1] and zero-centering 54 | """ 55 | self.img = self.img.astype(np.float32)/255.0 56 | self.img -= self.img.mean() 57 | 58 | def resize(self, shape): 59 | """ 60 | Resize image to shape. 61 | :param shape: New shape. 62 | """ 63 | if self.img.shape == shape: 64 | return 65 | self.img = resize(self.img, shape, preserve_range=True).astype(self.img.dtype) 66 | 67 | def resized(self, *args, **kwargs): 68 | """ 69 | :return: Resized copy of the image. 70 | """ 71 | i = self.copy() 72 | i.resize(*args, **kwargs) 73 | return i 74 | 75 | def rotate(self, angle, center=None): 76 | """ 77 | Rotate the image. 78 | :param angle: Angle (in radians) to rotate by. 79 | :param center: Center pixel to rotate if specified, otherwise image center is used. 80 | """ 81 | if center is not None: 82 | center = (center[1], center[0]) 83 | self.img = rotate(self.img, angle/np.pi*180, center=center, mode='symmetric', preserve_range=True).astype(self.img.dtype) 84 | 85 | def rotated(self, *args, **kwargs): 86 | """ 87 | :return: Rotated copy of image. 88 | """ 89 | i = self.copy() 90 | i.rotate(*args, **kwargs) 91 | return i 92 | 93 | def show(self, ax=None, **kwargs): 94 | """ 95 | Plot the image 96 | :param ax: Existing matplotlib axis (optional) 97 | :param kwargs: kwargs to imshow 98 | """ 99 | if ax: 100 | ax.imshow(self.img, **kwargs) 101 | else: 102 | plt.imshow(self.img, **kwargs) 103 | plt.show() 104 | 105 | def zoom(self, factor): 106 | """ 107 | "Zoom" the image by cropping and resizing. 108 | :param factor: Factor to zoom by. e.g. 0.5 will keep the center 50% of the image. 109 | """ 110 | sr = int(self.img.shape[0] * (1 - factor)) // 2 111 | sc = int(self.img.shape[1] * (1 - factor)) // 2 112 | orig_shape = self.img.shape 113 | self.img = self.img[sr:self.img.shape[0] - sr, sc: self.img.shape[1] - sc].copy() 114 | self.img = resize(self.img, orig_shape, mode='symmetric', preserve_range=True).astype(self.img.dtype) 115 | 116 | def zoomed(self, *args, **kwargs): 117 | """ 118 | :return: Zoomed copy of the image. 119 | """ 120 | i = self.copy() 121 | i.zoom(*args, **kwargs) 122 | return i 123 | 124 | 125 | class DepthImage(Image): 126 | def __init__(self, img): 127 | super().__init__(img) 128 | 129 | @classmethod 130 | def from_pcd(cls, pcd_filename, shape, default_filler=0, index=None): 131 | """ 132 | Create a depth image from an unstructured PCD file. 133 | If index isn't specified, use euclidean distance, otherwise choose x/y/z=0/1/2 134 | """ 135 | img = np.zeros(shape) 136 | if default_filler != 0: 137 | img += default_filler 138 | 139 | with open(pcd_filename) as f: 140 | for l in f.readlines(): 141 | ls = l.split() 142 | 143 | if len(ls) != 5: 144 | # Not a point line in the file. 145 | continue 146 | try: 147 | # Not a number, carry on. 148 | float(ls[0]) 149 | except ValueError: 150 | continue 151 | 152 | i = int(ls[4]) 153 | r = i // shape[1] 154 | c = i % shape[1] 155 | 156 | if index is None: 157 | x = float(ls[0]) 158 | y = float(ls[1]) 159 | z = float(ls[2]) 160 | 161 | img[r, c] = np.sqrt(x ** 2 + y ** 2 + z ** 2) 162 | 163 | else: 164 | img[r, c] = float(ls[index]) 165 | 166 | return cls(img/1000.0) 167 | 168 | @classmethod 169 | def from_tiff(cls, fname): 170 | return cls(imread(fname)) 171 | 172 | def inpaint(self, missing_value=0): 173 | """ 174 | Inpaint missing values in depth image. 175 | :param missing_value: Value to fill in teh depth image. 176 | """ 177 | # cv2 inpainting doesn't handle the border properly 178 | # https://stackoverflow.com/questions/25974033/inpainting-depth-map-still-a-black-image-border 179 | self.img = cv2.copyMakeBorder(self.img, 1, 1, 1, 1, cv2.BORDER_DEFAULT) 180 | mask = (self.img == missing_value).astype(np.uint8) 181 | 182 | # Scale to keep as float, but has to be in bounds -1:1 to keep opencv happy. 183 | scale = np.abs(self.img).max() 184 | self.img = self.img.astype(np.float32) / scale # Has to be float32, 64 not supported. 185 | self.img = cv2.inpaint(self.img, mask, 1, cv2.INPAINT_NS) 186 | 187 | # Back to original size and value range. 188 | self.img = self.img[1:-1, 1:-1] 189 | self.img = self.img * scale 190 | 191 | def gradients(self): 192 | """ 193 | Compute gradients of the depth image using Sobel filtesr. 194 | :return: Gradients in X direction, Gradients in Y diretion, Magnitude of XY gradients. 195 | """ 196 | grad_x = cv2.Sobel(self.img, cv2.CV_64F, 1, 0, borderType=cv2.BORDER_DEFAULT) 197 | grad_y = cv2.Sobel(self.img, cv2.CV_64F, 0, 1, borderType=cv2.BORDER_DEFAULT) 198 | grad = np.sqrt(grad_x ** 2 + grad_y ** 2) 199 | 200 | return DepthImage(grad_x), DepthImage(grad_y), DepthImage(grad) 201 | 202 | def normalise(self): 203 | """ 204 | Normalise by subtracting the mean and clippint [-1, 1] 205 | """ 206 | self.img = np.clip((self.img - self.img.mean()), -1, 1) 207 | 208 | 209 | class WidthImage(Image): 210 | """ 211 | A width image is one that describes the desired gripper width at each pixel. 212 | """ 213 | def zoom(self, factor): 214 | """ 215 | "Zoom" the image by cropping and resizing. Also scales the width accordingly. 216 | :param factor: Factor to zoom by. e.g. 0.5 will keep the center 50% of the image. 217 | """ 218 | super().zoom(factor) 219 | self.img = self.img/factor 220 | 221 | def normalise(self): 222 | """ 223 | Normalise by mapping [0, 150] -> [0, 1] 224 | """ 225 | self.img = np.clip(self.img, 0, 150.0)/150.0 226 | -------------------------------------------------------------------------------- /utils/dataset_processing/multigrasp_object.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/dataset_processing/multigrasp_object.py -------------------------------------------------------------------------------- /utils/timeit.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class TimeIt: 5 | print_output = True 6 | last_parent = None 7 | level = -1 8 | 9 | def __init__(self, s): 10 | self.s = s 11 | self.t0 = None 12 | self.t1 = None 13 | self.outputs = [] 14 | self.parent = None 15 | 16 | def __enter__(self): 17 | self.t0 = time.time() 18 | self.parent = TimeIt.last_parent 19 | TimeIt.last_parent = self 20 | TimeIt.level += 1 21 | 22 | def __exit__(self, t, value, traceback): 23 | self.t1 = time.time() 24 | st = '%s%s: %0.1fms' % (' ' * TimeIt.level, self.s, (self.t1 - self.t0)*1000) 25 | TimeIt.level -= 1 26 | 27 | if self.parent: 28 | self.parent.outputs.append(st) 29 | self.parent.outputs += self.outputs 30 | else: 31 | if TimeIt.print_output: 32 | print(st) 33 | for o in self.outputs: 34 | print(o) 35 | self.outputs = [] 36 | 37 | TimeIt.last_parent = self.parent 38 | -------------------------------------------------------------------------------- /utils/visualisation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/visualisation/__init__.py -------------------------------------------------------------------------------- /utils/visualisation/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/visualisation/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/visualisation/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/visualisation/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/visualisation/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/visualisation/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/visualisation/__pycache__/gridshow.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/visualisation/__pycache__/gridshow.cpython-36.pyc -------------------------------------------------------------------------------- /utils/visualisation/__pycache__/gridshow.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/visualisation/__pycache__/gridshow.cpython-37.pyc -------------------------------------------------------------------------------- /utils/visualisation/__pycache__/gridshow.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/visualisation/__pycache__/gridshow.cpython-38.pyc -------------------------------------------------------------------------------- /utils/visualisation/gridshow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | 5 | def gridshow(name, imgs, scales, cmaps, width, border=10): 6 | """ 7 | Display images in a grid. 8 | :param name: cv2 Window Name to update 9 | :param imgs: List of Images (np.ndarrays) 10 | :param scales: The min/max scale of images to properly scale the colormaps 11 | :param cmaps: List of cv2 Colormaps to apply 12 | :param width: Number of images in a row 13 | :param border: Border (pixels) between images. 14 | """ 15 | imgrows = [] 16 | imgcols = [] 17 | 18 | maxh = 0 19 | for i, (img, cmap, scale) in enumerate(zip(imgs, cmaps, scales)): 20 | 21 | # Scale images into range 0-1 22 | if scale is not None: 23 | img = (np.clip(img, scale[0], scale[1]) - scale[0])/(scale[1]-scale[0]) 24 | elif img.dtype == np.float: 25 | img = (img - img.min())/(img.max() - img.min() + 1e-6) 26 | 27 | # Apply colormap (if applicable) and convert to uint8 28 | if cmap is not None: 29 | try: 30 | imgc = cv2.applyColorMap((img * 255).astype(np.uint8), cmap) 31 | except: 32 | imgc = (img*255.0).astype(np.uint8) 33 | else: 34 | imgc = img 35 | 36 | if imgc.shape[0] == 3: 37 | imgc = imgc.transpose((1, 2, 0)) 38 | elif imgc.shape[0] == 4: 39 | imgc = imgc[1:, :, :].transpose((1, 2, 0)) 40 | 41 | # Arrange row of images. 42 | maxh = max(maxh, imgc.shape[0]) 43 | imgcols.append(imgc) 44 | if i > 0 and i % width == (width-1): 45 | imgrows.append(np.hstack([np.pad(c, ((0, maxh - c.shape[0]), (border//2, border//2), (0, 0)), mode='constant') for c in imgcols])) 46 | imgcols = [] 47 | maxh = 0 48 | 49 | # Unfinished row 50 | if imgcols: 51 | imgrows.append(np.hstack([np.pad(c, ((0, maxh - c.shape[0]), (border//2, border//2), (0, 0)), mode='constant') for c in imgcols])) 52 | 53 | maxw = max([c.shape[1] for c in imgrows]) 54 | 55 | cv2.imshow(name, np.vstack([np.pad(r, ((border//2, border//2), (0, maxw - r.shape[1]), (0, 0)), mode='constant') for r in imgrows])) 56 | -------------------------------------------------------------------------------- /visualise_grasp_rectangle.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | import matplotlib.pyplot as plt 5 | import torch.utils.data 6 | from utils.data import get_dataset 7 | from utils.dataset_processing.grasp import detect_grasps,GraspRectangles 8 | from models.common import post_process_output 9 | import cv2 10 | import matplotlib 11 | plt.rcParams.update({ 12 | "text.usetex": True, 13 | "font.family": "sans-serif", 14 | "font.sans-serif": ["Helvetica"]}) 15 | matplotlib.use("TkAgg") 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description='Evaluate GG-CNN') 20 | 21 | # Network 22 | parser.add_argument('--network', type=str,default="./output/models/220623_1311_/epoch_08_iou_0.97", help='Path to saved network to evaluate') 23 | 24 | # Dataset & Data & Training 25 | parser.add_argument('--dataset', type=str, default="cornell",help='Dataset Name ("cornell" or "jaquard")') 26 | parser.add_argument('--dataset-path', type=str,default="/home/sam/Desktop/archive111" ,help='Path to dataset') 27 | parser.add_argument('--use-depth', type=int, default=0, help='Use Depth image for training (1/0)') 28 | parser.add_argument('--use-rgb', type=int, default=1, help='Use RGB image for training (0/1)') 29 | parser.add_argument('--split', type=float, default=0.9, 30 | help='Fraction of data for training (remainder is validation)') 31 | parser.add_argument('--ds-rotate', type=float, default=0.0, 32 | help='Shift the start point of the dataset to use a different test/train split for cross validation.') 33 | parser.add_argument('--num-workers', type=int, default=8, help='Dataset workers') 34 | 35 | parser.add_argument('--batch-size', type=int, default=1, help='Batch size') 36 | parser.add_argument('--vis', type=bool, default=False, help='vis') 37 | parser.add_argument('--epochs', type=int, default=2000, help='Training epochs') 38 | parser.add_argument('--batches-per-epoch', type=int, default=200, help='Batches per Epoch') 39 | parser.add_argument('--val-batches', type=int, default=32, help='Validation Batches') 40 | # Logging etc. 41 | parser.add_argument('--description', type=str, default='', help='Training description') 42 | parser.add_argument('--outdir', type=str, default='output/models/', help='Training Output Directory') 43 | 44 | args = parser.parse_args() 45 | return args 46 | 47 | if __name__ == '__main__': 48 | args = parse_args() 49 | print(args.network) 50 | print(args.use_rgb,args.use_depth) 51 | net = torch.load(args.network) 52 | # net_ggcnn = torch.load('./output/models/211112_1458_/epoch_30_iou_0.75') 53 | device = torch.device("cuda:0") 54 | Dataset = get_dataset(args.dataset) 55 | 56 | val_dataset = Dataset(args.dataset_path, start=args.split, end=1.0, ds_rotate=args.ds_rotate, 57 | random_rotate=True, random_zoom=False, 58 | include_depth=args.use_depth, include_rgb=args.use_rgb) 59 | val_data = torch.utils.data.DataLoader( 60 | val_dataset, 61 | batch_size=1, 62 | shuffle=True, 63 | num_workers=args.num_workers 64 | ) 65 | results = { 66 | 'correct': 0, 67 | 'failed': 0, 68 | 'loss': 0, 69 | 'losses': { 70 | 71 | } 72 | } 73 | ld = len(val_data) 74 | with torch.no_grad(): 75 | batch_idx = 0 76 | fig = plt.figure(figsize=(20, 10)) 77 | # ax = fig.add_subplot(5, 5, 1) 78 | # while batch_idx < 100: 79 | for id,(x, y, didx, rot, zoom_factor) in enumerate( val_data): 80 | # batch_idx += 1 81 | if id>24: 82 | break 83 | print(id) 84 | print(x.shape) 85 | xc = x.to(device) 86 | yc = [yy.to(device) for yy in y] 87 | lossd = net.compute_loss(xc, yc) 88 | 89 | loss = lossd['loss'] 90 | 91 | results['loss'] += loss.item() / ld 92 | for ln, l in lossd['losses'].items(): 93 | if ln not in results['losses']: 94 | results['losses'][ln] = 0 95 | results['losses'][ln] += l.item() / ld 96 | 97 | q_out, ang_out, w_out = post_process_output(lossd['pred']['pos'], lossd['pred']['cos'], 98 | lossd['pred']['sin'], lossd['pred']['width']) 99 | gs_1 = detect_grasps(q_out, ang_out, width_img=w_out, no_grasps=1) 100 | rgb_img=val_dataset.get_rgb(didx, rot, zoom_factor, normalise=False) 101 | # print(rgb_img) 102 | ax = fig.add_subplot(5, 5, id+1) 103 | ax.imshow(rgb_img) 104 | ax.axis('off') 105 | for g in gs_1: 106 | g.plot(ax) 107 | plt.show() 108 | 109 | # s = evaluation.calculate_iou_match(q_out, ang_out, 110 | # val_data.dataset.get_gtbb(didx, rot, zoom_factor), 111 | # no_grasps=2, 112 | # grasp_width=w_out, 113 | # ) 114 | # 115 | # if s: 116 | # results['correct'] += 1 117 | # else: 118 | # results['failed'] += 1 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /visulaize_heatmaps.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import torch.utils.data 6 | from utils.data import get_dataset 7 | from utils.dataset_processing.grasp import detect_grasps,GraspRectangles 8 | from models.common import post_process_output 9 | import cv2 10 | import matplotlib 11 | plt.rcParams.update({ 12 | "text.usetex": True, 13 | "font.family": "sans-serif", 14 | "font.sans-serif": ["Helvetica"]}) 15 | matplotlib.use("TkAgg") 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description='Evaluate GG-CNN') 20 | 21 | # Network 22 | parser.add_argument('--network', type=str,default="./output/models/220623_1311_/epoch_08_iou_0.97", help='Path to saved network to evaluate') 23 | 24 | # Dataset & Data & Training 25 | parser.add_argument('--dataset', type=str, default="cornell",help='Dataset Name ("cornell" or "jaquard")') 26 | parser.add_argument('--dataset-path', type=str,default="/home/sam/Desktop/archive111" ,help='Path to dataset') 27 | parser.add_argument('--use-depth', type=int, default=0, help='Use Depth image for training (1/0)') 28 | parser.add_argument('--use-rgb', type=int, default=1, help='Use RGB image for training (0/1)') 29 | parser.add_argument('--split', type=float, default=0.9, 30 | help='Fraction of data for training (remainder is validation)') 31 | parser.add_argument('--ds-rotate', type=float, default=0.0, 32 | help='Shift the start point of the dataset to use a different test/train split for cross validation.') 33 | parser.add_argument('--num-workers', type=int, default=8, help='Dataset workers') 34 | 35 | parser.add_argument('--batch-size', type=int, default=1, help='Batch size') 36 | parser.add_argument('--vis', type=bool, default=False, help='vis') 37 | parser.add_argument('--epochs', type=int, default=2000, help='Training epochs') 38 | parser.add_argument('--batches-per-epoch', type=int, default=200, help='Batches per Epoch') 39 | parser.add_argument('--val-batches', type=int, default=32, help='Validation Batches') 40 | # Logging etc. 41 | parser.add_argument('--description', type=str, default='', help='Training description') 42 | parser.add_argument('--outdir', type=str, default='output/models/', help='Training Output Directory') 43 | 44 | args = parser.parse_args() 45 | return args 46 | 47 | if __name__ == '__main__': 48 | args = parse_args() 49 | print(args.network) 50 | print(args.use_rgb,args.use_depth) 51 | net = torch.load(args.network) 52 | # net_ggcnn = torch.load('./output/models/211112_1458_/epoch_30_iou_0.75') 53 | device = torch.device("cuda:0") 54 | Dataset = get_dataset(args.dataset) 55 | 56 | val_dataset = Dataset(args.dataset_path, start=args.split, end=1.0, ds_rotate=args.ds_rotate, 57 | random_rotate=True, random_zoom=False, 58 | include_depth=args.use_depth, include_rgb=args.use_rgb) 59 | val_data = torch.utils.data.DataLoader( 60 | val_dataset, 61 | batch_size=1, 62 | shuffle=True, 63 | num_workers=args.num_workers 64 | ) 65 | results = { 66 | 'correct': 0, 67 | 'failed': 0, 68 | 'loss': 0, 69 | 'losses': { 70 | 71 | } 72 | } 73 | ld = len(val_data) 74 | with torch.no_grad(): 75 | batch_idx = 0 76 | # fig = plt.figure(figsize=(10, 10)) 77 | # ax = fig.add_subplot(1, 4, 1) 78 | # while batch_idx < 100: 79 | for id,(x, y, didx, rot, zoom_factor) in enumerate( val_data): 80 | # batch_idx += 1 81 | 82 | print(id) 83 | print(x.shape) 84 | xc = x.to(device) 85 | yc = [yy.to(device) for yy in y] 86 | lossd = net.compute_loss(xc, yc) 87 | 88 | loss = lossd['loss'] 89 | 90 | results['loss'] += loss.item() / ld 91 | for ln, l in lossd['losses'].items(): 92 | if ln not in results['losses']: 93 | results['losses'][ln] = 0 94 | results['losses'][ln] += l.item() / ld 95 | 96 | q_out, ang_out, w_out = post_process_output(lossd['pred']['pos'], lossd['pred']['cos'], 97 | lossd['pred']['sin'], lossd['pred']['width']) 98 | gs_1 = detect_grasps(q_out, ang_out, width_img=w_out, no_grasps=1) 99 | rgb_img=val_dataset.get_rgb(didx, rot, zoom_factor, normalise=False) 100 | 101 | fig = plt.figure(figsize=(10, 10)) 102 | ax = fig.add_subplot(1, 4, 1) 103 | ax.imshow(rgb_img) 104 | 105 | ax = fig.add_subplot(1, 4, 2) 106 | plot = ax.imshow(q_out, cmap="jet", vmin=0, vmax=1) # ?terrain 107 | plt.colorbar(plot) 108 | ax.axis('off') 109 | ax.set_title('q image') 110 | 111 | ax = fig.add_subplot(1, 4, 3) # flag prism jet 112 | plot = ax.imshow(ang_out, cmap="hsv", vmin=-np.pi / 2, vmax=np.pi / 2) 113 | plt.colorbar(plot) 114 | ax.axis('off') 115 | ax.set_title('angle') 116 | 117 | ax = fig.add_subplot(1, 4, 4) 118 | plot = ax.imshow(w_out, cmap='jet', vmin=-0, vmax=150) 119 | plt.colorbar(plot) 120 | ax.set_title('width') 121 | ax.axis('off') 122 | # print(rgb_img) 123 | 124 | plt.show() 125 | plt.savefig('RGB_1_%d.pdf' % 1, bbox_inches='tight') --------------------------------------------------------------------------------