├── results ├── result_1.png └── results_2.png ├── __pycache__ └── config.cpython-36.pyc ├── misc ├── __pycache__ │ ├── utils.cpython-36.pyc │ └── viz_utils.cpython-36.pyc ├── utils.py ├── viz_utils.py └── patch_extractor.py ├── loader ├── __pycache__ │ ├── augs.cpython-36.pyc │ └── loader.cpython-36.pyc ├── loader.py └── augs.py ├── model ├── __pycache__ │ ├── sonnet.cpython-36.pyc │ └── utils.cpython-36.pyc ├── utils.py └── sonnet.py ├── opt ├── __pycache__ │ └── hyperconfig.cpython-36.pyc └── hyperconfig.py ├── metrics ├── __pycache__ │ └── stats_utils.cpython-36.pyc └── stats_utils.py ├── postproc ├── __pycache__ │ └── post_sonnet.cpython-36.pyc └── post_sonnet.py ├── requirements.txt ├── extract_patches.py ├── process.py ├── README.md ├── config.py ├── infer.py ├── train.py └── compute_stats.py /results/result_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/results/result_1.png -------------------------------------------------------------------------------- /results/results_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/results/results_2.png -------------------------------------------------------------------------------- /__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /misc/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/misc/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /loader/__pycache__/augs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/loader/__pycache__/augs.cpython-36.pyc -------------------------------------------------------------------------------- /loader/__pycache__/loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/loader/__pycache__/loader.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/sonnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/model/__pycache__/sonnet.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/model/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /misc/__pycache__/viz_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/misc/__pycache__/viz_utils.cpython-36.pyc -------------------------------------------------------------------------------- /opt/__pycache__/hyperconfig.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/opt/__pycache__/hyperconfig.cpython-36.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/stats_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/metrics/__pycache__/stats_utils.cpython-36.pyc -------------------------------------------------------------------------------- /postproc/__pycache__/post_sonnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuIIL/Sonnet/HEAD/postproc/__pycache__/post_sonnet.cpython-36.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit_image==0.15.0 2 | pandas==1.1.0 3 | matplotlib==3.3.4 4 | tensorflow==1.12.0 5 | tensorpack==0.9.0.1 6 | opencv_python_headless==3.4.8.29 7 | numpy==1.19.2 8 | scipy==1.1.0 9 | scikit_learn==0.24.2 10 | skimage==0.0 11 | -------------------------------------------------------------------------------- /opt/hyperconfig.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | sonnet = { 4 | 'train_input_shape' : [270, 270], 5 | 'train_mask_shape' : [76, 76], 6 | 'infer_input_shape' : [270, 270], 7 | 'infer_mask_shape' : [76, 76], 8 | 9 | 'training_phase' : [ 10 | { 11 | 'nr_epochs': 50, 12 | 'manual_parameters' : { 13 | # tuple(initial value, schedule) 14 | 'learning_rate': (1.0e-4, [('25', 1.0e-5)]), 15 | }, 16 | 'pretrained_path' : './ImageNet_pretrained_EfficientB0.npz', 17 | 'train_batch_size' : 8, 18 | 'infer_batch_size' : 16, 19 | 20 | 'model_flags' : { 21 | 'freeze_en' : True 22 | } 23 | }, 24 | 25 | { 26 | 'nr_epochs': 25, 27 | 'manual_parameters' : { 28 | # tuple(initial value, schedule) 29 | 'learning_rate': (1.0e-4, [('25', 1.0e-5)]), 30 | }, 31 | # path to load, -1 to auto load checkpoint from previous phase, 32 | # None to start from scratch 33 | 'pretrained_path' : -1, 34 | 'train_batch_size' : 4, 35 | 'infer_batch_size' : 8, 36 | 37 | 'model_flags' : { 38 | 'freeze_en' : False 39 | } 40 | }, 41 | 42 | { 43 | 'nr_epochs': 25, 44 | 'manual_parameters' : { 45 | # tuple(initial value, schedule) 46 | 'learning_rate': (1.0e-5, [('25', 1.0e-5)]), 47 | }, 48 | # path to load, -1 to auto load checkpoint from previous phase, 49 | # None to start from scratch 50 | 'pretrained_path' : -1, 51 | 'train_batch_size' : 4, 52 | 'infer_batch_size' : 8, 53 | 54 | 'model_flags' : { 55 | 'freeze_en' : False 56 | } 57 | } 58 | ], 59 | 60 | 'loss_term' : {'bce' : 1, 'dice' : 1}, 61 | 62 | 'train_optimizer' : tf.train.AdamOptimizer, 63 | 64 | 'inf_auto_metric' : 'valid_dice', 65 | 'inf_auto_comparator' : '>', 66 | 'inf_batch_size' : 4, 67 | } -------------------------------------------------------------------------------- /misc/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import os 4 | import shutil 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | 10 | #### 11 | def normalize(mask, dtype=np.uint8): 12 | return (255 * mask / np.amax(mask)).astype(dtype) 13 | 14 | #### 15 | def bounding_box(img): 16 | rows = np.any(img, axis=1) 17 | cols = np.any(img, axis=0) 18 | rmin, rmax = np.where(rows)[0][[0, -1]] 19 | cmin, cmax = np.where(cols)[0][[0, -1]] 20 | # due to python indexing, need to add 1 to max 21 | # else accessing will be 1px in the box, not out 22 | rmax += 1 23 | cmax += 1 24 | return [rmin, rmax, cmin, cmax] 25 | 26 | #### 27 | def cropping_center(x, crop_shape, batch=False): 28 | orig_shape = x.shape 29 | if not batch: 30 | h0 = int((orig_shape[0] - crop_shape[0]) * 0.5) 31 | w0 = int((orig_shape[1] - crop_shape[1]) * 0.5) 32 | x = x[h0:h0 + crop_shape[0], w0:w0 + crop_shape[1]] 33 | else: 34 | h0 = int((orig_shape[1] - crop_shape[0]) * 0.5) 35 | w0 = int((orig_shape[2] - crop_shape[1]) * 0.5) 36 | x = x[:,h0:h0 + crop_shape[0], w0:w0 + crop_shape[1]] 37 | return x 38 | 39 | #### 40 | def rm_n_mkdir(dir_path): 41 | if (os.path.isdir(dir_path)): 42 | shutil.rmtree(dir_path) 43 | os.makedirs(dir_path) 44 | 45 | #### 46 | def get_files(data_dir_list, data_ext): 47 | """ 48 | Given a list of directories containing data with extention 'data_ext', 49 | generate a list of paths for all files within these directories 50 | """ 51 | 52 | data_files = [] 53 | for sub_dir in data_dir_list: 54 | files_list = glob.glob(sub_dir + '/*'+ data_ext) 55 | files_list.sort() # ensure same order 56 | data_files.extend(files_list) 57 | 58 | return data_files 59 | 60 | #### 61 | def get_inst_centroid(inst_map): 62 | inst_centroid_list = [] 63 | inst_id_list = list(np.unique(inst_map)) 64 | for inst_id in inst_id_list[1:]: # avoid 0 i.e background 65 | mask = np.array(inst_map == inst_id, np.uint8) 66 | inst_moment = cv2.moments(mask) 67 | inst_centroid = [(inst_moment["m10"] / inst_moment["m00"]), 68 | (inst_moment["m01"] / inst_moment["m00"])] 69 | inst_centroid_list.append(inst_centroid) 70 | return np.array(inst_centroid_list) 71 | 72 | ### 73 | def change_keys_dict(a_dict): 74 | for i in a_dict.copy(): 75 | if 'bias' in i: 76 | new_key = i.replace('bias', 'b') 77 | a_dict[new_key] = a.pop(i) 78 | np.savez('new_dict', **a_dict) -------------------------------------------------------------------------------- /misc/viz_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | import math 4 | import random 5 | import colorsys 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from matplotlib import cm 9 | 10 | from .utils import bounding_box 11 | 12 | #### 13 | def random_colors(N, bright=True): 14 | """ 15 | Generate random colors. 16 | To get visually distinct colors, generate them in HSV space then 17 | convert to RGB. 18 | """ 19 | brightness = 1.0 if bright else 0.7 20 | hsv = [(i / N, 1, brightness) for i in range(N)] 21 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) 22 | random.shuffle(colors) 23 | return colors 24 | 25 | #### 26 | def visualize_instances(mask, canvas=None, color=None): 27 | """ 28 | Args: 29 | mask: array of NW 30 | Return: 31 | Image with the instance overlaid 32 | """ 33 | 34 | canvas = np.full(mask.shape + (3,), 200, dtype=np.uint8) \ 35 | if canvas is None else np.copy(canvas) 36 | 37 | insts_list = list(np.unique(mask)) 38 | insts_list.remove(0) # remove background 39 | 40 | inst_colors = random_colors(len(insts_list)) 41 | inst_colors = np.array(inst_colors) * 255 42 | 43 | for idx, inst_id in enumerate(insts_list): 44 | inst_color = color[idx] if color is not None else inst_colors[idx] 45 | inst_map = np.array(mask == inst_id, np.uint8) 46 | y1, y2, x1, x2 = bounding_box(inst_map) 47 | y1 = y1 - 2 if y1 - 2 >= 0 else y1 48 | x1 = x1 - 2 if x1 - 2 >= 0 else x1 49 | x2 = x2 + 2 if x2 + 2 <= mask.shape[1] - 1 else x2 50 | y2 = y2 + 2 if y2 + 2 <= mask.shape[0] - 1 else y2 51 | inst_map_crop = inst_map[y1:y2, x1:x2] 52 | inst_canvas_crop = canvas[y1:y2, x1:x2] 53 | contours = cv2.findContours(inst_map_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 54 | cv2.drawContours(inst_canvas_crop, contours[0], -1, inst_color, 2) 55 | canvas[y1:y2, x1:x2] = inst_canvas_crop 56 | return canvas 57 | 58 | #### 59 | def gen_figure(imgs_list, titles, fig_inch, shape=None, 60 | share_ax='all', show=False, colormap=plt.get_cmap('jet')): 61 | 62 | num_img = len(imgs_list) 63 | if shape is None: 64 | ncols = math.ceil(math.sqrt(num_img)) 65 | nrows = math.ceil(num_img / ncols) 66 | else: 67 | nrows, ncols = shape 68 | 69 | # generate figure 70 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols, 71 | sharex=share_ax, sharey=share_ax) 72 | axes = [axes] if nrows == 1 else axes 73 | 74 | # not very elegant 75 | idx = 0 76 | for ax in axes: 77 | for cell in ax: 78 | cell.set_title(titles[idx]) 79 | cell.imshow(imgs_list[idx], cmap=colormap) 80 | cell.tick_params(axis='both', 81 | which='both', 82 | bottom='off', 83 | top='off', 84 | labelbottom='off', 85 | right='off', 86 | left='off', 87 | labelleft='off') 88 | idx += 1 89 | if idx == len(titles): 90 | break 91 | if idx == len(titles): 92 | break 93 | 94 | fig.tight_layout() 95 | return fig 96 | #### 97 | -------------------------------------------------------------------------------- /extract_patches.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import os 4 | 5 | import cv2 6 | import numpy as np 7 | import scipy.io as sio 8 | from scipy.ndimage import measurements 9 | import scipy.io as sio 10 | 11 | from misc.patch_extractor import PatchExtractor 12 | from misc.utils import rm_n_mkdir 13 | 14 | from config import Config 15 | 16 | ########################################################################### 17 | if __name__ == '__main__': 18 | 19 | cfg = Config() 20 | 21 | extract_type = 'mirror' # 'valid' for fcn8 segnet etc. 22 | # 'mirror' for u-net etc. 23 | # check the patch_extractor.py 'main' to see the different 24 | 25 | # orignal size (win size) - input size - output size (step size) 26 | # 540x540 - 270x270 - 76x76 sonnet 27 | step_size = [76, 76] # should match self.train_mask_shape (config.py) 28 | win_size = [540, 540] # should be at least twice time larger than 29 | # self.train_base_shape (config.py) to reduce 30 | # the padding effect during augmentation 31 | 32 | xtractor = PatchExtractor(win_size, step_size) 33 | 34 | ### Paths to data - these need to be modified according to where the original data is stored 35 | img_ext = '.tif' 36 | img_dir = '/media/tandoan/Data/data/Monusac/Test(split)/Images' 37 | ann_dir = '/media/tandoan/Data/data/Monusac/Test(split)/Labels(mat)' 38 | #### 39 | out_dir = "/media/tandoan/Data/data/Monusac/%dx%d_%dx%d" % \ 40 | (win_size[0], win_size[1], step_size[0], step_size[1]) 41 | 42 | file_list = glob.glob('%s/*%s' % (img_dir, img_ext)) 43 | file_list.sort() 44 | 45 | rm_n_mkdir(out_dir) 46 | for filename in file_list: 47 | filename = os.path.basename(filename) 48 | basename = filename.split('.')[0] 49 | print(filename) 50 | 51 | img = cv2.imread(img_dir + '/' + basename + img_ext) 52 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 53 | 54 | if cfg.type_classification: 55 | # assumes that ann is HxWx2 (nuclei class labels are available at index 1 of C) 56 | ann = sio.loadmat(ann_dir + '/' + basename + '.mat') 57 | ann_inst = ann['inst_map'] 58 | ann_type = ann['type_map'] 59 | 60 | # merge classes for CoNSeP (in paper we only utilise 3 nuclei classes and background) 61 | # If own dataset is used, then the below may need to be modified 62 | if cfg.data_type == 'glysac': 63 | ann_type[(ann_type == 1) | (ann_type == 2) | (ann_type == 9) | (ann_type == 10)] = 1 64 | ann_type[(ann_type == 4) | (ann_type == 5) | (ann_type == 6) | (ann_type == 7)] = 2 65 | ann_type[(ann_type == 8) | (ann_type == 3)] = 3 66 | elif cfg.data_type == 'consep': 67 | ann_type[(ann_type == 3) | (ann_type == 4)] = 3 68 | ann_type[(ann_type == 5) | (ann_type == 6) | (ann_type == 7)] = 4 69 | assert np.max(ann_type) <= cfg.nr_types-1, \ 70 | "Only %d types of nuclei are defined for training"\ 71 | "but there are %d types found in the input image." % (cfg.nr_types, np.max(ann_type)) 72 | 73 | ann = np.dstack([ann_inst, ann_type]) 74 | ann = ann.astype('int32') 75 | 76 | img = np.concatenate([img, ann], axis=-1) 77 | 78 | sub_patches = xtractor.extract(img, extract_type) 79 | for idx, patch in enumerate(sub_patches): 80 | np.save("{0}/{1}_{2:03d}.npy".format(out_dir, basename, idx), patch) -------------------------------------------------------------------------------- /postproc/post_sonnet.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from numpy.random import default_rng 4 | from scipy.ndimage import filters, measurements, find_objects 5 | from scipy.ndimage.morphology import (binary_dilation, binary_fill_holes, 6 | distance_transform_cdt, 7 | distance_transform_edt) 8 | from skimage.morphology import remove_small_objects, binary_erosion 9 | import tensorflow as tf 10 | from matplotlib import cm 11 | from skimage import measure 12 | from skimage.morphology import remove_small_objects 13 | from skimage.segmentation import watershed 14 | import matplotlib.pyplot as plt 15 | from sklearn.utils import shuffle 16 | from collections import OrderedDict 17 | 18 | def colorize(value, vmin=None, vmax=None, cmap=None): 19 | """ 20 | Arguments: 21 | - value: input tensor, NHWC ('channels_last') 22 | - vmin: the minimum value of the range used for normalization. 23 | (Default: value minimum) 24 | - vmax: the maximum value of the range used for normalization. 25 | (Default: value maximum) 26 | - cmap: a valid cmap named for use with matplotlib's `get_cmap`. 27 | (Default: 'gray') 28 | Example usage: 29 | ``` 30 | output = tf.random_uniform(shape=[256, 256, 1]) 31 | output_color = colorize(output, vmin=0.0, vmax=1.0, cmap='viridis') 32 | tf.summary.image('output', output_color) 33 | ``` 34 | 35 | Returns a 3D tensor of shape [height, width, 3], uint8. 36 | """ 37 | 38 | # normalize 39 | if vmin is None: 40 | vmin = tf.reduce_min(value, axis=[1, 2]) 41 | vmin = tf.reshape(vmin, [-1, 1, 1]) 42 | if vmax is None: 43 | vmax = tf.reduce_max(value, axis=[1, 2]) 44 | vmax = tf.reshape(vmax, [-1, 1, 1]) 45 | value = (value - vmin) / (vmax - vmin) # vmin..vmax 46 | 47 | # squeeze last dim if it exists 48 | # NOTE: will throw error if use get_shape() 49 | # value = tf.squeeze(value) 50 | 51 | # quantize 52 | value = tf.round(value * 255) 53 | indices = tf.cast(value, np.int32) 54 | 55 | # gather 56 | colormap = cm.get_cmap(cmap if cmap is not None else 'gray') 57 | colors = colormap(np.arange(256))[:, :3] 58 | colors = tf.constant(colors, dtype=tf.float32) 59 | value = tf.gather(colors, indices) 60 | value = tf.cast(value * 255, tf.uint8) 61 | return value 62 | 63 | def proc_np_ord(pred, pred_ord): 64 | """ 65 | Process Nuclei Prediction with The ordinal map 66 | 67 | Args: 68 | pred: prediction output (NP branch) 69 | pred_ord: ordinal prediction output (ordinal branch) 70 | """ 71 | 72 | blb_raw = pred 73 | 74 | pred_ord = np.squeeze(pred_ord) 75 | distance = -pred_ord 76 | marker = np.copy(pred_ord) 77 | marker[marker <= 4] = 0 78 | marker[marker > 4] = 1 79 | marker = binary_dilation(marker, iterations=1) 80 | # marker = binary_erosion(marker) 81 | # marker = binary_erosion(marker) 82 | # kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5, 5)) 83 | # marker = cv2.morphologyEx(np.float32(marker), cv2.MORPH_OPEN, kernel) 84 | marker = measurements.label(marker)[0] 85 | marker = remove_small_objects(marker, min_size=10) 86 | 87 | # Processing 88 | blb = np.copy(blb_raw) 89 | blb[blb >= 0.5] = 1 90 | blb[blb < 0.5] = 0 91 | 92 | blb = measurements.label(blb)[0] 93 | blb = remove_small_objects(blb, min_size=10) 94 | blb[blb > 0] = 1 # background is 0 already 95 | 96 | markers = marker * blb 97 | 98 | proced_pred = watershed(distance, markers, mask=blb) 99 | 100 | return proced_pred -------------------------------------------------------------------------------- /process.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import os 4 | 5 | import cv2 6 | import numpy as np 7 | import scipy.io as sio 8 | from scipy.ndimage import filters, measurements 9 | from scipy.ndimage.morphology import (binary_dilation, binary_fill_holes, 10 | distance_transform_cdt, 11 | distance_transform_edt) 12 | from skimage.morphology import remove_small_objects, watershed 13 | import matplotlib.pyplot as plt 14 | 15 | import postproc.post_sonnet 16 | 17 | from config import Config 18 | 19 | from misc.viz_utils import visualize_instances 20 | from misc.utils import get_inst_centroid 21 | from metrics.stats_utils import remap_label 22 | 23 | ########## 24 | 25 | ## ! WARNING: 26 | ## check the prediction channels, wrong ordering will break the code ! 27 | ## the prediction channels ordering should match the ones produced in augs.py 28 | 29 | cfg = Config() 30 | 31 | 32 | pred_dir = cfg.inf_output_dir 33 | proc_dir = pred_dir + '/_proc' 34 | 35 | file_list = glob.glob('%s/*.mat' % (pred_dir)) 36 | file_list.sort() # ensure same order 37 | 38 | if not os.path.isdir(proc_dir): 39 | os.makedirs(proc_dir) 40 | 41 | for filename in file_list: 42 | filename = os.path.basename(filename) 43 | basename = filename.split('.')[0] 44 | print(pred_dir, basename, end=' ', flush=True) 45 | 46 | 47 | pred_mat = sio.loadmat('%s/%s.mat' % (pred_dir, basename)) 48 | pred = np.squeeze(pred_mat['result']) 49 | pred_ord = np.squeeze(pred_mat['result-ord']) 50 | 51 | if hasattr(cfg, 'type_classification') and cfg.type_classification: 52 | pred_inst = pred[...,cfg.nr_types:] 53 | pred_type = pred[...,:cfg.nr_types] 54 | 55 | pred_inst = np.squeeze(pred_inst) 56 | pred_type = np.argmax(pred_type, axis=-1) 57 | ### 58 | 59 | pred_inst = postproc.post_sonnet.proc_np_ord(pred_inst, pred_ord) 60 | 61 | # ! will be extremely slow on WSI/TMA so it's advisable to comment this out 62 | # * remap once so that further processing faster (metrics calculation, etc.) 63 | pred_inst = remap_label(pred_inst, by_size=True) 64 | 65 | 66 | if cfg.type_classification: 67 | #### * Get class of each instance id, stored at index id-1 68 | pred_id_list = list(np.unique(pred_inst))[1:] # exclude background ID 69 | pred_inst_type = np.full(len(pred_id_list), 0, dtype=np.int32) 70 | for idx, inst_id in enumerate(pred_id_list): 71 | inst_type = pred_type[(pred_inst == inst_id)&(pred_ord>=5)] 72 | type_list, type_pixels = np.unique(inst_type, return_counts=True) 73 | type_list = list(zip(type_list, type_pixels)) 74 | type_list = sorted(type_list, key=lambda x: x[1], reverse=True) 75 | try: 76 | inst_type = type_list[0][0] 77 | except IndexError: 78 | inst_type = 0 79 | if inst_type == 0: # ! pick the 2nd most dominant if exist 80 | if len(type_list) > 1: 81 | if (type_list[1][1] / type_list[0][1]) > 0.5: 82 | inst_type = type_list[1][0] 83 | else: 84 | inst_type = type_list[0][0] 85 | print('[Warn] Instance has `background` type') 86 | else: 87 | print('[Warn] Instance has `background` type' ) 88 | pred_inst_type[idx] = inst_type 89 | pred_inst_centroid = get_inst_centroid(pred_inst) 90 | 91 | sio.savemat('%s/%s.mat' % (proc_dir, basename), 92 | {'inst_map' : pred_inst, 93 | 'inst_type' : pred_inst_type[:, None], 94 | 'inst_centroid' : pred_inst_centroid, 95 | 'type_map' : pred_type 96 | }) 97 | 98 | ## 99 | print('FINISH') 100 | -------------------------------------------------------------------------------- /loader/loader.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from tensorpack.dataflow import (AugmentImageComponent, AugmentImageComponents, 6 | BatchData, BatchDataByShape, CacheData, 7 | PrefetchDataZMQ, RNGDataFlow, RepeatedData) 8 | 9 | #### 10 | class DatasetSerial(RNGDataFlow): 11 | """ 12 | Produce ``(image, label)`` pair, where 13 | ``image`` has shape HWC and is RGB, has values in range [0-255]. 14 | 15 | ``label`` is a float image of shape (H, W, C). Number of C depends 16 | on `self.model_mode` within `config.py` 17 | 18 | If self.model_mode is 'sonnet': 19 | channel 0 binary nuclei map, values are either 0 (background) or 1 (nuclei) 20 | channel 1 containing the type map 21 | channel 2 containing the ordinal map 22 | """ 23 | 24 | def __init__(self, path_list): 25 | self.path_list = path_list 26 | ## 27 | def size(self): 28 | return len(self.path_list) 29 | ## 30 | def get_data(self): 31 | idx_list = list(range(0, len(self.path_list))) 32 | random.shuffle(idx_list) 33 | for idx in idx_list: 34 | 35 | data = np.load(self.path_list[idx]) 36 | 37 | # split stacked channel into image and label 38 | img = data[...,:3] # RGB images 39 | ann = data[...,3:] # ann map 40 | 41 | img = img.astype('uint8') 42 | yield [img, ann] 43 | 44 | 45 | #### 46 | def valid_generator(ds, shape_aug=None, input_aug=None, label_aug=None, batch_size=8, nr_procs=1): 47 | ### augment both the input and label 48 | ds = ds if shape_aug is None else AugmentImageComponents(ds, shape_aug, (0, 1), copy=True) 49 | ### augment just the input 50 | ds = ds if input_aug is None else AugmentImageComponent(ds, input_aug, index=0, copy=False) 51 | ### augment just the output 52 | ds = ds if label_aug is None else AugmentImageComponent(ds, label_aug, index=1, copy=True) 53 | # 54 | ds = BatchData(ds, batch_size, remainder=True) 55 | ds = CacheData(ds) # cache all inference images 56 | # ds = DataThread(batch_size, ds) 57 | # ds = QueueInput(ds) 58 | return ds 59 | 60 | #### 61 | def train_generator(ds, shape_aug=None, input_aug=None, label_aug=None, batch_size=4, nr_procs=4): 62 | ### augment both the input and label 63 | ds = ds if shape_aug is None else AugmentImageComponents(ds, shape_aug, (0, 1), copy=True) 64 | ### augment just the input i.e index 0 within each yield of DatasetSerial 65 | ds = ds if input_aug is None else AugmentImageComponent(ds, input_aug, index=0, copy=False) 66 | ### augment just the output i.e index 1 within each yield of DatasetSerial 67 | ds = ds if label_aug is None else AugmentImageComponent(ds, label_aug, index=1, copy=True) 68 | # 69 | ds = BatchDataByShape(ds, batch_size, idx=0) 70 | ds = PrefetchDataZMQ(ds, nr_procs) 71 | return ds 72 | 73 | #### 74 | def visualize(datagen, batch_size, view_size=4): 75 | """ 76 | Read the batch from 'datagen' and display 'view_size' number of 77 | of images and their corresponding Ground Truth 78 | """ 79 | def prep_imgs(img, ann): 80 | cmap = plt.get_cmap('viridis') 81 | # cmap may randomly fails if of other types 82 | ann = ann.astype('float32') 83 | ann_chs = np.dsplit(ann, ann.shape[-1]) 84 | for i, ch in enumerate(ann_chs): 85 | ch = np.squeeze(ch) 86 | # normalize to -1 to 1 range else 87 | # cmap may behave stupidly 88 | ch = ch / (np.max(ch) - np.min(ch) + 1.0e-16) 89 | # take RGB from RGBA heat map 90 | ann_chs[i] = cmap(ch)[...,:3] 91 | img = img.astype('float32') / 255.0 92 | prepped_img = np.concatenate([img] + ann_chs, axis=1) 93 | return prepped_img 94 | 95 | assert view_size <= batch_size, 'Number of displayed images must <= batch size' 96 | ds = RepeatedData(datagen, -1) 97 | ds.reset_state() 98 | for imgs, segs in ds.get_data(): 99 | for idx in range (0, view_size): 100 | displayed_img = prep_imgs(imgs[idx], segs[idx]) 101 | plt.subplot(view_size, 1, idx+1) 102 | plt.imshow(displayed_img) 103 | plt.show() 104 | return 105 | ### -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sonnet: A self-guided ordinal regression neural network for segmentation and classification of nuclei in large-scale multi-tissue histology images 2 | 3 | ## Overview 4 | 5 | This repository contains a tensorflow implementation of SONNET: a self-guided ordinal regression neural network that performs simultaneously nuclei segmentation and classification. By introducing a distance decreasing discretization strategy, the network can detect nuclear pixels (inner pixels) and high uncertainty regions (outer pixels). The self-guided training strategy is applied to high uncertainty regions to improve the final outcome.
6 | As part of this research, we introduce a new dataset for Gastric Lymphocyte Segmentation And Classification (GLySAC), which includes 59 H&E stained image tiles, of size 1000x1000. More details can be found in the SONNET paper. 7 | 8 | The repository includes: 9 | - Source code of SONNET. 10 | - Training code for SONNET. 11 | - Datasets employed in the SONNET paper. 12 | - Pretrained weights for the SONNET encoder. 13 | - Evaluation on nuclei segmentation metrics (DICE, AJI, DQ, SQ, PQ) and nuclei classification metrics (Fd and F1 scores). 14 | 15 | ## Installation 16 | ``` 17 | conda create --name sonnet python=3.6 18 | conda activate sonnet 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | ## Dataset 23 | Download the CoNSeP dataset from this [link](https://warwick.ac.uk/fac/sci/dcs/research/tia/data/hovernet/).
24 | Download the PanNuke dataset from this [link](https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke)
25 | Download the MoNuSAC and GLySAC from this [link](https://drive.google.com/drive/folders/1p0Yt2w8MTcaZJU3bdh0fAtTrPWin1-zb?usp=sharing)
26 | 27 | 28 | ## Step by Step Instruction 29 | To help with debugging and applying the model for nuclei segmentation and classification, it requires four steps: 30 | 31 | ### Step 1: Extracting the original data into patches 32 | To train the model, the data needs to be extracted into patches. First, set the dataset name in ```config.py```. Then, To extract data into patches for training, simply run:
33 | `python extract_patches.py`
34 | The patches are numpy arrays with the shape of [RGB, inst, type], where RGB is the input image, inst is the foreground/background groundtruth, type is the type groundtruth map. 35 | 36 | ### Step 2: Training the model 37 | Download the pretrained-weights of the encoder used in SONNET on this [link](https://drive.google.com/drive/folders/1p0Yt2w8MTcaZJU3bdh0fAtTrPWin1-zb?usp=sharing)
38 | Before training the network: 39 | - Set the training dataset and validation dataset path in `config.py`. 40 | - Set the pretrained-weights of the encoder in `config.py`. 41 | - Set the log path in `config.py`. 42 | - Change the hyperparameters used in training process according to your need in `opt/hyperconfig.py`.
43 | To train the network with GPUs 0 and 1: 44 | ``` 45 | python train.py --gpu='0,1' 46 | ``` 47 | ### Step 3: Inference 48 | Before testing the network on the test dataset: 49 | - Set path to the test dataset, path to the model trained weights, path to save the output in `config.py`. 50 | Achieve the network prediction by the command: 51 | ``` 52 | python infer.py --gpu='0' 53 | ``` 54 | It is notice that the inference only support for 1 GPU only.
55 | To obtain the final outcome, i.e. the instance map and the type of each nuclear, run the command: 56 | ``` 57 | python process.py 58 | ``` 59 | ### Step 4: Calculate the metrics 60 | To calculate the metrics used in SONNET paper, run the command: 61 | - instance segmentation: `python compute_stats.py --mode=instance --pred_dir='pred_dir' --true_dir='true_dir'` 62 | - type classification: `python compute_stats.py --mode=type --pred_dir='pred_dir' --true_dir='true_dir'` 63 | 64 | ## Visual Results 65 | 66 | 67 | 70 | 73 | 74 |
68 | 69 | 71 | 72 |
75 | 76 | Type of each nuclear is represented by the color:
77 | - Pink for Epithelial.
78 | - Yellow for Lymphocyte.
79 | - Blue for Miscellaneous.
80 | 81 | ## Requirements 82 | Python 3.6, Tensorflow 1.12 and other common packages listed in requirements.txt 83 | 84 | ## Citation 85 | The datasets that we used in SONNET are from these papers: 86 | - CoNSeP: [paper](https://www.sciencedirect.com/science/article/pii/S1361841519301045). 87 | - MoNuSAC: [paper](https://ieeexplore.ieee.org/document/9446924). 88 | - PanNuke: [paper](https://arxiv.org/abs/2003.10778).
89 | 90 | If you use any of these dataset for your research, please have a citation of it. 91 | 92 | -------------------------------------------------------------------------------- /loader/augs.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | 7 | from scipy.ndimage import measurements 8 | from scipy.ndimage.filters import gaussian_filter 9 | from scipy.ndimage.interpolation import affine_transform, map_coordinates 10 | from scipy.ndimage.morphology import (distance_transform_cdt, 11 | distance_transform_edt) 12 | from skimage import morphology as morph 13 | 14 | from tensorpack.dataflow.imgaug import ImageAugmentor 15 | from tensorpack.utils.utils import get_rng 16 | 17 | from misc.utils import cropping_center, bounding_box, get_inst_centroid 18 | 19 | #### 20 | class GenInstance(ImageAugmentor): 21 | def __init__(self, crop_shape=None): 22 | super(GenInstance, self).__init__() 23 | self.crop_shape = crop_shape 24 | 25 | def reset_state(self): 26 | self.rng = get_rng(self) 27 | 28 | def _fix_mirror_padding(self, ann): 29 | """ 30 | Deal with duplicated instances due to mirroring in interpolation 31 | during shape augmentation (scale, rotation etc.) 32 | """ 33 | current_max_id = np.amax(ann) 34 | inst_list = list(np.unique(ann)) 35 | inst_list.remove(0) # 0 is background 36 | for inst_id in inst_list: 37 | inst_map = np.array(ann == inst_id, np.uint8) 38 | remapped_ids = measurements.label(inst_map)[0] 39 | remapped_ids[remapped_ids > 1] += int(current_max_id) 40 | ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1] 41 | current_max_id = np.amax(ann) 42 | return ann 43 | #### 44 | 45 | 46 | #### 47 | class GenInstanceOrd(GenInstance): 48 | """ 49 | Generate an ordinal distance map based on the instance map 50 | First, the euclidead distance map will be calculated. Then, the ordinal map is generated based on the euclidean distance map 51 | """ 52 | 53 | def _augment(self, img, _): 54 | img = np.copy(img) 55 | orig_ann = img[...,0].astype(np.int32) # instance ID map 56 | fixed_ann = self._fix_mirror_padding(orig_ann) 57 | # re-cropping with fixed instance id map 58 | crop_ann = cropping_center(fixed_ann, self.crop_shape) 59 | with warnings.catch_warnings(): 60 | warnings.simplefilter("ignore") 61 | crop_ann = morph.remove_small_objects(crop_ann, min_size=7) 62 | 63 | inst_list = list(np.unique(crop_ann)) 64 | # inst_centroid = get_inst_centroid(crop_ann) 65 | inst_list.remove(0) 66 | mask = np.zeros_like(fixed_ann, dtype=np.float32) 67 | for inst_id in inst_list: 68 | inst_id_map = np.copy(np.array(fixed_ann == inst_id, dtype=np.uint8)) 69 | M = cv2.moments(inst_id_map) 70 | cx, cy = int(M['m10'] / M['m00']), int(M['m01'] / M['m00']) 71 | inst_id_map[fixed_ann != inst_id] = 2 72 | inst_id_map[int(cy), int(cx)] = 0 73 | inst_id_map = distance_transform_edt(inst_id_map) 74 | inst_id_map[fixed_ann != inst_id] = 0 75 | max_val = np.max(inst_id_map) 76 | inst_id_map = inst_id_map / max_val 77 | mask[fixed_ann == inst_id] = inst_id_map[fixed_ann == inst_id] 78 | 79 | def gen_ord(euc_map): 80 | lut_gt = [1, 0.83, 0.68, 0.54, 0.41, 0.29, 0.19, 0.09, 0] 81 | zeros = np.zeros_like(euc_map) 82 | ones = np.ones_like(euc_map) 83 | decoded_label = np.full(euc_map.shape, 0, dtype=np.float32) 84 | for k in range(8): 85 | if k != 7: 86 | decoded_label += np.where((euc_map <= lut_gt[k]) & (euc_map > lut_gt[k+1]), ones * (k + 1), zeros) 87 | else: 88 | decoded_label += np.where((euc_map <= lut_gt[k]) & (euc_map >= lut_gt[k+1]), ones * (k + 1), zeros) 89 | return decoded_label 90 | 91 | ord_map = gen_ord(mask) 92 | img = img.astype('float32') 93 | img = np.dstack([img, ord_map]) 94 | 95 | return img 96 | #### 97 | 98 | class GaussianBlur(ImageAugmentor): 99 | """ Gaussian blur the image with random window size""" 100 | def __init__(self, max_size=3): 101 | """ 102 | Args: 103 | max_size (int): max possible Gaussian window size would be 2 * max_size + 1 104 | """ 105 | super(GaussianBlur, self).__init__() 106 | self.max_size = max_size 107 | 108 | def _get_augment_params(self, img): 109 | sx, sy = self.rng.randint(1, self.max_size, size=(2,)) 110 | sx = sx * 2 + 1 111 | sy = sy * 2 + 1 112 | return sx, sy 113 | 114 | def _augment(self, img, s): 115 | return np.reshape(cv2.GaussianBlur(img, s, sigmaX=0, sigmaY=0, 116 | borderType=cv2.BORDER_REPLICATE), img.shape) 117 | 118 | #### 119 | class BinarizeLabel(ImageAugmentor): 120 | """ Convert labels to binary maps""" 121 | def __init__(self): 122 | super(BinarizeLabel, self).__init__() 123 | 124 | def _get_augment_params(self, img): 125 | return None 126 | 127 | def _augment(self, img, s): 128 | img = np.copy(img) 129 | arr = img[...,0] 130 | arr[arr > 0] = 1 131 | return img 132 | 133 | #### 134 | class MedianBlur(ImageAugmentor): 135 | """ Median blur the image with random window size""" 136 | def __init__(self, max_size=3): 137 | """ 138 | Args: 139 | max_size (int): max possible window size 140 | would be 2 * max_size + 1 141 | """ 142 | super(MedianBlur, self).__init__() 143 | self.max_size = max_size 144 | 145 | def _get_augment_params(self, img): 146 | s = self.rng.randint(1, self.max_size) 147 | s = s * 2 + 1 148 | return s 149 | 150 | def _augment(self, img, ksize): 151 | return cv2.medianBlur(img, ksize) 152 | 153 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import cv2 4 | import numpy as np 5 | import tensorflow as tf 6 | from tensorpack import imgaug 7 | 8 | from loader.augs import (BinarizeLabel, GaussianBlur, 9 | GenInstanceOrd, MedianBlur) 10 | 11 | #### 12 | class Config(object): 13 | def __init__(self, ): 14 | 15 | self.seed = 9 16 | self.model_type = 'sonnet' 17 | self.data_type = 'consep' 18 | 19 | self.type_classification = True 20 | self.nr_types = 5 21 | self.nr_classes = 2 # Nuclei Pixels vs Background 22 | 23 | # define your nuclei type name here, please ensure it contains 24 | # same the amount as defined in `self.nr_types` . ID 0 is preserved 25 | # for background so please don't use it as ID 26 | if self.data_type == 'consep': 27 | self.nuclei_type_dict = { 28 | 'Miscellaneous': 1, # ! Please ensure the matching ID is unique 29 | 'Inflammatory' : 2, 30 | 'Epithelial' : 3, 31 | 'Spindle' : 4, 32 | } 33 | elif self.data_type == 'monusac': 34 | self.nuclei_type_dict ={ 35 | 'Epithelial' : 1, 36 | 'Lymphocyte' : 2, 37 | 'Macrophages': 3, 38 | 'Neutrophil' : 4 39 | } 40 | elif self.data_type == 'pannuke': 41 | self.nuclei_type_dict ={ 42 | 'Neoplastic' : 1, 43 | 'Inflammatory' : 2, 44 | 'Connective': 3, 45 | 'Dead' : 4, 46 | 'Non-Neoplastic Epithelial' : 5 47 | } 48 | else: 49 | self.nuclei_type_dict ={ 50 | 'Other' : 1, 51 | 'Lymphocyte' : 2, 52 | 'Epithelial' : 3 53 | } 54 | assert len(self.nuclei_type_dict.values()) == self.nr_types - 1 55 | 56 | #### Dynamically setting the config file into variable 57 | config_file = importlib.import_module('opt.hyperconfig') # np_hv, np_dist 58 | config_dict = config_file.__getattribute__(self.model_type) 59 | 60 | for variable, value in config_dict.items(): 61 | self.__setattr__(variable, value) 62 | #### Training data 63 | 64 | # patches are stored as numpy arrays with N channels 65 | # ordering as [Image][Nuclei Pixels][Nuclei Type][Additional Map] 66 | # Ex: with type_classification=True 67 | # HoVer-Net: RGB - Nuclei Pixels - Type Map - Horizontal and Vertical Map 68 | # Ex: with type_classification=False 69 | # Dist : RGB - Nuclei Pixels - Distance Map 70 | if self.data_type != 'pannuke': 71 | data_code_dict = { 72 | 'sonnet' : '540x540_76x76', 73 | } 74 | else: 75 | data_code_dict = { 76 | 'sonnet' : '270x270_76x76', 77 | } 78 | 79 | self.data_ext = '.npy' 80 | # list of directories containing validation patches. 81 | # For both train and valid directories, a comma separated list of directories can be used 82 | self.train_dir = ['/media/tandoan/data2/CoNSeP/Train/%s/' % data_code_dict[self.model_type]] 83 | # Used train_test_split alr 84 | self.valid_dir = ['/home/tandoan/work/PanNuke/Valid/%s' % data_code_dict[self.model_type]] 85 | 86 | # number of processes for parallel processing input 87 | self.nr_procs_train = 8 88 | self.nr_procs_valid = 4 89 | 90 | self.input_norm = True # normalize RGB to 0-1 range 91 | 92 | #### 93 | exp_id = 'v1.0' 94 | model_id = '%s' % self.model_type 95 | self.model_name = '%s/%s' % (exp_id, model_id) 96 | # loading chkpts in tensorflow, the path must not contain extra '/' 97 | self.log_path = '/media/tandoan/data2/logs/logs_test' 98 | self.save_dir = '%s/%s' % (self.log_path, self.model_name) # log file destination 99 | 100 | #### Info for running inferencee 101 | self.inf_auto_find_chkpt = False 102 | # path to checkpoints will be used for inference, replace accordingly 103 | self.inf_model_path = '/media/tandoan/data2/logs/logs_focalnet_noguide_consep/v1.0/focalnet/02/model-39650.index' 104 | 105 | # output will have channel ordering as [Nuclei Type][Nuclei Pixels][Additional] 106 | # where [Nuclei Type] will be used for getting the type of each instance 107 | # while [Nuclei Pixels][Additional] will be used for extracting instance 108 | 109 | self.inf_imgs_ext = '.png' 110 | self.inf_data_dir = '/media/tandoan/data2/CoNSeP/Test/Images' 111 | self.inf_output_dir = 'output/test/' 112 | 113 | # for inference during evalutaion mode i.e run by infer.py 114 | self.eval_inf_input_tensor_names = ['images'] 115 | 116 | 117 | 118 | # for inference during training mode i.e run by trainer.py 119 | if self.model_type == 'sonnet': 120 | self.train_inf_output_tensor_names = ['predmap-coded', 'truemap-coded'] 121 | self.eval_inf_output_tensor_names = ['predmap-coded', 'predmap-ord'] 122 | 123 | 124 | def get_model(self, phase=1): 125 | if phase!=2: 126 | model_constructor = importlib.import_module('model.sonnet') 127 | model_constructor = model_constructor.Sonnet 128 | else: 129 | model_constructor = importlib.import_module('model.sonnet_v2') 130 | model_constructor = model_constructor.Sonnet_phase2 131 | return model_constructor # NOTE return alias, not object 132 | 133 | 134 | # refer to https://tensorpack.readthedocs.io/modules/dataflow.imgaug.html for 135 | # information on how to modify the augmentation parameters 136 | def get_train_augmentors(self, input_shape, output_shape, view=False): 137 | if self.data_type != 'pannuke': 138 | shape_augs = [ 139 | imgaug.Affine( 140 | shear=5, # in degree 141 | scale=(0.8, 1.2), 142 | rotate_max_deg=179, 143 | translate_frac=(0.01, 0.01), 144 | interp=cv2.INTER_NEAREST, 145 | border=cv2.BORDER_CONSTANT), 146 | imgaug.Flip(vert=True), 147 | imgaug.Flip(horiz=True), 148 | imgaug.CenterCrop(input_shape), 149 | ] 150 | else: 151 | shape_augs =[ 152 | imgaug.Flip(vert=True), 153 | imgaug.Flip(horiz=True), 154 | ] 155 | 156 | input_augs = [ 157 | imgaug.RandomApplyAug( 158 | imgaug.RandomChooseAug( 159 | [ 160 | GaussianBlur(), 161 | MedianBlur(), 162 | imgaug.GaussianNoise(), 163 | ] 164 | ), 0.5), 165 | # standard color augmentation 166 | imgaug.RandomOrderAug( 167 | [imgaug.Hue((-8, 8), rgb=True), 168 | imgaug.Saturation(0.2, rgb=True), 169 | imgaug.Brightness(26, clip=True), 170 | imgaug.Contrast((0.75, 1.25), clip=True), 171 | ]), 172 | imgaug.ToUint8(), 173 | ] 174 | 175 | label_augs = [] 176 | if self.model_type == 'sonnet': 177 | label_augs = [GenInstanceOrd(crop_shape=output_shape)] 178 | 179 | if not self.type_classification: 180 | label_augs.append(BinarizeLabel()) 181 | 182 | if not view: 183 | label_augs.append(imgaug.CenterCrop(output_shape)) 184 | 185 | return shape_augs, input_augs, label_augs 186 | 187 | def get_valid_augmentors(self, input_shape, output_shape, view=False): 188 | shape_augs = [ 189 | imgaug.CenterCrop(input_shape), 190 | ] 191 | 192 | input_augs = None 193 | 194 | label_augs = [] 195 | if self.model_type == 'sonnet': 196 | label_augs = [GenInstanceOrd(crop_shape=output_shape)] 197 | label_augs.append(BinarizeLabel()) 198 | 199 | if not view: 200 | label_augs.append(imgaug.CenterCrop(output_shape)) 201 | 202 | return shape_augs, input_augs, label_augs 203 | -------------------------------------------------------------------------------- /misc/patch_extractor.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import os 4 | 5 | import cv2 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import glob 9 | from scipy.ndimage import measurements 10 | 11 | from .utils import cropping_center 12 | from .utils import rm_n_mkdir 13 | 14 | ##### 15 | class PatchExtractor(object): 16 | """ 17 | Extractor to generate patches with or without padding. 18 | Turn on debug mode to see how it is done. 19 | 20 | Args: 21 | x : input image, should be of shape HWC 22 | win_size : a tuple of (h, w) 23 | step_size : a tuple of (h, w) 24 | debug : flag to see how it is done 25 | Return: 26 | a list of sub patches, each patch has dtype same as x 27 | 28 | Examples: 29 | >>> xtractor = PatchExtractor((450, 450), (120, 120)) 30 | >>> img = np.full([1200, 1200, 3], 255, np.uint8) 31 | >>> patches = xtractor.extract(img, 'mirror') 32 | """ 33 | def __init__(self, win_size, step_size, debug=False): 34 | 35 | self.patch_type = 'mirror' 36 | self.win_size = win_size 37 | self.step_size = step_size 38 | self.debug = debug 39 | self.counter = 0 40 | 41 | def __get_patch(self, x, ptx): 42 | pty = (ptx[0]+self.win_size[0], 43 | ptx[1]+self.win_size[1]) 44 | win = x[ptx[0]:pty[0], 45 | ptx[1]:pty[1]] 46 | assert win.shape[0] == self.win_size[0] and \ 47 | win.shape[1] == self.win_size[1], \ 48 | '[BUG] Incorrect Patch Size {0}'.format(win.shape) 49 | if self.debug: 50 | if self.patch_type == 'mirror': 51 | cen = cropping_center(win, self.step_size) 52 | cen = cen[...,self.counter % 3] 53 | cen.fill(150) 54 | cv2.rectangle(x,ptx,pty,(255,0,0),2) 55 | plt.imshow(x) 56 | plt.show(block=False) 57 | plt.pause(1) 58 | plt.close() 59 | self.counter += 1 60 | return win 61 | 62 | def __extract_valid(self, x): 63 | """ 64 | Extracted patches without padding, only work in case win_size > step_size 65 | 66 | Note: to deal with the remaining portions which are at the boundary a.k.a 67 | those which do not fit when slide left->right, top->bottom), we flip 68 | the sliding direction then extract 1 patch starting from right / bottom edge. 69 | There will be 1 additional patch extracted at the bottom-right corner 70 | 71 | Args: 72 | x : input image, should be of shape HWC 73 | win_size : a tuple of (h, w) 74 | step_size : a tuple of (h, w) 75 | Return: 76 | a list of sub patches, each patch is same dtype as x 77 | """ 78 | 79 | im_h = x.shape[0] 80 | im_w = x.shape[1] 81 | 82 | def extract_infos(length, win_size, step_size): 83 | flag = (length - win_size) % step_size != 0 84 | last_step = math.floor((length - win_size) / step_size) 85 | last_step = (last_step + 1) * step_size 86 | return flag, last_step 87 | 88 | h_flag, h_last = extract_infos(im_h, self.win_size[0], self.step_size[0]) 89 | w_flag, w_last = extract_infos(im_w, self.win_size[1], self.step_size[1]) 90 | 91 | sub_patches = [] 92 | #### Deal with valid block 93 | for row in range(0, h_last, self.step_size[0]): 94 | for col in range(0, w_last, self.step_size[1]): 95 | win = self.__get_patch(x, (row, col)) 96 | sub_patches.append(win) 97 | #### Deal with edge case 98 | if h_flag: 99 | row = im_h - self.win_size[0] 100 | for col in range(0, w_last, self.step_size[1]): 101 | win = self.__get_patch(x, (row, col)) 102 | sub_patches.append(win) 103 | if w_flag: 104 | col = im_w - self.win_size[1] 105 | for row in range(0, h_last, self.step_size[0]): 106 | win = self.__get_patch(x, (row, col)) 107 | sub_patches.append(win) 108 | if h_flag and w_flag: 109 | ptx = (im_h - self.win_size[0], im_w - self.win_size[1]) 110 | win = self.__get_patch(x, ptx) 111 | sub_patches.append(win) 112 | return sub_patches 113 | 114 | def __extract_mirror(self, x): 115 | """ 116 | Extracted patches with mirror padding the boundary such that the 117 | central region of each patch is always within the orginal (non-padded) 118 | image while all patches' central region cover the whole orginal image 119 | 120 | Args: 121 | x : input image, should be of shape HWC 122 | win_size : a tuple of (h, w) 123 | step_size : a tuple of (h, w) 124 | Return: 125 | a list of sub patches, each patch is same dtype as x 126 | """ 127 | 128 | diff_h = self.win_size[0] - self.step_size[0] 129 | padt = diff_h // 2 130 | padb = diff_h - padt 131 | 132 | diff_w = self.win_size[1] - self.step_size[1] 133 | padl = diff_w // 2 134 | padr = diff_w - padl 135 | 136 | pad_type = 'constant' if self.debug else 'reflect' 137 | x = np.lib.pad(x, ((padt, padb), (padl, padr), (0, 0)), pad_type) 138 | sub_patches = self.__extract_valid(x) 139 | return sub_patches 140 | 141 | def extract(self, x, patch_type): 142 | patch_type = patch_type.lower() 143 | self.patch_type = patch_type 144 | if patch_type == 'valid': 145 | return self.__extract_valid(x) 146 | elif patch_type == 'mirror': 147 | return self.__extract_mirror(x) 148 | else: 149 | assert False, 'Unknown Patch Type [%s]' % patch_type 150 | return 151 | 152 | class Padding_image(object): 153 | """ 154 | Padding Images to reach the minimum size using `mirror` method. Use for Monusac dataset. 155 | For HoverNet, win_size is 540x540 156 | """ 157 | def __init__(self, win_size): 158 | self.win_size = win_size 159 | 160 | def pad(self, img_dir, ann_dir, save_img_dir, save_ann_dir, pad_type='reflect'): 161 | file_list = glob.glob(img_dir + '/*.tif') 162 | file_list.sort() 163 | rm_n_mkdir(save_img_dir) 164 | rm_n_mkdir(save_ann_dir) 165 | for filename in file_list: 166 | print(filename) 167 | filename = os.path.basename(filename) 168 | basename = filename.split('.')[0] 169 | img = cv2.imread(img_dir + '/' + basename + '.tif') 170 | ann = np.load(ann_dir + '/' + basename + '.npy') 171 | padt, padb, padl, padr = 0, 0, 0, 0 172 | if img.shape[0] < 540: 173 | diff_h = self.win_size[0] - img.shape[0] 174 | padt = diff_h // 2 175 | padb = diff_h - padt 176 | if img.shape[1] < 540: 177 | diff_w = self.win_size[1] - img.shape[1] 178 | padl = diff_w // 2 179 | padr = diff_w - padl 180 | img = np.lib.pad(img, ((padt, padb), (padl, padr), (0,0)), pad_type) 181 | ann = np.lib.pad(ann, ((padt, padb), (padl, padr), (0,0)), pad_type) 182 | inst_map = ann[...,0] 183 | inst_map = self._fix_mirror_padding(inst_map) 184 | 185 | cv2.imwrite(save_img_dir + '/' + basename + '.tif', img) 186 | np.save(save_ann_dir + '/' + basename, ann) 187 | 188 | 189 | def _fix_mirror_padding(self, ann): 190 | """ 191 | Deal with duplicated instances due to mirroring in interpolation 192 | during shape augmentation (scale, rotation etc.) 193 | """ 194 | current_max_id = np.amax(ann) 195 | inst_list = list(np.unique(ann)) 196 | inst_list.remove(0) # 0 is background 197 | for inst_id in inst_list: 198 | inst_map = np.array(ann == inst_id, np.uint8) 199 | remapped_ids = measurements.label(inst_map)[0] 200 | remapped_ids[remapped_ids > 1] += current_max_id 201 | ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1] 202 | current_max_id = np.amax(ann) 203 | return ann 204 | 205 | 206 | 207 | ##### 208 | 209 | ########################################################################### 210 | 211 | if __name__ == '__main__': 212 | # toy example for debug 213 | # 355x355, 480x480 214 | xtractor = PatchExtractor((450, 450), (120, 120), debug=True) 215 | a = np.full([1200, 1200, 3], 255, np.uint8) 216 | xtractor.extract(a, 'mirror') 217 | xtractor.extract(a, 'valid') 218 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import math 4 | import os 5 | from collections import deque 6 | 7 | import cv2 8 | import numpy as np 9 | from scipy import io as sio 10 | import matplotlib.pyplot as plt 11 | from skimage import measure 12 | from scipy.ndimage import find_objects 13 | 14 | from tensorpack.predict import OfflinePredictor, PredictConfig 15 | from tensorpack.tfutils.sessinit import get_model_loader 16 | 17 | from config import Config 18 | from misc.utils import rm_n_mkdir 19 | 20 | import json 21 | import operator 22 | 23 | 24 | #### 25 | def get_best_chkpts(path, metric_name, comparator='>'): 26 | """ 27 | Return the best checkpoint according to some criteria. 28 | Note that it will only return valid path, so any checkpoint that has been 29 | removed wont be returned (i.e moving to next one that satisfies the criteria 30 | such as second best etc.) 31 | Args: 32 | path: directory contains all checkpoints, including the "stats.json" file 33 | """ 34 | stat_file = path + '/02' + '/stats.json' 35 | ops = { 36 | '>': operator.gt, 37 | '<': operator.lt, 38 | } 39 | 40 | op_func = ops[comparator] 41 | with open(stat_file) as f: 42 | info = json.load(f) 43 | 44 | if comparator == '>': 45 | best_value = -float("inf") 46 | else: 47 | best_value = +float("inf") 48 | 49 | best_chkpt = None 50 | for epoch_stat in info: 51 | epoch_value = epoch_stat[metric_name] 52 | if op_func(epoch_value, best_value): 53 | chkpt_path = "%s/02/model-%d.index" % (path, epoch_stat['global_step']) 54 | if os.path.isfile(chkpt_path): 55 | selected_stat = epoch_stat 56 | best_value = epoch_value 57 | best_chkpt = chkpt_path 58 | return best_chkpt, selected_stat 59 | 60 | 61 | #### 62 | class Inferer(Config): 63 | 64 | def __gen_prediction(self, x, predictor): 65 | """ 66 | Using 'predictor' to generate the prediction of image 'x' 67 | 68 | Args: 69 | x : input image to be segmented. It will be split into patches 70 | to run the prediction upon before being assembled back 71 | """ 72 | step_size = [40, 40] 73 | msk_size = self.infer_mask_shape 74 | win_size = self.infer_input_shape 75 | 76 | 77 | 78 | def get_last_steps(length, step_size): 79 | nr_step = math.ceil((length - step_size) / step_size) 80 | last_step = (nr_step + 1) * step_size 81 | return int(last_step), int(nr_step + 1) 82 | 83 | im_h = x.shape[0] 84 | im_w = x.shape[1] 85 | 86 | padt_img, padb_img = 0, 0 87 | padl_img, padr_img = 0, 0 88 | 89 | # pad if image size smaller than msk_size (for monusac dataset) 90 | if im_h < msk_size[0]: 91 | diff_h_img = msk_size[0] - im_h 92 | padt_img = diff_h_img // 2 93 | padb_img = diff_h_img - padt_img 94 | if im_w < msk_size[1]: 95 | diff_w_img = msk_size[1] - im_w 96 | padl_img = diff_w_img // 2 97 | padr_img = diff_w_img - padl_img 98 | x_pad = np.lib.pad(x, ((padt_img, padb_img), (padl_img, padr_img), (0, 0)), 'reflect') 99 | im_h_pad = x_pad.shape[0] 100 | im_w_pad = x_pad.shape[1] 101 | 102 | 103 | last_h, nr_step_h = get_last_steps(im_h_pad, step_size[0]) 104 | last_w, nr_step_w = get_last_steps(im_w_pad, step_size[1]) 105 | diff_h = win_size[0] - step_size[0] 106 | padt = diff_h // 2 107 | padb = last_h + win_size[0] - im_h_pad 108 | 109 | 110 | diff_w = win_size[1] - step_size[1] 111 | padl = diff_w // 2 112 | padr = last_w + win_size[1] - im_w_pad 113 | 114 | 115 | x_pad = np.lib.pad(x_pad, ((padt, padb), (padl, padr), (0, 0)), 'reflect') 116 | 117 | 118 | #### TODO: optimize this 119 | sub_patches = [] 120 | # generating subpatches from orginal 121 | for row in range(0, last_h, step_size[0]): 122 | for col in range (0, last_w, step_size[1]): 123 | win = x_pad[row:row+win_size[0], 124 | col:col+win_size[1]] 125 | sub_patches.append(win) 126 | pred_coded = deque() 127 | pred_ord = deque() 128 | while len(sub_patches) > self.inf_batch_size: 129 | mini_batch = sub_patches[:self.inf_batch_size] 130 | sub_patches = sub_patches[self.inf_batch_size:] 131 | mini_output = predictor(mini_batch) 132 | mini_coded = mini_output[0][:, 18:58, 18:58,:] 133 | mini_ord = mini_output[1][:, 18:58, 18:58, :] 134 | mini_coded = np.split(mini_coded, self.inf_batch_size, axis=0) 135 | pred_coded.extend(mini_coded) 136 | mini_ord = np.split(mini_ord, self.inf_batch_size, axis=0) 137 | pred_ord.extend(mini_ord) 138 | if len(sub_patches) != 0: 139 | mini_output = predictor(sub_patches) 140 | mini_coded = mini_output[0][:, 18:58, 18:58,:] 141 | mini_ord = mini_output[1][:, 18:58, 18:58, :] 142 | mini_coded = np.split(mini_coded, len(sub_patches), axis=0) 143 | pred_coded.extend(mini_coded) 144 | mini_ord = np.split(mini_ord, len(sub_patches), axis=0) 145 | pred_ord.extend(mini_ord) 146 | 147 | #### Assemble back into full image 148 | output_patch_shape = np.squeeze(pred_coded[0]).shape 149 | ch = 1 if len(output_patch_shape) == 2 else output_patch_shape[-1] 150 | 151 | #### Assemble back into full image 152 | pred_coded = np.squeeze(np.array(pred_coded)) 153 | pred_coded = np.reshape(pred_coded, (nr_step_h, nr_step_w) + pred_coded.shape[1:]) 154 | 155 | pred_coded = np.transpose(pred_coded, [0, 2, 1, 3, 4]) if ch != 1 else \ 156 | np.transpose(pred_coded, [0, 2, 1, 3]) 157 | pred_coded = np.reshape(pred_coded, (pred_coded.shape[0] * pred_coded.shape[1], 158 | pred_coded.shape[2] * pred_coded.shape[3], ch)) 159 | pred_coded = np.squeeze(pred_coded[:im_h_pad, :im_w_pad]) # just crop back to original size 160 | pred_coded = pred_coded[padt_img:padt_img+im_h, padl_img:padl_img+im_w] 161 | 162 | pred_ord = np.squeeze(np.array(pred_ord)) 163 | pred_ord = np.reshape(pred_ord, (nr_step_h, nr_step_w) + pred_ord.shape[1:]) 164 | pred_ord = np.transpose(pred_ord, [0, 2, 1, 3]) 165 | pred_ord = np.reshape(pred_ord, (pred_ord.shape[0] * pred_ord.shape[1], pred_ord.shape[2] * pred_ord.shape[3])) 166 | pred_ord = np.squeeze(pred_ord[:im_h_pad, :im_w_pad]) 167 | pred_ord = pred_ord[padt_img:padt_img+im_h, padl_img:padl_img+im_w] 168 | 169 | return pred_coded, pred_ord 170 | 171 | #### 172 | def run(self): 173 | 174 | if self.inf_auto_find_chkpt: 175 | print('-----Auto Selecting Checkpoint Basing On "%s" Through "%s" Comparison' % \ 176 | (self.inf_auto_metric, self.inf_auto_comparator)) 177 | model_path, stat = get_best_chkpts(self.save_dir, self.inf_auto_metric, self.inf_auto_comparator) 178 | print('Selecting: %s' % model_path) 179 | print('Having Following Statistics:') 180 | for key, value in stat.items(): 181 | print('\t%s: %s' % (key, value)) 182 | else: 183 | model_path = self.inf_model_path 184 | 185 | model_constructor = self.get_model() 186 | pred_config = PredictConfig( 187 | model = model_constructor(), 188 | session_init = get_model_loader(model_path), 189 | input_names = self.eval_inf_input_tensor_names, 190 | output_names = self.eval_inf_output_tensor_names) 191 | predictor = OfflinePredictor(pred_config) 192 | 193 | save_dir = self.inf_output_dir 194 | file_list = glob.glob('%s/*%s' % (self.inf_data_dir, self.inf_imgs_ext)) 195 | file_list.sort() # ensure same order 196 | 197 | rm_n_mkdir(save_dir) 198 | for filename in file_list: 199 | filename = os.path.basename(filename) 200 | basename = filename.split('.')[0] 201 | print(self.inf_data_dir, basename, end=' ', flush=True) 202 | 203 | ## 204 | if self.data_type != 'pannuke': 205 | img = cv2.imread(self.inf_data_dir + '/' + filename) 206 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 207 | else: 208 | img = np.load(self.inf_data_dir + '/' + filename) 209 | ## 210 | pred_coded, pred_ord = self.__gen_prediction(img, predictor) 211 | sio.savemat('%s/%s.mat' % (save_dir, basename), {'result':[pred_coded], 'result-ord':[pred_ord]}) 212 | print('FINISH') 213 | 214 | 215 | 216 | 217 | #### 218 | if __name__ == '__main__': 219 | parser = argparse.ArgumentParser() 220 | parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') 221 | args = parser.parse_args() 222 | 223 | if args.gpu: 224 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 225 | 226 | inferer = Inferer() 227 | inferer.run() 228 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import json 4 | import os 5 | import random 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | from tensorpack import Inferencer, logger 10 | from tensorpack.callbacks import (DataParallelInferenceRunner, ModelSaver, 11 | MaxSaver, ScheduledHyperParamSetter, RunOp) 12 | from tensorpack.tfutils import SaverRestore, get_model_loader 13 | from tensorpack.train import (SyncMultiGPUTrainerParameterServer, TrainConfig, 14 | launch_train_with_config) 15 | 16 | import loader.loader as loader 17 | from config import Config 18 | from misc.utils import get_files, rm_n_mkdir 19 | 20 | import matplotlib.pyplot as plt 21 | 22 | 23 | class StatCollector(Inferencer, Config): 24 | """ 25 | Accumulate output of inference during training. 26 | After the inference finishes, calculate the statistics 27 | """ 28 | def __init__(self, prefix='valid'): 29 | super(StatCollector, self).__init__() 30 | self.prefix = prefix 31 | 32 | def _get_fetches(self): 33 | return self.train_inf_output_tensor_names 34 | 35 | def _before_inference(self): 36 | 37 | self.over_inter_np = 0 38 | self.over_total_np = 0 39 | self.over_correct_np = 0 40 | self.nr_pixels = 0 41 | self.over_type_dict = {} 42 | for type_name, type_id in self.nuclei_type_dict.items(): 43 | self.over_type_dict['fdetect_inter_%s' % (type_name)] = 0 44 | self.over_type_dict['fdetect_total_%s' % (type_name)] = 0 45 | 46 | def _on_fetches(self, outputs): 47 | pred, true = outputs 48 | 49 | def _dice_info(true, pred, label): 50 | true = np.array(true == label, np.int32) 51 | pred = np.array(pred == label, np.int32) 52 | inter = (pred * true).sum() 53 | total = (pred + true).sum() 54 | return inter, total 55 | 56 | def _fdetect_info(true, pred, label): 57 | tp_dt = ((true == label)&(pred == label)).sum() 58 | tn_dt = ((true != label)&(true != 0)&(pred != label)&(pred != 0)).sum() 59 | fp_dt = ((true != label)&(true != 0)&(pred == label)).sum() 60 | fn_dt = ((true == label)&(pred != label)&(pred != 0)).sum() 61 | fp_d = ((true == 0)&(pred != 0)).sum() 62 | fn_d = ((true != 0)&(pred == 0)).sum() 63 | inter = 2 * (tp_dt + tn_dt) 64 | total = 2 * (tp_dt + tn_dt + fp_dt + fn_dt) + fp_d + fn_d 65 | return inter, total 66 | 67 | pred_type = pred[...,:self.nr_types] 68 | pred_inst = pred[...,self.nr_types:] 69 | true_type = true[...,1] 70 | 71 | self.nr_pixels += np.size(true[...,:1]) 72 | pred_np = pred_inst[...,0] 73 | true_np = true[...,0] 74 | pred_np[pred_np >= 0.5] = 1.0 75 | pred_np[pred_np < 0.5] = 0.0 76 | correct = (pred_np == true_np).sum() 77 | self.over_correct_np += correct 78 | inter, total = _dice_info(true_np, pred_np, 1) 79 | self.over_inter_np += inter 80 | self.over_total_np += total 81 | 82 | 83 | pred_type = np.argmax(pred_type, axis=-1) 84 | for type_name, type_id in self.nuclei_type_dict.items(): 85 | inter, total = _fdetect_info(true_type, pred_type, type_id) 86 | self.over_type_dict['fdetect_inter_%s' % (type_name)] += inter 87 | self.over_type_dict['fdetect_total_%s' % (type_name)] += total 88 | 89 | def _after_inference(self): 90 | 91 | stat_dict = {} 92 | 93 | stat_dict[self.prefix + '_acc' ] = self.over_correct_np / self.nr_pixels 94 | stat_dict[self.prefix + '_dice'] = 2 * self.over_inter_np / (self.over_total_np + 1.0e-8) 95 | 96 | if self.type_classification: 97 | for type_name, type_id in self.nuclei_type_dict.items(): 98 | stat_dict['%s_fdetect_%s' % (self.prefix, type_name)] = (self.over_type_dict['fdetect_inter_%s' % (type_name)] + 1.0e-8) / (self.over_type_dict['fdetect_total_%s' % (type_name)] + 1.0e-8) 99 | 100 | return stat_dict 101 | 102 | 103 | #### 104 | 105 | ########################################### 106 | class Trainer(Config): 107 | #### 108 | def get_datagen(self, batch_size, mode='train', view=False): 109 | train_set = get_files(self.train_dir, self.data_ext) 110 | test_set = get_files(self.valid_dir, self.data_ext) 111 | if mode == 'train': 112 | augmentors = self.get_train_augmentors( 113 | self.train_input_shape, 114 | self.train_mask_shape, 115 | view) 116 | data_files = train_set 117 | data_generator = loader.train_generator 118 | nr_procs = self.nr_procs_train 119 | else: 120 | augmentors = self.get_valid_augmentors( 121 | self.infer_input_shape, 122 | self.infer_mask_shape, 123 | view) 124 | data_files = test_set 125 | data_generator = loader.valid_generator 126 | nr_procs = self.nr_procs_valid 127 | 128 | # set nr_proc=1 for viewing to ensure clean ctrl-z 129 | nr_procs = 1 if view else nr_procs 130 | dataset = loader.DatasetSerial(data_files) 131 | datagen = data_generator(dataset, 132 | shape_aug=augmentors[0], 133 | input_aug=augmentors[1], 134 | label_aug=augmentors[2], 135 | batch_size=batch_size, 136 | nr_procs=nr_procs) 137 | 138 | return datagen 139 | #### 140 | def view_dataset(self, mode='train'): 141 | assert mode == 'train' or mode == 'valid', "Invalid view mode" 142 | datagen = self.get_datagen(4, mode='train', view=True) 143 | loader.visualize(datagen, 4) 144 | return 145 | #### 146 | def run_once(self, opt, idx, sess_init=None, save_dir=None): 147 | #### 148 | train_datagen = self.get_datagen(opt['train_batch_size'], mode='train') 149 | valid_datagen = self.get_datagen(opt['infer_batch_size'], mode='valid') 150 | 151 | ###### must be called before ModelSaver 152 | if save_dir is None: 153 | logger.set_logger_dir(self.save_dir) 154 | else: 155 | logger.set_logger_dir(save_dir) 156 | 157 | ###### 158 | model_flags = opt['model_flags'] 159 | model = self.get_model(phase=idx)(**model_flags) 160 | ###### 161 | callbacks=[ 162 | ModelSaver(max_to_keep=opt['nr_epochs']), 163 | ] 164 | callbacks.append(RunOp(tf.tables_initializer(), run_as_trigger=False)) 165 | for param_name, param_info in opt['manual_parameters'].items(): 166 | model.add_manual_variable(param_name, param_info[0]) 167 | callbacks.append(ScheduledHyperParamSetter(param_name, param_info[1])) 168 | # multi-GPU inference (with mandatory queue prefetch) 169 | infs = [StatCollector()] 170 | callbacks.append(DataParallelInferenceRunner( 171 | valid_datagen, infs, list(range(nr_gpus)))) 172 | callbacks.append(MaxSaver('valid_dice')) 173 | 174 | ###### 175 | steps_per_epoch = train_datagen.size() // nr_gpus 176 | 177 | config = TrainConfig( 178 | model = model, 179 | callbacks = callbacks , 180 | dataflow = train_datagen , 181 | steps_per_epoch = steps_per_epoch, 182 | max_epoch = opt['nr_epochs'], 183 | ) 184 | config.session_init = sess_init 185 | 186 | launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_gpus)) 187 | tf.reset_default_graph() # remove the entire graph in case of multiple runs 188 | return 189 | #### 190 | def run(self): 191 | def get_last_chkpt_path(prev_phase_dir): 192 | stat_file_path = prev_phase_dir + 'stats.json' 193 | with open(stat_file_path) as stat_file: 194 | info = json.load(stat_file) 195 | chkpt_list = [epoch_stat['global_step'] for epoch_stat in info] 196 | last_chkpts_path = "%smodel-%d.index" % (prev_phase_dir, max(chkpt_list)) 197 | return last_chkpts_path 198 | 199 | phase_opts = self.training_phase 200 | if len(phase_opts) > 1: 201 | for idx, opt in enumerate(phase_opts): 202 | random.seed(self.seed) 203 | np.random.seed(self.seed) 204 | tf.random.set_random_seed(self.seed) 205 | log_dir = '%s/%02d/' % (self.save_dir, idx) 206 | pretrained_path = opt['pretrained_path'] 207 | if pretrained_path == -1: 208 | pretrained_path = get_last_chkpt_path(prev_log_dir) 209 | init_weights = SaverRestore(pretrained_path, ignore=['learning_rate']) 210 | elif pretrained_path is not None: 211 | init_weights = get_model_loader(pretrained_path) 212 | prev_log_dir = log_dir 213 | self.run_once(opt, idx, sess_init=init_weights, save_dir=log_dir) 214 | return 215 | #### 216 | #### 217 | 218 | ########################################################################### 219 | 220 | if __name__ == '__main__': 221 | parser = argparse.ArgumentParser() 222 | parser.add_argument('--gpu', help="comma separated list of GPU(s) to use.") 223 | parser.add_argument('--view', help="view dataset, received either 'train' or 'valid' as input") 224 | args = parser.parse_args() 225 | 226 | trainer = Trainer() 227 | if args.view: 228 | trainer.view_dataset(mode=args.view) 229 | else: 230 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 231 | nr_gpus = len(args.gpu.split(',')) 232 | trainer.run() 233 | -------------------------------------------------------------------------------- /compute_stats.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cProfile as profile 3 | import glob 4 | import os 5 | 6 | import cv2 7 | import numpy as np 8 | import scipy.io as sio 9 | import pandas as pd 10 | 11 | from metrics.stats_utils import * 12 | from config import Config 13 | 14 | cfg = Config() 15 | 16 | def run_nuclei_type_stat(pred_dir, true_dir, type_uid_list=None, exhaustive=True): 17 | """ 18 | GT must be exhaustively annotated for instance location (detection) 19 | 20 | Args: 21 | true_dir, pred_dir: Directory contains .mat annotation for each image. 22 | Each .mat must contain: 23 | 24 | --`inst_centroid`: Nx2, contains N instance centroid 25 | of mass coordinates (X, Y) 26 | --`inst_type` : Nx1: type of each instance at each index 27 | 28 | `inst_centroid` and `inst_type` must be aligned and each 29 | index must be associated to the same instance 30 | 31 | type_uid_list : list of id for nuclei type which the score should be calculated. 32 | Default to `None` means available nuclei type in GT. 33 | 34 | exhaustive : Flag to indicate whether GT is exhaustively labelled 35 | for instance types 36 | """ 37 | ### 38 | file_list = glob.glob(pred_dir + '*.mat') 39 | file_list.sort() # ensure same order [1] 40 | paired_all = [] # unique matched index pair 41 | unpaired_true_all = [] # the index must exist in `true_inst_type_all` and unique 42 | unpaired_pred_all = [] # the index must exist in `pred_inst_type_all` and unique 43 | true_inst_type_all = [] # each index is 1 independent data point 44 | pred_inst_type_all = [] # each index is 1 independent data point 45 | for file_idx, filename in enumerate(file_list[:]): 46 | filename = os.path.basename(filename) 47 | basename = filename.split('.')[0] 48 | true_info = sio.loadmat(true_dir + basename + '.mat') 49 | # dont squeeze, may be 1 instance exist 50 | true_centroid = (true_info['inst_centroid']).astype('float32') 51 | true_inst_type = (true_info['inst_type']).astype('int32') 52 | 53 | if true_centroid.shape[0] != 0: 54 | true_inst_type = true_inst_type[:,0] 55 | else: # no instance at all 56 | true_centroid = np.array([[0, 0]]) 57 | true_inst_type = np.array([0]) 58 | 59 | # * for converting the GT type in CoNSeP and GLySAC 60 | if cfg.data_type == 'consep': 61 | true_inst_type[(true_inst_type == 3) | (true_inst_type == 4)] = 3 62 | true_inst_type[(true_inst_type == 5) | (true_inst_type == 6) | (true_inst_type == 7)] = 4 63 | if cfg.data_type == 'glysac': 64 | true_inst_type[(true_inst_type == 1) | (true_inst_type == 2) | (true_inst_type == 9) | (true_inst_type == 10)] = 1 65 | true_inst_type[(true_inst_type == 4) | (true_inst_type == 5) | (true_inst_type == 6) | (true_inst_type == 7)] = 2 66 | true_inst_type[(true_inst_type == 8) | (true_inst_type == 3)] = 3 67 | 68 | pred_info = sio.loadmat(pred_dir + basename + '.mat') 69 | # dont squeeze, may be 1 instance exist 70 | pred_centroid = (pred_info['inst_centroid']).astype('float32') 71 | pred_inst_type = (pred_info['inst_type']).astype('int32') 72 | 73 | if pred_centroid.shape[0] != 0: 74 | pred_inst_type = pred_inst_type[:,0] 75 | else: # no instance at all 76 | pred_centroid = np.array([[0, 0]]) 77 | pred_inst_type = np.array([0]) 78 | 79 | # ! if take longer than 1min for 1000 vs 1000 pairing, sthg is wrong with coord 80 | paired, unpaired_true, unpaired_pred = pair_coordinates(true_centroid, pred_centroid, 12) 81 | 82 | # * Aggreate information 83 | # get the offset as each index represent 1 independent instance 84 | true_idx_offset = true_idx_offset + true_inst_type_all[-1].shape[0] if file_idx != 0 else 0 85 | pred_idx_offset = pred_idx_offset + pred_inst_type_all[-1].shape[0] if file_idx != 0 else 0 86 | true_inst_type_all.append(true_inst_type) 87 | pred_inst_type_all.append(pred_inst_type) 88 | 89 | # increment the pairing index statistic 90 | if paired.shape[0] != 0: # ! sanity 91 | paired[:,0] += true_idx_offset 92 | paired[:,1] += pred_idx_offset 93 | paired_all.append(paired) 94 | 95 | unpaired_true += true_idx_offset 96 | unpaired_pred += pred_idx_offset 97 | unpaired_true_all.append(unpaired_true) 98 | unpaired_pred_all.append(unpaired_pred) 99 | 100 | paired_all = np.concatenate(paired_all, axis=0) 101 | unpaired_true_all = np.concatenate(unpaired_true_all, axis=0) 102 | unpaired_pred_all = np.concatenate(unpaired_pred_all, axis=0) 103 | true_inst_type_all = np.concatenate(true_inst_type_all, axis=0) 104 | pred_inst_type_all = np.concatenate(pred_inst_type_all, axis=0) 105 | 106 | paired_true_type = true_inst_type_all[paired_all[:,0]] 107 | paired_pred_type = pred_inst_type_all[paired_all[:,1]] 108 | unpaired_true_type = true_inst_type_all[unpaired_true_all] 109 | unpaired_pred_type = pred_inst_type_all[unpaired_pred_all] 110 | 111 | ### 112 | def _f1_type(paired_true, paired_pred, unpaired_true, unpaired_pred, type_id, w): 113 | type_samples = (paired_true == type_id) | (paired_pred == type_id) 114 | 115 | paired_true = paired_true[type_samples] 116 | paired_pred = paired_pred[type_samples] 117 | 118 | tp_dt = ((paired_true == type_id) & (paired_pred == type_id)).sum() 119 | tn_dt = ((paired_true != type_id) & (paired_pred != type_id)).sum() 120 | fp_dt = ((paired_true != type_id) & (paired_pred == type_id)).sum() 121 | fn_dt = ((paired_true == type_id) & (paired_pred != type_id)).sum() 122 | 123 | if not exhaustive: 124 | ignore = (paired_true == -1).sum() 125 | fp_dt -= ignore 126 | 127 | summary_1 = {} 128 | p_summ_1 = np.copy(paired_true) 129 | val, cnt = np.unique(p_summ_1[(paired_true != type_id) & (paired_pred == type_id)], return_counts=True) 130 | for name, num in zip(val, cnt): 131 | summary_1[name] = summary_1.get(name, 0) + num 132 | 133 | summary = {} 134 | p_summ = np.copy(paired_pred) 135 | val, cnt = np.unique(p_summ[(paired_true == type_id) & (paired_pred != type_id)], return_counts=True) 136 | for name, num in zip(val, cnt): 137 | summary[name] = summary.get(name, 0) + num 138 | 139 | fp_d = (unpaired_pred == type_id).sum() 140 | fn_d = (unpaired_true == type_id).sum() 141 | 142 | 143 | f1_type = (2 * (tp_dt + tn_dt)) / \ 144 | (2 * (tp_dt + tn_dt) + w[0] * fp_dt + w[1] * fn_dt \ 145 | + w[2] * fp_d + w[3] * fn_d) 146 | return f1_type 147 | 148 | # overall 149 | # * quite meaningless for not exhaustive annotated dataset 150 | w = [1, 1] 151 | tp_d = paired_pred_type.shape[0] 152 | fp_d = unpaired_pred_type.shape[0] 153 | fn_d = unpaired_true_type.shape[0] 154 | 155 | tp_tn_dt = (paired_pred_type == paired_true_type).sum() 156 | fp_fn_dt = (paired_pred_type != paired_true_type).sum() 157 | 158 | if not exhaustive: 159 | ignore = (paired_true_type == -1).sum() 160 | fp_fn_dt -= ignore 161 | 162 | acc_type = tp_tn_dt / (tp_tn_dt + fp_fn_dt) 163 | f1_d = 2 * tp_d / (2 * tp_d + w[0] * fp_d + w[1] * fn_d) 164 | 165 | w = [2, 2, 1, 1] 166 | 167 | if type_uid_list is None: 168 | type_uid_list = np.unique(true_inst_type_all).tolist() 169 | 170 | 171 | results_list = [f1_d, acc_type] 172 | for type_uid in type_uid_list: 173 | f1_type = _f1_type(paired_true_type, paired_pred_type, 174 | unpaired_pred_type, unpaired_true_type, type_uid, w) 175 | results_list.append(f1_type) 176 | 177 | np.set_printoptions(formatter={'float': '{: 0.5f}'.format}) 178 | print(np.array(results_list)) 179 | return 180 | 181 | def run_nuclei_inst_stat(pred_dir, true_dir, print_img_stats=False, ext='.mat'): 182 | # print stats of each image 183 | print(pred_dir) 184 | 185 | file_list = glob.glob('%s/*%s' % (pred_dir, ext)) 186 | file_list.sort() # ensure same order 187 | metrics = [[], [], [], [], [], []] 188 | for filename in file_list[:]: 189 | filename = os.path.basename(filename) 190 | basename = filename.split('.')[0] 191 | 192 | true = sio.loadmat(true_dir + basename + '.mat') 193 | true = (true['inst_map']).astype('int32') 194 | 195 | pred = sio.loadmat(pred_dir + basename + '.mat') 196 | pred = (pred['inst_map']).astype('int32') 197 | 198 | # to ensure that the instance numbering is contiguous 199 | pred = remap_label(pred, by_size=False) 200 | true = remap_label(true, by_size=False) 201 | 202 | pq_info = get_fast_pq(true, pred, match_iou=0.5)[0] 203 | metrics[0].append(get_dice_1(true, pred)) 204 | metrics[1].append(pq_info[0]) # dq 205 | metrics[2].append(pq_info[1]) # sq 206 | metrics[3].append(pq_info[2]) # pq 207 | metrics[4].append(get_fast_aji_plus(true, pred)) 208 | metrics[5].append(get_fast_aji(true, pred)) 209 | 210 | if print_img_stats: 211 | print(basename, end="\t") 212 | for scores in metrics: 213 | print("%f " % scores[-1], end=" ") 214 | print() 215 | #### 216 | metrics = np.array(metrics) 217 | metrics_avg = np.mean(metrics, axis=-1) 218 | np.set_printoptions(formatter={'float': '{: 0.5f}'.format}) 219 | print(metrics_avg) 220 | metrics_avg = list(metrics_avg) 221 | return metrics 222 | 223 | if __name__ == '__main__': 224 | parser = argparse.ArgumentParser() 225 | parser.add_argument('--mode', help="mode to run the measurement," 226 | "`type` for nuclei instance type classification or" 227 | "`instance` for nuclei instance segmentation", 228 | nargs='?', default='instance', const='instance') 229 | parser.add_argument('--pred_dir', help="point to output dir", nargs='?', default='', const='') 230 | parser.add_argument('--true_dir', help="point to ground truth dir", nargs='?', default='', const='') 231 | args = parser.parse_args() 232 | 233 | if args.mode == 'instance': 234 | run_nuclei_inst_stat(args.pred_dir, args.true_dir, print_img_stats=False) 235 | if args.mode == 'type': 236 | run_nuclei_type_stat(args.pred_dir, args.true_dir) 237 | 238 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import numpy as np 4 | 5 | import tensorflow as tf 6 | 7 | from tensorpack import * 8 | from tensorpack.tfutils.symbolic_functions import * 9 | from tensorpack.tfutils.summary import * 10 | 11 | from matplotlib import cm 12 | from config import Config 13 | import scipy.io as sio 14 | import glob 15 | 16 | cfg = Config() 17 | #### 18 | def resize_op(x, height_factor=None, width_factor=None, size=None, 19 | interp='bicubic', data_format='channels_last'): 20 | """ 21 | Resize by a factor if `size=None` else resize to `size` 22 | """ 23 | original_shape = x.get_shape().as_list() 24 | if size is not None: 25 | if data_format == 'channels_first': 26 | x = tf.transpose(x, [0, 2, 3, 1]) 27 | if interp == 'bicubic': 28 | x = tf.image.resize_bicubic(x, size) 29 | elif interp == 'bilinear': 30 | x = tf.image.resize_bilinear(x, size) 31 | else: 32 | x = tf.image.resize_nearest_neighbor(x, size) 33 | x = tf.transpose(x, [0, 3, 1, 2]) 34 | x.set_shape((None, 35 | original_shape[1] if original_shape[1] is not None else None, 36 | size[0], size[1])) 37 | else: 38 | if interp == 'bicubic': 39 | x = tf.image.resize_bicubic(x, size) 40 | elif interp == 'bilinear': 41 | x = tf.image.resize_bilinear(x, size) 42 | else: 43 | x = tf.image.resize_nearest_neighbor(x, size) 44 | x.set_shape((None, 45 | size[0], size[1], 46 | original_shape[3] if original_shape[3] is not None else None)) 47 | else: 48 | if data_format == 'channels_first': 49 | new_shape = tf.cast(tf.shape(x)[2:], tf.float32) 50 | new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('float32')) 51 | new_shape = tf.cast(new_shape, tf.int32) 52 | x = tf.transpose(x, [0, 2, 3, 1]) 53 | if interp == 'bicubic': 54 | x = tf.image.resize_bicubic(x, new_shape) 55 | elif interp == 'bilinear': 56 | x = tf.image.resize_bilinear(x, new_shape) 57 | else: 58 | x = tf.image.resize_nearest_neighbor(x, new_shape) 59 | x = tf.transpose(x, [0, 3, 1, 2]) 60 | x.set_shape((None, 61 | original_shape[1] if original_shape[1] is not None else None, 62 | int(original_shape[2] * height_factor) if original_shape[2] is not None else None, 63 | int(original_shape[3] * width_factor) if original_shape[3] is not None else None)) 64 | else: 65 | original_shape = x.get_shape().as_list() 66 | new_shape = tf.cast(tf.shape(x)[1:3], tf.float32) 67 | new_shape *= tf.constant(np.array([height_factor, width_factor]).astype('float32')) 68 | new_shape = tf.cast(new_shape, tf.int32) 69 | if interp == 'bicubic': 70 | x = tf.image.resize_bicubic(x, new_shape) 71 | elif interp == 'bilinear': 72 | x = tf.image.resize_bilinear(x, new_shape) 73 | else: 74 | x = tf.image.resize_nearest_neighbor(x, new_shape) 75 | x.set_shape((None, 76 | int(original_shape[1] * height_factor) if original_shape[1] is not None else None, 77 | int(original_shape[2] * width_factor) if original_shape[2] is not None else None, 78 | original_shape[3] if original_shape[3] is not None else None)) 79 | return x 80 | 81 | #### 82 | def crop_op(x, cropping, data_format='channels_first'): 83 | """ 84 | Center crop image 85 | Args: 86 | cropping is the substracted portion 87 | """ 88 | crop_t = cropping[0] // 2 89 | crop_b = cropping[0] - crop_t 90 | crop_l = cropping[1] // 2 91 | crop_r = cropping[1] - crop_l 92 | if data_format == 'channels_first': 93 | x = x[:,:,crop_t:-crop_b,crop_l:-crop_r] 94 | else: 95 | x = x[:,crop_t:-crop_b,crop_l:-crop_r] 96 | return x 97 | #### 98 | 99 | def label_smoothing(label, factor=0.2): 100 | """ 101 | Softening label (substract the hard label by factor value, prevent overconfident and improve calibration) 102 | Args: 103 | label: one-hot format 104 | Return: 105 | Soft Label 106 | """ 107 | label *= (1-factor) 108 | label += factor/(label.get_shape().as_list()[-1]) 109 | return label 110 | 111 | #### 112 | def categorical_crossentropy(output, target): 113 | """ 114 | categorical cross-entropy, accept probabilities not logit 115 | """ 116 | # scale preds so that the class probs of each sample sum to 1 117 | output /= tf.reduce_sum(output, 118 | reduction_indices=len(output.get_shape()) - 1, 119 | keepdims=True) 120 | # manual computation of crossentropy 121 | epsilon = tf.convert_to_tensor(10e-8, output.dtype.base_dtype) 122 | output = tf.clip_by_value(output, epsilon, 1. - epsilon) 123 | return - tf.reduce_sum(target * tf.log(output), 124 | reduction_indices=len(output.get_shape()) - 1) 125 | 126 | def categorical_crossentropy_modified(output, target): 127 | """ 128 | categorical cross-entropy, accept probabilities not logit 129 | """ 130 | w = np.array([0.8, 1.2]) # change according to the dataset; w[0], w[1] is the weight for background and foreground, respectively 131 | output /= tf.reduce_sum(output, 132 | reduction_indices=len(output.get_shape()) - 1, 133 | keepdims=True) 134 | # manual computation of crossentropy 135 | epsilon = tf.convert_to_tensor(10e-8, output.dtype.base_dtype) 136 | output = tf.clip_by_value(output, epsilon, 1. - epsilon) 137 | loss = tf.zeros_like(output[...,0]) 138 | for i in range(len(w)): 139 | loss += w[i] * target[...,i] * tf.log(output[...,i]) 140 | return -loss 141 | #### 142 | 143 | def check_weight_loss(train_dir): 144 | ''' 145 | Calculate the weights using for focal_loss_modified() 146 | Args: 147 | train_dir: training directory 148 | Return: 149 | w: numpy array 150 | ''' 151 | file_list = glob.glob(train_dir + '/*.mat') 152 | N = {} 153 | for file in file_list: 154 | type_map = sio.loadmat(file)['type_map'] 155 | if cfg.data_type == 'consep': 156 | type_map[(type_map == 3) | (type_map == 4)] = 3 157 | type_map[(type_map == 5) | (type_map == 6) | (type_map == 7)] = 4 158 | elif cfg.data_type == 'glysac': 159 | type_map[(type_map == 1) | (type_map == 2) | (type_map == 9) | (type_map == 10)] = 1 160 | type_map[(type_map == 4) | (type_map == 5) | (type_map == 6) | (type_map == 7)] = 2 161 | type_map[(type_map == 8) | (type_map == 3)] = 3 162 | val, cnt = np.unique(type_map, return_counts=True) 163 | for idx, type_id in enumerate(val): 164 | N[type_id] = N.get(type_id, 0) + cnt[idx] 165 | N = sorted(N.items()) 166 | N = [val for key, val in N] 167 | c = len(N) 168 | N = np.array(N) 169 | w = np.power(N[0]/N, 1/3) 170 | w = w/w.sum() * c 171 | return w 172 | 173 | 174 | #### 175 | def focal_loss_modified(output, target, gamma=1): 176 | w = np.array([0.3, 2.2, 0.96, 0.73, 0.81]) # Calculated from check_weight_loss(), need to modify according to the dataset 177 | output /= tf.reduce_sum(output, reduction_indices=len(output.get_shape()) - 1, 178 | keepdims=True) 179 | # manual computation of focal loss 180 | epsilon = tf.convert_to_tensor(10e-8, output.dtype.base_dtype) 181 | output = tf.clip_by_value(output, epsilon, 1. - epsilon) 182 | loss = tf.zeros_like(output[...,0]) 183 | for i in range(len(w)): 184 | loss += w[i] * target[...,i] * tf.pow((1 - output[...,i]), gamma) * tf.log(output[...,i]) 185 | return -loss 186 | 187 | 188 | #### 189 | def focal_loss(output, target, gamma=1): 190 | """ 191 | categorical focal loss, accept probabilities not logit 192 | Parameters: 193 | ----------- 194 | output: Predict type of each class [N, H, W, nr_classes] 195 | target: Ground Truth in one-hot encoding [N, H, W, nr_classes] 196 | """ 197 | # scale preds so that the class probs of each sample sum to 1 198 | output /= tf.reduce_sum(output, reduction_indices=len(output.get_shape()) - 1, 199 | keepdims=True) 200 | # manual computation of focal loss 201 | epsilon = tf.convert_to_tensor(10e-8, output.dtype.base_dtype) 202 | output = tf.clip_by_value(output, epsilon, 1. - epsilon) 203 | return -tf.reduce_sum(target * tf.pow((1- output), gamma) * tf.log(output), reduction_indices=len(output.get_shape()) - 1) 204 | 205 | 206 | 207 | #### 208 | def dice_loss(output, target, loss_type='sorensen', axis=None, smooth=1e-3): 209 | """Soft dice (Sørensen or Jaccard) coefficient for comparing the similarity 210 | of two batch of data, usually be used for binary image segmentation 211 | i.e. labels are binary. The coefficient between 0 to 1, 1 means totally match. 212 | 213 | Parameters 214 | ----------- 215 | output : Tensor 216 | A distribution with shape: [batch_size, ....], (any dimensions). 217 | target : Tensor 218 | The target distribution, format the same with `output`. 219 | loss_type : str 220 | ``jaccard`` or ``sorensen``, default is ``jaccard``. 221 | axis : tuple of int 222 | All dimensions are reduced, default ``[1,2,3]``. 223 | smooth : float 224 | This small value will be added to the numerator and denominator. 225 | - If both output and target are empty, it makes sure dice is 1. 226 | - If either output or target are empty (all pixels are background), 227 | dice = ```smooth/(small_value + smooth)``, then if smooth is very small, 228 | dice close to 0 (even the image values lower than the threshold), 229 | so in this case, higher smooth can have a higher dice. 230 | 231 | Examples 232 | --------- 233 | >>> dice_loss = dice_coe(outputs, y_) 234 | """ 235 | target = tf.squeeze(tf.cast(target, tf.float32)) 236 | output = tf.squeeze(tf.cast(output, tf.float32)) 237 | 238 | inse = tf.reduce_sum(output * target, axis=axis) 239 | if loss_type == 'jaccard': 240 | l = tf.reduce_sum(output * output, axis=axis) 241 | r = tf.reduce_sum(target * target, axis=axis) 242 | elif loss_type == 'sorensen': 243 | l = tf.reduce_sum(output, axis=axis) 244 | r = tf.reduce_sum(target, axis=axis) 245 | else: 246 | raise Exception("Unknown loss_type") 247 | # already flatten 248 | dice = 1.0 - (2. * inse + smooth) / (l + r + smooth) 249 | ## 250 | return dice 251 | 252 | #### 253 | 254 | def colorize(value, vmin=None, vmax=None, cmap=None): 255 | """ 256 | Arguments: 257 | - value: input tensor, NHWC ('channels_last') 258 | - vmin: the minimum value of the range used for normalization. 259 | (Default: value minimum) 260 | - vmax: the maximum value of the range used for normalization. 261 | (Default: value maximum) 262 | - cmap: a valid cmap named for use with matplotlib's `get_cmap`. 263 | (Default: 'gray') 264 | Example usage: 265 | ``` 266 | output = tf.random_uniform(shape=[256, 256, 1]) 267 | output_color = colorize(output, vmin=0.0, vmax=1.0, cmap='viridis') 268 | tf.summary.image('output', output_color) 269 | ``` 270 | 271 | Returns a 3D tensor of shape [height, width, 3], uint8. 272 | """ 273 | 274 | # normalize 275 | if vmin is None: 276 | vmin = tf.reduce_min(value, axis=[1,2]) 277 | vmin = tf.reshape(vmin, [-1, 1, 1]) 278 | if vmax is None: 279 | vmax = tf.reduce_max(value, axis=[1,2]) 280 | vmax = tf.reshape(vmax, [-1, 1, 1]) 281 | value = (value - vmin) / (vmax - vmin) # vmin..vmax 282 | 283 | # squeeze last dim if it exists 284 | # NOTE: will throw error if use get_shape() 285 | # value = tf.squeeze(value) 286 | 287 | # quantize 288 | value = tf.round(value * 255) 289 | indices = tf.cast(value, np.int32) 290 | 291 | # gather 292 | colormap = cm.get_cmap(cmap if cmap is not None else 'gray') 293 | colors = colormap(np.arange(256))[:, :3] 294 | colors = tf.constant(colors, dtype=tf.float32) 295 | value = tf.gather(colors, indices) 296 | value = tf.cast(value * 255, tf.uint8) 297 | return value 298 | #### 299 | def make_image(x, cy, cx, scale_y, scale_x): 300 | """ 301 | Take 1st image from x and turn channels representations 302 | into 2D image, with cx number of channels in x-axis and 303 | cy number of channels in y-axis 304 | """ 305 | # norm x for better visual 306 | x = tf.transpose(x,(0,2,3,1)) # NHWC 307 | max_x = tf.reduce_max(x, axis=-1, keep_dims=True) 308 | min_x = tf.reduce_min(x, axis=-1, keep_dims=True) 309 | x = 255 * (x - min_x) / (max_x - min_x) 310 | ### 311 | x_shape = tf.shape(x) 312 | channels = x_shape[-1] 313 | iy , ix = x_shape[1], x_shape[2] 314 | ### 315 | x = tf.slice(x,(0,0,0,0),(1,-1,-1,-1)) 316 | x = tf.reshape(x,(iy,ix,channels)) 317 | ix += 4 318 | iy += 4 319 | x = tf.image.resize_image_with_crop_or_pad(x, iy, ix) 320 | x = tf.reshape(x,(iy,ix,cy,cx)) 321 | x = tf.transpose(x,(2,0,3,1)) #cy,iy,cx,ix 322 | x = tf.reshape(x,(1,cy*iy,cx*ix,1)) 323 | x = resize_op(x, scale_y, scale_x) 324 | return tf.cast(x, tf.uint8) 325 | #### 326 | 327 | #### 328 | def gen_disc_labels(nrof_labels, min, max): 329 | min += 1 330 | max += 1 331 | 332 | labels = [] 333 | keys = [] 334 | for k in range(nrof_labels + 1): 335 | disc_label = math.exp(math.log(min) + (math.log(max / min) * k) / nrof_labels) 336 | labels.append(disc_label - 1) 337 | keys.append(k) 338 | 339 | 340 | return labels, tf.contrib.lookup.Hashtable(tf.contrib.lookup.KeyValueTensorInitializer(keys, labels, value_dtype=tf.float64), 0) 341 | 342 | #### 343 | def replacenan(t): 344 | return tf.where(tf.is_nan(t), tf.zeros_like(t), t) 345 | -------------------------------------------------------------------------------- /metrics/stats_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import linear_sum_assignment 3 | 4 | 5 | #####--------------------------Optimized for Speed 6 | def get_fast_aji(true, pred): 7 | """ 8 | AJI version distributed by MoNuSeg, has no permutation problem but suffered from 9 | over-penalisation similar to DICE2 10 | 11 | Fast computation requires instance IDs are in contiguous orderding i.e [1, 2, 3, 4] 12 | not [2, 3, 6, 10]. Please call `remap_label` before hand and `by_size` flag has no 13 | effect on the result. 14 | """ 15 | true = np.copy(true) 16 | pred = np.copy(pred) 17 | true_id_list = list(np.unique(true)) 18 | pred_id_list = list(np.unique(pred)) 19 | 20 | true_masks = [None,] 21 | for t in true_id_list[1:]: 22 | t_mask = np.array(true == t, np.uint8) 23 | true_masks.append(t_mask) 24 | 25 | pred_masks = [None,] 26 | for p in pred_id_list[1:]: 27 | p_mask = np.array(pred == p, np.uint8) 28 | pred_masks.append(p_mask) 29 | 30 | # prefill with value 31 | pairwise_inter = np.zeros([len(true_id_list) -1, 32 | len(pred_id_list) -1], dtype=np.float64) 33 | pairwise_union = np.zeros([len(true_id_list) -1, 34 | len(pred_id_list) -1], dtype=np.float64) 35 | 36 | # caching pairwise 37 | for true_id in true_id_list[1:]: # 0-th is background 38 | t_mask = true_masks[true_id] 39 | pred_true_overlap = pred[t_mask > 0] 40 | pred_true_overlap_id = np.unique(pred_true_overlap) 41 | pred_true_overlap_id = list(pred_true_overlap_id) 42 | for pred_id in pred_true_overlap_id: 43 | if pred_id == 0: # ignore 44 | continue # overlaping background 45 | p_mask = pred_masks[pred_id] 46 | total = (t_mask + p_mask).sum() 47 | inter = (t_mask * p_mask).sum() 48 | pairwise_inter[true_id-1, pred_id-1] = inter 49 | pairwise_union[true_id-1, pred_id-1] = total - inter 50 | # 51 | pairwise_iou = pairwise_inter / (pairwise_union + 1.0e-6) 52 | # pair of pred that give highest iou for each true, dont care 53 | # about reusing pred instance multiple times 54 | paired_pred = np.argmax(pairwise_iou, axis=1) 55 | pairwise_iou = np.max(pairwise_iou, axis=1) 56 | # exlude those dont have intersection 57 | paired_true = np.nonzero(pairwise_iou > 0.0)[0] 58 | paired_pred = paired_pred[paired_true] 59 | overall_inter = (pairwise_inter[paired_true, paired_pred]).sum() 60 | overall_union = (pairwise_union[paired_true, paired_pred]).sum() 61 | # 62 | paired_true = (list(paired_true + 1)) # index to instance ID 63 | paired_pred = (list(paired_pred + 1)) 64 | # add all unpaired GT and Prediction into the union 65 | unpaired_true = np.array([idx for idx in true_id_list[1:] if idx not in paired_true]) 66 | unpaired_pred = np.array([idx for idx in pred_id_list[1:] if idx not in paired_pred]) 67 | for true_id in unpaired_true: 68 | overall_union += true_masks[true_id].sum() 69 | for pred_id in unpaired_pred: 70 | overall_union += pred_masks[pred_id].sum() 71 | aji_score = overall_inter / overall_union 72 | return aji_score 73 | ##### 74 | def get_fast_aji_plus(true, pred): 75 | """ 76 | AJI+, an AJI version with maximal unique pairing to obtain overall intersecion. 77 | Every prediction instance is paired with at most 1 GT instance (1 to 1) mapping, unlike AJI 78 | where a prediction instance can be paired against many GT instances (1 to many). 79 | Remaining unpaired GT and Prediction instances will be added to the overall union. 80 | The 1 to 1 mapping prevents AJI's over-penalisation from happening. 81 | 82 | Fast computation requires instance IDs are in contiguous orderding i.e [1, 2, 3, 4] 83 | not [2, 3, 6, 10]. Please call `remap_label` before hand and `by_size` flag has no 84 | effect on the result. 85 | """ 86 | true = np.copy(true) # ? do we need this 87 | pred = np.copy(pred) 88 | true_id_list = list(np.unique(true)) 89 | pred_id_list = list(np.unique(pred)) 90 | 91 | true_masks = [None,] 92 | for t in true_id_list[1:]: 93 | t_mask = np.array(true == t, np.uint8) 94 | true_masks.append(t_mask) 95 | 96 | pred_masks = [None,] 97 | for p in pred_id_list[1:]: 98 | p_mask = np.array(pred == p, np.uint8) 99 | pred_masks.append(p_mask) 100 | 101 | # prefill with value 102 | pairwise_inter = np.zeros([len(true_id_list) -1, 103 | len(pred_id_list) -1], dtype=np.float64) 104 | pairwise_union = np.zeros([len(true_id_list) -1, 105 | len(pred_id_list) -1], dtype=np.float64) 106 | 107 | # caching pairwise 108 | for true_id in true_id_list[1:]: # 0-th is background 109 | t_mask = true_masks[true_id] 110 | pred_true_overlap = pred[t_mask > 0] 111 | pred_true_overlap_id = np.unique(pred_true_overlap) 112 | pred_true_overlap_id = list(pred_true_overlap_id) 113 | for pred_id in pred_true_overlap_id: 114 | if pred_id == 0: # ignore 115 | continue # overlaping background 116 | p_mask = pred_masks[pred_id] 117 | total = (t_mask + p_mask).sum() 118 | inter = (t_mask * p_mask).sum() 119 | pairwise_inter[true_id-1, pred_id-1] = inter 120 | pairwise_union[true_id-1, pred_id-1] = total - inter 121 | # 122 | pairwise_iou = pairwise_inter / (pairwise_union + 1.0e-6) 123 | #### Munkres pairing to find maximal unique pairing 124 | paired_true, paired_pred = linear_sum_assignment(-pairwise_iou) 125 | ### extract the paired cost and remove invalid pair 126 | paired_iou = pairwise_iou[paired_true, paired_pred] 127 | # now select all those paired with iou != 0.0 i.e have intersection 128 | paired_true = paired_true[paired_iou > 0.0] 129 | paired_pred = paired_pred[paired_iou > 0.0] 130 | paired_inter = pairwise_inter[paired_true, paired_pred] 131 | paired_union = pairwise_union[paired_true, paired_pred] 132 | paired_true = (list(paired_true + 1)) # index to instance ID 133 | paired_pred = (list(paired_pred + 1)) 134 | overall_inter = paired_inter.sum() 135 | overall_union = paired_union.sum() 136 | # add all unpaired GT and Prediction into the union 137 | unpaired_true = np.array([idx for idx in true_id_list[1:] if idx not in paired_true]) 138 | unpaired_pred = np.array([idx for idx in pred_id_list[1:] if idx not in paired_pred]) 139 | for true_id in unpaired_true: 140 | overall_union += true_masks[true_id].sum() 141 | for pred_id in unpaired_pred: 142 | overall_union += pred_masks[pred_id].sum() 143 | # 144 | aji_score = overall_inter / overall_union 145 | return aji_score 146 | ##### 147 | def get_fast_pq(true, pred, match_iou=0.5): 148 | """ 149 | `match_iou` is the IoU threshold level to determine the pairing between 150 | GT instances `p` and prediction instances `g`. `p` and `g` is a pair 151 | if IoU > `match_iou`. However, pair of `p` and `g` must be unique 152 | (1 prediction instance to 1 GT instance mapping). 153 | 154 | If `match_iou` < 0.5, Munkres assignment (solving minimum weight matching 155 | in bipartite graphs) is caculated to find the maximal amount of unique pairing. 156 | 157 | If `match_iou` >= 0.5, all IoU(p,g) > 0.5 pairing is proven to be unique and 158 | the number of pairs is also maximal. 159 | 160 | Fast computation requires instance IDs are in contiguous orderding 161 | i.e [1, 2, 3, 4] not [2, 3, 6, 10]. Please call `remap_label` beforehand 162 | and `by_size` flag has no effect on the result. 163 | 164 | Returns: 165 | [dq, sq, pq]: measurement statistic 166 | 167 | [paired_true, paired_pred, unpaired_true, unpaired_pred]: 168 | pairing information to perform measurement 169 | 170 | """ 171 | assert match_iou >= 0.0, "Cant' be negative" 172 | 173 | true = np.copy(true) 174 | pred = np.copy(pred) 175 | true_id_list = list(np.unique(true)) 176 | pred_id_list = list(np.unique(pred)) 177 | 178 | true_masks = [None,] 179 | for t in true_id_list[1:]: 180 | t_mask = np.array(true == t, np.uint8) 181 | true_masks.append(t_mask) 182 | 183 | pred_masks = [None,] 184 | for p in pred_id_list[1:]: 185 | p_mask = np.array(pred == p, np.uint8) 186 | pred_masks.append(p_mask) 187 | 188 | # prefill with value 189 | pairwise_iou = np.zeros([len(true_id_list) -1, 190 | len(pred_id_list) -1], dtype=np.float64) 191 | 192 | # caching pairwise iou 193 | for true_id in true_id_list[1:]: # 0-th is background 194 | t_mask = true_masks[true_id] 195 | pred_true_overlap = pred[t_mask > 0] 196 | pred_true_overlap_id = np.unique(pred_true_overlap) 197 | pred_true_overlap_id = list(pred_true_overlap_id) 198 | for pred_id in pred_true_overlap_id: 199 | if pred_id == 0: # ignore 200 | continue # overlaping background 201 | p_mask = pred_masks[pred_id] 202 | total = (t_mask + p_mask).sum() 203 | inter = (t_mask * p_mask).sum() 204 | iou = inter / (total - inter) 205 | pairwise_iou[true_id-1, pred_id-1] = iou 206 | # 207 | if match_iou >= 0.5: 208 | paired_iou = pairwise_iou[pairwise_iou > match_iou] 209 | pairwise_iou[pairwise_iou <= match_iou] = 0.0 210 | paired_true, paired_pred = np.nonzero(pairwise_iou) 211 | paired_iou = pairwise_iou[paired_true, paired_pred] 212 | paired_true += 1 # index is instance id - 1 213 | paired_pred += 1 # hence return back to original 214 | else: # * Exhaustive maximal unique pairing 215 | #### Munkres pairing with scipy library 216 | # the algorithm return (row indices, matched column indices) 217 | # if there is multiple same cost in a row, index of first occurence 218 | # is return, thus the unique pairing is ensure 219 | # inverse pair to get high IoU as minimum 220 | paired_true, paired_pred = linear_sum_assignment(-pairwise_iou) 221 | ### extract the paired cost and remove invalid pair 222 | paired_iou = pairwise_iou[paired_true, paired_pred] 223 | 224 | # now select those above threshold level 225 | # paired with iou = 0.0 i.e no intersection => FP or FN 226 | paired_true = list(paired_true[paired_iou > match_iou] + 1) 227 | paired_pred = list(paired_pred[paired_iou > match_iou] + 1) 228 | paired_iou = paired_iou[paired_iou > match_iou] 229 | 230 | # get the actual FP and FN 231 | unpaired_true = [idx for idx in true_id_list[1:] if idx not in paired_true] 232 | unpaired_pred = [idx for idx in pred_id_list[1:] if idx not in paired_pred] 233 | # print(paired_iou.shape, paired_true.shape, len(unpaired_true), len(unpaired_pred)) 234 | 235 | # 236 | tp = len(paired_true) 237 | fp = len(unpaired_pred) 238 | fn = len(unpaired_true) 239 | # get the F1-score i.e DQ 240 | dq = tp / (tp + 0.5 * fp + 0.5 * fn) 241 | # get the SQ, no paired has 0 iou so not impact 242 | sq = paired_iou.sum() / (tp + 1.0e-6) 243 | 244 | return [dq, sq, dq * sq], [paired_true, paired_pred, unpaired_true, unpaired_pred] 245 | 246 | ##### 247 | def get_fast_dice_2(true, pred): 248 | """ 249 | Ensemble dice 250 | """ 251 | true = np.copy(true) 252 | pred = np.copy(pred) 253 | true_id = list(np.unique(true)) 254 | pred_id = list(np.unique(pred)) 255 | 256 | overall_total = 0 257 | overall_inter = 0 258 | 259 | true_masks = [np.zeros(true.shape)] 260 | for t in true_id[1:]: 261 | t_mask = np.array(true == t, np.uint8) 262 | true_masks.append(t_mask) 263 | 264 | pred_masks = [np.zeros(true.shape)] 265 | for p in pred_id[1:]: 266 | p_mask = np.array(pred == p, np.uint8) 267 | pred_masks.append(p_mask) 268 | 269 | for true_idx in range(1, len(true_id)): 270 | t_mask = true_masks[true_idx] 271 | pred_true_overlap = pred[t_mask > 0] 272 | pred_true_overlap_id = np.unique(pred_true_overlap) 273 | pred_true_overlap_id = list(pred_true_overlap_id) 274 | try: # blinly remove background 275 | pred_true_overlap_id.remove(0) 276 | except ValueError: 277 | pass # just mean no background 278 | for pred_idx in pred_true_overlap_id: 279 | p_mask = pred_masks[pred_idx] 280 | total = (t_mask + p_mask).sum() 281 | inter = (t_mask * p_mask).sum() 282 | overall_total += total 283 | overall_inter += inter 284 | 285 | return 2 * overall_inter / overall_total 286 | ##### 287 | 288 | #####--------------------------As pseudocode 289 | def get_dice_1(true, pred): 290 | """ 291 | Traditional dice 292 | """ 293 | # cast to binary 1st 294 | true = np.copy(true) 295 | pred = np.copy(pred) 296 | true[true > 0] = 1 297 | pred[pred > 0] = 1 298 | inter = true * pred 299 | denom = true + pred 300 | return 2.0 * np.sum(inter) / np.sum(denom) 301 | #### 302 | def get_dice_2(true, pred): 303 | true = np.copy(true) 304 | pred = np.copy(pred) 305 | true_id = list(np.unique(true)) 306 | pred_id = list(np.unique(pred)) 307 | # remove background aka id 0 308 | true_id.remove(0) 309 | pred_id.remove(0) 310 | 311 | total_markup = 0 312 | total_intersect = 0 313 | for t in true_id: 314 | t_mask = np.array(true == t, np.uint8) 315 | for p in pred_id: 316 | p_mask = np.array(pred == p, np.uint8) 317 | intersect = p_mask * t_mask 318 | if intersect.sum() > 0: 319 | total_intersect += intersect.sum() 320 | total_markup += (t_mask.sum() + p_mask.sum()) 321 | return 2 * total_intersect / total_markup 322 | ##### 323 | def remap_label(pred, by_size=False): 324 | """ 325 | Rename all instance id so that the id is contiguous i.e [0, 1, 2, 3] 326 | not [0, 2, 4, 6]. The ordering of instances (which one comes first) 327 | is preserved unless by_size=True, then the instances will be reordered 328 | so that bigger nucler has smaller ID 329 | 330 | Args: 331 | pred : the 2d array contain instances where each instances is marked 332 | by non-zero integer 333 | by_size : renaming with larger nuclei has smaller id (on-top) 334 | """ 335 | pred_id = list(np.unique(pred)) 336 | pred_id.remove(0) 337 | if len(pred_id) == 0: 338 | return pred # no label 339 | if by_size: 340 | pred_size = [] 341 | for inst_id in pred_id: 342 | size = (pred == inst_id).sum() 343 | pred_size.append(size) 344 | # sort the id by size in descending order 345 | pair_list = zip(pred_id, pred_size) 346 | pair_list = sorted(pair_list, key=lambda x: x[1], reverse=True) 347 | pred_id, pred_size = zip(*pair_list) 348 | 349 | new_pred = np.zeros(pred.shape, np.int32) 350 | for idx, inst_id in enumerate(pred_id): 351 | new_pred[pred == inst_id] = idx + 1 352 | return new_pred 353 | ##### 354 | def pair_coordinates(setA, setB, radius): 355 | """ 356 | Use the Munkres or Kuhn-Munkres algorithm to find the most optimal 357 | unique pairing (largest possible match) when pairing points in set B 358 | against points in set A, using distance as cost function 359 | 360 | Args: 361 | setA, setB: np.array (float32) of size Nx2 contains the of XY coordinate 362 | of N different points 363 | radius: valid area around a point in setA to consider 364 | a given coordinate in setB a candidate for match 365 | Return: 366 | pairing: pairing is an array of indices 367 | where point at index pairing[0] in set A paired with point 368 | in set B at index pairing[1] 369 | unparedA, unpairedB: remaining poitn in set A and set B unpaired 370 | """ 371 | 372 | # * Euclidean distance as the cost matrix 373 | setA_tile = np.expand_dims(setA, axis=1) 374 | setB_tile = np.expand_dims(setB, axis=0) 375 | setA_tile = np.repeat(setA_tile, setB.shape[0], axis=1) 376 | setB_tile = np.repeat(setB_tile, setA.shape[0], axis=0) 377 | pair_distance = (setA_tile - setB_tile) ** 2 378 | # set A is row, and set B is paired against set A 379 | pair_distance = np.sqrt(np.sum(pair_distance, axis=-1)) 380 | 381 | # * Munkres pairing with scipy library 382 | # the algorithm return (row indices, matched column indices) 383 | # if there is multiple same cost in a row, index of first occurence 384 | # is return, thus the unique pairing is ensured 385 | indicesA, paired_indicesB = linear_sum_assignment(pair_distance) 386 | 387 | # extract the paired cost and remove instances 388 | # outside of designated radius 389 | pair_cost = pair_distance[indicesA, paired_indicesB] 390 | 391 | pairedA = indicesA[pair_cost <= radius] 392 | pairedB = paired_indicesB[pair_cost <= radius] 393 | 394 | unpairedA = [idx for idx in range(setA.shape[0]) if idx not in list(pairedA)] 395 | unpairedB = [idx for idx in range(setB.shape[0]) if idx not in list(pairedB)] 396 | 397 | pairing = np.array(list(zip(pairedA, pairedB))) 398 | unpairedA = np.array(unpairedA, dtype=np.int64) 399 | unpairedB = np.array(unpairedB, dtype=np.int64) 400 | 401 | return pairing, unpairedA, unpairedB 402 | -------------------------------------------------------------------------------- /model/sonnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import string 4 | import collections 5 | 6 | import tensorflow as tf 7 | import cv2 8 | 9 | from tensorpack import * 10 | from tensorpack.models import BatchNorm, BNReLU, Conv2D, GlobalAvgPooling, Dropout 11 | from tensorpack.tfutils.summary import add_moving_summary, add_param_summary, add_activation_summary 12 | from tensorpack.tfutils.scope_utils import under_name_scope, auto_reuse_variable_scope 13 | from tensorpack.tfutils import optimizer 14 | from tensorpack.tfutils.gradproc import GlobalNormClip 15 | 16 | 17 | import sys 18 | sys.path.append("..") # adds higher directory to python modules path. 19 | 20 | from .utils import * 21 | 22 | 23 | try: # HACK: import beyond current level, may need to restructure 24 | from config import Config 25 | except ImportError: 26 | assert False, 'Fail to import config.py' 27 | 28 | #### 29 | def upsample2x(name, x): 30 | """ 31 | Nearest neighbor up-sampling 32 | """ 33 | return FixedUnPooling( 34 | name, x, 2, unpool_mat=np.ones((2, 2), dtype='float32'), 35 | data_format='channels_first') 36 | 37 | def res_blk(name, l, ch, ksize, count, split=1, strides=1): 38 | ch_in = l.get_shape().as_list() 39 | with tf.variable_scope(name): 40 | for i in range(0, count): 41 | with tf.variable_scope('block' + str(i)): 42 | x = l if i == 0 else BNReLU('preact', l) 43 | x = Conv2D('conv1', x, ch[0], ksize[0], activation=BNReLU) 44 | x = Conv2D('conv2', x, ch[1], ksize[1], split=split, 45 | strides=strides if i == 0 else 1, activation=BNReLU) 46 | x = Conv2D('conv3', x, ch[2], ksize[2], activation=tf.identity) 47 | if (strides != 1 or ch_in[1] != ch[2]) and i == 0: 48 | l = Conv2D('convshortcut', l, ch[2], 1, strides=strides) 49 | l = l + x 50 | # end of each group need an extra activation 51 | l = BNReLU('bnlast',l) 52 | return l 53 | 54 | def dense_blk(name, l, ch, ksize, count, split=1, padding='valid'): 55 | with tf.variable_scope(name): 56 | for i in range(0, count): 57 | with tf.variable_scope('blk/' + str(i)): 58 | x = BNReLU('preact_bna', l) 59 | x = Conv2D('conv1', x, ch[0], ksize[0], padding=padding, activation=BNReLU) 60 | x = Conv2D('conv2', x, ch[1], ksize[1], padding=padding, split=split) 61 | ## 62 | if padding == 'valid': 63 | x_shape = x.get_shape().as_list() 64 | l_shape = l.get_shape().as_list() 65 | l = crop_op(l, (l_shape[2] - x_shape[2], 66 | l_shape[3] - x_shape[3])) 67 | 68 | l = tf.concat([l, x], axis=1) 69 | l = BNReLU('blk_bna', l) 70 | return l 71 | 72 | #### 73 | @layer_register(log_shape=True) 74 | def resize_bilinear(i, size, align_corners=False): 75 | ret = tf.transpose(i, [0, 2, 3, 1]) 76 | ret = tf.image.resize_bilinear(ret, size=[size, size], align_corners=align_corners) 77 | ret = tf.transpose(ret, [0, 3, 1, 2]) 78 | return tf.identity(ret, name='output') 79 | 80 | #### 81 | def resize_nearest_neighbor(i, size): 82 | ret = tf.transpose(i, (0, 2, 3, 1)) 83 | ret = tf.image.resize_nearest_neighbor(ret, (size, size)) 84 | ret = tf.transpose(ret, (0, 3, 1, 2)) 85 | return ret 86 | 87 | #### 88 | @layer_register(log_shape=True) 89 | def DepthConv(x, out_channel, kernel_shape, padding='SAME', stride=1, 90 | W_init=None, activation=tf.identity): 91 | in_shape = x.get_shape().as_list() 92 | in_channel = in_shape[1] 93 | assert out_channel % in_channel == 0, (out_channel, in_channel) 94 | channel_mult = out_channel // in_channel 95 | 96 | if W_init is None: 97 | W_init = tf.variance_scaling_initializer(scale=2.0, mode='fan_out') 98 | kernel_shape = [kernel_shape, kernel_shape] 99 | filter_shape = kernel_shape + [in_channel, channel_mult] 100 | 101 | W = tf.get_variable('W', filter_shape, initializer=W_init) 102 | conv = tf.nn.depthwise_conv2d(x, W, [1, 1, stride, stride], padding=padding, data_format='NCHW') 103 | return activation(conv, name='output') 104 | 105 | 106 | BlockArgs = collections.namedtuple('BlockArgs', [ 107 | 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', 108 | 'expand_ratio', 'id_skip', 'strides', 'se_ratio' 109 | ]) 110 | 111 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) 112 | 113 | DEFAULT_BLOCKS_ARGS = [ 114 | BlockArgs(kernel_size=3, num_repeat=1, input_filters=32, output_filters=16, 115 | expand_ratio=1, id_skip=True, strides=[1, 1], se_ratio=0.25), 116 | BlockArgs(kernel_size=3, num_repeat=2, input_filters=16, output_filters=24, 117 | expand_ratio=6, id_skip=True, strides=[2, 2], se_ratio=0.25), 118 | BlockArgs(kernel_size=5, num_repeat=2, input_filters=24, output_filters=40, 119 | expand_ratio=6, id_skip=True, strides=[2, 2], se_ratio=0.25), 120 | BlockArgs(kernel_size=3, num_repeat=3, input_filters=40, output_filters=80, 121 | expand_ratio=6, id_skip=True, strides=[2, 2], se_ratio=0.25), 122 | BlockArgs(kernel_size=5, num_repeat=3, input_filters=80, output_filters=112, 123 | expand_ratio=6, id_skip=True, strides=[1, 1], se_ratio=0.25), 124 | BlockArgs(kernel_size=5, num_repeat=4, input_filters=112, output_filters=192, 125 | expand_ratio=6, id_skip=True, strides=[2, 2], se_ratio=0.25), 126 | BlockArgs(kernel_size=3, num_repeat=1, input_filters=192, output_filters=320, 127 | expand_ratio=6, id_skip=True, strides=[1, 1], se_ratio=0.25) 128 | ] 129 | 130 | 131 | def round_filters(filters, width_coefficient, depth_divisor): 132 | """Round number of filters based on width multiplier""" 133 | filters *= width_coefficient 134 | new_filters = int(filters + depth_divisor / 2) // depth_divisor * depth_divisor 135 | new_filters = max(depth_divisor, new_filters) 136 | # Make sure that round down does not go down by more than 10% 137 | if new_filters < 0.9 * filters: 138 | new_filters += depth_divisor 139 | return int(new_filters) 140 | 141 | def round_repeats(repeats, depth_coefficient): 142 | """Round number of repeats based on depth multiplier""" 143 | return int(math.ceil(depth_coefficient * repeats)) 144 | 145 | def mb_conv_block(inputs, block_args, activation, drop_rate=None, freeze_en=False, prefix='',): 146 | """Mobile Inverted Residual Bottleneck.""" 147 | has_se = (block_args.se_ratio is not None) and (0 < block_args.se_ratio <= 1) 148 | bn_axis = 1 149 | # For naming convolutional, batchnorm layers to match pretrained weights 150 | num_conv = '' 151 | num_batch = 0 152 | # Expansion phase 153 | filters = block_args.input_filters * block_args.expand_ratio 154 | if block_args.expand_ratio != 1: 155 | x = Conv2D(prefix + '/conv2d', inputs, filters, 1, padding='same', use_bias=False) 156 | num_conv = '_1' 157 | x = BatchNorm(prefix + '/tpu_batch_normalization', x) 158 | num_batch += 1 159 | x = activation(x) 160 | else: 161 | x = inputs 162 | # Depthwise convolution 163 | x = DepthConv(prefix+'/depthwise_conv2d', x, x.get_shape().as_list()[1], block_args.kernel_size, stride=block_args.strides[0]) 164 | if num_batch != 0: 165 | x = BatchNorm(prefix + '/tpu_batch_normalization' + '_' + str(num_batch), x) 166 | else: 167 | x = BatchNorm(prefix + '/tpu_batch_normalization', x) 168 | num_batch += 1 169 | x = activation(x) 170 | 171 | # Squeeze and Excitation phase 172 | if has_se: 173 | num_reduced_filters = max(1, int(block_args.input_filters * block_args.se_ratio)) 174 | se_tensor = GlobalAvgPooling(prefix+'/se/squeeze', x, data_format='NCHW') 175 | target_shape = [-1, filters, 1, 1] 176 | se_tensor = tf.reshape(se_tensor, target_shape, name=prefix + '/se/reshape') 177 | 178 | se_tensor = Conv2D(prefix + '/se/conv2d', se_tensor, num_reduced_filters, 1, activation=activation, padding='same', use_bias=True) 179 | se_tensor = Conv2D(prefix + '/se/conv2d_1', se_tensor, filters, 1, activation=tf.sigmoid, padding='same', use_bias=True) 180 | x = tf.multiply(x, se_tensor, name=prefix + '/se/excite') 181 | # Output phase 182 | x = Conv2D(prefix + '/conv2d' + num_conv, x, block_args.output_filters, 1, padding='same', use_bias=False) 183 | x = tf.stop_gradient(x) if freeze_en else x 184 | x = BatchNorm(prefix + '/tpu_batch_normalization' + '_' + str(num_batch), x) 185 | if block_args.id_skip and all(s==1 for s in block_args.strides) and block_args.input_filters == block_args.output_filters: 186 | if drop_rate and (drop_rate > 0): 187 | x = Dropout(x, rate=drop_rate, noise_shape=(tf.shape(x)[0], 1, 1, 1), name=prefix + 'drop') 188 | x = tf.math.add(x, inputs, name=prefix + 'add') 189 | x = tf.stop_gradient(x) if freeze_en else x 190 | return x 191 | 192 | def EfficientNet(i, width_coefficient, depth_coefficient, default_resolution, dropout_rate=0.2, drop_connect_rate=0.2, 193 | depth_divisor=8, block_args=DEFAULT_BLOCKS_ARGS, freeze_en=False, model_name='efficientnet'): 194 | """Instantiates the EfficientNet architecture using given scaling coefficient. 195 | # Arguments: 196 | width_coefficient: float, scaling coefficient for network width 197 | depth_coefficient: float, scaling coefficient for network depth 198 | default_resolution: int, default resolution of input 199 | dropout_rate: float, dropout rate before final classifier layer. 200 | drop_connect_rate: float, dropout rate at skip connection 201 | depth_divisor: int. 202 | block_args: A list of BlockArgs to construct block modules. 203 | model_name: string, model name. 204 | # Returns: 205 | EfficientNet as a backbone encoder. 206 | """ 207 | 208 | bn_axis = 1 209 | activation = tf.nn.swish 210 | with tf.variable_scope(model_name): 211 | with tf.variable_scope('stem'): 212 | # Build stem 213 | x = Conv2D('conv2d', i, round_filters(32, width_coefficient, depth_divisor), 3, strides=(1, 1), padding='same', use_bias=False) 214 | x = BatchNorm('tpu_batch_normalization', x) 215 | x = activation(x) 216 | 217 | # Build blocks 218 | num_blocks_total = sum(block_args.num_repeat for block_args in block_args) 219 | block_num = 0 220 | ret = [] 221 | for idx, block_args in enumerate(block_args): 222 | assert block_args.num_repeat > 0 223 | # Update block input and output filters based on depth multiplier. 224 | block_args = block_args._replace( 225 | input_filters=round_filters(block_args.input_filters, width_coefficient, depth_divisor), 226 | output_filters=round_filters(block_args.output_filters, width_coefficient, depth_divisor), 227 | num_repeat=round_repeats(block_args.num_repeat, depth_coefficient) 228 | ) 229 | # The first block needs to take care of stride and filter size increase. 230 | drop_rate = drop_connect_rate * float(block_num) / num_blocks_total 231 | x = mb_conv_block(x, block_args, activation=activation, drop_rate=drop_rate, freeze_en=freeze_en, prefix='blocks_{}'.format(block_num)) 232 | if block_num==0 or block_num==2 or block_num==4 or block_num==10: 233 | ret.append(x) 234 | x = tf.stop_gradient(x) if freeze_en else x 235 | block_num += 1 236 | if block_args.num_repeat > 1: 237 | block_args = block_args._replace(input_filters=block_args.output_filters, strides=[1, 1]) 238 | for bidx in range(block_args.num_repeat - 1): 239 | drop_rate = drop_connect_rate * float(block_num) / num_blocks_total 240 | block_prefix = 'blocks_{}'.format(block_num) 241 | x = mb_conv_block(x, block_args, activation=activation, drop_rate = drop_rate, freeze_en=freeze_en, prefix=block_prefix) 242 | if block_num==0 or block_num == 2 or block_num==4 or block_num==10 : 243 | ret.append(x) 244 | x = tf.stop_gradient(x) if freeze_en else x 245 | block_num += 1 246 | 247 | # Build head 248 | with tf.variable_scope('head'): 249 | x = Conv2D('conv2d', x, round_filters(1280, width_coefficient, depth_divisor), 1, padding='same', use_bias=False) 250 | x = tf.stop_gradient(x) if freeze_en else x 251 | x = Conv2D('conv_bot', x, 1024, 1) 252 | ret.append(x) 253 | return ret 254 | 255 | 256 | ### 257 | def EfficientNetB0(x, freeze_en): 258 | return EfficientNet(x, 1.0, 1.0, 224, 0.2, freeze_en=freeze_en, model_name='efficientnet-b0') 259 | 260 | ### 261 | def EfficientNetB1(x, freeze_en): 262 | return EfficientNet(x, 1.0, 1.1, 240, 0.2, freeze_en=freeze_en, model_name='efficientnet-b1') 263 | 264 | ### 265 | def EfficientNetB2(x, freeze_en): 266 | return EfficientNet(x, 1.1, 1.2, 260, 0.3, freeze_en=freeze_en, model_name='efficientnet-b2') 267 | 268 | ### 269 | def EfficientNetB3(x, freeze_en): 270 | return EfficientNet(x, 1.2, 1.4, 300, 0.3, freeze_en=freeze_en, model_name='efficientnet-b3') 271 | ### 272 | def EfficientNetB4(x, freeze_en): 273 | return EfficientNet(x, 1.4, 1.8, 380, 0.4, freeze_en=freeze_en, model_name='efficientnet-b4') 274 | 275 | #### 276 | def decoder(name, i): 277 | pad = 'valid' 278 | with tf.variable_scope(name): 279 | with tf.variable_scope('S5'): 280 | u5 = Conv2D('bottleneck', i[-1], 256, 1, strides=1, activation=BNReLU) 281 | u5_x2 = upsample2x('deux', u5) 282 | with tf.variable_scope('S4'): 283 | u4 = Conv2D('bottleneck', i[-2], 256, 1, strides=1, activation=BNReLU) 284 | u4_add = tf.add_n([u4, u5_x2]) 285 | u4 = Conv2D('conva', u4_add, 256, 5, strides=1, padding=pad) 286 | with tf.variable_scope('_add_1'): 287 | u4_10 = BNReLU('preact', u4) 288 | u4_11 = Conv2D('_1', u4_10, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU) 289 | u4_12 = Conv2D('_2', u4_11, 32, 3, padding=pad, dilation_rate=(1, 1)) 290 | u4_12 = crop_op(u4_12, (6, 6)) 291 | with tf.variable_scope('_add_2'): 292 | u4_20 = BNReLU('preact', u4) 293 | u4_21 = Conv2D('_1', u4_20, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU) 294 | u4_22 = Conv2D('_2', u4_21, 32, 5, padding=pad, dilation_rate=(1, 1)) 295 | u4_22 = crop_op(u4_22, (4, 4)) 296 | with tf.variable_scope('_add_3'): 297 | u4_30 = BNReLU('preact', u4) 298 | u4_31 = Conv2D('_1', u4_30, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU) 299 | u4_32 = Conv2D('_2', u4_31, 32, 3, padding=pad, dilation_rate=(2, 2)) 300 | u4_32 = crop_op(u4_32, (4, 4)) 301 | with tf.variable_scope('_add_4'): 302 | u4_40 = BNReLU('preact', u4) 303 | u4_41 = Conv2D('_1', u4_40, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU) 304 | u4_42 = Conv2D('_2', u4_41, 32, 5, padding=pad, dilation_rate=(2, 2), activation=BNReLU) 305 | u4 = crop_op(u4, (8, 8)) 306 | u4_cat = tf.concat([u4, u4_12, u4_22, u4_32, u4_42], axis=1, name='_cat') 307 | u4_cat = BNReLU('outcat_BNReLU', u4_cat) 308 | u4_cat = Conv2D('_outcat', u4_cat, 256, 3, padding=pad, strides=1, activation=BNReLU) 309 | u4_x2 = upsample2x('deux', u4_cat) 310 | 311 | with tf.variable_scope('S3'): 312 | u3 = Conv2D('bottleneck', i[-3], 256, 1, strides=1, activation=BNReLU) 313 | u3 = crop_op(u3, (28, 28)) 314 | u3_add = tf.add_n([u3, u4_x2]) 315 | u3 = Conv2D('conva', u3_add, 256, 5, strides=1, padding=pad) 316 | with tf.variable_scope('_add_1'): 317 | u3_10 = BNReLU('preact', u3) 318 | u3_11 = Conv2D('_1', u3_10, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU) 319 | u3_12 = Conv2D('_2', u3_11, 32, 3, padding=pad, dilation_rate=(1, 1)) 320 | u3_12 = crop_op(u3_12, (6, 6)) 321 | with tf.variable_scope('_add_2'): 322 | u3_20 = BNReLU('preact', u3) 323 | u3_21 = Conv2D('_1', u3_20, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU) 324 | u3_22 = Conv2D('_2', u3_21, 32, 5, padding=pad, dilation_rate=(1, 1)) 325 | u3_22 = crop_op(u3_22, (4, 4)) 326 | with tf.variable_scope('_add_3'): 327 | u3_30 = BNReLU('preact', u3) 328 | u3_31 = Conv2D('_1', u3_30, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU) 329 | u3_32 = Conv2D('_2', u3_31, 32, 3, padding=pad, dilation_rate=(2, 2)) 330 | u3_32 = crop_op(u3_32, (4, 4)) 331 | with tf.variable_scope('_add_4'): 332 | u3_40 = BNReLU('preact', u3) 333 | u3_41 = Conv2D('_1', u3_40, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU) 334 | u3_42 = Conv2D('_2', u3_41, 32, 5, padding=pad, dilation_rate=(2, 2)) 335 | 336 | u3 = crop_op(u3, (8, 8)) 337 | u3_cat = tf.concat([u3, u3_12, u3_22, u3_32, u3_42], axis=1, name='_cat') 338 | u3_cat = BNReLU('outcat_BNReLU', u3_cat) 339 | u3_cat = Conv2D('_outcat', u3_cat, 256, 3, padding=pad, strides=1, activation=BNReLU) 340 | u3_x2 = upsample2x('deux', u3_cat) 341 | 342 | with tf.variable_scope('S2'): 343 | u2 = Conv2D('bottleneck', i[-4], 256, 1, strides=1, activation=BNReLU) 344 | u2 = crop_op(u2, (83, 83)) 345 | u2_add = tf.add_n([u2, u3_x2]) 346 | u2 = Conv2D('conva', u2_add, 256, 5, strides=1, padding=pad) 347 | with tf.variable_scope('_add_1'): 348 | u2_10 = BNReLU('preact', u2) 349 | u2_11 = Conv2D('_1', u2_10, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU) 350 | u2_12 = Conv2D('_2', u2_11, 32, 3, padding=pad, dilation_rate=(1, 1)) 351 | u2_12 = crop_op(u2_12, (6, 6)) 352 | with tf.variable_scope('_add_2'): 353 | u2_20 = BNReLU('preact', u2) 354 | u2_21 = Conv2D('_1', u2_20, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU) 355 | u2_22 = Conv2D('_2', u2_21, 32, 5, padding=pad, dilation_rate=(1, 1)) 356 | u2_22 = crop_op(u2_22, (4, 4)) 357 | with tf.variable_scope('_add_3'): 358 | u2_30 = BNReLU('preact', u2) 359 | u2_31 = Conv2D('_1', u2_30, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU) 360 | u2_32 = Conv2D('_2', u2_31, 32, 3, padding=pad, dilation_rate=(2, 2)) 361 | u2_32 = crop_op(u2_32, (4, 4)) 362 | with tf.variable_scope('_add_4'): 363 | u2_40 = BNReLU('preact', u2) 364 | u2_41 = Conv2D('_1', u2_40, 128, 1, padding=pad, dilation_rate=(1, 1), activation=BNReLU) 365 | u2_42 = Conv2D('_2', u2_41, 32, 5, padding=pad, dilation_rate=(2, 2)) 366 | 367 | u2 = crop_op(u2, (8, 8)) 368 | u2_cat = tf.concat([u2, u2_12, u2_22, u2_32, u2_42], axis=1, name='_cat') 369 | u2_cat = BNReLU('outcat_BNReLU', u2_cat) 370 | u2_cat = Conv2D('_outcat', u2_cat, 256, 3, padding='same', strides=1, activation=BNReLU) 371 | 372 | 373 | with tf.variable_scope('Pyramid'): 374 | p1 = Conv2D('p1', u2_cat, 128, 5, padding='same', activation=BNReLU) 375 | p1 = Conv2D('p1_1', p1, 128, 5, padding='same', activation=BNReLU) 376 | p2 = Conv2D('p2', u3_cat, 128, 5, padding=pad, activation=BNReLU) 377 | p2 = Conv2D('p2_1', p2, 128, 3, padding=pad, activation=BNReLU) 378 | p2_x2 = upsample2x('deux_p2', p2) 379 | p3 = crop_op(u4_cat, (2, 2)) 380 | p3 = Conv2D('p3', p3, 128, 5, padding=pad, activation=BNReLU) 381 | p3 = Conv2D('p3_1', p3, 128, 5, padding=pad, activation=BNReLU) 382 | p3_x2 = upsample2x('deux_p3', p3) 383 | p3_x4 = upsample2x('quatre_p3', p3_x2) 384 | p4 = crop_op(u5, (4, 4)) 385 | p4 = Conv2D('p4', p4, 128, 5, padding=pad, activation=BNReLU) 386 | p4 = Conv2D('p4_1', p4, 128, 5, padding=pad, activation=BNReLU) 387 | p4_x2 = upsample2x('deux_p4', p4) 388 | p4_x4 = upsample2x('quatre_p4', p4_x2) 389 | p4_x8 = upsample2x('huit_p4', p4_x4) 390 | p_cat = tf.concat([p1, p2_x2, p3_x4, p4_x8], axis=1) 391 | p_cat = Conv2D('_outcat_1', p_cat, 256, 5, padding='same', activation=BNReLU) 392 | p_cat = Conv2D('_outcat_2', p_cat, 256, 5, padding='same', activation=BNReLU) 393 | p_cat_x2 = upsample2x('deux_cat', p_cat) 394 | 395 | i_5 = crop_op(i[0], (190, 190)) 396 | i_5 = Conv2D('i_5', i_5, 256, 1, activation=BNReLU) 397 | p_cat = tf.add_n([i_5, p_cat_x2]) 398 | p_0 = Conv2D('p_0', p_cat, 128, 5, strides=1, padding='valid') 399 | 400 | return p_0 401 | 402 | 403 | class Model(ModelDesc, Config): 404 | def __init__(self, freeze_en=False): 405 | super(Model, self).__init__() 406 | assert tf.test.is_gpu_available() 407 | self.freeze_en = freeze_en 408 | self.data_format = 'NCHW' 409 | 410 | 411 | def _get_inputs(self): 412 | return [InputDesc(tf.float32, [None] + self.train_input_shape + [3], 'images'), 413 | InputDesc(tf.float32, [None] + self.train_mask_shape + [None], 'truemap-coded')] 414 | 415 | # for node to receive manual info such as learning rate. 416 | def add_manual_variable(self, name, init_value, summary=True): 417 | var = tf.get_variable(name, initializer=init_value, trainable=False) 418 | if summary: 419 | tf.summary.scalar(name + '-summary', var) 420 | return 421 | 422 | def optimizer(self): 423 | with tf.variable_scope("", reuse=True): 424 | lr = tf.get_variable('learning_rate') 425 | opt = self.train_optimizer(learning_rate=lr) 426 | return optimizer.apply_grad_processors(opt, [GlobalNormClip(1.)]) 427 | 428 | 429 | class Sonnet(Model): 430 | 431 | def build_graph(self, inputs, truemap_coded): 432 | images = inputs 433 | orig_imgs = images 434 | 435 | 436 | 437 | if hasattr(self, 'type_classification') and self.type_classification: 438 | true_type = truemap_coded[..., 1] 439 | true_type = tf.cast(true_type, tf.int32) 440 | true_type = tf.identity(true_type, name='truemap-type') 441 | one_type = tf.one_hot(true_type, self.nr_types, axis=-1) 442 | true_type = tf.expand_dims(true_type, axis=-1) 443 | 444 | true_np = tf.cast(true_type > 0, tf.int32) 445 | true_np = tf.identity(true_np, name='truemap-np') 446 | one_np = tf.one_hot(tf.squeeze(true_np, axis=-1), 2, axis=-1) 447 | 448 | 449 | else: 450 | true_np = truemap_coded[..., 0] 451 | true_np = tf.cast(true_np, tf.int32) 452 | one_np = tf.one_hot(true_np, 2, axis=-1) 453 | true_np = tf.expand_dims(true_np, axis=-1) 454 | true_np = tf.identity(true_np, name='truemap-np') 455 | 456 | 457 | true_ord = truemap_coded[...,-1] 458 | true_ord = tf.expand_dims(true_ord, axis=-1) 459 | true_ord = tf.identity(true_ord, name='true-ord') 460 | 461 | #### 462 | with argscope(Conv2D, activation=tf.identity, use_bias=False, 463 | W_init=tf.variance_scaling_initializer(scale=2.0, mode='fan_out')),\ 464 | argscope([Conv2D, BatchNorm], data_format=self.data_format): 465 | i = tf.transpose(images, [0, 3, 1, 2]) 466 | i = i if not self.input_norm else i / 255.0 467 | 468 | #### 469 | d = EfficientNetB0(i, self.freeze_en) 470 | 471 | np_feat = decoder('np', d) 472 | np_feat = tf.identity(np_feat, name='np_feat') 473 | npx = BNReLU('preact_out_np', np_feat) 474 | 475 | 476 | ordi_feat = decoder('ordi', d) 477 | ordi = BNReLU('preact_out_ordi', ordi_feat) 478 | 479 | if self.type_classification: 480 | tp_feat = decoder('tp', d) 481 | tp = BNReLU('preact_out_tp', tp_feat) 482 | 483 | # Nuclei Type Pixels (NT) 484 | logi_class = Conv2D('conv_out_tp', tp, self.nr_types, 1, use_bias=True, activation=tf.identity) 485 | logi_class = tf.transpose(logi_class, [0, 2, 3, 1]) 486 | soft_class = tf.nn.softmax(logi_class, axis=-1) 487 | 488 | ### Nuclei Pixels (NF) 489 | logi_np = Conv2D('conv_out_np', npx, 2, 1, use_bias=True, activation=tf.identity) 490 | logi_np = tf.transpose(logi_np, [0, 2, 3, 1]) 491 | soft_np = tf.nn.softmax(logi_np, axis=-1) 492 | prob_np = tf.identity(soft_np[...,1], name='predmap-prob-np') 493 | prob_np = tf.expand_dims(prob_np, axis=-1) 494 | 495 | 496 | ### Ordinal (NO) 497 | logi_ord = Conv2D('conv_out_ord', ordi, 16, 1, use_bias=True, activation=tf.identity) 498 | logi_ord_t = tf.transpose(logi_ord, [0, 2, 3, 1]) 499 | N, C ,H, W = logi_ord.get_shape().as_list() 500 | ord_num = int(C/2) 501 | logi_ord = tf.reshape(logi_ord_t, shape=[-1, H, W, ord_num, 2]) 502 | prob_ord = tf.nn.softmax(logi_ord, axis=-1) 503 | prob_ord = tf.identity(prob_ord, name='prob_ord') 504 | nn_out_labels = tf.reduce_sum(tf.argmax(prob_ord, axis=-1, output_type=tf.int32), axis=3, keepdims=True) 505 | pred_ord = tf.identity(nn_out_labels, name='predmap-ord') 506 | 507 | ### encoded so that inference can extract all output at once 508 | predmap_coded = tf.concat([soft_class, prob_np], axis=-1, name='predmap-coded') 509 | 510 | 511 | 512 | 513 | 514 | def loss_ord(prob_ord, true_ord, name=None): 515 | (ord_n_pk, ord_pk) = tf.unstack(prob_ord, axis=-1) 516 | epsilon = tf.convert_to_tensor(10e-8, ord_n_pk.dtype.base_dtype) 517 | ord_n_pk, ord_pk = tf.clip_by_value(ord_n_pk, epsilon, 1 - epsilon), tf.clip_by_value(ord_pk, epsilon, 1 - epsilon) 518 | ord_log_n_pk = tf.log(ord_n_pk) 519 | ord_log_pk = tf.log(ord_pk) 520 | (N, H, W, C) = ord_log_pk.get_shape().as_list() 521 | foreground_mask = tf.reshape(tf.sequence_mask(true_ord, C), shape=[-1, H, W, C]) 522 | sum_of_p = tf.reduce_sum(tf.where(foreground_mask, ord_log_pk, ord_log_n_pk), axis=3) 523 | loss = -tf.reduce_mean(sum_of_p) 524 | loss = tf.identity(loss, name=name) 525 | return loss 526 | 527 | 528 | 529 | #### 530 | if get_current_tower_context().is_training: 531 | #---- LOSS ----# 532 | loss = 0 533 | 534 | for term, weight in self.loss_term.items(): 535 | if term == 'bce': 536 | term_loss = categorical_crossentropy_modified(soft_np, one_np) 537 | term_loss = tf.reduce_mean(term_loss, name='loss-bce') 538 | elif term == 'dice': 539 | term_loss = dice_loss(soft_np[...,0], one_np[...,0]) \ 540 | + dice_loss(soft_np[...,1], one_np[...,1]) 541 | term_loss = tf.identity(term_loss, name='loss-dice') 542 | elif term == 'ord': 543 | term_loss = loss_ord(prob_ord, true_ord, name='loss-ord') 544 | else: 545 | assert False, 'Not support loss term: %s' % term 546 | add_moving_summary(term_loss) 547 | loss += term_loss * weight 548 | 549 | if self.type_classification: 550 | term_loss = focal_loss_modified(soft_class, one_type) 551 | term_loss = tf.reduce_mean(term_loss, name='loss-classification') 552 | add_moving_summary(term_loss) 553 | loss = loss + term_loss 554 | 555 | 556 | 557 | 558 | self.cost = tf.identity(loss, name='overall-loss') 559 | add_moving_summary(self.cost) 560 | #### 561 | 562 | add_param_summary(('.*/W', ['histogram'])) # monitor W 563 | 564 | return 565 | 566 | 567 | class Sonnet_phase2(Model): 568 | 569 | def build_graph(self, inputs, truemap_coded): 570 | images = inputs 571 | orig_imgs = images 572 | 573 | if hasattr(self, 'type_classification') and self.type_classification: 574 | true_type = truemap_coded[..., 1] 575 | true_type = tf.cast(true_type, tf.int32) 576 | true_type = tf.identity(true_type, name='truemap-type') 577 | one_type = tf.one_hot(true_type, self.nr_types, axis=-1) 578 | true_type = tf.expand_dims(true_type, axis=-1) 579 | 580 | true_np = tf.cast(true_type > 0, tf.int32) 581 | true_np = tf.identity(true_np, name='truemap-np') 582 | one_np = tf.one_hot(tf.squeeze(true_np, axis=-1), 2, axis=-1) 583 | 584 | 585 | else: 586 | true_np = truemap_coded[..., 0] 587 | true_np = tf.cast(true_np, tf.int32) 588 | one_np = tf.one_hot(true_np, 2, axis=-1) 589 | true_np = tf.expand_dims(true_np, axis=-1) 590 | true_np = tf.identity(true_np, name='truemap-np') 591 | 592 | 593 | true_ord = truemap_coded[...,-1] 594 | true_ord = tf.expand_dims(true_ord, axis=-1) 595 | true_ord = tf.identity(true_ord, name='true-ord') 596 | 597 | #### 598 | with argscope(Conv2D, activation=tf.identity, use_bias=False, 599 | W_init=tf.variance_scaling_initializer(scale=2.0, mode='fan_out')),\ 600 | argscope([Conv2D, BatchNorm], data_format=self.data_format): 601 | i = tf.transpose(images, [0, 3, 1, 2]) 602 | i = i if not self.input_norm else i / 255.0 603 | 604 | #### 605 | d = EfficientNetB0(i, self.freeze_en) 606 | 607 | np_feat = decoder('np', d) 608 | np_feat = tf.identity(np_feat, name='np_feat') 609 | npx = BNReLU('preact_out_np', np_feat) 610 | 611 | 612 | ordi_feat = decoder('ordi', d) 613 | ordi = BNReLU('preact_out_ordi', ordi_feat) 614 | 615 | if self.type_classification: 616 | tp_feat = decoder('tp', d) 617 | tp = BNReLU('preact_out_tp', tp_feat) 618 | 619 | # Nuclei Type Pixels (NT) 620 | logi_class = Conv2D('conv_out_tp', tp, self.nr_types, 1, use_bias=True, activation=tf.identity) 621 | logi_class = tf.transpose(logi_class, [0, 2, 3, 1]) 622 | soft_class = tf.nn.softmax(logi_class, axis=-1) 623 | 624 | ### Nuclei Pixels (NF) 625 | logi_np = Conv2D('conv_out_np', npx, 2, 1, use_bias=True, activation=tf.identity) 626 | logi_np = tf.transpose(logi_np, [0, 2, 3, 1]) 627 | soft_np = tf.nn.softmax(logi_np, axis=-1) 628 | prob_np = tf.identity(soft_np[...,1], name='predmap-prob-np') 629 | prob_np = tf.expand_dims(prob_np, axis=-1) 630 | 631 | 632 | ### Ordinal (NO) 633 | logi_ord = Conv2D('conv_out_ord', ordi, 16, 1, use_bias=True, activation=tf.identity) 634 | logi_ord_t = tf.transpose(logi_ord, [0, 2, 3, 1]) 635 | N, C ,H, W = logi_ord.get_shape().as_list() 636 | ord_num = int(C/2) 637 | logi_ord = tf.reshape(logi_ord_t, shape=[-1, H, W, ord_num, 2]) 638 | prob_ord = tf.nn.softmax(logi_ord, axis=-1) 639 | prob_ord = tf.identity(prob_ord, name='prob_ord') 640 | nn_out_labels = tf.reduce_sum(tf.argmax(prob_ord, axis=-1, output_type=tf.int32), axis=3, keepdims=True) # prediction output (sum to take the exact class, e.g.class 4) 641 | pred_ord = tf.identity(nn_out_labels, name='predmap-ord') # [N, 76, 76, 1] 642 | 643 | ### encoded so that inference can extract all output at once 644 | predmap_coded = tf.concat([soft_class, prob_np], axis=-1, name='predmap-coded') 645 | 646 | 647 | 648 | 649 | 650 | def loss_ord(prob_ord, true_ord, name=None): 651 | weight_map = tf.squeeze(true_ord, axis=-1) 652 | weight_map_1 = tf.where( 653 | tf.equal(weight_map, 1), 654 | 2 * tf.cast(tf.equal(weight_map, 1), tf.float32), 655 | tf.cast(tf.equal(weight_map, 2), tf.float32) 656 | ) 657 | (ord_n_pk, ord_pk) = tf.unstack(prob_ord, axis=-1) 658 | epsilon = tf.convert_to_tensor(10e-8, ord_n_pk.dtype.base_dtype) 659 | ord_n_pk, ord_pk = tf.clip_by_value(ord_n_pk, epsilon, 1 - epsilon), tf.clip_by_value(ord_pk, epsilon, 1 - epsilon) 660 | ord_log_n_pk = tf.log(ord_n_pk) 661 | ord_log_pk = tf.log(ord_pk) 662 | (N, H, W, C) = ord_log_pk.get_shape().as_list() 663 | foreground_mask = tf.reshape(tf.sequence_mask(true_ord, C), shape=[-1, H, W, C]) 664 | sum_of_p = tf.reduce_sum(tf.where(foreground_mask, ord_log_pk, ord_log_n_pk), axis=3) 665 | sum_of_p += sum_of_p * weight_map_1 666 | loss = -tf.reduce_mean(sum_of_p) 667 | loss = tf.identity(loss, name=name) 668 | return loss 669 | 670 | 671 | 672 | #### 673 | if get_current_tower_context().is_training: 674 | weight_map = tf.cast(tf.squeeze(pred_ord, axis=-1), tf.float32) # [N, 76, 76] 675 | weight_map_1 = tf.where( 676 | tf.equal(weight_map, 1), 677 | 2 * tf.cast(tf.equal(weight_map, 1), tf.float32), 678 | tf.cast(tf.equal(weight_map, 2), tf.float32) 679 | ) 680 | #---- LOSS ----# 681 | loss = 0 682 | 683 | for term, weight in self.loss_term.items(): 684 | if term == 'bce': 685 | term_loss = categorical_crossentropy_modified(soft_np, one_np) 686 | term_loss += weight_map_1 * term_loss 687 | term_loss = tf.reduce_mean(term_loss, name='loss-bce') 688 | elif term == 'dice': 689 | term_loss = dice_loss(soft_np[...,0], one_np[...,0]) \ 690 | + dice_loss(soft_np[...,1], one_np[...,1]) 691 | term_loss = tf.identity(term_loss, name='loss-dice') 692 | elif term == 'ord': 693 | term_loss = loss_ord(prob_ord, true_ord, name='loss-ord') 694 | else: 695 | assert False, 'Not support loss term: %s' % term 696 | add_moving_summary(term_loss) 697 | loss += term_loss * weight 698 | 699 | if self.type_classification: 700 | term_loss = focal_loss_modified(soft_class, one_type) 701 | term_loss += weight_map_1 * term_loss 702 | term_loss = tf.reduce_mean(term_loss, name='loss-classification') 703 | add_moving_summary(term_loss) 704 | loss = loss + term_loss 705 | 706 | self.cost = tf.identity(loss, name='overall-loss') 707 | add_moving_summary(self.cost) 708 | #### 709 | return --------------------------------------------------------------------------------