├── ImageSet
├── Test
│ └── img_41.png
├── Train
│ ├── 1
│ │ └── img1.png
│ ├── 2
│ │ └── img1.png
│ ├── 3
│ │ └── img1.png
│ ├── 4
│ │ └── img1.png
│ ├── 5
│ │ └── img1.png
│ └── 6
│ │ └── img1.png
└── Validation
│ ├── 1
│ └── img1.png
│ ├── 2
│ └── img1.png
│ ├── 3
│ └── img1.png
│ ├── 4
│ └── img1.png
│ ├── 5
│ └── img1.png
│ └── 6
│ └── img1.png
├── README.md
├── Results
├── bad_2_good.png
├── test1.png
├── test10.png
├── test2.png
├── test3.png
├── test4.png
├── test5.png
├── test6.png
├── test7.png
├── test8.png
└── test9.png
├── model.npz
├── read_data.py
├── test_image.py
└── train_model.py
/ImageSet/Test/img_41.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Test/img_41.png
--------------------------------------------------------------------------------
/ImageSet/Train/1/img1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Train/1/img1.png
--------------------------------------------------------------------------------
/ImageSet/Train/2/img1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Train/2/img1.png
--------------------------------------------------------------------------------
/ImageSet/Train/3/img1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Train/3/img1.png
--------------------------------------------------------------------------------
/ImageSet/Train/4/img1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Train/4/img1.png
--------------------------------------------------------------------------------
/ImageSet/Train/5/img1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Train/5/img1.png
--------------------------------------------------------------------------------
/ImageSet/Train/6/img1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Train/6/img1.png
--------------------------------------------------------------------------------
/ImageSet/Validation/1/img1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Validation/1/img1.png
--------------------------------------------------------------------------------
/ImageSet/Validation/2/img1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Validation/2/img1.png
--------------------------------------------------------------------------------
/ImageSet/Validation/3/img1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Validation/3/img1.png
--------------------------------------------------------------------------------
/ImageSet/Validation/4/img1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Validation/4/img1.png
--------------------------------------------------------------------------------
/ImageSet/Validation/5/img1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Validation/5/img1.png
--------------------------------------------------------------------------------
/ImageSet/Validation/6/img1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Validation/6/img1.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Cell Detection
2 | A course project that detect cells in image by a simple full convolution neural network. The project is driven by TensorFlow.
3 |
4 | # Dependencies
5 |
6 | + python 3.5.2
7 | + numpy 1.11.3
8 | + scipy 0.18.1
9 | + pillow 4.1.0
10 | + tensorflow 1.0.0
11 | + matplotlib 2.0.0
12 | + tensorlayer 1.4.1
13 |
14 | This demo is tested only in Ubuntu 16.04.
15 |
16 | # Data Organization
17 |
18 | 50 full-scale images are composed of cells whose positions have been marked, from which training batch is extracted from 30 images, validation batch is extracted from 10 images, and the rest 10 images are used to test. Image set is not included in this repositery, you could eamil to quqixun@gmail.com to request dataset.
19 |
20 | ### Training and Validating Data
21 |
22 | Six groups patches are extracted from training and validating images on the basis of the different locations of patches' centers. The dimension of each patch is 35 by 35 by 3.
23 | The groups are shown as follows with one sample patch, in each group, the patch center locates at:
24 | + **Group 1 - the interaction region of cells**: 
25 | + **Group 2 - non-goal cell**: 
26 | + **Group 3 - nearby region of cell's edge**: 
27 | + **Group 4 - the gap between cells**: 
28 | + **Group 5 - background**: 
29 | + **Group 6 - the center of cell**: 
30 |
31 | ### Testing Data
32 |
33 | A sample of testing image is shown below.
34 |
35 | 
36 |
37 | # Code Organization
38 |
39 | + **read_data.py**: Create TFRecords for training and validating batch to train the model. Training and validating batch is randomly selected according to the batch size.
40 | + **train_model.py**: In this solution, a simple end-to-end convolution nural network is implemented, being trained and updated by input training set. The model is saved into the file "model.npz".
41 | + **test_model.py**: Carry out a pixel-wised classification on the input test image, reserving pixels that have highest posibbility to be a cell center.
42 |
43 | # Usage
44 |
45 | In terminal,
46 |
47 | + **Step 1**: run **python read_data.py** to create TFRecords (change the folder path and the name of TFRecords)
48 | + **Step 2**: run **python train_model.py** to train and save model
49 | + **Step 3**: run **python test_image.py** to test full-scale images
50 |
51 | # Result
52 |
53 | ### A good case:
54 |
55 |
56 |
57 | ### A bad case:
58 |
59 |
60 |
61 | Here is a bad case, in which several cells have not been detected. Increasing the number of training patches is able to solve this problem. The model is trainded by **29,818** patches generates the bad result as shown above. If the number of data is augmented by rotating and modifing HSV color space, the model is likely to perform better. The better result image is shown as below, which is detected by the model that is trained with **321,985** training patches. (This image is obtained from the solution in Matlab, data augmentation is not included in this repository.)
62 |
63 |
64 |
--------------------------------------------------------------------------------
/Results/bad_2_good.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/bad_2_good.png
--------------------------------------------------------------------------------
/Results/test1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/test1.png
--------------------------------------------------------------------------------
/Results/test10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/test10.png
--------------------------------------------------------------------------------
/Results/test2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/test2.png
--------------------------------------------------------------------------------
/Results/test3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/test3.png
--------------------------------------------------------------------------------
/Results/test4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/test4.png
--------------------------------------------------------------------------------
/Results/test5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/test5.png
--------------------------------------------------------------------------------
/Results/test6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/test6.png
--------------------------------------------------------------------------------
/Results/test7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/test7.png
--------------------------------------------------------------------------------
/Results/test8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/test8.png
--------------------------------------------------------------------------------
/Results/test9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/test9.png
--------------------------------------------------------------------------------
/model.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/model.npz
--------------------------------------------------------------------------------
/read_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 | import tensorflow as tf
5 |
6 |
7 | def create_record(path, classes, filename, patch_size):
8 | writer = tf.python_io.TFRecordWriter(filename)
9 | for index, name in enumerate(classes):
10 | class_path = path + str(name) + '/'
11 | print(class_path, index)
12 | for img_name in os.listdir(class_path):
13 | img_path = class_path + img_name
14 | img = Image.open(img_path)
15 | img = img.resize((patch_size, patch_size))
16 | img_raw = img.tobytes()
17 | example = tf.train.Example(features=tf.train.Features(feature={
18 | 'label': tf.train.Feature(
19 | int64_list=tf.train.Int64List(value=[index])),
20 | 'image': tf.train.Feature(
21 | bytes_list=tf.train.BytesList(value=[img_raw]))
22 | }))
23 | writer.write(example.SerializeToString())
24 | writer.close()
25 |
26 |
27 | def decode_record(filename_queue, patch_size,
28 | channel_num=3):
29 | reader = tf.TFRecordReader()
30 | _, serialized_example = reader.read(filename_queue)
31 | features = tf.parse_single_example(
32 | serialized_example,
33 | features={
34 | 'label': tf.FixedLenFeature([], tf.int64),
35 | 'image': tf.FixedLenFeature([], tf.string),
36 | })
37 |
38 | img = tf.decode_raw(features['image'], tf.uint8)
39 | img = tf.reshape(img, [patch_size, patch_size, channel_num])
40 | img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
41 | label = tf.cast(features['label'], tf.int32)
42 |
43 | return img, label
44 |
45 |
46 | def inputs(path, batch_size, num_epochs,
47 | patch_size, channel_num=3,
48 | capacity=50000, mad=30000):
49 | if not num_epochs:
50 | num_epochs = None
51 |
52 | with tf.name_scope('input'):
53 | filename_queue = tf.train.string_input_producer(
54 | [path], num_epochs=num_epochs)
55 | image, label = decode_record(filename_queue,
56 | patch_size,
57 | channel_num)
58 |
59 | images, labels = \
60 | tf.train.shuffle_batch(
61 | [image, label],
62 | batch_size=batch_size,
63 | num_threads=4,
64 | capacity=capacity,
65 | min_after_dequeue=mad)
66 |
67 | return images, labels
68 |
69 |
70 | if __name__ == '__main__':
71 | path = os.getcwd() + '/ImageSet/Train/'
72 | classes = np.arange(1, 6 + 1, 1)
73 | filename = 'TFRecords/train.tfrecords'
74 | patch_size = 35
75 | create_record(path, classes, filename, patch_size)
76 |
77 | channel_num = 3
78 | images, labels = inputs(path=filename,
79 | batch_size=10,
80 | num_epochs=2,
81 | patch_size=patch_size,
82 | channel_num=channel_num,
83 | capacity=500,
84 | mad=100)
85 |
86 | init = tf.group(tf.global_variables_initializer(),
87 | tf.local_variables_initializer())
88 |
89 | sess = tf.Session()
90 | sess.run(init)
91 |
92 | coord = tf.train.Coordinator()
93 | threads = tf.train.start_queue_runners(sess=sess, coord=coord)
94 |
95 | try:
96 | while not coord.should_stop():
97 | [val, l] = sess.run([images, labels])
98 | print(val.shape, l)
99 | except tf.errors.OutOfRangeError:
100 | print('Out of range.')
101 | finally:
102 | coord.request_stop()
103 |
104 | coord.join(threads)
105 | sess.close()
106 |
--------------------------------------------------------------------------------
/test_image.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import tensorflow as tf
4 | import tensorlayer as tl
5 | import train_model as tm
6 | import scipy.ndimage as sn
7 | from matplotlib import pyplot as plt
8 |
9 |
10 | HIGH_PROB = 0.7
11 | PATCH_RADIUS = int((tm.PATCH_SIZE - 1) / 2)
12 |
13 |
14 | def load_image(img_path):
15 | img = Image.open(img_path)
16 | img_raw = np.asarray(img, dtype=np.uint8)
17 | img_data = img_raw * (1. / 255) - 0.5
18 | pad_width = ((PATCH_RADIUS, PATCH_RADIUS),
19 | (PATCH_RADIUS, PATCH_RADIUS), (0, 0))
20 | img_pad = np.lib.pad(img_data, pad_width, 'symmetric')
21 |
22 | return img_raw, img_pad
23 |
24 |
25 | def strict_local_maximum(prob_map):
26 | prob_gau = np.zeros(prob_map.shape)
27 | sn.gaussian_filter(prob_map, 2,
28 | output=prob_gau,
29 | mode='mirror')
30 |
31 | prob_fil = np.zeros(prob_map.shape)
32 | sn.rank_filter(prob_gau, -2,
33 | output=prob_fil,
34 | footprint=np.ones([3, 3]))
35 |
36 | temp = np.logical_and(prob_gau > prob_fil,
37 | prob_map > HIGH_PROB) * 1.
38 | idx = np.where(temp > 0)
39 |
40 | return idx
41 |
42 |
43 | def plot_save_result(img_raw, idx, save_path):
44 | img_temp = np.copy(img_raw)
45 | for i in range(len(idx[0])):
46 | img_temp[idx[0][i], idx[1][i]] = [255, 0, 0]
47 | Image.fromarray(img_temp).save(save_path)
48 |
49 | plt.imshow(img_raw)
50 | plt.scatter(idx[1], idx[0], c='r', s=10)
51 | plt.axis('off')
52 | plt.show()
53 |
54 | return
55 |
56 |
57 | def test_image(img_path,
58 | model_path='model.npz',
59 | save_path='test.png'):
60 | img_raw, img_pad = load_image(img_path)
61 |
62 | rows = img_raw.shape[0]
63 | cols = img_raw.shape[1]
64 | test_set_shape = [cols, tm.PATCH_SIZE,
65 | tm.PATCH_SIZE, tm.CHANNEL_NUM]
66 | print(test_set_shape)
67 |
68 | x = tf.placeholder(tf.float32, test_set_shape)
69 | net = tm.build_network(x)
70 | y_out = tf.reshape(net.outputs, shape=[cols, tm.CLASS_NUM])
71 | y_stm = tf.nn.softmax(y_out)
72 | print(y_stm.shape)
73 |
74 | sess = tf.InteractiveSession()
75 | load_params = tl.files.load_npz(path='', name=model_path)
76 | tl.files.assign_params(sess, load_params, net)
77 |
78 | prob_map = np.zeros([rows, cols])
79 | for r in range(rows):
80 | print("Processing NO.{} rows.".format(r + 1))
81 | x_ = np.zeros(test_set_shape)
82 | for c in range(cols):
83 | x_[c] = img_pad[r:r + tm.PATCH_SIZE,
84 | c:c + tm.PATCH_SIZE, :]
85 |
86 | prob = y_stm.eval(feed_dict={x: x_})
87 | temp = np.where(prob[:, 5] > HIGH_PROB)[0]
88 | prob_map[r, temp] = prob[temp, 5]
89 |
90 | sess.close()
91 |
92 | idx = strict_local_maximum(prob_map)
93 | plot_save_result(img_raw, idx, save_path)
94 |
95 | return
96 |
97 |
98 | if __name__ == '__main__':
99 | test_image('ImageSet/Test/img_41.png',
100 | 'model.npz',
101 | 'test1.png')
102 |
--------------------------------------------------------------------------------
/train_model.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import read_data as rd
4 | import tensorflow as tf
5 | import tensorlayer as tl
6 |
7 |
8 | NUM_EPOCHS = 10
9 | BATCH_SIZE = 200
10 | LEARNING_RATE = 0.001
11 |
12 | CLASS_NUM = 6
13 | PATCH_SIZE = 35
14 | CHANNEL_NUM = 3
15 |
16 | LABEL_SET_SHAPE = [BATCH_SIZE, CLASS_NUM]
17 | IMAGE_SET_SHAPE = [BATCH_SIZE, PATCH_SIZE,
18 | PATCH_SIZE, CHANNEL_NUM]
19 |
20 |
21 | def weight(shape):
22 | sd = 1 / np.sqrt(np.prod(shape[0:3]) * CLASS_NUM)
23 | return tf.random_normal_initializer(stddev=sd)
24 |
25 |
26 | def conv2d(net, shape, act=tf.nn.relu, name=None):
27 | return tl.layers.Conv2dLayer(net,
28 | act=act,
29 | shape=shape,
30 | strides=[1, 1, 1, 1],
31 | padding='VALID',
32 | W_init=weight(shape),
33 | b_init=None,
34 | name=name)
35 |
36 |
37 | def max_pool(net, name=None):
38 | return tl.layers.PoolLayer(net,
39 | ksize=[1, 2, 2, 1],
40 | strides=[1, 2, 2, 1],
41 | padding='VALID',
42 | pool=tf.nn.max_pool,
43 | name=name)
44 |
45 |
46 | def sub2ind(shape, rows, cols):
47 | return rows * shape[1] + cols
48 |
49 |
50 | def reshape_labels(labels):
51 | lc = np.zeros(LABEL_SET_SHAPE).flatten()
52 | index = sub2ind(LABEL_SET_SHAPE,
53 | np.arange(BATCH_SIZE),
54 | np.reshape(labels, [1, BATCH_SIZE]))
55 | lc[index] = 1
56 |
57 | return np.reshape(lc, LABEL_SET_SHAPE)
58 |
59 |
60 | def build_network(x):
61 | net = tl.layers.InputLayer(inputs=x, name='input_layer')
62 | net = conv2d(net, [6, 6, 3, 30], name='conv1')
63 | net = max_pool(net, 'maxpool1')
64 | net = conv2d(net, [6, 6, 30, 50], name='conv2')
65 | net = max_pool(net, 'maxpool2')
66 | net = conv2d(net, [4, 4, 50, 500], name='conv3')
67 | net = conv2d(net, [2, 2, 500, 6], tf.identity, name='conv4')
68 |
69 | return net
70 |
71 |
72 | def train_model(train_set_path,
73 | validation_set_path,
74 | save_model_path):
75 | x = tf.placeholder(tf.float32, shape=IMAGE_SET_SHAPE)
76 | y = tf.placeholder(tf.float32, shape=LABEL_SET_SHAPE)
77 |
78 | net = build_network(x)
79 |
80 | y_out = net.outputs
81 | y_out = tf.reshape(y_out, shape=LABEL_SET_SHAPE)
82 |
83 | loss = tf.reduce_mean(
84 | tf.nn.softmax_cross_entropy_with_logits(labels=y,
85 | logits=y_out))
86 |
87 | train_step = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)
88 |
89 | y_arg = tf.reshape(tf.argmax(y_out, 1), shape=[BATCH_SIZE])
90 | correct_prediction = tf.equal(y_arg, tf.argmax(y, 1))
91 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
92 |
93 | tri_img, tri_lbl = rd.inputs(path=train_set_path,
94 | batch_size=BATCH_SIZE,
95 | num_epochs=NUM_EPOCHS,
96 | patch_size=PATCH_SIZE,
97 | channel_num=CHANNEL_NUM)
98 |
99 | val_img, val_lbl = rd.inputs(path=validation_set_path,
100 | batch_size=BATCH_SIZE,
101 | num_epochs=NUM_EPOCHS,
102 | patch_size=PATCH_SIZE,
103 | channel_num=CHANNEL_NUM)
104 |
105 | init = tf.group(tf.global_variables_initializer(),
106 | tf.local_variables_initializer())
107 |
108 | sess = tf.InteractiveSession()
109 | sess.run(init)
110 |
111 | coord = tf.train.Coordinator()
112 | thread = tf.train.start_queue_runners(sess=sess, coord=coord)
113 |
114 | try:
115 | step = 1
116 | while not coord.should_stop():
117 | [tris, tril] = sess.run([tri_img, tri_lbl])
118 | fd_train = {x: tris, y: reshape_labels(tril)}
119 |
120 | if step % 10 == 0 or step == 1:
121 | [vals, vall] = sess.run([val_img, val_lbl])
122 | fd_val = {x: vals, y: reshape_labels(vall)}
123 |
124 | print("----------\nStep {}:\n----------".format(step))
125 |
126 | tri_accuracy = accuracy.eval(feed_dict=fd_train)
127 | print("Training accuracy {0:.6f}".format(tri_accuracy))
128 | tri_cost = loss.eval(feed_dict=fd_train)
129 | print("Training cost is {0:.6f}".format(tri_cost))
130 |
131 | val_accuracy = accuracy.eval(feed_dict=fd_val)
132 | print("Validation accuracy {0:.6f}".format(val_accuracy))
133 | val_cost = loss.eval(feed_dict=fd_val)
134 | print("Validation cost is {0:.6f}".format(val_cost))
135 |
136 | sess.run(train_step, feed_dict=fd_train)
137 | step += 1
138 | time.sleep(1)
139 |
140 | except tf.errors.OutOfRangeError:
141 | print('---------\nTraining has stopped.')
142 | finally:
143 | coord.request_stop()
144 |
145 | tl.files.save_npz(net.all_params, save_model_path)
146 | coord.join(thread)
147 | sess.close()
148 |
149 |
150 | if __name__ == '__main__':
151 | train_set_path = 'TFRecords/train.tfrecords'
152 | validation_set_path = 'TFRecords/validation.tfrecords'
153 | save_model_path = 'model.npz'
154 | train_model(train_set_path,
155 | validation_set_path,
156 | save_model_path)
157 |
--------------------------------------------------------------------------------