├── .idea
├── misc.xml
├── modules.xml
├── semantic_segmentation_contest_deeplabv3.iml
├── vcs.xml
└── workspace.xml
├── DataGenerate
└── __pycache__
│ ├── GetDataset.cpython-36.pyc
│ └── __init__.cpython-36.pyc
├── GeneratingBatchSize
├── GetDataset.py
└── __init__.py
├── GeneratingDatasets
└── get_new_dataset.py
├── NET
├── __init__.py
├── __pycache__
│ └── __init__.cpython-36.pyc
├── aaf
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ └── layers.cpython-36.pyc
│ ├── layers.py
│ └── losses.py
├── deeplab_v3.py
├── deeplabv3_DA.py
├── deeplabv3_plus.py
├── pspnet.py
├── resnet_v2
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── resnet_utils.cpython-36.pyc
│ │ └── resnet_v2.cpython-36.pyc
│ ├── resnet_utils.py
│ └── resnet_v2.py
├── resnet_v2_psp
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── resnet_utils.cpython-36.pyc
│ │ └── resnet_v2.cpython-36.pyc
│ ├── resnet_utils.py
│ └── resnet_v2.py
└── self_attention_layers
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ └── self_attention_layers.cpython-36.pyc
│ └── self_attention_layers.py
├── README.md
├── eval.py
├── resource
├── 1.png
├── 2.png
└── 语义分割比赛进展.pptx
├── test1000_stride_400.py
├── test_stride_400.py
├── tools_aaf.py
├── tools_deeplabv3.py
├── tools_deeplabv3_DA.py
├── tools_deeplabv3plus.py
├── tools_psp.py
├── train_aaf.py
├── train_deeplabv3.py
├── train_deeplabv3_DA.py
├── train_deeplabv3plus.py
├── train_deeplabv3plus_4chanel.py
├── train_psp.py
└── utils
├── __init__.py
├── __pycache__
└── __init__.cpython-36.pyc
└── preprocessing.py
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/semantic_segmentation_contest_deeplabv3.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 | _NUM_CLASSES
34 |
35 |
36 |
37 |
38 |
39 |
40 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 | 1563523548273
110 |
111 |
112 | 1563523548273
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
--------------------------------------------------------------------------------
/DataGenerate/__pycache__/GetDataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/DataGenerate/__pycache__/GetDataset.cpython-36.pyc
--------------------------------------------------------------------------------
/DataGenerate/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/DataGenerate/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/GeneratingBatchSize/GetDataset.py:
--------------------------------------------------------------------------------
1 | from utils import preprocessing
2 | import tensorflow as tf
3 | import os
4 |
5 | # Randomly crop or pad a [_HEIGHT, _WIDTH] section of the image and label.
6 | _HEIGHT = 1000
7 | _WIDTH = 1000
8 |
9 | # image chanel
10 | _DEPTH = 3
11 |
12 | # Randomly scale the image and label
13 | _MIN_SCALE = 0.5
14 | _MAX_SCALE = 2.0
15 |
16 | # 忽略标签
17 | _IGNORE_LABEL = 255
18 |
19 | _NUM_IMAGES = {
20 | 'train': 5000,
21 | 'validation': 500
22 | }
23 | def get_filenames(is_training, data_dir):
24 | """Return a list of filenames.
25 |
26 | Args:
27 | is_training: A boolean denoting whether the input is for training.
28 | data_dir: path to the the directory containing the input data.
29 |
30 | Returns:
31 | A list of file names.
32 | """
33 | if is_training:
34 | return [os.path.join(data_dir, 'train3.tfrecord')]
35 | else:
36 | return [os.path.join(data_dir, 'val3.tfrecord')]
37 |
38 |
39 | def parse_record(raw_record):
40 | """Parse PASCAL image and label from a tf record."""
41 | features = tf.parse_single_example(
42 | raw_record, features={
43 | 'data': tf.FixedLenFeature([], tf.string),
44 | 'label': tf.FixedLenFeature([], tf.string),
45 | })
46 |
47 | image = tf.decode_raw(features['data'], tf.uint8)
48 | image = tf.reshape(image, [1000, 1000, 4])
49 | image = tf.cast(image, dtype=tf.float32)
50 |
51 | image = tf.gather(image, [0, 1, 2], axis=2)
52 | # 减去均值
53 | image_reshape = tf.reshape(image, [1000000, 3])
54 | mean = tf.reduce_mean(image_reshape, 0)
55 | image = image - mean
56 | label = tf.decode_raw(features['label'], tf.uint8)
57 | label = tf.reshape(label, [1000, 1000, 1])
58 | label = tf.cast(label, dtype=tf.int32)
59 |
60 | return image, label
61 |
62 | def preprocess_image(image, label, is_training):
63 | """Preprocess a single image of layout [height, width, depth]."""
64 | if is_training:
65 | # Randomly scale the image and label.
66 | image, label = preprocessing.random_rescale_image_and_label(image, label, _MIN_SCALE, _MAX_SCALE)
67 |
68 | # Randomly crop or pad a [_HEIGHT, _WIDTH] section of the image and label.
69 | image, label = preprocessing.random_crop_or_pad_image_and_label(image, label, _HEIGHT, _WIDTH, _IGNORE_LABEL)
70 |
71 | # Randomly flip the image and label horizontally.
72 | image, label = preprocessing.random_flip_left_right_image_and_label(
73 | image, label)
74 |
75 | image.set_shape([_HEIGHT, _WIDTH, 3])
76 | label.set_shape([_HEIGHT, _WIDTH, 1])
77 |
78 | #image = preprocessing.mean_image_subtraction(image)
79 |
80 | return image, label
81 |
82 | def train_or_eval_input_fn(is_training, data_dir, batch_size, num_epochs=None):
83 | """Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
84 |
85 | Args:
86 | is_training: A boolean denoting whether the input is for training.
87 | data_dir: The directory containing the input data.
88 | batch_size: The number of samples per batch.
89 | num_epochs: The number of epochs to repeat the dataset.
90 |
91 | Returns:
92 | A tuple of images and labels.
93 | """
94 | dataset = tf.data.Dataset.from_tensor_slices(get_filenames(is_training, data_dir))
95 | dataset = dataset.flat_map(tf.data.TFRecordDataset)
96 |
97 | if is_training:
98 | dataset = dataset.shuffle(buffer_size=500)
99 |
100 | dataset = dataset.map(parse_record)
101 | dataset = dataset.map(
102 | lambda image, label: preprocess_image(image, label, is_training))
103 | dataset = dataset.prefetch(batch_size)
104 |
105 | # We call repeat after shuffling, rather than before, to prevent separate
106 | # epochs from blending together.
107 | dataset = dataset.repeat(num_epochs)
108 | dataset = dataset.batch(batch_size)
109 |
110 | return dataset
111 |
112 | def eval_or_test_input_fn(image_filenames, label_filenames=None, batch_size=1):
113 | """An input function for evaluation and inference.
114 |
115 | Args:
116 | image_filenames: The file names for the inferred images.
117 | label_filenames: The file names for the grand truth labels.
118 | batch_size: The number of samples per batch. Need to be 1
119 | for the images of different sizes.
120 |
121 | Returns:
122 | A tuple of images and labels.
123 | """
124 | # Reads an image from a file, decodes it into a dense tensor
125 | def _parse_function(filename, is_label):
126 | if not is_label:
127 | image_filename, label_filename = filename, None
128 | else:
129 | image_filename, label_filename = filename
130 |
131 | image_string = tf.read_file(image_filename)
132 | image = tf.image.decode_image(image_string)
133 | image = tf.to_float(tf.image.convert_image_dtype(image, dtype=tf.uint8))
134 | image.set_shape([None, None, 4])
135 |
136 | image = preprocessing.mean_image_subtraction(image)
137 |
138 | if not is_label:
139 | return image
140 | else:
141 | label_string = tf.read_file(label_filename)
142 | label = tf.image.decode_image(label_string)
143 | label = tf.to_int32(tf.image.convert_image_dtype(label, dtype=tf.uint8))
144 | label.set_shape([None, None, 1])
145 |
146 | return image, label
147 |
148 | if label_filenames is None:
149 | input_filenames = image_filenames
150 | else:
151 | input_filenames = (image_filenames, label_filenames)
152 |
153 | dataset = tf.data.Dataset.from_tensor_slices(input_filenames)
154 | if label_filenames is None:
155 | dataset = dataset.map(lambda x: _parse_function(x, False))
156 | else:
157 | dataset = dataset.map(lambda x, y: _parse_function((x, y), True))
158 | dataset = dataset.prefetch(batch_size)
159 | dataset = dataset.batch(batch_size)
160 | iterator = dataset.make_one_shot_iterator()
161 |
162 | if label_filenames is None:
163 | images = iterator.get_next()
164 | labels = None
165 | else:
166 | images, labels = iterator.get_next()
167 |
168 | return images, labels
169 |
--------------------------------------------------------------------------------
/GeneratingBatchSize/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/GeneratingBatchSize/__init__.py
--------------------------------------------------------------------------------
/GeneratingDatasets/get_new_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | from tqdm import tqdm
4 | import random
5 | import tensorflow as tf
6 | import sys
7 | from libtiff import TIFF
8 | from scipy import io
9 | img_w = 1000
10 | img_h = 1000
11 | train_sets = ['train/1', 'train/2', 'train/3', 'train/4',
12 | 'train/5', 'train/6', 'train/7', 'train/8']
13 | val_sets = ['val/1', 'val/2']
14 |
15 | class0 = np.array([0, 0, 0])
16 | class1 = np.array([0, 200, 0])
17 | class2 = np.array([150, 250, 0])
18 | class3 = np.array([150, 200, 150])
19 | class4 = np.array([200, 0, 200])
20 | class5 = np.array([150, 0, 250])
21 | class6 = np.array([150, 150, 250])
22 | class7 = np.array([250, 200, 0])
23 | class8 = np.array([200, 200, 0])
24 | class9 = np.array([200, 0, 0])
25 | class10 = np.array([250, 0, 150])
26 | class11 = np.array([200, 150, 150])
27 | class12 = np.array([250, 150, 150])
28 | class13 = np.array([0, 0, 200])
29 | class14 = np.array([0, 150, 200])
30 | class15 = np.array([0, 200, 250])
31 |
32 | def creat_dataset(image_num=100000, image_sets=train_sets, type='train', mode='original'):
33 | print('creating dataset...')
34 | image_each = image_num / len(image_sets)
35 | g_count = 0
36 | for i in tqdm(range(len(image_sets))):
37 | count = 0
38 | tif = TIFF.open('../DatasetOrigin/' + image_sets[i] + '.tif', mode='r') # 4 channels
39 | src_img = tif.read_image()
40 | #src_img_new = cv2.imread('../DatasetOrigin/' + image_sets[i] + '.tif', cv2.IMREAD_UNCHANGED) # 4 channels
41 | #label_img = cv2.imread('../DatasetOrigin/' + image_sets[i] + '_label.tif', cv2.IMREAD_COLOR) # 3 channels
42 | tif_label = TIFF.open('../DatasetOrigin/' + image_sets[i] + '_label.tif', mode='r')
43 | label_img = tif_label.read_image()
44 | X_height, X_width, _ = src_img.shape
45 | while count < image_each:
46 | random_width = random.randint(0, X_width - img_w - 1)
47 | random_height = random.randint(0, X_height - img_h - 1)
48 | src_roi = src_img[random_height: random_height + img_h, random_width: random_width + img_w, :]
49 | label_roi = label_img[random_height: random_height + img_h, random_width: random_width + img_w]
50 |
51 | #visualize = np.zeros((256, 256)).astype(np.uint8)
52 | #visualize = label_roi * 50
53 |
54 | if type == "train":
55 | #cv2.imwrite(('../DatasetNew/train/images/%d.tif' % g_count), src_roi)
56 | io.savemat('../DatasetNew/train/images/%d.mat' % g_count, {"feature": src_roi})
57 | #cv2.imwrite(('../DatasetNew/train/labels/%d.png' % g_count), label_roi)
58 | io.savemat('../DatasetNew/train/labels/%d.mat' % g_count, {"feature": label_roi})
59 | else:
60 | #cv2.imwrite(('../DatasetNew/val/images/%d.tif' % g_count), src_roi)
61 | io.savemat('../DatasetNew/val/images/%d.mat' % g_count, {"feature": src_roi})
62 | #cv2.imwrite(('../DatasetNew/val/labels/%d.png' % g_count), label_roi)
63 | io.savemat('../DatasetNew/val/labels/%d.mat' % g_count, {"feature": label_roi})
64 | count += 1
65 | g_count += 1
66 | def _bytes_feature(value):
67 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
68 | def _datas_to_tfexample(data, label):
69 | return tf.train.Example(features=tf.train.Features(feature={
70 | 'data': _bytes_feature(data),
71 | 'label': _bytes_feature(label)
72 | }))
73 |
74 |
75 | def to_tfrecord_train(train_filename, length):
76 | tfrecord_writer = tf.python_io.TFRecordWriter(train_filename)
77 | for key in tqdm(range(length)):
78 | #img_data = cv2.imread('../DatasetNew/train/images/%d.tif' % key, cv2.IMREAD_UNCHANGED)
79 | img_data = io.loadmat('../DatasetNew/train/images/%d.mat' % key)["feature"] # uint8
80 | label_data = cv2.imread('../DatasetNew/train/labels_2d/%d.png' % key, cv2.IMREAD_GRAYSCALE) # int8
81 | img_data = img_data.tobytes()
82 | label_data = label_data.tobytes()
83 | # 生成example
84 | example = _datas_to_tfexample(img_data, label_data)
85 | tfrecord_writer.write(example.SerializeToString())
86 | #sys.stdout.write("\r>> Converting image" + output_filename)
87 | tfrecord_writer.close()
88 | sys.stdout.flush()
89 | def to_tfrecord_val(val_filename, length):
90 | tfrecord_writer = tf.python_io.TFRecordWriter(val_filename)
91 | for key in tqdm(range(length)):
92 | #img_data = cv2.imread('../DatasetNew/val/images/%d.tif' % key, cv2.IMREAD_UNCHANGED)
93 | img_data = io.loadmat('../DatasetNew/val/images/%d.mat' % key)["feature"]
94 | label_data = cv2.imread('../DatasetNew/val/labels_2d/%d.png' % key, cv2.IMREAD_GRAYSCALE)
95 | img_data = img_data.tobytes()
96 | label_data = label_data.tobytes()
97 | # 生成example
98 | example = _datas_to_tfexample(img_data, label_data)
99 | tfrecord_writer.write(example.SerializeToString())
100 | #sys.stdout.write("\r>> Converting image" + output_filename)
101 | tfrecord_writer.close()
102 | sys.stdout.flush()
103 |
104 | def encode_labels(image_color_data):
105 | iamge_data_RGB = image_color_data
106 | height, width, chanel = iamge_data_RGB.shape
107 | label_seg = np.zeros([height, width], dtype=np.int8)
108 | label_seg[(iamge_data_RGB == class0).all(axis=2)] = 0
109 | label_seg[(iamge_data_RGB == class1).all(axis=2)] = 1
110 | label_seg[(iamge_data_RGB == class2).all(axis=2)] = 2
111 | label_seg[(iamge_data_RGB == class3).all(axis=2)] = 3
112 | label_seg[(iamge_data_RGB == class4).all(axis=2)] = 4
113 | label_seg[(iamge_data_RGB == class5).all(axis=2)] = 5
114 | label_seg[(iamge_data_RGB == class6).all(axis=2)] = 6
115 | label_seg[(iamge_data_RGB == class7).all(axis=2)] = 7
116 | label_seg[(iamge_data_RGB == class8).all(axis=2)] = 8
117 | label_seg[(iamge_data_RGB == class9).all(axis=2)] = 9
118 | label_seg[(iamge_data_RGB == class10).all(axis=2)] = 10
119 | label_seg[(iamge_data_RGB == class11).all(axis=2)] = 11
120 | label_seg[(iamge_data_RGB == class12).all(axis=2)] = 12
121 | label_seg[(iamge_data_RGB == class13).all(axis=2)] = 13
122 | label_seg[(iamge_data_RGB == class14).all(axis=2)] = 14
123 | label_seg[(iamge_data_RGB == class15).all(axis=2)] = 15
124 | return label_seg
125 |
126 |
127 | if __name__ == "__main__":
128 | creat_dataset(image_num=500, image_sets=val_sets, type='val', mode='original')
129 | creat_dataset(image_num=5000, image_sets=train_sets, type='train', mode='original')
130 |
131 | # 把RGB标签换成灰度标签
132 | for key in tqdm(range(5000)):
133 | src_img = io.loadmat('../DatasetNew/train/labels/%d.mat' % key)
134 | new_data = encode_labels(src_img["feature"])
135 | result = cv2.imwrite("../DatasetNew/train/labels_2d/%d.png" % key, new_data)
136 | for key in tqdm(range(500)):
137 | src_img = io.loadmat('../DatasetNew/val/labels/%d.mat' % key)
138 | new_data = encode_labels(src_img["feature"])
139 | result = cv2.imwrite('../DatasetNew/val/labels_2d/%d.png' % key, new_data)
140 | to_tfrecord_train(train_filename="../DatasetNew/train/train3.tfrecord", length=5000)
141 | to_tfrecord_val(val_filename="../DatasetNew/val/val3.tfrecord", length=500)
--------------------------------------------------------------------------------
/NET/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/NET/__init__.py
--------------------------------------------------------------------------------
/NET/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/NET/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/NET/aaf/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/NET/aaf/__init__.py
--------------------------------------------------------------------------------
/NET/aaf/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/NET/aaf/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/NET/aaf/__pycache__/layers.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/NET/aaf/__pycache__/layers.cpython-36.pyc
--------------------------------------------------------------------------------
/NET/aaf/layers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def eightway_activation(x):
5 | """Retrieves neighboring pixels/features on the eight corners from
6 | a 3x3 patch.
7 |
8 | Args:
9 | x: A tensor of size [batch_size, height_in, width_in, channels]
10 |
11 | Returns:
12 | A tensor of size [batch_size, height_in, width_in, channels, 8]
13 | """
14 | # Get the number of channels in the input.
15 | shape_x = x.get_shape().as_list()
16 | if len(shape_x) != 4:
17 | raise ValueError('Only support for 4-D tensors!')
18 |
19 | # Pad at the margin.
20 | x = tf.pad(x,
21 | paddings=[[0,0],[1,1],[1,1],[0,0]],
22 | mode='SYMMETRIC')
23 | # Get eight neighboring pixels/features.
24 | x_groups = [
25 | x[:, 1:-1, :-2, :], # left
26 | x[:, 1:-1, 2:, :], # right
27 | x[:, :-2, 1:-1, :], # up
28 | x[:, 2:, 1:-1, :], # down
29 | x[:, :-2, :-2, :], # left-up
30 | x[:, 2:, :-2, :], # left-down
31 | x[:, :-2, 2:, :], # right-up
32 | x[:, 2:, 2:, :] # right-down
33 | ]
34 | output = [
35 | tf.expand_dims(c, axis=-1) for c in x_groups
36 | ]
37 | output = tf.concat(output, axis=-1)
38 |
39 | return output
40 |
41 |
42 | def eightcorner_activation(x, size):
43 | """Retrieves neighboring pixels one the eight corners from a
44 | (2*size+1)x(2*size+1) patch.
45 |
46 | Args:
47 | x: A tensor of size [batch_size, height_in, width_in, channels]
48 | size: A number indicating the half size of a patch.
49 |
50 | Returns:
51 | A tensor of size [batch_size, height_in, width_in, channels, 8]
52 | """
53 | # Get the number of channels in the input.
54 | shape_x = x.get_shape().as_list()
55 | if len(shape_x) != 4:
56 | raise ValueError('Only support for 4-D tensors!')
57 | n, h, w, c = shape_x
58 | h = 500
59 | w = 500
60 | # Pad at the margin.
61 | p = size
62 | x_pad = tf.pad(x,
63 | paddings=[[0,0],[p,p],[p,p],[0,0]],
64 | mode='CONSTANT',
65 | constant_values=0)
66 |
67 | # Get eight corner pixels/features in the patch.
68 | x_groups = []
69 | for st_y in range(0,2*size+1,size):
70 | for st_x in range(0,2*size+1,size):
71 | if st_y == size and st_x == size:
72 | # Ignore the center pixel/feature.
73 | continue
74 |
75 | x_neighbor = x_pad[:, st_y:st_y+h, st_x:st_x+w, :]
76 | x_groups.append(x_neighbor)
77 |
78 | output = [tf.expand_dims(c, axis=-1) for c in x_groups]
79 | output = tf.concat(output, axis=-1)
80 |
81 | return output
82 |
83 |
84 | def ignores_from_label(labels, num_classes, size):
85 | """Retrieves ignorable pixels from the ground-truth labels.
86 |
87 | This function returns a binary map in which 1 denotes ignored pixels
88 | and 0 means not ignored ones. For those ignored pixels, they are not
89 | only the pixels with label value >= num_classes, but also the
90 | corresponding neighboring pixels, which are on the the eight cornerls
91 | from a (2*size+1)x(2*size+1) patch.
92 |
93 | Args:
94 | labels: A tensor of size [batch_size, height_in, width_in], indicating
95 | semantic segmentation ground-truth labels.
96 | num_classes: A number indicating the total number of valid classes. The
97 | labels ranges from 0 to (num_classes-1), and any value >= num_classes
98 | would be ignored.
99 | size: A number indicating the half size of a patch.
100 |
101 | Return:
102 | A tensor of size [batch_size, height_in, width_in, 8]
103 | """
104 | # Get the number of channels in the input.
105 | shape_lab = labels.get_shape().as_list()
106 | if len(shape_lab) != 3:
107 | raise ValueError('Only support for 3-D label tensors!')
108 | n, h, w = shape_lab
109 |
110 | # Retrieve ignored pixels with label value >= num_classes.
111 | ignore = tf.greater(labels, num_classes-1) # NxHxW
112 |
113 | # Pad at the margin.
114 | p = size
115 | ignore_pad = tf.pad(ignore,
116 | paddings=[[0,0],[p,p],[p,p]],
117 | mode='CONSTANT',
118 | constant_values=True)
119 |
120 | # Retrieve eight corner pixels from the center, where the center
121 | # is ignored. Note that it should be bi-directional. For example,
122 | # when computing AAF loss with top-left pixels, the ignored pixels
123 | # might be the center or the top-left ones.
124 | ignore_groups= []
125 | for st_y in range(2*size,-1,-size):
126 | for st_x in range(2*size,-1,-size):
127 | if st_y == size and st_x == size:
128 | continue
129 | ignore_neighbor = ignore_pad[:,st_y:st_y+h,st_x:st_x+w]
130 | mask = tf.logical_or(ignore_neighbor, ignore)
131 | ignore_groups.append(mask)
132 |
133 | ig = 0
134 | for st_y in range(0,2*size+1,size):
135 | for st_x in range(0,2*size+1,size):
136 | if st_y == size and st_x == size:
137 | continue
138 | ignore_neighbor = ignore_pad[:,st_y:st_y+h,st_x:st_x+w]
139 | mask = tf.logical_or(ignore_neighbor, ignore_groups[ig])
140 | ignore_groups[ig] = mask
141 | ig += 1
142 |
143 | ignore_groups = [
144 | tf.expand_dims(c, axis=-1) for c in ignore_groups
145 | ] # NxHxWx1
146 | ignore = tf.concat(ignore_groups, axis=-1) #NxHxWx8
147 |
148 | return ignore
149 |
150 |
151 | def edges_from_label(labels, size, ignore_class=255):
152 | """Retrieves edge positions from the ground-truth labels.
153 |
154 | This function computes the edge map by considering if the pixel values
155 | are equal between the center and the neighboring pixels on the eight
156 | corners from a (2*size+1)*(2*size+1) patch. Ignore edges where the any
157 | of the paired pixels with label value >= num_classes.
158 |
159 | Args:
160 | labels: A tensor of size [batch_size, height_in, width_in], indicating
161 | semantic segmentation ground-truth labels.
162 | size: A number indicating the half size of a patch.
163 | ignore_class: A number indicating the label value to ignore.
164 |
165 | Return:
166 | A tensor of size [batch_size, height_in, width_in, 1, 8]
167 | """
168 | # Get the number of channels in the input.
169 | shape_lab = labels.get_shape().as_list()
170 | if len(shape_lab) != 4:
171 | raise ValueError('Only support for 4-D label tensors!')
172 | n, h, w, c = shape_lab
173 |
174 | # Pad at the margin.
175 | p = size
176 | labels_pad = tf.pad(
177 | labels, paddings=[[0,0],[p,p],[p,p],[0,0]],
178 | mode='CONSTANT',
179 | constant_values=ignore_class)
180 |
181 | # Get the edge by comparing label value of the center and it paired pixels.
182 | edge_groups= []
183 | for st_y in range(0,2*size+1,size):
184 | for st_x in range(0,2*size+1,size):
185 | if st_y == size and st_x == size:
186 | continue
187 | labels_neighbor = labels_pad[:,st_y:st_y+h,st_x:st_x+w]
188 | edge = tf.not_equal(labels_neighbor, labels)
189 | edge_groups.append(edge)
190 |
191 | edge_groups = [
192 | tf.expand_dims(c, axis=-1) for c in edge_groups
193 | ] # NxHxWx1x1
194 | edge = tf.concat(edge_groups, axis=-1) #NxHxWx1x8
195 |
196 | return edge
197 |
--------------------------------------------------------------------------------
/NET/aaf/losses.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | import NET.aaf.layers as nnx
4 |
5 |
6 | def affinity_loss(labels,
7 | probs,
8 | num_classes,
9 | kld_margin):
10 | """Affinity Field (AFF) loss.
11 |
12 | This function computes AFF loss. There are several components in the
13 | function:
14 | 1) extracts edges from the ground-truth labels.
15 | 2) extracts ignored pixels and their paired pixels (the neighboring
16 | pixels on the eight corners).
17 | 3) extracts neighboring pixels on the eight corners from a 3x3 patch.
18 | 4) computes KL-Divergence between center pixels and their neighboring
19 | pixels from the eight corners.
20 |
21 | Args:
22 | labels: A tensor of size [batch_size, height_in, width_in], indicating
23 | semantic segmentation ground-truth labels.
24 | probs: A tensor of size [batch_size, height_in, width_in, num_classes],
25 | indicating segmentation predictions.
26 | num_classes: A number indicating the total number of valid classes.
27 | kld_margin: A number indicating the margin for KL-Divergence at edge.
28 |
29 | Returns:
30 | Two 1-D tensors value indicating the loss at edge and non-edge.
31 | """
32 | # Compute ignore map (e.g, label of 255 and their paired pixels).
33 | labels = tf.squeeze(labels, axis=-1) # NxHxW
34 | ignore = nnx.ignores_from_label(labels, num_classes, 1) # NxHxWx8
35 | not_ignore = tf.logical_not(ignore)
36 | not_ignore = tf.expand_dims(not_ignore, axis=3) # NxHxWx1x8 # 不是ignore是true
37 |
38 | # Compute edge map.
39 | one_hot_lab = tf.one_hot(labels, depth=num_classes)
40 | edge = nnx.edges_from_label(one_hot_lab, 1, 255) # NxHxWxCx8 # zhenjie不相等是ture
41 |
42 | # Remove ignored pixels from the edge/non-edge.
43 | edge = tf.logical_and(edge, not_ignore) # zhenjie NxHxWxCx8
44 | not_edge = tf.logical_and(tf.logical_not(edge), not_ignore) # zhenjie NxHxWxCx8
45 |
46 | edge_indices = tf.where(tf.reshape(edge, [-1]))
47 | not_edge_indices = tf.where(tf.reshape(not_edge, [-1]))
48 |
49 | # Extract eight corner from the center in a patch as paired pixels.
50 | probs_paired = nnx.eightcorner_activation(probs, 1) # NxHxWxCx8
51 | probs = tf.expand_dims(probs, axis=-1) # NxHxWxCx1
52 | bot_epsilon = tf.constant(1e-4, name='bot_epsilon')
53 | top_epsilon = tf.constant(1.0, name='top_epsilon')
54 | neg_probs = tf.clip_by_value(
55 | 1-probs, bot_epsilon, top_epsilon)
56 | probs = tf.clip_by_value(
57 | probs, bot_epsilon, top_epsilon)
58 | neg_probs_paired= tf.clip_by_value(
59 | 1-probs_paired, bot_epsilon, top_epsilon)
60 | probs_paired = tf.clip_by_value(
61 | probs_paired, bot_epsilon, top_epsilon)
62 |
63 | # Compute KL-Divergence.
64 | kldiv = probs_paired*tf.log(probs_paired/probs)
65 | kldiv += neg_probs_paired*tf.log(neg_probs_paired/neg_probs)
66 | not_edge_loss = kldiv
67 | edge_loss = tf.maximum(0.0, kld_margin-kldiv)
68 |
69 | not_edge_loss = tf.reshape(not_edge_loss, [-1])
70 | not_edge_loss = tf.gather(not_edge_loss, not_edge_indices)
71 | edge_loss = tf.reshape(edge_loss, [-1])
72 | edge_loss = tf.gather(edge_loss, edge_indices)
73 |
74 | return edge_loss, not_edge_loss
75 |
76 |
77 | def adaptive_affinity_loss(labels,
78 | one_hot_lab,
79 | probs,
80 | size,
81 | num_classes,
82 | kld_margin,
83 | w_edge,
84 | w_not_edge):
85 | """Adaptive affinity field (AAF) loss.
86 |
87 | This function computes AAF loss. There are three components in the function:
88 | 1) extracts edges from the ground-truth labels.
89 | 2) extracts ignored pixels and their paired pixels (usually the eight corner
90 | pixels).
91 | 3) extracts eight corner pixels/predictions from the center in a
92 | (2*size+1)x(2*size+1) patch
93 | 4) computes KL-Divergence between center pixels and their paired pixels (the
94 | eight corner).
95 | 5) imposes adaptive weightings on the loss.
96 |
97 | Args:
98 | labels: A tensor of size [batch_size, height_in, width_in], indicating
99 | semantic segmentation ground-truth labels.
100 | one_hot_lab: A tensor of size [batch_size, height_in, width_in, num_classes]
101 | which is the ground-truth labels in the form of one-hot vector.
102 | probs: A tensor of size [batch_size, height_in, width_in, num_classes],
103 | indicating segmentation predictions.
104 | size: A number indicating the half size of a patch.
105 | num_classes: A number indicating the total number of valid classes. The
106 | kld_margin: A number indicating the margin for KL-Divergence at edge.
107 | w_edge: A number indicating the weighting for KL-Divergence at edge.
108 | w_not_edge: A number indicating the weighting for KL-Divergence at non-edge.
109 |
110 | Returns:
111 | Two 1-D tensors value indicating the loss at edge and non-edge.
112 | """
113 | # Compute ignore map (e.g, label of 255 and their paired pixels).
114 | labels = tf.squeeze(labels, axis=-1) # NxHxW
115 | ignore = nnx.ignores_from_label(labels, num_classes, size) # NxHxWx8
116 | not_ignore = tf.logical_not(ignore)
117 | not_ignore = tf.expand_dims(not_ignore, axis=3) # NxHxWx1x8
118 |
119 | # Compute edge map.
120 | edge = nnx.edges_from_label(one_hot_lab, size, 255) # NxHxWxCx8
121 |
122 | # Remove ignored pixels from the edge/non-edge.
123 | edge = tf.logical_and(edge, not_ignore)
124 | not_edge = tf.logical_and(tf.logical_not(edge), not_ignore)
125 |
126 | edge_indices = tf.where(tf.reshape(edge, [-1]))
127 | not_edge_indices = tf.where(tf.reshape(not_edge, [-1]))
128 |
129 | # Extract eight corner from the center in a patch as paired pixels.
130 | probs_paired = nnx.eightcorner_activation(probs, size) # NxHxWxCx8
131 | probs = tf.expand_dims(probs, axis=-1) # NxHxWxCx1
132 | bot_epsilon = tf.constant(1e-4, name='bot_epsilon')
133 | top_epsilon = tf.constant(1.0, name='top_epsilon')
134 |
135 | neg_probs = tf.clip_by_value(
136 | 1-probs, bot_epsilon, top_epsilon)
137 | neg_probs_paired = tf.clip_by_value(
138 | 1-probs_paired, bot_epsilon, top_epsilon)
139 | probs = tf.clip_by_value(
140 | probs, bot_epsilon, top_epsilon)
141 | probs_paired = tf.clip_by_value(
142 | probs_paired, bot_epsilon, top_epsilon)
143 |
144 | # Compute KL-Divergence.
145 | kldiv = probs_paired*tf.log(probs_paired/probs)
146 | kldiv += neg_probs_paired*tf.log(neg_probs_paired/neg_probs)
147 | edge_loss = tf.maximum(0.0, kld_margin-kldiv)
148 | not_edge_loss = kldiv
149 |
150 | # Impose weights on edge/non-edge losses.
151 | one_hot_lab = tf.expand_dims(one_hot_lab, axis=-1)
152 | w_edge = tf.reduce_sum(w_edge*one_hot_lab, axis=3, keep_dims=True) # NxHxWx1x1
153 | w_not_edge = tf.reduce_sum(w_not_edge*one_hot_lab, axis=3, keep_dims=True) # NxHxWx1x1
154 |
155 | edge_loss *= w_edge
156 | not_edge_loss *= w_not_edge
157 |
158 | not_edge_loss = tf.reshape(not_edge_loss, [-1])
159 | not_edge_loss = tf.gather(not_edge_loss, not_edge_indices)
160 | edge_loss = tf.reshape(edge_loss, [-1])
161 | edge_loss = tf.gather(edge_loss, edge_indices)
162 |
163 | return edge_loss, not_edge_loss
164 |
--------------------------------------------------------------------------------
/NET/deeplab_v3.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from NET.resnet_v2 import resnet_utils, resnet_v2
3 | slim = tf.contrib.slim
4 |
5 | # ImageNet mean statistics
6 | _R_MEAN = 123.68
7 | _G_MEAN = 116.78
8 | _B_MEAN = 103.94
9 |
10 | @slim.add_arg_scope
11 | def atrous_spatial_pyramid_pooling(net, scope, depth=256, reuse=None):
12 | """
13 | ASPP consists of (a) one 1×1 convolution and three 3×3 convolutions with rates = (6, 12, 18) when output stride = 16
14 | (all with 256 filters and batch normalization), and (b) the image-level features as described in https://arxiv.org/abs/1706.05587
15 | :param net: tensor of shape [BATCH_SIZE, WIDTH, HEIGHT, DEPTH]
16 | :param scope: scope name of the aspp layer
17 | :return: network layer with aspp applyed to it.
18 | """
19 |
20 | with tf.variable_scope(scope, reuse=reuse):
21 | feature_map_size = tf.shape(net)
22 |
23 | # apply global average pooling
24 | image_level_features = tf.reduce_mean(net, [1, 2], name='image_level_global_pool', keep_dims=True)
25 | image_level_features = slim.conv2d(image_level_features, depth, [1, 1], scope="image_level_conv_1x1",
26 | activation_fn=None)
27 | image_level_features = tf.image.resize_bilinear(image_level_features, (feature_map_size[1], feature_map_size[2]))
28 |
29 | at_pool1x1 = slim.conv2d(net, depth, [1, 1], scope="conv_1x1_0", activation_fn=None)
30 |
31 | at_pool3x3_1 = slim.conv2d(net, depth, [3, 3], scope="conv_3x3_1", rate=6, activation_fn=None)
32 |
33 | at_pool3x3_2 = slim.conv2d(net, depth, [3, 3], scope="conv_3x3_2", rate=12, activation_fn=None)
34 |
35 | at_pool3x3_3 = slim.conv2d(net, depth, [3, 3], scope="conv_3x3_3", rate=18, activation_fn=None)
36 |
37 | net = tf.concat((image_level_features, at_pool1x1, at_pool3x3_1, at_pool3x3_2, at_pool3x3_3), axis=3,
38 | name="concat")
39 | net = slim.conv2d(net, depth, [1, 1], scope="conv_1x1_output", activation_fn=None)
40 | return net
41 |
42 |
43 | def deeplab_v3(inputs, args, is_training, reuse):
44 |
45 | # mean subtraction normalization
46 | #inputs = inputs - [_R_MEAN, _G_MEAN, _B_MEAN]
47 |
48 | # inputs has shape - Original: [batch, 256, 256, 4]
49 | with slim.arg_scope(resnet_utils.resnet_arg_scope(args.l2_regularizer, is_training,
50 | args.batch_norm_decay,
51 | args.batch_norm_epsilon)):
52 | resnet = getattr(resnet_v2, args.resnet_model)
53 | _, end_points = resnet(inputs,
54 | args.number_of_classes,
55 | is_training=is_training,
56 | global_pool=False,
57 | spatial_squeeze=False,
58 | output_stride=args.output_stride,
59 | reuse=reuse)
60 | #if is_training:
61 | # exclude = [args.resnet_model + '/logits', 'global_step']
62 | # variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude)
63 | # tf.train.init_from_checkpoint(args.pre_trained_model,
64 | # {v.name.split(':')[0]: v for v in variables_to_restore})
65 |
66 | with tf.variable_scope("DeepLab_v3", reuse=reuse):
67 |
68 | # get block 4 feature outputs
69 | net = end_points[args.resnet_model + '/block4']
70 |
71 | net = atrous_spatial_pyramid_pooling(net, "ASPP_layer", depth=256, reuse=reuse)
72 | net = tf.layers.dropout(net, rate=0.5, training=is_training)
73 | net = slim.conv2d(net, args.number_of_classes, [1, 1], activation_fn=None,
74 | normalizer_fn=None, scope='logits')
75 |
76 | size = tf.shape(inputs)[1:3]
77 | # resize the output logits to match the labels dimensions
78 | #net = tf.image.resize_nearest_neighbor(net, size)
79 | net = tf.image.resize_bilinear(net, size)
80 | return net
81 |
82 |
83 |
--------------------------------------------------------------------------------
/NET/deeplabv3_DA.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from NET.resnet_v2_psp import resnet_utils, resnet_v2
3 | from NET.self_attention_layers.self_attention_layers import position_attention_module
4 | slim = tf.contrib.slim
5 |
6 | # ImageNet mean statistics
7 | _R_MEAN = 123.68
8 | _G_MEAN = 116.78
9 | _B_MEAN = 103.94
10 |
11 | @slim.add_arg_scope
12 | def atrous_spatial_pyramid_pooling(net, scope, depth=256, reuse=None):
13 | """
14 | ASPP consists of (a) one 1×1 convolution and three 3×3 convolutions with rates = (6, 12, 18) when output stride = 16
15 | (all with 256 filters and batch normalization), and (b) the image-level features as described in https://arxiv.org/abs/1706.05587
16 | :param net: tensor of shape [BATCH_SIZE, WIDTH, HEIGHT, DEPTH]
17 | :param scope: scope name of the aspp layer
18 | :return: network layer with aspp applyed to it.
19 | """
20 |
21 | with tf.variable_scope(scope, reuse=reuse):
22 | #feature_map_size = tf.shape(net)
23 | reduce_chanel = slim.conv2d(net, depth, [3, 3], scope="reduce_chanel", activation_fn=None)
24 | position_feature = position_attention_module(reduce_chanel)
25 | # apply global average pooling
26 | #image_level_features = tf.reduce_mean(net, [1, 2], name='image_level_global_pool', keep_dims=True)
27 | #image_level_features = slim.conv2d(image_level_features, depth, [1, 1], scope="image_level_conv_1x1",
28 | # activation_fn=None)
29 | #image_level_features = tf.image.resize_bilinear(image_level_features, (feature_map_size[1], feature_map_size[2]))
30 |
31 | at_pool1x1 = slim.conv2d(net, depth, [1, 1], scope="conv_1x1_0", activation_fn=None)
32 |
33 | at_pool3x3_1 = slim.conv2d(net, depth, [3, 3], scope="conv_3x3_1", rate=6, activation_fn=None)
34 |
35 | at_pool3x3_2 = slim.conv2d(net, depth, [3, 3], scope="conv_3x3_2", rate=12, activation_fn=None)
36 |
37 | at_pool3x3_3 = slim.conv2d(net, depth, [3, 3], scope="conv_3x3_3", rate=18, activation_fn=None)
38 |
39 | net = tf.concat((position_feature, at_pool1x1, at_pool3x3_1, at_pool3x3_2, at_pool3x3_3), axis=3,
40 | name="concat")
41 | net = slim.conv2d(net, depth, [1, 1], scope="conv_1x1_output", activation_fn=None)
42 | return net
43 |
44 |
45 | def deeplabv3_DA(inputs, args, is_training, reuse):
46 |
47 | # mean subtraction normalization
48 | #inputs = inputs - [_R_MEAN, _G_MEAN, _B_MEAN]
49 |
50 | # inputs has shape - Original: [batch, 256, 256, 4]
51 | with slim.arg_scope(resnet_utils.resnet_arg_scope(args.l2_regularizer, is_training,
52 | args.batch_norm_decay,
53 | args.batch_norm_epsilon)):
54 | resnet = getattr(resnet_v2, args.resnet_model)
55 | _, end_points = resnet(inputs,
56 | args.number_of_classes,
57 | is_training=is_training,
58 | global_pool=False,
59 | spatial_squeeze=False,
60 | output_stride=args.output_stride,
61 | reuse=reuse)
62 | #if is_training:
63 | # exclude = [args.resnet_model + '/logits', 'global_step']
64 | # variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude)
65 | # tf.train.init_from_checkpoint(args.pre_trained_model,
66 | # {v.name.split(':')[0]: v for v in variables_to_restore})
67 |
68 | with tf.variable_scope("DeepLab_v3", reuse=reuse):
69 |
70 | # get block 4 feature outputs
71 | net = end_points[args.resnet_model + '/block4']
72 |
73 | net = atrous_spatial_pyramid_pooling(net, "ASPP_layer", depth=256, reuse=reuse)
74 | net = tf.layers.dropout(net, rate=0.5, training=is_training)
75 |
76 | net = slim.conv2d(net, args.number_of_classes, [1, 1], activation_fn=None,
77 | normalizer_fn=None, scope='logits')
78 |
79 | size = tf.shape(inputs)[1:3]
80 | # resize the output logits to match the labels dimensions
81 | #net = tf.image.resize_nearest_neighbor(net, size)
82 | net = tf.image.resize_bilinear(net, size)
83 | return net
84 |
85 |
86 |
--------------------------------------------------------------------------------
/NET/deeplabv3_plus.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from NET.resnet_v2 import resnet_utils, resnet_v2
3 | slim = tf.contrib.slim
4 |
5 | # ImageNet mean statistics
6 | _R_MEAN = 123.68
7 | _G_MEAN = 116.78
8 | _B_MEAN = 103.94
9 |
10 | @slim.add_arg_scope
11 | def atrous_spatial_pyramid_pooling(net, scope, depth=256, reuse=None):
12 | """
13 | ASPP consists of (a) one 1×1 convolution and three 3×3 convolutions with rates = (6, 12, 18) when output stride = 16
14 | (all with 256 filters and batch normalization), and (b) the image-level features as described in https://arxiv.org/abs/1706.05587
15 | :param net: tensor of shape [BATCH_SIZE, WIDTH, HEIGHT, DEPTH]
16 | :param scope: scope name of the aspp layer
17 | :return: network layer with aspp applyed to it.
18 | """
19 |
20 | with tf.variable_scope(scope, reuse=reuse):
21 | feature_map_size = tf.shape(net)
22 |
23 | # apply global average pooling
24 | image_level_features = tf.reduce_mean(net, [1, 2], name='image_level_global_pool', keep_dims=True)
25 | image_level_features = slim.conv2d(image_level_features, depth, [1, 1], scope="image_level_conv_1x1",
26 | activation_fn=None)
27 | image_level_features = tf.image.resize_bilinear(image_level_features, (feature_map_size[1], feature_map_size[2]))
28 |
29 | at_pool1x1 = slim.conv2d(net, depth, [1, 1], scope="conv_1x1_0", activation_fn=None)
30 |
31 | at_pool3x3_1 = slim.conv2d(net, depth, [3, 3], scope="conv_3x3_1", rate=6, activation_fn=None)
32 |
33 | at_pool3x3_2 = slim.conv2d(net, depth, [3, 3], scope="conv_3x3_2", rate=12, activation_fn=None)
34 |
35 | at_pool3x3_3 = slim.conv2d(net, depth, [3, 3], scope="conv_3x3_3", rate=18, activation_fn=None)
36 |
37 | net = tf.concat((image_level_features, at_pool1x1, at_pool3x3_1, at_pool3x3_2, at_pool3x3_3), axis=3,
38 | name="concat")
39 | net = slim.conv2d(net, depth, [1, 1], scope="conv_1x1_output", activation_fn=None)
40 | return net
41 |
42 |
43 | def deeplabv3_plus(inputs, args, is_training, reuse):
44 |
45 | # mean subtraction normalization
46 | #inputs = inputs - [_R_MEAN, _G_MEAN, _B_MEAN]
47 |
48 | # inputs has shape - Original: [batch, 256, 256, 4]
49 | with slim.arg_scope(resnet_utils.resnet_arg_scope(args.l2_regularizer, is_training,
50 | args.batch_norm_decay,
51 | args.batch_norm_epsilon)):
52 | resnet = getattr(resnet_v2, args.resnet_model)
53 | _, end_points = resnet(inputs,
54 | args.number_of_classes,
55 | is_training=is_training,
56 | global_pool=False,
57 | spatial_squeeze=False,
58 | output_stride=args.output_stride,
59 | reuse=reuse)
60 | #if is_training:
61 | # exclude = [args.resnet_model + '/logits', 'global_step']
62 | # variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude)
63 | # tf.train.init_from_checkpoint(args.pre_trained_model,
64 | # {v.name.split(':')[0]: v for v in variables_to_restore})
65 |
66 | with tf.variable_scope("DeepLab_v3", reuse=reuse):
67 |
68 | # get block 4 feature outputs
69 | net = end_points[args.resnet_model + '/block4']
70 |
71 | encoder_output = atrous_spatial_pyramid_pooling(net, "ASPP_layer", depth=256, reuse=reuse)
72 |
73 | #net = slim.conv2d(net, args.number_of_classes, [1, 1], activation_fn=None,
74 | # normalizer_fn=None, scope='logits')
75 |
76 | #size = tf.shape(inputs)[1:3]
77 | # resize the output logits to match the labels dimensions
78 | #net = tf.image.resize_nearest_neighbor(net, size)
79 | #net = tf.image.resize_bilinear(net, size)
80 | with tf.variable_scope("decoder", reuse=reuse):
81 | with tf.variable_scope("low_level_features"):
82 | low_level_features = end_points[args.resnet_model + '/block1/unit_3/bottleneck_v2/conv1']
83 | low_level_features = slim.conv2d(low_level_features, 48,
84 | [1, 1], normalizer_fn=None, scope='conv_1x1')
85 | low_level_features_size = tf.shape(low_level_features)[1:3]
86 |
87 | with tf.variable_scope("upsampling_logits"):
88 | net_decode = tf.image.resize_bilinear(encoder_output, low_level_features_size, name='upsample_1')
89 | net_decode = tf.concat([net_decode, low_level_features], axis=3, name='concat')
90 | net_decode = slim.conv2d(net_decode, 256, [3, 3], normalizer_fn=None, scope='conv_3x3_1')
91 | net_decode = slim.conv2d(net_decode, 256, [3, 3], normalizer_fn=None, scope='conv_3x3_2')
92 | net_decode = slim.conv2d(net_decode, args.number_of_classes, [1, 1], activation_fn=None, normalizer_fn=None,
93 | scope='conv_1x1')
94 | size = tf.shape(inputs)[1:3]
95 | logits = tf.image.resize_bilinear(net_decode, size, name='upsample_2')
96 | return logits
97 |
98 |
99 |
--------------------------------------------------------------------------------
/NET/pspnet.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from NET.resnet_v2_psp import resnet_utils, resnet_v2
3 |
4 | slim = tf.contrib.slim
5 | def psp_conv(x, kernel_size, scope_name, is_training=True):
6 | filters_in = x.get_shape()[-1]
7 | with tf.variable_scope(scope_name) as scope:
8 | kernel = tf.get_variable(
9 | name='weights',
10 | shape=[kernel_size, kernel_size, filters_in, 512],
11 | trainable=True)
12 | conv_out = tf.nn.conv2d(x, kernel, [1, 1, 1, 1], padding='SAME')
13 | bn = slim.batch_norm(conv_out, activation_fn=tf.nn.relu, is_training=is_training, scope=scope_name + "/bn")
14 | return bn
15 |
16 | def _pspnet_builder(x,
17 | is_training,
18 | args,
19 | reuse=False):
20 | """Helper function to build PSPNet model for semantic segmentation.
21 |
22 | The PSPNet model is composed of one base network (ResNet101) and
23 | one pyramid spatial pooling (PSP) module, followed with concatenation
24 | and two more convlutional layers for segmentation prediction.
25 |
26 | Args:
27 | x: A tensor of size [batch_size, height_in, width_in, channels].
28 | name: The prefix of tensorflow variables defined in this network.
29 | cnn_fn: A function which builds the base network (ResNet101).
30 | num_classes: Number of predicted classes for classification tasks.
31 | is_training: If the tensorflow variables defined in this network
32 | would be used for training.
33 | reuse: enable/disable reuse for reusing tensorflow variables. It is
34 | useful for sharing weight parameters across two identical networks.
35 |
36 | Returns:
37 | A tensor of size [batch_size, height_in/8, width_in/8, num_classes].
38 | """
39 | # Ensure that the size of input data is valid (should be multiple of 6x8=48).
40 | h, w = x.get_shape().as_list()[1:3] # NxHxWxC
41 | assert(h == w)
42 | with slim.arg_scope(resnet_utils.resnet_arg_scope(args.l2_regularizer, is_training,
43 | args.batch_norm_decay,
44 | args.batch_norm_epsilon)):
45 | cnn_fn = getattr(resnet_v2, args.resnet_model)
46 | # Build the base network.
47 | _, end_points = cnn_fn(inputs=x,
48 | is_training=is_training,
49 | global_pool=False,
50 | output_stride=8,
51 | spatial_squeeze=False,
52 | reuse=reuse)
53 |
54 | if is_training:
55 | exclude = [args.resnet_model + '/logits', 'global_step']
56 | variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude)
57 | tf.train.init_from_checkpoint(args.pre_trained_model,
58 | {v.name.split(':')[0]: v for v in variables_to_restore})
59 |
60 |
61 | x = end_points[args.resnet_model + "/block4"]
62 |
63 | with tf.variable_scope(args.resnet_model, reuse=reuse) as scope:
64 | # Build the PSP module
65 | pool_k = int(h/8) # the base network is stride 8 by default.
66 |
67 | # Build pooling layer results in 1x1 output.
68 | pool1 = tf.nn.avg_pool(x,
69 | name='block5/pool1',
70 | ksize=[1, pool_k, pool_k, 1],
71 | strides=[1, pool_k, pool_k, 1],
72 | padding='VALID')
73 | pool1 = psp_conv(pool1, 1, 'block5/pool1/conv1', is_training)
74 |
75 | pool1 = tf.image.resize_bilinear(pool1, [pool_k, pool_k])
76 |
77 | # Build pooling layer results in 2x2 output.
78 | pool2 = tf.nn.avg_pool(x,
79 | name='block5/pool2',
80 | ksize=[1, pool_k//2, pool_k//2, 1],
81 | strides=[1, pool_k//2, pool_k//2, 1],
82 | padding='VALID')
83 | pool2 = psp_conv(pool2, 1, 'block5/pool2/conv1', is_training)
84 |
85 | pool2 = tf.image.resize_bilinear(pool2, [pool_k, pool_k])
86 |
87 | # Build pooling layer results in 3x3 output.
88 | pool3 = tf.nn.avg_pool(x,
89 | name='block5/pool3',
90 | ksize=[1, pool_k//3, pool_k//3, 1],
91 | strides=[1, pool_k//3, pool_k//3, 1],
92 | padding='VALID')
93 | pool3 = psp_conv(pool3, 1, 'block5/pool3/conv1', is_training)
94 |
95 | pool3 = tf.image.resize_bilinear(pool3, [pool_k, pool_k])
96 |
97 | # Build pooling layer results in 6x6 output.
98 | pool6 = tf.nn.avg_pool(x,
99 | name='block5/pool6',
100 | ksize=[1, pool_k//6, pool_k//6, 1],
101 | strides=[1, pool_k//6, pool_k//6, 1],
102 | padding='VALID')
103 | pool6 = psp_conv(pool6, 1, 'block5/pool6/conv1', is_training)
104 |
105 | pool6 = tf.image.resize_bilinear(pool6, [pool_k, pool_k])
106 |
107 | # Fuse the pooled feature maps with its input, and generate
108 | # segmentation prediction.
109 | x = tf.concat([pool1, pool2, pool3, pool6, x],
110 | name='block5/concat',
111 | axis=3)
112 | x = psp_conv(x, 3, 'block5/conv2', is_training)
113 |
114 | x = slim.conv2d(x, args.number_of_classes, [1, 1], activation_fn=None,
115 | normalizer_fn=None, scope='block5/fc1_voc12', padding='SAME')
116 |
117 | x = tf.image.resize_bilinear(x, [h, w])
118 |
119 | return x
120 |
121 | def pspnet_resnet(x, args, is_training, reuse=False):
122 | """Helper function to build PSPNet model for semantic segmentation.
123 |
124 | The PSPNet model is composed of one base network (ResNet101) and
125 | one pyramid spatial pooling (PSP) module, followed with concatenation
126 | and two more convlutional layers for segmentation prediction.
127 |
128 | Args:
129 | x: A tensor of size [batch_size, height_in, width_in, channels].
130 | num_classes: Number of predicted classes for classification tasks.
131 | is_training: If the tensorflow variables defined in this network
132 | would be used for training.
133 | use_global_status: enable/disable use_global_status for batch
134 | normalization. If True, moving mean and moving variance are updated
135 | by exponential decay.
136 | reuse: enable/disable reuse for reusing tensorflow variables. It is
137 | useful for sharing weight parameters across two identical networks.
138 |
139 | Returns:
140 | A tensor of size [batch_size, height_in/8, width_in/8, num_classes].
141 | """
142 | with tf.name_scope('psp') as scope:
143 | result = _pspnet_builder(x,
144 | is_training,
145 | args,
146 | reuse=reuse)
147 | return result
--------------------------------------------------------------------------------
/NET/resnet_v2/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/NET/resnet_v2/__init__.py
--------------------------------------------------------------------------------
/NET/resnet_v2/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/NET/resnet_v2/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/NET/resnet_v2/__pycache__/resnet_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/NET/resnet_v2/__pycache__/resnet_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/NET/resnet_v2/__pycache__/resnet_v2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/NET/resnet_v2/__pycache__/resnet_v2.cpython-36.pyc
--------------------------------------------------------------------------------
/NET/resnet_v2/resnet_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains building blocks for various versions of Residual Networks.
16 |
17 | Residual networks (ResNets) were proposed in:
18 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015
20 |
21 | More variants were introduced in:
22 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
23 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016
24 |
25 | We can obtain different ResNet variants by changing the network depth, width,
26 | and form of residual unit. This module implements the infrastructure for
27 | building them. Concrete ResNet units and full ResNet networks are implemented in
28 | the accompanying resnet_v1.py and resnet_v2.py modules.
29 |
30 | Compared to https://github.com/KaimingHe/deep-residual-networks, in the current
31 | implementation we subsample the output activations in the last residual unit of
32 | each block, instead of subsampling the input activations in the first residual
33 | unit of each block. The two implementations give identical results but our
34 | implementation is more memory efficient.
35 | """
36 | from __future__ import absolute_import
37 | from __future__ import division
38 | from __future__ import print_function
39 |
40 | import collections
41 | import tensorflow as tf
42 |
43 | slim = tf.contrib.slim
44 |
45 |
46 | class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])):
47 | """A named tuple describing a ResNet block.
48 |
49 | Its parts are:
50 | scope: The scope of the `Block`.
51 | unit_fn: The ResNet unit function which takes as input a `Tensor` and
52 | returns another `Tensor` with the output of the ResNet unit.
53 | args: A list of length equal to the number of units in the `Block`. The list
54 | contains one (depth, depth_bottleneck, stride) tuple for each unit in the
55 | block to serve as argument to unit_fn.
56 | """
57 |
58 |
59 | def subsample(inputs, factor, scope=None):
60 | """Subsamples the input along the spatial dimensions.
61 |
62 | Args:
63 | inputs: A `Tensor` of size [batch, height_in, width_in, channels].
64 | factor: The subsampling factor.
65 | scope: Optional variable_scope.
66 |
67 | Returns:
68 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the
69 | input, either intact (if factor == 1) or subsampled (if factor > 1).
70 | """
71 | if factor == 1:
72 | return inputs
73 | else:
74 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope)
75 |
76 |
77 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None):
78 | """Strided 2-D convolution with 'SAME' padding.
79 |
80 | When stride > 1, then we do explicit zero-padding, followed by conv2d with
81 | 'VALID' padding.
82 |
83 | Note that
84 |
85 | net = conv2d_same(inputs, num_outputs, 3, stride=stride)
86 |
87 | is equivalent to
88 |
89 | net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME')
90 | net = subsample(net, factor=stride)
91 |
92 | whereas
93 |
94 | net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME')
95 |
96 | is different when the input's height or width is even, which is why we add the
97 | current function. For more details, see ResnetUtilsTest.testConv2DSameEven().
98 |
99 | Args:
100 | inputs: A 4-D tensor of size [batch, height_in, width_in, channels].
101 | num_outputs: An integer, the number of output filters.
102 | kernel_size: An int with the kernel_size of the filters.
103 | stride: An integer, the output stride.
104 | rate: An integer, rate for atrous convolution.
105 | scope: Scope.
106 |
107 | Returns:
108 | output: A 4-D tensor of size [batch, height_out, width_out, channels] with
109 | the convolution output.
110 | """
111 | if stride == 1:
112 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate,
113 | padding='SAME', scope=scope)
114 | else:
115 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
116 | pad_total = kernel_size_effective - 1
117 | pad_beg = pad_total // 2
118 | pad_end = pad_total - pad_beg
119 | inputs = tf.pad(inputs,
120 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
121 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride,
122 | rate=rate, padding='VALID', scope=scope)
123 |
124 |
125 | @slim.add_arg_scope
126 | def stack_blocks_dense(net, blocks, multi_grid, output_stride=None,
127 | outputs_collections=None):
128 | """Stacks ResNet `Blocks` and controls output feature density.
129 |
130 | First, this function creates scopes for the ResNet in the form of
131 | 'block_name/unit_1', 'block_name/unit_2', etc.
132 |
133 | Second, this function allows the user to explicitly control the ResNet
134 | output_stride, which is the ratio of the input to output spatial resolution.
135 | This is useful for dense prediction tasks such as semantic segmentation or
136 | object detection.
137 |
138 | Most ResNets consist of 4 ResNet blocks and subsample the activations by a
139 | factor of 2 when transitioning between consecutive ResNet blocks. This results
140 | to a nominal ResNet output_stride equal to 8. If we set the output_stride to
141 | half the nominal network stride (e.g., output_stride=4), then we compute
142 | responses twice.
143 |
144 | Control of the output feature density is implemented by atrous convolution.
145 |
146 | Args:
147 | net: A `Tensor` of size [batch, height, width, channels].
148 | blocks: A list of length equal to the number of ResNet `Blocks`. Each
149 | element is a ResNet `Block` object describing the units in the `Block`.
150 | output_stride: If `None`, then the output will be computed at the nominal
151 | network stride. If output_stride is not `None`, it specifies the requested
152 | ratio of input to output spatial resolution, which needs to be equal to
153 | the product of unit strides from the start up to some level of the ResNet.
154 | For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1,
155 | then valid values for the output_stride are 1, 2, 6, 24 or None (which
156 | is equivalent to output_stride=24).
157 | outputs_collections: Collection to add the ResNet block outputs.
158 |
159 | Returns:
160 | net: Output tensor with stride equal to the specified output_stride.
161 |
162 | Raises:
163 | ValueError: If the target output_stride is not valid.
164 | """
165 | # The current_stride variable keeps track of the effective stride of the
166 | # activations. This allows us to invoke atrous convolution whenever applying
167 | # the next residual unit would result in the activations having stride larger
168 | # than the target output_stride.
169 | current_stride = 1
170 |
171 | # The atrous convolution rate parameter.
172 | rate = 1
173 |
174 | for block in blocks:
175 | with tf.variable_scope(block.scope, 'block', [net]) as sc:
176 | for i, unit in enumerate(block.args):
177 | if output_stride is not None and current_stride > output_stride:
178 | raise ValueError('The target output_stride cannot be reached.')
179 |
180 | with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
181 | # If we have reached the target output_stride, then we need to employ
182 | # atrous convolution with stride=1 and multiply the atrous rate by the
183 | # current unit's stride for use in subsequent layers.
184 | if output_stride is not None and current_stride == output_stride:
185 | # Only uses atrous convolutions with multi-graid rates in the last (block4) block
186 | if block.scope == "block4":
187 | net = block.unit_fn(net, rate=rate * multi_grid[i], **dict(unit, stride=1))
188 | else:
189 | net = block.unit_fn(net, rate=rate, **dict(unit, stride=1))
190 | rate *= unit.get('stride', 1)
191 | else:
192 | net = block.unit_fn(net, rate=1, **unit)
193 | current_stride *= unit.get('stride', 1)
194 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net)
195 |
196 | if output_stride is not None and current_stride != output_stride:
197 | raise ValueError('The target output_stride cannot be reached.')
198 |
199 | return net
200 |
201 |
202 | def resnet_arg_scope(weight_decay=0.0001,
203 | is_training=True,
204 | batch_norm_decay=0.997,
205 | batch_norm_epsilon=1e-5,
206 | batch_norm_scale=True,
207 | activation_fn=tf.nn.relu,
208 | use_batch_norm=True):
209 | """Defines the default ResNet arg scope.
210 |
211 | TODO(gpapan): The batch-normalization related default values above are
212 | appropriate for use in conjunction with the reference ResNet models
213 | released at https://github.com/KaimingHe/deep-residual-networks. When
214 | training ResNets from scratch, they might need to be tuned.
215 |
216 | Args:
217 | weight_decay: The weight decay to use for regularizing the model.
218 | batch_norm_decay: The moving average decay when estimating layer activation
219 | statistics in batch normalization.
220 | batch_norm_epsilon: Small constant to prevent division by zero when
221 | normalizing activations by their variance in batch normalization.
222 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
223 | activations in the batch normalization layer.
224 | activation_fn: The activation function which is used in ResNet.
225 | use_batch_norm: Whether or not to use batch normalization.
226 |
227 | Returns:
228 | An `arg_scope` to use for the resnet models.
229 | """
230 | batch_norm_params = {
231 | 'decay': batch_norm_decay,
232 | 'epsilon': batch_norm_epsilon,
233 | 'scale': batch_norm_scale,
234 | 'updates_collections': None,
235 | 'is_training': is_training,
236 | 'fused': True, # Use fused batch norm if possible.
237 | }
238 |
239 | with slim.arg_scope(
240 | [slim.conv2d],
241 | weights_regularizer=slim.l2_regularizer(weight_decay),
242 | weights_initializer=slim.variance_scaling_initializer(),
243 | activation_fn=activation_fn,
244 | normalizer_fn=slim.batch_norm if use_batch_norm else None,
245 | normalizer_params=batch_norm_params):
246 | with slim.arg_scope([slim.batch_norm], **batch_norm_params):
247 | # The following implies padding='SAME' for pool1, which makes feature
248 | # alignment easier for dense prediction tasks. This is also used in
249 | # https://github.com/facebook/fb.resnet.torch. However the accompanying
250 | # code of 'Deep Residual Learning for Image Recognition' uses
251 | # padding='VALID' for pool1. You can switch to that choice by setting
252 | # slim.arg_scope([slim.max_pool2d], padding='VALID').
253 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc:
254 | return arg_sc
255 |
--------------------------------------------------------------------------------
/NET/resnet_v2/resnet_v2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains definitions for the preactivation form of Residual Networks.
16 |
17 | Residual networks (ResNets) were originally proposed in:
18 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
20 |
21 | The full preactivation 'v2' ResNet variant implemented in this module was
22 | introduced by:
23 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
24 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027
25 |
26 | The key difference of the full preactivation 'v2' variant compared to the
27 | 'v1' variant in [1] is the use of batch normalization before every weight layer.
28 |
29 | Typical use:
30 |
31 | from tensorflow.contrib.slim.nets import resnet_v2
32 |
33 | ResNet-101 for image classification into 1000 classes:
34 |
35 | # inputs has shape [batch, 224, 224, 3]
36 | with slim.arg_scope(resnet_v2.resnet_arg_scope()):
37 | net, end_points = resnet_v2.resnet_v2_101(inputs, 1000, is_training=False)
38 |
39 | ResNet-101 for semantic segmentation into 21 classes:
40 |
41 | # inputs has shape [batch, 513, 513, 3]
42 | with slim.arg_scope(resnet_v2.resnet_arg_scope()):
43 | net, end_points = resnet_v2.resnet_v2_101(inputs,
44 | 21,
45 | is_training=False,
46 | global_pool=False,
47 | output_stride=16)
48 | """
49 | from __future__ import absolute_import
50 | from __future__ import division
51 | from __future__ import print_function
52 |
53 | import tensorflow as tf
54 |
55 | from NET.resnet_v2 import resnet_utils
56 |
57 | slim = tf.contrib.slim
58 | resnet_arg_scope = resnet_utils.resnet_arg_scope
59 |
60 |
61 | @slim.add_arg_scope
62 | def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1,
63 | outputs_collections=None, scope=None):
64 | """Bottleneck residual unit variant with BN before convolutions.
65 |
66 | This is the full preactivation residual unit variant proposed in [2]. See
67 | Fig. 1(b) of [2] for its definition. Note that we use here the bottleneck
68 | variant which has an extra bottleneck layer.
69 |
70 | When putting together two consecutive ResNet blocks that use this unit, one
71 | should use stride = 2 in the last unit of the first block.
72 |
73 | Args:
74 | inputs: A tensor of size [batch, height, width, channels].
75 | depth: The depth of the ResNet unit output.
76 | depth_bottleneck: The depth of the bottleneck layers.
77 | stride: The ResNet unit's stride. Determines the amount of downsampling of
78 | the units output compared to its input.
79 | rate: An integer, rate for atrous convolution.
80 | outputs_collections: Collection to add the ResNet unit output.
81 | scope: Optional variable_scope.
82 |
83 | Returns:
84 | The ResNet unit's output.
85 | """
86 | with tf.variable_scope(scope, 'bottleneck_v2', [inputs]) as sc:
87 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
88 | preact = slim.batch_norm(inputs, activation_fn=tf.nn.relu, scope='preact')
89 | if depth == depth_in:
90 | shortcut = resnet_utils.subsample(inputs, stride, 'shortcut')
91 | else:
92 | shortcut = slim.conv2d(preact, depth, [1, 1], stride=stride,
93 | normalizer_fn=None, activation_fn=None,
94 | scope='shortcut')
95 |
96 | residual = slim.conv2d(preact, depth_bottleneck, [1, 1], stride=1,
97 | scope='conv1')
98 | residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride,
99 | rate=rate, scope='conv2')
100 | residual = slim.conv2d(residual, depth, [1, 1], stride=1,
101 | normalizer_fn=None, activation_fn=None,
102 | scope='conv3')
103 |
104 | output = shortcut + residual
105 |
106 | return slim.utils.collect_named_outputs(outputs_collections,
107 | sc.name,
108 | output)
109 |
110 |
111 | def resnet_v2(inputs,
112 | blocks,
113 | num_classes=None,
114 | multi_grid=None,
115 | is_training=True,
116 | global_pool=True,
117 | output_stride=None,
118 | include_root_block=True,
119 | spatial_squeeze=True,
120 | reuse=None,
121 | scope=None):
122 | """Generator for v2 (preactivation) ResNet models.
123 |
124 | This function generates a family of ResNet v2 models. See the resnet_v2_*()
125 | methods for specific model instantiations, obtained by selecting different
126 | block instantiations that produce ResNets of various depths.
127 |
128 | Training for image classification on Imagenet is usually done with [224, 224]
129 | inputs, resulting in [7, 7] feature maps at the output of the last ResNet
130 | block for the ResNets defined in [1] that have nominal stride equal to 32.
131 | However, for dense prediction tasks we advise that one uses inputs with
132 | spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In
133 | this case the feature maps at the ResNet output will have spatial shape
134 | [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1]
135 | and corners exactly aligned with the input image corners, which greatly
136 | facilitates alignment of the features to the image. Using as input [225, 225]
137 | images results in [8, 8] feature maps at the output of the last ResNet block.
138 |
139 | For dense prediction tasks, the ResNet needs to run in fully-convolutional
140 | (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all
141 | have nominal stride equal to 32 and a good choice in FCN mode is to use
142 | output_stride=16 in order to increase the density of the computed features at
143 | small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915.
144 |
145 | Args:
146 | inputs: A tensor of size [batch, height_in, width_in, channels].
147 | blocks: A list of length equal to the number of ResNet blocks. Each element
148 | is a resnet_utils.Block object describing the units in the block.
149 | num_classes: Number of predicted classes for classification tasks.
150 | If 0 or None, we return the features before the logit layer.
151 | is_training: whether batch_norm layers are in training mode.
152 | global_pool: If True, we perform global average pooling before computing the
153 | logits. Set to True for image classification, False for dense prediction.
154 | output_stride: If None, then the output will be computed at the nominal
155 | network stride. If output_stride is not None, it specifies the requested
156 | ratio of input to output spatial resolution.
157 | include_root_block: If True, include the initial convolution followed by
158 | max-pooling, if False excludes it. If excluded, `inputs` should be the
159 | results of an activation-less convolution.
160 | spatial_squeeze: if True, logits is of shape [B, C], if false logits is
161 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
162 | To use this parameter, the input images must be smaller than 300x300
163 | pixels, in which case the output logit layer does not contain spatial
164 | information and can be removed.
165 | reuse: whether or not the network and its variables should be reused. To be
166 | able to reuse 'scope' must be given.
167 | scope: Optional variable_scope.
168 |
169 |
170 | Returns:
171 | net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
172 | If global_pool is False, then height_out and width_out are reduced by a
173 | factor of output_stride compared to the respective height_in and width_in,
174 | else both height_out and width_out equal one. If num_classes is 0 or None,
175 | then net is the output of the last ResNet block, potentially after global
176 | average pooling. If num_classes is a non-zero integer, net contains the
177 | pre-softmax activations.
178 | end_points: A dictionary from components of the network to the corresponding
179 | activation.
180 |
181 | Raises:
182 | ValueError: If the target output_stride is not valid.
183 | """
184 | with tf.variable_scope(scope, 'resnet_v2', [inputs], reuse=reuse) as sc:
185 | end_points_collection = sc.original_name_scope + '_end_points'
186 | with slim.arg_scope([slim.conv2d, bottleneck,
187 | resnet_utils.stack_blocks_dense],
188 | outputs_collections=end_points_collection):
189 |
190 | net = inputs
191 | if include_root_block:
192 | if output_stride is not None:
193 | if output_stride % 4 != 0:
194 | raise ValueError('The output_stride needs to be a multiple of 4.')
195 | output_stride /= 4
196 | # We do not include batch normalization or activation functions in
197 | # conv1 because the first ResNet unit will perform these. Cf.
198 | # Appendix of [2].
199 | with slim.arg_scope([slim.conv2d],
200 | activation_fn=None, normalizer_fn=None):
201 | net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1')
202 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1')
203 | net = resnet_utils.stack_blocks_dense(net, blocks, multi_grid, output_stride)
204 | # This is needed because the pre-activation variant does not have batch
205 | # normalization or activation functions in the residual unit output. See
206 | # Appendix of [2].
207 | net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='postnorm')
208 | # Convert end_points_collection into a dictionary of end_points.
209 | end_points = slim.utils.convert_collection_to_dict(
210 | end_points_collection)
211 |
212 | if global_pool:
213 | # Global average pooling.
214 | net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True)
215 | end_points['global_pool'] = net
216 | if num_classes is not None:
217 | net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
218 | normalizer_fn=None, scope='logits')
219 | end_points[sc.name + '/logits'] = net
220 | if spatial_squeeze:
221 | net = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
222 | end_points[sc.name + '/spatial_squeeze'] = net
223 | end_points['predictions'] = slim.softmax(net, scope='predictions')
224 | return net, end_points
225 |
226 |
227 | resnet_v2.default_image_size = 224
228 |
229 |
230 | def resnet_v2_block(scope, base_depth, num_units, stride):
231 | """Helper function for creating a resnet_v2 bottleneck block.
232 |
233 | Args:
234 | scope: The scope of the block.
235 | base_depth: The depth of the bottleneck layer for each unit.
236 | num_units: The number of units in the block.
237 | stride: The stride of the block, implemented as a stride in the last unit.
238 | All other units have stride=1.
239 |
240 | Returns:
241 | A resnet_v2 bottleneck block.
242 | """
243 | return resnet_utils.Block(scope, bottleneck, [{
244 | 'depth': base_depth * 4,
245 | 'depth_bottleneck': base_depth,
246 | 'stride': 1
247 | }] * (num_units - 1) + [{
248 | 'depth': base_depth * 4,
249 | 'depth_bottleneck': base_depth,
250 | 'stride': stride
251 | }])
252 |
253 |
254 | resnet_v2.default_image_size = 224
255 |
256 |
257 | def resnet_v2_50(inputs,
258 | num_classes=None,
259 | is_training=True,
260 | multi_grid=[1, 2, 4],
261 | global_pool=True,
262 | output_stride=None,
263 | spatial_squeeze=True,
264 | reuse=None,
265 | scope='resnet_v2_50'):
266 | """ResNet-50 model of [1]. See resnet_v2() for arg and return description."""
267 | blocks = [
268 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
269 | resnet_v2_block('block2', base_depth=128, num_units=4, stride=2),
270 | resnet_v2_block('block3', base_depth=256, num_units=6, stride=2),
271 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
272 | ]
273 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
274 | global_pool=global_pool, output_stride=output_stride, multi_grid=multi_grid,
275 | include_root_block=True, spatial_squeeze=spatial_squeeze,
276 | reuse=reuse, scope=scope)
277 |
278 |
279 | resnet_v2_50.default_image_size = resnet_v2.default_image_size
280 |
281 |
282 | def resnet_v2_101(inputs,
283 | num_classes=None,
284 | is_training=True,
285 | multi_grid=[1, 2, 4],
286 | global_pool=True,
287 | output_stride=None,
288 | spatial_squeeze=True,
289 | reuse=None,
290 | scope='resnet_v2_101'):
291 | """ResNet-101 model of [1]. See resnet_v2() for arg and return description."""
292 | blocks = [
293 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
294 | resnet_v2_block('block2', base_depth=128, num_units=4, stride=2),
295 | resnet_v2_block('block3', base_depth=256, num_units=23, stride=2),
296 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
297 | ]
298 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
299 | global_pool=global_pool, output_stride=output_stride, multi_grid=multi_grid,
300 | include_root_block=True, spatial_squeeze=spatial_squeeze,
301 | reuse=reuse, scope=scope)
302 |
303 |
304 | resnet_v2_101.default_image_size = resnet_v2.default_image_size
305 |
306 |
307 | def resnet_v2_152(inputs,
308 | num_classes=None,
309 | is_training=True,
310 | multi_grid=[1, 2, 4],
311 | global_pool=True,
312 | output_stride=None,
313 | spatial_squeeze=True,
314 | reuse=None,
315 | scope='resnet_v2_152'):
316 | """ResNet-152 model of [1]. See resnet_v2() for arg and return description."""
317 | blocks = [
318 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
319 | resnet_v2_block('block2', base_depth=128, num_units=8, stride=2),
320 | resnet_v2_block('block3', base_depth=256, num_units=36, stride=2),
321 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
322 | ]
323 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
324 | global_pool=global_pool, output_stride=output_stride, multi_grid=multi_grid,
325 | include_root_block=True, spatial_squeeze=spatial_squeeze,
326 | reuse=reuse, scope=scope)
327 |
328 |
329 | resnet_v2_152.default_image_size = resnet_v2.default_image_size
330 |
331 |
332 | def resnet_v2_200(inputs,
333 | num_classes=None,
334 | is_training=True,
335 | multi_grid=[1, 2, 4],
336 | global_pool=True,
337 | output_stride=None,
338 | spatial_squeeze=True,
339 | reuse=None,
340 | scope='resnet_v2_200'):
341 | """ResNet-200 model of [2]. See resnet_v2() for arg and return description."""
342 | blocks = [
343 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
344 | resnet_v2_block('block2', base_depth=128, num_units=24, stride=2),
345 | resnet_v2_block('block3', base_depth=256, num_units=36, stride=2),
346 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
347 | ]
348 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
349 | global_pool=global_pool, output_stride=output_stride, multi_grid=multi_grid,
350 | include_root_block=True, spatial_squeeze=spatial_squeeze,
351 | reuse=reuse, scope=scope)
352 |
353 |
354 | resnet_v2_200.default_image_size = resnet_v2.default_image_size
355 |
--------------------------------------------------------------------------------
/NET/resnet_v2_psp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/NET/resnet_v2_psp/__init__.py
--------------------------------------------------------------------------------
/NET/resnet_v2_psp/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/NET/resnet_v2_psp/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/NET/resnet_v2_psp/__pycache__/resnet_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/NET/resnet_v2_psp/__pycache__/resnet_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/NET/resnet_v2_psp/__pycache__/resnet_v2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/NET/resnet_v2_psp/__pycache__/resnet_v2.cpython-36.pyc
--------------------------------------------------------------------------------
/NET/resnet_v2_psp/resnet_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains building blocks for various versions of Residual Networks.
16 |
17 | Residual networks (ResNets) were proposed in:
18 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015
20 |
21 | More variants were introduced in:
22 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
23 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016
24 |
25 | We can obtain different ResNet variants by changing the network depth, width,
26 | and form of residual unit. This module implements the infrastructure for
27 | building them. Concrete ResNet units and full ResNet networks are implemented in
28 | the accompanying resnet_v1.py and resnet_v2.py modules.
29 |
30 | Compared to https://github.com/KaimingHe/deep-residual-networks, in the current
31 | implementation we subsample the output activations in the last residual unit of
32 | each block, instead of subsampling the input activations in the first residual
33 | unit of each block. The two implementations give identical results but our
34 | implementation is more memory efficient.
35 | """
36 | from __future__ import absolute_import
37 | from __future__ import division
38 | from __future__ import print_function
39 |
40 | import collections
41 | import tensorflow as tf
42 |
43 | slim = tf.contrib.slim
44 |
45 |
46 | class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])):
47 | """A named tuple describing a ResNet block.
48 |
49 | Its parts are:
50 | scope: The scope of the `Block`.
51 | unit_fn: The ResNet unit function which takes as input a `Tensor` and
52 | returns another `Tensor` with the output of the ResNet unit.
53 | args: A list of length equal to the number of units in the `Block`. The list
54 | contains one (depth, depth_bottleneck, stride) tuple for each unit in the
55 | block to serve as argument to unit_fn.
56 | """
57 |
58 |
59 | def subsample(inputs, factor, scope=None):
60 | """Subsamples the input along the spatial dimensions.
61 |
62 | Args:
63 | inputs: A `Tensor` of size [batch, height_in, width_in, channels].
64 | factor: The subsampling factor.
65 | scope: Optional variable_scope.
66 |
67 | Returns:
68 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the
69 | input, either intact (if factor == 1) or subsampled (if factor > 1).
70 | """
71 | if factor == 1:
72 | return inputs
73 | else:
74 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope)
75 |
76 |
77 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None):
78 | """Strided 2-D convolution with 'SAME' padding.
79 |
80 | When stride > 1, then we do explicit zero-padding, followed by conv2d with
81 | 'VALID' padding.
82 |
83 | Note that
84 |
85 | net = conv2d_same(inputs, num_outputs, 3, stride=stride)
86 |
87 | is equivalent to
88 |
89 | net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME')
90 | net = subsample(net, factor=stride)
91 |
92 | whereas
93 |
94 | net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME')
95 |
96 | is different when the input's height or width is even, which is why we add the
97 | current function. For more details, see ResnetUtilsTest.testConv2DSameEven().
98 |
99 | Args:
100 | inputs: A 4-D tensor of size [batch, height_in, width_in, channels].
101 | num_outputs: An integer, the number of output filters.
102 | kernel_size: An int with the kernel_size of the filters.
103 | stride: An integer, the output stride.
104 | rate: An integer, rate for atrous convolution.
105 | scope: Scope.
106 |
107 | Returns:
108 | output: A 4-D tensor of size [batch, height_out, width_out, channels] with
109 | the convolution output.
110 | """
111 | if stride == 1:
112 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate,
113 | padding='SAME', scope=scope)
114 | else:
115 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
116 | pad_total = kernel_size_effective - 1
117 | pad_beg = pad_total // 2
118 | pad_end = pad_total - pad_beg
119 | inputs = tf.pad(inputs,
120 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
121 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride,
122 | rate=rate, padding='VALID', scope=scope)
123 |
124 |
125 | @slim.add_arg_scope
126 | def stack_blocks_dense(net, blocks, multi_grid, output_stride=None,
127 | outputs_collections=None):
128 | """Stacks ResNet `Blocks` and controls output feature density.
129 |
130 | First, this function creates scopes for the ResNet in the form of
131 | 'block_name/unit_1', 'block_name/unit_2', etc.
132 |
133 | Second, this function allows the user to explicitly control the ResNet
134 | output_stride, which is the ratio of the input to output spatial resolution.
135 | This is useful for dense prediction tasks such as semantic segmentation or
136 | object detection.
137 |
138 | Most ResNets consist of 4 ResNet blocks and subsample the activations by a
139 | factor of 2 when transitioning between consecutive ResNet blocks. This results
140 | to a nominal ResNet output_stride equal to 8. If we set the output_stride to
141 | half the nominal network stride (e.g., output_stride=4), then we compute
142 | responses twice.
143 |
144 | Control of the output feature density is implemented by atrous convolution.
145 |
146 | Args:
147 | net: A `Tensor` of size [batch, height, width, channels].
148 | blocks: A list of length equal to the number of ResNet `Blocks`. Each
149 | element is a ResNet `Block` object describing the units in the `Block`.
150 | output_stride: If `None`, then the output will be computed at the nominal
151 | network stride. If output_stride is not `None`, it specifies the requested
152 | ratio of input to output spatial resolution, which needs to be equal to
153 | the product of unit strides from the start up to some level of the ResNet.
154 | For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1,
155 | then valid values for the output_stride are 1, 2, 6, 24 or None (which
156 | is equivalent to output_stride=24).
157 | outputs_collections: Collection to add the ResNet block outputs.
158 |
159 | Returns:
160 | net: Output tensor with stride equal to the specified output_stride.
161 |
162 | Raises:
163 | ValueError: If the target output_stride is not valid.
164 | """
165 | # The current_stride variable keeps track of the effective stride of the
166 | # activations. This allows us to invoke atrous convolution whenever applying
167 | # the next residual unit would result in the activations having stride larger
168 | # than the target output_stride.
169 | current_stride = 1
170 |
171 |
172 | for block in blocks:
173 | with tf.variable_scope(block.scope, 'block', [net]) as sc:
174 | for i, unit in enumerate(block.args):
175 | if output_stride is not None and current_stride > output_stride:
176 | raise ValueError('The target output_stride cannot be reached.')
177 |
178 | with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
179 | # If we have reached the target output_stride, then we need to employ
180 | # atrous convolution with stride=1 and multiply the atrous rate by the
181 | # current unit's stride for use in subsequent layers.
182 | if output_stride is not None and current_stride == output_stride:
183 | if block.scope == "block2":
184 | net = block.unit_fn(net, rate=1, **dict(unit, stride=1))
185 | elif block.scope == "block3":
186 | net = block.unit_fn(net, rate=2, **dict(unit, stride=1))
187 | else:
188 | net = block.unit_fn(net, rate=4, **dict(unit, stride=1))
189 | else:
190 | net = block.unit_fn(net, rate=1, **unit)
191 | current_stride *= unit.get('stride', 1)
192 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net)
193 |
194 | if output_stride is not None and current_stride != output_stride:
195 | raise ValueError('The target output_stride cannot be reached.')
196 |
197 | return net
198 |
199 |
200 | def resnet_arg_scope(weight_decay=0.0001,
201 | is_training=True,
202 | batch_norm_decay=0.997,
203 | batch_norm_epsilon=1e-5,
204 | batch_norm_scale=True,
205 | activation_fn=tf.nn.relu,
206 | use_batch_norm=True):
207 | """Defines the default ResNet arg scope.
208 |
209 | TODO(gpapan): The batch-normalization related default values above are
210 | appropriate for use in conjunction with the reference ResNet models
211 | released at https://github.com/KaimingHe/deep-residual-networks. When
212 | training ResNets from scratch, they might need to be tuned.
213 |
214 | Args:
215 | weight_decay: The weight decay to use for regularizing the model.
216 | batch_norm_decay: The moving average decay when estimating layer activation
217 | statistics in batch normalization.
218 | batch_norm_epsilon: Small constant to prevent division by zero when
219 | normalizing activations by their variance in batch normalization.
220 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
221 | activations in the batch normalization layer.
222 | activation_fn: The activation function which is used in ResNet.
223 | use_batch_norm: Whether or not to use batch normalization.
224 |
225 | Returns:
226 | An `arg_scope` to use for the resnet models.
227 | """
228 | batch_norm_params = {
229 | 'decay': batch_norm_decay,
230 | 'epsilon': batch_norm_epsilon,
231 | 'scale': batch_norm_scale,
232 | 'updates_collections': None,
233 | 'is_training': is_training,
234 | 'fused': True, # Use fused batch norm if possible.
235 | }
236 |
237 | with slim.arg_scope(
238 | [slim.conv2d],
239 | weights_regularizer=slim.l2_regularizer(weight_decay),
240 | weights_initializer=slim.variance_scaling_initializer(),
241 | activation_fn=activation_fn,
242 | normalizer_fn=slim.batch_norm if use_batch_norm else None,
243 | normalizer_params=batch_norm_params):
244 | with slim.arg_scope([slim.batch_norm], **batch_norm_params):
245 | # The following implies padding='SAME' for pool1, which makes feature
246 | # alignment easier for dense prediction tasks. This is also used in
247 | # https://github.com/facebook/fb.resnet.torch. However the accompanying
248 | # code of 'Deep Residual Learning for Image Recognition' uses
249 | # padding='VALID' for pool1. You can switch to that choice by setting
250 | # slim.arg_scope([slim.max_pool2d], padding='VALID').
251 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc:
252 | return arg_sc
253 |
--------------------------------------------------------------------------------
/NET/self_attention_layers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/NET/self_attention_layers/__init__.py
--------------------------------------------------------------------------------
/NET/self_attention_layers/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/NET/self_attention_layers/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/NET/self_attention_layers/__pycache__/self_attention_layers.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/NET/self_attention_layers/__pycache__/self_attention_layers.cpython-36.pyc
--------------------------------------------------------------------------------
/NET/self_attention_layers/self_attention_layers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | slim = tf.contrib.slim
3 |
4 | def position_attention_module(feature, wights=1):
5 | """
6 | Position Attention Module
7 | :param feature: tensor of shape [BATCH_SIZE, WIDTH, HEIGHT, DEPTH]
8 | :return: the shape is same to feature.
9 | """
10 | BATCH_SIZE, WIDTH, HEIGHT, DEPTH = feature.get_shape().as_list()
11 | with tf.variable_scope("position_module"):
12 | value = slim.conv2d(feature, DEPTH, [1, 1], activation_fn=None, scope="value")
13 | value = tf.reshape(value, [-1, WIDTH*HEIGHT, DEPTH])
14 | value = tf.transpose(value, [0, 2, 1])
15 |
16 | query = slim.conv2d(feature, DEPTH, [1, 1], activation_fn=None, scope="query")
17 | key = slim.conv2d(feature, DEPTH, [1, 1], activation_fn=None, scope="key")
18 | query = tf.reshape(query, [-1, WIDTH*HEIGHT, DEPTH])
19 | key = tf.reshape(key, [-1, WIDTH*HEIGHT, DEPTH])
20 | query = tf.transpose(query, [0, 2, 1])
21 | mul_end = tf.matmul(key, query) # shape[batch_size, WIDTH*HEIGHT, WIDTH*HEIGHT]
22 | s = tf.nn.softmax(mul_end, dim=1) # shape[batch_size, WIDTH*HEIGHT, WIDTH*HEIGHT]
23 | s = tf.transpose(s, [0, 2, 1])
24 |
25 | position_ends = tf.matmul(value, s)
26 |
27 | position_ends = tf.transpose(position_ends, [0, 2, 1])
28 | position_ends = tf.reshape(position_ends, [-1, WIDTH, HEIGHT, DEPTH])
29 | result = wights * position_ends + feature
30 | return result
31 | def chanel_attention_module(feature, wights=1):
32 | """
33 | Position Attention Module
34 | :param feature: tensor of shape [BATCH_SIZE, WIDTH, HEIGHT, DEPTH]
35 | :return: the shape is same to feature.
36 | """
37 | BATCH_SIZE, WIDTH, HEIGHT, DEPTH = feature.get_shape().as_list()
38 | with tf.variable_scope("chanel_module"):
39 | value = tf.reshape(feature, [-1, WIDTH*HEIGHT, DEPTH])
40 | value = tf.transpose(value, [0, 2, 1])
41 |
42 | query = tf.reshape(feature, [-1, WIDTH*HEIGHT, DEPTH])
43 | key = tf.reshape(feature, [-1, WIDTH*HEIGHT, DEPTH])
44 | query = tf.transpose(query, [0, 2, 1])
45 | mul_end = tf.matmul(query, key) # shape[batch_size, DEPTH, DEPTH]
46 | s = tf.nn.softmax(mul_end, dim=1) # shape[batch_size, DEPTH, DEPTH]
47 | s = tf.transpose(s, [0, 2, 1])
48 |
49 | position_ends = tf.matmul(s, value) # shape[batch_size, DEPTH, WIDTH*HEIGHT]
50 |
51 | position_ends = tf.transpose(position_ends, [0, 2, 1])
52 | position_ends = tf.reshape(position_ends, [-1, WIDTH, HEIGHT, DEPTH])
53 | result = wights * position_ends + feature
54 | return result
55 | if __name__ == "__main__":
56 | b = tf.ones(shape=[1, 50, 50, 3])
57 | x = position_attention_module(b)
58 | init_op = tf.group(
59 | tf.local_variables_initializer(),
60 | tf.global_variables_initializer()
61 | )
62 | with tf.Session() as sess:
63 | sess.run(init_op)
64 | z = sess.run(x)
65 | print(z)
66 |
67 |
68 |
69 |
70 |
71 |
72 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # semantic_segmentation_contest_deeplabv3
2 | 遥感图像稀疏表征与智能分析竞赛之语义分割。晃晃悠悠比赛就结束了,因为是第一次参加自己方向
3 | 上的比赛,主要是通过这个比赛学习和巩固语义分割的知识,同时也为随后的比赛增加经验。所有的
4 | 代码都已经整理上传。
5 | 最后的结果:(仅仅只用了deeplabv3 + dropout)第27名。
6 | ## 结果
7 |
8 | | |Method | OS | kappa |
9 | |:-----:|:------------------------------------:|:---:|:----------:|
10 | | repo | MG(1,2,4)+ASPP(6,12,18)+Image Pooling|16 | **50.662%** |
11 |
12 | 图片结果:
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 | 总结:
21 | 第一步:做数据集(1000*1000)。(这期间出现了问题,我们使用cv2读取的RGB通道数据,取得了最好的成绩。但是cv2读取不了红外通道数据,然后我们换成libtiff
22 | 读取数据时,发现libtif数据的RGB数据与cv2不一样,然后我们使用libtiff的RGB跑出的结果又没有cv2的RGB数据高。最后我们由于时间的关系,只使用cv2的RGB数据)
23 | 第二步:选择网络,最开始我们使用unet,效果不好,后来就直接使用deeplabv3了,后来又用了v3+, DA的模块,pspnet, aaf等。
24 | 效果都不太好,deeplabv3效果是最好的。然后看结果找问题,找解决方法。
25 | 第三步:测试,我们只使用了一个1000*1000留下其中400*400的结果。
26 | 使用滚动引导滤波先预处理图像。(由于时间关系还没有做)
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from utils import dataset_util
2 | from DataGenerate.GetDataset import eval_or_test_input_fn
3 | import tensorflow as tf
4 | import os
5 | import argparse
6 |
7 | evaluation_data_list = "./VOC2012_AUG/txt/val.txt"
8 | parser = argparse.ArgumentParser()
9 |
10 | #添加参数
11 | envarg = parser.add_argument_group('Training params')
12 | envarg.add_argument("--batch_norm_epsilon", type=float, default=1e-5, help="batch norm epsilon argument for batch normalization")
13 | envarg.add_argument('--batch_norm_decay', type=float, default=0.9997, help='batch norm decay argument for batch normalization.')
14 | envarg.add_argument("--number_of_classes", type=int, default=16, help="Number of classes to be predicted.")
15 | envarg.add_argument("--l2_regularizer", type=float, default=0.0001, help="l2 regularizer parameter.")
16 | envarg.add_argument('--starting_learning_rate', type=float, default=0.00001, help="initial learning rate.")
17 | envarg.add_argument('--learning_rate', type=float, default=0.00001, help="initial learning rate.")
18 | envarg.add_argument("--multi_grid", type=list, default=[1, 2, 4], help="Spatial Pyramid Pooling rates")
19 | envarg.add_argument("--output_stride", type=int, default=16, help="Spatial Pyramid Pooling rates")
20 | envarg.add_argument("--gpu_id", type=int, default=0, help="Id of the GPU to be used")
21 | envarg.add_argument("--crop_size", type=int, default=513, help="Image Cropsize.")
22 | envarg.add_argument("--resnet_model", default="resnet_v2_50", choices=["resnet_v2_50", "resnet_v2_101", "resnet_v2_152", "resnet_v2_200"], help="Resnet model to use as feature extractor. Choose one of: resnet_v2_50 or resnet_v2_101")
23 | envarg.add_argument('--learning_power', type=float, default=0.9, help='batch norm decay argument for batch normalization.')
24 | envarg.add_argument("--current_best_val_loss", type=int, default=99999, help="Best validation loss value.")
25 | envarg.add_argument("--accumulated_validation_miou", type=int, default=0, help="Accumulated validation intersection over union.")
26 | args = parser.parse_args()
27 | def main():
28 | os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '0'
29 |
30 | # 获取数据
31 | examples = dataset_util.read_examples_list(FLAGS.evaluation_data_list)
32 | image_files = [os.path.join(FLAGS.image_data_dir, filename) + '.jpg' for filename in examples]
33 | label_files = [os.path.join(FLAGS.label_data_dir, filename) + '.png' for filename in examples]
34 |
35 | features, labels = eval_or_test_input_fn.eval_input_fn(image_files, label_files)
36 |
37 | # Manually load the latest checkpoint
38 | saver = tf.train.Saver()
39 | with tf.Session() as sess:
40 | ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
41 | saver.restore(sess, ckpt.model_checkpoint_path)
42 |
43 |
44 |
45 | if __name__ == '__main__':
46 | tf.logging.set_verbosity(tf.logging.INFO)
47 | FLAGS, unparsed = parser.parse_known_args()
48 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
--------------------------------------------------------------------------------
/resource/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/resource/1.png
--------------------------------------------------------------------------------
/resource/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/resource/2.png
--------------------------------------------------------------------------------
/resource/语义分割比赛进展.pptx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/resource/语义分割比赛进展.pptx
--------------------------------------------------------------------------------
/test1000_stride_400.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import os
4 | import argparse
5 | import tensorflow as tf
6 | from NET import deeplab_v3
7 | from utils.preprocessing import decode_labels
8 |
9 | checkpoint_path = "./checkpoint_1000/"
10 | image_size = 1000
11 | stride = 400
12 | def predict(TEST_SET, sess, prediction, imgs_batch):
13 | for n in range(len(TEST_SET)):
14 |
15 | path = TEST_SET[n] # load the image
16 | image = cv2.imread('/2T/tzj/semantic_segmentation_contest/DatasetNew/test/' + path, cv2.IMREAD_UNCHANGED)
17 | h, w, chanel = image.shape
18 |
19 | result = np.zeros((h, w, 3), dtype=np.uint8)
20 |
21 | padding_h = h + 600
22 | padding_w = w + 600
23 | padding_img = np.zeros((padding_h, padding_w, 3), dtype=np.uint8)
24 | padding_img[300:300 + h, 300:300 + w, :] = image[:, :, 0:3]
25 | mask_whole = np.zeros((padding_h, padding_w, 3), dtype=np.uint8)
26 | for i in range(((padding_h - 1000) // stride) + 1):
27 | for j in range(((padding_w - 1000) // stride) + 1):
28 | crop = padding_img[i * stride:i * stride + image_size, j * stride:j * stride + image_size, :]
29 | ch, cw, _ = crop.shape
30 | if ch != 1000 or cw != 1000:
31 | print('invalid size!')
32 | continue
33 | img_batch = np.expand_dims(crop, axis=0)
34 | pred = sess.run(prediction, feed_dict={imgs_batch: img_batch})
35 | pred = decode_labels(pred)
36 | mask_whole[300 + i * stride: 300 + i * stride + stride, 300 + j * stride: 300 + j * stride + stride, :] = pred[0][300:700, 300:700, :]
37 | print("i:%d" % i, "j:%d" % j)
38 | result[0:h, 0:w, :] = mask_whole[300:300 + h, 300:300 + w, :]
39 | result_bgr = result[..., ::-1]
40 | cv2.imwrite("./Result1000_stride_400/" + path.split(".")[0] + "_label.tif", result_bgr)
41 | print("./Result_stride_400/" + path.split(".")[0] + "_label.tif saved")
42 |
43 | #cv2.imwrite('./predict/pre' + str(n + 1) + '.png', mask_whole[0:h, 0:w])
44 | # 获取全部测试集图片
45 | TEST_SET = os.listdir('/2T/tzj/semantic_segmentation_contest/DatasetNew/test')
46 |
47 | parser = argparse.ArgumentParser()
48 |
49 | #添加参数
50 | envarg = parser.add_argument_group('Training params')
51 | # BN params
52 | envarg.add_argument("--batch_norm_epsilon", type=float, default=1e-5, help="batch norm epsilon argument for batch normalization")
53 | envarg.add_argument('--batch_norm_decay', type=float, default=0.9997, help='batch norm decay argument for batch normalization.')
54 | envarg.add_argument('--freeze_batch_norm', type=bool, default=False, help='Freeze batch normalization parameters during the training.')
55 | # the number of classes
56 | envarg.add_argument("--number_of_classes", type=int, default=16, help="Number of classes to be predicted.")
57 |
58 | # regularizer
59 | envarg.add_argument("--l2_regularizer", type=float, default=0.0001, help="l2 regularizer parameter.")
60 |
61 | # for deeplabv3
62 | envarg.add_argument("--multi_grid", type=list, default=[1, 2, 4], help="Spatial Pyramid Pooling rates")
63 | envarg.add_argument("--output_stride", type=int, default=16, help="Spatial Pyramid Pooling rates")
64 |
65 | # the base network
66 | envarg.add_argument("--resnet_model", default="resnet_v2_50", choices=["resnet_v2_50", "resnet_v2_101", "resnet_v2_152", "resnet_v2_200"], help="Resnet model to use as feature extractor. Choose one of: resnet_v2_50 or resnet_v2_101")
67 |
68 | # the pre_trained model for example resnet50 101 and so on
69 | envarg.add_argument('--pre_trained_model', type=str, default='./pre_trained_model/resnet_v2_50/resnet_v2_50.ckpt',
70 | help='Path to the pre-trained model checkpoint.')
71 |
72 | # max number of batch elements to tensorboard
73 | parser.add_argument('--tensorboard_images_max_outputs', type=int, default=6,
74 | help='Max number of batch elements to generate for Tensorboard.')
75 | # poly learn_rate
76 | parser.add_argument('--initial_learning_rate', type=float, default=7e-4,
77 | help='Initial learning rate for the optimizer.')
78 |
79 | parser.add_argument('--end_learning_rate', type=float, default=1e-6,
80 | help='End learning rate for the optimizer.')
81 |
82 | parser.add_argument('--initial_global_step', type=int, default=0,
83 | help='Initial global step for controlling learning rate when fine-tuning model.')
84 | parser.add_argument('--max_iter', type=int, default=25000,
85 | help='Number of maximum iteration used for "poly" learning rate policy.')
86 | args = parser.parse_args()
87 |
88 | img_batch = tf.placeholder("float32", shape=[None, 1000, 1000, 3], name="img_batch")
89 | # 损失
90 |
91 | logits = deeplab_v3.deeplab_v3(img_batch, args, is_training=False, reuse=False)
92 | prediction = tf.argmax(logits, axis=3)
93 | prediction = tf.expand_dims(prediction, axis=3)
94 |
95 | init_op = tf.group(
96 | tf.local_variables_initializer(),
97 | tf.global_variables_initializer()
98 | )
99 | saver = tf.train.Saver()
100 | # 运行图
101 | config = tf.ConfigProto()
102 | config.gpu_options.allow_growth = True
103 | with tf.Session(config=config) as sess:
104 | sess.run(init_op)
105 | # 恢复权重
106 | ckpt = tf.train.get_checkpoint_state(checkpoint_path)
107 | if ckpt and ckpt.model_checkpoint_path:
108 | saver.restore(sess, checkpoint_path + "model.ckpt-26")
109 | print("restored!!!")
110 | predict(TEST_SET=TEST_SET, sess=sess, prediction=prediction, imgs_batch=img_batch)
--------------------------------------------------------------------------------
/test_stride_400.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import os
4 | import argparse
5 | import tensorflow as tf
6 | from NET import deeplab_v3
7 | from utils.preprocessing import decode_labels
8 |
9 | checkpoint_path = "./checkpoint_1000/"
10 | image_size = 500
11 | stride = 400
12 | def predict(TEST_SET, sess, prediction, imgs_batch):
13 | for n in range(len(TEST_SET)):
14 |
15 | path = TEST_SET[n] # load the image
16 | image = cv2.imread('/2T/tzj/semantic_segmentation_contest/DatasetOrigin/val/' + path, cv2.IMREAD_UNCHANGED)
17 | h, w, chanel = image.shape
18 |
19 | result = np.zeros((h, w, 3), dtype=np.uint8)
20 |
21 | padding_h = h + 100
22 | padding_w = w + 100
23 | padding_img = np.zeros((padding_h, padding_w, 3), dtype=np.uint8)
24 | padding_img[50:50 + h, 50:50 + w, :] = image[:, :, 0:3]
25 | mask_whole = np.zeros((padding_h, padding_w, 3), dtype=np.uint8)
26 | for i in range(padding_h // stride):
27 | for j in range(padding_w // stride):
28 | crop = padding_img[i * stride:i * stride + image_size, j * stride:j * stride + image_size, :]
29 | ch, cw, _ = crop.shape
30 | if ch != 500 or cw != 500:
31 | print('invalid size!')
32 | continue
33 | img_batch = np.expand_dims(crop, axis=0)
34 | pred = sess.run(prediction, feed_dict={imgs_batch: img_batch})
35 | pred = decode_labels(pred)
36 | mask_whole[50 + i * stride: 50 + i * stride + stride, 50 + j * stride: 50 + j * stride + stride, :] = pred[0][50:450, 50:450, :]
37 | print("i:%d" % i, "j:%d" % j)
38 | result[0:h, 0:w, :] = mask_whole[50:50 + h, 50:50 + w, :]
39 | result_bgr = result[..., ::-1]
40 | cv2.imwrite("./Result_stride_400/" + path.split(".")[0] + "_label.tif", result_bgr)
41 | print("./Result_stride_400/" + path.split(".")[0] + "_label.tif saved")
42 |
43 | #cv2.imwrite('./predict/pre' + str(n + 1) + '.png', mask_whole[0:h, 0:w])
44 | # 获取全部测试集图片
45 | TEST_SET = os.listdir('/2T/tzj/semantic_segmentation_contest/DatasetOrigin/val')
46 |
47 | parser = argparse.ArgumentParser()
48 |
49 | #添加参数
50 | envarg = parser.add_argument_group('Training params')
51 | # BN params
52 | envarg.add_argument("--batch_norm_epsilon", type=float, default=1e-5, help="batch norm epsilon argument for batch normalization")
53 | envarg.add_argument('--batch_norm_decay', type=float, default=0.9997, help='batch norm decay argument for batch normalization.')
54 | envarg.add_argument('--freeze_batch_norm', type=bool, default=False, help='Freeze batch normalization parameters during the training.')
55 | # the number of classes
56 | envarg.add_argument("--number_of_classes", type=int, default=16, help="Number of classes to be predicted.")
57 |
58 | # regularizer
59 | envarg.add_argument("--l2_regularizer", type=float, default=0.0001, help="l2 regularizer parameter.")
60 |
61 | # for deeplabv3
62 | envarg.add_argument("--multi_grid", type=list, default=[1, 2, 4], help="Spatial Pyramid Pooling rates")
63 | envarg.add_argument("--output_stride", type=int, default=16, help="Spatial Pyramid Pooling rates")
64 |
65 | # the base network
66 | envarg.add_argument("--resnet_model", default="resnet_v2_50", choices=["resnet_v2_50", "resnet_v2_101", "resnet_v2_152", "resnet_v2_200"], help="Resnet model to use as feature extractor. Choose one of: resnet_v2_50 or resnet_v2_101")
67 |
68 | # the pre_trained model for example resnet50 101 and so on
69 | envarg.add_argument('--pre_trained_model', type=str, default='./pre_trained_model/resnet_v2_50/resnet_v2_50.ckpt',
70 | help='Path to the pre-trained model checkpoint.')
71 |
72 | # max number of batch elements to tensorboard
73 | parser.add_argument('--tensorboard_images_max_outputs', type=int, default=6,
74 | help='Max number of batch elements to generate for Tensorboard.')
75 | # poly learn_rate
76 | parser.add_argument('--initial_learning_rate', type=float, default=7e-4,
77 | help='Initial learning rate for the optimizer.')
78 |
79 | parser.add_argument('--end_learning_rate', type=float, default=1e-6,
80 | help='End learning rate for the optimizer.')
81 |
82 | parser.add_argument('--initial_global_step', type=int, default=0,
83 | help='Initial global step for controlling learning rate when fine-tuning model.')
84 | parser.add_argument('--max_iter', type=int, default=25000,
85 | help='Number of maximum iteration used for "poly" learning rate policy.')
86 | args = parser.parse_args()
87 |
88 | img_batch = tf.placeholder("float32", shape=[None, 500, 500, 3], name="img_batch")
89 | # 损失
90 |
91 | logits = deeplab_v3.deeplab_v3(img_batch, args, is_training=False, reuse=False)
92 | prediction = tf.argmax(logits, axis=3)
93 | prediction = tf.expand_dims(prediction, axis=3)
94 |
95 | init_op = tf.group(
96 | tf.local_variables_initializer(),
97 | tf.global_variables_initializer()
98 | )
99 | saver = tf.train.Saver()
100 | # 运行图
101 | config = tf.ConfigProto()
102 | config.gpu_options.allow_growth = True
103 | with tf.Session(config=config) as sess:
104 | sess.run(init_op)
105 | # 恢复权重
106 | ckpt = tf.train.get_checkpoint_state(checkpoint_path)
107 | if ckpt and ckpt.model_checkpoint_path:
108 | saver.restore(sess, checkpoint_path + "model.ckpt-10")
109 | predict(TEST_SET=TEST_SET, sess=sess, prediction=prediction, imgs_batch=img_batch)
--------------------------------------------------------------------------------
/tools_aaf.py:
--------------------------------------------------------------------------------
1 | from NET import deeplab_v3
2 | from utils import preprocessing
3 | import tensorflow as tf
4 | import numpy as np
5 | import NET.aaf.layers as nnx
6 |
7 | _WEIGHT_DECAY = 5e-4
8 |
9 |
10 | def get_loss_pre_metrics(x, y, is_training, batch_size, args):
11 | # 恢复图像
12 | images = tf.cast(x, tf.uint8)
13 |
14 | # 前向传播
15 | logits = tf.cond(is_training, true_fn=lambda: deeplab_v3.deeplab_v3(x, args, is_training=True, reuse=False),
16 | false_fn=lambda: deeplab_v3.deeplab_v3(x, args, is_training=False, reuse=True))
17 | pred_classes = tf.expand_dims(tf.argmax(logits, axis=3, output_type=tf.int32), axis=3)
18 |
19 | # 解码预测结果
20 | pred_decoded_labels = tf.py_func(preprocessing.decode_labels, [pred_classes, batch_size, args.number_of_classes], tf.uint8)
21 |
22 |
23 | # 解码标签
24 | gt_decoded_labels = tf.py_func(preprocessing.decode_labels, [y, batch_size, args.number_of_classes], tf.uint8)
25 |
26 |
27 | tf.summary.image('images', tf.concat(axis=2, values=[images, gt_decoded_labels, pred_decoded_labels]),
28 | max_outputs=args.tensorboard_images_max_outputs)
29 |
30 | # 求loss
31 | labels = tf.squeeze(y, axis=3) # reduce the channel dimension.
32 | logits_by_num_classes = tf.reshape(logits, [-1, args.number_of_classes])
33 | labels_flat = tf.reshape(labels, [-1, ])
34 |
35 | cross_entropy = tf.losses.sparse_softmax_cross_entropy(
36 | logits=logits_by_num_classes, labels=labels_flat)
37 |
38 | if not args.freeze_batch_norm:
39 | train_var_list = [v for v in tf.trainable_variables()]
40 | else:
41 | train_var_list = [v for v in tf.trainable_variables()
42 | if 'beta' not in v.name and 'gamma' not in v.name]
43 |
44 | global_step = tf.train.get_or_create_global_step()
45 | with tf.variable_scope("total_loss"):
46 | loss = cross_entropy + _WEIGHT_DECAY * tf.add_n(
47 | [tf.nn.l2_loss(v) for v in train_var_list])
48 | # affinity loss
49 | edge_loss, not_edge_loss = affinity_loss(labels=y, probs=logits,
50 | num_classes=args.number_of_classes,
51 | kld_margin=args.kld_margin)
52 |
53 | dec = tf.pow(10.0, tf.cast(-global_step / args.max_iter, tf.float32))
54 | aff_loss = tf.reduce_mean(edge_loss) * args.kld_lambda_1 * dec
55 | aff_loss += tf.reduce_mean(not_edge_loss) * args.kld_lambda_2 * dec
56 |
57 | total_loss = loss + aff_loss
58 | tf.summary.scalar('loss', total_loss)
59 | # 优化函数
60 | learning_rate = tf.train.polynomial_decay(
61 | args.initial_learning_rate,
62 | tf.cast(global_step, tf.int32) - args.initial_global_step,
63 | args.max_iter, args.end_learning_rate, power=0.9) # args.max_iter = 30000 args.initial_global_step=0
64 | tf.summary.scalar('learning_rate', learning_rate)
65 |
66 | optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
67 |
68 | # Batch norm requires update ops to be added as a dependency to the train_op
69 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
70 | with tf.control_dependencies(update_ops):
71 | train_op = optimizer.minimize(total_loss, global_step, var_list=train_var_list)
72 |
73 | # metrics
74 | preds_flat = tf.reshape(pred_classes, [-1, ])
75 | confusion_matrix = tf.confusion_matrix(labels_flat, preds_flat, num_classes=args.number_of_classes)
76 |
77 | correct_pred = tf.equal(preds_flat, labels_flat)
78 | accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
79 | tf.summary.scalar('accuracy', accuracy)
80 |
81 | def compute_mean_iou(total_cm, name='mean_iou'):
82 | """Compute the mean intersection-over-union via the confusion matrix."""
83 | sum_over_row = tf.to_float(tf.reduce_sum(total_cm, 0))
84 | sum_over_col = tf.to_float(tf.reduce_sum(total_cm, 1))
85 | cm_diag = tf.to_float(tf.diag_part(total_cm))
86 | denominator = sum_over_row + sum_over_col - cm_diag
87 |
88 | # The mean is only computed over classes that appear in the
89 | # label or prediction tensor. If the denominator is 0, we need to
90 | # ignore the class.
91 | num_valid_entries = tf.reduce_sum(tf.cast(
92 | tf.not_equal(denominator, 0), dtype=tf.float32))
93 |
94 | # If the value of the denominator is 0, set it to 1 to avoid
95 | # zero division.
96 | denominator = tf.where(
97 | tf.greater(denominator, 0),
98 | denominator,
99 | tf.ones_like(denominator))
100 | iou = tf.div(cm_diag, denominator)
101 |
102 | for i in range(args.number_of_classes):
103 | tf.identity(iou[i], name='train_iou_class{}'.format(i))
104 | tf.summary.scalar('train_iou_class{}'.format(i), iou[i])
105 |
106 | # If the number of valid entries is 0 (no classes) we return 0.
107 | result = tf.where(
108 | tf.greater(num_valid_entries, 0),
109 | tf.reduce_sum(iou, name=name) / num_valid_entries,
110 | 0)
111 | return result
112 |
113 | mean_iou = compute_mean_iou(confusion_matrix)
114 |
115 | tf.summary.scalar('mean_iou', mean_iou)
116 |
117 | metrics = {'px_accuracy': accuracy, 'mean_iou': mean_iou, 'confusion_matrix': confusion_matrix}
118 |
119 | return total_loss, train_op, metrics
120 |
121 |
122 | # 没有对输入的合法性进行校验
123 | # 使用时需要注意
124 | def kappa(confusion_matrix):
125 | """计算kappa值系数"""
126 | confusion_matrix = confusion_matrix.astype(np.int64)
127 | pe_rows = np.sum(confusion_matrix, axis=0) # 每一类真实值
128 | pe_cols = np.sum(confusion_matrix, axis=1) # 预测出每一类的总数
129 | sum_total = sum(pe_cols) # 样本总数
130 | pe = np.dot(pe_rows, pe_cols) / float(sum_total * sum_total)
131 | po = np.trace(confusion_matrix) / float(sum_total)
132 | return (po - pe) / (1 - pe)
133 |
134 | def affinity_loss(labels,
135 | probs,
136 | num_classes,
137 | kld_margin):
138 | """Affinity Field (AFF) loss.
139 |
140 | This function computes AFF loss. There are several components in the
141 | function:
142 | 1) extracts edges from the ground-truth labels.
143 | 2) extracts ignored pixels and their paired pixels (the neighboring
144 | pixels on the eight corners).
145 | 3) extracts neighboring pixels on the eight corners from a 3x3 patch.
146 | 4) computes KL-Divergence between center pixels and their neighboring
147 | pixels from the eight corners.
148 |
149 | Args:
150 | labels: A tensor of size [batch_size, height_in, width_in], indicating
151 | semantic segmentation ground-truth labels.
152 | probs: A tensor of size [batch_size, height_in, width_in, num_classes],
153 | indicating segmentation predictions.
154 | num_classes: A number indicating the total number of valid classes.
155 | kld_margin: A number indicating the margin for KL-Divergence at edge.
156 |
157 | Returns:
158 | Two 1-D tensors value indicating the loss at edge and non-edge.
159 | """
160 | # Compute ignore map (e.g, label of 255 and their paired pixels).
161 | labels = tf.squeeze(labels, axis=-1) # NxHxW
162 | ignore = nnx.ignores_from_label(labels, num_classes, 1) # NxHxWx8
163 | not_ignore = tf.logical_not(ignore)
164 | not_ignore = tf.expand_dims(not_ignore, axis=3) # NxHxWx1x8 # 不是ignore是true
165 |
166 | # Compute edge map.
167 | one_hot_lab = tf.one_hot(labels, depth=num_classes)
168 | edge = nnx.edges_from_label(one_hot_lab, 1, 255) # NxHxWxCx8 # zhenjie不相等是ture
169 |
170 | # Remove ignored pixels from the edge/non-edge.
171 | edge = tf.logical_and(edge, not_ignore) # zhenjie NxHxWxCx8
172 | not_edge = tf.logical_and(tf.logical_not(edge), not_ignore) # zhenjie NxHxWxCx8
173 |
174 | edge_indices = tf.where(tf.reshape(edge, [-1]))
175 | not_edge_indices = tf.where(tf.reshape(not_edge, [-1]))
176 |
177 | # Extract eight corner from the center in a patch as paired pixels.
178 | probs_paired = nnx.eightcorner_activation(probs, 1) # NxHxWxCx8
179 | probs = tf.expand_dims(probs, axis=-1) # NxHxWxCx1
180 | bot_epsilon = tf.constant(1e-4, name='bot_epsilon')
181 | top_epsilon = tf.constant(1.0, name='top_epsilon')
182 | neg_probs = tf.clip_by_value(
183 | 1-probs, bot_epsilon, top_epsilon)
184 | probs = tf.clip_by_value(
185 | probs, bot_epsilon, top_epsilon)
186 | neg_probs_paired= tf.clip_by_value(
187 | 1-probs_paired, bot_epsilon, top_epsilon)
188 | probs_paired = tf.clip_by_value(
189 | probs_paired, bot_epsilon, top_epsilon)
190 |
191 | # Compute KL-Divergence.
192 | kldiv = probs_paired*tf.log(probs_paired/probs)
193 | kldiv += neg_probs_paired*tf.log(neg_probs_paired/neg_probs)
194 | not_edge_loss = kldiv
195 | edge_loss = tf.maximum(0.0, kld_margin-kldiv)
196 |
197 | not_edge_loss = tf.reshape(not_edge_loss, [-1])
198 | not_edge_loss = tf.gather(not_edge_loss, not_edge_indices)
199 | edge_loss = tf.reshape(edge_loss, [-1])
200 | edge_loss = tf.gather(edge_loss, edge_indices)
201 |
202 | return edge_loss, not_edge_loss
203 |
--------------------------------------------------------------------------------
/tools_deeplabv3.py:
--------------------------------------------------------------------------------
1 | from NET import deeplab_v3
2 | from utils import preprocessing
3 | import tensorflow as tf
4 | import numpy as np
5 |
6 | _WEIGHT_DECAY = 5e-4
7 |
8 |
9 | def get_loss_pre_metrics(x, y, is_training, batch_size, args):
10 | # 恢复图像
11 | images = tf.cast(x, tf.uint8)
12 |
13 | # 前向传播
14 | logits = tf.cond(is_training, true_fn=lambda: deeplab_v3.deeplab_v3(x, args, is_training=True, reuse=False),
15 | false_fn=lambda: deeplab_v3.deeplab_v3(x, args, is_training=False, reuse=True))
16 | pred_classes = tf.expand_dims(tf.argmax(logits, axis=3, output_type=tf.int32), axis=3)
17 |
18 | # 解码预测结果
19 | pred_decoded_labels = tf.py_func(preprocessing.decode_labels, [pred_classes, batch_size, args.number_of_classes], tf.uint8)
20 |
21 | # 解码标签
22 | gt_decoded_labels = tf.py_func(preprocessing.decode_labels, [y, batch_size, args.number_of_classes], tf.uint8)
23 |
24 |
25 | tf.summary.image('images', tf.concat(axis=2, values=[images, gt_decoded_labels, pred_decoded_labels]),
26 | max_outputs=args.tensorboard_images_max_outputs)
27 |
28 | # 求loss
29 | labels = tf.squeeze(y, axis=3) # reduce the channel dimension.
30 | logits_by_num_classes = tf.reshape(logits, [-1, args.number_of_classes])
31 | labels_flat = tf.reshape(labels, [-1, ])
32 |
33 | cross_entropy = tf.losses.sparse_softmax_cross_entropy(
34 | logits=logits_by_num_classes, labels=labels_flat)
35 |
36 | if not args.freeze_batch_norm:
37 | train_var_list = [v for v in tf.trainable_variables()]
38 | else:
39 | train_var_list = [v for v in tf.trainable_variables()
40 | if 'beta' not in v.name and 'gamma'not in v.name]
41 |
42 | #train_var_list = [v for v in tf.trainable_variables()]
43 | with tf.variable_scope("total_loss"):
44 | loss = cross_entropy + _WEIGHT_DECAY * tf.add_n(
45 | [tf.nn.l2_loss(v) for v in train_var_list])
46 | tf.summary.scalar('loss', loss)
47 |
48 | # 优化函数
49 | global_step = tf.train.get_or_create_global_step()
50 | learning_rate = tf.train.polynomial_decay(
51 | args.initial_learning_rate,
52 | tf.cast(global_step, tf.int32) - args.initial_global_step,
53 | args.max_iter, args.end_learning_rate, power=0.9) # args.max_iter = 30000 args.initial_global_step=0
54 | tf.summary.scalar('learning_rate', learning_rate)
55 |
56 | optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
57 |
58 | # Batch norm requires update ops to be added as a dependency to the train_op
59 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
60 | with tf.control_dependencies(update_ops):
61 | train_op = optimizer.minimize(loss, global_step, var_list=train_var_list)
62 |
63 | # metrics
64 | preds_flat = tf.reshape(pred_classes, [-1, ])
65 | confusion_matrix = tf.confusion_matrix(labels_flat, preds_flat, num_classes=args.number_of_classes)
66 |
67 | correct_pred = tf.equal(preds_flat, labels_flat)
68 | accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
69 | tf.summary.scalar('accuracy', accuracy)
70 |
71 | def compute_mean_iou(total_cm, name='mean_iou'):
72 | """Compute the mean intersection-over-union via the confusion matrix."""
73 | sum_over_row = tf.to_float(tf.reduce_sum(total_cm, 0))
74 | sum_over_col = tf.to_float(tf.reduce_sum(total_cm, 1))
75 | cm_diag = tf.to_float(tf.diag_part(total_cm))
76 | denominator = sum_over_row + sum_over_col - cm_diag
77 |
78 | # The mean is only computed over classes that appear in the
79 | # label or prediction tensor. If the denominator is 0, we need to
80 | # ignore the class.
81 | num_valid_entries = tf.reduce_sum(tf.cast(
82 | tf.not_equal(denominator, 0), dtype=tf.float32))
83 |
84 | # If the value of the denominator is 0, set it to 1 to avoid
85 | # zero division.
86 | denominator = tf.where(
87 | tf.greater(denominator, 0),
88 | denominator,
89 | tf.ones_like(denominator))
90 | iou = tf.div(cm_diag, denominator)
91 |
92 | for i in range(args.number_of_classes):
93 | tf.identity(iou[i], name='train_iou_class{}'.format(i))
94 | tf.summary.scalar('train_iou_class{}'.format(i), iou[i])
95 |
96 | # If the number of valid entries is 0 (no classes) we return 0.
97 | result = tf.where(
98 | tf.greater(num_valid_entries, 0),
99 | tf.reduce_sum(iou, name=name) / num_valid_entries,
100 | 0)
101 | return result
102 |
103 | mean_iou = compute_mean_iou(confusion_matrix)
104 |
105 | tf.summary.scalar('mean_iou', mean_iou)
106 |
107 | metrics = {'px_accuracy': accuracy, 'mean_iou': mean_iou, 'confusion_matrix': confusion_matrix}
108 |
109 | return loss, train_op, metrics
110 |
111 |
112 | # 没有对输入的合法性进行校验
113 | # 使用时需要注意
114 | def kappa(confusion_matrix):
115 | """计算kappa值系数"""
116 | confusion_matrix = confusion_matrix.astype(np.int64)
117 | pe_rows = np.sum(confusion_matrix, axis=0) # 每一类真实值
118 | pe_cols = np.sum(confusion_matrix, axis=1) # 预测出每一类的总数
119 | sum_total = sum(pe_cols) # 样本总数
120 | pe = np.dot(pe_rows, pe_cols) / float(sum_total * sum_total)
121 | po = np.trace(confusion_matrix) / float(sum_total)
122 | return (po - pe) / (1 - pe)
--------------------------------------------------------------------------------
/tools_deeplabv3_DA.py:
--------------------------------------------------------------------------------
1 | from NET import deeplabv3_DA
2 | from utils import preprocessing
3 | import tensorflow as tf
4 | import numpy as np
5 |
6 | _WEIGHT_DECAY = 5e-4
7 |
8 |
9 | def get_loss_pre_metrics(x, y, is_training, batch_size, args):
10 | # 恢复图像
11 | images = tf.cast(x, tf.uint8)
12 |
13 | # 前向传播
14 | logits = tf.cond(is_training, true_fn=lambda: deeplabv3_DA.deeplabv3_DA(x, args, is_training=True, reuse=False),
15 | false_fn=lambda: deeplabv3_DA.deeplabv3_DA(x, args, is_training=False, reuse=True))
16 | pred_classes = tf.expand_dims(tf.argmax(logits, axis=3, output_type=tf.int32), axis=3)
17 |
18 | # 解码预测结果
19 | pred_decoded_labels = tf.py_func(preprocessing.decode_labels, [pred_classes, batch_size, args.number_of_classes], tf.uint8)
20 |
21 |
22 | # 解码标签
23 | gt_decoded_labels = tf.py_func(preprocessing.decode_labels, [y, batch_size, args.number_of_classes], tf.uint8)
24 |
25 |
26 | tf.summary.image('images', tf.concat(axis=2, values=[images, gt_decoded_labels, pred_decoded_labels]),
27 | max_outputs=args.tensorboard_images_max_outputs)
28 |
29 | # 求loss
30 | labels = tf.squeeze(y, axis=3) # reduce the channel dimension.
31 | logits_by_num_classes = tf.reshape(logits, [-1, args.number_of_classes])
32 | labels_flat = tf.reshape(labels, [-1, ])
33 |
34 | cross_entropy = tf.losses.sparse_softmax_cross_entropy(
35 | logits=logits_by_num_classes, labels=labels_flat)
36 |
37 | if not args.freeze_batch_norm:
38 | train_var_list = [v for v in tf.trainable_variables()]
39 | else:
40 | train_var_list = [v for v in tf.trainable_variables()
41 | if 'beta' not in v.name and 'gamma' not in v.name]
42 |
43 | #train_var_list = [v for v in tf.trainable_variables()]
44 | with tf.variable_scope("total_loss"):
45 | loss = cross_entropy + _WEIGHT_DECAY * tf.add_n(
46 | [tf.nn.l2_loss(v) for v in train_var_list])
47 | tf.summary.scalar('loss', loss)
48 |
49 | # 优化函数
50 | global_step = tf.train.get_or_create_global_step()
51 | learning_rate = tf.train.polynomial_decay(
52 | args.initial_learning_rate,
53 | tf.cast(global_step, tf.int32) - args.initial_global_step,
54 | args.max_iter, args.end_learning_rate, power=0.9) # args.max_iter = 30000 args.initial_global_step=0
55 | tf.summary.scalar('learning_rate', learning_rate)
56 |
57 | optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
58 |
59 | # Batch norm requires update ops to be added as a dependency to the train_op
60 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
61 | with tf.control_dependencies(update_ops):
62 | train_op = optimizer.minimize(loss, global_step, var_list=train_var_list)
63 |
64 | # metrics
65 | preds_flat = tf.reshape(pred_classes, [-1, ])
66 | confusion_matrix = tf.confusion_matrix(labels_flat, preds_flat, num_classes=args.number_of_classes)
67 |
68 | correct_pred = tf.equal(preds_flat, labels_flat)
69 | accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
70 | tf.summary.scalar('accuracy', accuracy)
71 |
72 | def compute_mean_iou(total_cm, name='mean_iou'):
73 | """Compute the mean intersection-over-union via the confusion matrix."""
74 | sum_over_row = tf.to_float(tf.reduce_sum(total_cm, 0))
75 | sum_over_col = tf.to_float(tf.reduce_sum(total_cm, 1))
76 | cm_diag = tf.to_float(tf.diag_part(total_cm))
77 | denominator = sum_over_row + sum_over_col - cm_diag
78 |
79 | # The mean is only computed over classes that appear in the
80 | # label or prediction tensor. If the denominator is 0, we need to
81 | # ignore the class.
82 | num_valid_entries = tf.reduce_sum(tf.cast(
83 | tf.not_equal(denominator, 0), dtype=tf.float32))
84 |
85 | # If the value of the denominator is 0, set it to 1 to avoid
86 | # zero division.
87 | denominator = tf.where(
88 | tf.greater(denominator, 0),
89 | denominator,
90 | tf.ones_like(denominator))
91 | iou = tf.div(cm_diag, denominator)
92 |
93 | for i in range(args.number_of_classes):
94 | tf.identity(iou[i], name='train_iou_class{}'.format(i))
95 | tf.summary.scalar('train_iou_class{}'.format(i), iou[i])
96 |
97 | # If the number of valid entries is 0 (no classes) we return 0.
98 | result = tf.where(
99 | tf.greater(num_valid_entries, 0),
100 | tf.reduce_sum(iou, name=name) / num_valid_entries,
101 | 0)
102 | return result
103 |
104 | mean_iou = compute_mean_iou(confusion_matrix)
105 |
106 | tf.summary.scalar('mean_iou', mean_iou)
107 |
108 | metrics = {'px_accuracy': accuracy, 'mean_iou': mean_iou, 'confusion_matrix': confusion_matrix}
109 |
110 | return loss, train_op, metrics
111 |
112 |
113 | # 没有对输入的合法性进行校验
114 | # 使用时需要注意
115 | def kappa(confusion_matrix):
116 | """计算kappa值系数"""
117 | confusion_matrix = confusion_matrix.astype(np.int64)
118 | pe_rows = np.sum(confusion_matrix, axis=0) # 每一类真实值
119 | pe_cols = np.sum(confusion_matrix, axis=1) # 预测出每一类的总数
120 | sum_total = sum(pe_cols) # 样本总数
121 | pe = np.dot(pe_rows, pe_cols) / float(sum_total * sum_total)
122 | po = np.trace(confusion_matrix) / float(sum_total)
123 | return (po - pe) / (1 - pe)
--------------------------------------------------------------------------------
/tools_deeplabv3plus.py:
--------------------------------------------------------------------------------
1 | from NET import deeplabv3_plus
2 | from utils import preprocessing
3 | import tensorflow as tf
4 | import numpy as np
5 |
6 | _WEIGHT_DECAY = 5e-4
7 |
8 |
9 | def get_loss_pre_metrics(x, y, is_training, batch_size, args):
10 | # 恢复图像
11 | images = tf.cast(x, tf.uint8)
12 |
13 | # 前向传播
14 | logits = tf.cond(is_training, true_fn=lambda: deeplabv3_plus.deeplabv3_plus(x, args, is_training=True, reuse=False),
15 | false_fn=lambda: deeplabv3_plus.deeplabv3_plus(x, args, is_training=False, reuse=True))
16 | pred_classes = tf.expand_dims(tf.argmax(logits, axis=3, output_type=tf.int32), axis=3)
17 |
18 | # 解码预测结果
19 | pred_decoded_labels = tf.py_func(preprocessing.decode_labels, [pred_classes, batch_size, args.number_of_classes], tf.uint8)
20 |
21 | # 解码标签
22 | gt_decoded_labels = tf.py_func(preprocessing.decode_labels, [y, batch_size, args.number_of_classes], tf.uint8)
23 |
24 |
25 | tf.summary.image('images', tf.concat(axis=2, values=[images, gt_decoded_labels, pred_decoded_labels]),
26 | max_outputs=args.tensorboard_images_max_outputs)
27 |
28 | # 求loss
29 | labels = tf.squeeze(y, axis=3) # reduce the channel dimension.
30 | logits_by_num_classes = tf.reshape(logits, [-1, args.number_of_classes])
31 | labels_flat = tf.reshape(labels, [-1, ])
32 |
33 | cross_entropy = tf.losses.sparse_softmax_cross_entropy(
34 | logits=logits_by_num_classes, labels=labels_flat)
35 |
36 | if not args.freeze_batch_norm:
37 | train_var_list = [v for v in tf.trainable_variables()]
38 | else:
39 | train_var_list = [v for v in tf.trainable_variables()
40 | if 'beta' not in v.name and 'gamma' not in v.name]
41 |
42 | #train_var_list = [v for v in tf.trainable_variables()]
43 | with tf.variable_scope("total_loss"):
44 | loss = cross_entropy + _WEIGHT_DECAY * tf.add_n(
45 | [tf.nn.l2_loss(v) for v in train_var_list])
46 | tf.summary.scalar('loss', loss)
47 |
48 | # 优化函数
49 | global_step = tf.train.get_or_create_global_step()
50 | learning_rate = tf.train.polynomial_decay(
51 | args.initial_learning_rate,
52 | tf.cast(global_step, tf.int32) - args.initial_global_step,
53 | args.max_iter, args.end_learning_rate, power=0.9) # args.max_iter = 30000 args.initial_global_step=0
54 | tf.summary.scalar('learning_rate', learning_rate)
55 |
56 | optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
57 |
58 | # Batch norm requires update ops to be added as a dependency to the train_op
59 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
60 | with tf.control_dependencies(update_ops):
61 | train_op = optimizer.minimize(loss, global_step, var_list=train_var_list)
62 |
63 | # metrics
64 | preds_flat = tf.reshape(pred_classes, [-1, ])
65 | confusion_matrix = tf.confusion_matrix(labels_flat, preds_flat, num_classes=args.number_of_classes)
66 |
67 | correct_pred = tf.equal(preds_flat, labels_flat)
68 | accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
69 | tf.summary.scalar('accuracy', accuracy)
70 |
71 | def compute_mean_iou(total_cm, name='mean_iou'):
72 | """Compute the mean intersection-over-union via the confusion matrix."""
73 | sum_over_row = tf.to_float(tf.reduce_sum(total_cm, 0))
74 | sum_over_col = tf.to_float(tf.reduce_sum(total_cm, 1))
75 | cm_diag = tf.to_float(tf.diag_part(total_cm))
76 | denominator = sum_over_row + sum_over_col - cm_diag
77 |
78 | # The mean is only computed over classes that appear in the
79 | # label or prediction tensor. If the denominator is 0, we need to
80 | # ignore the class.
81 | num_valid_entries = tf.reduce_sum(tf.cast(
82 | tf.not_equal(denominator, 0), dtype=tf.float32))
83 |
84 | # If the value of the denominator is 0, set it to 1 to avoid
85 | # zero division.
86 | denominator = tf.where(
87 | tf.greater(denominator, 0),
88 | denominator,
89 | tf.ones_like(denominator))
90 | iou = tf.div(cm_diag, denominator)
91 |
92 | for i in range(args.number_of_classes):
93 | tf.identity(iou[i], name='train_iou_class{}'.format(i))
94 | tf.summary.scalar('train_iou_class{}'.format(i), iou[i])
95 |
96 | # If the number of valid entries is 0 (no classes) we return 0.
97 | result = tf.where(
98 | tf.greater(num_valid_entries, 0),
99 | tf.reduce_sum(iou, name=name) / num_valid_entries,
100 | 0)
101 | return result
102 |
103 | mean_iou = compute_mean_iou(confusion_matrix)
104 |
105 | tf.summary.scalar('mean_iou', mean_iou)
106 |
107 | metrics = {'px_accuracy': accuracy, 'mean_iou': mean_iou, 'confusion_matrix': confusion_matrix}
108 |
109 | return loss, train_op, metrics
110 |
111 |
112 | # 没有对输入的合法性进行校验
113 | # 使用时需要注意
114 | def kappa(confusion_matrix):
115 | """计算kappa值系数"""
116 | confusion_matrix = confusion_matrix.astype(np.int64)
117 | pe_rows = np.sum(confusion_matrix, axis=0) # 每一类真实值
118 | pe_cols = np.sum(confusion_matrix, axis=1) # 预测出每一类的总数
119 | sum_total = sum(pe_cols) # 样本总数
120 | pe = np.dot(pe_rows, pe_cols) / float(sum_total * sum_total)
121 | po = np.trace(confusion_matrix) / float(sum_total)
122 | return (po - pe) / (1 - pe)
--------------------------------------------------------------------------------
/tools_psp.py:
--------------------------------------------------------------------------------
1 | from NET import pspnet
2 | from utils import preprocessing
3 | import tensorflow as tf
4 | import numpy as np
5 |
6 | _WEIGHT_DECAY = 1e-4
7 |
8 |
9 | def get_loss_pre_metrics(x, y, is_training, batch_size, args):
10 | # 恢复图像
11 | images = tf.cast(x, tf.uint8)
12 |
13 | # 前向传播
14 | logits = tf.cond(is_training, true_fn=lambda: pspnet.pspnet_resnet(x, args, is_training=True, reuse=False),
15 | false_fn=lambda: pspnet.pspnet_resnet(x, args, is_training=False, reuse=True))
16 | pred_classes = tf.expand_dims(tf.argmax(logits, axis=3, output_type=tf.int32), axis=3)
17 |
18 | # 解码预测结果
19 | pred_decoded_labels = tf.py_func(preprocessing.decode_labels, [pred_classes, batch_size, args.number_of_classes], tf.uint8)
20 |
21 |
22 | # 解码标签
23 | gt_decoded_labels = tf.py_func(preprocessing.decode_labels, [y, batch_size, args.number_of_classes], tf.uint8)
24 |
25 |
26 | tf.summary.image('images', tf.concat(axis=2, values=[images, gt_decoded_labels, pred_decoded_labels]),
27 | max_outputs=args.tensorboard_images_max_outputs)
28 |
29 | # 求loss
30 | labels = tf.squeeze(y, axis=3) # reduce the channel dimension.
31 | logits_by_num_classes = tf.reshape(logits, [-1, args.number_of_classes])
32 | labels_flat = tf.reshape(labels, [-1, ])
33 |
34 | cross_entropy = tf.losses.sparse_softmax_cross_entropy(
35 | logits=logits_by_num_classes, labels=labels_flat)
36 |
37 | if not args.freeze_batch_norm:
38 | train_var_list = [v for v in tf.trainable_variables()]
39 | else:
40 | train_var_list = [v for v in tf.trainable_variables()
41 | if 'beta' not in v.name and 'gamma' not in v.name]
42 |
43 | with tf.variable_scope("total_loss"):
44 | loss = cross_entropy + _WEIGHT_DECAY * tf.add_n(
45 | [tf.nn.l2_loss(v) for v in train_var_list])
46 | tf.summary.scalar('loss', loss)
47 |
48 | # 优化函数
49 | global_step = tf.train.get_or_create_global_step()
50 | learning_rate = tf.train.polynomial_decay(
51 | args.initial_learning_rate,
52 | tf.cast(global_step, tf.int32) - args.initial_global_step,
53 | args.max_iter, args.end_learning_rate, power=0.9) # args.max_iter = 30000 args.initial_global_step=0
54 | tf.summary.scalar('learning_rate', learning_rate)
55 |
56 | optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
57 |
58 | # Batch norm requires update ops to be added as a dependency to the train_op
59 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
60 | with tf.control_dependencies(update_ops):
61 | train_op = optimizer.minimize(loss, global_step, var_list=train_var_list)
62 |
63 | # metrics
64 | preds_flat = tf.reshape(pred_classes, [-1, ])
65 | confusion_matrix = tf.confusion_matrix(labels_flat, preds_flat, num_classes=args.number_of_classes)
66 |
67 | correct_pred = tf.equal(preds_flat, labels_flat)
68 | accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
69 | tf.summary.scalar('accuracy', accuracy)
70 |
71 | def compute_mean_iou(total_cm, name='mean_iou'):
72 | """Compute the mean intersection-over-union via the confusion matrix."""
73 | sum_over_row = tf.to_float(tf.reduce_sum(total_cm, 0))
74 | sum_over_col = tf.to_float(tf.reduce_sum(total_cm, 1))
75 | cm_diag = tf.to_float(tf.diag_part(total_cm))
76 | denominator = sum_over_row + sum_over_col - cm_diag
77 |
78 | # The mean is only computed over classes that appear in the
79 | # label or prediction tensor. If the denominator is 0, we need to
80 | # ignore the class.
81 | num_valid_entries = tf.reduce_sum(tf.cast(
82 | tf.not_equal(denominator, 0), dtype=tf.float32))
83 |
84 | # If the value of the denominator is 0, set it to 1 to avoid
85 | # zero division.
86 | denominator = tf.where(
87 | tf.greater(denominator, 0),
88 | denominator,
89 | tf.ones_like(denominator))
90 | iou = tf.div(cm_diag, denominator)
91 |
92 | for i in range(args.number_of_classes):
93 | tf.identity(iou[i], name='train_iou_class{}'.format(i))
94 | tf.summary.scalar('train_iou_class{}'.format(i), iou[i])
95 |
96 | # If the number of valid entries is 0 (no classes) we return 0.
97 | result = tf.where(
98 | tf.greater(num_valid_entries, 0),
99 | tf.reduce_sum(iou, name=name) / num_valid_entries,
100 | 0)
101 | return result
102 |
103 | mean_iou = compute_mean_iou(confusion_matrix)
104 |
105 | tf.summary.scalar('mean_iou', mean_iou)
106 |
107 | metrics = {'px_accuracy': accuracy, 'mean_iou': mean_iou, 'confusion_matrix': confusion_matrix}
108 |
109 | return loss, train_op, metrics
110 |
111 |
112 | # 没有对输入的合法性进行校验
113 | # 使用时需要注意
114 | def kappa(confusion_matrix):
115 | """计算kappa值系数"""
116 | confusion_matrix = confusion_matrix.astype(np.int64)
117 | pe_rows = np.sum(confusion_matrix, axis=0) # 每一类真实值
118 | pe_cols = np.sum(confusion_matrix, axis=1) # 预测出每一类的总数
119 | sum_total = sum(pe_cols) # 样本总数
120 | pe = np.dot(pe_rows, pe_cols) / float(sum_total * sum_total)
121 | po = np.trace(confusion_matrix) / float(sum_total)
122 | return (po - pe) / (1 - pe)
--------------------------------------------------------------------------------
/train_aaf.py:
--------------------------------------------------------------------------------
1 | from GeneratingBatchSize.GetDataset import train_or_eval_input_fn
2 | import tensorflow as tf
3 | import os
4 | import argparse
5 | import tools_aaf
6 | import datetime
7 | import math
8 |
9 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
10 |
11 | batch_size = 4
12 | summary_path = "./summary_aff/"
13 | checkpoint_path_voc = ""
14 | checkpoint_path = "./checkpoint_aff/"
15 | EPOCHS = 50
16 | train_set_length = 5000
17 | eval_set_length = 500
18 |
19 | parser = argparse.ArgumentParser()
20 |
21 | #添加参数
22 | envarg = parser.add_argument_group('Training params')
23 | # BN params
24 | envarg.add_argument("--batch_norm_epsilon", type=float, default=1e-5, help="batch norm epsilon argument for batch normalization")
25 | envarg.add_argument('--batch_norm_decay', type=float, default=0.9997, help='batch norm decay argument for batch normalization.')
26 | envarg.add_argument('--freeze_batch_norm', type=bool, default=True, help='Freeze batch normalization parameters during the training.')
27 | # the number of classes
28 | envarg.add_argument("--number_of_classes", type=int, default=16, help="Number of classes to be predicted.")
29 |
30 | # regularizer
31 | envarg.add_argument("--l2_regularizer", type=float, default=0.0001, help="l2 regularizer parameter.")
32 |
33 | # for deeplabv3
34 | envarg.add_argument("--multi_grid", type=list, default=[1, 2, 4], help="Spatial Pyramid Pooling rates")
35 | envarg.add_argument("--output_stride", type=int, default=16, help="Spatial Pyramid Pooling rates")
36 |
37 | # the base network
38 | envarg.add_argument("--resnet_model", default="resnet_v2_50", choices=["resnet_v2_50", "resnet_v2_101", "resnet_v2_152", "resnet_v2_200"], help="Resnet model to use as feature extractor. Choose one of: resnet_v2_50 or resnet_v2_101")
39 |
40 | # the pre_trained model for example resnet50 101 and so on
41 | envarg.add_argument('--pre_trained_model', type=str, default='./pre_trained_model/resnet_v2_50/resnet_v2_50.ckpt',
42 | help='Path to the pre-trained model checkpoint.')
43 |
44 | # max number of batch elements to tensorboard
45 | parser.add_argument('--tensorboard_images_max_outputs', type=int, default=6,
46 | help='Max number of batch elements to generate for Tensorboard.')
47 | # poly learn_rate
48 | parser.add_argument('--initial_learning_rate', type=float, default=7e-4,
49 | help='Initial learning rate for the optimizer.')
50 |
51 | parser.add_argument('--end_learning_rate', type=float, default=1e-6,
52 | help='End learning rate for the optimizer.')
53 |
54 | parser.add_argument('--initial_global_step', type=int, default=0,
55 | help='Initial global step for controlling learning rate when fine-tuning model.')
56 | parser.add_argument('--max_iter', type=int, default=62500,
57 | help='Number of maximum iteration used for "poly" learning rate policy.')
58 |
59 | # aaf 参数
60 | parser.add_argument('--kld_margin', type=float, default=3.0, help='margin for affinity loss')
61 | parser.add_argument('--kld_lambda_1', type=float, default=1.0, help='Lambda for affinity loss at edge.')
62 | parser.add_argument('--kld_lambda_2', type=float, default=1.0, help='Lambda for affinity loss not at edge.')
63 | args = parser.parse_args()
64 |
65 | def main():
66 | os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '0'
67 |
68 | is_train = tf.placeholder(tf.bool, shape=[])
69 | x = tf.placeholder(dtype=tf.float32, shape=[batch_size, 1000, 1000, 3], name="image_batch")
70 | y = tf.placeholder(dtype=tf.int32, shape=[batch_size, 1000, 1000, 1], name="label_batch")
71 |
72 | train_dataset = train_or_eval_input_fn(is_training=True,
73 | data_dir="./DatasetNew/train/", batch_size=batch_size)
74 | eval_dataset = train_or_eval_input_fn(is_training=False,
75 | data_dir="./DatasetNew/val/", batch_size=batch_size, num_epochs=1)
76 | iterator_train = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
77 | next_batch = iterator_train.get_next()
78 | training_init_op = iterator_train.make_initializer(train_dataset)
79 | evaling_init_op = iterator_train.make_initializer(eval_dataset)
80 |
81 | loss, train_op, metrics = tools_aaf.get_loss_pre_metrics(x, y, is_train, batch_size, args)
82 |
83 | accuracy = metrics["px_accuracy"]
84 | mean_iou = metrics["mean_iou"]
85 | confusion_matrix = metrics['confusion_matrix']
86 |
87 | summary_op = tf.summary.merge_all()
88 | init_op = tf.group(
89 | tf.local_variables_initializer(),
90 | tf.global_variables_initializer()
91 | )
92 | # 首次运行从deeplabv3中获取权重需要剔除logits层
93 | exclude = ['global_step']
94 | variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude)
95 |
96 | saver_first = tf.train.Saver(variables_to_restore)
97 | saver = tf.train.Saver(max_to_keep=50)
98 | summary_writer_train = tf.summary.FileWriter(summary_path + "train/")
99 | summary_writer_val = tf.summary.FileWriter(summary_path + "val/")
100 | # 运行图
101 | config = tf.ConfigProto()
102 | config.gpu_options.allow_growth = True
103 | with tf.Session() as sess:
104 | sess.run(init_op, feed_dict={is_train: True})
105 | ckpt = tf.train.get_checkpoint_state(checkpoint_path_voc)
106 | if ckpt and ckpt.model_checkpoint_path:
107 | saver_first.restore(sess, ckpt.model_checkpoint_path)
108 | sess.graph.finalize()
109 |
110 | train_batches_of_epoch = int(math.ceil(train_set_length / batch_size))
111 | val_batches_of_epoch = int(math.floor(eval_set_length / batch_size))
112 | for epoch in range(EPOCHS):
113 | sess.run(training_init_op)
114 | print("{} Epoch number: {}".format(datetime.datetime.now(), epoch + 1))
115 | # step = 1 (epoch * train_batches_of_epoch), ((epoch + 1) * train_batches_of_epoch)
116 | for step in range((epoch * train_batches_of_epoch), ((epoch + 1) * train_batches_of_epoch)):
117 | img_batch, label_batch = sess.run(next_batch)
118 | sess.run([train_op], feed_dict={x: img_batch, y: label_batch, is_train: True})
119 |
120 | if (step + 1) % 625 == 0:
121 | loss_value, acc, m_iou, merge, con_matrix = sess.run(
122 | [loss, accuracy, mean_iou, summary_op, confusion_matrix],
123 | feed_dict={x: img_batch, y: label_batch, is_train: True})
124 | kappa = tools_aaf.kappa(con_matrix)
125 | print("{} {} loss = {:.4f}".format(datetime.datetime.now(), step + 1, loss_value))
126 | print("accuracy{}".format(acc))
127 | print("miou{}".format(m_iou))
128 | print("kappa{}".format(kappa))
129 | summary_writer_train.add_summary(merge, step + 1)
130 | saver.save(sess, checkpoint_path + "model.ckpt", epoch + 1)
131 | print("checkpoint saved")
132 |
133 | # 验证过程
134 | sess.run(evaling_init_op)
135 | print("{} Start validation".format(datetime.datetime.now()))
136 | test_acc = 0.0
137 | test_miou = 0.0
138 | test_kappa = 0.0
139 | test_count = 0
140 | for tag in range(val_batches_of_epoch):
141 | img_batch, label_batch = sess.run(next_batch)
142 | acc, m_iou, merge, con_matrix = sess.run(
143 | [accuracy, mean_iou, summary_op, confusion_matrix],
144 | feed_dict={x: img_batch, y: label_batch, is_train: False})
145 | kappa = tools_aaf.kappa(con_matrix)
146 | test_kappa += kappa
147 | test_acc += acc
148 | test_miou += m_iou
149 | test_count += 1
150 | test_acc /= test_count
151 | test_miou /= test_count
152 | test_kappa /= test_count
153 | s = tf.Summary(value=[
154 | tf.Summary.Value(tag="validation_accuracy", simple_value=test_acc),
155 | tf.Summary.Value(tag="validation_miou", simple_value=test_miou),
156 | tf.Summary.Value(tag="validation_kappa", simple_value=test_kappa)
157 | ])
158 | summary_writer_val.add_summary(s, epoch + 1)
159 | print("{} Validation Accuracy = {:.4f}".format(datetime.datetime.now(), test_acc))
160 | print("{} Validation miou = {:.4f}".format(datetime.datetime.now(), test_miou))
161 | print("{} Validation kappa = {:.4f}".format(datetime.datetime.now(), test_kappa))
162 |
163 | if __name__ == '__main__':
164 | tf.logging.set_verbosity(tf.logging.INFO)
165 | main()
--------------------------------------------------------------------------------
/train_deeplabv3.py:
--------------------------------------------------------------------------------
1 | from GeneratingBatchSize.GetDataset import train_or_eval_input_fn
2 | import tensorflow as tf
3 | import os
4 | import argparse
5 | import tools_deeplabv3
6 | import datetime
7 | import math
8 |
9 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
10 |
11 | batch_size = 4
12 | summary_path = "./summary_v3/"
13 | checkpoint_path_voc = ""
14 | checkpoint_path = "./checkpoint_v3/"
15 | EPOCHS = 100
16 | train_set_length = 5000
17 | eval_set_length = 500
18 |
19 | parser = argparse.ArgumentParser()
20 |
21 | #添加参数
22 | envarg = parser.add_argument_group('Training params')
23 | # BN params
24 | envarg.add_argument("--batch_norm_epsilon", type=float, default=1e-5, help="batch norm epsilon argument for batch normalization")
25 | envarg.add_argument('--batch_norm_decay', type=float, default=0.9997, help='batch norm decay argument for batch normalization.')
26 | envarg.add_argument('--freeze_batch_norm', type=bool, default=True, help='Freeze batch normalization parameters during the training.')
27 | # the number of classes
28 | envarg.add_argument("--number_of_classes", type=int, default=16, help="Number of classes to be predicted.")
29 |
30 | # regularizer
31 | envarg.add_argument("--l2_regularizer", type=float, default=0.0001, help="l2 regularizer parameter.")
32 |
33 | # for deeplabv3
34 | envarg.add_argument("--multi_grid", type=list, default=[1, 2, 4], help="Spatial Pyramid Pooling rates")
35 | envarg.add_argument("--output_stride", type=int, default=16, help="Spatial Pyramid Pooling rates")
36 |
37 | # the base network
38 | envarg.add_argument("--resnet_model", default="resnet_v2_50", choices=["resnet_v2_50", "resnet_v2_101", "resnet_v2_152", "resnet_v2_200"], help="Resnet model to use as feature extractor. Choose one of: resnet_v2_50 or resnet_v2_101")
39 |
40 | # the pre_trained model for example resnet50 101 and so on
41 | envarg.add_argument('--pre_trained_model', type=str, default='./pre_trained_model/resnet_v2_50/resnet_v2_50.ckpt',
42 | help='Path to the pre-trained model checkpoint.')
43 |
44 | # max number of batch elements to tensorboard
45 | parser.add_argument('--tensorboard_images_max_outputs', type=int, default=6,
46 | help='Max number of batch elements to generate for Tensorboard.')
47 | # poly learn_rate
48 | parser.add_argument('--initial_learning_rate', type=float, default=1e-4,
49 | help='Initial learning rate for the optimizer.')
50 |
51 | parser.add_argument('--end_learning_rate', type=float, default=5e-6,
52 | help='End learning rate for the optimizer.')
53 |
54 | parser.add_argument('--initial_global_step', type=int, default=0,
55 | help='Initial global step for controlling learning rate when fine-tuning model.')
56 | parser.add_argument('--max_iter', type=int, default=125000,
57 | help='Number of maximum iteration used for "poly" learning rate policy.')
58 | args = parser.parse_args()
59 |
60 | def main():
61 | os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '0'
62 |
63 | is_train = tf.placeholder(tf.bool, shape=[])
64 | x = tf.placeholder(dtype=tf.float32, shape=[None, 1000, 1000, 3], name="image_batch")
65 | y = tf.placeholder(dtype=tf.int32, shape=[None, 1000, 1000, 1], name="label_batch")
66 |
67 | train_dataset = train_or_eval_input_fn(is_training=True,
68 | data_dir="./DatasetNew/train/", batch_size=batch_size)
69 | eval_dataset = train_or_eval_input_fn(is_training=False,
70 | data_dir="./DatasetNew/val/", batch_size=batch_size, num_epochs=1)
71 | iterator_train = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
72 | next_batch = iterator_train.get_next()
73 | training_init_op = iterator_train.make_initializer(train_dataset)
74 | evaling_init_op = iterator_train.make_initializer(eval_dataset)
75 |
76 | loss, train_op, metrics = tools_deeplabv3.get_loss_pre_metrics(x, y, is_train, batch_size, args)
77 |
78 | accuracy = metrics["px_accuracy"]
79 | mean_iou = metrics["mean_iou"]
80 | confusion_matrix = metrics['confusion_matrix']
81 |
82 | summary_op = tf.summary.merge_all()
83 | init_op = tf.group(
84 | tf.local_variables_initializer(),
85 | tf.global_variables_initializer()
86 | )
87 | # 首次运行从deeplabv3中获取权重需要剔除logits层
88 | exclude = [args.resnet_model + '/logits', 'DeepLab_v3/logits', 'global_step']
89 | variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude)
90 |
91 | saver_first = tf.train.Saver(variables_to_restore)
92 | saver = tf.train.Saver(max_to_keep=100)
93 | summary_writer_train = tf.summary.FileWriter(summary_path + "train/")
94 | summary_writer_val = tf.summary.FileWriter(summary_path + "val/")
95 | # 运行图
96 | config = tf.ConfigProto()
97 | config.gpu_options.allow_growth = True
98 | with tf.Session(config=config) as sess:
99 | sess.run(init_op, feed_dict={is_train: True})
100 | ckpt = tf.train.get_checkpoint_state(checkpoint_path_voc)
101 | if ckpt and ckpt.model_checkpoint_path:
102 | saver_first.restore(sess, ckpt.model_checkpoint_path)
103 | sess.graph.finalize()
104 |
105 | train_batches_of_epoch = int(math.ceil(train_set_length / batch_size))
106 | val_batches_of_epoch = int(math.ceil(eval_set_length / batch_size))
107 | for epoch in range(EPOCHS):
108 | sess.run(training_init_op)
109 | print("{} Epoch number: {}".format(datetime.datetime.now(), epoch + 1))
110 | for step in range((epoch * train_batches_of_epoch), ((epoch + 1) * train_batches_of_epoch)):
111 | img_batch, label_batch = sess.run(next_batch)
112 | sess.run([train_op], feed_dict={x: img_batch, y: label_batch, is_train: True})
113 | if (step + 1) % 625 == 0:
114 | loss_value, acc, m_iou, con_matrix = sess.run(
115 | [loss, accuracy, mean_iou, confusion_matrix],
116 | feed_dict={x: img_batch, y: label_batch, is_train: True})
117 | kappa = tools_deeplabv3.kappa(con_matrix)
118 | print("{} {} loss = {:.4f}".format(datetime.datetime.now(), step + 1, loss_value))
119 | print("accuracy{}".format(acc))
120 | print("miou{}".format(m_iou))
121 | print("kappa{}".format(kappa))
122 | merge = sess.run(summary_op, feed_dict={x: img_batch, y: label_batch, is_train: True})
123 | summary_writer_train.add_summary(merge, step + 1)
124 | saver.save(sess, checkpoint_path + "model.ckpt", epoch + 1)
125 | print("checkpoint saved")
126 |
127 | # 验证过程
128 | sess.run(evaling_init_op)
129 | print("{} Start validation".format(datetime.datetime.now()))
130 | test_acc = 0.0
131 | test_miou = 0.0
132 | test_kappa = 0.0
133 | test_count = 0
134 | for tag in range(val_batches_of_epoch):
135 | img_batch, label_batch = sess.run(next_batch)
136 | acc, m_iou, con_matrix = sess.run(
137 | [accuracy, mean_iou, confusion_matrix],
138 | feed_dict={x: img_batch, y: label_batch, is_train: False})
139 |
140 | kappa = tools_deeplabv3.kappa(con_matrix)
141 | test_kappa += kappa
142 | test_acc += acc
143 | test_miou += m_iou
144 | test_count += 1
145 | test_acc /= test_count
146 | test_miou /= test_count
147 | test_kappa /= test_count
148 | s = tf.Summary(value=[
149 | tf.Summary.Value(tag="validation_accuracy", simple_value=test_acc),
150 | tf.Summary.Value(tag="validation_miou", simple_value=test_miou),
151 | tf.Summary.Value(tag="validation_kappa", simple_value=test_kappa)
152 | ])
153 | summary_writer_val.add_summary(s, epoch + 1)
154 | print("{} Validation Accuracy = {:.4f}".format(datetime.datetime.now(), test_acc))
155 | print("{} Validation miou = {:.4f}".format(datetime.datetime.now(), test_miou))
156 | print("{} Validation kappa = {:.4f}".format(datetime.datetime.now(), test_kappa))
157 |
158 | if __name__ == '__main__':
159 | tf.logging.set_verbosity(tf.logging.INFO)
160 | main()
--------------------------------------------------------------------------------
/train_deeplabv3_DA.py:
--------------------------------------------------------------------------------
1 | from GeneratingBatchSize.GetDataset import train_or_eval_input_fn
2 | import tensorflow as tf
3 | import os
4 | import argparse
5 | import tools_deeplabv3_DA
6 | import datetime
7 | import math
8 |
9 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
10 |
11 | batch_size = 3
12 | summary_path = "./summary_position/"
13 | checkpoint_path_first = "./first_checkpoint/"
14 | checkpoint_path = "./checkpoint_position/"
15 | EPOCHS = 50
16 | train_set_length = 5000
17 | eval_set_length = 500
18 |
19 | parser = argparse.ArgumentParser()
20 |
21 | #添加参数
22 | envarg = parser.add_argument_group('Training params')
23 | # BN params
24 | envarg.add_argument("--batch_norm_epsilon", type=float, default=1e-5, help="batch norm epsilon argument for batch normalization")
25 | envarg.add_argument('--batch_norm_decay', type=float, default=0.9997, help='batch norm decay argument for batch normalization.')
26 | envarg.add_argument('--freeze_batch_norm', type=bool, default=True, help='Freeze batch normalization parameters during the training.')
27 | # the number of classes
28 | envarg.add_argument("--number_of_classes", type=int, default=16, help="Number of classes to be predicted.")
29 |
30 | # regularizer
31 | envarg.add_argument("--l2_regularizer", type=float, default=0.0001, help="l2 regularizer parameter.")
32 |
33 | # for deeplabv3
34 | envarg.add_argument("--multi_grid", type=list, default=[1, 2, 4], help="Spatial Pyramid Pooling rates")
35 | envarg.add_argument("--output_stride", type=int, default=16, help="Spatial Pyramid Pooling rates")
36 |
37 | # the base network
38 | envarg.add_argument("--resnet_model", default="resnet_v2_50", choices=["resnet_v2_50", "resnet_v2_101", "resnet_v2_152", "resnet_v2_200"], help="Resnet model to use as feature extractor. Choose one of: resnet_v2_50 or resnet_v2_101")
39 |
40 | # the pre_trained model for example resnet50 101 and so on
41 | envarg.add_argument('--pre_trained_model', type=str, default='./pre_trained_model/resnet_v2_50/resnet_v2_50.ckpt',
42 | help='Path to the pre-trained model checkpoint.')
43 |
44 | # max number of batch elements to tensorboard
45 | parser.add_argument('--tensorboard_images_max_outputs', type=int, default=6,
46 | help='Max number of batch elements to generate for Tensorboard.')
47 | # poly learn_rate
48 | parser.add_argument('--initial_learning_rate', type=float, default=7e-4,
49 | help='Initial learning rate for the optimizer.')
50 |
51 | parser.add_argument('--end_learning_rate', type=float, default=1e-6,
52 | help='End learning rate for the optimizer.')
53 |
54 | parser.add_argument('--initial_global_step', type=int, default=0,
55 | help='Initial global step for controlling learning rate when fine-tuning model.')
56 | parser.add_argument('--max_iter', type=int, default=83333,
57 | help='Number of maximum iteration used for "poly" learning rate policy.')
58 | args = parser.parse_args()
59 |
60 | def main():
61 | os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '0'
62 |
63 | is_train = tf.placeholder(tf.bool, shape=[])
64 | x = tf.placeholder(dtype=tf.float32, shape=[None, 1000, 1000, 3], name="image_batch")
65 | y = tf.placeholder(dtype=tf.int32, shape=[None, 1000, 1000, 1], name="label_batch")
66 |
67 | train_dataset = train_or_eval_input_fn(is_training=True,
68 | data_dir="./DatasetNew/train/", batch_size=batch_size)
69 | eval_dataset = train_or_eval_input_fn(is_training=False,
70 | data_dir="./DatasetNew/val/", batch_size=batch_size, num_epochs=1)
71 | iterator_train = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
72 | next_batch = iterator_train.get_next()
73 | training_init_op = iterator_train.make_initializer(train_dataset)
74 | evaling_init_op = iterator_train.make_initializer(eval_dataset)
75 |
76 | loss, train_op, metrics = tools_deeplabv3_DA.get_loss_pre_metrics(x, y, is_train, batch_size, args)
77 |
78 | accuracy = metrics["px_accuracy"]
79 | mean_iou = metrics["mean_iou"]
80 | confusion_matrix = metrics['confusion_matrix']
81 |
82 | summary_op = tf.summary.merge_all()
83 | init_op = tf.group(
84 | tf.local_variables_initializer(),
85 | tf.global_variables_initializer()
86 | )
87 | # 首次运行从deeplabv3中获取权重需要剔除logits层
88 | exclude = ['global_step', 'DeepLab_v3/ASPP_layer/image_level_conv_1x1', 'DeepLab_v3/ASPP_layer/reduce_chanel', 'DeepLab_v3/ASPP_layer/position_module']
89 | variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude)
90 |
91 | saver_first = tf.train.Saver(variables_to_restore)
92 | saver = tf.train.Saver(max_to_keep=100)
93 | summary_writer_train = tf.summary.FileWriter(summary_path + "train/")
94 | summary_writer_val = tf.summary.FileWriter(summary_path + "val/")
95 | # 运行图
96 | config = tf.ConfigProto()
97 | config.gpu_options.allow_growth = True
98 | with tf.Session(config=config) as sess:
99 | sess.run(init_op, feed_dict={is_train: True})
100 | ckpt = tf.train.get_checkpoint_state(checkpoint_path_first)
101 | if ckpt and ckpt.model_checkpoint_path:
102 | saver_first.restore(sess, ckpt.model_checkpoint_path)
103 | sess.graph.finalize()
104 |
105 | train_batches_of_epoch = int(math.ceil(train_set_length / batch_size))
106 | val_batches_of_epoch = int(math.ceil(eval_set_length / batch_size))
107 | for epoch in range(EPOCHS):
108 | sess.run(training_init_op)
109 | print("{} Epoch number: {}".format(datetime.datetime.now(), epoch + 1))
110 | # step = 1 (epoch * train_batches_of_epoch), ((epoch + 1) * train_batches_of_epoch)
111 | for step in range((epoch * train_batches_of_epoch), ((epoch + 1) * train_batches_of_epoch)):
112 | img_batch, label_batch = sess.run(next_batch)
113 | loss_value, _, acc, m_iou, con_matrix = sess.run(
114 | [loss, train_op, accuracy, mean_iou, confusion_matrix],
115 | feed_dict={x: img_batch, y: label_batch, is_train: True})
116 |
117 | if (step + 1) % 833 == 0:
118 | kappa = tools_deeplabv3_DA.kappa(con_matrix)
119 | print("{} {} loss = {:.4f}".format(datetime.datetime.now(), step + 1, loss_value))
120 | print("accuracy{}".format(acc))
121 | print("miou{}".format(m_iou))
122 | print("kappa{}".format(kappa))
123 | merge = sess.run(summary_op, feed_dict={x: img_batch, y: label_batch, is_train: True})
124 | summary_writer_train.add_summary(merge, step + 1)
125 | saver.save(sess, checkpoint_path + "model.ckpt", epoch + 1)
126 | print("checkpoint saved")
127 |
128 | # 验证过程
129 | sess.run(evaling_init_op)
130 | print("{} Start validation".format(datetime.datetime.now()))
131 | test_acc = 0.0
132 | test_miou = 0.0
133 | test_kappa = 0.0
134 | test_count = 0
135 | for tag in range(val_batches_of_epoch):
136 | img_batch, label_batch = sess.run(next_batch)
137 | acc, m_iou, con_matrix = sess.run(
138 | [accuracy, mean_iou, confusion_matrix],
139 | feed_dict={x: img_batch, y: label_batch, is_train: False})
140 |
141 | kappa = tools_deeplabv3_DA.kappa(con_matrix)
142 | test_kappa += kappa
143 | test_acc += acc
144 | test_miou += m_iou
145 | test_count += 1
146 | test_acc /= test_count
147 | test_miou /= test_count
148 | test_kappa /= test_count
149 | s = tf.Summary(value=[
150 | tf.Summary.Value(tag="validation_accuracy", simple_value=test_acc),
151 | tf.Summary.Value(tag="validation_miou", simple_value=test_miou),
152 | tf.Summary.Value(tag="validation_kappa", simple_value=test_kappa)
153 | ])
154 | summary_writer_val.add_summary(s, epoch + 1)
155 | print("{} Validation Accuracy = {:.4f}".format(datetime.datetime.now(), test_acc))
156 | print("{} Validation miou = {:.4f}".format(datetime.datetime.now(), test_miou))
157 | print("{} Validation kappa = {:.4f}".format(datetime.datetime.now(), test_kappa))
158 |
159 | if __name__ == '__main__':
160 | tf.logging.set_verbosity(tf.logging.INFO)
161 | main()
--------------------------------------------------------------------------------
/train_deeplabv3plus.py:
--------------------------------------------------------------------------------
1 | from GeneratingBatchSize.GetDataset import train_or_eval_input_fn
2 | import tensorflow as tf
3 | import os
4 | import argparse
5 | import tools_deeplabv3plus
6 | import datetime
7 | import math
8 |
9 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
10 |
11 | batch_size = 3
12 | summary_path = "./summary_deeplabv3plus/"
13 | checkpoint_path = "./checkpoint_deeplabv3plus/"
14 | EPOCHS = 50
15 | train_set_length = 5000
16 | eval_set_length = 500
17 |
18 | parser = argparse.ArgumentParser()
19 |
20 | #添加参数
21 | envarg = parser.add_argument_group('Training params')
22 | # BN params
23 | envarg.add_argument("--batch_norm_epsilon", type=float, default=1e-5, help="batch norm epsilon argument for batch normalization")
24 | envarg.add_argument('--batch_norm_decay', type=float, default=0.9997, help='batch norm decay argument for batch normalization.')
25 | envarg.add_argument('--freeze_batch_norm', type=bool, default=False, help='Freeze batch normalization parameters during the training.')
26 | # the number of classes
27 | envarg.add_argument("--number_of_classes", type=int, default=16, help="Number of classes to be predicted.")
28 |
29 | # regularizer
30 | envarg.add_argument("--l2_regularizer", type=float, default=0.0001, help="l2 regularizer parameter.")
31 |
32 | # for deeplabv3
33 | envarg.add_argument("--multi_grid", type=list, default=[1, 2, 4], help="Spatial Pyramid Pooling rates")
34 | envarg.add_argument("--output_stride", type=int, default=16, help="Spatial Pyramid Pooling rates")
35 |
36 | # the base network
37 | envarg.add_argument("--resnet_model", default="resnet_v2_50", choices=["resnet_v2_50", "resnet_v2_101", "resnet_v2_152", "resnet_v2_200"], help="Resnet model to use as feature extractor. Choose one of: resnet_v2_50 or resnet_v2_101")
38 |
39 | # the pre_trained model for example resnet50 101 and so on
40 | envarg.add_argument('--pre_trained_model', type=str, default='./pre_trained_model/resnet_v2_50/resnet_v2_50.ckpt',
41 | help='Path to the pre-trained model checkpoint.')
42 |
43 | # max number of batch elements to tensorboard
44 | parser.add_argument('--tensorboard_images_max_outputs', type=int, default=6,
45 | help='Max number of batch elements to generate for Tensorboard.')
46 | # poly learn_rate
47 | parser.add_argument('--initial_learning_rate', type=float, default=7e-3,
48 | help='Initial learning rate for the optimizer.')
49 |
50 | parser.add_argument('--end_learning_rate', type=float, default=1e-6,
51 | help='End learning rate for the optimizer.')
52 |
53 | parser.add_argument('--initial_global_step', type=int, default=0,
54 | help='Initial global step for controlling learning rate when fine-tuning model.')
55 | parser.add_argument('--max_iter', type=int, default=83333,
56 | help='Number of maximum iteration used for "poly" learning rate policy.')
57 | args = parser.parse_args()
58 |
59 | def main():
60 | os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '0'
61 |
62 | is_train = tf.placeholder(tf.bool, shape=[])
63 | x = tf.placeholder(dtype=tf.float32, shape=[None, 1000, 1000, 3], name="image_batch")
64 | y = tf.placeholder(dtype=tf.int32, shape=[None, 1000, 1000, 1], name="label_batch")
65 |
66 | train_dataset = train_or_eval_input_fn(is_training=True,
67 | data_dir="./DatasetNew/train/", batch_size=batch_size)
68 | eval_dataset = train_or_eval_input_fn(is_training=False,
69 | data_dir="./DatasetNew/val/", batch_size=batch_size, num_epochs=1)
70 | iterator_train = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
71 | next_batch = iterator_train.get_next()
72 | training_init_op = iterator_train.make_initializer(train_dataset)
73 | evaling_init_op = iterator_train.make_initializer(eval_dataset)
74 |
75 | loss, train_op, metrics = tools_deeplabv3plus.get_loss_pre_metrics(x, y, is_train, batch_size, args)
76 |
77 | accuracy = metrics["px_accuracy"]
78 | mean_iou = metrics["mean_iou"]
79 | confusion_matrix = metrics['confusion_matrix']
80 |
81 | summary_op = tf.summary.merge_all()
82 | init_op = tf.group(
83 | tf.local_variables_initializer(),
84 | tf.global_variables_initializer()
85 | )
86 | saver = tf.train.Saver(max_to_keep=40)
87 | summary_writer_train = tf.summary.FileWriter(summary_path + "train/")
88 | summary_writer_val = tf.summary.FileWriter(summary_path + "val/")
89 | # 运行图
90 | config = tf.ConfigProto()
91 | config.gpu_options.allow_growth = True
92 | with tf.Session(config=config) as sess:
93 | sess.run(init_op, feed_dict={is_train: True})
94 | ckpt = tf.train.get_checkpoint_state(checkpoint_path)
95 | if ckpt and ckpt.model_checkpoint_path:
96 | saver.restore(sess, ckpt.model_checkpoint_path)
97 | sess.graph.finalize()
98 | train_batches_of_epoch = int(math.ceil(train_set_length / batch_size))
99 | val_batches_of_epoch = int(math.ceil(eval_set_length / batch_size))
100 | for epoch in range(18, EPOCHS):
101 | sess.run(training_init_op)
102 | print("{} Epoch number: {}".format(datetime.datetime.now(), epoch + 1))
103 | # step = 1
104 | for step in range((epoch * train_batches_of_epoch), ((epoch + 1) * train_batches_of_epoch)):
105 | img_batch, label_batch = sess.run(next_batch)
106 | sess.run([train_op], feed_dict={x: img_batch, y: label_batch, is_train: True})
107 | if (step + 1) % 833 == 0:
108 | loss_value, acc, m_iou, con_matrix, merge = sess.run(
109 | [loss, accuracy, mean_iou, confusion_matrix, summary_op],
110 | feed_dict={x: img_batch, y: label_batch, is_train: True})
111 | kappa = tools_deeplabv3plus.kappa(con_matrix)
112 | print("{} {} loss = {:.4f}".format(datetime.datetime.now(), step + 1, loss_value))
113 | print("accuracy{}".format(acc))
114 | print("miou{}".format(m_iou))
115 | print("kappa{}".format(kappa))
116 | summary_writer_train.add_summary(merge, step + 1)
117 | saver.save(sess, checkpoint_path + "model.ckpt", epoch + 1)
118 | print("checkpoint saved")
119 |
120 | # 验证过程
121 | sess.run(evaling_init_op)
122 | print("{} Start validation".format(datetime.datetime.now()))
123 | test_acc = 0.0
124 | test_miou = 0.0
125 | test_kappa = 0.0
126 | test_count = 0
127 | for tag in range(val_batches_of_epoch):
128 | img_batch, label_batch = sess.run(next_batch)
129 | acc, m_iou, con_matrix = sess.run(
130 | [accuracy, mean_iou, confusion_matrix],
131 | feed_dict={x: img_batch, y: label_batch, is_train: False})
132 |
133 | kappa = tools_deeplabv3plus.kappa(con_matrix)
134 | test_kappa += kappa
135 | test_acc += acc
136 | test_miou += m_iou
137 | test_count += 1
138 | test_acc /= test_count
139 | test_miou /= test_count
140 | test_kappa /= test_count
141 | s = tf.Summary(value=[
142 | tf.Summary.Value(tag="validation_accuracy", simple_value=test_acc),
143 | tf.Summary.Value(tag="validation_miou", simple_value=test_miou),
144 | tf.Summary.Value(tag="validation_kappa", simple_value=test_kappa)
145 | ])
146 | summary_writer_val.add_summary(s, epoch + 1)
147 | print("{} Validation Accuracy = {:.4f}".format(datetime.datetime.now(), test_acc))
148 | print("{} Validation miou = {:.4f}".format(datetime.datetime.now(), test_miou))
149 | print("{} Validation kappa = {:.4f}".format(datetime.datetime.now(), test_kappa))
150 |
151 | if __name__ == '__main__':
152 | tf.logging.set_verbosity(tf.logging.INFO)
153 | main()
--------------------------------------------------------------------------------
/train_deeplabv3plus_4chanel.py:
--------------------------------------------------------------------------------
1 | from GeneratingBatchSize.GetDataset import train_or_eval_input_fn
2 | from tensorflow.python import pywrap_tensorflow
3 | import tensorflow as tf
4 | import os
5 | import argparse
6 | import tools_deeplabv3plus
7 | import datetime
8 | import math
9 | import numpy as np
10 |
11 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
12 |
13 | batch_size = 4
14 | summary_path = "./summary_deeplabv3plus_4/"
15 | checkpoint_path_first = "./checkpoint_deeplabv3plus/" # 训练好的三通道deeplabv3+
16 | checkpoint_path = "./checkpoint_deeplabv3plus_4/"
17 | EPOCHS = 30
18 | train_set_length = 5000
19 | eval_set_length = 500
20 |
21 | parser = argparse.ArgumentParser()
22 |
23 | #添加参数
24 | envarg = parser.add_argument_group('Training params')
25 | # BN params
26 | envarg.add_argument("--batch_norm_epsilon", type=float, default=1e-5, help="batch norm epsilon argument for batch normalization")
27 | envarg.add_argument('--batch_norm_decay', type=float, default=0.9997, help='batch norm decay argument for batch normalization.')
28 | envarg.add_argument('--freeze_batch_norm', type=bool, default=False, help='Freeze batch normalization parameters during the training.')
29 | # the number of classes
30 | envarg.add_argument("--number_of_classes", type=int, default=16, help="Number of classes to be predicted.")
31 |
32 | # regularizer
33 | envarg.add_argument("--l2_regularizer", type=float, default=0.0001, help="l2 regularizer parameter.")
34 |
35 | # for deeplabv3
36 | envarg.add_argument("--multi_grid", type=list, default=[1, 2, 4], help="Spatial Pyramid Pooling rates")
37 | envarg.add_argument("--output_stride", type=int, default=16, help="Spatial Pyramid Pooling rates")
38 |
39 | # the base network
40 | envarg.add_argument("--resnet_model", default="resnet_v2_50", choices=["resnet_v2_50", "resnet_v2_101", "resnet_v2_152", "resnet_v2_200"], help="Resnet model to use as feature extractor. Choose one of: resnet_v2_50 or resnet_v2_101")
41 |
42 | # the pre_trained model for example resnet50 101 and so on
43 | envarg.add_argument('--pre_trained_model', type=str, default='./pre_trained_model/resnet_v2_50/resnet_v2_50.ckpt',
44 | help='Path to the pre-trained model checkpoint.')
45 |
46 | # max number of batch elements to tensorboard
47 | parser.add_argument('--tensorboard_images_max_outputs', type=int, default=6,
48 | help='Max number of batch elements to generate for Tensorboard.')
49 | # poly learn_rate
50 | parser.add_argument('--initial_learning_rate', type=float, default=4e-3,
51 | help='Initial learning rate for the optimizer.')
52 |
53 | parser.add_argument('--end_learning_rate', type=float, default=1e-6,
54 | help='End learning rate for the optimizer.')
55 |
56 | parser.add_argument('--initial_global_step', type=int, default=0,
57 | help='Initial global step for controlling learning rate when fine-tuning model.')
58 | parser.add_argument('--max_iter', type=int, default=37500,
59 | help='Number of maximum iteration used for "poly" learning rate policy.')
60 | args = parser.parse_args()
61 |
62 | def main():
63 | os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '0'
64 |
65 | is_train = tf.placeholder(tf.bool, shape=[])
66 | x = tf.placeholder(dtype=tf.float32, shape=[None, 1000, 1000, 4], name="image_batch")
67 | y = tf.placeholder(dtype=tf.int32, shape=[None, 1000, 1000, 1], name="label_batch")
68 |
69 | train_dataset = train_or_eval_input_fn(is_training=True,
70 | data_dir="./DatasetNew/train/", batch_size=batch_size)
71 | eval_dataset = train_or_eval_input_fn(is_training=False,
72 | data_dir="./DatasetNew/val/", batch_size=batch_size, num_epochs=1)
73 | iterator_train = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
74 | next_batch = iterator_train.get_next()
75 | training_init_op = iterator_train.make_initializer(train_dataset)
76 | evaling_init_op = iterator_train.make_initializer(eval_dataset)
77 |
78 | loss, train_op, metrics = tools_deeplabv3plus.get_loss_pre_metrics(x, y, is_train, batch_size, args)
79 |
80 | accuracy = metrics["px_accuracy"]
81 | mean_iou = metrics["mean_iou"]
82 | confusion_matrix = metrics['confusion_matrix']
83 |
84 | summary_op = tf.summary.merge_all()
85 | init_op = tf.group(
86 | tf.local_variables_initializer(),
87 | tf.global_variables_initializer()
88 | )
89 | # 首次运行从deeplabv3+中获取权重需要剔除logits层
90 | exclude = [args.resnet_model + '/conv1', 'global_step']
91 | variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude)
92 |
93 | with tf.variable_scope("resnet_v2_50", reuse=True):
94 | conv1_weight_restored = tf.get_variable("conv1/weights")
95 | conv1_biase_restored = tf.get_variable("conv1/biases")
96 | # 获得变量值
97 | checkpoint_path_restored = os.path.join(checkpoint_path_first, "model.ckpt-58")
98 | reader_restored = pywrap_tensorflow.NewCheckpointReader(checkpoint_path_restored)
99 | conv1_weight_value = reader_restored.get_tensor("resnet_v2_50/conv1/weights")
100 | conv1_weight_value_4 = np.sum(conv1_weight_value, axis=2, keepdims=True)
101 | conv1_weight_value_4 = np.true_divide(conv1_weight_value_4, 3)
102 | conv1_weight_value = np.concatenate((conv1_weight_value_4, conv1_weight_value), axis=2)
103 | conv1_biase_value = reader_restored.get_tensor("resnet_v2_50/conv1/biases")
104 | conv1_weight_op = tf.assign(conv1_weight_restored, conv1_weight_value)
105 | conv1_biase_op = tf.assign(conv1_biase_restored, conv1_biase_value)
106 |
107 |
108 | saver_first = tf.train.Saver(variables_to_restore)
109 | saver = tf.train.Saver(max_to_keep=50)
110 | summary_writer_train = tf.summary.FileWriter(summary_path + "train/")
111 | summary_writer_val = tf.summary.FileWriter(summary_path + "val/")
112 | # 运行图
113 | config = tf.ConfigProto()
114 | config.gpu_options.allow_growth = True
115 | with tf.Session(config=config) as sess:
116 | sess.run(init_op, feed_dict={is_train: True})
117 | ckpt = tf.train.get_checkpoint_state(checkpoint_path_first)
118 | if ckpt and ckpt.model_checkpoint_path:
119 | saver_first.restore(sess, checkpoint_path_first+"model.ckpt-58")
120 | sess.graph.finalize()
121 |
122 | sess.run([conv1_weight_op, conv1_biase_op])
123 |
124 | train_batches_of_epoch = int(math.ceil(train_set_length / batch_size))
125 | val_batches_of_epoch = int(math.ceil(eval_set_length / batch_size))
126 | for epoch in range(EPOCHS):
127 | sess.run(training_init_op)
128 | print("{} Epoch number: {}".format(datetime.datetime.now(), epoch + 1))
129 | # step = 1
130 | for step in range((epoch * train_batches_of_epoch), ((epoch + 1) * train_batches_of_epoch)):
131 | img_batch, label_batch = sess.run(next_batch)
132 | sess.run([train_op], feed_dict={x: img_batch, y: label_batch, is_train: True})
133 | if (step + 1) % 625 == 0:
134 | loss_value, acc, m_iou, con_matrix, merge = sess.run(
135 | [loss, accuracy, mean_iou, confusion_matrix, summary_op],
136 | feed_dict={x: img_batch, y: label_batch, is_train: True})
137 | kappa = tools_deeplabv3plus.kappa(con_matrix)
138 | print("{} {} loss = {:.4f}".format(datetime.datetime.now(), step + 1, loss_value))
139 | print("accuracy{}".format(acc))
140 | print("miou{}".format(m_iou))
141 | print("kappa{}".format(kappa))
142 | summary_writer_train.add_summary(merge, step + 1)
143 | saver.save(sess, checkpoint_path + "model.ckpt", epoch + 1)
144 | print("checkpoint saved")
145 |
146 | # 验证过程
147 | sess.run(evaling_init_op)
148 | print("{} Start validation".format(datetime.datetime.now()))
149 | test_acc = 0.0
150 | test_miou = 0.0
151 | test_kappa = 0.0
152 | test_count = 0
153 | for tag in range(val_batches_of_epoch):
154 | img_batch, label_batch = sess.run(next_batch)
155 | acc, m_iou, con_matrix = sess.run(
156 | [accuracy, mean_iou, confusion_matrix],
157 | feed_dict={x: img_batch, y: label_batch, is_train: False})
158 |
159 | kappa = tools_deeplabv3plus.kappa(con_matrix)
160 | test_kappa += kappa
161 | test_acc += acc
162 | test_miou += m_iou
163 | test_count += 1
164 | test_acc /= test_count
165 | test_miou /= test_count
166 | test_kappa /= test_count
167 | s = tf.Summary(value=[
168 | tf.Summary.Value(tag="validation_accuracy", simple_value=test_acc),
169 | tf.Summary.Value(tag="validation_miou", simple_value=test_miou),
170 | tf.Summary.Value(tag="validation_kappa", simple_value=test_kappa)
171 | ])
172 | summary_writer_val.add_summary(s, epoch + 1)
173 | print("{} Validation Accuracy = {:.4f}".format(datetime.datetime.now(), test_acc))
174 | print("{} Validation miou = {:.4f}".format(datetime.datetime.now(), test_miou))
175 | print("{} Validation kappa = {:.4f}".format(datetime.datetime.now(), test_kappa))
176 |
177 | if __name__ == '__main__':
178 | tf.logging.set_verbosity(tf.logging.INFO)
179 | main()
--------------------------------------------------------------------------------
/train_psp.py:
--------------------------------------------------------------------------------
1 | from GeneratingBatchSize.GetDataset import train_or_eval_input_fn
2 | import tensorflow as tf
3 | import os
4 | import argparse
5 | import tools_psp
6 | import datetime
7 | import math
8 |
9 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
10 |
11 | batch_size = 3
12 | summary_path = "./summary_psp/"
13 | checkpoint_path = "./checkpoint_psp/"
14 | EPOCHS = 50
15 | train_set_length = 5000
16 | eval_set_length = 500
17 |
18 | parser = argparse.ArgumentParser()
19 |
20 | #添加参数
21 | envarg = parser.add_argument_group('Training params')
22 | # BN params
23 | envarg.add_argument("--batch_norm_epsilon", type=float, default=1e-5, help="batch norm epsilon argument for batch normalization")
24 | envarg.add_argument('--batch_norm_decay', type=float, default=0.9997, help='batch norm decay argument for batch normalization.')
25 | envarg.add_argument('--freeze_batch_norm', type=bool, default=False, help='Freeze batch normalization parameters during the training.')
26 | # the number of classes
27 | envarg.add_argument("--number_of_classes", type=int, default=16, help="Number of classes to be predicted.")
28 |
29 | # regularizer
30 | envarg.add_argument("--l2_regularizer", type=float, default=0.0001, help="l2 regularizer parameter.")
31 |
32 | # the base network
33 | envarg.add_argument("--resnet_model", default="resnet_v2_50", choices=["resnet_v2_50", "resnet_v2_101", "resnet_v2_152", "resnet_v2_200"], help="Resnet model to use as feature extractor. Choose one of: resnet_v2_50 or resnet_v2_101")
34 |
35 | # the pre_trained model for example resnet50 101 and so on
36 | envarg.add_argument('--pre_trained_model', type=str, default='./pre_trained_model/resnet_v2_50/resnet_v2_50.ckpt',
37 | help='Path to the pre-trained model checkpoint.')
38 |
39 | # max number of batch elements to tensorboard
40 | parser.add_argument('--tensorboard_images_max_outputs', type=int, default=4,
41 | help='Max number of batch elements to generate for Tensorboard.')
42 | # poly learn_rate
43 | parser.add_argument('--initial_learning_rate', type=float, default=7e-3,
44 | help='Initial learning rate for the optimizer.')
45 |
46 | parser.add_argument('--end_learning_rate', type=float, default=1e-6,
47 | help='End learning rate for the optimizer.')
48 |
49 | parser.add_argument('--initial_global_step', type=int, default=0,
50 | help='Initial global step for controlling learning rate when fine-tuning model.')
51 | parser.add_argument('--max_iter', type=int, default=83333,
52 | help='Number of maximum iteration used for "poly" learning rate policy.')
53 | args = parser.parse_args()
54 |
55 | def main():
56 | os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '0'
57 |
58 | is_train = tf.placeholder(tf.bool, shape=[])
59 | x = tf.placeholder(dtype=tf.float32, shape=[None, 1000, 1000, 3], name="image_batch")
60 | y = tf.placeholder(dtype=tf.int32, shape=[None, 1000, 1000, 1], name="label_batch")
61 |
62 | train_dataset = train_or_eval_input_fn(is_training=True,
63 | data_dir="./DatasetNew/train/", batch_size=batch_size)
64 | eval_dataset = train_or_eval_input_fn(is_training=False,
65 | data_dir="./DatasetNew/val/", batch_size=batch_size, num_epochs=1)
66 | iterator_train = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
67 | next_batch = iterator_train.get_next()
68 | training_init_op = iterator_train.make_initializer(train_dataset)
69 | evaling_init_op = iterator_train.make_initializer(eval_dataset)
70 |
71 | loss, train_op, metrics = tools_psp.get_loss_pre_metrics(x, y, is_train, batch_size, args)
72 |
73 | accuracy = metrics["px_accuracy"]
74 | mean_iou = metrics["mean_iou"]
75 | confusion_matrix = metrics['confusion_matrix']
76 |
77 | summary_op = tf.summary.merge_all()
78 | init_op = tf.group(
79 | tf.local_variables_initializer(),
80 | tf.global_variables_initializer()
81 | )
82 |
83 | saver = tf.train.Saver(max_to_keep=100)
84 | summary_writer_train = tf.summary.FileWriter(summary_path + "train/")
85 | summary_writer_val = tf.summary.FileWriter(summary_path + "val/")
86 | # 运行图
87 | config = tf.ConfigProto()
88 | config.gpu_options.allow_growth = True
89 | with tf.Session(config=config) as sess:
90 | sess.run(init_op, feed_dict={is_train: True})
91 | ckpt = tf.train.get_checkpoint_state(checkpoint_path)
92 | if ckpt and ckpt.model_checkpoint_path:
93 | saver.restore(sess, ckpt.model_checkpoint_path)
94 | print("restored")
95 | sess.graph.finalize()
96 |
97 | train_batches_of_epoch = int(math.ceil(train_set_length / batch_size))
98 | val_batches_of_epoch = int(math.ceil(eval_set_length / batch_size))
99 | for epoch in range(EPOCHS):
100 | sess.run(training_init_op)
101 | print("{} Epoch number: {}".format(datetime.datetime.now(), epoch + 1))
102 | # step = 1
103 | for step in range((epoch * train_batches_of_epoch), ((epoch + 1) * train_batches_of_epoch)):
104 | img_batch, label_batch = sess.run(next_batch)
105 | sess.run([train_op], feed_dict={x: img_batch, y: label_batch, is_train: True})
106 |
107 | if (step + 1) % 833 == 0:
108 | loss_value, acc, m_iou, con_matrix = sess.run(
109 | [loss, accuracy, mean_iou, confusion_matrix],
110 | feed_dict={x: img_batch, y: label_batch, is_train: True})
111 | kappa = tools_psp.kappa(con_matrix)
112 | print("{} {} loss = {:.4f}".format(datetime.datetime.now(), step + 1, loss_value))
113 | print("accuracy{}".format(acc))
114 | print("miou{}".format(m_iou))
115 | print("kappa{}".format(kappa))
116 | merge = sess.run(summary_op, feed_dict={x: img_batch, y: label_batch, is_train: True})
117 | summary_writer_train.add_summary(merge, step + 1)
118 | saver.save(sess, checkpoint_path + "model.ckpt", epoch + 1)
119 | print("checkpoint saved")
120 |
121 | # 验证过程
122 | sess.run(evaling_init_op)
123 | print("{} Start validation".format(datetime.datetime.now()))
124 | test_acc = 0.0
125 | test_miou = 0.0
126 | test_kappa = 0.0
127 | test_count = 0
128 | for tag in range(val_batches_of_epoch):
129 | img_batch, label_batch = sess.run(next_batch)
130 | acc, m_iou, con_matrix = sess.run(
131 | [accuracy, mean_iou, confusion_matrix],
132 | feed_dict={x: img_batch, y: label_batch, is_train: False})
133 |
134 | kappa = tools_psp.kappa(con_matrix)
135 | test_kappa += kappa
136 | test_acc += acc
137 | test_miou += m_iou
138 | test_count += 1
139 | test_acc /= test_count
140 | test_miou /= test_count
141 | test_kappa /= test_count
142 | s = tf.Summary(value=[
143 | tf.Summary.Value(tag="validation_accuracy", simple_value=test_acc),
144 | tf.Summary.Value(tag="validation_miou", simple_value=test_miou),
145 | tf.Summary.Value(tag="validation_kappa", simple_value=test_kappa)
146 | ])
147 | summary_writer_val.add_summary(s, epoch + 1)
148 | print("{} Validation Accuracy = {:.4f}".format(datetime.datetime.now(), test_acc))
149 | print("{} Validation miou = {:.4f}".format(datetime.datetime.now(), test_miou))
150 | print("{} Validation kappa = {:.4f}".format(datetime.datetime.now(), test_kappa))
151 |
152 | if __name__ == '__main__':
153 | tf.logging.set_verbosity(tf.logging.INFO)
154 | main()
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tangzhenjie/semantic_segmentation_contest/df3df24296f26209950a2455ed2f7751a9e046ca/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/preprocessing.py:
--------------------------------------------------------------------------------
1 | """Utility functions for preprocessing data sets."""
2 |
3 | from PIL import Image
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | _R_MEAN = 78.51
8 | _G_MEAN = 114.00
9 | _B_MEAN = 110.00
10 |
11 | # Colour map.
12 | label_colours = [(0, 0, 0)
13 | # 0=其他
14 | ,(0, 200, 0), (150, 250, 0), (150, 200, 150), (200, 0, 200), (150, 0, 250)
15 | # 1=水田, 2=水浇地, 3=旱耕地, 4=园地, 5=乔木林地
16 | , (150, 150, 250), (250, 200, 0), (200, 200, 0), (200, 0, 0), (250, 0, 150)
17 | # 6=灌木林地, 7=天然草地, 8=人工草地, 9=工业用地, 10=城市住宅
18 | , (200, 150, 150), (250, 150, 150), (0, 0, 200), (0, 150, 200), (0, 200, 250)]
19 | # 11=村镇住宅, 12=交通运输, 13=河流, 14=湖泊, 15=坑塘]
20 |
21 |
22 | def decode_labels(mask, num_images=1, num_classes=16):
23 | """Decode batch of segmentation masks.
24 |
25 | Args:
26 | mask: result of inference after taking argmax.
27 | num_images: number of images to decode from the batch.
28 | num_classes: number of classes to predict (including background).
29 |
30 | Returns:
31 | A batch with num_images RGB images of the same size as the input.
32 | """
33 | n, h, w, c = mask.shape
34 | assert (n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' \
35 | % (n, num_images)
36 | outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8)
37 | for i in range(num_images):
38 | img = Image.new('RGB', (len(mask[i, 0]), len(mask[i])))
39 | pixels = img.load()
40 | for j_, j in enumerate(mask[i, :, :, 0]):
41 | for k_, k in enumerate(j):
42 | if k < num_classes:
43 | pixels[k_, j_] = label_colours[k]
44 | outputs[i] = np.array(img)
45 | return outputs
46 |
47 |
48 | def mean_image_addition(image, means=(_R_MEAN, _G_MEAN, _B_MEAN)):
49 | """Adds the given means from each image channel.
50 |
51 | For example:
52 | means = [123.68, 116.779, 103.939]
53 | image = _mean_image_subtraction(image, means)
54 |
55 | Note that the rank of `image` must be known.
56 |
57 | Args:
58 | image: a tensor of size [height, width, C].
59 | means: a C-vector of values to subtract from each channel.
60 |
61 | Returns:
62 | the centered image.
63 |
64 | Raises:
65 | ValueError: If the rank of `image` is unknown, if `image` has a rank other
66 | than three or if the number of channels in `image` doesn't match the
67 | number of values in `means`.
68 | """
69 | if image.get_shape().ndims != 3:
70 | raise ValueError('Input must be of size [height, width, C>0]')
71 | num_channels = image.get_shape().as_list()[-1]
72 | if len(means) != num_channels:
73 | raise ValueError('len(means) must match the number of channels')
74 |
75 | channels = tf.split(axis=2, num_or_size_splits=num_channels, value=image)
76 | for i in range(num_channels):
77 | channels[i] += means[i]
78 | return tf.concat(axis=2, values=channels)
79 |
80 |
81 | def mean_image_subtraction(image, means=(_R_MEAN, _G_MEAN, _B_MEAN)):
82 | """Subtracts the given means from each image channel.
83 |
84 | For example:
85 | means = [123.68, 116.779, 103.939]
86 | image = _mean_image_subtraction(image, means)
87 |
88 | Note that the rank of `image` must be known.
89 |
90 | Args:
91 | image: a tensor of size [height, width, C].
92 | means: a C-vector of values to subtract from each channel.
93 |
94 | Returns:
95 | the centered image.
96 |
97 | Raises:
98 | ValueError: If the rank of `image` is unknown, if `image` has a rank other
99 | than three or if the number of channels in `image` doesn't match the
100 | number of values in `means`.
101 | """
102 | if image.get_shape().ndims != 3:
103 | raise ValueError('Input must be of size [height, width, C>0]')
104 | num_channels = image.get_shape().as_list()[-1]
105 | if len(means) != num_channels:
106 | raise ValueError('len(means) must match the number of channels')
107 |
108 | channels = tf.split(axis=2, num_or_size_splits=num_channels, value=image)
109 | for i in range(num_channels):
110 | channels[i] -= means[i]
111 | return tf.concat(axis=2, values=channels)
112 |
113 |
114 | def random_rescale_image_and_label(image, label, min_scale, max_scale):
115 | """Rescale an image and label with in target scale.
116 |
117 | Rescales an image and label within the range of target scale.
118 |
119 | Args:
120 | image: 3-D Tensor of shape `[height, width, channels]`.
121 | label: 3-D Tensor of shape `[height, width, 1]`.
122 | min_scale: Min target scale.
123 | max_scale: Max target scale.
124 |
125 | Returns:
126 | Cropped and/or padded image.
127 | If `images` was 3-D, a 3-D float Tensor of shape
128 | `[new_height, new_width, channels]`.
129 | If `labels` was 3-D, a 3-D float Tensor of shape
130 | `[new_height, new_width, 1]`.
131 | """
132 | if min_scale <= 0:
133 | raise ValueError('\'min_scale\' must be greater than 0.')
134 | elif max_scale <= 0:
135 | raise ValueError('\'max_scale\' must be greater than 0.')
136 | elif min_scale >= max_scale:
137 | raise ValueError('\'max_scale\' must be greater than \'min_scale\'.')
138 |
139 | shape = tf.shape(image)
140 | height = tf.to_float(shape[0])
141 | width = tf.to_float(shape[1])
142 | scale = tf.random_uniform(
143 | [], minval=min_scale, maxval=max_scale, dtype=tf.float32)
144 | new_height = tf.to_int32(height * scale)
145 | new_width = tf.to_int32(width * scale)
146 | image = tf.image.resize_images(image, [new_height, new_width],
147 | method=tf.image.ResizeMethod.BILINEAR)
148 | # Since label classes are integers, nearest neighbor need to be used.
149 | label = tf.image.resize_images(label, [new_height, new_width],
150 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
151 |
152 | return image, label
153 |
154 |
155 | def random_crop_or_pad_image_and_label(image, label, crop_height, crop_width, ignore_label):
156 | """Crops and/or pads an image to a target width and height.
157 |
158 | Resizes an image to a target width and height by rondomly
159 | cropping the image or padding it evenly with zeros.
160 |
161 | Args:
162 | image: 3-D Tensor of shape `[height, width, channels]`.
163 | label: 3-D Tensor of shape `[height, width, 1]`.
164 | crop_height: The new height.
165 | crop_width: The new width.
166 | ignore_label: Label class to be ignored.
167 |
168 | Returns:
169 | Cropped and/or padded image.
170 | If `images` was 3-D, a 3-D float Tensor of shape
171 | `[new_height, new_width, channels]`.
172 | """
173 | label = tf.to_float(label)
174 | image_height = tf.shape(image)[0]
175 | image_width = tf.shape(image)[1]
176 | image_and_label = tf.concat([image, label], axis=2)
177 | image_and_label_pad = tf.image.pad_to_bounding_box(
178 | image_and_label, 0, 0,
179 | tf.maximum(crop_height, image_height),
180 | tf.maximum(crop_width, image_width))
181 | image_and_label_crop = tf.random_crop(
182 | image_and_label_pad, [crop_height, crop_width, 4])
183 |
184 | image_crop = image_and_label_crop[:, :, :3]
185 | label_crop = image_and_label_crop[:, :, 3:]
186 | label_crop = tf.to_int32(label_crop)
187 |
188 | return image_crop, label_crop
189 |
190 |
191 | def random_flip_left_right_image_and_label(image, label):
192 | """Randomly flip an image and label horizontally (left to right).
193 |
194 | Args:
195 | image: A 3-D tensor of shape `[height, width, channels].`
196 | label: A 3-D tensor of shape `[height, width, 1].`
197 |
198 | Returns:
199 | A 3-D tensor of the same type and shape as `image`.
200 | A 3-D tensor of the same type and shape as `label`.
201 | """
202 | uniform_random = tf.random_uniform([], 0, 1.0)
203 | mirror_cond = tf.less(uniform_random, .5)
204 | image = tf.cond(mirror_cond, lambda: tf.reverse(image, [1]), lambda: image)
205 | label = tf.cond(mirror_cond, lambda: tf.reverse(label, [1]), lambda: label)
206 |
207 | return image, label
208 |
209 |
210 | def eval_input_fn(image_filenames, label_filenames=None, batch_size=1):
211 | """An input function for evaluation and inference.
212 |
213 | Args:
214 | image_filenames: The file names for the inferred images.
215 | label_filenames: The file names for the grand truth labels.
216 | batch_size: The number of samples per batch. Need to be 1
217 | for the images of different sizes.
218 |
219 | Returns:
220 | A tuple of images and labels.
221 | """
222 | # Reads an image from a file, decodes it into a dense tensor
223 | def _parse_function(filename, is_label):
224 | if not is_label:
225 | image_filename, label_filename = filename, None
226 | else:
227 | image_filename, label_filename = filename
228 |
229 | image_string = tf.read_file(image_filename)
230 | image = tf.image.decode_image(image_string)
231 | image = tf.to_float(tf.image.convert_image_dtype(image, dtype=tf.uint8))
232 | image.set_shape([None, None, 3])
233 |
234 | image = mean_image_subtraction(image)
235 |
236 | if not is_label:
237 | return image
238 | else:
239 | label_string = tf.read_file(label_filename)
240 | label = tf.image.decode_image(label_string)
241 | label = tf.to_int32(tf.image.convert_image_dtype(label, dtype=tf.uint8))
242 | label.set_shape([None, None, 1])
243 |
244 | return image, label
245 |
246 | if label_filenames is None:
247 | input_filenames = image_filenames
248 | else:
249 | input_filenames = (image_filenames, label_filenames)
250 |
251 | dataset = tf.data.Dataset.from_tensor_slices(input_filenames)
252 | if label_filenames is None:
253 | dataset = dataset.map(lambda x: _parse_function(x, False))
254 | else:
255 | dataset = dataset.map(lambda x, y: _parse_function((x, y), True))
256 | dataset = dataset.prefetch(batch_size)
257 | dataset = dataset.batch(batch_size)
258 | iterator = dataset.make_one_shot_iterator()
259 |
260 | if label_filenames is None:
261 | images = iterator.get_next()
262 | labels = None
263 | else:
264 | images, labels = iterator.get_next()
265 |
266 | return images, labels
267 |
--------------------------------------------------------------------------------