├── results
├── result_1.png
└── results_2.png
├── __pycache__
└── config.cpython-36.pyc
├── misc
├── __pycache__
│ ├── utils.cpython-36.pyc
│ └── viz_utils.cpython-36.pyc
├── utils.py
├── viz_utils.py
└── patch_extractor.py
├── loader
├── __pycache__
│ ├── augs.cpython-36.pyc
│ └── loader.cpython-36.pyc
├── loader.py
└── augs.py
├── model
├── __pycache__
│ ├── sonnet.cpython-36.pyc
│ └── utils.cpython-36.pyc
├── utils.py
└── sonnet.py
├── opt
├── __pycache__
│ └── hyperconfig.cpython-36.pyc
└── hyperconfig.py
├── metrics
├── __pycache__
│ └── stats_utils.cpython-36.pyc
└── stats_utils.py
├── postproc
├── __pycache__
│ └── post_sonnet.cpython-36.pyc
└── post_sonnet.py
├── requirements.txt
├── extract_patches.py
├── process.py
├── README.md
├── config.py
├── infer.py
├── train.py
└── compute_stats.py
/results/result_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/results/result_1.png
--------------------------------------------------------------------------------
/results/results_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/results/results_2.png
--------------------------------------------------------------------------------
/__pycache__/config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/__pycache__/config.cpython-36.pyc
--------------------------------------------------------------------------------
/misc/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/misc/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/loader/__pycache__/augs.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/loader/__pycache__/augs.cpython-36.pyc
--------------------------------------------------------------------------------
/loader/__pycache__/loader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/loader/__pycache__/loader.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/sonnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/model/__pycache__/sonnet.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/model/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/misc/__pycache__/viz_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/misc/__pycache__/viz_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/opt/__pycache__/hyperconfig.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/opt/__pycache__/hyperconfig.cpython-36.pyc
--------------------------------------------------------------------------------
/metrics/__pycache__/stats_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/metrics/__pycache__/stats_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/postproc/__pycache__/post_sonnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/postproc/__pycache__/post_sonnet.cpython-36.pyc
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scikit_image==0.15.0
2 | pandas==1.1.0
3 | matplotlib==3.3.4
4 | tensorflow==1.12.0
5 | tensorpack==0.9.0.1
6 | opencv_python_headless==3.4.8.29
7 | numpy==1.19.2
8 | scipy==1.1.0
9 | scikit_learn==0.24.2
10 | skimage==0.0
11 |
--------------------------------------------------------------------------------
/opt/hyperconfig.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | sonnet = {
4 | 'train_input_shape' : [270, 270],
5 | 'train_mask_shape' : [76, 76],
6 | 'infer_input_shape' : [270, 270],
7 | 'infer_mask_shape' : [76, 76],
8 |
9 | 'training_phase' : [
10 | {
11 | 'nr_epochs': 50,
12 | 'manual_parameters' : {
13 | # tuple(initial value, schedule)
14 | 'learning_rate': (1.0e-4, [('25', 1.0e-5)]),
15 | },
16 | 'pretrained_path' : './ImageNet_pretrained_EfficientB0.npz',
17 | 'train_batch_size' : 8,
18 | 'infer_batch_size' : 16,
19 |
20 | 'model_flags' : {
21 | 'freeze_en' : True
22 | }
23 | },
24 |
25 | {
26 | 'nr_epochs': 25,
27 | 'manual_parameters' : {
28 | # tuple(initial value, schedule)
29 | 'learning_rate': (1.0e-4, [('25', 1.0e-5)]),
30 | },
31 | # path to load, -1 to auto load checkpoint from previous phase,
32 | # None to start from scratch
33 | 'pretrained_path' : -1,
34 | 'train_batch_size' : 4,
35 | 'infer_batch_size' : 8,
36 |
37 | 'model_flags' : {
38 | 'freeze_en' : False
39 | }
40 | },
41 |
42 | {
43 | 'nr_epochs': 25,
44 | 'manual_parameters' : {
45 | # tuple(initial value, schedule)
46 | 'learning_rate': (1.0e-5, [('25', 1.0e-5)]),
47 | },
48 | # path to load, -1 to auto load checkpoint from previous phase,
49 | # None to start from scratch
50 | 'pretrained_path' : -1,
51 | 'train_batch_size' : 4,
52 | 'infer_batch_size' : 8,
53 |
54 | 'model_flags' : {
55 | 'freeze_en' : False
56 | }
57 | }
58 | ],
59 |
60 | 'loss_term' : {'bce' : 1, 'dice' : 1},
61 |
62 | 'train_optimizer' : tf.train.AdamOptimizer,
63 |
64 | 'inf_auto_metric' : 'valid_dice',
65 | 'inf_auto_comparator' : '>',
66 | 'inf_batch_size' : 4,
67 | }
--------------------------------------------------------------------------------
/misc/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import glob
3 | import os
4 | import shutil
5 |
6 | import cv2
7 | import numpy as np
8 |
9 |
10 | ####
11 | def normalize(mask, dtype=np.uint8):
12 | return (255 * mask / np.amax(mask)).astype(dtype)
13 |
14 | ####
15 | def bounding_box(img):
16 | rows = np.any(img, axis=1)
17 | cols = np.any(img, axis=0)
18 | rmin, rmax = np.where(rows)[0][[0, -1]]
19 | cmin, cmax = np.where(cols)[0][[0, -1]]
20 | # due to python indexing, need to add 1 to max
21 | # else accessing will be 1px in the box, not out
22 | rmax += 1
23 | cmax += 1
24 | return [rmin, rmax, cmin, cmax]
25 |
26 | ####
27 | def cropping_center(x, crop_shape, batch=False):
28 | orig_shape = x.shape
29 | if not batch:
30 | h0 = int((orig_shape[0] - crop_shape[0]) * 0.5)
31 | w0 = int((orig_shape[1] - crop_shape[1]) * 0.5)
32 | x = x[h0:h0 + crop_shape[0], w0:w0 + crop_shape[1]]
33 | else:
34 | h0 = int((orig_shape[1] - crop_shape[0]) * 0.5)
35 | w0 = int((orig_shape[2] - crop_shape[1]) * 0.5)
36 | x = x[:,h0:h0 + crop_shape[0], w0:w0 + crop_shape[1]]
37 | return x
38 |
39 | ####
40 | def rm_n_mkdir(dir_path):
41 | if (os.path.isdir(dir_path)):
42 | shutil.rmtree(dir_path)
43 | os.makedirs(dir_path)
44 |
45 | ####
46 | def get_files(data_dir_list, data_ext):
47 | """
48 | Given a list of directories containing data with extention 'data_ext',
49 | generate a list of paths for all files within these directories
50 | """
51 |
52 | data_files = []
53 | for sub_dir in data_dir_list:
54 | files_list = glob.glob(sub_dir + '/*'+ data_ext)
55 | files_list.sort() # ensure same order
56 | data_files.extend(files_list)
57 |
58 | return data_files
59 |
60 | ####
61 | def get_inst_centroid(inst_map):
62 | inst_centroid_list = []
63 | inst_id_list = list(np.unique(inst_map))
64 | for inst_id in inst_id_list[1:]: # avoid 0 i.e background
65 | mask = np.array(inst_map == inst_id, np.uint8)
66 | inst_moment = cv2.moments(mask)
67 | inst_centroid = [(inst_moment["m10"] / inst_moment["m00"]),
68 | (inst_moment["m01"] / inst_moment["m00"])]
69 | inst_centroid_list.append(inst_centroid)
70 | return np.array(inst_centroid_list)
71 |
72 | ###
73 | def change_keys_dict(a_dict):
74 | for i in a_dict.copy():
75 | if 'bias' in i:
76 | new_key = i.replace('bias', 'b')
77 | a_dict[new_key] = a.pop(i)
78 | np.savez('new_dict', **a_dict)
--------------------------------------------------------------------------------
/misc/viz_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import cv2
3 | import math
4 | import random
5 | import colorsys
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | from matplotlib import cm
9 |
10 | from .utils import bounding_box
11 |
12 | ####
13 | def random_colors(N, bright=True):
14 | """
15 | Generate random colors.
16 | To get visually distinct colors, generate them in HSV space then
17 | convert to RGB.
18 | """
19 | brightness = 1.0 if bright else 0.7
20 | hsv = [(i / N, 1, brightness) for i in range(N)]
21 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
22 | random.shuffle(colors)
23 | return colors
24 |
25 | ####
26 | def visualize_instances(mask, canvas=None, color=None):
27 | """
28 | Args:
29 | mask: array of NW
30 | Return:
31 | Image with the instance overlaid
32 | """
33 |
34 | canvas = np.full(mask.shape + (3,), 200, dtype=np.uint8) \
35 | if canvas is None else np.copy(canvas)
36 |
37 | insts_list = list(np.unique(mask))
38 | insts_list.remove(0) # remove background
39 |
40 | inst_colors = random_colors(len(insts_list))
41 | inst_colors = np.array(inst_colors) * 255
42 |
43 | for idx, inst_id in enumerate(insts_list):
44 | inst_color = color[idx] if color is not None else inst_colors[idx]
45 | inst_map = np.array(mask == inst_id, np.uint8)
46 | y1, y2, x1, x2 = bounding_box(inst_map)
47 | y1 = y1 - 2 if y1 - 2 >= 0 else y1
48 | x1 = x1 - 2 if x1 - 2 >= 0 else x1
49 | x2 = x2 + 2 if x2 + 2 <= mask.shape[1] - 1 else x2
50 | y2 = y2 + 2 if y2 + 2 <= mask.shape[0] - 1 else y2
51 | inst_map_crop = inst_map[y1:y2, x1:x2]
52 | inst_canvas_crop = canvas[y1:y2, x1:x2]
53 | contours = cv2.findContours(inst_map_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
54 | cv2.drawContours(inst_canvas_crop, contours[0], -1, inst_color, 2)
55 | canvas[y1:y2, x1:x2] = inst_canvas_crop
56 | return canvas
57 |
58 | ####
59 | def gen_figure(imgs_list, titles, fig_inch, shape=None,
60 | share_ax='all', show=False, colormap=plt.get_cmap('jet')):
61 |
62 | num_img = len(imgs_list)
63 | if shape is None:
64 | ncols = math.ceil(math.sqrt(num_img))
65 | nrows = math.ceil(num_img / ncols)
66 | else:
67 | nrows, ncols = shape
68 |
69 | # generate figure
70 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
71 | sharex=share_ax, sharey=share_ax)
72 | axes = [axes] if nrows == 1 else axes
73 |
74 | # not very elegant
75 | idx = 0
76 | for ax in axes:
77 | for cell in ax:
78 | cell.set_title(titles[idx])
79 | cell.imshow(imgs_list[idx], cmap=colormap)
80 | cell.tick_params(axis='both',
81 | which='both',
82 | bottom='off',
83 | top='off',
84 | labelbottom='off',
85 | right='off',
86 | left='off',
87 | labelleft='off')
88 | idx += 1
89 | if idx == len(titles):
90 | break
91 | if idx == len(titles):
92 | break
93 |
94 | fig.tight_layout()
95 | return fig
96 | ####
97 |
--------------------------------------------------------------------------------
/extract_patches.py:
--------------------------------------------------------------------------------
1 |
2 | import glob
3 | import os
4 |
5 | import cv2
6 | import numpy as np
7 | import scipy.io as sio
8 | from scipy.ndimage import measurements
9 | import scipy.io as sio
10 |
11 | from misc.patch_extractor import PatchExtractor
12 | from misc.utils import rm_n_mkdir
13 |
14 | from config import Config
15 |
16 | ###########################################################################
17 | if __name__ == '__main__':
18 |
19 | cfg = Config()
20 |
21 | extract_type = 'mirror' # 'valid' for fcn8 segnet etc.
22 | # 'mirror' for u-net etc.
23 | # check the patch_extractor.py 'main' to see the different
24 |
25 | # orignal size (win size) - input size - output size (step size)
26 | # 540x540 - 270x270 - 76x76 sonnet
27 | step_size = [76, 76] # should match self.train_mask_shape (config.py)
28 | win_size = [540, 540] # should be at least twice time larger than
29 | # self.train_base_shape (config.py) to reduce
30 | # the padding effect during augmentation
31 |
32 | xtractor = PatchExtractor(win_size, step_size)
33 |
34 | ### Paths to data - these need to be modified according to where the original data is stored
35 | img_ext = '.tif'
36 | img_dir = '/media/tandoan/Data/data/Monusac/Test(split)/Images'
37 | ann_dir = '/media/tandoan/Data/data/Monusac/Test(split)/Labels(mat)'
38 | ####
39 | out_dir = "/media/tandoan/Data/data/Monusac/%dx%d_%dx%d" % \
40 | (win_size[0], win_size[1], step_size[0], step_size[1])
41 |
42 | file_list = glob.glob('%s/*%s' % (img_dir, img_ext))
43 | file_list.sort()
44 |
45 | rm_n_mkdir(out_dir)
46 | for filename in file_list:
47 | filename = os.path.basename(filename)
48 | basename = filename.split('.')[0]
49 | print(filename)
50 |
51 | img = cv2.imread(img_dir + '/' + basename + img_ext)
52 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
53 |
54 | if cfg.type_classification:
55 | # assumes that ann is HxWx2 (nuclei class labels are available at index 1 of C)
56 | ann = sio.loadmat(ann_dir + '/' + basename + '.mat')
57 | ann_inst = ann['inst_map']
58 | ann_type = ann['type_map']
59 |
60 | # merge classes for CoNSeP (in paper we only utilise 3 nuclei classes and background)
61 | # If own dataset is used, then the below may need to be modified
62 | if cfg.data_type == 'glysac':
63 | ann_type[(ann_type == 1) | (ann_type == 2) | (ann_type == 9) | (ann_type == 10)] = 1
64 | ann_type[(ann_type == 4) | (ann_type == 5) | (ann_type == 6) | (ann_type == 7)] = 2
65 | ann_type[(ann_type == 8) | (ann_type == 3)] = 3
66 | elif cfg.data_type == 'consep':
67 | ann_type[(ann_type == 3) | (ann_type == 4)] = 3
68 | ann_type[(ann_type == 5) | (ann_type == 6) | (ann_type == 7)] = 4
69 | assert np.max(ann_type) <= cfg.nr_types-1, \
70 | "Only %d types of nuclei are defined for training"\
71 | "but there are %d types found in the input image." % (cfg.nr_types, np.max(ann_type))
72 |
73 | ann = np.dstack([ann_inst, ann_type])
74 | ann = ann.astype('int32')
75 |
76 | img = np.concatenate([img, ann], axis=-1)
77 |
78 | sub_patches = xtractor.extract(img, extract_type)
79 | for idx, patch in enumerate(sub_patches):
80 | np.save("{0}/{1}_{2:03d}.npy".format(out_dir, basename, idx), patch)
--------------------------------------------------------------------------------
/postproc/post_sonnet.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from numpy.random import default_rng
4 | from scipy.ndimage import filters, measurements, find_objects
5 | from scipy.ndimage.morphology import (binary_dilation, binary_fill_holes,
6 | distance_transform_cdt,
7 | distance_transform_edt)
8 | from skimage.morphology import remove_small_objects, binary_erosion
9 | import tensorflow as tf
10 | from matplotlib import cm
11 | from skimage import measure
12 | from skimage.morphology import remove_small_objects
13 | from skimage.segmentation import watershed
14 | import matplotlib.pyplot as plt
15 | from sklearn.utils import shuffle
16 | from collections import OrderedDict
17 |
18 | def colorize(value, vmin=None, vmax=None, cmap=None):
19 | """
20 | Arguments:
21 | - value: input tensor, NHWC ('channels_last')
22 | - vmin: the minimum value of the range used for normalization.
23 | (Default: value minimum)
24 | - vmax: the maximum value of the range used for normalization.
25 | (Default: value maximum)
26 | - cmap: a valid cmap named for use with matplotlib's `get_cmap`.
27 | (Default: 'gray')
28 | Example usage:
29 | ```
30 | output = tf.random_uniform(shape=[256, 256, 1])
31 | output_color = colorize(output, vmin=0.0, vmax=1.0, cmap='viridis')
32 | tf.summary.image('output', output_color)
33 | ```
34 |
35 | Returns a 3D tensor of shape [height, width, 3], uint8.
36 | """
37 |
38 | # normalize
39 | if vmin is None:
40 | vmin = tf.reduce_min(value, axis=[1, 2])
41 | vmin = tf.reshape(vmin, [-1, 1, 1])
42 | if vmax is None:
43 | vmax = tf.reduce_max(value, axis=[1, 2])
44 | vmax = tf.reshape(vmax, [-1, 1, 1])
45 | value = (value - vmin) / (vmax - vmin) # vmin..vmax
46 |
47 | # squeeze last dim if it exists
48 | # NOTE: will throw error if use get_shape()
49 | # value = tf.squeeze(value)
50 |
51 | # quantize
52 | value = tf.round(value * 255)
53 | indices = tf.cast(value, np.int32)
54 |
55 | # gather
56 | colormap = cm.get_cmap(cmap if cmap is not None else 'gray')
57 | colors = colormap(np.arange(256))[:, :3]
58 | colors = tf.constant(colors, dtype=tf.float32)
59 | value = tf.gather(colors, indices)
60 | value = tf.cast(value * 255, tf.uint8)
61 | return value
62 |
63 | def proc_np_ord(pred, pred_ord):
64 | """
65 | Process Nuclei Prediction with The ordinal map
66 |
67 | Args:
68 | pred: prediction output (NP branch)
69 | pred_ord: ordinal prediction output (ordinal branch)
70 | """
71 |
72 | blb_raw = pred
73 |
74 | pred_ord = np.squeeze(pred_ord)
75 | distance = -pred_ord
76 | marker = np.copy(pred_ord)
77 | marker[marker <= 4] = 0
78 | marker[marker > 4] = 1
79 | marker = binary_dilation(marker, iterations=1)
80 | # marker = binary_erosion(marker)
81 | # marker = binary_erosion(marker)
82 | # kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5, 5))
83 | # marker = cv2.morphologyEx(np.float32(marker), cv2.MORPH_OPEN, kernel)
84 | marker = measurements.label(marker)[0]
85 | marker = remove_small_objects(marker, min_size=10)
86 |
87 | # Processing
88 | blb = np.copy(blb_raw)
89 | blb[blb >= 0.5] = 1
90 | blb[blb < 0.5] = 0
91 |
92 | blb = measurements.label(blb)[0]
93 | blb = remove_small_objects(blb, min_size=10)
94 | blb[blb > 0] = 1 # background is 0 already
95 |
96 | markers = marker * blb
97 |
98 | proced_pred = watershed(distance, markers, mask=blb)
99 |
100 | return proced_pred
--------------------------------------------------------------------------------
/process.py:
--------------------------------------------------------------------------------
1 |
2 | import glob
3 | import os
4 |
5 | import cv2
6 | import numpy as np
7 | import scipy.io as sio
8 | from scipy.ndimage import filters, measurements
9 | from scipy.ndimage.morphology import (binary_dilation, binary_fill_holes,
10 | distance_transform_cdt,
11 | distance_transform_edt)
12 | from skimage.morphology import remove_small_objects, watershed
13 | import matplotlib.pyplot as plt
14 |
15 | import postproc.post_sonnet
16 |
17 | from config import Config
18 |
19 | from misc.viz_utils import visualize_instances
20 | from misc.utils import get_inst_centroid
21 | from metrics.stats_utils import remap_label
22 |
23 | ##########
24 |
25 | ## ! WARNING:
26 | ## check the prediction channels, wrong ordering will break the code !
27 | ## the prediction channels ordering should match the ones produced in augs.py
28 |
29 | cfg = Config()
30 |
31 |
32 | pred_dir = cfg.inf_output_dir
33 | proc_dir = pred_dir + '/_proc'
34 |
35 | file_list = glob.glob('%s/*.mat' % (pred_dir))
36 | file_list.sort() # ensure same order
37 |
38 | if not os.path.isdir(proc_dir):
39 | os.makedirs(proc_dir)
40 |
41 | for filename in file_list:
42 | filename = os.path.basename(filename)
43 | basename = filename.split('.')[0]
44 | print(pred_dir, basename, end=' ', flush=True)
45 |
46 |
47 | pred_mat = sio.loadmat('%s/%s.mat' % (pred_dir, basename))
48 | pred = np.squeeze(pred_mat['result'])
49 | pred_ord = np.squeeze(pred_mat['result-ord'])
50 |
51 | if hasattr(cfg, 'type_classification') and cfg.type_classification:
52 | pred_inst = pred[...,cfg.nr_types:]
53 | pred_type = pred[...,:cfg.nr_types]
54 |
55 | pred_inst = np.squeeze(pred_inst)
56 | pred_type = np.argmax(pred_type, axis=-1)
57 | ###
58 |
59 | pred_inst = postproc.post_sonnet.proc_np_ord(pred_inst, pred_ord)
60 |
61 | # ! will be extremely slow on WSI/TMA so it's advisable to comment this out
62 | # * remap once so that further processing faster (metrics calculation, etc.)
63 | pred_inst = remap_label(pred_inst, by_size=True)
64 |
65 |
66 | if cfg.type_classification:
67 | #### * Get class of each instance id, stored at index id-1
68 | pred_id_list = list(np.unique(pred_inst))[1:] # exclude background ID
69 | pred_inst_type = np.full(len(pred_id_list), 0, dtype=np.int32)
70 | for idx, inst_id in enumerate(pred_id_list):
71 | inst_type = pred_type[(pred_inst == inst_id)&(pred_ord>=5)]
72 | type_list, type_pixels = np.unique(inst_type, return_counts=True)
73 | type_list = list(zip(type_list, type_pixels))
74 | type_list = sorted(type_list, key=lambda x: x[1], reverse=True)
75 | try:
76 | inst_type = type_list[0][0]
77 | except IndexError:
78 | inst_type = 0
79 | if inst_type == 0: # ! pick the 2nd most dominant if exist
80 | if len(type_list) > 1:
81 | if (type_list[1][1] / type_list[0][1]) > 0.5:
82 | inst_type = type_list[1][0]
83 | else:
84 | inst_type = type_list[0][0]
85 | print('[Warn] Instance has `background` type')
86 | else:
87 | print('[Warn] Instance has `background` type' )
88 | pred_inst_type[idx] = inst_type
89 | pred_inst_centroid = get_inst_centroid(pred_inst)
90 |
91 | sio.savemat('%s/%s.mat' % (proc_dir, basename),
92 | {'inst_map' : pred_inst,
93 | 'inst_type' : pred_inst_type[:, None],
94 | 'inst_centroid' : pred_inst_centroid,
95 | 'type_map' : pred_type
96 | })
97 |
98 | ##
99 | print('FINISH')
100 |
--------------------------------------------------------------------------------
/loader/loader.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | from tensorpack.dataflow import (AugmentImageComponent, AugmentImageComponents,
6 | BatchData, BatchDataByShape, CacheData,
7 | PrefetchDataZMQ, RNGDataFlow, RepeatedData)
8 |
9 | ####
10 | class DatasetSerial(RNGDataFlow):
11 | """
12 | Produce ``(image, label)`` pair, where
13 | ``image`` has shape HWC and is RGB, has values in range [0-255].
14 |
15 | ``label`` is a float image of shape (H, W, C). Number of C depends
16 | on `self.model_mode` within `config.py`
17 |
18 | If self.model_mode is 'sonnet':
19 | channel 0 binary nuclei map, values are either 0 (background) or 1 (nuclei)
20 | channel 1 containing the type map
21 | channel 2 containing the ordinal map
22 | """
23 |
24 | def __init__(self, path_list):
25 | self.path_list = path_list
26 | ##
27 | def size(self):
28 | return len(self.path_list)
29 | ##
30 | def get_data(self):
31 | idx_list = list(range(0, len(self.path_list)))
32 | random.shuffle(idx_list)
33 | for idx in idx_list:
34 |
35 | data = np.load(self.path_list[idx])
36 |
37 | # split stacked channel into image and label
38 | img = data[...,:3] # RGB images
39 | ann = data[...,3:] # ann map
40 |
41 | img = img.astype('uint8')
42 | yield [img, ann]
43 |
44 |
45 | ####
46 | def valid_generator(ds, shape_aug=None, input_aug=None, label_aug=None, batch_size=8, nr_procs=1):
47 | ### augment both the input and label
48 | ds = ds if shape_aug is None else AugmentImageComponents(ds, shape_aug, (0, 1), copy=True)
49 | ### augment just the input
50 | ds = ds if input_aug is None else AugmentImageComponent(ds, input_aug, index=0, copy=False)
51 | ### augment just the output
52 | ds = ds if label_aug is None else AugmentImageComponent(ds, label_aug, index=1, copy=True)
53 | #
54 | ds = BatchData(ds, batch_size, remainder=True)
55 | ds = CacheData(ds) # cache all inference images
56 | # ds = DataThread(batch_size, ds)
57 | # ds = QueueInput(ds)
58 | return ds
59 |
60 | ####
61 | def train_generator(ds, shape_aug=None, input_aug=None, label_aug=None, batch_size=4, nr_procs=4):
62 | ### augment both the input and label
63 | ds = ds if shape_aug is None else AugmentImageComponents(ds, shape_aug, (0, 1), copy=True)
64 | ### augment just the input i.e index 0 within each yield of DatasetSerial
65 | ds = ds if input_aug is None else AugmentImageComponent(ds, input_aug, index=0, copy=False)
66 | ### augment just the output i.e index 1 within each yield of DatasetSerial
67 | ds = ds if label_aug is None else AugmentImageComponent(ds, label_aug, index=1, copy=True)
68 | #
69 | ds = BatchDataByShape(ds, batch_size, idx=0)
70 | ds = PrefetchDataZMQ(ds, nr_procs)
71 | return ds
72 |
73 | ####
74 | def visualize(datagen, batch_size, view_size=4):
75 | """
76 | Read the batch from 'datagen' and display 'view_size' number of
77 | of images and their corresponding Ground Truth
78 | """
79 | def prep_imgs(img, ann):
80 | cmap = plt.get_cmap('viridis')
81 | # cmap may randomly fails if of other types
82 | ann = ann.astype('float32')
83 | ann_chs = np.dsplit(ann, ann.shape[-1])
84 | for i, ch in enumerate(ann_chs):
85 | ch = np.squeeze(ch)
86 | # normalize to -1 to 1 range else
87 | # cmap may behave stupidly
88 | ch = ch / (np.max(ch) - np.min(ch) + 1.0e-16)
89 | # take RGB from RGBA heat map
90 | ann_chs[i] = cmap(ch)[...,:3]
91 | img = img.astype('float32') / 255.0
92 | prepped_img = np.concatenate([img] + ann_chs, axis=1)
93 | return prepped_img
94 |
95 | assert view_size <= batch_size, 'Number of displayed images must <= batch size'
96 | ds = RepeatedData(datagen, -1)
97 | ds.reset_state()
98 | for imgs, segs in ds.get_data():
99 | for idx in range (0, view_size):
100 | displayed_img = prep_imgs(imgs[idx], segs[idx])
101 | plt.subplot(view_size, 1, idx+1)
102 | plt.imshow(displayed_img)
103 | plt.show()
104 | return
105 | ###
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Sonnet: A self-guided ordinal regression neural network for segmentation and classification of nuclei in large-scale multi-tissue histology images
2 |
3 | ## Overview
4 |
5 | This repository contains a tensorflow implementation of SONNET: a self-guided ordinal regression neural network that performs simultaneously nuclei segmentation and classification. By introducing a distance decreasing discretization strategy, the network can detect nuclear pixels (inner pixels) and high uncertainty regions (outer pixels). The self-guided training strategy is applied to high uncertainty regions to improve the final outcome.
6 | As part of this research, we introduce a new dataset for Gastric Lymphocyte Segmentation And Classification (GLySAC), which includes 59 H&E stained image tiles, of size 1000x1000. More details can be found in the SONNET paper.
7 |
8 | The repository includes:
9 | - Source code of SONNET.
10 | - Training code for SONNET.
11 | - Datasets employed in the SONNET paper.
12 | - Pretrained weights for the SONNET encoder.
13 | - Evaluation on nuclei segmentation metrics (DICE, AJI, DQ, SQ, PQ) and nuclei classification metrics (Fd and F1 scores).
14 |
15 | ## Installation
16 | ```
17 | conda create --name sonnet python=3.6
18 | conda activate sonnet
19 | pip install -r requirements.txt
20 | ```
21 |
22 | ## Dataset
23 | Download the CoNSeP dataset from this [link](https://warwick.ac.uk/fac/sci/dcs/research/tia/data/hovernet/).
24 | Download the PanNuke dataset from this [link](https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke)
25 | Download the MoNuSAC and GLySAC from this [link](https://drive.google.com/drive/folders/1p0Yt2w8MTcaZJU3bdh0fAtTrPWin1-zb?usp=sharing)
26 |
27 |
28 | ## Step by Step Instruction
29 | To help with debugging and applying the model for nuclei segmentation and classification, it requires four steps:
30 |
31 | ### Step 1: Extracting the original data into patches
32 | To train the model, the data needs to be extracted into patches. First, set the dataset name in ```config.py```. Then, To extract data into patches for training, simply run:
33 | `python extract_patches.py`
34 | The patches are numpy arrays with the shape of [RGB, inst, type], where RGB is the input image, inst is the foreground/background groundtruth, type is the type groundtruth map.
35 |
36 | ### Step 2: Training the model
37 | Download the pretrained-weights of the encoder used in SONNET on this [link](https://drive.google.com/drive/folders/1p0Yt2w8MTcaZJU3bdh0fAtTrPWin1-zb?usp=sharing)
38 | Before training the network:
39 | - Set the training dataset and validation dataset path in `config.py`.
40 | - Set the pretrained-weights of the encoder in `config.py`.
41 | - Set the log path in `config.py`.
42 | - Change the hyperparameters used in training process according to your need in `opt/hyperconfig.py`.
43 | To train the network with GPUs 0 and 1:
44 | ```
45 | python train.py --gpu='0,1'
46 | ```
47 | ### Step 3: Inference
48 | Before testing the network on the test dataset:
49 | - Set path to the test dataset, path to the model trained weights, path to save the output in `config.py`.
50 | Achieve the network prediction by the command:
51 | ```
52 | python infer.py --gpu='0'
53 | ```
54 | It is notice that the inference only support for 1 GPU only.
55 | To obtain the final outcome, i.e. the instance map and the type of each nuclear, run the command:
56 | ```
57 | python process.py
58 | ```
59 | ### Step 4: Calculate the metrics
60 | To calculate the metrics used in SONNET paper, run the command:
61 | - instance segmentation: `python compute_stats.py --mode=instance --pred_dir='pred_dir' --true_dir='true_dir'`
62 | - type classification: `python compute_stats.py --mode=type --pred_dir='pred_dir' --true_dir='true_dir'`
63 |
64 | ## Visual Results
65 |
66 |
67 |
68 |
69 | |
70 |
71 |
72 | |
73 |
74 |
75 |
76 | Type of each nuclear is represented by the color:
77 | - Pink for Epithelial.
78 | - Yellow for Lymphocyte.
79 | - Blue for Miscellaneous.
80 |
81 | ## Requirements
82 | Python 3.6, Tensorflow 1.12 and other common packages listed in requirements.txt
83 |
84 | ## Citation
85 | The datasets that we used in SONNET are from these papers:
86 | - CoNSeP: [paper](https://www.sciencedirect.com/science/article/pii/S1361841519301045).
87 | - MoNuSAC: [paper](https://ieeexplore.ieee.org/document/9446924).
88 | - PanNuke: [paper](https://arxiv.org/abs/2003.10778).
89 |
90 | If you use any of these dataset for your research, please have a citation of it.
91 |
92 |
--------------------------------------------------------------------------------
/loader/augs.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import cv2
4 | import numpy as np
5 |
6 |
7 | from scipy.ndimage import measurements
8 | from scipy.ndimage.filters import gaussian_filter
9 | from scipy.ndimage.interpolation import affine_transform, map_coordinates
10 | from scipy.ndimage.morphology import (distance_transform_cdt,
11 | distance_transform_edt)
12 | from skimage import morphology as morph
13 |
14 | from tensorpack.dataflow.imgaug import ImageAugmentor
15 | from tensorpack.utils.utils import get_rng
16 |
17 | from misc.utils import cropping_center, bounding_box, get_inst_centroid
18 |
19 | ####
20 | class GenInstance(ImageAugmentor):
21 | def __init__(self, crop_shape=None):
22 | super(GenInstance, self).__init__()
23 | self.crop_shape = crop_shape
24 |
25 | def reset_state(self):
26 | self.rng = get_rng(self)
27 |
28 | def _fix_mirror_padding(self, ann):
29 | """
30 | Deal with duplicated instances due to mirroring in interpolation
31 | during shape augmentation (scale, rotation etc.)
32 | """
33 | current_max_id = np.amax(ann)
34 | inst_list = list(np.unique(ann))
35 | inst_list.remove(0) # 0 is background
36 | for inst_id in inst_list:
37 | inst_map = np.array(ann == inst_id, np.uint8)
38 | remapped_ids = measurements.label(inst_map)[0]
39 | remapped_ids[remapped_ids > 1] += int(current_max_id)
40 | ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1]
41 | current_max_id = np.amax(ann)
42 | return ann
43 | ####
44 |
45 |
46 | ####
47 | class GenInstanceOrd(GenInstance):
48 | """
49 | Generate an ordinal distance map based on the instance map
50 | First, the euclidead distance map will be calculated. Then, the ordinal map is generated based on the euclidean distance map
51 | """
52 |
53 | def _augment(self, img, _):
54 | img = np.copy(img)
55 | orig_ann = img[...,0].astype(np.int32) # instance ID map
56 | fixed_ann = self._fix_mirror_padding(orig_ann)
57 | # re-cropping with fixed instance id map
58 | crop_ann = cropping_center(fixed_ann, self.crop_shape)
59 | with warnings.catch_warnings():
60 | warnings.simplefilter("ignore")
61 | crop_ann = morph.remove_small_objects(crop_ann, min_size=7)
62 |
63 | inst_list = list(np.unique(crop_ann))
64 | # inst_centroid = get_inst_centroid(crop_ann)
65 | inst_list.remove(0)
66 | mask = np.zeros_like(fixed_ann, dtype=np.float32)
67 | for inst_id in inst_list:
68 | inst_id_map = np.copy(np.array(fixed_ann == inst_id, dtype=np.uint8))
69 | M = cv2.moments(inst_id_map)
70 | cx, cy = int(M['m10'] / M['m00']), int(M['m01'] / M['m00'])
71 | inst_id_map[fixed_ann != inst_id] = 2
72 | inst_id_map[int(cy), int(cx)] = 0
73 | inst_id_map = distance_transform_edt(inst_id_map)
74 | inst_id_map[fixed_ann != inst_id] = 0
75 | max_val = np.max(inst_id_map)
76 | inst_id_map = inst_id_map / max_val
77 | mask[fixed_ann == inst_id] = inst_id_map[fixed_ann == inst_id]
78 |
79 | def gen_ord(euc_map):
80 | lut_gt = [1, 0.83, 0.68, 0.54, 0.41, 0.29, 0.19, 0.09, 0]
81 | zeros = np.zeros_like(euc_map)
82 | ones = np.ones_like(euc_map)
83 | decoded_label = np.full(euc_map.shape, 0, dtype=np.float32)
84 | for k in range(8):
85 | if k != 7:
86 | decoded_label += np.where((euc_map <= lut_gt[k]) & (euc_map > lut_gt[k+1]), ones * (k + 1), zeros)
87 | else:
88 | decoded_label += np.where((euc_map <= lut_gt[k]) & (euc_map >= lut_gt[k+1]), ones * (k + 1), zeros)
89 | return decoded_label
90 |
91 | ord_map = gen_ord(mask)
92 | img = img.astype('float32')
93 | img = np.dstack([img, ord_map])
94 |
95 | return img
96 | ####
97 |
98 | class GaussianBlur(ImageAugmentor):
99 | """ Gaussian blur the image with random window size"""
100 | def __init__(self, max_size=3):
101 | """
102 | Args:
103 | max_size (int): max possible Gaussian window size would be 2 * max_size + 1
104 | """
105 | super(GaussianBlur, self).__init__()
106 | self.max_size = max_size
107 |
108 | def _get_augment_params(self, img):
109 | sx, sy = self.rng.randint(1, self.max_size, size=(2,))
110 | sx = sx * 2 + 1
111 | sy = sy * 2 + 1
112 | return sx, sy
113 |
114 | def _augment(self, img, s):
115 | return np.reshape(cv2.GaussianBlur(img, s, sigmaX=0, sigmaY=0,
116 | borderType=cv2.BORDER_REPLICATE), img.shape)
117 |
118 | ####
119 | class BinarizeLabel(ImageAugmentor):
120 | """ Convert labels to binary maps"""
121 | def __init__(self):
122 | super(BinarizeLabel, self).__init__()
123 |
124 | def _get_augment_params(self, img):
125 | return None
126 |
127 | def _augment(self, img, s):
128 | img = np.copy(img)
129 | arr = img[...,0]
130 | arr[arr > 0] = 1
131 | return img
132 |
133 | ####
134 | class MedianBlur(ImageAugmentor):
135 | """ Median blur the image with random window size"""
136 | def __init__(self, max_size=3):
137 | """
138 | Args:
139 | max_size (int): max possible window size
140 | would be 2 * max_size + 1
141 | """
142 | super(MedianBlur, self).__init__()
143 | self.max_size = max_size
144 |
145 | def _get_augment_params(self, img):
146 | s = self.rng.randint(1, self.max_size)
147 | s = s * 2 + 1
148 | return s
149 |
150 | def _augment(self, img, ksize):
151 | return cv2.medianBlur(img, ksize)
152 |
153 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import cv2
4 | import numpy as np
5 | import tensorflow as tf
6 | from tensorpack import imgaug
7 |
8 | from loader.augs import (BinarizeLabel, GaussianBlur,
9 | GenInstanceOrd, MedianBlur)
10 |
11 | ####
12 | class Config(object):
13 | def __init__(self, ):
14 |
15 | self.seed = 9
16 | self.model_type = 'sonnet'
17 | self.data_type = 'consep'
18 |
19 | self.type_classification = True
20 | self.nr_types = 5
21 | self.nr_classes = 2 # Nuclei Pixels vs Background
22 |
23 | # define your nuclei type name here, please ensure it contains
24 | # same the amount as defined in `self.nr_types` . ID 0 is preserved
25 | # for background so please don't use it as ID
26 | if self.data_type == 'consep':
27 | self.nuclei_type_dict = {
28 | 'Miscellaneous': 1, # ! Please ensure the matching ID is unique
29 | 'Inflammatory' : 2,
30 | 'Epithelial' : 3,
31 | 'Spindle' : 4,
32 | }
33 | elif self.data_type == 'monusac':
34 | self.nuclei_type_dict ={
35 | 'Epithelial' : 1,
36 | 'Lymphocyte' : 2,
37 | 'Macrophages': 3,
38 | 'Neutrophil' : 4
39 | }
40 | elif self.data_type == 'pannuke':
41 | self.nuclei_type_dict ={
42 | 'Neoplastic' : 1,
43 | 'Inflammatory' : 2,
44 | 'Connective': 3,
45 | 'Dead' : 4,
46 | 'Non-Neoplastic Epithelial' : 5
47 | }
48 | else:
49 | self.nuclei_type_dict ={
50 | 'Other' : 1,
51 | 'Lymphocyte' : 2,
52 | 'Epithelial' : 3
53 | }
54 | assert len(self.nuclei_type_dict.values()) == self.nr_types - 1
55 |
56 | #### Dynamically setting the config file into variable
57 | config_file = importlib.import_module('opt.hyperconfig') # np_hv, np_dist
58 | config_dict = config_file.__getattribute__(self.model_type)
59 |
60 | for variable, value in config_dict.items():
61 | self.__setattr__(variable, value)
62 | #### Training data
63 |
64 | # patches are stored as numpy arrays with N channels
65 | # ordering as [Image][Nuclei Pixels][Nuclei Type][Additional Map]
66 | # Ex: with type_classification=True
67 | # HoVer-Net: RGB - Nuclei Pixels - Type Map - Horizontal and Vertical Map
68 | # Ex: with type_classification=False
69 | # Dist : RGB - Nuclei Pixels - Distance Map
70 | if self.data_type != 'pannuke':
71 | data_code_dict = {
72 | 'sonnet' : '540x540_76x76',
73 | }
74 | else:
75 | data_code_dict = {
76 | 'sonnet' : '270x270_76x76',
77 | }
78 |
79 | self.data_ext = '.npy'
80 | # list of directories containing validation patches.
81 | # For both train and valid directories, a comma separated list of directories can be used
82 | self.train_dir = ['/media/tandoan/data2/CoNSeP/Train/%s/' % data_code_dict[self.model_type]]
83 | # Used train_test_split alr
84 | self.valid_dir = ['/home/tandoan/work/PanNuke/Valid/%s' % data_code_dict[self.model_type]]
85 |
86 | # number of processes for parallel processing input
87 | self.nr_procs_train = 8
88 | self.nr_procs_valid = 4
89 |
90 | self.input_norm = True # normalize RGB to 0-1 range
91 |
92 | ####
93 | exp_id = 'v1.0'
94 | model_id = '%s' % self.model_type
95 | self.model_name = '%s/%s' % (exp_id, model_id)
96 | # loading chkpts in tensorflow, the path must not contain extra '/'
97 | self.log_path = '/media/tandoan/data2/logs/logs_test'
98 | self.save_dir = '%s/%s' % (self.log_path, self.model_name) # log file destination
99 |
100 | #### Info for running inferencee
101 | self.inf_auto_find_chkpt = False
102 | # path to checkpoints will be used for inference, replace accordingly
103 | self.inf_model_path = '/media/tandoan/data2/logs/logs_focalnet_noguide_consep/v1.0/focalnet/02/model-39650.index'
104 |
105 | # output will have channel ordering as [Nuclei Type][Nuclei Pixels][Additional]
106 | # where [Nuclei Type] will be used for getting the type of each instance
107 | # while [Nuclei Pixels][Additional] will be used for extracting instance
108 |
109 | self.inf_imgs_ext = '.png'
110 | self.inf_data_dir = '/media/tandoan/data2/CoNSeP/Test/Images'
111 | self.inf_output_dir = 'output/test/'
112 |
113 | # for inference during evalutaion mode i.e run by infer.py
114 | self.eval_inf_input_tensor_names = ['images']
115 |
116 |
117 |
118 | # for inference during training mode i.e run by trainer.py
119 | if self.model_type == 'sonnet':
120 | self.train_inf_output_tensor_names = ['predmap-coded', 'truemap-coded']
121 | self.eval_inf_output_tensor_names = ['predmap-coded', 'predmap-ord']
122 |
123 |
124 | def get_model(self, phase=1):
125 | if phase!=2:
126 | model_constructor = importlib.import_module('model.sonnet')
127 | model_constructor = model_constructor.Sonnet
128 | else:
129 | model_constructor = importlib.import_module('model.sonnet_v2')
130 | model_constructor = model_constructor.Sonnet_phase2
131 | return model_constructor # NOTE return alias, not object
132 |
133 |
134 | # refer to https://tensorpack.readthedocs.io/modules/dataflow.imgaug.html for
135 | # information on how to modify the augmentation parameters
136 | def get_train_augmentors(self, input_shape, output_shape, view=False):
137 | if self.data_type != 'pannuke':
138 | shape_augs = [
139 | imgaug.Affine(
140 | shear=5, # in degree
141 | scale=(0.8, 1.2),
142 | rotate_max_deg=179,
143 | translate_frac=(0.01, 0.01),
144 | interp=cv2.INTER_NEAREST,
145 | border=cv2.BORDER_CONSTANT),
146 | imgaug.Flip(vert=True),
147 | imgaug.Flip(horiz=True),
148 | imgaug.CenterCrop(input_shape),
149 | ]
150 | else:
151 | shape_augs =[
152 | imgaug.Flip(vert=True),
153 | imgaug.Flip(horiz=True),
154 | ]
155 |
156 | input_augs = [
157 | imgaug.RandomApplyAug(
158 | imgaug.RandomChooseAug(
159 | [
160 | GaussianBlur(),
161 | MedianBlur(),
162 | imgaug.GaussianNoise(),
163 | ]
164 | ), 0.5),
165 | # standard color augmentation
166 | imgaug.RandomOrderAug(
167 | [imgaug.Hue((-8, 8), rgb=True),
168 | imgaug.Saturation(0.2, rgb=True),
169 | imgaug.Brightness(26, clip=True),
170 | imgaug.Contrast((0.75, 1.25), clip=True),
171 | ]),
172 | imgaug.ToUint8(),
173 | ]
174 |
175 | label_augs = []
176 | if self.model_type == 'sonnet':
177 | label_augs = [GenInstanceOrd(crop_shape=output_shape)]
178 |
179 | if not self.type_classification:
180 | label_augs.append(BinarizeLabel())
181 |
182 | if not view:
183 | label_augs.append(imgaug.CenterCrop(output_shape))
184 |
185 | return shape_augs, input_augs, label_augs
186 |
187 | def get_valid_augmentors(self, input_shape, output_shape, view=False):
188 | shape_augs = [
189 | imgaug.CenterCrop(input_shape),
190 | ]
191 |
192 | input_augs = None
193 |
194 | label_augs = []
195 | if self.model_type == 'sonnet':
196 | label_augs = [GenInstanceOrd(crop_shape=output_shape)]
197 | label_augs.append(BinarizeLabel())
198 |
199 | if not view:
200 | label_augs.append(imgaug.CenterCrop(output_shape))
201 |
202 | return shape_augs, input_augs, label_augs
203 |
--------------------------------------------------------------------------------
/misc/patch_extractor.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | import os
4 |
5 | import cv2
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import glob
9 | from scipy.ndimage import measurements
10 |
11 | from .utils import cropping_center
12 | from .utils import rm_n_mkdir
13 |
14 | #####
15 | class PatchExtractor(object):
16 | """
17 | Extractor to generate patches with or without padding.
18 | Turn on debug mode to see how it is done.
19 |
20 | Args:
21 | x : input image, should be of shape HWC
22 | win_size : a tuple of (h, w)
23 | step_size : a tuple of (h, w)
24 | debug : flag to see how it is done
25 | Return:
26 | a list of sub patches, each patch has dtype same as x
27 |
28 | Examples:
29 | >>> xtractor = PatchExtractor((450, 450), (120, 120))
30 | >>> img = np.full([1200, 1200, 3], 255, np.uint8)
31 | >>> patches = xtractor.extract(img, 'mirror')
32 | """
33 | def __init__(self, win_size, step_size, debug=False):
34 |
35 | self.patch_type = 'mirror'
36 | self.win_size = win_size
37 | self.step_size = step_size
38 | self.debug = debug
39 | self.counter = 0
40 |
41 | def __get_patch(self, x, ptx):
42 | pty = (ptx[0]+self.win_size[0],
43 | ptx[1]+self.win_size[1])
44 | win = x[ptx[0]:pty[0],
45 | ptx[1]:pty[1]]
46 | assert win.shape[0] == self.win_size[0] and \
47 | win.shape[1] == self.win_size[1], \
48 | '[BUG] Incorrect Patch Size {0}'.format(win.shape)
49 | if self.debug:
50 | if self.patch_type == 'mirror':
51 | cen = cropping_center(win, self.step_size)
52 | cen = cen[...,self.counter % 3]
53 | cen.fill(150)
54 | cv2.rectangle(x,ptx,pty,(255,0,0),2)
55 | plt.imshow(x)
56 | plt.show(block=False)
57 | plt.pause(1)
58 | plt.close()
59 | self.counter += 1
60 | return win
61 |
62 | def __extract_valid(self, x):
63 | """
64 | Extracted patches without padding, only work in case win_size > step_size
65 |
66 | Note: to deal with the remaining portions which are at the boundary a.k.a
67 | those which do not fit when slide left->right, top->bottom), we flip
68 | the sliding direction then extract 1 patch starting from right / bottom edge.
69 | There will be 1 additional patch extracted at the bottom-right corner
70 |
71 | Args:
72 | x : input image, should be of shape HWC
73 | win_size : a tuple of (h, w)
74 | step_size : a tuple of (h, w)
75 | Return:
76 | a list of sub patches, each patch is same dtype as x
77 | """
78 |
79 | im_h = x.shape[0]
80 | im_w = x.shape[1]
81 |
82 | def extract_infos(length, win_size, step_size):
83 | flag = (length - win_size) % step_size != 0
84 | last_step = math.floor((length - win_size) / step_size)
85 | last_step = (last_step + 1) * step_size
86 | return flag, last_step
87 |
88 | h_flag, h_last = extract_infos(im_h, self.win_size[0], self.step_size[0])
89 | w_flag, w_last = extract_infos(im_w, self.win_size[1], self.step_size[1])
90 |
91 | sub_patches = []
92 | #### Deal with valid block
93 | for row in range(0, h_last, self.step_size[0]):
94 | for col in range(0, w_last, self.step_size[1]):
95 | win = self.__get_patch(x, (row, col))
96 | sub_patches.append(win)
97 | #### Deal with edge case
98 | if h_flag:
99 | row = im_h - self.win_size[0]
100 | for col in range(0, w_last, self.step_size[1]):
101 | win = self.__get_patch(x, (row, col))
102 | sub_patches.append(win)
103 | if w_flag:
104 | col = im_w - self.win_size[1]
105 | for row in range(0, h_last, self.step_size[0]):
106 | win = self.__get_patch(x, (row, col))
107 | sub_patches.append(win)
108 | if h_flag and w_flag:
109 | ptx = (im_h - self.win_size[0], im_w - self.win_size[1])
110 | win = self.__get_patch(x, ptx)
111 | sub_patches.append(win)
112 | return sub_patches
113 |
114 | def __extract_mirror(self, x):
115 | """
116 | Extracted patches with mirror padding the boundary such that the
117 | central region of each patch is always within the orginal (non-padded)
118 | image while all patches' central region cover the whole orginal image
119 |
120 | Args:
121 | x : input image, should be of shape HWC
122 | win_size : a tuple of (h, w)
123 | step_size : a tuple of (h, w)
124 | Return:
125 | a list of sub patches, each patch is same dtype as x
126 | """
127 |
128 | diff_h = self.win_size[0] - self.step_size[0]
129 | padt = diff_h // 2
130 | padb = diff_h - padt
131 |
132 | diff_w = self.win_size[1] - self.step_size[1]
133 | padl = diff_w // 2
134 | padr = diff_w - padl
135 |
136 | pad_type = 'constant' if self.debug else 'reflect'
137 | x = np.lib.pad(x, ((padt, padb), (padl, padr), (0, 0)), pad_type)
138 | sub_patches = self.__extract_valid(x)
139 | return sub_patches
140 |
141 | def extract(self, x, patch_type):
142 | patch_type = patch_type.lower()
143 | self.patch_type = patch_type
144 | if patch_type == 'valid':
145 | return self.__extract_valid(x)
146 | elif patch_type == 'mirror':
147 | return self.__extract_mirror(x)
148 | else:
149 | assert False, 'Unknown Patch Type [%s]' % patch_type
150 | return
151 |
152 | class Padding_image(object):
153 | """
154 | Padding Images to reach the minimum size using `mirror` method. Use for Monusac dataset.
155 | For HoverNet, win_size is 540x540
156 | """
157 | def __init__(self, win_size):
158 | self.win_size = win_size
159 |
160 | def pad(self, img_dir, ann_dir, save_img_dir, save_ann_dir, pad_type='reflect'):
161 | file_list = glob.glob(img_dir + '/*.tif')
162 | file_list.sort()
163 | rm_n_mkdir(save_img_dir)
164 | rm_n_mkdir(save_ann_dir)
165 | for filename in file_list:
166 | print(filename)
167 | filename = os.path.basename(filename)
168 | basename = filename.split('.')[0]
169 | img = cv2.imread(img_dir + '/' + basename + '.tif')
170 | ann = np.load(ann_dir + '/' + basename + '.npy')
171 | padt, padb, padl, padr = 0, 0, 0, 0
172 | if img.shape[0] < 540:
173 | diff_h = self.win_size[0] - img.shape[0]
174 | padt = diff_h // 2
175 | padb = diff_h - padt
176 | if img.shape[1] < 540:
177 | diff_w = self.win_size[1] - img.shape[1]
178 | padl = diff_w // 2
179 | padr = diff_w - padl
180 | img = np.lib.pad(img, ((padt, padb), (padl, padr), (0,0)), pad_type)
181 | ann = np.lib.pad(ann, ((padt, padb), (padl, padr), (0,0)), pad_type)
182 | inst_map = ann[...,0]
183 | inst_map = self._fix_mirror_padding(inst_map)
184 |
185 | cv2.imwrite(save_img_dir + '/' + basename + '.tif', img)
186 | np.save(save_ann_dir + '/' + basename, ann)
187 |
188 |
189 | def _fix_mirror_padding(self, ann):
190 | """
191 | Deal with duplicated instances due to mirroring in interpolation
192 | during shape augmentation (scale, rotation etc.)
193 | """
194 | current_max_id = np.amax(ann)
195 | inst_list = list(np.unique(ann))
196 | inst_list.remove(0) # 0 is background
197 | for inst_id in inst_list:
198 | inst_map = np.array(ann == inst_id, np.uint8)
199 | remapped_ids = measurements.label(inst_map)[0]
200 | remapped_ids[remapped_ids > 1] += current_max_id
201 | ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1]
202 | current_max_id = np.amax(ann)
203 | return ann
204 |
205 |
206 |
207 | #####
208 |
209 | ###########################################################################
210 |
211 | if __name__ == '__main__':
212 | # toy example for debug
213 | # 355x355, 480x480
214 | xtractor = PatchExtractor((450, 450), (120, 120), debug=True)
215 | a = np.full([1200, 1200, 3], 255, np.uint8)
216 | xtractor.extract(a, 'mirror')
217 | xtractor.extract(a, 'valid')
218 |
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import math
4 | import os
5 | from collections import deque
6 |
7 | import cv2
8 | import numpy as np
9 | from scipy import io as sio
10 | import matplotlib.pyplot as plt
11 | from skimage import measure
12 | from scipy.ndimage import find_objects
13 |
14 | from tensorpack.predict import OfflinePredictor, PredictConfig
15 | from tensorpack.tfutils.sessinit import get_model_loader
16 |
17 | from config import Config
18 | from misc.utils import rm_n_mkdir
19 |
20 | import json
21 | import operator
22 |
23 |
24 | ####
25 | def get_best_chkpts(path, metric_name, comparator='>'):
26 | """
27 | Return the best checkpoint according to some criteria.
28 | Note that it will only return valid path, so any checkpoint that has been
29 | removed wont be returned (i.e moving to next one that satisfies the criteria
30 | such as second best etc.)
31 | Args:
32 | path: directory contains all checkpoints, including the "stats.json" file
33 | """
34 | stat_file = path + '/02' + '/stats.json'
35 | ops = {
36 | '>': operator.gt,
37 | '<': operator.lt,
38 | }
39 |
40 | op_func = ops[comparator]
41 | with open(stat_file) as f:
42 | info = json.load(f)
43 |
44 | if comparator == '>':
45 | best_value = -float("inf")
46 | else:
47 | best_value = +float("inf")
48 |
49 | best_chkpt = None
50 | for epoch_stat in info:
51 | epoch_value = epoch_stat[metric_name]
52 | if op_func(epoch_value, best_value):
53 | chkpt_path = "%s/02/model-%d.index" % (path, epoch_stat['global_step'])
54 | if os.path.isfile(chkpt_path):
55 | selected_stat = epoch_stat
56 | best_value = epoch_value
57 | best_chkpt = chkpt_path
58 | return best_chkpt, selected_stat
59 |
60 |
61 | ####
62 | class Inferer(Config):
63 |
64 | def __gen_prediction(self, x, predictor):
65 | """
66 | Using 'predictor' to generate the prediction of image 'x'
67 |
68 | Args:
69 | x : input image to be segmented. It will be split into patches
70 | to run the prediction upon before being assembled back
71 | """
72 | step_size = [40, 40]
73 | msk_size = self.infer_mask_shape
74 | win_size = self.infer_input_shape
75 |
76 |
77 |
78 | def get_last_steps(length, step_size):
79 | nr_step = math.ceil((length - step_size) / step_size)
80 | last_step = (nr_step + 1) * step_size
81 | return int(last_step), int(nr_step + 1)
82 |
83 | im_h = x.shape[0]
84 | im_w = x.shape[1]
85 |
86 | padt_img, padb_img = 0, 0
87 | padl_img, padr_img = 0, 0
88 |
89 | # pad if image size smaller than msk_size (for monusac dataset)
90 | if im_h < msk_size[0]:
91 | diff_h_img = msk_size[0] - im_h
92 | padt_img = diff_h_img // 2
93 | padb_img = diff_h_img - padt_img
94 | if im_w < msk_size[1]:
95 | diff_w_img = msk_size[1] - im_w
96 | padl_img = diff_w_img // 2
97 | padr_img = diff_w_img - padl_img
98 | x_pad = np.lib.pad(x, ((padt_img, padb_img), (padl_img, padr_img), (0, 0)), 'reflect')
99 | im_h_pad = x_pad.shape[0]
100 | im_w_pad = x_pad.shape[1]
101 |
102 |
103 | last_h, nr_step_h = get_last_steps(im_h_pad, step_size[0])
104 | last_w, nr_step_w = get_last_steps(im_w_pad, step_size[1])
105 | diff_h = win_size[0] - step_size[0]
106 | padt = diff_h // 2
107 | padb = last_h + win_size[0] - im_h_pad
108 |
109 |
110 | diff_w = win_size[1] - step_size[1]
111 | padl = diff_w // 2
112 | padr = last_w + win_size[1] - im_w_pad
113 |
114 |
115 | x_pad = np.lib.pad(x_pad, ((padt, padb), (padl, padr), (0, 0)), 'reflect')
116 |
117 |
118 | #### TODO: optimize this
119 | sub_patches = []
120 | # generating subpatches from orginal
121 | for row in range(0, last_h, step_size[0]):
122 | for col in range (0, last_w, step_size[1]):
123 | win = x_pad[row:row+win_size[0],
124 | col:col+win_size[1]]
125 | sub_patches.append(win)
126 | pred_coded = deque()
127 | pred_ord = deque()
128 | while len(sub_patches) > self.inf_batch_size:
129 | mini_batch = sub_patches[:self.inf_batch_size]
130 | sub_patches = sub_patches[self.inf_batch_size:]
131 | mini_output = predictor(mini_batch)
132 | mini_coded = mini_output[0][:, 18:58, 18:58,:]
133 | mini_ord = mini_output[1][:, 18:58, 18:58, :]
134 | mini_coded = np.split(mini_coded, self.inf_batch_size, axis=0)
135 | pred_coded.extend(mini_coded)
136 | mini_ord = np.split(mini_ord, self.inf_batch_size, axis=0)
137 | pred_ord.extend(mini_ord)
138 | if len(sub_patches) != 0:
139 | mini_output = predictor(sub_patches)
140 | mini_coded = mini_output[0][:, 18:58, 18:58,:]
141 | mini_ord = mini_output[1][:, 18:58, 18:58, :]
142 | mini_coded = np.split(mini_coded, len(sub_patches), axis=0)
143 | pred_coded.extend(mini_coded)
144 | mini_ord = np.split(mini_ord, len(sub_patches), axis=0)
145 | pred_ord.extend(mini_ord)
146 |
147 | #### Assemble back into full image
148 | output_patch_shape = np.squeeze(pred_coded[0]).shape
149 | ch = 1 if len(output_patch_shape) == 2 else output_patch_shape[-1]
150 |
151 | #### Assemble back into full image
152 | pred_coded = np.squeeze(np.array(pred_coded))
153 | pred_coded = np.reshape(pred_coded, (nr_step_h, nr_step_w) + pred_coded.shape[1:])
154 |
155 | pred_coded = np.transpose(pred_coded, [0, 2, 1, 3, 4]) if ch != 1 else \
156 | np.transpose(pred_coded, [0, 2, 1, 3])
157 | pred_coded = np.reshape(pred_coded, (pred_coded.shape[0] * pred_coded.shape[1],
158 | pred_coded.shape[2] * pred_coded.shape[3], ch))
159 | pred_coded = np.squeeze(pred_coded[:im_h_pad, :im_w_pad]) # just crop back to original size
160 | pred_coded = pred_coded[padt_img:padt_img+im_h, padl_img:padl_img+im_w]
161 |
162 | pred_ord = np.squeeze(np.array(pred_ord))
163 | pred_ord = np.reshape(pred_ord, (nr_step_h, nr_step_w) + pred_ord.shape[1:])
164 | pred_ord = np.transpose(pred_ord, [0, 2, 1, 3])
165 | pred_ord = np.reshape(pred_ord, (pred_ord.shape[0] * pred_ord.shape[1], pred_ord.shape[2] * pred_ord.shape[3]))
166 | pred_ord = np.squeeze(pred_ord[:im_h_pad, :im_w_pad])
167 | pred_ord = pred_ord[padt_img:padt_img+im_h, padl_img:padl_img+im_w]
168 |
169 | return pred_coded, pred_ord
170 |
171 | ####
172 | def run(self):
173 |
174 | if self.inf_auto_find_chkpt:
175 | print('-----Auto Selecting Checkpoint Basing On "%s" Through "%s" Comparison' % \
176 | (self.inf_auto_metric, self.inf_auto_comparator))
177 | model_path, stat = get_best_chkpts(self.save_dir, self.inf_auto_metric, self.inf_auto_comparator)
178 | print('Selecting: %s' % model_path)
179 | print('Having Following Statistics:')
180 | for key, value in stat.items():
181 | print('\t%s: %s' % (key, value))
182 | else:
183 | model_path = self.inf_model_path
184 |
185 | model_constructor = self.get_model()
186 | pred_config = PredictConfig(
187 | model = model_constructor(),
188 | session_init = get_model_loader(model_path),
189 | input_names = self.eval_inf_input_tensor_names,
190 | output_names = self.eval_inf_output_tensor_names)
191 | predictor = OfflinePredictor(pred_config)
192 |
193 | save_dir = self.inf_output_dir
194 | file_list = glob.glob('%s/*%s' % (self.inf_data_dir, self.inf_imgs_ext))
195 | file_list.sort() # ensure same order
196 |
197 | rm_n_mkdir(save_dir)
198 | for filename in file_list:
199 | filename = os.path.basename(filename)
200 | basename = filename.split('.')[0]
201 | print(self.inf_data_dir, basename, end=' ', flush=True)
202 |
203 | ##
204 | if self.data_type != 'pannuke':
205 | img = cv2.imread(self.inf_data_dir + '/' + filename)
206 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
207 | else:
208 | img = np.load(self.inf_data_dir + '/' + filename)
209 | ##
210 | pred_coded, pred_ord = self.__gen_prediction(img, predictor)
211 | sio.savemat('%s/%s.mat' % (save_dir, basename), {'result':[pred_coded], 'result-ord':[pred_ord]})
212 | print('FINISH')
213 |
214 |
215 |
216 |
217 | ####
218 | if __name__ == '__main__':
219 | parser = argparse.ArgumentParser()
220 | parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
221 | args = parser.parse_args()
222 |
223 | if args.gpu:
224 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
225 |
226 | inferer = Inferer()
227 | inferer.run()
228 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 |
2 | import argparse
3 | import json
4 | import os
5 | import random
6 |
7 | import numpy as np
8 | import tensorflow as tf
9 | from tensorpack import Inferencer, logger
10 | from tensorpack.callbacks import (DataParallelInferenceRunner, ModelSaver,
11 | MaxSaver, ScheduledHyperParamSetter, RunOp)
12 | from tensorpack.tfutils import SaverRestore, get_model_loader
13 | from tensorpack.train import (SyncMultiGPUTrainerParameterServer, TrainConfig,
14 | launch_train_with_config)
15 |
16 | import loader.loader as loader
17 | from config import Config
18 | from misc.utils import get_files, rm_n_mkdir
19 |
20 | import matplotlib.pyplot as plt
21 |
22 |
23 | class StatCollector(Inferencer, Config):
24 | """
25 | Accumulate output of inference during training.
26 | After the inference finishes, calculate the statistics
27 | """
28 | def __init__(self, prefix='valid'):
29 | super(StatCollector, self).__init__()
30 | self.prefix = prefix
31 |
32 | def _get_fetches(self):
33 | return self.train_inf_output_tensor_names
34 |
35 | def _before_inference(self):
36 |
37 | self.over_inter_np = 0
38 | self.over_total_np = 0
39 | self.over_correct_np = 0
40 | self.nr_pixels = 0
41 | self.over_type_dict = {}
42 | for type_name, type_id in self.nuclei_type_dict.items():
43 | self.over_type_dict['fdetect_inter_%s' % (type_name)] = 0
44 | self.over_type_dict['fdetect_total_%s' % (type_name)] = 0
45 |
46 | def _on_fetches(self, outputs):
47 | pred, true = outputs
48 |
49 | def _dice_info(true, pred, label):
50 | true = np.array(true == label, np.int32)
51 | pred = np.array(pred == label, np.int32)
52 | inter = (pred * true).sum()
53 | total = (pred + true).sum()
54 | return inter, total
55 |
56 | def _fdetect_info(true, pred, label):
57 | tp_dt = ((true == label)&(pred == label)).sum()
58 | tn_dt = ((true != label)&(true != 0)&(pred != label)&(pred != 0)).sum()
59 | fp_dt = ((true != label)&(true != 0)&(pred == label)).sum()
60 | fn_dt = ((true == label)&(pred != label)&(pred != 0)).sum()
61 | fp_d = ((true == 0)&(pred != 0)).sum()
62 | fn_d = ((true != 0)&(pred == 0)).sum()
63 | inter = 2 * (tp_dt + tn_dt)
64 | total = 2 * (tp_dt + tn_dt + fp_dt + fn_dt) + fp_d + fn_d
65 | return inter, total
66 |
67 | pred_type = pred[...,:self.nr_types]
68 | pred_inst = pred[...,self.nr_types:]
69 | true_type = true[...,1]
70 |
71 | self.nr_pixels += np.size(true[...,:1])
72 | pred_np = pred_inst[...,0]
73 | true_np = true[...,0]
74 | pred_np[pred_np >= 0.5] = 1.0
75 | pred_np[pred_np < 0.5] = 0.0
76 | correct = (pred_np == true_np).sum()
77 | self.over_correct_np += correct
78 | inter, total = _dice_info(true_np, pred_np, 1)
79 | self.over_inter_np += inter
80 | self.over_total_np += total
81 |
82 |
83 | pred_type = np.argmax(pred_type, axis=-1)
84 | for type_name, type_id in self.nuclei_type_dict.items():
85 | inter, total = _fdetect_info(true_type, pred_type, type_id)
86 | self.over_type_dict['fdetect_inter_%s' % (type_name)] += inter
87 | self.over_type_dict['fdetect_total_%s' % (type_name)] += total
88 |
89 | def _after_inference(self):
90 |
91 | stat_dict = {}
92 |
93 | stat_dict[self.prefix + '_acc' ] = self.over_correct_np / self.nr_pixels
94 | stat_dict[self.prefix + '_dice'] = 2 * self.over_inter_np / (self.over_total_np + 1.0e-8)
95 |
96 | if self.type_classification:
97 | for type_name, type_id in self.nuclei_type_dict.items():
98 | stat_dict['%s_fdetect_%s' % (self.prefix, type_name)] = (self.over_type_dict['fdetect_inter_%s' % (type_name)] + 1.0e-8) / (self.over_type_dict['fdetect_total_%s' % (type_name)] + 1.0e-8)
99 |
100 | return stat_dict
101 |
102 |
103 | ####
104 |
105 | ###########################################
106 | class Trainer(Config):
107 | ####
108 | def get_datagen(self, batch_size, mode='train', view=False):
109 | train_set = get_files(self.train_dir, self.data_ext)
110 | test_set = get_files(self.valid_dir, self.data_ext)
111 | if mode == 'train':
112 | augmentors = self.get_train_augmentors(
113 | self.train_input_shape,
114 | self.train_mask_shape,
115 | view)
116 | data_files = train_set
117 | data_generator = loader.train_generator
118 | nr_procs = self.nr_procs_train
119 | else:
120 | augmentors = self.get_valid_augmentors(
121 | self.infer_input_shape,
122 | self.infer_mask_shape,
123 | view)
124 | data_files = test_set
125 | data_generator = loader.valid_generator
126 | nr_procs = self.nr_procs_valid
127 |
128 | # set nr_proc=1 for viewing to ensure clean ctrl-z
129 | nr_procs = 1 if view else nr_procs
130 | dataset = loader.DatasetSerial(data_files)
131 | datagen = data_generator(dataset,
132 | shape_aug=augmentors[0],
133 | input_aug=augmentors[1],
134 | label_aug=augmentors[2],
135 | batch_size=batch_size,
136 | nr_procs=nr_procs)
137 |
138 | return datagen
139 | ####
140 | def view_dataset(self, mode='train'):
141 | assert mode == 'train' or mode == 'valid', "Invalid view mode"
142 | datagen = self.get_datagen(4, mode='train', view=True)
143 | loader.visualize(datagen, 4)
144 | return
145 | ####
146 | def run_once(self, opt, idx, sess_init=None, save_dir=None):
147 | ####
148 | train_datagen = self.get_datagen(opt['train_batch_size'], mode='train')
149 | valid_datagen = self.get_datagen(opt['infer_batch_size'], mode='valid')
150 |
151 | ###### must be called before ModelSaver
152 | if save_dir is None:
153 | logger.set_logger_dir(self.save_dir)
154 | else:
155 | logger.set_logger_dir(save_dir)
156 |
157 | ######
158 | model_flags = opt['model_flags']
159 | model = self.get_model(phase=idx)(**model_flags)
160 | ######
161 | callbacks=[
162 | ModelSaver(max_to_keep=opt['nr_epochs']),
163 | ]
164 | callbacks.append(RunOp(tf.tables_initializer(), run_as_trigger=False))
165 | for param_name, param_info in opt['manual_parameters'].items():
166 | model.add_manual_variable(param_name, param_info[0])
167 | callbacks.append(ScheduledHyperParamSetter(param_name, param_info[1]))
168 | # multi-GPU inference (with mandatory queue prefetch)
169 | infs = [StatCollector()]
170 | callbacks.append(DataParallelInferenceRunner(
171 | valid_datagen, infs, list(range(nr_gpus))))
172 | callbacks.append(MaxSaver('valid_dice'))
173 |
174 | ######
175 | steps_per_epoch = train_datagen.size() // nr_gpus
176 |
177 | config = TrainConfig(
178 | model = model,
179 | callbacks = callbacks ,
180 | dataflow = train_datagen ,
181 | steps_per_epoch = steps_per_epoch,
182 | max_epoch = opt['nr_epochs'],
183 | )
184 | config.session_init = sess_init
185 |
186 | launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_gpus))
187 | tf.reset_default_graph() # remove the entire graph in case of multiple runs
188 | return
189 | ####
190 | def run(self):
191 | def get_last_chkpt_path(prev_phase_dir):
192 | stat_file_path = prev_phase_dir + 'stats.json'
193 | with open(stat_file_path) as stat_file:
194 | info = json.load(stat_file)
195 | chkpt_list = [epoch_stat['global_step'] for epoch_stat in info]
196 | last_chkpts_path = "%smodel-%d.index" % (prev_phase_dir, max(chkpt_list))
197 | return last_chkpts_path
198 |
199 | phase_opts = self.training_phase
200 | if len(phase_opts) > 1:
201 | for idx, opt in enumerate(phase_opts):
202 | random.seed(self.seed)
203 | np.random.seed(self.seed)
204 | tf.random.set_random_seed(self.seed)
205 | log_dir = '%s/%02d/' % (self.save_dir, idx)
206 | pretrained_path = opt['pretrained_path']
207 | if pretrained_path == -1:
208 | pretrained_path = get_last_chkpt_path(prev_log_dir)
209 | init_weights = SaverRestore(pretrained_path, ignore=['learning_rate'])
210 | elif pretrained_path is not None:
211 | init_weights = get_model_loader(pretrained_path)
212 | prev_log_dir = log_dir
213 | self.run_once(opt, idx, sess_init=init_weights, save_dir=log_dir)
214 | return
215 | ####
216 | ####
217 |
218 | ###########################################################################
219 |
220 | if __name__ == '__main__':
221 | parser = argparse.ArgumentParser()
222 | parser.add_argument('--gpu', help="comma separated list of GPU(s) to use.")
223 | parser.add_argument('--view', help="view dataset, received either 'train' or 'valid' as input")
224 | args = parser.parse_args()
225 |
226 | trainer = Trainer()
227 | if args.view:
228 | trainer.view_dataset(mode=args.view)
229 | else:
230 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
231 | nr_gpus = len(args.gpu.split(','))
232 | trainer.run()
233 |
--------------------------------------------------------------------------------
/compute_stats.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cProfile as profile
3 | import glob
4 | import os
5 |
6 | import cv2
7 | import numpy as np
8 | import scipy.io as sio
9 | import pandas as pd
10 |
11 | from metrics.stats_utils import *
12 | from config import Config
13 |
14 | cfg = Config()
15 |
16 | def run_nuclei_type_stat(pred_dir, true_dir, type_uid_list=None, exhaustive=True):
17 | """
18 | GT must be exhaustively annotated for instance location (detection)
19 |
20 | Args:
21 | true_dir, pred_dir: Directory contains .mat annotation for each image.
22 | Each .mat must contain:
23 |
24 | --`inst_centroid`: Nx2, contains N instance centroid
25 | of mass coordinates (X, Y)
26 | --`inst_type` : Nx1: type of each instance at each index
27 |
28 | `inst_centroid` and `inst_type` must be aligned and each
29 | index must be associated to the same instance
30 |
31 | type_uid_list : list of id for nuclei type which the score should be calculated.
32 | Default to `None` means available nuclei type in GT.
33 |
34 | exhaustive : Flag to indicate whether GT is exhaustively labelled
35 | for instance types
36 | """
37 | ###
38 | file_list = glob.glob(pred_dir + '*.mat')
39 | file_list.sort() # ensure same order [1]
40 | paired_all = [] # unique matched index pair
41 | unpaired_true_all = [] # the index must exist in `true_inst_type_all` and unique
42 | unpaired_pred_all = [] # the index must exist in `pred_inst_type_all` and unique
43 | true_inst_type_all = [] # each index is 1 independent data point
44 | pred_inst_type_all = [] # each index is 1 independent data point
45 | for file_idx, filename in enumerate(file_list[:]):
46 | filename = os.path.basename(filename)
47 | basename = filename.split('.')[0]
48 | true_info = sio.loadmat(true_dir + basename + '.mat')
49 | # dont squeeze, may be 1 instance exist
50 | true_centroid = (true_info['inst_centroid']).astype('float32')
51 | true_inst_type = (true_info['inst_type']).astype('int32')
52 |
53 | if true_centroid.shape[0] != 0:
54 | true_inst_type = true_inst_type[:,0]
55 | else: # no instance at all
56 | true_centroid = np.array([[0, 0]])
57 | true_inst_type = np.array([0])
58 |
59 | # * for converting the GT type in CoNSeP and GLySAC
60 | if cfg.data_type == 'consep':
61 | true_inst_type[(true_inst_type == 3) | (true_inst_type == 4)] = 3
62 | true_inst_type[(true_inst_type == 5) | (true_inst_type == 6) | (true_inst_type == 7)] = 4
63 | if cfg.data_type == 'glysac':
64 | true_inst_type[(true_inst_type == 1) | (true_inst_type == 2) | (true_inst_type == 9) | (true_inst_type == 10)] = 1
65 | true_inst_type[(true_inst_type == 4) | (true_inst_type == 5) | (true_inst_type == 6) | (true_inst_type == 7)] = 2
66 | true_inst_type[(true_inst_type == 8) | (true_inst_type == 3)] = 3
67 |
68 | pred_info = sio.loadmat(pred_dir + basename + '.mat')
69 | # dont squeeze, may be 1 instance exist
70 | pred_centroid = (pred_info['inst_centroid']).astype('float32')
71 | pred_inst_type = (pred_info['inst_type']).astype('int32')
72 |
73 | if pred_centroid.shape[0] != 0:
74 | pred_inst_type = pred_inst_type[:,0]
75 | else: # no instance at all
76 | pred_centroid = np.array([[0, 0]])
77 | pred_inst_type = np.array([0])
78 |
79 | # ! if take longer than 1min for 1000 vs 1000 pairing, sthg is wrong with coord
80 | paired, unpaired_true, unpaired_pred = pair_coordinates(true_centroid, pred_centroid, 12)
81 |
82 | # * Aggreate information
83 | # get the offset as each index represent 1 independent instance
84 | true_idx_offset = true_idx_offset + true_inst_type_all[-1].shape[0] if file_idx != 0 else 0
85 | pred_idx_offset = pred_idx_offset + pred_inst_type_all[-1].shape[0] if file_idx != 0 else 0
86 | true_inst_type_all.append(true_inst_type)
87 | pred_inst_type_all.append(pred_inst_type)
88 |
89 | # increment the pairing index statistic
90 | if paired.shape[0] != 0: # ! sanity
91 | paired[:,0] += true_idx_offset
92 | paired[:,1] += pred_idx_offset
93 | paired_all.append(paired)
94 |
95 | unpaired_true += true_idx_offset
96 | unpaired_pred += pred_idx_offset
97 | unpaired_true_all.append(unpaired_true)
98 | unpaired_pred_all.append(unpaired_pred)
99 |
100 | paired_all = np.concatenate(paired_all, axis=0)
101 | unpaired_true_all = np.concatenate(unpaired_true_all, axis=0)
102 | unpaired_pred_all = np.concatenate(unpaired_pred_all, axis=0)
103 | true_inst_type_all = np.concatenate(true_inst_type_all, axis=0)
104 | pred_inst_type_all = np.concatenate(pred_inst_type_all, axis=0)
105 |
106 | paired_true_type = true_inst_type_all[paired_all[:,0]]
107 | paired_pred_type = pred_inst_type_all[paired_all[:,1]]
108 | unpaired_true_type = true_inst_type_all[unpaired_true_all]
109 | unpaired_pred_type = pred_inst_type_all[unpaired_pred_all]
110 |
111 | ###
112 | def _f1_type(paired_true, paired_pred, unpaired_true, unpaired_pred, type_id, w):
113 | type_samples = (paired_true == type_id) | (paired_pred == type_id)
114 |
115 | paired_true = paired_true[type_samples]
116 | paired_pred = paired_pred[type_samples]
117 |
118 | tp_dt = ((paired_true == type_id) & (paired_pred == type_id)).sum()
119 | tn_dt = ((paired_true != type_id) & (paired_pred != type_id)).sum()
120 | fp_dt = ((paired_true != type_id) & (paired_pred == type_id)).sum()
121 | fn_dt = ((paired_true == type_id) & (paired_pred != type_id)).sum()
122 |
123 | if not exhaustive:
124 | ignore = (paired_true == -1).sum()
125 | fp_dt -= ignore
126 |
127 | summary_1 = {}
128 | p_summ_1 = np.copy(paired_true)
129 | val, cnt = np.unique(p_summ_1[(paired_true != type_id) & (paired_pred == type_id)], return_counts=True)
130 | for name, num in zip(val, cnt):
131 | summary_1[name] = summary_1.get(name, 0) + num
132 |
133 | summary = {}
134 | p_summ = np.copy(paired_pred)
135 | val, cnt = np.unique(p_summ[(paired_true == type_id) & (paired_pred != type_id)], return_counts=True)
136 | for name, num in zip(val, cnt):
137 | summary[name] = summary.get(name, 0) + num
138 |
139 | fp_d = (unpaired_pred == type_id).sum()
140 | fn_d = (unpaired_true == type_id).sum()
141 |
142 |
143 | f1_type = (2 * (tp_dt + tn_dt)) / \
144 | (2 * (tp_dt + tn_dt) + w[0] * fp_dt + w[1] * fn_dt \
145 | + w[2] * fp_d + w[3] * fn_d)
146 | return f1_type
147 |
148 | # overall
149 | # * quite meaningless for not exhaustive annotated dataset
150 | w = [1, 1]
151 | tp_d = paired_pred_type.shape[0]
152 | fp_d = unpaired_pred_type.shape[0]
153 | fn_d = unpaired_true_type.shape[0]
154 |
155 | tp_tn_dt = (paired_pred_type == paired_true_type).sum()
156 | fp_fn_dt = (paired_pred_type != paired_true_type).sum()
157 |
158 | if not exhaustive:
159 | ignore = (paired_true_type == -1).sum()
160 | fp_fn_dt -= ignore
161 |
162 | acc_type = tp_tn_dt / (tp_tn_dt + fp_fn_dt)
163 | f1_d = 2 * tp_d / (2 * tp_d + w[0] * fp_d + w[1] * fn_d)
164 |
165 | w = [2, 2, 1, 1]
166 |
167 | if type_uid_list is None:
168 | type_uid_list = np.unique(true_inst_type_all).tolist()
169 |
170 |
171 | results_list = [f1_d, acc_type]
172 | for type_uid in type_uid_list:
173 | f1_type = _f1_type(paired_true_type, paired_pred_type,
174 | unpaired_pred_type, unpaired_true_type, type_uid, w)
175 | results_list.append(f1_type)
176 |
177 | np.set_printoptions(formatter={'float': '{: 0.5f}'.format})
178 | print(np.array(results_list))
179 | return
180 |
181 | def run_nuclei_inst_stat(pred_dir, true_dir, print_img_stats=False, ext='.mat'):
182 | # print stats of each image
183 | print(pred_dir)
184 |
185 | file_list = glob.glob('%s/*%s' % (pred_dir, ext))
186 | file_list.sort() # ensure same order
187 | metrics = [[], [], [], [], [], []]
188 | for filename in file_list[:]:
189 | filename = os.path.basename(filename)
190 | basename = filename.split('.')[0]
191 |
192 | true = sio.loadmat(true_dir + basename + '.mat')
193 | true = (true['inst_map']).astype('int32')
194 |
195 | pred = sio.loadmat(pred_dir + basename + '.mat')
196 | pred = (pred['inst_map']).astype('int32')
197 |
198 | # to ensure that the instance numbering is contiguous
199 | pred = remap_label(pred, by_size=False)
200 | true = remap_label(true, by_size=False)
201 |
202 | pq_info = get_fast_pq(true, pred, match_iou=0.5)[0]
203 | metrics[0].append(get_dice_1(true, pred))
204 | metrics[1].append(pq_info[0]) # dq
205 | metrics[2].append(pq_info[1]) # sq
206 | metrics[3].append(pq_info[2]) # pq
207 | metrics[4].append(get_fast_aji_plus(true, pred))
208 | metrics[5].append(get_fast_aji(true, pred))
209 |
210 | if print_img_stats:
211 | print(basename, end="\t")
212 | for scores in metrics:
213 | print("%f " % scores[-1], end=" ")
214 | print()
215 | ####
216 | metrics = np.array(metrics)
217 | metrics_avg = np.mean(metrics, axis=-1)
218 | np.set_printoptions(formatter={'float': '{: 0.5f}'.format})
219 | print(metrics_avg)
220 | metrics_avg = list(metrics_avg)
221 | return metrics
222 |
223 | if __name__ == '__main__':
224 | parser = argparse.ArgumentParser()
225 | parser.add_argument('--mode', help="mode to run the measurement,"
226 | "`type` for nuclei instance type classification or"
227 | "`instance` for nuclei instance segmentation",
228 | nargs='?', default='instance', const='instance')
229 | parser.add_argument('--pred_dir', help="point to output dir", nargs='?', default='', const='')
230 | parser.add_argument('--true_dir', help="point to ground truth dir", nargs='?', default='', const='')
231 | args = parser.parse_args()
232 |
233 | if args.mode == 'instance':
234 | run_nuclei_inst_stat(args.pred_dir, args.true_dir, print_img_stats=False)
235 | if args.mode == 'type':
236 | run_nuclei_type_stat(args.pred_dir, args.true_dir)
237 |
238 |
--------------------------------------------------------------------------------
/model/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import math
3 | import numpy as np
4 |
5 | import tensorflow as tf
6 |
7 | from tensorpack import *
8 | from tensorpack.tfutils.symbolic_functions import *
9 | from tensorpack.tfutils.summary import *
10 |
11 | from matplotlib import cm
12 | from config import Config
13 | import scipy.io as sio
14 | import glob
15 |
16 | cfg = Config()
17 | ####
18 | def resize_op(x, height_factor=None, width_factor=None, size=None,
19 | interp='bicubic', data_format='channels_last'):
20 | """
21 | Resize by a factor if `size=None` else resize to `size`
22 | """
23 | original_shape = x.get_shape().as_list()
24 | if size is not None:
25 | if data_format == 'channels_first':
26 | x = tf.transpose(x, [0, 2, 3, 1])
27 | if interp == 'bicubic':
28 | x = tf.image.resize_bicubic(x, size)
29 | elif interp == 'bilinear':
30 | x = tf.image.resize_bilinear(x, size)
31 | else:
32 | x = tf.image.resize_nearest_neighbor(x, size)
33 | x = tf.transpose(x, [0, 3, 1, 2])
34 | x.set_shape((None,
35 | original_shape[1] if original_shape[1] is not None else None,
36 | size[0], size[1]))
37 | else:
38 | if interp == 'bicubic':
39 | x = tf.image.resize_bicubic(x, size)
40 | elif interp == 'bilinear':
41 | x = tf.image.resize_bilinear(x, size)
42 | else:
43 | x = tf.image.resize_nearest_neighbor(x, size)
44 | x.set_shape((None,
45 | size[0], size[1],
46 | original_shape[3] if original_shape[3] is not None else None))
47 | else:
48 | if data_format == 'channels_first':
49 | new_shape = tf.cast(tf.shape(x)[2:], tf.float32)
50 | new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('float32'))
51 | new_shape = tf.cast(new_shape, tf.int32)
52 | x = tf.transpose(x, [0, 2, 3, 1])
53 | if interp == 'bicubic':
54 | x = tf.image.resize_bicubic(x, new_shape)
55 | elif interp == 'bilinear':
56 | x = tf.image.resize_bilinear(x, new_shape)
57 | else:
58 | x = tf.image.resize_nearest_neighbor(x, new_shape)
59 | x = tf.transpose(x, [0, 3, 1, 2])
60 | x.set_shape((None,
61 | original_shape[1] if original_shape[1] is not None else None,
62 | int(original_shape[2] * height_factor) if original_shape[2] is not None else None,
63 | int(original_shape[3] * width_factor) if original_shape[3] is not None else None))
64 | else:
65 | original_shape = x.get_shape().as_list()
66 | new_shape = tf.cast(tf.shape(x)[1:3], tf.float32)
67 | new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('float32'))
68 | new_shape = tf.cast(new_shape, tf.int32)
69 | if interp == 'bicubic':
70 | x = tf.image.resize_bicubic(x, new_shape)
71 | elif interp == 'bilinear':
72 | x = tf.image.resize_bilinear(x, new_shape)
73 | else:
74 | x = tf.image.resize_nearest_neighbor(x, new_shape)
75 | x.set_shape((None,
76 | int(original_shape[1] * height_factor) if original_shape[1] is not None else None,
77 | int(original_shape[2] * width_factor) if original_shape[2] is not None else None,
78 | original_shape[3] if original_shape[3] is not None else None))
79 | return x
80 |
81 | ####
82 | def crop_op(x, cropping, data_format='channels_first'):
83 | """
84 | Center crop image
85 | Args:
86 | cropping is the substracted portion
87 | """
88 | crop_t = cropping[0] // 2
89 | crop_b = cropping[0] - crop_t
90 | crop_l = cropping[1] // 2
91 | crop_r = cropping[1] - crop_l
92 | if data_format == 'channels_first':
93 | x = x[:,:,crop_t:-crop_b,crop_l:-crop_r]
94 | else:
95 | x = x[:,crop_t:-crop_b,crop_l:-crop_r]
96 | return x
97 | ####
98 |
99 | def label_smoothing(label, factor=0.2):
100 | """
101 | Softening label (substract the hard label by factor value, prevent overconfident and improve calibration)
102 | Args:
103 | label: one-hot format
104 | Return:
105 | Soft Label
106 | """
107 | label *= (1-factor)
108 | label += factor/(label.get_shape().as_list()[-1])
109 | return label
110 |
111 | ####
112 | def categorical_crossentropy(output, target):
113 | """
114 | categorical cross-entropy, accept probabilities not logit
115 | """
116 | # scale preds so that the class probs of each sample sum to 1
117 | output /= tf.reduce_sum(output,
118 | reduction_indices=len(output.get_shape()) - 1,
119 | keepdims=True)
120 | # manual computation of crossentropy
121 | epsilon = tf.convert_to_tensor(10e-8, output.dtype.base_dtype)
122 | output = tf.clip_by_value(output, epsilon, 1. - epsilon)
123 | return - tf.reduce_sum(target * tf.log(output),
124 | reduction_indices=len(output.get_shape()) - 1)
125 |
126 | def categorical_crossentropy_modified(output, target):
127 | """
128 | categorical cross-entropy, accept probabilities not logit
129 | """
130 | w = np.array([0.8, 1.2]) # change according to the dataset; w[0], w[1] is the weight for background and foreground, respectively
131 | output /= tf.reduce_sum(output,
132 | reduction_indices=len(output.get_shape()) - 1,
133 | keepdims=True)
134 | # manual computation of crossentropy
135 | epsilon = tf.convert_to_tensor(10e-8, output.dtype.base_dtype)
136 | output = tf.clip_by_value(output, epsilon, 1. - epsilon)
137 | loss = tf.zeros_like(output[...,0])
138 | for i in range(len(w)):
139 | loss += w[i] * target[...,i] * tf.log(output[...,i])
140 | return -loss
141 | ####
142 |
143 | def check_weight_loss(train_dir):
144 | '''
145 | Calculate the weights using for focal_loss_modified()
146 | Args:
147 | train_dir: training directory
148 | Return:
149 | w: numpy array
150 | '''
151 | file_list = glob.glob(train_dir + '/*.mat')
152 | N = {}
153 | for file in file_list:
154 | type_map = sio.loadmat(file)['type_map']
155 | if cfg.data_type == 'consep':
156 | type_map[(type_map == 3) | (type_map == 4)] = 3
157 | type_map[(type_map == 5) | (type_map == 6) | (type_map == 7)] = 4
158 | elif cfg.data_type == 'glysac':
159 | type_map[(type_map == 1) | (type_map == 2) | (type_map == 9) | (type_map == 10)] = 1
160 | type_map[(type_map == 4) | (type_map == 5) | (type_map == 6) | (type_map == 7)] = 2
161 | type_map[(type_map == 8) | (type_map == 3)] = 3
162 | val, cnt = np.unique(type_map, return_counts=True)
163 | for idx, type_id in enumerate(val):
164 | N[type_id] = N.get(type_id, 0) + cnt[idx]
165 | N = sorted(N.items())
166 | N = [val for key, val in N]
167 | c = len(N)
168 | N = np.array(N)
169 | w = np.power(N[0]/N, 1/3)
170 | w = w/w.sum() * c
171 | return w
172 |
173 |
174 | ####
175 | def focal_loss_modified(output, target, gamma=1):
176 | w = np.array([0.3, 2.2, 0.96, 0.73, 0.81]) # Calculated from check_weight_loss(), need to modify according to the dataset
177 | output /= tf.reduce_sum(output, reduction_indices=len(output.get_shape()) - 1,
178 | keepdims=True)
179 | # manual computation of focal loss
180 | epsilon = tf.convert_to_tensor(10e-8, output.dtype.base_dtype)
181 | output = tf.clip_by_value(output, epsilon, 1. - epsilon)
182 | loss = tf.zeros_like(output[...,0])
183 | for i in range(len(w)):
184 | loss += w[i] * target[...,i] * tf.pow((1 - output[...,i]), gamma) * tf.log(output[...,i])
185 | return -loss
186 |
187 |
188 | ####
189 | def focal_loss(output, target, gamma=1):
190 | """
191 | categorical focal loss, accept probabilities not logit
192 | Parameters:
193 | -----------
194 | output: Predict type of each class [N, H, W, nr_classes]
195 | target: Ground Truth in one-hot encoding [N, H, W, nr_classes]
196 | """
197 | # scale preds so that the class probs of each sample sum to 1
198 | output /= tf.reduce_sum(output, reduction_indices=len(output.get_shape()) - 1,
199 | keepdims=True)
200 | # manual computation of focal loss
201 | epsilon = tf.convert_to_tensor(10e-8, output.dtype.base_dtype)
202 | output = tf.clip_by_value(output, epsilon, 1. - epsilon)
203 | return -tf.reduce_sum(target * tf.pow((1- output), gamma) * tf.log(output), reduction_indices=len(output.get_shape()) - 1)
204 |
205 |
206 |
207 | ####
208 | def dice_loss(output, target, loss_type='sorensen', axis=None, smooth=1e-3):
209 | """Soft dice (Sørensen or Jaccard) coefficient for comparing the similarity
210 | of two batch of data, usually be used for binary image segmentation
211 | i.e. labels are binary. The coefficient between 0 to 1, 1 means totally match.
212 |
213 | Parameters
214 | -----------
215 | output : Tensor
216 | A distribution with shape: [batch_size, ....], (any dimensions).
217 | target : Tensor
218 | The target distribution, format the same with `output`.
219 | loss_type : str
220 | ``jaccard`` or ``sorensen``, default is ``jaccard``.
221 | axis : tuple of int
222 | All dimensions are reduced, default ``[1,2,3]``.
223 | smooth : float
224 | This small value will be added to the numerator and denominator.
225 | - If both output and target are empty, it makes sure dice is 1.
226 | - If either output or target are empty (all pixels are background),
227 | dice = ```smooth/(small_value + smooth)``, then if smooth is very small,
228 | dice close to 0 (even the image values lower than the threshold),
229 | so in this case, higher smooth can have a higher dice.
230 |
231 | Examples
232 | ---------
233 | >>> dice_loss = dice_coe(outputs, y_)
234 | """
235 | target = tf.squeeze(tf.cast(target, tf.float32))
236 | output = tf.squeeze(tf.cast(output, tf.float32))
237 |
238 | inse = tf.reduce_sum(output * target, axis=axis)
239 | if loss_type == 'jaccard':
240 | l = tf.reduce_sum(output * output, axis=axis)
241 | r = tf.reduce_sum(target * target, axis=axis)
242 | elif loss_type == 'sorensen':
243 | l = tf.reduce_sum(output, axis=axis)
244 | r = tf.reduce_sum(target, axis=axis)
245 | else:
246 | raise Exception("Unknown loss_type")
247 | # already flatten
248 | dice = 1.0 - (2. * inse + smooth) / (l + r + smooth)
249 | ##
250 | return dice
251 |
252 | ####
253 |
254 | def colorize(value, vmin=None, vmax=None, cmap=None):
255 | """
256 | Arguments:
257 | - value: input tensor, NHWC ('channels_last')
258 | - vmin: the minimum value of the range used for normalization.
259 | (Default: value minimum)
260 | - vmax: the maximum value of the range used for normalization.
261 | (Default: value maximum)
262 | - cmap: a valid cmap named for use with matplotlib's `get_cmap`.
263 | (Default: 'gray')
264 | Example usage:
265 | ```
266 | output = tf.random_uniform(shape=[256, 256, 1])
267 | output_color = colorize(output, vmin=0.0, vmax=1.0, cmap='viridis')
268 | tf.summary.image('output', output_color)
269 | ```
270 |
271 | Returns a 3D tensor of shape [height, width, 3], uint8.
272 | """
273 |
274 | # normalize
275 | if vmin is None:
276 | vmin = tf.reduce_min(value, axis=[1,2])
277 | vmin = tf.reshape(vmin, [-1, 1, 1])
278 | if vmax is None:
279 | vmax = tf.reduce_max(value, axis=[1,2])
280 | vmax = tf.reshape(vmax, [-1, 1, 1])
281 | value = (value - vmin) / (vmax - vmin) # vmin..vmax
282 |
283 | # squeeze last dim if it exists
284 | # NOTE: will throw error if use get_shape()
285 | # value = tf.squeeze(value)
286 |
287 | # quantize
288 | value = tf.round(value * 255)
289 | indices = tf.cast(value, np.int32)
290 |
291 | # gather
292 | colormap = cm.get_cmap(cmap if cmap is not None else 'gray')
293 | colors = colormap(np.arange(256))[:, :3]
294 | colors = tf.constant(colors, dtype=tf.float32)
295 | value = tf.gather(colors, indices)
296 | value = tf.cast(value * 255, tf.uint8)
297 | return value
298 | ####
299 | def make_image(x, cy, cx, scale_y, scale_x):
300 | """
301 | Take 1st image from x and turn channels representations
302 | into 2D image, with cx number of channels in x-axis and
303 | cy number of channels in y-axis
304 | """
305 | # norm x for better visual
306 | x = tf.transpose(x,(0,2,3,1)) # NHWC
307 | max_x = tf.reduce_max(x, axis=-1, keep_dims=True)
308 | min_x = tf.reduce_min(x, axis=-1, keep_dims=True)
309 | x = 255 * (x - min_x) / (max_x - min_x)
310 | ###
311 | x_shape = tf.shape(x)
312 | channels = x_shape[-1]
313 | iy , ix = x_shape[1], x_shape[2]
314 | ###
315 | x = tf.slice(x,(0,0,0,0),(1,-1,-1,-1))
316 | x = tf.reshape(x,(iy,ix,channels))
317 | ix += 4
318 | iy += 4
319 | x = tf.image.resize_image_with_crop_or_pad(x, iy, ix)
320 | x = tf.reshape(x,(iy,ix,cy,cx))
321 | x = tf.transpose(x,(2,0,3,1)) #cy,iy,cx,ix
322 | x = tf.reshape(x,(1,cy*iy,cx*ix,1))
323 | x = resize_op(x, scale_y, scale_x)
324 | return tf.cast(x, tf.uint8)
325 | ####
326 |
327 | ####
328 | def gen_disc_labels(nrof_labels, min, max):
329 | min += 1
330 | max += 1
331 |
332 | labels = []
333 | keys = []
334 | for k in range(nrof_labels + 1):
335 | disc_label = math.exp(math.log(min) + (math.log(max / min) * k) / nrof_labels)
336 | labels.append(disc_label - 1)
337 | keys.append(k)
338 |
339 |
340 | return labels, tf.contrib.lookup.Hashtable(tf.contrib.lookup.KeyValueTensorInitializer(keys, labels, value_dtype=tf.float64), 0)
341 |
342 | ####
343 | def replacenan(t):
344 | return tf.where(tf.is_nan(t), tf.zeros_like(t), t)
345 |
--------------------------------------------------------------------------------
/metrics/stats_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.optimize import linear_sum_assignment
3 |
4 |
5 | #####--------------------------Optimized for Speed
6 | def get_fast_aji(true, pred):
7 | """
8 | AJI version distributed by MoNuSeg, has no permutation problem but suffered from
9 | over-penalisation similar to DICE2
10 |
11 | Fast computation requires instance IDs are in contiguous orderding i.e [1, 2, 3, 4]
12 | not [2, 3, 6, 10]. Please call `remap_label` before hand and `by_size` flag has no
13 | effect on the result.
14 | """
15 | true = np.copy(true)
16 | pred = np.copy(pred)
17 | true_id_list = list(np.unique(true))
18 | pred_id_list = list(np.unique(pred))
19 |
20 | true_masks = [None,]
21 | for t in true_id_list[1:]:
22 | t_mask = np.array(true == t, np.uint8)
23 | true_masks.append(t_mask)
24 |
25 | pred_masks = [None,]
26 | for p in pred_id_list[1:]:
27 | p_mask = np.array(pred == p, np.uint8)
28 | pred_masks.append(p_mask)
29 |
30 | # prefill with value
31 | pairwise_inter = np.zeros([len(true_id_list) -1,
32 | len(pred_id_list) -1], dtype=np.float64)
33 | pairwise_union = np.zeros([len(true_id_list) -1,
34 | len(pred_id_list) -1], dtype=np.float64)
35 |
36 | # caching pairwise
37 | for true_id in true_id_list[1:]: # 0-th is background
38 | t_mask = true_masks[true_id]
39 | pred_true_overlap = pred[t_mask > 0]
40 | pred_true_overlap_id = np.unique(pred_true_overlap)
41 | pred_true_overlap_id = list(pred_true_overlap_id)
42 | for pred_id in pred_true_overlap_id:
43 | if pred_id == 0: # ignore
44 | continue # overlaping background
45 | p_mask = pred_masks[pred_id]
46 | total = (t_mask + p_mask).sum()
47 | inter = (t_mask * p_mask).sum()
48 | pairwise_inter[true_id-1, pred_id-1] = inter
49 | pairwise_union[true_id-1, pred_id-1] = total - inter
50 | #
51 | pairwise_iou = pairwise_inter / (pairwise_union + 1.0e-6)
52 | # pair of pred that give highest iou for each true, dont care
53 | # about reusing pred instance multiple times
54 | paired_pred = np.argmax(pairwise_iou, axis=1)
55 | pairwise_iou = np.max(pairwise_iou, axis=1)
56 | # exlude those dont have intersection
57 | paired_true = np.nonzero(pairwise_iou > 0.0)[0]
58 | paired_pred = paired_pred[paired_true]
59 | overall_inter = (pairwise_inter[paired_true, paired_pred]).sum()
60 | overall_union = (pairwise_union[paired_true, paired_pred]).sum()
61 | #
62 | paired_true = (list(paired_true + 1)) # index to instance ID
63 | paired_pred = (list(paired_pred + 1))
64 | # add all unpaired GT and Prediction into the union
65 | unpaired_true = np.array([idx for idx in true_id_list[1:] if idx not in paired_true])
66 | unpaired_pred = np.array([idx for idx in pred_id_list[1:] if idx not in paired_pred])
67 | for true_id in unpaired_true:
68 | overall_union += true_masks[true_id].sum()
69 | for pred_id in unpaired_pred:
70 | overall_union += pred_masks[pred_id].sum()
71 | aji_score = overall_inter / overall_union
72 | return aji_score
73 | #####
74 | def get_fast_aji_plus(true, pred):
75 | """
76 | AJI+, an AJI version with maximal unique pairing to obtain overall intersecion.
77 | Every prediction instance is paired with at most 1 GT instance (1 to 1) mapping, unlike AJI
78 | where a prediction instance can be paired against many GT instances (1 to many).
79 | Remaining unpaired GT and Prediction instances will be added to the overall union.
80 | The 1 to 1 mapping prevents AJI's over-penalisation from happening.
81 |
82 | Fast computation requires instance IDs are in contiguous orderding i.e [1, 2, 3, 4]
83 | not [2, 3, 6, 10]. Please call `remap_label` before hand and `by_size` flag has no
84 | effect on the result.
85 | """
86 | true = np.copy(true) # ? do we need this
87 | pred = np.copy(pred)
88 | true_id_list = list(np.unique(true))
89 | pred_id_list = list(np.unique(pred))
90 |
91 | true_masks = [None,]
92 | for t in true_id_list[1:]:
93 | t_mask = np.array(true == t, np.uint8)
94 | true_masks.append(t_mask)
95 |
96 | pred_masks = [None,]
97 | for p in pred_id_list[1:]:
98 | p_mask = np.array(pred == p, np.uint8)
99 | pred_masks.append(p_mask)
100 |
101 | # prefill with value
102 | pairwise_inter = np.zeros([len(true_id_list) -1,
103 | len(pred_id_list) -1], dtype=np.float64)
104 | pairwise_union = np.zeros([len(true_id_list) -1,
105 | len(pred_id_list) -1], dtype=np.float64)
106 |
107 | # caching pairwise
108 | for true_id in true_id_list[1:]: # 0-th is background
109 | t_mask = true_masks[true_id]
110 | pred_true_overlap = pred[t_mask > 0]
111 | pred_true_overlap_id = np.unique(pred_true_overlap)
112 | pred_true_overlap_id = list(pred_true_overlap_id)
113 | for pred_id in pred_true_overlap_id:
114 | if pred_id == 0: # ignore
115 | continue # overlaping background
116 | p_mask = pred_masks[pred_id]
117 | total = (t_mask + p_mask).sum()
118 | inter = (t_mask * p_mask).sum()
119 | pairwise_inter[true_id-1, pred_id-1] = inter
120 | pairwise_union[true_id-1, pred_id-1] = total - inter
121 | #
122 | pairwise_iou = pairwise_inter / (pairwise_union + 1.0e-6)
123 | #### Munkres pairing to find maximal unique pairing
124 | paired_true, paired_pred = linear_sum_assignment(-pairwise_iou)
125 | ### extract the paired cost and remove invalid pair
126 | paired_iou = pairwise_iou[paired_true, paired_pred]
127 | # now select all those paired with iou != 0.0 i.e have intersection
128 | paired_true = paired_true[paired_iou > 0.0]
129 | paired_pred = paired_pred[paired_iou > 0.0]
130 | paired_inter = pairwise_inter[paired_true, paired_pred]
131 | paired_union = pairwise_union[paired_true, paired_pred]
132 | paired_true = (list(paired_true + 1)) # index to instance ID
133 | paired_pred = (list(paired_pred + 1))
134 | overall_inter = paired_inter.sum()
135 | overall_union = paired_union.sum()
136 | # add all unpaired GT and Prediction into the union
137 | unpaired_true = np.array([idx for idx in true_id_list[1:] if idx not in paired_true])
138 | unpaired_pred = np.array([idx for idx in pred_id_list[1:] if idx not in paired_pred])
139 | for true_id in unpaired_true:
140 | overall_union += true_masks[true_id].sum()
141 | for pred_id in unpaired_pred:
142 | overall_union += pred_masks[pred_id].sum()
143 | #
144 | aji_score = overall_inter / overall_union
145 | return aji_score
146 | #####
147 | def get_fast_pq(true, pred, match_iou=0.5):
148 | """
149 | `match_iou` is the IoU threshold level to determine the pairing between
150 | GT instances `p` and prediction instances `g`. `p` and `g` is a pair
151 | if IoU > `match_iou`. However, pair of `p` and `g` must be unique
152 | (1 prediction instance to 1 GT instance mapping).
153 |
154 | If `match_iou` < 0.5, Munkres assignment (solving minimum weight matching
155 | in bipartite graphs) is caculated to find the maximal amount of unique pairing.
156 |
157 | If `match_iou` >= 0.5, all IoU(p,g) > 0.5 pairing is proven to be unique and
158 | the number of pairs is also maximal.
159 |
160 | Fast computation requires instance IDs are in contiguous orderding
161 | i.e [1, 2, 3, 4] not [2, 3, 6, 10]. Please call `remap_label` beforehand
162 | and `by_size` flag has no effect on the result.
163 |
164 | Returns:
165 | [dq, sq, pq]: measurement statistic
166 |
167 | [paired_true, paired_pred, unpaired_true, unpaired_pred]:
168 | pairing information to perform measurement
169 |
170 | """
171 | assert match_iou >= 0.0, "Cant' be negative"
172 |
173 | true = np.copy(true)
174 | pred = np.copy(pred)
175 | true_id_list = list(np.unique(true))
176 | pred_id_list = list(np.unique(pred))
177 |
178 | true_masks = [None,]
179 | for t in true_id_list[1:]:
180 | t_mask = np.array(true == t, np.uint8)
181 | true_masks.append(t_mask)
182 |
183 | pred_masks = [None,]
184 | for p in pred_id_list[1:]:
185 | p_mask = np.array(pred == p, np.uint8)
186 | pred_masks.append(p_mask)
187 |
188 | # prefill with value
189 | pairwise_iou = np.zeros([len(true_id_list) -1,
190 | len(pred_id_list) -1], dtype=np.float64)
191 |
192 | # caching pairwise iou
193 | for true_id in true_id_list[1:]: # 0-th is background
194 | t_mask = true_masks[true_id]
195 | pred_true_overlap = pred[t_mask > 0]
196 | pred_true_overlap_id = np.unique(pred_true_overlap)
197 | pred_true_overlap_id = list(pred_true_overlap_id)
198 | for pred_id in pred_true_overlap_id:
199 | if pred_id == 0: # ignore
200 | continue # overlaping background
201 | p_mask = pred_masks[pred_id]
202 | total = (t_mask + p_mask).sum()
203 | inter = (t_mask * p_mask).sum()
204 | iou = inter / (total - inter)
205 | pairwise_iou[true_id-1, pred_id-1] = iou
206 | #
207 | if match_iou >= 0.5:
208 | paired_iou = pairwise_iou[pairwise_iou > match_iou]
209 | pairwise_iou[pairwise_iou <= match_iou] = 0.0
210 | paired_true, paired_pred = np.nonzero(pairwise_iou)
211 | paired_iou = pairwise_iou[paired_true, paired_pred]
212 | paired_true += 1 # index is instance id - 1
213 | paired_pred += 1 # hence return back to original
214 | else: # * Exhaustive maximal unique pairing
215 | #### Munkres pairing with scipy library
216 | # the algorithm return (row indices, matched column indices)
217 | # if there is multiple same cost in a row, index of first occurence
218 | # is return, thus the unique pairing is ensure
219 | # inverse pair to get high IoU as minimum
220 | paired_true, paired_pred = linear_sum_assignment(-pairwise_iou)
221 | ### extract the paired cost and remove invalid pair
222 | paired_iou = pairwise_iou[paired_true, paired_pred]
223 |
224 | # now select those above threshold level
225 | # paired with iou = 0.0 i.e no intersection => FP or FN
226 | paired_true = list(paired_true[paired_iou > match_iou] + 1)
227 | paired_pred = list(paired_pred[paired_iou > match_iou] + 1)
228 | paired_iou = paired_iou[paired_iou > match_iou]
229 |
230 | # get the actual FP and FN
231 | unpaired_true = [idx for idx in true_id_list[1:] if idx not in paired_true]
232 | unpaired_pred = [idx for idx in pred_id_list[1:] if idx not in paired_pred]
233 | # print(paired_iou.shape, paired_true.shape, len(unpaired_true), len(unpaired_pred))
234 |
235 | #
236 | tp = len(paired_true)
237 | fp = len(unpaired_pred)
238 | fn = len(unpaired_true)
239 | # get the F1-score i.e DQ
240 | dq = tp / (tp + 0.5 * fp + 0.5 * fn)
241 | # get the SQ, no paired has 0 iou so not impact
242 | sq = paired_iou.sum() / (tp + 1.0e-6)
243 |
244 | return [dq, sq, dq * sq], [paired_true, paired_pred, unpaired_true, unpaired_pred]
245 |
246 | #####
247 | def get_fast_dice_2(true, pred):
248 | """
249 | Ensemble dice
250 | """
251 | true = np.copy(true)
252 | pred = np.copy(pred)
253 | true_id = list(np.unique(true))
254 | pred_id = list(np.unique(pred))
255 |
256 | overall_total = 0
257 | overall_inter = 0
258 |
259 | true_masks = [np.zeros(true.shape)]
260 | for t in true_id[1:]:
261 | t_mask = np.array(true == t, np.uint8)
262 | true_masks.append(t_mask)
263 |
264 | pred_masks = [np.zeros(true.shape)]
265 | for p in pred_id[1:]:
266 | p_mask = np.array(pred == p, np.uint8)
267 | pred_masks.append(p_mask)
268 |
269 | for true_idx in range(1, len(true_id)):
270 | t_mask = true_masks[true_idx]
271 | pred_true_overlap = pred[t_mask > 0]
272 | pred_true_overlap_id = np.unique(pred_true_overlap)
273 | pred_true_overlap_id = list(pred_true_overlap_id)
274 | try: # blinly remove background
275 | pred_true_overlap_id.remove(0)
276 | except ValueError:
277 | pass # just mean no background
278 | for pred_idx in pred_true_overlap_id:
279 | p_mask = pred_masks[pred_idx]
280 | total = (t_mask + p_mask).sum()
281 | inter = (t_mask * p_mask).sum()
282 | overall_total += total
283 | overall_inter += inter
284 |
285 | return 2 * overall_inter / overall_total
286 | #####
287 |
288 | #####--------------------------As pseudocode
289 | def get_dice_1(true, pred):
290 | """
291 | Traditional dice
292 | """
293 | # cast to binary 1st
294 | true = np.copy(true)
295 | pred = np.copy(pred)
296 | true[true > 0] = 1
297 | pred[pred > 0] = 1
298 | inter = true * pred
299 | denom = true + pred
300 | return 2.0 * np.sum(inter) / np.sum(denom)
301 | ####
302 | def get_dice_2(true, pred):
303 | true = np.copy(true)
304 | pred = np.copy(pred)
305 | true_id = list(np.unique(true))
306 | pred_id = list(np.unique(pred))
307 | # remove background aka id 0
308 | true_id.remove(0)
309 | pred_id.remove(0)
310 |
311 | total_markup = 0
312 | total_intersect = 0
313 | for t in true_id:
314 | t_mask = np.array(true == t, np.uint8)
315 | for p in pred_id:
316 | p_mask = np.array(pred == p, np.uint8)
317 | intersect = p_mask * t_mask
318 | if intersect.sum() > 0:
319 | total_intersect += intersect.sum()
320 | total_markup += (t_mask.sum() + p_mask.sum())
321 | return 2 * total_intersect / total_markup
322 | #####
323 | def remap_label(pred, by_size=False):
324 | """
325 | Rename all instance id so that the id is contiguous i.e [0, 1, 2, 3]
326 | not [0, 2, 4, 6]. The ordering of instances (which one comes first)
327 | is preserved unless by_size=True, then the instances will be reordered
328 | so that bigger nucler has smaller ID
329 |
330 | Args:
331 | pred : the 2d array contain instances where each instances is marked
332 | by non-zero integer
333 | by_size : renaming with larger nuclei has smaller id (on-top)
334 | """
335 | pred_id = list(np.unique(pred))
336 | pred_id.remove(0)
337 | if len(pred_id) == 0:
338 | return pred # no label
339 | if by_size:
340 | pred_size = []
341 | for inst_id in pred_id:
342 | size = (pred == inst_id).sum()
343 | pred_size.append(size)
344 | # sort the id by size in descending order
345 | pair_list = zip(pred_id, pred_size)
346 | pair_list = sorted(pair_list, key=lambda x: x[1], reverse=True)
347 | pred_id, pred_size = zip(*pair_list)
348 |
349 | new_pred = np.zeros(pred.shape, np.int32)
350 | for idx, inst_id in enumerate(pred_id):
351 | new_pred[pred == inst_id] = idx + 1
352 | return new_pred
353 | #####
354 | def pair_coordinates(setA, setB, radius):
355 | """
356 | Use the Munkres or Kuhn-Munkres algorithm to find the most optimal
357 | unique pairing (largest possible match) when pairing points in set B
358 | against points in set A, using distance as cost function
359 |
360 | Args:
361 | setA, setB: np.array (float32) of size Nx2 contains the of XY coordinate
362 | of N different points
363 | radius: valid area around a point in setA to consider
364 | a given coordinate in setB a candidate for match
365 | Return:
366 | pairing: pairing is an array of indices
367 | where point at index pairing[0] in set A paired with point
368 | in set B at index pairing[1]
369 | unparedA, unpairedB: remaining poitn in set A and set B unpaired
370 | """
371 |
372 | # * Euclidean distance as the cost matrix
373 | setA_tile = np.expand_dims(setA, axis=1)
374 | setB_tile = np.expand_dims(setB, axis=0)
375 | setA_tile = np.repeat(setA_tile, setB.shape[0], axis=1)
376 | setB_tile = np.repeat(setB_tile, setA.shape[0], axis=0)
377 | pair_distance = (setA_tile - setB_tile) ** 2
378 | # set A is row, and set B is paired against set A
379 | pair_distance = np.sqrt(np.sum(pair_distance, axis=-1))
380 |
381 | # * Munkres pairing with scipy library
382 | # the algorithm return (row indices, matched column indices)
383 | # if there is multiple same cost in a row, index of first occurence
384 | # is return, thus the unique pairing is ensured
385 | indicesA, paired_indicesB = linear_sum_assignment(pair_distance)
386 |
387 | # extract the paired cost and remove instances
388 | # outside of designated radius
389 | pair_cost = pair_distance[indicesA, paired_indicesB]
390 |
391 | pairedA = indicesA[pair_cost <= radius]
392 | pairedB = paired_indicesB[pair_cost <= radius]
393 |
394 | unpairedA = [idx for idx in range(setA.shape[0]) if idx not in list(pairedA)]
395 | unpairedB = [idx for idx in range(setB.shape[0]) if idx not in list(pairedB)]
396 |
397 | pairing = np.array(list(zip(pairedA, pairedB)))
398 | unpairedA = np.array(unpairedA, dtype=np.int64)
399 | unpairedB = np.array(unpairedB, dtype=np.int64)
400 |
401 | return pairing, unpairedA, unpairedB
402 |
--------------------------------------------------------------------------------
/model/sonnet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import string
4 | import collections
5 |
6 | import tensorflow as tf
7 | import cv2
8 |
9 | from tensorpack import *
10 | from tensorpack.models import BatchNorm, BNReLU, Conv2D, GlobalAvgPooling, Dropout
11 | from tensorpack.tfutils.summary import add_moving_summary, add_param_summary, add_activation_summary
12 | from tensorpack.tfutils.scope_utils import under_name_scope, auto_reuse_variable_scope
13 | from tensorpack.tfutils import optimizer
14 | from tensorpack.tfutils.gradproc import GlobalNormClip
15 |
16 |
17 | import sys
18 | sys.path.append("..") # adds higher directory to python modules path.
19 |
20 | from .utils import *
21 |
22 |
23 | try: # HACK: import beyond current level, may need to restructure
24 | from config import Config
25 | except ImportError:
26 | assert False, 'Fail to import config.py'
27 |
28 | ####
29 | def upsample2x(name, x):
30 | """
31 | Nearest neighbor up-sampling
32 | """
33 | return FixedUnPooling(
34 | name, x, 2, unpool_mat=np.ones((2, 2), dtype='float32'),
35 | data_format='channels_first')
36 |
37 | def res_blk(name, l, ch, ksize, count, split=1, strides=1):
38 | ch_in = l.get_shape().as_list()
39 | with tf.variable_scope(name):
40 | for i in range(0, count):
41 | with tf.variable_scope('block' + str(i)):
42 | x = l if i == 0 else BNReLU('preact', l)
43 | x = Conv2D('conv1', x, ch[0], ksize[0], activation=BNReLU)
44 | x = Conv2D('conv2', x, ch[1], ksize[1], split=split,
45 | strides=strides if i == 0 else 1, activation=BNReLU)
46 | x = Conv2D('conv3', x, ch[2], ksize[2], activation=tf.identity)
47 | if (strides != 1 or ch_in[1] != ch[2]) and i == 0:
48 | l = Conv2D('convshortcut', l, ch[2], 1, strides=strides)
49 | l = l + x
50 | # end of each group need an extra activation
51 | l = BNReLU('bnlast',l)
52 | return l
53 |
54 | def dense_blk(name, l, ch, ksize, count, split=1, padding='valid'):
55 | with tf.variable_scope(name):
56 | for i in range(0, count):
57 | with tf.variable_scope('blk/' + str(i)):
58 | x = BNReLU('preact_bna', l)
59 | x = Conv2D('conv1', x, ch[0], ksize[0], padding=padding, activation=BNReLU)
60 | x = Conv2D('conv2', x, ch[1], ksize[1], padding=padding, split=split)
61 | ##
62 | if padding == 'valid':
63 | x_shape = x.get_shape().as_list()
64 | l_shape = l.get_shape().as_list()
65 | l = crop_op(l, (l_shape[2] - x_shape[2],
66 | l_shape[3] - x_shape[3]))
67 |
68 | l = tf.concat([l, x], axis=1)
69 | l = BNReLU('blk_bna', l)
70 | return l
71 |
72 | ####
73 | @layer_register(log_shape=True)
74 | def resize_bilinear(i, size, align_corners=False):
75 | ret = tf.transpose(i, [0, 2, 3, 1])
76 | ret = tf.image.resize_bilinear(ret, size=[size, size], align_corners=align_corners)
77 | ret = tf.transpose(ret, [0, 3, 1, 2])
78 | return tf.identity(ret, name='output')
79 |
80 | ####
81 | def resize_nearest_neighbor(i, size):
82 | ret = tf.transpose(i, (0, 2, 3, 1))
83 | ret = tf.image.resize_nearest_neighbor(ret, (size, size))
84 | ret = tf.transpose(ret, (0, 3, 1, 2))
85 | return ret
86 |
87 | ####
88 | @layer_register(log_shape=True)
89 | def DepthConv(x, out_channel, kernel_shape, padding='SAME', stride=1,
90 | W_init=None, activation=tf.identity):
91 | in_shape = x.get_shape().as_list()
92 | in_channel = in_shape[1]
93 | assert out_channel % in_channel == 0, (out_channel, in_channel)
94 | channel_mult = out_channel // in_channel
95 |
96 | if W_init is None:
97 | W_init = tf.variance_scaling_initializer(scale=2.0, mode='fan_out')
98 | kernel_shape = [kernel_shape, kernel_shape]
99 | filter_shape = kernel_shape + [in_channel, channel_mult]
100 |
101 | W = tf.get_variable('W', filter_shape, initializer=W_init)
102 | conv = tf.nn.depthwise_conv2d(x, W, [1, 1, stride, stride], padding=padding, data_format='NCHW')
103 | return activation(conv, name='output')
104 |
105 |
106 | BlockArgs = collections.namedtuple('BlockArgs', [
107 | 'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
108 | 'expand_ratio', 'id_skip', 'strides', 'se_ratio'
109 | ])
110 |
111 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
112 |
113 | DEFAULT_BLOCKS_ARGS = [
114 | BlockArgs(kernel_size=3, num_repeat=1, input_filters=32, output_filters=16,
115 | expand_ratio=1, id_skip=True, strides=[1, 1], se_ratio=0.25),
116 | BlockArgs(kernel_size=3, num_repeat=2, input_filters=16, output_filters=24,
117 | expand_ratio=6, id_skip=True, strides=[2, 2], se_ratio=0.25),
118 | BlockArgs(kernel_size=5, num_repeat=2, input_filters=24, output_filters=40,
119 | expand_ratio=6, id_skip=True, strides=[2, 2], se_ratio=0.25),
120 | BlockArgs(kernel_size=3, num_repeat=3, input_filters=40, output_filters=80,
121 | expand_ratio=6, id_skip=True, strides=[2, 2], se_ratio=0.25),
122 | BlockArgs(kernel_size=5, num_repeat=3, input_filters=80, output_filters=112,
123 | expand_ratio=6, id_skip=True, strides=[1, 1], se_ratio=0.25),
124 | BlockArgs(kernel_size=5, num_repeat=4, input_filters=112, output_filters=192,
125 | expand_ratio=6, id_skip=True, strides=[2, 2], se_ratio=0.25),
126 | BlockArgs(kernel_size=3, num_repeat=1, input_filters=192, output_filters=320,
127 | expand_ratio=6, id_skip=True, strides=[1, 1], se_ratio=0.25)
128 | ]
129 |
130 |
131 | def round_filters(filters, width_coefficient, depth_divisor):
132 | """Round number of filters based on width multiplier"""
133 | filters *= width_coefficient
134 | new_filters = int(filters + depth_divisor / 2) // depth_divisor * depth_divisor
135 | new_filters = max(depth_divisor, new_filters)
136 | # Make sure that round down does not go down by more than 10%
137 | if new_filters < 0.9 * filters:
138 | new_filters += depth_divisor
139 | return int(new_filters)
140 |
141 | def round_repeats(repeats, depth_coefficient):
142 | """Round number of repeats based on depth multiplier"""
143 | return int(math.ceil(depth_coefficient * repeats))
144 |
145 | def mb_conv_block(inputs, block_args, activation, drop_rate=None, freeze_en=False, prefix='',):
146 | """Mobile Inverted Residual Bottleneck."""
147 | has_se = (block_args.se_ratio is not None) and (0 < block_args.se_ratio <= 1)
148 | bn_axis = 1
149 | # For naming convolutional, batchnorm layers to match pretrained weights
150 | num_conv = ''
151 | num_batch = 0
152 | # Expansion phase
153 | filters = block_args.input_filters * block_args.expand_ratio
154 | if block_args.expand_ratio != 1:
155 | x = Conv2D(prefix + '/conv2d', inputs, filters, 1, padding='same', use_bias=False)
156 | num_conv = '_1'
157 | x = BatchNorm(prefix + '/tpu_batch_normalization', x)
158 | num_batch += 1
159 | x = activation(x)
160 | else:
161 | x = inputs
162 | # Depthwise convolution
163 | x = DepthConv(prefix+'/depthwise_conv2d', x, x.get_shape().as_list()[1], block_args.kernel_size, stride=block_args.strides[0])
164 | if num_batch != 0:
165 | x = BatchNorm(prefix + '/tpu_batch_normalization' + '_' + str(num_batch), x)
166 | else:
167 | x = BatchNorm(prefix + '/tpu_batch_normalization', x)
168 | num_batch += 1
169 | x = activation(x)
170 |
171 | # Squeeze and Excitation phase
172 | if has_se:
173 | num_reduced_filters = max(1, int(block_args.input_filters * block_args.se_ratio))
174 | se_tensor = GlobalAvgPooling(prefix+'/se/squeeze', x, data_format='NCHW')
175 | target_shape = [-1, filters, 1, 1]
176 | se_tensor = tf.reshape(se_tensor, target_shape, name=prefix + '/se/reshape')
177 |
178 | se_tensor = Conv2D(prefix + '/se/conv2d', se_tensor, num_reduced_filters, 1, activation=activation, padding='same', use_bias=True)
179 | se_tensor = Conv2D(prefix + '/se/conv2d_1', se_tensor, filters, 1, activation=tf.sigmoid, padding='same', use_bias=True)
180 | x = tf.multiply(x, se_tensor, name=prefix + '/se/excite')
181 | # Output phase
182 | x = Conv2D(prefix + '/conv2d' + num_conv, x, block_args.output_filters, 1, padding='same', use_bias=False)
183 | x = tf.stop_gradient(x) if freeze_en else x
184 | x = BatchNorm(prefix + '/tpu_batch_normalization' + '_' + str(num_batch), x)
185 | if block_args.id_skip and all(s==1 for s in block_args.strides) and block_args.input_filters == block_args.output_filters:
186 | if drop_rate and (drop_rate > 0):
187 | x = Dropout(x, rate=drop_rate, noise_shape=(tf.shape(x)[0], 1, 1, 1), name=prefix + 'drop')
188 | x = tf.math.add(x, inputs, name=prefix + 'add')
189 | x = tf.stop_gradient(x) if freeze_en else x
190 | return x
191 |
192 | def EfficientNet(i, width_coefficient, depth_coefficient, default_resolution, dropout_rate=0.2, drop_connect_rate=0.2,
193 | depth_divisor=8, block_args=DEFAULT_BLOCKS_ARGS, freeze_en=False, model_name='efficientnet'):
194 | """Instantiates the EfficientNet architecture using given scaling coefficient.
195 | # Arguments:
196 | width_coefficient: float, scaling coefficient for network width
197 | depth_coefficient: float, scaling coefficient for network depth
198 | default_resolution: int, default resolution of input
199 | dropout_rate: float, dropout rate before final classifier layer.
200 | drop_connect_rate: float, dropout rate at skip connection
201 | depth_divisor: int.
202 | block_args: A list of BlockArgs to construct block modules.
203 | model_name: string, model name.
204 | # Returns:
205 | EfficientNet as a backbone encoder.
206 | """
207 |
208 | bn_axis = 1
209 | activation = tf.nn.swish
210 | with tf.variable_scope(model_name):
211 | with tf.variable_scope('stem'):
212 | # Build stem
213 | x = Conv2D('conv2d', i, round_filters(32, width_coefficient, depth_divisor), 3, strides=(1, 1), padding='same', use_bias=False)
214 | x = BatchNorm('tpu_batch_normalization', x)
215 | x = activation(x)
216 |
217 | # Build blocks
218 | num_blocks_total = sum(block_args.num_repeat for block_args in block_args)
219 | block_num = 0
220 | ret = []
221 | for idx, block_args in enumerate(block_args):
222 | assert block_args.num_repeat > 0
223 | # Update block input and output filters based on depth multiplier.
224 | block_args = block_args._replace(
225 | input_filters=round_filters(block_args.input_filters, width_coefficient, depth_divisor),
226 | output_filters=round_filters(block_args.output_filters, width_coefficient, depth_divisor),
227 | num_repeat=round_repeats(block_args.num_repeat, depth_coefficient)
228 | )
229 | # The first block needs to take care of stride and filter size increase.
230 | drop_rate = drop_connect_rate * float(block_num) / num_blocks_total
231 | x = mb_conv_block(x, block_args, activation=activation, drop_rate=drop_rate, freeze_en=freeze_en, prefix='blocks_{}'.format(block_num))
232 | if block_num==0 or block_num==2 or block_num==4 or block_num==10:
233 | ret.append(x)
234 | x = tf.stop_gradient(x) if freeze_en else x
235 | block_num += 1
236 | if block_args.num_repeat > 1:
237 | block_args = block_args._replace(input_filters=block_args.output_filters, strides=[1, 1])
238 | for bidx in range(block_args.num_repeat - 1):
239 | drop_rate = drop_connect_rate * float(block_num) / num_blocks_total
240 | block_prefix = 'blocks_{}'.format(block_num)
241 | x = mb_conv_block(x, block_args, activation=activation, drop_rate = drop_rate, freeze_en=freeze_en, prefix=block_prefix)
242 | if block_num==0 or block_num == 2 or block_num==4 or block_num==10 :
243 | ret.append(x)
244 | x = tf.stop_gradient(x) if freeze_en else x
245 | block_num += 1
246 |
247 | # Build head
248 | with tf.variable_scope('head'):
249 | x = Conv2D('conv2d', x, round_filters(1280, width_coefficient, depth_divisor), 1, padding='same', use_bias=False)
250 | x = tf.stop_gradient(x) if freeze_en else x
251 | x = Conv2D('conv_bot', x, 1024, 1)
252 | ret.append(x)
253 | return ret
254 |
255 |
256 | ###
257 | def EfficientNetB0(x, freeze_en):
258 | return EfficientNet(x, 1.0, 1.0, 224, 0.2, freeze_en=freeze_en, model_name='efficientnet-b0')
259 |
260 | ###
261 | def EfficientNetB1(x, freeze_en):
262 | return EfficientNet(x, 1.0, 1.1, 240, 0.2, freeze_en=freeze_en, model_name='efficientnet-b1')
263 |
264 | ###
265 | def EfficientNetB2(x, freeze_en):
266 | return EfficientNet(x, 1.1, 1.2, 260, 0.3, freeze_en=freeze_en, model_name='efficientnet-b2')
267 |
268 | ###
269 | def EfficientNetB3(x, freeze_en):
270 | return EfficientNet(x, 1.2, 1.4, 300, 0.3, freeze_en=freeze_en, model_name='efficientnet-b3')
271 | ###
272 | def EfficientNetB4(x, freeze_en):
273 | return EfficientNet(x, 1.4, 1.8, 380, 0.4, freeze_en=freeze_en, model_name='efficientnet-b4')
274 |
275 | ####
276 | def decoder(name, i):
277 | pad = 'valid'
278 | with tf.variable_scope(name):
279 | with tf.variable_scope('S5'):
280 | u5 = Conv2D('bottleneck', i[-1], 256, 1, strides=1, activation=BNReLU)
281 | u5_x2 = upsample2x('deux', u5)
282 | with tf.variable_scope('S4'):
283 | u4 = Conv2D('bottleneck', i[-2], 256, 1, strides=1, activation=BNReLU)
284 | u4_add = tf.add_n([u4, u5_x2])
285 | u4 = Conv2D('conva', u4_add, 256, 5, strides=1, padding=pad)
286 | with tf.variable_scope('_add_1'):
287 | u4_10 = BNReLU('preact', u4)
288 | u4_11 = Conv2D('_1', u4_10, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU)
289 | u4_12 = Conv2D('_2', u4_11, 32, 3, padding=pad, dilation_rate=(1, 1))
290 | u4_12 = crop_op(u4_12, (6, 6))
291 | with tf.variable_scope('_add_2'):
292 | u4_20 = BNReLU('preact', u4)
293 | u4_21 = Conv2D('_1', u4_20, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU)
294 | u4_22 = Conv2D('_2', u4_21, 32, 5, padding=pad, dilation_rate=(1, 1))
295 | u4_22 = crop_op(u4_22, (4, 4))
296 | with tf.variable_scope('_add_3'):
297 | u4_30 = BNReLU('preact', u4)
298 | u4_31 = Conv2D('_1', u4_30, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU)
299 | u4_32 = Conv2D('_2', u4_31, 32, 3, padding=pad, dilation_rate=(2, 2))
300 | u4_32 = crop_op(u4_32, (4, 4))
301 | with tf.variable_scope('_add_4'):
302 | u4_40 = BNReLU('preact', u4)
303 | u4_41 = Conv2D('_1', u4_40, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU)
304 | u4_42 = Conv2D('_2', u4_41, 32, 5, padding=pad, dilation_rate=(2, 2), activation=BNReLU)
305 | u4 = crop_op(u4, (8, 8))
306 | u4_cat = tf.concat([u4, u4_12, u4_22, u4_32, u4_42], axis=1, name='_cat')
307 | u4_cat = BNReLU('outcat_BNReLU', u4_cat)
308 | u4_cat = Conv2D('_outcat', u4_cat, 256, 3, padding=pad, strides=1, activation=BNReLU)
309 | u4_x2 = upsample2x('deux', u4_cat)
310 |
311 | with tf.variable_scope('S3'):
312 | u3 = Conv2D('bottleneck', i[-3], 256, 1, strides=1, activation=BNReLU)
313 | u3 = crop_op(u3, (28, 28))
314 | u3_add = tf.add_n([u3, u4_x2])
315 | u3 = Conv2D('conva', u3_add, 256, 5, strides=1, padding=pad)
316 | with tf.variable_scope('_add_1'):
317 | u3_10 = BNReLU('preact', u3)
318 | u3_11 = Conv2D('_1', u3_10, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU)
319 | u3_12 = Conv2D('_2', u3_11, 32, 3, padding=pad, dilation_rate=(1, 1))
320 | u3_12 = crop_op(u3_12, (6, 6))
321 | with tf.variable_scope('_add_2'):
322 | u3_20 = BNReLU('preact', u3)
323 | u3_21 = Conv2D('_1', u3_20, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU)
324 | u3_22 = Conv2D('_2', u3_21, 32, 5, padding=pad, dilation_rate=(1, 1))
325 | u3_22 = crop_op(u3_22, (4, 4))
326 | with tf.variable_scope('_add_3'):
327 | u3_30 = BNReLU('preact', u3)
328 | u3_31 = Conv2D('_1', u3_30, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU)
329 | u3_32 = Conv2D('_2', u3_31, 32, 3, padding=pad, dilation_rate=(2, 2))
330 | u3_32 = crop_op(u3_32, (4, 4))
331 | with tf.variable_scope('_add_4'):
332 | u3_40 = BNReLU('preact', u3)
333 | u3_41 = Conv2D('_1', u3_40, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU)
334 | u3_42 = Conv2D('_2', u3_41, 32, 5, padding=pad, dilation_rate=(2, 2))
335 |
336 | u3 = crop_op(u3, (8, 8))
337 | u3_cat = tf.concat([u3, u3_12, u3_22, u3_32, u3_42], axis=1, name='_cat')
338 | u3_cat = BNReLU('outcat_BNReLU', u3_cat)
339 | u3_cat = Conv2D('_outcat', u3_cat, 256, 3, padding=pad, strides=1, activation=BNReLU)
340 | u3_x2 = upsample2x('deux', u3_cat)
341 |
342 | with tf.variable_scope('S2'):
343 | u2 = Conv2D('bottleneck', i[-4], 256, 1, strides=1, activation=BNReLU)
344 | u2 = crop_op(u2, (83, 83))
345 | u2_add = tf.add_n([u2, u3_x2])
346 | u2 = Conv2D('conva', u2_add, 256, 5, strides=1, padding=pad)
347 | with tf.variable_scope('_add_1'):
348 | u2_10 = BNReLU('preact', u2)
349 | u2_11 = Conv2D('_1', u2_10, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU)
350 | u2_12 = Conv2D('_2', u2_11, 32, 3, padding=pad, dilation_rate=(1, 1))
351 | u2_12 = crop_op(u2_12, (6, 6))
352 | with tf.variable_scope('_add_2'):
353 | u2_20 = BNReLU('preact', u2)
354 | u2_21 = Conv2D('_1', u2_20, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU)
355 | u2_22 = Conv2D('_2', u2_21, 32, 5, padding=pad, dilation_rate=(1, 1))
356 | u2_22 = crop_op(u2_22, (4, 4))
357 | with tf.variable_scope('_add_3'):
358 | u2_30 = BNReLU('preact', u2)
359 | u2_31 = Conv2D('_1', u2_30, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU)
360 | u2_32 = Conv2D('_2', u2_31, 32, 3, padding=pad, dilation_rate=(2, 2))
361 | u2_32 = crop_op(u2_32, (4, 4))
362 | with tf.variable_scope('_add_4'):
363 | u2_40 = BNReLU('preact', u2)
364 | u2_41 = Conv2D('_1', u2_40, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU)
365 | u2_42 = Conv2D('_2', u2_41, 32, 5, padding=pad, dilation_rate=(2, 2))
366 |
367 | u2 = crop_op(u2, (8, 8))
368 | u2_cat = tf.concat([u2, u2_12, u2_22, u2_32, u2_42], axis=1, name='_cat')
369 | u2_cat = BNReLU('outcat_BNReLU', u2_cat)
370 | u2_cat = Conv2D('_outcat', u2_cat, 256, 3, padding='same', strides=1, activation=BNReLU)
371 |
372 |
373 | with tf.variable_scope('Pyramid'):
374 | p1 = Conv2D('p1', u2_cat, 128, 5, padding='same', activation=BNReLU)
375 | p1 = Conv2D('p1_1', p1, 128, 5, padding='same', activation=BNReLU)
376 | p2 = Conv2D('p2', u3_cat, 128, 5, padding=pad, activation=BNReLU)
377 | p2 = Conv2D('p2_1', p2, 128, 3, padding=pad, activation=BNReLU)
378 | p2_x2 = upsample2x('deux_p2', p2)
379 | p3 = crop_op(u4_cat, (2, 2))
380 | p3 = Conv2D('p3', p3, 128, 5, padding=pad, activation=BNReLU)
381 | p3 = Conv2D('p3_1', p3, 128, 5, padding=pad, activation=BNReLU)
382 | p3_x2 = upsample2x('deux_p3', p3)
383 | p3_x4 = upsample2x('quatre_p3', p3_x2)
384 | p4 = crop_op(u5, (4, 4))
385 | p4 = Conv2D('p4', p4, 128, 5, padding=pad, activation=BNReLU)
386 | p4 = Conv2D('p4_1', p4, 128, 5, padding=pad, activation=BNReLU)
387 | p4_x2 = upsample2x('deux_p4', p4)
388 | p4_x4 = upsample2x('quatre_p4', p4_x2)
389 | p4_x8 = upsample2x('huit_p4', p4_x4)
390 | p_cat = tf.concat([p1, p2_x2, p3_x4, p4_x8], axis=1)
391 | p_cat = Conv2D('_outcat_1', p_cat, 256, 5, padding='same', activation=BNReLU)
392 | p_cat = Conv2D('_outcat_2', p_cat, 256, 5, padding='same', activation=BNReLU)
393 | p_cat_x2 = upsample2x('deux_cat', p_cat)
394 |
395 | i_5 = crop_op(i[0], (190, 190))
396 | i_5 = Conv2D('i_5', i_5, 256, 1, activation=BNReLU)
397 | p_cat = tf.add_n([i_5, p_cat_x2])
398 | p_0 = Conv2D('p_0', p_cat, 128, 5, strides=1, padding='valid')
399 |
400 | return p_0
401 |
402 |
403 | class Model(ModelDesc, Config):
404 | def __init__(self, freeze_en=False):
405 | super(Model, self).__init__()
406 | assert tf.test.is_gpu_available()
407 | self.freeze_en = freeze_en
408 | self.data_format = 'NCHW'
409 |
410 |
411 | def _get_inputs(self):
412 | return [InputDesc(tf.float32, [None] + self.train_input_shape + [3], 'images'),
413 | InputDesc(tf.float32, [None] + self.train_mask_shape + [None], 'truemap-coded')]
414 |
415 | # for node to receive manual info such as learning rate.
416 | def add_manual_variable(self, name, init_value, summary=True):
417 | var = tf.get_variable(name, initializer=init_value, trainable=False)
418 | if summary:
419 | tf.summary.scalar(name + '-summary', var)
420 | return
421 |
422 | def optimizer(self):
423 | with tf.variable_scope("", reuse=True):
424 | lr = tf.get_variable('learning_rate')
425 | opt = self.train_optimizer(learning_rate=lr)
426 | return optimizer.apply_grad_processors(opt, [GlobalNormClip(1.)])
427 |
428 |
429 | class Sonnet(Model):
430 |
431 | def build_graph(self, inputs, truemap_coded):
432 | images = inputs
433 | orig_imgs = images
434 |
435 |
436 |
437 | if hasattr(self, 'type_classification') and self.type_classification:
438 | true_type = truemap_coded[..., 1]
439 | true_type = tf.cast(true_type, tf.int32)
440 | true_type = tf.identity(true_type, name='truemap-type')
441 | one_type = tf.one_hot(true_type, self.nr_types, axis=-1)
442 | true_type = tf.expand_dims(true_type, axis=-1)
443 |
444 | true_np = tf.cast(true_type > 0, tf.int32)
445 | true_np = tf.identity(true_np, name='truemap-np')
446 | one_np = tf.one_hot(tf.squeeze(true_np, axis=-1), 2, axis=-1)
447 |
448 |
449 | else:
450 | true_np = truemap_coded[..., 0]
451 | true_np = tf.cast(true_np, tf.int32)
452 | one_np = tf.one_hot(true_np, 2, axis=-1)
453 | true_np = tf.expand_dims(true_np, axis=-1)
454 | true_np = tf.identity(true_np, name='truemap-np')
455 |
456 |
457 | true_ord = truemap_coded[...,-1]
458 | true_ord = tf.expand_dims(true_ord, axis=-1)
459 | true_ord = tf.identity(true_ord, name='true-ord')
460 |
461 | ####
462 | with argscope(Conv2D, activation=tf.identity, use_bias=False,
463 | W_init=tf.variance_scaling_initializer(scale=2.0, mode='fan_out')),\
464 | argscope([Conv2D, BatchNorm], data_format=self.data_format):
465 | i = tf.transpose(images, [0, 3, 1, 2])
466 | i = i if not self.input_norm else i / 255.0
467 |
468 | ####
469 | d = EfficientNetB0(i, self.freeze_en)
470 |
471 | np_feat = decoder('np', d)
472 | np_feat = tf.identity(np_feat, name='np_feat')
473 | npx = BNReLU('preact_out_np', np_feat)
474 |
475 |
476 | ordi_feat = decoder('ordi', d)
477 | ordi = BNReLU('preact_out_ordi', ordi_feat)
478 |
479 | if self.type_classification:
480 | tp_feat = decoder('tp', d)
481 | tp = BNReLU('preact_out_tp', tp_feat)
482 |
483 | # Nuclei Type Pixels (NT)
484 | logi_class = Conv2D('conv_out_tp', tp, self.nr_types, 1, use_bias=True, activation=tf.identity)
485 | logi_class = tf.transpose(logi_class, [0, 2, 3, 1])
486 | soft_class = tf.nn.softmax(logi_class, axis=-1)
487 |
488 | ### Nuclei Pixels (NF)
489 | logi_np = Conv2D('conv_out_np', npx, 2, 1, use_bias=True, activation=tf.identity)
490 | logi_np = tf.transpose(logi_np, [0, 2, 3, 1])
491 | soft_np = tf.nn.softmax(logi_np, axis=-1)
492 | prob_np = tf.identity(soft_np[...,1], name='predmap-prob-np')
493 | prob_np = tf.expand_dims(prob_np, axis=-1)
494 |
495 |
496 | ### Ordinal (NO)
497 | logi_ord = Conv2D('conv_out_ord', ordi, 16, 1, use_bias=True, activation=tf.identity)
498 | logi_ord_t = tf.transpose(logi_ord, [0, 2, 3, 1])
499 | N, C ,H, W = logi_ord.get_shape().as_list()
500 | ord_num = int(C/2)
501 | logi_ord = tf.reshape(logi_ord_t, shape=[-1, H, W, ord_num, 2])
502 | prob_ord = tf.nn.softmax(logi_ord, axis=-1)
503 | prob_ord = tf.identity(prob_ord, name='prob_ord')
504 | nn_out_labels = tf.reduce_sum(tf.argmax(prob_ord, axis=-1, output_type=tf.int32), axis=3, keepdims=True)
505 | pred_ord = tf.identity(nn_out_labels, name='predmap-ord')
506 |
507 | ### encoded so that inference can extract all output at once
508 | predmap_coded = tf.concat([soft_class, prob_np], axis=-1, name='predmap-coded')
509 |
510 |
511 |
512 |
513 |
514 | def loss_ord(prob_ord, true_ord, name=None):
515 | (ord_n_pk, ord_pk) = tf.unstack(prob_ord, axis=-1)
516 | epsilon = tf.convert_to_tensor(10e-8, ord_n_pk.dtype.base_dtype)
517 | ord_n_pk, ord_pk = tf.clip_by_value(ord_n_pk, epsilon, 1 - epsilon), tf.clip_by_value(ord_pk, epsilon, 1 - epsilon)
518 | ord_log_n_pk = tf.log(ord_n_pk)
519 | ord_log_pk = tf.log(ord_pk)
520 | (N, H, W, C) = ord_log_pk.get_shape().as_list()
521 | foreground_mask = tf.reshape(tf.sequence_mask(true_ord, C), shape=[-1, H, W, C])
522 | sum_of_p = tf.reduce_sum(tf.where(foreground_mask, ord_log_pk, ord_log_n_pk), axis=3)
523 | loss = -tf.reduce_mean(sum_of_p)
524 | loss = tf.identity(loss, name=name)
525 | return loss
526 |
527 |
528 |
529 | ####
530 | if get_current_tower_context().is_training:
531 | #---- LOSS ----#
532 | loss = 0
533 |
534 | for term, weight in self.loss_term.items():
535 | if term == 'bce':
536 | term_loss = categorical_crossentropy_modified(soft_np, one_np)
537 | term_loss = tf.reduce_mean(term_loss, name='loss-bce')
538 | elif term == 'dice':
539 | term_loss = dice_loss(soft_np[...,0], one_np[...,0]) \
540 | + dice_loss(soft_np[...,1], one_np[...,1])
541 | term_loss = tf.identity(term_loss, name='loss-dice')
542 | elif term == 'ord':
543 | term_loss = loss_ord(prob_ord, true_ord, name='loss-ord')
544 | else:
545 | assert False, 'Not support loss term: %s' % term
546 | add_moving_summary(term_loss)
547 | loss += term_loss * weight
548 |
549 | if self.type_classification:
550 | term_loss = focal_loss_modified(soft_class, one_type)
551 | term_loss = tf.reduce_mean(term_loss, name='loss-classification')
552 | add_moving_summary(term_loss)
553 | loss = loss + term_loss
554 |
555 |
556 |
557 |
558 | self.cost = tf.identity(loss, name='overall-loss')
559 | add_moving_summary(self.cost)
560 | ####
561 |
562 | add_param_summary(('.*/W', ['histogram'])) # monitor W
563 |
564 | return
565 |
566 |
567 | class Sonnet_phase2(Model):
568 |
569 | def build_graph(self, inputs, truemap_coded):
570 | images = inputs
571 | orig_imgs = images
572 |
573 | if hasattr(self, 'type_classification') and self.type_classification:
574 | true_type = truemap_coded[..., 1]
575 | true_type = tf.cast(true_type, tf.int32)
576 | true_type = tf.identity(true_type, name='truemap-type')
577 | one_type = tf.one_hot(true_type, self.nr_types, axis=-1)
578 | true_type = tf.expand_dims(true_type, axis=-1)
579 |
580 | true_np = tf.cast(true_type > 0, tf.int32)
581 | true_np = tf.identity(true_np, name='truemap-np')
582 | one_np = tf.one_hot(tf.squeeze(true_np, axis=-1), 2, axis=-1)
583 |
584 |
585 | else:
586 | true_np = truemap_coded[..., 0]
587 | true_np = tf.cast(true_np, tf.int32)
588 | one_np = tf.one_hot(true_np, 2, axis=-1)
589 | true_np = tf.expand_dims(true_np, axis=-1)
590 | true_np = tf.identity(true_np, name='truemap-np')
591 |
592 |
593 | true_ord = truemap_coded[...,-1]
594 | true_ord = tf.expand_dims(true_ord, axis=-1)
595 | true_ord = tf.identity(true_ord, name='true-ord')
596 |
597 | ####
598 | with argscope(Conv2D, activation=tf.identity, use_bias=False,
599 | W_init=tf.variance_scaling_initializer(scale=2.0, mode='fan_out')),\
600 | argscope([Conv2D, BatchNorm], data_format=self.data_format):
601 | i = tf.transpose(images, [0, 3, 1, 2])
602 | i = i if not self.input_norm else i / 255.0
603 |
604 | ####
605 | d = EfficientNetB0(i, self.freeze_en)
606 |
607 | np_feat = decoder('np', d)
608 | np_feat = tf.identity(np_feat, name='np_feat')
609 | npx = BNReLU('preact_out_np', np_feat)
610 |
611 |
612 | ordi_feat = decoder('ordi', d)
613 | ordi = BNReLU('preact_out_ordi', ordi_feat)
614 |
615 | if self.type_classification:
616 | tp_feat = decoder('tp', d)
617 | tp = BNReLU('preact_out_tp', tp_feat)
618 |
619 | # Nuclei Type Pixels (NT)
620 | logi_class = Conv2D('conv_out_tp', tp, self.nr_types, 1, use_bias=True, activation=tf.identity)
621 | logi_class = tf.transpose(logi_class, [0, 2, 3, 1])
622 | soft_class = tf.nn.softmax(logi_class, axis=-1)
623 |
624 | ### Nuclei Pixels (NF)
625 | logi_np = Conv2D('conv_out_np', npx, 2, 1, use_bias=True, activation=tf.identity)
626 | logi_np = tf.transpose(logi_np, [0, 2, 3, 1])
627 | soft_np = tf.nn.softmax(logi_np, axis=-1)
628 | prob_np = tf.identity(soft_np[...,1], name='predmap-prob-np')
629 | prob_np = tf.expand_dims(prob_np, axis=-1)
630 |
631 |
632 | ### Ordinal (NO)
633 | logi_ord = Conv2D('conv_out_ord', ordi, 16, 1, use_bias=True, activation=tf.identity)
634 | logi_ord_t = tf.transpose(logi_ord, [0, 2, 3, 1])
635 | N, C ,H, W = logi_ord.get_shape().as_list()
636 | ord_num = int(C/2)
637 | logi_ord = tf.reshape(logi_ord_t, shape=[-1, H, W, ord_num, 2])
638 | prob_ord = tf.nn.softmax(logi_ord, axis=-1)
639 | prob_ord = tf.identity(prob_ord, name='prob_ord')
640 | nn_out_labels = tf.reduce_sum(tf.argmax(prob_ord, axis=-1, output_type=tf.int32), axis=3, keepdims=True) # prediction output (sum to take the exact class, e.g.class 4)
641 | pred_ord = tf.identity(nn_out_labels, name='predmap-ord') # [N, 76, 76, 1]
642 |
643 | ### encoded so that inference can extract all output at once
644 | predmap_coded = tf.concat([soft_class, prob_np], axis=-1, name='predmap-coded')
645 |
646 |
647 |
648 |
649 |
650 | def loss_ord(prob_ord, true_ord, name=None):
651 | weight_map = tf.squeeze(true_ord, axis=-1)
652 | weight_map_1 = tf.where(
653 | tf.equal(weight_map, 1),
654 | 2 * tf.cast(tf.equal(weight_map, 1), tf.float32),
655 | tf.cast(tf.equal(weight_map, 2), tf.float32)
656 | )
657 | (ord_n_pk, ord_pk) = tf.unstack(prob_ord, axis=-1)
658 | epsilon = tf.convert_to_tensor(10e-8, ord_n_pk.dtype.base_dtype)
659 | ord_n_pk, ord_pk = tf.clip_by_value(ord_n_pk, epsilon, 1 - epsilon), tf.clip_by_value(ord_pk, epsilon, 1 - epsilon)
660 | ord_log_n_pk = tf.log(ord_n_pk)
661 | ord_log_pk = tf.log(ord_pk)
662 | (N, H, W, C) = ord_log_pk.get_shape().as_list()
663 | foreground_mask = tf.reshape(tf.sequence_mask(true_ord, C), shape=[-1, H, W, C])
664 | sum_of_p = tf.reduce_sum(tf.where(foreground_mask, ord_log_pk, ord_log_n_pk), axis=3)
665 | sum_of_p += sum_of_p * weight_map_1
666 | loss = -tf.reduce_mean(sum_of_p)
667 | loss = tf.identity(loss, name=name)
668 | return loss
669 |
670 |
671 |
672 | ####
673 | if get_current_tower_context().is_training:
674 | weight_map = tf.cast(tf.squeeze(pred_ord, axis=-1), tf.float32) # [N, 76, 76]
675 | weight_map_1 = tf.where(
676 | tf.equal(weight_map, 1),
677 | 2 * tf.cast(tf.equal(weight_map, 1), tf.float32),
678 | tf.cast(tf.equal(weight_map, 2), tf.float32)
679 | )
680 | #---- LOSS ----#
681 | loss = 0
682 |
683 | for term, weight in self.loss_term.items():
684 | if term == 'bce':
685 | term_loss = categorical_crossentropy_modified(soft_np, one_np)
686 | term_loss += weight_map_1 * term_loss
687 | term_loss = tf.reduce_mean(term_loss, name='loss-bce')
688 | elif term == 'dice':
689 | term_loss = dice_loss(soft_np[...,0], one_np[...,0]) \
690 | + dice_loss(soft_np[...,1], one_np[...,1])
691 | term_loss = tf.identity(term_loss, name='loss-dice')
692 | elif term == 'ord':
693 | term_loss = loss_ord(prob_ord, true_ord, name='loss-ord')
694 | else:
695 | assert False, 'Not support loss term: %s' % term
696 | add_moving_summary(term_loss)
697 | loss += term_loss * weight
698 |
699 | if self.type_classification:
700 | term_loss = focal_loss_modified(soft_class, one_type)
701 | term_loss += weight_map_1 * term_loss
702 | term_loss = tf.reduce_mean(term_loss, name='loss-classification')
703 | add_moving_summary(term_loss)
704 | loss = loss + term_loss
705 |
706 | self.cost = tf.identity(loss, name='overall-loss')
707 | add_moving_summary(self.cost)
708 | ####
709 | return
--------------------------------------------------------------------------------