├── README.md
├── img
├── grasp-transformer.png
└── test
├── main.py
├── main_grasp_1b.py
├── main_k_fold.py
├── models
├── __init__.py
├── common.py
└── swin.py
├── requirements.txt
├── traning.py
├── utils
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ └── __init__.cpython-38.pyc
├── data
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── cornell_data.cpython-36.pyc
│ │ ├── cornell_data.cpython-37.pyc
│ │ ├── grasp_data.cpython-36.pyc
│ │ ├── grasp_data.cpython-37.pyc
│ │ ├── jacquard_data.cpython-36.pyc
│ │ └── multi_object.cpython-36.pyc
│ ├── cornell_data.py
│ ├── gn1b_data.py
│ ├── grasp_data.py
│ ├── jacquard_data.py
│ ├── multi_object.py
│ └── multigrasp_object.py
├── dataset_processing
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── evaluation.cpython-36.pyc
│ │ ├── evaluation.cpython-37.pyc
│ │ ├── grasp.cpython-36.pyc
│ │ ├── grasp.cpython-37.pyc
│ │ ├── image.cpython-36.pyc
│ │ └── image.cpython-37.pyc
│ ├── evaluation.py
│ ├── generate_cornell_depth.py
│ ├── grasp.py
│ ├── image.py
│ └── multigrasp_object.py
├── timeit.py
└── visualisation
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── gridshow.cpython-36.pyc
│ ├── gridshow.cpython-37.pyc
│ └── gridshow.cpython-38.pyc
│ └── gridshow.py
├── visualise_grasp_rectangle.py
└── visulaize_heatmaps.py
/README.md:
--------------------------------------------------------------------------------
1 | ## When Transformer Meets Robotic Grasping: Exploits Context for Efficient Grasping Detection
2 |
3 | PyTorch implementation of paper "When Transformer Meets Robotic Grasping:
4 | Exploits Context for Efficient Grasping Detection"
5 |
6 | ## Visualization of the architecture
7 |
8 |
9 |
10 | This code was developed with Python 3.6 on Ubuntu 16.04. Python requirements can installed by:
11 |
12 | ```bash
13 | pip install -r requirements.txt
14 | ```
15 |
16 | ## Datasets
17 |
18 | Currently, both the [Cornell Grasping Dataset](http://pr.cs.cornell.edu/grasping/rect_data/data.php),
19 | [Jacquard Dataset](https://jacquard.liris.cnrs.fr/) , and [GraspNet 1Billion](https://graspnet.net/datasets.html) are supported.
20 |
21 | ### Cornell Grasping Dataset
22 | 1. Download the and extract [Cornell Grasping Dataset](http://pr.cs.cornell.edu/grasping/rect_data/data.php).
23 |
24 | ### Jacquard Dataset
25 |
26 | 1. Download and extract the [Jacquard Dataset](https://jacquard.liris.cnrs.fr/).
27 |
28 | ### GraspNet 1Billion dataset
29 |
30 | 1. The dataset can be downloaded [here](https://graspnet.net/datasets.html).
31 | 2. Install graspnetAPI following [here](https://graspnetapi.readthedocs.io/en/latest/install.html#install-api).
32 |
33 | ```bash
34 | pip install graspnetAPI
35 | ```
36 | 3. We use the setting in [here](https://github.com/ryanreadbooks/Modified-GGCNN)
37 |
38 |
39 | ## Training
40 |
41 | Training is done by the `main.py` script.
42 |
43 | Some basic examples:
44 |
45 | ```bash
46 | # Train on Cornell Dataset
47 | python main.py --dataset cornell
48 |
49 | # k-fold training
50 | python main_k_fold.py --dataset cornell
51 |
52 | # GraspNet 1
53 | python main_grasp_1b.py
54 | ```
55 |
56 | Trained models are saved in `output/models` by default, with the validation score appended.
57 |
58 | ## Visualize
59 | Some basic examples:
60 | ```bash
61 | # visulaize grasp rectangles
62 | python visualise_grasp_rectangle.py --network your network address
63 |
64 | # visulaize heatmaps
65 | python visulaize_heatmaps.py --network your network address
66 |
67 | ```
68 |
69 |
70 |
71 | ## Running on a Robot
72 |
73 | Our ROS implementation for running the grasping system see [https://github.com/USTC-ICR/SimGrasp/tree/main/SimGrasp](https://github.com/USTC-ICR/SimGrasp/tree/main/SimGrasp).
74 |
75 | The original implementation for running experiments on a Kinva Mico arm can be found in the repository [https://github.com/dougsm/ggcnn_kinova_grasping](https://github.com/dougsm/ggcnn_kinova_grasping).
76 |
77 | ## Acknowledgement
78 | Code heavily inspired and modified from https://github.com/dougsm/ggcnn
79 |
80 | If you find this helpful, please cite
81 | ```bash
82 | @ARTICLE{9810182,
83 | author={Wang, Shaochen and Zhou, Zhangli and Kan, Zhen},
84 | journal={IEEE Robotics and Automation Letters},
85 | title={When Transformer Meets Robotic Grasping: Exploits Context for Efficient Grasp Detection},
86 | year={2022},
87 | volume={},
88 | number={},
89 | pages={1-8},
90 | doi={10.1109/LRA.2022.3187261}}
91 |
92 | ```
93 |
--------------------------------------------------------------------------------
/img/grasp-transformer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/img/grasp-transformer.png
--------------------------------------------------------------------------------
/img/test:
--------------------------------------------------------------------------------
1 | test
2 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import sys
4 | import argparse
5 | import logging
6 | import torch
7 | import torch.utils.data
8 | import torch.optim as optim
9 | from torchsummary import summary
10 | from traning import train, validate
11 | from utils.data import get_dataset
12 | from models.swin import SwinTransformerSys
13 | logging.basicConfig(level=logging.INFO)
14 | def parse_args():
15 | parser = argparse.ArgumentParser(description='TF-Grasp')
16 |
17 | # Network
18 | # Dataset & Data & Training
19 | parser.add_argument('--dataset', type=str,default="jacquard", help='Dataset Name ("cornell" or "jacquard")')
20 | parser.add_argument('--dataset-path', type=str,default="/home/sam/Desktop/cornell" ,help='Path to dataset')
21 | parser.add_argument('--use-depth', type=int, default=0, help='Use Depth image for training (1/0)')
22 | parser.add_argument('--use-rgb', type=int, default=1, help='Use RGB image for training (0/1)')
23 | parser.add_argument('--split', type=float, default=0.9, help='Fraction of data for training (remainder is validation)')
24 | parser.add_argument('--ds-rotate', type=float, default=0.0,
25 | help='Shift the start point of the dataset to use a different test/train split for cross validation.')
26 | parser.add_argument('--num-workers', type=int, default=8, help='Dataset workers')
27 |
28 | parser.add_argument('--batch-size', type=int, default=32, help='Batch size')
29 | parser.add_argument('--vis', type=bool, default=False, help='vis')
30 | parser.add_argument('--epochs', type=int, default=2000, help='Training epochs')
31 | parser.add_argument('--batches-per-epoch', type=int, default=200, help='Batches per Epoch')
32 | parser.add_argument('--val-batches', type=int, default=32, help='Validation Batches')
33 | # Logging etc.
34 | parser.add_argument('--description', type=str, default='', help='Training description')
35 | parser.add_argument('--outdir', type=str, default='output/models/', help='Training Output Directory')
36 |
37 | args = parser.parse_args()
38 | return args
39 | def run():
40 | args = parse_args()
41 | dt = datetime.datetime.now().strftime('%y%m%d_%H%M')
42 | net_desc = '{}_{}'.format(dt, '_'.join(args.description.split()))
43 |
44 | save_folder = os.path.join(args.outdir, net_desc)
45 | if not os.path.exists(save_folder):
46 | os.makedirs(save_folder)
47 | # tb = tensorboardX.SummaryWriter(os.path.join(args.logdir, net_desc))
48 |
49 | # Load Dataset
50 | logging.info('Loading {} Dataset...'.format(args.dataset.title()))
51 | Dataset = get_dataset(args.dataset)
52 |
53 | train_dataset = Dataset(args.dataset_path, start=0.0, end=args.split, ds_rotate=args.ds_rotate,
54 | random_rotate=True, random_zoom=False,
55 | include_depth=args.use_depth, include_rgb=args.use_rgb)
56 | train_data = torch.utils.data.DataLoader(
57 | train_dataset,
58 | batch_size=args.batch_size,
59 | shuffle=True,
60 | num_workers=args.num_workers
61 | )
62 | val_dataset = Dataset(args.dataset_path, start=args.split, end=1.0, ds_rotate=args.ds_rotate,
63 | random_rotate=True, random_zoom=False,
64 | include_depth=args.use_depth, include_rgb=args.use_rgb)
65 | val_data = torch.utils.data.DataLoader(
66 | val_dataset,
67 | batch_size=1,
68 | shuffle=False,
69 | num_workers=args.num_workers
70 | )
71 | logging.info('Done')
72 |
73 | logging.info('Loading Network...')
74 | input_channels = 1*args.use_depth + 3*args.use_rgb
75 | net = SwinTransformerSys(in_chans=input_channels, embed_dim=48, num_heads=[1, 2, 4, 8])
76 | device = torch.device("cuda:0")
77 | net = net.to(device)
78 | optimizer = optim.AdamW(net.parameters(), lr=1e-4)
79 | listy = [x * 2 for x in range(1,1000,5)]
80 | schedule=torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=listy,gamma=0.5)
81 | logging.info('Done')
82 | summary(net, (input_channels, 224, 224))
83 | f = open(os.path.join(save_folder, 'net.txt'), 'w')
84 | sys.stdout = f
85 | summary(net, (input_channels, 224, 224))
86 | sys.stdout = sys.__stdout__
87 | f.close()
88 | best_iou = 0.0
89 | for epoch in range(args.epochs):
90 | logging.info('Beginning Epoch {:02d}'.format(epoch))
91 | print("current lr:",optimizer.state_dict()['param_groups'][0]['lr'])
92 | # for i in range(5000):
93 | train_results = train(epoch, net, device, train_data, optimizer, args.batches_per_epoch, vis=args.vis)
94 |
95 | # schedule.step()
96 | # Run Validation
97 | logging.info('Validating...')
98 | test_results = validate(net, device, val_data, args.val_batches)
99 | logging.info('%d/%d = %f' % (test_results['correct'], test_results['correct'] + test_results['failed'],
100 | test_results['correct']/(test_results['correct']+test_results['failed'])))
101 |
102 | iou = test_results['correct'] / (test_results['correct'] + test_results['failed'])
103 | if epoch%1==0 or iou>best_iou:
104 | torch.save(net, os.path.join(save_folder, 'epoch_%02d_iou_%0.2f' % (epoch, iou)))
105 | torch.save(net.state_dict(), os.path.join(save_folder, 'epoch_%02d_iou_%0.2f_statedict.pt' % (epoch, iou)))
106 | best_iou = iou
107 |
108 |
109 | if __name__ == '__main__':
110 | run()
111 |
--------------------------------------------------------------------------------
/main_grasp_1b.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from doctest import FAIL_FAST
3 | import os
4 | import sys
5 | import argparse
6 | import logging
7 | from tqdm import tqdm
8 | import cv2
9 |
10 | import torch
11 | import torch.utils.data
12 | import torch.optim as optim
13 |
14 | from torchsummary import summary
15 |
16 |
17 | from traning import train, validate
18 | from utils.visualisation.gridshow import gridshow
19 |
20 | from utils.dataset_processing import evaluation
21 | from utils.data import get_dataset
22 | from models.common import post_process_output
23 | from models.swin import SwinTransformerSys
24 | logging.basicConfig(level=logging.INFO)
25 |
26 | def parse_args():
27 | parser = argparse.ArgumentParser(description='TF-Grasp')
28 |
29 | # Network
30 | # Dataset & Data & Training
31 | parser.add_argument('--dataset', type=str,default="graspnet1b", help='Dataset Name ("cornell" or "jaquard")')
32 | parser.add_argument('--use-depth', type=int, default=1, help='Use Depth image for training (1/0)')
33 | parser.add_argument('--use-rgb', type=int, default=0, help='Use RGB image for training (0/1)')
34 | parser.add_argument('--split', type=float, default=1., help='Fraction of data for training (remainder is validation)')
35 | parser.add_argument('--ds-rotate', type=float, default=0.0,
36 | help='Shift the start point of the dataset to use a different test/train split for cross validation.')
37 | parser.add_argument('--num-workers', type=int, default=32, help='Dataset workers')
38 |
39 | parser.add_argument('--batch-size', type=int, default=32, help='Batch size')
40 | parser.add_argument('--epochs', type=int, default=500, help='Training epochs')
41 | parser.add_argument('--batches-per-epoch', type=int, default=500, help='Batches per Epoch')
42 | parser.add_argument('--val-batches', type=int, default=100, help='Validation Batches')
43 | parser.add_argument('--output-size', type=int, default=224,
44 | help='the output size of the network, determining the cropped size of dataset images')
45 |
46 | parser.add_argument('--camera', type=str, default='realsense',
47 | help='Which camera\'s data to use, only effective when using graspnet1b dataset')
48 | parser.add_argument('--scale', type=int, default=2,
49 | help='the scale factor for the original images, only effective when using graspnet1b dataset')
50 | # Logging etc.
51 | parser.add_argument('--description', type=str, default='', help='Training description')
52 | parser.add_argument('--outdir', type=str, default='output/models/', help='Training Output Directory')
53 | parser.add_argument('--logdir', type=str, default='tensorboard/', help='Log directory')
54 | parser.add_argument('--vis', default=False,help='Visualise the training process')
55 |
56 | args = parser.parse_args()
57 | return args
58 |
59 |
60 | def validate(net, device, val_data, batches_per_epoch,no_grasps=1):
61 | """
62 | Run validation.
63 | :param net: Network
64 | :param device: Torch device
65 | :param val_data: Validation Dataset
66 | :param batches_per_epoch: Number of batches to run
67 | :return: Successes, Failures and Losses
68 | """
69 | net.eval()
70 |
71 | results = {
72 | 'correct': 0,
73 | 'failed': 0,
74 | 'loss': 0,
75 | 'losses': {
76 |
77 | }
78 | }
79 |
80 | ld = len(val_data)
81 |
82 | with torch.no_grad():
83 | batch_idx = 0
84 | while batch_idx < batches_per_epoch:
85 | for x, y, didx, rot, zoom_factor in tqdm(val_data):
86 | batch_idx += 1
87 | if batches_per_epoch is not None and batch_idx >= batches_per_epoch:
88 | break
89 |
90 | xc = x.to(device)
91 | yc = [yy.to(device) for yy in y]
92 | lossd = net.compute_loss(xc, yc)
93 |
94 | loss = lossd['loss']
95 |
96 | results['loss'] += loss.item()/ld
97 | for ln, l in lossd['losses'].items():
98 | if ln not in results['losses']:
99 | results['losses'][ln] = 0
100 | results['losses'][ln] += l.item()/ld
101 |
102 | q_out, ang_out, w_out = post_process_output(lossd['pred']['pos'], lossd['pred']['cos'],
103 | lossd['pred']['sin'], lossd['pred']['width'])
104 | # print("inde:",didx)
105 |
106 | s = evaluation.calculate_iou_match(q_out, ang_out,
107 | val_data.dataset.get_gtbb(didx, rot, zoom_factor),
108 | no_grasps=no_grasps,
109 | grasp_width=w_out,
110 | )
111 |
112 | if s:
113 | results['correct'] += 1
114 | else:
115 | results['failed'] += 1
116 |
117 |
118 | return results
119 |
120 |
121 | def run():
122 | args = parse_args()
123 |
124 | dt = datetime.datetime.now().strftime('%y%m%d_%H%M')
125 | net_desc = '{}_{}'.format(dt, '_'.join(args.description.split()))
126 |
127 | save_folder = os.path.join(args.outdir, net_desc+"_d="+str(args.use_depth+args.use_rgb)+"_scale=3")
128 | if not os.path.exists(save_folder):
129 | os.makedirs(save_folder)
130 |
131 | # Load Dataset
132 | logging.info('Loading {} Dataset...'.format(args.dataset.title()))
133 | Dataset = get_dataset(args.dataset)
134 |
135 | if args.dataset == 'graspnet1b':
136 | print("1 billion")
137 | train_dataset = Dataset( args.dataset_path, ds_rotate=args.ds_rotate,
138 | output_size=args.output_size,
139 | random_rotate=False, random_zoom=False,
140 | include_depth=args.use_depth,
141 | include_rgb=args.use_rgb,
142 | camera=args.camera,
143 | scale=args.scale,
144 | split='train')
145 | else:
146 | train_dataset = Dataset(file_path=args.dataset_path, start=0.0, end=args.split, ds_rotate=args.ds_rotate,
147 | output_size=args.output_size,
148 | random_rotate=True, random_zoom=True,
149 | include_depth=args.use_depth,
150 | include_rgb=args.use_rgb)
151 |
152 | train_data = torch.utils.data.DataLoader(
153 | train_dataset,
154 | batch_size=args.batch_size,
155 | shuffle=True,
156 | num_workers=args.num_workers,
157 | pin_memory=False
158 | )
159 | train_validate_data = torch.utils.data.DataLoader(
160 | train_dataset,
161 | batch_size=1,
162 | shuffle=True,
163 | num_workers=args.num_workers//4,
164 | pin_memory=False
165 | )
166 | if args.dataset == 'graspnet1b':
167 | val_dataset = Dataset(args.dataset_path, ds_rotate=args.ds_rotate,
168 | output_size=args.output_size,
169 | random_rotate=False, random_zoom=False,
170 | include_depth=args.use_depth,
171 | include_rgb=args.use_rgb,
172 | camera=args.camera,
173 | scale=args.scale,
174 | split='test_seen')
175 | val_dataset_1 = Dataset(args.dataset_path, ds_rotate=False,
176 | output_size=args.output_size,
177 | random_rotate=False, random_zoom=False,
178 | include_depth=args.use_depth,
179 | include_rgb=args.use_rgb,
180 | camera=args.camera,
181 | scale=args.scale,
182 | split='test_similar')
183 | val_dataset_2 = Dataset(args.dataset_path, ds_rotate=False,
184 | output_size=args.output_size,
185 | random_rotate=False, random_zoom=False,
186 | include_depth=args.use_depth,
187 | include_rgb=args.use_rgb,
188 | camera=args.camera,
189 | scale=args.scale,
190 | split='test_novel')
191 |
192 | else:
193 | val_dataset = Dataset(args.dataset_path, start=args.split, end=1.0, ds_rotate=args.ds_rotate,
194 | output_size=args.output_size,
195 | random_rotate=True, random_zoom=True,
196 | include_depth=args.use_depth,
197 | include_rgb=args.use_rgb)
198 | val_data = torch.utils.data.DataLoader(
199 | val_dataset,
200 | batch_size=1, # do not modify
201 | shuffle=True,
202 | num_workers=args.num_workers // 4,
203 | pin_memory=False,
204 | )
205 | val_data_1 = torch.utils.data.DataLoader(
206 | val_dataset_1,
207 | batch_size=1, # do not modify
208 | shuffle=True,
209 | num_workers=args.num_workers // 4,
210 | pin_memory=False
211 | )
212 | val_data_2 = torch.utils.data.DataLoader(
213 | val_dataset_2,
214 | batch_size=1, # do not modify
215 | shuffle=True,
216 | num_workers=args.num_workers // 4,
217 | pin_memory=False
218 | )
219 | logging.info('Done')
220 |
221 | # Load the network
222 | logging.info('Loading Network...')
223 | input_channels = 1*args.use_depth + 3*args.use_rgb
224 | print("channels:",input_channels)
225 |
226 | net=SwinTransformerSys(in_chans=input_channels,embed_dim=48,num_heads=[1, 2, 4, 8])
227 | device = torch.device("cuda:0")
228 | net = net.to(device)
229 | optimizer = optim.AdamW(net.parameters(),lr=1e-4)
230 | logging.info('Done')
231 | summary(net, (input_channels, 224, 224))
232 | f = open(os.path.join(save_folder, 'arch.txt'), 'w')
233 | sys.stdout = f
234 | summary(net, (input_channels, 224, 224))
235 | sys.stdout = sys.__stdout__
236 | f.close()
237 |
238 | best_iou = 0.0
239 | for epoch in range(args.epochs):
240 | logging.info('Beginning Epoch {:02d}'.format(epoch))
241 | train_results = train(epoch, net, device, train_data, optimizer, args.batches_per_epoch, vis=args.vis)
242 |
243 |
244 | test_results = validate(net, device, train_validate_data, args.val_batches)
245 | logging.info(' traning %d/%d = %f' % (test_results['correct'], test_results['correct'] + test_results['failed'],
246 | test_results['correct'] / (
247 | test_results['correct'] + test_results['failed'])))
248 | logging.info('loss/train_loss: %f'%test_results['loss'])
249 |
250 | test_results = validate(net, device, val_data, args.val_batches)
251 | logging.info(' seen %d/%d = %f' % (test_results['correct'], test_results['correct'] + test_results['failed'],
252 | test_results['correct']/(test_results['correct']+test_results['failed'])))
253 | logging.info('loss/seen_loss: %f'%test_results['loss'])
254 |
255 | test_results = validate(net, device, val_data_1, args.val_batches)
256 | logging.info('similar %d/%d = %f' % (test_results['correct'], test_results['correct'] + test_results['failed'],
257 | test_results['correct'] / (test_results['correct'] + test_results['failed'])))
258 | logging.info('loss/similar_loss: %f'%test_results['loss'])
259 |
260 | test_results = validate(net, device, val_data_2, args.val_batches,no_grasps=2)
261 | logging.info('novel %d/%d = %f' % (test_results['correct'], test_results['correct'] + test_results['failed'],
262 | test_results['correct'] / (test_results['correct'] + test_results['failed'])))
263 | logging.info('loss/novel_loss: %f'%test_results['loss'])
264 |
265 |
266 | # Save best performing network
267 | iou = test_results['correct'] / (test_results['correct'] + test_results['failed'])
268 | if iou > best_iou or epoch == 0 or (epoch % 10) == 0:
269 | torch.save(net, os.path.join(save_folder, 'epoch_%02d_iou_%0.2f' % (epoch, iou)))
270 | # torch.save(net.state_dict(), os.path.join(save_folder, 'epoch_%02d_iou_%0.2f_statedict.pt' % (epoch, iou)))
271 | best_iou = iou
272 |
273 |
274 | if __name__ == '__main__':
275 | run()
276 |
--------------------------------------------------------------------------------
/main_k_fold.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import sys
4 | import argparse
5 | import logging
6 | import cv2
7 | import torch
8 | import torch.utils.data
9 | import torch.optim as optim
10 | from torchsummary import summary
11 | from sklearn.model_selection import KFold
12 | from traning import train, validate
13 | from utils.data import get_dataset
14 | from models.swin import SwinTransformerSys
15 | # from models.Swin_without_skipconcetion import SwinTransformerSys
16 | logging.basicConfig(level=logging.INFO)
17 |
18 | def parse_args():
19 | parser = argparse.ArgumentParser(description='TF-Grasp')
20 |
21 | # Network
22 |
23 |
24 | # Dataset & Data & Training
25 | parser.add_argument('--dataset', type=str,default="cornell", help='Dataset Name ("cornell" or "jaquard or multi")')
26 | parser.add_argument('--dataset-path', type=str,default="/home/sam/Desktop/archive111" ,help='Path to dataset')
27 | parser.add_argument('--use-depth', type=int, default=1, help='Use Depth image for training (1/0)')
28 | parser.add_argument('--use-rgb', type=int, default=0, help='Use RGB image for training (0/1)')
29 | parser.add_argument('--split', type=float, default=0.9, help='Fraction of data for training (remainder is validation)')
30 | parser.add_argument('--ds-rotate', type=float, default=0.0,
31 | help='Shift the start point of the dataset to use a different test/train split for cross validation.')
32 | parser.add_argument('--num-workers', type=int, default=8, help='Dataset workers')
33 |
34 | parser.add_argument('--batch-size', type=int, default=32, help='Batch size')
35 | parser.add_argument('--epochs', type=int, default=2000, help='Training epochs')
36 | parser.add_argument('--batches-per-epoch', type=int, default=500, help='Batches per Epoch')
37 | parser.add_argument('--val-batches', type=int, default=50, help='Validation Batches')
38 | # Logging etc.
39 | parser.add_argument('--description', type=str, default='', help='Training description')
40 | parser.add_argument('--outdir', type=str, default='output/models/', help='Training Output Directory')
41 |
42 | args = parser.parse_args()
43 | return args
44 |
45 |
46 | def run():
47 | args = parse_args()
48 | # Set-up output directories
49 | dt = datetime.datetime.now().strftime('%y%m%d_%H%M')
50 | net_desc = '{}_{}'.format(dt, '_'.join(args.description.split()))
51 |
52 | save_folder = os.path.join(args.outdir, net_desc)
53 | if not os.path.exists(save_folder):
54 | os.makedirs(save_folder)
55 |
56 | # Load Dataset
57 | logging.info('Loading {} Dataset...'.format(args.dataset.title()))
58 | Dataset = get_dataset(args.dataset)
59 |
60 |
61 | dataset = Dataset(args.dataset_path, start=0.0, end=1.0, ds_rotate=args.ds_rotate,
62 | random_rotate=True, random_zoom=True,
63 | include_depth=args.use_depth, include_rgb=args.use_rgb)
64 | k_folds = 5
65 | kfold = KFold(n_splits=k_folds, shuffle=True)
66 | logging.info('Done')
67 | logging.info('Loading Network...')
68 | input_channels = 1*args.use_depth + 3*args.use_rgb
69 | net = SwinTransformerSys(in_chans=input_channels,embed_dim=48,num_heads=[1,2,4,8])
70 | device = torch.device("cuda:0")
71 | net = net.to(device)
72 | optimizer = optim.AdamW(net.parameters(), lr=5e-4)
73 | listy = [x *7 for x in range(1,1000,3)]
74 | schedule=torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=listy,gamma=0.6)
75 | logging.info('Done')
76 | # Print model architecture.
77 | summary(net, (input_channels, 224, 224))
78 | f = open(os.path.join(save_folder, 'arch.txt'), 'w')
79 | sys.stdout = f
80 | summary(net, (input_channels, 224, 224))
81 | sys.stdout = sys.__stdout__
82 | f.close()
83 |
84 | best_iou = 0.0
85 | for epoch in range(args.epochs):
86 | accuracy=0.
87 | for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):
88 |
89 | train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
90 | test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)
91 | trainloader = torch.utils.data.DataLoader(
92 | dataset,
93 | batch_size=args.batch_size,num_workers=args.num_workers, sampler=train_subsampler)
94 | testloader = torch.utils.data.DataLoader(
95 | dataset,
96 | batch_size=1,num_workers=args.num_workers, sampler=test_subsampler)
97 |
98 |
99 | logging.info('Beginning Epoch {:02d}'.format(epoch))
100 | print("lr:",optimizer.state_dict()['param_groups'][0]['lr'])
101 | train_results = train(epoch, net, device, trainloader, optimizer, args.batches_per_epoch, )
102 | schedule.step()
103 |
104 | # Run Validation
105 | logging.info('Validating...')
106 | test_results = validate(net, device, testloader, args.val_batches)
107 | logging.info('%d/%d = %f' % (test_results['correct'], test_results['correct'] + test_results['failed'],
108 | test_results['correct']/(test_results['correct']+test_results['failed'])))
109 |
110 |
111 | iou = test_results['correct'] / (test_results['correct'] + test_results['failed'])
112 | accuracy+=iou
113 | if iou > best_iou or epoch == 0 or (epoch % 50) == 0:
114 | torch.save(net, os.path.join(save_folder, 'epoch_%02d_iou_%0.2f' % (epoch, iou)))
115 | # torch.save(net.state_dict(), os.path.join(save_folder, 'epoch_%02d_iou_%0.2f_statedict.pt' % (epoch, iou)))
116 | best_iou = iou
117 | schedule.step()
118 | print("the accuracy:",accuracy/k_folds)
119 |
120 |
121 | if __name__ == '__main__':
122 | run()
123 |
124 |
125 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/models/common.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from skimage.filters import gaussian
3 |
4 |
5 | def post_process_output(q_img, cos_img, sin_img, width_img):
6 | """
7 | Post-process the raw output of the GG-CNN, convert to numpy arrays, apply filtering.
8 | :param q_img: Q output of GG-CNN (as torch Tensors)
9 | :param cos_img: cos output of GG-CNN
10 | :param sin_img: sin output of GG-CNN
11 | :param width_img: Width output of GG-CNN
12 | :return: Filtered Q output, Filtered Angle output, Filtered Width output
13 | """
14 | q_img = q_img.data.detach().cpu().numpy().squeeze()
15 | ang_img = (torch.atan2(sin_img, cos_img) / 2.0).data.detach().cpu().numpy().squeeze()
16 | width_img = width_img.data.detach().cpu().numpy().squeeze() * 150.0
17 |
18 | q_img = gaussian(q_img, 2.0, preserve_range=True)
19 | ang_img = gaussian(ang_img, 2.0, preserve_range=True)
20 | width_img = gaussian(width_img, 1.0, preserve_range=True)
21 |
22 | return q_img, ang_img, width_img
23 |
--------------------------------------------------------------------------------
/models/swin.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.checkpoint as checkpoint
4 | #from einops import rearrange
5 | #from timm.models.layers import DropPath, to_2tuple, trunc_normal_
6 | import torch.nn.functional as F
7 | def drop_path(x, drop_prob: float = 0., training: bool = False):
8 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
9 |
10 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
11 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
12 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
13 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
14 | 'survival rate' as the argument.
15 |
16 | """
17 | if drop_prob == 0. or not training:
18 | return x
19 | keep_prob = 1 - drop_prob
20 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
21 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
22 | random_tensor.floor_() # binarize
23 | output = x.div(keep_prob) * random_tensor
24 | return output
25 |
26 |
27 | class DropPath(nn.Module):
28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
29 | """
30 | def __init__(self, drop_prob=None):
31 | super(DropPath, self).__init__()
32 | self.drop_prob = drop_prob
33 |
34 | def forward(self, x):
35 | return drop_path(x, self.drop_prob, self.training)
36 |
37 | from itertools import repeat
38 | import collections.abc
39 |
40 |
41 | # From PyTorch internals
42 | def _ntuple(n):
43 | def parse(x):
44 | # if isinstance(x, collections.abc.Iterable):
45 | # return x
46 | return tuple(repeat(x, n))
47 | return parse
48 |
49 |
50 | to_1tuple = _ntuple(1)
51 | to_2tuple = _ntuple(2)
52 | to_3tuple = _ntuple(3)
53 | to_4tuple = _ntuple(4)
54 |
55 |
56 | import torch
57 | import math
58 | import warnings
59 |
60 | from torch.nn.init import _calculate_fan_in_and_fan_out
61 |
62 |
63 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
64 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
65 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
66 | def norm_cdf(x):
67 | # Computes standard normal cumulative distribution function
68 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
69 |
70 | if (mean < a - 2 * std) or (mean > b + 2 * std):
71 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
72 | "The distribution of values may be incorrect.",
73 | stacklevel=2)
74 |
75 | with torch.no_grad():
76 | # Values are generated by using a truncated uniform distribution and
77 | # then using the inverse CDF for the normal distribution.
78 | # Get upper and lower cdf values
79 | l = norm_cdf((a - mean) / std)
80 | u = norm_cdf((b - mean) / std)
81 |
82 | # Uniformly fill tensor with values from [l, u], then translate to
83 | # [2l-1, 2u-1].
84 | tensor.uniform_(2 * l - 1, 2 * u - 1)
85 |
86 | # Use inverse cdf transform for normal distribution to get truncated
87 | # standard normal
88 | tensor.erfinv_()
89 |
90 | # Transform to proper mean, std
91 | tensor.mul_(std * math.sqrt(2.))
92 | tensor.add_(mean)
93 |
94 | # Clamp to ensure it's in the proper range
95 | tensor.clamp_(min=a, max=b)
96 | return tensor
97 |
98 |
99 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
100 | # type: (Tensor, float, float, float, float) -> Tensor
101 | r"""Fills the input Tensor with values drawn from a truncated
102 | normal distribution. The values are effectively drawn from the
103 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
104 | with values outside :math:`[a, b]` redrawn until they are within
105 | the bounds. The method used for generating the random values works
106 | best when :math:`a \leq \text{mean} \leq b`.
107 | Args:
108 | tensor: an n-dimensional `torch.Tensor`
109 | mean: the mean of the normal distribution
110 | std: the standard deviation of the normal distribution
111 | a: the minimum cutoff value
112 | b: the maximum cutoff value
113 | Examples:
114 | >>> w = torch.empty(3, 5)
115 | >>> nn.init.trunc_normal_(w)
116 | """
117 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
118 | class Mlp(nn.Module):
119 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, drop=0.): #nn.ReLU nn.GELU
120 | super().__init__()
121 | out_features = out_features or in_features
122 | hidden_features = hidden_features or in_features
123 | self.fc1 = nn.Linear(in_features, hidden_features)
124 | self.act = act_layer()
125 | self.fc2 = nn.Linear(hidden_features, out_features)
126 | self.drop = nn.Dropout(drop)
127 |
128 | def forward(self, x):
129 | x = self.fc1(x)
130 | x = self.act(x)
131 | x = self.drop(x)
132 | x = self.fc2(x)
133 | x = self.drop(x)
134 | return x
135 | def window_partition(x, window_size):
136 | """
137 | Args:
138 | x: (B, H, W, C)
139 | window_size (int): window size
140 | Returns:
141 | windows: (num_windows*B, window_size, window_size, C)
142 | """
143 | B, H, W, C = x.shape
144 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
145 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
146 | return windows
147 | def window_reverse(windows, window_size, H, W):
148 | """
149 | Args:
150 | windows: (num_windows*B, window_size, window_size, C)
151 | window_size (int): Window size
152 | H (int): Height of image
153 | W (int): Width of image
154 | Returns:
155 | x: (B, H, W, C)
156 | """
157 | B = int(windows.shape[0] / (H * W / window_size / window_size))
158 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
159 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
160 | return x
161 |
162 | class WindowAttention(nn.Module):
163 | r""" Window based multi-head self attention (W-MSA) module with relative position bias.
164 | It supports both of shifted and non-shifted window.
165 | Args:
166 | dim (int): Number of input channels.
167 | window_size (tuple[int]): The height and width of the window.
168 | num_heads (int): Number of attention heads.
169 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
170 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
171 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
172 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
173 | """
174 |
175 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
176 |
177 | super().__init__()
178 | self.dim = dim
179 | self.window_size = window_size # Wh, Ww
180 | self.num_heads = num_heads
181 | head_dim = dim // num_heads
182 | self.scale = qk_scale or head_dim ** -0.5
183 |
184 | # define a parameter table of relative position bias
185 | self.relative_position_bias_table = nn.Parameter(
186 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
187 |
188 | # get pair-wise relative position index for each token inside the window
189 | coords_h = torch.arange(self.window_size[0])
190 | coords_w = torch.arange(self.window_size[1])
191 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
192 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
193 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
194 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
195 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
196 | relative_coords[:, :, 1] += self.window_size[1] - 1
197 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
198 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
199 | self.register_buffer("relative_position_index", relative_position_index)
200 |
201 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
202 | self.attn_drop = nn.Dropout(attn_drop)
203 | self.proj = nn.Linear(dim, dim)
204 | self.proj_drop = nn.Dropout(proj_drop)
205 |
206 | trunc_normal_(self.relative_position_bias_table, std=.02)
207 | self.softmax = nn.Softmax(dim=-1)
208 |
209 | def forward(self, x, mask=None):
210 | """
211 | Args:
212 | x: input features with shape of (num_windows*B, N, C)
213 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
214 | """
215 | B_, N, C = x.shape
216 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
217 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
218 |
219 | q = q * self.scale
220 | #attn = (q @ k.transpose(-2, -1))
221 | attn = torch.matmul(q,k.transpose(-2, -1))
222 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
223 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
224 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
225 | attn = attn + relative_position_bias.unsqueeze(0)
226 |
227 | if mask is not None:
228 | nW = mask.shape[0]
229 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
230 | attn = attn.view(-1, self.num_heads, N, N)
231 | attn = self.softmax(attn)
232 | else:
233 | attn = self.softmax(attn)
234 |
235 | attn = self.attn_drop(attn)
236 |
237 | # x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
238 | x = torch.matmul(attn,v).transpose(1, 2).reshape(B_, N, C)
239 | x = self.proj(x)
240 | x = self.proj_drop(x)
241 | return x
242 |
243 | def extra_repr(self) -> str:
244 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
245 |
246 | def flops(self, N):
247 | # calculate flops for 1 window with token length of N
248 | flops = 0
249 | # qkv = self.qkv(x)
250 | flops += N * self.dim * 3 * self.dim
251 | # attn = (q @ k.transpose(-2, -1))
252 | flops += self.num_heads * N * (self.dim // self.num_heads) * N
253 | # x = (attn @ v)
254 | flops += self.num_heads * N * N * (self.dim // self.num_heads)
255 | # x = self.proj(x)
256 | flops += N * self.dim * self.dim
257 | return flops
258 |
259 | class SwinTransformerBlock(nn.Module):
260 | r""" Swin Transformer Block.
261 | Args:
262 | dim (int): Number of input channels.
263 | input_resolution (tuple[int]): Input resulotion.
264 | num_heads (int): Number of attention heads.
265 | window_size (int): Window size.
266 | shift_size (int): Shift size for SW-MSA.
267 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
268 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
269 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
270 | drop (float, optional): Dropout rate. Default: 0.0
271 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
272 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
273 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
274 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
275 | """
276 |
277 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
278 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
279 | act_layer=nn.ReLU, norm_layer=nn.LayerNorm):
280 | super().__init__()
281 | self.dim = dim
282 | self.input_resolution = input_resolution
283 | self.num_heads = num_heads
284 | self.window_size = window_size
285 | self.shift_size = shift_size
286 | self.mlp_ratio = mlp_ratio
287 | if min(self.input_resolution) <= self.window_size:
288 | # if window size is larger than input resolution, we don't partition windows
289 | self.shift_size = 0
290 | self.window_size = min(self.input_resolution)
291 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
292 |
293 | self.norm1 = norm_layer(dim)
294 | self.attn = WindowAttention(
295 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
296 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
297 |
298 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
299 | self.norm2 = norm_layer(dim)
300 | mlp_hidden_dim = int(dim * mlp_ratio)
301 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
302 |
303 | if self.shift_size > 0:
304 | # calculate attention mask for SW-MSA
305 | H, W = self.input_resolution
306 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
307 | h_slices = (slice(0, -self.window_size),
308 | slice(-self.window_size, -self.shift_size),
309 | slice(-self.shift_size, None))
310 | w_slices = (slice(0, -self.window_size),
311 | slice(-self.window_size, -self.shift_size),
312 | slice(-self.shift_size, None))
313 | cnt = 0
314 | for h in h_slices:
315 | for w in w_slices:
316 | img_mask[:, h, w, :] = cnt
317 | cnt += 1
318 |
319 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
320 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
321 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
322 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
323 | else:
324 | attn_mask = None
325 |
326 | self.register_buffer("attn_mask", attn_mask)
327 |
328 | def forward(self, x):
329 | H, W = self.input_resolution
330 | B, L, C = x.shape
331 | assert L == H * W, "input feature has wrong size"
332 |
333 | shortcut = x
334 | x = self.norm1(x)
335 | x = x.view(B, H, W, C)
336 |
337 | pad_l = pad_t = 0
338 | pad_r = (self.window_size - W % self.window_size) % self.window_size
339 | pad_b = (self.window_size - H % self.window_size) % self.window_size
340 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
341 | _, Hp, Wp, _ = x.shape
342 | # cyclic shift
343 |
344 | if self.shift_size > 0:
345 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
346 | else:
347 | shifted_x = x
348 |
349 | # partition windows
350 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
351 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
352 |
353 | # W-MSA/SW-MSA
354 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
355 |
356 | # merge windows
357 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
358 | shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
359 |
360 | # reverse cyclic shift
361 | if self.shift_size > 0:
362 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
363 | else:
364 | x = shifted_x
365 | x = x.view(B, H * W, C)
366 |
367 | # FFN
368 | x = shortcut + self.drop_path(x)
369 | x = x + self.drop_path(self.mlp(self.norm2(x)))
370 |
371 | return x
372 |
373 | def extra_repr(self) -> str:
374 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
375 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
376 |
377 | def flops(self):
378 | flops = 0
379 | H, W = self.input_resolution
380 | # norm1
381 | flops += self.dim * H * W
382 | # W-MSA/SW-MSA
383 | nW = H * W / self.window_size / self.window_size
384 | flops += nW * self.attn.flops(self.window_size * self.window_size)
385 | # mlp
386 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
387 | # norm2
388 | flops += self.dim * H * W
389 | return flops
390 |
391 | class PatchMerging(nn.Module):
392 | r""" Patch Merging Layer.
393 | Args:
394 | input_resolution (tuple[int]): Resolution of input feature.
395 | dim (int): Number of input channels.
396 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
397 | """
398 |
399 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
400 | super().__init__()
401 | self.input_resolution = input_resolution
402 | self.dim = dim
403 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
404 | self.norm = norm_layer(4 * dim)
405 |
406 | def forward(self, x):
407 | """
408 | x: B, H*W, C
409 | """
410 | H, W = self.input_resolution
411 | B, L, C = x.shape
412 | assert L == H * W, "input feature has wrong size"
413 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
414 |
415 | x = x.view(B, H, W, C)
416 | pad_input = (H % 2 == 1) or (W % 2 == 1)
417 | if pad_input:
418 | # to pad the last 3 dimensions, starting from the last dimension and moving forward.
419 | # (C_front, C_back, W_left, W_right, H_top, H_bottom)
420 | # 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
421 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
422 |
423 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
424 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
425 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
426 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
427 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
428 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
429 |
430 | x = self.norm(x)
431 | x = self.reduction(x)
432 |
433 | return x
434 |
435 | def extra_repr(self) -> str:
436 | return f"input_resolution={self.input_resolution}, dim={self.dim}"
437 |
438 | def flops(self):
439 | H, W = self.input_resolution
440 | flops = H * W * self.dim
441 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
442 | return flops
443 | # class PatchExpand(nn.Module):
444 | # def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
445 | # super().__init__()
446 | # self.input_resolution = input_resolution
447 | # self.dim = dim
448 | # self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity()
449 | # self.norm = norm_layer(dim // dim_scale)
450 | #
451 | # def forward(self, x):
452 | # """
453 | # x: B, H*W, C
454 | # """
455 | # H, W = self.input_resolution
456 | # x = self.expand(x)
457 | # B, L, C = x.shape
458 | # assert L == H * W, "input feature has wrong size"
459 | #
460 | # x = x.view(B, H, W, C)
461 | # x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)
462 | # x = x.view(B,-1,C//4)
463 | # x= self.norm(x)
464 | #
465 | # return x
466 | class PatchExpand(nn.Module):
467 | def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
468 | super().__init__()
469 | self.input_resolution = input_resolution
470 | self.dim = dim
471 | self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity()
472 | self.norm = norm_layer(dim // dim_scale)
473 |
474 | def forward(self, x):
475 | """
476 | x: B, H*W, C
477 | """
478 | H, W = self.input_resolution
479 | x = self.expand(x)
480 | B, L, C = x.shape
481 | assert L == H * W, "input feature has wrong size"
482 | x = x.view(B, H, W, C)
483 | #print("x:",x.shape)
484 | x=x.reshape(B,H*2,W*2, C//4)
485 | #x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)
486 | x = x.view(B,-1,C//4)
487 | x= self.norm(x)
488 |
489 | return x
490 |
491 | class FinalPatchExpand_X4(nn.Module):
492 | def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
493 | super().__init__()
494 | self.input_resolution = input_resolution
495 | self.dim = dim
496 | self.dim_scale = dim_scale
497 | self.expand = nn.Linear(dim, 16*dim, bias=False)
498 | self.output_dim = dim
499 | self.norm = norm_layer(self.output_dim)
500 |
501 | def forward(self, x):
502 | """
503 | x: B, H*W, C
504 | """
505 | H, W = self.input_resolution
506 | x = self.expand(x)
507 | B, L, C = x.shape
508 | assert L == H * W, "input feature has wrong size"
509 |
510 | x = x.view(B, H, W, C)
511 | x = x.reshape(B,H*self.dim_scale,W*self.dim_scale,C//(self.dim_scale**2))
512 | #x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2))
513 | #print(x.shape)
514 | x = x.view(B,-1,self.output_dim)
515 | x= self.norm(x)
516 |
517 | return x
518 |
519 | class BasicLayer(nn.Module):
520 | """ A basic Swin Transformer layer for one stage.
521 | Args:
522 | dim (int): Number of input channels.
523 | input_resolution (tuple[int]): Input resolution.
524 | depth (int): Number of blocks.
525 | num_heads (int): Number of attention heads.
526 | window_size (int): Local window size.
527 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
528 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
529 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
530 | drop (float, optional): Dropout rate. Default: 0.0
531 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
532 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
533 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
534 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
535 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
536 | """
537 |
538 | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
539 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
540 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
541 |
542 | super().__init__()
543 | self.dim = dim
544 | self.input_resolution = input_resolution
545 | self.depth = depth
546 | self.use_checkpoint = use_checkpoint
547 |
548 | # build blocks
549 | self.blocks = nn.ModuleList([
550 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
551 | num_heads=num_heads, window_size=window_size,
552 | shift_size=0 if (i % 2 == 0) else window_size // 2,
553 | mlp_ratio=mlp_ratio,
554 | qkv_bias=qkv_bias, qk_scale=qk_scale,
555 | drop=drop, attn_drop=attn_drop,
556 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
557 | norm_layer=norm_layer)
558 | for i in range(depth)])
559 |
560 | # patch merging layer
561 | if downsample is not None:
562 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
563 | else:
564 | self.downsample = None
565 |
566 | def forward(self, x):
567 | for blk in self.blocks:
568 | if self.use_checkpoint:
569 | x = checkpoint.checkpoint(blk, x)
570 | else:
571 | x = blk(x)
572 | if self.downsample is not None:
573 | x = self.downsample(x)
574 | return x
575 | def extra_repr(self) -> str:
576 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
577 |
578 | def flops(self):
579 | flops = 0
580 | for blk in self.blocks:
581 | flops += blk.flops()
582 | if self.downsample is not None:
583 | flops += self.downsample.flops()
584 | return flops
585 |
586 | class BasicLayer_up(nn.Module):
587 | """ A basic Swin Transformer layer for one stage.
588 | Args:
589 | dim (int): Number of input channels.
590 | input_resolution (tuple[int]): Input resolution.
591 | depth (int): Number of blocks.
592 | num_heads (int): Number of attention heads.
593 | window_size (int): Local window size.
594 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
595 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
596 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
597 | drop (float, optional): Dropout rate. Default: 0.0
598 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
599 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
600 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
601 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
602 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
603 | """
604 |
605 | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
606 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
607 | drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):
608 |
609 | super().__init__()
610 | self.dim = dim
611 | self.input_resolution = input_resolution
612 | self.depth = depth
613 | self.use_checkpoint = use_checkpoint
614 |
615 | # build blocks
616 | self.blocks = nn.ModuleList([
617 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
618 | num_heads=num_heads, window_size=window_size,
619 | shift_size=0 if (i % 2 == 0) else window_size // 2,
620 | mlp_ratio=mlp_ratio,
621 | qkv_bias=qkv_bias, qk_scale=qk_scale,
622 | drop=drop, attn_drop=attn_drop,
623 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
624 | norm_layer=norm_layer)
625 | for i in range(depth)])
626 |
627 | # patch merging layer
628 | if upsample is not None:
629 | self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer)
630 | else:
631 | self.upsample = None
632 |
633 | def forward(self, x):
634 | for blk in self.blocks:
635 | if self.use_checkpoint:
636 | x = checkpoint.checkpoint(blk, x)
637 | else:
638 | x = blk(x)
639 | if self.upsample is not None:
640 | x = self.upsample(x)
641 | return x
642 |
643 | class PatchEmbed(nn.Module):
644 | r""" Image to Patch Embedding
645 | Args:
646 | img_size (int): Image size. Default: 224.
647 | patch_size (int): Patch token size. Default: 4.
648 | in_chans (int): Number of input image channels. Default: 3.
649 | embed_dim (int): Number of linear projection output channels. Default: 96.
650 | norm_layer (nn.Module, optional): Normalization layer. Default: None
651 | """
652 |
653 | def __init__(self, img_size=300, patch_size=4, in_chans=1, embed_dim=96, norm_layer=None):
654 | super().__init__()
655 | img_size = to_2tuple(img_size)
656 | patch_size = to_2tuple(patch_size)
657 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
658 | self.img_size = img_size
659 | self.patch_size = patch_size
660 | self.patches_resolution = patches_resolution
661 | self.num_patches = patches_resolution[0] * patches_resolution[1]
662 |
663 | self.in_chans = in_chans
664 | self.embed_dim = embed_dim
665 |
666 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
667 | if norm_layer is not None:
668 | self.norm = norm_layer(embed_dim)
669 | else:
670 | self.norm = None
671 |
672 | def forward(self, x):
673 | B, C, H, W = x.shape
674 | # FIXME look at relaxing size constraints
675 | assert H == self.img_size[0] and W == self.img_size[1], \
676 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
677 |
678 | # pad_input = (H % self.patch_size != 0) or (W % self.patch_size != 0)
679 | # if pad_input:
680 | # # to pad the last 3 dimensions,
681 | # # (W_left, W_right, H_top,H_bottom, C_front, C_back)
682 | # x = F.pad(x, (0, self.patch_size - W % self.patch_size,
683 | # 0, self.patch_size - H % self.patch_size,
684 | # 0, 0))
685 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
686 | if self.norm is not None:
687 | x = self.norm(x)
688 | return x
689 |
690 | def flops(self):
691 | Ho, Wo = self.patches_resolution
692 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
693 | if self.norm is not None:
694 | flops += Ho * Wo * self.embed_dim
695 | return flops
696 |
697 |
698 | class SwinTransformerSys(nn.Module):
699 | r""" Swin Transformer
700 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
701 | https://arxiv.org/pdf/2103.14030
702 | Args:
703 | img_size (int | tuple(int)): Input image size. Default 224
704 | patch_size (int | tuple(int)): Patch size. Default: 4
705 | in_chans (int): Number of input image channels. Default: 3
706 | num_classes (int): Number of classes for classification head. Default: 1000
707 | embed_dim (int): Patch embedding dimension. Default: 96
708 | depths (tuple(int)): Depth of each Swin Transformer layer.
709 | num_heads (tuple(int)): Number of attention heads in different layers.
710 | window_size (int): Window size. Default: 7
711 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
712 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
713 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
714 | drop_rate (float): Dropout rate. Default: 0
715 | attn_drop_rate (float): Attention dropout rate. Default: 0
716 | drop_path_rate (float): Stochastic depth rate. Default: 0.1
717 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
718 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
719 | patch_norm (bool): If True, add normalization after patch embedding. Default: True
720 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
721 | """
722 |
723 | def __init__(self, img_size=224, patch_size=4, in_chans=1, num_classes=1,
724 | embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24],
725 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
726 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
727 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
728 | use_checkpoint=False, final_upsample="expand_first", **kwargs):
729 | super().__init__()
730 |
731 | print(
732 | "SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(
733 | depths,
734 | depths_decoder, drop_path_rate, num_classes))
735 |
736 | self.num_classes = num_classes
737 | self.num_layers = len(depths)
738 | self.embed_dim = embed_dim
739 | self.ape = ape
740 | self.patch_norm = patch_norm
741 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
742 | self.num_features_up = int(embed_dim * 2)
743 | self.mlp_ratio = mlp_ratio
744 | self.final_upsample = final_upsample
745 |
746 | # split image into non-overlapping patches
747 | self.patch_embed = PatchEmbed(
748 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
749 | norm_layer=norm_layer if self.patch_norm else None)
750 | num_patches = self.patch_embed.num_patches
751 | patches_resolution = self.patch_embed.patches_resolution
752 | self.patches_resolution = patches_resolution
753 |
754 | # absolute position embedding
755 | if self.ape:
756 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
757 | trunc_normal_(self.absolute_pos_embed, std=.02)
758 |
759 | self.pos_drop = nn.Dropout(p=drop_rate)
760 |
761 | # stochastic depth
762 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
763 |
764 | # build encoder and bottleneck layers
765 | self.layers = nn.ModuleList()
766 | for i_layer in range(self.num_layers):
767 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
768 | input_resolution=(patches_resolution[0] // (2 ** i_layer),
769 | patches_resolution[1] // (2 ** i_layer)),
770 | depth=depths[i_layer],
771 | num_heads=num_heads[i_layer],
772 | window_size=window_size,
773 | mlp_ratio=self.mlp_ratio,
774 | qkv_bias=qkv_bias, qk_scale=qk_scale,
775 | drop=drop_rate, attn_drop=attn_drop_rate,
776 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
777 | norm_layer=norm_layer,
778 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
779 | use_checkpoint=use_checkpoint)
780 | self.layers.append(layer)
781 |
782 | # build decoder layers
783 | self.layers_up = nn.ModuleList()
784 | self.concat_back_dim = nn.ModuleList()
785 | for i_layer in range(self.num_layers):
786 | concat_linear = nn.Linear(2 * int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)),
787 | int(embed_dim * 2 ** (
788 | self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity()
789 | if i_layer == 0:
790 | layer_up = PatchExpand(
791 | input_resolution=(patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)),
792 | patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))),
793 | dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)), dim_scale=2, norm_layer=norm_layer)
794 | else:
795 | layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers - 1 - i_layer)),
796 | input_resolution=(
797 | patches_resolution[0] // (2 ** (self.num_layers - 1 - i_layer)),
798 | patches_resolution[1] // (2 ** (self.num_layers - 1 - i_layer))),
799 | depth=depths[(self.num_layers - 1 - i_layer)],
800 | num_heads=num_heads[(self.num_layers - 1 - i_layer)],
801 | window_size=window_size,
802 | mlp_ratio=self.mlp_ratio,
803 | qkv_bias=qkv_bias, qk_scale=qk_scale,
804 | drop=drop_rate, attn_drop=attn_drop_rate,
805 | drop_path=dpr[sum(depths[:(self.num_layers - 1 - i_layer)]):sum(
806 | depths[:(self.num_layers - 1 - i_layer) + 1])],
807 | norm_layer=norm_layer,
808 | upsample=PatchExpand if (i_layer < self.num_layers - 1) else None,
809 | use_checkpoint=use_checkpoint)
810 | self.layers_up.append(layer_up)
811 | self.concat_back_dim.append(concat_linear)
812 |
813 | self.norm = norm_layer(self.num_features)
814 | self.norm_up = norm_layer(self.embed_dim)
815 |
816 | if self.final_upsample == "expand_first":
817 | print("---final upsample expand_first---")
818 | self.up = FinalPatchExpand_X4(input_resolution=(img_size // patch_size, img_size // patch_size),
819 | dim_scale=4, dim=embed_dim)
820 | self.pos_output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False)
821 | self.cos_output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False)
822 | self.sin_output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False)
823 | self.width_output = nn.Conv2d(in_channels=embed_dim, out_channels=self.num_classes, kernel_size=1, bias=False)
824 | self.apply(self._init_weights)
825 |
826 | def _init_weights(self, m):
827 | if isinstance(m, nn.Linear):
828 | trunc_normal_(m.weight, std=.02)
829 | if isinstance(m, nn.Linear) and m.bias is not None:
830 | nn.init.constant_(m.bias, 0)
831 | elif isinstance(m, nn.LayerNorm):
832 | nn.init.constant_(m.bias, 0)
833 | nn.init.constant_(m.weight, 1.0)
834 |
835 | @torch.jit.ignore
836 | def no_weight_decay(self):
837 | return {'absolute_pos_embed'}
838 |
839 | @torch.jit.ignore
840 | def no_weight_decay_keywords(self):
841 | return {'relative_position_bias_table'}
842 |
843 | # Encoder and Bottleneck
844 | def forward_features(self, x):
845 | x = self.patch_embed(x)
846 | if self.ape:
847 | x = x + self.absolute_pos_embed
848 | x = self.pos_drop(x)
849 | x_downsample = []
850 |
851 | for layer in self.layers:
852 | x_downsample.append(x)
853 | x = layer(x)
854 |
855 | x = self.norm(x) # B L C
856 |
857 | return x, x_downsample
858 |
859 | # Dencoder and Skip connection
860 | def forward_up_features(self, x, x_downsample):
861 | for inx, layer_up in enumerate(self.layers_up):
862 | if inx == 0:
863 | x = layer_up(x)
864 | else:
865 | x = torch.cat([x, x_downsample[3 - inx]], -1)
866 | # print(x.shape)
867 | x = self.concat_back_dim[inx](x)
868 | x = layer_up(x)
869 |
870 | x = self.norm_up(x) # B L C
871 |
872 | return x
873 |
874 | def up_x4(self, x):
875 | H, W = self.patches_resolution
876 | B, L, C = x.shape
877 | assert L == H * W, "input features has wrong size"
878 |
879 | if self.final_upsample == "expand_first":
880 | x = self.up(x)
881 | x = x.view(B, 4 * H, 4 * W, -1)
882 | x = x.permute(0, 3, 1, 2) # B,C,H,W
883 | pos_output = self.pos_output(x)
884 | cos_output = self.cos_output(x)
885 | sin_output = self.sin_output(x)
886 | width_output = self.width_output(x)
887 |
888 |
889 | return pos_output, cos_output, sin_output, width_output
890 |
891 | def forward(self, x):
892 | x, x_downsample = self.forward_features(x)
893 | x = self.forward_up_features(x, x_downsample)
894 | pos_output, cos_output, sin_output, width_output = self.up_x4(x)
895 |
896 | return pos_output, cos_output, sin_output, width_output
897 | def compute_loss(self, xc, yc):
898 | y_pos, y_cos, y_sin, y_width = yc
899 | pos_pred, cos_pred, sin_pred, width_pred = self(xc)
900 | # print("pos shape:",pos_pred.shape)
901 | p_loss = F.mse_loss(pos_pred, y_pos)
902 | cos_loss = F.mse_loss(cos_pred, y_cos)
903 | sin_loss = F.mse_loss(sin_pred, y_sin)
904 | width_loss = F.mse_loss(width_pred, y_width)
905 |
906 | return {
907 | 'loss': p_loss + cos_loss + sin_loss + width_loss,
908 | 'losses': {
909 | 'p_loss': p_loss,
910 | 'cos_loss': cos_loss,
911 | 'sin_loss': sin_loss,
912 | 'width_loss': width_loss
913 | },
914 | 'pred': {
915 | 'pos': pos_pred,
916 | 'cos': cos_pred,
917 | 'sin': sin_pred,
918 | 'width': width_pred
919 | }
920 | }
921 | def flops(self):
922 | flops = 0
923 | flops += self.patch_embed.flops()
924 | for i, layer in enumerate(self.layers):
925 | flops += layer.flops()
926 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
927 | flops += self.num_features * self.num_classes
928 | return flops
929 | if __name__ == '__main__':
930 | from torchsummary import summary
931 | # model = SwinTransformerSys(in_chans=1)
932 | # summary(model, (1, 224, 224))
933 | # model = SwinTransformerSys(in_chans=3).cuda()
934 | #
935 | #
936 | import time
937 |
938 | # from torchstat import stat
939 | from ptflops import get_model_complexity_info
940 | from thop import profile
941 | model = SwinTransformerSys(in_chans=1,embed_dim=24,num_heads=[3, 6, 12, 24]).cuda()
942 | # model = SwinTransformerSys(in_chans=1, embed_dim=12, num_heads=[1, 2, 2, 4]).cuda()
943 | # flops, params = profile(model, inputs=(input, ))
944 | # macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True,
945 | # print_per_layer_stat=True, verbose=True)
946 | # print('{:<30} {:<8}'.format('Computational complexity: ', macs))
947 | # print('{:<30} {:<8}'.format('Number of parameters: ', params))
948 | # model = SwinTransformerSys(in_chans=3, embed_dim=24, num_heads=[1, 2, 4, 8]).cuda()
949 | # stat(model, (3, 224, 224))
950 | summary(model, (1, 224, 224))
951 | # model = SwinTransformerSys(in_chans=1, )
952 | # model = SwinTransformerSys(in_chans=1, embed_dim=96, num_heads=[3, 6, 12, 24])
953 | # model = SwinTransformerSys(in_chans=1, embed_dim=96, num_heads=[1, 2, 4, 8])
954 | # model = SwinTransformerSys(in_chans=1, embed_dim=48, num_heads=[1, 2, 4, 8])
955 | model = SwinTransformerSys(in_chans=1, embed_dim=48, num_heads=[1, 2, 4, 8])
956 | # summary(model, (1, 224, 224))
957 | model.eval()
958 | sum=0.
959 | # imge = torch.rand(1, 1, 224, 224)
960 | len=20
961 | # imge = torch.rand(1, 3, 224, 224)
962 | for i in range(len):
963 | imge = torch.rand(1, 1, 224, 224)
964 | start = time.perf_counter()
965 | a = model(imge)
966 | # time.sleep(3)
967 | end = time.perf_counter()
968 | dur = end - start
969 | sum+=dur
970 | print(sum/len)
971 | # imge = torch.rand(3, 1, 224, 224)
972 | # a = model(imge)
973 | # print(a[0].shape)
974 | # print(model)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | opencv-python
3 | matplotlib
4 | scikit-image
5 | imageio
6 | torch
7 | torchvision
8 | torchsummary
9 | tensorboardX
10 | sklearn
11 |
--------------------------------------------------------------------------------
/traning.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | import torch.optim as optim
4 | from utils.dataset_processing import evaluation
5 | from models.common import post_process_output
6 | import logging
7 | def validate(net, device, val_data, batches_per_epoch):
8 | """
9 | Run validation.
10 | :param net: Network
11 | :param device: Torch device
12 | :param val_data: Validation Dataset
13 | :param batches_per_epoch: Number of batches to run
14 | :return: Successes, Failures and Losses
15 | """
16 | net.eval()
17 |
18 | results = {
19 | 'correct': 0,
20 | 'failed': 0,
21 | 'loss': 0,
22 | 'losses': {
23 |
24 | }
25 | }
26 |
27 | ld = len(val_data)
28 |
29 | with torch.no_grad():
30 | batch_idx = 0
31 | while batch_idx < batches_per_epoch:
32 | for x, y, didx, rot, zoom_factor in val_data:
33 | batch_idx += 1
34 | if batches_per_epoch is not None and batch_idx >= batches_per_epoch:
35 | break
36 |
37 | xc = x.to(device)
38 | yc = [yy.to(device) for yy in y]
39 | lossd = net.compute_loss(xc, yc)
40 |
41 | loss = lossd['loss']
42 |
43 | results['loss'] += loss.item()/ld
44 | for ln, l in lossd['losses'].items():
45 | if ln not in results['losses']:
46 | results['losses'][ln] = 0
47 | results['losses'][ln] += l.item()/ld
48 |
49 | q_out, ang_out, w_out = post_process_output(lossd['pred']['pos'], lossd['pred']['cos'],
50 | lossd['pred']['sin'], lossd['pred']['width'])
51 |
52 | s = evaluation.calculate_iou_match(q_out, ang_out,
53 | val_data.dataset.get_gtbb(didx, rot, zoom_factor),
54 | no_grasps=2,
55 | grasp_width=w_out,
56 | )
57 |
58 |
59 |
60 |
61 | if s:
62 | results['correct'] += 1
63 | else:
64 | results['failed'] += 1
65 |
66 | return results
67 |
68 |
69 | def train(epoch, net, device, train_data, optimizer, batches_per_epoch, vis=False):
70 | """
71 | Run one training epoch
72 | :param epoch: Current epoch
73 | :param net: Network
74 | :param device: Torch device
75 | :param train_data: Training Dataset
76 | :param optimizer: Optimizer
77 | :param batches_per_epoch: Data batches to train on
78 | :param vis: Visualise training progress
79 | :return: Average Losses for Epoch
80 | """
81 | results = {
82 | 'loss': 0,
83 | 'losses': {
84 | }
85 | }
86 |
87 | net.train()
88 |
89 | batch_idx = 0
90 | # Use batches per epoch to make training on different sized datasets (cornell/jacquard) more equivalent.
91 | while batch_idx < batches_per_epoch:
92 | for x, y, _, _, _ in train_data:
93 | # print("shape:",x.shape)
94 | batch_idx += 1
95 | # if batch_idx >= batches_per_epoch:
96 | # break
97 | # print("x_0:",x[0].shape,y[0][0].shape)
98 | # plt.imshow(x[0].permute(1,2,0).numpy())
99 | # plt.show()
100 | # plt.imshow(y[0][0][0].numpy())
101 | # plt.show()
102 | xc = x.to(device)
103 | yc = [yy.to(device) for yy in y]
104 | # print("xc shape:",xc.shape)
105 | lossd = net.compute_loss(xc, yc)
106 |
107 | loss = lossd['loss']
108 |
109 | if batch_idx % 10 == 0:
110 | logging.info('Epoch: {}, Batch: {}, Loss: {:0.4f}'.format(epoch, batch_idx, loss.item()))
111 |
112 | results['loss'] += loss.item()
113 | for ln, l in lossd['losses'].items():
114 | if ln not in results['losses']:
115 | results['losses'][ln] = 0
116 | results['losses'][ln] += l.item()
117 |
118 | optimizer.zero_grad()
119 | loss.backward()
120 | optimizer.step()
121 |
122 | # Display the images
123 | if vis:
124 | imgs = []
125 | n_img = min(4, x.shape[0])
126 | for idx in range(n_img):
127 | imgs.extend([x[idx,].numpy().squeeze()] + [yi[idx,].numpy().squeeze() for yi in y] + [
128 | x[idx,].numpy().squeeze()] + [pc[idx,].detach().cpu().numpy().squeeze() for pc in lossd['pred'].values()])
129 | # gridshow('Display', imgs,
130 | # [(xc.min().item(), xc.max().item()), (0.0, 1.0), (0.0, 1.0), (-1.0, 1.0), (0.0, 1.0)] * 2 * n_img,
131 | # [cv2.COLORMAP_BONE] * 10 * n_img, 10)
132 | # cv2.waitKey(2)
133 |
134 | results['loss'] /= batch_idx
135 | for l in results['losses']:
136 | results['losses'][l] /= batch_idx
137 |
138 | return results
139 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/data/__init__.py:
--------------------------------------------------------------------------------
1 | def get_dataset(dataset_name):
2 | if dataset_name == 'cornell':
3 | from .cornell_data import CornellDataset
4 | return CornellDataset
5 | elif dataset_name == 'jacquard':
6 | from .jacquard_data import JacquardDataset
7 | return JacquardDataset
8 | elif dataset_name == 'multi':
9 | from .multi_object import CornellDataset
10 | return CornellDataset
11 | elif dataset_name == 'graspnet1b':
12 | from .gn1b_data import GraspNet1BDataset
13 | return GraspNet1BDataset
14 | else:
15 | raise NotImplementedError('Dataset Type {} is Not implemented'.format(dataset_name))
--------------------------------------------------------------------------------
/utils/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/data/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/data/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/data/__pycache__/cornell_data.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/data/__pycache__/cornell_data.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/data/__pycache__/cornell_data.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/data/__pycache__/cornell_data.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/data/__pycache__/grasp_data.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/data/__pycache__/grasp_data.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/data/__pycache__/grasp_data.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/data/__pycache__/grasp_data.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/data/__pycache__/jacquard_data.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/data/__pycache__/jacquard_data.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/data/__pycache__/multi_object.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/data/__pycache__/multi_object.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/data/cornell_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 |
4 | from .grasp_data import GraspDatasetBase
5 | from utils.dataset_processing import grasp, image
6 |
7 |
8 | class CornellDataset(GraspDatasetBase):
9 | """
10 | Dataset wrapper for the Cornell dataset.
11 | """
12 | def __init__(self, file_path, start=0.0, end=1.0, ds_rotate=0, **kwargs):
13 | """
14 | :param file_path: Cornell Dataset directory.
15 | :param start: If splitting the dataset, start at this fraction [0,1]
16 | :param end: If splitting the dataset, finish at this fraction
17 | :param ds_rotate: If splitting the dataset, rotate the list of items by this fraction first
18 | :param kwargs: kwargs for GraspDatasetBase
19 | """
20 | super(CornellDataset, self).__init__(**kwargs)
21 |
22 | graspf = glob.glob(os.path.join(file_path, '*', 'pcd*cpos.txt'))
23 | graspf.sort()
24 | l = len(graspf)
25 | if l == 0:
26 | raise FileNotFoundError('No dataset files found. Check path: {}'.format(file_path))
27 |
28 | if ds_rotate:
29 | graspf = graspf[int(l*ds_rotate):] + graspf[:int(l*ds_rotate)]
30 |
31 | depthf = [f.replace('cpos.txt', 'd.tiff') for f in graspf]
32 | rgbf = [f.replace('d.tiff', 'r.png') for f in depthf]
33 |
34 | self.grasp_files = graspf[int(l*start):int(l*end)]
35 | self.depth_files = depthf[int(l*start):int(l*end)]
36 | self.rgb_files = rgbf[int(l*start):int(l*end)]
37 |
38 | def _get_crop_attrs(self, idx):
39 | gtbbs = grasp.GraspRectangles.load_from_cornell_file(self.grasp_files[idx])
40 | center = gtbbs.center
41 | left = max(0, min(center[1] - self.output_size // 2, 640 - self.output_size))
42 | top = max(0, min(center[0] - self.output_size // 2, 480 - self.output_size))
43 | return center, left, top
44 |
45 | def get_gtbb(self, idx, rot=0, zoom=1.0):
46 | gtbbs = grasp.GraspRectangles.load_from_cornell_file(self.grasp_files[idx])
47 | center, left, top = self._get_crop_attrs(idx)
48 | gtbbs.rotate(rot, center)
49 | gtbbs.offset((-top, -left))
50 | gtbbs.zoom(zoom, (self.output_size//2, self.output_size//2))
51 | return gtbbs
52 |
53 | def get_depth(self, idx, rot=0, zoom=1.0):
54 | depth_img = image.DepthImage.from_tiff(self.depth_files[idx])
55 | center, left, top = self._get_crop_attrs(idx)
56 | depth_img.rotate(rot, center)
57 | depth_img.crop((top, left), (min(480, top + self.output_size), min(640, left + self.output_size)))
58 | depth_img.normalise()
59 | depth_img.zoom(zoom)
60 | depth_img.resize((self.output_size, self.output_size))
61 | return depth_img.img
62 |
63 | def get_rgb(self, idx, rot=0, zoom=1.0, normalise=True):
64 | rgb_img = image.Image.from_file(self.rgb_files[idx])
65 | center, left, top = self._get_crop_attrs(idx)
66 | rgb_img.rotate(rot, center)
67 | rgb_img.crop((top, left), (min(480, top + self.output_size), min(640, left + self.output_size)))
68 | rgb_img.zoom(zoom)
69 | rgb_img.resize((self.output_size, self.output_size))
70 | if normalise:
71 | rgb_img.normalise()
72 | rgb_img.img = rgb_img.img.transpose((2, 0, 1))
73 | return rgb_img.img
74 |
--------------------------------------------------------------------------------
/utils/data/gn1b_data.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import copy
4 | import glob
5 |
6 | import torch
7 | import torch.utils.data
8 | from graspnetAPI import GraspNet
9 | from .grasp_data import GraspDatasetBase
10 | from .cornell_data import CornellDataset
11 |
12 | from utils.dataset_processing import grasp, image
13 |
14 | graspnet_root = "/home/zzl/Pictures/graspnet"
15 |
16 |
17 | class GraspNet1BDataset(GraspDatasetBase):
18 |
19 | def __init__(self, file_path, camera='realsense', split='train', scale=2.0, ds_rotate=True,
20 | output_size=224,
21 | random_rotate=True, random_zoom=True,
22 | include_depth=True,
23 | include_rgb=True,
24 | ):
25 | super(GraspNet1BDataset, self).__init__(output_size=output_size, include_depth=include_depth, include_rgb=include_rgb, random_rotate=random_rotate,
26 | random_zoom=random_zoom, input_only=False)
27 | logging.info('Graspnet root = {}'.format(graspnet_root))
28 | logging.info('Using data from camera {}'.format(camera))
29 | self.graspnet_root = graspnet_root
30 | self.camera = camera
31 | self.split = split
32 | self.scale = scale # 原图是hxw=720x1280,scale将原图缩小scale倍
33 |
34 | self._graspnet_instance = GraspNet(graspnet_root, camera, split)
35 |
36 | self.g_rgb_files = self._graspnet_instance.rgbPath # 存放rgb的路径
37 | self.g_depth_files = self._graspnet_instance.depthPath # 存放深度图的路径
38 | self.g_rect_files = [] # 存放抓取标签的路径
39 |
40 | for original_rect_grasp_file in self._graspnet_instance.rectLabelPath:
41 | self.g_rect_files.append(
42 | original_rect_grasp_file
43 | .replace('rect', 'rect_cornell')
44 | .replace('.npy', '.txt')
45 | )
46 |
47 | logging.info('Graspnet 1Billion dataset created!!')
48 |
49 | def _get_crop_attrs(self, idx, return_gtbbs=False):
50 | gtbbs = grasp.GraspRectangles.load_from_cornell_file(self.g_rect_files[idx], scale=self.scale)
51 | center = gtbbs.center
52 | left = max(0, min(center[1] - self.output_size // 2, int(1280 // self.scale) - self.output_size))
53 | top = max(0, min(center[0] - self.output_size // 2, int(720 // self.scale) - self.output_size))
54 | if not return_gtbbs:
55 | return center, left, top
56 | else:
57 | return center, left, top, gtbbs
58 |
59 | def get_gtbb(self, idx, rot=0, zoom=1):
60 | # gtbbs = grasp.GraspRectangles.load_from_cornell_file(self.g_rect_files[idx], scale=self.scale)
61 | center, left, top, gtbbs = self._get_crop_attrs(idx, return_gtbbs=True)
62 | gtbbs.rotate(rot, center)
63 | gtbbs.offset((-top, -left))
64 | gtbbs.zoom(zoom, (self.output_size // 2, self.output_size // 2))
65 | return gtbbs
66 |
67 | def get_depth(self, idx, rot=0, zoom=1.0):
68 | # graspnet 1b中的深度图单位转换成m
69 | depth_img = image.DepthImage.from_tiff(self.g_depth_files[idx], depth_scale=1000.0)
70 | rh, rw = int(720 // self.scale), int(1280 // self.scale)
71 | # 读入的是wxh=1280x720 resize成目标尺寸
72 | depth_img.resize((rh, rw))
73 | center, left, top = self._get_crop_attrs(idx)
74 | depth_img.rotate(rot, center)
75 | depth_img.crop((top, left), (min(rh, top + self.output_size), min(rw, left + self.output_size)))
76 | depth_img.normalise()
77 | depth_img.zoom(zoom)
78 | depth_img.resize((self.output_size, self.output_size))
79 | return depth_img.img
80 |
81 | def get_rgb(self, idx, rot=0, zoom=1.0, normalise=True):
82 | rgb_img = image.Image.from_file(self.g_rgb_files[idx])
83 | rh, rw = int(720 // self.scale), int(1280 // self.scale)
84 | # 读入的是wxh=1280x720 resize成目标尺寸
85 | rgb_img.resize((rh, rw))
86 | center, left, top = self._get_crop_attrs(idx)
87 | rgb_img.rotate(rot, center)
88 | rgb_img.crop((top, left), (min(rh, top + self.output_size), min(rw, left + self.output_size)))
89 | rgb_img.zoom(zoom)
90 | rgb_img.resize((self.output_size, self.output_size))
91 | if normalise:
92 | rgb_img.normalise()
93 | rgb_img.img = rgb_img.img.transpose((2, 0, 1))
94 | return rgb_img.img
95 |
96 | def __len__(self):
97 | return len(self.g_rect_files)
--------------------------------------------------------------------------------
/utils/data/grasp_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.utils.data
5 |
6 | import random
7 |
8 |
9 | class GraspDatasetBase(torch.utils.data.Dataset):
10 | """
11 | An abstract dataset for training GG-CNNs in a common format.
12 | """
13 | def __init__(self, output_size=224, include_depth=True, include_rgb=False, random_rotate=False,
14 | random_zoom=False, input_only=False):
15 | """
16 | :param output_size: Image output size in pixels (square)
17 | :param include_depth: Whether depth image is included
18 | :param include_rgb: Whether RGB image is included
19 | :param random_rotate: Whether random rotations are applied
20 | :param random_zoom: Whether random zooms are applied
21 | :param input_only: Whether to return only the network input (no labels)
22 | """
23 | self.output_size = output_size
24 | self.random_rotate = random_rotate
25 | self.random_zoom = random_zoom
26 | self.input_only = input_only
27 | self.include_depth = include_depth
28 | self.include_rgb = include_rgb
29 |
30 | self.grasp_files = []
31 |
32 | if include_depth is False and include_rgb is False:
33 | raise ValueError('At least one of Depth or RGB must be specified.')
34 |
35 | @staticmethod
36 | def numpy_to_torch(s):
37 | if len(s.shape) == 2:
38 | return torch.from_numpy(np.expand_dims(s, 0).astype(np.float32))
39 | else:
40 | return torch.from_numpy(s.astype(np.float32))
41 |
42 | def get_gtbb(self, idx, rot=0, zoom=1.0):
43 | raise NotImplementedError()
44 |
45 | def get_depth(self, idx, rot=0, zoom=1.0):
46 | raise NotImplementedError()
47 |
48 | def get_rgb(self, idx, rot=0, zoom=1.0):
49 | raise NotImplementedError()
50 |
51 | def __getitem__(self, idx):
52 | if self.random_rotate:
53 | rotations = [0, np.pi/2, 2*np.pi/2, 3*np.pi/2]
54 | rot = random.choice(rotations)
55 | else:
56 | rot = 0.0
57 |
58 | if self.random_zoom:
59 | zoom_factor = np.random.uniform(0.5, 1.0)
60 | else:
61 | zoom_factor = 1.0
62 |
63 | # Load the depth image
64 | if self.include_depth:
65 | depth_img = self.get_depth(idx, rot, zoom_factor)
66 |
67 | # Load the RGB image
68 | if self.include_rgb:
69 | rgb_img = self.get_rgb(idx, rot, zoom_factor)
70 |
71 | # Load the grasps
72 | bbs = self.get_gtbb(idx, rot, zoom_factor)
73 |
74 | pos_img, ang_img, width_img = bbs.draw((self.output_size, self.output_size))
75 | width_img = np.clip(width_img, 0.0, 150.0)/150.0
76 |
77 | if self.include_depth and self.include_rgb:
78 | x = self.numpy_to_torch(
79 | np.concatenate(
80 | (np.expand_dims(depth_img, 0),
81 | rgb_img),
82 | 0
83 | )
84 | )
85 | elif self.include_depth:
86 | x = self.numpy_to_torch(depth_img)
87 | elif self.include_rgb:
88 | x = self.numpy_to_torch(rgb_img)
89 |
90 | pos = self.numpy_to_torch(pos_img)
91 | cos = self.numpy_to_torch(np.cos(2*ang_img))
92 | sin = self.numpy_to_torch(np.sin(2*ang_img))
93 | width = self.numpy_to_torch(width_img)
94 |
95 | return x, (pos, cos, sin, width), idx, rot, zoom_factor
96 |
97 | def __len__(self):
98 | return len(self.grasp_files)
99 |
--------------------------------------------------------------------------------
/utils/data/jacquard_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 |
4 | from .grasp_data import GraspDatasetBase
5 | from utils.dataset_processing import grasp, image
6 |
7 |
8 | class JacquardDataset(GraspDatasetBase):
9 | """
10 | Dataset wrapper for the Jacquard dataset.
11 | """
12 | def __init__(self, file_path, start=0.0, end=1.0, ds_rotate=0, **kwargs):
13 | """
14 | :param file_path: Jacquard Dataset directory.
15 | :param start: If splitting the dataset, start at this fraction [0,1]
16 | :param end: If splitting the dataset, finish at this fraction
17 | :param ds_rotate: If splitting the dataset, rotate the list of items by this fraction first
18 | :param kwargs: kwargs for GraspDatasetBase
19 | """
20 | super(JacquardDataset, self).__init__(**kwargs)
21 |
22 | # graspf = glob.glob(os.path.join(file_path, '*', '*_grasps.txt'))
23 | graspf = glob.glob('/home/sam/Desktop/jacouard' + '/*/*/' + '*_grasps.txt')
24 | graspf.sort()
25 | l = len(graspf)
26 | print("len jaccquard:", l)
27 |
28 | if l == 0:
29 | raise FileNotFoundError('No dataset files found. Check path: {}'.format(file_path))
30 |
31 | if ds_rotate:
32 | graspf = graspf[int(l*ds_rotate):] + graspf[:int(l*ds_rotate)]
33 |
34 | depthf = [f.replace('grasps.txt', 'perfect_depth.tiff') for f in graspf]
35 | rgbf = [f.replace('perfect_depth.tiff', 'RGB.png') for f in depthf]
36 |
37 | self.grasp_files = graspf[int(l*start):int(l*end)]
38 | self.depth_files = depthf[int(l*start):int(l*end)]
39 | self.rgb_files = rgbf[int(l*start):int(l*end)]
40 |
41 | def get_gtbb(self, idx, rot=0, zoom=1.0):
42 | gtbbs = grasp.GraspRectangles.load_from_jacquard_file(self.grasp_files[idx], scale=self.output_size / 1024.0)
43 | c = self.output_size//2
44 | gtbbs.rotate(rot, (c, c))
45 | gtbbs.zoom(zoom, (c, c))
46 | return gtbbs
47 |
48 | def get_depth(self, idx, rot=0, zoom=1.0):
49 | depth_img = image.DepthImage.from_tiff(self.depth_files[idx])
50 | depth_img.rotate(rot)
51 | depth_img.normalise()
52 | depth_img.zoom(zoom)
53 | depth_img.resize((self.output_size, self.output_size))
54 | return depth_img.img
55 |
56 | def get_rgb(self, idx, rot=0, zoom=1.0, normalise=True):
57 | rgb_img = image.Image.from_file(self.rgb_files[idx])
58 | rgb_img.rotate(rot)
59 | rgb_img.zoom(zoom)
60 | rgb_img.resize((self.output_size, self.output_size))
61 | if normalise:
62 | rgb_img.normalise()
63 | rgb_img.img = rgb_img.img.transpose((2, 0, 1))
64 | return rgb_img.img
65 |
66 | def get_jname(self, idx):
67 | return '_'.join(self.grasp_files[idx].split(os.sep)[-1].split('_')[:-1])
--------------------------------------------------------------------------------
/utils/data/multi_object.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 |
4 | from .grasp_data import GraspDatasetBase
5 | from utils.dataset_processing import grasp, image
6 |
7 | class CornellDataset(GraspDatasetBase):
8 | """
9 | Dataset wrapper for the Cornell dataset.
10 | """
11 | def __init__(self, file_path, start=0.0, end=1.0, ds_rotate=0, **kwargs):
12 | """
13 | :param file_path: Cornell Dataset directory.
14 | :param start: If splitting the dataset, start at this fraction [0,1]
15 | :param end: If splitting the dataset, finish at this fraction
16 | :param ds_rotate: If splitting the dataset, rotate the list of items by this fraction first
17 | :param kwargs: kwargs for GraspDatasetBase
18 | """
19 | super(CornellDataset, self).__init__(**kwargs)
20 | file_path="/home/sam/Desktop/multiobject/rgbd"
21 | graspf = glob.glob(os.path.join(file_path,'rgb_*.txt'))
22 | graspf.sort()
23 | l = 95
24 | if l == 0:
25 | raise FileNotFoundError('No dataset files found. Check path: {}'.format(file_path))
26 |
27 | if ds_rotate:
28 | graspf = graspf[int(l*ds_rotate):] + graspf[:int(l*ds_rotate)]
29 |
30 | depthf =glob.glob(os.path.join(file_path,'depth_*.png'))
31 | depthf.sort()
32 | rgbf =glob.glob(os.path.join(file_path,'rgb*.jpg'))
33 | rgbf.sort()
34 |
35 | self.grasp_files = graspf[int(l*start):int(l*end)]
36 | self.depth_files = depthf[int(l*start):int(l*end)]
37 | self.rgb_files = rgbf[int(l*start):int(l*end)]
38 |
39 | def _get_crop_attrs(self, idx):
40 | gtbbs = grasp.GraspRectangles.load_from_cornell_file(self.grasp_files[idx])
41 | center = gtbbs.center
42 | left = max(0, min(center[1] - self.output_size // 2, 640 - self.output_size))
43 | top = max(0, min(center[0] - self.output_size // 2, 480 - self.output_size))
44 | return center, left, top
45 |
46 | def get_gtbb(self, idx, rot=0, zoom=1.0):
47 | gtbbs = grasp.GraspRectangles.load_from_cornell_file(self.grasp_files[idx])
48 | center, left, top = self._get_crop_attrs(idx)
49 | gtbbs.rotate(rot, center)
50 | gtbbs.offset((-top, -left))
51 | gtbbs.zoom(zoom, (self.output_size//2, self.output_size//2))
52 | return gtbbs
53 |
54 | def get_depth(self, idx, rot=0, zoom=1.0):
55 | depth_img = image.DepthImage.from_tiff(self.depth_files[idx])
56 | center, left, top = self._get_crop_attrs(idx)
57 | depth_img.rotate(rot, center)
58 | depth_img.crop((top, left), (min(480, top + self.output_size), min(640, left + self.output_size)))
59 | depth_img.normalise()
60 | depth_img.zoom(zoom)
61 | depth_img.resize((self.output_size, self.output_size))
62 | return depth_img.img
63 |
64 | def get_rgb(self, idx, rot=0, zoom=1.0, normalise=True):
65 | rgb_img = image.Image.from_file(self.rgb_files[idx])
66 | center, left, top = self._get_crop_attrs(idx)
67 | rgb_img.rotate(rot, center)
68 | rgb_img.crop((top, left), (min(480, top + self.output_size), min(640, left + self.output_size)))
69 | rgb_img.zoom(zoom)
70 | rgb_img.resize((self.output_size, self.output_size))
71 | if normalise:
72 | rgb_img.normalise()
73 | rgb_img.img = rgb_img.img.transpose((2, 0, 1))
74 | return rgb_img.img
75 |
--------------------------------------------------------------------------------
/utils/data/multigrasp_object.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | origin_path = os.getcwd()#记录一下原始的路径
4 | os.chdir('/home/sam/Desktop/multiobject/rgbd')#这是我的路径
5 | print(os.listdir())
6 | path='/home/sam/Desktop/multiobject/rgbd'
7 | depth = glob.glob(os.path.join(path,'depth_*.png'))
8 | depth.sort()
9 | print(depth[0:10])
10 |
11 | rgb=glob.glob(os.path.join(path,'rgb*.jpg'))
12 | rgb.sort()
13 | print(rgb[0:10])
14 |
15 | label=glob.glob(os.path.join(path,'rgb_*.txt'))
16 | label.sort()
17 | print(label[0:10])
18 |
19 | from PIL import Image
20 | import matplotlib.pyplot as plt
21 | # plt.figure(figsize=(15,15))
22 | # for i in range(9):
23 | # img = Image.open(rgb[i])
24 | # plt.subplot(331+i)
25 | # plt.imshow(img)
26 | # plt.show()
27 | # plt.figure(figsize=(15,15))
28 | # for i in range(9):
29 | # img = Image.open(depth[i])
30 | # plt.subplot(331+i)
31 | # plt.imshow(img)
32 | # plt.show()
33 |
34 | # with open(label[0],'r') as f:
35 | # grasp_data = f.read()
36 | # print(grasp_data)
37 | # grasp_data = [grasp.strip() for grasp in grasp_data]#去除末尾换行符
38 |
39 | def str2num(point):
40 | '''
41 | :参数 :point,字符串,以字符串形式存储的一个点的坐标
42 | :返回值 :列表,包含int型抓取点数据的列表[x,y]
43 |
44 | '''
45 | x, y = point.split()
46 | x, y = int(round(float(x))), int(round(float(y)))
47 |
48 | return (x, y)
49 |
50 |
51 | def get_rectangle(cornell_grasp_file):
52 | '''
53 | :参数 :cornell_grap_file:字符串,指向某个抓取文件的路径
54 | :返回值 :列表,包含各个抓取矩形数据的列表
55 |
56 | '''
57 | grasp_rectangles = []
58 | with open(cornell_grasp_file, 'r') as f:
59 | while True:
60 | grasp_rectangle = []
61 | point0 = f.readline().strip()
62 | if not point0:
63 | break
64 | point1, point2, point3 = f.readline().strip(), f.readline().strip(), f.readline().strip()
65 | grasp_rectangle = [str2num(point0),
66 | str2num(point1),
67 | str2num(point2),
68 | str2num(point3)]
69 | grasp_rectangles.append(grasp_rectangle)
70 |
71 | return grasp_rectangles
72 | i=90
73 | grs = get_rectangle(label[i])
74 | print(grs)
75 | import cv2
76 | import random
77 | import cv2
78 | img = cv2.imread(rgb[i])
79 | for gr in grs:
80 | # 产生随机颜色
81 | color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
82 | # 绘制添加矩形框
83 | for i in range(3): # 因为一个框只有四条线,所以这里是3
84 | img = cv2.line(img, gr[i], gr[i + 1], color, 3)
85 | img = cv2.line(img, gr[3], gr[0], color, 2) # 补上最后一条封闭的线
86 |
87 | plt.figure(figsize=(10, 10))
88 | plt.imshow(img) # 之前用cv2.imshow,显示倒是能显示,就是服务老是挂掉,现在索性换成这个
89 | plt.show()
90 | plt.imshow(img)
91 | plt.show()
--------------------------------------------------------------------------------
/utils/dataset_processing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/dataset_processing/__init__.py
--------------------------------------------------------------------------------
/utils/dataset_processing/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/dataset_processing/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/dataset_processing/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/dataset_processing/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/dataset_processing/__pycache__/evaluation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/dataset_processing/__pycache__/evaluation.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/dataset_processing/__pycache__/evaluation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/dataset_processing/__pycache__/evaluation.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/dataset_processing/__pycache__/grasp.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/dataset_processing/__pycache__/grasp.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/dataset_processing/__pycache__/grasp.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/dataset_processing/__pycache__/grasp.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/dataset_processing/__pycache__/image.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/dataset_processing/__pycache__/image.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/dataset_processing/__pycache__/image.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/dataset_processing/__pycache__/image.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/dataset_processing/evaluation.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 |
5 | from .grasp import GraspRectangles, detect_grasps
6 | plt.rcParams.update({
7 | "text.usetex": True,
8 | "font.family": "sans-serif",
9 | "font.sans-serif": ["Helvetica"]})
10 | matplotlib.use("TkAgg")
11 | counter=100
12 | def plot_output(rgb_img,rgb_img_1, depth_img, grasp_q_img, grasp_angle_img, no_grasps=1, grasp_width_img=None,
13 | grasp_q_img_ggcnn=None,grasp_angle_img_ggcnn=None,grasp_width_img_ggcnn=None):
14 | """
15 | Plot the output of a GG-CNN
16 | :param rgb_img: RGB Image
17 | :param depth_img: Depth Image
18 | :param grasp_q_img: Q output of GG-CNN
19 | :param grasp_angle_img: Angle output of GG-CNN
20 | :param no_grasps: Maximum number of grasps to plot
21 | :param grasp_width_img: (optional) Width output of GG-CNN
22 | :return:
23 | """
24 | global counter
25 |
26 | gs_1 = detect_grasps(grasp_q_img, grasp_angle_img, width_img=grasp_width_img, no_grasps=5)
27 | # print(len(gs_1))
28 | fig = plt.figure(figsize=(10, 10))
29 | ax = fig.add_subplot(3, 3, 1)
30 | ax.imshow(rgb_img)
31 | for g in gs_1:
32 | g.plot(ax)
33 | ax.set_title('RGB')
34 | ax.axis('off')
35 | # ax.savefig('/home/sam/compare_conv/Q_%d.pdf' % counter, bbox_inches='tight')
36 |
37 |
38 | gs_2 = detect_grasps(grasp_q_img_ggcnn, grasp_angle_img_ggcnn, width_img=grasp_width_img_ggcnn, no_grasps=5)
39 | print(len(gs_2))
40 | # fig = plt.figure(figsize=(10, 10))
41 | ax = fig.add_subplot(3, 3, 2)
42 | ax.imshow(rgb_img)
43 | for g in gs_2:
44 | g.plot(ax)
45 | ax.set_title('RGB')
46 | ax.axis('off')
47 |
48 | # ax.imshow(depth_img, cmap='gist_rainbow')
49 | # ax.imshow(depth_img,)
50 | # for g in gs:
51 | # g.plot(ax)
52 | # ax.set_title('Depth')
53 | # ax.axis('off')
54 |
55 | ax = fig.add_subplot(3, 3, 3)
56 | # plot = ax.imshow(grasp_width_img, cmap='prism', vmin=-0, vmax=150)
57 | # ax.set_title('q image')
58 |
59 | ax = fig.add_subplot(3, 3, 4)
60 | plot = ax.imshow(grasp_q_img, cmap="jet", vmin=0, vmax=1) #?terrain
61 | plt.colorbar(plot)
62 | ax.axis('off')
63 | ax.set_title('q image')
64 |
65 |
66 | ax = fig.add_subplot(3, 3, 5) #flag prism jet
67 | plot = ax.imshow(grasp_angle_img, cmap="hsv", vmin=-np.pi / 2, vmax=np.pi / 2)
68 | plt.colorbar(plot)
69 | ax.axis('off')
70 | ax.set_title('angle')
71 |
72 | ax = fig.add_subplot(3, 3, 6,)
73 | plot = ax.imshow(grasp_width_img, cmap='jet', vmin=-0, vmax=150)
74 | plt.colorbar(plot)
75 | ax.set_title('width')
76 | ax.axis('off')
77 |
78 |
79 |
80 | ax = fig.add_subplot(3, 3, 7, )
81 | plot = ax.imshow(grasp_q_img_ggcnn, cmap='jet', vmin=0, vmax=1)
82 | ax.set_title('q image')
83 | ax.axis('off')
84 |
85 | ax = fig.add_subplot(3, 3, 8) # flag prism jet
86 | plot = ax.imshow(grasp_angle_img_ggcnn, cmap="hsv", vmin=-np.pi / 2, vmax=np.pi / 2)
87 | ax.axis('off')
88 | ax.set_title('angle')
89 |
90 | ax = fig.add_subplot(3, 3, 9, )
91 | plot = ax.imshow(grasp_width_img_ggcnn, cmap='jet', vmin=-0, vmax=150)
92 | ax.set_title('width')
93 | ax.axis('off')
94 | # for g in gs:
95 | # g.plot(ax)
96 | # ax.set_title('Angle')
97 |
98 | # plt.colorbar(plot)
99 |
100 | # plt.imshow(rgb_img)
101 | plt.show()
102 | if input("input") == "1":
103 | print("333")
104 | # matplotlib.use("Agg")
105 | # plt.margins(0, 0)
106 | plt.imshow(rgb_img)
107 | for g in gs_1:
108 | g.plot(plt)
109 | # plt.axis("off")
110 | # plot=plt.imshow(rgb_img)
111 | # for g in gs:
112 | # g.plot(plot)
113 | plt.axis("off")
114 | plt.savefig('/home/sam/compare_conv/RGB_1_%d.pdf'%counter, bbox_inches='tight')
115 | plt.show()
116 | plt.imshow(rgb_img)
117 | for g in gs_2:
118 | g.plot(plt)
119 | # plt.axis("off")
120 | # plot=plt.imshow(rgb_img)
121 | # for g in gs:
122 | # g.plot(plot)
123 | plt.axis("off")
124 | plt.savefig('/home/sam/compare_conv/RGB_2_%d.pdf' % counter, bbox_inches='tight')
125 | plt.show()
126 |
127 | plot1 = plt.imshow(grasp_q_img, cmap="jet", vmin=0, vmax=1)
128 | plt.axis("off")
129 | plt.colorbar(plot1)
130 | plt.savefig('/home/sam/compare_conv/Q_1_%d.pdf'%counter, bbox_inches='tight')
131 | plt.show()
132 | plot1 = plt.imshow(grasp_q_img_ggcnn, cmap="jet", vmin=0, vmax=1)
133 | plt.axis("off")
134 | plt.colorbar(plot1)
135 | plt.savefig('/home/sam/compare_conv/Q_2_%d.pdf' % counter, bbox_inches='tight')
136 | plt.show()
137 | #
138 | counter=counter+1
139 | # if input("input") == "1":
140 | # print("333")
141 | # # matplotlib.use("Agg")
142 | # # plt.margins(0, 0)
143 | # plt.imshow(rgb_img)
144 | # plt.axis("off")
145 | # plt.savefig('RGB_%d.pdf'%counter, bbox_inches='tight')
146 | # # plt.show()
147 | #
148 | # plot1=plt.imshow(grasp_q_img, cmap="jet", vmin=0, vmax=1)
149 | # plt.axis("off")
150 | # plt.colorbar(plot1)
151 | # plt.savefig('Q_%d.pdf'%counter, bbox_inches='tight')
152 | # plt.show()
153 | #
154 | # plt.imshow(grasp_q_img, cmap="jet", vmin=0, vmax=1)
155 | # plt.axis("off")
156 | # plt.savefig('Q_1_%d.pdf' % counter, bbox_inches='tight')
157 | #
158 | #
159 | # plot2=plt.imshow(grasp_angle_img, cmap="hsv", vmin=-np.pi / 2, vmax=np.pi / 2)
160 | # plt.axis("off")
161 | # plt.colorbar(plot2)
162 | # plt.savefig('Angle_%d.pdf'%counter, bbox_inches='tight')
163 | # plt.show()
164 | #
165 | # plt.imshow(grasp_angle_img, cmap="hsv", vmin=-np.pi / 2, vmax=np.pi / 2)
166 | # plt.axis("off")
167 | # plt.savefig('Angle_1_%d.pdf' % counter, bbox_inches='tight')
168 | #
169 | # plot3=plt.imshow(grasp_width_img_ggcnn, cmap='jet', vmin=-0, vmax=150)
170 | # plt.axis("off")
171 | # plt.colorbar(plot3)
172 | # plt.savefig('Width_%d.pdf'%counter, bbox_inches='tight')
173 | # plt.show()
174 | #
175 | # plt.imshow(grasp_width_img_ggcnn, cmap='jet', vmin=-0, vmax=150)
176 | # plt.axis("off")
177 | # plt.savefig('Width_1_%d.pdf' % counter, bbox_inches='tight')
178 | # counter=counter+1
179 | # matplotlib.use("TkAgg")
180 | # if (input())==str(1):
181 | # plt.figure(figsize=(5,5))
182 | # plt.imshow(rgb_img)
183 | # for g in gs:
184 | # g.plot(plt)
185 | # plt.axis("off")
186 | # plt.tight_layout()
187 | # plt.show()
188 | #
189 | # plt.figure(figsize=(5, 5))
190 | # # plt.imshow(depth_img, cmap='gist_gray')
191 | # plt.imshow(depth_img, )
192 | # plt.axis("off")
193 | # for g in gs:
194 | # g.plot(plt)
195 | # plt.tight_layout()
196 | # plt.show()
197 | #
198 | # plt.figure(figsize=(5, 5))
199 | # plt.imshow(grasp_q_img, cmap="terrain", vmin=0, vmax=1)
200 | # plt.axis("off")
201 | # plt.tight_layout()
202 | # plt.show()
203 | #
204 | # plt.figure(figsize=(5, 5))
205 | # plt.imshow(grasp_angle_img, cmap="prism", vmin=-np.pi / 2, vmax=np.pi / 2)
206 | # plt.axis("off")
207 | # plt.tight_layout()
208 | # plt.show()
209 | #
210 | # plt.figure(figsize=(5, 5))
211 | # plt.imshow(grasp_width_img, cmap='hsv', vmin=-0, vmax=150)
212 | # plt.axis("off")
213 | # plt.tight_layout()
214 | # plt.show()
215 | #
216 | #
217 | # plt.figure(figsize=(5, 5))
218 | # plt.imshow(grasp_q_img_ggcnn, cmap="terrain", vmin=0, vmax=1)
219 | # plt.axis("off")
220 | # plt.tight_layout()
221 | # plt.show()
222 | #
223 | # plt.figure(figsize=(5, 5))
224 | # plt.imshow(grasp_angle_img_ggcnn, cmap="hsv", vmin=-np.pi / 2, vmax=np.pi / 2)
225 | # plt.axis("off")
226 | # plt.tight_layout()
227 | # plt.show()
228 | #
229 | # plt.figure(figsize=(5, 5))
230 | # plt.imshow(grasp_width_img_ggcnn, cmap='hsv', vmin=-0, vmax=150)
231 | # plt.axis("off")
232 | # plt.tight_layout()
233 | # plt.show()
234 | def calculate_iou_match(grasp_q, grasp_angle, ground_truth_bbs, no_grasps=1, grasp_width=None):
235 | """
236 | Calculate grasp success using the IoU (Jacquard) metric (e.g. in https://arxiv.org/abs/1301.3592)
237 | A success is counted if grasp rectangle has a 25% IoU with a ground truth, and is withing 30 degrees.
238 | :param grasp_q: Q outputs of GG-CNN (Nx300x300x3)
239 | :param grasp_angle: Angle outputs of GG-CNN
240 | :param ground_truth_bbs: Corresponding ground-truth BoundingBoxes
241 | :param no_grasps: Maximum number of grasps to consider per image.
242 | :param grasp_width: (optional) Width output from GG-CNN
243 | :return: success
244 | """
245 |
246 | if not isinstance(ground_truth_bbs, GraspRectangles):
247 | gt_bbs = GraspRectangles.load_from_array(ground_truth_bbs)
248 | else:
249 | gt_bbs = ground_truth_bbs
250 | gs = detect_grasps(grasp_q, grasp_angle, width_img=grasp_width, no_grasps=no_grasps)
251 | for g in gs:
252 | if g.max_iou(gt_bbs) > 0.25:
253 | return True
254 | else:
255 | return False
--------------------------------------------------------------------------------
/utils/dataset_processing/generate_cornell_depth.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import numpy as np
4 | from imageio import imsave
5 | import argparse
6 | from utils.dataset_processing.image import DepthImage
7 |
8 |
9 | if __name__ == '__main__':
10 | parser = argparse.ArgumentParser(description='Generate depth images from Cornell PCD files.')
11 | parser.add_argument('path', type=str, help='Path to Cornell Grasping Dataset')
12 | args = parser.parse_args()
13 |
14 | pcds = glob.glob(os.path.join(args.path, '*', 'pcd*[0-9].txt'))
15 | pcds.sort()
16 |
17 | for pcd in pcds:
18 | di = DepthImage.from_pcd(pcd, (480, 640))
19 | di.inpaint()
20 |
21 | of_name = pcd.replace('.txt', 'd.tiff')
22 | print(of_name)
23 | imsave(of_name, di.img.astype(np.float32))
--------------------------------------------------------------------------------
/utils/dataset_processing/grasp.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import matplotlib.pyplot as plt
4 |
5 | from skimage.draw import polygon
6 | from skimage.feature import peak_local_max
7 |
8 |
9 | def _gr_text_to_no(l, offset=(0, 0)):
10 | """
11 | Transform a single point from a Cornell file line to a pair of ints.
12 | :param l: Line from Cornell grasp file (str)
13 | :param offset: Offset to apply to point positions
14 | :return: Point [y, x]
15 | """
16 | x, y = l.split()
17 | return [int(round(float(y))) - offset[0], int(round(float(x))) - offset[1]]
18 |
19 |
20 | class GraspRectangles:
21 | """
22 | Convenience class for loading and operating on sets of Grasp Rectangles.
23 | """
24 | def __init__(self, grs=None):
25 | if grs:
26 | self.grs = grs
27 | else:
28 | self.grs = []
29 |
30 | def __getitem__(self, item):
31 | return self.grs[item]
32 |
33 | def __iter__(self):
34 | return self.grs.__iter__()
35 |
36 | def __getattr__(self, attr):
37 | """
38 | Test if GraspRectangle has the desired attr as a function and call it.
39 | """
40 | # Fuck yeah python.
41 | if hasattr(GraspRectangle, attr) and callable(getattr(GraspRectangle, attr)):
42 | return lambda *args, **kwargs: list(map(lambda gr: getattr(gr, attr)(*args, **kwargs), self.grs))
43 | else:
44 | raise AttributeError("Couldn't find function %s in BoundingBoxes or BoundingBox" % attr)
45 |
46 | @classmethod
47 | def load_from_array(cls, arr):
48 | """
49 | Load grasp rectangles from numpy array.
50 | :param arr: Nx4x2 array, where each 4x2 array is the 4 corner pixels of a grasp rectangle.
51 | :return: GraspRectangles()
52 | """
53 | grs = []
54 | for i in range(arr.shape[0]):
55 | grp = arr[i, :, :].squeeze()
56 | if grp.max() == 0:
57 | break
58 | else:
59 | grs.append(GraspRectangle(grp))
60 | return cls(grs)
61 |
62 | @classmethod
63 | def load_from_cornell_file(cls, fname):
64 | """
65 | Load grasp rectangles from a Cornell dataset grasp file.
66 | :param fname: Path to text file.
67 | :return: GraspRectangles()
68 | """
69 | grs = []
70 | with open(fname) as f:
71 | while True:
72 | # Load 4 lines at a time, corners of bounding box.
73 | p0 = f.readline()
74 | if not p0:
75 | break # EOF
76 | p1, p2, p3 = f.readline(), f.readline(), f.readline()
77 | try:
78 | gr = np.array([
79 | _gr_text_to_no(p0),
80 | _gr_text_to_no(p1),
81 | _gr_text_to_no(p2),
82 | _gr_text_to_no(p3)
83 | ])
84 |
85 | grs.append(GraspRectangle(gr))
86 |
87 | except ValueError:
88 | # Some files contain weird values.
89 | continue
90 | return cls(grs)
91 |
92 | @classmethod
93 | def load_from_jacquard_file(cls, fname, scale=1.0):
94 | """
95 | Load grasp rectangles from a Jacquard dataset file.
96 | :param fname: Path to file.
97 | :param scale: Scale to apply (e.g. if resizing images)
98 | :return: GraspRectangles()
99 | """
100 | grs = []
101 | with open(fname) as f:
102 | for l in f:
103 | x, y, theta, w, h = [float(v) for v in l[:-1].split(';')]
104 | # index based on row, column (y,x), and the Jacquard dataset's angles are flipped around an axis.
105 | grs.append(Grasp(np.array([y, x]), -theta/180.0*np.pi, w, h).as_gr)
106 | grs = cls(grs)
107 | grs.scale(scale)
108 | return grs
109 |
110 | def append(self, gr):
111 | """
112 | Add a grasp rectangle to this GraspRectangles object
113 | :param gr: GraspRectangle
114 | """
115 | self.grs.append(gr)
116 |
117 | def copy(self):
118 | """
119 | :return: A deep copy of this object and all of its GraspRectangles.
120 | """
121 | new_grs = GraspRectangles()
122 | for gr in self.grs:
123 | new_grs.append(gr.copy())
124 | return new_grs
125 |
126 | def show(self, ax=None, shape=None):
127 | """
128 | Draw all GraspRectangles on a matplotlib plot.
129 | :param ax: (optional) existing axis
130 | :param shape: (optional) Plot shape if no existing axis
131 | """
132 | if ax is None:
133 | f = plt.figure()
134 | ax = f.add_subplot(1, 1, 1)
135 | ax.imshow(np.zeros(shape))
136 | ax.axis([0, shape[1], shape[0], 0])
137 | self.plot(ax)
138 | plt.show()
139 | else:
140 | self.plot(ax)
141 |
142 | def draw(self, shape, position=True, angle=True, width=True):
143 | """
144 | Plot all GraspRectangles as solid rectangles in a numpy array, e.g. as network training data.
145 | :param shape: output shape
146 | :param position: If True, Q output will be produced
147 | :param angle: If True, Angle output will be produced
148 | :param width: If True, Width output will be produced
149 | :return: Q, Angle, Width outputs (or None)
150 | """
151 | if position:
152 | pos_out = np.zeros(shape)
153 | else:
154 | pos_out = None
155 | if angle:
156 | ang_out = np.zeros(shape)
157 | else:
158 | ang_out = None
159 | if width:
160 | width_out = np.zeros(shape)
161 | else:
162 | width_out = None
163 |
164 | for gr in self.grs:
165 | rr, cc = gr.compact_polygon_coords(shape)
166 | if position:
167 | pos_out[rr, cc] = 1.0
168 | if angle:
169 | ang_out[rr, cc] = gr.angle
170 | if width:
171 | width_out[rr, cc] = gr.length
172 |
173 | return pos_out, ang_out, width_out
174 |
175 | def to_array(self, pad_to=0):
176 | """
177 | Convert all GraspRectangles to a single array.
178 | :param pad_to: Length to 0-pad the array along the first dimension
179 | :return: Nx4x2 numpy array
180 | """
181 | a = np.stack([gr.points for gr in self.grs])
182 | if pad_to:
183 | if pad_to > len(self.grs):
184 | a = np.concatenate((a, np.zeros((pad_to - len(self.grs), 4, 2))))
185 | return a.astype(np.int)
186 |
187 | @property
188 | def center(self):
189 | """
190 | Compute mean center of all GraspRectangles
191 | :return: float, mean centre of all GraspRectangles
192 | """
193 | points = [gr.points for gr in self.grs]
194 | return np.mean(np.vstack(points), axis=0).astype(np.int)
195 |
196 |
197 | class GraspRectangle:
198 | """
199 | Representation of a grasp in the common "Grasp Rectangle" format.
200 | """
201 | def __init__(self, points):
202 | self.points = points
203 |
204 | def __str__(self):
205 | return str(self.points)
206 |
207 | @property
208 | def angle(self):
209 | """
210 | :return: Angle of the grasp to the horizontal.
211 | """
212 | dx = self.points[1, 1] - self.points[0, 1]
213 | dy = self.points[1, 0] - self.points[0, 0]
214 | return (np.arctan2(-dy, dx) + np.pi/2) % np.pi - np.pi/2
215 |
216 | @property
217 | def as_grasp(self):
218 | """
219 | :return: GraspRectangle converted to a Grasp
220 | """
221 | return Grasp(self.center, self.angle, self.length, self.width)
222 |
223 | @property
224 | def center(self):
225 | """
226 | :return: Rectangle center point
227 | """
228 | return self.points.mean(axis=0).astype(np.int)
229 |
230 | @property
231 | def length(self):
232 | """
233 | :return: Rectangle length (i.e. along the axis of the grasp)
234 | """
235 | dx = self.points[1, 1] - self.points[0, 1]
236 | dy = self.points[1, 0] - self.points[0, 0]
237 | return np.sqrt(dx ** 2 + dy ** 2)
238 |
239 | @property
240 | def width(self):
241 | """
242 | :return: Rectangle width (i.e. perpendicular to the axis of the grasp)
243 | """
244 | dy = self.points[2, 1] - self.points[1, 1]
245 | dx = self.points[2, 0] - self.points[1, 0]
246 | return np.sqrt(dx ** 2 + dy ** 2)
247 |
248 | def polygon_coords(self, shape=None):
249 | """
250 | :param shape: Output Shape
251 | :return: Indices of pixels within the grasp rectangle polygon.
252 | """
253 | return polygon(self.points[:, 0], self.points[:, 1], shape)
254 |
255 | def compact_polygon_coords(self, shape=None):
256 | """
257 | :param shape: Output shape
258 | :return: Indices of pixels within the centre thrid of the grasp rectangle.
259 | """
260 | return Grasp(self.center, self.angle, self.length/3, self.width).as_gr.polygon_coords(shape)
261 |
262 | def iou(self, gr, angle_threshold=np.pi/6):
263 | """
264 | Compute IoU with another grasping rectangle
265 | :param gr: GraspingRectangle to compare
266 | :param angle_threshold: Maximum angle difference between GraspRectangles
267 | :return: IoU between Grasp Rectangles
268 | """
269 | if abs((self.angle - gr.angle + np.pi/2) % np.pi - np.pi/2) > angle_threshold:
270 | return 0
271 |
272 | rr1, cc1 = self.polygon_coords()
273 | rr2, cc2 = polygon(gr.points[:, 0], gr.points[:, 1])
274 |
275 | try:
276 | r_max = max(rr1.max(), rr2.max()) + 1
277 | c_max = max(cc1.max(), cc2.max()) + 1
278 | except:
279 | return 0
280 |
281 | canvas = np.zeros((r_max, c_max))
282 | canvas[rr1, cc1] += 1
283 | canvas[rr2, cc2] += 1
284 | union = np.sum(canvas > 0)
285 | if union == 0:
286 | return 0
287 | intersection = np.sum(canvas == 2)
288 | return intersection/union
289 |
290 | def copy(self):
291 | """
292 | :return: Copy of self.
293 | """
294 | return GraspRectangle(self.points.copy())
295 |
296 | def offset(self, offset):
297 | """
298 | Offset grasp rectangle
299 | :param offset: array [y, x] distance to offset
300 | """
301 | self.points += np.array(offset).reshape((1, 2))
302 |
303 | def rotate(self, angle, center):
304 | """
305 | Rotate grasp rectangle
306 | :param angle: Angle to rotate (in radians)
307 | :param center: Point to rotate around (e.g. image center)
308 | """
309 | R = np.array(
310 | [
311 | [np.cos(-angle), np.sin(-angle)],
312 | [-1 * np.sin(-angle), np.cos(-angle)],
313 | ]
314 | )
315 | c = np.array(center).reshape((1, 2))
316 | self.points = ((np.dot(R, (self.points - c).T)).T + c).astype(np.int)
317 |
318 | def scale(self, factor):
319 | """
320 | :param factor: Scale grasp rectangle by factor
321 | """
322 | if factor == 1.0:
323 | return
324 | self.points *= factor
325 |
326 | def plot(self, ax, color=None):
327 | """
328 | Plot grasping rectangle.
329 | :param ax: Existing matplotlib axis
330 | :param color: matplotlib color code (optional)
331 | """
332 | points = np.vstack((self.points, self.points[0]))
333 | ax.plot(points[:, 1], points[:, 0], color=color)
334 |
335 | def zoom(self, factor, center):
336 | """
337 | Zoom grasp rectangle by given factor.
338 | :param factor: Zoom factor
339 | :param center: Zoom zenter (focus point, e.g. image center)
340 | """
341 | T = np.array(
342 | [
343 | [1/factor, 0],
344 | [0, 1/factor]
345 | ]
346 | )
347 | c = np.array(center).reshape((1, 2))
348 | self.points = ((np.dot(T, (self.points - c).T)).T + c).astype(np.int)
349 |
350 |
351 | class Grasp:
352 | """
353 | A Grasp represented by a center pixel, rotation angle and gripper width (length)
354 | """
355 | def __init__(self, center, angle, length=60, width=30):
356 | self.center = center
357 | self.angle = angle # Positive angle means rotate anti-clockwise from horizontal.
358 | self.length = length
359 | self.width = width
360 |
361 | @property
362 | def as_gr(self):
363 | """
364 | Convert to GraspRectangle
365 | :return: GraspRectangle representation of grasp.
366 | """
367 | xo = np.cos(self.angle)
368 | yo = np.sin(self.angle)
369 |
370 | y1 = self.center[0] + self.length / 2 * yo
371 | x1 = self.center[1] - self.length / 2 * xo
372 | y2 = self.center[0] - self.length / 2 * yo
373 | x2 = self.center[1] + self.length / 2 * xo
374 |
375 | return GraspRectangle(np.array(
376 | [
377 | [y1 - self.width/2 * xo, x1 - self.width/2 * yo],
378 | [y2 - self.width/2 * xo, x2 - self.width/2 * yo],
379 | [y2 + self.width/2 * xo, x2 + self.width/2 * yo],
380 | [y1 + self.width/2 * xo, x1 + self.width/2 * yo],
381 | ]
382 | ).astype(np.float))
383 |
384 | def max_iou(self, grs):
385 | """
386 | Return maximum IoU between self and a list of GraspRectangles
387 | :param grs: List of GraspRectangles
388 | :return: Maximum IoU with any of the GraspRectangles
389 | """
390 | self_gr = self.as_gr
391 | max_iou = 0
392 | for gr in grs:
393 | iou = self_gr.iou(gr)
394 | max_iou = max(max_iou, iou)
395 | return max_iou
396 |
397 | def plot(self, ax, color=None):
398 | """
399 | Plot Grasp
400 | :param ax: Existing matplotlib axis
401 | :param color: (optional) color
402 | """
403 | self.as_gr.plot(ax, color)
404 |
405 | def to_jacquard(self, scale=1):
406 | """
407 | Output grasp in "Jacquard Dataset Format" (https://jacquard.liris.cnrs.fr/database.php)
408 | :param scale: (optional) scale to apply to grasp
409 | :return: string in Jacquard format
410 | """
411 | # Output in jacquard format.
412 | return '%0.2f;%0.2f;%0.2f;%0.2f;%0.2f' % (self.center[1]*scale, self.center[0]*scale, -1*self.angle*180/np.pi, self.length*scale, self.width*scale)
413 |
414 |
415 | def detect_grasps(q_img, ang_img, width_img=None, no_grasps=1):
416 | """
417 | Detect grasps in a GG-CNN output.
418 | :param q_img: Q image network output
419 | :param ang_img: Angle image network output
420 | :param width_img: (optional) Width image network output
421 | :param no_grasps: Max number of grasps to return
422 | :return: list of Grasps
423 | """
424 | local_max = peak_local_max(q_img, min_distance=20, threshold_abs=0.2, num_peaks=no_grasps)
425 |
426 | grasps = []
427 | for grasp_point_array in local_max:
428 | grasp_point = tuple(grasp_point_array)
429 |
430 | grasp_angle = ang_img[grasp_point]
431 |
432 | g = Grasp(grasp_point, grasp_angle)
433 | if width_img is not None:
434 | g.length = width_img[grasp_point]
435 | g.width = g.length/2
436 |
437 | grasps.append(g)
438 |
439 | return grasps
440 |
--------------------------------------------------------------------------------
/utils/dataset_processing/image.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | from imageio import imread
5 | from skimage.transform import rotate, resize
6 |
7 | import warnings
8 | warnings.filterwarnings("ignore", category=UserWarning)
9 |
10 |
11 | class Image:
12 | """
13 | Wrapper around an image with some convenient functions.
14 | """
15 | def __init__(self, img):
16 | self.img = img
17 |
18 | def __getattr__(self, attr):
19 | # Pass along any other methods to the underlying ndarray
20 | return getattr(self.img, attr)
21 |
22 | @classmethod
23 | def from_file(cls, fname):
24 | return cls(imread(fname))
25 |
26 | def copy(self):
27 | """
28 | :return: Copy of self.
29 | """
30 | return self.__class__(self.img.copy())
31 |
32 | def crop(self, top_left, bottom_right, resize=None):
33 | """
34 | Crop the image to a bounding box given by top left and bottom right pixels.
35 | :param top_left: tuple, top left pixel.
36 | :param bottom_right: tuple, bottom right pixel
37 | :param resize: If specified, resize the cropped image to this size
38 | """
39 | self.img = self.img[top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]]
40 | if resize is not None:
41 | self.resize(resize)
42 |
43 | def cropped(self, *args, **kwargs):
44 | """
45 | :return: Cropped copy of the image.
46 | """
47 | i = self.copy()
48 | i.crop(*args, **kwargs)
49 | return i
50 |
51 | def normalise(self):
52 | """
53 | Normalise the image by converting to float [0,1] and zero-centering
54 | """
55 | self.img = self.img.astype(np.float32)/255.0
56 | self.img -= self.img.mean()
57 |
58 | def resize(self, shape):
59 | """
60 | Resize image to shape.
61 | :param shape: New shape.
62 | """
63 | if self.img.shape == shape:
64 | return
65 | self.img = resize(self.img, shape, preserve_range=True).astype(self.img.dtype)
66 |
67 | def resized(self, *args, **kwargs):
68 | """
69 | :return: Resized copy of the image.
70 | """
71 | i = self.copy()
72 | i.resize(*args, **kwargs)
73 | return i
74 |
75 | def rotate(self, angle, center=None):
76 | """
77 | Rotate the image.
78 | :param angle: Angle (in radians) to rotate by.
79 | :param center: Center pixel to rotate if specified, otherwise image center is used.
80 | """
81 | if center is not None:
82 | center = (center[1], center[0])
83 | self.img = rotate(self.img, angle/np.pi*180, center=center, mode='symmetric', preserve_range=True).astype(self.img.dtype)
84 |
85 | def rotated(self, *args, **kwargs):
86 | """
87 | :return: Rotated copy of image.
88 | """
89 | i = self.copy()
90 | i.rotate(*args, **kwargs)
91 | return i
92 |
93 | def show(self, ax=None, **kwargs):
94 | """
95 | Plot the image
96 | :param ax: Existing matplotlib axis (optional)
97 | :param kwargs: kwargs to imshow
98 | """
99 | if ax:
100 | ax.imshow(self.img, **kwargs)
101 | else:
102 | plt.imshow(self.img, **kwargs)
103 | plt.show()
104 |
105 | def zoom(self, factor):
106 | """
107 | "Zoom" the image by cropping and resizing.
108 | :param factor: Factor to zoom by. e.g. 0.5 will keep the center 50% of the image.
109 | """
110 | sr = int(self.img.shape[0] * (1 - factor)) // 2
111 | sc = int(self.img.shape[1] * (1 - factor)) // 2
112 | orig_shape = self.img.shape
113 | self.img = self.img[sr:self.img.shape[0] - sr, sc: self.img.shape[1] - sc].copy()
114 | self.img = resize(self.img, orig_shape, mode='symmetric', preserve_range=True).astype(self.img.dtype)
115 |
116 | def zoomed(self, *args, **kwargs):
117 | """
118 | :return: Zoomed copy of the image.
119 | """
120 | i = self.copy()
121 | i.zoom(*args, **kwargs)
122 | return i
123 |
124 |
125 | class DepthImage(Image):
126 | def __init__(self, img):
127 | super().__init__(img)
128 |
129 | @classmethod
130 | def from_pcd(cls, pcd_filename, shape, default_filler=0, index=None):
131 | """
132 | Create a depth image from an unstructured PCD file.
133 | If index isn't specified, use euclidean distance, otherwise choose x/y/z=0/1/2
134 | """
135 | img = np.zeros(shape)
136 | if default_filler != 0:
137 | img += default_filler
138 |
139 | with open(pcd_filename) as f:
140 | for l in f.readlines():
141 | ls = l.split()
142 |
143 | if len(ls) != 5:
144 | # Not a point line in the file.
145 | continue
146 | try:
147 | # Not a number, carry on.
148 | float(ls[0])
149 | except ValueError:
150 | continue
151 |
152 | i = int(ls[4])
153 | r = i // shape[1]
154 | c = i % shape[1]
155 |
156 | if index is None:
157 | x = float(ls[0])
158 | y = float(ls[1])
159 | z = float(ls[2])
160 |
161 | img[r, c] = np.sqrt(x ** 2 + y ** 2 + z ** 2)
162 |
163 | else:
164 | img[r, c] = float(ls[index])
165 |
166 | return cls(img/1000.0)
167 |
168 | @classmethod
169 | def from_tiff(cls, fname):
170 | return cls(imread(fname))
171 |
172 | def inpaint(self, missing_value=0):
173 | """
174 | Inpaint missing values in depth image.
175 | :param missing_value: Value to fill in teh depth image.
176 | """
177 | # cv2 inpainting doesn't handle the border properly
178 | # https://stackoverflow.com/questions/25974033/inpainting-depth-map-still-a-black-image-border
179 | self.img = cv2.copyMakeBorder(self.img, 1, 1, 1, 1, cv2.BORDER_DEFAULT)
180 | mask = (self.img == missing_value).astype(np.uint8)
181 |
182 | # Scale to keep as float, but has to be in bounds -1:1 to keep opencv happy.
183 | scale = np.abs(self.img).max()
184 | self.img = self.img.astype(np.float32) / scale # Has to be float32, 64 not supported.
185 | self.img = cv2.inpaint(self.img, mask, 1, cv2.INPAINT_NS)
186 |
187 | # Back to original size and value range.
188 | self.img = self.img[1:-1, 1:-1]
189 | self.img = self.img * scale
190 |
191 | def gradients(self):
192 | """
193 | Compute gradients of the depth image using Sobel filtesr.
194 | :return: Gradients in X direction, Gradients in Y diretion, Magnitude of XY gradients.
195 | """
196 | grad_x = cv2.Sobel(self.img, cv2.CV_64F, 1, 0, borderType=cv2.BORDER_DEFAULT)
197 | grad_y = cv2.Sobel(self.img, cv2.CV_64F, 0, 1, borderType=cv2.BORDER_DEFAULT)
198 | grad = np.sqrt(grad_x ** 2 + grad_y ** 2)
199 |
200 | return DepthImage(grad_x), DepthImage(grad_y), DepthImage(grad)
201 |
202 | def normalise(self):
203 | """
204 | Normalise by subtracting the mean and clippint [-1, 1]
205 | """
206 | self.img = np.clip((self.img - self.img.mean()), -1, 1)
207 |
208 |
209 | class WidthImage(Image):
210 | """
211 | A width image is one that describes the desired gripper width at each pixel.
212 | """
213 | def zoom(self, factor):
214 | """
215 | "Zoom" the image by cropping and resizing. Also scales the width accordingly.
216 | :param factor: Factor to zoom by. e.g. 0.5 will keep the center 50% of the image.
217 | """
218 | super().zoom(factor)
219 | self.img = self.img/factor
220 |
221 | def normalise(self):
222 | """
223 | Normalise by mapping [0, 150] -> [0, 1]
224 | """
225 | self.img = np.clip(self.img, 0, 150.0)/150.0
226 |
--------------------------------------------------------------------------------
/utils/dataset_processing/multigrasp_object.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/dataset_processing/multigrasp_object.py
--------------------------------------------------------------------------------
/utils/timeit.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 |
4 | class TimeIt:
5 | print_output = True
6 | last_parent = None
7 | level = -1
8 |
9 | def __init__(self, s):
10 | self.s = s
11 | self.t0 = None
12 | self.t1 = None
13 | self.outputs = []
14 | self.parent = None
15 |
16 | def __enter__(self):
17 | self.t0 = time.time()
18 | self.parent = TimeIt.last_parent
19 | TimeIt.last_parent = self
20 | TimeIt.level += 1
21 |
22 | def __exit__(self, t, value, traceback):
23 | self.t1 = time.time()
24 | st = '%s%s: %0.1fms' % (' ' * TimeIt.level, self.s, (self.t1 - self.t0)*1000)
25 | TimeIt.level -= 1
26 |
27 | if self.parent:
28 | self.parent.outputs.append(st)
29 | self.parent.outputs += self.outputs
30 | else:
31 | if TimeIt.print_output:
32 | print(st)
33 | for o in self.outputs:
34 | print(o)
35 | self.outputs = []
36 |
37 | TimeIt.last_parent = self.parent
38 |
--------------------------------------------------------------------------------
/utils/visualisation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/visualisation/__init__.py
--------------------------------------------------------------------------------
/utils/visualisation/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/visualisation/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/visualisation/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/visualisation/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/visualisation/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/visualisation/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/visualisation/__pycache__/gridshow.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/visualisation/__pycache__/gridshow.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/visualisation/__pycache__/gridshow.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/visualisation/__pycache__/gridshow.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/visualisation/__pycache__/gridshow.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WangShaoSUN/grasp-transformer/ade36864ffbb77dac07363671f6c6c6eee536bcf/utils/visualisation/__pycache__/gridshow.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/visualisation/gridshow.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 |
5 | def gridshow(name, imgs, scales, cmaps, width, border=10):
6 | """
7 | Display images in a grid.
8 | :param name: cv2 Window Name to update
9 | :param imgs: List of Images (np.ndarrays)
10 | :param scales: The min/max scale of images to properly scale the colormaps
11 | :param cmaps: List of cv2 Colormaps to apply
12 | :param width: Number of images in a row
13 | :param border: Border (pixels) between images.
14 | """
15 | imgrows = []
16 | imgcols = []
17 |
18 | maxh = 0
19 | for i, (img, cmap, scale) in enumerate(zip(imgs, cmaps, scales)):
20 |
21 | # Scale images into range 0-1
22 | if scale is not None:
23 | img = (np.clip(img, scale[0], scale[1]) - scale[0])/(scale[1]-scale[0])
24 | elif img.dtype == np.float:
25 | img = (img - img.min())/(img.max() - img.min() + 1e-6)
26 |
27 | # Apply colormap (if applicable) and convert to uint8
28 | if cmap is not None:
29 | try:
30 | imgc = cv2.applyColorMap((img * 255).astype(np.uint8), cmap)
31 | except:
32 | imgc = (img*255.0).astype(np.uint8)
33 | else:
34 | imgc = img
35 |
36 | if imgc.shape[0] == 3:
37 | imgc = imgc.transpose((1, 2, 0))
38 | elif imgc.shape[0] == 4:
39 | imgc = imgc[1:, :, :].transpose((1, 2, 0))
40 |
41 | # Arrange row of images.
42 | maxh = max(maxh, imgc.shape[0])
43 | imgcols.append(imgc)
44 | if i > 0 and i % width == (width-1):
45 | imgrows.append(np.hstack([np.pad(c, ((0, maxh - c.shape[0]), (border//2, border//2), (0, 0)), mode='constant') for c in imgcols]))
46 | imgcols = []
47 | maxh = 0
48 |
49 | # Unfinished row
50 | if imgcols:
51 | imgrows.append(np.hstack([np.pad(c, ((0, maxh - c.shape[0]), (border//2, border//2), (0, 0)), mode='constant') for c in imgcols]))
52 |
53 | maxw = max([c.shape[1] for c in imgrows])
54 |
55 | cv2.imshow(name, np.vstack([np.pad(r, ((border//2, border//2), (0, maxw - r.shape[1]), (0, 0)), mode='constant') for r in imgrows]))
56 |
--------------------------------------------------------------------------------
/visualise_grasp_rectangle.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 |
4 | import matplotlib.pyplot as plt
5 | import torch.utils.data
6 | from utils.data import get_dataset
7 | from utils.dataset_processing.grasp import detect_grasps,GraspRectangles
8 | from models.common import post_process_output
9 | import cv2
10 | import matplotlib
11 | plt.rcParams.update({
12 | "text.usetex": True,
13 | "font.family": "sans-serif",
14 | "font.sans-serif": ["Helvetica"]})
15 | matplotlib.use("TkAgg")
16 |
17 |
18 | def parse_args():
19 | parser = argparse.ArgumentParser(description='Evaluate GG-CNN')
20 |
21 | # Network
22 | parser.add_argument('--network', type=str,default="./output/models/220623_1311_/epoch_08_iou_0.97", help='Path to saved network to evaluate')
23 |
24 | # Dataset & Data & Training
25 | parser.add_argument('--dataset', type=str, default="cornell",help='Dataset Name ("cornell" or "jaquard")')
26 | parser.add_argument('--dataset-path', type=str,default="/home/sam/Desktop/archive111" ,help='Path to dataset')
27 | parser.add_argument('--use-depth', type=int, default=0, help='Use Depth image for training (1/0)')
28 | parser.add_argument('--use-rgb', type=int, default=1, help='Use RGB image for training (0/1)')
29 | parser.add_argument('--split', type=float, default=0.9,
30 | help='Fraction of data for training (remainder is validation)')
31 | parser.add_argument('--ds-rotate', type=float, default=0.0,
32 | help='Shift the start point of the dataset to use a different test/train split for cross validation.')
33 | parser.add_argument('--num-workers', type=int, default=8, help='Dataset workers')
34 |
35 | parser.add_argument('--batch-size', type=int, default=1, help='Batch size')
36 | parser.add_argument('--vis', type=bool, default=False, help='vis')
37 | parser.add_argument('--epochs', type=int, default=2000, help='Training epochs')
38 | parser.add_argument('--batches-per-epoch', type=int, default=200, help='Batches per Epoch')
39 | parser.add_argument('--val-batches', type=int, default=32, help='Validation Batches')
40 | # Logging etc.
41 | parser.add_argument('--description', type=str, default='', help='Training description')
42 | parser.add_argument('--outdir', type=str, default='output/models/', help='Training Output Directory')
43 |
44 | args = parser.parse_args()
45 | return args
46 |
47 | if __name__ == '__main__':
48 | args = parse_args()
49 | print(args.network)
50 | print(args.use_rgb,args.use_depth)
51 | net = torch.load(args.network)
52 | # net_ggcnn = torch.load('./output/models/211112_1458_/epoch_30_iou_0.75')
53 | device = torch.device("cuda:0")
54 | Dataset = get_dataset(args.dataset)
55 |
56 | val_dataset = Dataset(args.dataset_path, start=args.split, end=1.0, ds_rotate=args.ds_rotate,
57 | random_rotate=True, random_zoom=False,
58 | include_depth=args.use_depth, include_rgb=args.use_rgb)
59 | val_data = torch.utils.data.DataLoader(
60 | val_dataset,
61 | batch_size=1,
62 | shuffle=True,
63 | num_workers=args.num_workers
64 | )
65 | results = {
66 | 'correct': 0,
67 | 'failed': 0,
68 | 'loss': 0,
69 | 'losses': {
70 |
71 | }
72 | }
73 | ld = len(val_data)
74 | with torch.no_grad():
75 | batch_idx = 0
76 | fig = plt.figure(figsize=(20, 10))
77 | # ax = fig.add_subplot(5, 5, 1)
78 | # while batch_idx < 100:
79 | for id,(x, y, didx, rot, zoom_factor) in enumerate( val_data):
80 | # batch_idx += 1
81 | if id>24:
82 | break
83 | print(id)
84 | print(x.shape)
85 | xc = x.to(device)
86 | yc = [yy.to(device) for yy in y]
87 | lossd = net.compute_loss(xc, yc)
88 |
89 | loss = lossd['loss']
90 |
91 | results['loss'] += loss.item() / ld
92 | for ln, l in lossd['losses'].items():
93 | if ln not in results['losses']:
94 | results['losses'][ln] = 0
95 | results['losses'][ln] += l.item() / ld
96 |
97 | q_out, ang_out, w_out = post_process_output(lossd['pred']['pos'], lossd['pred']['cos'],
98 | lossd['pred']['sin'], lossd['pred']['width'])
99 | gs_1 = detect_grasps(q_out, ang_out, width_img=w_out, no_grasps=1)
100 | rgb_img=val_dataset.get_rgb(didx, rot, zoom_factor, normalise=False)
101 | # print(rgb_img)
102 | ax = fig.add_subplot(5, 5, id+1)
103 | ax.imshow(rgb_img)
104 | ax.axis('off')
105 | for g in gs_1:
106 | g.plot(ax)
107 | plt.show()
108 |
109 | # s = evaluation.calculate_iou_match(q_out, ang_out,
110 | # val_data.dataset.get_gtbb(didx, rot, zoom_factor),
111 | # no_grasps=2,
112 | # grasp_width=w_out,
113 | # )
114 | #
115 | # if s:
116 | # results['correct'] += 1
117 | # else:
118 | # results['failed'] += 1
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
--------------------------------------------------------------------------------
/visulaize_heatmaps.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | import torch.utils.data
6 | from utils.data import get_dataset
7 | from utils.dataset_processing.grasp import detect_grasps,GraspRectangles
8 | from models.common import post_process_output
9 | import cv2
10 | import matplotlib
11 | plt.rcParams.update({
12 | "text.usetex": True,
13 | "font.family": "sans-serif",
14 | "font.sans-serif": ["Helvetica"]})
15 | matplotlib.use("TkAgg")
16 |
17 |
18 | def parse_args():
19 | parser = argparse.ArgumentParser(description='Evaluate GG-CNN')
20 |
21 | # Network
22 | parser.add_argument('--network', type=str,default="./output/models/220623_1311_/epoch_08_iou_0.97", help='Path to saved network to evaluate')
23 |
24 | # Dataset & Data & Training
25 | parser.add_argument('--dataset', type=str, default="cornell",help='Dataset Name ("cornell" or "jaquard")')
26 | parser.add_argument('--dataset-path', type=str,default="/home/sam/Desktop/archive111" ,help='Path to dataset')
27 | parser.add_argument('--use-depth', type=int, default=0, help='Use Depth image for training (1/0)')
28 | parser.add_argument('--use-rgb', type=int, default=1, help='Use RGB image for training (0/1)')
29 | parser.add_argument('--split', type=float, default=0.9,
30 | help='Fraction of data for training (remainder is validation)')
31 | parser.add_argument('--ds-rotate', type=float, default=0.0,
32 | help='Shift the start point of the dataset to use a different test/train split for cross validation.')
33 | parser.add_argument('--num-workers', type=int, default=8, help='Dataset workers')
34 |
35 | parser.add_argument('--batch-size', type=int, default=1, help='Batch size')
36 | parser.add_argument('--vis', type=bool, default=False, help='vis')
37 | parser.add_argument('--epochs', type=int, default=2000, help='Training epochs')
38 | parser.add_argument('--batches-per-epoch', type=int, default=200, help='Batches per Epoch')
39 | parser.add_argument('--val-batches', type=int, default=32, help='Validation Batches')
40 | # Logging etc.
41 | parser.add_argument('--description', type=str, default='', help='Training description')
42 | parser.add_argument('--outdir', type=str, default='output/models/', help='Training Output Directory')
43 |
44 | args = parser.parse_args()
45 | return args
46 |
47 | if __name__ == '__main__':
48 | args = parse_args()
49 | print(args.network)
50 | print(args.use_rgb,args.use_depth)
51 | net = torch.load(args.network)
52 | # net_ggcnn = torch.load('./output/models/211112_1458_/epoch_30_iou_0.75')
53 | device = torch.device("cuda:0")
54 | Dataset = get_dataset(args.dataset)
55 |
56 | val_dataset = Dataset(args.dataset_path, start=args.split, end=1.0, ds_rotate=args.ds_rotate,
57 | random_rotate=True, random_zoom=False,
58 | include_depth=args.use_depth, include_rgb=args.use_rgb)
59 | val_data = torch.utils.data.DataLoader(
60 | val_dataset,
61 | batch_size=1,
62 | shuffle=True,
63 | num_workers=args.num_workers
64 | )
65 | results = {
66 | 'correct': 0,
67 | 'failed': 0,
68 | 'loss': 0,
69 | 'losses': {
70 |
71 | }
72 | }
73 | ld = len(val_data)
74 | with torch.no_grad():
75 | batch_idx = 0
76 | # fig = plt.figure(figsize=(10, 10))
77 | # ax = fig.add_subplot(1, 4, 1)
78 | # while batch_idx < 100:
79 | for id,(x, y, didx, rot, zoom_factor) in enumerate( val_data):
80 | # batch_idx += 1
81 |
82 | print(id)
83 | print(x.shape)
84 | xc = x.to(device)
85 | yc = [yy.to(device) for yy in y]
86 | lossd = net.compute_loss(xc, yc)
87 |
88 | loss = lossd['loss']
89 |
90 | results['loss'] += loss.item() / ld
91 | for ln, l in lossd['losses'].items():
92 | if ln not in results['losses']:
93 | results['losses'][ln] = 0
94 | results['losses'][ln] += l.item() / ld
95 |
96 | q_out, ang_out, w_out = post_process_output(lossd['pred']['pos'], lossd['pred']['cos'],
97 | lossd['pred']['sin'], lossd['pred']['width'])
98 | gs_1 = detect_grasps(q_out, ang_out, width_img=w_out, no_grasps=1)
99 | rgb_img=val_dataset.get_rgb(didx, rot, zoom_factor, normalise=False)
100 |
101 | fig = plt.figure(figsize=(10, 10))
102 | ax = fig.add_subplot(1, 4, 1)
103 | ax.imshow(rgb_img)
104 |
105 | ax = fig.add_subplot(1, 4, 2)
106 | plot = ax.imshow(q_out, cmap="jet", vmin=0, vmax=1) # ?terrain
107 | plt.colorbar(plot)
108 | ax.axis('off')
109 | ax.set_title('q image')
110 |
111 | ax = fig.add_subplot(1, 4, 3) # flag prism jet
112 | plot = ax.imshow(ang_out, cmap="hsv", vmin=-np.pi / 2, vmax=np.pi / 2)
113 | plt.colorbar(plot)
114 | ax.axis('off')
115 | ax.set_title('angle')
116 |
117 | ax = fig.add_subplot(1, 4, 4)
118 | plot = ax.imshow(w_out, cmap='jet', vmin=-0, vmax=150)
119 | plt.colorbar(plot)
120 | ax.set_title('width')
121 | ax.axis('off')
122 | # print(rgb_img)
123 |
124 | plt.show()
125 | plt.savefig('RGB_1_%d.pdf' % 1, bbox_inches='tight')
--------------------------------------------------------------------------------