├── ImageSet ├── Test │ └── img_41.png ├── Train │ ├── 1 │ │ └── img1.png │ ├── 2 │ │ └── img1.png │ ├── 3 │ │ └── img1.png │ ├── 4 │ │ └── img1.png │ ├── 5 │ │ └── img1.png │ └── 6 │ │ └── img1.png └── Validation │ ├── 1 │ └── img1.png │ ├── 2 │ └── img1.png │ ├── 3 │ └── img1.png │ ├── 4 │ └── img1.png │ ├── 5 │ └── img1.png │ └── 6 │ └── img1.png ├── README.md ├── Results ├── bad_2_good.png ├── test1.png ├── test10.png ├── test2.png ├── test3.png ├── test4.png ├── test5.png ├── test6.png ├── test7.png ├── test8.png └── test9.png ├── model.npz ├── read_data.py ├── test_image.py └── train_model.py /ImageSet/Test/img_41.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Test/img_41.png -------------------------------------------------------------------------------- /ImageSet/Train/1/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Train/1/img1.png -------------------------------------------------------------------------------- /ImageSet/Train/2/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Train/2/img1.png -------------------------------------------------------------------------------- /ImageSet/Train/3/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Train/3/img1.png -------------------------------------------------------------------------------- /ImageSet/Train/4/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Train/4/img1.png -------------------------------------------------------------------------------- /ImageSet/Train/5/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Train/5/img1.png -------------------------------------------------------------------------------- /ImageSet/Train/6/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Train/6/img1.png -------------------------------------------------------------------------------- /ImageSet/Validation/1/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Validation/1/img1.png -------------------------------------------------------------------------------- /ImageSet/Validation/2/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Validation/2/img1.png -------------------------------------------------------------------------------- /ImageSet/Validation/3/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Validation/3/img1.png -------------------------------------------------------------------------------- /ImageSet/Validation/4/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Validation/4/img1.png -------------------------------------------------------------------------------- /ImageSet/Validation/5/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Validation/5/img1.png -------------------------------------------------------------------------------- /ImageSet/Validation/6/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/ImageSet/Validation/6/img1.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cell Detection 2 | A course project that detect cells in image by a simple full convolution neural network. The project is driven by TensorFlow. 3 | 4 | # Dependencies 5 | 6 | + python 3.5.2 7 | + numpy 1.11.3 8 | + scipy 0.18.1 9 | + pillow 4.1.0 10 | + tensorflow 1.0.0 11 | + matplotlib 2.0.0 12 | + tensorlayer 1.4.1 13 | 14 | This demo is tested only in Ubuntu 16.04. 15 | 16 | # Data Organization 17 | 18 | 50 full-scale images are composed of cells whose positions have been marked, from which training batch is extracted from 30 images, validation batch is extracted from 10 images, and the rest 10 images are used to test. Image set is not included in this repositery, you could eamil to quqixun@gmail.com to request dataset. 19 | 20 | ### Training and Validating Data 21 | 22 | Six groups patches are extracted from training and validating images on the basis of the different locations of patches' centers. The dimension of each patch is 35 by 35 by 3. 23 | The groups are shown as follows with one sample patch, in each group, the patch center locates at: 24 | + **Group 1 - the interaction region of cells**: ![alt text](https://github.com/quqixun/CellDetection/blob/master/ImageSet/Train/1/img1.png) 25 | + **Group 2 - non-goal cell**: ![alt text](https://github.com/quqixun/CellDetection/blob/master/ImageSet/Train/2/img1.png) 26 | + **Group 3 - nearby region of cell's edge**: ![alt text](https://github.com/quqixun/CellDetection/blob/master/ImageSet/Train/3/img1.png) 27 | + **Group 4 - the gap between cells**: ![alt text](https://github.com/quqixun/CellDetection/blob/master/ImageSet/Train/4/img1.png) 28 | + **Group 5 - background**: ![alt text](https://github.com/quqixun/CellDetection/blob/master/ImageSet/Train/5/img1.png) 29 | + **Group 6 - the center of cell**: ![alt text](https://github.com/quqixun/CellDetection/blob/master/ImageSet/Train/6/img1.png) 30 | 31 | ### Testing Data 32 | 33 | A sample of testing image is shown below. 34 | 35 | ![alt text](https://github.com/quqixun/CellDetection/blob/master/ImageSet/Test/img_41.png) 36 | 37 | # Code Organization 38 | 39 | + **read_data.py**: Create TFRecords for training and validating batch to train the model. Training and validating batch is randomly selected according to the batch size. 40 | + **train_model.py**: In this solution, a simple end-to-end convolution nural network is implemented, being trained and updated by input training set. The model is saved into the file "model.npz". 41 | + **test_model.py**: Carry out a pixel-wised classification on the input test image, reserving pixels that have highest posibbility to be a cell center. 42 | 43 | # Usage 44 | 45 | In terminal, 46 | 47 | + **Step 1**: run **python read_data.py** to create TFRecords (change the folder path and the name of TFRecords) 48 | + **Step 2**: run **python train_model.py** to train and save model 49 | + **Step 3**: run **python test_image.py** to test full-scale images 50 | 51 | # Result 52 | 53 | ### A good case: 54 | 55 | 56 | 57 | ### A bad case: 58 | 59 | 60 | 61 | Here is a bad case, in which several cells have not been detected. Increasing the number of training patches is able to solve this problem. The model is trainded by **29,818** patches generates the bad result as shown above. If the number of data is augmented by rotating and modifing HSV color space, the model is likely to perform better. The better result image is shown as below, which is detected by the model that is trained with **321,985** training patches. (This image is obtained from the solution in Matlab, data augmentation is not included in this repository.) 62 | 63 | 64 | -------------------------------------------------------------------------------- /Results/bad_2_good.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/bad_2_good.png -------------------------------------------------------------------------------- /Results/test1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/test1.png -------------------------------------------------------------------------------- /Results/test10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/test10.png -------------------------------------------------------------------------------- /Results/test2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/test2.png -------------------------------------------------------------------------------- /Results/test3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/test3.png -------------------------------------------------------------------------------- /Results/test4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/test4.png -------------------------------------------------------------------------------- /Results/test5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/test5.png -------------------------------------------------------------------------------- /Results/test6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/test6.png -------------------------------------------------------------------------------- /Results/test7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/test7.png -------------------------------------------------------------------------------- /Results/test8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/test8.png -------------------------------------------------------------------------------- /Results/test9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/Results/test9.png -------------------------------------------------------------------------------- /model.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/CellDetection/53d48dff351e60870d08a2e6fd9417cf302d8759/model.npz -------------------------------------------------------------------------------- /read_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import tensorflow as tf 5 | 6 | 7 | def create_record(path, classes, filename, patch_size): 8 | writer = tf.python_io.TFRecordWriter(filename) 9 | for index, name in enumerate(classes): 10 | class_path = path + str(name) + '/' 11 | print(class_path, index) 12 | for img_name in os.listdir(class_path): 13 | img_path = class_path + img_name 14 | img = Image.open(img_path) 15 | img = img.resize((patch_size, patch_size)) 16 | img_raw = img.tobytes() 17 | example = tf.train.Example(features=tf.train.Features(feature={ 18 | 'label': tf.train.Feature( 19 | int64_list=tf.train.Int64List(value=[index])), 20 | 'image': tf.train.Feature( 21 | bytes_list=tf.train.BytesList(value=[img_raw])) 22 | })) 23 | writer.write(example.SerializeToString()) 24 | writer.close() 25 | 26 | 27 | def decode_record(filename_queue, patch_size, 28 | channel_num=3): 29 | reader = tf.TFRecordReader() 30 | _, serialized_example = reader.read(filename_queue) 31 | features = tf.parse_single_example( 32 | serialized_example, 33 | features={ 34 | 'label': tf.FixedLenFeature([], tf.int64), 35 | 'image': tf.FixedLenFeature([], tf.string), 36 | }) 37 | 38 | img = tf.decode_raw(features['image'], tf.uint8) 39 | img = tf.reshape(img, [patch_size, patch_size, channel_num]) 40 | img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 41 | label = tf.cast(features['label'], tf.int32) 42 | 43 | return img, label 44 | 45 | 46 | def inputs(path, batch_size, num_epochs, 47 | patch_size, channel_num=3, 48 | capacity=50000, mad=30000): 49 | if not num_epochs: 50 | num_epochs = None 51 | 52 | with tf.name_scope('input'): 53 | filename_queue = tf.train.string_input_producer( 54 | [path], num_epochs=num_epochs) 55 | image, label = decode_record(filename_queue, 56 | patch_size, 57 | channel_num) 58 | 59 | images, labels = \ 60 | tf.train.shuffle_batch( 61 | [image, label], 62 | batch_size=batch_size, 63 | num_threads=4, 64 | capacity=capacity, 65 | min_after_dequeue=mad) 66 | 67 | return images, labels 68 | 69 | 70 | if __name__ == '__main__': 71 | path = os.getcwd() + '/ImageSet/Train/' 72 | classes = np.arange(1, 6 + 1, 1) 73 | filename = 'TFRecords/train.tfrecords' 74 | patch_size = 35 75 | create_record(path, classes, filename, patch_size) 76 | 77 | channel_num = 3 78 | images, labels = inputs(path=filename, 79 | batch_size=10, 80 | num_epochs=2, 81 | patch_size=patch_size, 82 | channel_num=channel_num, 83 | capacity=500, 84 | mad=100) 85 | 86 | init = tf.group(tf.global_variables_initializer(), 87 | tf.local_variables_initializer()) 88 | 89 | sess = tf.Session() 90 | sess.run(init) 91 | 92 | coord = tf.train.Coordinator() 93 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 94 | 95 | try: 96 | while not coord.should_stop(): 97 | [val, l] = sess.run([images, labels]) 98 | print(val.shape, l) 99 | except tf.errors.OutOfRangeError: 100 | print('Out of range.') 101 | finally: 102 | coord.request_stop() 103 | 104 | coord.join(threads) 105 | sess.close() 106 | -------------------------------------------------------------------------------- /test_image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import tensorflow as tf 4 | import tensorlayer as tl 5 | import train_model as tm 6 | import scipy.ndimage as sn 7 | from matplotlib import pyplot as plt 8 | 9 | 10 | HIGH_PROB = 0.7 11 | PATCH_RADIUS = int((tm.PATCH_SIZE - 1) / 2) 12 | 13 | 14 | def load_image(img_path): 15 | img = Image.open(img_path) 16 | img_raw = np.asarray(img, dtype=np.uint8) 17 | img_data = img_raw * (1. / 255) - 0.5 18 | pad_width = ((PATCH_RADIUS, PATCH_RADIUS), 19 | (PATCH_RADIUS, PATCH_RADIUS), (0, 0)) 20 | img_pad = np.lib.pad(img_data, pad_width, 'symmetric') 21 | 22 | return img_raw, img_pad 23 | 24 | 25 | def strict_local_maximum(prob_map): 26 | prob_gau = np.zeros(prob_map.shape) 27 | sn.gaussian_filter(prob_map, 2, 28 | output=prob_gau, 29 | mode='mirror') 30 | 31 | prob_fil = np.zeros(prob_map.shape) 32 | sn.rank_filter(prob_gau, -2, 33 | output=prob_fil, 34 | footprint=np.ones([3, 3])) 35 | 36 | temp = np.logical_and(prob_gau > prob_fil, 37 | prob_map > HIGH_PROB) * 1. 38 | idx = np.where(temp > 0) 39 | 40 | return idx 41 | 42 | 43 | def plot_save_result(img_raw, idx, save_path): 44 | img_temp = np.copy(img_raw) 45 | for i in range(len(idx[0])): 46 | img_temp[idx[0][i], idx[1][i]] = [255, 0, 0] 47 | Image.fromarray(img_temp).save(save_path) 48 | 49 | plt.imshow(img_raw) 50 | plt.scatter(idx[1], idx[0], c='r', s=10) 51 | plt.axis('off') 52 | plt.show() 53 | 54 | return 55 | 56 | 57 | def test_image(img_path, 58 | model_path='model.npz', 59 | save_path='test.png'): 60 | img_raw, img_pad = load_image(img_path) 61 | 62 | rows = img_raw.shape[0] 63 | cols = img_raw.shape[1] 64 | test_set_shape = [cols, tm.PATCH_SIZE, 65 | tm.PATCH_SIZE, tm.CHANNEL_NUM] 66 | print(test_set_shape) 67 | 68 | x = tf.placeholder(tf.float32, test_set_shape) 69 | net = tm.build_network(x) 70 | y_out = tf.reshape(net.outputs, shape=[cols, tm.CLASS_NUM]) 71 | y_stm = tf.nn.softmax(y_out) 72 | print(y_stm.shape) 73 | 74 | sess = tf.InteractiveSession() 75 | load_params = tl.files.load_npz(path='', name=model_path) 76 | tl.files.assign_params(sess, load_params, net) 77 | 78 | prob_map = np.zeros([rows, cols]) 79 | for r in range(rows): 80 | print("Processing NO.{} rows.".format(r + 1)) 81 | x_ = np.zeros(test_set_shape) 82 | for c in range(cols): 83 | x_[c] = img_pad[r:r + tm.PATCH_SIZE, 84 | c:c + tm.PATCH_SIZE, :] 85 | 86 | prob = y_stm.eval(feed_dict={x: x_}) 87 | temp = np.where(prob[:, 5] > HIGH_PROB)[0] 88 | prob_map[r, temp] = prob[temp, 5] 89 | 90 | sess.close() 91 | 92 | idx = strict_local_maximum(prob_map) 93 | plot_save_result(img_raw, idx, save_path) 94 | 95 | return 96 | 97 | 98 | if __name__ == '__main__': 99 | test_image('ImageSet/Test/img_41.png', 100 | 'model.npz', 101 | 'test1.png') 102 | -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import read_data as rd 4 | import tensorflow as tf 5 | import tensorlayer as tl 6 | 7 | 8 | NUM_EPOCHS = 10 9 | BATCH_SIZE = 200 10 | LEARNING_RATE = 0.001 11 | 12 | CLASS_NUM = 6 13 | PATCH_SIZE = 35 14 | CHANNEL_NUM = 3 15 | 16 | LABEL_SET_SHAPE = [BATCH_SIZE, CLASS_NUM] 17 | IMAGE_SET_SHAPE = [BATCH_SIZE, PATCH_SIZE, 18 | PATCH_SIZE, CHANNEL_NUM] 19 | 20 | 21 | def weight(shape): 22 | sd = 1 / np.sqrt(np.prod(shape[0:3]) * CLASS_NUM) 23 | return tf.random_normal_initializer(stddev=sd) 24 | 25 | 26 | def conv2d(net, shape, act=tf.nn.relu, name=None): 27 | return tl.layers.Conv2dLayer(net, 28 | act=act, 29 | shape=shape, 30 | strides=[1, 1, 1, 1], 31 | padding='VALID', 32 | W_init=weight(shape), 33 | b_init=None, 34 | name=name) 35 | 36 | 37 | def max_pool(net, name=None): 38 | return tl.layers.PoolLayer(net, 39 | ksize=[1, 2, 2, 1], 40 | strides=[1, 2, 2, 1], 41 | padding='VALID', 42 | pool=tf.nn.max_pool, 43 | name=name) 44 | 45 | 46 | def sub2ind(shape, rows, cols): 47 | return rows * shape[1] + cols 48 | 49 | 50 | def reshape_labels(labels): 51 | lc = np.zeros(LABEL_SET_SHAPE).flatten() 52 | index = sub2ind(LABEL_SET_SHAPE, 53 | np.arange(BATCH_SIZE), 54 | np.reshape(labels, [1, BATCH_SIZE])) 55 | lc[index] = 1 56 | 57 | return np.reshape(lc, LABEL_SET_SHAPE) 58 | 59 | 60 | def build_network(x): 61 | net = tl.layers.InputLayer(inputs=x, name='input_layer') 62 | net = conv2d(net, [6, 6, 3, 30], name='conv1') 63 | net = max_pool(net, 'maxpool1') 64 | net = conv2d(net, [6, 6, 30, 50], name='conv2') 65 | net = max_pool(net, 'maxpool2') 66 | net = conv2d(net, [4, 4, 50, 500], name='conv3') 67 | net = conv2d(net, [2, 2, 500, 6], tf.identity, name='conv4') 68 | 69 | return net 70 | 71 | 72 | def train_model(train_set_path, 73 | validation_set_path, 74 | save_model_path): 75 | x = tf.placeholder(tf.float32, shape=IMAGE_SET_SHAPE) 76 | y = tf.placeholder(tf.float32, shape=LABEL_SET_SHAPE) 77 | 78 | net = build_network(x) 79 | 80 | y_out = net.outputs 81 | y_out = tf.reshape(y_out, shape=LABEL_SET_SHAPE) 82 | 83 | loss = tf.reduce_mean( 84 | tf.nn.softmax_cross_entropy_with_logits(labels=y, 85 | logits=y_out)) 86 | 87 | train_step = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss) 88 | 89 | y_arg = tf.reshape(tf.argmax(y_out, 1), shape=[BATCH_SIZE]) 90 | correct_prediction = tf.equal(y_arg, tf.argmax(y, 1)) 91 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 92 | 93 | tri_img, tri_lbl = rd.inputs(path=train_set_path, 94 | batch_size=BATCH_SIZE, 95 | num_epochs=NUM_EPOCHS, 96 | patch_size=PATCH_SIZE, 97 | channel_num=CHANNEL_NUM) 98 | 99 | val_img, val_lbl = rd.inputs(path=validation_set_path, 100 | batch_size=BATCH_SIZE, 101 | num_epochs=NUM_EPOCHS, 102 | patch_size=PATCH_SIZE, 103 | channel_num=CHANNEL_NUM) 104 | 105 | init = tf.group(tf.global_variables_initializer(), 106 | tf.local_variables_initializer()) 107 | 108 | sess = tf.InteractiveSession() 109 | sess.run(init) 110 | 111 | coord = tf.train.Coordinator() 112 | thread = tf.train.start_queue_runners(sess=sess, coord=coord) 113 | 114 | try: 115 | step = 1 116 | while not coord.should_stop(): 117 | [tris, tril] = sess.run([tri_img, tri_lbl]) 118 | fd_train = {x: tris, y: reshape_labels(tril)} 119 | 120 | if step % 10 == 0 or step == 1: 121 | [vals, vall] = sess.run([val_img, val_lbl]) 122 | fd_val = {x: vals, y: reshape_labels(vall)} 123 | 124 | print("----------\nStep {}:\n----------".format(step)) 125 | 126 | tri_accuracy = accuracy.eval(feed_dict=fd_train) 127 | print("Training accuracy {0:.6f}".format(tri_accuracy)) 128 | tri_cost = loss.eval(feed_dict=fd_train) 129 | print("Training cost is {0:.6f}".format(tri_cost)) 130 | 131 | val_accuracy = accuracy.eval(feed_dict=fd_val) 132 | print("Validation accuracy {0:.6f}".format(val_accuracy)) 133 | val_cost = loss.eval(feed_dict=fd_val) 134 | print("Validation cost is {0:.6f}".format(val_cost)) 135 | 136 | sess.run(train_step, feed_dict=fd_train) 137 | step += 1 138 | time.sleep(1) 139 | 140 | except tf.errors.OutOfRangeError: 141 | print('---------\nTraining has stopped.') 142 | finally: 143 | coord.request_stop() 144 | 145 | tl.files.save_npz(net.all_params, save_model_path) 146 | coord.join(thread) 147 | sess.close() 148 | 149 | 150 | if __name__ == '__main__': 151 | train_set_path = 'TFRecords/train.tfrecords' 152 | validation_set_path = 'TFRecords/validation.tfrecords' 153 | save_model_path = 'model.npz' 154 | train_model(train_set_path, 155 | validation_set_path, 156 | save_model_path) 157 | --------------------------------------------------------------------------------