├── F1_CellSegment.py ├── F3_FeatureExtract.py ├── F4_Visualization.py ├── Hover ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── __init__.py ├── compute_stats.py ├── config.py ├── convert_chkpt_tf2pytorch.py ├── convert_format.py ├── dataloader │ ├── __init__.py │ ├── augs.py │ ├── infer_loader.py │ └── train_loader.py ├── dataset.py ├── environment.yml ├── extract_patches.py ├── fun1.py ├── infer │ ├── __init__.py │ ├── base.py │ ├── tile.py │ └── wsi.py ├── metrics │ ├── README.md │ ├── __init__.py │ └── stats_utils.py ├── misc │ ├── __init__.py │ ├── patch_extractor.py │ ├── utils.py │ ├── viz_utils.py │ └── wsi_handler.py ├── models │ ├── __init__.py │ └── hovernet │ │ ├── __init__.py │ │ ├── net_desc.py │ │ ├── net_utils.py │ │ ├── opt.py │ │ ├── post_proc.py │ │ ├── run_desc.py │ │ ├── targets.py │ │ └── utils.py ├── requirements.txt ├── run_infer.py ├── run_tile.sh ├── run_train.py ├── run_utils │ ├── __init__.py │ ├── callbacks │ │ ├── __init__.py │ │ ├── base.py │ │ ├── logging.py │ │ └── serialize.py │ ├── engine.py │ └── utils.py ├── run_wsi.sh ├── type_info.json └── variables_tf2pytorch.csv ├── LICENSE ├── WSIGraph.py ├── main.py ├── readme.md ├── requirements.txt └── utils_xml.py /F1_CellSegment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | def fun1(input_dir, output_dir): 5 | print(input_dir) 6 | print(output_dir) 7 | model_path = 'Hover/hovernet_fast_pannuke_type_tf2pytorch.tar' 8 | args = {'gpu':'0', 'nr_types':6, 'type_info_path':'Hover/type_info.json', 'model_mode':'fast', 9 | 'model_path':model_path, 'nr_inference_workers':8, 'nr_post_proc_workers':0, 10 | 'batch_size':16} 11 | sub_args = {'input_dir': input_dir, 12 | 'output_dir': output_dir, 13 | 'presplit_dir': None, 14 | 'cache_path':'cache', 15 | 'input_mask_dir':'', 16 | 'proc_mag':40, 17 | 'ambiguous_size':128, 18 | 'chunk_shape':10000, 19 | 'tile_shape':2048, 20 | 'save_thumb':True, 21 | 'save_mask':True} 22 | gpu_list = args.pop('gpu') 23 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 24 | 25 | nr_gpus = torch.cuda.device_count() 26 | nr_types = int(args['nr_types']) if int(args['nr_types']) > 0 else None 27 | method_args = { 28 | 'method' : { 29 | 'model_args' : { 30 | 'nr_types' : nr_types, 31 | 'mode' : args['model_mode'], 32 | }, 33 | 'model_path' : args['model_path'], 34 | }, 35 | 'type_info_path' : None if args['type_info_path'] == '' \ 36 | else args['type_info_path'], 37 | } 38 | 39 | run_args = { 40 | 'batch_size' : int(args['batch_size']) * nr_gpus, 41 | 42 | 'nr_inference_workers' : int(args['nr_inference_workers']), 43 | 'nr_post_proc_workers' : int(args['nr_post_proc_workers']), 44 | } 45 | 46 | if args['model_mode'] == 'fast': 47 | run_args['patch_input_shape'] = 256 48 | run_args['patch_output_shape'] = 164 49 | else: 50 | run_args['patch_input_shape'] = 270 51 | run_args['patch_output_shape'] = 80 52 | 53 | run_args.update({ 54 | 'input_dir' : sub_args['input_dir'], 55 | 'output_dir' : sub_args['output_dir'], 56 | 'presplit_dir' : sub_args['presplit_dir'], 57 | 'input_mask_dir' : sub_args['input_mask_dir'], 58 | 'cache_path' : sub_args['cache_path'], 59 | 60 | 'proc_mag' : int(sub_args['proc_mag']), 61 | 'ambiguous_size' : int(sub_args['ambiguous_size']), 62 | 'chunk_shape' : int(sub_args['chunk_shape']), 63 | 'tile_shape' : int(sub_args['tile_shape']), 64 | 'save_thumb' : sub_args['save_thumb'], 65 | 'save_mask' : sub_args['save_mask'], 66 | }) 67 | 68 | from Hover.infer.wsi import InferManager 69 | infer = InferManager(**method_args) 70 | infer.process_wsi_list(run_args) 71 | -------------------------------------------------------------------------------- /F3_FeatureExtract.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from WSIGraph import constructGraphFromDict 4 | from collections import defaultdict 5 | import numpy as np 6 | from utils_xml import get_windows 7 | 8 | def fun3(json_path, wsi_path, output_path, xml_path=None): 9 | # json_path = '../Part1_HoverNet/COWH.json' 10 | # wsi_path = '../Part1_HoverNet/COWH.ndpi' 11 | # xml_path = '../Part1_HoverNet/COWH.xml' 12 | # output_path = './' 13 | distanceThreshold = 100 14 | level = 0 15 | k = 5 16 | 17 | sample_name = os.path.basename(wsi_path).split('.')[0] 18 | with open(json_path) as fp: 19 | print(f"{'Loading json':*^30s}") 20 | nucleusInfo = json.load(fp) 21 | 22 | globalgraph, edge_info = constructGraphFromDict(wsi_path, nucleusInfo, distanceThreshold, k, level) 23 | vertex_dataframe = globalgraph.get_vertex_dataframe() 24 | centroid = np.array(vertex_dataframe['Centroid'].tolist()) 25 | 26 | if xml_path is not None: 27 | window_bbox = np.array(get_windows(xml_path)) 28 | index = np.zeros((len(centroid), len(window_bbox)), dtype=np.bool_) 29 | for i in range(len(window_bbox)): 30 | index[:, i] = ((window_bbox[i, 0, 0] 1] += current_max_id 30 | ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1] 31 | current_max_id = np.amax(ann) 32 | return ann 33 | 34 | 35 | #### 36 | def gaussian_blur(images, random_state, parents, hooks, max_ksize=3): 37 | """Apply Gaussian blur to input images.""" 38 | img = images[0] # aleju input batch as default (always=1 in our case) 39 | ksize = random_state.randint(0, max_ksize, size=(2,)) 40 | ksize = tuple((ksize * 2 + 1).tolist()) 41 | 42 | ret = cv2.GaussianBlur( 43 | img, ksize, sigmaX=0, sigmaY=0, borderType=cv2.BORDER_REPLICATE 44 | ) 45 | ret = np.reshape(ret, img.shape) 46 | ret = ret.astype(np.uint8) 47 | return [ret] 48 | 49 | 50 | #### 51 | def median_blur(images, random_state, parents, hooks, max_ksize=3): 52 | """Apply median blur to input images.""" 53 | img = images[0] # aleju input batch as default (always=1 in our case) 54 | ksize = random_state.randint(0, max_ksize) 55 | ksize = ksize * 2 + 1 56 | ret = cv2.medianBlur(img, ksize) 57 | ret = ret.astype(np.uint8) 58 | return [ret] 59 | 60 | 61 | #### 62 | def add_to_hue(images, random_state, parents, hooks, range=None): 63 | """Perturbe the hue of input images.""" 64 | img = images[0] # aleju input batch as default (always=1 in our case) 65 | hue = random_state.uniform(*range) 66 | hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) 67 | if hsv.dtype.itemsize == 1: 68 | # OpenCV uses 0-179 for 8-bit images 69 | hsv[..., 0] = (hsv[..., 0] + hue) % 180 70 | else: 71 | # OpenCV uses 0-360 for floating point images 72 | hsv[..., 0] = (hsv[..., 0] + 2 * hue) % 360 73 | ret = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) 74 | ret = ret.astype(np.uint8) 75 | return [ret] 76 | 77 | 78 | #### 79 | def add_to_saturation(images, random_state, parents, hooks, range=None): 80 | """Perturbe the saturation of input images.""" 81 | img = images[0] # aleju input batch as default (always=1 in our case) 82 | value = 1 + random_state.uniform(*range) 83 | gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 84 | ret = img * value + (gray * (1 - value))[:, :, np.newaxis] 85 | ret = np.clip(ret, 0, 255) 86 | ret = ret.astype(np.uint8) 87 | return [ret] 88 | 89 | 90 | #### 91 | def add_to_contrast(images, random_state, parents, hooks, range=None): 92 | """Perturbe the contrast of input images.""" 93 | img = images[0] # aleju input batch as default (always=1 in our case) 94 | value = random_state.uniform(*range) 95 | mean = np.mean(img, axis=(0, 1), keepdims=True) 96 | ret = img * value + mean * (1 - value) 97 | ret = np.clip(img, 0, 255) 98 | ret = ret.astype(np.uint8) 99 | return [ret] 100 | 101 | 102 | #### 103 | def add_to_brightness(images, random_state, parents, hooks, range=None): 104 | """Perturbe the brightness of input images.""" 105 | img = images[0] # aleju input batch as default (always=1 in our case) 106 | value = random_state.uniform(*range) 107 | ret = np.clip(img + value, 0, 255) 108 | ret = ret.astype(np.uint8) 109 | return [ret] 110 | -------------------------------------------------------------------------------- /Hover/dataloader/infer_loader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | import numpy as np 4 | import cv2 5 | import matplotlib.pyplot as plt 6 | 7 | import torch 8 | import torch.utils.data as data 9 | 10 | import psutil 11 | 12 | 13 | #### 14 | class SerializeFileList(data.IterableDataset): 15 | """Read a single file as multiple patches of same shape, perform the padding beforehand.""" 16 | 17 | def __init__(self, img_list, patch_info_list, patch_size, preproc=None): 18 | super().__init__() 19 | self.patch_size = patch_size 20 | 21 | self.img_list = img_list 22 | self.patch_info_list = patch_info_list 23 | 24 | self.worker_start_img_idx = 0 25 | # * for internal worker state 26 | self.curr_img_idx = 0 27 | self.stop_img_idx = 0 28 | self.curr_patch_idx = 0 29 | self.stop_patch_idx = 0 30 | self.preproc = preproc 31 | return 32 | 33 | def __iter__(self): 34 | worker_info = torch.utils.data.get_worker_info() 35 | if worker_info is None: # single-process data loading, return the full iterator 36 | self.stop_img_idx = len(self.img_list) 37 | self.stop_patch_idx = len(self.patch_info_list) 38 | return self 39 | else: # in a worker process so split workload, return a reduced copy of self 40 | per_worker = len(self.patch_info_list) / float(worker_info.num_workers) 41 | per_worker = int(math.ceil(per_worker)) 42 | 43 | global_curr_patch_idx = worker_info.id * per_worker 44 | global_stop_patch_idx = global_curr_patch_idx + per_worker 45 | self.patch_info_list = self.patch_info_list[ 46 | global_curr_patch_idx:global_stop_patch_idx 47 | ] 48 | self.curr_patch_idx = 0 49 | self.stop_patch_idx = len(self.patch_info_list) 50 | # * check img indexer, implicit protocol in infer.py 51 | global_curr_img_idx = self.patch_info_list[0][-1] 52 | global_stop_img_idx = self.patch_info_list[-1][-1] + 1 53 | self.worker_start_img_idx = global_curr_img_idx 54 | self.img_list = self.img_list[global_curr_img_idx:global_stop_img_idx] 55 | self.curr_img_idx = 0 56 | self.stop_img_idx = len(self.img_list) 57 | return self # does it mutate source copy? 58 | 59 | def __next__(self): 60 | 61 | if self.curr_patch_idx >= self.stop_patch_idx: 62 | raise StopIteration # when there is nothing more to yield 63 | patch_info = self.patch_info_list[self.curr_patch_idx] 64 | img_ptr = self.img_list[patch_info[-1] - self.worker_start_img_idx] 65 | patch_data = img_ptr[ 66 | patch_info[0] : patch_info[0] + self.patch_size, 67 | patch_info[1] : patch_info[1] + self.patch_size, 68 | ] 69 | self.curr_patch_idx += 1 70 | if self.preproc is not None: 71 | patch_data = self.preproc(patch_data) 72 | return patch_data, patch_info 73 | 74 | 75 | #### 76 | class SerializeArray(data.Dataset): 77 | def __init__(self, mmap_array_path, patch_info_list, patch_size, preproc=None): 78 | super().__init__() 79 | self.patch_size = patch_size 80 | 81 | # use mmap as intermediate sharing, else variable will be duplicated 82 | # accross torch worker => OOM error, open in read only mode 83 | self.image = np.load(mmap_array_path, mmap_mode="r") 84 | 85 | self.patch_info_list = patch_info_list 86 | self.preproc = preproc 87 | return 88 | 89 | def __len__(self): 90 | return len(self.patch_info_list) 91 | 92 | def __getitem__(self, idx): 93 | patch_info = self.patch_info_list[idx] 94 | patch_data = self.image[ 95 | patch_info[0] : patch_info[0] + self.patch_size[0], 96 | patch_info[1] : patch_info[1] + self.patch_size[1], 97 | ] 98 | if self.preproc is not None: 99 | patch_data = self.preproc(patch_data) 100 | return patch_data, patch_info 101 | -------------------------------------------------------------------------------- /Hover/dataloader/train_loader.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import glob 3 | import os 4 | import re 5 | 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import scipy.io as sio 10 | import torch.utils.data 11 | 12 | import imgaug as ia 13 | from imgaug import augmenters as iaa 14 | from misc.utils import cropping_center 15 | 16 | from .augs import ( 17 | add_to_brightness, 18 | add_to_contrast, 19 | add_to_hue, 20 | add_to_saturation, 21 | gaussian_blur, 22 | median_blur, 23 | ) 24 | 25 | 26 | #### 27 | class FileLoader(torch.utils.data.Dataset): 28 | """Data Loader. Loads images from a file list and 29 | performs augmentation with the albumentation library. 30 | After augmentation, horizontal and vertical maps are 31 | generated. 32 | 33 | Args: 34 | file_list: list of filenames to load 35 | input_shape: shape of the input [h,w] - defined in config.py 36 | mask_shape: shape of the output [h,w] - defined in config.py 37 | mode: 'train' or 'valid' 38 | 39 | """ 40 | 41 | # TODO: doc string 42 | 43 | def __init__( 44 | self, 45 | file_list, 46 | with_type=False, 47 | input_shape=None, 48 | mask_shape=None, 49 | mode="train", 50 | setup_augmentor=True, 51 | target_gen=None, 52 | ): 53 | assert input_shape is not None and mask_shape is not None 54 | self.mode = mode 55 | self.info_list = file_list 56 | self.with_type = with_type 57 | self.mask_shape = mask_shape 58 | self.input_shape = input_shape 59 | self.id = 0 60 | self.target_gen_func = target_gen[0] 61 | self.target_gen_kwargs = target_gen[1] 62 | if setup_augmentor: 63 | self.setup_augmentor(0, 0) 64 | return 65 | 66 | def setup_augmentor(self, worker_id, seed): 67 | self.augmentor = self.__get_augmentation(self.mode, seed) 68 | self.shape_augs = iaa.Sequential(self.augmentor[0]) 69 | self.input_augs = iaa.Sequential(self.augmentor[1]) 70 | self.id = self.id + worker_id 71 | return 72 | 73 | def __len__(self): 74 | return len(self.info_list) 75 | 76 | def __getitem__(self, idx): 77 | path = self.info_list[idx] 78 | data = np.load(path) 79 | 80 | # split stacked channel into image and label 81 | img = (data[..., :3]).astype("uint8") # RGB images 82 | ann = (data[..., 3:]).astype("int32") # instance ID map and type map 83 | 84 | if self.shape_augs is not None: 85 | shape_augs = self.shape_augs.to_deterministic() 86 | img = shape_augs.augment_image(img) 87 | ann = shape_augs.augment_image(ann) 88 | 89 | if self.input_augs is not None: 90 | input_augs = self.input_augs.to_deterministic() 91 | img = input_augs.augment_image(img) 92 | 93 | img = cropping_center(img, self.input_shape) 94 | feed_dict = {"img": img} 95 | 96 | inst_map = ann[..., 0] # HW1 -> HW 97 | if self.with_type: 98 | type_map = (ann[..., 1]).copy() 99 | type_map = cropping_center(type_map, self.mask_shape) 100 | #type_map[type_map == 5] = 1 # merge neoplastic and non-neoplastic 101 | feed_dict["tp_map"] = type_map 102 | 103 | # TODO: document hard coded assumption about #input 104 | target_dict = self.target_gen_func( 105 | inst_map, self.mask_shape, **self.target_gen_kwargs 106 | ) 107 | feed_dict.update(target_dict) 108 | 109 | return feed_dict 110 | 111 | def __get_augmentation(self, mode, rng): 112 | if mode == "train": 113 | shape_augs = [ 114 | # * order = ``0`` -> ``cv2.INTER_NEAREST`` 115 | # * order = ``1`` -> ``cv2.INTER_LINEAR`` 116 | # * order = ``2`` -> ``cv2.INTER_CUBIC`` 117 | # * order = ``3`` -> ``cv2.INTER_CUBIC`` 118 | # * order = ``4`` -> ``cv2.INTER_CUBIC`` 119 | # ! for pannuke v0, no rotation or translation, just flip to avoid mirror padding 120 | # iaa.Affine( 121 | # # scale images to 80-120% of their size, individually per axis 122 | # scale={"x": (0.8, 1.2), "y": (0.8, 1.2)}, 123 | # # translate by -A to +A percent (per axis) 124 | # # translate_percent={"x": (-0.01, 0.01), "y": (-0.01, 0.01)}, 125 | # shear=(-5, 5), # shear by -5 to +5 degrees 126 | # # rotate=(-179, 179), # rotate by -179 to +179 degrees 127 | # order=0, # use nearest neighbour 128 | # backend="cv2", # opencv for fast processing 129 | # seed=rng, 130 | # ), 131 | # # set position to 'center' for center crop 132 | # # else 'uniform' for random crop 133 | 134 | # iaa.CropToFixedSize( 135 | # self.input_shape[0], self.input_shape[1], position="center" 136 | # ), 137 | # ! 2021-11-22 138 | iaa.Resize((self.input_shape[0], self.input_shape[1]), interpolation='nearest'), 139 | iaa.Fliplr(0.5, seed=rng), 140 | iaa.Flipud(0.5, seed=rng), 141 | ] 142 | 143 | input_augs = [ 144 | iaa.OneOf( 145 | [ 146 | iaa.Lambda( 147 | seed=rng, 148 | func_images=lambda *args: gaussian_blur(*args, max_ksize=3), 149 | ), 150 | iaa.Lambda( 151 | seed=rng, 152 | func_images=lambda *args: median_blur(*args, max_ksize=3), 153 | ), 154 | iaa.AdditiveGaussianNoise( 155 | loc=0, scale=(0.0, 0.05 * 255), per_channel=0.5 156 | ), 157 | ] 158 | ), 159 | iaa.Sequential( 160 | [ 161 | iaa.Lambda( 162 | seed=rng, 163 | func_images=lambda *args: add_to_hue(*args, range=(-8, 8)), 164 | ), 165 | iaa.Lambda( 166 | seed=rng, 167 | func_images=lambda *args: add_to_saturation( 168 | *args, range=(-0.2, 0.2) 169 | ), 170 | ), 171 | iaa.Lambda( 172 | seed=rng, 173 | func_images=lambda *args: add_to_brightness( 174 | *args, range=(-26, 26) 175 | ), 176 | ), 177 | iaa.Lambda( 178 | seed=rng, 179 | func_images=lambda *args: add_to_contrast( 180 | *args, range=(0.75, 1.25) 181 | ), 182 | ), 183 | ], 184 | random_order=True, 185 | ), 186 | ] 187 | elif mode == "valid": 188 | shape_augs = [ 189 | # set position to 'center' for center crop 190 | # else 'uniform' for random crop 191 | iaa.CropToFixedSize( 192 | self.input_shape[0], self.input_shape[1], position="center" 193 | ) 194 | ] 195 | input_augs = [] 196 | 197 | return shape_augs, input_augs 198 | -------------------------------------------------------------------------------- /Hover/dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import cv2 3 | import numpy as np 4 | import scipy.io as sio 5 | 6 | 7 | class __AbstractDataset(object): 8 | """Abstract class for interface of subsequent classes. 9 | Main idea is to encapsulate how each dataset should parse 10 | their images and annotations. 11 | 12 | """ 13 | 14 | def load_img(self, path): 15 | raise NotImplementedError 16 | 17 | def load_ann(self, path, with_type=False): 18 | raise NotImplementedError 19 | 20 | 21 | #### 22 | class __Kumar(__AbstractDataset): 23 | """Defines the Kumar dataset as originally introduced in: 24 | 25 | Kumar, Neeraj, Ruchika Verma, Sanuj Sharma, Surabhi Bhargava, Abhishek Vahadane, 26 | and Amit Sethi. "A dataset and a technique for generalized nuclear segmentation for 27 | computational pathology." IEEE transactions on medical imaging 36, no. 7 (2017): 1550-1560. 28 | 29 | """ 30 | 31 | def load_img(self, path): 32 | return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) 33 | 34 | def load_ann(self, path, with_type=False): 35 | # assumes that ann is HxW 36 | assert not with_type, "Not support" 37 | ann_inst = sio.loadmat(path)["inst_map"] 38 | ann_inst = ann_inst.astype("int32") 39 | ann = np.expand_dims(ann_inst, -1) 40 | return ann 41 | 42 | 43 | #### 44 | class __CPM17(__AbstractDataset): 45 | """Defines the CPM 2017 dataset as originally introduced in: 46 | 47 | Vu, Quoc Dang, Simon Graham, Tahsin Kurc, Minh Nguyen Nhat To, Muhammad Shaban, 48 | Talha Qaiser, Navid Alemi Koohbanani et al. "Methods for segmentation and classification 49 | of digital microscopy tissue images." Frontiers in bioengineering and biotechnology 7 (2019). 50 | 51 | """ 52 | 53 | def load_img(self, path): 54 | return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) 55 | 56 | def load_ann(self, path, with_type=False): 57 | assert not with_type, "Not support" 58 | # assumes that ann is HxW 59 | ann_inst = sio.loadmat(path)["inst_map"] 60 | ann_inst = ann_inst.astype("int32") 61 | ann = np.expand_dims(ann_inst, -1) 62 | return ann 63 | 64 | 65 | #### 66 | class __CoNSeP(__AbstractDataset): 67 | """Defines the CoNSeP dataset as originally introduced in: 68 | 69 | Graham, Simon, Quoc Dang Vu, Shan E. Ahmed Raza, Ayesha Azam, Yee Wah Tsang, Jin Tae Kwak, 70 | and Nasir Rajpoot. "Hover-Net: Simultaneous segmentation and classification of nuclei in 71 | multi-tissue histology images." Medical Image Analysis 58 (2019): 101563 72 | 73 | """ 74 | 75 | def load_img(self, path): 76 | return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB) 77 | 78 | def load_ann(self, path, with_type=False): 79 | # assumes that ann is HxW 80 | ann_inst = sio.loadmat(path)["inst_map"] 81 | if with_type: 82 | ann_type = sio.loadmat(path)["type_map"] 83 | 84 | # merge classes for CoNSeP (in paper we only utilise 3 nuclei classes and background) 85 | # If own dataset is used, then the below may need to be modified 86 | ann_type[(ann_type == 3) | (ann_type == 4)] = 3 87 | ann_type[(ann_type == 5) | (ann_type == 6) | (ann_type == 7)] = 4 88 | 89 | ann = np.dstack([ann_inst, ann_type]) 90 | ann = ann.astype("int32") 91 | else: 92 | ann = np.expand_dims(ann_inst, -1) 93 | ann = ann.astype("int32") 94 | 95 | return ann 96 | 97 | 98 | 99 | 100 | 101 | #### 102 | def get_dataset(name): 103 | """Return a pre-defined dataset object associated with `name`.""" 104 | name_dict = { 105 | "kumar": lambda: __Kumar(), 106 | "cpm17": lambda: __CPM17(), 107 | "consep": lambda: __CoNSeP(), 108 | } 109 | if name.lower() in name_dict: 110 | return name_dict[name]() 111 | else: 112 | assert False, "Unknown dataset `%s`" % name 113 | 114 | 115 | #### convert PanNuke to wanted format 116 | #### 2021-11-19 117 | -------------------------------------------------------------------------------- /Hover/environment.yml: -------------------------------------------------------------------------------- 1 | name: hovernet 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - python=3.6.12 7 | - pip=20.3.1 8 | - openslide 9 | - pip: 10 | - -r file:requirements.txt 11 | - openslide-python==1.1.2 12 | -------------------------------------------------------------------------------- /Hover/extract_patches.py: -------------------------------------------------------------------------------- 1 | """extract_patches.py 2 | 3 | Patch extraction script. 4 | """ 5 | 6 | import re 7 | import glob 8 | import os 9 | import tqdm 10 | import pathlib 11 | 12 | import numpy as np 13 | 14 | from misc.patch_extractor import PatchExtractor 15 | from misc.utils import rm_n_mkdir 16 | 17 | from dataset import get_dataset 18 | 19 | # ------------------------------------------------------------------------------------- 20 | if __name__ == "__main__": 21 | 22 | # Determines whether to extract type map (only applicable to datasets with class labels). 23 | type_classification = True 24 | 25 | win_size = [540, 540] 26 | step_size = [164, 164] 27 | extract_type = "mirror" # Choose 'mirror' or 'valid'. 'mirror'- use padding at borders. 'valid'- only extract from valid regions. 28 | 29 | # Name of dataset - use Kumar, CPM17 or CoNSeP. 30 | # This used to get the specific dataset img and ann loading scheme from dataset.py 31 | dataset_name = "consep" 32 | save_root = "dataset/training_data/%s/" % dataset_name 33 | 34 | # a dictionary to specify where the dataset path should be 35 | dataset_info = { 36 | "train": { 37 | "img": (".png", "dataset/CoNSeP/Train/Images/"), 38 | "ann": (".mat", "dataset/CoNSeP/Train/Labels/"), 39 | }, 40 | "valid": { 41 | "img": (".png", "dataset/CoNSeP/Test/Images/"), 42 | "ann": (".mat", "dataset/CoNSeP/Test/Labels/"), 43 | }, 44 | } 45 | 46 | patterning = lambda x: re.sub("([\[\]])", "[\\1]", x) 47 | parser = get_dataset(dataset_name) 48 | xtractor = PatchExtractor(win_size, step_size) 49 | for split_name, split_desc in dataset_info.items(): 50 | img_ext, img_dir = split_desc["img"] 51 | ann_ext, ann_dir = split_desc["ann"] 52 | 53 | out_dir = "%s/%s/%s/%dx%d_%dx%d/" % ( 54 | save_root, 55 | dataset_name, 56 | split_name, 57 | win_size[0], 58 | win_size[1], 59 | step_size[0], 60 | step_size[1], 61 | ) 62 | file_list = glob.glob(patterning("%s/*%s" % (ann_dir, ann_ext))) 63 | file_list.sort() # ensure same ordering across platform 64 | 65 | rm_n_mkdir(out_dir) 66 | 67 | pbar_format = "Process File: |{bar}| {n_fmt}/{total_fmt}[{elapsed}<{remaining},{rate_fmt}]" 68 | pbarx = tqdm.tqdm( 69 | total=len(file_list), bar_format=pbar_format, ascii=True, position=0 70 | ) 71 | 72 | for file_idx, file_path in enumerate(file_list): 73 | base_name = pathlib.Path(file_path).stem 74 | 75 | img = parser.load_img("%s/%s%s" % (img_dir, base_name, img_ext)) 76 | ann = parser.load_ann( 77 | "%s/%s%s" % (ann_dir, base_name, ann_ext), type_classification 78 | ) 79 | 80 | # * 81 | img = np.concatenate([img, ann], axis=-1) 82 | sub_patches = xtractor.extract(img, extract_type) 83 | 84 | pbar_format = "Extracting : |{bar}| {n_fmt}/{total_fmt}[{elapsed}<{remaining},{rate_fmt}]" 85 | pbar = tqdm.tqdm( 86 | total=len(sub_patches), 87 | leave=False, 88 | bar_format=pbar_format, 89 | ascii=True, 90 | position=1, 91 | ) 92 | 93 | for idx, patch in enumerate(sub_patches): 94 | np.save("{0}/{1}_{2:03d}.npy".format(out_dir, base_name, idx), patch) 95 | pbar.update() 96 | pbar.close() 97 | # * 98 | 99 | pbarx.update() 100 | pbarx.close() 101 | -------------------------------------------------------------------------------- /Hover/fun1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | # method_args = 4 | ''' 5 | --input_dir= Path to input data directory. Assumes the files are not nested within directory. 6 | --output_dir= Path to output directory. 7 | --presplit_dir= Path to presplit data directory. 8 | --cache_path= Path for cache. Should be placed on SSD with at least 100GB. [default: cache] 9 | --mask_dir= Path to directory containing tissue masks. 10 | Should have the same name as corresponding WSIs. [default: ''] 11 | 12 | --proc_mag= Magnification level (objective power) used for WSI processing. [default: 40] 13 | --ambiguous_size= Define ambiguous region along tiling grid to perform re-post processing. [default: 128] 14 | --chunk_shape= Shape of chunk for processing. [default: 10000] 15 | --tile_shape= Shape of tiles for processing. [default: 2048] 16 | --save_thumb To save thumb. [default: False] 17 | --save_mask To save mask. [default: False] 18 | ''' 19 | '''python run_infer.py \ 20 | --gpu='0' \ 21 | --nr_types=6 \ 22 | --type_info_path=type_info.json \ 23 | --batch_size=32 \ 24 | --model_mode=fast \ 25 | --model_path=/home/xujun/FUSCC/Hover/hovernet_fast_pannuke_type_tf2pytorch.tar \ 26 | --nr_inference_workers=4 \ 27 | --nr_post_proc_workers=0 \ 28 | wsi \ 29 | --input_dir=/home/xujun/FUSCC/WSI_example/WSI2 \ 30 | --output_dir=/home/xujun/FUSCC/WSI_example/pred2 \ 31 | --presplit_dir=/home/xujun/FUSCC/WSI_example/WSI_presplit2 \ 32 | --proc_mag 20 \ 33 | --save_thumb \ 34 | --save_mask 35 | ''' 36 | model_path = './hovernet_fast_pannuke_type_ft2pytorch.tar' 37 | args = {'gpu':0, 'nr_types':6, 'type_info_path':'type_info.json', 'model_mode':'fast', 38 | 'model_path':model_path, 'nr_inference_workers':8, 'nr_post_proc_workers':0, 39 | 'batch_size':16} 40 | sub_args = {} 41 | gpu_list = args.pop('gpu') 42 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 43 | 44 | nr_gpus = torch.cuda.device_count() 45 | nr_types = int(args['nr_types']) if int(args['nr_types']) > 0 else None 46 | method_args = { 47 | 'method' : { 48 | 'model_args' : { 49 | 'nr_types' : nr_types, 50 | 'mode' : args['model_mode'], 51 | }, 52 | 'model_path' : args['model_path'], 53 | }, 54 | 'type_info_path' : None if args['type_info_path'] == '' \ 55 | else args['type_info_path'], 56 | } 57 | 58 | run_args = { 59 | 'batch_size' : int(args['batch_size']) * nr_gpus, 60 | 61 | 'nr_inference_workers' : int(args['nr_inference_workers']), 62 | 'nr_post_proc_workers' : int(args['nr_post_proc_workers']), 63 | } 64 | 65 | if args['model_mode'] == 'fast': 66 | run_args['patch_input_shape'] = 256 67 | run_args['patch_output_shape'] = 164 68 | else: 69 | run_args['patch_input_shape'] = 270 70 | run_args['patch_output_shape'] = 80 71 | 72 | run_args.update({ 73 | 'input_dir' : sub_args['input_dir'], 74 | 'output_dir' : sub_args['output_dir'], 75 | 'presplit_dir' : sub_args['presplit_dir'], 76 | 'input_mask_dir' : sub_args['input_mask_dir'], 77 | 'cache_path' : sub_args['cache_path'], 78 | 79 | 'proc_mag' : int(sub_args['proc_mag']), 80 | 'ambiguous_size' : int(sub_args['ambiguous_size']), 81 | 'chunk_shape' : int(sub_args['chunk_shape']), 82 | 'tile_shape' : int(sub_args['tile_shape']), 83 | 'save_thumb' : sub_args['save_thumb'], 84 | 'save_mask' : sub_args['save_mask'], 85 | }) 86 | 87 | from Hover.infer.wsi import InferManager 88 | infer = InferManager(**method_args) 89 | infer.process_wsi_list(run_args) 90 | -------------------------------------------------------------------------------- /Hover/infer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuscc-deep-path/sc_MTOP/33ff31fbd01f37705118244da0a8df96a4f19014/Hover/infer/__init__.py -------------------------------------------------------------------------------- /Hover/infer/base.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import json 4 | import math 5 | import multiprocessing 6 | import os 7 | import re 8 | import sys 9 | from importlib import import_module 10 | from multiprocessing import Lock, Pool 11 | 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | import torch 15 | import torch.utils.data as data 16 | import tqdm 17 | 18 | from run_utils.utils import convert_pytorch_checkpoint 19 | 20 | 21 | #### 22 | class InferManager(object): 23 | def __init__(self, **kwargs): 24 | self.run_step = None 25 | for variable, value in kwargs.items(): 26 | self.__setattr__(variable, value) 27 | self.__load_model() 28 | self.nr_types = self.method["model_args"]["nr_types"] 29 | # create type info name and colour 30 | 31 | # default 32 | self.type_info_dict = { 33 | None: ["no label", [0, 0, 0]], 34 | } 35 | 36 | if self.nr_types is not None and self.type_info_path is not None: 37 | self.type_info_dict = json.load(open(self.type_info_path, "r")) 38 | self.type_info_dict = { 39 | int(k): (v[0], tuple(v[1])) for k, v in self.type_info_dict.items() 40 | } 41 | # availability check 42 | for k in range(self.nr_types): 43 | if k not in self.type_info_dict: 44 | assert False, "Not detect type_id=%d defined in json." % k 45 | 46 | if self.nr_types is not None and self.type_info_path is None: 47 | cmap = plt.get_cmap("hot") 48 | colour_list = np.arange(self.nr_types, dtype=np.int32) 49 | colour_list = (cmap(colour_list)[..., :3] * 255).astype(np.uint8) 50 | # should be compatible out of the box wrt qupath 51 | self.type_info_dict = { 52 | k: (str(k), tuple(v)) for k, v in enumerate(colour_list) 53 | } 54 | return 55 | 56 | def __load_model(self): 57 | """Create the model, load the checkpoint and define 58 | associated run steps to process each data batch. 59 | 60 | """ 61 | model_desc = import_module("models.hovernet.net_desc") 62 | model_creator = getattr(model_desc, "create_model") 63 | 64 | net = model_creator(**self.method["model_args"]) 65 | saved_state_dict = torch.load(self.method["model_path"])["desc"] 66 | saved_state_dict = convert_pytorch_checkpoint(saved_state_dict) 67 | 68 | net.load_state_dict(saved_state_dict, strict=True) 69 | net = torch.nn.DataParallel(net) 70 | net = net.to("cuda") 71 | 72 | module_lib = import_module("models.hovernet.run_desc") 73 | run_step = getattr(module_lib, "infer_step") 74 | self.run_step = lambda input_batch: run_step(input_batch, net) 75 | 76 | module_lib = import_module("models.hovernet.post_proc") 77 | self.post_proc_func = getattr(module_lib, "process") 78 | return 79 | 80 | def __save_json(self, path, old_dict, mag=None): 81 | new_dict = {} 82 | for inst_id, inst_info in old_dict.items(): 83 | new_inst_info = {} 84 | for info_name, info_value in inst_info.items(): 85 | # convert to jsonable 86 | if isinstance(info_value, np.ndarray): 87 | info_value = info_value.tolist() 88 | new_inst_info[info_name] = info_value 89 | new_dict[int(inst_id)] = new_inst_info 90 | 91 | json_dict = {"mag": mag, "nuc": new_dict} # to sync the format protocol 92 | with open(path, "w") as handle: 93 | json.dump(json_dict, handle) 94 | return new_dict 95 | -------------------------------------------------------------------------------- /Hover/metrics/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Statistical Measurements for Instance Segmentation and Classification 3 | 4 | ## Description 5 | 6 | In this directory, the script `stats_utils.py` contains the statistical measurements code for instance segmentation. In order of appearance, the available measurements are AJI+, AJI, DICE2, Panoptic Quality (PQ), DICE which can be access through following functions: 7 | 8 | `get_fast_aji()`: aji ported from the matlab code but is optimised for speed **[1]**.
9 | `get_fast_aji_plus()`: extension of aggregated jaccard index that doesn't suffer from over-penalisation.
10 | `get_dice_1()` and `get_dice_2()`: standard dice and ensemble dice (DICE2) **[2]** measures respectively.
11 | `get_fast_dice_2()`: ensemble dice optimised for speed.
12 | `get_fast_panoptic_quality()`: panoptic quality as used in **[3]**. 13 | 14 | ## Sample 15 | 16 |

17 | Metric 18 |

19 | 20 | Given the predictions as above, basic difference between AJI, AJI+ and Panoptic Quality is summarized 21 | in the following table. 22 | 23 | | | DICE2 | AJI | AJI+ | PQ | 24 | | ------------- |:------:|:------:|:------:|:------:| 25 | | Prediction A | 0.6477 | 0.4790 | 0.6375 | 0.6803 | 26 | | Prediction B | 0.9007 | 0.6414 | 0.6414 | 0.6863 | 27 | 28 | ## Processing 29 | 30 | ### Instance Segmentation 31 | 32 | To get the instance segmentation measurements, run:
33 | `python compute_stats.py --mode=instance --pred_dir='pred_dir' --true_dir='true_dir'` 34 | 35 | Toggle `print_img_stats` to determine whether to show the stats for each image. 36 | 37 | ### Classification 38 | 39 | To get the classification measurements, run:
40 | `python compute_stats.py --mode=type --pred_dir='pred_dir' --true_dir='true_dir'` 41 | 42 | The above calculates the classification metrics, as discussed in the evaluation metrics section of our paper. 43 | 44 | 45 | ## References 46 | **[1]** Kumar, Neeraj, Ruchika Verma, Sanuj Sharma, Surabhi Bhargava, Abhishek Vahadane, and Amit Sethi. "A dataset and a technique for generalized nuclear segmentation for computational pathology." IEEE transactions on medical imaging 36, no. 7 (2017): 1550-1560.
47 | **[2]** Vu, Quoc Dang, Simon Graham, Minh Nguyen Nhat To, Muhammad Shaban, Talha Qaiser, Navid Alemi Koohbanani, Syed Ali Khurram et al. "Methods for Segmentation and Classification of Digital Microscopy Tissue Images." arXiv preprint arXiv:1810.13230 (2018).
48 | **[3]** Kirillov, Alexander, Kaiming He, Ross Girshick, Carsten Rother, and Piotr Dollár. "Panoptic Segmentation." arXiv preprint arXiv:1801.00868 (2018). 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /Hover/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuscc-deep-path/sc_MTOP/33ff31fbd01f37705118244da0a8df96a4f19014/Hover/metrics/__init__.py -------------------------------------------------------------------------------- /Hover/misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuscc-deep-path/sc_MTOP/33ff31fbd01f37705118244da0a8df96a4f19014/Hover/misc/__init__.py -------------------------------------------------------------------------------- /Hover/misc/patch_extractor.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import cv2 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | from .utils import cropping_center 9 | 10 | 11 | ##### 12 | class PatchExtractor(object): 13 | """Extractor to generate patches with or without padding. 14 | Turn on debug mode to see how it is done. 15 | 16 | Args: 17 | x : input image, should be of shape HWC 18 | win_size : a tuple of (h, w) 19 | step_size : a tuple of (h, w) 20 | debug : flag to see how it is done 21 | Return: 22 | a list of sub patches, each patch has dtype same as x 23 | 24 | Examples: 25 | >>> xtractor = PatchExtractor((450, 450), (120, 120)) 26 | >>> img = np.full([1200, 1200, 3], 255, np.uint8) 27 | >>> patches = xtractor.extract(img, 'mirror') 28 | 29 | """ 30 | 31 | def __init__(self, win_size, step_size, debug=False): 32 | 33 | self.patch_type = "mirror" 34 | self.win_size = win_size 35 | self.step_size = step_size 36 | self.debug = debug 37 | self.counter = 0 38 | 39 | def __get_patch(self, x, ptx): 40 | pty = (ptx[0] + self.win_size[0], ptx[1] + self.win_size[1]) 41 | win = x[ptx[0] : pty[0], ptx[1] : pty[1]] 42 | assert ( 43 | win.shape[0] == self.win_size[0] and win.shape[1] == self.win_size[1] 44 | ), "[BUG] Incorrect Patch Size {0}".format(win.shape) 45 | if self.debug: 46 | if self.patch_type == "mirror": 47 | cen = cropping_center(win, self.step_size) 48 | cen = cen[..., self.counter % 3] 49 | cen.fill(150) 50 | cv2.rectangle(x, ptx, pty, (255, 0, 0), 2) 51 | plt.imshow(x) 52 | plt.show(block=False) 53 | plt.pause(1) 54 | plt.close() 55 | self.counter += 1 56 | return win 57 | 58 | def __extract_valid(self, x): 59 | """Extracted patches without padding, only work in case win_size > step_size. 60 | 61 | Note: to deal with the remaining portions which are at the boundary a.k.a 62 | those which do not fit when slide left->right, top->bottom), we flip 63 | the sliding direction then extract 1 patch starting from right / bottom edge. 64 | There will be 1 additional patch extracted at the bottom-right corner. 65 | 66 | Args: 67 | x : input image, should be of shape HWC 68 | win_size : a tuple of (h, w) 69 | step_size : a tuple of (h, w) 70 | Return: 71 | a list of sub patches, each patch is same dtype as x 72 | 73 | """ 74 | im_h = x.shape[0] 75 | im_w = x.shape[1] 76 | 77 | def extract_infos(length, win_size, step_size): 78 | flag = (length - win_size) % step_size != 0 79 | last_step = math.floor((length - win_size) / step_size) 80 | last_step = (last_step + 1) * step_size 81 | return flag, last_step 82 | 83 | h_flag, h_last = extract_infos(im_h, self.win_size[0], self.step_size[0]) 84 | w_flag, w_last = extract_infos(im_w, self.win_size[1], self.step_size[1]) 85 | 86 | sub_patches = [] 87 | #### Deal with valid block 88 | for row in range(0, h_last, self.step_size[0]): 89 | for col in range(0, w_last, self.step_size[1]): 90 | win = self.__get_patch(x, (row, col)) 91 | sub_patches.append(win) 92 | #### Deal with edge case 93 | if h_flag: 94 | row = im_h - self.win_size[0] 95 | for col in range(0, w_last, self.step_size[1]): 96 | win = self.__get_patch(x, (row, col)) 97 | sub_patches.append(win) 98 | if w_flag: 99 | col = im_w - self.win_size[1] 100 | for row in range(0, h_last, self.step_size[0]): 101 | win = self.__get_patch(x, (row, col)) 102 | sub_patches.append(win) 103 | if h_flag and w_flag: 104 | ptx = (im_h - self.win_size[0], im_w - self.win_size[1]) 105 | win = self.__get_patch(x, ptx) 106 | sub_patches.append(win) 107 | return sub_patches 108 | 109 | def __extract_mirror(self, x): 110 | """Extracted patches with mirror padding the boundary such that the 111 | central region of each patch is always within the orginal (non-padded) 112 | image while all patches' central region cover the whole orginal image. 113 | 114 | Args: 115 | x : input image, should be of shape HWC 116 | win_size : a tuple of (h, w) 117 | step_size : a tuple of (h, w) 118 | Return: 119 | a list of sub patches, each patch is same dtype as x 120 | 121 | """ 122 | diff_h = self.win_size[0] - self.step_size[0] 123 | padt = diff_h // 2 124 | padb = diff_h - padt 125 | 126 | diff_w = self.win_size[1] - self.step_size[1] 127 | padl = diff_w // 2 128 | padr = diff_w - padl 129 | 130 | pad_type = "constant" if self.debug else "reflect" 131 | x = np.lib.pad(x, ((padt, padb), (padl, padr), (0, 0)), pad_type) 132 | sub_patches = self.__extract_valid(x) 133 | return sub_patches 134 | 135 | def extract(self, x, patch_type): 136 | patch_type = patch_type.lower() 137 | self.patch_type = patch_type 138 | if patch_type == "valid": 139 | return self.__extract_valid(x) 140 | elif patch_type == "mirror": 141 | return self.__extract_mirror(x) 142 | else: 143 | assert False, "Unknown Patch Type [%s]" % patch_type 144 | return 145 | 146 | 147 | # ---------------------------------------------------------------------------- 148 | 149 | if __name__ == "__main__": 150 | # toy example for debug 151 | # 355x355, 480x480 152 | xtractor = PatchExtractor((450, 450), (120, 120), debug=True) 153 | a = np.full([1200, 1200, 3], 255, np.uint8) 154 | xtractor.extract(a, "mirror") 155 | xtractor.extract(a, "valid") 156 | -------------------------------------------------------------------------------- /Hover/misc/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import inspect 3 | import logging 4 | import os 5 | import shutil 6 | 7 | import cv2 8 | import numpy as np 9 | from scipy import ndimage 10 | 11 | 12 | #### 13 | def normalize(mask, dtype=np.uint8): 14 | return (255 * mask / np.amax(mask)).astype(dtype) 15 | 16 | 17 | #### 18 | def get_bounding_box(img): 19 | """Get bounding box coordinate information.""" 20 | rows = np.any(img, axis=1) 21 | cols = np.any(img, axis=0) 22 | rmin, rmax = np.where(rows)[0][[0, -1]] 23 | cmin, cmax = np.where(cols)[0][[0, -1]] 24 | # due to python indexing, need to add 1 to max 25 | # else accessing will be 1px in the box, not out 26 | rmax += 1 27 | cmax += 1 28 | return [rmin, rmax, cmin, cmax] 29 | 30 | 31 | #### 32 | def cropping_center(x, crop_shape, batch=False): 33 | """Crop an input image at the centre. 34 | 35 | Args: 36 | x: input array 37 | crop_shape: dimensions of cropped array 38 | 39 | Returns: 40 | x: cropped array 41 | 42 | """ 43 | orig_shape = x.shape 44 | if not batch: 45 | h0 = int((orig_shape[0] - crop_shape[0]) * 0.5) 46 | w0 = int((orig_shape[1] - crop_shape[1]) * 0.5) 47 | x = x[h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]] 48 | else: 49 | h0 = int((orig_shape[1] - crop_shape[0]) * 0.5) 50 | w0 = int((orig_shape[2] - crop_shape[1]) * 0.5) 51 | x = x[:, h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]] 52 | return x 53 | 54 | 55 | #### 56 | def rm_n_mkdir(dir_path): 57 | """Remove and make directory.""" 58 | if os.path.isdir(dir_path): 59 | shutil.rmtree(dir_path) 60 | os.makedirs(dir_path) 61 | 62 | 63 | #### 64 | def mkdir(dir_path): 65 | """Make directory.""" 66 | if not os.path.isdir(dir_path): 67 | os.makedirs(dir_path) 68 | 69 | 70 | #### 71 | def get_inst_centroid(inst_map): 72 | """Get instance centroids given an input instance map. 73 | 74 | Args: 75 | inst_map: input instance map 76 | 77 | Returns: 78 | array of centroids 79 | 80 | """ 81 | inst_centroid_list = [] 82 | inst_id_list = list(np.unique(inst_map)) 83 | for inst_id in inst_id_list[1:]: # avoid 0 i.e background 84 | mask = np.array(inst_map == inst_id, np.uint8) 85 | inst_moment = cv2.moments(mask) 86 | inst_centroid = [ 87 | (inst_moment["m10"] / inst_moment["m00"]), 88 | (inst_moment["m01"] / inst_moment["m00"]), 89 | ] 90 | inst_centroid_list.append(inst_centroid) 91 | return np.array(inst_centroid_list) 92 | 93 | 94 | #### 95 | def center_pad_to_shape(img, size, cval=255): 96 | """Pad input image.""" 97 | # rounding down, add 1 98 | pad_h = size[0] - img.shape[0] 99 | pad_w = size[1] - img.shape[1] 100 | pad_h = (pad_h // 2, pad_h - pad_h // 2) 101 | pad_w = (pad_w // 2, pad_w - pad_w // 2) 102 | if len(img.shape) == 2: 103 | pad_shape = (pad_h, pad_w) 104 | else: 105 | pad_shape = (pad_h, pad_w, (0, 0)) 106 | img = np.pad(img, pad_shape, "constant", constant_values=cval) 107 | return img 108 | 109 | 110 | #### 111 | def color_deconvolution(rgb, stain_mat): 112 | """Apply colour deconvolution.""" 113 | log255 = np.log(255) # to base 10, not base e 114 | rgb_float = rgb.astype(np.float64) 115 | log_rgb = -((255.0 * np.log((rgb_float + 1) / 255.0)) / log255) 116 | output = np.exp(-(log_rgb @ stain_mat - 255.0) * log255 / 255.0) 117 | output[output > 255] = 255 118 | output = np.floor(output + 0.5).astype("uint8") 119 | return output 120 | 121 | 122 | #### 123 | def log_debug(msg): 124 | frame, filename, line_number, function_name, lines, index = inspect.getouterframes( 125 | inspect.currentframe() 126 | )[1] 127 | line = lines[0] 128 | indentation_level = line.find(line.lstrip()) 129 | logging.debug("{i} {m}".format(i="." * indentation_level, m=msg)) 130 | 131 | 132 | #### 133 | def log_info(msg): 134 | frame, filename, line_number, function_name, lines, index = inspect.getouterframes( 135 | inspect.currentframe() 136 | )[1] 137 | line = lines[0] 138 | indentation_level = line.find(line.lstrip()) 139 | logging.info("{i} {m}".format(i="." * indentation_level, m=msg)) 140 | 141 | 142 | def remove_small_objects(pred, min_size=64, connectivity=1): 143 | """Remove connected components smaller than the specified size. 144 | 145 | This function is taken from skimage.morphology.remove_small_objects, but the warning 146 | is removed when a single label is provided. 147 | 148 | Args: 149 | pred: input labelled array 150 | min_size: minimum size of instance in output array 151 | connectivity: The connectivity defining the neighborhood of a pixel. 152 | 153 | Returns: 154 | out: output array with instances removed under min_size 155 | 156 | """ 157 | out = pred 158 | 159 | if min_size == 0: # shortcut for efficiency 160 | return out 161 | 162 | if out.dtype == bool: 163 | selem = ndimage.generate_binary_structure(pred.ndim, connectivity) 164 | ccs = np.zeros_like(pred, dtype=np.int32) 165 | ndimage.label(pred, selem, output=ccs) 166 | else: 167 | ccs = out 168 | 169 | try: 170 | component_sizes = np.bincount(ccs.ravel()) 171 | except ValueError: 172 | raise ValueError( 173 | "Negative value labels are not supported. Try " 174 | "relabeling the input with `scipy.ndimage.label` or " 175 | "`skimage.morphology.label`." 176 | ) 177 | 178 | too_small = component_sizes < min_size 179 | too_small_mask = too_small[ccs] 180 | out[too_small_mask] = 0 181 | 182 | return out 183 | -------------------------------------------------------------------------------- /Hover/misc/viz_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import random 4 | import colorsys 5 | import numpy as np 6 | import itertools 7 | import matplotlib.pyplot as plt 8 | from matplotlib import cm 9 | 10 | from .utils import get_bounding_box 11 | 12 | #### 13 | def colorize(ch, vmin, vmax): 14 | """Will clamp value value outside the provided range to vmax and vmin.""" 15 | cmap = plt.get_cmap("jet") 16 | ch = np.squeeze(ch.astype("float32")) 17 | vmin = vmin if vmin is not None else ch.min() 18 | vmax = vmax if vmax is not None else ch.max() 19 | ch[ch > vmax] = vmax # clamp value 20 | ch[ch < vmin] = vmin 21 | ch = (ch - vmin) / (vmax - vmin + 1.0e-16) 22 | # take RGB from RGBA heat map 23 | ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8") 24 | return ch_cmap 25 | 26 | 27 | #### 28 | def random_colors(N, bright=True): 29 | """Generate random colors. 30 | 31 | To get visually distinct colors, generate them in HSV space then 32 | convert to RGB. 33 | """ 34 | brightness = 1.0 if bright else 0.7 35 | hsv = [(i / N, 1, brightness) for i in range(N)] 36 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) 37 | random.shuffle(colors) 38 | return colors 39 | 40 | 41 | #### 42 | def visualize_instances_map( 43 | input_image, inst_map, type_map=None, type_colour=None, line_thickness=2 44 | ): 45 | """Overlays segmentation results on image as contours. 46 | 47 | Args: 48 | input_image: input image 49 | inst_map: instance mask with unique value for every object 50 | type_map: type mask with unique value for every class 51 | type_colour: a dict of {type : colour} , `type` is from 0-N 52 | and `colour` is a tuple of (R, G, B) 53 | line_thickness: line thickness of contours 54 | 55 | Returns: 56 | overlay: output image with segmentation overlay as contours 57 | """ 58 | overlay = np.copy((input_image).astype(np.uint8)) 59 | 60 | inst_list = list(np.unique(inst_map)) # get list of instances 61 | inst_list.remove(0) # remove background 62 | 63 | inst_rng_colors = random_colors(len(inst_list)) 64 | inst_rng_colors = np.array(inst_rng_colors) * 255 65 | inst_rng_colors = inst_rng_colors.astype(np.uint8) 66 | 67 | for inst_idx, inst_id in enumerate(inst_list): 68 | inst_map_mask = np.array(inst_map == inst_id, np.uint8) # get single object 69 | y1, y2, x1, x2 = get_bounding_box(inst_map_mask) 70 | y1 = y1 - 2 if y1 - 2 >= 0 else y1 71 | x1 = x1 - 2 if x1 - 2 >= 0 else x1 72 | x2 = x2 + 2 if x2 + 2 <= inst_map.shape[1] - 1 else x2 73 | y2 = y2 + 2 if y2 + 2 <= inst_map.shape[0] - 1 else y2 74 | inst_map_crop = inst_map_mask[y1:y2, x1:x2] 75 | contours_crop = cv2.findContours( 76 | inst_map_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE 77 | ) 78 | # only has 1 instance per map, no need to check #contour detected by opencv 79 | contours_crop = np.squeeze( 80 | contours_crop[0][0].astype("int32") 81 | ) # * opencv protocol format may break 82 | contours_crop += np.asarray([[x1, y1]]) # index correction 83 | if type_map is not None: 84 | type_map_crop = type_map[y1:y2, x1:x2] 85 | type_id = np.unique(type_map_crop).max() # non-zero 86 | inst_colour = type_colour[type_id] 87 | else: 88 | inst_colour = (inst_rng_colors[inst_idx]).tolist() 89 | cv2.drawContours(overlay, [contours_crop], -1, inst_colour, line_thickness) 90 | return overlay 91 | 92 | 93 | #### 94 | def visualize_instances_dict( 95 | input_image, inst_dict, draw_dot=False, type_colour=None, line_thickness=2 96 | ): 97 | """Overlays segmentation results (dictionary) on image as contours. 98 | 99 | Args: 100 | input_image: input image 101 | inst_dict: dict of output prediction, defined as in this library 102 | draw_dot: to draw a dot for each centroid 103 | type_colour: a dict of {type_id : (type_name, colour)} , 104 | `type_id` is from 0-N and `colour` is a tuple of (R, G, B) 105 | line_thickness: line thickness of contours 106 | """ 107 | overlay = np.copy((input_image)) 108 | 109 | inst_rng_colors = random_colors(len(inst_dict)) 110 | inst_rng_colors = np.array(inst_rng_colors) * 255 111 | inst_rng_colors = inst_rng_colors.astype(np.uint8) 112 | 113 | for idx, [inst_id, inst_info] in enumerate(inst_dict.items()): 114 | inst_contour = inst_info["contour"] 115 | if "type" in inst_info and type_colour is not None: 116 | inst_colour = type_colour[inst_info["type"]][1] 117 | else: 118 | inst_colour = (inst_rng_colors[idx]).tolist() 119 | cv2.drawContours(overlay, [inst_contour], -1, inst_colour, line_thickness) 120 | 121 | if draw_dot: 122 | inst_centroid = inst_info["centroid"] 123 | inst_centroid = tuple([int(v) for v in inst_centroid]) 124 | overlay = cv2.circle(overlay, inst_centroid, 3, (255, 0, 0), -1) 125 | return overlay 126 | 127 | 128 | #### 129 | def gen_figure( 130 | imgs_list, 131 | titles, 132 | fig_inch, 133 | shape=None, 134 | share_ax="all", 135 | show=False, 136 | colormap=plt.get_cmap("jet"), 137 | ): 138 | """Generate figure.""" 139 | num_img = len(imgs_list) 140 | if shape is None: 141 | ncols = math.ceil(math.sqrt(num_img)) 142 | nrows = math.ceil(num_img / ncols) 143 | else: 144 | nrows, ncols = shape 145 | 146 | # generate figure 147 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols, sharex=share_ax, sharey=share_ax) 148 | fig.set_dpi(600) 149 | axes = [axes] if nrows == 1 else axes 150 | 151 | # not very elegant 152 | idx = 0 153 | for ax in axes: 154 | for cell in ax: 155 | cell.set_title(titles[idx]) 156 | cell.imshow(imgs_list[idx], cmap=colormap) 157 | cell.tick_params( 158 | axis="both", 159 | which="both", 160 | bottom="off", 161 | top="off", 162 | labelbottom="off", 163 | right="off", 164 | left="off", 165 | labelleft="off", 166 | ) 167 | idx += 1 168 | if idx == len(titles): 169 | break 170 | if idx == len(titles): 171 | break 172 | 173 | fig.tight_layout() 174 | return fig 175 | -------------------------------------------------------------------------------- /Hover/misc/wsi_handler.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import cv2 3 | import numpy as np 4 | from skimage import img_as_ubyte 5 | from skimage import color 6 | import xml.etree.ElementTree as et 7 | import re 8 | import subprocess 9 | 10 | import openslide 11 | 12 | 13 | class FileHandler(object): 14 | def __init__(self): 15 | """The handler is responsible for storing the processed data, parsing 16 | the metadata from original file, and reading it from storage. 17 | """ 18 | self.metadata = { 19 | ("available_mag", None), 20 | ("base_mag", None), 21 | ("vendor", None), 22 | ("mpp ", None), 23 | ("base_shape", None), 24 | } 25 | pass 26 | 27 | def __load_metadata(self): 28 | raise NotImplementedError 29 | 30 | def get_full_img(self, read_mag=None, read_mpp=None): 31 | """Only use `read_mag` or `read_mpp`, not both, prioritize `read_mpp`. 32 | 33 | `read_mpp` is in X, Y format 34 | """ 35 | raise NotImplementedError 36 | 37 | def read_region(self, coords, size): 38 | """Must call `prepare_reading` before hand. 39 | 40 | Args: 41 | coords (tuple): (dims_x, dims_y), 42 | top left coordinates of image region at selected 43 | `read_mag` or `read_mpp` from `prepare_reading` 44 | size (tuple): (dims_x, dims_y) 45 | width and height of image region at selected 46 | `read_mag` or `read_mpp` from `prepare_reading` 47 | 48 | """ 49 | raise NotImplementedError 50 | 51 | def get_dimensions(self, read_mag=None, read_mpp=None): 52 | """Will be in X, Y.""" 53 | if read_mpp is not None: 54 | read_scale = (self.metadata["base_mpp"] / read_mpp)[0] 55 | read_mag = read_scale * self.metadata["base_mag"] 56 | scale = read_mag / self.metadata["base_mag"] 57 | # may off some pixels wrt existing mag 58 | return (self.metadata["base_shape"] * scale).astype(np.int32) 59 | 60 | def prepare_reading(self, read_mag=None, read_mpp=None, cache_path=None): 61 | """Only use `read_mag` or `read_mpp`, not both, prioritize `read_mpp`. 62 | 63 | `read_mpp` is in X, Y format. 64 | """ 65 | read_lv, scale_factor = self._get_read_info( 66 | read_mag=read_mag, read_mpp=read_mpp 67 | ) 68 | 69 | if scale_factor is None: 70 | self.image_ptr = None 71 | self.read_lv = read_lv 72 | else: 73 | np.save(cache_path, self.get_full_img(read_mag=read_mag)) 74 | self.image_ptr = np.load(cache_path, mmap_mode="r") 75 | return 76 | 77 | def _get_read_info(self, read_mag=None, read_mpp=None): 78 | if read_mpp is not None: 79 | assert read_mpp[0] == read_mpp[1], "Not supported uneven `read_mpp`" 80 | read_scale = (self.metadata["base_mpp"] / read_mpp)[0] 81 | read_mag = read_scale * self.metadata["base_mag"] 82 | 83 | hires_mag = read_mag 84 | scale_factor = None 85 | if read_mag not in self.metadata["available_mag"]: 86 | if read_mag > self.metadata["base_mag"]: 87 | scale_factor = read_mag / self.metadata["base_mag"] 88 | hires_mag = self.metadata["base_mag"] 89 | else: 90 | mag_list = np.array(self.metadata["available_mag"]) 91 | mag_list = np.sort(mag_list)[::-1] 92 | hires_mag = mag_list - read_mag 93 | # only use higher mag as base for loading 94 | hires_mag = hires_mag[hires_mag > 0] 95 | # use the immediate higher to save compuration 96 | hires_mag = mag_list[np.argmin(hires_mag)] 97 | scale_factor = read_mag / hires_mag 98 | 99 | hires_lv = self.metadata["available_mag"].index(hires_mag) 100 | return hires_lv, scale_factor 101 | 102 | 103 | class OpenSlideHandler(FileHandler): 104 | """Class for handling OpenSlide supported whole-slide images.""" 105 | 106 | def __init__(self, file_path, path_presplit = None): 107 | """file_path (string): path to single whole-slide image.""" 108 | super().__init__() 109 | 110 | self.file_ptr = openslide.OpenSlide(file_path) # load OpenSlide object 111 | self.path_presplit = path_presplit 112 | self.metadata = self.__load_metadata() 113 | 114 | # only used for cases where the read magnification is different from 115 | self.image_ptr = None # the existing modes of the read file 116 | self.read_level = None 117 | 118 | def __load_metadata(self): 119 | self.box = None 120 | if self.path_presplit: 121 | tree = et.parse(self.path_presplit) 122 | root = tree.getroot() 123 | box = np.array([[float(vertex.get('X')), float(vertex.get('Y'))] for vertex in root.find('Annotation').find('Regions').find('Region').find('Vertices').findall('Vertex')]) 124 | self.box = box = box.astype('int64') 125 | start_point = np.array([box[:, 0].min(), box[:, 1].min()]) 126 | end_point = np.array([box[:, 0].max(), box[:, 1].max()]) 127 | box_size = end_point - start_point 128 | metadata = {} 129 | 130 | wsi_properties = self.file_ptr.properties 131 | level_0_magnification = wsi_properties[openslide.PROPERTY_NAME_OBJECTIVE_POWER] 132 | level_0_magnification = float(level_0_magnification) 133 | 134 | downsample_level = self.file_ptr.level_downsamples 135 | magnification_level = [level_0_magnification / lv for lv in downsample_level] 136 | 137 | mpp = [ 138 | wsi_properties[openslide.PROPERTY_NAME_MPP_X], 139 | wsi_properties[openslide.PROPERTY_NAME_MPP_Y], 140 | ] 141 | mpp = np.array(mpp) 142 | 143 | metadata = [ 144 | ("available_mag", magnification_level), # highest to lowest mag 145 | ("base_mag", magnification_level[0]), 146 | ("vendor", wsi_properties[openslide.PROPERTY_NAME_VENDOR]), 147 | ("mpp ", mpp), 148 | ("base_shape", np.array(self.file_ptr.dimensions) if not self.path_presplit else np.array(box_size)), 149 | ("downsample_level", downsample_level) 150 | ] 151 | return OrderedDict(metadata) 152 | 153 | def read_region(self, coords, size): 154 | """Must call `prepare_reading` before hand. 155 | 156 | Args: 157 | coords (tuple): (dims_x, dims_y), 158 | top left coordinates of image region at selected 159 | `read_mag` or `read_mpp` from `prepare_reading` 160 | size (tuple): (dims_x, dims_y) 161 | width and height of image region at selected 162 | `read_mag` or `read_mpp` from `prepare_reading` 163 | 164 | """ 165 | if self.image_ptr is None: 166 | # convert coord from read lv to lv zero 167 | lv_0_shape = np.array(self.file_ptr.level_dimensions[0]) 168 | lv_r_shape = np.array(self.file_ptr.level_dimensions[self.read_lv]) 169 | up_sample = (lv_0_shape / lv_r_shape)[0] 170 | new_coord = [0, 0] 171 | new_coord[0] = int(coords[0] * up_sample) 172 | new_coord[1] = int(coords[1] * up_sample) 173 | if self.box: 174 | new_coord[0] = int(new_coord[0] + self.box[:, 0].min()) 175 | new_coord[1] = int(new_coord[1] + self.box[:, 1].min()) 176 | region = self.file_ptr.read_region(new_coord, self.read_lv, size) 177 | else: 178 | region = self.image_ptr[ 179 | coords[1] : coords[1] + size[1], coords[0] : coords[0] + size[0] 180 | ] 181 | return np.array(region)[..., :3] 182 | 183 | def get_full_img(self, read_mag=None, read_mpp=None): 184 | """Only use `read_mag` or `read_mpp`, not both, prioritize `read_mpp`. 185 | 186 | `read_mpp` is in X, Y format. 187 | """ 188 | 189 | read_lv, scale_factor = self._get_read_info( 190 | read_mag=read_mag, read_mpp=read_mpp 191 | ) 192 | 193 | read_size = self.file_ptr.level_dimensions[read_lv] 194 | start_point = (0, 0) 195 | 196 | if self.path_presplit: 197 | box = self.box 198 | start_point = np.array([box[:, 0].min(), box[:, 1].min()]) 199 | end_point = np.array([box[:, 0].max(), box[:, 1].max()]) 200 | box_size = end_point - start_point 201 | read_size = (box_size / self.metadata['downsample_level'][read_lv]).astype('int64') 202 | if scale_factor: 203 | read_size = (read_size * scale_factor).astype('int64') 204 | temp = self.metadata['downsample_level'][read_lv] 205 | wsi_img = self.file_ptr.read_region(start_point, read_lv, read_size) 206 | wsi_img = np.array(wsi_img)[..., :3] # remove alpha channel 207 | if scale_factor is not None: 208 | # now rescale then return 209 | if scale_factor > 1.0: 210 | interp = cv2.INTER_CUBIC 211 | else: 212 | interp = cv2.INTER_LINEAR 213 | wsi_img = cv2.resize( 214 | wsi_img, (0, 0), fx=scale_factor, fy=scale_factor, interpolation=interp 215 | ) 216 | return wsi_img 217 | 218 | 219 | def get_file_handler(path, backend, path_presplit = None): 220 | if backend in [ 221 | '.svs', '.tif', 222 | '.vms', '.vmu', '.ndpi', 223 | '.scn', '.mrxs', '.tiff', 224 | '.svslide', 225 | '.bif', 226 | ]: 227 | return OpenSlideHandler(path, path_presplit) 228 | else: 229 | assert False, "Unknown WSI format `%s`" % backend 230 | 231 | -------------------------------------------------------------------------------- /Hover/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuscc-deep-path/sc_MTOP/33ff31fbd01f37705118244da0a8df96a4f19014/Hover/models/__init__.py -------------------------------------------------------------------------------- /Hover/models/hovernet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuscc-deep-path/sc_MTOP/33ff31fbd01f37705118244da0a8df96a4f19014/Hover/models/hovernet/__init__.py -------------------------------------------------------------------------------- /Hover/models/hovernet/net_desc.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import OrderedDict 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from .net_utils import (DenseBlock, Net, ResidualBlock, TFSamepaddingLayer, 10 | UpSample2x) 11 | from .utils import crop_op, crop_to_shape 12 | 13 | #### 14 | class HoVerNet(Net): 15 | """Initialise HoVer-Net.""" 16 | 17 | def __init__(self, input_ch=3, nr_types=None, freeze=False, mode='original'): 18 | super().__init__() 19 | self.mode = mode 20 | self.freeze = freeze 21 | self.nr_types = nr_types 22 | self.output_ch = 3 if nr_types is None else 4 23 | 24 | assert mode == 'original' or mode == 'fast', \ 25 | 'Unknown mode `%s` for HoVerNet %s. Only support `original` or `fast`.' % mode 26 | 27 | module_list = [ 28 | ("/", nn.Conv2d(input_ch, 64, 7, stride=1, padding=0, bias=False)), 29 | ("bn", nn.BatchNorm2d(64, eps=1e-5)), 30 | ("relu", nn.ReLU(inplace=True)), 31 | ] 32 | if mode == 'fast': # prepend the padding for `fast` mode 33 | module_list = [("pad", TFSamepaddingLayer(ksize=7, stride=1))] + module_list 34 | 35 | self.conv0 = nn.Sequential(OrderedDict(module_list)) 36 | self.d0 = ResidualBlock(64, [1, 3, 1], [64, 64, 256], 3, stride=1) 37 | self.d1 = ResidualBlock(256, [1, 3, 1], [128, 128, 512], 4, stride=2) 38 | self.d2 = ResidualBlock(512, [1, 3, 1], [256, 256, 1024], 6, stride=2) 39 | self.d3 = ResidualBlock(1024, [1, 3, 1], [512, 512, 2048], 3, stride=2) 40 | 41 | self.conv_bot = nn.Conv2d(2048, 1024, 1, stride=1, padding=0, bias=False) 42 | 43 | def create_decoder_branch(out_ch=2, ksize=5): 44 | module_list = [ 45 | ("conva", nn.Conv2d(1024, 256, ksize, stride=1, padding=0, bias=False)), 46 | ("dense", DenseBlock(256, [1, ksize], [128, 32], 8, split=4)), 47 | ("convf", nn.Conv2d(512, 512, 1, stride=1, padding=0, bias=False),), 48 | ] 49 | u3 = nn.Sequential(OrderedDict(module_list)) 50 | 51 | module_list = [ 52 | ("conva", nn.Conv2d(512, 128, ksize, stride=1, padding=0, bias=False)), 53 | ("dense", DenseBlock(128, [1, ksize], [128, 32], 4, split=4)), 54 | ("convf", nn.Conv2d(256, 256, 1, stride=1, padding=0, bias=False),), 55 | ] 56 | u2 = nn.Sequential(OrderedDict(module_list)) 57 | 58 | module_list = [ 59 | ("conva/pad", TFSamepaddingLayer(ksize=ksize, stride=1)), 60 | ("conva", nn.Conv2d(256, 64, ksize, stride=1, padding=0, bias=False),), 61 | ] 62 | u1 = nn.Sequential(OrderedDict(module_list)) 63 | 64 | module_list = [ 65 | ("bn", nn.BatchNorm2d(64, eps=1e-5)), 66 | ("relu", nn.ReLU(inplace=True)), 67 | ("conv", nn.Conv2d(64, out_ch, 1, stride=1, padding=0, bias=True),), 68 | ] 69 | u0 = nn.Sequential(OrderedDict(module_list)) 70 | 71 | decoder = nn.Sequential( 72 | OrderedDict([("u3", u3), ("u2", u2), ("u1", u1), ("u0", u0),]) 73 | ) 74 | return decoder 75 | 76 | ksize = 5 if mode == 'original' else 3 77 | if nr_types is None: 78 | self.decoder = nn.ModuleDict( 79 | OrderedDict( 80 | [ 81 | ("np", create_decoder_branch(ksize=ksize,out_ch=2)), 82 | ("hv", create_decoder_branch(ksize=ksize,out_ch=2)), 83 | ] 84 | ) 85 | ) 86 | else: 87 | self.decoder = nn.ModuleDict( 88 | OrderedDict( 89 | [ 90 | ("tp", create_decoder_branch(ksize=ksize, out_ch=nr_types)), 91 | ("np", create_decoder_branch(ksize=ksize, out_ch=2)), 92 | ("hv", create_decoder_branch(ksize=ksize, out_ch=2)), 93 | ] 94 | ) 95 | ) 96 | 97 | self.upsample2x = UpSample2x() 98 | # TODO: pytorch still require the channel eventhough its ignored 99 | self.weights_init() 100 | 101 | def forward(self, imgs): 102 | 103 | imgs = imgs / 255.0 # to 0-1 range to match XY 104 | 105 | if self.training: 106 | d0 = self.conv0(imgs) 107 | d0 = self.d0(d0, self.freeze) 108 | with torch.set_grad_enabled(not self.freeze): 109 | d1 = self.d1(d0) 110 | d2 = self.d2(d1) 111 | d3 = self.d3(d2) 112 | d3 = self.conv_bot(d3) 113 | d = [d0, d1, d2, d3] 114 | else: 115 | d0 = self.conv0(imgs) 116 | d0 = self.d0(d0) 117 | d1 = self.d1(d0) 118 | d2 = self.d2(d1) 119 | d3 = self.d3(d2) 120 | d3 = self.conv_bot(d3) 121 | d = [d0, d1, d2, d3] 122 | 123 | # TODO: switch to `crop_to_shape` ? 124 | if self.mode == 'original': 125 | d[0] = crop_op(d[0], [184, 184]) 126 | d[1] = crop_op(d[1], [72, 72]) 127 | else: 128 | d[0] = crop_op(d[0], [92, 92]) 129 | d[1] = crop_op(d[1], [36, 36]) 130 | 131 | out_dict = OrderedDict() 132 | for branch_name, branch_desc in self.decoder.items(): 133 | u3 = self.upsample2x(d[-1]) + d[-2] 134 | u3 = branch_desc[0](u3) 135 | 136 | u2 = self.upsample2x(u3) + d[-3] 137 | u2 = branch_desc[1](u2) 138 | 139 | u1 = self.upsample2x(u2) + d[-4] 140 | u1 = branch_desc[2](u1) 141 | 142 | u0 = branch_desc[3](u1) 143 | out_dict[branch_name] = u0 144 | 145 | return out_dict 146 | 147 | 148 | #### 149 | def create_model(mode=None, **kwargs): 150 | if mode not in ['original', 'fast']: 151 | assert "Unknown Model Mode %s" % mode 152 | return HoVerNet(mode=mode, **kwargs) 153 | 154 | -------------------------------------------------------------------------------- /Hover/models/hovernet/net_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from collections import OrderedDict 9 | 10 | from .utils import crop_op, crop_to_shape 11 | from config import Config 12 | 13 | 14 | #### 15 | class Net(nn.Module): 16 | """ A base class provides a common weight initialisation scheme.""" 17 | 18 | def weights_init(self): 19 | for m in self.modules(): 20 | classname = m.__class__.__name__ 21 | 22 | # ! Fixed the type checking 23 | if isinstance(m, nn.Conv2d): 24 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 25 | 26 | if "norm" in classname.lower(): 27 | nn.init.constant_(m.weight, 1) 28 | nn.init.constant_(m.bias, 0) 29 | 30 | if "linear" in classname.lower(): 31 | if m.bias is not None: 32 | nn.init.constant_(m.bias, 0) 33 | 34 | def forward(self, x): 35 | return x 36 | 37 | 38 | #### 39 | class TFSamepaddingLayer(nn.Module): 40 | """To align with tf `same` padding. 41 | 42 | Putting this before any conv layer that need padding 43 | Assuming kernel has Height == Width for simplicity 44 | """ 45 | 46 | def __init__(self, ksize, stride): 47 | super(TFSamepaddingLayer, self).__init__() 48 | self.ksize = ksize 49 | self.stride = stride 50 | 51 | def forward(self, x): 52 | if x.shape[2] % self.stride == 0: 53 | pad = max(self.ksize - self.stride, 0) 54 | else: 55 | pad = max(self.ksize - (x.shape[2] % self.stride), 0) 56 | 57 | if pad % 2 == 0: 58 | pad_val = pad // 2 59 | padding = (pad_val, pad_val, pad_val, pad_val) 60 | else: 61 | pad_val_start = pad // 2 62 | pad_val_end = pad - pad_val_start 63 | padding = (pad_val_start, pad_val_end, pad_val_start, pad_val_end) 64 | # print(x.shape, padding) 65 | x = F.pad(x, padding, "constant", 0) 66 | # print(x.shape) 67 | return x 68 | 69 | 70 | #### 71 | class DenseBlock(Net): 72 | """Dense Block as defined in: 73 | 74 | Huang, Gao, Zhuang Liu, Laurens Van Der Maaten, and Kilian Q. Weinberger. 75 | "Densely connected convolutional networks." In Proceedings of the IEEE conference 76 | on computer vision and pattern recognition, pp. 4700-4708. 2017. 77 | 78 | Only performs `valid` convolution. 79 | 80 | """ 81 | 82 | def __init__(self, in_ch, unit_ksize, unit_ch, unit_count, split=1): 83 | super(DenseBlock, self).__init__() 84 | assert len(unit_ksize) == len(unit_ch), "Unbalance Unit Info" 85 | 86 | self.nr_unit = unit_count 87 | self.in_ch = in_ch 88 | self.unit_ch = unit_ch 89 | 90 | # ! For inference only so init values for batchnorm may not match tensorflow 91 | unit_in_ch = in_ch 92 | self.units = nn.ModuleList() 93 | for idx in range(unit_count): 94 | self.units.append( 95 | nn.Sequential( 96 | OrderedDict( 97 | [ 98 | ("preact_bna/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), 99 | ("preact_bna/relu", nn.ReLU(inplace=True)), 100 | ( 101 | "conv1", 102 | nn.Conv2d( 103 | unit_in_ch, 104 | unit_ch[0], 105 | unit_ksize[0], 106 | stride=1, 107 | padding=0, 108 | bias=False, 109 | ), 110 | ), 111 | ("conv1/bn", nn.BatchNorm2d(unit_ch[0], eps=1e-5)), 112 | ("conv1/relu", nn.ReLU(inplace=True)), 113 | # ('conv2/pool', TFSamepaddingLayer(ksize=unit_ksize[1], stride=1)), 114 | ( 115 | "conv2", 116 | nn.Conv2d( 117 | unit_ch[0], 118 | unit_ch[1], 119 | unit_ksize[1], 120 | groups=split, 121 | stride=1, 122 | padding=0, 123 | bias=False, 124 | ), 125 | ), 126 | ] 127 | ) 128 | ) 129 | ) 130 | unit_in_ch += unit_ch[1] 131 | 132 | self.blk_bna = nn.Sequential( 133 | OrderedDict( 134 | [ 135 | ("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), 136 | ("relu", nn.ReLU(inplace=True)), 137 | ] 138 | ) 139 | ) 140 | 141 | def out_ch(self): 142 | return self.in_ch + self.nr_unit * self.unit_ch[-1] 143 | 144 | def forward(self, prev_feat): 145 | for idx in range(self.nr_unit): 146 | new_feat = self.units[idx](prev_feat) 147 | prev_feat = crop_to_shape(prev_feat, new_feat) 148 | prev_feat = torch.cat([prev_feat, new_feat], dim=1) 149 | prev_feat = self.blk_bna(prev_feat) 150 | 151 | return prev_feat 152 | 153 | 154 | #### 155 | class ResidualBlock(Net): 156 | """Residual block as defined in: 157 | 158 | He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning 159 | for image recognition." In Proceedings of the IEEE conference on computer vision 160 | and pattern recognition, pp. 770-778. 2016. 161 | 162 | """ 163 | 164 | def __init__(self, in_ch, unit_ksize, unit_ch, unit_count, stride=1): 165 | super(ResidualBlock, self).__init__() 166 | assert len(unit_ksize) == len(unit_ch), "Unbalance Unit Info" 167 | 168 | self.nr_unit = unit_count 169 | self.in_ch = in_ch 170 | self.unit_ch = unit_ch 171 | 172 | # ! For inference only so init values for batchnorm may not match tensorflow 173 | unit_in_ch = in_ch 174 | self.units = nn.ModuleList() 175 | for idx in range(unit_count): 176 | unit_layer = [ 177 | ("preact/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), 178 | ("preact/relu", nn.ReLU(inplace=True)), 179 | ( 180 | "conv1", 181 | nn.Conv2d( 182 | unit_in_ch, 183 | unit_ch[0], 184 | unit_ksize[0], 185 | stride=1, 186 | padding=0, 187 | bias=False, 188 | ), 189 | ), 190 | ("conv1/bn", nn.BatchNorm2d(unit_ch[0], eps=1e-5)), 191 | ("conv1/relu", nn.ReLU(inplace=True)), 192 | ( 193 | "conv2/pad", 194 | TFSamepaddingLayer( 195 | ksize=unit_ksize[1], stride=stride if idx == 0 else 1 196 | ), 197 | ), 198 | ( 199 | "conv2", 200 | nn.Conv2d( 201 | unit_ch[0], 202 | unit_ch[1], 203 | unit_ksize[1], 204 | stride=stride if idx == 0 else 1, 205 | padding=0, 206 | bias=False, 207 | ), 208 | ), 209 | ("conv2/bn", nn.BatchNorm2d(unit_ch[1], eps=1e-5)), 210 | ("conv2/relu", nn.ReLU(inplace=True)), 211 | ( 212 | "conv3", 213 | nn.Conv2d( 214 | unit_ch[1], 215 | unit_ch[2], 216 | unit_ksize[2], 217 | stride=1, 218 | padding=0, 219 | bias=False, 220 | ), 221 | ), 222 | ] 223 | # * has bna to conclude each previous block so 224 | # * must not put preact for the first unit of this block 225 | unit_layer = unit_layer if idx != 0 else unit_layer[2:] 226 | self.units.append(nn.Sequential(OrderedDict(unit_layer))) 227 | unit_in_ch = unit_ch[-1] 228 | 229 | if in_ch != unit_ch[-1] or stride != 1: 230 | self.shortcut = nn.Conv2d(in_ch, unit_ch[-1], 1, stride=stride, bias=False) 231 | else: 232 | self.shortcut = None 233 | 234 | self.blk_bna = nn.Sequential( 235 | OrderedDict( 236 | [ 237 | ("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), 238 | ("relu", nn.ReLU(inplace=True)), 239 | ] 240 | ) 241 | ) 242 | 243 | # print(self.units[0]) 244 | # print(self.units[1]) 245 | # exit() 246 | 247 | def out_ch(self): 248 | return self.unit_ch[-1] 249 | 250 | def forward(self, prev_feat, freeze=False): 251 | if self.shortcut is None: 252 | shortcut = prev_feat 253 | else: 254 | shortcut = self.shortcut(prev_feat) 255 | 256 | for idx in range(0, len(self.units)): 257 | new_feat = prev_feat 258 | if self.training: 259 | with torch.set_grad_enabled(not freeze): 260 | new_feat = self.units[idx](new_feat) 261 | else: 262 | new_feat = self.units[idx](new_feat) 263 | prev_feat = new_feat + shortcut 264 | shortcut = prev_feat 265 | feat = self.blk_bna(prev_feat) 266 | return feat 267 | 268 | 269 | #### 270 | class UpSample2x(nn.Module): 271 | """Upsample input by a factor of 2. 272 | 273 | Assume input is of NCHW, port FixedUnpooling from TensorPack. 274 | """ 275 | 276 | def __init__(self): 277 | super(UpSample2x, self).__init__() 278 | # correct way to create constant within module 279 | self.register_buffer( 280 | "unpool_mat", torch.from_numpy(np.ones((2, 2), dtype="float32")) 281 | ) 282 | self.unpool_mat.unsqueeze(0) 283 | 284 | def forward(self, x): 285 | input_shape = list(x.shape) 286 | # unsqueeze is expand_dims equivalent 287 | # permute is transpose equivalent 288 | # view is reshape equivalent 289 | x = x.unsqueeze(-1) # bchwx1 290 | mat = self.unpool_mat.unsqueeze(0) # 1xshxsw 291 | ret = torch.tensordot(x, mat, dims=1) # bxcxhxwxshxsw 292 | ret = ret.permute(0, 1, 2, 4, 3, 5) 293 | ret = ret.reshape((-1, input_shape[1], input_shape[2] * 2, input_shape[3] * 2)) 294 | return ret 295 | 296 | -------------------------------------------------------------------------------- /Hover/models/hovernet/opt.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | from run_utils.callbacks.base import ( 4 | AccumulateRawOutput, 5 | PeriodicSaver, 6 | ProcessAccumulatedRawOutput, 7 | ScalarMovingAverage, 8 | ScheduleLr, 9 | TrackLr, 10 | VisualizeOutput, 11 | TriggerEngine, 12 | ) 13 | from run_utils.callbacks.logging import LoggingEpochOutput, LoggingGradient 14 | from run_utils.engine import Events 15 | 16 | from .targets import gen_targets, prep_sample 17 | from .net_desc import create_model 18 | from .run_desc import proc_valid_step_output, train_step, valid_step, viz_step_output 19 | 20 | 21 | # TODO: training config only ? 22 | # TODO: switch all to function name String for all option 23 | def get_config(nr_type, mode): 24 | return { 25 | # ------------------------------------------------------------------ 26 | # ! All phases have the same number of run engine 27 | # phases are run sequentially from index 0 to N 28 | "phase_list": [ 29 | { 30 | "run_info": { 31 | # may need more dynamic for each network 32 | "net": { 33 | "desc": lambda: create_model( 34 | input_ch=3, nr_types=nr_type, 35 | freeze=True, mode=mode 36 | ), 37 | "optimizer": [ 38 | optim.Adam, 39 | { # should match keyword for parameters within the optimizer 40 | "lr": 1.0e-4, # initial learning rate, 41 | "betas": (0.9, 0.999), 42 | }, 43 | ], 44 | # learning rate scheduler 45 | "lr_scheduler": lambda x: optim.lr_scheduler.StepLR(x, 25), 46 | "extra_info": { 47 | "loss": { 48 | "np": {"bce": 1, "dice": 1}, 49 | "hv": {"mse": 1, "msge": 1}, 50 | "tp": {"bce": 1, "dice": 1}, 51 | }, 52 | }, 53 | # path to load, -1 to auto load checkpoint from previous phase, 54 | # None to start from scratch 55 | "pretrained": "/data/jhan/PanNuke/ImageNet-ResNet50-Preact_pytorch.tar", 56 | # 'pretrained': None, 57 | }, 58 | }, 59 | "target_info": {"gen": (gen_targets, {}), "viz": (prep_sample, {})}, 60 | "batch_size": {"train": 16, "valid": 16,}, # engine name : value 61 | "nr_epochs": 50, 62 | }, 63 | { 64 | "run_info": { 65 | # may need more dynamic for each network 66 | "net": { 67 | "desc": lambda: create_model( 68 | input_ch=3, nr_types=nr_type, 69 | freeze=False, mode=mode 70 | ), 71 | "optimizer": [ 72 | optim.Adam, 73 | { # should match keyword for parameters within the optimizer 74 | "lr": 1.0e-4, # initial learning rate, 75 | "betas": (0.9, 0.999), 76 | }, 77 | ], 78 | # learning rate scheduler 79 | "lr_scheduler": lambda x: optim.lr_scheduler.StepLR(x, 25), 80 | "extra_info": { 81 | "loss": { 82 | "np": {"bce": 1, "dice": 1}, 83 | "hv": {"mse": 1, "msge": 1}, 84 | "tp": {"bce": 1, "dice": 1}, 85 | }, 86 | }, 87 | # path to load, -1 to auto load checkpoint from previous phase, 88 | # None to start from scratch 89 | "pretrained": -1, 90 | }, 91 | }, 92 | "target_info": {"gen": (gen_targets, {}), "viz": (prep_sample, {})}, 93 | "batch_size": {"train": 4, "valid": 8,}, # batch size per gpu 94 | "nr_epochs": 50, 95 | }, 96 | ], 97 | # ------------------------------------------------------------------ 98 | # TODO: dynamically for dataset plugin selection and processing also? 99 | # all enclosed engine shares the same neural networks 100 | # as the on at the outer calling it 101 | "run_engine": { 102 | "train": { 103 | # TODO: align here, file path or what? what about CV? 104 | "dataset": "", # whats about compound dataset ? 105 | "nr_procs": 16, # number of threads for dataloader 106 | "run_step": train_step, # TODO: function name or function variable ? 107 | "reset_per_run": False, 108 | # callbacks are run according to the list order of the event 109 | "callbacks": { 110 | Events.STEP_COMPLETED: [ 111 | # LoggingGradient(), # TODO: very slow, may be due to back forth of tensor/numpy ? 112 | ScalarMovingAverage(), 113 | ], 114 | Events.EPOCH_COMPLETED: [ 115 | TrackLr(), 116 | PeriodicSaver(), 117 | VisualizeOutput(viz_step_output), 118 | LoggingEpochOutput(), 119 | TriggerEngine("valid"), 120 | ScheduleLr(), 121 | ], 122 | }, 123 | }, 124 | "valid": { 125 | "dataset": "", # whats about compound dataset ? 126 | "nr_procs": 8, # number of threads for dataloader 127 | "run_step": valid_step, 128 | "reset_per_run": True, # * to stop aggregating output etc. from last run 129 | # callbacks are run according to the list order of the event 130 | "callbacks": { 131 | Events.STEP_COMPLETED: [AccumulateRawOutput(),], 132 | Events.EPOCH_COMPLETED: [ 133 | # TODO: is there way to preload these ? 134 | ProcessAccumulatedRawOutput( 135 | lambda a: proc_valid_step_output(a, nr_types=nr_type) 136 | ), 137 | LoggingEpochOutput(), 138 | ], 139 | }, 140 | }, 141 | }, 142 | } 143 | -------------------------------------------------------------------------------- /Hover/models/hovernet/post_proc.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from scipy.ndimage import filters, measurements 5 | from scipy.ndimage.morphology import ( 6 | binary_dilation, 7 | binary_fill_holes, 8 | distance_transform_cdt, 9 | distance_transform_edt, 10 | ) 11 | 12 | from skimage.segmentation import watershed 13 | from misc.utils import get_bounding_box, remove_small_objects 14 | 15 | import warnings 16 | 17 | 18 | def noop(*args, **kargs): 19 | pass 20 | 21 | 22 | warnings.warn = noop 23 | 24 | 25 | #### 26 | def __proc_np_hv(pred): 27 | """Process Nuclei Prediction with XY Coordinate Map. 28 | 29 | Args: 30 | pred: prediction output, assuming 31 | channel 0 contain probability map of nuclei 32 | channel 1 containing the regressed X-map 33 | channel 2 containing the regressed Y-map 34 | 35 | """ 36 | pred = np.array(pred, dtype=np.float32) 37 | 38 | blb_raw = pred[..., 0] 39 | h_dir_raw = pred[..., 1] 40 | v_dir_raw = pred[..., 2] 41 | 42 | # processing 43 | blb = np.array(blb_raw >= 0.5, dtype=np.int32) 44 | 45 | blb = measurements.label(blb)[0] 46 | 47 | # blb = remove_small_objects(blb, min_size=10) 48 | # ! 2021-11-22 49 | blb = remove_small_objects(blb, min_size=4) 50 | blb[blb > 0] = 1 # background is 0 already 51 | 52 | h_dir = cv2.normalize( 53 | h_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F 54 | ) 55 | v_dir = cv2.normalize( 56 | v_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F 57 | ) 58 | 59 | # sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=21) 60 | # sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=21) 61 | # ! 2021-11-22 62 | sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=11) 63 | sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=11) 64 | 65 | sobelh = 1 - ( 66 | cv2.normalize( 67 | sobelh, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F 68 | ) 69 | ) 70 | sobelv = 1 - ( 71 | cv2.normalize( 72 | sobelv, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F 73 | ) 74 | ) 75 | 76 | overall = np.maximum(sobelh, sobelv) 77 | overall = overall - (1 - blb) 78 | overall[overall < 0] = 0 79 | 80 | dist = (1.0 - overall) * blb 81 | ## nuclei values form mountains so inverse to get basins 82 | dist = -cv2.GaussianBlur(dist, (3, 3), 0) 83 | 84 | overall = np.array(overall >= 0.4, dtype=np.int32) 85 | 86 | marker = blb - overall 87 | marker[marker < 0] = 0 88 | marker = binary_fill_holes(marker).astype("uint8") 89 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) 90 | marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel) 91 | marker = measurements.label(marker)[0] 92 | # marker = remove_small_objects(marker, min_size=10) 93 | # ! 2021-11-22 94 | marker = remove_small_objects(marker, min_size=4) 95 | 96 | proced_pred = watershed(dist, markers=marker, mask=blb) 97 | 98 | return proced_pred 99 | 100 | 101 | #### 102 | def process(pred_map, nr_types=None, return_centroids=False): 103 | """Post processing script for image tiles. 104 | 105 | Args: 106 | pred_map: commbined output of tp, np and hv branches, in the same order 107 | nr_types: number of types considered at output of nc branch 108 | overlaid_img: img to overlay the predicted instances upon, `None` means no 109 | type_colour (dict) : `None` to use random, else overlay instances of a type to colour in the dict 110 | output_dtype: data type of output 111 | 112 | Returns: 113 | pred_inst: pixel-wise nuclear instance segmentation prediction 114 | pred_type_out: pixel-wise nuclear type prediction 115 | 116 | """ 117 | if nr_types is not None: 118 | pred_type = pred_map[..., :1] 119 | pred_inst = pred_map[..., 1:] 120 | pred_type = pred_type.astype(np.int32) 121 | else: 122 | pred_inst = pred_map 123 | 124 | pred_inst = np.squeeze(pred_inst) 125 | pred_inst = __proc_np_hv(pred_inst) 126 | 127 | inst_info_dict = None 128 | if return_centroids or nr_types is not None: 129 | inst_id_list = np.unique(pred_inst)[1:] # exlcude background 130 | inst_info_dict = {} 131 | for inst_id in inst_id_list: 132 | inst_map = pred_inst == inst_id 133 | # TODO: chane format of bbox output 134 | rmin, rmax, cmin, cmax = get_bounding_box(inst_map) 135 | inst_bbox = np.array([[rmin, cmin], [rmax, cmax]]) 136 | inst_map = inst_map[ 137 | inst_bbox[0][0] : inst_bbox[1][0], inst_bbox[0][1] : inst_bbox[1][1] 138 | ] 139 | inst_map = inst_map.astype(np.uint8) 140 | inst_moment = cv2.moments(inst_map) 141 | inst_contour = cv2.findContours( 142 | inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE 143 | ) 144 | # * opencv protocol format may break 145 | inst_contour = np.squeeze(inst_contour[0][0].astype("int32")) 146 | # < 3 points dont make a contour, so skip, likely artifact too 147 | # as the contours obtained via approximation => too small or sthg 148 | if inst_contour.shape[0] < 3: 149 | continue 150 | if len(inst_contour.shape) != 2: 151 | continue # ! check for trickery shape 152 | inst_centroid = [ 153 | (inst_moment["m10"] / inst_moment["m00"]), 154 | (inst_moment["m01"] / inst_moment["m00"]), 155 | ] 156 | inst_centroid = np.array(inst_centroid) 157 | inst_contour[:, 0] += inst_bbox[0][1] # X 158 | inst_contour[:, 1] += inst_bbox[0][0] # Y 159 | inst_centroid[0] += inst_bbox[0][1] # X 160 | inst_centroid[1] += inst_bbox[0][0] # Y 161 | inst_info_dict[inst_id] = { # inst_id should start at 1 162 | "bbox": inst_bbox, 163 | "centroid": inst_centroid, 164 | "contour": inst_contour, 165 | "type_prob": None, 166 | "type": None, 167 | } 168 | 169 | if nr_types is not None: 170 | #### * Get class of each instance id, stored at index id-1 171 | for inst_id in list(inst_info_dict.keys()): 172 | rmin, cmin, rmax, cmax = (inst_info_dict[inst_id]["bbox"]).flatten() 173 | inst_map_crop = pred_inst[rmin:rmax, cmin:cmax] 174 | inst_type_crop = pred_type[rmin:rmax, cmin:cmax] 175 | inst_map_crop = ( 176 | inst_map_crop == inst_id 177 | ) # TODO: duplicated operation, may be expensive 178 | inst_type = inst_type_crop[inst_map_crop] 179 | type_list, type_pixels = np.unique(inst_type, return_counts=True) 180 | type_list = list(zip(type_list, type_pixels)) 181 | type_list = sorted(type_list, key=lambda x: x[1], reverse=True) 182 | inst_type = type_list[0][0] 183 | if inst_type == 0: # ! pick the 2nd most dominant if exist 184 | if len(type_list) > 1: 185 | inst_type = type_list[1][0] 186 | type_dict = {v[0]: v[1] for v in type_list} 187 | type_prob = type_dict[inst_type] / (np.sum(inst_map_crop) + 1.0e-6) 188 | inst_info_dict[inst_id]["type"] = int(inst_type) 189 | inst_info_dict[inst_id]["type_prob"] = float(type_prob) 190 | 191 | # print('here') 192 | # ! WARNING: ID MAY NOT BE CONTIGUOUS 193 | # inst_id in the dict maps to the same value in the `pred_inst` 194 | return pred_inst, inst_info_dict 195 | -------------------------------------------------------------------------------- /Hover/models/hovernet/run_desc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from misc.utils import center_pad_to_shape, cropping_center 7 | from .utils import crop_to_shape, dice_loss, mse_loss, msge_loss, xentropy_loss 8 | 9 | from collections import OrderedDict 10 | 11 | #### 12 | def train_step(batch_data, run_info): 13 | # TODO: synchronize the attach protocol 14 | run_info, state_info = run_info 15 | loss_func_dict = { 16 | "bce": xentropy_loss, 17 | "dice": dice_loss, 18 | "mse": mse_loss, 19 | "msge": msge_loss, 20 | } 21 | # use 'ema' to add for EMA calculation, must be scalar! 22 | result_dict = {"EMA": {}} 23 | track_value = lambda name, value: result_dict["EMA"].update({name: value}) 24 | 25 | #### 26 | model = run_info["net"]["desc"] 27 | optimizer = run_info["net"]["optimizer"] 28 | 29 | #### 30 | imgs = batch_data["img"] 31 | true_np = batch_data["np_map"] 32 | true_hv = batch_data["hv_map"] 33 | 34 | imgs = imgs.to("cuda").type(torch.float32) # to NCHW 35 | imgs = imgs.permute(0, 3, 1, 2).contiguous() 36 | 37 | # HWC 38 | true_np = true_np.to("cuda").type(torch.int64) 39 | true_hv = true_hv.to("cuda").type(torch.float32) 40 | 41 | true_np_onehot = (F.one_hot(true_np, num_classes=2)).type(torch.float32) 42 | true_dict = { 43 | "np": true_np_onehot, 44 | "hv": true_hv, 45 | } 46 | 47 | if model.module.nr_types is not None: 48 | true_tp = batch_data["tp_map"] 49 | true_tp = torch.squeeze(true_tp).to("cuda").type(torch.int64) 50 | true_tp_onehot = F.one_hot(true_tp, num_classes=model.module.nr_types) 51 | true_tp_onehot = true_tp_onehot.type(torch.float32) 52 | true_dict["tp"] = true_tp_onehot 53 | 54 | #### 55 | model.train() 56 | model.zero_grad() # not rnn so not accumulate 57 | 58 | pred_dict = model(imgs) 59 | pred_dict = OrderedDict( 60 | [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] 61 | ) 62 | pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1) 63 | if model.module.nr_types is not None: 64 | pred_dict["tp"] = F.softmax(pred_dict["tp"], dim=-1) 65 | 66 | #### 67 | loss = 0 68 | loss_opts = run_info["net"]["extra_info"]["loss"] 69 | for branch_name in pred_dict.keys(): 70 | for loss_name, loss_weight in loss_opts[branch_name].items(): 71 | loss_func = loss_func_dict[loss_name] 72 | loss_args = [true_dict[branch_name], pred_dict[branch_name]] 73 | if loss_name == "msge": 74 | loss_args.append(true_np_onehot[..., 1]) 75 | term_loss = loss_func(*loss_args) 76 | track_value("loss_%s_%s" % (branch_name, loss_name), term_loss.cpu().item()) 77 | loss += loss_weight * term_loss 78 | 79 | track_value("overall_loss", loss.cpu().item()) 80 | # * gradient update 81 | 82 | # torch.set_printoptions(precision=10) 83 | loss.backward() 84 | optimizer.step() 85 | #### 86 | 87 | # pick 2 random sample from the batch for visualization 88 | sample_indices = torch.randint(0, true_np.shape[0], (2,)) 89 | 90 | imgs = (imgs[sample_indices]).byte() # to uint8 91 | imgs = imgs.permute(0, 2, 3, 1).contiguous().cpu().numpy() 92 | 93 | pred_dict["np"] = pred_dict["np"][..., 1] # return pos only 94 | pred_dict = { 95 | k: v[sample_indices].detach().cpu().numpy() for k, v in pred_dict.items() 96 | } 97 | 98 | true_dict["np"] = true_np 99 | true_dict = { 100 | k: v[sample_indices].detach().cpu().numpy() for k, v in true_dict.items() 101 | } 102 | 103 | # * Its up to user to define the protocol to process the raw output per step! 104 | result_dict["raw"] = { # protocol for contents exchange within `raw` 105 | "img": imgs, 106 | "np": (true_dict["np"], pred_dict["np"]), 107 | "hv": (true_dict["hv"], pred_dict["hv"]), 108 | } 109 | return result_dict 110 | 111 | 112 | #### 113 | def valid_step(batch_data, run_info): 114 | run_info, state_info = run_info 115 | #### 116 | model = run_info["net"]["desc"] 117 | model.eval() # infer mode 118 | 119 | #### 120 | imgs = batch_data["img"] 121 | true_np = batch_data["np_map"] 122 | true_hv = batch_data["hv_map"] 123 | 124 | imgs_gpu = imgs.to("cuda").type(torch.float32) # to NCHW 125 | imgs_gpu = imgs_gpu.permute(0, 3, 1, 2).contiguous() 126 | 127 | # HWC 128 | true_np = torch.squeeze(true_np).to("cuda").type(torch.int64) 129 | true_hv = torch.squeeze(true_hv).to("cuda").type(torch.float32) 130 | 131 | true_dict = { 132 | "np": true_np, 133 | "hv": true_hv, 134 | } 135 | 136 | if model.module.nr_types is not None: 137 | true_tp = batch_data["tp_map"] 138 | true_tp = torch.squeeze(true_tp).to("cuda").type(torch.int64) 139 | true_dict["tp"] = true_tp 140 | 141 | # -------------------------------------------------------------- 142 | with torch.no_grad(): # dont compute gradient 143 | pred_dict = model(imgs_gpu) 144 | pred_dict = OrderedDict( 145 | [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] 146 | ) 147 | pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1] 148 | if model.module.nr_types is not None: 149 | type_map = F.softmax(pred_dict["tp"], dim=-1) 150 | type_map = torch.argmax(type_map, dim=-1, keepdim=False) 151 | type_map = type_map.type(torch.float32) 152 | pred_dict["tp"] = type_map 153 | 154 | # * Its up to user to define the protocol to process the raw output per step! 155 | result_dict = { # protocol for contents exchange within `raw` 156 | "raw": { 157 | "imgs": imgs.numpy(), 158 | "true_np": true_dict["np"].cpu().numpy(), 159 | "true_hv": true_dict["hv"].cpu().numpy(), 160 | "prob_np": pred_dict["np"].cpu().numpy(), 161 | "pred_hv": pred_dict["hv"].cpu().numpy(), 162 | } 163 | } 164 | if model.module.nr_types is not None: 165 | result_dict["raw"]["true_tp"] = true_dict["tp"].cpu().numpy() 166 | result_dict["raw"]["pred_tp"] = pred_dict["tp"].cpu().numpy() 167 | return result_dict 168 | 169 | 170 | #### 171 | def infer_step(batch_data, model): 172 | 173 | #### 174 | patch_imgs = batch_data 175 | 176 | patch_imgs_gpu = patch_imgs.to("cuda").type(torch.float32) # to NCHW 177 | patch_imgs_gpu = patch_imgs_gpu.permute(0, 3, 1, 2).contiguous() 178 | 179 | #### 180 | model.eval() # infer mode 181 | 182 | # -------------------------------------------------------------- 183 | with torch.no_grad(): # dont compute gradient 184 | pred_dict = model(patch_imgs_gpu) 185 | pred_dict = OrderedDict( 186 | [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] 187 | ) 188 | pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1:] 189 | if "tp" in pred_dict: 190 | type_map = F.softmax(pred_dict["tp"], dim=-1) 191 | type_map = torch.argmax(type_map, dim=-1, keepdim=True) 192 | type_map = type_map.type(torch.float32) 193 | pred_dict["tp"] = type_map 194 | pred_output = torch.cat(list(pred_dict.values()), -1) 195 | 196 | # * Its up to user to define the protocol to process the raw output per step! 197 | return pred_output.cpu().numpy() 198 | 199 | 200 | #### 201 | def viz_step_output(raw_data, nr_types=None): 202 | """ 203 | `raw_data` will be implicitly provided in the similar format as the 204 | return dict from train/valid step, but may have been accumulated across N running step 205 | """ 206 | 207 | imgs = raw_data["img"] 208 | true_np, pred_np = raw_data["np"] 209 | true_hv, pred_hv = raw_data["hv"] 210 | if nr_types is not None: 211 | true_tp, pred_tp = raw_data["tp"] 212 | 213 | aligned_shape = [list(imgs.shape), list(true_np.shape), list(pred_np.shape)] 214 | aligned_shape = np.min(np.array(aligned_shape), axis=0)[1:3] 215 | 216 | cmap = plt.get_cmap("jet") 217 | 218 | def colorize(ch, vmin, vmax): 219 | """ 220 | Will clamp value value outside the provided range to vmax and vmin 221 | """ 222 | ch = np.squeeze(ch.astype("float32")) 223 | ch[ch > vmax] = vmax # clamp value 224 | ch[ch < vmin] = vmin 225 | ch = (ch - vmin) / (vmax - vmin + 1.0e-16) 226 | # take RGB from RGBA heat map 227 | ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8") 228 | # ch_cmap = center_pad_to_shape(ch_cmap, aligned_shape) 229 | return ch_cmap 230 | 231 | viz_list = [] 232 | for idx in range(imgs.shape[0]): 233 | # img = center_pad_to_shape(imgs[idx], aligned_shape) 234 | img = cropping_center(imgs[idx], aligned_shape) 235 | 236 | true_viz_list = [img] 237 | # cmap may randomly fails if of other types 238 | true_viz_list.append(colorize(true_np[idx], 0, 1)) 239 | true_viz_list.append(colorize(true_hv[idx][..., 0], -1, 1)) 240 | true_viz_list.append(colorize(true_hv[idx][..., 1], -1, 1)) 241 | if nr_types is not None: # TODO: a way to pass through external info 242 | true_viz_list.append(colorize(true_tp[idx], 0, nr_types)) 243 | true_viz_list = np.concatenate(true_viz_list, axis=1) 244 | 245 | pred_viz_list = [img] 246 | # cmap may randomly fails if of other types 247 | pred_viz_list.append(colorize(pred_np[idx], 0, 1)) 248 | pred_viz_list.append(colorize(pred_hv[idx][..., 0], -1, 1)) 249 | pred_viz_list.append(colorize(pred_hv[idx][..., 1], -1, 1)) 250 | if nr_types is not None: 251 | pred_viz_list.append(colorize(pred_tp[idx], 0, nr_types)) 252 | pred_viz_list = np.concatenate(pred_viz_list, axis=1) 253 | 254 | viz_list.append(np.concatenate([true_viz_list, pred_viz_list], axis=0)) 255 | viz_list = np.concatenate(viz_list, axis=0) 256 | return viz_list 257 | 258 | 259 | #### 260 | from itertools import chain 261 | 262 | 263 | def proc_valid_step_output(raw_data, nr_types=None): 264 | # TODO: add auto populate from main state track list 265 | track_dict = {"scalar": {}, "image": {}} 266 | 267 | def track_value(name, value, vtype): 268 | return track_dict[vtype].update({name: value}) 269 | 270 | def _dice_info(true, pred, label): 271 | true = np.array(true == label, np.int32) 272 | pred = np.array(pred == label, np.int32) 273 | inter = (pred * true).sum() 274 | total = (pred + true).sum() 275 | return inter, total 276 | 277 | over_inter = 0 278 | over_total = 0 279 | over_correct = 0 280 | prob_np = raw_data["prob_np"] 281 | true_np = raw_data["true_np"] 282 | for idx in range(len(raw_data["true_np"])): 283 | patch_prob_np = prob_np[idx] 284 | patch_true_np = true_np[idx] 285 | patch_pred_np = np.array(patch_prob_np > 0.5, dtype=np.int32) 286 | inter, total = _dice_info(patch_true_np, patch_pred_np, 1) 287 | correct = (patch_pred_np == patch_true_np).sum() 288 | over_inter += inter 289 | over_total += total 290 | over_correct += correct 291 | nr_pixels = len(true_np) * np.size(true_np[0]) 292 | acc_np = over_correct / nr_pixels 293 | dice_np = 2 * over_inter / (over_total + 1.0e-8) 294 | track_value("np_acc", acc_np, "scalar") 295 | track_value("np_dice", dice_np, "scalar") 296 | 297 | # * TP statistic 298 | if nr_types is not None: 299 | pred_tp = raw_data["pred_tp"] 300 | true_tp = raw_data["true_tp"] 301 | for type_id in range(0, nr_types): 302 | over_inter = 0 303 | over_total = 0 304 | for idx in range(len(raw_data["true_np"])): 305 | patch_pred_tp = pred_tp[idx] 306 | patch_true_tp = true_tp[idx] 307 | inter, total = _dice_info(patch_true_tp, patch_pred_tp, type_id) 308 | over_inter += inter 309 | over_total += total 310 | dice_tp = 2 * over_inter / (over_total + 1.0e-8) 311 | track_value("tp_dice_%d" % type_id, dice_tp, "scalar") 312 | 313 | # * HV regression statistic 314 | pred_hv = raw_data["pred_hv"] 315 | true_hv = raw_data["true_hv"] 316 | 317 | over_squared_error = 0 318 | for idx in range(len(raw_data["true_np"])): 319 | patch_pred_hv = pred_hv[idx] 320 | patch_true_hv = true_hv[idx] 321 | squared_error = patch_pred_hv - patch_true_hv 322 | squared_error = squared_error * squared_error 323 | over_squared_error += squared_error.sum() 324 | mse = over_squared_error / nr_pixels 325 | track_value("hv_mse", mse, "scalar") 326 | 327 | # * 328 | imgs = raw_data["imgs"] 329 | selected_idx = np.random.randint(0, len(imgs), size=(8,)).tolist() 330 | imgs = np.array([imgs[idx] for idx in selected_idx]) 331 | true_np = np.array([true_np[idx] for idx in selected_idx]) 332 | true_hv = np.array([true_hv[idx] for idx in selected_idx]) 333 | prob_np = np.array([prob_np[idx] for idx in selected_idx]) 334 | pred_hv = np.array([pred_hv[idx] for idx in selected_idx]) 335 | viz_raw_data = {"img": imgs, "np": (true_np, prob_np), "hv": (true_hv, pred_hv)} 336 | 337 | if nr_types is not None: 338 | true_tp = np.array([true_tp[idx] for idx in selected_idx]) 339 | pred_tp = np.array([pred_tp[idx] for idx in selected_idx]) 340 | viz_raw_data["tp"] = (true_tp, pred_tp) 341 | viz_fig = viz_step_output(viz_raw_data, nr_types) 342 | track_dict["image"]["output"] = viz_fig 343 | 344 | return track_dict 345 | -------------------------------------------------------------------------------- /Hover/models/hovernet/targets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from scipy import ndimage 8 | from scipy.ndimage import measurements 9 | from skimage import morphology as morph 10 | import matplotlib.pyplot as plt 11 | 12 | from misc.utils import center_pad_to_shape, cropping_center, get_bounding_box 13 | from dataloader.augs import fix_mirror_padding 14 | 15 | 16 | #### 17 | def gen_instance_hv_map(ann, crop_shape): 18 | """Input annotation must be of original shape. 19 | 20 | The map is calculated only for instances within the crop portion 21 | but based on the original shape in original image. 22 | 23 | Perform following operation: 24 | Obtain the horizontal and vertical distance maps for each 25 | nuclear instance. 26 | 27 | """ 28 | orig_ann = ann.copy() # instance ID map 29 | fixed_ann = fix_mirror_padding(orig_ann) 30 | # re-cropping with fixed instance id map 31 | crop_ann = cropping_center(fixed_ann, crop_shape) 32 | # TODO: deal with 1 label warning 33 | # crop_ann = morph.remove_small_objects(crop_ann, min_size=30) 34 | # ! 2021-11-22 35 | crop_ann = morph.remove_small_objects(crop_ann, min_size=9) 36 | 37 | x_map = np.zeros(orig_ann.shape[:2], dtype=np.float32) 38 | y_map = np.zeros(orig_ann.shape[:2], dtype=np.float32) 39 | 40 | inst_list = list(np.unique(crop_ann)) 41 | inst_list.remove(0) # 0 is background 42 | for inst_id in inst_list: 43 | inst_map = np.array(fixed_ann == inst_id, np.uint8) 44 | inst_box = get_bounding_box(inst_map) 45 | 46 | # expand the box by 2px 47 | # Because we first pad the ann at line 207, the bboxes 48 | # will remain valid after expansion 49 | inst_box[0] -= 2 50 | inst_box[2] -= 2 51 | inst_box[1] += 2 52 | inst_box[3] += 2 53 | 54 | inst_map = inst_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] 55 | 56 | if inst_map.shape[0] < 2 or inst_map.shape[1] < 2: 57 | continue 58 | 59 | # instance center of mass, rounded to nearest pixel 60 | inst_com = list(measurements.center_of_mass(inst_map)) 61 | 62 | inst_com[0] = int(inst_com[0] + 0.5) 63 | inst_com[1] = int(inst_com[1] + 0.5) 64 | 65 | inst_x_range = np.arange(1, inst_map.shape[1] + 1) 66 | inst_y_range = np.arange(1, inst_map.shape[0] + 1) 67 | # shifting center of pixels grid to instance center of mass 68 | inst_x_range -= inst_com[1] 69 | inst_y_range -= inst_com[0] 70 | 71 | inst_x, inst_y = np.meshgrid(inst_x_range, inst_y_range) 72 | 73 | # remove coord outside of instance 74 | inst_x[inst_map == 0] = 0 75 | inst_y[inst_map == 0] = 0 76 | inst_x = inst_x.astype("float32") 77 | inst_y = inst_y.astype("float32") 78 | 79 | # normalize min into -1 scale 80 | if np.min(inst_x) < 0: 81 | inst_x[inst_x < 0] /= -np.amin(inst_x[inst_x < 0]) 82 | if np.min(inst_y) < 0: 83 | inst_y[inst_y < 0] /= -np.amin(inst_y[inst_y < 0]) 84 | # normalize max into +1 scale 85 | if np.max(inst_x) > 0: 86 | inst_x[inst_x > 0] /= np.amax(inst_x[inst_x > 0]) 87 | if np.max(inst_y) > 0: 88 | inst_y[inst_y > 0] /= np.amax(inst_y[inst_y > 0]) 89 | 90 | #### 91 | x_map_box = x_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] 92 | x_map_box[inst_map > 0] = inst_x[inst_map > 0] 93 | 94 | y_map_box = y_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] 95 | y_map_box[inst_map > 0] = inst_y[inst_map > 0] 96 | 97 | hv_map = np.dstack([x_map, y_map]) 98 | return hv_map 99 | 100 | 101 | #### 102 | def gen_targets(ann, crop_shape, **kwargs): 103 | """Generate the targets for the network.""" 104 | hv_map = gen_instance_hv_map(ann, crop_shape) 105 | np_map = ann.copy() 106 | np_map[np_map > 0] = 1 107 | 108 | hv_map = cropping_center(hv_map, crop_shape) 109 | np_map = cropping_center(np_map, crop_shape) 110 | 111 | target_dict = { 112 | "hv_map": hv_map, 113 | "np_map": np_map, 114 | } 115 | 116 | return target_dict 117 | 118 | 119 | #### 120 | def prep_sample(data, is_batch=False, **kwargs): 121 | """ 122 | Designed to process direct output from loader 123 | """ 124 | cmap = plt.get_cmap("jet") 125 | 126 | def colorize(ch, vmin, vmax, shape): 127 | ch = np.squeeze(ch.astype("float32")) 128 | ch = ch / (vmax - vmin + 1.0e-16) 129 | # take RGB from RGBA heat map 130 | ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8") 131 | ch_cmap = center_pad_to_shape(ch_cmap, shape) 132 | return ch_cmap 133 | 134 | def prep_one_sample(data): 135 | shape_array = [np.array(v.shape[:2]) for v in data.values()] 136 | shape = np.maximum(*shape_array) 137 | viz_list = [] 138 | viz_list.append(colorize(data["np_map"], 0, 1, shape)) 139 | # map to [0,2] for better visualisation. 140 | # Note, [-1,1] is used for training. 141 | viz_list.append(colorize(data["hv_map"][..., 0] + 1, 0, 2, shape)) 142 | viz_list.append(colorize(data["hv_map"][..., 1] + 1, 0, 2, shape)) 143 | img = center_pad_to_shape(data["img"], shape) 144 | return np.concatenate([img] + viz_list, axis=1) 145 | 146 | # cmap may randomly fails if of other types 147 | if is_batch: 148 | viz_list = [] 149 | data_shape = list(data.values())[0].shape 150 | for batch_idx in range(data_shape[0]): 151 | sub_data = {k : v[batch_idx] for k, v in data.items()} 152 | viz_list.append(prep_one_sample(sub_data)) 153 | return np.concatenate(viz_list, axis=0) 154 | else: 155 | return prep_one_sample(data) 156 | -------------------------------------------------------------------------------- /Hover/models/hovernet/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from matplotlib import cm 8 | 9 | 10 | #### 11 | def crop_op(x, cropping, data_format="NCHW"): 12 | """Center crop image. 13 | 14 | Args: 15 | x: input image 16 | cropping: the substracted amount 17 | data_format: choose either `NCHW` or `NHWC` 18 | 19 | """ 20 | crop_t = cropping[0] // 2 21 | crop_b = cropping[0] - crop_t 22 | crop_l = cropping[1] // 2 23 | crop_r = cropping[1] - crop_l 24 | if data_format == "NCHW": 25 | x = x[:, :, crop_t:-crop_b, crop_l:-crop_r] 26 | else: 27 | x = x[:, crop_t:-crop_b, crop_l:-crop_r, :] 28 | return x 29 | 30 | 31 | #### 32 | def crop_to_shape(x, y, data_format="NCHW"): 33 | """Centre crop x so that x has shape of y. y dims must be smaller than x dims. 34 | 35 | Args: 36 | x: input array 37 | y: array with desired shape. 38 | 39 | """ 40 | assert ( 41 | y.shape[0] <= x.shape[0] and y.shape[1] <= x.shape[1] 42 | ), "Ensure that y dimensions are smaller than x dimensions!" 43 | 44 | x_shape = x.size() 45 | y_shape = y.size() 46 | if data_format == "NCHW": 47 | crop_shape = (x_shape[2] - y_shape[2], x_shape[3] - y_shape[3]) 48 | else: 49 | crop_shape = (x_shape[1] - y_shape[1], x_shape[2] - y_shape[2]) 50 | return crop_op(x, crop_shape, data_format) 51 | 52 | 53 | #### 54 | def xentropy_loss(true, pred, reduction="mean"): 55 | """Cross entropy loss. Assumes NHWC! 56 | 57 | Args: 58 | pred: prediction array 59 | true: ground truth array 60 | 61 | Returns: 62 | cross entropy loss 63 | 64 | """ 65 | epsilon = 10e-8 66 | # scale preds so that the class probs of each sample sum to 1 67 | pred = pred / torch.sum(pred, -1, keepdim=True) 68 | # manual computation of crossentropy 69 | pred = torch.clamp(pred, epsilon, 1.0 - epsilon) 70 | loss = -torch.sum((true * torch.log(pred)), -1, keepdim=True) 71 | loss = loss.mean() if reduction == "mean" else loss.sum() 72 | return loss 73 | 74 | 75 | #### 76 | def dice_loss(true, pred, smooth=1e-3): 77 | """`pred` and `true` must be of torch.float32. Assuming of shape NxHxWxC.""" 78 | inse = torch.sum(pred * true, (0, 1, 2)) 79 | l = torch.sum(pred, (0, 1, 2)) 80 | r = torch.sum(true, (0, 1, 2)) 81 | loss = 1.0 - (2.0 * inse + smooth) / (l + r + smooth) 82 | loss = torch.sum(loss) 83 | return loss 84 | 85 | 86 | #### 87 | def mse_loss(true, pred): 88 | """Calculate mean squared error loss. 89 | 90 | Args: 91 | true: ground truth of combined horizontal 92 | and vertical maps 93 | pred: prediction of combined horizontal 94 | and vertical maps 95 | 96 | Returns: 97 | loss: mean squared error 98 | 99 | """ 100 | loss = pred - true 101 | loss = (loss * loss).mean() 102 | return loss 103 | 104 | 105 | #### 106 | def msge_loss(true, pred, focus): 107 | """Calculate the mean squared error of the gradients of 108 | horizontal and vertical map predictions. Assumes 109 | channel 0 is Vertical and channel 1 is Horizontal. 110 | 111 | Args: 112 | true: ground truth of combined horizontal 113 | and vertical maps 114 | pred: prediction of combined horizontal 115 | and vertical maps 116 | focus: area where to apply loss (we only calculate 117 | the loss within the nuclei) 118 | 119 | Returns: 120 | loss: mean squared error of gradients 121 | 122 | """ 123 | 124 | def get_sobel_kernel(size): 125 | """Get sobel kernel with a given size.""" 126 | assert size % 2 == 1, "Must be odd, get size=%d" % size 127 | 128 | h_range = torch.arange( 129 | -size // 2 + 1, 130 | size // 2 + 1, 131 | dtype=torch.float32, 132 | device="cuda", 133 | requires_grad=False, 134 | ) 135 | v_range = torch.arange( 136 | -size // 2 + 1, 137 | size // 2 + 1, 138 | dtype=torch.float32, 139 | device="cuda", 140 | requires_grad=False, 141 | ) 142 | h, v = torch.meshgrid(h_range, v_range) 143 | kernel_h = h / (h * h + v * v + 1.0e-15) 144 | kernel_v = v / (h * h + v * v + 1.0e-15) 145 | return kernel_h, kernel_v 146 | 147 | #### 148 | def get_gradient_hv(hv): 149 | """For calculating gradient.""" 150 | kernel_h, kernel_v = get_sobel_kernel(5) 151 | kernel_h = kernel_h.view(1, 1, 5, 5) # constant 152 | kernel_v = kernel_v.view(1, 1, 5, 5) # constant 153 | 154 | h_ch = hv[..., 0].unsqueeze(1) # Nx1xHxW 155 | v_ch = hv[..., 1].unsqueeze(1) # Nx1xHxW 156 | 157 | # can only apply in NCHW mode 158 | h_dh_ch = F.conv2d(h_ch, kernel_h, padding=2) 159 | v_dv_ch = F.conv2d(v_ch, kernel_v, padding=2) 160 | dhv = torch.cat([h_dh_ch, v_dv_ch], dim=1) 161 | dhv = dhv.permute(0, 2, 3, 1).contiguous() # to NHWC 162 | return dhv 163 | 164 | focus = (focus[..., None]).float() # assume input NHW 165 | focus = torch.cat([focus, focus], axis=-1) 166 | true_grad = get_gradient_hv(true) 167 | pred_grad = get_gradient_hv(pred) 168 | loss = pred_grad - true_grad 169 | loss = focus * (loss * loss) 170 | # artificial reduce_mean with focused region 171 | loss = loss.sum() / (focus.sum() + 1.0e-8) 172 | return loss 173 | -------------------------------------------------------------------------------- /Hover/requirements.txt: -------------------------------------------------------------------------------- 1 | docopt==0.6.2 2 | future==0.18.2 3 | imgaug==0.4.0 4 | matplotlib==3.3.0 5 | numpy==1.19.1 6 | opencv-python==4.3.0.36 7 | pandas==1.1.0 8 | pillow==7.2.0 9 | psutil==5.7.3 10 | scikit-image==0.17.2 11 | scikit-learn==0.23.1 12 | scipy==1.5.2 13 | tensorboard==2.3.0 14 | tensorboardx==2.1 15 | termcolor==1.1.0 16 | tqdm==4.48.0 17 | -------------------------------------------------------------------------------- /Hover/run_infer.py: -------------------------------------------------------------------------------- 1 | """run_infer.py 2 | 3 | Usage: 4 | run_infer.py [options] [--help] [...] 5 | run_infer.py --version 6 | run_infer.py (-h | --help) 7 | 8 | Options: 9 | -h --help Show this string. 10 | --version Show version. 11 | 12 | --gpu= GPU list. [default: 0] 13 | --nr_types= Number of nuclei types to predict. [default: 0] 14 | --type_info_path= Path to a json define mapping between type id, type name, 15 | and expected overlaid color. [default: ''] 16 | 17 | --model_path= Path to saved checkpoint. 18 | --model_mode= Original HoVer-Net or the reduced version used PanNuke and MoNuSAC, 19 | 'original' or 'fast'. [default: fast] 20 | --nr_inference_workers= Number of workers during inference. [default: 8] 21 | --nr_post_proc_workers= Number of workers during post-processing. [default: 16] 22 | --batch_size= Batch size per 1 GPU. [default: 32] 23 | 24 | Two command mode are `tile` and `wsi` to enter corresponding inference mode 25 | tile run the inference on tile 26 | wsi run the inference on wsi 27 | 28 | Use `run_infer.py --help` to show their options and usage. 29 | """ 30 | 31 | tile_cli = """ 32 | Arguments for processing tiles. 33 | 34 | usage: 35 | tile (--input_dir=) (--output_dir=) \ 36 | [--draw_dot] [--save_qupath] [--save_raw_map] [--mem_usage=] 37 | 38 | options: 39 | --input_dir= Path to input data directory. Assumes the files are not nested within directory. 40 | --output_dir= Path to output directory.. 41 | 42 | --mem_usage= Declare how much memory (physical + swap) should be used for caching. 43 | By default it will load as many tiles as possible till reaching the 44 | declared limit. [default: 0.2] 45 | --draw_dot To draw nuclei centroid on overlay. [default: False] 46 | --save_qupath To optionally output QuPath v0.2.3 compatible format. [default: False] 47 | --save_raw_map To save raw prediction or not. [default: False] 48 | """ 49 | 50 | wsi_cli = """ 51 | Arguments for processing wsi 52 | 53 | usage: 54 | wsi (--input_dir=) (--output_dir=) (--presplit_dir=) [--proc_mag=]\ 55 | [--cache_path=] [--input_mask_dir=] \ 56 | [--ambiguous_size=] [--chunk_shape=] [--tile_shape=] \ 57 | [--save_thumb] [--save_mask] 58 | 59 | options: 60 | --input_dir= Path to input data directory. Assumes the files are not nested within directory. 61 | --output_dir= Path to output directory. 62 | --presplit_dir= Path to presplit data directory. 63 | --cache_path= Path for cache. Should be placed on SSD with at least 100GB. [default: cache] 64 | --mask_dir= Path to directory containing tissue masks. 65 | Should have the same name as corresponding WSIs. [default: ''] 66 | 67 | --proc_mag= Magnification level (objective power) used for WSI processing. [default: 40] 68 | --ambiguous_size= Define ambiguous region along tiling grid to perform re-post processing. [default: 128] 69 | --chunk_shape= Shape of chunk for processing. [default: 10000] 70 | --tile_shape= Shape of tiles for processing. [default: 2048] 71 | --save_thumb To save thumb. [default: False] 72 | --save_mask To save mask. [default: False] 73 | """ 74 | 75 | import torch 76 | import logging 77 | import os 78 | import copy 79 | from misc.utils import log_info 80 | from docopt import docopt 81 | 82 | #------------------------------------------------------------------------------------------------------- 83 | 84 | if __name__ == '__main__': 85 | sub_cli_dict = {'tile' : tile_cli, 'wsi' : wsi_cli} 86 | args = docopt(__doc__, help=False, options_first=True, 87 | version='HoVer-Net Pytorch Inference v1.0') 88 | sub_cmd = args.pop('') 89 | sub_cmd_args = args.pop('') 90 | 91 | # ! TODO: where to save logging 92 | logging.basicConfig( 93 | level=logging.INFO, 94 | format='|%(asctime)s.%(msecs)03d| [%(levelname)s] %(message)s',datefmt='%Y-%m-%d|%H:%M:%S', 95 | handlers=[ 96 | logging.FileHandler("debug.log"), 97 | logging.StreamHandler() 98 | ] 99 | ) 100 | 101 | if args['--help'] and sub_cmd is not None: 102 | if sub_cmd in sub_cli_dict: 103 | print(sub_cli_dict[sub_cmd]) 104 | else: 105 | print(__doc__) 106 | exit() 107 | if args['--help'] or sub_cmd is None: 108 | print(__doc__) 109 | exit() 110 | 111 | sub_args = docopt(sub_cli_dict[sub_cmd], argv=sub_cmd_args, help=True) 112 | 113 | args.pop('--version') 114 | gpu_list = args.pop('--gpu') 115 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 116 | 117 | nr_gpus = torch.cuda.device_count() 118 | log_info('Detect #GPUS: %d' % nr_gpus) 119 | 120 | args = {k.replace('--', '') : v for k, v in args.items()} 121 | sub_args = {k.replace('--', '') : v for k, v in sub_args.items()} 122 | if args['model_path'] == None: 123 | raise Exception('A model path must be supplied as an argument with --model_path.') 124 | 125 | nr_types = int(args['nr_types']) if int(args['nr_types']) > 0 else None 126 | method_args = { 127 | 'method' : { 128 | 'model_args' : { 129 | 'nr_types' : nr_types, 130 | 'mode' : args['model_mode'], 131 | }, 132 | 'model_path' : args['model_path'], 133 | }, 134 | 'type_info_path' : None if args['type_info_path'] == '' \ 135 | else args['type_info_path'], 136 | } 137 | 138 | # *** 139 | run_args = { 140 | 'batch_size' : int(args['batch_size']) * nr_gpus, 141 | 142 | 'nr_inference_workers' : int(args['nr_inference_workers']), 143 | 'nr_post_proc_workers' : int(args['nr_post_proc_workers']), 144 | } 145 | 146 | if args['model_mode'] == 'fast': 147 | run_args['patch_input_shape'] = 256 148 | run_args['patch_output_shape'] = 164 149 | else: 150 | run_args['patch_input_shape'] = 270 151 | run_args['patch_output_shape'] = 80 152 | 153 | if sub_cmd == 'tile': 154 | run_args.update({ 155 | 'input_dir' : sub_args['input_dir'], 156 | 'output_dir' : sub_args['output_dir'], 157 | 158 | 'mem_usage' : float(sub_args['mem_usage']), 159 | 'draw_dot' : sub_args['draw_dot'], 160 | 'save_qupath' : sub_args['save_qupath'], 161 | 'save_raw_map': sub_args['save_raw_map'], 162 | }) 163 | 164 | if sub_cmd == 'wsi': 165 | run_args.update({ 166 | 'input_dir' : sub_args['input_dir'], 167 | 'output_dir' : sub_args['output_dir'], 168 | 'presplit_dir' : sub_args['presplit_dir'], 169 | 'input_mask_dir' : sub_args['input_mask_dir'], 170 | 'cache_path' : sub_args['cache_path'], 171 | 172 | 'proc_mag' : int(sub_args['proc_mag']), 173 | 'ambiguous_size' : int(sub_args['ambiguous_size']), 174 | 'chunk_shape' : int(sub_args['chunk_shape']), 175 | 'tile_shape' : int(sub_args['tile_shape']), 176 | 'save_thumb' : sub_args['save_thumb'], 177 | 'save_mask' : sub_args['save_mask'], 178 | }) 179 | # *** 180 | 181 | if sub_cmd == 'tile': 182 | from infer.tile import InferManager 183 | infer = InferManager(**method_args) 184 | infer.process_file_list(run_args) 185 | else: 186 | from infer.wsi import InferManager 187 | infer = InferManager(**method_args) 188 | infer.process_wsi_list(run_args) 189 | -------------------------------------------------------------------------------- /Hover/run_tile.sh: -------------------------------------------------------------------------------- 1 | python run_infer.py \ 2 | --gpu='1' \ 3 | --nr_types=6 \ 4 | --type_info_path=type_info.json \ 5 | --batch_size=32 \ 6 | --model_mode=fast \ 7 | --model_path=/media/tiger/Disk1/jhan/code/hover_net-master/hovernet_fast_pannuke_type_tf2pytorch.tar \ 8 | --nr_inference_workers=8 \ 9 | --nr_post_proc_workers=16 \ 10 | tile \ 11 | --input_dir=/media/tiger/Disk1/jhan/datasets/sliangProstate/Images/ \ 12 | --output_dir=/media/tiger/Disk1/jhan/datasets/sliangProstate/Preds/ \ 13 | --mem_usage=0.1 \ 14 | --draw_dot \ 15 | --save_qupath 16 | -------------------------------------------------------------------------------- /Hover/run_train.py: -------------------------------------------------------------------------------- 1 | """run_train.py 2 | 3 | Main HoVer-Net training script. 4 | 5 | Usage: 6 | run_train.py [--gpu=] [--view=] 7 | run_train.py (-h | --help) 8 | run_train.py --version 9 | 10 | Options: 11 | -h --help Show this string. 12 | --version Show version. 13 | --gpu= Comma separated GPU list. [default: 0,1,2,3] 14 | --view= Visualise images after augmentation. Choose 'train' or 'valid'. 15 | """ 16 | 17 | import cv2 18 | 19 | cv2.setNumThreads(0) 20 | import argparse 21 | import glob 22 | import importlib 23 | import inspect 24 | import json 25 | import os 26 | import shutil 27 | 28 | import matplotlib 29 | import numpy as np 30 | import torch 31 | from docopt import docopt 32 | from tensorboardX import SummaryWriter 33 | from torch.nn import DataParallel # TODO: switch to DistributedDataParallel 34 | from torch.utils.data import DataLoader 35 | 36 | from config import Config 37 | from dataloader.train_loader import FileLoader 38 | from misc.utils import rm_n_mkdir 39 | from run_utils.engine import RunEngine 40 | from run_utils.utils import ( 41 | check_log_dir, 42 | check_manual_seed, 43 | colored, 44 | convert_pytorch_checkpoint, 45 | ) 46 | 47 | 48 | #### have to move outside because of spawn 49 | # * must initialize augmentor per worker, else duplicated rng generators may happen 50 | def worker_init_fn(worker_id): 51 | # ! to make the seed chain reproducible, must use the torch random, not numpy 52 | # the torch rng from main thread will regenerate a base seed, which is then 53 | # copied into the dataloader each time it created (i.e start of each epoch) 54 | # then dataloader with this seed will spawn worker, now we reseed the worker 55 | worker_info = torch.utils.data.get_worker_info() 56 | # to make it more random, simply switch torch.randint to np.randint 57 | worker_seed = torch.randint(0, 2 ** 32, (1,))[0].cpu().item() + worker_id 58 | # print('Loader Worker %d Uses RNG Seed: %d' % (worker_id, worker_seed)) 59 | # retrieve the dataset copied into this worker process 60 | # then set the random seed for each augmentation 61 | worker_info.dataset.setup_augmentor(worker_id, worker_seed) 62 | return 63 | 64 | 65 | #### 66 | class TrainManager(Config): 67 | """Either used to view the dataset or to initialise the main training loop.""" 68 | 69 | def __init__(self): 70 | super().__init__() 71 | return 72 | 73 | #### 74 | def view_dataset(self, mode="train"): 75 | """ 76 | Manually change to plt.savefig or plt.show 77 | if using on headless machine or not 78 | """ 79 | self.nr_gpus = 1 80 | import matplotlib.pyplot as plt 81 | check_manual_seed(self.seed) 82 | # TODO: what if each phase want diff annotation ? 83 | phase_list = self.model_config["phase_list"][0] 84 | target_info = phase_list["target_info"] 85 | prep_func, prep_kwargs = target_info["viz"] 86 | dataloader = self._get_datagen(2, mode, target_info["gen"]) 87 | for batch_data in dataloader: 88 | # convert from Tensor to Numpy 89 | batch_data = {k: v.numpy() for k, v in batch_data.items()} 90 | viz = prep_func(batch_data, is_batch=True, **prep_kwargs) 91 | plt.imshow(viz) 92 | plt.show() 93 | self.nr_gpus = -1 94 | return 95 | 96 | #### 97 | def _get_datagen(self, batch_size, run_mode, target_gen, nr_procs=0, fold_idx=0): 98 | nr_procs = nr_procs if not self.debug else 0 99 | 100 | # ! Hard assumption on file type 101 | file_list = [] 102 | if run_mode == "train": 103 | data_dir_list = self.train_dir_list 104 | else: 105 | data_dir_list = self.valid_dir_list 106 | for dir_path in data_dir_list: 107 | file_list.extend(glob.glob("%s/*.npy" % dir_path)) 108 | file_list.sort() # to always ensure same input ordering 109 | 110 | assert len(file_list) > 0, ( 111 | "No .npy found for `%s`, please check `%s` in `config.py`" 112 | % (run_mode, "%s_dir_list" % run_mode) 113 | ) 114 | print("Dataset %s: %d" % (run_mode, len(file_list))) 115 | input_dataset = FileLoader( 116 | file_list, 117 | mode=run_mode, 118 | with_type=self.type_classification, 119 | setup_augmentor=nr_procs == 0, 120 | target_gen=target_gen, 121 | **self.shape_info[run_mode] 122 | ) 123 | 124 | dataloader = DataLoader( 125 | input_dataset, 126 | num_workers=nr_procs, 127 | batch_size=batch_size * self.nr_gpus, 128 | shuffle=run_mode == "train", 129 | drop_last=run_mode == "train", 130 | worker_init_fn=worker_init_fn, 131 | ) 132 | return dataloader 133 | 134 | #### 135 | def run_once(self, opt, run_engine_opt, log_dir, prev_log_dir=None, fold_idx=0): 136 | """Simply run the defined run_step of the related method once.""" 137 | check_manual_seed(self.seed) 138 | 139 | log_info = {} 140 | if self.logging: 141 | # check_log_dir(log_dir) 142 | rm_n_mkdir(log_dir) 143 | 144 | tfwriter = SummaryWriter(log_dir=log_dir) 145 | json_log_file = log_dir + "/stats.json" 146 | with open(json_log_file, "w") as json_file: 147 | json.dump({}, json_file) # create empty file 148 | log_info = { 149 | "json_file": json_log_file, 150 | "tfwriter": tfwriter, 151 | } 152 | 153 | #### 154 | loader_dict = {} 155 | for runner_name, runner_opt in run_engine_opt.items(): 156 | loader_dict[runner_name] = self._get_datagen( 157 | opt["batch_size"][runner_name], 158 | runner_name, 159 | opt["target_info"]["gen"], 160 | nr_procs=runner_opt["nr_procs"], 161 | fold_idx=fold_idx, 162 | ) 163 | #### 164 | def get_last_chkpt_path(prev_phase_dir, net_name): 165 | stat_file_path = prev_phase_dir + "/stats.json" 166 | with open(stat_file_path) as stat_file: 167 | info = json.load(stat_file) 168 | epoch_list = [int(v) for v in info.keys()] 169 | last_chkpts_path = "%s/%s_epoch=%d.tar" % ( 170 | prev_phase_dir, 171 | net_name, 172 | max(epoch_list), 173 | ) 174 | return last_chkpts_path 175 | 176 | # TODO: adding way to load pretrained weight or resume the training 177 | # parsing the network and optimizer information 178 | net_run_info = {} 179 | net_info_opt = opt["run_info"] 180 | for net_name, net_info in net_info_opt.items(): 181 | assert inspect.isclass(net_info["desc"]) or inspect.isfunction( 182 | net_info["desc"] 183 | ), "`desc` must be a Class or Function which instantiate NEW objects !!!" 184 | net_desc = net_info["desc"]() 185 | 186 | # TODO: customize print-out for each run ? 187 | # summary_string(net_desc, (3, 270, 270), device='cpu') 188 | 189 | pretrained_path = net_info["pretrained"] 190 | if pretrained_path is not None: 191 | if pretrained_path == -1: 192 | # * depend on logging format so may be broken if logging format has been changed 193 | pretrained_path = get_last_chkpt_path(prev_log_dir, net_name) 194 | net_state_dict = torch.load(pretrained_path)["desc"] 195 | else: 196 | chkpt_ext = os.path.basename(pretrained_path).split(".")[-1] 197 | if chkpt_ext == "npz": 198 | net_state_dict = dict(np.load(pretrained_path)) 199 | net_state_dict = { 200 | k: torch.from_numpy(v) for k, v in net_state_dict.items() 201 | } 202 | elif chkpt_ext == "tar": # ! assume same saving format we desire 203 | net_state_dict = torch.load(pretrained_path)["desc"] 204 | 205 | colored_word = colored(net_name, color="red", attrs=["bold"]) 206 | print( 207 | "Model `%s` pretrained path: %s" % (colored_word, pretrained_path) 208 | ) 209 | 210 | # load_state_dict returns (missing keys, unexpected keys) 211 | net_state_dict = convert_pytorch_checkpoint(net_state_dict) 212 | load_feedback = net_desc.load_state_dict(net_state_dict, strict=False) 213 | # * uncomment for your convenience 214 | print("Missing Variables: \n", load_feedback[0]) 215 | print("Detected Unknown Variables: \n", load_feedback[1]) 216 | 217 | # * extremely slow to pass this on DGX with 1 GPU, why (?) 218 | net_desc = DataParallel(net_desc) 219 | net_desc = net_desc.to("cuda") 220 | # print(net_desc) # * dump network definition or not? 221 | optimizer, optimizer_args = net_info["optimizer"] 222 | optimizer = optimizer(net_desc.parameters(), **optimizer_args) 223 | # TODO: expand for external aug for scheduler 224 | nr_iter = opt["nr_epochs"] * len(loader_dict["train"]) 225 | scheduler = net_info["lr_scheduler"](optimizer) 226 | net_run_info[net_name] = { 227 | "desc": net_desc, 228 | "optimizer": optimizer, 229 | "lr_scheduler": scheduler, 230 | # TODO: standardize API for external hooks 231 | "extra_info": net_info["extra_info"], 232 | } 233 | 234 | # parsing the running engine configuration 235 | assert ( 236 | "train" in run_engine_opt 237 | ), "No engine for training detected in description file" 238 | 239 | # initialize runner and attach callback afterward 240 | # * all engine shared the same network info declaration 241 | runner_dict = {} 242 | for runner_name, runner_opt in run_engine_opt.items(): 243 | runner_dict[runner_name] = RunEngine( 244 | dataloader=loader_dict[runner_name], 245 | engine_name=runner_name, 246 | run_step=runner_opt["run_step"], 247 | run_info=net_run_info, 248 | log_info=log_info, 249 | ) 250 | 251 | for runner_name, runner in runner_dict.items(): 252 | callback_info = run_engine_opt[runner_name]["callbacks"] 253 | for event, callback_list, in callback_info.items(): 254 | for callback in callback_list: 255 | if callback.engine_trigger: 256 | triggered_runner_name = callback.triggered_engine_name 257 | callback.triggered_engine = runner_dict[triggered_runner_name] 258 | runner.add_event_handler(event, callback) 259 | 260 | # retrieve main runner 261 | main_runner = runner_dict["train"] 262 | main_runner.state.logging = self.logging 263 | main_runner.state.log_dir = log_dir 264 | # start the run loop 265 | main_runner.run(opt["nr_epochs"]) 266 | 267 | print("\n") 268 | print("########################################################") 269 | print("########################################################") 270 | print("\n") 271 | return 272 | 273 | #### 274 | def run(self): 275 | """Define multi-stage run or cross-validation or whatever in here.""" 276 | self.nr_gpus = torch.cuda.device_count() 277 | print('Detect #GPUS: %d' % self.nr_gpus) 278 | 279 | phase_list = self.model_config["phase_list"] 280 | engine_opt = self.model_config["run_engine"] 281 | 282 | prev_save_path = None 283 | for phase_idx, phase_info in enumerate(phase_list): 284 | if len(phase_list) == 1: 285 | save_path = self.log_dir 286 | else: 287 | save_path = self.log_dir + "/%02d/" % (phase_idx) 288 | self.run_once( 289 | phase_info, engine_opt, save_path, prev_log_dir=prev_save_path 290 | ) 291 | prev_save_path = save_path 292 | 293 | 294 | #### 295 | if __name__ == "__main__": 296 | args = docopt(__doc__, version="HoVer-Net v1.0") 297 | trainer = TrainManager() 298 | 299 | if args["--view"]: 300 | if args["--view"] != "train" and args["--view"] != "valid": 301 | raise Exception('Use "train" or "valid" for --view.') 302 | trainer.view_dataset(args["--view"]) 303 | else: 304 | os.environ["CUDA_VISIBLE_DEVICES"] = args["--gpu"] 305 | trainer.run() 306 | -------------------------------------------------------------------------------- /Hover/run_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuscc-deep-path/sc_MTOP/33ff31fbd01f37705118244da0a8df96a4f19014/Hover/run_utils/__init__.py -------------------------------------------------------------------------------- /Hover/run_utils/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuscc-deep-path/sc_MTOP/33ff31fbd01f37705118244da0a8df96a4f19014/Hover/run_utils/callbacks/__init__.py -------------------------------------------------------------------------------- /Hover/run_utils/callbacks/base.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import json 3 | 4 | import cv2 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | from misc.utils import center_pad_to_shape, cropping_center 9 | from scipy.stats import mode as major_value 10 | from sklearn.metrics import confusion_matrix 11 | 12 | 13 | #### 14 | class BaseCallbacks(object): 15 | def __init__(self): 16 | self.engine_trigger = False 17 | 18 | def reset(self): 19 | pass 20 | 21 | def run(self, state, event): 22 | pass 23 | 24 | 25 | #### 26 | class TrackLr(BaseCallbacks): 27 | """ 28 | Add learning rate to tracking 29 | """ 30 | 31 | def __init__(self, per_n_epoch=1, per_n_step=None): 32 | super().__init__() 33 | self.per_n_epoch = per_n_epoch 34 | self.per_n_step = per_n_step 35 | 36 | def run(self, state, event): 37 | # logging learning rate, decouple into another callback? 38 | run_info = state.run_info 39 | for net_name, net_info in run_info.items(): 40 | lr = net_info["optimizer"].param_groups[0]["lr"] 41 | state.tracked_step_output["scalar"]["lr-%s" % net_name] = lr 42 | return 43 | 44 | 45 | #### 46 | class ScheduleLr(BaseCallbacks): 47 | """Trigger all scheduler.""" 48 | 49 | def __init__(self): 50 | super().__init__() 51 | 52 | def run(self, state, event): 53 | # logging learning rate, decouple into another callback? 54 | run_info = state.run_info 55 | for net_name, net_info in run_info.items(): 56 | net_info["lr_scheduler"].step() 57 | return 58 | 59 | 60 | #### 61 | class TriggerEngine(BaseCallbacks): 62 | def __init__(self, triggered_engine_name, nr_epoch=1): 63 | self.engine_trigger = True 64 | self.triggered_engine_name = triggered_engine_name 65 | self.triggered_engine = None 66 | self.nr_epoch = nr_epoch 67 | 68 | def run(self, state, event): 69 | self.triggered_engine.run( 70 | chained=True, nr_epoch=self.nr_epoch, shared_state=state 71 | ) 72 | return 73 | 74 | 75 | #### 76 | class PeriodicSaver(BaseCallbacks): 77 | """Must declare save dir first in the shared global state of the attached engine.""" 78 | 79 | def __init__(self, per_n_epoch=1, per_n_step=None): 80 | super().__init__() 81 | self.per_n_epoch = per_n_epoch 82 | self.per_n_step = per_n_step 83 | 84 | def run(self, state, event): 85 | if not state.logging: 86 | return 87 | 88 | # TODO: add switch so that only one of [per_n_epoch / per_n_step] can run 89 | if state.curr_epoch % self.per_n_epoch != 0: 90 | return 91 | 92 | for net_name, net_info in state.run_info.items(): 93 | net_checkpoint = {} 94 | for key, value in net_info.items(): 95 | if key != "extra_info": 96 | net_checkpoint[key] = value.state_dict() 97 | torch.save( 98 | net_checkpoint, 99 | "%s/%s_epoch=%d.tar" % (state.log_dir, net_name, state.curr_epoch), 100 | ) 101 | return 102 | 103 | 104 | #### 105 | class ConditionalSaver(BaseCallbacks): 106 | """Must declare save dir first in the shared global state of the attached engine.""" 107 | 108 | def __init__(self, metric_name, comparator=">="): 109 | super().__init__() 110 | self.metric_name = metric_name 111 | self.comparator = comparator 112 | 113 | def run(self, state, event): 114 | if not state.logging: 115 | return 116 | 117 | ops = { 118 | ">": operator.gt, 119 | "<": operator.lt, 120 | ">=": operator.ge, 121 | "<=": operator.le, 122 | } 123 | op_func = ops[self.comparator] 124 | if self.comparator == ">" or self.comparator == ">=": 125 | best_value = -float("inf") 126 | else: 127 | best_value = +float("inf") 128 | 129 | # json stat log file, update and overwrite 130 | with open(state.log_info["json_file"]) as json_file: 131 | json_data = json.load(json_file) 132 | 133 | for epoch, epoch_stat in json_data.items(): 134 | epoch_value = epoch_stat[self.metric_name] 135 | if op_func(epoch_value, best_value): 136 | best_value = epoch_value 137 | 138 | current_value = json_data[str(state.curr_epoch)][self.metric_name] 139 | if not op_func(current_value, best_value): 140 | return # simply return because not satisfy 141 | 142 | print( 143 | state.curr_epoch 144 | ) # TODO: better way to track which optimal epoch is saved 145 | for net_name, net_info in state.run_info.items(): 146 | net_checkpoint = {} 147 | for key, value in net_info.items(): 148 | if key != "extra_info": 149 | net_checkpoint[key] = value.state_dict() 150 | torch.save( 151 | net_checkpoint, 152 | "%s/%s_best=[%s].tar" % (state.log_dir, net_name, self.metric_name), 153 | ) 154 | return 155 | 156 | 157 | #### 158 | class AccumulateRawOutput(BaseCallbacks): 159 | def run(self, state, event): 160 | step_output = state.step_output["raw"] 161 | accumulated_output = state.epoch_accumulated_output 162 | 163 | for key, step_value in step_output.items(): 164 | if key in accumulated_output: 165 | accumulated_output[key].extend(list(step_value)) 166 | else: 167 | accumulated_output[key] = list(step_value) 168 | return 169 | 170 | 171 | #### 172 | class ScalarMovingAverage(BaseCallbacks): 173 | """Calculate the running average for all scalar output of 174 | each runstep of the attached RunEngine.""" 175 | 176 | def __init__(self, alpha=0.95): 177 | super().__init__() 178 | self.alpha = alpha 179 | self.tracking_dict = {} 180 | 181 | def run(self, state, event): 182 | # TODO: protocol for dynamic key retrieval for EMA 183 | step_output = state.step_output["EMA"] 184 | 185 | for key, current_value in step_output.items(): 186 | if key in self.tracking_dict: 187 | old_ema_value = self.tracking_dict[key] 188 | # calculate the exponential moving average 189 | new_ema_value = ( 190 | old_ema_value * self.alpha + (1.0 - self.alpha) * current_value 191 | ) 192 | self.tracking_dict[key] = new_ema_value 193 | else: # init for variable which appear for the first time 194 | new_ema_value = current_value 195 | self.tracking_dict[key] = new_ema_value 196 | 197 | state.tracked_step_output["scalar"] = self.tracking_dict 198 | return 199 | 200 | 201 | #### 202 | class ProcessAccumulatedRawOutput(BaseCallbacks): 203 | def __init__(self, proc_func, per_n_epoch=1): 204 | # TODO: allow dynamically attach specific procesing for `type` 205 | super().__init__() 206 | self.per_n_epoch = per_n_epoch 207 | self.proc_func = proc_func 208 | 209 | def run(self, state, event): 210 | current_epoch = state.curr_epoch 211 | # if current_epoch % self.per_n_epoch != 0: return 212 | raw_data = state.epoch_accumulated_output 213 | track_dict = self.proc_func(raw_data) 214 | # update global shared states 215 | state.tracked_step_output = track_dict 216 | return 217 | 218 | 219 | #### 220 | class VisualizeOutput(BaseCallbacks): 221 | def __init__(self, proc_func, per_n_epoch=1): 222 | super().__init__() 223 | # TODO: option to dump viz per epoch or per n step 224 | self.per_n_epoch = per_n_epoch 225 | self.proc_func = proc_func 226 | 227 | def run(self, state, event): 228 | current_epoch = state.curr_epoch 229 | raw_output = state.step_output["raw"] 230 | viz_image = self.proc_func(raw_output) 231 | state.tracked_step_output["image"]["output"] = viz_image 232 | return 233 | -------------------------------------------------------------------------------- /Hover/run_utils/callbacks/logging.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from matplotlib.lines import Line2D 7 | from termcolor import colored 8 | 9 | from .base import BaseCallbacks 10 | from .serialize import fig2data, serialize 11 | 12 | # TODO: logging for all printed info on the terminal 13 | 14 | 15 | #### 16 | class LoggingGradient(BaseCallbacks): 17 | """Will log per each training step.""" 18 | 19 | def _pyplot_grad_flow(self, named_parameters): 20 | """Plots the gradients flowing through different layers in the net during training. 21 | "_pyplot_grad_flow(self.model.named_parameters())" to visualize the gradient flow. 22 | 23 | ! Very slow if triggered per steps because of CPU <=> GPU. 24 | 25 | """ 26 | ave_grads = [] 27 | max_grads = [] 28 | layers = [] 29 | for n, p in named_parameters: 30 | if (p.requires_grad) and ("bias" not in n): 31 | layers.append(n) 32 | ave_grads.append(p.grad.abs().mean().cpu().item()) 33 | max_grads.append(p.grad.abs().max().cpu().item()) 34 | fig = plt.figure(figsize=(10, 10)) 35 | plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c") 36 | plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b") 37 | plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k") 38 | plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical") 39 | plt.xlim(left=0, right=len(ave_grads)) 40 | # zoom in on the lower gradient regions 41 | plt.xlabel("Layers") 42 | plt.ylabel("average gradient") 43 | plt.title("Gradient flow") 44 | plt.grid(True) 45 | plt.legend( 46 | [ 47 | Line2D([0], [0], color="c", lw=4), 48 | Line2D([0], [0], color="b", lw=4), 49 | Line2D([0], [0], color="k", lw=4), 50 | ], 51 | ["max-gradient", "mean-gradient", "zero-gradient"], 52 | ) 53 | fig = np.transpose(fig2data(fig), axes=[2, 0, 1]) # HWC => CHW 54 | plt.close() 55 | return fig 56 | 57 | def run(self, state, event): 58 | 59 | if random.random() > 0.05: 60 | return 61 | curr_step = state.curr_global_step 62 | 63 | # logging the grad of all trainable parameters 64 | tfwriter = state.log_info["tfwriter"] 65 | run_info = state.run_info 66 | for net_name, net_info in run_info.items(): 67 | netdesc = net_info["desc"].module 68 | for param_name, param in netdesc.named_parameters(): 69 | param_grad = param.grad 70 | # TODO: sync test None or epislon for pytorch 1.4 vs 1.5 71 | if param_grad is None: 72 | continue 73 | tfwriter.add_histogram( 74 | "%s_grad/%s" % (net_name, param_name), 75 | param_grad.detach().cpu().numpy().flatten(), 76 | global_step=curr_step, 77 | ) # ditribute into 10 bins (np default) 78 | tfwriter.add_histogram( 79 | "%s_para/%s" % (net_name, param_name), 80 | param.detach().cpu().numpy().flatten(), 81 | global_step=curr_step, 82 | ) # ditribute into 10 bins (np default) 83 | return 84 | 85 | 86 | #### 87 | class LoggingEpochOutput(BaseCallbacks): 88 | """Must declare save dir first in the shared global state of the attached engine.""" 89 | 90 | def __init__(self, per_n_epoch=1): 91 | super().__init__() 92 | self.per_n_epoch = per_n_epoch 93 | 94 | def run(self, state, event): 95 | 96 | # only logging every n epochs also 97 | if state.curr_epoch % self.per_n_epoch != 0: 98 | return 99 | 100 | # TODO: rename to differentiate global vs local epoch 101 | if state.global_state is not None: 102 | current_epoch = str(state.global_state.curr_epoch) 103 | else: 104 | current_epoch = str(state.curr_epoch) 105 | 106 | output = state.tracked_step_output 107 | 108 | def get_serializable_values(output_format): 109 | log_dict = {} 110 | # get type and variable that is serializable 111 | # to console or other logging format (json, tensorboard) 112 | for variable_type, variable_dict in output.items(): 113 | for value_name, value in variable_dict.items(): 114 | value_name = "%s-%s" % (state.attached_engine_name, value_name) 115 | new_format = serialize(value, variable_type, output_format) 116 | if new_format is not None: 117 | log_dict[value_name] = new_format 118 | return log_dict 119 | 120 | # * Serialize to Console 121 | # align the console print output 122 | formatted_values = get_serializable_values("console") 123 | max_length = len(max(formatted_values.keys(), key=len)) 124 | for value_name, value_text in formatted_values.items(): 125 | value_name = colored(value_name.ljust(max_length), "green") 126 | print("------%s : %s" % (value_name, value_text)) 127 | 128 | # TODO: [CRITICAL] fix passing this between engine 129 | # if not state.logging: return 130 | 131 | # * Serialize to JSON file 132 | stat_dict = get_serializable_values("json") 133 | # json stat log file, update and overwrite 134 | with open(state.log_info["json_file"]) as json_file: 135 | json_data = json.load(json_file) 136 | 137 | if current_epoch in json_data: 138 | old_stat_dict = json_data[current_epoch] 139 | stat_dict.update(old_stat_dict) 140 | current_epoch_dict = {current_epoch: stat_dict} 141 | json_data.update(current_epoch_dict) 142 | 143 | # TODO: may corrupt 144 | with open(state.log_info["json_file"], "w") as json_file: 145 | json.dump(json_data, json_file) 146 | 147 | # * Serialize to Tensorboard 148 | tfwriter = state.log_info["tfwriter"] 149 | formatted_values = get_serializable_values("tensorboard") 150 | # ! may need to flush to force update 151 | for value_name, value in formatted_values.items(): 152 | # TODO: dynamically call this 153 | if value[0] == "scalar": 154 | tfwriter.add_scalar(value_name, value[1], current_epoch) 155 | elif value[0] == "image": 156 | tfwriter.add_image( 157 | value_name, value[1], current_epoch, dataformats="HWC" 158 | ) 159 | # tfwriter.flush() 160 | 161 | return 162 | -------------------------------------------------------------------------------- /Hover/run_utils/callbacks/serialize.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import matplotlib 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | 6 | # * syn where to set this 7 | # must use 'Agg' to plot out onto image 8 | matplotlib.use("Agg") 9 | 10 | #### 11 | def fig2data(fig, dpi=180): 12 | """Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it. 13 | 14 | Args: 15 | fig: a matplotlib figure 16 | 17 | Return: a numpy 3D array of RGBA values 18 | 19 | """ 20 | buf = io.BytesIO() 21 | fig.savefig(buf, format="png", dpi=dpi) 22 | buf.seek(0) 23 | img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8) 24 | buf.close() 25 | img = cv2.imdecode(img_arr, 1) 26 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 27 | return img 28 | 29 | 30 | #### 31 | class _Scalar(object): 32 | @staticmethod 33 | def to_console(value): 34 | return "%0.5f" % value 35 | 36 | @staticmethod 37 | def to_json(value): 38 | return value 39 | 40 | @staticmethod 41 | def to_tensorboard(value): 42 | return "scalar", value 43 | 44 | 45 | #### 46 | class _ConfusionMatrix(object): 47 | @staticmethod 48 | def to_console(value): 49 | value = pd.DataFrame(value) 50 | value.index.name = "True" 51 | value.columns.name = "Pred" 52 | formatted_value = value.to_string() 53 | return "\n" + formatted_value 54 | 55 | @staticmethod 56 | def to_json(value): 57 | value = pd.DataFrame(value) 58 | value.index.name = "True" 59 | value.columns.name = "Pred" 60 | value = value.unstack().rename("value").reset_index() 61 | value = pd.Series({"conf_mat": value}) 62 | formatted_value = value.to_json(orient="records") 63 | return formatted_value 64 | 65 | @staticmethod 66 | def to_tensorboard(value): 67 | def plot_confusion_matrix( 68 | cm, target_names, title="Confusion matrix", cmap=None, normalize=False 69 | ): 70 | """given a sklearn confusion matrix (cm), make a nice plot. 71 | 72 | Args: 73 | cm: confusion matrix from sklearn.metrics.confusion_matrix 74 | 75 | target_names: given classification classes such as [0, 1, 2] 76 | the class names, for example: ['high', 'medium', 'low'] 77 | 78 | title: the text to display at the top of the matrix 79 | 80 | cmap: the gradient of the values displayed from matplotlib.pyplot.cm 81 | see http://matplotlib.org/examples/color/colormaps_reference.html 82 | plt.get_cmap('jet') or plt.cm.Blues 83 | 84 | normalize: If False, plot the raw numbers 85 | If True, plot the proportions 86 | 87 | Usage 88 | ----- 89 | plot_confusion_matrix(cm = cm, # confusion matrix created by 90 | # sklearn.metrics.confusion_matrix 91 | normalize = True, # show proportions 92 | target_names = y_labels_vals, # list of names of the classes 93 | title = best_estimator_name) # title of graph 94 | 95 | Citiation 96 | --------- 97 | http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html 98 | 99 | """ 100 | import matplotlib.pyplot as plt 101 | import numpy as np 102 | import itertools 103 | 104 | accuracy = np.trace(cm) / np.sum(cm).astype("float") 105 | misclass = 1 - accuracy 106 | 107 | if cmap is None: 108 | cmap = plt.get_cmap("Blues") 109 | 110 | plt.figure(figsize=(8, 6)) 111 | plt.imshow(cm, interpolation="nearest", cmap=cmap) 112 | plt.title(title) 113 | plt.colorbar() 114 | 115 | if target_names is not None: 116 | tick_marks = np.arange(len(target_names)) 117 | plt.xticks(tick_marks, target_names, rotation=45) 118 | plt.yticks(tick_marks, target_names) 119 | 120 | if normalize: 121 | cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] 122 | 123 | thresh = cm.max() / 1.5 if normalize else cm.max() / 2 124 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 125 | if normalize: 126 | plt.text( 127 | j, 128 | i, 129 | "{:0.4f}".format(cm[i, j]), 130 | horizontalalignment="center", 131 | color="white" if cm[i, j] > thresh else "black", 132 | ) 133 | else: 134 | plt.text( 135 | j, 136 | i, 137 | "{:,}".format(cm[i, j]), 138 | horizontalalignment="center", 139 | color="white" if cm[i, j] > thresh else "black", 140 | ) 141 | 142 | plt.tight_layout() 143 | plt.ylabel("True label") 144 | plt.xlabel( 145 | "Predicted label\naccuracy={:0.4f}; misclass={:0.4f}".format( 146 | accuracy, misclass 147 | ) 148 | ) 149 | 150 | plot_confusion_matrix(value, ["0", "1"]) 151 | img = fig2data(plt.gcf()) 152 | plt.close() 153 | return "image", img 154 | 155 | 156 | #### 157 | class _Image(object): 158 | @staticmethod 159 | def to_console(value): 160 | # TODO: add warn for not possible or sthg here 161 | return None 162 | 163 | @staticmethod 164 | def to_json(value): 165 | # TODO: add warn for not possible or sthg here 166 | return None 167 | 168 | @staticmethod 169 | def to_tensorboard(value): 170 | # TODO: add method 171 | return "image", value 172 | 173 | 174 | __converter_dict = {"scalar": _Scalar, "conf_mat": _ConfusionMatrix, "image": _Image} 175 | 176 | 177 | #### 178 | def serialize(value, input_format, output_format): 179 | converter = __converter_dict[input_format] 180 | if output_format == "console": 181 | return converter.to_console(value) 182 | elif output_format == "json": 183 | return converter.to_json(value) 184 | elif output_format == "tensorboard": 185 | return converter.to_tensorboard(value) 186 | else: 187 | assert False, "Unknown format" 188 | -------------------------------------------------------------------------------- /Hover/run_utils/engine.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from enum import Enum 3 | 4 | 5 | #### 6 | class Events(Enum): 7 | EPOCH_STARTED = "epoch_started" 8 | EPOCH_COMPLETED = "epoch_completed" 9 | STEP_STARTED = "step_started" 10 | STEP_COMPLETED = "step_completed" 11 | STARTED = "started" 12 | COMPLETED = "completed" 13 | EXCEPTION_RAISED = "exception_raised" 14 | 15 | 16 | #### 17 | class State(object): 18 | """An object that is used to pass internal and user-defined state between event handlers.""" 19 | 20 | def __init__(self): 21 | # settings propagated from config 22 | self.logging = None 23 | self.log_dir = None 24 | self.log_info = None 25 | 26 | # internal variable 27 | self.curr_epoch_step = 0 # current step in epoch 28 | self.curr_global_step = 0 # current global step 29 | self.curr_epoch = 0 # current global epoch 30 | 31 | # TODO: [LOW] better document this 32 | # for outputing value that will be tracked per step 33 | # "scalar" will always be printed out and added to the tensorboard 34 | # "images" will need dedicated function to process and added to the tensorboard 35 | 36 | # ! naming should match with types supported for serialize 37 | # TODO: Need way to dynamically adding new types 38 | self.tracked_step_output = { 39 | "scalar": {}, # type : {variable_name : variablee_value} 40 | "image": {}, 41 | } 42 | # TODO: find way to known which method bind/interact with which value 43 | 44 | self.epoch_accumulated_output = {} # all output of the current epoch 45 | 46 | # TODO: soft reset for pertain variable for N epochs 47 | self.run_accumulated_output = [] # of run until reseted 48 | 49 | # holder for output returned after current runstep 50 | # * depend on the type of training i.e GAN, the updated accumulated may be different 51 | self.step_output = None 52 | 53 | self.global_state = None 54 | return 55 | 56 | def reset_variable(self): 57 | # type : {variable_name : variable_value} 58 | self.tracked_step_output = {k: {} for k in self.tracked_step_output.keys()} 59 | 60 | # TODO: [CRITICAL] refactor this 61 | if self.curr_epoch % self.pertain_n_epoch_output == 0: 62 | self.run_accumulated_output = [] 63 | 64 | self.epoch_accumulated_output = {} 65 | 66 | # * depend on the type of training i.e GAN, the updated accumulated may be different 67 | self.step_output = None # holder for output returned after current runstep 68 | return 69 | 70 | 71 | #### 72 | class RunEngine(object): 73 | """ 74 | TODO: Include docstring 75 | """ 76 | 77 | def __init__( 78 | self, 79 | engine_name=None, 80 | dataloader=None, 81 | run_step=None, 82 | run_info=None, 83 | log_info=None, # TODO: refactor this with trainer.py 84 | ): 85 | 86 | # * auto set all input as object variables 87 | self.engine_name = engine_name 88 | self.run_step = run_step 89 | self.dataloader = dataloader 90 | 91 | # * global variable/object holder shared between all event handler 92 | self.state = State() 93 | # * check if correctly referenced, not new copies 94 | self.state.attached_engine_name = engine_name # TODO: redundant? 95 | self.state.run_info = run_info 96 | self.state.log_info = log_info 97 | self.state.batch_size = dataloader.batch_size 98 | 99 | # TODO: [CRITICAL] match all the mechanism outline with opt 100 | self.state.pertain_n_epoch_output = 1 if engine_name == "valid" else 1 101 | 102 | self.event_handler_dict = {event: [] for event in Events} 103 | 104 | # TODO: think about this more 105 | # to share global state across a chain of RunEngine such as 106 | # from the engine for training to engine for validation 107 | 108 | # 109 | self.terminate = False 110 | return 111 | 112 | def __reset_state(self): 113 | # TODO: think about this more, looks too redundant 114 | new_state = State() 115 | new_state.attached_engine_name = self.state.attached_engine_name 116 | new_state.run_info = self.state.run_info 117 | new_state.log_info = self.state.log_info 118 | self.state = new_state 119 | return 120 | 121 | def __trigger_events(self, event): 122 | for callback in self.event_handler_dict[event]: 123 | callback.run(self.state, event) 124 | # TODO: exception and throwing error with name or sthg to allow trace back 125 | return 126 | 127 | # TODO: variable to indicate output dependency between handler ! 128 | def add_event_handler(self, event_name, handler): 129 | self.event_handler_dict[event_name].append(handler) 130 | 131 | # ! Put into trainer.py ? 132 | def run(self, nr_epoch=1, shared_state=None, chained=False): 133 | 134 | # TODO: refactor this 135 | if chained: 136 | self.state.curr_epoch = 0 137 | self.state.global_state = shared_state 138 | 139 | while self.state.curr_epoch < nr_epoch: 140 | self.state.reset_variable() # * reset all EMA holder per epoch 141 | 142 | if not chained: 143 | print("----------------EPOCH %d" % (self.state.curr_epoch + 1)) 144 | 145 | self.__trigger_events(Events.EPOCH_STARTED) 146 | 147 | pbar_format = ( 148 | "Processing: |{bar}| " 149 | "{n_fmt}/{total_fmt}[{elapsed}<{remaining},{rate_fmt}]" 150 | ) 151 | if self.engine_name == "train": 152 | pbar_format += ( 153 | "Batch = {postfix[1][Batch]:0.5f}|" "EMA = {postfix[1][EMA]:0.5f}" 154 | ) 155 | # * changing print char may break the bar so avoid it 156 | pbar = tqdm.tqdm( 157 | total=len(self.dataloader), 158 | leave=True, 159 | initial=0, 160 | bar_format=pbar_format, 161 | ascii=True, 162 | postfix=["", dict(Batch=float("NaN"), EMA=float("NaN"))], 163 | ) 164 | else: 165 | pbar = tqdm.tqdm( 166 | total=len(self.dataloader), 167 | leave=True, 168 | bar_format=pbar_format, 169 | ascii=True, 170 | ) 171 | 172 | for data_batch in self.dataloader: 173 | self.__trigger_events(Events.STEP_STARTED) 174 | 175 | step_run_info = [ 176 | self.state.run_info, 177 | { 178 | "epoch": self.state.curr_epoch, 179 | "step": self.state.curr_global_step, 180 | }, 181 | ] 182 | step_output = self.run_step(data_batch, step_run_info) 183 | self.state.step_output = step_output 184 | 185 | self.__trigger_events(Events.STEP_COMPLETED) 186 | self.state.curr_global_step += 1 187 | self.state.curr_epoch_step += 1 188 | 189 | if self.engine_name == "train": 190 | pbar.postfix[1]["Batch"] = step_output["EMA"]["overall_loss"] 191 | pbar.postfix[1]["EMA"] = self.state.tracked_step_output["scalar"][ 192 | "overall_loss" 193 | ] 194 | pbar.update() 195 | pbar.close() # to flush out the bar before doing end of epoch reporting 196 | self.state.curr_epoch += 1 197 | self.__trigger_events(Events.EPOCH_COMPLETED) 198 | 199 | # TODO: [CRITICAL] align the protocol 200 | self.state.run_accumulated_output.append( 201 | self.state.epoch_accumulated_output 202 | ) 203 | 204 | return 205 | 206 | -------------------------------------------------------------------------------- /Hover/run_utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | from collections import OrderedDict 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from imgaug import imgaug as ia 10 | from termcolor import colored 11 | from torch.autograd import Variable 12 | 13 | 14 | #### 15 | def convert_pytorch_checkpoint(net_state_dict): 16 | variable_name_list = list(net_state_dict.keys()) 17 | is_in_parallel_mode = all(v.split(".")[0] == "module" for v in variable_name_list) 18 | if is_in_parallel_mode: 19 | colored_word = colored("WARNING", color="red", attrs=["bold"]) 20 | print( 21 | ( 22 | "%s: Detect checkpoint saved in data-parallel mode." 23 | " Converting saved model to single GPU mode." % colored_word 24 | ).rjust(80) 25 | ) 26 | net_state_dict = { 27 | ".".join(k.split(".")[1:]): v for k, v in net_state_dict.items() 28 | } 29 | return net_state_dict 30 | 31 | 32 | #### 33 | def check_manual_seed(seed): 34 | """ If manual seed is not specified, choose a 35 | random one and communicate it to the user. 36 | 37 | Args: 38 | seed: seed to check 39 | 40 | """ 41 | seed = seed or random.randint(1, 10000) 42 | random.seed(seed) 43 | np.random.seed(seed) 44 | torch.manual_seed(seed) 45 | torch.cuda.manual_seed(seed) 46 | # ia.random.seed(seed) 47 | 48 | print("Using manual seed: {seed}".format(seed=seed)) 49 | return 50 | 51 | 52 | #### 53 | def check_log_dir(log_dir): 54 | """Check if log directory exists. 55 | 56 | Args: 57 | log_dir: path to logs 58 | 59 | """ 60 | if os.path.isdir(log_dir): 61 | colored_word = colored("WARNING", color="red", attrs=["bold", "blink"]) 62 | print("%s: %s exist!" % (colored_word, colored(log_dir, attrs=["underline"]))) 63 | while True: 64 | print("Select Action: d (delete) / q (quit)", end="") 65 | key = input() 66 | if key == "d": 67 | shutil.rmtree(log_dir) 68 | break 69 | elif key == "q": 70 | exit() 71 | else: 72 | color_word = colored("ERR", color="red") 73 | print("---[%s] Unrecognize Characters!" % colored_word) 74 | return 75 | 76 | 77 | def get_model_summary( 78 | model, input_size, batch_size=-1, device=torch.device("cpu"), dtypes=None 79 | ): 80 | """Reusable utility layers such as pool or upsample will also get printed, but their printed values will 81 | be corresponding to the last call. 82 | 83 | """ 84 | if dtypes == None: 85 | dtypes = [torch.FloatTensor] * len(input_size) 86 | 87 | summary_str = "" 88 | 89 | def register_hook(module): 90 | def hook(module, input, output): 91 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 92 | module_idx = len(summary) 93 | 94 | m_key = module.name if module.name != "" else "%s" % class_name 95 | 96 | summary[m_key] = OrderedDict() 97 | summary[m_key]["input_shape"] = list(input[0].size()) 98 | summary[m_key]["input_shape"][0] = batch_size 99 | if isinstance(output, (list, tuple)): 100 | summary[m_key]["output_shape"] = [ 101 | [-1] + list(o.size())[1:] for o in output 102 | ] 103 | elif isinstance(output, dict): 104 | summary[m_key]["output_shape"] = [ 105 | [-1] + list(o.size())[1:] for o in output.values() 106 | ] 107 | elif isinstance(output, torch.Tensor): 108 | summary[m_key]["output_shape"] = list(output.size()) 109 | summary[m_key]["output_shape"][0] = batch_size 110 | 111 | params = 0 112 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 113 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 114 | summary[m_key]["trainable"] = module.weight.requires_grad 115 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 116 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 117 | summary[m_key]["nb_params"] = params 118 | 119 | if len(list(module.children())) == 0: 120 | hooks.append(module.register_forward_hook(hook)) 121 | 122 | # multiple inputs to the network 123 | if isinstance(input_size, tuple): 124 | input_size = [input_size] 125 | 126 | # batch_size of 2 for batchnorm 127 | x = [ 128 | torch.rand(2, *in_size).type(dtype).to(device=device) 129 | for in_size, dtype in zip(input_size, dtypes) 130 | ] 131 | 132 | # create properties 133 | summary = OrderedDict() 134 | hooks = [] 135 | 136 | # create layer name according to hierachy names 137 | for name, module in model.named_modules(): 138 | module.name = name 139 | 140 | # register hook 141 | model.apply(register_hook) 142 | 143 | # make a forward pass 144 | model(*x) 145 | 146 | # aligning name to the left 147 | max_name_length = len(max(summary.keys(), key=len)) 148 | summary = [(k.ljust(max_name_length), v) for k, v in summary.items()] 149 | summary = OrderedDict(summary) 150 | 151 | # remove these hooks 152 | for h in hooks: 153 | h.remove() 154 | 155 | header_line = "{} {:>25} {:>15}".format( 156 | "Layer Name".center(max_name_length), "Output Shape", "Param #" 157 | ) 158 | summary_str += "".join("-" for _ in range(len(header_line))) + "\n" 159 | summary_str += header_line + "\n" 160 | summary_str += "".join("=" for _ in range(len(header_line))) + "\n" 161 | total_params = 0 162 | total_output = 0 163 | trainable_params = 0 164 | for layer in summary: 165 | # input_shape, output_shape, trainable, nb_params 166 | line_new = "{:>20} {:>25} {:>15}".format( 167 | layer, 168 | str(summary[layer]["output_shape"]), 169 | "{0:,}".format(summary[layer]["nb_params"]), 170 | ) 171 | total_params += summary[layer]["nb_params"] 172 | 173 | total_output += np.prod(summary[layer]["output_shape"]) 174 | if "trainable" in summary[layer]: 175 | if summary[layer]["trainable"] == True: 176 | trainable_params += summary[layer]["nb_params"] 177 | summary_str += line_new + "\n" 178 | 179 | # assume 4 bytes/number (float on cuda). 180 | total_input_size = abs( 181 | np.prod(sum(input_size, ())) * batch_size * 4.0 / (1024 ** 2.0) 182 | ) 183 | total_output_size = abs( 184 | 2.0 * total_output * 4.0 / (1024 ** 2.0) 185 | ) # x2 for gradients 186 | total_params_size = abs(total_params * 4.0 / (1024 ** 2.0)) 187 | total_size = total_params_size + total_output_size + total_input_size 188 | 189 | summary_str += "".join("=" for _ in range(len(header_line))) + "\n" 190 | summary_str += "Total params: {0:,}".format(total_params) + "\n" 191 | summary_str += "Trainable params: {0:,}".format(trainable_params) + "\n" 192 | summary_str += ( 193 | "Non-trainable params: {0:,}".format(total_params - trainable_params) + "\n" 194 | ) 195 | summary_str += "".join("-" for _ in range(len(header_line))) + "\n" 196 | summary_str += "Input size (MB): %0.2f" % total_input_size + "\n" 197 | summary_str += "Forward/backward pass size (MB): %0.2f" % total_output_size + "\n" 198 | summary_str += "Params size (MB): %0.2f" % total_params_size + "\n" 199 | summary_str += "Estimated Total Size (MB): %0.2f" % total_size + "\n" 200 | summary_str += "".join("-" for _ in range(len(header_line))) + "\n" 201 | return summary_str 202 | -------------------------------------------------------------------------------- /Hover/run_wsi.sh: -------------------------------------------------------------------------------- 1 | python run_infer.py \ 2 | --gpu='0' \ 3 | --nr_types=6 \ 4 | --type_info_path=type_info.json \ 5 | --batch_size=32 \ 6 | --model_mode=fast \ 7 | --model_path=/home/xujun/FUSCC/Hover/hovernet_fast_pannuke_type_tf2pytorch.tar \ 8 | --nr_inference_workers=4 \ 9 | --nr_post_proc_workers=0 \ 10 | wsi \ 11 | --input_dir=/home/xujun/FUSCC/WSI_example/WSI2 \ 12 | --output_dir=/home/xujun/FUSCC/WSI_example/pred2 \ 13 | --presplit_dir=/home/xujun/FUSCC/WSI_example/WSI_presplit2 \ 14 | --proc_mag 20 \ 15 | --save_thumb \ 16 | --save_mask 17 | -------------------------------------------------------------------------------- /Hover/type_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "0" : ["nolabe", [0 , 0, 0]], 3 | "1" : ["neopla", [255, 0, 0]], 4 | "2" : ["inflam", [0 , 255, 0]], 5 | "3" : ["connec", [0 , 0, 255]], 6 | "4" : ["necros", [255, 255, 0]], 7 | "5" : ["no-neo", [255, 165, 0]] 8 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 FUSCC DeepPath & NUIST IMIC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """main.py 2 | 3 | Usage: 4 | main.py [options] [--help] [...] 5 | main.py (-h | --help) 6 | 7 | Options: 8 | -h --help Show this screen 9 | 10 | Two command mode are `tile` and `wsi` to enter corresponding inference mode 11 | segment cell segment from the whole slide image 12 | feature feature extract 13 | visual visualization of the graph and segmentation 14 | 15 | Use 'main.py --help' to show their options and usage. 16 | """ 17 | segment_cli = """ 18 | Arguments for cell segment using HoVer-Net. 19 | 20 | Usage: 21 | segment --input_dir= --output_dir= 22 | 23 | Option: 24 | --input_dir= Path to input data directory. Assumes the files are not nested within directory. 25 | --output_dir= Path to output directory. 26 | """ 27 | 28 | feature_cli = """ 29 | Arguments for feature extract. 30 | 31 | Usage: 32 | feature --json_path= --wsi_path= --output_path= [--xml_path=] 33 | 34 | Option: 35 | --json_path= Path to HoVer-Net output, it show be a json file. 36 | --wsi_path= Path to wsi file. 37 | --output_path= Path to output. 38 | --xml_path= Path to xml. The xml is an annotation file of ImageScope. Only extract the feature in the annotation.[default: None] 39 | """ 40 | 41 | visual_cli = """ 42 | Arguments for visual. 43 | 44 | Usage: 45 | visual --feature_path= --wsi_path= --xml_path= 46 | 47 | Option: 48 | --feature_path= Path to feature folder, it show be a folder including feature and edge .csv file. 49 | --wsi_path= Path to wsi file. 50 | --xml_path= Path to xml file. The xml is an annotation file of ImageScope.\ 51 | Only plot in the scale of annotation.[default: None] 52 | """ 53 | 54 | from docopt import docopt 55 | 56 | if __name__ == '__main__': 57 | sub_cli_dict = {'segment':segment_cli, 58 | 'feature':feature_cli, 59 | 'visual':visual_cli} 60 | args = docopt(__doc__, help=False, options_first=True) 61 | sub_cmd = args.pop('') 62 | sub_cmd_args = args.pop('') 63 | 64 | if args['--help'] and sub_cmd is not None: 65 | if sub_cmd in sub_cli_dict: 66 | print(sub_cli_dict[sub_cmd]) 67 | else: 68 | print(__doc__) 69 | exit() 70 | if args['--help'] or sub_cmd is None: 71 | print(__doc__) 72 | exit() 73 | 74 | sub_args = docopt(sub_cli_dict[sub_cmd], argv=sub_cmd_args, help=True) 75 | sub_args = {k.replace('--', '') : v for k, v in sub_args.items()} 76 | print(sub_args) 77 | if sub_cmd=='segment': 78 | from F1_CellSegment import fun1 79 | import sys 80 | sys.path.append('Hover') 81 | fun1(**sub_args) 82 | elif sub_cmd=='feature': 83 | from F3_FeatureExtract import fun3 84 | if sub_args['xml_path'] == 'None': 85 | sub_args['xml_path'] = None 86 | fun3(**sub_args) 87 | elif sub_cmd=='visual': 88 | from F4_Visualization import fun4 89 | fun4(**sub_args) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Single Cell Morphological and Topological Profiling Based on Digital Pathology 2 | 3 | ## Description 4 | 5 | sc-MTOP is an analysis framework based on deep learning and computational pathology. It consists of two steps: 1) Nuclear segmentation and classification; 2) Feature extraction. This framework aims to characterize the tumor ecosystem diversity at the single-cell level. We have established a [demo](http://sc-mtop.biosolver.cn/) website to show the functions. 6 | 7 | This is the official pytorch implementation of sc-MTOP. Note that only the Nuclear segmentation and classification step supports batch processing. 8 | 9 | 10 | 1. `F1_CellSegment.py` for nuclear segmentation and classification: 11 | 12 | This step employs [HoVer-Net](https://github.com/vqdang/hover_net) for simultaneous nuclear segmentation and classification. The model is pre-trained based on PanNuke dataset and can be downloaded from [url](https://drive.google.com/file/d/1SbSArI3KOOWHxRlxnjchO7_MbWzB4lNR/view). 13 | 14 | Provide your WSI files as input. We use `.ndpi` WSI files in our work, and theoretically it supports all WSI file formats allowed by HoVer-Net. The step outputs a `.json` file including all information on nuclear segmentation and classification for each sample. 15 | 16 | 17 | 2. `F3_FeatureExtract.py` for feature extraction: 18 | 19 | This step extracts morphological, texture and topological features for individual tumor, inflammatory and stroma cells, which are the main cellular components of breast cancer ecosystem. 20 | 21 | Provide your WSI files and the corresponding `.json` files output by the segmentation step as input. It is allowed to define region of interest (ROI) using an `.xml` annotation file generated by the [ImageScope](https://www.leicabiosystems.com/zh/digital-pathology/manage/aperio-imagescope/) software. For each sample, the feature extraction step outputs a folder containing four `.csv` data files. For each type of tumor, inflammatory and stroma cells, one `.csv` files stores the features for all cells belonging to this type and each cell was identified by a unique cell ID together with the centroid’s spatial coordinates. The other `.csv` file stored the edge information for this sample and characterized each edge by the connected cell IDs. 22 | 23 | 3. `F4_Visualization.py` for visualization: 24 | 25 | We provide an additional function for the visualization of the nuclear segmentation results and nuclear graph. 26 | 27 | Provide the WSI files, the corresponding feature files output by the feature extraction step and an `.xml` annotation file defining the ROI. The output visualization results will be written in the annotation file and can be viewed using the ImageScope software. Note that ImageScope may fail to open the annotation file once your ROI is too large. 28 | 29 | ## Requirements 30 | ### Packages and version 31 | The packages required have been provided in the file `requirements.txt` 32 | ### Operating systems 33 | The code have been tested in the Windows and Ubuntu 16.04.7 LTS.The installation in the different operation systems may be different because of some packages. 34 | ### Hardware 35 | The code involves deep learning-based neural network inference, so it needs GPU with more than 8GB video memory. HoVer-Net needs SSD at least 100GB for cache. The requirement of RAM depends on the data size and we suggest that it should be more than 128GB. The code has been tested on GeForce GTX 2080Ti NVIDIA GPU, RAM 128GB. 36 | 37 | ## Installation 38 | To install the environment, you can run the command in the terminal: 39 | ``` 40 | pip install -r requirements.txt 41 | ``` 42 | The code require package `openslide python`, but its installation is different between Linux and Windows. Please follow the [offical documentation](https://openslide.org/api/python/) to install and import it in python to make sure it can work correctly. 43 | The pre-trained HoVer-Net model is not provided in the source code due to the file size. You can download it following the [Description](#hovernet) or you can download it in our [release](https://github.com/fuscc-deep-path/sc_MTOP/releases/download/Demo/hovernet_fast_pannuke_type_tf2pytorch.tar). 44 | 45 | ## Repository Structure 46 | `Hover`: the implementation of HoVer-Net, which is cloned from the official [implementation](https://github.com/vqdang/hover_net) 47 | `main.py`: main function 48 | `F1_CellSegment.py`: nuclear segmentation and classification by calling `Hover`. 49 | `F3_FeatureExtract.py`: feature extraction by calling `WSIGraph.py`. 50 | `F4_Visualization.py`: visualization by calling `utils_xml.py`. 51 | `utils_xml.py`: define some tools to finish visualization. 52 | `WSIGraph.py`: define the process of feature extraction. 53 | 54 | ## Usage Demo 55 | Here is a demo to use it in the bash terminal of Ubuntu. Some commands may not work in different terminal. 56 | To run the whole demo, you should get the demo data and pre-train parameter first. Download them with the follow command: 57 | Download the pre-train network parameter 58 | ``` 59 | wget --no-check-certificate --content-disposition -P ./Hover https://github.com/fuscc-deep-path/sc_MTOP/releases/download/Demo/hovernet_fast_pannuke_type_tf2pytorch.tar 60 | ``` 61 | Download the demo data 62 | ``` 63 | mkdir -p {wsi,xml,fun_fig} 64 | wget --no-check-certificate --content-disposition -P ./wsi https://github.com/fuscc-deep-path/sc_MTOP/releases/download/Demo/Example001.ndpi 65 | wget --no-check-certificate --content-disposition -P ./xml https://github.com/fuscc-deep-path/sc_MTOP/releases/download/Demo/Example001.xml 66 | ``` 67 | Nuclear segmentation and classification -- This step takes almost 2 hours with 2080Ti GPU and SSD. 68 | ``` 69 | python main.py segment --input_dir='./wsi' --output_dir='./output' 70 | ``` 71 | Feature extraction -- This step takes almost 40 minutes with 128GB RAM and 8 process. 72 | ``` 73 | python main.py feature --json_path='./output/json/Example001.json' --wsi_path='./wsi/Example001.ndpi' --output_path='./feature' 74 | ``` 75 | Visualization -- output will be in 'fun_fig' directory 76 | ``` 77 | python main.py visual --feature_path='./feature/Example001' --wsi_path='./wsi/Example001.ndpi' --xml_path='./xml/Example001.xml' 78 | ``` 79 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | docopt==0.6.2 2 | imgaug==0.4.0 3 | matplotlib==3.6.2 4 | numpy==1.23.5 5 | opencv_python==4.6.0.66 6 | openslide_python==1.2.0 7 | pandas==1.5.2 8 | psutil==5.9.4 9 | python_igraph==0.10.2 10 | scikit_image==0.19.2 11 | scikit_learn==1.2.0 12 | scipy==1.9.3 13 | tensorboardX==2.5.1 14 | termcolor==2.1.1 15 | torch==1.12.1+cu102 16 | tqdm==4.64.1 17 | --extra-index-url https://download.pytorch.org/whl/cu102 18 | --------------------------------------------------------------------------------