├── nets ├── __init__.py ├── resnet_utils.py └── resnet_v1.py ├── data ├── output │ └── .gitignore ├── .gitignore └── test │ ├── test0.png │ ├── test1.png │ ├── test2.png │ ├── test3.png │ ├── test4.png │ ├── test5.png │ └── test6.png ├── data_collection ├── .gitignore ├── remoteApi.so ├── README.md ├── ur.py ├── grasp_trials_without_rule.py ├── rg2.py ├── corrective_grasp_trials.py └── scene.py ├── .gitignore ├── models └── .gitignore ├── doc ├── preprocessing.png ├── 000007_0000_color.png ├── 000007_0000_depth.png ├── 000007_0000_label.png ├── 000007_0000_height_map_color.png ├── 000007_0000_height_map_depth.png └── the_proposed_grasp_system_pipeline.jpg ├── requirements.txt ├── download.sh ├── losses.py ├── hparams.py ├── data_processing ├── create_tf_record.py ├── rescale_data.py ├── data_processor.py ├── dataset_utils.py ├── data_preprocessing.py ├── mag_data_preprocessing.py ├── generate_pseudo_data_v1.py └── genarate_pseudo_data_v2.py ├── metagrasp_test.py ├── mag_test.py ├── models.py ├── README.md ├── metagrasp_train.py ├── mag_train.py └── network_utils.py /nets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/output/.gitignore: -------------------------------------------------------------------------------- 1 | *.png -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | tfrecord* 2 | 3dnet* -------------------------------------------------------------------------------- /data_collection/.gitignore: -------------------------------------------------------------------------------- 1 | *.ttt 2 | *.ttm -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.pyc 3 | *.tfrecord 4 | __pycache__ 5 | .idea -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | *.ckpt 2 | events* 3 | checkpoint 4 | *.pbtxt 5 | *.json 6 | model.* -------------------------------------------------------------------------------- /data/test/test0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/data/test/test0.png -------------------------------------------------------------------------------- /data/test/test1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/data/test/test1.png -------------------------------------------------------------------------------- /data/test/test2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/data/test/test2.png -------------------------------------------------------------------------------- /data/test/test3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/data/test/test3.png -------------------------------------------------------------------------------- /data/test/test4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/data/test/test4.png -------------------------------------------------------------------------------- /data/test/test5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/data/test/test5.png -------------------------------------------------------------------------------- /data/test/test6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/data/test/test6.png -------------------------------------------------------------------------------- /doc/preprocessing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/doc/preprocessing.png -------------------------------------------------------------------------------- /doc/000007_0000_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/doc/000007_0000_color.png -------------------------------------------------------------------------------- /doc/000007_0000_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/doc/000007_0000_depth.png -------------------------------------------------------------------------------- /doc/000007_0000_label.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/doc/000007_0000_label.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow>=1.9.0 2 | opencv-python>=3.4.0 3 | numpy>=1.13.3 4 | scipy>=1.0.0 5 | gdown==3.8.1 6 | -------------------------------------------------------------------------------- /data_collection/remoteApi.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/data_collection/remoteApi.so -------------------------------------------------------------------------------- /doc/000007_0000_height_map_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/doc/000007_0000_height_map_color.png -------------------------------------------------------------------------------- /doc/000007_0000_height_map_depth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/doc/000007_0000_height_map_depth.png -------------------------------------------------------------------------------- /doc/the_proposed_grasp_system_pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysu-robotics-lab/MetaGrasp/HEAD/doc/the_proposed_grasp_system_pipeline.jpg -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | gdown https://drive.google.com/uc?id=1BCJkYBIA1wnPtmEZ_MCA9HMaZ6J6DlUg 2 | 3 | tar -zxvf metagrasp.tar.gz 4 | rm metagrasp.tar.gz 5 | 6 | cd metagrasp 7 | 8 | mv ur.ttt ../data_collection/ 9 | 10 | tar -zxvf 3dnet.tar.gz 11 | rm 3dnet.tar.gz 12 | mv 3dnet ../data_collection/ 13 | 14 | tar -zxvf model.tar.gz 15 | rm model.tar.gz 16 | mv metagrasp ../models/ 17 | 18 | cd .. 19 | rm -rf metagrasp 20 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def create_loss(net, labels): 5 | y = tf.exp(net) / tf.reduce_sum(tf.exp(net), axis=3, keepdims=True) 6 | cross_entropy = -tf.reduce_mean(labels * tf.log(tf.clip_by_value(y, 0.001, 0.999))) 7 | return cross_entropy 8 | 9 | 10 | def create_loss_with_label_mask(net, labels, lamb): 11 | bad, good, background = tf.unstack(labels, axis=3) 12 | mask = lamb * tf.add(bad, good) + background * 0.1 13 | attention_mask = tf.stack([mask, mask, mask], axis=3) 14 | y = tf.exp(net) / tf.reduce_sum(tf.exp(net), axis=3, keepdims=True) 15 | cross_entropy = -tf.reduce_mean(attention_mask * (labels * tf.log(tf.clip_by_value(y, 0.001, 0.999)))) 16 | return cross_entropy 17 | 18 | 19 | def create_loss_without_background(net, labels): 20 | bad, good, background = tf.unstack(labels, axis=3) 21 | background = tf.zeros_like(background, dtype=tf.float32) 22 | labels = tf.stack([bad, good, background], axis=3) 23 | y = tf.exp(net) / tf.reduce_sum(tf.exp(net), axis=3, keepdims=True) 24 | cross_entropy = -tf.reduce_mean(labels * tf.log(tf.clip_by_value(y, 0.001, 0.999))) 25 | return cross_entropy 26 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def create_mag_hparams(hparam_string=None): 5 | hparams = tf.contrib.training.HParams(learning_rate=0.001, 6 | lr_decay_step=200000, 7 | lr_decay_rate=0.77, 8 | momentum=0.99, 9 | lamb=15.0, 10 | batch_size=8, 11 | image_size=288, 12 | label_size=36, 13 | scope='mag', 14 | model_name='resnet_v1_50') 15 | if hparam_string: 16 | tf.logging.info('Parsing command line hparams: %s', hparam_string) 17 | hparams.parse(hparam_string) 18 | 19 | tf.logging.info('Final parsed hparams: %s', hparams.values()) 20 | return hparams 21 | 22 | 23 | def create_metagrasp_hparams(hparam_string=None): 24 | hparams = tf.contrib.training.HParams(learning_rate=0.001, 25 | lr_decay_step=200000, 26 | lr_decay_rate=0.77, 27 | momentum=0.99, 28 | lamb=120.0, 29 | batch_size=16, 30 | image_size=288, 31 | label_size=288, 32 | scope='metagrasp', 33 | model_name='resnet_v1_50') 34 | if hparam_string: 35 | tf.logging.info('Parsing command line hparams: %s', hparam_string) 36 | hparams.parse(hparam_string) 37 | 38 | tf.logging.info('Final parsed hparams: %s', hparams.values()) 39 | return hparams 40 | 41 | -------------------------------------------------------------------------------- /data_processing/create_tf_record.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import random 3 | import math 4 | 5 | import argparse 6 | import os 7 | 8 | from data_processing.dataset_utils import * 9 | 10 | parser = argparse.ArgumentParser(description='create tensorflow record') 11 | parser.add_argument('--data_path', 12 | required=True, 13 | type=str, 14 | help='Path to data set.') 15 | parser.add_argument('--output', 16 | default='data/tfrecords', 17 | type=str, 18 | help='Path to the output files.') 19 | args = parser.parse_args() 20 | 21 | folders = ['color', 'encoded_depth', 'label_map', 'camera_height'] 22 | 23 | 24 | def dict_to_tf_example(file_name): 25 | """ 26 | Create tfrecord example. 27 | :param file_name: File path corresponding to the data. 28 | :return: example: Example of tfrecord. 29 | """ 30 | with open(os.path.join(args.data_path, folders[0], file_name+'.png'), 'rb') as fid: 31 | encoded_color = fid.read() 32 | with open(os.path.join(args.data_path, folders[1], file_name + '.png'), 'rb') as fid: 33 | encoded_depth = fid.read() 34 | with open(os.path.join(args.data_path, folders[2], file_name + '.png'), 'rb') as fid: 35 | encoded_label_map = fid.read() 36 | # camera_height = float(np.loadtxt(os.path.join(args.data_path, folders[3], file_name + '.txt'))) 37 | example = tf.train.Example(features=tf.train.Features(feature={ 38 | 'image/color': bytes_feature(encoded_color), 39 | 'image/format': bytes_feature(b'png'), 40 | 'image/encoded_depth': bytes_feature(encoded_depth), 41 | 'image/label': bytes_feature(encoded_label_map), 42 | # 'image/camera_height': float_feature(camera_height), 43 | })) 44 | return example 45 | 46 | 47 | def main(): 48 | if not os.path.exists(args.output): 49 | os.makedirs(args.output) 50 | os.makedirs(os.path.join(args.output, 'pos')) 51 | os.makedirs(os.path.join(args.output, 'neg')) 52 | curr_pos_records = len(os.listdir(os.path.join(args.output, 'pos'))) 53 | curr_neg_records = len(os.listdir(os.path.join(args.output, 'neg'))) 54 | curr_records_list = [curr_pos_records, curr_neg_records] 55 | sets = ['pos', 'neg'] 56 | num_samples_per_record = 10000 57 | 58 | for curr_records, set in zip(curr_records_list, sets): 59 | with open(os.path.join(args.data_path, set+'.txt'), 'r') as f: 60 | file_list = f.readlines() 61 | random.shuffle(file_list) 62 | num_samples = len(file_list) 63 | num_records = math.ceil(num_samples / num_samples_per_record) 64 | for i, record_idx in enumerate(range(curr_pos_records, curr_pos_records + num_records)): 65 | writer = tf.python_io.TFRecordWriter(os.path.join(args.output, 66 | set, 67 | 'train_{}_{:04d}.tfrecord'.format(set, record_idx))) 68 | j = i * num_samples_per_record 69 | while j < (i + 1) * num_samples_per_record and j < num_samples: 70 | tf_example = dict_to_tf_example(file_list[j][:-1]) 71 | writer.write(tf_example.SerializeToString()) 72 | j += 1 73 | writer.close() 74 | 75 | 76 | if __name__ == '__main__': 77 | main() 78 | 79 | -------------------------------------------------------------------------------- /data_collection/README.md: -------------------------------------------------------------------------------- 1 | ### Collect new data in V-REP 2 | 3 | To collect new grasp data in V-REP, we you should first run the scene (i.e., run ```ur.ttt```) in simulator, 4 | then control the actions using the remote python api. 5 | 6 | #### Step 1: Launch the V-REP simulator and load scene. 7 | You first need to have access to the V-REP 3.4 simulator, the software is available at [here](http://www.coppeliarobotics.com/). 8 | Then run the script to launch V-REP: 9 | ```bash 10 | cd path_to_vrep 11 | ./vrep 12 | ``` 13 | Load the scene ```ur.ttt``` into V-REP, the scene file and object models are available at 14 | [here](https://drive.google.com/open?id=1YQfju1x6_Kj7Hc0hPD154YmPWP3SbZ5h) 15 | 16 | #### Step 2: run control script to collect data. 17 | ```bash 18 | python data_collection/corrective_grasp_trials.py \ 19 | --ip 127.0.0.1 \ # ip address to the vrep simulator. 20 | --port 19997 \ # port to the vrep simulator. 21 | --obj_id obj0000 \ # object handle in vrep. 22 | --num_grasp 200 \ # the number of grasp trails. 23 | --num_repeat 1 \ # the number of repeat time if the gripper successfully grasp an object. 24 | --output data/3dnet # directory to save data. 25 | ``` 26 | 27 | If you want to collect data which is not guided by antipodal rule, you can run the script: 28 | ```bash 29 | python data_collection/grasp_trials_without_rule.py \ 30 | --ip 127.0.0.1 \ # ip address to the vrep simulator. 31 | --port 19997 \ # port to the vrep simulator. 32 | --obj_id obj0000 \ # object handle in vrep. 33 | --num_grasp 200 \ # the number of grasp trails. 34 | --output data/3dnet # directory to save data. 35 | ``` 36 | 37 | ### Data structure 38 | A grasp sample consists of an RGB image of the whole workspace, 39 | the grasp label that illustrates which locations are graspable or not 40 | as well as how much degrees the gripper should rotate, 41 | and object coordinates that indicate the whole object pixel locations in the RGB image. 42 | 43 | ### Contents of directories 44 | * **3dnet** 45 | * **0000** 46 | * **color**: The raw RGB images obtained from vision sensor. 47 | * **000000.png** 48 | * **000001_0000.png** 49 | * ...... 50 | * **depth**: The raw depth images obtained from vision sensor. 51 | * **000000.png** 52 | * **000001_0000.png** 53 | * ...... 54 | * **height_map_color**: The cropped RGB images that only contain the information of workspace. 55 | * **000000.png** 56 | * **000001_0000.png** 57 | * ...... 58 | * **height_map_depth**: The cropped depth images that only contain the information of workspace. 59 | * **000000.png** 60 | * **000001_0000.png** 61 | * ...... 62 | * **label** 63 | * **000000.bad.txt**: This file contains the grasp points in image space and corresponding grasp angles. 64 | * **000000.object_points.txt**: This file contains coordinates that belong to object. 65 | * **000000.png**: The visualization of the grasp angles. 66 | * **000001_0000.good.txt** 67 | * **000001_0000.object_points.txt** 68 | * **000001_0000.png** 69 | * ...... 70 | * **background_color.png** 71 | * **background_depth.png** 72 | * **crop_background_color.png** 73 | * **crop_background_depth.png** 74 | * **file_name.txt** 75 | * **0001** 76 | * ...... 77 | * ...... 78 | 79 | -------------------------------------------------------------------------------- /data_processing/rescale_data.py: -------------------------------------------------------------------------------- 1 | from data_processing.dataset_utils import RescaleData, DataInfo 2 | 3 | 4 | import numpy as np 5 | import cv2 6 | 7 | import argparse 8 | import os 9 | 10 | parser = argparse.ArgumentParser(description='process labels') 11 | parser.add_argument('--data_dir', 12 | required=True, 13 | type=str, 14 | help='path to the data') 15 | args = parser.parse_args() 16 | 17 | 18 | def main(): 19 | data_dir = os.listdir(args.data_dir) 20 | info = DataInfo() 21 | for data_id in data_dir: 22 | parent_dir = os.path.join(args.data_dir, data_id) 23 | print(parent_dir) 24 | os.mkdir(os.path.join(parent_dir, 'zoomed_height_map_color')) 25 | os.mkdir(os.path.join(parent_dir, 'zoomed_height_map_depth')) 26 | os.mkdir(os.path.join(parent_dir, 'zoomed_label')) 27 | background_depth = cv2.imread(os.path.join(parent_dir, 'background_depth.png'), cv2.IMREAD_ANYDEPTH) 28 | with open(os.path.join(parent_dir, 'file_name.txt'), 'r') as f: 29 | with open(os.path.join(parent_dir, 'zoomed_file_name.txt'), 'w') as f_zoom: 30 | file_names = f.readlines() 31 | for file_name in file_names: 32 | file_name = file_name[:-2] if file_name[-2:] == '\r\n' else file_name[:-1] 33 | print(file_name) 34 | color = cv2.imread(os.path.join(parent_dir, 'color', file_name + '.png')) 35 | depth = cv2.imread(os.path.join(parent_dir, 'depth', file_name + '.png'), cv2.IMREAD_ANYDEPTH) 36 | try: 37 | grasp_labels = np.loadtxt(os.path.join(parent_dir, 'label', file_name + '.good.txt')) 38 | label_flag = 'good' 39 | except IOError: 40 | grasp_labels = np.loadtxt(os.path.join(parent_dir, 'label', file_name + '.bad.txt')) 41 | label_flag = 'bad' 42 | rd = RescaleData(color, depth, background_depth, grasp_labels, info) 43 | for i in range(7): 44 | factor = np.random.uniform(0.20/info.camera_height, 1.0) 45 | crop_color, crop_depth, label, label_points, object_points = rd.get_zoomed_data(factor) 46 | cv2.imwrite(os.path.join(parent_dir, 'zoomed_height_map_color', 47 | file_name + '.{:0.7f}.png'.format(factor)), crop_color) 48 | cv2.imwrite(os.path.join(parent_dir, 'zoomed_height_map_depth', 49 | file_name + '.{:0.7f}.png'.format(factor)), crop_depth) 50 | cv2.imwrite(os.path.join(parent_dir, 'zoomed_label', file_name + '.{:0.7f}.png'.format(factor)), 51 | label) 52 | np.savetxt(os.path.join(parent_dir, 'zoomed_label', 53 | file_name + '.{:0.7f}.'.format(factor) + label_flag + '.txt'), 54 | label_points) 55 | np.savetxt(os.path.join(parent_dir, 'zoomed_label', 56 | file_name + '.{:0.7f}.'.format(factor) + 'object_points.txt'), 57 | object_points) 58 | f_zoom.write(file_name + '.{:0.7f}\n'.format(factor)) 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | 64 | 65 | -------------------------------------------------------------------------------- /metagrasp_test.py: -------------------------------------------------------------------------------- 1 | from models import * 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import cv2 6 | 7 | import argparse 8 | import os 9 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 10 | 11 | parser = argparse.ArgumentParser(description='network evaluating') 12 | parser.add_argument('--color_dir', 13 | default='data/test/000024.png', 14 | type=str, 15 | help='The directory where the color image can be found.') 16 | parser.add_argument('--output_dir', 17 | default='data/output', 18 | type=str, 19 | help='The directory where the color image can be found.') 20 | parser.add_argument('--checkpoint_dir', 21 | default='models/metagrasp', 22 | type=str, 23 | help='The directory where the checkpoint can be found') 24 | args = parser.parse_args() 25 | 26 | 27 | def main(): 28 | colors_n = cv2.resize(cv2.imread(args.color_dir), (200, 200))[..., ::-1] 29 | pad_size = 44 30 | pad_colors = np.zeros((288, 288, 3), dtype=np.uint8) 31 | pad_colors[pad_size:pad_size + 200, pad_size:pad_size + 200, :] = colors_n 32 | 33 | colors_p = tf.placeholder(dtype=tf.float32, shape=[1, 288, 288, 3]) 34 | colors = colors_p * tf.random_normal(colors_p.get_shape(), mean=1, stddev=0.01) 35 | colors = colors / tf.constant([255.0]) 36 | colors = (colors - tf.constant([0.485, 0.456, 0.406])) / tf.constant([0.229, 0.224, 0.225]) 37 | 38 | net, end_points = metagrasp(colors, 39 | num_classes=3, 40 | num_channels=1000, 41 | is_training=False, 42 | global_pool=False, 43 | output_stride=16, 44 | spatial_squeeze=False, 45 | scope='metagrasp') 46 | probability_map = tf.exp(net) / tf.reduce_sum(tf.exp(net), axis=3, keepdims=True) 47 | probability_map = tf.image.resize_bilinear(probability_map, [288, 288]) 48 | saver = tf.train.Saver() 49 | session_config = tf.ConfigProto(allow_soft_placement=True, 50 | log_device_placement=False) 51 | session_config.gpu_options.allow_growth = True 52 | sess = tf.Session(config=session_config) 53 | saver.restore(sess, tf.train.latest_checkpoint(args.checkpoint_dir)) 54 | print('Successfully loading model: {}.'.format(tf.train.latest_checkpoint(args.checkpoint_dir))) 55 | sess.run(tf.local_variables_initializer()) 56 | 57 | outputs_a = [] 58 | outputs_a_plus_c = [] 59 | for i in range(16): 60 | mtx = cv2.getRotationMatrix2D((144, 144), 22.5 * i, 1) 61 | rotated_colors = cv2.warpAffine(pad_colors, mtx, (288, 288)) 62 | 63 | output = sess.run(probability_map, 64 | feed_dict={colors_p: np.expand_dims(rotated_colors, 0).astype(np.float32)}) 65 | outputs_a.append(output[..., 1]) # extract green channel 66 | outputs_a_plus_c.append((rotated_colors*0.7 + np.squeeze(output)*255.0*0.3).astype(np.uint8)) 67 | 68 | cv2.imwrite(os.path.join(args.output_dir, 'rotated_colors_{}.png'.format(i)), 69 | rotated_colors[..., ::-1]) 70 | cv2.imwrite(os.path.join(args.output_dir, 'output_{}.png'.format(i)), 71 | (np.squeeze(output)*255.0).astype(np.uint8)[..., ::-1]) 72 | outputs = np.concatenate(outputs_a, axis=0) 73 | threshold = np.max(outputs) - 0.001 74 | for idx, h, w in zip(*np.where(outputs >= threshold)): 75 | cv2.circle(outputs_a_plus_c[idx], (w, h), 1, color=(0, 255, 0), thickness=5) 76 | vis_map = np.concatenate( 77 | tuple([np.concatenate(tuple(outputs_a_plus_c[4 * i:4 * (i + 1)]), axis=1) for i in range(4)]), 78 | axis=0) 79 | cv2.imshow('visualization_map', vis_map[..., ::-1]) 80 | cv2.waitKey(0) 81 | cv2.imwrite(os.path.join(args.output_dir, 'visualization_map.png'), vis_map[..., ::-1]) 82 | 83 | sess.close() 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | 89 | -------------------------------------------------------------------------------- /mag_test.py: -------------------------------------------------------------------------------- 1 | from models import * 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import cv2 6 | 7 | import argparse 8 | import os 9 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 10 | 11 | parser = argparse.ArgumentParser(description='network evaluating') 12 | parser.add_argument('--color_dir', 13 | default='data/test/color.png', 14 | type=str, 15 | help='The directory where the color image can be found.') 16 | parser.add_argument('--output_dir', 17 | default='data/output', 18 | type=str, 19 | help='The directory where the color image can be found.') 20 | parser.add_argument('--checkpoint_dir', 21 | default='models/mag', 22 | type=str, 23 | help='The directory where the checkpoint can be found') 24 | args = parser.parse_args() 25 | 26 | 27 | def main(): 28 | colors_n = cv2.resize(cv2.imread(args.color_dir), (200, 200))[..., ::-1] 29 | pad_size = 44 30 | pad_colors = np.zeros((288, 288, 3), dtype=np.uint8) 31 | pad_colors[pad_size:pad_size + 200, pad_size:pad_size + 200, :] = colors_n 32 | 33 | colors_p = tf.placeholder(dtype=tf.float32, shape=[1, 288, 288, 3]) 34 | colors = colors_p * tf.random_normal(colors_p.get_shape(), mean=1, stddev=0.01) 35 | colors = colors / tf.constant([255.0]) 36 | colors = (colors - tf.constant([0.485, 0.456, 0.406])) / tf.constant([0.229, 0.224, 0.225]) 37 | 38 | net, end_points = mag(colors, 39 | num_classes=3, 40 | num_channels=1000, 41 | is_training=False, 42 | global_pool=False, 43 | output_stride=16, 44 | upsample_ratio=2, 45 | spatial_squeeze=False, 46 | reuse=tf.AUTO_REUSE, 47 | scope='mag') 48 | probability_map = tf.exp(net) / tf.reduce_sum(tf.exp(net), axis=3, keepdims=True) 49 | probability_map = tf.image.resize_bilinear(probability_map, [288, 288]) 50 | saver = tf.train.Saver() 51 | session_config = tf.ConfigProto(allow_soft_placement=True, 52 | log_device_placement=False) 53 | session_config.gpu_options.allow_growth = True 54 | sess = tf.Session(config=session_config) 55 | saver.restore(sess, tf.train.latest_checkpoint(args.checkpoint_dir)) 56 | print('Successfully loading model: {}.'.format(tf.train.latest_checkpoint(args.checkpoint_dir))) 57 | sess.run(tf.local_variables_initializer()) 58 | 59 | outputs_a = [] 60 | outputs_a_plus_c = [] 61 | for i in range(16): 62 | mtx = cv2.getRotationMatrix2D((144, 144), 22.5 * i, 1) 63 | rotated_colors = cv2.warpAffine(pad_colors, mtx, (288, 288)) 64 | 65 | output = sess.run(probability_map, 66 | feed_dict={colors_p: np.expand_dims(rotated_colors, 0).astype(np.float32)}) 67 | outputs_a.append(output[..., 1]) # extract green channel 68 | outputs_a_plus_c.append((rotated_colors*0.7 + np.squeeze(output)*255.0*0.3).astype(np.uint8)) 69 | 70 | cv2.imwrite(os.path.join(args.output_dir, 'rotated_colors_{}.png'.format(i)), 71 | rotated_colors[..., ::-1]) 72 | cv2.imwrite(os.path.join(args.output_dir, 'output_{}.png'.format(i)), 73 | (np.squeeze(output)*255.0).astype(np.uint8)[..., ::-1]) 74 | outputs = np.concatenate(outputs_a, axis=0) 75 | threshold = np.max(outputs) - 0.001 76 | for idx, h, w in zip(*np.where(outputs >= threshold)): 77 | cv2.circle(outputs_a_plus_c[idx], (w, h), 1, color=(0, 255, 0), thickness=5) 78 | vis_map = np.concatenate( 79 | tuple([np.concatenate(tuple(outputs_a_plus_c[4 * i:4 * (i + 1)]), axis=1) for i in range(4)]), 80 | axis=0) 81 | cv2.imshow('visualization_map', vis_map[..., ::-1]) 82 | cv2.waitKey(0) 83 | cv2.imwrite(os.path.join(args.output_dir, 'visualization_map.png'), vis_map[..., ::-1]) 84 | 85 | sess.close() 86 | 87 | 88 | if __name__ == '__main__': 89 | main() 90 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.slim as slim 3 | 4 | from nets import resnet_v1 5 | 6 | 7 | def arg_scope(weight_decay=0.0005, 8 | batch_norm_decay=0.997, 9 | batch_norm_epsilon=1e-5, 10 | batch_norm_scale=True, 11 | activation_fn=tf.nn.relu, 12 | use_batch_norm=True): 13 | batch_norm_params = { 14 | 'decay': batch_norm_decay, 15 | 'epsilon': batch_norm_epsilon, 16 | 'scale': batch_norm_scale, 17 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 18 | 'fused': None, # Use fused batch norm if possible. 19 | } 20 | with slim.arg_scope( 21 | [slim.conv2d, slim.conv2d_transpose], 22 | weights_regularizer=slim.l2_regularizer(weight_decay), 23 | weights_initializer=slim.variance_scaling_initializer(), 24 | activation_fn=activation_fn, 25 | normalizer_fn=slim.batch_norm if use_batch_norm else None, 26 | normalizer_params=batch_norm_params): 27 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 28 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc: 29 | return arg_sc 30 | 31 | 32 | @slim.add_arg_scope 33 | def resize_bilinear(inputs, 34 | height, 35 | width, 36 | outputs_collections=None, 37 | scope=None): 38 | with tf.variable_scope(scope, 'resize', [inputs]) as sc: 39 | outputs = tf.image.resize_bilinear(inputs, [height, width], name='resize_bilinear') 40 | return slim.utils.collect_named_outputs(outputs_collections, sc.name, outputs) 41 | 42 | 43 | # multi-affordance grasping 44 | def mag(inputs, 45 | num_classes=3, 46 | num_channels=1000, 47 | is_training=True, 48 | global_pool=False, 49 | output_stride=16, 50 | upsample_ratio=2, 51 | spatial_squeeze=False, 52 | reuse=tf.AUTO_REUSE, 53 | scope='graspnet'): 54 | with tf.variable_scope(scope, 'graspnet', [inputs], reuse=reuse): 55 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 56 | net, end_points = resnet_v1.resnet_v1_50(inputs=inputs, 57 | num_classes=num_channels, 58 | is_training=is_training, 59 | global_pool=global_pool, 60 | output_stride=output_stride, 61 | spatial_squeeze=spatial_squeeze, 62 | scope='feature_extractor') 63 | with tf.variable_scope('prediction', [net]) as sc: 64 | end_points_collection = sc.original_name_scope + '_end_points' 65 | # to do: add batch normalization to the following conv layers. 66 | with slim.arg_scope([slim.conv2d], 67 | outputs_collections=end_points_collection): 68 | net = slim.conv2d(net, 512, [1, 1], scope='conv1') 69 | net = slim.conv2d(net, 128, [1, 1], scope='conv2') 70 | net = slim.conv2d(net, num_classes, [1, 1], scope='conv3') 71 | height, width = net.get_shape().as_list()[1:3] 72 | net = tf.image.resize_bilinear(net, 73 | [height * upsample_ratio, width * upsample_ratio], 74 | name='resize_bilinear') 75 | end_points.update(slim.utils.convert_collection_to_dict(end_points_collection)) 76 | end_points['logits'] = net 77 | return net, end_points 78 | 79 | 80 | def metagrasp(inputs, 81 | num_classes=3, 82 | num_channels=1000, 83 | is_training=True, 84 | global_pool=False, 85 | output_stride=16, 86 | spatial_squeeze=False, 87 | reuse=tf.AUTO_REUSE, 88 | scope='metagrasp'): 89 | with tf.variable_scope(scope, 'metagrasp', [inputs], reuse=reuse): 90 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 91 | net, end_points = resnet_v1.resnet_v1_50(inputs=inputs, 92 | num_classes=num_channels, 93 | is_training=is_training, 94 | global_pool=global_pool, 95 | output_stride=output_stride, 96 | spatial_squeeze=spatial_squeeze, 97 | scope='feature_extractor') 98 | with tf.variable_scope('prediction', [net]) as sc: 99 | end_points_collection = sc.original_name_scope + '_end_points' 100 | # to do: add batch normalization to the following conv layers. 101 | with slim.arg_scope([slim.conv2d, resize_bilinear], 102 | outputs_collections=end_points_collection): 103 | net = slim.conv2d(net, 512, [3, 3], scope='conv1') 104 | height, width = net.get_shape().as_list()[1:3] 105 | net = resize_bilinear(net, height * 2, width * 2, scope='resize_bilinear1') 106 | net = slim.conv2d(net, 256, [3, 3], scope='conv2') 107 | height, width = net.get_shape().as_list()[1:3] 108 | net = resize_bilinear(net, height * 2, width * 2, scope='resize_bilinear2') 109 | net = slim.conv2d(net, 128, [3, 3], scope='conv3') 110 | height, width = net.get_shape().as_list()[1:3] 111 | net = resize_bilinear(net, height * 2, width * 2, scope='resize_bilinear3') 112 | net = slim.conv2d(net, 64, [3, 3], scope='conv4') 113 | height, width = net.get_shape().as_list()[1:3] 114 | net = resize_bilinear(net, height * 2, width * 2, scope='resize_bilinear4') 115 | net = slim.conv2d(net, num_classes, [5, 5], scope='conv5') 116 | end_points.update(slim.utils.convert_collection_to_dict(end_points_collection)) 117 | end_points['logits'] = net 118 | return net, end_points 119 | -------------------------------------------------------------------------------- /data_processing/data_processor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class DataProcessor(object): 5 | def __init__(self): 6 | pass 7 | 8 | @staticmethod 9 | def get_grasp_center(grasp_labels): 10 | """ 11 | Get grasp center. 12 | :param grasp_labels: A numpy array with shape [num_labels, 4]. 13 | Each row represents a grasp label formulated as 2 points (i.e., (x1, y1, x2, y2)). 14 | :return: grasp_center: A numpy array with shape [num_labels, 2]. 15 | Each row represents a grasp center corresponding to the grasp label. 16 | """ 17 | row, _ = grasp_labels.shape 18 | grasp_center = np.empty((row, 2), dtype=np.float32) 19 | grasp_center[:, 0] = (grasp_labels[:, 0] + grasp_labels[:, 2]) / 2.0 20 | grasp_center[:, 1] = (grasp_labels[:, 1] + grasp_labels[:, 3]) / 2.0 21 | return grasp_center 22 | 23 | @staticmethod 24 | def get_grasp_angle(grasp_label): 25 | """ 26 | Get grasp angle. 27 | :param grasp_label: A numpy array with shape [1, 4] which represents 28 | a grasp label formulated as 2 points (i.e., (x1, y1, x2, y2)). 29 | :return: angle_indices: A list of int with length 2. The discretized angles ranged from 0 to 15. 30 | Besides the original grasp angle, the list contains another angle flipped vertically by original one. 31 | """ 32 | pt1 = grasp_label[0:2] 33 | pt2 = grasp_label[2:] 34 | angle = np.arctan2(pt2[0] - pt1[0], pt2[1] - pt1[1]) 35 | if angle < 0: 36 | angle += np.pi * 2 37 | angle_indices = [] 38 | angle_indices.append(int(round(angle / ((22.5 / 360.0) * np.pi * 2)))) 39 | if angle >= np.pi: 40 | angle_indices.append(int(round((angle - np.pi) / ((22.5 / 360.0) * np.pi * 2)))) 41 | else: 42 | angle_indices.append(int(round((angle + np.pi) / ((22.5 / 360.0) * np.pi * 2)))) 43 | return angle_indices 44 | 45 | @staticmethod 46 | def rotate(points, center, angle): 47 | """ 48 | Rotate points. 49 | :param points: A numpy array with shape [num_points, 2]. 50 | :param center: A numpy array with shape [1, 2]. The rotation center of the points. 51 | :param angle: A float. The rotated angle represented in radian. 52 | :return: points: A numpy array with shape [num_points, 2]. The rotated points. 53 | """ 54 | points = points.copy() 55 | h = center[0] 56 | w = center[1] 57 | points[:, 0] -= h 58 | points[:, 1] -= w 59 | rotate_matrix = np.array([[np.cos(angle), -np.sin(angle)], 60 | [np.sin(angle), np.cos(angle)]]) 61 | points = np.dot(rotate_matrix, points.T).T 62 | points[:, 0] += h 63 | points[:, 1] += w 64 | return points 65 | 66 | @staticmethod 67 | def get_diff_depth(depth_o, depth_b): 68 | """ 69 | Get difference depth image by subtracting current depth image and background depth image. 70 | :param depth_o: A numpy array with shape [height, width]. The current depth image. 71 | :param depth_b: A numpy array with shape [height, width]. The background depth image. 72 | :return: diff_depth: A numpy array with shape [height, width]. The difference depth image. 73 | """ 74 | diff_depth = depth_b - depth_o 75 | diff_depth[np.where(diff_depth < 0)] = 0 76 | # diff_depth = cv2.medianBlur(diff_depth, 3) 77 | diff_depth = diff_depth.astype(np.uint16) 78 | return diff_depth 79 | 80 | @staticmethod 81 | def encode_depth(depth): 82 | """ 83 | Encode depth image to RGB format. 84 | :param depth: A numpy array with shape [height, width]. The depth image. 85 | :return: 86 | """ 87 | r = depth / 256 / 256 88 | g = depth / 256 89 | b = depth % 256 90 | # encoded_depth = np.stack([r, g, b], axis=2).astype(np.uint8) 91 | encoded_depth = np.stack([b, g, r], axis=2).astype(np.uint8) # use bgr order due to cv2 format 92 | return encoded_depth 93 | 94 | @staticmethod 95 | def gaussianize_label(grasp_label, grasp_centers, camera_height, is_good, expand=True): 96 | if is_good: 97 | for grasp_center in grasp_centers: 98 | if grasp_label[grasp_center[0], grasp_center[1], 0] == 255: 99 | # grasp_label[grasp_center[0], grasp_center[1], 0] = 0 100 | # grasp_label[grasp_center[0], grasp_center[1], 1] = 255 101 | pass 102 | else: 103 | left = right = 0 104 | while grasp_label[grasp_center[0], grasp_center[1]-left, 0] == 0: 105 | left += 1 106 | while grasp_label[grasp_center[0], grasp_center[1]+right, 0] == 0: 107 | right += 1 108 | 109 | def gauss(x, c, sigma): 110 | return int(255 * np.exp(-(x-c) ** 2 / (2 * sigma ** 2))) 111 | width = left + right 112 | # width = min(left, right) * 2 113 | grasp_center[1] += width / 2 - left 114 | # sigma = camera_height * width 115 | # sigma = width / (camera_height * 10.0) 116 | sigma = 0.03 / (0.4 * camera_height / 0.57) * 200.0 / 3.0 117 | for idx in range(grasp_center[1]-width/2, grasp_center[1]+width/2+1): 118 | value = gauss(idx, grasp_center[1], sigma) 119 | grasp_label[grasp_center[0], idx, 1] = value 120 | grasp_label[grasp_center[0], idx, 2] = 255 - value 121 | else: 122 | if expand: 123 | for grasp_center in grasp_centers: 124 | if grasp_label[grasp_center[0], grasp_center[1], 0] == 255: 125 | grasp_label[grasp_center[0], grasp_center[1], 0] = 0 126 | grasp_label[grasp_center[0], grasp_center[1], 2] = 255 127 | else: 128 | left = right = 0 129 | while grasp_label[grasp_center[0], grasp_center[1]-left, 0] == 0: 130 | grasp_label[grasp_center[0], grasp_center[1]-left, 2] = 255 131 | left += 1 132 | while grasp_label[grasp_center[0], grasp_center[1]+right, 0] == 0: 133 | grasp_label[grasp_center[0], grasp_center[1]+right, 2] = 255 134 | right += 1 135 | else: 136 | grasp_label[grasp_centers[:, 0], grasp_centers[:, 1], 0] = 0 137 | grasp_label[grasp_centers[:, 0], grasp_centers[:, 1], 2] = 255 138 | return grasp_label 139 | 140 | 141 | 142 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MetaGrasp: Data Efficient Grasping by Affordance Interpreter Network 2 | This repository is the code for the paper 3 | **[MetaGrasp: Data Efficient Grasping by Affordance Interpreter Network](https://sysu-robotics-lab.github.io/MetaGrasp/)** 4 | 5 | [Junhao Cai](None)1 6 | [Hui Cheng](http://sdcs.sysu.edu.cn/content/2504)1 7 | [Zhangpeng Zhang](https://zhzhanp.github.io/)2 8 | [Jingcheng Su](None)1 9 | 10 | 1Sun Yat-sen University 11 | 2Sensetime Group Limited 12 | 13 | Accepted at International Conference on Robotics and Automation 2019 (ICRA2019) . 14 | 15 |
16 | 17 |
18 |
Fig.1. The grasp system pipeline.
19 | 20 | ### MetaGrasp: Data Efficient Grasping by Affordance Interpreter Network 21 | **Abstract** Data-driven approach for grasping shows significant advance recently. 22 | But these approaches usually require much training data. To increase the efficiency 23 | of grasping data collection, this paper presents a novel grasp training system 24 | including the whole pipeline from data collection to model inference. The system 25 | can collect effective grasp sample with a corrective strategy assisted by antipodal 26 | grasp rule, and we design an affordance interpreter network to predict pixelwise 27 | grasp affordance map. We define graspability, ungraspability and background as grasp 28 | affordances. The key advantage of our system is that the pixel-level affordance 29 | interpreter network trained with only a small number of grasp samples under antipodal 30 | rule can achieve significant performance on totally unseen objects and backgrounds. 31 | The training sample is only collected in simulation. Extensive qualitative and 32 | quantitative experiments demonstrate the accuracy and robustness of our proposed 33 | approach. In the real-world grasp experiments, we achieve a grasp success rate of 93% 34 | on a set of household items and 91% on a set of adversarial items with only about 35 | 6,300 simulated samples. We also achieve 87% accuracy in clutter scenario. 36 | Although the model is trained using only RGB image, when changing the background 37 | textures, it also performs well and can achieve even 94% accuracy on the set of 38 | adversarial objects, which outperforms current state-of-the-art methods. 39 | **Citing** 40 | If you find this code useful in your work, pleace consider citing: 41 | ```bash 42 | @article{cai2019data, 43 | title={Data Efficient Grasping by Affordance Interpreter Network}, 44 | author={Junhao Cai, Hui Cheng, Zhanpeng Zhang, Jingcheng Su}, 45 | booktitle={IEEE Conference of Robotics and Automation}, 46 | year={2019} 47 | } 48 | ``` 49 | 50 | **Contact** 51 | 52 | If you have any questions or fine any bugs, please contact to Junhao Cai (caijh28@mail2.sysu.edu.cn). 53 | #### Requirements and Dependencies 54 | * Python 3 55 | * Tensorflow 1.9 or later 56 | * Opencv 3.4.0 or later 57 | * NVIDIA GPU with compute capability 3.5+ 58 | #### Quick Start 59 | Given an RGB image containing the whole workspace, we can obtain the affordance map using our pre-trained model by: 60 | 1. Download our pre-trained model: 61 | ```bash 62 | ./download.sh 63 | ``` 64 | 2. run the test script: 65 | ```bash 66 | python metagrasp_test.py \ 67 | --color_dir data/test/test1.png \ 68 | --output_dir data/output \ 69 | --checkpoint_dir models/metagrasp 70 | ``` 71 | The visual results are saved in the output folder. 72 | 73 | #### Training 74 | To train new model, you should follow the steps below : 75 | 76 | **Step 1: Collect the dataset** 77 | You can obtain grasp dataset in V-REP using our data collecting system, 78 | the details can be seen in ```data_collection/README.md```. 79 | 80 | **Step 2: Preprocess the data** 81 | After the data collection, you should preprocess the data for the data converting process. Run the script: 82 | ```bash 83 | python data_processing/data_preprocessing.py \ 84 | --data_dir data/3dnet \ 85 | --output data/training 86 | ``` 87 | Then use ```data_processing/create_tf_record.py``` to generate positive and negative tfrecords. 88 | ```bash 89 | python data_processing/create_tf_record.py \ 90 | --data_path data/training \ 91 | --output data/tfrecords 92 | ``` 93 | **Step 3: Download pre-trained model** 94 | 95 | We use [Resnet](https://arxiv.org/abs/1512.03385) as initialized parameters of feature extractor of the model, 96 | you can have access to the pre-trained model [here](http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz). 97 | ```bash 98 | mkdir models/checkpoints && cd models/checkpoints 99 | wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz 100 | cd ../.. 101 | ``` 102 | **Step 4: Train the model** 103 | 104 | Then you can train the model with the script ```ain_train_color.py```. 105 | For the basic usage the command will look something like this: 106 | ```bash 107 | python metagrasp_train.py \ 108 | --train_log_dir models/logs_metagrasp 109 | --dataset_dir data/tfrecords 110 | --checkpoint_dir models/checkpoints 111 | ``` 112 | You can visualize the training results using Tensorboard by the following command: 113 | ```bash 114 | tensorboard --logdir=models/logs_metagrasp --host=localhost --port=6666 115 | ``` 116 | 117 | #### Data format 118 | 119 | In this section, we will first introduce the data structure of a grasp sample, 120 | then we will illustrate how to preprocess the sample so that we can obtain 121 | horizontally antipodal grasp data. 122 | 123 | **1. Data structure** 124 | 125 | A grasp sample consists of 126 | 1\) the RGB image which contains the whole workspace, 127 | 2\) the grasp point represented by image coordinates in image space, 128 | 3\) the grasp angle, and 129 | 4\) the grasp label with respect to the grasp point and the grasp angle. 130 | Note that we only assign the labels to the pixels where we do grasp trial, 131 | so we don't use the other regions of the grasp object when training. 132 | The visualization of a sample is shown as below. 133 |
134 | 135 | 136 |
137 |
Fig.2. Raw image captured from vision sensor.
138 | 139 |
140 | 141 | 142 | 143 |
144 |
Fig.3. Cropped images and grasp label.
145 | 146 | 147 | **2. Data after preprocessing** 148 | 149 | During the preprocessing, we first rotate the RGB image and select the rotated images 150 | which satisfy the horizontally antipodal grasp rule. Then we augment the data by 151 | slightly rotate the images and corresponding grasp angles. The whole process is shown in 152 | Fig.4. 153 |
154 | 155 |
156 |
Fig.4. Cropped images and grasp label.
157 | 158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /data_collection/ur.py: -------------------------------------------------------------------------------- 1 | from functools import partial, reduce 2 | from math import sqrt 3 | import time 4 | 5 | import data_collection.vrep as vrep 6 | 7 | 8 | class UR5(object): 9 | def __init__(self, client_id): 10 | """ 11 | Initialization. 12 | :param client_id: An int. The client ID. refer to simxStart. 13 | """ 14 | self.client_id = client_id 15 | self.joint_handles = [vrep.simxGetObjectHandle(self.client_id, 16 | 'UR5_joint{}'.format(i), 17 | vrep.simx_opmode_blocking)[1] for i in range(1, 6+1)] 18 | _, self.ik_tip_handle = vrep.simxGetObjectHandle(self.client_id, 19 | 'UR5_ik_tip', 20 | vrep.simx_opmode_blocking) 21 | _, self.ik_target_handle = vrep.simxGetObjectHandle(self.client_id, 22 | 'UR5_ik_target', 23 | vrep.simx_opmode_blocking) 24 | self.func = partial(vrep.simxCallScriptFunction, 25 | clientID=self.client_id, 26 | scriptDescription="UR5", 27 | options=vrep.sim_scripttype_childscript, 28 | operationMode=vrep.simx_opmode_blocking) 29 | self.initialization() 30 | 31 | def initialization(self): 32 | """ 33 | Call script function pyInit in vrep child script. 34 | :return: None 35 | """ 36 | _ = self.func(functionName='pyInit', 37 | inputInts=[], 38 | inputFloats=[], 39 | inputStrings=[], 40 | inputBuffer='') 41 | time.sleep(0.7) 42 | 43 | def wait_until_stop(self, handle, threshold=0.01): 44 | """ 45 | Wait until the operation finishes. 46 | This is a delay function called in order to make sure that 47 | the operation executed has been completed. 48 | :param handle: An int.Handle of the object. 49 | :param threshold: A float. The object position threshold. 50 | If the object positions difference between two time steps is smaller than the threshold, 51 | the execution completes, otherwise the loop continues. 52 | :return: None 53 | """ 54 | while True: 55 | _, pos1 = vrep.simxGetObjectPosition(self.client_id, handle, -1, vrep.simx_opmode_blocking) 56 | _, quat1 = vrep.simxGetObjectQuaternion(self.client_id, handle, -1, vrep.simx_opmode_blocking) 57 | time.sleep(0.7) 58 | _, pos2 = vrep.simxGetObjectPosition(self.client_id, handle, -1, vrep.simx_opmode_blocking) 59 | _, quat2 = vrep.simxGetObjectQuaternion(self.client_id, handle, -1, vrep.simx_opmode_blocking) 60 | pose1 = pos1 + quat1 61 | pose2 = pos2 + quat2 62 | theta = 0.5 * sqrt(reduce(lambda x, y: x + y, map(lambda x, y: (x - y) ** 2, pose1, pose2))) 63 | if theta < threshold: 64 | return 65 | 66 | def enable_ik(self, enable=0): 67 | """ 68 | Call script function pyEnableIk in vrep child script. 69 | :param enable: A int. Whether to enable inverse kinematic. 70 | If enable = 1 the ur5 enables ik, and vice versa. 71 | :return: None 72 | """ 73 | _ = self.func(functionName='pyEnableIk', 74 | inputInts=[enable], 75 | inputFloats=[], 76 | inputStrings=[], 77 | inputBuffer='') 78 | 79 | def move_to_joint_position(self, joint_angles): 80 | """ 81 | Moves (actuates) several joints at the same time using the Reflexxes Motion Library type II or IV. 82 | Call script function pyMoveToJointPositions in vrep child script. 83 | :param joint_angles: A list of floats. The desired target angle positions of the joints. 84 | :return: None 85 | """ 86 | self.enable_ik(0) 87 | _ = self.func(functionName='pyMoveToJointPositions', 88 | inputInts=[], 89 | inputFloats=joint_angles, 90 | inputStrings=[], 91 | inputBuffer='') 92 | self.wait_until_stop(self.ik_target_handle) 93 | time.sleep(0.7) # todo: remove manual time delay 94 | 95 | def move_to_object_position(self, pose): 96 | """ 97 | Moves an object to a given position and/or orientation using Reflexxes Motion Library type II or IV. 98 | Call script function pyMoveToPosition in vrep child script. 99 | :param pose: A list of floats. The desired target position of the object. 100 | :return: None 101 | """ 102 | self.enable_ik(1) 103 | _ = self.func(functionName='pyMoveToPosition', 104 | inputInts=[], 105 | inputFloats=pose, 106 | inputStrings=[], 107 | inputBuffer='') 108 | self.wait_until_stop(self.ik_target_handle) 109 | 110 | def get_end_effector_position(self): 111 | """ 112 | Get end effector position. 113 | :return: pos: A list of floats. The position of end effector. 114 | """ 115 | _, pos = vrep.simxGetObjectPosition(self.client_id, 116 | self.ik_tip_handle, 117 | -1, 118 | vrep.simx_opmode_blocking) 119 | return pos 120 | 121 | def get_end_effector_quaternion(self): 122 | """ 123 | Get end effector quaternion. 124 | :return: quat: A list of floats. The angle of end effector represented by quaternion. 125 | """ 126 | # _, quat = vrep.simxGetObjectQuaternion(self.client_id, 127 | # self.ik_tip_handle, 128 | # -1, 129 | # vrep.simx_opmode_blocking) 130 | _, _, quat, _, _ = self.func(functionName='pyGetObjectQuaternion', 131 | inputInts=[self.ik_tip_handle, -1], 132 | inputFloats=[], 133 | inputStrings=[], 134 | inputBuffer='') 135 | return quat 136 | 137 | def get_joint_positions(self, is_first_call=False): 138 | """ 139 | Get joint angle positions. 140 | :param is_first_call: A boolean. Specify which operation mode vrep api function chooses. 141 | :return: joint_positions: A list of float. Return value of UR joint angles. 142 | """ 143 | if is_first_call: 144 | opmode = vrep.simx_opmode_streaming 145 | else: 146 | opmode = vrep.simx_opmode_blocking 147 | joint_positions = [vrep.simxGetJointPosition(self.client_id, 148 | joint_handle, 149 | opmode)[1] for joint_handle in self.joint_handles] 150 | return joint_positions 151 | 152 | 153 | -------------------------------------------------------------------------------- /data_collection/grasp_trials_without_rule.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | from scipy import misc 3 | from data_collection.scene import Scene 4 | from functools import partial 5 | import numpy as np 6 | import cv2 7 | import argparse 8 | import os 9 | 10 | parser = argparse.ArgumentParser(description='Data collection in vrep simulator.') 11 | parser.add_argument('--ip', 12 | default='127.0.0.1', 13 | type=str, 14 | help='ip address to the vrep simulator.') 15 | parser.add_argument('--port', 16 | default=19997, 17 | type=str, 18 | help='port to the vrep simulator.') 19 | parser.add_argument('--obj_id', 20 | default='obj0', 21 | type=str, 22 | help='object name in vrep.') 23 | parser.add_argument('--num_grasp', 24 | default=200, 25 | type=int, 26 | help='the number of grasp trails.') 27 | parser.add_argument('--num_repeat', 28 | default=1, 29 | type=int, 30 | help='the number of repeat time if the gripper successfully grasp an object.') 31 | parser.add_argument('--output', 32 | default='data/3dnet', 33 | type=str, 34 | help='directory to save data.') 35 | args = parser.parse_args() 36 | 37 | step = pi / 180.0 38 | ratio = 0.4 / 200 39 | interval = [0.23, 0.57] 40 | rand = partial(np.random.uniform, interval[0], interval[1]) 41 | 42 | 43 | def main(): 44 | if not os.path.exists(args.output): 45 | os.makedirs(args.output) 46 | data_id = len(os.listdir(args.output)) 47 | data_path = os.path.join(args.output, '{:04d}'.format(data_id)) 48 | color_path = os.path.join(data_path, 'color') 49 | depth_path = os.path.join(data_path, 'depth') 50 | height_map_color_path = os.path.join(data_path, 'height_map_color') 51 | height_map_depth_path = os.path.join(data_path, 'height_map_depth') 52 | label_path = os.path.join(data_path, 'label') 53 | os.mkdir(data_path) 54 | os.mkdir(color_path) 55 | os.mkdir(depth_path) 56 | os.mkdir(height_map_color_path) 57 | os.mkdir(height_map_depth_path) 58 | os.mkdir(label_path) 59 | 60 | s = Scene(args.ip, args.port, 'realsense') 61 | init_pos = [45*step, 10*step, 90*step, -10*step, -90*step, -45*step] 62 | top_pos = [-45*step, 10*step, 90*step, -10*step, -90*step, -45*step] 63 | # object handle 64 | obj = s.get_object_handle(args.obj_id) 65 | obj_quat = s.get_object_quaternion(obj) 66 | s.gripper.open_gripper() 67 | 68 | misc.imsave(data_path + '/background_color.png', s.background_color) 69 | cv2.imwrite(data_path + '/background_depth.png', s.background_depth) 70 | misc.imsave(data_path+'/crop_background_color.png', 71 | cv2.resize(s.background_color[45:435, 119:509, :], (200, 200))) 72 | cv2.imwrite(data_path+'/crop_background_depth.png', 73 | cv2.resize(s.background_depth[45:435, 119:509], (200, 200))) 74 | 75 | f = open(os.path.join(data_path, 'file_name.txt'), 'w') 76 | 77 | s.set_object_position(obj, [rand(), rand(), s.get_object_position(obj)[-1]]) 78 | 79 | for i in range(args.num_grasp): # number of grasp trials. 80 | print(i) 81 | s.gripper.open_gripper() 82 | s.ur5.move_to_joint_position(init_pos) 83 | s.replace_object(obj, obj_quat, interval) 84 | color = s.get_color_image() 85 | depth = s.get_depth_image() 86 | center_point, object_points, reaching_height = s.get_center_from_image(depth) 87 | s.ur5.move_to_joint_position(top_pos) 88 | episode = [] 89 | # move to the upward side with respect to the object 90 | init_quat = s.ur5.get_end_effector_quaternion() 91 | target_pos = s.get_object_position(obj) 92 | target_pos[0] = 0.200 + center_point[0] * ratio 93 | target_pos[1] = 0.200 + center_point[1] * ratio 94 | target_pos[2] = reaching_height 95 | target_pos[-1] += 0.1563 96 | episode.append(target_pos + init_quat) 97 | s.ur5.move_to_object_position(target_pos + init_quat) 98 | # randomly rotate gripper's angle 99 | joint_angles = s.ur5.get_joint_positions() 100 | curr_angle = joint_angles[-1] / np.pi * 180.0 101 | angle = np.random.randint(0, 180) 102 | joint_angles[-1] = (curr_angle + angle) * step 103 | s.ur5.move_to_joint_position(joint_angles) 104 | quat = s.ur5.get_end_effector_quaternion() 105 | episode.append(target_pos + quat) 106 | # move to the object downward 107 | target_pos[-1] -= 0.1563 108 | episode.append(target_pos + quat) 109 | s.ur5.move_to_object_position(target_pos + quat) 110 | # try to grasp object 111 | s.gripper.close_gripper() 112 | # determine whether the object is successfully grasped. 113 | if s.gripper.get_object_detection_status(): 114 | print('grasp success.') 115 | width = s.gripper.get_gripper_width() 116 | s.gripper.open_gripper() 117 | s.ur5.move_to_object_position(episode[1]) 118 | crop_color = cv2.resize(color[45:435, 119:509, :], (200, 200)) 119 | crop_depth = cv2.resize(depth[45:435, 119:509], (200, 200)) 120 | points = s.get_label(center_point, -angle * np.pi / 180.0, width, ratio) 121 | label = s.draw_label(points, 200, 200) 122 | misc.imsave(color_path + '/{:06d}.png'.format(i), color) 123 | cv2.imwrite(depth_path + '/{:06d}.png'.format(i), depth) 124 | misc.imsave(height_map_color_path + '/{:06d}.png'.format(i), crop_color) 125 | cv2.imwrite(height_map_depth_path + '/{:06d}.png'.format(i), crop_depth) 126 | misc.imsave(label_path + '/{:06d}.png'.format(i), label) 127 | np.savetxt(label_path + '/{:06d}.good.txt'.format(i), points) 128 | np.savetxt(label_path + '/{:06d}.object_points.txt'.format(i), np.asanyarray(object_points)) 129 | f.write('{:06d}\n'.format(i)) 130 | 131 | else: 132 | print('grasp failed') 133 | s.gripper.open_gripper() 134 | s.ur5.move_to_object_position(episode[1]) 135 | crop_color = cv2.resize(color[45:435, 119:509, :], (200, 200)) 136 | crop_depth = cv2.resize(depth[45:435, 119:509], (200, 200)) 137 | points = s.get_label(center_point, -angle * np.pi / 180.0, 0.085, ratio) 138 | label = s.draw_label(points, 200, 200, (255, 0, 0)) 139 | misc.imsave(color_path + '/{:06d}.png'.format(i), color) 140 | cv2.imwrite(depth_path + '/{:06d}.png'.format(i), depth) 141 | misc.imsave(height_map_color_path + '/{:06d}.png'.format(i), crop_color) 142 | cv2.imwrite(height_map_depth_path + '/{:06d}.png'.format(i), crop_depth) 143 | misc.imsave(label_path + '/{:06d}.png'.format(i), label) 144 | np.savetxt(label_path + '/{:06d}.bad.txt'.format(i), points) 145 | np.savetxt(label_path + '/{:06d}.object_points.txt'.format(i), np.asanyarray(object_points)) 146 | f.write('{:06d}\n'.format(i)) 147 | 148 | s.stop_simulation() 149 | f.close() 150 | 151 | 152 | if __name__ == '__main__': 153 | main() 154 | -------------------------------------------------------------------------------- /data_processing/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import cv2 4 | 5 | 6 | def int64_feature(value): 7 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 8 | 9 | 10 | def int64_list_feature(value): 11 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 12 | 13 | 14 | def bytes_feature(value): 15 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 16 | 17 | 18 | def bytes_list_feature(value): 19 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 20 | 21 | 22 | def float_feature(value): 23 | return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) 24 | 25 | 26 | def float_list_feature(value): 27 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 28 | 29 | 30 | def read_examples_list(path): 31 | with tf.gfile.GFile(path) as fid: 32 | lines = fid.readlines() 33 | return [line.strip().split(',') for line in lines] 34 | 35 | 36 | class DataInfo(object): 37 | camera_height = 0.57 # camera height, 0.57m 38 | original_image_height = 480 # the height of raw image captured from the camera 39 | original_image_width = 640 # the width of raw image captured from the camera 40 | top_left_corner_h = 45 # the first height index of the original image from which the crop image is cropped. 41 | top_left_corner_w = 119 # the first width index of the original image from which the crop image is cropped. 42 | crop_image_size = 390 # crop image size 43 | resize_image_size = 200 # resize image size 44 | 45 | 46 | class RescaleData(object): 47 | def __init__(self, color, depth, background_depth, grasp_labels, info=DataInfo()): 48 | self.info = info 49 | 50 | self.color = color 51 | self.diff_depth = background_depth - depth 52 | self.grasp_labels = grasp_labels 53 | self.grasp_center = self.get_grasp_center(self.grasp_labels) 54 | self.grasp_angle = self.get_grasp_angle(self.grasp_labels[0]) 55 | self.gripper_width = self.get_gripper_width(self.grasp_labels[0]) 56 | 57 | def get_zoomed_data(self, factor): 58 | zoomed_size = int(self.info.crop_image_size * factor) 59 | keep = zoomed_size / 7 60 | half = zoomed_size / 2 61 | top, down, left, right = 1, 0, 1, 0 62 | keep1 = keep2 = keep 63 | while top > down: 64 | top = self.grasp_center[0] - half + keep1 if self.grasp_center[0] > zoomed_size else half 65 | down = self.grasp_center[0] + half - keep1 if self.info.original_image_height - self.grasp_center[0] > zoomed_size else self.info.original_image_height - half - zoomed_size % 2 66 | keep1 /= 2 67 | while left > right: 68 | left = self.grasp_center[1] - half + keep2 if self.grasp_center[1] > zoomed_size else half 69 | right = self.grasp_center[1] + half - keep2 if self.info.original_image_width - self.grasp_center[1] > zoomed_size else self.info.original_image_width - half - zoomed_size % 2 70 | keep2 /= 2 71 | image_center = np.array([np.random.randint(top, down+1, dtype=int), np.random.randint(left, right+1 72 | , dtype=int)]) 73 | crop_color, crop_depth, object_points = self.zoom(image_center, zoomed_size, factor) 74 | grasp_point = self.grasp_center - (image_center - half) 75 | grasp_point = (grasp_point * 200.0 / zoomed_size).astype(np.int) 76 | label_points = self.get_label(grasp_point, self.grasp_angle, self.gripper_width, factor) 77 | label = self.draw_label(label_points, 200, 200) 78 | return crop_color, crop_depth, label, label_points, object_points 79 | 80 | def zoom(self, center, zoomed_size, factor): 81 | half = zoomed_size / 2 82 | background_depth = int(self.info.camera_height * 1000 * factor) 83 | crop_color = self.color[center[0] - half:center[0] + half + zoomed_size % 2, center[1] - half:center[1] + half + zoomed_size % 2, :] 84 | crop_diff_depth = self.diff_depth[center[0] - half:center[0] + half + zoomed_size % 2, center[1] - half:center[1] + half + zoomed_size % 2] 85 | crop_depth = background_depth - crop_diff_depth 86 | crop_color = cv2.resize(crop_color, (200, 200)) 87 | crop_depth = cv2.resize(crop_depth, (200, 200)) 88 | points = np.stack(np.where(cv2.resize(crop_diff_depth, (200, 200)) > 0), axis=0).T 89 | 90 | return crop_color, crop_depth, points 91 | 92 | def get_grasp_center(self, grasp_labels): 93 | row, _ = grasp_labels.shape 94 | grasp_center = np.empty((row, 2), dtype=np.float32) 95 | grasp_center[:, 0] = (grasp_labels[:, 0] + grasp_labels[:, 2]) / 2.0 96 | grasp_center[:, 1] = (grasp_labels[:, 1] + grasp_labels[:, 3]) / 2.0 97 | grasp_center = 1.0 * grasp_center * self.info.crop_image_size / 200 + np.array([self.info.top_left_corner_h, self.info.top_left_corner_w]) 98 | return grasp_center.mean(axis=0).astype(np.int) 99 | 100 | @staticmethod 101 | def get_label(center_point, angle, width, factor): 102 | """ 103 | Obtain grasp labels. Each grasp label is represented as two points. 104 | :param center_point: A numpy array of float32. The position of the object which the gripper will reach. 105 | :param angle: A float. The rotated angle of the gripper. 106 | :param width: A float. The width between two finger tips. 107 | :return: A numpy array of float. The grasp points. 108 | """ 109 | width = int(width / factor) 110 | tip_width = int(7 / factor) 111 | c_h = center_point[0] 112 | c_w = center_point[1] 113 | h = [delta_h for delta_h in range(-tip_width/2, tip_width/2)] 114 | w = [-width / 2, width / 2] 115 | points = np.asanyarray([[hh, w[0], hh, w[1]] for hh in h]) 116 | rotate_matrix = np.array([[np.cos(angle), -np.sin(angle)], 117 | [np.sin(angle), np.cos(angle)]]) 118 | points[:, 0:2] = np.dot(rotate_matrix, points[:, 0:2].T).T 119 | points[:, 2:] = np.dot(rotate_matrix, points[:, 2:].T).T 120 | points = points + np.asanyarray([[c_h, c_w, c_h, c_w]]) 121 | points = np.floor(points).astype(np.int) 122 | return points 123 | 124 | @staticmethod 125 | def get_grasp_angle(grasp_label): 126 | """ 127 | Get grasp angle. 128 | :param grasp_label: A numpy array with shape [1, 4] which represents 129 | a grasp label formulated as 2 points (i.e., (x1, y1, x2, y2)). 130 | :return: angle_indices: A list of int with length 2. The discretized angles ranged from 0 to 15. 131 | Besides the original grasp angle, the list contains another angle flipped vertically by original one. 132 | """ 133 | pt1 = grasp_label[0:2] 134 | pt2 = grasp_label[2:] 135 | angle = np.arctan2(pt2[0] - pt1[0], pt2[1] - pt1[1]) 136 | return -angle 137 | 138 | @staticmethod 139 | def get_gripper_width(grasp_label): 140 | pt1 = grasp_label[0:2] 141 | pt2 = grasp_label[2:] 142 | gripper_width = np.sqrt(np.sum(np.power(pt1-pt2, 2))) 143 | return gripper_width 144 | 145 | @staticmethod 146 | def draw_label(points, width, height, color=(255, 0, 0)): 147 | """ 148 | Draw labels according to the grasp points. 149 | :param points: A numpy array of float. The grasp points. 150 | :param width: A float. The width of image. 151 | :param height: A float. The height of image. 152 | :param color: A tuple. The color of lines. 153 | :return: A numpy array of uint8. The label map. 154 | """ 155 | label = np.ones((height, width, 3), dtype=np.uint8) * 255 156 | for point in points: 157 | pt1 = (point[1], point[0]) 158 | pt2 = (point[3], point[2]) 159 | cv2.line(label, pt1, pt2, color) 160 | return label 161 | -------------------------------------------------------------------------------- /metagrasp_train.py: -------------------------------------------------------------------------------- 1 | from hparams import create_metagrasp_hparams 2 | from network_utils import * 3 | from models import metagrasp 4 | from losses import * 5 | 6 | import tensorflow as tf 7 | import tensorflow.contrib.slim as slim 8 | 9 | import argparse 10 | import os 11 | os.environ["CUDA_VISIBLE_DEVICES"] = '1' 12 | 13 | parser = argparse.ArgumentParser(description='network training') 14 | parser.add_argument('--master', 15 | default='', 16 | type=str, 17 | help='BNS name of the TensorFlow master to use') 18 | parser.add_argument('--task_id', 19 | default=0, 20 | type=int, 21 | help='The Task ID. This value is used when training with multiple workers to identify each worker.') 22 | parser.add_argument('--train_log_dir', 23 | default='models/metagrasp', 24 | type=str, 25 | help='Directory where to write event models.') 26 | parser.add_argument('--dataset_dir', 27 | default='data/tfrecords', 28 | type=str, 29 | help='The directory where the datasets can be found.') 30 | parser.add_argument('--save_summaries_steps', 31 | default=120, 32 | type=int, 33 | help='The frequency with which summaries are saved, in seconds.') 34 | parser.add_argument('--save_interval_secs', 35 | default=600, 36 | type=int, 37 | help='The frequency with which the model is saved, in seconds.') 38 | parser.add_argument('--print_loss_steps', 39 | default=100, 40 | type=int, 41 | help='The frequency with which the losses are printed, in steps.') 42 | parser.add_argument('--num_readers', 43 | default=2, 44 | type=int, 45 | help='The number of parallel readers that read data from the dataset.') 46 | parser.add_argument('--num_steps', 47 | default=200000, 48 | type=int, 49 | help='The max number of gradient steps to take during training.') 50 | parser.add_argument('--num_preprocessing_threads', 51 | default=4, 52 | type=int, 53 | help='The number of threads used to create the batches.') 54 | parser.add_argument('--from_metagrasp_checkpoint', 55 | default=False, 56 | type=bool, 57 | help='load checkpoint from metagrasp checkpoint or classification checkpoint.') 58 | parser.add_argument('--checkpoint_dir', 59 | default='', 60 | type=str, 61 | help='The directory where the checkpoint can be found') 62 | args = parser.parse_args() 63 | 64 | 65 | def main(): 66 | tf.logging.set_verbosity(tf.logging.INFO) 67 | h = create_metagrasp_hparams() 68 | for path in [args.train_log_dir]: 69 | if not tf.gfile.Exists(path): 70 | tf.gfile.MakeDirs(path) 71 | hparams_filename = os.path.join(args.train_log_dir, 'hparams.json') 72 | with tf.gfile.FastGFile(hparams_filename, 'w') as f: 73 | f.write(h.to_json()) 74 | with tf.Graph().as_default(): 75 | with tf.device(tf.train.replica_device_setter(args.task_id)): 76 | global_step = tf.train.get_or_create_global_step() 77 | colors_p, labels_p = get_color_dataset(args.dataset_dir + '/pos', 78 | args.num_readers, 79 | args.num_preprocessing_threads, 80 | h.image_size, 81 | h.label_size, 82 | int(h.batch_size/2)) 83 | colors_n, labels_n = get_color_dataset(args.dataset_dir + '/neg', 84 | args.num_readers, 85 | args.num_preprocessing_threads, 86 | h.image_size, 87 | h.label_size, 88 | int(h.batch_size/2)) 89 | colors = tf.concat([colors_p, colors_n], axis=0) 90 | labels = tf.concat([labels_p, labels_n], axis=0) 91 | net, end_points = metagrasp(colors, 92 | num_classes=3, 93 | num_channels=1000, 94 | is_training=True, 95 | global_pool=False, 96 | output_stride=16, 97 | spatial_squeeze=False, 98 | scope=h.scope) 99 | loss = create_loss_with_label_mask(net, labels, h.lamb) 100 | learning_rate = h.learning_rate 101 | if h.lr_decay_step: 102 | learning_rate = tf.train.exponential_decay(h.learning_rate, 103 | tf.train.get_or_create_global_step(), 104 | decay_steps=h.lr_decay_step, 105 | decay_rate=h.lr_decay_rate, 106 | staircase=True) 107 | tf.summary.scalar('Learning_rate', learning_rate) 108 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 109 | train_op = slim.learning.create_train_op(loss, optimizer) 110 | add_summary(colors, labels, end_points, loss, h) 111 | summary_op = tf.summary.merge_all() 112 | if not args.from_metagrasp_checkpoint: 113 | variable_map = restore_from_classification_checkpoint( 114 | scope=h.scope, 115 | model_name=h.model_name, 116 | checkpoint_exclude_scopes=['prediction']) 117 | init_saver = tf.train.Saver(variable_map) 118 | 119 | def initializer_fn(sess): 120 | init_saver.restore(sess, os.path.join(args.checkpoint_dir, h.model_name+'.ckpt')) 121 | tf.logging.info('Successfully load pretrained checkpoint.') 122 | 123 | init_fn = initializer_fn 124 | 125 | else: 126 | variable_map = restore_map() 127 | init_saver = tf.train.Saver(variable_map) 128 | 129 | def initializer_fn(sess): 130 | init_saver.restore(sess, tf.train.latest_checkpoint(args.checkpoint_dir)) 131 | tf.logging.info('Successfully load pretrained checkpoint.') 132 | 133 | init_fn = initializer_fn 134 | 135 | session_config = tf.ConfigProto(allow_soft_placement=True, 136 | log_device_placement=False) 137 | session_config.gpu_options.allow_growth = True 138 | saver = tf.train.Saver(keep_checkpoint_every_n_hours=args.save_interval_secs, 139 | max_to_keep=100) 140 | 141 | slim.learning.train(train_op, 142 | logdir=args.train_log_dir, 143 | master=args.master, 144 | global_step=global_step, 145 | session_config=session_config, 146 | init_fn=init_fn, 147 | summary_op=summary_op, 148 | number_of_steps=args.num_steps, 149 | startup_delay_steps=15, 150 | save_summaries_secs=args.save_summaries_steps, 151 | saver=saver) 152 | 153 | 154 | if __name__ == '__main__': 155 | main() 156 | -------------------------------------------------------------------------------- /mag_train.py: -------------------------------------------------------------------------------- 1 | from hparams import create_mag_hparams 2 | from network_utils import * 3 | from models import mag 4 | from losses import * 5 | 6 | import tensorflow as tf 7 | import tensorflow.contrib.slim as slim 8 | 9 | import argparse 10 | import os 11 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 12 | 13 | parser = argparse.ArgumentParser(description='network training') 14 | parser.add_argument('--master', 15 | default='', 16 | type=str, 17 | help='BNS name of the TensorFlow master to use') 18 | parser.add_argument('--task_id', 19 | default=0, 20 | type=int, 21 | help='The Task ID. This value is used when training with multiple workers to identify each worker.') 22 | parser.add_argument('--train_log_dir', 23 | default='models/', 24 | type=str, 25 | help='Directory where to write event models.') 26 | parser.add_argument('--save_summaries_steps', 27 | default=120, 28 | type=int, 29 | help='The frequency with which summaries are saved, in seconds.') 30 | parser.add_argument('--save_interval_secs', 31 | default=600, 32 | type=int, 33 | help='The frequency with which the model is saved, in seconds.') 34 | parser.add_argument('--print_loss_steps', 35 | default=100, 36 | type=int, 37 | help='The frequency with which the losses are printed, in steps.') 38 | parser.add_argument('--dataset_dir', 39 | default='', 40 | type=str, 41 | help='The directory where the datasets can be found.') 42 | parser.add_argument('--num_readers', 43 | default=2, 44 | type=int, 45 | help='The number of parallel readers that read data from the dataset.') 46 | parser.add_argument('--num_steps', 47 | default=200000, 48 | type=int, 49 | help='The max number of gradient steps to take during training.') 50 | parser.add_argument('--num_preprocessing_threads', 51 | default=4, 52 | type=int, 53 | help='The number of threads used to create the batches.') 54 | parser.add_argument('--from_mag_checkpoint', 55 | default=False, 56 | type=bool, 57 | help='Load checkpoint from mag checkpoint or classification checkpoint.') 58 | parser.add_argument('--checkpoint_dir', 59 | default='', 60 | type=str, 61 | help='The directory where the checkpoint can be found') 62 | args = parser.parse_args() 63 | 64 | 65 | def main(): 66 | tf.logging.set_verbosity(tf.logging.INFO) 67 | h = create_mag_hparams() 68 | for path in [args.train_log_dir]: 69 | if not tf.gfile.Exists(path): 70 | tf.gfile.MakeDirs(path) 71 | hparams_filename = os.path.join(args.train_log_dir, 'hparams.json') 72 | with tf.gfile.FastGFile(hparams_filename, 'w') as f: 73 | f.write(h.to_json()) 74 | with tf.Graph().as_default(): 75 | with tf.device(tf.train.replica_device_setter(args.task_id)): 76 | global_step = tf.train.get_or_create_global_step() 77 | colors_p, labels_p = get_color_dataset(args.dataset_dir+'/pos', 78 | args.num_readers, 79 | args.num_preprocessing_threads, 80 | h.image_size, 81 | h.label_size, 82 | int(h.batch_size/2)) 83 | colors_n, labels_n = get_color_dataset(args.dataset_dir + '/neg', 84 | args.num_readers, 85 | args.num_preprocessing_threads, 86 | h.image_size, 87 | h.label_size, 88 | int(h.batch_size/2)) 89 | colors = tf.concat([colors_p, colors_n], axis=0) 90 | labels = tf.concat([labels_p, labels_n], axis=0) 91 | net, end_points = mag(colors, 92 | num_classes=3, 93 | num_channels=1000, 94 | is_training=True, 95 | global_pool=False, 96 | output_stride=16, 97 | upsample_ratio=2, 98 | spatial_squeeze=False, 99 | scope=h.scope) 100 | loss = create_loss_with_label_mask(net, labels, h.lamb) 101 | # loss = create_loss_without_background(net, labels) 102 | learning_rate = h.learning_rate 103 | if h.lr_decay_step: 104 | learning_rate = tf.train.exponential_decay(h.learning_rate, 105 | tf.train.get_or_create_global_step(), 106 | decay_steps=h.lr_decay_step, 107 | decay_rate=h.lr_decay_rate, 108 | staircase=True) 109 | tf.summary.scalar('Learning_rate', learning_rate) 110 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 111 | # optimizer = tf.train.MomentumOptimizer(0.001, momentum=h.momentum) 112 | train_op = slim.learning.create_train_op(loss, optimizer) 113 | add_summary(colors, labels, end_points, loss, h) 114 | summary_op = tf.summary.merge_all() 115 | if not args.from_mag_checkpoint: 116 | variable_map = restore_from_classification_checkpoint( 117 | scope=h.scope, 118 | model_name=h.model_name, 119 | checkpoint_exclude_scopes=['prediction']) 120 | init_saver = tf.train.Saver(variable_map) 121 | 122 | def initializer_fn(sess): 123 | init_saver.restore(sess, os.path.join(args.checkpoint_dir, h.model_name+'.ckpt')) 124 | tf.logging.info('Successfully load pretrained checkpoint.') 125 | 126 | init_fn = initializer_fn 127 | 128 | else: 129 | variable_map = restore_map() 130 | init_saver = tf.train.Saver(variable_map) 131 | 132 | def initializer_fn(sess): 133 | init_saver.restore(sess, tf.train.latest_checkpoint(args.checkpoint_dir)) 134 | tf.logging.info('Successfully load pretrained checkpoint.') 135 | 136 | init_fn = initializer_fn 137 | 138 | session_config = tf.ConfigProto(allow_soft_placement=True, 139 | log_device_placement=False) 140 | session_config.gpu_options.allow_growth = True 141 | saver = tf.train.Saver(keep_checkpoint_every_n_hours=args.save_interval_secs, 142 | max_to_keep=100) 143 | 144 | slim.learning.train(train_op, 145 | logdir=args.train_log_dir, 146 | master=args.master, 147 | global_step=global_step, 148 | session_config=session_config, 149 | init_fn=init_fn, 150 | summary_op=summary_op, 151 | number_of_steps=args.num_steps, 152 | startup_delay_steps=15, 153 | save_summaries_secs=args.save_summaries_steps, 154 | saver=saver) 155 | 156 | 157 | if __name__ == '__main__': 158 | main() 159 | -------------------------------------------------------------------------------- /data_collection/rg2.py: -------------------------------------------------------------------------------- 1 | from functools import partial, reduce 2 | from math import sqrt 3 | import data_collection.vrep as vrep 4 | import time 5 | 6 | 7 | class RG2(object): 8 | def __init__(self, client_id, script_name='UR5', motor_vel=0.11, motor_force=77): 9 | """ 10 | Initialization for RG2. 11 | :param client_id: An int. The client ID. refer to simxStart. 12 | :param script_name: A string. The lua script name in vrep. 13 | :param motor_vel: A float. Target velocity of the joint1. 14 | :param motor_force: The maximum force or torque that the joint can exert. 15 | """ 16 | self.client_id = client_id 17 | self.motor_vel = motor_vel 18 | self.motor_force = motor_force 19 | _, self.attach_point = vrep.simxGetObjectHandle(self.client_id, 20 | 'RG2_attachPoint', 21 | vrep.simx_opmode_blocking) 22 | _, self.proximity_sensor = vrep.simxGetObjectHandle(self.client_id, 23 | 'RG2_attachProxSensor', 24 | vrep.simx_opmode_blocking) 25 | _, self.right_touch = vrep.simxGetObjectHandle(self.client_id, 26 | 'RG2_rightTouch', 27 | vrep.simx_opmode_blocking) 28 | _, self.left_touch = vrep.simxGetObjectHandle(self.client_id, 29 | 'RG2_leftTouch', 30 | vrep.simx_opmode_blocking) 31 | _, self.joint = vrep.simxGetObjectHandle(self.client_id, 32 | 'RG2_openCloseJoint', 33 | vrep.simx_opmode_blocking) 34 | self.func = partial(vrep.simxCallScriptFunction, 35 | clientID=self.client_id, 36 | scriptDescription=script_name, 37 | options=vrep.sim_scripttype_childscript, 38 | operationMode=vrep.simx_opmode_blocking) 39 | 40 | def open_gripper(self): 41 | """ 42 | Open gripper. 43 | :return: None. 44 | """ 45 | vrep.simxSetJointTargetVelocity(self.client_id, 46 | self.joint, 47 | self.motor_vel, 48 | vrep.simx_opmode_streaming) 49 | vrep.simxSetJointForce(self.client_id, 50 | self.joint, 51 | self.motor_force, 52 | vrep.simx_opmode_oneshot) 53 | self.wait_until_stop(self.right_touch) 54 | 55 | def close_gripper(self): 56 | """ 57 | Close gripper. 58 | :return: None 59 | """ 60 | vrep.simxSetJointTargetVelocity(self.client_id, 61 | self.joint, 62 | -self.motor_vel, 63 | vrep.simx_opmode_oneshot) 64 | vrep.simxSetJointForce(self.client_id, 65 | self.joint, 66 | self.motor_force, 67 | vrep.simx_opmode_oneshot) 68 | self.wait_until_stop(self.right_touch) 69 | 70 | def attach_object(self, object_handle): 71 | """ 72 | Attach object to the gripper. This is an alternative to grasp objects. 73 | :param object_handle: An int. The handle of object which is successfully grasped by gripper. 74 | :return: None 75 | """ 76 | vrep.simxSetObjectParent(self.client_id, 77 | object_handle, 78 | self.attach_point, 79 | True, 80 | vrep.simx_opmode_blocking) 81 | 82 | def untie_object(self, object_handle): 83 | """ 84 | Untie object to the gripper. This is an alternative to grasp objects. 85 | :param object_handle: An int. The handle of object which is attached to the gripper. 86 | :return: None 87 | """ 88 | vrep.simxSetObjectParent(self.client_id, 89 | object_handle, 90 | -1, 91 | True, 92 | vrep.simx_opmode_blocking) 93 | 94 | def get_object_detection_status(self, threshold=0.005): 95 | """ 96 | Detect whether gripper grasp object successfully. 97 | :param threshold: A float. When distance between two gripper tips is larger than threshold, 98 | gripper successfully grasp object, vice versa. 99 | :return: A boolean. Return object detection status. 100 | """ 101 | time.sleep(0.7) 102 | _, d_pos = vrep.simxGetObjectPosition(self.client_id, 103 | self.left_touch, 104 | self.right_touch, 105 | vrep.simx_opmode_blocking) 106 | if threshold < abs(d_pos[0]) < 0.085: 107 | half_touch = 0.01475 - 0.005 108 | # distance from proximity sensor to left touch. 109 | _, d_p2t_l = vrep.simxGetObjectPosition(self.client_id, 110 | self.left_touch, 111 | self.proximity_sensor, 112 | vrep.simx_opmode_blocking) 113 | # distance from proximity sensor to right touch. 114 | _, d_p2t_r = vrep.simxGetObjectPosition(self.client_id, 115 | self.right_touch, 116 | self.proximity_sensor, 117 | vrep.simx_opmode_blocking) 118 | _, distance = self.read_proximity_sensor() 119 | if distance < d_p2t_l[-1] + half_touch and distance < d_p2t_r[-1] + half_touch: 120 | _, d_l2r = vrep.simxGetObjectPosition(self.client_id, 121 | self.left_touch, 122 | self.right_touch, 123 | vrep.simx_opmode_blocking) 124 | diff = abs(d_l2r[-1]) 125 | print(diff) 126 | if diff < 0.001: 127 | return True 128 | return False 129 | 130 | def get_gripper_width(self): 131 | """ 132 | Get gripper width. 133 | :return: Gripper width. 134 | """ 135 | _, d_pos = vrep.simxGetObjectPosition(self.client_id, 136 | self.left_touch, 137 | self.right_touch, 138 | vrep.simx_opmode_blocking) 139 | return abs(d_pos[0]) 140 | 141 | def read_proximity_sensor(self): 142 | """ 143 | Read proximity sensor. 144 | :return: 145 | """ 146 | _, out_ints, out_floats, _, _ = self.func(functionName='pyReadProxSensor', 147 | inputInts=[], 148 | inputFloats=[], 149 | inputStrings=[], 150 | inputBuffer='') 151 | return out_ints[0], out_floats[0] 152 | 153 | def wait_until_stop(self, handle, threshold=0.005, time_delay=0.2): 154 | """ 155 | Wait until the operation finishes. 156 | This is a delay function called in order to make sure that 157 | the operation executed has been completed. 158 | :param handle: An int.Handle of the object. 159 | :param threshold: A float. The object position threshold. 160 | If the object positions difference between two time steps is smaller than the threshold, 161 | the execution completes, otherwise the loop continues. 162 | :param time_delay:A float. How much time we should wait in an execution step. 163 | :return: None 164 | """ 165 | while True: 166 | _, pos1 = vrep.simxGetObjectPosition(self.client_id, handle, -1, vrep.simx_opmode_blocking) 167 | _, quat1 = vrep.simxGetObjectQuaternion(self.client_id, handle, -1, vrep.simx_opmode_blocking) 168 | time.sleep(time_delay) 169 | _, pos2 = vrep.simxGetObjectPosition(self.client_id, handle, -1, vrep.simx_opmode_blocking) 170 | _, quat2 = vrep.simxGetObjectQuaternion(self.client_id, handle, -1, vrep.simx_opmode_blocking) 171 | pose1 = pos1 + quat1 172 | pose2 = pos2 + quat2 173 | theta = 0.5 * sqrt(reduce(lambda x, y: x + y, map(lambda x, y: (x - y) ** 2, pose1, pose2))) 174 | if theta < threshold: 175 | return 176 | -------------------------------------------------------------------------------- /data_processing/data_preprocessing.py: -------------------------------------------------------------------------------- 1 | from data_processing.data_processor import DataProcessor 2 | 3 | import numpy as np 4 | import cv2 5 | 6 | import argparse 7 | import random 8 | import os 9 | 10 | parser = argparse.ArgumentParser(description='process data') 11 | parser.add_argument('--data_dir', 12 | required=True, 13 | type=str, 14 | help='path to the data') 15 | parser.add_argument('--output', 16 | required=True, 17 | type=str, 18 | help='path to save the process data') 19 | args = parser.parse_args() 20 | 21 | 22 | def main(): 23 | if not os.path.exists(args.output): 24 | os.mkdir(args.output) 25 | os.mkdir(os.path.join(args.output, 'color')) 26 | os.mkdir(os.path.join(args.output, 'depth')) 27 | os.mkdir(os.path.join(args.output, 'encoded_depth')) 28 | os.mkdir(os.path.join(args.output, 'label_map')) 29 | os.mkdir(os.path.join(args.output, 'camera_height')) 30 | f_p = open(os.path.join(args.output, 'pos.txt'), 'w') 31 | f_n = open(os.path.join(args.output, 'neg.txt'), 'w') 32 | dp = DataProcessor() 33 | 34 | # load data 35 | data_dir = os.listdir(args.data_dir) 36 | random.shuffle(data_dir) 37 | for data_id in data_dir: 38 | parent_dir = os.path.join(args.data_dir, data_id) 39 | print(parent_dir) 40 | # b_depth_height_map = cv2.imread(os.path.join(parent_dir, 'crop_background_depth.png'), 41 | # cv2.IMREAD_ANYDEPTH).astype(np.float32) 42 | label_files = os.listdir(os.path.join(parent_dir, 'label')) 43 | with open(os.path.join(parent_dir, 'file_name.txt'), 'r') as f: 44 | file_names = f.readlines() 45 | for file_name in file_names: 46 | file_name = file_name[:-2] if file_name[-2:] == '\r\n' else file_name[:-1] 47 | print(file_name) 48 | color = cv2.imread(os.path.join(parent_dir, 'height_map_color', file_name+'.png')) 49 | depth = cv2.imread(os.path.join(parent_dir, 'height_map_depth', file_name+'.png'), 50 | cv2.IMREAD_ANYDEPTH).astype(np.float32) 51 | # diff_depth = dp.get_diff_depth(depth, b_depth_height_map) 52 | # pad color and depth images 53 | pad_size = 44 54 | pad_color = np.ones((288, 288, 3), dtype=np.uint8) * 7 55 | pad_color[pad_size:pad_size + 200, pad_size:pad_size + 200, :] = color 56 | pad_depth = np.ones((288, 288), dtype=np.uint16) * 7 57 | pad_depth[pad_size:pad_size + 200, pad_size:pad_size + 200] = depth 58 | background_depth_value = np.argmax(np.bincount(depth.astype(np.int).flatten())) 59 | camera_height = background_depth_value / 1000.0 60 | neglect_points = np.loadtxt(os.path.join(parent_dir, 'label', file_name+'.object_points.txt')) + pad_size 61 | if file_name+'.good.txt' in label_files: 62 | good_pixel_labels = np.loadtxt(os.path.join(parent_dir, 'label', file_name+'.good.txt')) + pad_size 63 | grasp_centers = dp.get_grasp_center(good_pixel_labels) 64 | angle_indices = dp.get_grasp_angle(good_pixel_labels[0]) 65 | for angle_idx in angle_indices: 66 | quantified_angle = 22.5 * angle_idx 67 | for i, angle in enumerate(np.arange(quantified_angle-5, quantified_angle+5, 1)): 68 | grasp_label = np.zeros((288, 288, 3), dtype=np.uint8) # bgr for opencv 69 | grasp_label[..., 0] = 255 70 | rotated_neglect_points = dp.rotate(neglect_points, (144, 144), (angle / 360.0) * np.pi * 2) 71 | rotated_neglect_points = np.round(rotated_neglect_points).astype(np.int) 72 | grasp_label[rotated_neglect_points[:, 0], rotated_neglect_points[:, 1], 0] = 0 73 | grasp_label = cv2.medianBlur(grasp_label, 3) 74 | rotated_grasp_centers = dp.rotate(grasp_centers, (144, 144), (angle / 360.0) * np.pi * 2) 75 | rotated_grasp_centers = np.round(rotated_grasp_centers).astype(np.int) 76 | grasp_label[rotated_grasp_centers[:, 0], rotated_grasp_centers[:, 1], 1] = 255 77 | mtx = cv2.getRotationMatrix2D((144, 144), angle, 1) 78 | rotated_color = cv2.warpAffine(pad_color, mtx, (288, 288), 79 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7)) 80 | rotated_depth = cv2.warpAffine(pad_depth, mtx, (288, 288), 81 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7)) 82 | encoded_depth = dp.encode_depth(rotated_depth) 83 | cv2.imwrite(os.path.join(args.output, 'label_map', 84 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 85 | grasp_label) 86 | cv2.imwrite(os.path.join(args.output, 'color', 87 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 88 | rotated_color) 89 | cv2.imwrite(os.path.join(args.output, 'depth', 90 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 91 | rotated_depth) 92 | cv2.imwrite(os.path.join(args.output, 'encoded_depth', 93 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 94 | encoded_depth) 95 | np.savetxt(os.path.join(args.output, 'camera_height', 96 | data_id + '-' + file_name + '-{}-{:02d}.txt'.format(angle_idx, i)), 97 | np.array([camera_height * 1000.0])) 98 | f_p.write(data_id + '-' + file_name + '-{}-{:02d}\n'.format(angle_idx, i)) 99 | else: 100 | bad_pixel_labels = np.loadtxt(os.path.join(parent_dir, 'label', file_name + '.bad.txt')) + pad_size 101 | grasp_centers = dp.get_grasp_center(bad_pixel_labels) 102 | angle_indices = dp.get_grasp_angle(bad_pixel_labels[0]) 103 | for angle_idx in angle_indices: 104 | quantified_angle = 22.5 * angle_idx 105 | for i, angle in enumerate(np.arange(quantified_angle - 5, quantified_angle + 5, 1)): 106 | grasp_label = np.zeros((288, 288, 3), dtype=np.uint8) # bgr for opencv 107 | grasp_label[..., 0] = 255 108 | rotated_neglect_points = dp.rotate(neglect_points, (144, 144), (angle / 360.0) * np.pi * 2) 109 | rotated_neglect_points = np.round(rotated_neglect_points).astype(np.int) 110 | grasp_label[rotated_neglect_points[:, 0], rotated_neglect_points[:, 1], 0] = 0 111 | grasp_label = cv2.medianBlur(grasp_label, 3) 112 | rotated_grasp_centers = dp.rotate(grasp_centers, (144, 144), (angle / 360.0) * np.pi * 2) 113 | rotated_grasp_centers = np.round(rotated_grasp_centers).astype(np.int) 114 | grasp_label[rotated_grasp_centers[:, 0], rotated_grasp_centers[:, 1], 2] = 255 115 | mtx = cv2.getRotationMatrix2D((144, 144), angle, 1) 116 | rotated_color = cv2.warpAffine(pad_color, mtx, (288, 288), 117 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7)) 118 | rotated_depth = cv2.warpAffine(pad_depth, mtx, (288, 288), 119 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7)) 120 | encoded_depth = dp.encode_depth(rotated_depth) 121 | cv2.imwrite(os.path.join(args.output, 'label_map', 122 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 123 | grasp_label) 124 | cv2.imwrite(os.path.join(args.output, 'color', 125 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 126 | rotated_color) 127 | cv2.imwrite(os.path.join(args.output, 'depth', 128 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 129 | rotated_depth) 130 | cv2.imwrite(os.path.join(args.output, 'encoded_depth', 131 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 132 | encoded_depth) 133 | np.savetxt(os.path.join(args.output, 'camera_height', 134 | data_id + '-' + file_name + '-{}-{:02d}.txt'.format(angle_idx, i)), 135 | np.array([camera_height * 1000.0])) 136 | f_n.write(data_id + '-' + file_name + '-{}-{:02d}\n'.format(angle_idx, i)) 137 | f_p.close() 138 | f_n.close() 139 | 140 | 141 | if __name__ == '__main__': 142 | main() 143 | 144 | -------------------------------------------------------------------------------- /data_processing/mag_data_preprocessing.py: -------------------------------------------------------------------------------- 1 | from data_processing.data_processor import DataProcessor 2 | 3 | import numpy as np 4 | import cv2 5 | 6 | import argparse 7 | import random 8 | import os 9 | 10 | parser = argparse.ArgumentParser(description='process data') 11 | parser.add_argument('--data_dir', 12 | required=True, 13 | type=str, 14 | help='path to the data') 15 | parser.add_argument('--output', 16 | required=True, 17 | type=str, 18 | help='path to save the process data') 19 | args = parser.parse_args() 20 | 21 | 22 | def main(): 23 | if not os.path.exists(args.output): 24 | os.mkdir(args.output) 25 | os.mkdir(os.path.join(args.output, 'color')) 26 | os.mkdir(os.path.join(args.output, 'depth')) 27 | os.mkdir(os.path.join(args.output, 'encoded_depth')) 28 | os.mkdir(os.path.join(args.output, 'label_map')) 29 | os.mkdir(os.path.join(args.output, 'camera_height')) 30 | f_p = open(os.path.join(args.output, 'pos.txt'), 'w') 31 | f_n = open(os.path.join(args.output, 'neg.txt'), 'w') 32 | dp = DataProcessor() 33 | 34 | # load data 35 | data_dir = os.listdir(args.data_dir) 36 | random.shuffle(data_dir) 37 | for data_id in data_dir: 38 | parent_dir = os.path.join(args.data_dir, data_id) 39 | print(parent_dir) 40 | # b_depth_height_map = cv2.imread(os.path.join(parent_dir, 'crop_background_depth.png'), 41 | # cv2.IMREAD_ANYDEPTH).astype(np.float32) 42 | label_files = os.listdir(os.path.join(parent_dir, 'label')) 43 | with open(os.path.join(parent_dir, 'file_name.txt'), 'r') as f: 44 | file_names = f.readlines() 45 | for file_name in file_names: 46 | file_name = file_name[:-2] if file_name[-2:] == '\r\n' else file_name[:-1] 47 | print(file_name) 48 | color = cv2.imread(os.path.join(parent_dir, 'height_map_color', file_name+'.png')) 49 | depth = cv2.imread(os.path.join(parent_dir, 'height_map_depth', file_name+'.png'), 50 | cv2.IMREAD_ANYDEPTH).astype(np.float32) 51 | # diff_depth = dp.get_diff_depth(depth, b_depth_height_map) 52 | # pad color and depth images 53 | pad_size = 44 54 | pad_color = np.ones((288, 288, 3), dtype=np.uint8) * 7 55 | pad_color[pad_size:pad_size + 200, pad_size:pad_size + 200, :] = color 56 | pad_depth = np.ones((288, 288), dtype=np.uint16) * 7 57 | pad_depth[pad_size:pad_size + 200, pad_size:pad_size + 200] = depth 58 | background_depth_value = np.argmax(np.bincount(depth.astype(np.int).flatten())) 59 | camera_height = background_depth_value / 1000.0 60 | neglect_points = np.loadtxt(os.path.join(parent_dir, 'label', file_name+'.object_points.txt')) + pad_size 61 | if file_name+'.good.txt' in label_files: 62 | good_pixel_labels = np.loadtxt(os.path.join(parent_dir, 'label', file_name+'.good.txt')) + pad_size 63 | grasp_centers = dp.get_grasp_center(good_pixel_labels) 64 | angle_indices = dp.get_grasp_angle(good_pixel_labels[0]) 65 | for angle_idx in angle_indices: 66 | quantified_angle = 22.5 * angle_idx 67 | for i, angle in enumerate(np.arange(quantified_angle-5, quantified_angle+5, 1)): 68 | grasp_label = np.zeros((36, 36, 3), dtype=np.uint8) # bgr for opencv 69 | grasp_label[..., 0] = 255 70 | rotated_neglect_points = dp.rotate(neglect_points, (144, 144), (angle / 360.0) * np.pi * 2) 71 | rotated_neglect_points = np.round(rotated_neglect_points / 8.0).astype(np.int) 72 | grasp_label[rotated_neglect_points[:, 0], rotated_neglect_points[:, 1], 0] = 0 73 | grasp_label = cv2.medianBlur(grasp_label, 3) 74 | rotated_grasp_centers = dp.rotate(grasp_centers, (144, 144), (angle / 360.0) * np.pi * 2) 75 | rotated_grasp_centers = np.round(rotated_grasp_centers / 8.0).astype(np.int) 76 | grasp_label[rotated_grasp_centers[:, 0], rotated_grasp_centers[:, 1], 1] = 255 77 | mtx = cv2.getRotationMatrix2D((144, 144), angle, 1) 78 | rotated_color = cv2.warpAffine(pad_color, mtx, (288, 288), 79 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7)) 80 | rotated_depth = cv2.warpAffine(pad_depth, mtx, (288, 288), 81 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7)) 82 | encoded_depth = dp.encode_depth(rotated_depth) 83 | cv2.imwrite(os.path.join(args.output, 'label_map', 84 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 85 | grasp_label) 86 | cv2.imwrite(os.path.join(args.output, 'color', 87 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 88 | rotated_color) 89 | cv2.imwrite(os.path.join(args.output, 'depth', 90 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 91 | rotated_depth) 92 | cv2.imwrite(os.path.join(args.output, 'encoded_depth', 93 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 94 | encoded_depth) 95 | np.savetxt(os.path.join(args.output, 'camera_height', 96 | data_id + '-' + file_name + '-{}-{:02d}.txt'.format(angle_idx, i)), 97 | np.array([camera_height * 1000.0])) 98 | f_p.write(data_id + '-' + file_name + '-{}-{:02d}\n'.format(angle_idx, i)) 99 | else: 100 | bad_pixel_labels = np.loadtxt(os.path.join(parent_dir, 'label', file_name + '.bad.txt')) + pad_size 101 | grasp_centers = dp.get_grasp_center(bad_pixel_labels) 102 | angle_indices = dp.get_grasp_angle(bad_pixel_labels[0]) 103 | for angle_idx in angle_indices: 104 | quantified_angle = 22.5 * angle_idx 105 | for i, angle in enumerate(np.arange(quantified_angle - 5, quantified_angle + 5, 1)): 106 | grasp_label = np.zeros((36, 36, 3), dtype=np.uint8) # bgr for opencv 107 | grasp_label[..., 0] = 255 108 | rotated_neglect_points = dp.rotate(neglect_points, (144, 144), (angle / 360.0) * np.pi * 2) 109 | rotated_neglect_points = np.round(rotated_neglect_points / 8.0).astype(np.int) 110 | grasp_label[rotated_neglect_points[:, 0], rotated_neglect_points[:, 1], 0] = 0 111 | grasp_label = cv2.medianBlur(grasp_label, 3) 112 | rotated_grasp_centers = dp.rotate(grasp_centers, (144, 144), (angle / 360.0) * np.pi * 2) 113 | rotated_grasp_centers = np.round(rotated_grasp_centers / 8.0).astype(np.int) 114 | grasp_label[rotated_grasp_centers[:, 0], rotated_grasp_centers[:, 1], 2] = 255 115 | mtx = cv2.getRotationMatrix2D((144, 144), angle, 1) 116 | rotated_color = cv2.warpAffine(pad_color, mtx, (288, 288), 117 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7)) 118 | rotated_depth = cv2.warpAffine(pad_depth, mtx, (288, 288), 119 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7)) 120 | encoded_depth = dp.encode_depth(rotated_depth) 121 | cv2.imwrite(os.path.join(args.output, 'label_map', 122 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 123 | grasp_label) 124 | cv2.imwrite(os.path.join(args.output, 'color', 125 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 126 | rotated_color) 127 | cv2.imwrite(os.path.join(args.output, 'depth', 128 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 129 | rotated_depth) 130 | cv2.imwrite(os.path.join(args.output, 'encoded_depth', 131 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 132 | encoded_depth) 133 | np.savetxt(os.path.join(args.output, 'camera_height', 134 | data_id + '-' + file_name + '-{}-{:02d}.txt'.format(angle_idx, i)), 135 | np.array([camera_height * 1000.0])) 136 | f_n.write(data_id + '-' + file_name + '-{}-{:02d}\n'.format(angle_idx, i)) 137 | f_p.close() 138 | f_n.close() 139 | 140 | 141 | if __name__ == '__main__': 142 | main() 143 | 144 | -------------------------------------------------------------------------------- /data_collection/corrective_grasp_trials.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | from scipy import misc 3 | from data_collection.scene import Scene 4 | from functools import partial 5 | import numpy as np 6 | import cv2 7 | import argparse 8 | import os 9 | 10 | parser = argparse.ArgumentParser(description='Data collection in vrep simulator.') 11 | parser.add_argument('--ip', 12 | default='127.0.0.1', 13 | type=str, 14 | help='ip address to the vrep simulator.') 15 | parser.add_argument('--port', 16 | default=19997, 17 | type=str, 18 | help='port to the vrep simulator.') 19 | parser.add_argument('--obj_id', 20 | default='obj0', 21 | type=str, 22 | help='object name in vrep.') 23 | parser.add_argument('--num_grasp', 24 | default=200, 25 | type=int, 26 | help='the number of grasp trails.') 27 | parser.add_argument('--num_repeat', 28 | default=1, 29 | type=int, 30 | help='the number of repeat time if the gripper successfully grasp an object.') 31 | parser.add_argument('--output', 32 | default='data/3dnet', 33 | type=str, 34 | help='directory to save data.') 35 | args = parser.parse_args() 36 | 37 | step = pi / 180.0 38 | ratio = 0.4 / 200 39 | interval = [0.23, 0.57] 40 | rand = partial(np.random.uniform, interval[0], interval[1]) 41 | 42 | 43 | def main(): 44 | if not os.path.exists(args.output): 45 | os.makedirs(args.output) 46 | data_id = len(os.listdir(args.output)) 47 | data_path = os.path.join(args.output, '{:04d}'.format(data_id)) 48 | color_path = os.path.join(data_path, 'color') 49 | depth_path = os.path.join(data_path, 'depth') 50 | height_map_color_path = os.path.join(data_path, 'height_map_color') 51 | height_map_depth_path = os.path.join(data_path, 'height_map_depth') 52 | label_path = os.path.join(data_path, 'label') 53 | os.mkdir(data_path) 54 | os.mkdir(color_path) 55 | os.mkdir(depth_path) 56 | os.mkdir(height_map_color_path) 57 | os.mkdir(height_map_depth_path) 58 | os.mkdir(label_path) 59 | 60 | s = Scene(args.ip, args.port, 'realsense') 61 | init_pos = [45*step, 10*step, 90*step, -10*step, -90*step, -45*step] 62 | top_pos = [-45*step, 10*step, 90*step, -10*step, -90*step, -45*step] 63 | # object handle 64 | obj = s.get_object_handle(args.obj_id) 65 | obj_quat = s.get_object_quaternion(obj) 66 | s.gripper.open_gripper() 67 | 68 | misc.imsave(data_path + '/background_color.png', s.background_color) 69 | cv2.imwrite(data_path + '/background_depth.png', s.background_depth) 70 | misc.imsave(data_path+'/crop_background_color.png', 71 | cv2.resize(s.background_color[45:435, 119:509, :], (200, 200))) 72 | cv2.imwrite(data_path+'/crop_background_depth.png', 73 | cv2.resize(s.background_depth[45:435, 119:509], (200, 200))) 74 | 75 | f = open(os.path.join(data_path, 'file_name.txt'), 'w') 76 | 77 | s.set_object_position(obj, [rand(), rand(), s.get_object_position(obj)[-1]]) 78 | 79 | for i in range(args.num_grasp): # number of grasp trials. 80 | print(i) 81 | s.gripper.open_gripper() 82 | s.ur5.move_to_joint_position(init_pos) 83 | s.replace_object(obj, obj_quat, interval) 84 | color = s.get_color_image() 85 | depth = s.get_depth_image() 86 | center_point, object_points, reaching_height = s.get_center_from_image(depth) 87 | s.ur5.move_to_joint_position(top_pos) 88 | episode = [] 89 | # move to the upward side with respect to the object 90 | init_quat = s.ur5.get_end_effector_quaternion() 91 | target_pos = s.get_object_position(obj) 92 | target_pos[0] = 0.200 + center_point[0] * ratio 93 | target_pos[1] = 0.200 + center_point[1] * ratio 94 | target_pos[2] = reaching_height 95 | target_pos[-1] += 0.1563 96 | episode.append(target_pos + init_quat) 97 | s.ur5.move_to_object_position(target_pos + init_quat) 98 | # randomly rotate gripper's angle 99 | joint_angles = s.ur5.get_joint_positions() 100 | curr_angle = joint_angles[-1] / np.pi * 180.0 101 | angle = np.random.randint(0, 180) 102 | joint_angles[-1] = (curr_angle + angle) * step 103 | s.ur5.move_to_joint_position(joint_angles) 104 | quat = s.ur5.get_end_effector_quaternion() 105 | episode.append(target_pos + quat) 106 | # move to the object downward 107 | target_pos[-1] -= 0.1563 108 | episode.append(target_pos + quat) 109 | s.ur5.move_to_object_position(target_pos + quat) 110 | # try to grasp object 111 | s.gripper.close_gripper() 112 | # determine whether the object is successfully grasped. 113 | if s.gripper.get_object_detection_status(): 114 | print('grasp success.') 115 | s.gripper.attach_object(obj) 116 | width = s.gripper.get_gripper_width() 117 | s.gripper.open_gripper() 118 | s.gripper.untie_object(obj) 119 | s.ur5.move_to_object_position(episode[1]) 120 | s.ur5.move_to_joint_position(init_pos) 121 | # if we successfully grasp the object, 122 | # we can record the relative grasp configuration and successively collect the positive samples. 123 | for j in range(args.num_repeat): # the number of positive samples we want to collect in this grasp trials. 124 | if s.replace_object(obj, obj_quat, interval): 125 | break 126 | # rerecord the pattern. 127 | color = s.get_color_image() 128 | depth = s.get_depth_image() 129 | _, object_points, reaching_height = s.get_center_from_image(depth) 130 | crop_color = cv2.resize(color[45:435, 119:509, :], (200, 200)) 131 | crop_depth = cv2.resize(depth[45:435, 119:509], (200, 200)) 132 | points = s.get_label(center_point, -angle*np.pi/180.0, width, ratio) 133 | label = s.draw_label(points, 200, 200) 134 | s.ur5.move_to_joint_position(top_pos) 135 | s.ur5.move_to_object_position(episode[0]) 136 | s.ur5.move_to_object_position(episode[1]) 137 | s.ur5.move_to_object_position(episode[2]) 138 | s.gripper.close_gripper() 139 | if s.gripper.get_object_detection_status(): 140 | s.gripper.attach_object(obj) 141 | s.ur5.move_to_object_position(episode[1]) 142 | # quat = s.ur5.get_end_effector_quaternion() 143 | pos = s.ur5.get_end_effector_position() 144 | pos[0], pos[1] = rand(), rand() 145 | center_point = [int((pos[0] - 0.200) / ratio), int((pos[1] - 0.200) / ratio)] 146 | episode[0] = pos + init_quat 147 | s.ur5.move_to_object_position(pos + init_quat) 148 | joint_angles = s.ur5.get_joint_positions() 149 | curr_angle = joint_angles[-1] / np.pi * 180.0 150 | angle = np.random.randint(0, 180) 151 | joint_angles[-1] = (curr_angle + angle) * step 152 | s.ur5.move_to_joint_position(joint_angles) 153 | quat = s.ur5.get_end_effector_quaternion() 154 | pos = s.ur5.get_end_effector_position() 155 | episode[1] = pos + quat 156 | pos[-1] = episode[2][2] 157 | episode[2] = pos + quat 158 | s.ur5.move_to_object_position(pos + quat) 159 | s.gripper.open_gripper() 160 | s.gripper.untie_object(obj) 161 | s.ur5.move_to_object_position(episode[1]) 162 | s.ur5.move_to_joint_position(init_pos) 163 | misc.imsave(color_path + '/{:06d}_{:04d}.png'.format(i, j), color) 164 | cv2.imwrite(depth_path + '/{:06d}_{:04d}.png'.format(i, j), depth) 165 | misc.imsave(height_map_color_path + '/{:06d}_{:04d}.png'.format(i, j), crop_color) 166 | cv2.imwrite(height_map_depth_path + '/{:06d}_{:04d}.png'.format(i, j), crop_depth) 167 | misc.imsave(label_path + '/{:06d}_{:04d}.png'.format(i, j), label) 168 | np.savetxt(label_path + '/{:06d}_{:04d}.good.txt'.format(i, j), points) 169 | np.savetxt(label_path + '/{:06d}_{:04d}.object_points.txt'.format(i, j), np.asanyarray(object_points)) 170 | f.write('{:06d}_{:04d}\n'.format(i, j)) 171 | else: 172 | break 173 | 174 | else: 175 | print('grasp failed') 176 | s.gripper.open_gripper() 177 | s.ur5.move_to_object_position(episode[1]) 178 | crop_color = cv2.resize(color[45:435, 119:509, :], (200, 200)) 179 | crop_depth = cv2.resize(depth[45:435, 119:509], (200, 200)) 180 | points = s.get_label(center_point, -angle * np.pi / 180.0, 0.085, ratio) 181 | label = s.draw_label(points, 200, 200, (255, 0, 0)) 182 | misc.imsave(color_path + '/{:06d}.png'.format(i), color) 183 | cv2.imwrite(depth_path + '/{:06d}.png'.format(i), depth) 184 | misc.imsave(height_map_color_path + '/{:06d}.png'.format(i), crop_color) 185 | cv2.imwrite(height_map_depth_path + '/{:06d}.png'.format(i), crop_depth) 186 | misc.imsave(label_path + '/{:06d}.png'.format(i), label) 187 | np.savetxt(label_path + '/{:06d}.bad.txt'.format(i), points) 188 | np.savetxt(label_path + '/{:06d}.object_points.txt'.format(i), np.asanyarray(object_points)) 189 | f.write('{:06d}\n'.format(i)) 190 | 191 | s.stop_simulation() 192 | f.close() 193 | 194 | 195 | if __name__ == '__main__': 196 | main() 197 | -------------------------------------------------------------------------------- /data_processing/generate_pseudo_data_v1.py: -------------------------------------------------------------------------------- 1 | from data_processing.data_processor import DataProcessor 2 | 3 | import numpy as np 4 | import cv2 5 | 6 | import argparse 7 | import random 8 | import os 9 | 10 | parser = argparse.ArgumentParser(description='process labels') 11 | parser.add_argument('--data_dir', 12 | required=True, 13 | type=str, 14 | help='path to the data') 15 | parser.add_argument('--output', 16 | required=True, 17 | type=str, 18 | help='path to save the process data') 19 | args = parser.parse_args() 20 | 21 | 22 | def main(): 23 | if not os.path.exists(args.output): 24 | os.mkdir(args.output) 25 | os.mkdir(os.path.join(args.output, 'color')) 26 | os.mkdir(os.path.join(args.output, 'depth')) 27 | os.mkdir(os.path.join(args.output, 'encoded_depth')) 28 | os.mkdir(os.path.join(args.output, 'label_map')) 29 | os.mkdir(os.path.join(args.output, 'camera_height')) 30 | f_p = open(os.path.join(args.output, 'pos.txt'), 'w') 31 | f_n = open(os.path.join(args.output, 'neg.txt'), 'w') 32 | dp = DataProcessor() 33 | 34 | # load data 35 | data_dir = os.listdir(args.data_dir) 36 | random.shuffle(data_dir) 37 | for data_id in data_dir: 38 | parent_dir = os.path.join(args.data_dir, data_id) 39 | print(parent_dir) 40 | # b_depth_height_map = cv2.imread(os.path.join(parent_dir, 'crop_background_depth.png'), 41 | # cv2.IMREAD_ANYDEPTH).astype(np.float32) 42 | label_files = os.listdir(os.path.join(parent_dir, 'zoomed_label')) 43 | with open(os.path.join(parent_dir, 'zoomed_file_name.txt'), 'r') as f: 44 | file_names = f.readlines() 45 | for file_name in file_names: 46 | print(file_name) 47 | file_name = file_name[:-2] if file_name[-2:] == '\r\n' else file_name[:-1] 48 | color = cv2.imread(os.path.join(parent_dir, 'zoomed_height_map_color', file_name+'.png')) 49 | depth = cv2.imread(os.path.join(parent_dir, 'zoomed_height_map_depth', file_name+'.png'), 50 | cv2.IMREAD_ANYDEPTH).astype(np.float32) 51 | # diff_depth = dp.get_diff_depth(depth, b_depth_height_map) 52 | # pad color and depth images 53 | pad_size = 44 54 | pad_color = np.ones((288, 288, 3), dtype=np.uint8) * 7 55 | pad_color[pad_size:pad_size + 200, pad_size:pad_size + 200, :] = color 56 | pad_depth = np.ones((288, 288), dtype=np.uint16) * 7 57 | # pad_depth[pad_size:pad_size + 200, pad_size:pad_size + 200] = diff_depth 58 | pad_depth[pad_size:pad_size + 200, pad_size:pad_size + 200] = depth 59 | background_depth_value = np.argmax(np.bincount(depth.astype(np.int).flatten())) 60 | camera_height = background_depth_value / 1000.0 61 | neglect_points = np.loadtxt(os.path.join(parent_dir, 'zoomed_label', file_name+'.object_points.txt')) + pad_size 62 | if len(neglect_points) == 0: 63 | continue 64 | if file_name+'.good.txt' in label_files: 65 | good_pixel_labels = np.loadtxt(os.path.join(parent_dir, 'zoomed_label', file_name+'.good.txt')) + pad_size 66 | grasp_centers = dp.get_grasp_center(good_pixel_labels) 67 | angle_indices = dp.get_grasp_angle(good_pixel_labels[0]) 68 | for angle_idx in angle_indices: 69 | quantified_angle = 22.5 * angle_idx 70 | for i, angle in enumerate(np.arange(quantified_angle-5, quantified_angle+5, 1)): 71 | grasp_label = np.zeros((288, 288, 3), dtype=np.uint8) # bgr for opencv 72 | grasp_label[..., 0] = 255 73 | rotated_neglect_points = dp.rotate(neglect_points, (144, 144), (angle / 360.0) * np.pi * 2) 74 | rotated_neglect_points = np.round(rotated_neglect_points).astype(np.int) 75 | grasp_label[rotated_neglect_points[:, 0], rotated_neglect_points[:, 1], 0] = 0 76 | grasp_label = cv2.medianBlur(grasp_label, 3) 77 | rotated_grasp_centers = dp.rotate(grasp_centers, (144, 144), (angle / 360.0) * np.pi * 2) 78 | rotated_grasp_centers = np.round(rotated_grasp_centers).astype(np.int) 79 | grasp_label = dp.gaussianize_label(grasp_label, 80 | rotated_grasp_centers, 81 | camera_height, 82 | is_good=True) 83 | # grasp_label[rotated_grasp_centers[:, 0], rotated_grasp_centers[:, 1], 1] = 255 84 | mtx = cv2.getRotationMatrix2D((144, 144), angle, 1) 85 | rotated_color = cv2.warpAffine(pad_color, mtx, (288, 288), 86 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7)) 87 | rotated_depth = cv2.warpAffine(pad_depth, mtx, (288, 288), 88 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7)) 89 | encoded_depth = dp.encode_depth(rotated_depth) 90 | cv2.imwrite(os.path.join(args.output, 'label_map', 91 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 92 | grasp_label) 93 | cv2.imwrite(os.path.join(args.output, 'color', 94 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 95 | rotated_color) 96 | cv2.imwrite(os.path.join(args.output, 'depth', 97 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 98 | rotated_depth) 99 | cv2.imwrite(os.path.join(args.output, 'encoded_depth', 100 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 101 | encoded_depth) 102 | np.savetxt(os.path.join(args.output, 'camera_height', 103 | data_id + '-' + file_name + '-{}-{:02d}.txt'.format(angle_idx, i)), 104 | np.array([camera_height * 1000.0])) 105 | f_p.write(data_id + '-' + file_name + '-{}-{:02d}\n'.format(angle_idx, i)) 106 | else: 107 | bad_pixel_labels = np.loadtxt(os.path.join(parent_dir, 'zoomed_label', file_name + '.bad.txt')) + pad_size 108 | grasp_centers = dp.get_grasp_center(bad_pixel_labels) 109 | angle_indices = dp.get_grasp_angle(bad_pixel_labels[0]) 110 | for angle_idx in angle_indices: 111 | quantified_angle = 22.5 * angle_idx 112 | for i, angle in enumerate(np.arange(quantified_angle - 5, quantified_angle + 5, 1)): 113 | grasp_label = np.zeros((288, 288, 3), dtype=np.uint8) # bgr for opencv 114 | grasp_label[..., 0] = 255 115 | rotated_neglect_points = dp.rotate(neglect_points, (144, 144), (angle / 360.0) * np.pi * 2) 116 | rotated_neglect_points = np.round(rotated_neglect_points).astype(np.int) 117 | grasp_label[rotated_neglect_points[:, 0], rotated_neglect_points[:, 1], 0] = 0 118 | grasp_label = cv2.medianBlur(grasp_label, 3) 119 | rotated_grasp_centers = dp.rotate(grasp_centers, (144, 144), (angle / 360.0) * np.pi * 2) 120 | rotated_grasp_centers = np.round(rotated_grasp_centers).astype(np.int) 121 | grasp_label = dp.gaussianize_label(grasp_label, 122 | rotated_grasp_centers, 123 | camera_height, 124 | is_good=False) 125 | # grasp_label[rotated_grasp_centers[:, 0], rotated_grasp_centers[:, 1], 2] = 255 126 | mtx = cv2.getRotationMatrix2D((144, 144), angle, 1) 127 | rotated_color = cv2.warpAffine(pad_color, mtx, (288, 288), 128 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7)) 129 | rotated_depth = cv2.warpAffine(pad_depth, mtx, (288, 288), 130 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7)) 131 | encoded_depth = dp.encode_depth(rotated_depth) 132 | cv2.imwrite(os.path.join(args.output, 'label_map', 133 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 134 | grasp_label) 135 | cv2.imwrite(os.path.join(args.output, 'color', 136 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 137 | rotated_color) 138 | cv2.imwrite(os.path.join(args.output, 'depth', 139 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 140 | rotated_depth) 141 | cv2.imwrite(os.path.join(args.output, 'encoded_depth', 142 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 143 | encoded_depth) 144 | np.savetxt(os.path.join(args.output, 'camera_height', 145 | data_id + '-' + file_name + '-{}-{:02d}.txt'.format(angle_idx, i)), 146 | np.array([camera_height * 1000.0])) 147 | f_n.write(data_id + '-' + file_name + '-{}-{:02d}\n'.format(angle_idx, i)) 148 | f_p.close() 149 | f_n.close() 150 | 151 | 152 | if __name__ == '__main__': 153 | main() 154 | 155 | -------------------------------------------------------------------------------- /data_collection/scene.py: -------------------------------------------------------------------------------- 1 | from data_collection.ur import UR5 2 | from data_collection.rg2 import RG2 3 | import numpy as np 4 | import data_collection.vrep as vrep 5 | import time 6 | import cv2 7 | 8 | 9 | class Scene(object): 10 | def __init__(self, ip='127.0.0.1', 11 | port=19997, 12 | camera_handle='realsense'): 13 | """ 14 | Initialization.Connect to remote server, initialize ur5 and gripper. 15 | :param ip: The ip address where the server is located. 16 | :param port: The port number where to connected. 17 | :param camera_handle: The camera handle id. 18 | """ 19 | vrep.simxFinish(-1) 20 | self.client_id = vrep.simxStart(ip, port, True, True, 5000, 5) 21 | vrep.simxStartSimulation(self.client_id, vrep.simx_opmode_blocking) 22 | self.ur5 = UR5(self.client_id) 23 | self.gripper = RG2(self.client_id) 24 | self.ur5.initialization() 25 | _, self.camera_handle = vrep.simxGetObjectHandle(self.client_id, camera_handle, vrep.simx_opmode_blocking) 26 | self.camera_height = self.get_object_position(self.camera_handle)[-1] 27 | 28 | self.background_color = self.get_color_image() 29 | self.background_depth = self.get_depth_image() 30 | 31 | def get_object_handle(self, handle_name): 32 | """ 33 | Get object handle. 34 | :param handle_name: A string. Handle name with respect to specific object handle in vrep. 35 | :return: The number of object handle. 36 | """ 37 | _, obj = vrep.simxGetObjectHandle(self.client_id, handle_name, vrep.simx_opmode_blocking) 38 | return obj 39 | 40 | def set_object_position(self, object_handle, position, relative_to_object_handle=-1): 41 | """ 42 | Set object position. 43 | :param object_handle: Handle of the object. 44 | :param position: The position value. 45 | :param relative_to_object_handle: Indicates relative to which reference frame the position is specified. 46 | :return: None. 47 | """ 48 | vrep.simxSetObjectPosition(self.client_id, 49 | object_handle, 50 | relative_to_object_handle, 51 | position, 52 | vrep.simx_opmode_oneshot) 53 | 54 | def get_object_position(self, object_handle, relative_to_object_handle=-1): 55 | """ 56 | Get object position. 57 | :param object_handle: An int. Handle of the object. 58 | :param relative_to_object_handle: An Int. Indicates relative to which reference frame the position is specified. 59 | :return: A list of float. The position of object. 60 | """ 61 | _, _, pos, _, _ = self.ur5.func(functionName='pyGetObjectQPosition', 62 | inputInts=[object_handle, relative_to_object_handle], 63 | inputFloats=[], 64 | inputStrings=[], 65 | inputBuffer='') 66 | return pos 67 | 68 | def set_object_quaternion(self, object_handle, quaternion, relative_to_object_handle=-1): 69 | """ 70 | Set object quaternion. 71 | :param object_handle: An int. Handle of the object. 72 | :param quaternion: A list of float. The quaternion value (x, y, z, w) 73 | :param relative_to_object_handle: An Int. Indicates relative to which reference frame the object is specified. 74 | :return: None. 75 | """ 76 | vrep.simxSetObjectQuaternion(self.client_id, 77 | object_handle, 78 | relative_to_object_handle, 79 | quaternion, 80 | vrep.simx_opmode_oneshot) 81 | 82 | def get_object_quaternion(self, object_handle, relative_to_object_handle=-1): 83 | """ 84 | Get object quaternion. 85 | :param object_handle: An int. Handel of the object. 86 | :param relative_to_object_handle: An int. Indicates relative to which refer 87 | :return: A list of float. The quaternion of object. 88 | """ 89 | _, _, quat, _, _ = self.ur5.func(functionName='pyGetObjectQuaternion', 90 | inputInts=[object_handle, relative_to_object_handle], 91 | inputFloats=[], 92 | inputStrings=[], 93 | inputBuffer='') 94 | return quat 95 | 96 | def replace_object(self, object_handle, object_quat, interval): 97 | """ 98 | Replace object if it is out of the workspace. 99 | :param object_handle: An int. Handel of the object. 100 | :param object_quat: A list of float. The quaternion which the object should rotate. 101 | :param interval: A list of float. The interval of position where the object should be placed. 102 | :return: A boolean. A boolean value indicating whether the object is replaced. 103 | """ 104 | pos = self.get_object_position(object_handle) 105 | if not interval[0] <= pos[0] <= interval[1] or not interval[0] <= pos[1] <= interval[1]: 106 | pos[0] = np.random.uniform(interval[0], interval[1]) 107 | pos[1] = np.random.uniform(interval[0], interval[1]) 108 | self.set_object_position(object_handle, pos) 109 | self.set_object_quaternion(object_handle, object_quat) 110 | time.sleep(0.7) 111 | return True 112 | return False 113 | 114 | def get_color_image(self): 115 | """ 116 | Get color image. 117 | :return: A numpy array of uint8. The RGB image containing the whole workspace. 118 | """ 119 | res, resolution, color = vrep.simxGetVisionSensorImage(self.client_id, 120 | self.camera_handle, 121 | 0, 122 | vrep.simx_opmode_blocking) 123 | return np.asanyarray(color, dtype=np.uint8).reshape(resolution[1], resolution[0], 3)[::-1, ...] 124 | 125 | def get_depth_image(self): 126 | """ 127 | Get depth image. 128 | :return: A numpy array of uint16. The depth image containing the whole workspace. 129 | """ 130 | res, resolution, depth = vrep.simxGetVisionSensorDepthBuffer(self.client_id, 131 | self.camera_handle, 132 | vrep.simx_opmode_blocking) 133 | depth = 1000 * np.asanyarray(depth) 134 | return depth.astype(np.uint16).reshape(resolution[1], resolution[0])[::-1, ...] 135 | 136 | def get_center_from_image(self, depth): 137 | """ 138 | Get grasp position which belongs to the object in image space. 139 | :param depth: A numpy array of uint16. The depth image. 140 | :return: center_point: A numpy array of float32. The position of the object which the gripper will reach. 141 | :return: points: A numpy array of float32. The object positions which belong to the object in image space. 142 | :return: A float. The distance between the camera and the center point of object which the gripper will reach. 143 | """ 144 | crop_depth = depth[45:435, 119:509] 145 | crop_b_depth = self.background_depth[45:435, 119:509] 146 | resized_depth = cv2.resize(crop_depth, (200, 200)).astype(np.float32) 147 | resized_b_depth = cv2.resize(crop_b_depth, (200, 200)).astype(np.float32) 148 | diff_depth = np.abs(resized_depth - resized_b_depth) 149 | cv2.medianBlur(diff_depth, 5) 150 | points = np.stack(np.where(diff_depth > 0), axis=0).T 151 | center_point = points[np.random.randint(0, points.shape[0])].tolist() 152 | reaching_height = max(0.02, diff_depth[center_point[0], center_point[1]]*0.0001 - 0.027) 153 | return center_point, points, reaching_height 154 | 155 | def stop_simulation(self): 156 | """ 157 | Stop vrep simulation. 158 | :return: None. 159 | """ 160 | vrep.simxStopSimulation(self.client_id, vrep.simx_opmode_oneshot_wait) 161 | vrep.simxFinish(self.client_id) 162 | 163 | @staticmethod 164 | def get_label(center_point, angle, width, ratio): 165 | """ 166 | Obtain grasp labels. Each grasp label is represented as two points. 167 | :param center_point: A numpy array of float32. The position of the object which the gripper will reach. 168 | :param angle: A float. The rotated angle of the gripper. 169 | :param width: A float. The width between two finger tips. 170 | :param ratio: A float. The ratio of the width of workspace and the width of image. 171 | :return: A numpy array of float. The grasp points. 172 | """ 173 | p_width = width / ratio 174 | c_h = center_point[0] 175 | c_w = center_point[1] 176 | h = [delta_h for delta_h in range(-3, 4)] 177 | w = [-p_width / 2, p_width / 2] 178 | points = np.asanyarray([[hh, w[0], hh, w[1]] for hh in h]) 179 | rotate_matrix = np.array([[np.cos(angle), -np.sin(angle)], 180 | [np.sin(angle), np.cos(angle)]]) 181 | points[:, 0:2] = np.dot(rotate_matrix, points[:, 0:2].T).T 182 | points[:, 2:] = np.dot(rotate_matrix, points[:, 2:].T).T 183 | points = points + np.asanyarray([[c_h, c_w, c_h, c_w]]) 184 | points = np.floor(points).astype(np.int) 185 | return points 186 | 187 | @staticmethod 188 | def draw_label(points, width, height, color=(0, 255, 0)): 189 | """ 190 | Draw labels according to the grasp points. 191 | :param points: A numpy array of float. The grasp points. 192 | :param width: A float. The width of image. 193 | :param height: A float. The height of image. 194 | :param color: A tuple. The color of lines. 195 | :return: A numpy array of uint8. The label map. 196 | """ 197 | label = np.ones((height, width, 3), dtype=np.uint8) * 255 198 | for point in points: 199 | pt1 = (point[1], point[0]) 200 | pt2 = (point[3], point[2]) 201 | cv2.line(label, pt1, pt2, color) 202 | return label 203 | -------------------------------------------------------------------------------- /data_processing/genarate_pseudo_data_v2.py: -------------------------------------------------------------------------------- 1 | from data_processing.data_processor import DataProcessor 2 | 3 | import numpy as np 4 | import cv2 5 | 6 | import argparse 7 | import random 8 | import copy 9 | import os 10 | 11 | parser = argparse.ArgumentParser(description='process labels') 12 | parser.add_argument('--data_dir', 13 | required=True, 14 | type=str, 15 | help='path to the data') 16 | parser.add_argument('--output', 17 | required=True, 18 | type=str, 19 | help='path to save the process data') 20 | args = parser.parse_args() 21 | 22 | 23 | def main(): 24 | if not os.path.exists(args.output): 25 | os.mkdir(args.output) 26 | os.mkdir(os.path.join(args.output, 'color')) 27 | os.mkdir(os.path.join(args.output, 'depth')) 28 | os.mkdir(os.path.join(args.output, 'encoded_depth')) 29 | os.mkdir(os.path.join(args.output, 'label_map')) 30 | os.mkdir(os.path.join(args.output, 'camera_height')) 31 | f_p = open(os.path.join(args.output, 'pos.txt'), 'w') 32 | f_n = open(os.path.join(args.output, 'neg.txt'), 'w') 33 | dp = DataProcessor() 34 | 35 | # load data 36 | data_dir = os.listdir(args.data_dir) 37 | random.shuffle(data_dir) 38 | for data_id in data_dir: 39 | parent_dir = os.path.join(args.data_dir, data_id) 40 | print(parent_dir) 41 | # b_depth_height_map = cv2.imread(os.path.join(parent_dir, 'crop_background_depth.png'), 42 | # cv2.IMREAD_ANYDEPTH).astype(np.float32) 43 | label_files = os.listdir(os.path.join(parent_dir, 'zoomed_label')) 44 | with open(os.path.join(parent_dir, 'zoomed_file_name.txt'), 'r') as f: 45 | file_names = f.readlines() 46 | for file_name in file_names: 47 | print(file_name) 48 | file_name = file_name[:-2] if file_name[-2:] == '\r\n' else file_name[:-1] 49 | color = cv2.imread(os.path.join(parent_dir, 'zoomed_height_map_color', file_name+'.png')) 50 | depth = cv2.imread(os.path.join(parent_dir, 'zoomed_height_map_depth', file_name+'.png'), 51 | cv2.IMREAD_ANYDEPTH).astype(np.float32) 52 | # pad color and depth images 53 | pad_size = 44 54 | pad_color = np.ones((288, 288, 3), dtype=np.uint8) * 7 55 | pad_color[pad_size:pad_size + 200, pad_size:pad_size + 200, :] = color 56 | pad_depth = np.ones((288, 288), dtype=np.uint16) * 7 57 | pad_depth[pad_size:pad_size + 200, pad_size:pad_size + 200] = depth 58 | background_depth_value = np.argmax(np.bincount(depth.astype(np.int).flatten())) 59 | camera_height = background_depth_value / 1000.0 60 | neglect_points = np.loadtxt(os.path.join(parent_dir, 'zoomed_label', file_name+'.object_points.txt')) + pad_size 61 | if len(neglect_points) == 0: 62 | continue 63 | if file_name+'.good.txt' in label_files: 64 | good_pixel_labels = np.loadtxt(os.path.join(parent_dir, 'zoomed_label', file_name+'.good.txt')) + pad_size 65 | grasp_centers = dp.get_grasp_center(good_pixel_labels) 66 | angle_indices = dp.get_grasp_angle(good_pixel_labels[0]) 67 | for angle_idx in angle_indices: 68 | quantified_angle = 22.5 * angle_idx 69 | for i, angle in enumerate(np.arange(quantified_angle-3, quantified_angle+3, 1)): 70 | grasp_label = np.zeros((288, 288, 3), dtype=np.uint8) # bgr for opencv 71 | grasp_label[..., 0] = 255 72 | rotated_neglect_points = dp.rotate(neglect_points, (144, 144), (angle / 360.0) * np.pi * 2) 73 | rotated_neglect_points = np.round(rotated_neglect_points).astype(np.int) 74 | grasp_label[rotated_neglect_points[:, 0], rotated_neglect_points[:, 1], 0] = 0 75 | grasp_label = cv2.medianBlur(grasp_label, 3) 76 | rotated_grasp_centers = dp.rotate(grasp_centers, (144, 144), (angle / 360.0) * np.pi * 2) 77 | rotated_grasp_centers = np.round(rotated_grasp_centers).astype(np.int) 78 | for j in range(3): 79 | camera_height = np.random.uniform(0.20, 0.63) 80 | new_depth = int(camera_height * 1000.0) - (background_depth_value - depth) 81 | pad_depth[pad_size:pad_size + 200, pad_size:pad_size + 200] = new_depth 82 | gaussianized_label = dp.gaussianize_label(copy.deepcopy(grasp_label), 83 | rotated_grasp_centers, 84 | camera_height, 85 | is_good=True) 86 | # grasp_label[rotated_grasp_centers[:, 0], rotated_grasp_centers[:, 1], 1] = 255 87 | mtx = cv2.getRotationMatrix2D((144, 144), angle, 1) 88 | rotated_color = cv2.warpAffine(pad_color, mtx, (288, 288), 89 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7)) 90 | rotated_depth = cv2.warpAffine(pad_depth, mtx, (288, 288), 91 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7)) 92 | encoded_depth = dp.encode_depth(rotated_depth) 93 | cv2.imwrite(os.path.join(args.output, 'label_map', 94 | data_id + '-' + file_name + '-{}-{:02d}-{:02d}.png'.format(angle_idx, i, j)), 95 | gaussianized_label) 96 | cv2.imwrite(os.path.join(args.output, 'color', 97 | data_id + '-' + file_name + '-{}-{:02d}-{:02d}.png'.format(angle_idx, i, j)), 98 | rotated_color) 99 | cv2.imwrite(os.path.join(args.output, 'depth', 100 | data_id + '-' + file_name + '-{}-{:02d}-{:02d}.png'.format(angle_idx, i, j)), 101 | rotated_depth) 102 | cv2.imwrite(os.path.join(args.output, 'encoded_depth', 103 | data_id + '-' + file_name + '-{}-{:02d}-{:02d}.png'.format(angle_idx, i, j)), 104 | encoded_depth) 105 | np.savetxt(os.path.join(args.output, 'camera_height', 106 | data_id + '-' + file_name + '-{}-{:02d}-{:02d}.txt'.format(angle_idx, i, j)), 107 | np.array([camera_height*1000.0])) 108 | f_p.write(data_id + '-' + file_name + '-{}-{:02d}-{:02d}\n'.format(angle_idx, i, j)) 109 | else: 110 | bad_pixel_labels = np.loadtxt(os.path.join(parent_dir, 'zoomed_label', file_name + '.bad.txt')) + pad_size 111 | grasp_centers = dp.get_grasp_center(bad_pixel_labels) 112 | angle_indices = dp.get_grasp_angle(bad_pixel_labels[0]) 113 | for angle_idx in angle_indices: 114 | quantified_angle = 22.5 * angle_idx 115 | for i, angle in enumerate(np.arange(quantified_angle - 5, quantified_angle + 5, 1)): 116 | grasp_label = np.zeros((288, 288, 3), dtype=np.uint8) # bgr for opencv 117 | grasp_label[..., 0] = 255 118 | rotated_neglect_points = dp.rotate(neglect_points, (144, 144), (angle / 360.0) * np.pi * 2) 119 | rotated_neglect_points = np.round(rotated_neglect_points).astype(np.int) 120 | grasp_label[rotated_neglect_points[:, 0], rotated_neglect_points[:, 1], 0] = 0 121 | grasp_label = cv2.medianBlur(grasp_label, 3) 122 | rotated_grasp_centers = dp.rotate(grasp_centers, (144, 144), (angle / 360.0) * np.pi * 2) 123 | rotated_grasp_centers = np.round(rotated_grasp_centers).astype(np.int) 124 | grasp_label = dp.gaussianize_label(grasp_label, 125 | rotated_grasp_centers, 126 | camera_height, 127 | is_good=False) 128 | # grasp_label[rotated_grasp_centers[:, 0], rotated_grasp_centers[:, 1], 2] = 255 129 | mtx = cv2.getRotationMatrix2D((144, 144), angle, 1) 130 | rotated_color = cv2.warpAffine(pad_color, mtx, (288, 288), 131 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7)) 132 | rotated_depth = cv2.warpAffine(pad_depth, mtx, (288, 288), 133 | borderMode=cv2.BORDER_CONSTANT, borderValue=(7, 7, 7)) 134 | encoded_depth = dp.encode_depth(rotated_depth) 135 | cv2.imwrite(os.path.join(args.output, 'label_map', 136 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 137 | grasp_label) 138 | cv2.imwrite(os.path.join(args.output, 'color', 139 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 140 | rotated_color) 141 | cv2.imwrite(os.path.join(args.output, 'depth', 142 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 143 | rotated_depth) 144 | cv2.imwrite(os.path.join(args.output, 'encoded_depth', 145 | data_id + '-' + file_name + '-{}-{:02d}.png'.format(angle_idx, i)), 146 | encoded_depth) 147 | np.savetxt(os.path.join(args.output, 'camera_height', 148 | data_id + '-' + file_name + '-{}-{:02d}.txt'.format(angle_idx, i)), 149 | np.array([camera_height * 1000.0])) 150 | f_n.write(data_id + '-' + file_name + '-{}-{:02d}\n'.format(angle_idx, i)) 151 | f_p.close() 152 | f_n.close() 153 | 154 | 155 | if __name__ == '__main__': 156 | main() 157 | 158 | -------------------------------------------------------------------------------- /nets/resnet_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import tensorflow as tf 3 | 4 | slim = tf.contrib.slim 5 | 6 | 7 | class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])): 8 | """A named tuple describing a ResNet block. 9 | 10 | Its parts are: 11 | scope: The scope of the `Block`. 12 | unit_fn: The ResNet unit function which takes as input a `Tensor` and 13 | returns another `Tensor` with the output of the ResNet unit. 14 | args: A list of length equal to the number of units in the `Block`. The list 15 | contains one (depth, depth_bottleneck, stride) tuple for each unit in the 16 | block to serve as argument to unit_fn. 17 | """ 18 | 19 | 20 | def subsample(inputs, factor, scope=None): 21 | """Subsamples the input along the spatial dimensions. 22 | 23 | Args: 24 | inputs: A `Tensor` of size [batch, height_in, width_in, channels]. 25 | factor: The subsampling factor. 26 | scope: Optional variable_scope. 27 | 28 | Returns: 29 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the 30 | input, either intact (if factor == 1) or subsampled (if factor > 1). 31 | """ 32 | if factor == 1: 33 | return inputs 34 | else: 35 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope) 36 | 37 | 38 | def upsample(inputs, factor, scope=None): 39 | """Upsamples the input along the spatial dimensions. 40 | 41 | Args: 42 | inputs: A `Tensor` of size [batch, height_in, width_in, channels]. 43 | factor: The subsampling factor. 44 | scope: Optional variable_scope. 45 | 46 | Returns: 47 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the 48 | input, either intact (if factor == 1) or subsampled (if factor > 1). 49 | """ 50 | if factor == 1: 51 | return inputs 52 | else: 53 | height, width = inputs.get_shape().as_list()[1:3] 54 | return tf.image.resize_bilinear(inputs, [height * factor, width * factor], name=scope) 55 | 56 | 57 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None): 58 | """Strided 2-D convolution with 'SAME' padding. 59 | 60 | When stride > 1, then we do explicit zero-padding, followed by conv2d with 61 | 'VALID' padding. 62 | 63 | Note that 64 | 65 | net = conv2d_same(inputs, num_outputs, 3, stride=stride) 66 | 67 | is equivalent to 68 | 69 | net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME') 70 | net = subsample(net, factor=stride) 71 | 72 | whereas 73 | 74 | net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME') 75 | 76 | is different when the input's height or width is even, which is why we add the 77 | current function. For more details, see ResnetUtilsTest.testConv2DSameEven(). 78 | 79 | Args: 80 | inputs: A 4-D tensor of size [batch, height_in, width_in, channels]. 81 | num_outputs: An integer, the number of output filters. 82 | kernel_size: An int with the kernel_size of the filters. 83 | stride: An integer, the output stride. 84 | rate: An integer, rate for atrous convolution. 85 | scope: Scope. 86 | 87 | Returns: 88 | output: A 4-D tensor of size [batch, height_out, width_out, channels] with 89 | the convolution output. 90 | """ 91 | if stride == 1: 92 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate, 93 | padding='SAME', scope=scope) 94 | else: 95 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) 96 | pad_total = kernel_size_effective - 1 97 | pad_beg = pad_total // 2 98 | pad_end = pad_total - pad_beg 99 | inputs = tf.pad(inputs, 100 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) 101 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride, 102 | rate=rate, padding='VALID', scope=scope) 103 | 104 | 105 | @slim.add_arg_scope 106 | def stack_blocks_dense(net, blocks, output_stride=None, 107 | outputs_collections=None): 108 | """Stacks ResNet `Blocks` and controls output feature density. 109 | 110 | First, this function creates scopes for the ResNet in the form of 111 | 'block_name/unit_1', 'block_name/unit_2', etc. 112 | 113 | Second, this function allows the user to explicitly control the ResNet 114 | output_stride, which is the ratio of the input to output spatial resolution. 115 | This is useful for dense prediction tasks such as semantic segmentation or 116 | object detection. 117 | 118 | Most ResNets consist of 4 ResNet blocks and subsample the activations by a 119 | factor of 2 when transitioning between consecutive ResNet blocks. This results 120 | to a nominal ResNet output_stride equal to 8. If we set the output_stride to 121 | half the nominal network stride (e.g., output_stride=4), then we compute 122 | responses twice. 123 | 124 | Control of the output feature density is implemented by atrous convolution. 125 | 126 | Args: 127 | net: A `Tensor` of size [batch, height, width, channels]. 128 | blocks: A list of length equal to the number of ResNet `Blocks`. Each 129 | element is a ResNet `Block` object describing the units in the `Block`. 130 | output_stride: If `None`, then the output will be computed at the nominal 131 | network stride. If output_stride is not `None`, it specifies the requested 132 | ratio of input to output spatial resolution, which needs to be equal to 133 | the product of unit strides from the start up to some level of the ResNet. 134 | For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1, 135 | then valid values for the output_stride are 1, 2, 6, 24 or None (which 136 | is equivalent to output_stride=24). 137 | outputs_collections: Collection to add the ResNet block output. 138 | 139 | Returns: 140 | net: Output tensor with stride equal to the specified output_stride. 141 | 142 | Raises: 143 | ValueError: If the target output_stride is not valid. 144 | """ 145 | # The current_stride variable keeps track of the effective stride of the 146 | # activations. This allows us to invoke atrous convolution whenever applying 147 | # the next residual unit would result in the activations having stride larger 148 | # than the target output_stride. 149 | current_stride = 1 150 | 151 | # The atrous convolution rate parameter. 152 | rate = 1 153 | 154 | for block in blocks: 155 | with tf.variable_scope(block.scope, 'block', [net]) as sc: 156 | for i, unit in enumerate(block.args): 157 | if output_stride is not None and current_stride > output_stride: 158 | raise ValueError('The target output_stride cannot be reached.') 159 | 160 | with tf.variable_scope('unit_%d' % (i + 1), values=[net]): 161 | # If we have reached the target output_stride, then we need to employ 162 | # atrous convolution with stride=1 and multiply the atrous rate by the 163 | # current unit's stride for use in subsequent layers. 164 | if output_stride is not None and current_stride == output_stride: 165 | net = block.unit_fn(net, rate=rate, **dict(unit, stride=1)) 166 | rate *= unit.get('stride', 1) 167 | 168 | else: 169 | net = block.unit_fn(net, rate=1, **unit) 170 | current_stride *= unit.get('stride', 1) 171 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 172 | 173 | if output_stride is not None and current_stride != output_stride: 174 | raise ValueError('The target output_stride cannot be reached.') 175 | 176 | return net 177 | 178 | 179 | def resnet_arg_scope(weight_decay=0.0001, 180 | batch_norm_decay=0.997, 181 | batch_norm_epsilon=1e-5, 182 | batch_norm_scale=True, 183 | activation_fn=tf.nn.relu, 184 | use_batch_norm=True): 185 | """Defines the default ResNet arg scope. 186 | 187 | TODO(gpapan): The batch-normalization related default values above are 188 | appropriate for use in conjunction with the reference ResNet models 189 | released at https://github.com/KaimingHe/deep-residual-networks. When 190 | training ResNets from scratch, they might need to be tuned. 191 | 192 | Args: 193 | weight_decay: The weight decay to use for regularizing the model. 194 | batch_norm_decay: The moving average decay when estimating layer activation 195 | statistics in batch normalization. 196 | batch_norm_epsilon: Small constant to prevent division by zero when 197 | normalizing activations by their variance in batch normalization. 198 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the 199 | activations in the batch normalization layer. 200 | activation_fn: The activation function which is used in ResNet. 201 | use_batch_norm: Whether or not to use batch normalization. 202 | 203 | Returns: 204 | An `arg_scope` to use for the resnet models. 205 | """ 206 | batch_norm_params = { 207 | 'decay': batch_norm_decay, 208 | 'epsilon': batch_norm_epsilon, 209 | 'scale': batch_norm_scale, 210 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 211 | 'fused': None, # Use fused batch norm if possible. 212 | } 213 | 214 | with slim.arg_scope( 215 | [slim.conv2d], 216 | weights_regularizer=slim.l2_regularizer(weight_decay), 217 | weights_initializer=slim.variance_scaling_initializer(), 218 | activation_fn=activation_fn, 219 | normalizer_fn=slim.batch_norm if use_batch_norm else None, 220 | normalizer_params=batch_norm_params): 221 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 222 | # The following implies padding='SAME' for pool1, which makes feature 223 | # alignment easier for dense prediction tasks. This is also used in 224 | # https://github.com/facebook/fb.resnet.torch. However the accompanying 225 | # code of 'Deep Residual Learning for Image Recognition' uses 226 | # padding='VALID' for pool1. You can switch to that choice by setting 227 | # slim.arg_scope([slim.max_pool2d], padding='VALID'). 228 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc: 229 | return arg_sc 230 | -------------------------------------------------------------------------------- /network_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.slim as slim 3 | 4 | import os 5 | 6 | 7 | def decode_depth(encoded_depth): 8 | encoded_depth = tf.cast(encoded_depth, tf.float32) 9 | r, g, b = tf.unstack(encoded_depth, axis=2) 10 | depth = r * 65536.0 + g * 256.0 + b # decode depth image 11 | # depth = tf.div(depth, tf.constant(100.0)) 12 | return depth 13 | 14 | 15 | def encode_depth(depth): 16 | # depth = tf.multiply(depth, tf.constant(10000.0)) 17 | depth = tf.cast(depth, tf.uint16) 18 | r = depth / 256 / 256 19 | g = depth / 256 20 | b = depth % 256 21 | encoded_depth = tf.stack([r, g, b], axis=2) 22 | encoded_depth = tf.cast(encoded_depth, tf.uint8) 23 | return encoded_depth 24 | 25 | 26 | def get_dataset(dataset_dir, 27 | num_readers, 28 | num_preprocessing_threads, 29 | image_size, 30 | label_size, 31 | batch_size=1, 32 | reader=None, 33 | shuffle=True, 34 | num_epochs=None, 35 | is_training=True, 36 | is_depth=True): 37 | dataset_dir_list = [os.path.join(dataset_dir, filename) 38 | for filename in os.listdir(dataset_dir) if filename.endswith('.tfrecord')] 39 | if reader is None: 40 | reader = tf.TFRecordReader 41 | keys_to_features = { 42 | 'image/color': tf.FixedLenFeature((), tf.string, default_value=''), 43 | 'image/encoded_depth': tf.FixedLenFeature((), tf.string, default_value=''), 44 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 45 | 'image/label': tf.FixedLenFeature((), tf.string, default_value=''), 46 | 'image/camera_height': tf.FixedLenFeature((), tf.float32, default_value=0.0), 47 | } 48 | items_to_handlers = { 49 | 'color': slim.tfexample_decoder.Image(image_key='image/color', 50 | shape=(image_size, image_size, 3), 51 | channels=3), 52 | 'encoded_depth': slim.tfexample_decoder.Image(image_key='image/encoded_depth', 53 | shape=(image_size, image_size, 3), 54 | channels=3), 55 | 'label': slim.tfexample_decoder.Image(image_key='image/label', 56 | shape=(label_size, label_size, 3), 57 | channels=3), 58 | 'camera_height': slim.tfexample_decoder.Tensor(tensor_key='image/camera_height', 59 | shape=(1,)), 60 | 61 | } 62 | decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) 63 | dataset = slim.dataset.Dataset(data_sources=dataset_dir_list, 64 | reader=reader, 65 | decoder=decoder, 66 | num_samples=3, 67 | items_to_descriptions=None) 68 | provider = slim.dataset_data_provider.DatasetDataProvider(dataset, 69 | num_readers=num_readers, 70 | shuffle=shuffle, 71 | num_epochs=num_epochs, 72 | common_queue_capacity=20 * batch_size, 73 | common_queue_min=10 * batch_size) 74 | color, encoded_depth, label, camera_height = provider.get(['color', 'encoded_depth', 'label', 'camera_height']) 75 | color = tf.cast(color, tf.float32) 76 | color = color * tf.random_normal(color.get_shape(), mean=1, stddev=0.01) 77 | color = color / tf.constant([255.0]) 78 | color = (color - tf.constant([0.485, 0.456, 0.406])) / tf.constant([0.229, 0.224, 0.225]) 79 | depth = decode_depth(encoded_depth) 80 | camera_height = tf.expand_dims(tf.expand_dims(camera_height, axis=0), axis=0) 81 | camera_height = camera_height / 1000.0 82 | depth = depth * tf.random_normal(depth.get_shape(), mean=1, stddev=0.01) 83 | depth = depth / 1000.0 # (depth - tf.reduce_mean(depth)) / 1000.0 84 | if is_depth: 85 | input = tf.concat([color, tf.expand_dims(depth, axis=2)], axis=2) 86 | else: 87 | input = color 88 | label = tf.cast(label, tf.float32) 89 | label = label / 255.0 90 | if is_training: 91 | inputs, labels, camera_heights = tf.train.batch([input, label, camera_height], 92 | batch_size=batch_size, 93 | num_threads=num_preprocessing_threads, 94 | capacity=5*batch_size) 95 | else: 96 | inputs = tf.expand_dims(input, axis=0) 97 | labels = tf.expand_dims(label, axis=0) 98 | camera_heights = tf.expand_dims(camera_height, axis=0) 99 | return inputs, labels, camera_heights 100 | 101 | 102 | def get_color_dataset(dataset_dir, 103 | num_readers, 104 | num_preprocessing_threads, 105 | image_size, 106 | label_size, 107 | batch_size=1, 108 | reader=None, 109 | shuffle=True, 110 | num_epochs=None, 111 | is_training=True): 112 | dataset_dir_list = [os.path.join(dataset_dir, filename) 113 | for filename in os.listdir(dataset_dir) if filename.endswith('.tfrecord')] 114 | if reader is None: 115 | reader = tf.TFRecordReader 116 | keys_to_features = { 117 | 'image/color': tf.FixedLenFeature((), tf.string, default_value=''), 118 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 119 | 'image/label': tf.FixedLenFeature((), tf.string, default_value=''), 120 | } 121 | items_to_handlers = { 122 | 'color': slim.tfexample_decoder.Image(image_key='image/color', 123 | shape=(image_size, image_size, 3), 124 | channels=3), 125 | 'label': slim.tfexample_decoder.Image(image_key='image/label', 126 | shape=(label_size, label_size, 3), 127 | channels=3), 128 | } 129 | decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) 130 | dataset = slim.dataset.Dataset(data_sources=dataset_dir_list, 131 | reader=reader, 132 | decoder=decoder, 133 | num_samples=3, 134 | items_to_descriptions=None) 135 | provider = slim.dataset_data_provider.DatasetDataProvider(dataset, 136 | num_readers=num_readers, 137 | shuffle=shuffle, 138 | num_epochs=num_epochs, 139 | common_queue_capacity=20 * batch_size, 140 | common_queue_min=10 * batch_size) 141 | color, label = provider.get(['color', 'label']) 142 | color = tf.cast(color, tf.float32) 143 | color = color * tf.random_normal(color.get_shape(), mean=1, stddev=0.01) 144 | color = color / tf.constant([255.0]) 145 | color = (color - tf.constant([0.485, 0.456, 0.406])) / tf.constant([0.229, 0.224, 0.225]) 146 | label = tf.cast(label, tf.float32) / 255.0 147 | if is_training: 148 | colors, labels = tf.train.batch([color, label], 149 | batch_size=batch_size, 150 | num_threads=num_preprocessing_threads, 151 | capacity=5*batch_size) 152 | else: 153 | colors = tf.expand_dims(color, axis=0) 154 | labels = tf.expand_dims(label, axis=0) 155 | return colors, labels 156 | 157 | 158 | def get_depth_dataset(dataset_dir, 159 | num_readers, 160 | num_preprocessing_threads, 161 | image_size, 162 | label_size, 163 | batch_size=1, 164 | reader=None, 165 | shuffle=True, 166 | num_epochs=None, 167 | is_training=True): 168 | dataset_dir_list = [os.path.join(dataset_dir, filename) 169 | for filename in os.listdir(dataset_dir) if filename.endswith('.tfrecord')] 170 | if reader is None: 171 | reader = tf.TFRecordReader 172 | keys_to_features = { 173 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 174 | 'image/encoded_depth': tf.FixedLenFeature((), tf.string, default_value=''), 175 | 'image/label': tf.FixedLenFeature((), tf.string, default_value=''), 176 | } 177 | items_to_handlers = { 178 | 'encoded_depth': slim.tfexample_decoder.Image(image_key='image/encoded_depth', 179 | shape=(image_size, image_size, 3), 180 | channels=3), 181 | 'label': slim.tfexample_decoder.Image(image_key='image/label', 182 | shape=(label_size, label_size, 3), 183 | channels=3), 184 | } 185 | decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) 186 | dataset = slim.dataset.Dataset(data_sources=dataset_dir_list, 187 | reader=reader, 188 | decoder=decoder, 189 | num_samples=3, 190 | items_to_descriptions=None) 191 | provider = slim.dataset_data_provider.DatasetDataProvider(dataset, 192 | num_readers=num_readers, 193 | shuffle=shuffle, 194 | num_epochs=num_epochs, 195 | common_queue_capacity=20 * batch_size, 196 | common_queue_min=10 * batch_size) 197 | encoded_depth, label = provider.get(['encoded_depth', 'label']) 198 | depth = decode_depth(encoded_depth) 199 | depth = (depth - tf.reduce_mean(depth)) / 1000.0 200 | depth = depth * tf.random_normal(depth.get_shape(), mean=1, stddev=0.01) 201 | depth = tf.stack([depth, depth, depth], axis=2) 202 | label = tf.cast(label, tf.float32) / 255.0 203 | if is_training: 204 | depths, labels = tf.train.batch([depth, label], 205 | batch_size=batch_size, 206 | num_threads=num_preprocessing_threads, 207 | capacity=5*batch_size) 208 | else: 209 | depths = tf.expand_dims(depth, axis=0) 210 | labels = tf.expand_dims(label, axis=0) 211 | return depths, labels 212 | 213 | 214 | def add_summary(inputs, labels, end_points, loss, hparams): 215 | h_b = int(hparams.batch_size/2) 216 | tf.summary.scalar(hparams.scope+'_loss', loss) 217 | tf.summary.image(hparams.scope+'_inputs_g', inputs[0:h_b]) 218 | tf.summary.image(hparams.scope+'_inputs_r', inputs[h_b:]) 219 | tf.summary.image(hparams.scope+'_labels_g', labels[0:h_b]) 220 | tf.summary.image(hparams.scope+'_labels_r', labels[h_b:]) 221 | # for i in range(1, 3): 222 | # for j in range(64): 223 | # tf.summary.image(scope + '/conv{}' + '_{}'.format(i, j), 224 | # end_points[scope + '/conv{}'.format(i)][0:1, :, :, j:j + 1]) 225 | # tf.summary.image(scope + '/conv3', end_points[scope + '/conv3']) 226 | net = end_points['logits'] 227 | infer_map = tf.exp(net) / tf.reduce_sum(tf.exp(net), axis=3, keepdims=True) 228 | tf.summary.image(hparams.scope+'_inference_map_g', infer_map[0:h_b]) 229 | tf.summary.image(hparams.scope+'_inference_map_r', infer_map[h_b:]) 230 | # variable_list = slim.get_model_variables() 231 | # for var in variable_list: 232 | # tf.summary.histogram(var.name[:-2], var) 233 | 234 | 235 | def restore_map(): 236 | variable_list = slim.get_model_variables() 237 | variables_to_restore = {var.op.name: var for var in variable_list} 238 | return variables_to_restore 239 | 240 | 241 | def restore_from_classification_checkpoint(scope, model_name, checkpoint_exclude_scopes): 242 | variable_list = slim.get_model_variables(os.path.join(scope, 'feature_extractor')) 243 | for checkpoint_exclude_scope in checkpoint_exclude_scopes: 244 | variable_list = [var for var in variable_list if checkpoint_exclude_scope not in var.op.name] 245 | variables_to_restore = {} 246 | for var in variable_list: 247 | if var.name.startswith(os.path.join(scope, 'feature_extractor')): 248 | var_name = var.op.name.replace(os.path.join(scope, 'feature_extractor'), model_name) 249 | variables_to_restore[var_name] = var 250 | return variables_to_restore 251 | -------------------------------------------------------------------------------- /nets/resnet_v1.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.slim as slim 3 | from nets import resnet_utils 4 | 5 | 6 | resnet_arg_scope = resnet_utils.resnet_arg_scope 7 | 8 | 9 | @slim.add_arg_scope 10 | def bottleneck(inputs, 11 | depth, 12 | depth_bottleneck, 13 | stride, 14 | rate=1, 15 | outputs_collections=None, 16 | scope=None, 17 | use_bounded_activations=False): 18 | """Bottleneck residual unit variant with BN after convolutions. 19 | 20 | This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for 21 | its definition. Note that we use here the bottleneck variant which has an 22 | extra bottleneck layer. 23 | 24 | When putting together two consecutive ResNet blocks that use this unit, one 25 | should use stride = 2 in the last unit of the first block. 26 | 27 | Args: 28 | inputs: A tensor of size [batch, height, width, channels]. 29 | depth: The depth of the ResNet unit output. 30 | depth_bottleneck: The depth of the bottleneck layers. 31 | stride: The ResNet unit's stride. Determines the amount of downsampling of 32 | the units output compared to its input. 33 | rate: An integer, rate for atrous convolution. 34 | outputs_collections: Collection to add the ResNet unit output. 35 | scope: Optional variable_scope. 36 | use_bounded_activations: Whether or not to use bounded activations. Bounded 37 | activations better lend themselves to quantized inference. 38 | 39 | Returns: 40 | The ResNet unit's output. 41 | """ 42 | with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc: 43 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) 44 | if depth == depth_in: 45 | shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') 46 | else: 47 | shortcut = slim.conv2d( 48 | inputs, 49 | depth, [1, 1], 50 | stride=stride, 51 | activation_fn=tf.nn.relu6 if use_bounded_activations else None, 52 | scope='shortcut') 53 | 54 | residual = slim.conv2d(inputs, depth_bottleneck, [1, 1], stride=1, 55 | scope='conv1') 56 | residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride, 57 | rate=rate, scope='conv2') 58 | residual = slim.conv2d(residual, depth, [1, 1], stride=1, 59 | activation_fn=None, scope='conv3') 60 | 61 | if use_bounded_activations: 62 | # Use clip_by_value to simulate bandpass activation. 63 | residual = tf.clip_by_value(residual, -6.0, 6.0) 64 | output = tf.nn.relu6(shortcut + residual) 65 | else: 66 | output = tf.nn.relu(shortcut + residual) 67 | 68 | return slim.utils.collect_named_outputs(outputs_collections, 69 | sc.name, 70 | output) 71 | 72 | 73 | def resnet_v1(inputs, 74 | blocks, 75 | num_classes=None, 76 | is_training=True, 77 | global_pool=True, 78 | output_stride=None, 79 | include_root_block=True, 80 | spatial_squeeze=True, 81 | reuse=None, 82 | scope=None): 83 | """Generator for v1 ResNet models. 84 | 85 | This function generates a family of ResNet v1 models. See the resnet_v1_*() 86 | methods for specific model instantiations, obtained by selecting different 87 | block instantiations that produce ResNets of various depths. 88 | 89 | Training for image classification on Imagenet is usually done with [224, 224] 90 | inputs, resulting in [7, 7] feature maps at the output of the last ResNet 91 | block for the ResNets defined in [1] that have nominal stride equal to 32. 92 | However, for dense prediction tasks we advise that one uses inputs with 93 | spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In 94 | this case the feature maps at the ResNet output will have spatial shape 95 | [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1] 96 | and corners exactly aligned with the input image corners, which greatly 97 | facilitates alignment of the features to the image. Using as input [225, 225] 98 | images results in [8, 8] feature maps at the output of the last ResNet block. 99 | 100 | For dense prediction tasks, the ResNet needs to run in fully-convolutional 101 | (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all 102 | have nominal stride equal to 32 and a good choice in FCN mode is to use 103 | output_stride=16 in order to increase the density of the computed features at 104 | small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915. 105 | 106 | Args: 107 | inputs: A tensor of size [batch, height_in, width_in, channels]. 108 | blocks: A list of length equal to the number of ResNet blocks. Each element 109 | is a resnet_utils.Block object describing the units in the block. 110 | num_classes: Number of predicted classes for classification tasks. 111 | If 0 or None, we return the features before the logit layer. 112 | is_training: whether batch_norm layers are in training mode. 113 | global_pool: If True, we perform global average pooling before computing the 114 | logits. Set to True for image classification, False for dense prediction. 115 | output_stride: If None, then the output will be computed at the nominal 116 | network stride. If output_stride is not None, it specifies the requested 117 | ratio of input to output spatial resolution. 118 | include_root_block: If True, include the initial convolution followed by 119 | max-pooling, if False excludes it. 120 | spatial_squeeze: if True, logits is of shape [B, C], if false logits is 121 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes. 122 | To use this parameter, the input images must be smaller than 300x300 123 | pixels, in which case the output logit layer does not contain spatial 124 | information and can be removed. 125 | reuse: whether or not the network and its variables should be reused. To be 126 | able to reuse 'scope' must be given. 127 | scope: Optional variable_scope. 128 | 129 | Returns: 130 | net: A rank-4 tensor of size [batch, height_out, width_out, channels_out]. 131 | If global_pool is False, then height_out and width_out are reduced by a 132 | factor of output_stride compared to the respective height_in and width_in, 133 | else both height_out and width_out equal one. If num_classes is 0 or None, 134 | then net is the output of the last ResNet block, potentially after global 135 | average pooling. If num_classes a non-zero integer, net contains the 136 | pre-softmax activations. 137 | end_points: A dictionary from components of the network to the corresponding 138 | activation. 139 | 140 | Raises: 141 | ValueError: If the target output_stride is not valid. 142 | """ 143 | with tf.variable_scope(scope, 'resnet_v1', [inputs], reuse=reuse) as sc: 144 | end_points_collection = sc.original_name_scope + '_end_points' 145 | with slim.arg_scope([slim.conv2d, bottleneck, 146 | resnet_utils.stack_blocks_dense], 147 | outputs_collections=end_points_collection): 148 | with slim.arg_scope([slim.batch_norm], is_training=is_training): 149 | net = inputs 150 | if include_root_block: 151 | if output_stride is not None: 152 | if output_stride % 4 != 0: 153 | raise ValueError('The output_stride needs to be a multiple of 4.') 154 | output_stride /= 4 155 | net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1') 156 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') 157 | net = resnet_utils.stack_blocks_dense(net, blocks, output_stride) 158 | # Convert end_points_collection into a dictionary of end_points. 159 | end_points = slim.utils.convert_collection_to_dict( 160 | end_points_collection) 161 | 162 | if global_pool: 163 | # Global average pooling. 164 | net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True) 165 | end_points['global_pool'] = net 166 | if num_classes: 167 | net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, 168 | normalizer_fn=None, scope='logits') 169 | end_points[sc.name + '/logits'] = net 170 | if spatial_squeeze: 171 | net = tf.squeeze(net, [1, 2], name='SpatialSqueeze') 172 | end_points[sc.name + '/spatial_squeeze'] = net 173 | end_points['predictions'] = slim.softmax(net, scope='predictions') 174 | return net, end_points 175 | 176 | 177 | resnet_v1.default_image_size = 320 178 | 179 | 180 | def resnet_v1_block(scope, base_depth, num_units, stride): 181 | """Helper function for creating a resnet_v1 bottleneck block. 182 | 183 | Args: 184 | scope: The scope of the block. 185 | base_depth: The depth of the bottleneck layer for each unit. 186 | num_units: The number of units in the block. 187 | stride: The stride of the block, implemented as a stride in the last unit. 188 | All other units have stride=1. 189 | 190 | Returns: 191 | A resnet_v1 bottleneck block. 192 | """ 193 | return resnet_utils.Block(scope, bottleneck, [{ 194 | 'depth': base_depth * 4, 195 | 'depth_bottleneck': base_depth, 196 | 'stride': 1 197 | }] * (num_units - 1) + [{ 198 | 'depth': base_depth * 4, 199 | 'depth_bottleneck': base_depth, 200 | 'stride': stride 201 | }]) 202 | 203 | 204 | def resnet_v1_50(inputs, 205 | num_classes=None, 206 | is_training=True, 207 | global_pool=True, 208 | output_stride=None, 209 | spatial_squeeze=True, 210 | reuse=None, 211 | scope='resnet_v1_50'): 212 | """ResNet-50 model of [1]. See resnet_v1() for arg and return description.""" 213 | blocks = [ 214 | resnet_v1_block('block1', base_depth=64, num_units=3, stride=2), 215 | resnet_v1_block('block2', base_depth=128, num_units=4, stride=2), 216 | resnet_v1_block('block3', base_depth=256, num_units=6, stride=2), 217 | resnet_v1_block('block4', base_depth=512, num_units=3, stride=1), 218 | ] 219 | return resnet_v1(inputs, blocks, num_classes, is_training, 220 | global_pool=global_pool, output_stride=output_stride, 221 | include_root_block=True, spatial_squeeze=spatial_squeeze, 222 | reuse=reuse, scope=scope) 223 | 224 | 225 | resnet_v1_50.default_image_size = resnet_v1.default_image_size 226 | 227 | 228 | def resnet_v1_101(inputs, 229 | num_classes=None, 230 | is_training=True, 231 | global_pool=True, 232 | output_stride=None, 233 | spatial_squeeze=True, 234 | reuse=None, 235 | scope='resnet_v1_101'): 236 | """ResNet-101 model of [1]. See resnet_v1() for arg and return description.""" 237 | blocks = [ 238 | resnet_v1_block('block1', base_depth=64, num_units=3, stride=2), 239 | resnet_v1_block('block2', base_depth=128, num_units=4, stride=2), 240 | resnet_v1_block('block3', base_depth=256, num_units=23, stride=2), 241 | resnet_v1_block('block4', base_depth=512, num_units=3, stride=1), 242 | ] 243 | return resnet_v1(inputs, blocks, num_classes, is_training, 244 | global_pool=global_pool, output_stride=output_stride, 245 | include_root_block=True, spatial_squeeze=spatial_squeeze, 246 | reuse=reuse, scope=scope) 247 | 248 | 249 | resnet_v1_101.default_image_size = resnet_v1.default_image_size 250 | 251 | 252 | def resnet_v1_152(inputs, 253 | num_classes=None, 254 | is_training=True, 255 | global_pool=True, 256 | output_stride=None, 257 | spatial_squeeze=True, 258 | reuse=None, 259 | scope='resnet_v1_152'): 260 | """ResNet-152 model of [1]. See resnet_v1() for arg and return description.""" 261 | blocks = [ 262 | resnet_v1_block('block1', base_depth=64, num_units=3, stride=2), 263 | resnet_v1_block('block2', base_depth=128, num_units=8, stride=2), 264 | resnet_v1_block('block3', base_depth=256, num_units=36, stride=2), 265 | resnet_v1_block('block4', base_depth=512, num_units=3, stride=1), 266 | ] 267 | return resnet_v1(inputs, blocks, num_classes, is_training, 268 | global_pool=global_pool, output_stride=output_stride, 269 | include_root_block=True, spatial_squeeze=spatial_squeeze, 270 | reuse=reuse, scope=scope) 271 | 272 | 273 | resnet_v1_152.default_image_size = resnet_v1.default_image_size 274 | 275 | 276 | def resnet_v1_200(inputs, 277 | num_classes=None, 278 | is_training=True, 279 | global_pool=True, 280 | output_stride=None, 281 | spatial_squeeze=True, 282 | reuse=None, 283 | scope='resnet_v1_200'): 284 | """ResNet-200 model of [2]. See resnet_v1() for arg and return description.""" 285 | blocks = [ 286 | resnet_v1_block('block1', base_depth=64, num_units=3, stride=2), 287 | resnet_v1_block('block2', base_depth=128, num_units=24, stride=2), 288 | resnet_v1_block('block3', base_depth=256, num_units=36, stride=2), 289 | resnet_v1_block('block4', base_depth=512, num_units=3, stride=1), 290 | ] 291 | return resnet_v1(inputs, blocks, num_classes, is_training, 292 | global_pool=global_pool, output_stride=output_stride, 293 | include_root_block=True, spatial_squeeze=spatial_squeeze, 294 | reuse=reuse, scope=scope) 295 | 296 | 297 | resnet_v1_200.default_image_size = resnet_v1.default_image_size 298 | --------------------------------------------------------------------------------