├── nets
├── __init__.py
├── resnet_utils.py
└── resnet_v1.py
├── data
├── output
│ └── .gitignore
├── .gitignore
└── test
│ ├── test0.png
│ ├── test1.png
│ ├── test2.png
│ ├── test3.png
│ ├── test4.png
│ ├── test5.png
│ └── test6.png
├── data_collection
├── .gitignore
├── remoteApi.so
├── README.md
├── ur.py
├── grasp_trials_without_rule.py
├── rg2.py
├── corrective_grasp_trials.py
└── scene.py
├── .gitignore
├── models
└── .gitignore
├── doc
├── preprocessing.png
├── 000007_0000_color.png
├── 000007_0000_depth.png
├── 000007_0000_label.png
├── 000007_0000_height_map_color.png
├── 000007_0000_height_map_depth.png
└── the_proposed_grasp_system_pipeline.jpg
├── requirements.txt
├── download.sh
├── losses.py
├── hparams.py
├── data_processing
├── create_tf_record.py
├── rescale_data.py
├── data_processor.py
├── dataset_utils.py
├── data_preprocessing.py
├── mag_data_preprocessing.py
├── generate_pseudo_data_v1.py
└── genarate_pseudo_data_v2.py
├── metagrasp_test.py
├── mag_test.py
├── models.py
├── README.md
├── metagrasp_train.py
├── mag_train.py
└── network_utils.py
/nets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/output/.gitignore:
--------------------------------------------------------------------------------
1 | *.png
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | tfrecord*
2 | 3dnet*
--------------------------------------------------------------------------------
/data_collection/.gitignore:
--------------------------------------------------------------------------------
1 | *.ttt
2 | *.ttm
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | *.pyc
3 | *.tfrecord
4 | __pycache__
5 | .idea
--------------------------------------------------------------------------------
/models/.gitignore:
--------------------------------------------------------------------------------
1 | *.ckpt
2 | events*
3 | checkpoint
4 | *.pbtxt
5 | *.json
6 | model.*
--------------------------------------------------------------------------------
/data/test/test0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/data/test/test0.png
--------------------------------------------------------------------------------
/data/test/test1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/data/test/test1.png
--------------------------------------------------------------------------------
/data/test/test2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/data/test/test2.png
--------------------------------------------------------------------------------
/data/test/test3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/data/test/test3.png
--------------------------------------------------------------------------------
/data/test/test4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/data/test/test4.png
--------------------------------------------------------------------------------
/data/test/test5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/data/test/test5.png
--------------------------------------------------------------------------------
/data/test/test6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/data/test/test6.png
--------------------------------------------------------------------------------
/doc/preprocessing.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/doc/preprocessing.png
--------------------------------------------------------------------------------
/doc/000007_0000_color.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/doc/000007_0000_color.png
--------------------------------------------------------------------------------
/doc/000007_0000_depth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/doc/000007_0000_depth.png
--------------------------------------------------------------------------------
/doc/000007_0000_label.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/doc/000007_0000_label.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow>=1.9.0
2 | opencv-python>=3.4.0
3 | numpy>=1.13.3
4 | scipy>=1.0.0
5 | gdown==3.8.1
6 |
--------------------------------------------------------------------------------
/data_collection/remoteApi.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/data_collection/remoteApi.so
--------------------------------------------------------------------------------
/doc/000007_0000_height_map_color.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/doc/000007_0000_height_map_color.png
--------------------------------------------------------------------------------
/doc/000007_0000_height_map_depth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/doc/000007_0000_height_map_depth.png
--------------------------------------------------------------------------------
/doc/the_proposed_grasp_system_pipeline.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/doc/the_proposed_grasp_system_pipeline.jpg
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | gdown https://drive.google.com/uc?id=1BCJkYBIA1wnPtmEZ_MCA9HMaZ6J6DlUg
2 |
3 | tar -zxvf metagrasp.tar.gz
4 | rm metagrasp.tar.gz
5 |
6 | cd metagrasp
7 |
8 | mv ur.ttt ../data_collection/
9 |
10 | tar -zxvf 3dnet.tar.gz
11 | rm 3dnet.tar.gz
12 | mv 3dnet ../data_collection/
13 |
14 | tar -zxvf model.tar.gz
15 | rm model.tar.gz
16 | mv metagrasp ../models/
17 |
18 | cd ..
19 | rm -rf metagrasp
20 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def create_loss(net, labels):
5 | y = tf.exp(net) / tf.reduce_sum(tf.exp(net), axis=3, keepdims=True)
6 | cross_entropy = -tf.reduce_mean(labels * tf.log(tf.clip_by_value(y, 0.001, 0.999)))
7 | return cross_entropy
8 |
9 |
10 | def create_loss_with_label_mask(net, labels, lamb):
11 | bad, good, background = tf.unstack(labels, axis=3)
12 | mask = lamb * tf.add(bad, good) + background * 0.1
13 | attention_mask = tf.stack([mask, mask, mask], axis=3)
14 | y = tf.exp(net) / tf.reduce_sum(tf.exp(net), axis=3, keepdims=True)
15 | cross_entropy = -tf.reduce_mean(attention_mask * (labels * tf.log(tf.clip_by_value(y, 0.001, 0.999))))
16 | return cross_entropy
17 |
18 |
19 | def create_loss_without_background(net, labels):
20 | bad, good, background = tf.unstack(labels, axis=3)
21 | background = tf.zeros_like(background, dtype=tf.float32)
22 | labels = tf.stack([bad, good, background], axis=3)
23 | y = tf.exp(net) / tf.reduce_sum(tf.exp(net), axis=3, keepdims=True)
24 | cross_entropy = -tf.reduce_mean(labels * tf.log(tf.clip_by_value(y, 0.001, 0.999)))
25 | return cross_entropy
26 |
--------------------------------------------------------------------------------
/hparams.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def create_mag_hparams(hparam_string=None):
5 | hparams = tf.contrib.training.HParams(learning_rate=0.001,
6 | lr_decay_step=200000,
7 | lr_decay_rate=0.77,
8 | momentum=0.99,
9 | lamb=15.0,
10 | batch_size=8,
11 | image_size=288,
12 | label_size=36,
13 | scope='mag',
14 | model_name='resnet_v1_50')
15 | if hparam_string:
16 | tf.logging.info('Parsing command line hparams: %s', hparam_string)
17 | hparams.parse(hparam_string)
18 |
19 | tf.logging.info('Final parsed hparams: %s', hparams.values())
20 | return hparams
21 |
22 |
23 | def create_metagrasp_hparams(hparam_string=None):
24 | hparams = tf.contrib.training.HParams(learning_rate=0.001,
25 | lr_decay_step=200000,
26 | lr_decay_rate=0.77,
27 | momentum=0.99,
28 | lamb=120.0,
29 | batch_size=16,
30 | image_size=288,
31 | label_size=288,
32 | scope='metagrasp',
33 | model_name='resnet_v1_50')
34 | if hparam_string:
35 | tf.logging.info('Parsing command line hparams: %s', hparam_string)
36 | hparams.parse(hparam_string)
37 |
38 | tf.logging.info('Final parsed hparams: %s', hparams.values())
39 | return hparams
40 |
41 |
--------------------------------------------------------------------------------
/data_processing/create_tf_record.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import random
3 | import math
4 |
5 | import argparse
6 | import os
7 |
8 | from data_processing.dataset_utils import *
9 |
10 | parser = argparse.ArgumentParser(description='create tensorflow record')
11 | parser.add_argument('--data_path',
12 | required=True,
13 | type=str,
14 | help='Path to data set.')
15 | parser.add_argument('--output',
16 | default='data/tfrecords',
17 | type=str,
18 | help='Path to the output files.')
19 | args = parser.parse_args()
20 |
21 | folders = ['color', 'encoded_depth', 'label_map', 'camera_height']
22 |
23 |
24 | def dict_to_tf_example(file_name):
25 | """
26 | Create tfrecord example.
27 | :param file_name: File path corresponding to the data.
28 | :return: example: Example of tfrecord.
29 | """
30 | with open(os.path.join(args.data_path, folders[0], file_name+'.png'), 'rb') as fid:
31 | encoded_color = fid.read()
32 | with open(os.path.join(args.data_path, folders[1], file_name + '.png'), 'rb') as fid:
33 | encoded_depth = fid.read()
34 | with open(os.path.join(args.data_path, folders[2], file_name + '.png'), 'rb') as fid:
35 | encoded_label_map = fid.read()
36 | # camera_height = float(np.loadtxt(os.path.join(args.data_path, folders[3], file_name + '.txt')))
37 | example = tf.train.Example(features=tf.train.Features(feature={
38 | 'image/color': bytes_feature(encoded_color),
39 | 'image/format': bytes_feature(b'png'),
40 | 'image/encoded_depth': bytes_feature(encoded_depth),
41 | 'image/label': bytes_feature(encoded_label_map),
42 | # 'image/camera_height': float_feature(camera_height),
43 | }))
44 | return example
45 |
46 |
47 | def main():
48 | if not os.path.exists(args.output):
49 | os.makedirs(args.output)
50 | os.makedirs(os.path.join(args.output, 'pos'))
51 | os.makedirs(os.path.join(args.output, 'neg'))
52 | curr_pos_records = len(os.listdir(os.path.join(args.output, 'pos')))
53 | curr_neg_records = len(os.listdir(os.path.join(args.output, 'neg')))
54 | curr_records_list = [curr_pos_records, curr_neg_records]
55 | sets = ['pos', 'neg']
56 | num_samples_per_record = 10000
57 |
58 | for curr_records, set in zip(curr_records_list, sets):
59 | with open(os.path.join(args.data_path, set+'.txt'), 'r') as f:
60 | file_list = f.readlines()
61 | random.shuffle(file_list)
62 | num_samples = len(file_list)
63 | num_records = math.ceil(num_samples / num_samples_per_record)
64 | for i, record_idx in enumerate(range(curr_pos_records, curr_pos_records + num_records)):
65 | writer = tf.python_io.TFRecordWriter(os.path.join(args.output,
66 | set,
67 | 'train_{}_{:04d}.tfrecord'.format(set, record_idx)))
68 | j = i * num_samples_per_record
69 | while j < (i + 1) * num_samples_per_record and j < num_samples:
70 | tf_example = dict_to_tf_example(file_list[j][:-1])
71 | writer.write(tf_example.SerializeToString())
72 | j += 1
73 | writer.close()
74 |
75 |
76 | if __name__ == '__main__':
77 | main()
78 |
79 |
--------------------------------------------------------------------------------
/data_collection/README.md:
--------------------------------------------------------------------------------
1 | ### Collect new data in V-REP
2 |
3 | To collect new grasp data in V-REP, we you should first run the scene (i.e., run ```ur.ttt```) in simulator,
4 | then control the actions using the remote python api.
5 |
6 | #### Step 1: Launch the V-REP simulator and load scene.
7 | You first need to have access to the V-REP 3.4 simulator, the software is available at [here](http://www.coppeliarobotics.com/).
8 | Then run the script to launch V-REP:
9 | ```bash
10 | cd path_to_vrep
11 | ./vrep
12 | ```
13 | Load the scene ```ur.ttt``` into V-REP, the scene file and object models are available at
14 | [here](https://drive.google.com/open?id=1YQfju1x6_Kj7Hc0hPD154YmPWP3SbZ5h)
15 |
16 | #### Step 2: run control script to collect data.
17 | ```bash
18 | python data_collection/corrective_grasp_trials.py \
19 | --ip 127.0.0.1 \ # ip address to the vrep simulator.
20 | --port 19997 \ # port to the vrep simulator.
21 | --obj_id obj0000 \ # object handle in vrep.
22 | --num_grasp 200 \ # the number of grasp trails.
23 | --num_repeat 1 \ # the number of repeat time if the gripper successfully grasp an object.
24 | --output data/3dnet # directory to save data.
25 | ```
26 |
27 | If you want to collect data which is not guided by antipodal rule, you can run the script:
28 | ```bash
29 | python data_collection/grasp_trials_without_rule.py \
30 | --ip 127.0.0.1 \ # ip address to the vrep simulator.
31 | --port 19997 \ # port to the vrep simulator.
32 | --obj_id obj0000 \ # object handle in vrep.
33 | --num_grasp 200 \ # the number of grasp trails.
34 | --output data/3dnet # directory to save data.
35 | ```
36 |
37 | ### Data structure
38 | A grasp sample consists of an RGB image of the whole workspace,
39 | the grasp label that illustrates which locations are graspable or not
40 | as well as how much degrees the gripper should rotate,
41 | and object coordinates that indicate the whole object pixel locations in the RGB image.
42 |
43 | ### Contents of directories
44 | * **3dnet**
45 | * **0000**
46 | * **color**: The raw RGB images obtained from vision sensor.
47 | * **000000.png**
48 | * **000001_0000.png**
49 | * ......
50 | * **depth**: The raw depth images obtained from vision sensor.
51 | * **000000.png**
52 | * **000001_0000.png**
53 | * ......
54 | * **height_map_color**: The cropped RGB images that only contain the information of workspace.
55 | * **000000.png**
56 | * **000001_0000.png**
57 | * ......
58 | * **height_map_depth**: The cropped depth images that only contain the information of workspace.
59 | * **000000.png**
60 | * **000001_0000.png**
61 | * ......
62 | * **label**
63 | * **000000.bad.txt**: This file contains the grasp points in image space and corresponding grasp angles.
64 | * **000000.object_points.txt**: This file contains coordinates that belong to object.
65 | * **000000.png**: The visualization of the grasp angles.
66 | * **000001_0000.good.txt**
67 | * **000001_0000.object_points.txt**
68 | * **000001_0000.png**
69 | * ......
70 | * **background_color.png**
71 | * **background_depth.png**
72 | * **crop_background_color.png**
73 | * **crop_background_depth.png**
74 | * **file_name.txt**
75 | * **0001**
76 | * ......
77 | * ......
78 |
79 |
--------------------------------------------------------------------------------
/data_processing/rescale_data.py:
--------------------------------------------------------------------------------
1 | from data_processing.dataset_utils import RescaleData, DataInfo
2 |
3 |
4 | import numpy as np
5 | import cv2
6 |
7 | import argparse
8 | import os
9 |
10 | parser = argparse.ArgumentParser(description='process labels')
11 | parser.add_argument('--data_dir',
12 | required=True,
13 | type=str,
14 | help='path to the data')
15 | args = parser.parse_args()
16 |
17 |
18 | def main():
19 | data_dir = os.listdir(args.data_dir)
20 | info = DataInfo()
21 | for data_id in data_dir:
22 | parent_dir = os.path.join(args.data_dir, data_id)
23 | print(parent_dir)
24 | os.mkdir(os.path.join(parent_dir, 'zoomed_height_map_color'))
25 | os.mkdir(os.path.join(parent_dir, 'zoomed_height_map_depth'))
26 | os.mkdir(os.path.join(parent_dir, 'zoomed_label'))
27 | background_depth = cv2.imread(os.path.join(parent_dir, 'background_depth.png'), cv2.IMREAD_ANYDEPTH)
28 | with open(os.path.join(parent_dir, 'file_name.txt'), 'r') as f:
29 | with open(os.path.join(parent_dir, 'zoomed_file_name.txt'), 'w') as f_zoom:
30 | file_names = f.readlines()
31 | for file_name in file_names:
32 | file_name = file_name[:-2] if file_name[-2:] == '\r\n' else file_name[:-1]
33 | print(file_name)
34 | color = cv2.imread(os.path.join(parent_dir, 'color', file_name + '.png'))
35 | depth = cv2.imread(os.path.join(parent_dir, 'depth', file_name + '.png'), cv2.IMREAD_ANYDEPTH)
36 | try:
37 | grasp_labels = np.loadtxt(os.path.join(parent_dir, 'label', file_name + '.good.txt'))
38 | label_flag = 'good'
39 | except IOError:
40 | grasp_labels = np.loadtxt(os.path.join(parent_dir, 'label', file_name + '.bad.txt'))
41 | label_flag = 'bad'
42 | rd = RescaleData(color, depth, background_depth, grasp_labels, info)
43 | for i in range(7):
44 | factor = np.random.uniform(0.20/info.camera_height, 1.0)
45 | crop_color, crop_depth, label, label_points, object_points = rd.get_zoomed_data(factor)
46 | cv2.imwrite(os.path.join(parent_dir, 'zoomed_height_map_color',
47 | file_name + '.{:0.7f}.png'.format(factor)), crop_color)
48 | cv2.imwrite(os.path.join(parent_dir, 'zoomed_height_map_depth',
49 | file_name + '.{:0.7f}.png'.format(factor)), crop_depth)
50 | cv2.imwrite(os.path.join(parent_dir, 'zoomed_label', file_name + '.{:0.7f}.png'.format(factor)),
51 | label)
52 | np.savetxt(os.path.join(parent_dir, 'zoomed_label',
53 | file_name + '.{:0.7f}.'.format(factor) + label_flag + '.txt'),
54 | label_points)
55 | np.savetxt(os.path.join(parent_dir, 'zoomed_label',
56 | file_name + '.{:0.7f}.'.format(factor) + 'object_points.txt'),
57 | object_points)
58 | f_zoom.write(file_name + '.{:0.7f}\n'.format(factor))
59 |
60 |
61 | if __name__ == '__main__':
62 | main()
63 |
64 |
65 |
--------------------------------------------------------------------------------
/metagrasp_test.py:
--------------------------------------------------------------------------------
1 | from models import *
2 |
3 | import tensorflow as tf
4 | import numpy as np
5 | import cv2
6 |
7 | import argparse
8 | import os
9 | os.environ['CUDA_VISIBLE_DEVICES'] = '1'
10 |
11 | parser = argparse.ArgumentParser(description='network evaluating')
12 | parser.add_argument('--color_dir',
13 | default='data/test/000024.png',
14 | type=str,
15 | help='The directory where the color image can be found.')
16 | parser.add_argument('--output_dir',
17 | default='data/output',
18 | type=str,
19 | help='The directory where the color image can be found.')
20 | parser.add_argument('--checkpoint_dir',
21 | default='models/metagrasp',
22 | type=str,
23 | help='The directory where the checkpoint can be found')
24 | args = parser.parse_args()
25 |
26 |
27 | def main():
28 | colors_n = cv2.resize(cv2.imread(args.color_dir), (200, 200))[..., ::-1]
29 | pad_size = 44
30 | pad_colors = np.zeros((288, 288, 3), dtype=np.uint8)
31 | pad_colors[pad_size:pad_size + 200, pad_size:pad_size + 200, :] = colors_n
32 |
33 | colors_p = tf.placeholder(dtype=tf.float32, shape=[1, 288, 288, 3])
34 | colors = colors_p * tf.random_normal(colors_p.get_shape(), mean=1, stddev=0.01)
35 | colors = colors / tf.constant([255.0])
36 | colors = (colors - tf.constant([0.485, 0.456, 0.406])) / tf.constant([0.229, 0.224, 0.225])
37 |
38 | net, end_points = metagrasp(colors,
39 | num_classes=3,
40 | num_channels=1000,
41 | is_training=False,
42 | global_pool=False,
43 | output_stride=16,
44 | spatial_squeeze=False,
45 | scope='metagrasp')
46 | probability_map = tf.exp(net) / tf.reduce_sum(tf.exp(net), axis=3, keepdims=True)
47 | probability_map = tf.image.resize_bilinear(probability_map, [288, 288])
48 | saver = tf.train.Saver()
49 | session_config = tf.ConfigProto(allow_soft_placement=True,
50 | log_device_placement=False)
51 | session_config.gpu_options.allow_growth = True
52 | sess = tf.Session(config=session_config)
53 | saver.restore(sess, tf.train.latest_checkpoint(args.checkpoint_dir))
54 | print('Successfully loading model: {}.'.format(tf.train.latest_checkpoint(args.checkpoint_dir)))
55 | sess.run(tf.local_variables_initializer())
56 |
57 | outputs_a = []
58 | outputs_a_plus_c = []
59 | for i in range(16):
60 | mtx = cv2.getRotationMatrix2D((144, 144), 22.5 * i, 1)
61 | rotated_colors = cv2.warpAffine(pad_colors, mtx, (288, 288))
62 |
63 | output = sess.run(probability_map,
64 | feed_dict={colors_p: np.expand_dims(rotated_colors, 0).astype(np.float32)})
65 | outputs_a.append(output[..., 1]) # extract green channel
66 | outputs_a_plus_c.append((rotated_colors*0.7 + np.squeeze(output)*255.0*0.3).astype(np.uint8))
67 |
68 | cv2.imwrite(os.path.join(args.output_dir, 'rotated_colors_{}.png'.format(i)),
69 | rotated_colors[..., ::-1])
70 | cv2.imwrite(os.path.join(args.output_dir, 'output_{}.png'.format(i)),
71 | (np.squeeze(output)*255.0).astype(np.uint8)[..., ::-1])
72 | outputs = np.concatenate(outputs_a, axis=0)
73 | threshold = np.max(outputs) - 0.001
74 | for idx, h, w in zip(*np.where(outputs >= threshold)):
75 | cv2.circle(outputs_a_plus_c[idx], (w, h), 1, color=(0, 255, 0), thickness=5)
76 | vis_map = np.concatenate(
77 | tuple([np.concatenate(tuple(outputs_a_plus_c[4 * i:4 * (i + 1)]), axis=1) for i in range(4)]),
78 | axis=0)
79 | cv2.imshow('visualization_map', vis_map[..., ::-1])
80 | cv2.waitKey(0)
81 | cv2.imwrite(os.path.join(args.output_dir, 'visualization_map.png'), vis_map[..., ::-1])
82 |
83 | sess.close()
84 |
85 |
86 | if __name__ == '__main__':
87 | main()
88 |
89 |
--------------------------------------------------------------------------------
/mag_test.py:
--------------------------------------------------------------------------------
1 | from models import *
2 |
3 | import tensorflow as tf
4 | import numpy as np
5 | import cv2
6 |
7 | import argparse
8 | import os
9 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
10 |
11 | parser = argparse.ArgumentParser(description='network evaluating')
12 | parser.add_argument('--color_dir',
13 | default='data/test/color.png',
14 | type=str,
15 | help='The directory where the color image can be found.')
16 | parser.add_argument('--output_dir',
17 | default='data/output',
18 | type=str,
19 | help='The directory where the color image can be found.')
20 | parser.add_argument('--checkpoint_dir',
21 | default='models/mag',
22 | type=str,
23 | help='The directory where the checkpoint can be found')
24 | args = parser.parse_args()
25 |
26 |
27 | def main():
28 | colors_n = cv2.resize(cv2.imread(args.color_dir), (200, 200))[..., ::-1]
29 | pad_size = 44
30 | pad_colors = np.zeros((288, 288, 3), dtype=np.uint8)
31 | pad_colors[pad_size:pad_size + 200, pad_size:pad_size + 200, :] = colors_n
32 |
33 | colors_p = tf.placeholder(dtype=tf.float32, shape=[1, 288, 288, 3])
34 | colors = colors_p * tf.random_normal(colors_p.get_shape(), mean=1, stddev=0.01)
35 | colors = colors / tf.constant([255.0])
36 | colors = (colors - tf.constant([0.485, 0.456, 0.406])) / tf.constant([0.229, 0.224, 0.225])
37 |
38 | net, end_points = mag(colors,
39 | num_classes=3,
40 | num_channels=1000,
41 | is_training=False,
42 | global_pool=False,
43 | output_stride=16,
44 | upsample_ratio=2,
45 | spatial_squeeze=False,
46 | reuse=tf.AUTO_REUSE,
47 | scope='mag')
48 | probability_map = tf.exp(net) / tf.reduce_sum(tf.exp(net), axis=3, keepdims=True)
49 | probability_map = tf.image.resize_bilinear(probability_map, [288, 288])
50 | saver = tf.train.Saver()
51 | session_config = tf.ConfigProto(allow_soft_placement=True,
52 | log_device_placement=False)
53 | session_config.gpu_options.allow_growth = True
54 | sess = tf.Session(config=session_config)
55 | saver.restore(sess, tf.train.latest_checkpoint(args.checkpoint_dir))
56 | print('Successfully loading model: {}.'.format(tf.train.latest_checkpoint(args.checkpoint_dir)))
57 | sess.run(tf.local_variables_initializer())
58 |
59 | outputs_a = []
60 | outputs_a_plus_c = []
61 | for i in range(16):
62 | mtx = cv2.getRotationMatrix2D((144, 144), 22.5 * i, 1)
63 | rotated_colors = cv2.warpAffine(pad_colors, mtx, (288, 288))
64 |
65 | output = sess.run(probability_map,
66 | feed_dict={colors_p: np.expand_dims(rotated_colors, 0).astype(np.float32)})
67 | outputs_a.append(output[..., 1]) # extract green channel
68 | outputs_a_plus_c.append((rotated_colors*0.7 + np.squeeze(output)*255.0*0.3).astype(np.uint8))
69 |
70 | cv2.imwrite(os.path.join(args.output_dir, 'rotated_colors_{}.png'.format(i)),
71 | rotated_colors[..., ::-1])
72 | cv2.imwrite(os.path.join(args.output_dir, 'output_{}.png'.format(i)),
73 | (np.squeeze(output)*255.0).astype(np.uint8)[..., ::-1])
74 | outputs = np.concatenate(outputs_a, axis=0)
75 | threshold = np.max(outputs) - 0.001
76 | for idx, h, w in zip(*np.where(outputs >= threshold)):
77 | cv2.circle(outputs_a_plus_c[idx], (w, h), 1, color=(0, 255, 0), thickness=5)
78 | vis_map = np.concatenate(
79 | tuple([np.concatenate(tuple(outputs_a_plus_c[4 * i:4 * (i + 1)]), axis=1) for i in range(4)]),
80 | axis=0)
81 | cv2.imshow('visualization_map', vis_map[..., ::-1])
82 | cv2.waitKey(0)
83 | cv2.imwrite(os.path.join(args.output_dir, 'visualization_map.png'), vis_map[..., ::-1])
84 |
85 | sess.close()
86 |
87 |
88 | if __name__ == '__main__':
89 | main()
90 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.contrib.slim as slim
3 |
4 | from nets import resnet_v1
5 |
6 |
7 | def arg_scope(weight_decay=0.0005,
8 | batch_norm_decay=0.997,
9 | batch_norm_epsilon=1e-5,
10 | batch_norm_scale=True,
11 | activation_fn=tf.nn.relu,
12 | use_batch_norm=True):
13 | batch_norm_params = {
14 | 'decay': batch_norm_decay,
15 | 'epsilon': batch_norm_epsilon,
16 | 'scale': batch_norm_scale,
17 | 'updates_collections': tf.GraphKeys.UPDATE_OPS,
18 | 'fused': None, # Use fused batch norm if possible.
19 | }
20 | with slim.arg_scope(
21 | [slim.conv2d, slim.conv2d_transpose],
22 | weights_regularizer=slim.l2_regularizer(weight_decay),
23 | weights_initializer=slim.variance_scaling_initializer(),
24 | activation_fn=activation_fn,
25 | normalizer_fn=slim.batch_norm if use_batch_norm else None,
26 | normalizer_params=batch_norm_params):
27 | with slim.arg_scope([slim.batch_norm], **batch_norm_params):
28 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc:
29 | return arg_sc
30 |
31 |
32 | @slim.add_arg_scope
33 | def resize_bilinear(inputs,
34 | height,
35 | width,
36 | outputs_collections=None,
37 | scope=None):
38 | with tf.variable_scope(scope, 'resize', [inputs]) as sc:
39 | outputs = tf.image.resize_bilinear(inputs, [height, width], name='resize_bilinear')
40 | return slim.utils.collect_named_outputs(outputs_collections, sc.name, outputs)
41 |
42 |
43 | # multi-affordance grasping
44 | def mag(inputs,
45 | num_classes=3,
46 | num_channels=1000,
47 | is_training=True,
48 | global_pool=False,
49 | output_stride=16,
50 | upsample_ratio=2,
51 | spatial_squeeze=False,
52 | reuse=tf.AUTO_REUSE,
53 | scope='graspnet'):
54 | with tf.variable_scope(scope, 'graspnet', [inputs], reuse=reuse):
55 | with slim.arg_scope(resnet_v1.resnet_arg_scope()):
56 | net, end_points = resnet_v1.resnet_v1_50(inputs=inputs,
57 | num_classes=num_channels,
58 | is_training=is_training,
59 | global_pool=global_pool,
60 | output_stride=output_stride,
61 | spatial_squeeze=spatial_squeeze,
62 | scope='feature_extractor')
63 | with tf.variable_scope('prediction', [net]) as sc:
64 | end_points_collection = sc.original_name_scope + '_end_points'
65 | # to do: add batch normalization to the following conv layers.
66 | with slim.arg_scope([slim.conv2d],
67 | outputs_collections=end_points_collection):
68 | net = slim.conv2d(net, 512, [1, 1], scope='conv1')
69 | net = slim.conv2d(net, 128, [1, 1], scope='conv2')
70 | net = slim.conv2d(net, num_classes, [1, 1], scope='conv3')
71 | height, width = net.get_shape().as_list()[1:3]
72 | net = tf.image.resize_bilinear(net,
73 | [height * upsample_ratio, width * upsample_ratio],
74 | name='resize_bilinear')
75 | end_points.update(slim.utils.convert_collection_to_dict(end_points_collection))
76 | end_points['logits'] = net
77 | return net, end_points
78 |
79 |
80 | def metagrasp(inputs,
81 | num_classes=3,
82 | num_channels=1000,
83 | is_training=True,
84 | global_pool=False,
85 | output_stride=16,
86 | spatial_squeeze=False,
87 | reuse=tf.AUTO_REUSE,
88 | scope='metagrasp'):
89 | with tf.variable_scope(scope, 'metagrasp', [inputs], reuse=reuse):
90 | with slim.arg_scope(resnet_v1.resnet_arg_scope()):
91 | net, end_points = resnet_v1.resnet_v1_50(inputs=inputs,
92 | num_classes=num_channels,
93 | is_training=is_training,
94 | global_pool=global_pool,
95 | output_stride=output_stride,
96 | spatial_squeeze=spatial_squeeze,
97 | scope='feature_extractor')
98 | with tf.variable_scope('prediction', [net]) as sc:
99 | end_points_collection = sc.original_name_scope + '_end_points'
100 | # to do: add batch normalization to the following conv layers.
101 | with slim.arg_scope([slim.conv2d, resize_bilinear],
102 | outputs_collections=end_points_collection):
103 | net = slim.conv2d(net, 512, [3, 3], scope='conv1')
104 | height, width = net.get_shape().as_list()[1:3]
105 | net = resize_bilinear(net, height * 2, width * 2, scope='resize_bilinear1')
106 | net = slim.conv2d(net, 256, [3, 3], scope='conv2')
107 | height, width = net.get_shape().as_list()[1:3]
108 | net = resize_bilinear(net, height * 2, width * 2, scope='resize_bilinear2')
109 | net = slim.conv2d(net, 128, [3, 3], scope='conv3')
110 | height, width = net.get_shape().as_list()[1:3]
111 | net = resize_bilinear(net, height * 2, width * 2, scope='resize_bilinear3')
112 | net = slim.conv2d(net, 64, [3, 3], scope='conv4')
113 | height, width = net.get_shape().as_list()[1:3]
114 | net = resize_bilinear(net, height * 2, width * 2, scope='resize_bilinear4')
115 | net = slim.conv2d(net, num_classes, [5, 5], scope='conv5')
116 | end_points.update(slim.utils.convert_collection_to_dict(end_points_collection))
117 | end_points['logits'] = net
118 | return net, end_points
119 |
--------------------------------------------------------------------------------
/data_processing/data_processor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class DataProcessor(object):
5 | def __init__(self):
6 | pass
7 |
8 | @staticmethod
9 | def get_grasp_center(grasp_labels):
10 | """
11 | Get grasp center.
12 | :param grasp_labels: A numpy array with shape [num_labels, 4].
13 | Each row represents a grasp label formulated as 2 points (i.e., (x1, y1, x2, y2)).
14 | :return: grasp_center: A numpy array with shape [num_labels, 2].
15 | Each row represents a grasp center corresponding to the grasp label.
16 | """
17 | row, _ = grasp_labels.shape
18 | grasp_center = np.empty((row, 2), dtype=np.float32)
19 | grasp_center[:, 0] = (grasp_labels[:, 0] + grasp_labels[:, 2]) / 2.0
20 | grasp_center[:, 1] = (grasp_labels[:, 1] + grasp_labels[:, 3]) / 2.0
21 | return grasp_center
22 |
23 | @staticmethod
24 | def get_grasp_angle(grasp_label):
25 | """
26 | Get grasp angle.
27 | :param grasp_label: A numpy array with shape [1, 4] which represents
28 | a grasp label formulated as 2 points (i.e., (x1, y1, x2, y2)).
29 | :return: angle_indices: A list of int with length 2. The discretized angles ranged from 0 to 15.
30 | Besides the original grasp angle, the list contains another angle flipped vertically by original one.
31 | """
32 | pt1 = grasp_label[0:2]
33 | pt2 = grasp_label[2:]
34 | angle = np.arctan2(pt2[0] - pt1[0], pt2[1] - pt1[1])
35 | if angle < 0:
36 | angle += np.pi * 2
37 | angle_indices = []
38 | angle_indices.append(int(round(angle / ((22.5 / 360.0) * np.pi * 2))))
39 | if angle >= np.pi:
40 | angle_indices.append(int(round((angle - np.pi) / ((22.5 / 360.0) * np.pi * 2))))
41 | else:
42 | angle_indices.append(int(round((angle + np.pi) / ((22.5 / 360.0) * np.pi * 2))))
43 | return angle_indices
44 |
45 | @staticmethod
46 | def rotate(points, center, angle):
47 | """
48 | Rotate points.
49 | :param points: A numpy array with shape [num_points, 2].
50 | :param center: A numpy array with shape [1, 2]. The rotation center of the points.
51 | :param angle: A float. The rotated angle represented in radian.
52 | :return: points: A numpy array with shape [num_points, 2]. The rotated points.
53 | """
54 | points = points.copy()
55 | h = center[0]
56 | w = center[1]
57 | points[:, 0] -= h
58 | points[:, 1] -= w
59 | rotate_matrix = np.array([[np.cos(angle), -np.sin(angle)],
60 | [np.sin(angle), np.cos(angle)]])
61 | points = np.dot(rotate_matrix, points.T).T
62 | points[:, 0] += h
63 | points[:, 1] += w
64 | return points
65 |
66 | @staticmethod
67 | def get_diff_depth(depth_o, depth_b):
68 | """
69 | Get difference depth image by subtracting current depth image and background depth image.
70 | :param depth_o: A numpy array with shape [height, width]. The current depth image.
71 | :param depth_b: A numpy array with shape [height, width]. The background depth image.
72 | :return: diff_depth: A numpy array with shape [height, width]. The difference depth image.
73 | """
74 | diff_depth = depth_b - depth_o
75 | diff_depth[np.where(diff_depth < 0)] = 0
76 | # diff_depth = cv2.medianBlur(diff_depth, 3)
77 | diff_depth = diff_depth.astype(np.uint16)
78 | return diff_depth
79 |
80 | @staticmethod
81 | def encode_depth(depth):
82 | """
83 | Encode depth image to RGB format.
84 | :param depth: A numpy array with shape [height, width]. The depth image.
85 | :return:
86 | """
87 | r = depth / 256 / 256
88 | g = depth / 256
89 | b = depth % 256
90 | # encoded_depth = np.stack([r, g, b], axis=2).astype(np.uint8)
91 | encoded_depth = np.stack([b, g, r], axis=2).astype(np.uint8) # use bgr order due to cv2 format
92 | return encoded_depth
93 |
94 | @staticmethod
95 | def gaussianize_label(grasp_label, grasp_centers, camera_height, is_good, expand=True):
96 | if is_good:
97 | for grasp_center in grasp_centers:
98 | if grasp_label[grasp_center[0], grasp_center[1], 0] == 255:
99 | # grasp_label[grasp_center[0], grasp_center[1], 0] = 0
100 | # grasp_label[grasp_center[0], grasp_center[1], 1] = 255
101 | pass
102 | else:
103 | left = right = 0
104 | while grasp_label[grasp_center[0], grasp_center[1]-left, 0] == 0:
105 | left += 1
106 | while grasp_label[grasp_center[0], grasp_center[1]+right, 0] == 0:
107 | right += 1
108 |
109 | def gauss(x, c, sigma):
110 | return int(255 * np.exp(-(x-c) ** 2 / (2 * sigma ** 2)))
111 | width = left + right
112 | # width = min(left, right) * 2
113 | grasp_center[1] += width / 2 - left
114 | # sigma = camera_height * width
115 | # sigma = width / (camera_height * 10.0)
116 | sigma = 0.03 / (0.4 * camera_height / 0.57) * 200.0 / 3.0
117 | for idx in range(grasp_center[1]-width/2, grasp_center[1]+width/2+1):
118 | value = gauss(idx, grasp_center[1], sigma)
119 | grasp_label[grasp_center[0], idx, 1] = value
120 | grasp_label[grasp_center[0], idx, 2] = 255 - value
121 | else:
122 | if expand:
123 | for grasp_center in grasp_centers:
124 | if grasp_label[grasp_center[0], grasp_center[1], 0] == 255:
125 | grasp_label[grasp_center[0], grasp_center[1], 0] = 0
126 | grasp_label[grasp_center[0], grasp_center[1], 2] = 255
127 | else:
128 | left = right = 0
129 | while grasp_label[grasp_center[0], grasp_center[1]-left, 0] == 0:
130 | grasp_label[grasp_center[0], grasp_center[1]-left, 2] = 255
131 | left += 1
132 | while grasp_label[grasp_center[0], grasp_center[1]+right, 0] == 0:
133 | grasp_label[grasp_center[0], grasp_center[1]+right, 2] = 255
134 | right += 1
135 | else:
136 | grasp_label[grasp_centers[:, 0], grasp_centers[:, 1], 0] = 0
137 | grasp_label[grasp_centers[:, 0], grasp_centers[:, 1], 2] = 255
138 | return grasp_label
139 |
140 |
141 |
142 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MetaGrasp: Data Efficient Grasping by Affordance Interpreter Network
2 | This repository is the code for the paper
3 | **[MetaGrasp: Data Efficient Grasping by Affordance Interpreter Network](https://sysu-robotics-lab.github.io/MetaGrasp/)**
4 |
5 | [Junhao Cai](None)1
6 | [Hui Cheng](http://sdcs.sysu.edu.cn/content/2504)1
7 | [Zhangpeng Zhang](https://zhzhanp.github.io/)2
8 | [Jingcheng Su](None)1
9 |
10 | 1Sun Yat-sen University
11 | 2Sensetime Group Limited
12 |
13 | Accepted at International Conference on Robotics and Automation 2019 (ICRA2019) .
14 |
15 |
16 |

17 |
18 | Fig.1. The grasp system pipeline.
19 |
20 | ### MetaGrasp: Data Efficient Grasping by Affordance Interpreter Network
21 | **Abstract** Data-driven approach for grasping shows significant advance recently.
22 | But these approaches usually require much training data. To increase the efficiency
23 | of grasping data collection, this paper presents a novel grasp training system
24 | including the whole pipeline from data collection to model inference. The system
25 | can collect effective grasp sample with a corrective strategy assisted by antipodal
26 | grasp rule, and we design an affordance interpreter network to predict pixelwise
27 | grasp affordance map. We define graspability, ungraspability and background as grasp
28 | affordances. The key advantage of our system is that the pixel-level affordance
29 | interpreter network trained with only a small number of grasp samples under antipodal
30 | rule can achieve significant performance on totally unseen objects and backgrounds.
31 | The training sample is only collected in simulation. Extensive qualitative and
32 | quantitative experiments demonstrate the accuracy and robustness of our proposed
33 | approach. In the real-world grasp experiments, we achieve a grasp success rate of 93%
34 | on a set of household items and 91% on a set of adversarial items with only about
35 | 6,300 simulated samples. We also achieve 87% accuracy in clutter scenario.
36 | Although the model is trained using only RGB image, when changing the background
37 | textures, it also performs well and can achieve even 94% accuracy on the set of
38 | adversarial objects, which outperforms current state-of-the-art methods.
39 | **Citing**
40 | If you find this code useful in your work, pleace consider citing:
41 | ```bash
42 | @article{cai2019data,
43 | title={Data Efficient Grasping by Affordance Interpreter Network},
44 | author={Junhao Cai, Hui Cheng, Zhanpeng Zhang, Jingcheng Su},
45 | booktitle={IEEE Conference of Robotics and Automation},
46 | year={2019}
47 | }
48 | ```
49 |
50 | **Contact**
51 |
52 | If you have any questions or fine any bugs, please contact to Junhao Cai (caijh28@mail2.sysu.edu.cn).
53 | #### Requirements and Dependencies
54 | * Python 3
55 | * Tensorflow 1.9 or later
56 | * Opencv 3.4.0 or later
57 | * NVIDIA GPU with compute capability 3.5+
58 | #### Quick Start
59 | Given an RGB image containing the whole workspace, we can obtain the affordance map using our pre-trained model by:
60 | 1. Download our pre-trained model:
61 | ```bash
62 | ./download.sh
63 | ```
64 | 2. run the test script:
65 | ```bash
66 | python metagrasp_test.py \
67 | --color_dir data/test/test1.png \
68 | --output_dir data/output \
69 | --checkpoint_dir models/metagrasp
70 | ```
71 | The visual results are saved in the output folder.
72 |
73 | #### Training
74 | To train new model, you should follow the steps below :
75 |
76 | **Step 1: Collect the dataset**
77 | You can obtain grasp dataset in V-REP using our data collecting system,
78 | the details can be seen in ```data_collection/README.md```.
79 |
80 | **Step 2: Preprocess the data**
81 | After the data collection, you should preprocess the data for the data converting process. Run the script:
82 | ```bash
83 | python data_processing/data_preprocessing.py \
84 | --data_dir data/3dnet \
85 | --output data/training
86 | ```
87 | Then use ```data_processing/create_tf_record.py``` to generate positive and negative tfrecords.
88 | ```bash
89 | python data_processing/create_tf_record.py \
90 | --data_path data/training \
91 | --output data/tfrecords
92 | ```
93 | **Step 3: Download pre-trained model**
94 |
95 | We use [Resnet](https://arxiv.org/abs/1512.03385) as initialized parameters of feature extractor of the model,
96 | you can have access to the pre-trained model [here](http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz).
97 | ```bash
98 | mkdir models/checkpoints && cd models/checkpoints
99 | wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz
100 | cd ../..
101 | ```
102 | **Step 4: Train the model**
103 |
104 | Then you can train the model with the script ```ain_train_color.py```.
105 | For the basic usage the command will look something like this:
106 | ```bash
107 | python metagrasp_train.py \
108 | --train_log_dir models/logs_metagrasp
109 | --dataset_dir data/tfrecords
110 | --checkpoint_dir models/checkpoints
111 | ```
112 | You can visualize the training results using Tensorboard by the following command:
113 | ```bash
114 | tensorboard --logdir=models/logs_metagrasp --host=localhost --port=6666
115 | ```
116 |
117 | #### Data format
118 |
119 | In this section, we will first introduce the data structure of a grasp sample,
120 | then we will illustrate how to preprocess the sample so that we can obtain
121 | horizontally antipodal grasp data.
122 |
123 | **1. Data structure**
124 |
125 | A grasp sample consists of
126 | 1\) the RGB image which contains the whole workspace,
127 | 2\) the grasp point represented by image coordinates in image space,
128 | 3\) the grasp angle, and
129 | 4\) the grasp label with respect to the grasp point and the grasp angle.
130 | Note that we only assign the labels to the pixels where we do grasp trial,
131 | so we don't use the other regions of the grasp object when training.
132 | The visualization of a sample is shown as below.
133 |
134 |

135 |

136 |
137 | Fig.2. Raw image captured from vision sensor.
138 |
139 |
144 | Fig.3. Cropped images and grasp label.
145 |
146 |
147 | **2. Data after preprocessing**
148 |
149 | During the preprocessing, we first rotate the RGB image and select the rotated images
150 | which satisfy the horizontally antipodal grasp rule. Then we augment the data by
151 | slightly rotate the images and corresponding grasp angles. The whole process is shown in
152 | Fig.4.
153 |
154 |

155 |
156 | Fig.4. Cropped images and grasp label.
157 |
158 |
159 |
160 |
161 |
--------------------------------------------------------------------------------
/data_collection/ur.py:
--------------------------------------------------------------------------------
1 | from functools import partial, reduce
2 | from math import sqrt
3 | import time
4 |
5 | import data_collection.vrep as vrep
6 |
7 |
8 | class UR5(object):
9 | def __init__(self, client_id):
10 | """
11 | Initialization.
12 | :param client_id: An int. The client ID. refer to simxStart.
13 | """
14 | self.client_id = client_id
15 | self.joint_handles = [vrep.simxGetObjectHandle(self.client_id,
16 | 'UR5_joint{}'.format(i),
17 | vrep.simx_opmode_blocking)[1] for i in range(1, 6+1)]
18 | _, self.ik_tip_handle = vrep.simxGetObjectHandle(self.client_id,
19 | 'UR5_ik_tip',
20 | vrep.simx_opmode_blocking)
21 | _, self.ik_target_handle = vrep.simxGetObjectHandle(self.client_id,
22 | 'UR5_ik_target',
23 | vrep.simx_opmode_blocking)
24 | self.func = partial(vrep.simxCallScriptFunction,
25 | clientID=self.client_id,
26 | scriptDescription="UR5",
27 | options=vrep.sim_scripttype_childscript,
28 | operationMode=vrep.simx_opmode_blocking)
29 | self.initialization()
30 |
31 | def initialization(self):
32 | """
33 | Call script function pyInit in vrep child script.
34 | :return: None
35 | """
36 | _ = self.func(functionName='pyInit',
37 | inputInts=[],
38 | inputFloats=[],
39 | inputStrings=[],
40 | inputBuffer='')
41 | time.sleep(0.7)
42 |
43 | def wait_until_stop(self, handle, threshold=0.01):
44 | """
45 | Wait until the operation finishes.
46 | This is a delay function called in order to make sure that
47 | the operation executed has been completed.
48 | :param handle: An int.Handle of the object.
49 | :param threshold: A float. The object position threshold.
50 | If the object positions difference between two time steps is smaller than the threshold,
51 | the execution completes, otherwise the loop continues.
52 | :return: None
53 | """
54 | while True:
55 | _, pos1 = vrep.simxGetObjectPosition(self.client_id, handle, -1, vrep.simx_opmode_blocking)
56 | _, quat1 = vrep.simxGetObjectQuaternion(self.client_id, handle, -1, vrep.simx_opmode_blocking)
57 | time.sleep(0.7)
58 | _, pos2 = vrep.simxGetObjectPosition(self.client_id, handle, -1, vrep.simx_opmode_blocking)
59 | _, quat2 = vrep.simxGetObjectQuaternion(self.client_id, handle, -1, vrep.simx_opmode_blocking)
60 | pose1 = pos1 + quat1
61 | pose2 = pos2 + quat2
62 | theta = 0.5 * sqrt(reduce(lambda x, y: x + y, map(lambda x, y: (x - y) ** 2, pose1, pose2)))
63 | if theta < threshold:
64 | return
65 |
66 | def enable_ik(self, enable=0):
67 | """
68 | Call script function pyEnableIk in vrep child script.
69 | :param enable: A int. Whether to enable inverse kinematic.
70 | If enable = 1 the ur5 enables ik, and vice versa.
71 | :return: None
72 | """
73 | _ = self.func(functionName='pyEnableIk',
74 | inputInts=[enable],
75 | inputFloats=[],
76 | inputStrings=[],
77 | inputBuffer='')
78 |
79 | def move_to_joint_position(self, joint_angles):
80 | """
81 | Moves (actuates) several joints at the same time using the Reflexxes Motion Library type II or IV.
82 | Call script function pyMoveToJointPositions in vrep child script.
83 | :param joint_angles: A list of floats. The desired target angle positions of the joints.
84 | :return: None
85 | """
86 | self.enable_ik(0)
87 | _ = self.func(functionName='pyMoveToJointPositions',
88 | inputInts=[],
89 | inputFloats=joint_angles,
90 | inputStrings=[],
91 | inputBuffer='')
92 | self.wait_until_stop(self.ik_target_handle)
93 | time.sleep(0.7) # todo: remove manual time delay
94 |
95 | def move_to_object_position(self, pose):
96 | """
97 | Moves an object to a given position and/or orientation using Reflexxes Motion Library type II or IV.
98 | Call script function pyMoveToPosition in vrep child script.
99 | :param pose: A list of floats. The desired target position of the object.
100 | :return: None
101 | """
102 | self.enable_ik(1)
103 | _ = self.func(functionName='pyMoveToPosition',
104 | inputInts=[],
105 | inputFloats=pose,
106 | inputStrings=[],
107 | inputBuffer='')
108 | self.wait_until_stop(self.ik_target_handle)
109 |
110 | def get_end_effector_position(self):
111 | """
112 | Get end effector position.
113 | :return: pos: A list of floats. The position of end effector.
114 | """
115 | _, pos = vrep.simxGetObjectPosition(self.client_id,
116 | self.ik_tip_handle,
117 | -1,
118 | vrep.simx_opmode_blocking)
119 | return pos
120 |
121 | def get_end_effector_quaternion(self):
122 | """
123 | Get end effector quaternion.
124 | :return: quat: A list of floats. The angle of end effector represented by quaternion.
125 | """
126 | # _, quat = vrep.simxGetObjectQuaternion(self.client_id,
127 | # self.ik_tip_handle,
128 | # -1,
129 | # vrep.simx_opmode_blocking)
130 | _, _, quat, _, _ = self.func(functionName='pyGetObjectQuaternion',
131 | inputInts=[self.ik_tip_handle, -1],
132 | inputFloats=[],
133 | inputStrings=[],
134 | inputBuffer='')
135 | return quat
136 |
137 | def get_joint_positions(self, is_first_call=False):
138 | """
139 | Get joint angle positions.
140 | :param is_first_call: A boolean. Specify which operation mode vrep api function chooses.
141 | :return: joint_positions: A list of float. Return value of UR joint angles.
142 | """
143 | if is_first_call:
144 | opmode = vrep.simx_opmode_streaming
145 | else:
146 | opmode = vrep.simx_opmode_blocking
147 | joint_positions = [vrep.simxGetJointPosition(self.client_id,
148 | joint_handle,
149 | opmode)[1] for joint_handle in self.joint_handles]
150 | return joint_positions
151 |
152 |
153 |
--------------------------------------------------------------------------------
/data_collection/grasp_trials_without_rule.py:
--------------------------------------------------------------------------------
1 | from math import pi
2 | from scipy import misc
3 | from data_collection.scene import Scene
4 | from functools import partial
5 | import numpy as np
6 | import cv2
7 | import argparse
8 | import os
9 |
10 | parser = argparse.ArgumentParser(description='Data collection in vrep simulator.')
11 | parser.add_argument('--ip',
12 | default='127.0.0.1',
13 | type=str,
14 | help='ip address to the vrep simulator.')
15 | parser.add_argument('--port',
16 | default=19997,
17 | type=str,
18 | help='port to the vrep simulator.')
19 | parser.add_argument('--obj_id',
20 | default='obj0',
21 | type=str,
22 | help='object name in vrep.')
23 | parser.add_argument('--num_grasp',
24 | default=200,
25 | type=int,
26 | help='the number of grasp trails.')
27 | parser.add_argument('--num_repeat',
28 | default=1,
29 | type=int,
30 | help='the number of repeat time if the gripper successfully grasp an object.')
31 | parser.add_argument('--output',
32 | default='data/3dnet',
33 | type=str,
34 | help='directory to save data.')
35 | args = parser.parse_args()
36 |
37 | step = pi / 180.0
38 | ratio = 0.4 / 200
39 | interval = [0.23, 0.57]
40 | rand = partial(np.random.uniform, interval[0], interval[1])
41 |
42 |
43 | def main():
44 | if not os.path.exists(args.output):
45 | os.makedirs(args.output)
46 | data_id = len(os.listdir(args.output))
47 | data_path = os.path.join(args.output, '{:04d}'.format(data_id))
48 | color_path = os.path.join(data_path, 'color')
49 | depth_path = os.path.join(data_path, 'depth')
50 | height_map_color_path = os.path.join(data_path, 'height_map_color')
51 | height_map_depth_path = os.path.join(data_path, 'height_map_depth')
52 | label_path = os.path.join(data_path, 'label')
53 | os.mkdir(data_path)
54 | os.mkdir(color_path)
55 | os.mkdir(depth_path)
56 | os.mkdir(height_map_color_path)
57 | os.mkdir(height_map_depth_path)
58 | os.mkdir(label_path)
59 |
60 | s = Scene(args.ip, args.port, 'realsense')
61 | init_pos = [45*step, 10*step, 90*step, -10*step, -90*step, -45*step]
62 | top_pos = [-45*step, 10*step, 90*step, -10*step, -90*step, -45*step]
63 | # object handle
64 | obj = s.get_object_handle(args.obj_id)
65 | obj_quat = s.get_object_quaternion(obj)
66 | s.gripper.open_gripper()
67 |
68 | misc.imsave(data_path + '/background_color.png', s.background_color)
69 | cv2.imwrite(data_path + '/background_depth.png', s.background_depth)
70 | misc.imsave(data_path+'/crop_background_color.png',
71 | cv2.resize(s.background_color[45:435, 119:509, :], (200, 200)))
72 | cv2.imwrite(data_path+'/crop_background_depth.png',
73 | cv2.resize(s.background_depth[45:435, 119:509], (200, 200)))
74 |
75 | f = open(os.path.join(data_path, 'file_name.txt'), 'w')
76 |
77 | s.set_object_position(obj, [rand(), rand(), s.get_object_position(obj)[-1]])
78 |
79 | for i in range(args.num_grasp): # number of grasp trials.
80 | print(i)
81 | s.gripper.open_gripper()
82 | s.ur5.move_to_joint_position(init_pos)
83 | s.replace_object(obj, obj_quat, interval)
84 | color = s.get_color_image()
85 | depth = s.get_depth_image()
86 | center_point, object_points, reaching_height = s.get_center_from_image(depth)
87 | s.ur5.move_to_joint_position(top_pos)
88 | episode = []
89 | # move to the upward side with respect to the object
90 | init_quat = s.ur5.get_end_effector_quaternion()
91 | target_pos = s.get_object_position(obj)
92 | target_pos[0] = 0.200 + center_point[0] * ratio
93 | target_pos[1] = 0.200 + center_point[1] * ratio
94 | target_pos[2] = reaching_height
95 | target_pos[-1] += 0.1563
96 | episode.append(target_pos + init_quat)
97 | s.ur5.move_to_object_position(target_pos + init_quat)
98 | # randomly rotate gripper's angle
99 | joint_angles = s.ur5.get_joint_positions()
100 | curr_angle = joint_angles[-1] / np.pi * 180.0
101 | angle = np.random.randint(0, 180)
102 | joint_angles[-1] = (curr_angle + angle) * step
103 | s.ur5.move_to_joint_position(joint_angles)
104 | quat = s.ur5.get_end_effector_quaternion()
105 | episode.append(target_pos + quat)
106 | # move to the object downward
107 | target_pos[-1] -= 0.1563
108 | episode.append(target_pos + quat)
109 | s.ur5.move_to_object_position(target_pos + quat)
110 | # try to grasp object
111 | s.gripper.close_gripper()
112 | # determine whether the object is successfully grasped.
113 | if s.gripper.get_object_detection_status():
114 | print('grasp success.')
115 | width = s.gripper.get_gripper_width()
116 | s.gripper.open_gripper()
117 | s.ur5.move_to_object_position(episode[1])
118 | crop_color = cv2.resize(color[45:435, 119:509, :], (200, 200))
119 | crop_depth = cv2.resize(depth[45:435, 119:509], (200, 200))
120 | points = s.get_label(center_point, -angle * np.pi / 180.0, width, ratio)
121 | label = s.draw_label(points, 200, 200)
122 | misc.imsave(color_path + '/{:06d}.png'.format(i), color)
123 | cv2.imwrite(depth_path + '/{:06d}.png'.format(i), depth)
124 | misc.imsave(height_map_color_path + '/{:06d}.png'.format(i), crop_color)
125 | cv2.imwrite(height_map_depth_path + '/{:06d}.png'.format(i), crop_depth)
126 | misc.imsave(label_path + '/{:06d}.png'.format(i), label)
127 | np.savetxt(label_path + '/{:06d}.good.txt'.format(i), points)
128 | np.savetxt(label_path + '/{:06d}.object_points.txt'.format(i), np.asanyarray(object_points))
129 | f.write('{:06d}\n'.format(i))
130 |
131 | else:
132 | print('grasp failed')
133 | s.gripper.open_gripper()
134 | s.ur5.move_to_object_position(episode[1])
135 | crop_color = cv2.resize(color[45:435, 119:509, :], (200, 200))
136 | crop_depth = cv2.resize(depth[45:435, 119:509], (200, 200))
137 | points = s.get_label(center_point, -angle * np.pi / 180.0, 0.085, ratio)
138 | label = s.draw_label(points, 200, 200, (255, 0, 0))
139 | misc.imsave(color_path + '/{:06d}.png'.format(i), color)
140 | cv2.imwrite(depth_path + '/{:06d}.png'.format(i), depth)
141 | misc.imsave(height_map_color_path + '/{:06d}.png'.format(i), crop_color)
142 | cv2.imwrite(height_map_depth_path + '/{:06d}.png'.format(i), crop_depth)
143 | misc.imsave(label_path + '/{:06d}.png'.format(i), label)
144 | np.savetxt(label_path + '/{:06d}.bad.txt'.format(i), points)
145 | np.savetxt(label_path + '/{:06d}.object_points.txt'.format(i), np.asanyarray(object_points))
146 | f.write('{:06d}\n'.format(i))
147 |
148 | s.stop_simulation()
149 | f.close()
150 |
151 |
152 | if __name__ == '__main__':
153 | main()
154 |
--------------------------------------------------------------------------------
/data_processing/dataset_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import cv2
4 |
5 |
6 | def int64_feature(value):
7 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
8 |
9 |
10 | def int64_list_feature(value):
11 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
12 |
13 |
14 | def bytes_feature(value):
15 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
16 |
17 |
18 | def bytes_list_feature(value):
19 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
20 |
21 |
22 | def float_feature(value):
23 | return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
24 |
25 |
26 | def float_list_feature(value):
27 | return tf.train.Feature(float_list=tf.train.FloatList(value=value))
28 |
29 |
30 | def read_examples_list(path):
31 | with tf.gfile.GFile(path) as fid:
32 | lines = fid.readlines()
33 | return [line.strip().split(',') for line in lines]
34 |
35 |
36 | class DataInfo(object):
37 | camera_height = 0.57 # camera height, 0.57m
38 | original_image_height = 480 # the height of raw image captured from the camera
39 | original_image_width = 640 # the width of raw image captured from the camera
40 | top_left_corner_h = 45 # the first height index of the original image from which the crop image is cropped.
41 | top_left_corner_w = 119 # the first width index of the original image from which the crop image is cropped.
42 | crop_image_size = 390 # crop image size
43 | resize_image_size = 200 # resize image size
44 |
45 |
46 | class RescaleData(object):
47 | def __init__(self, color, depth, background_depth, grasp_labels, info=DataInfo()):
48 | self.info = info
49 |
50 | self.color = color
51 | self.diff_depth = background_depth - depth
52 | self.grasp_labels = grasp_labels
53 | self.grasp_center = self.get_grasp_center(self.grasp_labels)
54 | self.grasp_angle = self.get_grasp_angle(self.grasp_labels[0])
55 | self.gripper_width = self.get_gripper_width(self.grasp_labels[0])
56 |
57 | def get_zoomed_data(self, factor):
58 | zoomed_size = int(self.info.crop_image_size * factor)
59 | keep = zoomed_size / 7
60 | half = zoomed_size / 2
61 | top, down, left, right = 1, 0, 1, 0
62 | keep1 = keep2 = keep
63 | while top > down:
64 | top = self.grasp_center[0] - half + keep1 if self.grasp_center[0] > zoomed_size else half
65 | down = self.grasp_center[0] + half - keep1 if self.info.original_image_height - self.grasp_center[0] > zoomed_size else self.info.original_image_height - half - zoomed_size % 2
66 | keep1 /= 2
67 | while left > right:
68 | left = self.grasp_center[1] - half + keep2 if self.grasp_center[1] > zoomed_size else half
69 | right = self.grasp_center[1] + half - keep2 if self.info.original_image_width - self.grasp_center[1] > zoomed_size else self.info.original_image_width - half - zoomed_size % 2
70 | keep2 /= 2
71 | image_center = np.array([np.random.randint(top, down+1, dtype=int), np.random.randint(left, right+1
72 | , dtype=int)])
73 | crop_color, crop_depth, object_points = self.zoom(image_center, zoomed_size, factor)
74 | grasp_point = self.grasp_center - (image_center - half)
75 | grasp_point = (grasp_point * 200.0 / zoomed_size).astype(np.int)
76 | label_points = self.get_label(grasp_point, self.grasp_angle, self.gripper_width, factor)
77 | label = self.draw_label(label_points, 200, 200)
78 | return crop_color, crop_depth, label, label_points, object_points
79 |
80 | def zoom(self, center, zoomed_size, factor):
81 | half = zoomed_size / 2
82 | background_depth = int(self.info.camera_height * 1000 * factor)
83 | crop_color = self.color[center[0] - half:center[0] + half + zoomed_size % 2, center[1] - half:center[1] + half + zoomed_size % 2, :]
84 | crop_diff_depth = self.diff_depth[center[0] - half:center[0] + half + zoomed_size % 2, center[1] - half:center[1] + half + zoomed_size % 2]
85 | crop_depth = background_depth - crop_diff_depth
86 | crop_color = cv2.resize(crop_color, (200, 200))
87 | crop_depth = cv2.resize(crop_depth, (200, 200))
88 | points = np.stack(np.where(cv2.resize(crop_diff_depth, (200, 200)) > 0), axis=0).T
89 |
90 | return crop_color, crop_depth, points
91 |
92 | def get_grasp_center(self, grasp_labels):
93 | row, _ = grasp_labels.shape
94 | grasp_center = np.empty((row, 2), dtype=np.float32)
95 | grasp_center[:, 0] = (grasp_labels[:, 0] + grasp_labels[:, 2]) / 2.0
96 | grasp_center[:, 1] = (grasp_labels[:, 1] + grasp_labels[:, 3]) / 2.0
97 | grasp_center = 1.0 * grasp_center * self.info.crop_image_size / 200 + np.array([self.info.top_left_corner_h, self.info.top_left_corner_w])
98 | return grasp_center.mean(axis=0).astype(np.int)
99 |
100 | @staticmethod
101 | def get_label(center_point, angle, width, factor):
102 | """
103 | Obtain grasp labels. Each grasp label is represented as two points.
104 | :param center_point: A numpy array of float32. The position of the object which the gripper will reach.
105 | :param angle: A float. The rotated angle of the gripper.
106 | :param width: A float. The width between two finger tips.
107 | :return: A numpy array of float. The grasp points.
108 | """
109 | width = int(width / factor)
110 | tip_width = int(7 / factor)
111 | c_h = center_point[0]
112 | c_w = center_point[1]
113 | h = [delta_h for delta_h in range(-tip_width/2, tip_width/2)]
114 | w = [-width / 2, width / 2]
115 | points = np.asanyarray([[hh, w[0], hh, w[1]] for hh in h])
116 | rotate_matrix = np.array([[np.cos(angle), -np.sin(angle)],
117 | [np.sin(angle), np.cos(angle)]])
118 | points[:, 0:2] = np.dot(rotate_matrix, points[:, 0:2].T).T
119 | points[:, 2:] = np.dot(rotate_matrix, points[:, 2:].T).T
120 | points = points + np.asanyarray([[c_h, c_w, c_h, c_w]])
121 | points = np.floor(points).astype(np.int)
122 | return points
123 |
124 | @staticmethod
125 | def get_grasp_angle(grasp_label):
126 | """
127 | Get grasp angle.
128 | :param grasp_label: A numpy array with shape [1, 4] which represents
129 | a grasp label formulated as 2 points (i.e., (x1, y1, x2, y2)).
130 | :return: angle_indices: A list of int with length 2. The discretized angles ranged from 0 to 15.
131 | Besides the original grasp angle, the list contains another angle flipped vertically by original one.
132 | """
133 | pt1 = grasp_label[0:2]
134 | pt2 = grasp_label[2:]
135 | angle = np.arctan2(pt2[0] - pt1[0], pt2[1] - pt1[1])
136 | return -angle
137 |
138 | @staticmethod
139 | def get_gripper_width(grasp_label):
140 | pt1 = grasp_label[0:2]
141 | pt2 = grasp_label[2:]
142 | gripper_width = np.sqrt(np.sum(np.power(pt1-pt2, 2)))
143 | return gripper_width
144 |
145 | @staticmethod
146 | def draw_label(points, width, height, color=(255, 0, 0)):
147 | """
148 | Draw labels according to the grasp points.
149 | :param points: A numpy array of float. The grasp points.
150 | :param width: A float. The width of image.
151 | :param height: A float. The height of image.
152 | :param color: A tuple. The color of lines.
153 | :return: A numpy array of uint8. The label map.
154 | """
155 | label = np.ones((height, width, 3), dtype=np.uint8) * 255
156 | for point in points:
157 | pt1 = (point[1], point[0])
158 | pt2 = (point[3], point[2])
159 | cv2.line(label, pt1, pt2, color)
160 | return label
161 |
--------------------------------------------------------------------------------
/metagrasp_train.py:
--------------------------------------------------------------------------------
1 | from hparams import create_metagrasp_hparams
2 | from network_utils import *
3 | from models import metagrasp
4 | from losses import *
5 |
6 | import tensorflow as tf
7 | import tensorflow.contrib.slim as slim
8 |
9 | import argparse
10 | import os
11 | os.environ["CUDA_VISIBLE_DEVICES"] = '1'
12 |
13 | parser = argparse.ArgumentParser(description='network training')
14 | parser.add_argument('--master',
15 | default='',
16 | type=str,
17 | help='BNS name of the TensorFlow master to use')
18 | parser.add_argument('--task_id',
19 | default=0,
20 | type=int,
21 | help='The Task ID. This value is used when training with multiple workers to identify each worker.')
22 | parser.add_argument('--train_log_dir',
23 | default='models/metagrasp',
24 | type=str,
25 | help='Directory where to write event models.')
26 | parser.add_argument('--dataset_dir',
27 | default='data/tfrecords',
28 | type=str,
29 | help='The directory where the datasets can be found.')
30 | parser.add_argument('--save_summaries_steps',
31 | default=120,
32 | type=int,
33 | help='The frequency with which summaries are saved, in seconds.')
34 | parser.add_argument('--save_interval_secs',
35 | default=600,
36 | type=int,
37 | help='The frequency with which the model is saved, in seconds.')
38 | parser.add_argument('--print_loss_steps',
39 | default=100,
40 | type=int,
41 | help='The frequency with which the losses are printed, in steps.')
42 | parser.add_argument('--num_readers',
43 | default=2,
44 | type=int,
45 | help='The number of parallel readers that read data from the dataset.')
46 | parser.add_argument('--num_steps',
47 | default=200000,
48 | type=int,
49 | help='The max number of gradient steps to take during training.')
50 | parser.add_argument('--num_preprocessing_threads',
51 | default=4,
52 | type=int,
53 | help='The number of threads used to create the batches.')
54 | parser.add_argument('--from_metagrasp_checkpoint',
55 | default=False,
56 | type=bool,
57 | help='load checkpoint from metagrasp checkpoint or classification checkpoint.')
58 | parser.add_argument('--checkpoint_dir',
59 | default='',
60 | type=str,
61 | help='The directory where the checkpoint can be found')
62 | args = parser.parse_args()
63 |
64 |
65 | def main():
66 | tf.logging.set_verbosity(tf.logging.INFO)
67 | h = create_metagrasp_hparams()
68 | for path in [args.train_log_dir]:
69 | if not tf.gfile.Exists(path):
70 | tf.gfile.MakeDirs(path)
71 | hparams_filename = os.path.join(args.train_log_dir, 'hparams.json')
72 | with tf.gfile.FastGFile(hparams_filename, 'w') as f:
73 | f.write(h.to_json())
74 | with tf.Graph().as_default():
75 | with tf.device(tf.train.replica_device_setter(args.task_id)):
76 | global_step = tf.train.get_or_create_global_step()
77 | colors_p, labels_p = get_color_dataset(args.dataset_dir + '/pos',
78 | args.num_readers,
79 | args.num_preprocessing_threads,
80 | h.image_size,
81 | h.label_size,
82 | int(h.batch_size/2))
83 | colors_n, labels_n = get_color_dataset(args.dataset_dir + '/neg',
84 | args.num_readers,
85 | args.num_preprocessing_threads,
86 | h.image_size,
87 | h.label_size,
88 | int(h.batch_size/2))
89 | colors = tf.concat([colors_p, colors_n], axis=0)
90 | labels = tf.concat([labels_p, labels_n], axis=0)
91 | net, end_points = metagrasp(colors,
92 | num_classes=3,
93 | num_channels=1000,
94 | is_training=True,
95 | global_pool=False,
96 | output_stride=16,
97 | spatial_squeeze=False,
98 | scope=h.scope)
99 | loss = create_loss_with_label_mask(net, labels, h.lamb)
100 | learning_rate = h.learning_rate
101 | if h.lr_decay_step:
102 | learning_rate = tf.train.exponential_decay(h.learning_rate,
103 | tf.train.get_or_create_global_step(),
104 | decay_steps=h.lr_decay_step,
105 | decay_rate=h.lr_decay_rate,
106 | staircase=True)
107 | tf.summary.scalar('Learning_rate', learning_rate)
108 | optimizer = tf.train.GradientDescentOptimizer(learning_rate)
109 | train_op = slim.learning.create_train_op(loss, optimizer)
110 | add_summary(colors, labels, end_points, loss, h)
111 | summary_op = tf.summary.merge_all()
112 | if not args.from_metagrasp_checkpoint:
113 | variable_map = restore_from_classification_checkpoint(
114 | scope=h.scope,
115 | model_name=h.model_name,
116 | checkpoint_exclude_scopes=['prediction'])
117 | init_saver = tf.train.Saver(variable_map)
118 |
119 | def initializer_fn(sess):
120 | init_saver.restore(sess, os.path.join(args.checkpoint_dir, h.model_name+'.ckpt'))
121 | tf.logging.info('Successfully load pretrained checkpoint.')
122 |
123 | init_fn = initializer_fn
124 |
125 | else:
126 | variable_map = restore_map()
127 | init_saver = tf.train.Saver(variable_map)
128 |
129 | def initializer_fn(sess):
130 | init_saver.restore(sess, tf.train.latest_checkpoint(args.checkpoint_dir))
131 | tf.logging.info('Successfully load pretrained checkpoint.')
132 |
133 | init_fn = initializer_fn
134 |
135 | session_config = tf.ConfigProto(allow_soft_placement=True,
136 | log_device_placement=False)
137 | session_config.gpu_options.allow_growth = True
138 | saver = tf.train.Saver(keep_checkpoint_every_n_hours=args.save_interval_secs,
139 | max_to_keep=100)
140 |
141 | slim.learning.train(train_op,
142 | logdir=args.train_log_dir,
143 | master=args.master,
144 | global_step=global_step,
145 | session_config=session_config,
146 | init_fn=init_fn,
147 | summary_op=summary_op,
148 | number_of_steps=args.num_steps,
149 | startup_delay_steps=15,
150 | save_summaries_secs=args.save_summaries_steps,
151 | saver=saver)
152 |
153 |
154 | if __name__ == '__main__':
155 | main()
156 |
--------------------------------------------------------------------------------
/mag_train.py:
--------------------------------------------------------------------------------
1 | from hparams import create_mag_hparams
2 | from network_utils import *
3 | from models import mag
4 | from losses import *
5 |
6 | import tensorflow as tf
7 | import tensorflow.contrib.slim as slim
8 |
9 | import argparse
10 | import os
11 | os.environ["CUDA_VISIBLE_DEVICES"] = '0'
12 |
13 | parser = argparse.ArgumentParser(description='network training')
14 | parser.add_argument('--master',
15 | default='',
16 | type=str,
17 | help='BNS name of the TensorFlow master to use')
18 | parser.add_argument('--task_id',
19 | default=0,
20 | type=int,
21 | help='The Task ID. This value is used when training with multiple workers to identify each worker.')
22 | parser.add_argument('--train_log_dir',
23 | default='models/',
24 | type=str,
25 | help='Directory where to write event models.')
26 | parser.add_argument('--save_summaries_steps',
27 | default=120,
28 | type=int,
29 | help='The frequency with which summaries are saved, in seconds.')
30 | parser.add_argument('--save_interval_secs',
31 | default=600,
32 | type=int,
33 | help='The frequency with which the model is saved, in seconds.')
34 | parser.add_argument('--print_loss_steps',
35 | default=100,
36 | type=int,
37 | help='The frequency with which the losses are printed, in steps.')
38 | parser.add_argument('--dataset_dir',
39 | default='',
40 | type=str,
41 | help='The directory where the datasets can be found.')
42 | parser.add_argument('--num_readers',
43 | default=2,
44 | type=int,
45 | help='The number of parallel readers that read data from the dataset.')
46 | parser.add_argument('--num_steps',
47 | default=200000,
48 | type=int,
49 | help='The max number of gradient steps to take during training.')
50 | parser.add_argument('--num_preprocessing_threads',
51 | default=4,
52 | type=int,
53 | help='The number of threads used to create the batches.')
54 | parser.add_argument('--from_mag_checkpoint',
55 | default=False,
56 | type=bool,
57 | help='Load checkpoint from mag checkpoint or classification checkpoint.')
58 | parser.add_argument('--checkpoint_dir',
59 | default='',
60 | type=str,
61 | help='The directory where the checkpoint can be found')
62 | args = parser.parse_args()
63 |
64 |
65 | def main():
66 | tf.logging.set_verbosity(tf.logging.INFO)
67 | h = create_mag_hparams()
68 | for path in [args.train_log_dir]:
69 | if not tf.gfile.Exists(path):
70 | tf.gfile.MakeDirs(path)
71 | hparams_filename = os.path.join(args.train_log_dir, 'hparams.json')
72 | with tf.gfile.FastGFile(hparams_filename, 'w') as f:
73 | f.write(h.to_json())
74 | with tf.Graph().as_default():
75 | with tf.device(tf.train.replica_device_setter(args.task_id)):
76 | global_step = tf.train.get_or_create_global_step()
77 | colors_p, labels_p = get_color_dataset(args.dataset_dir+'/pos',
78 | args.num_readers,
79 | args.num_preprocessing_threads,
80 | h.image_size,
81 | h.label_size,
82 | int(h.batch_size/2))
83 | colors_n, labels_n = get_color_dataset(args.dataset_dir + '/neg',
84 | args.num_readers,
85 | args.num_preprocessing_threads,
86 | h.image_size,
87 | h.label_size,
88 | int(h.batch_size/2))
89 | colors = tf.concat([colors_p, colors_n], axis=0)
90 | labels = tf.concat([labels_p, labels_n], axis=0)
91 | net, end_points = mag(colors,
92 | num_classes=3,
93 | num_channels=1000,
94 | is_training=True,
95 | global_pool=False,
96 | output_stride=16,
97 | upsample_ratio=2,
98 | spatial_squeeze=False,
99 | scope=h.scope)
100 | loss = create_loss_with_label_mask(net, labels, h.lamb)
101 | # loss = create_loss_without_background(net, labels)
102 | learning_rate = h.learning_rate
103 | if h.lr_decay_step:
104 | learning_rate = tf.train.exponential_decay(h.learning_rate,
105 | tf.train.get_or_create_global_step(),
106 | decay_steps=h.lr_decay_step,
107 | decay_rate=h.lr_decay_rate,
108 | staircase=True)
109 | tf.summary.scalar('Learning_rate', learning_rate)
110 | optimizer = tf.train.GradientDescentOptimizer(learning_rate)
111 | # optimizer = tf.train.MomentumOptimizer(0.001, momentum=h.momentum)
112 | train_op = slim.learning.create_train_op(loss, optimizer)
113 | add_summary(colors, labels, end_points, loss, h)
114 | summary_op = tf.summary.merge_all()
115 | if not args.from_mag_checkpoint:
116 | variable_map = restore_from_classification_checkpoint(
117 | scope=h.scope,
118 | model_name=h.model_name,
119 | checkpoint_exclude_scopes=['prediction'])
120 | init_saver = tf.train.Saver(variable_map)
121 |
122 | def initializer_fn(sess):
123 | init_saver.restore(sess, os.path.join(args.checkpoint_dir, h.model_name+'.ckpt'))
124 | tf.logging.info('Successfully load pretrained checkpoint.')
125 |
126 | init_fn = initializer_fn
127 |
128 | else:
129 | variable_map = restore_map()
130 | init_saver = tf.train.Saver(variable_map)
131 |
132 | def initializer_fn(sess):
133 | init_saver.restore(sess, tf.train.latest_checkpoint(args.checkpoint_dir))
134 | tf.logging.info('Successfully load pretrained checkpoint.')
135 |
136 | init_fn = initializer_fn
137 |
138 | session_config = tf.ConfigProto(allow_soft_placement=True,
139 | log_device_placement=False)
140 | session_config.gpu_options.allow_growth = True
141 | saver = tf.train.Saver(keep_checkpoint_every_n_hours=args.save_interval_secs,
142 | max_to_keep=100)
143 |
144 | slim.learning.train(train_op,
145 | logdir=args.train_log_dir,
146 | master=args.master,
147 | global_step=global_step,
148 | session_config=session_config,
149 | init_fn=init_fn,
150 | summary_op=summary_op,
151 | number_of_steps=args.num_steps,
152 | startup_delay_steps=15,
153 | save_summaries_secs=args.save_summaries_steps,
154 | saver=saver)
155 |
156 |
157 | if __name__ == '__main__':
158 | main()
159 |
--------------------------------------------------------------------------------
/data_collection/rg2.py:
--------------------------------------------------------------------------------
1 | from functools import partial, reduce
2 | from math import sqrt
3 | import data_collection.vrep as vrep
4 | import time
5 |
6 |
7 | class RG2(object):
8 | def __init__(self, client_id, script_name='UR5', motor_vel=0.11, motor_force=77):
9 | """
10 | Initialization for RG2.
11 | :param client_id: An int. The client ID. refer to simxStart.
12 | :param script_name: A string. The lua script name in vrep.
13 | :param motor_vel: A float. Target velocity of the joint1.
14 | :param motor_force: The maximum force or torque that the joint can exert.
15 | """
16 | self.client_id = client_id
17 | self.motor_vel = motor_vel
18 | self.motor_force = motor_force
19 | _, self.attach_point = vrep.simxGetObjectHandle(self.client_id,
20 | 'RG2_attachPoint',
21 | vrep.simx_opmode_blocking)
22 | _, self.proximity_sensor = vrep.simxGetObjectHandle(self.client_id,
23 | 'RG2_attachProxSensor',
24 | vrep.simx_opmode_blocking)
25 | _, self.right_touch = vrep.simxGetObjectHandle(self.client_id,
26 | 'RG2_rightTouch',
27 | vrep.simx_opmode_blocking)
28 | _, self.left_touch = vrep.simxGetObjectHandle(self.client_id,
29 | 'RG2_leftTouch',
30 | vrep.simx_opmode_blocking)
31 | _, self.joint = vrep.simxGetObjectHandle(self.client_id,
32 | 'RG2_openCloseJoint',
33 | vrep.simx_opmode_blocking)
34 | self.func = partial(vrep.simxCallScriptFunction,
35 | clientID=self.client_id,
36 | scriptDescription=script_name,
37 | options=vrep.sim_scripttype_childscript,
38 | operationMode=vrep.simx_opmode_blocking)
39 |
40 | def open_gripper(self):
41 | """
42 | Open gripper.
43 | :return: None.
44 | """
45 | vrep.simxSetJointTargetVelocity(self.client_id,
46 | self.joint,
47 | self.motor_vel,
48 | vrep.simx_opmode_streaming)
49 | vrep.simxSetJointForce(self.client_id,
50 | self.joint,
51 | self.motor_force,
52 | vrep.simx_opmode_oneshot)
53 | self.wait_until_stop(self.right_touch)
54 |
55 | def close_gripper(self):
56 | """
57 | Close gripper.
58 | :return: None
59 | """
60 | vrep.simxSetJointTargetVelocity(self.client_id,
61 | self.joint,
62 | -self.motor_vel,
63 | vrep.simx_opmode_oneshot)
64 | vrep.simxSetJointForce(self.client_id,
65 | self.joint,
66 | self.motor_force,
67 | vrep.simx_opmode_oneshot)
68 | self.wait_until_stop(self.right_touch)
69 |
70 | def attach_object(self, object_handle):
71 | """
72 | Attach object to the gripper. This is an alternative to grasp objects.
73 | :param object_handle: An int. The handle of object which is successfully grasped by gripper.
74 | :return: None
75 | """
76 | vrep.simxSetObjectParent(self.client_id,
77 | object_handle,
78 | self.attach_point,
79 | True,
80 | vrep.simx_opmode_blocking)
81 |
82 | def untie_object(self, object_handle):
83 | """
84 | Untie object to the gripper. This is an alternative to grasp objects.
85 | :param object_handle: An int. The handle of object which is attached to the gripper.
86 | :return: None
87 | """
88 | vrep.simxSetObjectParent(self.client_id,
89 | object_handle,
90 | -1,
91 | True,
92 | vrep.simx_opmode_blocking)
93 |
94 | def get_object_detection_status(self, threshold=0.005):
95 | """
96 | Detect whether gripper grasp object successfully.
97 | :param threshold: A float. When distance between two gripper tips is larger than threshold,
98 | gripper successfully grasp object, vice versa.
99 | :return: A boolean. Return object detection status.
100 | """
101 | time.sleep(0.7)
102 | _, d_pos = vrep.simxGetObjectPosition(self.client_id,
103 | self.left_touch,
104 | self.right_touch,
105 | vrep.simx_opmode_blocking)
106 | if threshold < abs(d_pos[0]) < 0.085:
107 | half_touch = 0.01475 - 0.005
108 | # distance from proximity sensor to left touch.
109 | _, d_p2t_l = vrep.simxGetObjectPosition(self.client_id,
110 | self.left_touch,
111 | self.proximity_sensor,
112 | vrep.simx_opmode_blocking)
113 | # distance from proximity sensor to right touch.
114 | _, d_p2t_r = vrep.simxGetObjectPosition(self.client_id,
115 | self.right_touch,
116 | self.proximity_sensor,
117 | vrep.simx_opmode_blocking)
118 | _, distance = self.read_proximity_sensor()
119 | if distance < d_p2t_l[-1] + half_touch and distance < d_p2t_r[-1] + half_touch:
120 | _, d_l2r = vrep.simxGetObjectPosition(self.client_id,
121 | self.left_touch,
122 | self.right_touch,
123 | vrep.simx_opmode_blocking)
124 | diff = abs(d_l2r[-1])
125 | print(diff)
126 | if diff < 0.001:
127 | return True
128 | return False
129 |
130 | def get_gripper_width(self):
131 | """
132 | Get gripper width.
133 | :return: Gripper width.
134 | """
135 | _, d_pos = vrep.simxGetObjectPosition(self.client_id,
136 | self.left_touch,
137 | self.right_touch,
138 | vrep.simx_opmode_blocking)
139 | return abs(d_pos[0])
140 |
141 | def read_proximity_sensor(self):
142 | """
143 | Read proximity sensor.
144 | :return:
145 | """
146 | _, out_ints, out_floats, _, _ = self.func(functionName='pyReadProxSensor',
147 | inputInts=[],
148 | inputFloats=[],
149 | inputStrings=[],
150 | inputBuffer='')
151 | return out_ints[0], out_floats[0]
152 |
153 | def wait_until_stop(self, handle, threshold=0.005, time_delay=0.2):
154 | """
155 | Wait until the operation finishes.
156 | This is a delay function called in order to make sure that
157 | the operation executed has been completed.
158 | :param handle: An int.Handle of the object.
159 | :param threshold: A float. The object position threshold.
160 | If the object positions difference between two time steps is smaller than the threshold,
161 | the execution completes, otherwise the loop continues.
162 | :param time_delay:A float. How much time we should wait in an execution step.
163 | :return: None
164 | """
165 | while True:
166 | _, pos1 = vrep.simxGetObjectPosition(self.client_id, handle, -1, vrep.simx_opmode_blocking)
167 | _, quat1 = vrep.simxGetObjectQuaternion(self.client_id, handle, -1, vrep.simx_opmode_blocking)
168 | time.sleep(time_delay)
169 | _, pos2 = vrep.simxGetObjectPosition(self.client_id, handle, -1, vrep.simx_opmode_blocking)
170 | _, quat2 = vrep.simxGetObjectQuaternion(self.client_id, handle, -1, vrep.simx_opmode_blocking)
171 | pose1 = pos1 + quat1
172 | pose2 = pos2 + quat2
173 | theta = 0.5 * sqrt(reduce(lambda x, y: x + y, map(lambda x, y: (x - y) ** 2, pose1, pose2)))
174 | if theta < threshold:
175 | return
176 |
--------------------------------------------------------------------------------
/data_processing/data_preprocessing.py:
--------------------------------------------------------------------------------
1 | from data_processing.data_processor import DataProcessor
2 |
3 | import numpy as np
4 | import cv2
5 |
6 | import argparse
7 | import random
8 | import os
9 |
10 | parser = argparse.ArgumentParser(description='process data')
11 | parser.add_argument('--data_dir',
12 | required=True,
13 | type=str,
14 | help='path to the data')
15 | parser.add_argument('--output',
16 | required=True,
17 | type=str,
18 | help='path to save the process data')
19 | args = parser.parse_args()
20 |
21 |
22 | def main():
23 | if not os.path.exists(args.output):
24 | os.mkdir(args.output)
25 | os.mkdir(os.path.join(args.output, 'color'))
26 | os.mkdir(os.path.join(args.output, 'depth'))
27 | os.mkdir(os.path.join(args.output, 'encoded_depth'))
28 | os.mkdir(os.path.join(args.output, 'label_map'))
29 | os.mkdir(os.path.join(args.output, 'camera_height'))
30 | f_p = open(os.path.join(args.output, 'pos.txt'), 'w')
31 | f_n = open(os.path.join(args.output, 'neg.txt'), 'w')
32 | dp = DataProcessor()
33 |
34 | # load data
35 | data_dir = os.listdir(args.data_dir)
36 | random.shuffle(data_dir)
37 | for data_id in data_dir:
38 | parent_dir = os.path.join(args.data_dir, data_id)
39 | print(parent_dir)
40 | # b_depth_height_map = cv2.imread(os.path.join(parent_dir, 'crop_background_depth.png'),
41 | # cv2.IMREAD_ANYDEPTH).astype(np.float32)
42 | label_files = os.listdir(os.path.join(parent_dir, 'label'))
43 | with open(os.path.join(parent_dir, 'file_name.txt'), 'r') as f:
44 | file_names = f.readlines()
45 | for file_name in file_names:
46 | file_name = file_name[:-2] if file_name[-2:] == '\r\n' else file_name[:-1]
47 | print(file_name)
48 | color = cv2.imread(os.path.join(parent_dir, 'height_map_color', file_name+'.png'))
49 | depth = cv2.imread(os.path.join(parent_dir, 'height_map_depth', file_name+'.png'),
50 | cv2.IMREAD_ANYDEPTH).astype(np.float32)
51 | # diff_depth = dp.get_diff_depth(depth, b_depth_height_map)
52 | # pad color and depth images
53 | pad_size = 44
54 | pad_color = np.ones((288, 288, 3), dtype=np.uint8) * 7
55 | pad_color[pad_size:pad_size + 200, pad_size:pad_size + 200, :] = color
56 | pad_depth = np.ones((288, 288), dtype=np.uint16) * 7
57 | pad_depth[pad_size:pad_size + 200, pad_size:pad_size + 200] = depth
58 | background_depth_value = np.argmax(np.bincount(depth.astype(np.int).flatten()))
59 | camera_height = background_depth_value / 1000.0
60 | neglect_points = np.loadtxt(os.path.join(parent_dir, 'label', file_name+'.object_points.txt')) + pad_size
61 | if file_name+'.good.txt' in label_files:
62 | good_pixel_labels = np.loadtxt(os.path.join(parent_dir, 'label', file_name+'.good.txt')) + pad_size
63 | grasp_centers = dp.get_grasp_center(good_pixel_labels)
64 | angle_indices = dp.get_grasp_angle(good_pixel_labels[0])
65 | for angle_idx in angle_indices:
66 | quantified_angle = 22.5 * angle_idx
67 | for i, angle in enumerate(np.arange(quantified_angle-5, quantified_angle+5, 1)):
68 | grasp_label = np.zeros((288, 288, 3), dtype=np.uint8) # bgr for opencv
69 | grasp_label[..., 0] = 255
70 | rotated_neglect_points = dp.rotate(neglect_points, (144, 144), (angle / 360.0) * np.pi * 2)
71 | rotated_neglect_points = np.round(rotated_neglect_points).astype(np.int)
72 | grasp_label[rotated_neglect_points[:, 0], rotated_neglect_points[:, 1], 0] = 0
73 | grasp_label = cv2.medianBlur(grasp_label, 3)
74 | rotated_grasp_centers = dp.rotate(grasp_centers, (144, 144), (angle / 360.0) * np.pi * 2)
75 | rotated_grasp_centers = np.round(rotated_grasp_centers).astype(np.int)
76 | grasp_label[rotated_grasp_centers[:, 0], rotated_grasp_centers[:, 1], 1] = 255
77 | mtx = cv2.getRotationMatrix2D((144, 144), angle, 1)
78 | rotated_color = cv2.warpAffine(pad_color, mtx, (288, 288),
79 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7))
80 | rotated_depth = cv2.warpAffine(pad_depth, mtx, (288, 288),
81 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7))
82 | encoded_depth = dp.encode_depth(rotated_depth)
83 | cv2.imwrite(os.path.join(args.output, 'label_map',
84 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
85 | grasp_label)
86 | cv2.imwrite(os.path.join(args.output, 'color',
87 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
88 | rotated_color)
89 | cv2.imwrite(os.path.join(args.output, 'depth',
90 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
91 | rotated_depth)
92 | cv2.imwrite(os.path.join(args.output, 'encoded_depth',
93 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
94 | encoded_depth)
95 | np.savetxt(os.path.join(args.output, 'camera_height',
96 | data_id + '-' + file_name + '-{}-{:02d}.txt'.format(angle_idx, i)),
97 | np.array([camera_height * 1000.0]))
98 | f_p.write(data_id + '-' + file_name + '-{}-{:02d}\n'.format(angle_idx, i))
99 | else:
100 | bad_pixel_labels = np.loadtxt(os.path.join(parent_dir, 'label', file_name + '.bad.txt')) + pad_size
101 | grasp_centers = dp.get_grasp_center(bad_pixel_labels)
102 | angle_indices = dp.get_grasp_angle(bad_pixel_labels[0])
103 | for angle_idx in angle_indices:
104 | quantified_angle = 22.5 * angle_idx
105 | for i, angle in enumerate(np.arange(quantified_angle - 5, quantified_angle + 5, 1)):
106 | grasp_label = np.zeros((288, 288, 3), dtype=np.uint8) # bgr for opencv
107 | grasp_label[..., 0] = 255
108 | rotated_neglect_points = dp.rotate(neglect_points, (144, 144), (angle / 360.0) * np.pi * 2)
109 | rotated_neglect_points = np.round(rotated_neglect_points).astype(np.int)
110 | grasp_label[rotated_neglect_points[:, 0], rotated_neglect_points[:, 1], 0] = 0
111 | grasp_label = cv2.medianBlur(grasp_label, 3)
112 | rotated_grasp_centers = dp.rotate(grasp_centers, (144, 144), (angle / 360.0) * np.pi * 2)
113 | rotated_grasp_centers = np.round(rotated_grasp_centers).astype(np.int)
114 | grasp_label[rotated_grasp_centers[:, 0], rotated_grasp_centers[:, 1], 2] = 255
115 | mtx = cv2.getRotationMatrix2D((144, 144), angle, 1)
116 | rotated_color = cv2.warpAffine(pad_color, mtx, (288, 288),
117 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7))
118 | rotated_depth = cv2.warpAffine(pad_depth, mtx, (288, 288),
119 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7))
120 | encoded_depth = dp.encode_depth(rotated_depth)
121 | cv2.imwrite(os.path.join(args.output, 'label_map',
122 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
123 | grasp_label)
124 | cv2.imwrite(os.path.join(args.output, 'color',
125 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
126 | rotated_color)
127 | cv2.imwrite(os.path.join(args.output, 'depth',
128 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
129 | rotated_depth)
130 | cv2.imwrite(os.path.join(args.output, 'encoded_depth',
131 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
132 | encoded_depth)
133 | np.savetxt(os.path.join(args.output, 'camera_height',
134 | data_id + '-' + file_name + '-{}-{:02d}.txt'.format(angle_idx, i)),
135 | np.array([camera_height * 1000.0]))
136 | f_n.write(data_id + '-' + file_name + '-{}-{:02d}\n'.format(angle_idx, i))
137 | f_p.close()
138 | f_n.close()
139 |
140 |
141 | if __name__ == '__main__':
142 | main()
143 |
144 |
--------------------------------------------------------------------------------
/data_processing/mag_data_preprocessing.py:
--------------------------------------------------------------------------------
1 | from data_processing.data_processor import DataProcessor
2 |
3 | import numpy as np
4 | import cv2
5 |
6 | import argparse
7 | import random
8 | import os
9 |
10 | parser = argparse.ArgumentParser(description='process data')
11 | parser.add_argument('--data_dir',
12 | required=True,
13 | type=str,
14 | help='path to the data')
15 | parser.add_argument('--output',
16 | required=True,
17 | type=str,
18 | help='path to save the process data')
19 | args = parser.parse_args()
20 |
21 |
22 | def main():
23 | if not os.path.exists(args.output):
24 | os.mkdir(args.output)
25 | os.mkdir(os.path.join(args.output, 'color'))
26 | os.mkdir(os.path.join(args.output, 'depth'))
27 | os.mkdir(os.path.join(args.output, 'encoded_depth'))
28 | os.mkdir(os.path.join(args.output, 'label_map'))
29 | os.mkdir(os.path.join(args.output, 'camera_height'))
30 | f_p = open(os.path.join(args.output, 'pos.txt'), 'w')
31 | f_n = open(os.path.join(args.output, 'neg.txt'), 'w')
32 | dp = DataProcessor()
33 |
34 | # load data
35 | data_dir = os.listdir(args.data_dir)
36 | random.shuffle(data_dir)
37 | for data_id in data_dir:
38 | parent_dir = os.path.join(args.data_dir, data_id)
39 | print(parent_dir)
40 | # b_depth_height_map = cv2.imread(os.path.join(parent_dir, 'crop_background_depth.png'),
41 | # cv2.IMREAD_ANYDEPTH).astype(np.float32)
42 | label_files = os.listdir(os.path.join(parent_dir, 'label'))
43 | with open(os.path.join(parent_dir, 'file_name.txt'), 'r') as f:
44 | file_names = f.readlines()
45 | for file_name in file_names:
46 | file_name = file_name[:-2] if file_name[-2:] == '\r\n' else file_name[:-1]
47 | print(file_name)
48 | color = cv2.imread(os.path.join(parent_dir, 'height_map_color', file_name+'.png'))
49 | depth = cv2.imread(os.path.join(parent_dir, 'height_map_depth', file_name+'.png'),
50 | cv2.IMREAD_ANYDEPTH).astype(np.float32)
51 | # diff_depth = dp.get_diff_depth(depth, b_depth_height_map)
52 | # pad color and depth images
53 | pad_size = 44
54 | pad_color = np.ones((288, 288, 3), dtype=np.uint8) * 7
55 | pad_color[pad_size:pad_size + 200, pad_size:pad_size + 200, :] = color
56 | pad_depth = np.ones((288, 288), dtype=np.uint16) * 7
57 | pad_depth[pad_size:pad_size + 200, pad_size:pad_size + 200] = depth
58 | background_depth_value = np.argmax(np.bincount(depth.astype(np.int).flatten()))
59 | camera_height = background_depth_value / 1000.0
60 | neglect_points = np.loadtxt(os.path.join(parent_dir, 'label', file_name+'.object_points.txt')) + pad_size
61 | if file_name+'.good.txt' in label_files:
62 | good_pixel_labels = np.loadtxt(os.path.join(parent_dir, 'label', file_name+'.good.txt')) + pad_size
63 | grasp_centers = dp.get_grasp_center(good_pixel_labels)
64 | angle_indices = dp.get_grasp_angle(good_pixel_labels[0])
65 | for angle_idx in angle_indices:
66 | quantified_angle = 22.5 * angle_idx
67 | for i, angle in enumerate(np.arange(quantified_angle-5, quantified_angle+5, 1)):
68 | grasp_label = np.zeros((36, 36, 3), dtype=np.uint8) # bgr for opencv
69 | grasp_label[..., 0] = 255
70 | rotated_neglect_points = dp.rotate(neglect_points, (144, 144), (angle / 360.0) * np.pi * 2)
71 | rotated_neglect_points = np.round(rotated_neglect_points / 8.0).astype(np.int)
72 | grasp_label[rotated_neglect_points[:, 0], rotated_neglect_points[:, 1], 0] = 0
73 | grasp_label = cv2.medianBlur(grasp_label, 3)
74 | rotated_grasp_centers = dp.rotate(grasp_centers, (144, 144), (angle / 360.0) * np.pi * 2)
75 | rotated_grasp_centers = np.round(rotated_grasp_centers / 8.0).astype(np.int)
76 | grasp_label[rotated_grasp_centers[:, 0], rotated_grasp_centers[:, 1], 1] = 255
77 | mtx = cv2.getRotationMatrix2D((144, 144), angle, 1)
78 | rotated_color = cv2.warpAffine(pad_color, mtx, (288, 288),
79 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7))
80 | rotated_depth = cv2.warpAffine(pad_depth, mtx, (288, 288),
81 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7))
82 | encoded_depth = dp.encode_depth(rotated_depth)
83 | cv2.imwrite(os.path.join(args.output, 'label_map',
84 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
85 | grasp_label)
86 | cv2.imwrite(os.path.join(args.output, 'color',
87 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
88 | rotated_color)
89 | cv2.imwrite(os.path.join(args.output, 'depth',
90 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
91 | rotated_depth)
92 | cv2.imwrite(os.path.join(args.output, 'encoded_depth',
93 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
94 | encoded_depth)
95 | np.savetxt(os.path.join(args.output, 'camera_height',
96 | data_id + '-' + file_name + '-{}-{:02d}.txt'.format(angle_idx, i)),
97 | np.array([camera_height * 1000.0]))
98 | f_p.write(data_id + '-' + file_name + '-{}-{:02d}\n'.format(angle_idx, i))
99 | else:
100 | bad_pixel_labels = np.loadtxt(os.path.join(parent_dir, 'label', file_name + '.bad.txt')) + pad_size
101 | grasp_centers = dp.get_grasp_center(bad_pixel_labels)
102 | angle_indices = dp.get_grasp_angle(bad_pixel_labels[0])
103 | for angle_idx in angle_indices:
104 | quantified_angle = 22.5 * angle_idx
105 | for i, angle in enumerate(np.arange(quantified_angle - 5, quantified_angle + 5, 1)):
106 | grasp_label = np.zeros((36, 36, 3), dtype=np.uint8) # bgr for opencv
107 | grasp_label[..., 0] = 255
108 | rotated_neglect_points = dp.rotate(neglect_points, (144, 144), (angle / 360.0) * np.pi * 2)
109 | rotated_neglect_points = np.round(rotated_neglect_points / 8.0).astype(np.int)
110 | grasp_label[rotated_neglect_points[:, 0], rotated_neglect_points[:, 1], 0] = 0
111 | grasp_label = cv2.medianBlur(grasp_label, 3)
112 | rotated_grasp_centers = dp.rotate(grasp_centers, (144, 144), (angle / 360.0) * np.pi * 2)
113 | rotated_grasp_centers = np.round(rotated_grasp_centers / 8.0).astype(np.int)
114 | grasp_label[rotated_grasp_centers[:, 0], rotated_grasp_centers[:, 1], 2] = 255
115 | mtx = cv2.getRotationMatrix2D((144, 144), angle, 1)
116 | rotated_color = cv2.warpAffine(pad_color, mtx, (288, 288),
117 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7))
118 | rotated_depth = cv2.warpAffine(pad_depth, mtx, (288, 288),
119 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7))
120 | encoded_depth = dp.encode_depth(rotated_depth)
121 | cv2.imwrite(os.path.join(args.output, 'label_map',
122 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
123 | grasp_label)
124 | cv2.imwrite(os.path.join(args.output, 'color',
125 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
126 | rotated_color)
127 | cv2.imwrite(os.path.join(args.output, 'depth',
128 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
129 | rotated_depth)
130 | cv2.imwrite(os.path.join(args.output, 'encoded_depth',
131 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
132 | encoded_depth)
133 | np.savetxt(os.path.join(args.output, 'camera_height',
134 | data_id + '-' + file_name + '-{}-{:02d}.txt'.format(angle_idx, i)),
135 | np.array([camera_height * 1000.0]))
136 | f_n.write(data_id + '-' + file_name + '-{}-{:02d}\n'.format(angle_idx, i))
137 | f_p.close()
138 | f_n.close()
139 |
140 |
141 | if __name__ == '__main__':
142 | main()
143 |
144 |
--------------------------------------------------------------------------------
/data_collection/corrective_grasp_trials.py:
--------------------------------------------------------------------------------
1 | from math import pi
2 | from scipy import misc
3 | from data_collection.scene import Scene
4 | from functools import partial
5 | import numpy as np
6 | import cv2
7 | import argparse
8 | import os
9 |
10 | parser = argparse.ArgumentParser(description='Data collection in vrep simulator.')
11 | parser.add_argument('--ip',
12 | default='127.0.0.1',
13 | type=str,
14 | help='ip address to the vrep simulator.')
15 | parser.add_argument('--port',
16 | default=19997,
17 | type=str,
18 | help='port to the vrep simulator.')
19 | parser.add_argument('--obj_id',
20 | default='obj0',
21 | type=str,
22 | help='object name in vrep.')
23 | parser.add_argument('--num_grasp',
24 | default=200,
25 | type=int,
26 | help='the number of grasp trails.')
27 | parser.add_argument('--num_repeat',
28 | default=1,
29 | type=int,
30 | help='the number of repeat time if the gripper successfully grasp an object.')
31 | parser.add_argument('--output',
32 | default='data/3dnet',
33 | type=str,
34 | help='directory to save data.')
35 | args = parser.parse_args()
36 |
37 | step = pi / 180.0
38 | ratio = 0.4 / 200
39 | interval = [0.23, 0.57]
40 | rand = partial(np.random.uniform, interval[0], interval[1])
41 |
42 |
43 | def main():
44 | if not os.path.exists(args.output):
45 | os.makedirs(args.output)
46 | data_id = len(os.listdir(args.output))
47 | data_path = os.path.join(args.output, '{:04d}'.format(data_id))
48 | color_path = os.path.join(data_path, 'color')
49 | depth_path = os.path.join(data_path, 'depth')
50 | height_map_color_path = os.path.join(data_path, 'height_map_color')
51 | height_map_depth_path = os.path.join(data_path, 'height_map_depth')
52 | label_path = os.path.join(data_path, 'label')
53 | os.mkdir(data_path)
54 | os.mkdir(color_path)
55 | os.mkdir(depth_path)
56 | os.mkdir(height_map_color_path)
57 | os.mkdir(height_map_depth_path)
58 | os.mkdir(label_path)
59 |
60 | s = Scene(args.ip, args.port, 'realsense')
61 | init_pos = [45*step, 10*step, 90*step, -10*step, -90*step, -45*step]
62 | top_pos = [-45*step, 10*step, 90*step, -10*step, -90*step, -45*step]
63 | # object handle
64 | obj = s.get_object_handle(args.obj_id)
65 | obj_quat = s.get_object_quaternion(obj)
66 | s.gripper.open_gripper()
67 |
68 | misc.imsave(data_path + '/background_color.png', s.background_color)
69 | cv2.imwrite(data_path + '/background_depth.png', s.background_depth)
70 | misc.imsave(data_path+'/crop_background_color.png',
71 | cv2.resize(s.background_color[45:435, 119:509, :], (200, 200)))
72 | cv2.imwrite(data_path+'/crop_background_depth.png',
73 | cv2.resize(s.background_depth[45:435, 119:509], (200, 200)))
74 |
75 | f = open(os.path.join(data_path, 'file_name.txt'), 'w')
76 |
77 | s.set_object_position(obj, [rand(), rand(), s.get_object_position(obj)[-1]])
78 |
79 | for i in range(args.num_grasp): # number of grasp trials.
80 | print(i)
81 | s.gripper.open_gripper()
82 | s.ur5.move_to_joint_position(init_pos)
83 | s.replace_object(obj, obj_quat, interval)
84 | color = s.get_color_image()
85 | depth = s.get_depth_image()
86 | center_point, object_points, reaching_height = s.get_center_from_image(depth)
87 | s.ur5.move_to_joint_position(top_pos)
88 | episode = []
89 | # move to the upward side with respect to the object
90 | init_quat = s.ur5.get_end_effector_quaternion()
91 | target_pos = s.get_object_position(obj)
92 | target_pos[0] = 0.200 + center_point[0] * ratio
93 | target_pos[1] = 0.200 + center_point[1] * ratio
94 | target_pos[2] = reaching_height
95 | target_pos[-1] += 0.1563
96 | episode.append(target_pos + init_quat)
97 | s.ur5.move_to_object_position(target_pos + init_quat)
98 | # randomly rotate gripper's angle
99 | joint_angles = s.ur5.get_joint_positions()
100 | curr_angle = joint_angles[-1] / np.pi * 180.0
101 | angle = np.random.randint(0, 180)
102 | joint_angles[-1] = (curr_angle + angle) * step
103 | s.ur5.move_to_joint_position(joint_angles)
104 | quat = s.ur5.get_end_effector_quaternion()
105 | episode.append(target_pos + quat)
106 | # move to the object downward
107 | target_pos[-1] -= 0.1563
108 | episode.append(target_pos + quat)
109 | s.ur5.move_to_object_position(target_pos + quat)
110 | # try to grasp object
111 | s.gripper.close_gripper()
112 | # determine whether the object is successfully grasped.
113 | if s.gripper.get_object_detection_status():
114 | print('grasp success.')
115 | s.gripper.attach_object(obj)
116 | width = s.gripper.get_gripper_width()
117 | s.gripper.open_gripper()
118 | s.gripper.untie_object(obj)
119 | s.ur5.move_to_object_position(episode[1])
120 | s.ur5.move_to_joint_position(init_pos)
121 | # if we successfully grasp the object,
122 | # we can record the relative grasp configuration and successively collect the positive samples.
123 | for j in range(args.num_repeat): # the number of positive samples we want to collect in this grasp trials.
124 | if s.replace_object(obj, obj_quat, interval):
125 | break
126 | # rerecord the pattern.
127 | color = s.get_color_image()
128 | depth = s.get_depth_image()
129 | _, object_points, reaching_height = s.get_center_from_image(depth)
130 | crop_color = cv2.resize(color[45:435, 119:509, :], (200, 200))
131 | crop_depth = cv2.resize(depth[45:435, 119:509], (200, 200))
132 | points = s.get_label(center_point, -angle*np.pi/180.0, width, ratio)
133 | label = s.draw_label(points, 200, 200)
134 | s.ur5.move_to_joint_position(top_pos)
135 | s.ur5.move_to_object_position(episode[0])
136 | s.ur5.move_to_object_position(episode[1])
137 | s.ur5.move_to_object_position(episode[2])
138 | s.gripper.close_gripper()
139 | if s.gripper.get_object_detection_status():
140 | s.gripper.attach_object(obj)
141 | s.ur5.move_to_object_position(episode[1])
142 | # quat = s.ur5.get_end_effector_quaternion()
143 | pos = s.ur5.get_end_effector_position()
144 | pos[0], pos[1] = rand(), rand()
145 | center_point = [int((pos[0] - 0.200) / ratio), int((pos[1] - 0.200) / ratio)]
146 | episode[0] = pos + init_quat
147 | s.ur5.move_to_object_position(pos + init_quat)
148 | joint_angles = s.ur5.get_joint_positions()
149 | curr_angle = joint_angles[-1] / np.pi * 180.0
150 | angle = np.random.randint(0, 180)
151 | joint_angles[-1] = (curr_angle + angle) * step
152 | s.ur5.move_to_joint_position(joint_angles)
153 | quat = s.ur5.get_end_effector_quaternion()
154 | pos = s.ur5.get_end_effector_position()
155 | episode[1] = pos + quat
156 | pos[-1] = episode[2][2]
157 | episode[2] = pos + quat
158 | s.ur5.move_to_object_position(pos + quat)
159 | s.gripper.open_gripper()
160 | s.gripper.untie_object(obj)
161 | s.ur5.move_to_object_position(episode[1])
162 | s.ur5.move_to_joint_position(init_pos)
163 | misc.imsave(color_path + '/{:06d}_{:04d}.png'.format(i, j), color)
164 | cv2.imwrite(depth_path + '/{:06d}_{:04d}.png'.format(i, j), depth)
165 | misc.imsave(height_map_color_path + '/{:06d}_{:04d}.png'.format(i, j), crop_color)
166 | cv2.imwrite(height_map_depth_path + '/{:06d}_{:04d}.png'.format(i, j), crop_depth)
167 | misc.imsave(label_path + '/{:06d}_{:04d}.png'.format(i, j), label)
168 | np.savetxt(label_path + '/{:06d}_{:04d}.good.txt'.format(i, j), points)
169 | np.savetxt(label_path + '/{:06d}_{:04d}.object_points.txt'.format(i, j), np.asanyarray(object_points))
170 | f.write('{:06d}_{:04d}\n'.format(i, j))
171 | else:
172 | break
173 |
174 | else:
175 | print('grasp failed')
176 | s.gripper.open_gripper()
177 | s.ur5.move_to_object_position(episode[1])
178 | crop_color = cv2.resize(color[45:435, 119:509, :], (200, 200))
179 | crop_depth = cv2.resize(depth[45:435, 119:509], (200, 200))
180 | points = s.get_label(center_point, -angle * np.pi / 180.0, 0.085, ratio)
181 | label = s.draw_label(points, 200, 200, (255, 0, 0))
182 | misc.imsave(color_path + '/{:06d}.png'.format(i), color)
183 | cv2.imwrite(depth_path + '/{:06d}.png'.format(i), depth)
184 | misc.imsave(height_map_color_path + '/{:06d}.png'.format(i), crop_color)
185 | cv2.imwrite(height_map_depth_path + '/{:06d}.png'.format(i), crop_depth)
186 | misc.imsave(label_path + '/{:06d}.png'.format(i), label)
187 | np.savetxt(label_path + '/{:06d}.bad.txt'.format(i), points)
188 | np.savetxt(label_path + '/{:06d}.object_points.txt'.format(i), np.asanyarray(object_points))
189 | f.write('{:06d}\n'.format(i))
190 |
191 | s.stop_simulation()
192 | f.close()
193 |
194 |
195 | if __name__ == '__main__':
196 | main()
197 |
--------------------------------------------------------------------------------
/data_processing/generate_pseudo_data_v1.py:
--------------------------------------------------------------------------------
1 | from data_processing.data_processor import DataProcessor
2 |
3 | import numpy as np
4 | import cv2
5 |
6 | import argparse
7 | import random
8 | import os
9 |
10 | parser = argparse.ArgumentParser(description='process labels')
11 | parser.add_argument('--data_dir',
12 | required=True,
13 | type=str,
14 | help='path to the data')
15 | parser.add_argument('--output',
16 | required=True,
17 | type=str,
18 | help='path to save the process data')
19 | args = parser.parse_args()
20 |
21 |
22 | def main():
23 | if not os.path.exists(args.output):
24 | os.mkdir(args.output)
25 | os.mkdir(os.path.join(args.output, 'color'))
26 | os.mkdir(os.path.join(args.output, 'depth'))
27 | os.mkdir(os.path.join(args.output, 'encoded_depth'))
28 | os.mkdir(os.path.join(args.output, 'label_map'))
29 | os.mkdir(os.path.join(args.output, 'camera_height'))
30 | f_p = open(os.path.join(args.output, 'pos.txt'), 'w')
31 | f_n = open(os.path.join(args.output, 'neg.txt'), 'w')
32 | dp = DataProcessor()
33 |
34 | # load data
35 | data_dir = os.listdir(args.data_dir)
36 | random.shuffle(data_dir)
37 | for data_id in data_dir:
38 | parent_dir = os.path.join(args.data_dir, data_id)
39 | print(parent_dir)
40 | # b_depth_height_map = cv2.imread(os.path.join(parent_dir, 'crop_background_depth.png'),
41 | # cv2.IMREAD_ANYDEPTH).astype(np.float32)
42 | label_files = os.listdir(os.path.join(parent_dir, 'zoomed_label'))
43 | with open(os.path.join(parent_dir, 'zoomed_file_name.txt'), 'r') as f:
44 | file_names = f.readlines()
45 | for file_name in file_names:
46 | print(file_name)
47 | file_name = file_name[:-2] if file_name[-2:] == '\r\n' else file_name[:-1]
48 | color = cv2.imread(os.path.join(parent_dir, 'zoomed_height_map_color', file_name+'.png'))
49 | depth = cv2.imread(os.path.join(parent_dir, 'zoomed_height_map_depth', file_name+'.png'),
50 | cv2.IMREAD_ANYDEPTH).astype(np.float32)
51 | # diff_depth = dp.get_diff_depth(depth, b_depth_height_map)
52 | # pad color and depth images
53 | pad_size = 44
54 | pad_color = np.ones((288, 288, 3), dtype=np.uint8) * 7
55 | pad_color[pad_size:pad_size + 200, pad_size:pad_size + 200, :] = color
56 | pad_depth = np.ones((288, 288), dtype=np.uint16) * 7
57 | # pad_depth[pad_size:pad_size + 200, pad_size:pad_size + 200] = diff_depth
58 | pad_depth[pad_size:pad_size + 200, pad_size:pad_size + 200] = depth
59 | background_depth_value = np.argmax(np.bincount(depth.astype(np.int).flatten()))
60 | camera_height = background_depth_value / 1000.0
61 | neglect_points = np.loadtxt(os.path.join(parent_dir, 'zoomed_label', file_name+'.object_points.txt')) + pad_size
62 | if len(neglect_points) == 0:
63 | continue
64 | if file_name+'.good.txt' in label_files:
65 | good_pixel_labels = np.loadtxt(os.path.join(parent_dir, 'zoomed_label', file_name+'.good.txt')) + pad_size
66 | grasp_centers = dp.get_grasp_center(good_pixel_labels)
67 | angle_indices = dp.get_grasp_angle(good_pixel_labels[0])
68 | for angle_idx in angle_indices:
69 | quantified_angle = 22.5 * angle_idx
70 | for i, angle in enumerate(np.arange(quantified_angle-5, quantified_angle+5, 1)):
71 | grasp_label = np.zeros((288, 288, 3), dtype=np.uint8) # bgr for opencv
72 | grasp_label[..., 0] = 255
73 | rotated_neglect_points = dp.rotate(neglect_points, (144, 144), (angle / 360.0) * np.pi * 2)
74 | rotated_neglect_points = np.round(rotated_neglect_points).astype(np.int)
75 | grasp_label[rotated_neglect_points[:, 0], rotated_neglect_points[:, 1], 0] = 0
76 | grasp_label = cv2.medianBlur(grasp_label, 3)
77 | rotated_grasp_centers = dp.rotate(grasp_centers, (144, 144), (angle / 360.0) * np.pi * 2)
78 | rotated_grasp_centers = np.round(rotated_grasp_centers).astype(np.int)
79 | grasp_label = dp.gaussianize_label(grasp_label,
80 | rotated_grasp_centers,
81 | camera_height,
82 | is_good=True)
83 | # grasp_label[rotated_grasp_centers[:, 0], rotated_grasp_centers[:, 1], 1] = 255
84 | mtx = cv2.getRotationMatrix2D((144, 144), angle, 1)
85 | rotated_color = cv2.warpAffine(pad_color, mtx, (288, 288),
86 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7))
87 | rotated_depth = cv2.warpAffine(pad_depth, mtx, (288, 288),
88 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7))
89 | encoded_depth = dp.encode_depth(rotated_depth)
90 | cv2.imwrite(os.path.join(args.output, 'label_map',
91 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
92 | grasp_label)
93 | cv2.imwrite(os.path.join(args.output, 'color',
94 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
95 | rotated_color)
96 | cv2.imwrite(os.path.join(args.output, 'depth',
97 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
98 | rotated_depth)
99 | cv2.imwrite(os.path.join(args.output, 'encoded_depth',
100 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
101 | encoded_depth)
102 | np.savetxt(os.path.join(args.output, 'camera_height',
103 | data_id + '-' + file_name + '-{}-{:02d}.txt'.format(angle_idx, i)),
104 | np.array([camera_height * 1000.0]))
105 | f_p.write(data_id + '-' + file_name + '-{}-{:02d}\n'.format(angle_idx, i))
106 | else:
107 | bad_pixel_labels = np.loadtxt(os.path.join(parent_dir, 'zoomed_label', file_name + '.bad.txt')) + pad_size
108 | grasp_centers = dp.get_grasp_center(bad_pixel_labels)
109 | angle_indices = dp.get_grasp_angle(bad_pixel_labels[0])
110 | for angle_idx in angle_indices:
111 | quantified_angle = 22.5 * angle_idx
112 | for i, angle in enumerate(np.arange(quantified_angle - 5, quantified_angle + 5, 1)):
113 | grasp_label = np.zeros((288, 288, 3), dtype=np.uint8) # bgr for opencv
114 | grasp_label[..., 0] = 255
115 | rotated_neglect_points = dp.rotate(neglect_points, (144, 144), (angle / 360.0) * np.pi * 2)
116 | rotated_neglect_points = np.round(rotated_neglect_points).astype(np.int)
117 | grasp_label[rotated_neglect_points[:, 0], rotated_neglect_points[:, 1], 0] = 0
118 | grasp_label = cv2.medianBlur(grasp_label, 3)
119 | rotated_grasp_centers = dp.rotate(grasp_centers, (144, 144), (angle / 360.0) * np.pi * 2)
120 | rotated_grasp_centers = np.round(rotated_grasp_centers).astype(np.int)
121 | grasp_label = dp.gaussianize_label(grasp_label,
122 | rotated_grasp_centers,
123 | camera_height,
124 | is_good=False)
125 | # grasp_label[rotated_grasp_centers[:, 0], rotated_grasp_centers[:, 1], 2] = 255
126 | mtx = cv2.getRotationMatrix2D((144, 144), angle, 1)
127 | rotated_color = cv2.warpAffine(pad_color, mtx, (288, 288),
128 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7))
129 | rotated_depth = cv2.warpAffine(pad_depth, mtx, (288, 288),
130 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7))
131 | encoded_depth = dp.encode_depth(rotated_depth)
132 | cv2.imwrite(os.path.join(args.output, 'label_map',
133 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
134 | grasp_label)
135 | cv2.imwrite(os.path.join(args.output, 'color',
136 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
137 | rotated_color)
138 | cv2.imwrite(os.path.join(args.output, 'depth',
139 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
140 | rotated_depth)
141 | cv2.imwrite(os.path.join(args.output, 'encoded_depth',
142 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
143 | encoded_depth)
144 | np.savetxt(os.path.join(args.output, 'camera_height',
145 | data_id + '-' + file_name + '-{}-{:02d}.txt'.format(angle_idx, i)),
146 | np.array([camera_height * 1000.0]))
147 | f_n.write(data_id + '-' + file_name + '-{}-{:02d}\n'.format(angle_idx, i))
148 | f_p.close()
149 | f_n.close()
150 |
151 |
152 | if __name__ == '__main__':
153 | main()
154 |
155 |
--------------------------------------------------------------------------------
/data_collection/scene.py:
--------------------------------------------------------------------------------
1 | from data_collection.ur import UR5
2 | from data_collection.rg2 import RG2
3 | import numpy as np
4 | import data_collection.vrep as vrep
5 | import time
6 | import cv2
7 |
8 |
9 | class Scene(object):
10 | def __init__(self, ip='127.0.0.1',
11 | port=19997,
12 | camera_handle='realsense'):
13 | """
14 | Initialization.Connect to remote server, initialize ur5 and gripper.
15 | :param ip: The ip address where the server is located.
16 | :param port: The port number where to connected.
17 | :param camera_handle: The camera handle id.
18 | """
19 | vrep.simxFinish(-1)
20 | self.client_id = vrep.simxStart(ip, port, True, True, 5000, 5)
21 | vrep.simxStartSimulation(self.client_id, vrep.simx_opmode_blocking)
22 | self.ur5 = UR5(self.client_id)
23 | self.gripper = RG2(self.client_id)
24 | self.ur5.initialization()
25 | _, self.camera_handle = vrep.simxGetObjectHandle(self.client_id, camera_handle, vrep.simx_opmode_blocking)
26 | self.camera_height = self.get_object_position(self.camera_handle)[-1]
27 |
28 | self.background_color = self.get_color_image()
29 | self.background_depth = self.get_depth_image()
30 |
31 | def get_object_handle(self, handle_name):
32 | """
33 | Get object handle.
34 | :param handle_name: A string. Handle name with respect to specific object handle in vrep.
35 | :return: The number of object handle.
36 | """
37 | _, obj = vrep.simxGetObjectHandle(self.client_id, handle_name, vrep.simx_opmode_blocking)
38 | return obj
39 |
40 | def set_object_position(self, object_handle, position, relative_to_object_handle=-1):
41 | """
42 | Set object position.
43 | :param object_handle: Handle of the object.
44 | :param position: The position value.
45 | :param relative_to_object_handle: Indicates relative to which reference frame the position is specified.
46 | :return: None.
47 | """
48 | vrep.simxSetObjectPosition(self.client_id,
49 | object_handle,
50 | relative_to_object_handle,
51 | position,
52 | vrep.simx_opmode_oneshot)
53 |
54 | def get_object_position(self, object_handle, relative_to_object_handle=-1):
55 | """
56 | Get object position.
57 | :param object_handle: An int. Handle of the object.
58 | :param relative_to_object_handle: An Int. Indicates relative to which reference frame the position is specified.
59 | :return: A list of float. The position of object.
60 | """
61 | _, _, pos, _, _ = self.ur5.func(functionName='pyGetObjectQPosition',
62 | inputInts=[object_handle, relative_to_object_handle],
63 | inputFloats=[],
64 | inputStrings=[],
65 | inputBuffer='')
66 | return pos
67 |
68 | def set_object_quaternion(self, object_handle, quaternion, relative_to_object_handle=-1):
69 | """
70 | Set object quaternion.
71 | :param object_handle: An int. Handle of the object.
72 | :param quaternion: A list of float. The quaternion value (x, y, z, w)
73 | :param relative_to_object_handle: An Int. Indicates relative to which reference frame the object is specified.
74 | :return: None.
75 | """
76 | vrep.simxSetObjectQuaternion(self.client_id,
77 | object_handle,
78 | relative_to_object_handle,
79 | quaternion,
80 | vrep.simx_opmode_oneshot)
81 |
82 | def get_object_quaternion(self, object_handle, relative_to_object_handle=-1):
83 | """
84 | Get object quaternion.
85 | :param object_handle: An int. Handel of the object.
86 | :param relative_to_object_handle: An int. Indicates relative to which refer
87 | :return: A list of float. The quaternion of object.
88 | """
89 | _, _, quat, _, _ = self.ur5.func(functionName='pyGetObjectQuaternion',
90 | inputInts=[object_handle, relative_to_object_handle],
91 | inputFloats=[],
92 | inputStrings=[],
93 | inputBuffer='')
94 | return quat
95 |
96 | def replace_object(self, object_handle, object_quat, interval):
97 | """
98 | Replace object if it is out of the workspace.
99 | :param object_handle: An int. Handel of the object.
100 | :param object_quat: A list of float. The quaternion which the object should rotate.
101 | :param interval: A list of float. The interval of position where the object should be placed.
102 | :return: A boolean. A boolean value indicating whether the object is replaced.
103 | """
104 | pos = self.get_object_position(object_handle)
105 | if not interval[0] <= pos[0] <= interval[1] or not interval[0] <= pos[1] <= interval[1]:
106 | pos[0] = np.random.uniform(interval[0], interval[1])
107 | pos[1] = np.random.uniform(interval[0], interval[1])
108 | self.set_object_position(object_handle, pos)
109 | self.set_object_quaternion(object_handle, object_quat)
110 | time.sleep(0.7)
111 | return True
112 | return False
113 |
114 | def get_color_image(self):
115 | """
116 | Get color image.
117 | :return: A numpy array of uint8. The RGB image containing the whole workspace.
118 | """
119 | res, resolution, color = vrep.simxGetVisionSensorImage(self.client_id,
120 | self.camera_handle,
121 | 0,
122 | vrep.simx_opmode_blocking)
123 | return np.asanyarray(color, dtype=np.uint8).reshape(resolution[1], resolution[0], 3)[::-1, ...]
124 |
125 | def get_depth_image(self):
126 | """
127 | Get depth image.
128 | :return: A numpy array of uint16. The depth image containing the whole workspace.
129 | """
130 | res, resolution, depth = vrep.simxGetVisionSensorDepthBuffer(self.client_id,
131 | self.camera_handle,
132 | vrep.simx_opmode_blocking)
133 | depth = 1000 * np.asanyarray(depth)
134 | return depth.astype(np.uint16).reshape(resolution[1], resolution[0])[::-1, ...]
135 |
136 | def get_center_from_image(self, depth):
137 | """
138 | Get grasp position which belongs to the object in image space.
139 | :param depth: A numpy array of uint16. The depth image.
140 | :return: center_point: A numpy array of float32. The position of the object which the gripper will reach.
141 | :return: points: A numpy array of float32. The object positions which belong to the object in image space.
142 | :return: A float. The distance between the camera and the center point of object which the gripper will reach.
143 | """
144 | crop_depth = depth[45:435, 119:509]
145 | crop_b_depth = self.background_depth[45:435, 119:509]
146 | resized_depth = cv2.resize(crop_depth, (200, 200)).astype(np.float32)
147 | resized_b_depth = cv2.resize(crop_b_depth, (200, 200)).astype(np.float32)
148 | diff_depth = np.abs(resized_depth - resized_b_depth)
149 | cv2.medianBlur(diff_depth, 5)
150 | points = np.stack(np.where(diff_depth > 0), axis=0).T
151 | center_point = points[np.random.randint(0, points.shape[0])].tolist()
152 | reaching_height = max(0.02, diff_depth[center_point[0], center_point[1]]*0.0001 - 0.027)
153 | return center_point, points, reaching_height
154 |
155 | def stop_simulation(self):
156 | """
157 | Stop vrep simulation.
158 | :return: None.
159 | """
160 | vrep.simxStopSimulation(self.client_id, vrep.simx_opmode_oneshot_wait)
161 | vrep.simxFinish(self.client_id)
162 |
163 | @staticmethod
164 | def get_label(center_point, angle, width, ratio):
165 | """
166 | Obtain grasp labels. Each grasp label is represented as two points.
167 | :param center_point: A numpy array of float32. The position of the object which the gripper will reach.
168 | :param angle: A float. The rotated angle of the gripper.
169 | :param width: A float. The width between two finger tips.
170 | :param ratio: A float. The ratio of the width of workspace and the width of image.
171 | :return: A numpy array of float. The grasp points.
172 | """
173 | p_width = width / ratio
174 | c_h = center_point[0]
175 | c_w = center_point[1]
176 | h = [delta_h for delta_h in range(-3, 4)]
177 | w = [-p_width / 2, p_width / 2]
178 | points = np.asanyarray([[hh, w[0], hh, w[1]] for hh in h])
179 | rotate_matrix = np.array([[np.cos(angle), -np.sin(angle)],
180 | [np.sin(angle), np.cos(angle)]])
181 | points[:, 0:2] = np.dot(rotate_matrix, points[:, 0:2].T).T
182 | points[:, 2:] = np.dot(rotate_matrix, points[:, 2:].T).T
183 | points = points + np.asanyarray([[c_h, c_w, c_h, c_w]])
184 | points = np.floor(points).astype(np.int)
185 | return points
186 |
187 | @staticmethod
188 | def draw_label(points, width, height, color=(0, 255, 0)):
189 | """
190 | Draw labels according to the grasp points.
191 | :param points: A numpy array of float. The grasp points.
192 | :param width: A float. The width of image.
193 | :param height: A float. The height of image.
194 | :param color: A tuple. The color of lines.
195 | :return: A numpy array of uint8. The label map.
196 | """
197 | label = np.ones((height, width, 3), dtype=np.uint8) * 255
198 | for point in points:
199 | pt1 = (point[1], point[0])
200 | pt2 = (point[3], point[2])
201 | cv2.line(label, pt1, pt2, color)
202 | return label
203 |
--------------------------------------------------------------------------------
/data_processing/genarate_pseudo_data_v2.py:
--------------------------------------------------------------------------------
1 | from data_processing.data_processor import DataProcessor
2 |
3 | import numpy as np
4 | import cv2
5 |
6 | import argparse
7 | import random
8 | import copy
9 | import os
10 |
11 | parser = argparse.ArgumentParser(description='process labels')
12 | parser.add_argument('--data_dir',
13 | required=True,
14 | type=str,
15 | help='path to the data')
16 | parser.add_argument('--output',
17 | required=True,
18 | type=str,
19 | help='path to save the process data')
20 | args = parser.parse_args()
21 |
22 |
23 | def main():
24 | if not os.path.exists(args.output):
25 | os.mkdir(args.output)
26 | os.mkdir(os.path.join(args.output, 'color'))
27 | os.mkdir(os.path.join(args.output, 'depth'))
28 | os.mkdir(os.path.join(args.output, 'encoded_depth'))
29 | os.mkdir(os.path.join(args.output, 'label_map'))
30 | os.mkdir(os.path.join(args.output, 'camera_height'))
31 | f_p = open(os.path.join(args.output, 'pos.txt'), 'w')
32 | f_n = open(os.path.join(args.output, 'neg.txt'), 'w')
33 | dp = DataProcessor()
34 |
35 | # load data
36 | data_dir = os.listdir(args.data_dir)
37 | random.shuffle(data_dir)
38 | for data_id in data_dir:
39 | parent_dir = os.path.join(args.data_dir, data_id)
40 | print(parent_dir)
41 | # b_depth_height_map = cv2.imread(os.path.join(parent_dir, 'crop_background_depth.png'),
42 | # cv2.IMREAD_ANYDEPTH).astype(np.float32)
43 | label_files = os.listdir(os.path.join(parent_dir, 'zoomed_label'))
44 | with open(os.path.join(parent_dir, 'zoomed_file_name.txt'), 'r') as f:
45 | file_names = f.readlines()
46 | for file_name in file_names:
47 | print(file_name)
48 | file_name = file_name[:-2] if file_name[-2:] == '\r\n' else file_name[:-1]
49 | color = cv2.imread(os.path.join(parent_dir, 'zoomed_height_map_color', file_name+'.png'))
50 | depth = cv2.imread(os.path.join(parent_dir, 'zoomed_height_map_depth', file_name+'.png'),
51 | cv2.IMREAD_ANYDEPTH).astype(np.float32)
52 | # pad color and depth images
53 | pad_size = 44
54 | pad_color = np.ones((288, 288, 3), dtype=np.uint8) * 7
55 | pad_color[pad_size:pad_size + 200, pad_size:pad_size + 200, :] = color
56 | pad_depth = np.ones((288, 288), dtype=np.uint16) * 7
57 | pad_depth[pad_size:pad_size + 200, pad_size:pad_size + 200] = depth
58 | background_depth_value = np.argmax(np.bincount(depth.astype(np.int).flatten()))
59 | camera_height = background_depth_value / 1000.0
60 | neglect_points = np.loadtxt(os.path.join(parent_dir, 'zoomed_label', file_name+'.object_points.txt')) + pad_size
61 | if len(neglect_points) == 0:
62 | continue
63 | if file_name+'.good.txt' in label_files:
64 | good_pixel_labels = np.loadtxt(os.path.join(parent_dir, 'zoomed_label', file_name+'.good.txt')) + pad_size
65 | grasp_centers = dp.get_grasp_center(good_pixel_labels)
66 | angle_indices = dp.get_grasp_angle(good_pixel_labels[0])
67 | for angle_idx in angle_indices:
68 | quantified_angle = 22.5 * angle_idx
69 | for i, angle in enumerate(np.arange(quantified_angle-3, quantified_angle+3, 1)):
70 | grasp_label = np.zeros((288, 288, 3), dtype=np.uint8) # bgr for opencv
71 | grasp_label[..., 0] = 255
72 | rotated_neglect_points = dp.rotate(neglect_points, (144, 144), (angle / 360.0) * np.pi * 2)
73 | rotated_neglect_points = np.round(rotated_neglect_points).astype(np.int)
74 | grasp_label[rotated_neglect_points[:, 0], rotated_neglect_points[:, 1], 0] = 0
75 | grasp_label = cv2.medianBlur(grasp_label, 3)
76 | rotated_grasp_centers = dp.rotate(grasp_centers, (144, 144), (angle / 360.0) * np.pi * 2)
77 | rotated_grasp_centers = np.round(rotated_grasp_centers).astype(np.int)
78 | for j in range(3):
79 | camera_height = np.random.uniform(0.20, 0.63)
80 | new_depth = int(camera_height * 1000.0) - (background_depth_value - depth)
81 | pad_depth[pad_size:pad_size + 200, pad_size:pad_size + 200] = new_depth
82 | gaussianized_label = dp.gaussianize_label(copy.deepcopy(grasp_label),
83 | rotated_grasp_centers,
84 | camera_height,
85 | is_good=True)
86 | # grasp_label[rotated_grasp_centers[:, 0], rotated_grasp_centers[:, 1], 1] = 255
87 | mtx = cv2.getRotationMatrix2D((144, 144), angle, 1)
88 | rotated_color = cv2.warpAffine(pad_color, mtx, (288, 288),
89 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7))
90 | rotated_depth = cv2.warpAffine(pad_depth, mtx, (288, 288),
91 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7))
92 | encoded_depth = dp.encode_depth(rotated_depth)
93 | cv2.imwrite(os.path.join(args.output, 'label_map',
94 | data_id + '-' + file_name + '-{}-{:02d}-{:02d}.png'.format(angle_idx, i, j)),
95 | gaussianized_label)
96 | cv2.imwrite(os.path.join(args.output, 'color',
97 | data_id + '-' + file_name + '-{}-{:02d}-{:02d}.png'.format(angle_idx, i, j)),
98 | rotated_color)
99 | cv2.imwrite(os.path.join(args.output, 'depth',
100 | data_id + '-' + file_name + '-{}-{:02d}-{:02d}.png'.format(angle_idx, i, j)),
101 | rotated_depth)
102 | cv2.imwrite(os.path.join(args.output, 'encoded_depth',
103 | data_id + '-' + file_name + '-{}-{:02d}-{:02d}.png'.format(angle_idx, i, j)),
104 | encoded_depth)
105 | np.savetxt(os.path.join(args.output, 'camera_height',
106 | data_id + '-' + file_name + '-{}-{:02d}-{:02d}.txt'.format(angle_idx, i, j)),
107 | np.array([camera_height*1000.0]))
108 | f_p.write(data_id + '-' + file_name + '-{}-{:02d}-{:02d}\n'.format(angle_idx, i, j))
109 | else:
110 | bad_pixel_labels = np.loadtxt(os.path.join(parent_dir, 'zoomed_label', file_name + '.bad.txt')) + pad_size
111 | grasp_centers = dp.get_grasp_center(bad_pixel_labels)
112 | angle_indices = dp.get_grasp_angle(bad_pixel_labels[0])
113 | for angle_idx in angle_indices:
114 | quantified_angle = 22.5 * angle_idx
115 | for i, angle in enumerate(np.arange(quantified_angle - 5, quantified_angle + 5, 1)):
116 | grasp_label = np.zeros((288, 288, 3), dtype=np.uint8) # bgr for opencv
117 | grasp_label[..., 0] = 255
118 | rotated_neglect_points = dp.rotate(neglect_points, (144, 144), (angle / 360.0) * np.pi * 2)
119 | rotated_neglect_points = np.round(rotated_neglect_points).astype(np.int)
120 | grasp_label[rotated_neglect_points[:, 0], rotated_neglect_points[:, 1], 0] = 0
121 | grasp_label = cv2.medianBlur(grasp_label, 3)
122 | rotated_grasp_centers = dp.rotate(grasp_centers, (144, 144), (angle / 360.0) * np.pi * 2)
123 | rotated_grasp_centers = np.round(rotated_grasp_centers).astype(np.int)
124 | grasp_label = dp.gaussianize_label(grasp_label,
125 | rotated_grasp_centers,
126 | camera_height,
127 | is_good=False)
128 | # grasp_label[rotated_grasp_centers[:, 0], rotated_grasp_centers[:, 1], 2] = 255
129 | mtx = cv2.getRotationMatrix2D((144, 144), angle, 1)
130 | rotated_color = cv2.warpAffine(pad_color, mtx, (288, 288),
131 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7))
132 | rotated_depth = cv2.warpAffine(pad_depth, mtx, (288, 288),
133 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7))
134 | encoded_depth = dp.encode_depth(rotated_depth)
135 | cv2.imwrite(os.path.join(args.output, 'label_map',
136 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
137 | grasp_label)
138 | cv2.imwrite(os.path.join(args.output, 'color',
139 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
140 | rotated_color)
141 | cv2.imwrite(os.path.join(args.output, 'depth',
142 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
143 | rotated_depth)
144 | cv2.imwrite(os.path.join(args.output, 'encoded_depth',
145 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)),
146 | encoded_depth)
147 | np.savetxt(os.path.join(args.output, 'camera_height',
148 | data_id + '-' + file_name + '-{}-{:02d}.txt'.format(angle_idx, i)),
149 | np.array([camera_height * 1000.0]))
150 | f_n.write(data_id + '-' + file_name + '-{}-{:02d}\n'.format(angle_idx, i))
151 | f_p.close()
152 | f_n.close()
153 |
154 |
155 | if __name__ == '__main__':
156 | main()
157 |
158 |
--------------------------------------------------------------------------------
/nets/resnet_utils.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import tensorflow as tf
3 |
4 | slim = tf.contrib.slim
5 |
6 |
7 | class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])):
8 | """A named tuple describing a ResNet block.
9 |
10 | Its parts are:
11 | scope: The scope of the `Block`.
12 | unit_fn: The ResNet unit function which takes as input a `Tensor` and
13 | returns another `Tensor` with the output of the ResNet unit.
14 | args: A list of length equal to the number of units in the `Block`. The list
15 | contains one (depth, depth_bottleneck, stride) tuple for each unit in the
16 | block to serve as argument to unit_fn.
17 | """
18 |
19 |
20 | def subsample(inputs, factor, scope=None):
21 | """Subsamples the input along the spatial dimensions.
22 |
23 | Args:
24 | inputs: A `Tensor` of size [batch, height_in, width_in, channels].
25 | factor: The subsampling factor.
26 | scope: Optional variable_scope.
27 |
28 | Returns:
29 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the
30 | input, either intact (if factor == 1) or subsampled (if factor > 1).
31 | """
32 | if factor == 1:
33 | return inputs
34 | else:
35 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope)
36 |
37 |
38 | def upsample(inputs, factor, scope=None):
39 | """Upsamples the input along the spatial dimensions.
40 |
41 | Args:
42 | inputs: A `Tensor` of size [batch, height_in, width_in, channels].
43 | factor: The subsampling factor.
44 | scope: Optional variable_scope.
45 |
46 | Returns:
47 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the
48 | input, either intact (if factor == 1) or subsampled (if factor > 1).
49 | """
50 | if factor == 1:
51 | return inputs
52 | else:
53 | height, width = inputs.get_shape().as_list()[1:3]
54 | return tf.image.resize_bilinear(inputs, [height * factor, width * factor], name=scope)
55 |
56 |
57 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None):
58 | """Strided 2-D convolution with 'SAME' padding.
59 |
60 | When stride > 1, then we do explicit zero-padding, followed by conv2d with
61 | 'VALID' padding.
62 |
63 | Note that
64 |
65 | net = conv2d_same(inputs, num_outputs, 3, stride=stride)
66 |
67 | is equivalent to
68 |
69 | net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME')
70 | net = subsample(net, factor=stride)
71 |
72 | whereas
73 |
74 | net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME')
75 |
76 | is different when the input's height or width is even, which is why we add the
77 | current function. For more details, see ResnetUtilsTest.testConv2DSameEven().
78 |
79 | Args:
80 | inputs: A 4-D tensor of size [batch, height_in, width_in, channels].
81 | num_outputs: An integer, the number of output filters.
82 | kernel_size: An int with the kernel_size of the filters.
83 | stride: An integer, the output stride.
84 | rate: An integer, rate for atrous convolution.
85 | scope: Scope.
86 |
87 | Returns:
88 | output: A 4-D tensor of size [batch, height_out, width_out, channels] with
89 | the convolution output.
90 | """
91 | if stride == 1:
92 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate,
93 | padding='SAME', scope=scope)
94 | else:
95 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
96 | pad_total = kernel_size_effective - 1
97 | pad_beg = pad_total // 2
98 | pad_end = pad_total - pad_beg
99 | inputs = tf.pad(inputs,
100 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
101 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride,
102 | rate=rate, padding='VALID', scope=scope)
103 |
104 |
105 | @slim.add_arg_scope
106 | def stack_blocks_dense(net, blocks, output_stride=None,
107 | outputs_collections=None):
108 | """Stacks ResNet `Blocks` and controls output feature density.
109 |
110 | First, this function creates scopes for the ResNet in the form of
111 | 'block_name/unit_1', 'block_name/unit_2', etc.
112 |
113 | Second, this function allows the user to explicitly control the ResNet
114 | output_stride, which is the ratio of the input to output spatial resolution.
115 | This is useful for dense prediction tasks such as semantic segmentation or
116 | object detection.
117 |
118 | Most ResNets consist of 4 ResNet blocks and subsample the activations by a
119 | factor of 2 when transitioning between consecutive ResNet blocks. This results
120 | to a nominal ResNet output_stride equal to 8. If we set the output_stride to
121 | half the nominal network stride (e.g., output_stride=4), then we compute
122 | responses twice.
123 |
124 | Control of the output feature density is implemented by atrous convolution.
125 |
126 | Args:
127 | net: A `Tensor` of size [batch, height, width, channels].
128 | blocks: A list of length equal to the number of ResNet `Blocks`. Each
129 | element is a ResNet `Block` object describing the units in the `Block`.
130 | output_stride: If `None`, then the output will be computed at the nominal
131 | network stride. If output_stride is not `None`, it specifies the requested
132 | ratio of input to output spatial resolution, which needs to be equal to
133 | the product of unit strides from the start up to some level of the ResNet.
134 | For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1,
135 | then valid values for the output_stride are 1, 2, 6, 24 or None (which
136 | is equivalent to output_stride=24).
137 | outputs_collections: Collection to add the ResNet block output.
138 |
139 | Returns:
140 | net: Output tensor with stride equal to the specified output_stride.
141 |
142 | Raises:
143 | ValueError: If the target output_stride is not valid.
144 | """
145 | # The current_stride variable keeps track of the effective stride of the
146 | # activations. This allows us to invoke atrous convolution whenever applying
147 | # the next residual unit would result in the activations having stride larger
148 | # than the target output_stride.
149 | current_stride = 1
150 |
151 | # The atrous convolution rate parameter.
152 | rate = 1
153 |
154 | for block in blocks:
155 | with tf.variable_scope(block.scope, 'block', [net]) as sc:
156 | for i, unit in enumerate(block.args):
157 | if output_stride is not None and current_stride > output_stride:
158 | raise ValueError('The target output_stride cannot be reached.')
159 |
160 | with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
161 | # If we have reached the target output_stride, then we need to employ
162 | # atrous convolution with stride=1 and multiply the atrous rate by the
163 | # current unit's stride for use in subsequent layers.
164 | if output_stride is not None and current_stride == output_stride:
165 | net = block.unit_fn(net, rate=rate, **dict(unit, stride=1))
166 | rate *= unit.get('stride', 1)
167 |
168 | else:
169 | net = block.unit_fn(net, rate=1, **unit)
170 | current_stride *= unit.get('stride', 1)
171 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net)
172 |
173 | if output_stride is not None and current_stride != output_stride:
174 | raise ValueError('The target output_stride cannot be reached.')
175 |
176 | return net
177 |
178 |
179 | def resnet_arg_scope(weight_decay=0.0001,
180 | batch_norm_decay=0.997,
181 | batch_norm_epsilon=1e-5,
182 | batch_norm_scale=True,
183 | activation_fn=tf.nn.relu,
184 | use_batch_norm=True):
185 | """Defines the default ResNet arg scope.
186 |
187 | TODO(gpapan): The batch-normalization related default values above are
188 | appropriate for use in conjunction with the reference ResNet models
189 | released at https://github.com/KaimingHe/deep-residual-networks. When
190 | training ResNets from scratch, they might need to be tuned.
191 |
192 | Args:
193 | weight_decay: The weight decay to use for regularizing the model.
194 | batch_norm_decay: The moving average decay when estimating layer activation
195 | statistics in batch normalization.
196 | batch_norm_epsilon: Small constant to prevent division by zero when
197 | normalizing activations by their variance in batch normalization.
198 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
199 | activations in the batch normalization layer.
200 | activation_fn: The activation function which is used in ResNet.
201 | use_batch_norm: Whether or not to use batch normalization.
202 |
203 | Returns:
204 | An `arg_scope` to use for the resnet models.
205 | """
206 | batch_norm_params = {
207 | 'decay': batch_norm_decay,
208 | 'epsilon': batch_norm_epsilon,
209 | 'scale': batch_norm_scale,
210 | 'updates_collections': tf.GraphKeys.UPDATE_OPS,
211 | 'fused': None, # Use fused batch norm if possible.
212 | }
213 |
214 | with slim.arg_scope(
215 | [slim.conv2d],
216 | weights_regularizer=slim.l2_regularizer(weight_decay),
217 | weights_initializer=slim.variance_scaling_initializer(),
218 | activation_fn=activation_fn,
219 | normalizer_fn=slim.batch_norm if use_batch_norm else None,
220 | normalizer_params=batch_norm_params):
221 | with slim.arg_scope([slim.batch_norm], **batch_norm_params):
222 | # The following implies padding='SAME' for pool1, which makes feature
223 | # alignment easier for dense prediction tasks. This is also used in
224 | # https://github.com/facebook/fb.resnet.torch. However the accompanying
225 | # code of 'Deep Residual Learning for Image Recognition' uses
226 | # padding='VALID' for pool1. You can switch to that choice by setting
227 | # slim.arg_scope([slim.max_pool2d], padding='VALID').
228 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc:
229 | return arg_sc
230 |
--------------------------------------------------------------------------------
/network_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.contrib.slim as slim
3 |
4 | import os
5 |
6 |
7 | def decode_depth(encoded_depth):
8 | encoded_depth = tf.cast(encoded_depth, tf.float32)
9 | r, g, b = tf.unstack(encoded_depth, axis=2)
10 | depth = r * 65536.0 + g * 256.0 + b # decode depth image
11 | # depth = tf.div(depth, tf.constant(100.0))
12 | return depth
13 |
14 |
15 | def encode_depth(depth):
16 | # depth = tf.multiply(depth, tf.constant(10000.0))
17 | depth = tf.cast(depth, tf.uint16)
18 | r = depth / 256 / 256
19 | g = depth / 256
20 | b = depth % 256
21 | encoded_depth = tf.stack([r, g, b], axis=2)
22 | encoded_depth = tf.cast(encoded_depth, tf.uint8)
23 | return encoded_depth
24 |
25 |
26 | def get_dataset(dataset_dir,
27 | num_readers,
28 | num_preprocessing_threads,
29 | image_size,
30 | label_size,
31 | batch_size=1,
32 | reader=None,
33 | shuffle=True,
34 | num_epochs=None,
35 | is_training=True,
36 | is_depth=True):
37 | dataset_dir_list = [os.path.join(dataset_dir, filename)
38 | for filename in os.listdir(dataset_dir) if filename.endswith('.tfrecord')]
39 | if reader is None:
40 | reader = tf.TFRecordReader
41 | keys_to_features = {
42 | 'image/color': tf.FixedLenFeature((), tf.string, default_value=''),
43 | 'image/encoded_depth': tf.FixedLenFeature((), tf.string, default_value=''),
44 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
45 | 'image/label': tf.FixedLenFeature((), tf.string, default_value=''),
46 | 'image/camera_height': tf.FixedLenFeature((), tf.float32, default_value=0.0),
47 | }
48 | items_to_handlers = {
49 | 'color': slim.tfexample_decoder.Image(image_key='image/color',
50 | shape=(image_size, image_size, 3),
51 | channels=3),
52 | 'encoded_depth': slim.tfexample_decoder.Image(image_key='image/encoded_depth',
53 | shape=(image_size, image_size, 3),
54 | channels=3),
55 | 'label': slim.tfexample_decoder.Image(image_key='image/label',
56 | shape=(label_size, label_size, 3),
57 | channels=3),
58 | 'camera_height': slim.tfexample_decoder.Tensor(tensor_key='image/camera_height',
59 | shape=(1,)),
60 |
61 | }
62 | decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
63 | dataset = slim.dataset.Dataset(data_sources=dataset_dir_list,
64 | reader=reader,
65 | decoder=decoder,
66 | num_samples=3,
67 | items_to_descriptions=None)
68 | provider = slim.dataset_data_provider.DatasetDataProvider(dataset,
69 | num_readers=num_readers,
70 | shuffle=shuffle,
71 | num_epochs=num_epochs,
72 | common_queue_capacity=20 * batch_size,
73 | common_queue_min=10 * batch_size)
74 | color, encoded_depth, label, camera_height = provider.get(['color', 'encoded_depth', 'label', 'camera_height'])
75 | color = tf.cast(color, tf.float32)
76 | color = color * tf.random_normal(color.get_shape(), mean=1, stddev=0.01)
77 | color = color / tf.constant([255.0])
78 | color = (color - tf.constant([0.485, 0.456, 0.406])) / tf.constant([0.229, 0.224, 0.225])
79 | depth = decode_depth(encoded_depth)
80 | camera_height = tf.expand_dims(tf.expand_dims(camera_height, axis=0), axis=0)
81 | camera_height = camera_height / 1000.0
82 | depth = depth * tf.random_normal(depth.get_shape(), mean=1, stddev=0.01)
83 | depth = depth / 1000.0 # (depth - tf.reduce_mean(depth)) / 1000.0
84 | if is_depth:
85 | input = tf.concat([color, tf.expand_dims(depth, axis=2)], axis=2)
86 | else:
87 | input = color
88 | label = tf.cast(label, tf.float32)
89 | label = label / 255.0
90 | if is_training:
91 | inputs, labels, camera_heights = tf.train.batch([input, label, camera_height],
92 | batch_size=batch_size,
93 | num_threads=num_preprocessing_threads,
94 | capacity=5*batch_size)
95 | else:
96 | inputs = tf.expand_dims(input, axis=0)
97 | labels = tf.expand_dims(label, axis=0)
98 | camera_heights = tf.expand_dims(camera_height, axis=0)
99 | return inputs, labels, camera_heights
100 |
101 |
102 | def get_color_dataset(dataset_dir,
103 | num_readers,
104 | num_preprocessing_threads,
105 | image_size,
106 | label_size,
107 | batch_size=1,
108 | reader=None,
109 | shuffle=True,
110 | num_epochs=None,
111 | is_training=True):
112 | dataset_dir_list = [os.path.join(dataset_dir, filename)
113 | for filename in os.listdir(dataset_dir) if filename.endswith('.tfrecord')]
114 | if reader is None:
115 | reader = tf.TFRecordReader
116 | keys_to_features = {
117 | 'image/color': tf.FixedLenFeature((), tf.string, default_value=''),
118 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
119 | 'image/label': tf.FixedLenFeature((), tf.string, default_value=''),
120 | }
121 | items_to_handlers = {
122 | 'color': slim.tfexample_decoder.Image(image_key='image/color',
123 | shape=(image_size, image_size, 3),
124 | channels=3),
125 | 'label': slim.tfexample_decoder.Image(image_key='image/label',
126 | shape=(label_size, label_size, 3),
127 | channels=3),
128 | }
129 | decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
130 | dataset = slim.dataset.Dataset(data_sources=dataset_dir_list,
131 | reader=reader,
132 | decoder=decoder,
133 | num_samples=3,
134 | items_to_descriptions=None)
135 | provider = slim.dataset_data_provider.DatasetDataProvider(dataset,
136 | num_readers=num_readers,
137 | shuffle=shuffle,
138 | num_epochs=num_epochs,
139 | common_queue_capacity=20 * batch_size,
140 | common_queue_min=10 * batch_size)
141 | color, label = provider.get(['color', 'label'])
142 | color = tf.cast(color, tf.float32)
143 | color = color * tf.random_normal(color.get_shape(), mean=1, stddev=0.01)
144 | color = color / tf.constant([255.0])
145 | color = (color - tf.constant([0.485, 0.456, 0.406])) / tf.constant([0.229, 0.224, 0.225])
146 | label = tf.cast(label, tf.float32) / 255.0
147 | if is_training:
148 | colors, labels = tf.train.batch([color, label],
149 | batch_size=batch_size,
150 | num_threads=num_preprocessing_threads,
151 | capacity=5*batch_size)
152 | else:
153 | colors = tf.expand_dims(color, axis=0)
154 | labels = tf.expand_dims(label, axis=0)
155 | return colors, labels
156 |
157 |
158 | def get_depth_dataset(dataset_dir,
159 | num_readers,
160 | num_preprocessing_threads,
161 | image_size,
162 | label_size,
163 | batch_size=1,
164 | reader=None,
165 | shuffle=True,
166 | num_epochs=None,
167 | is_training=True):
168 | dataset_dir_list = [os.path.join(dataset_dir, filename)
169 | for filename in os.listdir(dataset_dir) if filename.endswith('.tfrecord')]
170 | if reader is None:
171 | reader = tf.TFRecordReader
172 | keys_to_features = {
173 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
174 | 'image/encoded_depth': tf.FixedLenFeature((), tf.string, default_value=''),
175 | 'image/label': tf.FixedLenFeature((), tf.string, default_value=''),
176 | }
177 | items_to_handlers = {
178 | 'encoded_depth': slim.tfexample_decoder.Image(image_key='image/encoded_depth',
179 | shape=(image_size, image_size, 3),
180 | channels=3),
181 | 'label': slim.tfexample_decoder.Image(image_key='image/label',
182 | shape=(label_size, label_size, 3),
183 | channels=3),
184 | }
185 | decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
186 | dataset = slim.dataset.Dataset(data_sources=dataset_dir_list,
187 | reader=reader,
188 | decoder=decoder,
189 | num_samples=3,
190 | items_to_descriptions=None)
191 | provider = slim.dataset_data_provider.DatasetDataProvider(dataset,
192 | num_readers=num_readers,
193 | shuffle=shuffle,
194 | num_epochs=num_epochs,
195 | common_queue_capacity=20 * batch_size,
196 | common_queue_min=10 * batch_size)
197 | encoded_depth, label = provider.get(['encoded_depth', 'label'])
198 | depth = decode_depth(encoded_depth)
199 | depth = (depth - tf.reduce_mean(depth)) / 1000.0
200 | depth = depth * tf.random_normal(depth.get_shape(), mean=1, stddev=0.01)
201 | depth = tf.stack([depth, depth, depth], axis=2)
202 | label = tf.cast(label, tf.float32) / 255.0
203 | if is_training:
204 | depths, labels = tf.train.batch([depth, label],
205 | batch_size=batch_size,
206 | num_threads=num_preprocessing_threads,
207 | capacity=5*batch_size)
208 | else:
209 | depths = tf.expand_dims(depth, axis=0)
210 | labels = tf.expand_dims(label, axis=0)
211 | return depths, labels
212 |
213 |
214 | def add_summary(inputs, labels, end_points, loss, hparams):
215 | h_b = int(hparams.batch_size/2)
216 | tf.summary.scalar(hparams.scope+'_loss', loss)
217 | tf.summary.image(hparams.scope+'_inputs_g', inputs[0:h_b])
218 | tf.summary.image(hparams.scope+'_inputs_r', inputs[h_b:])
219 | tf.summary.image(hparams.scope+'_labels_g', labels[0:h_b])
220 | tf.summary.image(hparams.scope+'_labels_r', labels[h_b:])
221 | # for i in range(1, 3):
222 | # for j in range(64):
223 | # tf.summary.image(scope + '/conv{}' + '_{}'.format(i, j),
224 | # end_points[scope + '/conv{}'.format(i)][0:1, :, :, j:j + 1])
225 | # tf.summary.image(scope + '/conv3', end_points[scope + '/conv3'])
226 | net = end_points['logits']
227 | infer_map = tf.exp(net) / tf.reduce_sum(tf.exp(net), axis=3, keepdims=True)
228 | tf.summary.image(hparams.scope+'_inference_map_g', infer_map[0:h_b])
229 | tf.summary.image(hparams.scope+'_inference_map_r', infer_map[h_b:])
230 | # variable_list = slim.get_model_variables()
231 | # for var in variable_list:
232 | # tf.summary.histogram(var.name[:-2], var)
233 |
234 |
235 | def restore_map():
236 | variable_list = slim.get_model_variables()
237 | variables_to_restore = {var.op.name: var for var in variable_list}
238 | return variables_to_restore
239 |
240 |
241 | def restore_from_classification_checkpoint(scope, model_name, checkpoint_exclude_scopes):
242 | variable_list = slim.get_model_variables(os.path.join(scope, 'feature_extractor'))
243 | for checkpoint_exclude_scope in checkpoint_exclude_scopes:
244 | variable_list = [var for var in variable_list if checkpoint_exclude_scope not in var.op.name]
245 | variables_to_restore = {}
246 | for var in variable_list:
247 | if var.name.startswith(os.path.join(scope, 'feature_extractor')):
248 | var_name = var.op.name.replace(os.path.join(scope, 'feature_extractor'), model_name)
249 | variables_to_restore[var_name] = var
250 | return variables_to_restore
251 |
--------------------------------------------------------------------------------
/nets/resnet_v1.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.contrib.slim as slim
3 | from nets import resnet_utils
4 |
5 |
6 | resnet_arg_scope = resnet_utils.resnet_arg_scope
7 |
8 |
9 | @slim.add_arg_scope
10 | def bottleneck(inputs,
11 | depth,
12 | depth_bottleneck,
13 | stride,
14 | rate=1,
15 | outputs_collections=None,
16 | scope=None,
17 | use_bounded_activations=False):
18 | """Bottleneck residual unit variant with BN after convolutions.
19 |
20 | This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for
21 | its definition. Note that we use here the bottleneck variant which has an
22 | extra bottleneck layer.
23 |
24 | When putting together two consecutive ResNet blocks that use this unit, one
25 | should use stride = 2 in the last unit of the first block.
26 |
27 | Args:
28 | inputs: A tensor of size [batch, height, width, channels].
29 | depth: The depth of the ResNet unit output.
30 | depth_bottleneck: The depth of the bottleneck layers.
31 | stride: The ResNet unit's stride. Determines the amount of downsampling of
32 | the units output compared to its input.
33 | rate: An integer, rate for atrous convolution.
34 | outputs_collections: Collection to add the ResNet unit output.
35 | scope: Optional variable_scope.
36 | use_bounded_activations: Whether or not to use bounded activations. Bounded
37 | activations better lend themselves to quantized inference.
38 |
39 | Returns:
40 | The ResNet unit's output.
41 | """
42 | with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc:
43 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
44 | if depth == depth_in:
45 | shortcut = resnet_utils.subsample(inputs, stride, 'shortcut')
46 | else:
47 | shortcut = slim.conv2d(
48 | inputs,
49 | depth, [1, 1],
50 | stride=stride,
51 | activation_fn=tf.nn.relu6 if use_bounded_activations else None,
52 | scope='shortcut')
53 |
54 | residual = slim.conv2d(inputs, depth_bottleneck, [1, 1], stride=1,
55 | scope='conv1')
56 | residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride,
57 | rate=rate, scope='conv2')
58 | residual = slim.conv2d(residual, depth, [1, 1], stride=1,
59 | activation_fn=None, scope='conv3')
60 |
61 | if use_bounded_activations:
62 | # Use clip_by_value to simulate bandpass activation.
63 | residual = tf.clip_by_value(residual, -6.0, 6.0)
64 | output = tf.nn.relu6(shortcut + residual)
65 | else:
66 | output = tf.nn.relu(shortcut + residual)
67 |
68 | return slim.utils.collect_named_outputs(outputs_collections,
69 | sc.name,
70 | output)
71 |
72 |
73 | def resnet_v1(inputs,
74 | blocks,
75 | num_classes=None,
76 | is_training=True,
77 | global_pool=True,
78 | output_stride=None,
79 | include_root_block=True,
80 | spatial_squeeze=True,
81 | reuse=None,
82 | scope=None):
83 | """Generator for v1 ResNet models.
84 |
85 | This function generates a family of ResNet v1 models. See the resnet_v1_*()
86 | methods for specific model instantiations, obtained by selecting different
87 | block instantiations that produce ResNets of various depths.
88 |
89 | Training for image classification on Imagenet is usually done with [224, 224]
90 | inputs, resulting in [7, 7] feature maps at the output of the last ResNet
91 | block for the ResNets defined in [1] that have nominal stride equal to 32.
92 | However, for dense prediction tasks we advise that one uses inputs with
93 | spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In
94 | this case the feature maps at the ResNet output will have spatial shape
95 | [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1]
96 | and corners exactly aligned with the input image corners, which greatly
97 | facilitates alignment of the features to the image. Using as input [225, 225]
98 | images results in [8, 8] feature maps at the output of the last ResNet block.
99 |
100 | For dense prediction tasks, the ResNet needs to run in fully-convolutional
101 | (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all
102 | have nominal stride equal to 32 and a good choice in FCN mode is to use
103 | output_stride=16 in order to increase the density of the computed features at
104 | small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915.
105 |
106 | Args:
107 | inputs: A tensor of size [batch, height_in, width_in, channels].
108 | blocks: A list of length equal to the number of ResNet blocks. Each element
109 | is a resnet_utils.Block object describing the units in the block.
110 | num_classes: Number of predicted classes for classification tasks.
111 | If 0 or None, we return the features before the logit layer.
112 | is_training: whether batch_norm layers are in training mode.
113 | global_pool: If True, we perform global average pooling before computing the
114 | logits. Set to True for image classification, False for dense prediction.
115 | output_stride: If None, then the output will be computed at the nominal
116 | network stride. If output_stride is not None, it specifies the requested
117 | ratio of input to output spatial resolution.
118 | include_root_block: If True, include the initial convolution followed by
119 | max-pooling, if False excludes it.
120 | spatial_squeeze: if True, logits is of shape [B, C], if false logits is
121 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
122 | To use this parameter, the input images must be smaller than 300x300
123 | pixels, in which case the output logit layer does not contain spatial
124 | information and can be removed.
125 | reuse: whether or not the network and its variables should be reused. To be
126 | able to reuse 'scope' must be given.
127 | scope: Optional variable_scope.
128 |
129 | Returns:
130 | net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
131 | If global_pool is False, then height_out and width_out are reduced by a
132 | factor of output_stride compared to the respective height_in and width_in,
133 | else both height_out and width_out equal one. If num_classes is 0 or None,
134 | then net is the output of the last ResNet block, potentially after global
135 | average pooling. If num_classes a non-zero integer, net contains the
136 | pre-softmax activations.
137 | end_points: A dictionary from components of the network to the corresponding
138 | activation.
139 |
140 | Raises:
141 | ValueError: If the target output_stride is not valid.
142 | """
143 | with tf.variable_scope(scope, 'resnet_v1', [inputs], reuse=reuse) as sc:
144 | end_points_collection = sc.original_name_scope + '_end_points'
145 | with slim.arg_scope([slim.conv2d, bottleneck,
146 | resnet_utils.stack_blocks_dense],
147 | outputs_collections=end_points_collection):
148 | with slim.arg_scope([slim.batch_norm], is_training=is_training):
149 | net = inputs
150 | if include_root_block:
151 | if output_stride is not None:
152 | if output_stride % 4 != 0:
153 | raise ValueError('The output_stride needs to be a multiple of 4.')
154 | output_stride /= 4
155 | net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1')
156 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1')
157 | net = resnet_utils.stack_blocks_dense(net, blocks, output_stride)
158 | # Convert end_points_collection into a dictionary of end_points.
159 | end_points = slim.utils.convert_collection_to_dict(
160 | end_points_collection)
161 |
162 | if global_pool:
163 | # Global average pooling.
164 | net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True)
165 | end_points['global_pool'] = net
166 | if num_classes:
167 | net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
168 | normalizer_fn=None, scope='logits')
169 | end_points[sc.name + '/logits'] = net
170 | if spatial_squeeze:
171 | net = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
172 | end_points[sc.name + '/spatial_squeeze'] = net
173 | end_points['predictions'] = slim.softmax(net, scope='predictions')
174 | return net, end_points
175 |
176 |
177 | resnet_v1.default_image_size = 320
178 |
179 |
180 | def resnet_v1_block(scope, base_depth, num_units, stride):
181 | """Helper function for creating a resnet_v1 bottleneck block.
182 |
183 | Args:
184 | scope: The scope of the block.
185 | base_depth: The depth of the bottleneck layer for each unit.
186 | num_units: The number of units in the block.
187 | stride: The stride of the block, implemented as a stride in the last unit.
188 | All other units have stride=1.
189 |
190 | Returns:
191 | A resnet_v1 bottleneck block.
192 | """
193 | return resnet_utils.Block(scope, bottleneck, [{
194 | 'depth': base_depth * 4,
195 | 'depth_bottleneck': base_depth,
196 | 'stride': 1
197 | }] * (num_units - 1) + [{
198 | 'depth': base_depth * 4,
199 | 'depth_bottleneck': base_depth,
200 | 'stride': stride
201 | }])
202 |
203 |
204 | def resnet_v1_50(inputs,
205 | num_classes=None,
206 | is_training=True,
207 | global_pool=True,
208 | output_stride=None,
209 | spatial_squeeze=True,
210 | reuse=None,
211 | scope='resnet_v1_50'):
212 | """ResNet-50 model of [1]. See resnet_v1() for arg and return description."""
213 | blocks = [
214 | resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
215 | resnet_v1_block('block2', base_depth=128, num_units=4, stride=2),
216 | resnet_v1_block('block3', base_depth=256, num_units=6, stride=2),
217 | resnet_v1_block('block4', base_depth=512, num_units=3, stride=1),
218 | ]
219 | return resnet_v1(inputs, blocks, num_classes, is_training,
220 | global_pool=global_pool, output_stride=output_stride,
221 | include_root_block=True, spatial_squeeze=spatial_squeeze,
222 | reuse=reuse, scope=scope)
223 |
224 |
225 | resnet_v1_50.default_image_size = resnet_v1.default_image_size
226 |
227 |
228 | def resnet_v1_101(inputs,
229 | num_classes=None,
230 | is_training=True,
231 | global_pool=True,
232 | output_stride=None,
233 | spatial_squeeze=True,
234 | reuse=None,
235 | scope='resnet_v1_101'):
236 | """ResNet-101 model of [1]. See resnet_v1() for arg and return description."""
237 | blocks = [
238 | resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
239 | resnet_v1_block('block2', base_depth=128, num_units=4, stride=2),
240 | resnet_v1_block('block3', base_depth=256, num_units=23, stride=2),
241 | resnet_v1_block('block4', base_depth=512, num_units=3, stride=1),
242 | ]
243 | return resnet_v1(inputs, blocks, num_classes, is_training,
244 | global_pool=global_pool, output_stride=output_stride,
245 | include_root_block=True, spatial_squeeze=spatial_squeeze,
246 | reuse=reuse, scope=scope)
247 |
248 |
249 | resnet_v1_101.default_image_size = resnet_v1.default_image_size
250 |
251 |
252 | def resnet_v1_152(inputs,
253 | num_classes=None,
254 | is_training=True,
255 | global_pool=True,
256 | output_stride=None,
257 | spatial_squeeze=True,
258 | reuse=None,
259 | scope='resnet_v1_152'):
260 | """ResNet-152 model of [1]. See resnet_v1() for arg and return description."""
261 | blocks = [
262 | resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
263 | resnet_v1_block('block2', base_depth=128, num_units=8, stride=2),
264 | resnet_v1_block('block3', base_depth=256, num_units=36, stride=2),
265 | resnet_v1_block('block4', base_depth=512, num_units=3, stride=1),
266 | ]
267 | return resnet_v1(inputs, blocks, num_classes, is_training,
268 | global_pool=global_pool, output_stride=output_stride,
269 | include_root_block=True, spatial_squeeze=spatial_squeeze,
270 | reuse=reuse, scope=scope)
271 |
272 |
273 | resnet_v1_152.default_image_size = resnet_v1.default_image_size
274 |
275 |
276 | def resnet_v1_200(inputs,
277 | num_classes=None,
278 | is_training=True,
279 | global_pool=True,
280 | output_stride=None,
281 | spatial_squeeze=True,
282 | reuse=None,
283 | scope='resnet_v1_200'):
284 | """ResNet-200 model of [2]. See resnet_v1() for arg and return description."""
285 | blocks = [
286 | resnet_v1_block('block1', base_depth=64, num_units=3, stride=2),
287 | resnet_v1_block('block2', base_depth=128, num_units=24, stride=2),
288 | resnet_v1_block('block3', base_depth=256, num_units=36, stride=2),
289 | resnet_v1_block('block4', base_depth=512, num_units=3, stride=1),
290 | ]
291 | return resnet_v1(inputs, blocks, num_classes, is_training,
292 | global_pool=global_pool, output_stride=output_stride,
293 | include_root_block=True, spatial_squeeze=spatial_squeeze,
294 | reuse=reuse, scope=scope)
295 |
296 |
297 | resnet_v1_200.default_image_size = resnet_v1.default_image_size
298 |
--------------------------------------------------------------------------------