├── .gitignore ├── LICENSE ├── README.md ├── feature_maps.gif ├── requirements.txt └── src ├── README.md ├── config.py ├── extract_patches.py ├── infer.py ├── loader ├── custom_augs.py └── loader.py ├── misc ├── patch_extractor.py ├── utils.py └── viz_utils.py ├── model ├── class_pcam │ └── graph.py ├── seg_gland │ └── graph.py ├── seg_nuc │ └── graph.py └── utils │ ├── gconv_utils.py │ ├── model_utils.py │ ├── norm_utils.py │ └── rotation_utils.py ├── opt ├── augs.py └── params.py ├── process.py ├── train.py └── viz_filters.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled/cached Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Temporaries 7 | .ipynb_checkpoints/ 8 | *.swp 9 | 10 | # Outputs 11 | *.png 12 | *.svg 13 | *.eps 14 | *.out 15 | 16 | # Model and other data 17 | *.pth 18 | *.npy 19 | .idea 20 | *.DS_Store 21 | .pytest_cache/ 22 | *.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Simon Graham 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dense Steerable Filter CNNs for Expoiting Rotational Symmetry in Histology Images 2 | 3 | A densely connected rotation-equivariant CNN for histology image analysis.
4 | 5 | [Link](https://arxiv.org/abs/2004.03037) to the pre-print.
6 | 7 | **NEWS**: Our paper has now been published in IEEE Transactions on Medical Imaging. Find the published article [here](https://ieeexplore.ieee.org/document/9153847). 8 | 9 | ## Getting Started 10 | 11 | Environment instructions: 12 | 13 | ``` 14 | conda create --name dsf-cnn python=3.6 15 | conda activate dsf-cnn 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## Repository Structure 20 | 21 | - `src/` contains executable files used to run the model. Further information on running the code can be found in the corresponding directory. 22 | - `loader/`contains scripts for data loading and self implemented augmentation functions. 23 | - `misc/`contains util scripts. 24 | - `model/class_pcam/` model architecture for dsf-cnn on PCam dataset 25 | - `model/seg_nuc/` model architecture for dsf-cnn on Kumar dataset 26 | - `model/seg_gland/` model architecture for dsf-cnn on CRAG dataset 27 | - `model/utils/` contains util scripts for the models. 28 | - `opt/` contains scripts that define the model hyperparameters and augmentation pipeline. 29 | - `config.py` is the configuration file. Paths need to be changed accordingly. 30 | - `train.py` and `infer.py` are the training and inference scripts respectively. 31 | - `process.py` is the post processing script for obtaining the final instances for segmentation. 32 | 33 |

34 | Segmentation 35 |

36 | 37 | ## Citation 38 | 39 | If any part of this code is used, please give appropriate citation to our paper.
40 | 41 | ``` 42 | @article{graham2020dense, 43 | title={Dense Steerable Filter CNNs for Exploiting Rotational Symmetry in Histology Images}, 44 | author={Graham, Simon and Epstein, David and Rajpoot, Nasir}, 45 | journal={arXiv preprint arXiv:2004.03037}, 46 | year={2020} 47 | } 48 | ``` 49 | 50 | ## Authors 51 | 52 | See the list of [contributors](https://github.com/simongraham/dsf-cnn/graphs/contributors) who participated in this project. 53 | 54 | ## License 55 | 56 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details 57 | -------------------------------------------------------------------------------- /feature_maps.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simongraham/dsf-cnn/69ba9fb4834e86c238ce866d3ddebfc3fe6c5c57/feature_maps.gif -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | docopt==0.6.2 2 | tensorflow-gpu==1.12.0 3 | tensorpack==0.9.0.1 4 | scikit-image==0.14.2 5 | matplotlib===3.0.2 6 | numpy==1.15.4 7 | opencv-python==4.1.2.30 8 | -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | # Training and Inference Instructions 2 | 3 | ## Choosing the network 4 | 5 | The model to use and the selection of other hyperparameters is selected in `config.py`. The models available are: 6 | - Classification PCam DSF-CNN: `model/class_pcam/graph.py` 7 | - Segmentation CRAG DSF-CNN: `model/seg_gland/graph.py` 8 | - Segmentation Kumar DSF-CNN: `model/seg_nuc/graph.py` 9 | 10 | ## Modifying Hyperparameters 11 | 12 | To modify hyperparameters, refer to `opt/params.py` 13 | 14 | ## Augmentation 15 | 16 | To modify the augmentation pipeline, refer to `get_train_augmentors()` in `opt/augs.py`. Refer to [this webpage](https://tensorpack.readthedocs.io/modules/dataflow.imgaug.html) for information on how to modify the augmentation parameters. 17 | 18 | ## Data Format 19 | 20 | For segmentation, store patches in a 4 dimensional numpy array with channels [RGB, inst]. Here, inst is the instance segmentation ground truth. I.e pixels range from 0 to N, where 0 is background and N is the number of nuclear instances for that particular image.
For classification, save each image patch as a 1D array of size [(H * W * C)+1], where H, W and C refer to height, width and number of channels. The final value is the label of the image (starting from 0). 21 | 22 | ## Training 23 | 24 | To train the network, the command is:
25 | 26 | `python train.py --gpu=`
27 | where gpu_id is a comma separated list which denotes which GPU(s) will be used for training. For example, if we are using GPU number 0 and 1, the command is:
28 | `python train.py --gpu='0,1'`
29 | 30 | Before training, set in `config.py`: 31 | - path to the data directories 32 | - path where checkpoints will be saved 33 | - path to where the output will be saved 34 | 35 | ## Inference 36 | 37 | To generate the network predictions, the command is:
38 | `python infer.py --gpu= --mode=`
39 | Currently, the inference code only supports 1 GPU. For ``, use `'seg'` or `'class'`. Use `'class'` when processing PCam and `'seg'` when processing Kumar and CRAG. 40 | 41 | Before running inference, set in `config.py`: 42 | - path where the output will be saved 43 | - path to data root directories 44 | - path to model checkpoint. 45 | 46 | ## Post Processing 47 | 48 | To obtain the final segmentation, use the command:
49 | `python process.py`
50 | for post-processing the network predictions. Note, this is only for segmentation networks. 51 | 52 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration file 3 | """ 4 | 5 | import importlib 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | import opt.augs as augs 10 | 11 | from model.utils.gconv_utils import get_basis_filters, get_rot_info, get_basis_params 12 | 13 | 14 | class Config(object): 15 | def __init__(self,): 16 | 17 | self.seed = 10 18 | 19 | self.model_mode = "seg_nuc" # choose seg_gland, seg_nuc, class_pcam 20 | self.filter_type = "steerable" # choose steerable or standard 21 | 22 | self.nr_orients = 8 # number of orientations for the filters 23 | 24 | #### Dynamically setting the hyper-param and aug files into variable 25 | param_file = importlib.import_module("opt.params") 26 | param_dict = param_file.__getattribute__(self.model_mode) 27 | 28 | for variable, value in param_dict.items(): 29 | self.__setattr__(variable, value) 30 | 31 | self.data_ext = ".npy" 32 | # list of directories containing training and validation files. 33 | # Each directory contains one numpy file per image. 34 | # if self.model_mode = 'seg_nuc' or 'seg_gland', 35 | # data is of size [H,W,4] (RGB + instance label) 36 | # if self.mode == 'class_pcam', data is of size [(H*W*C)+1], 37 | # where the final value is the class label 38 | self.train_dir = ["/media/simon/Storage 1/Data/Nuclei/patches/kumar/train/"] 39 | self.valid_dir = ["/media/simon/Storage 1/Data/Nuclei/patches/kumar/valid/"] 40 | 41 | # nr of processes for parallel processing input 42 | self.nr_procs_train = 8 43 | self.nr_procs_valid = 4 44 | 45 | exp_id = "v1.0" 46 | # loading chkpts in tensorflow, the path must not contain extra '/' 47 | self.log_path = "/media/simon/Storage 1/dsf-cnn/checkpoints/" # log root path 48 | self.save_dir = "%s/%s/%s_%s_%s" % ( 49 | self.log_path, 50 | self.model_mode, 51 | self.filter_type, 52 | self.nr_orients, 53 | exp_id, 54 | ) # log file destination 55 | 56 | #### Info for running inference 57 | self.inf_auto_find_chkpt = False 58 | # path to checkpoints will be used for inference, replace accordingly 59 | self.inf_model_path = self.save_dir + "model-xxxx.index" 60 | 61 | # paths to files for inference. Note, for PCam we use the original .h5 file. 62 | self.inf_imgs_ext = ".tif" 63 | self.inf_data_list = ["/media/simon/Storage 1/Data/Nuclei/kumar/test/Images/"] 64 | 65 | output_root = "/media/simon/Storage 1/output/" 66 | # log file destination 67 | self.inf_output_dir = "%s/%s_%s/" % ( 68 | output_root, 69 | self.filter_type, 70 | self.nr_orients, 71 | ) 72 | 73 | if self.filter_type == "steerable": 74 | # Generate the basis filters- only need to do this once before training 75 | self.basis_filter_list = [] 76 | self.rot_matrix_list = [] 77 | for ksize in self.filter_sizes: 78 | alpha_list, beta_list, bl_list = get_basis_params(ksize) 79 | b_filters, freq_filters = get_basis_filters( 80 | alpha_list, beta_list, bl_list, ksize 81 | ) 82 | self.basis_filter_list.append(b_filters) 83 | self.rot_matrix_list.append(get_rot_info(self.nr_orients, freq_filters)) 84 | 85 | # for inference during evalutaion mode i.e run by inferer.py 86 | self.eval_inf_input_tensor_names = ["images"] 87 | self.eval_inf_output_tensor_names = ["predmap-coded"] 88 | # for inference during training mode i.e run by trainer.py 89 | self.train_inf_output_tensor_names = ["predmap-coded", "truemap-coded"] 90 | 91 | def get_model(self): 92 | model_constructor = importlib.import_module("model.%s.graph" % self.model_mode) 93 | model_constructor = model_constructor.Graph 94 | return model_constructor 95 | 96 | # refer to https://tensorpack.readthedocs.io/modules/dataflow.imgaug.html 97 | # for information on how to modify the augmentation parameters. 98 | # Pipeline can be modified in opt/augs.py 99 | 100 | def get_train_augmentors(self, input_shape, output_shape, view=False): 101 | return augs.get_train_augmentors(self, input_shape, output_shape, view) 102 | 103 | def get_valid_augmentors(self, input_shape, output_shape, view=False): 104 | return augs.get_valid_augmentors(self, input_shape, output_shape, view) 105 | -------------------------------------------------------------------------------- /src/extract_patches.py: -------------------------------------------------------------------------------- 1 | """extract_patches.py 2 | 3 | Script for extracting patches from image tiles. The script will read 4 | and RGB image and a corresponding label and form image patches to be 5 | used by the network. 6 | """ 7 | 8 | 9 | import glob 10 | import os 11 | 12 | import cv2 13 | import numpy as np 14 | 15 | from misc.patch_extractor import PatchExtractor 16 | from misc.utils import rm_n_mkdir 17 | 18 | from config import Config 19 | 20 | ########################################################################### 21 | if __name__ == "__main__": 22 | 23 | cfg = Config() 24 | 25 | extract_type = "mirror" # 'valid' or 'mirror' 26 | # 'mirror' reflects at the borders; 'valid' doesn't. 27 | # check the patch_extractor.py 'main' to see the difference 28 | 29 | # original size (win size) - input size - output size (step size) 30 | step_size = [112, 112] 31 | # set to size of network input: 448 for glands, 256 for nuclei 32 | win_size = [448, 448] 33 | 34 | xtractor = PatchExtractor(win_size, step_size) 35 | 36 | ### Paths to data - these need to be modified according to where the original data is stored 37 | img_ext = ".png" 38 | # img_dir should contain RGB image tiles from where to extract patches. 39 | img_dir = "path/to/images/" 40 | # ann_dir should contain 2D npy image tiles, with values ranging from 0 to N. 41 | # 0 is background and then each nucleus is uniquely labelled from 1-N. 42 | ann_dir = "path/to/labels/" 43 | #### 44 | out_dir = "output_path/%dx%d_%dx%d" % ( 45 | win_size[0], 46 | win_size[1], 47 | step_size[0], 48 | step_size[1], 49 | ) 50 | 51 | file_list = glob.glob("%s/*%s" % (img_dir, img_ext)) 52 | file_list.sort() 53 | 54 | rm_n_mkdir(out_dir) 55 | for filename in file_list: 56 | filename = os.path.basename(filename) 57 | basename = filename.split(".")[0] 58 | print(filename) 59 | 60 | img = cv2.imread(img_dir + basename + img_ext) 61 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 62 | 63 | # assumes that ann is HxW 64 | ann_inst = np.load(ann_dir + basename + ".npy") 65 | ann_inst = ann_inst.astype("int32") 66 | ann = np.expand_dims(ann_inst, -1) 67 | 68 | img = np.concatenate([img, ann], axis=-1) 69 | sub_patches = xtractor.extract(img, extract_type) 70 | for idx, patch in enumerate(sub_patches): 71 | np.save("{0}/{1}_{2:03d}.npy".format(out_dir, basename, idx), patch) 72 | -------------------------------------------------------------------------------- /src/infer.py: -------------------------------------------------------------------------------- 1 | """infer.py 2 | 3 | Main inference script. 4 | 5 | Usage: 6 | infer.py [--gpu=] [--mode=] 7 | infer.py (-h | --help) 8 | infer.py --version 9 | 10 | Options: 11 | -h --help Show this string. 12 | --version Show version. 13 | --gpu= Comma separated GPU list. [default: 0] 14 | --mode= Inference mode- use either 'seg' or 'class'. 15 | """ 16 | 17 | from docopt import docopt 18 | import argparse 19 | import glob 20 | import math 21 | import time 22 | import os 23 | import sys 24 | from collections import deque 25 | from keras.utils import HDF5Matrix 26 | 27 | import cv2 28 | import numpy as np 29 | 30 | from tensorpack.predict import OfflinePredictor, PredictConfig 31 | from tensorpack.tfutils.sessinit import get_model_loader 32 | 33 | from config import Config 34 | from misc.utils import rm_n_mkdir, cropping_center 35 | from model.utils.model_utils import crop_op 36 | 37 | import json 38 | import operator 39 | 40 | from sklearn.metrics import roc_auc_score 41 | 42 | 43 | def get_best_chkpts(path, metric_name, comparator=">"): 44 | """ 45 | Return the best checkpoint according to some criteria. 46 | Note that it will only return valid path, so any checkpoint that has been 47 | removed wont be returned (i.e moving to next one that satisfies the criteria 48 | such as second best etc.) 49 | 50 | Args: 51 | path: directory contains all checkpoints, including the "stats.json" file 52 | """ 53 | stat_file = path + "/stats.json" 54 | ops = { 55 | ">": operator.gt, 56 | "<": operator.lt, 57 | } 58 | 59 | op_func = ops[comparator] 60 | with open(stat_file) as f: 61 | info = json.load(f) 62 | 63 | if comparator == ">": 64 | best_value = -float("inf") 65 | else: 66 | best_value = +float("inf") 67 | 68 | best_chkpt = None 69 | for epoch_stat in info: 70 | epoch_value = epoch_stat[metric_name] 71 | if op_func(epoch_value, best_value): 72 | chkpt_path = "%s/model-%d.index" % (path, epoch_stat["global_step"]) 73 | if os.path.isfile(chkpt_path): 74 | selected_stat = epoch_stat 75 | best_value = epoch_value 76 | best_chkpt = chkpt_path 77 | return best_chkpt, selected_stat 78 | 79 | 80 | class InferClass(Config): 81 | def __gen_prediction(self, x, predictor): 82 | """ 83 | Using 'predictor' to generate the prediction of image 'x' 84 | 85 | Args: 86 | x : input image to be segmented. It will be split into patches 87 | to run the prediction upon before being assembled back 88 | tissue: tissue mask created via otsu -> only process tissue regions 89 | if it is provided 90 | """ 91 | 92 | prob = predictor(x)[0] 93 | pred = np.argmax(prob, -1) 94 | pred = np.squeeze(pred) 95 | prob = np.squeeze(prob[..., 1]) 96 | 97 | return prob, pred 98 | 99 | #### 100 | def run(self): 101 | if self.inf_auto_find_chkpt: 102 | print( 103 | '-----Auto Selecting Checkpoint Basing On "%s" Through "%s" Comparison' 104 | % (self.inf_auto_metric, self.inf_auto_comparator) 105 | ) 106 | model_path, stat = get_best_chkpts( 107 | self.save_dir, self.inf_auto_metric, self.inf_auto_comparator 108 | ) 109 | print("Selecting: %s" % model_path) 110 | print("Having Following Statistics:") 111 | for key, value in stat.items(): 112 | print("\t%s: %s" % (key, value)) 113 | else: 114 | model_path = self.inf_model_path 115 | model_constructor = self.get_model() 116 | pred_config = PredictConfig( 117 | model=model_constructor(), 118 | session_init=get_model_loader(model_path), 119 | input_names=self.eval_inf_input_tensor_names, 120 | output_names=self.eval_inf_output_tensor_names, 121 | create_graph=False) 122 | predictor = OfflinePredictor(pred_config) 123 | 124 | 125 | #### 126 | save_dir = self.inf_output_dir 127 | predict_list = [["case", "prediction"]] 128 | 129 | file_load_img = HDF5Matrix( 130 | self.inf_data_list[0] + "camelyonpatch_level_2_split_test_x.h5", "x" 131 | ) 132 | file_load_lab = HDF5Matrix( 133 | self.inf_data_list[0] + "camelyonpatch_level_2_split_test_y.h5", "y" 134 | ) 135 | 136 | true_list = [] 137 | prob_list = [] 138 | pred_list = [] 139 | 140 | num_ims = file_load_img.shape[0] 141 | last_step = math.floor(num_ims / self.inf_batch_size) 142 | last_step = self.inf_batch_size * last_step 143 | last_batch = num_ims - last_step 144 | count = 0 145 | for start_batch in range(0, last_step + 1, self.inf_batch_size): 146 | sys.stdout.write("\rProcessed (%d/%d)" % (start_batch, num_ims)) 147 | sys.stdout.flush() 148 | if start_batch != last_step: 149 | img = file_load_img[start_batch : start_batch + self.inf_batch_size] 150 | img = img.astype("uint8") 151 | lab = np.squeeze( 152 | file_load_lab[start_batch : start_batch + self.inf_batch_size] 153 | ) 154 | else: 155 | img = file_load_img[start_batch : start_batch + last_batch] 156 | img = img.astype("uint8") 157 | lab = np.squeeze(file_load_lab[start_batch : start_batch + last_batch]) 158 | 159 | prob, pred = self.__gen_prediction(img, predictor) 160 | 161 | for j in range(prob.shape[0]): 162 | predict_list.append([str(count), str(prob[j])]) 163 | count += 1 164 | 165 | prob_list.extend(prob) 166 | pred_list.extend(pred) 167 | true_list.extend(lab) 168 | 169 | prob_list = np.array(prob_list) 170 | pred_list = np.array(pred_list) 171 | true_list = np.array(true_list) 172 | accuracy = (pred_list == true_list).sum() / np.size(true_list) 173 | error = (pred_list != true_list).sum() / np.size(true_list) 174 | 175 | print("Accurcy (%): ", 100 * accuracy) 176 | print("Error (%): ", 100 * error) 177 | if self.model_mode == "class_pcam": 178 | auc = roc_auc_score(true_list, prob_list) 179 | print("AUC: ", auc) 180 | 181 | # Save predictions to csv 182 | rm_n_mkdir(save_dir) 183 | for result in predict_list: 184 | predict_file = open("%s/predict.csv" % save_dir, "a") 185 | predict_file.write(result[0]) 186 | predict_file.write(",") 187 | predict_file.write(result[1]) 188 | predict_file.write("\n") 189 | predict_file.close() 190 | 191 | 192 | class InferSeg(Config): 193 | def __gen_prediction(self, x, predictor): 194 | """ 195 | Using 'predictor' to generate the prediction of image 'x' 196 | 197 | Args: 198 | x : input image to be segmented. It will be split into patches 199 | to run the prediction upon before being assembled back 200 | tissue: tissue mask created via otsu -> only process tissue regions 201 | if it is provided 202 | """ 203 | step_size = self.infer_output_shape 204 | msk_size = self.infer_output_shape 205 | win_size = self.infer_input_shape 206 | 207 | def get_last_steps(length, msk_size, step_size): 208 | nr_step = math.ceil((length - msk_size) / step_size) 209 | last_step = (nr_step + 1) * step_size 210 | return int(last_step), int(nr_step + 1) 211 | 212 | im_h = x.shape[0] 213 | im_w = x.shape[1] 214 | 215 | last_h, nr_step_h = get_last_steps(im_h, msk_size[0], step_size[0]) 216 | last_w, nr_step_w = get_last_steps(im_w, msk_size[1], step_size[1]) 217 | 218 | diff_h = win_size[0] - step_size[0] 219 | padt = diff_h // 2 220 | padb = last_h + win_size[0] - im_h 221 | 222 | diff_w = win_size[1] - step_size[1] 223 | padl = diff_w // 2 224 | padr = last_w + win_size[1] - im_w 225 | 226 | x = np.lib.pad(x, ((padt, padb), (padl, padr), (0, 0)), "reflect") 227 | 228 | #### TODO: optimize this 229 | sub_patches = [] 230 | skipped_idx = [] 231 | # generating subpatches from orginal 232 | idx = 0 233 | for row in range(0, last_h, step_size[0]): 234 | for col in range(0, last_w, step_size[1]): 235 | win = x[row : row + win_size[0], col : col + win_size[1]] 236 | sub_patches.append(win) 237 | idx += 1 238 | 239 | pred_map = deque() 240 | while len(sub_patches) > self.inf_batch_size: 241 | mini_batch = sub_patches[: self.inf_batch_size] 242 | sub_patches = sub_patches[self.inf_batch_size :] 243 | mini_output = predictor(mini_batch)[0] 244 | if win_size[0] > msk_size[0]: 245 | mini_output = cropping_center(mini_output, (diff_h, diff_w)) 246 | mini_output = np.split(mini_output, self.inf_batch_size, axis=0) 247 | pred_map.extend(mini_output) 248 | if len(sub_patches) != 0: 249 | mini_output = predictor(sub_patches)[0] 250 | if win_size[0] > msk_size[0]: 251 | mini_output = cropping_center(mini_output, (diff_h, diff_w)) 252 | mini_output = np.split(mini_output, len(sub_patches), axis=0) 253 | pred_map.extend(mini_output) 254 | 255 | #### Assemble back into full image 256 | output_patch_shape = np.squeeze(pred_map[0]).shape 257 | 258 | ch = 1 if len(output_patch_shape) == 2 else output_patch_shape[-1] 259 | 260 | #### Assemble back into full image 261 | pred_map = np.squeeze(np.array(pred_map)) 262 | pred_map = np.reshape(pred_map, (nr_step_h, nr_step_w) + pred_map.shape[1:]) 263 | pred_map = ( 264 | np.transpose(pred_map, [0, 2, 1, 3, 4]) 265 | if ch != 1 266 | else np.transpose(pred_map, [0, 2, 1, 3]) 267 | ) 268 | pred_map = np.reshape( 269 | pred_map, 270 | ( 271 | pred_map.shape[0] * pred_map.shape[1], 272 | pred_map.shape[2] * pred_map.shape[3], 273 | ch, 274 | ), 275 | ) 276 | # just crop back to original size 277 | pred_map = np.squeeze(pred_map[:im_h, :im_w]) 278 | 279 | return pred_map 280 | 281 | #### 282 | def run(self): 283 | 284 | if self.inf_auto_find_chkpt: 285 | print( 286 | '-----Auto Selecting Checkpoint Basing On "%s" Through "%s" Comparison' 287 | % (self.inf_auto_metric, self.inf_auto_comparator) 288 | ) 289 | model_path, stat = get_best_chkpts( 290 | self.save_dir, self.inf_auto_metric, self.inf_auto_comparator 291 | ) 292 | print("Selecting: %s" % model_path) 293 | print("Having Following Statistics:") 294 | for key, value in stat.items(): 295 | print("\t%s: %s" % (key, value)) 296 | else: 297 | model_path = self.inf_model_path 298 | 299 | model_constructor = self.get_model() 300 | pred_config = PredictConfig( 301 | model=model_constructor(), 302 | session_init=get_model_loader(model_path), 303 | input_names=self.eval_inf_input_tensor_names, 304 | output_names=self.eval_inf_output_tensor_names, 305 | create_graph=False, 306 | ) 307 | predictor = OfflinePredictor(pred_config) 308 | 309 | for data_dir in self.inf_data_list: 310 | save_dir = self.inf_output_dir + "/raw/" 311 | file_list = glob.glob("%s/*%s" % (data_dir, self.inf_imgs_ext)) 312 | file_list.sort() # ensure same order 313 | 314 | rm_n_mkdir(save_dir) 315 | for filename in file_list: 316 | start = time.time() 317 | filename = os.path.basename(filename) 318 | basename = filename.split(".")[0] 319 | print(data_dir, basename, end=" ", flush=True) 320 | 321 | ## 322 | img = cv2.imread(data_dir + filename) 323 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 324 | 325 | pred_map = self.__gen_prediction(img, predictor) 326 | 327 | np.save("%s/%s.npy" % (save_dir, basename), [pred_map]) 328 | end = time.time() 329 | diff = str(round(end - start, 2)) 330 | print("FINISH. TIME: %s" % diff) 331 | 332 | 333 | #### 334 | if __name__ == "__main__": 335 | args = docopt(__doc__) 336 | print(args) 337 | 338 | if args["--gpu"]: 339 | os.environ["CUDA_VISIBLE_DEVICES"] = args["--gpu"] 340 | nr_gpus = len(args["--gpu"].split(",")) 341 | 342 | if args["--mode"] is None: 343 | raise Exception('Mode cannot be empty. Use either "class" or "seg".') 344 | 345 | if args["--mode"] == "class": 346 | infer = InferClass() 347 | elif args["--mode"] == "seg": 348 | infer = InferSeg() 349 | else: 350 | raise Exception('Mode not recognised. Use either "class" or "seg".') 351 | 352 | infer.run() 353 | -------------------------------------------------------------------------------- /src/loader/custom_augs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom augmentations 3 | """ 4 | 5 | import math 6 | import cv2 7 | import matplotlib.cm as cm 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | 11 | from scipy import ndimage 12 | from scipy.ndimage import measurements 13 | from scipy.ndimage.filters import gaussian_filter 14 | from scipy.ndimage.interpolation import affine_transform, map_coordinates 15 | from scipy.ndimage.morphology import distance_transform_cdt, distance_transform_edt 16 | from skimage import morphology as morph 17 | 18 | from tensorpack.dataflow.imgaug import ImageAugmentor 19 | from tensorpack.utils.utils import get_rng 20 | 21 | from misc.utils import cropping_center, bounding_box 22 | 23 | 24 | class GenInstance(ImageAugmentor): 25 | def __init__(self, crop_shape=None): 26 | super(GenInstance, self).__init__() 27 | self.crop_shape = crop_shape 28 | 29 | def reset_state(self): 30 | self.rng = get_rng(self) 31 | 32 | def _fix_mirror_padding(self, ann): 33 | """ 34 | Deal with duplicated instances due to mirroring in interpolation 35 | during shape augmentation (scale, rotation etc.) 36 | """ 37 | current_max_id = np.amax(ann) 38 | inst_list = list(np.unique(ann)) 39 | try: 40 | inst_list.remove(0) # remove background 41 | except ValueError: 42 | pass 43 | for inst_id in inst_list: 44 | inst_map = np.array(ann == inst_id, np.uint8) 45 | remapped_ids = measurements.label(inst_map)[0] 46 | remapped_ids[remapped_ids > 1] += current_max_id 47 | ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1] 48 | current_max_id = np.amax(ann) 49 | return ann 50 | 51 | 52 | class GenInstanceContourMap(GenInstance): 53 | """ 54 | Input annotation must be of original shape. 55 | """ 56 | 57 | def __init__(self, mode, crop_shape=None): 58 | super(GenInstanceContourMap, self).__init__() 59 | self.crop_shape = crop_shape 60 | self.mode = mode 61 | 62 | def _augment(self, img, _): 63 | img = np.copy(img) 64 | orig_ann = img[..., 0] # instance ID map 65 | fixed_ann = self._fix_mirror_padding(orig_ann) 66 | fixed_ann = orig_ann 67 | # re-cropping with fixed instance id map 68 | # crop_ann = cropping_center(fixed_ann, self.crop_shape) 69 | 70 | # setting 1 boundary pix of each instance to background 71 | inner_map = np.zeros(fixed_ann.shape[:2], np.uint8) 72 | contour_map = np.zeros(fixed_ann.shape[:2], np.uint8) 73 | 74 | inst_list = list(np.unique(fixed_ann)) 75 | 76 | try: # remove background 77 | inst_list.remove(0) # 0 is background 78 | except ValueError: 79 | pass 80 | 81 | if self.mode == "seg_gland": 82 | k_disk = np.array( 83 | [ 84 | [0, 0, 0, 0, 1, 0, 0, 0, 0], 85 | [0, 0, 0, 1, 1, 1, 0, 0, 0], 86 | [0, 0, 1, 1, 1, 1, 1, 0, 0], 87 | [0, 1, 1, 1, 1, 1, 1, 1, 0], 88 | [1, 1, 1, 1, 1, 1, 1, 1, 1], 89 | [0, 1, 1, 1, 1, 1, 1, 1, 0], 90 | [0, 0, 1, 1, 1, 1, 1, 0, 0], 91 | [0, 0, 0, 1, 1, 1, 0, 0, 0], 92 | [0, 0, 0, 0, 1, 0, 0, 0, 0], 93 | ], 94 | np.uint8, 95 | ) 96 | else: 97 | k_disk = np.array( 98 | [ 99 | [0, 0, 1, 0, 0], 100 | [0, 1, 1, 1, 0], 101 | [1, 1, 1, 1, 1], 102 | [0, 1, 1, 1, 0], 103 | [0, 0, 1, 0, 0], 104 | ], 105 | np.uint8, 106 | ) 107 | 108 | for inst_id in inst_list: 109 | inst_map = np.array(fixed_ann == inst_id, np.uint8) 110 | inner = cv2.erode(inst_map, k_disk, iterations=1) 111 | outer = cv2.dilate(inst_map, k_disk, iterations=1) 112 | inner_map += inner 113 | contour_map += outer - inner 114 | inner_map[inner_map > 0] = 1 # binarize 115 | contour_map[contour_map > 0] = 1 # binarize 116 | bg_map = 1 - (inner_map + contour_map) 117 | img = np.dstack([inner_map, contour_map, bg_map, img[..., 1:]]) 118 | return img 119 | 120 | 121 | class GenInstanceMarkerMap(GenInstance): 122 | """ 123 | Input annotation must be of original shape. 124 | Perform following operation: 125 | 1) Remove the 1px of boundary of each instance 126 | to create separation between touching instances 127 | 2) Generate the weight map from the result of 1) 128 | according to the unet paper equation. 129 | Args: 130 | wc (dict) : Dictionary of weight classes. 131 | w0 (int/float) : Border weight parameter. 132 | sigma (int/float): Border width parameter. 133 | """ 134 | 135 | def __init__(self, wc=None, w0=10.0, sigma=4.0, crop_shape=None): 136 | super(GenInstanceMarkerMap, self).__init__() 137 | self.crop_shape = crop_shape 138 | self.wc = wc 139 | self.w0 = w0 140 | self.sigma = sigma 141 | 142 | def _erode_obj(self, ann): 143 | new_ann = np.zeros(ann.shape[:2], np.int32) 144 | inst_list = list(np.unique(ann)) 145 | inst_list.remove(0) # 0 is background 146 | 147 | inner_map = np.zeros(ann.shape[:2], np.uint8) 148 | contour_map = np.zeros(ann.shape[:2], np.uint8) 149 | 150 | k = np.array( 151 | [ 152 | [0, 0, 0, 1, 0, 0, 0], 153 | [0, 0, 1, 1, 1, 0, 0], 154 | [0, 1, 1, 1, 1, 1, 0], 155 | [1, 1, 1, 1, 1, 1, 1], 156 | [0, 1, 1, 1, 1, 1, 0], 157 | [0, 0, 1, 1, 1, 0, 0], 158 | [0, 0, 0, 1, 0, 0, 0], 159 | ], 160 | np.uint8, 161 | ) 162 | 163 | for inst_id in inst_list: 164 | inst_map = np.array(ann == inst_id, np.uint8) 165 | inner = cv2.erode(inst_map, k, iterations=1) 166 | outer = cv2.dilate(inst_map, k, iterations=1) 167 | inner_map += inner 168 | contour_map += outer - inner 169 | inner_map[inner_map > 0] = 1 # binarize 170 | contour_map[contour_map > 0] = 1 # binarize 171 | bg_map = 1 - (inner_map + contour_map) 172 | 173 | return inner_map, contour_map, bg_map 174 | 175 | def _get_weight_map(self, ann, inst_list): 176 | if len(inst_list) <= 1: # 1 instance only 177 | return np.zeros(ann.shape[:2]) 178 | stacked_inst_bgd_dst = np.zeros(ann.shape[:2] + (len(inst_list),)) 179 | 180 | for idx, inst_id in enumerate(inst_list): 181 | inst_bgd_map = np.array(ann != inst_id, np.uint8) 182 | inst_bgd_dst = distance_transform_edt(inst_bgd_map) 183 | stacked_inst_bgd_dst[..., idx] = inst_bgd_dst 184 | 185 | near1_dst = np.amin(stacked_inst_bgd_dst, axis=2) 186 | near2_dst = np.expand_dims(near1_dst, axis=2) 187 | near2_dst = stacked_inst_bgd_dst - near2_dst 188 | near2_dst[near2_dst == 0] = np.PINF # very large 189 | near2_dst = np.amin(near2_dst, axis=2) 190 | near2_dst[ann > 0] = 0 # the instances 191 | near2_dst = near2_dst + near1_dst 192 | # to fix pixel where near1 == near2 193 | near2_eve = np.expand_dims(near1_dst, axis=2) 194 | # to avoide the warning of a / 0 195 | near2_eve = (1.0 + stacked_inst_bgd_dst) / (1.0 + near2_eve) 196 | near2_eve[near2_eve != 1] = 0 197 | near2_eve = np.sum(near2_eve, axis=2) 198 | near2_dst[near2_eve > 1] = near1_dst[near2_eve > 1] 199 | # 200 | pix_dst = near1_dst + near2_dst 201 | pen_map = pix_dst / self.sigma 202 | pen_map = self.w0 * np.exp(-(pen_map ** 2) / 2) 203 | pen_map[ann > 0] = 0 # inner instances zero 204 | return pen_map 205 | 206 | def _augment(self, img, _): 207 | img = np.copy(img) 208 | orig_ann = img[..., 0] # instance ID map 209 | orig_ann_copy = orig_ann.copy() 210 | fixed_ann = self._fix_mirror_padding(orig_ann) 211 | # setting 1 boundary pix of each instance to background 212 | inner_map, contour_map, bg_map = self._erode_obj(fixed_ann) 213 | 214 | # cant do the shortcut because near2 also needs instances 215 | # outside of cropped portion 216 | inst_list = list(np.unique(fixed_ann)) 217 | inst_list.remove(0) # 0 is background 218 | wmap = self._get_weight_map(fixed_ann, inst_list) 219 | 220 | if self.wc is None: 221 | wmap += 1 # uniform weight for all classes 222 | else: 223 | class_weights = np.zeros_like(fixed_ann.shape[:2]) 224 | for class_id, class_w in self.wc.items(): 225 | class_weights[fixed_ann == class_id] = class_w 226 | wmap += class_weights 227 | 228 | # fix other maps to align 229 | img[fixed_ann == 0] = 0 230 | orig_ann[orig_ann > 0] = 1 231 | img = np.dstack([orig_ann_copy, inner_map, contour_map, bg_map, wmap]) 232 | 233 | return img 234 | 235 | 236 | class GaussianBlur(ImageAugmentor): 237 | """ 238 | Gaussian blur the image with random window size 239 | """ 240 | 241 | def __init__(self, max_size=3): 242 | """ 243 | Args: 244 | max_size (int): max possible Gaussian window size would be 2 * max_size + 1 245 | """ 246 | super(GaussianBlur, self).__init__() 247 | self.max_size = max_size 248 | 249 | def _get_augment_params(self, img): 250 | sx, sy = self.rng.randint(1, self.max_size, size=(2,)) 251 | sx = sx * 2 + 1 252 | sy = sy * 2 + 1 253 | return sx, sy 254 | 255 | def _augment(self, img, s): 256 | return np.reshape( 257 | cv2.GaussianBlur( 258 | img, s, sigmaX=0, sigmaY=0, borderType=cv2.BORDER_REPLICATE 259 | ), 260 | img.shape, 261 | ) 262 | 263 | 264 | class BinarizeLabel(ImageAugmentor): 265 | """ 266 | Convert labels to binary maps 267 | """ 268 | 269 | def __init__(self): 270 | super(BinarizeLabel, self).__init__() 271 | 272 | def _get_augment_params(self, img): 273 | return None 274 | 275 | def _augment(self, img, s): 276 | img = np.copy(img) 277 | arr = img[..., 0] 278 | arr[arr > 0] = 1 279 | return img 280 | 281 | 282 | class MedianBlur(ImageAugmentor): 283 | """ 284 | Median blur the image with random window size 285 | """ 286 | 287 | def __init__(self, max_size=3): 288 | """ 289 | Args: 290 | max_size (int): max possible window size 291 | would be 2 * max_size + 1 292 | """ 293 | super(MedianBlur, self).__init__() 294 | self.max_size = max_size 295 | 296 | def _get_augment_params(self, img): 297 | s = self.rng.randint(1, self.max_size) 298 | s = s * 2 + 1 299 | return s 300 | 301 | def _augment(self, img, ksize): 302 | return cv2.medianBlur(img, ksize) 303 | -------------------------------------------------------------------------------- /src/loader/loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset loader 3 | """ 4 | 5 | import random 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | 9 | from tensorpack.dataflow import ( 10 | AugmentImageComponent, 11 | AugmentImageComponents, 12 | BatchData, 13 | BatchDataByShape, 14 | CacheData, 15 | PrefetchDataZMQ, 16 | RNGDataFlow, 17 | RepeatedData, 18 | ) 19 | 20 | from config import Config 21 | 22 | 23 | class DatasetSerial(RNGDataFlow, Config): 24 | """ 25 | Produce ``(image, label)`` pair, where 26 | ``image`` has shape HWC and is RGB, has values in range [0-255]. 27 | 28 | ``label`` is a float image of shape (H, W, C). Number of C depends 29 | on `self.model_mode` within `config.py` 30 | """ 31 | 32 | def __init__(self, path_list): 33 | super(DatasetSerial, self).__init__() 34 | self.path_list = path_list 35 | 36 | ## 37 | 38 | def size(self): 39 | return len(self.path_list) 40 | 41 | ## 42 | 43 | def get_data(self): 44 | idx_list = list(range(0, len(self.path_list))) 45 | random.shuffle(idx_list) 46 | for idx in idx_list: 47 | 48 | if self.model_mode == "seg_gland" or self.model_mode == "seg_nuc": 49 | data = np.load(self.path_list[idx]) 50 | # split stacked channel into image and label 51 | img = data[..., :3] # RGB image 52 | img = img.astype("uint8") 53 | lab = data[..., 3:] # instance ID map 54 | yield [img, lab] 55 | 56 | else: 57 | data = np.load(self.path_list[idx]) 58 | # split stacked channel into image and label 59 | img = data[:-1] # RGB image 60 | # reshape vector to HxWxC 61 | img = np.reshape(img, self.train_input_shape + [self.input_chans]) 62 | if self.model_mode != "class_rotmnist": 63 | img = img.astype("uint8") 64 | 65 | lab = data[-1] # label 66 | lab = np.reshape(lab, [1, 1, 1]) 67 | 68 | yield [img, lab] 69 | 70 | 71 | def valid_generator_seg( 72 | ds, shape_aug=None, input_aug=None, label_aug=None, batch_size=16, nr_procs=1 73 | ): 74 | ### augment both the input and label 75 | ds = ( 76 | ds 77 | if shape_aug is None 78 | else AugmentImageComponents(ds, shape_aug, (0, 1), copy=True) 79 | ) 80 | ### augment just the input 81 | ds = ( 82 | ds 83 | if input_aug is None 84 | else AugmentImageComponent(ds, input_aug, index=0, copy=False) 85 | ) 86 | ### augment just the output 87 | ds = ( 88 | ds 89 | if label_aug is None 90 | else AugmentImageComponent(ds, label_aug, index=1, copy=True) 91 | ) 92 | # 93 | ds = BatchData(ds, batch_size, remainder=True) 94 | ds = CacheData(ds) # cache all inference images 95 | return ds 96 | 97 | 98 | def valid_generator_class( 99 | ds, shape_aug=None, input_aug=None, batch_size=16, nr_procs=1 100 | ): 101 | ### augment the input 102 | ds = ( 103 | ds 104 | if shape_aug is None 105 | else AugmentImageComponent(ds, shape_aug, index=0, copy=True) 106 | ) 107 | ### augment the input 108 | ds = ( 109 | ds 110 | if input_aug is None 111 | else AugmentImageComponent(ds, input_aug, index=0, copy=False) 112 | ) 113 | # 114 | ds = BatchData(ds, batch_size, remainder=True) 115 | ds = CacheData(ds) # cache all inference images 116 | return ds 117 | 118 | 119 | def train_generator_seg( 120 | ds, shape_aug=None, input_aug=None, label_aug=None, batch_size=16, nr_procs=8 121 | ): 122 | ### augment both the input and label 123 | ds = ( 124 | ds 125 | if shape_aug is None 126 | else AugmentImageComponents(ds, shape_aug, (0, 1), copy=True) 127 | ) 128 | ### augment just the input i.e index 0 within each yield of DatasetSerial 129 | ds = ( 130 | ds 131 | if input_aug is None 132 | else AugmentImageComponent(ds, input_aug, index=0, copy=False) 133 | ) 134 | ### augment just the output i.e index 1 within each yield of DatasetSerial 135 | ds = ( 136 | ds 137 | if label_aug is None 138 | else AugmentImageComponent(ds, label_aug, index=1, copy=True) 139 | ) 140 | # 141 | ds = BatchDataByShape(ds, batch_size, idx=0) 142 | ds = PrefetchDataZMQ(ds, nr_procs) 143 | return ds 144 | 145 | 146 | def train_generator_class( 147 | ds, shape_aug=None, input_aug=None, batch_size=16, nr_procs=8 148 | ): 149 | ### augment the input 150 | ds = ( 151 | ds 152 | if shape_aug is None 153 | else AugmentImageComponent(ds, shape_aug, index=0, copy=True) 154 | ) 155 | ### augment the input i.e index 0 within each yield of DatasetSerial 156 | ds = ( 157 | ds 158 | if input_aug is None 159 | else AugmentImageComponent(ds, input_aug, index=0, copy=False) 160 | ) 161 | # 162 | ds = BatchDataByShape(ds, batch_size, idx=0) 163 | ds = PrefetchDataZMQ(ds, nr_procs) 164 | return ds 165 | 166 | 167 | def visualize(datagen, batch_size): 168 | """ 169 | Read the batch from 'datagen' and display 'view_size' number of 170 | of images and their corresponding Ground Truth 171 | """ 172 | cfg = Config() 173 | 174 | def prep_imgs(img, lab): 175 | 176 | # Deal with HxWx1 case 177 | img = np.squeeze(img) 178 | 179 | if cfg.model_mode == "seg_gland" or cfg.model_mode == "seg_nuc": 180 | cmap = plt.get_cmap("jet") 181 | # cmap may randomly fails if of other types 182 | lab = lab.astype("float32") 183 | lab_chs = np.dsplit(lab, lab.shape[-1]) 184 | for i, ch in enumerate(lab_chs): 185 | ch = np.squeeze(ch) 186 | # cmap may behave stupidly 187 | ch = ch / (np.max(ch) - np.min(ch) + 1.0e-16) 188 | # take RGB from RGBA heat map 189 | lab_chs[i] = cmap(ch)[..., :3] 190 | img = img.astype("float32") / 255.0 191 | prepped_img = np.concatenate([img] + lab_chs, axis=1) 192 | else: 193 | prepped_img = img 194 | return prepped_img 195 | 196 | ds = RepeatedData(datagen, -1) 197 | ds.reset_state() 198 | for imgs, labs in ds.get_data(): 199 | if cfg.model_mode == "seg_gland" or cfg.model_mode == "seg_nuc": 200 | for idx in range(0, 4): 201 | displayed_img = prep_imgs(imgs[idx], labs[idx]) 202 | # plot the image and the label 203 | plt.subplot(4, 1, idx + 1) 204 | plt.imshow(displayed_img, vmin=-1, vmax=1) 205 | plt.axis("off") 206 | plt.show() 207 | else: 208 | for idx in range(0, 8): 209 | displayed_img = prep_imgs(imgs[idx], labs[idx]) 210 | # plot the image and the label 211 | plt.subplot(2, 4, idx + 1) 212 | plt.imshow(displayed_img) 213 | if len(cfg.label_names) > 0: 214 | lab_title = cfg.label_names[int(labs[idx])] 215 | else: 216 | lab_tite = int(labs[idx]) 217 | plt.title(lab_title) 218 | plt.axis("off") 219 | plt.show() 220 | return 221 | -------------------------------------------------------------------------------- /src/misc/patch_extractor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Patch extraction script 3 | """ 4 | 5 | import math 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | 10 | from .utils import cropping_center 11 | 12 | 13 | class PatchExtractor(object): 14 | """ 15 | Extractor to generate patches with or without padding. 16 | Turn on debug mode to see how it is done. 17 | 18 | Args: 19 | x : input image, should be of shape HWC 20 | win_size : a tuple of (h, w) 21 | step_size : a tuple of (h, w) 22 | debug : flag to see how it is done 23 | 24 | Return: 25 | a list of sub patches, each patch has dtype same as x 26 | 27 | Examples: 28 | >>> xtractor = PatchExtractor((450, 450), (120, 120)) 29 | >>> img = np.full([1200, 1200, 3], 255, np.uint8) 30 | >>> patches = xtractor.extract(img, 'mirror') 31 | """ 32 | 33 | def __init__(self, win_size, step_size, debug=False): 34 | 35 | self.patch_type = 'mirror' 36 | self.win_size = win_size 37 | self.step_size = step_size 38 | self.debug = debug 39 | self.counter = 0 40 | 41 | def __get_patch(self, x, ptx): 42 | pty = (ptx[0]+self.win_size[0], 43 | ptx[1]+self.win_size[1]) 44 | win = x[ptx[0]:pty[0], 45 | ptx[1]:pty[1]] 46 | assert win.shape[0] == self.win_size[0] and \ 47 | win.shape[1] == self.win_size[1], \ 48 | '[BUG] Incorrect Patch Size {0}'.format(win.shape) 49 | if self.debug: 50 | if self.patch_type == 'mirror': 51 | cen = cropping_center(win, self.step_size) 52 | cen = cen[..., self.counter % 3] 53 | cen.fill(150) 54 | cv2.rectangle(x, ptx, pty, (255, 0, 0), 2) 55 | plt.imshow(x) 56 | plt.show(block=False) 57 | plt.pause(1) 58 | plt.close() 59 | self.counter += 1 60 | return win 61 | 62 | def __extract_valid(self, x): 63 | """ 64 | Extracted patches without padding, only work in case win_size > step_size 65 | 66 | Note: to deal with the remaining portions which are at the boundary a.k.a 67 | those which do not fit when slide left->right, top->bottom), we flip 68 | the sliding direction then extract 1 patch starting from right / bottom edge. 69 | There will be 1 additional patch extracted at the bottom-right corner 70 | 71 | Args: 72 | x : input image, should be of shape HWC 73 | win_size : a tuple of (h, w) 74 | step_size : a tuple of (h, w) 75 | 76 | Return: 77 | a list of sub patches, each patch is same dtype as x 78 | """ 79 | 80 | im_h = x.shape[0] 81 | im_w = x.shape[1] 82 | 83 | def extract_infos(length, win_size, step_size): 84 | flag = (length - win_size) % step_size != 0 85 | last_step = math.floor((length - win_size) / step_size) 86 | last_step = (last_step + 1) * step_size 87 | return flag, last_step 88 | 89 | h_flag, h_last = extract_infos( 90 | im_h, self.win_size[0], self.step_size[0]) 91 | w_flag, w_last = extract_infos( 92 | im_w, self.win_size[1], self.step_size[1]) 93 | 94 | sub_patches = [] 95 | #### Deal with valid block 96 | for row in range(0, h_last, self.step_size[0]): 97 | for col in range(0, w_last, self.step_size[1]): 98 | win = self.__get_patch(x, (row, col)) 99 | sub_patches.append(win) 100 | #### Deal with edge case 101 | if h_flag: 102 | row = im_h - self.win_size[0] 103 | for col in range(0, w_last, self.step_size[1]): 104 | win = self.__get_patch(x, (row, col)) 105 | sub_patches.append(win) 106 | if w_flag: 107 | col = im_w - self.win_size[1] 108 | for row in range(0, h_last, self.step_size[0]): 109 | win = self.__get_patch(x, (row, col)) 110 | sub_patches.append(win) 111 | if h_flag and w_flag: 112 | ptx = (im_h - self.win_size[0], im_w - self.win_size[1]) 113 | win = self.__get_patch(x, ptx) 114 | sub_patches.append(win) 115 | return sub_patches 116 | 117 | def __extract_mirror(self, x): 118 | """ 119 | Extracted patches with mirror padding the boundary such that the 120 | central region of each patch is always within the orginal (non-padded) 121 | image while all patches' central region cover the whole orginal image 122 | 123 | Args: 124 | x : input image, should be of shape HWC 125 | win_size : a tuple of (h, w) 126 | step_size : a tuple of (h, w) 127 | 128 | Return: 129 | a list of sub patches, each patch is same dtype as x 130 | """ 131 | 132 | diff_h = self.win_size[0] - self.step_size[0] 133 | padt = diff_h // 2 134 | padb = diff_h - padt 135 | 136 | diff_w = self.win_size[1] - self.step_size[1] 137 | padl = diff_w // 2 138 | padr = diff_w - padl 139 | 140 | pad_type = 'constant' if self.debug else 'reflect' 141 | x = np.lib.pad(x, ((padt, padb), (padl, padr), (0, 0)), pad_type) 142 | sub_patches = self.__extract_valid(x) 143 | return sub_patches 144 | 145 | def extract(self, x, patch_type): 146 | """ 147 | Extract patches 148 | """ 149 | patch_type = patch_type.lower() 150 | self.patch_type = patch_type 151 | if patch_type == 'valid': 152 | return self.__extract_valid(x) 153 | elif patch_type == 'mirror': 154 | return self.__extract_mirror(x) 155 | else: 156 | assert False, 'Unknown Patch Type [%s]' % patch_type 157 | return 158 | 159 | ########################################################################### 160 | 161 | 162 | if __name__ == '__main__': 163 | # debugging 164 | xtractor = PatchExtractor((450, 450), (120, 120), debug=True) 165 | a = np.full([1200, 1200, 3], 255, np.uint8) 166 | xtractor.extract(a, 'mirror') 167 | xtractor.extract(a, 'valid') 168 | -------------------------------------------------------------------------------- /src/misc/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils 3 | """ 4 | 5 | import glob 6 | import os 7 | import shutil 8 | import cv2 9 | import numpy as np 10 | 11 | 12 | def bounding_box(img): 13 | """ 14 | Get the bounding box of a binary region 15 | 16 | Args: 17 | img: input array- should contain one 18 | binary object. 19 | """ 20 | rows = np.any(img, axis=1) 21 | cols = np.any(img, axis=0) 22 | rmin, rmax = np.where(rows)[0][[0, -1]] 23 | cmin, cmax = np.where(cols)[0][[0, -1]] 24 | # due to python indexing, need to add 1 to max 25 | # else accessing will be 1px in the box, not out 26 | rmax += 1 27 | cmax += 1 28 | return [rmin, rmax, cmin, cmax] 29 | 30 | 31 | def cropping_center(img, crop_shape, batch=False): 32 | """ 33 | Crop an array at the centre 34 | 35 | Args: 36 | img: input array 37 | crop_shape: new spatial dimensions (h,w) 38 | """ 39 | 40 | orig_shape = img.shape 41 | if not batch: 42 | h_0 = int((orig_shape[0] - crop_shape[0]) * 0.5) 43 | w_0 = int((orig_shape[1] - crop_shape[1]) * 0.5) 44 | img = img[h_0 : h_0 + crop_shape[0], w_0 : w_0 + crop_shape[1]] 45 | else: 46 | h_0 = int((orig_shape[1] - crop_shape[0]) * 0.5) 47 | w_0 = int((orig_shape[2] - crop_shape[1]) * 0.5) 48 | img = img[:, h_0 : h_0 + crop_shape[0], w_0 : w_0 + crop_shape[1]] 49 | return img 50 | 51 | 52 | def rm_n_mkdir(dir_path): 53 | """ 54 | Remove, then create a new directory 55 | """ 56 | 57 | if os.path.isdir(dir_path): 58 | shutil.rmtree(dir_path) 59 | os.makedirs(dir_path) 60 | 61 | 62 | def get_files(data_dir_list, data_ext): 63 | """ 64 | Given a list of directories containing data with extention 'date_ext', 65 | generate a list of paths for all files within these directories 66 | """ 67 | 68 | data_files = [] 69 | for sub_dir in data_dir_list: 70 | files = glob.glob(sub_dir + "/*" + data_ext) 71 | data_files.extend(files) 72 | 73 | return data_files 74 | 75 | 76 | def remap_label(pred, by_size=False): 77 | """Rename all instance id so that the id is contiguous i.e [0, 1, 2, 3] 78 | not [0, 2, 4, 6]. The ordering of instances (which one comes first) 79 | is preserved unless by_size=True, then the instances will be reordered 80 | so that bigger object has smaller ID. 81 | Args: 82 | pred : the 2d array contain instances where each instances is marked 83 | by non-zero integer 84 | by_size : renaming with larger object has smaller id (on-top) 85 | """ 86 | pred_id = list(np.unique(pred)) 87 | pred_id.remove(0) 88 | if len(pred_id) == 0: 89 | return pred # no label 90 | if by_size: 91 | pred_size = [] 92 | for inst_id in pred_id: 93 | size = (pred == inst_id).sum() 94 | pred_size.append(size) 95 | # sort the id by size in descending order 96 | pair_list = zip(pred_id, pred_size) 97 | pair_list = sorted(pair_list, key=lambda x: x[1], reverse=True) 98 | pred_id, pred_size = zip(*pair_list) 99 | 100 | new_pred = np.zeros(pred.shape, np.int32) 101 | for idx, inst_id in enumerate(pred_id): 102 | new_pred[pred == inst_id] = idx + 1 103 | return new_pred 104 | -------------------------------------------------------------------------------- /src/misc/viz_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualisation utils 3 | """ 4 | 5 | import math 6 | import random 7 | import colorsys 8 | import cv2 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | from .utils import bounding_box 13 | 14 | 15 | def random_colors(N, bright=True): 16 | """ 17 | Generate random colors. 18 | To get visually distinct colors, generate them in HSV space then 19 | convert to RGB. 20 | """ 21 | brightness = 1.0 if bright else 0.7 22 | hsv = [(i / N, 1, brightness) for i in range(N)] 23 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) 24 | random.shuffle(colors) 25 | return colors 26 | 27 | 28 | def visualize_instances(mask, canvas=None): 29 | """ 30 | Args: 31 | mask: array of NW 32 | Return: 33 | Image with the instance overlaid 34 | """ 35 | 36 | colour = [255, 255, 0] # yellow 37 | 38 | canvas = ( 39 | np.full((mask.shape[0], mask.shape[1]) + (3,), 200, dtype=np.uint8) 40 | if canvas is None 41 | else np.copy(canvas) 42 | ) 43 | 44 | insts_list = list(np.unique(mask)) 45 | insts_list.remove(0) # remove background 46 | 47 | for idx, inst_id in enumerate(insts_list): 48 | inst_map = np.array(mask == inst_id, np.uint8) 49 | y1, y2, x1, x2 = bounding_box(inst_map) 50 | y1 = y1 - 2 if y1 - 2 >= 0 else y1 51 | x1 = x1 - 2 if x1 - 2 >= 0 else x1 52 | x2 = x2 + 2 if x2 + 2 <= mask.shape[1] - 1 else x2 53 | y2 = y2 + 2 if y2 + 2 <= mask.shape[0] - 1 else y2 54 | inst_map_crop = inst_map[y1:y2, x1:x2] 55 | inst_canvas_crop = canvas[y1:y2, x1:x2] 56 | contours = cv2.findContours( 57 | inst_map_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE 58 | ) 59 | cv2.drawContours(inst_canvas_crop, contours[0], -1, colour, 3) 60 | canvas[y1:y2, x1:x2] = inst_canvas_crop 61 | 62 | return canvas 63 | -------------------------------------------------------------------------------- /src/model/class_pcam/graph.py: -------------------------------------------------------------------------------- 1 | """ 2 | DSF-CNN for tumour classification 3 | """ 4 | 5 | import tensorflow as tf 6 | 7 | from tensorpack import * 8 | from tensorpack.models import BNReLU, Conv2D, MaxPooling 9 | from tensorpack.tfutils.summary import add_moving_summary, add_param_summary 10 | 11 | from model.utils.model_utils import * 12 | from model.utils.gconv_utils import * 13 | 14 | import sys 15 | 16 | sys.path.append("..") # adds higher directory to python modules path. 17 | try: # HACK: import beyond current level, may need to restructure 18 | from config import Config 19 | except ImportError: 20 | assert False, "Fail to import config.py" 21 | 22 | 23 | def group_concat(x, y, nr_orients): 24 | shape1 = x.get_shape().as_list() 25 | chans1 = shape1[3] 26 | c1 = int(chans1 / nr_orients) 27 | x = tf.reshape(x, [-1, shape1[1], shape1[2], nr_orients, c1]) 28 | 29 | shape2 = y.get_shape().as_list() 30 | chans2 = shape2[3] 31 | c2 = int(chans2 / nr_orients) 32 | y = tf.reshape(y, [-1, shape2[1], shape2[2], nr_orients, c2]) 33 | 34 | z = tf.concat([x, y], axis=-1) 35 | 36 | return tf.reshape(z, [-1, shape1[1], shape1[2], nr_orients * (c1 + c2)]) 37 | 38 | 39 | def g_dense_blk( 40 | name, 41 | l, 42 | ch, 43 | ksize, 44 | count, 45 | nr_orients, 46 | filter_type, 47 | basis_filter_list, 48 | rot_matrix_list, 49 | padding="same", 50 | bn_init=True, 51 | ): 52 | with tf.variable_scope(name): 53 | for i in range(0, count): 54 | with tf.variable_scope("blk/" + str(i)): 55 | if bn_init: 56 | x = GBNReLU("preact_bna", l, nr_orients) 57 | else: 58 | x = l 59 | x = GConv2D( 60 | "conv1", 61 | x, 62 | ch[0], 63 | ksize[0], 64 | nr_orients, 65 | filter_type, 66 | basis_filter_list[0], 67 | rot_matrix_list[0], 68 | ) 69 | x = GConv2D( 70 | "conv2", 71 | x, 72 | ch[1], 73 | ksize[1], 74 | nr_orients, 75 | filter_type, 76 | basis_filter_list[1], 77 | rot_matrix_list[1], 78 | activation="identity", 79 | ) 80 | ## 81 | if padding == "valid": 82 | x_shape = x.get_shape().as_list() 83 | l_shape = l.get_shape().as_list() 84 | l = crop_op( 85 | l, 86 | (l_shape[1] - x_shape[2], l_shape[1] - x_shape[2]), 87 | "channels_last", 88 | ) 89 | 90 | l = group_concat(l, x, nr_orients) 91 | l = GBNReLU("blk_bna", l, nr_orients) 92 | return l 93 | 94 | 95 | def net( 96 | name, i, basis_filter_list, rot_matrix_list, nr_orients, filter_type, is_training 97 | ): 98 | """ 99 | Dense Steerable Filter CNN 100 | """ 101 | 102 | dense_basis_list = [basis_filter_list[0], basis_filter_list[1]] 103 | dense_rot_list = [rot_matrix_list[0], rot_matrix_list[1]] 104 | 105 | with tf.variable_scope(name): 106 | 107 | c1 = GConv2D( 108 | "ds_conv1", 109 | i, 110 | 8, 111 | 7, 112 | nr_orients, 113 | filter_type, 114 | basis_filter_list[1], 115 | rot_matrix_list[1], 116 | input_layer=True, 117 | ) 118 | c2 = GConv2D( 119 | "ds_conv2", 120 | c1, 121 | 8, 122 | 7, 123 | nr_orients, 124 | filter_type, 125 | basis_filter_list[1], 126 | rot_matrix_list[1], 127 | activation="identity", 128 | ) 129 | p1 = MaxPooling("max_pool1", c2, 2) 130 | #### 131 | 132 | d1 = g_dense_blk( 133 | "dense1", 134 | p1, 135 | [32, 8], 136 | [5, 7], 137 | 2, 138 | nr_orients, 139 | filter_type, 140 | dense_basis_list, 141 | dense_rot_list, 142 | bn_init=False, 143 | ) 144 | c3 = GConv2D( 145 | "ds_conv3", 146 | d1, 147 | 32, 148 | 5, 149 | nr_orients, 150 | filter_type, 151 | basis_filter_list[0], 152 | rot_matrix_list[0], 153 | activation="identity", 154 | ) 155 | p2 = MaxPooling("max_pool2", c3, 2, padding="valid") 156 | #### 157 | 158 | d2 = g_dense_blk( 159 | "dense2", 160 | p2, 161 | [32, 8], 162 | [5, 7], 163 | 2, 164 | nr_orients, 165 | filter_type, 166 | dense_basis_list, 167 | dense_rot_list, 168 | bn_init=False, 169 | ) 170 | c4 = GConv2D( 171 | "ds_conv4", 172 | d2, 173 | 32, 174 | 5, 175 | nr_orients, 176 | filter_type, 177 | basis_filter_list[0], 178 | rot_matrix_list[0], 179 | activation="identity", 180 | ) 181 | p3 = MaxPooling("max_pool3", c4, 2, padding="valid") 182 | #### 183 | 184 | d3 = g_dense_blk( 185 | "dense3", 186 | p3, 187 | [32, 8], 188 | [5, 7], 189 | 3, 190 | nr_orients, 191 | filter_type, 192 | dense_basis_list, 193 | dense_rot_list, 194 | bn_init=False, 195 | ) 196 | c5 = GConv2D( 197 | "ds_conv5", 198 | d3, 199 | 32, 200 | 5, 201 | nr_orients, 202 | filter_type, 203 | basis_filter_list[0], 204 | rot_matrix_list[0], 205 | activation="identity", 206 | ) 207 | p4 = MaxPooling("max_pool4", c5, 2, padding="valid") 208 | #### 209 | 210 | d4 = g_dense_blk( 211 | "dense4", 212 | p4, 213 | [32, 8], 214 | [5, 7], 215 | 3, 216 | nr_orients, 217 | filter_type, 218 | dense_basis_list, 219 | dense_rot_list, 220 | bn_init=False, 221 | ) 222 | c6 = GConv2D( 223 | "ds_conv6", 224 | d4, 225 | 32, 226 | 5, 227 | nr_orients, 228 | filter_type, 229 | basis_filter_list[0], 230 | rot_matrix_list[0], 231 | ) 232 | p5 = AvgPooling("glb_avg_pool", c6, 6, padding="valid") 233 | p6 = GroupPool("orient_pool", p5, nr_orients, pool_type="max") 234 | #### 235 | 236 | c7 = Conv2D("conv3", p6, 96, 1, use_bias=True, nl=BNReLU) 237 | c7 = tf.layers.dropout(c7, rate=0.3, seed=5, training=is_training) 238 | c8 = Conv2D("conv4", c7, 96, 1, use_bias=True, nl=BNReLU) 239 | c8 = tf.layers.dropout(c8, rate=0.3, seed=5, training=is_training) 240 | 241 | return c8 242 | 243 | 244 | class Model(ModelDesc, Config): 245 | def __init__(self, freeze=False): 246 | super(Model, self).__init__() 247 | # assert tf.test.is_gpu_available() 248 | self.freeze = freeze 249 | self.data_format = "NHWC" 250 | 251 | def _get_inputs(self): 252 | return [ 253 | InputDesc(tf.float32, [None] + self.train_input_shape + [3], "images"), 254 | InputDesc( 255 | tf.float32, [None] + self.train_output_shape + [None], "truemap-coded" 256 | ), 257 | ] 258 | 259 | # for node to receive manual info such as learning rate. 260 | def add_manual_variable(self, name, init_value, summary=True): 261 | var = tf.get_variable(name, initializer=init_value, trainable=False) 262 | if summary: 263 | tf.summary.scalar(name + "-summary", var) 264 | return 265 | 266 | def _get_optimizer(self): 267 | with tf.variable_scope("", reuse=True): 268 | lr = tf.get_variable("learning_rate") 269 | opt = self.optimizer(learning_rate=lr) 270 | return opt 271 | 272 | 273 | class Graph(Model): 274 | def _build_graph(self, inputs): 275 | 276 | is_training = get_current_tower_context().is_training 277 | 278 | images, truemap_coded = inputs 279 | orig_imgs = images 280 | true = truemap_coded[..., 0] 281 | true = tf.cast(true, tf.int32) 282 | true = tf.identity(true, name="truemap") 283 | one_hot = tf.one_hot(true, 2, axis=-1) 284 | true = tf.expand_dims(true, axis=-1) 285 | 286 | #### 287 | with argscope( 288 | Conv2D, 289 | activation=tf.identity, 290 | use_bias=False, # K.he initializer 291 | W_init=tf.variance_scaling_initializer(scale=2.0, mode="fan_out"), 292 | ), argscope([Conv2D], data_format=self.data_format): 293 | 294 | i = images if not self.input_norm else images / 255.0 295 | 296 | #### 297 | feat = net( 298 | "net", 299 | i, 300 | self.basis_filter_list, 301 | self.rot_matrix_list, 302 | self.nr_orients, 303 | self.filter_type, 304 | is_training, 305 | ) 306 | 307 | #### Prediction 308 | o_logi = Conv2D("output", feat, 2, 1, use_bias=True, nl=tf.identity) 309 | soft = tf.nn.softmax(o_logi, axis=-1) 310 | 311 | prob = tf.identity(soft, name="predmap-prob") 312 | 313 | # encoded so that inference can extract all output at once 314 | predmap_coded = tf.concat(prob, axis=-1, name="predmap-coded") 315 | 316 | #### 317 | if get_current_tower_context().is_training: 318 | # ---- LOSS ----# 319 | loss = 0 320 | for term, weight in self.loss_term.items(): 321 | if term == "bce": 322 | term_loss = categorical_crossentropy(soft, one_hot) 323 | term_loss = tf.reduce_mean(term_loss, name="loss-bce") 324 | else: 325 | assert False, "Not support loss term: %s" % term 326 | add_moving_summary(term_loss) 327 | loss += term_loss * weight 328 | 329 | ### combine the loss into single cost function 330 | wd_loss = regularize_cost(".*/W", l2_regularizer(1.0e-7), name="l2_wd_loss") 331 | add_moving_summary(wd_loss) 332 | self.cost = tf.identity(loss + wd_loss, name="overall-loss") 333 | add_moving_summary(self.cost) 334 | #### 335 | 336 | add_param_summary((".*/W", ["histogram"])) # monitor W 337 | 338 | ### logging visual sthg 339 | orig_imgs = tf.cast(orig_imgs, tf.uint8) 340 | tf.summary.image("input", orig_imgs, max_outputs=1) 341 | 342 | return 343 | -------------------------------------------------------------------------------- /src/model/seg_gland/graph.py: -------------------------------------------------------------------------------- 1 | """ 2 | DSF-CNN for gland segmentation 3 | """ 4 | 5 | import tensorflow as tf 6 | 7 | from tensorpack import * 8 | from tensorpack.models import BNReLU, Conv2D, MaxPooling 9 | from tensorpack.tfutils.summary import add_moving_summary, add_param_summary 10 | 11 | from model.utils.model_utils import * 12 | from model.utils.gconv_utils import * 13 | 14 | import sys 15 | 16 | sys.path.append("..") # adds higher directory to python modules path. 17 | try: # HACK: import beyond current level, may need to restructure 18 | from config import Config 19 | except ImportError: 20 | assert False, "Fail to import config.py" 21 | 22 | 23 | def upsample(name, x, size): 24 | return tf.image.resize_images(x, [size, size]) 25 | 26 | 27 | def group_concat(x, y, nr_orients): 28 | shape1 = x.get_shape().as_list() 29 | chans1 = shape1[3] 30 | c1 = int(chans1 / nr_orients) 31 | x = tf.reshape(x, [-1, shape1[1], shape1[2], nr_orients, c1]) 32 | 33 | shape2 = y.get_shape().as_list() 34 | chans2 = shape2[3] 35 | c2 = int(chans2 / nr_orients) 36 | y = tf.reshape(y, [-1, shape2[1], shape2[2], nr_orients, c2]) 37 | 38 | z = tf.concat([x, y], axis=-1) 39 | 40 | return tf.reshape(z, [-1, shape1[1], shape1[2], nr_orients * (c1 + c2)]) 41 | 42 | 43 | def g_dense_blk( 44 | name, 45 | l, 46 | ch, 47 | ksize, 48 | count, 49 | nr_orients, 50 | filter_type, 51 | basis_filter_list, 52 | rot_matrix_list, 53 | padding="same", 54 | ): 55 | with tf.variable_scope(name): 56 | for i in range(0, count): 57 | with tf.variable_scope("blk/" + str(i)): 58 | x = GBNReLU("preact_bna", l, nr_orients) 59 | x = GConv2D( 60 | "conv1", 61 | x, 62 | ch[0], 63 | ksize[0], 64 | nr_orients, 65 | filter_type, 66 | basis_filter_list[0], 67 | rot_matrix_list[0], 68 | ) 69 | x = GConv2D( 70 | "conv2", 71 | x, 72 | ch[1], 73 | ksize[1], 74 | nr_orients, 75 | filter_type, 76 | basis_filter_list[1], 77 | rot_matrix_list[1], 78 | activation="identity", 79 | ) 80 | ## 81 | if padding == "valid": 82 | x_shape = x.get_shape().as_list() 83 | l_shape = l.get_shape().as_list() 84 | l = crop_op(l, (l_shape[2] - x_shape[2], l_shape[3] - x_shape[3])) 85 | 86 | l = group_concat(l, x, nr_orients) 87 | l = GBNReLU("blk_bna", l, nr_orients) 88 | return l 89 | 90 | 91 | def encoder( 92 | name, i, basis_filter_list, rot_matrix_list, nr_orients, filter_type, is_training 93 | ): 94 | """ 95 | Dense Steerable Filter Encoder 96 | """ 97 | 98 | dense_basis_list = [basis_filter_list[1], basis_filter_list[0]] 99 | dense_rot_list = [rot_matrix_list[1], rot_matrix_list[0]] 100 | 101 | with tf.variable_scope(name): 102 | 103 | c1 = GConv2D( 104 | "ds_conv1", 105 | i, 106 | 10, 107 | 7, 108 | nr_orients, 109 | filter_type, 110 | basis_filter_list[1], 111 | rot_matrix_list[1], 112 | input_layer=True, 113 | ) 114 | c2 = GConv2D( 115 | "ds_conv2", 116 | c1, 117 | 10, 118 | 7, 119 | nr_orients, 120 | filter_type, 121 | basis_filter_list[1], 122 | rot_matrix_list[1], 123 | activation="identity", 124 | ) 125 | p1 = MaxPooling("max_pool1", c2, 2) 126 | #### 127 | 128 | d1 = g_dense_blk( 129 | "dense1", 130 | p1, 131 | [14, 6], 132 | [7, 5], 133 | 3, 134 | nr_orients, 135 | filter_type, 136 | dense_basis_list, 137 | dense_rot_list, 138 | ) 139 | c3 = GConv2D( 140 | "ds_conv3", 141 | d1, 142 | 16, 143 | 5, 144 | nr_orients, 145 | filter_type, 146 | basis_filter_list[0], 147 | rot_matrix_list[0], 148 | activation="identity", 149 | ) 150 | p2 = MaxPooling("max_pool2", c3, 2, padding="valid") 151 | #### 152 | 153 | d2 = g_dense_blk( 154 | "dense2", 155 | p2, 156 | [14, 6], 157 | [7, 5], 158 | 4, 159 | nr_orients, 160 | filter_type, 161 | dense_basis_list, 162 | dense_rot_list, 163 | ) 164 | c4 = GConv2D( 165 | "ds_conv4", 166 | d2, 167 | 32, 168 | 5, 169 | nr_orients, 170 | filter_type, 171 | basis_filter_list[0], 172 | rot_matrix_list[0], 173 | activation="identity", 174 | ) 175 | p3 = MaxPooling("max_pool3", c4, 2, padding="valid") 176 | #### 177 | 178 | d3 = g_dense_blk( 179 | "dense3", 180 | p3, 181 | [14, 6], 182 | [7, 5], 183 | 5, 184 | nr_orients, 185 | filter_type, 186 | dense_basis_list, 187 | dense_rot_list, 188 | ) 189 | c5 = GConv2D( 190 | "ds_conv5", 191 | d3, 192 | 32, 193 | 5, 194 | nr_orients, 195 | filter_type, 196 | basis_filter_list[0], 197 | rot_matrix_list[0], 198 | activation="identity", 199 | ) 200 | p4 = MaxPooling("max_pool4", c5, 2, padding="valid") 201 | #### 202 | 203 | d4 = g_dense_blk( 204 | "dense4", 205 | p4, 206 | [14, 6], 207 | [7, 5], 208 | 6, 209 | nr_orients, 210 | filter_type, 211 | dense_basis_list, 212 | dense_rot_list, 213 | ) 214 | c6 = GConv2D( 215 | "ds_conv6", 216 | d4, 217 | 32, 218 | 5, 219 | nr_orients, 220 | filter_type, 221 | basis_filter_list[0], 222 | rot_matrix_list[0], 223 | activation="identity", 224 | ) 225 | 226 | return [c2, c3, c4, c5, c6] 227 | 228 | 229 | def decoder( 230 | name, i, basis_filter_list, rot_matrix_list, nr_orients, filter_type, is_training 231 | ): 232 | """ 233 | Dense Steerable Filter Decoder 234 | """ 235 | 236 | dense_basis_list = [basis_filter_list[1], basis_filter_list[0]] 237 | dense_rot_list = [rot_matrix_list[1], rot_matrix_list[0]] 238 | 239 | with tf.variable_scope(name): 240 | with tf.variable_scope("us1"): 241 | us1 = upsample("us1", i[-1], 56) 242 | us1 = g_dense_blk( 243 | "dense_us1", 244 | us1, 245 | [14, 6], 246 | [7, 5], 247 | 4, 248 | nr_orients, 249 | filter_type, 250 | dense_basis_list, 251 | dense_rot_list, 252 | ) 253 | us1 = GConv2D( 254 | "us_conv1", 255 | us1, 256 | 32, 257 | 5, 258 | nr_orients, 259 | filter_type, 260 | basis_filter_list[0], 261 | rot_matrix_list[0], 262 | activation="identity", 263 | ) 264 | #### 265 | 266 | with tf.variable_scope("us2"): 267 | us2 = upsample("us2", us1, 112) 268 | us2_sum = tf.add_n([us2, i[-3]]) 269 | us2 = g_dense_blk( 270 | "dense_us2", 271 | us2_sum, 272 | [14, 6], 273 | [7, 5], 274 | 3, 275 | nr_orients, 276 | filter_type, 277 | dense_basis_list, 278 | dense_rot_list, 279 | ) 280 | us2 = GConv2D( 281 | "us_conv2", 282 | us2, 283 | 16, 284 | 5, 285 | nr_orients, 286 | filter_type, 287 | basis_filter_list[0], 288 | rot_matrix_list[0], 289 | activation="identity", 290 | ) 291 | #### 292 | 293 | with tf.variable_scope("us3"): 294 | us3 = upsample("us3", us2, 224) 295 | us3_sum = tf.add_n([us3, i[-4]]) 296 | us3 = g_dense_blk( 297 | "dense_us3", 298 | us3_sum, 299 | [14, 6], 300 | [7, 5], 301 | 3, 302 | nr_orients, 303 | filter_type, 304 | dense_basis_list, 305 | dense_rot_list, 306 | ) 307 | us3 = GConv2D( 308 | "us_conv3", 309 | us3, 310 | 10, 311 | 5, 312 | nr_orients, 313 | filter_type, 314 | basis_filter_list[0], 315 | rot_matrix_list[0], 316 | activation="identity", 317 | ) 318 | #### 319 | 320 | with tf.variable_scope("us4"): 321 | us4 = upsample("us4", us3, 448) 322 | us4_sum = tf.add_n([us4, i[-5]]) 323 | us4 = GConv2D( 324 | "us_conv4", 325 | us4_sum, 326 | 10, 327 | 7, 328 | nr_orients, 329 | filter_type, 330 | basis_filter_list[1], 331 | rot_matrix_list[1], 332 | ) 333 | feat = GroupPool("us4", us4, nr_orients, pool_type="max") 334 | 335 | return feat 336 | 337 | 338 | class Model(ModelDesc, Config): 339 | def __init__(self, freeze=False): 340 | super(Model, self).__init__() 341 | # assert tf.test.is_gpu_available() 342 | self.freeze = freeze 343 | self.data_format = "NHWC" 344 | 345 | def _get_inputs(self): 346 | return [ 347 | InputDesc(tf.float32, [None] + self.train_input_shape + [3], "images"), 348 | InputDesc( 349 | tf.float32, [None] + self.train_output_shape + [None], "truemap-coded" 350 | ), 351 | ] 352 | 353 | # for node to receive manual info such as learning rate. 354 | def add_manual_variable(self, name, init_value, summary=True): 355 | var = tf.get_variable(name, initializer=init_value, trainable=False) 356 | if summary: 357 | tf.summary.scalar(name + "-summary", var) 358 | return 359 | 360 | def _get_optimizer(self): 361 | with tf.variable_scope("", reuse=True): 362 | lr = tf.get_variable("learning_rate") 363 | opt = self.optimizer(learning_rate=lr) 364 | return opt 365 | 366 | 367 | class Graph(Model): 368 | def _build_graph(self, inputs): 369 | 370 | is_training = get_current_tower_context().is_training 371 | 372 | images, truemap_coded = inputs 373 | orig_imgs = images 374 | 375 | true = truemap_coded[..., :3] 376 | true = tf.cast(true, tf.int32) 377 | true = tf.identity(true, name="truemap") 378 | one_hot = tf.cast(true, tf.float32) 379 | 380 | #### 381 | with argscope( 382 | Conv2D, 383 | activation=tf.identity, 384 | use_bias=False, # K.he initializer 385 | W_init=tf.variance_scaling_initializer(scale=2.0, mode="fan_out"), 386 | ), argscope([Conv2D], data_format=self.data_format): 387 | 388 | i = images if not self.input_norm else images / 255.0 389 | 390 | #### 391 | d = encoder( 392 | "encoder", 393 | i, 394 | self.basis_filter_list, 395 | self.rot_matrix_list, 396 | self.nr_orients, 397 | self.filter_type, 398 | is_training, 399 | ) 400 | 401 | #### 402 | feat = decoder( 403 | "decoder", 404 | d, 405 | self.basis_filter_list, 406 | self.rot_matrix_list, 407 | self.nr_orients, 408 | self.filter_type, 409 | is_training, 410 | ) 411 | 412 | feat1 = Conv2D("feat", feat, 96, 1, use_bias=True, nl=BNReLU) 413 | o_logi = Conv2D("output", feat, 3, 1, use_bias=True, nl=tf.identity) 414 | soft = tf.nn.softmax(o_logi, axis=-1) 415 | 416 | prob = tf.identity(soft[..., :2], name="predmap-prob") 417 | 418 | # encoded so that inference can extract all output at once 419 | predmap_coded = tf.concat(prob, axis=-1, name="predmap-coded") 420 | 421 | #### 422 | if get_current_tower_context().is_training: 423 | # ---- LOSS ----# 424 | loss = 0 425 | for term, weight in self.loss_term.items(): 426 | if term == "bce": 427 | term_loss = categorical_crossentropy(soft, one_hot) 428 | term_loss = tf.reduce_mean(term_loss, name="loss-bce") 429 | elif "dice" in self.loss_term: 430 | # branch 1 431 | term_loss = dice_loss(soft[..., 0], one_hot[..., 0]) + dice_loss( 432 | soft[..., 1], one_hot[..., 1] 433 | ) 434 | term_loss = tf.identity(term_loss, name="loss-dice") 435 | else: 436 | assert False, "Not support loss term: %s" % term 437 | add_moving_summary(term_loss) 438 | loss += term_loss 439 | 440 | ### combine the loss into single cost function 441 | wd_loss = regularize_cost(".*/W", l2_regularizer(1.0e-7), name="l2_wd_loss") 442 | add_moving_summary(wd_loss) 443 | self.cost = tf.identity(loss + wd_loss, name="overall-loss") 444 | add_moving_summary(self.cost) 445 | #### 446 | 447 | add_param_summary((".*/W", ["histogram"])) # monitor W 448 | 449 | ### logging visual sthg 450 | orig_imgs = tf.cast(orig_imgs, tf.uint8) 451 | tf.summary.image("input", orig_imgs, max_outputs=1) 452 | 453 | pred_blb = colorize(prob[..., 0], cmap="jet") 454 | true_blb = colorize(true[..., 0], cmap="jet") 455 | 456 | pred_cnt = colorize(prob[..., 1], cmap="jet") 457 | true_cnt = colorize(true[..., 1], cmap="jet") 458 | 459 | viz = tf.concat([orig_imgs, pred_blb, pred_cnt, true_blb, true_cnt], 2) 460 | 461 | viz = tf.concat([viz[0], viz[-1]], axis=0) 462 | viz = tf.expand_dims(viz, axis=0) 463 | tf.summary.image("output", viz, max_outputs=1) 464 | 465 | return 466 | 467 | -------------------------------------------------------------------------------- /src/model/seg_nuc/graph.py: -------------------------------------------------------------------------------- 1 | """ 2 | DSF-CNN for nuclear segmentation 3 | """ 4 | 5 | import tensorflow as tf 6 | 7 | from tensorpack import * 8 | from tensorpack.models import BNReLU, Conv2D, MaxPooling 9 | from tensorpack.tfutils.summary import add_moving_summary, add_param_summary 10 | 11 | from model.utils.model_utils import * 12 | from model.utils.gconv_utils import * 13 | 14 | import sys 15 | 16 | sys.path.append("..") # adds higher directory to python modules path. 17 | try: # HACK: import beyond current level, may need to restructure 18 | from config import Config 19 | except ImportError: 20 | assert False, "Fail to import config.py" 21 | 22 | 23 | def upsample(name, x, size): 24 | return tf.image.resize_images(x, [size, size]) 25 | 26 | 27 | def group_concat(x, y, nr_orients): 28 | shape1 = x.get_shape().as_list() 29 | chans1 = shape1[3] 30 | c1 = int(chans1 / nr_orients) 31 | x = tf.reshape(x, [-1, shape1[1], shape1[2], nr_orients, c1]) 32 | 33 | shape2 = y.get_shape().as_list() 34 | chans2 = shape2[3] 35 | c2 = int(chans2 / nr_orients) 36 | y = tf.reshape(y, [-1, shape2[1], shape2[2], nr_orients, c2]) 37 | 38 | z = tf.concat([x, y], axis=-1) 39 | 40 | return tf.reshape(z, [-1, shape1[1], shape1[2], nr_orients * (c1 + c2)]) 41 | 42 | 43 | def g_dense_blk( 44 | name, 45 | l, 46 | ch, 47 | ksize, 48 | count, 49 | nr_orients, 50 | filter_type, 51 | basis_filter_list, 52 | rot_matrix_list, 53 | padding="same", 54 | bn_init=True, 55 | ): 56 | with tf.variable_scope(name): 57 | for i in range(0, count): 58 | with tf.variable_scope("blk/" + str(i)): 59 | if bn_init: 60 | x = GBNReLU("preact_bna", l, nr_orients) 61 | else: 62 | x = l 63 | x = GConv2D( 64 | "conv1", 65 | x, 66 | ch[0], 67 | ksize[0], 68 | nr_orients, 69 | filter_type, 70 | basis_filter_list[0], 71 | rot_matrix_list[0], 72 | ) 73 | x = GConv2D( 74 | "conv2", 75 | x, 76 | ch[1], 77 | ksize[1], 78 | nr_orients, 79 | filter_type, 80 | basis_filter_list[1], 81 | rot_matrix_list[1], 82 | activation="identity", 83 | ) 84 | ## 85 | if padding == "valid": 86 | x_shape = x.get_shape().as_list() 87 | l_shape = l.get_shape().as_list() 88 | l = crop_op( 89 | l, 90 | (l_shape[1] - x_shape[2], l_shape[1] - x_shape[2]), 91 | "channels_last", 92 | ) 93 | 94 | l = group_concat(l, x, nr_orients) 95 | l = GBNReLU("blk_bna", l, nr_orients) 96 | return l 97 | 98 | 99 | def encoder( 100 | name, i, basis_filter_list, rot_matrix_list, nr_orients, filter_type, is_training 101 | ): 102 | """ 103 | Dense Steerable Filter Encoder 104 | """ 105 | 106 | dense_basis_list = [basis_filter_list[1], basis_filter_list[0]] 107 | dense_rot_list = [rot_matrix_list[1], rot_matrix_list[0]] 108 | 109 | with tf.variable_scope(name): 110 | 111 | c1 = GConv2D( 112 | "ds_conv1", 113 | i, 114 | 10, 115 | 7, 116 | nr_orients, 117 | filter_type, 118 | basis_filter_list[1], 119 | rot_matrix_list[1], 120 | input_layer=True, 121 | ) 122 | c2 = GConv2D( 123 | "ds_conv2", 124 | c1, 125 | 10, 126 | 7, 127 | nr_orients, 128 | filter_type, 129 | basis_filter_list[1], 130 | rot_matrix_list[1], 131 | activation="identity", 132 | ) 133 | p1 = MaxPooling("max_pool1", c2, 2) 134 | #### 135 | 136 | d1 = g_dense_blk( 137 | "dense1", 138 | p1, 139 | [14, 6], 140 | [7, 5], 141 | 3, 142 | nr_orients, 143 | filter_type, 144 | dense_basis_list, 145 | dense_rot_list, 146 | ) 147 | c3 = GConv2D( 148 | "ds_conv3", 149 | d1, 150 | 16, 151 | 5, 152 | nr_orients, 153 | filter_type, 154 | basis_filter_list[0], 155 | rot_matrix_list[0], 156 | activation="identity", 157 | ) 158 | p2 = MaxPooling("max_pool2", c3, 2, padding="valid") 159 | #### 160 | 161 | d2 = g_dense_blk( 162 | "dense2", 163 | p2, 164 | [14, 6], 165 | [7, 5], 166 | 4, 167 | nr_orients, 168 | filter_type, 169 | dense_basis_list, 170 | dense_rot_list, 171 | ) 172 | c4 = GConv2D( 173 | "ds_conv4", 174 | d2, 175 | 32, 176 | 5, 177 | nr_orients, 178 | filter_type, 179 | basis_filter_list[0], 180 | rot_matrix_list[0], 181 | activation="identity", 182 | ) 183 | p3 = MaxPooling("max_pool3", c4, 2, padding="valid") 184 | #### 185 | 186 | d3 = g_dense_blk( 187 | "dense3", 188 | p3, 189 | [14, 6], 190 | [7, 5], 191 | 5, 192 | nr_orients, 193 | filter_type, 194 | dense_basis_list, 195 | dense_rot_list, 196 | ) 197 | c5 = GConv2D( 198 | "ds_conv5", 199 | d3, 200 | 32, 201 | 5, 202 | nr_orients, 203 | filter_type, 204 | basis_filter_list[0], 205 | rot_matrix_list[0], 206 | activation="identity", 207 | ) 208 | p4 = MaxPooling("max_pool4", c5, 2, padding="valid") 209 | #### 210 | 211 | d4 = g_dense_blk( 212 | "dense4", 213 | p4, 214 | [14, 6], 215 | [7, 5], 216 | 6, 217 | nr_orients, 218 | filter_type, 219 | dense_basis_list, 220 | dense_rot_list, 221 | ) 222 | c6 = GConv2D( 223 | "ds_conv6", 224 | d4, 225 | 32, 226 | 5, 227 | nr_orients, 228 | filter_type, 229 | basis_filter_list[0], 230 | rot_matrix_list[0], 231 | activation="identity", 232 | ) 233 | 234 | return [c2, c3, c4, c5, c6] 235 | 236 | 237 | def decoder( 238 | name, i, basis_filter_list, rot_matrix_list, nr_orients, filter_type, is_training 239 | ): 240 | """ 241 | Dense Steerable Filter Decoder 242 | """ 243 | 244 | dense_basis_list = [basis_filter_list[1], basis_filter_list[0]] 245 | dense_rot_list = [rot_matrix_list[1], rot_matrix_list[0]] 246 | 247 | with tf.variable_scope(name): 248 | with tf.variable_scope("us1"): 249 | us1 = upsample("us1", i[-1], 32) 250 | us1_sum = tf.add_n([us1, i[-2]]) 251 | us1 = g_dense_blk( 252 | "dense_us1", 253 | us1_sum, 254 | [14, 6], 255 | [7, 5], 256 | 4, 257 | nr_orients, 258 | filter_type, 259 | dense_basis_list, 260 | dense_rot_list, 261 | ) 262 | us1 = GConv2D( 263 | "us_conv1", 264 | us1, 265 | 32, 266 | 5, 267 | nr_orients, 268 | filter_type, 269 | basis_filter_list[0], 270 | rot_matrix_list[0], 271 | activation="identity", 272 | ) 273 | #### 274 | 275 | with tf.variable_scope("us2"): 276 | us2 = upsample("us2", us1, 64) 277 | us2_sum = tf.add_n([us2, i[-3]]) 278 | us2 = g_dense_blk( 279 | "dense_us2", 280 | us2_sum, 281 | [14, 6], 282 | [7, 5], 283 | 3, 284 | nr_orients, 285 | filter_type, 286 | dense_basis_list, 287 | dense_rot_list, 288 | ) 289 | us2 = GConv2D( 290 | "us_conv2", 291 | us2, 292 | 16, 293 | 5, 294 | nr_orients, 295 | filter_type, 296 | basis_filter_list[0], 297 | rot_matrix_list[0], 298 | activation="identity", 299 | ) 300 | #### 301 | 302 | with tf.variable_scope("us3"): 303 | us3 = upsample("us3", us2, 128) 304 | us3_sum = tf.add_n([us3, i[-4]]) 305 | us3 = g_dense_blk( 306 | "dense_us3", 307 | us3_sum, 308 | [14, 6], 309 | [7, 5], 310 | 2, 311 | nr_orients, 312 | filter_type, 313 | dense_basis_list, 314 | dense_rot_list, 315 | ) 316 | us3 = GConv2D( 317 | "us_conv3", 318 | us3, 319 | 10, 320 | 5, 321 | nr_orients, 322 | filter_type, 323 | basis_filter_list[0], 324 | rot_matrix_list[0], 325 | activation="identity", 326 | ) 327 | #### 328 | 329 | with tf.variable_scope("us4"): 330 | us4 = upsample("us4", us3, 256) 331 | us4_sum = tf.add_n([us4, i[-5]]) 332 | us4 = GConv2D( 333 | "us_conv4", 334 | us4_sum, 335 | 10, 336 | 7, 337 | nr_orients, 338 | filter_type, 339 | basis_filter_list[1], 340 | rot_matrix_list[1], 341 | ) 342 | feat = GroupPool("us4", us4, nr_orients, pool_type="max") 343 | 344 | return feat 345 | 346 | 347 | class Model(ModelDesc, Config): 348 | def __init__(self, freeze=False): 349 | super(Model, self).__init__() 350 | # assert tf.test.is_gpu_available() 351 | self.freeze = freeze 352 | self.data_format = "NHWC" 353 | 354 | def _get_inputs(self): 355 | return [ 356 | InputDesc(tf.float32, [None] + self.train_input_shape + [3], "images"), 357 | InputDesc( 358 | tf.float32, [None] + self.train_output_shape + [None], "truemap-coded" 359 | ), 360 | ] 361 | 362 | # for node to receive manual info such as learning rate. 363 | def add_manual_variable(self, name, init_value, summary=True): 364 | var = tf.get_variable(name, initializer=init_value, trainable=False) 365 | if summary: 366 | tf.summary.scalar(name + "-summary", var) 367 | return 368 | 369 | def _get_optimizer(self): 370 | with tf.variable_scope("", reuse=True): 371 | lr = tf.get_variable("learning_rate") 372 | opt = self.optimizer(learning_rate=lr) 373 | return opt 374 | 375 | 376 | class Graph(Model): 377 | def _build_graph(self, inputs): 378 | 379 | is_training = get_current_tower_context().is_training 380 | 381 | images, truemap_coded = inputs 382 | orig_imgs = images 383 | 384 | pen_map = truemap_coded[..., -1] 385 | 386 | true_np = truemap_coded[..., 0] 387 | true_np = tf.cast(true_np, tf.int32) 388 | true_np = tf.identity(true_np, name="truemap-np") 389 | one_np = tf.one_hot(true_np, 2, axis=-1) 390 | true_np = tf.expand_dims(true_np, axis=-1) 391 | 392 | true_mk = truemap_coded[..., 1:4] 393 | true_mk = tf.cast(true_mk, tf.int32) 394 | true_mk = tf.identity(true_mk, name="truemap-mk") 395 | one_mk = tf.cast(true_mk, tf.float32) 396 | 397 | #### 398 | with argscope( 399 | Conv2D, 400 | activation=tf.identity, 401 | use_bias=False, # K.he initializer 402 | W_init=tf.variance_scaling_initializer(scale=2.0, mode="fan_out"), 403 | ), argscope([Conv2D], data_format=self.data_format): 404 | 405 | i = images if not self.input_norm else images / 255.0 406 | 407 | #### 408 | d = encoder( 409 | "encoder", 410 | i, 411 | self.basis_filter_list, 412 | self.rot_matrix_list, 413 | self.nr_orients, 414 | self.filter_type, 415 | is_training, 416 | ) 417 | #### 418 | feat = decoder( 419 | "decoder", 420 | d, 421 | self.basis_filter_list, 422 | self.rot_matrix_list, 423 | self.nr_orients, 424 | self.filter_type, 425 | is_training, 426 | ) 427 | 428 | feat_np = Conv2D("feat_np", feat, 96, 1, use_bias=True, nl=BNReLU) 429 | o_logi_np = Conv2D( 430 | "output_np", feat_np, 2, 1, use_bias=True, nl=tf.identity 431 | ) 432 | soft_np = tf.nn.softmax(o_logi_np, axis=-1) 433 | prob_np = tf.identity(soft_np[..., 1], name="predmap-prob") 434 | prob_np = tf.expand_dims(prob_np, -1) 435 | 436 | feat_mk = Conv2D("feat_mk", feat, 96, 1, use_bias=True, nl=BNReLU) 437 | o_logi_mk = Conv2D( 438 | "output_mk", feat_mk, 3, 1, use_bias=True, nl=tf.identity 439 | ) 440 | soft_mk = tf.nn.softmax(o_logi_mk, axis=-1) 441 | prob_mk = tf.identity(soft_mk[..., :2], name="predmap-prob") 442 | 443 | # encoded so that inference can extract all output at once 444 | predmap_coded = tf.concat([prob_np, prob_mk], axis=-1, name="predmap-coded") 445 | 446 | #### 447 | if get_current_tower_context().is_training: 448 | # ---- LOSS ----# 449 | loss = 0 450 | for term, weight in self.loss_term.items(): 451 | if term == "bce": 452 | term_loss_np = categorical_crossentropy(soft_np, one_np) 453 | term_loss_np = tf.reduce_mean(term_loss_np, name="loss-bce-np") 454 | 455 | term_loss_mk = categorical_crossentropy(soft_mk, one_mk) 456 | term_loss_mk = tf.reduce_mean( 457 | term_loss_mk * pen_map, name="loss-bce-mk" 458 | ) 459 | elif "dice" in self.loss_term: 460 | # branch 1 461 | term_loss_np = dice_loss( 462 | soft_np[..., 0], one_np[..., 0] 463 | ) + dice_loss(soft_np[..., 1], one_np[..., 1]) 464 | term_loss_np = tf.identity(term_loss_np, name="loss-dice-np") 465 | 466 | term_loss_mk = dice_loss( 467 | soft_mk[..., 0], one_mk[..., 0] 468 | ) + dice_loss(soft_mk[..., 1], one_mk[..., 1]) 469 | term_loss_mk = tf.identity(term_loss_mk, name="loss-dice-mk") 470 | else: 471 | assert False, "Not support loss term: %s" % term 472 | add_moving_summary(term_loss_np) 473 | add_moving_summary(term_loss_mk) 474 | loss += term_loss_np + term_loss_mk 475 | 476 | ### combine the loss into single cost function 477 | wd_loss = regularize_cost(".*/W", l2_regularizer(1.0e-7), name="l2_wd_loss") 478 | add_moving_summary(wd_loss) 479 | self.cost = tf.identity(loss + wd_loss, name="overall-loss") 480 | add_moving_summary(self.cost) 481 | #### 482 | 483 | add_param_summary((".*/W", ["histogram"])) # monitor W 484 | 485 | ### logging visual sthg 486 | orig_imgs = tf.cast(orig_imgs, tf.uint8) 487 | tf.summary.image("input", orig_imgs, max_outputs=1) 488 | 489 | pred_np = colorize(prob_np[..., 0], cmap="jet") 490 | true_np = colorize(true_np[..., 0], cmap="jet") 491 | 492 | pred_mk_blb = colorize(prob_mk[..., 0], cmap="jet") 493 | true_mk_blb = colorize(true_mk[..., 0], cmap="jet") 494 | pred_mk_cnt = colorize(prob_mk[..., 1], cmap="jet") 495 | true_mk_cnt = colorize(true_mk[..., 1], cmap="jet") 496 | 497 | viz = tf.concat( 498 | [ 499 | orig_imgs, 500 | pred_np, 501 | pred_mk_blb, 502 | pred_mk_cnt, 503 | true_np, 504 | true_mk_blb, 505 | true_mk_cnt, 506 | ], 507 | 2, 508 | ) 509 | 510 | viz = tf.concat([viz[0], viz[-1]], axis=0) 511 | viz = tf.expand_dims(viz, axis=0) 512 | tf.summary.image("output", viz, max_outputs=1) 513 | 514 | return 515 | -------------------------------------------------------------------------------- /src/model/utils/gconv_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Group equivariant convolution utils 3 | """ 4 | 5 | import math 6 | import numpy as np 7 | 8 | import tensorflow as tf 9 | from tensorflow.python.framework import dtypes 10 | from tensorflow.python.ops import random_ops 11 | 12 | from tensorpack import * 13 | from tensorpack.tfutils.symbolic_functions import * 14 | from tensorpack.tfutils.summary import * 15 | from tensorpack.tfutils.common import get_tf_version_tuple 16 | 17 | from matplotlib import cm 18 | 19 | from model.utils.norm_utils import * 20 | from model.utils.rotation_utils import * 21 | 22 | 23 | def GBNReLU(name, x, nr_orients): 24 | """ 25 | A shorthand of Group Equivariant BatchNormalization + ReLU. 26 | 27 | Args: 28 | name: variable scope name 29 | x: input tensor 30 | nr_orients: number of filter orientations 31 | 32 | Returns: 33 | out: normalised tensor with ReLU activation 34 | """ 35 | 36 | shape = x.get_shape().as_list() 37 | chans = shape[3] 38 | 39 | c = int(chans / nr_orients) 40 | 41 | x = tf.reshape(x, [-1, shape[1], shape[2], nr_orients, c]) 42 | bn = BatchNorm3d(name + "_bn", x) 43 | act = tf.nn.relu(bn, name="relu") 44 | out = tf.reshape(act, [-1, shape[1], shape[2], chans]) 45 | return out 46 | 47 | 48 | def GBatchNorm(name, x, nr_orients): 49 | """ 50 | Group Equivariant BatchNormalization. 51 | 52 | Args: 53 | name: variable scope name 54 | x: input tensor 55 | nr_orients: number of filter orientations 56 | 57 | Returns: 58 | out: normalised tensor 59 | """ 60 | 61 | shape = x.get_shape().as_list() 62 | chans = shape[3] 63 | 64 | c = int(chans / nr_orients) 65 | 66 | x = tf.reshape(x, [-1, shape[1], shape[2], nr_orients, c]) 67 | bn = BatchNorm3d(name + "_bn", x) 68 | out = tf.reshape(act, [-1, shape[1], shape[2], chans]) 69 | return out 70 | 71 | 72 | def get_basis_params(k_size): 73 | """ 74 | Get the filter parameters for a given kernel size 75 | 76 | Args: 77 | k_size (int): input kernel size 78 | 79 | Returns: 80 | alpha_list: list of alpha values 81 | beta_list: list of beta values 82 | bl_list: used to bandlimit high frequency filters in get_basis_filters() 83 | """ 84 | 85 | if k_size == 5: 86 | alpha_list = [0, 1, 2] 87 | beta_list = [0, 1, 2] 88 | bl_list = [0, 2, 2] 89 | if k_size == 7: 90 | alpha_list = [0, 1, 2, 3] 91 | beta_list = [0, 1, 2, 3] 92 | bl_list = [0, 2, 3, 2] 93 | if k_size == 9: 94 | alpha_list = [0, 1, 2, 3, 4] 95 | beta_list = [0, 1, 2, 3, 4] 96 | bl_list = [0, 3, 4, 4, 3] 97 | if k_size == 11: 98 | alpha_list = [0, 1, 2, 3, 4] 99 | beta_list = [1, 2, 3, 4] 100 | bl_list = [0, 3, 4, 4, 3] 101 | 102 | return alpha_list, beta_list, bl_list 103 | 104 | 105 | def get_basis_filters(alpha_list, beta_list, bl_list, k_size, eps=10 ** -8): 106 | """ 107 | Gets the atomic basis filters 108 | 109 | Args: 110 | alpha_list: list of alpha values for basis filters 111 | beta_list: list of beta values for the basis filters 112 | bl_list: bandlimit list to reduce aliasing of basis filters 113 | k_size (int): kernel size of basis filters 114 | eps=10**-8: epsilon used to prevent division by 0 115 | 116 | Returns: 117 | filter_list_bl: list of filters, with bandlimiting (bl) to reduce aliasing 118 | alpha_list_bl: corresponding list of alpha used in bandlimited filters 119 | beta_list_bl: corresponding list of beta used in bandlimited filters 120 | """ 121 | 122 | filter_list = [] 123 | freq_list = [] 124 | for beta in beta_list: 125 | for alpha in alpha_list: 126 | if alpha <= bl_list[beta]: 127 | his = k_size // 2 # half image size 128 | y_index, x_index = np.mgrid[-his : (his + 1), -his : (his + 1)] 129 | y_index *= -1 130 | z_index = x_index + 1j * y_index 131 | 132 | # convert z to natural coordinates and add eps to avoid division by zero 133 | z = z_index + eps 134 | r = np.abs(z) 135 | 136 | if beta == beta_list[-1]: 137 | sigma = 0.4 138 | else: 139 | sigma = 0.6 140 | rad_prof = np.exp(-((r - beta) ** 2) / (2 * (sigma ** 2))) 141 | c_image = rad_prof * (z / r) ** alpha 142 | c_image_norm = (math.sqrt(2) * c_image) / np.linalg.norm(c_image) 143 | 144 | # add basis filter to list 145 | filter_list.append(c_image) 146 | # add corresponding frequency of filter to list (info needed for phase manipulation) 147 | freq_list.append(alpha) 148 | 149 | filter_array = np.array(filter_list) 150 | 151 | filter_array = np.reshape( 152 | filter_array, 153 | [filter_array.shape[0], filter_array.shape[1], filter_array.shape[2], 1, 1, 1], 154 | ) 155 | return tf.convert_to_tensor(filter_array, dtype=tf.complex64), freq_list 156 | 157 | 158 | def get_rot_info(nr_orients, alpha_list): 159 | """ 160 | Generate rotation info for phase manipulation of steerable filters. 161 | Rotation is dependent onthe frequency of the filter (alpha) 162 | 163 | Args: 164 | nr_orients: number of filter rotations 165 | alpha_list: list of alpha values that detemine the frequency 166 | 167 | Returns: 168 | rot_info used to rotate steerable filters 169 | """ 170 | 171 | # Generate rotation matrix for phase manipulation of steerable function 172 | rot_list = [] 173 | for i in range(len(alpha_list)): 174 | list_tmp = [] 175 | for j in range(nr_orients): 176 | # Rotation is dependent on the frequency of the basis filter 177 | angle = (2 * np.math.pi / nr_orients) * j 178 | list_tmp.append(np.exp(-1j * alpha_list[i] * angle)) 179 | rot_list.append(list_tmp) 180 | rot_info = np.array(rot_list) 181 | 182 | # Reshape to enable matrix multiplication 183 | rot_info = np.reshape(rot_info, [rot_info.shape[0], 1, 1, 1, 1, nr_orients]) 184 | rot_info = tf.convert_to_tensor(rot_info, dtype=tf.complex64) 185 | return rot_info 186 | 187 | 188 | def GroupPool(name, x, nr_orients, pool_type="max"): 189 | """ 190 | Perform pooling along the orientation axis. 191 | 192 | Args: 193 | name: variable scope name 194 | x: input tensor 195 | nr_orients: number of filter orientations 196 | pool_type: choose either 'max' or 'mean' 197 | 198 | Returns: 199 | pool: pooled tensor 200 | """ 201 | shape = x.get_shape().as_list() 202 | new_shape = [-1, shape[1], shape[2], nr_orients, shape[3] // nr_orients] 203 | x_reshape = tf.reshape(x, new_shape) 204 | if pool_type == "max": 205 | pool = tf.reduce_max(x_reshape, 3) 206 | elif pool_type == "mean": 207 | pool = tf.reduce_mean(x_reshape, 3) 208 | else: 209 | raise ValueError("Pool type not recognised") 210 | return pool 211 | 212 | 213 | def steerable_initializer( 214 | nr_orients, factor=2.0, mode="FAN_IN", seed=None, dtype=dtypes.float32 215 | ): 216 | """ 217 | Initialise complex coefficients in accordance with Weiler et al. (https://arxiv.org/pdf/1711.07289.pdf) 218 | Note, here we use the truncated normal dist, whereas Weiler et al. uses the regular normal dist. 219 | 220 | Args: 221 | input_layer: 222 | nr_orients: number of filter orientations 223 | factor: factor used for weight init 224 | mode: 'FAN_IN' or 'FAN_OUT 225 | seed: seed for weight init 226 | dtype: data type 227 | 228 | Returns: 229 | _initializer: 230 | """ 231 | 232 | def _initializer(shape, dtype=dtype, partition_info=None): 233 | 234 | # total number of basis filters 235 | Q = shape[0] * shape[1] 236 | if mode == "FAN_IN": 237 | fan_in = shape[-2] 238 | C = fan_in 239 | # count number of input connections. 240 | elif mode == "FAN_OUT": 241 | fan_out = shape[-2] 242 | # count number of output connections. 243 | C = fan_out 244 | n = C * Q 245 | # to get stddev = math.sqrt(factor / n) need to adjust for truncated. 246 | trunc_stddev = math.sqrt(factor / n) / 0.87962566103423978 247 | return random_ops.truncated_normal(shape, 0.0, trunc_stddev, dtype, seed=seed) 248 | 249 | return _initializer 250 | 251 | 252 | def cycle_channels(filters, shape_list): 253 | """ 254 | Perform cyclic permutation of the orientation channels for kernels on the group G. 255 | 256 | Args: 257 | filters: input filters 258 | shape_list: [nr_orients_out, ksize, ksize, 259 | nr_orients_in, filters_in, filters_out] 260 | 261 | Returns: 262 | tensor of filters with channels permuted 263 | """ 264 | 265 | nr_orients_out = shape_list[0] 266 | rotated_filters = [None] * nr_orients_out 267 | for orientation in range(nr_orients_out): 268 | # [K, K, nr_orients_in, filters_in, filters_out] 269 | filters_temp = filters[orientation] 270 | # [K, K, filters_in, filters_out, nr_orients] 271 | filters_temp = tf.transpose(filters_temp, [0, 1, 3, 4, 2]) 272 | # [K * K * filters_in * filters_out, nr_orients_in] 273 | filters_temp = tf.reshape( 274 | filters_temp, 275 | [ 276 | shape_list[1] * shape_list[2] * shape_list[4] * shape_list[5], 277 | shape_list[3], 278 | ], 279 | ) 280 | # Cycle along the orientation axis 281 | roll_matrix = tf.constant( 282 | np.roll(np.identity(shape_list[3]), orientation, axis=1), dtype=tf.float32 283 | ) 284 | filters_temp = tf.matmul(filters_temp, roll_matrix) 285 | filters_temp = tf.reshape( 286 | filters_temp, 287 | [shape_list[1], shape_list[2], shape_list[4], shape_list[5], shape_list[3]], 288 | ) 289 | filters_temp = tf.transpose(filters_temp, [0, 1, 4, 2, 3]) 290 | rotated_filters[orientation] = filters_temp 291 | 292 | return tf.stack(rotated_filters) 293 | 294 | 295 | def gen_rotated_filters( 296 | w, filter_type, input_layer, nr_orients_out, basis_filters=None, rot_info=None 297 | ): 298 | """ 299 | Generate the rotated filters either by phase manipulation or direct rotation of planar filter. 300 | Cyclic permutation of channels is performed for kernels on the group G. 301 | 302 | Args: 303 | w: coefficients used to perform a linear combination of basis filters 304 | filter_type: either 'steerable' or 'standard' 305 | input_layer (bool): whether 1st layer convolution or not 306 | nr_orients_out: number of output filter orientations 307 | basis_filters: atomic basis filters 308 | rot_info: array to determine how to rotate filters 309 | 310 | Returns: 311 | rot_filters: rotated steerable basis filters, with 312 | cyclic permutation if not the first layer 313 | """ 314 | 315 | if filter_type == "steerable": 316 | # if using steerable filters, then rotate by phase manipulation 317 | 318 | rot_filters = [None] * nr_orients_out 319 | for orientation in range(nr_orients_out): 320 | rot_info_tmp = tf.expand_dims(rot_info[..., orientation], -1) 321 | filter_tmp = w * rot_info_tmp * basis_filters # phase manipulation 322 | rot_filters[orientation] = filter_tmp 323 | # [nr_orients_out, J, K, K, nr_orients_in, filters_in, filters_out] (M: nr frequencies, R: nr radial profile params) 324 | rot_filters = tf.stack(rot_filters) 325 | 326 | # Linear combination of basis filters 327 | # [nr_orients_out, K, K, nr_orients_in, filters_in, filters_out] 328 | rot_filters = tf.reduce_sum(rot_filters, axis=1) 329 | # Get real part of filters 330 | # [nr_orients_out, K, K, nr_orients_in, filters_in, filters_out] 331 | rot_filters = tf.math.real(rot_filters, name="filters") 332 | 333 | else: 334 | # if using regular kernels, rotate by sparse matrix multiplication 335 | 336 | # [K, K, nr_orients_in, filters_in, filters_out] 337 | filter_shape = w.get_shape().as_list() 338 | 339 | # Flatten the filter 340 | filter_flat = tf.reshape( 341 | w, 342 | [ 343 | filter_shape[0] * filter_shape[1], 344 | filter_shape[2] * filter_shape[3] * filter_shape[4], 345 | ], 346 | ) 347 | 348 | # Generate a set of rotated kernels via rotation matrix multiplication 349 | idx, vals = MultiRotationOperatorMatrixSparse( 350 | [filter_shape[0], filter_shape[1]], 351 | nr_orients_out, 352 | periodicity=2 * np.pi, 353 | diskMask=True, 354 | ) 355 | 356 | # Sparse rotation matrix 357 | rotOp_matrix = tf.SparseTensor( 358 | idx, 359 | vals, 360 | [ 361 | nr_orients_out * filter_shape[0] * filter_shape[1], 362 | filter_shape[0] * filter_shape[1], 363 | ], 364 | ) 365 | 366 | # Matrix multiplication 367 | rot_filters = tf.sparse_tensor_dense_matmul(rotOp_matrix, filter_flat) 368 | # [nr_orients_out * K * K, filters_in * filters_out] 369 | 370 | # Reshape the filters to [nr_orients_out, K, K, nr_orients_in, filters_in, filters_out] 371 | rot_filters = tf.reshape( 372 | rot_filters, 373 | [ 374 | nr_orients_out, 375 | filter_shape[0], 376 | filter_shape[1], 377 | filter_shape[2], 378 | filter_shape[3], 379 | filter_shape[4], 380 | ], 381 | ) 382 | 383 | # Do not cycle filter for input convolution f: Z2 -> G 384 | if input_layer is False: 385 | shape_list = rot_filters.get_shape().as_list() 386 | # cycle channels - [nr_orients_out, K, K, nr_orients_in, filters_in, filters_out] 387 | rot_filters = cycle_channels(rot_filters, shape_list) 388 | 389 | return rot_filters 390 | 391 | 392 | def GConv2D( 393 | name, 394 | inputs, 395 | filters_out, 396 | kernel_size, 397 | nr_orients, 398 | filter_type, 399 | basis_filters=None, 400 | rot_info=None, 401 | input_layer=False, 402 | strides=[1, 1, 1, 1], 403 | padding="SAME", 404 | data_format="NHWC", 405 | activation="bnrelu", 406 | use_bias=False, 407 | bias_initializer=tf.zeros_initializer(), 408 | ): 409 | """ 410 | Rotation equivatiant group convolution layer 411 | 412 | Args: 413 | name: variable scope name 414 | inputs: input tensor 415 | filters_out: number of filters out (per orientation) 416 | kernel_size: size of kernel 417 | basis_filters: atomic basis filters 418 | rot_info: array to determine how to rotate filters 419 | input_layer: whether the operation is the input layer (1st conv) 420 | strides: stride of kernel for convolution 421 | padding: choose either 'SAME' or 'VALID' 422 | data_format: either 'NHWC' or 'NCHW' 423 | activation: activation function to apply 424 | use_bias: whether to use bias 425 | bias_initializer: bias initialiser method 426 | 427 | Returns: 428 | conv: group equivariant convolution of input with 429 | steerable filters and optional activation. 430 | """ 431 | 432 | if filter_type == "steerable": 433 | assert ( 434 | basis_filters != None and rot_info != None 435 | ), "Must provide basis filters and rotation matrix" 436 | 437 | in_shape = inputs.get_shape().as_list() 438 | channel_axis = 3 if data_format == "NHWC" else 1 439 | 440 | if input_layer == False: 441 | nr_orients_in = nr_orients 442 | else: 443 | nr_orients_in = 1 444 | nr_orients_out = nr_orients 445 | 446 | filters_in = int(in_shape[channel_axis] / nr_orients_in) 447 | 448 | if filter_type == "steerable": 449 | # shape for the filter coefficients 450 | nr_b_filts = basis_filters.shape[0] 451 | w_shape = [nr_b_filts, 1, 1, nr_orients_in, filters_in, filters_out] 452 | 453 | # init complex valued weights with the adapted He init (Weiler et al.) 454 | w1 = tf.get_variable( 455 | name + "_W_real", w_shape, initializer=steerable_initializer(nr_orients_out) 456 | ) 457 | w2 = tf.get_variable( 458 | name + "_W_imag", w_shape, initializer=steerable_initializer(nr_orients_out) 459 | ) 460 | w = tf.complex(w1, w2) 461 | 462 | # Generate filters at different orientations- also perform cyclic permutation of channels if f: G -> G 463 | # Cyclic permutation of filters happenens for all rotation equivariant layers except for the input layer 464 | # [nr_orients_out, K, K, nr_orients_in, filters_in, filters_out] 465 | filters = gen_rotated_filters( 466 | w, filter_type, input_layer, nr_orients_out, basis_filters, rot_info 467 | ) 468 | 469 | else: 470 | w_shape = [kernel_size, kernel_size, nr_orients_in, filters_in, filters_out] 471 | w = tf.get_variable( 472 | name + "_W", 473 | w_shape, 474 | initializer=tf.variance_scaling_initializer(scale=2.0, mode="fan_out"), 475 | ) 476 | 477 | # Generate filters at different orientations- also perform cyclic permutation of channels if f: G -> G 478 | # Cyclic permutation of filters happenens for all rotation equivariant layers except for the input layer 479 | # [nr_orients_out, K, K, nr_orients_in, filters_in, filters_out] 480 | filters = gen_rotated_filters(w, filter_type, input_layer, nr_orients_out) 481 | 482 | # reshape filters for 2D convolution 483 | # [K, K, nr_orients_in, filters_in, nr_orients_out, filters_out] 484 | filters = tf.transpose(filters, [1, 2, 3, 4, 0, 5]) 485 | filters = tf.reshape( 486 | filters, 487 | [ 488 | kernel_size, 489 | kernel_size, 490 | nr_orients_in * filters_in, 491 | nr_orients_out * filters_out, 492 | ], 493 | ) 494 | 495 | # perform conv with rotated filters (rehshaped so we can perform 2D convolution) 496 | kwargs = dict(data_format=data_format) 497 | conv = tf.nn.conv2d(inputs, filters, strides, padding.upper(), **kwargs) 498 | if use_bias: 499 | # Use same bias for all orientations 500 | b = tf.get_variable( 501 | name + "_bias", [filters_out], initializer=tf.zeros_initializer() 502 | ) 503 | b = tf.stack([b] * nr_orients_out) 504 | b = tf.reshape(b, [nr_orients_out * filters_out]) 505 | conv = tf.nn.bias_add(conv, b) 506 | 507 | if activation == "bnrelu": 508 | # Rotation equivariant batch normalisation 509 | conv = GBNReLU(name, conv, nr_orients_out) 510 | 511 | if activation == "bn": 512 | # Rotation equivariant batch normalisation 513 | conv = GBatchNorm(name, conv, nr_orients_out) 514 | 515 | if activation == "relu": 516 | # Rotation equivariant batch normalisation 517 | conv = tf.nn.relu(conv) 518 | 519 | return conv 520 | -------------------------------------------------------------------------------- /src/model/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model utils 3 | """ 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from tensorflow.python.framework import dtypes 8 | from tensorflow.python.ops import random_ops 9 | 10 | from tensorpack import * 11 | from tensorpack.tfutils.symbolic_functions import * 12 | from tensorpack.tfutils.summary import * 13 | from tensorpack.tfutils.common import get_tf_version_tuple 14 | 15 | from matplotlib import cm 16 | 17 | from .norm_utils import * 18 | 19 | 20 | def resize_op( 21 | x, 22 | height_factor=None, 23 | width_factor=None, 24 | size=None, 25 | interp="bicubic", 26 | data_format="channels_last", 27 | ): 28 | """ 29 | Resize by a factor if `size=None` else resize to `size` 30 | """ 31 | original_shape = x.get_shape().as_list() 32 | if size is not None: 33 | if data_format == "channels_first": 34 | x = tf.transpose(x, [0, 2, 3, 1]) 35 | if interp == "bicubic": 36 | x = tf.image.resize_bicubic(x, size) 37 | elif interp == "bilinear": 38 | x = tf.image.resize_bilinear(x, size) 39 | else: 40 | x = tf.image.resize_nearest_neighbor(x, size) 41 | x = tf.transpose(x, [0, 3, 1, 2]) 42 | x.set_shape( 43 | ( 44 | None, 45 | original_shape[1] if original_shape[1] is not None else None, 46 | size[0], 47 | size[1], 48 | ) 49 | ) 50 | else: 51 | if interp == "bicubic": 52 | x = tf.image.resize_bicubic(x, size) 53 | elif interp == "bilinear": 54 | x = tf.image.resize_bilinear(x, size) 55 | else: 56 | x = tf.image.resize_nearest_neighbor(x, size) 57 | x.set_shape( 58 | ( 59 | None, 60 | size[0], 61 | size[1], 62 | original_shape[3] if original_shape[3] is not None else None, 63 | ) 64 | ) 65 | else: 66 | if data_format == "channels_first": 67 | new_shape = tf.cast(tf.shape(x)[2:], tf.float32) 68 | new_shape *= tf.constant( 69 | np.array([height_factor, width_factor]).astype("float32") 70 | ) 71 | new_shape = tf.cast(new_shape, tf.int32) 72 | x = tf.transpose(x, [0, 2, 3, 1]) 73 | if interp == "bicubic": 74 | x = tf.image.resize_bicubic(x, new_shape) 75 | elif interp == "bilinear": 76 | x = tf.image.resize_bilinear(x, new_shape) 77 | else: 78 | x = tf.image.resize_nearest_neighbor(x, new_shape) 79 | x = tf.transpose(x, [0, 3, 1, 2]) 80 | x.set_shape( 81 | ( 82 | None, 83 | original_shape[1] if original_shape[1] is not None else None, 84 | int(original_shape[2] * height_factor) 85 | if original_shape[2] is not None 86 | else None, 87 | int(original_shape[3] * width_factor) 88 | if original_shape[3] is not None 89 | else None, 90 | ) 91 | ) 92 | else: 93 | original_shape = x.get_shape().as_list() 94 | new_shape = tf.cast(tf.shape(x)[1:3], tf.float32) 95 | new_shape *= tf.constant( 96 | np.array([height_factor, width_factor]).astype("float32") 97 | ) 98 | new_shape = tf.cast(new_shape, tf.int32) 99 | if interp == "bicubic": 100 | x = tf.image.resize_bicubic(x, new_shape) 101 | elif interp == "bilinear": 102 | x = tf.image.resize_bilinear(x, new_shape) 103 | else: 104 | x = tf.image.resize_nearest_neighbor(x, new_shape) 105 | x.set_shape( 106 | ( 107 | None, 108 | int(original_shape[1] * height_factor) 109 | if original_shape[1] is not None 110 | else None, 111 | int(original_shape[2] * width_factor) 112 | if original_shape[2] is not None 113 | else None, 114 | original_shape[3] if original_shape[3] is not None else None, 115 | ) 116 | ) 117 | return x 118 | 119 | 120 | #### 121 | def crop_op(x, cropping, data_format="channels_last"): 122 | """ 123 | Center crop image 124 | 125 | Args: 126 | cropping: the substracted portion 127 | """ 128 | 129 | crop_t = cropping[0] // 2 130 | crop_b = cropping[0] - crop_t 131 | crop_l = cropping[1] // 2 132 | crop_r = cropping[1] - crop_l 133 | if data_format == "channels_first": 134 | x = x[:, :, crop_t:-crop_b, crop_l:-crop_r] 135 | else: 136 | x = x[:, crop_t:-crop_b, crop_l:-crop_r] 137 | return x 138 | 139 | 140 | def categorical_crossentropy(output, target): 141 | """ 142 | categorical cross-entropy, accept probabilities not logit 143 | 144 | Args: 145 | output: 146 | target: 147 | """ 148 | 149 | # scale preds so that the class probs of each sample sum to 1 150 | output /= tf.reduce_sum( 151 | output, reduction_indices=len(output.get_shape()) - 1, keepdims=True 152 | ) 153 | # manual computation of crossentropy 154 | epsilon = tf.convert_to_tensor(10e-8, output.dtype.base_dtype) 155 | output = tf.clip_by_value(output, epsilon, 1.0 - epsilon) 156 | return -tf.reduce_sum( 157 | target * tf.log(output), reduction_indices=len(output.get_shape()) - 1 158 | ) 159 | 160 | 161 | def dice_loss(output, target, loss_type="sorensen", axis=None, smooth=1e-3): 162 | """Soft dice (Sørensen or Jaccard) coefficient for comparing the similarity 163 | of two batch of data, usually be used for binary image segmentation 164 | i.e. labels are binary. The coefficient between 0 to 1, 1 means totally match. 165 | Parameters 166 | ----------- 167 | output : Tensor 168 | A distribution with shape: [batch_size, ....], (any dimensions). 169 | target : Tensor 170 | The target distribution, format the same with `output`. 171 | loss_type : str 172 | ``jaccard`` or ``sorensen``, default is ``jaccard``. 173 | axis : tuple of int 174 | All dimensions are reduced, default ``[1,2,3]``. 175 | smooth : float 176 | This small value will be added to the numerator and denominator. 177 | - If both output and target are empty, it makes sure dice is 1. 178 | - If either output or target are empty (all pixels are background), 179 | dice = ```smooth/(small_value + smooth)``, then if smooth is very small, 180 | dice close to 0 (even the image values lower than the threshold), 181 | so in this case, higher smooth can have a higher dice. 182 | Examples 183 | --------- 184 | >>> dice_loss = dice_coe(outputs, y_) 185 | """ 186 | 187 | target = tf.squeeze(tf.cast(target, tf.float32)) 188 | output = tf.squeeze(tf.cast(output, tf.float32)) 189 | 190 | inse = tf.reduce_sum(output * target, axis=axis) 191 | if loss_type == "jaccard": 192 | l = tf.reduce_sum(output * output, axis=axis) 193 | r = tf.reduce_sum(target * target, axis=axis) 194 | elif loss_type == "sorensen": 195 | l = tf.reduce_sum(output, axis=axis) 196 | r = tf.reduce_sum(target, axis=axis) 197 | else: 198 | raise Exception("Unknown loss_type") 199 | # already flatten 200 | dice = 1.0 - (2.0 * inse + smooth) / (l + r + smooth) 201 | ## 202 | return dice 203 | 204 | 205 | def colorize(value, vmin=None, vmax=None, cmap=None): 206 | """ 207 | Args: 208 | - value: input tensor, NHWC ('channels_last') 209 | - vmin: the minimum value of the range used for normalization. 210 | (Default: value minimum) 211 | - vmax: the maximum value of the range used for normalization. 212 | (Default: value maximum) 213 | - cmap: a valid cmap named for use with matplotlib's `get_cmap`. 214 | (Default: 'gray') 215 | 216 | Example usage: 217 | ``` 218 | output = tf.random_uniform(shape=[256, 256, 1]) 219 | output_color = colorize(output, vmin=0.0, vmax=1.0, cmap='viridis') 220 | tf.summary.image('output', output_color) 221 | ``` 222 | 223 | Returns: 224 | 3D tensor of shape [height, width, 3], uint8. 225 | """ 226 | 227 | # normalize 228 | if vmin is None: 229 | vmin = tf.reduce_min(value, axis=[1, 2]) 230 | vmin = tf.reshape(vmin, [-1, 1, 1]) 231 | if vmax is None: 232 | vmax = tf.reduce_max(value, axis=[1, 2]) 233 | vmax = tf.reshape(vmax, [-1, 1, 1]) 234 | value = (value - vmin) / (vmax - vmin) # vmin..vmax 235 | 236 | # squeeze last dim if it exists 237 | # NOTE: will throw error if use get_shape() 238 | # value = tf.squeeze(value) 239 | 240 | # quantize 241 | value = tf.round(value * 255) 242 | indices = tf.cast(value, np.int32) 243 | 244 | # gather 245 | colormap = cm.get_cmap(cmap if cmap is not None else "gray") 246 | colors = colormap(np.arange(256))[:, :3] 247 | colors = tf.constant(colors, dtype=tf.float32) 248 | value = tf.gather(colors, indices) 249 | value = tf.cast(value * 255, tf.uint8) 250 | return value 251 | 252 | 253 | @layer_register(use_scope=None) 254 | def BNELU(x, name=None): 255 | """ 256 | A shorthand of BatchNormalization + ELU. 257 | 258 | Args: 259 | x (tf.Tensor): the input 260 | name: deprecated, don't use. 261 | """ 262 | if name is not None: 263 | log_deprecated("BNReLU(name=...)", "The output tensor will be named `output`.") 264 | 265 | x = BatchNorm("bn", x) 266 | x = tf.nn.elu(x, name=name) 267 | return x 268 | 269 | -------------------------------------------------------------------------------- /src/model/utils/norm_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File: custom_ops.py 3 | ### 4 | # https://github.com/tensorpack/tensorpack/blob/master/tensorpack/models/batch_norm.py 5 | ### 6 | 7 | import tensorflow as tf 8 | from tensorflow.contrib.framework import add_model_variable 9 | from tensorflow.python.training import moving_averages 10 | import re 11 | import six 12 | import functools 13 | 14 | from tensorpack.utils import logger 15 | from tensorpack.utils.argtools import get_data_format 16 | from tensorpack.tfutils.tower import get_current_tower_context 17 | from tensorpack.tfutils.common import get_tf_version_tuple 18 | from tensorpack.tfutils.collection import backup_collection, restore_collection 19 | from tensorpack import layer_register, VariableHolder 20 | from tensorpack.tfutils.varreplace import custom_getter_scope 21 | 22 | 23 | def rename_get_variable(mapping): 24 | """ 25 | Args: 26 | 27 | mapping(dict): an old -> new mapping for variable basename. e.g. {'kernel': 'W'} 28 | 29 | Returns: 30 | A context where the variables are renamed. 31 | """ 32 | def custom_getter(getter, name, *args, **kwargs): 33 | splits = name.split('/') 34 | basename = splits[-1] 35 | if basename in mapping: 36 | basename = mapping[basename] 37 | splits[-1] = basename 38 | name = '/'.join(splits) 39 | return getter(name, *args, **kwargs) 40 | return custom_getter_scope(custom_getter) 41 | 42 | 43 | def map_common_tfargs(kwargs): 44 | df = kwargs.pop('data_format', None) 45 | if df is not None: 46 | df = get_data_format(df, tfmode=True) 47 | kwargs['data_format'] = df 48 | 49 | old_nl = kwargs.pop('nl', None) 50 | if old_nl is not None: 51 | kwargs['activation'] = lambda x, name=None: old_nl(x, name=name) 52 | 53 | if 'W_init' in kwargs: 54 | kwargs['kernel_initializer'] = kwargs.pop('W_init') 55 | 56 | if 'b_init' in kwargs: 57 | kwargs['bias_initializer'] = kwargs.pop('b_init') 58 | return kwargs 59 | 60 | 61 | def convert_to_tflayer_args(args_names, name_mapping): 62 | """ 63 | After applying this decorator: 64 | 1. data_format becomes tf.layers style 65 | 2. nl becomes activation 66 | 3. initializers are renamed 67 | 4. positional args are transformed to correspoding kwargs, according to args_names 68 | 5. kwargs are mapped to tf.layers names if needed, by name_mapping 69 | """ 70 | 71 | def decorator(func): 72 | @functools.wraps(func) 73 | def decorated_func(inputs, *args, **kwargs): 74 | kwargs = map_common_tfargs(kwargs) 75 | 76 | posarg_dic = {} 77 | assert len(args) <= len(args_names), \ 78 | "Please use kwargs instead of positional args to call this model, " \ 79 | "except for the following arguments: {}".format( 80 | ', '.join(args_names)) 81 | for pos_arg, name in zip(args, args_names): 82 | posarg_dic[name] = pos_arg 83 | 84 | ret = {} 85 | for name, arg in six.iteritems(kwargs): 86 | newname = name_mapping.get(name, None) 87 | if newname is not None: 88 | assert newname not in kwargs, \ 89 | "Argument {} and {} conflicts!".format(name, newname) 90 | else: 91 | newname = name 92 | ret[newname] = arg 93 | # Let pos arg overwrite kw arg, for argscope to work 94 | ret.update(posarg_dic) 95 | 96 | return func(inputs, **ret) 97 | 98 | return decorated_func 99 | 100 | return decorator 101 | 102 | 103 | __all__ = ['BatchNorm3d'] 104 | 105 | 106 | def get_bn_variables(n_out, use_scale, use_bias, beta_init, gamma_init): 107 | if use_bias: 108 | beta = tf.get_variable('beta', [n_out], initializer=beta_init) 109 | else: 110 | beta = tf.zeros([n_out], name='beta') 111 | if use_scale: 112 | gamma = tf.get_variable('gamma', [n_out], initializer=gamma_init) 113 | else: 114 | gamma = tf.ones([n_out], name='gamma') 115 | # x * gamma + beta 116 | 117 | moving_mean = tf.get_variable('mean/EMA', [n_out], 118 | initializer=tf.constant_initializer(), trainable=False) 119 | moving_var = tf.get_variable('variance/EMA', [n_out], 120 | initializer=tf.constant_initializer(1.0), trainable=False) 121 | 122 | if get_current_tower_context().is_main_training_tower: 123 | for v in [moving_mean, moving_var]: 124 | tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v) 125 | return beta, gamma, moving_mean, moving_var 126 | 127 | 128 | def internal_update_bn_ema(xn, batch_mean, batch_var, 129 | moving_mean, moving_var, decay): 130 | update_op1 = moving_averages.assign_moving_average( 131 | moving_mean, batch_mean, decay, zero_debias=False, 132 | name='mean_ema_op') 133 | update_op2 = moving_averages.assign_moving_average( 134 | moving_var, batch_var, decay, zero_debias=False, 135 | name='var_ema_op') 136 | 137 | # When sync_statistics is True, always enable internal_update. 138 | # Otherwise the update ops (only executed on main tower) 139 | # will hang when some BatchNorm layers are unused (https://github.com/tensorpack/tensorpack/issues/1078) 140 | with tf.control_dependencies([update_op1, update_op2]): 141 | return tf.identity(xn, name='output') 142 | 143 | 144 | @layer_register() 145 | @convert_to_tflayer_args( 146 | args_names=[], 147 | name_mapping={ 148 | 'use_bias': 'center', 149 | 'use_scale': 'scale', 150 | 'gamma_init': 'gamma_initializer', 151 | 'decay': 'momentum', 152 | 'use_local_stat': 'training' 153 | }) 154 | def BatchNorm3d(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, 155 | center=True, scale=True, 156 | beta_initializer=tf.zeros_initializer(), 157 | gamma_initializer=tf.ones_initializer(), 158 | virtual_batch_size=None, 159 | data_format='channels_last', 160 | internal_update=False, 161 | sync_statistics=None): 162 | """ 163 | Almost equivalent to `tf.layers.batch_normalization`, but different (and more powerful) 164 | in the following: 165 | 1. Accepts an alternative `data_format` option when `axis` is None. For 2D input, this argument will be ignored. 166 | 2. Default value for `momentum` and `epsilon` is different. 167 | 3. Default value for `training` is automatically obtained from tensorpack's `TowerContext`, but can be overwritten. 168 | 4. Support the `internal_update` option, which enables the use of BatchNorm layer inside conditionals. 169 | 5. Support the `sync_statistics` option, which is very useful in small-batch models. 170 | Args: 171 | internal_update (bool): if False, add EMA update ops to 172 | `tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer by control dependencies. 173 | They are very similar in speed, but `internal_update=True` can be used 174 | when you have conditionals in your model, or when you have multiple networks to train. 175 | Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/14699 176 | sync_statistics: either None or "nccl". By default (None), it uses statistics of the input tensor to normalize. 177 | When set to "nccl", this layer must be used under tensorpack multi-gpu trainers, 178 | and it then uses per-machine (multiple GPU) statistics to normalize. 179 | Note that this implementation averages the per-tower E[x] and E[x^2] among towers to compute 180 | global mean&variance. The result is the global mean&variance only if each tower has the same batch size. 181 | This option has no effect when not training. 182 | This option is also known as "Cross-GPU BatchNorm" as mentioned in https://arxiv.org/abs/1711.07240. 183 | Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/18222 184 | Variable Names: 185 | * ``beta``: the bias term. Will be zero-inited by default. 186 | * ``gamma``: the scale term. Will be one-inited by default. 187 | * ``mean/EMA``: the moving average of mean. 188 | * ``variance/EMA``: the moving average of variance. 189 | Note: 190 | Combinations of ``training`` and ``ctx.is_training``: 191 | * ``training == ctx.is_training``: standard BN, EMA are maintained during training 192 | and used during inference. This is the default. 193 | * ``training and not ctx.is_training``: still use batch statistics in inference. 194 | * ``not training and ctx.is_training``: use EMA to normalize in 195 | training. This is useful when you load a pre-trained BN and 196 | don't want to fine tune the EMA. EMA will not be updated in 197 | this case. 198 | """ 199 | # parse shapes 200 | data_format = get_data_format(data_format, tfmode=False) 201 | shape = inputs.get_shape().as_list() 202 | ndims = len(shape) 203 | # in 3d conv, we have 5d dim [batch, h, w, t, c] 204 | if sync_statistics is not None: 205 | sync_statistics = sync_statistics.lower() 206 | assert sync_statistics in [None, 'nccl', 'horovod'], sync_statistics 207 | 208 | if axis is None: 209 | assert ndims == 5, 'Number of input dims must be 5 for 3d Batch Norm' 210 | axis = 1 if data_format == 'NCHW' else 4 211 | 212 | num_chan = shape[axis] 213 | 214 | # parse training/ctx 215 | ctx = get_current_tower_context() 216 | if training is None: 217 | training = ctx.is_training 218 | training = bool(training) 219 | TF_version = get_tf_version_tuple() 220 | if not training and ctx.is_training: 221 | assert TF_version >= (1, 4), \ 222 | "Fine tuning a BatchNorm model with fixed statistics is only " \ 223 | "supported after https://github.com/tensorflow/tensorflow/pull/12580 " 224 | if ctx.is_main_training_tower: # only warn in first tower 225 | logger.warn( 226 | "[BatchNorm] Using moving_mean/moving_variance in training.") 227 | # Using moving_mean/moving_variance in training, which means we 228 | # loaded a pre-trained BN and only fine-tuning the affine part. 229 | 230 | if sync_statistics is None or not (training and ctx.is_training): 231 | coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS]) 232 | with rename_get_variable( 233 | {'moving_mean': 'mean/EMA', 234 | 'moving_variance': 'variance/EMA'}): 235 | tf_args = dict( 236 | axis=axis, 237 | momentum=momentum, epsilon=epsilon, 238 | center=center, scale=scale, 239 | beta_initializer=beta_initializer, 240 | gamma_initializer=gamma_initializer, 241 | fused=True, 242 | _reuse=tf.get_variable_scope().reuse) 243 | if TF_version >= (1, 5): 244 | tf_args['virtual_batch_size'] = virtual_batch_size 245 | else: 246 | assert virtual_batch_size is None, "Feature not supported in this version of TF!" 247 | layer = tf.layers.BatchNormalization(**tf_args) 248 | xn = layer.apply(inputs, training=training, 249 | scope=tf.get_variable_scope()) 250 | 251 | # maintain EMA only on one GPU is OK, even in replicated mode. 252 | # because during training, EMA isn't used 253 | if ctx.is_main_training_tower: 254 | for v in layer.non_trainable_variables: 255 | add_model_variable(v) 256 | if not ctx.is_main_training_tower or internal_update: 257 | restore_collection(coll_bk) 258 | 259 | if training and internal_update: 260 | assert layer.updates 261 | with tf.control_dependencies(layer.updates): 262 | ret = tf.identity(xn, name='output') 263 | else: 264 | ret = tf.identity(xn, name='output') 265 | 266 | vh = ret.variables = VariableHolder( 267 | moving_mean=layer.moving_mean, 268 | mean=layer.moving_mean, # for backward-compatibility 269 | moving_variance=layer.moving_variance, 270 | variance=layer.moving_variance) # for backward-compatibility 271 | if scale: 272 | vh.gamma = layer.gamma 273 | if center: 274 | vh.beta = layer.beta 275 | else: 276 | red_axis = [0] if ndims == 2 else ( 277 | [0, 2, 3] if axis == 1 else [0, 1, 2]) 278 | if ndims == 5: 279 | red_axis = [0, 2, 3, 4] if axis == 1 else [0, 1, 2, 3] 280 | new_shape = None # don't need to reshape unless ... 281 | if ndims == 4 and axis == 1: 282 | new_shape = [1, num_chan, 1, 1] 283 | if ndims == 5 and axis == 1: 284 | new_shape = [1, num_chan, 1, 1, 1] 285 | 286 | batch_mean = tf.reduce_mean(inputs, axis=red_axis) 287 | batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis) 288 | 289 | if sync_statistics == 'nccl': 290 | if six.PY3 and TF_version <= (1, 8) and ctx.is_main_training_tower: 291 | logger.warn("A TensorFlow bug will cause cross-GPU BatchNorm to fail. " 292 | "Apply this patch: https://github.com/tensorflow/tensorflow/pull/20360") 293 | 294 | from tensorflow.contrib.nccl.ops import gen_nccl_ops 295 | shared_name = re.sub( 296 | 'tower[0-9]+/', '', tf.get_variable_scope().name) 297 | num_dev = ctx.total 298 | batch_mean = gen_nccl_ops.nccl_all_reduce( 299 | input=batch_mean, 300 | reduction='sum', 301 | num_devices=num_dev, 302 | shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev) 303 | batch_mean_square = gen_nccl_ops.nccl_all_reduce( 304 | input=batch_mean_square, 305 | reduction='sum', 306 | num_devices=num_dev, 307 | shared_name=shared_name + '_NCCL_mean_square') * (1.0 / num_dev) 308 | elif sync_statistics == 'horovod': 309 | # Require https://github.com/uber/horovod/pull/331 310 | # Proof-of-concept, not ready yet. 311 | import horovod.tensorflow as hvd 312 | batch_mean = hvd.allreduce(batch_mean, average=True) 313 | batch_mean_square = hvd.allreduce(batch_mean_square, average=True) 314 | batch_var = batch_mean_square - tf.square(batch_mean) 315 | batch_mean_vec = batch_mean 316 | batch_var_vec = batch_var 317 | 318 | beta, gamma, moving_mean, moving_var = get_bn_variables( 319 | num_chan, scale, center, beta_initializer, gamma_initializer) 320 | if new_shape is not None: 321 | batch_mean = tf.reshape(batch_mean, new_shape) 322 | batch_var = tf.reshape(batch_var, new_shape) 323 | # Using fused_batch_norm(is_training=False) is actually slightly faster, 324 | # but hopefully this call will be JITed in the future. 325 | xn = tf.nn.batch_normalization( 326 | inputs, batch_mean, batch_var, 327 | tf.reshape(beta, new_shape), 328 | tf.reshape(gamma, new_shape), epsilon) 329 | else: 330 | xn = tf.nn.batch_normalization( 331 | inputs, batch_mean, batch_var, 332 | beta, gamma, epsilon) 333 | 334 | if ctx.is_main_training_tower: 335 | ret = update_bn_ema( 336 | xn, batch_mean_vec, batch_var_vec, moving_mean, moving_var, 337 | momentum, internal_update) 338 | else: 339 | ret = tf.identity(xn, name='output') 340 | 341 | vh = ret.variables = VariableHolder( 342 | moving_mean=moving_mean, 343 | mean=moving_mean, # for backward-compatibility 344 | moving_variance=moving_var, 345 | variance=moving_var) # for backward-compatibility 346 | if scale: 347 | vh.gamma = gamma 348 | if center: 349 | vh.beta = beta 350 | return ret 351 | -------------------------------------------------------------------------------- /src/model/utils/rotation_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | The below functions are adapted from www.github.com/tueimage/SE2CNN 4 | 5 | Released in June 2018 6 | @author: EJ Bekkers, Eindhoven University of Technology, The Netherlands 7 | @author: MW Lafarge, Eindhoven University of Technology, The Netherlands 8 | ________________________________________________________________________ 9 | 10 | Copyright 2018 Erik J Bekkers and Maxime W Lafarge, Eindhoven University 11 | of Technology, the Netherlands 12 | 13 | Licensed under the Apache License, Version 2.0 (the "License"); 14 | you may not use this file except in compliance with the License. 15 | You may obtain a copy of the License at 16 | 17 | http://www.apache.org/licenses/LICENSE-2.0 18 | 19 | Unless required by applicable law or agreed to in writing, software 20 | distributed under the License is distributed on an "AS IS" BASIS, 21 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 22 | See the License for the specific language governing permissions and 23 | limitations under the License. 24 | ________________________________________________________________________ 25 | """ 26 | 27 | import numpy as np 28 | import math as m 29 | 30 | 31 | def CoordRotationInv(ij, NiNj, theta): 32 | """ Appplies the inverse rotation transformation on input coordinates (i,j). 33 | The rotation is around the center of the image with dimensions Ni, Nj 34 | (resp. # of rows and colums). Input theta is the applied rotation. 35 | 36 | INPUT: 37 | - ij, a list of length 2 containing the i and j coordinate as [i,j] 38 | - NiNj, a list of length 2 containing the dimensions of the 2D domain 39 | as [Ni, Nj] 40 | - theta, a real number specifying the angle of rotations 41 | 42 | OUTPUT: 43 | - ijOld, a list of length 2 containing the new coordinate of the 44 | inverse rotation, i.e., the old coordinate which was mapped to 45 | the new one via (forward) rotation over theta. 46 | """ 47 | 48 | # Define the center of rotation 49 | centeri = m.floor(NiNj[0] / 2) 50 | centerj = m.floor(NiNj[1] / 2) 51 | 52 | # Compute the output of the inverse rotation transformation 53 | ijOld = np.zeros([2]) 54 | ijOld[0] = m.cos(theta) * (ij[0] - centeri) + \ 55 | m.sin(theta) * (ij[1] - centerj) + centeri 56 | ijOld[1] = -1 * m.sin(theta) * (ij[0] - centeri) + \ 57 | m.cos(theta) * (ij[1] - centerj) + centerj 58 | 59 | # Return the "old" indices 60 | return ijOld 61 | 62 | 63 | def LinIntIndicesAndWeights(ij, NiNj): 64 | """ Returns, given a target index (i,j), the 4 neighbouring indices and 65 | their corresponding weights used for linear interpolation. 66 | 67 | INPUT: 68 | - ij, a list of length 2 containing the i and j coordinate 69 | as [i,j] 70 | - NiNj, a list of length 2 containing the dimensions of the 2D 71 | domain as [Ni,Nj] 72 | 73 | OUTPUT: 74 | - indicesAndWeights, a list index-weight pairs as [[i0,j0,w00], 75 | [i0,j1,w01],...] 76 | """ 77 | 78 | # The index where want to obtain the value 79 | i = ij[0] 80 | j = ij[1] 81 | # Image size 82 | Ni = NiNj[0] 83 | Nj = NiNj[1] 84 | 85 | # The neighbouring indices 86 | i1 = int(m.floor(i)) # -- to integer format 87 | i2 = i1 + 1 88 | j1 = int(m.floor(j)) # -- to integer format 89 | j2 = j1 + 1 90 | 91 | # The 1D weights 92 | ti = i - i1 93 | tj = j - j1 94 | 95 | # The 2D weights 96 | w11 = (1 - ti) * (1 - tj) 97 | w12 = (1 - ti) * tj 98 | w21 = ti * (1 - tj) 99 | w22 = ti * tj 100 | 101 | # Only add indices and weights if they fall in the range of the image with 102 | # dimensions NiNj 103 | indicesAndWeights = [] 104 | if (0 <= i1 < Ni) and (0 <= j1 < Nj): 105 | indicesAndWeights.append([i1, j1, w11]) 106 | if (0 <= i1 < Ni) and (0 <= j2 < Nj): 107 | indicesAndWeights.append([i1, j2, w12]) 108 | if (0 <= i2 < Ni) and (0 <= j1 < Nj): 109 | indicesAndWeights.append([i2, j1, w21]) 110 | if (0 <= i2 < Ni) and (0 <= j2 < Nj): 111 | indicesAndWeights.append([i2, j2, w22]) 112 | 113 | return indicesAndWeights 114 | 115 | 116 | def ToLinearIndex(ij, NiNj): 117 | """ Returns the linear index of a flattened 2D image that has dimensions 118 | [Ni,Nj] before flattening. 119 | 120 | INPUT: 121 | - ij, a list of length 2 containing the i and j coordinate 122 | as [i,j] 123 | - NiNj, a list of length 2 containing the dimensions of the 2D 124 | domain as [Ni,Nj] 125 | 126 | OUTPUT: 127 | - ijFlat = ij[0] * NiNj[0] + ij[1] 128 | """ 129 | 130 | return ij[0] * NiNj[0] + ij[1] 131 | 132 | 133 | def RotationOperatorMatrix(NiNj, theta, diskMask=True): 134 | """ Returns the matrix that rotates a square image by R.f, where f is the 135 | flattend image (a vector of length Ni*Nj). 136 | The resulting vector needs to be repartitioned in to a [Ni,Nj] sized image later. 137 | 138 | INPUT: 139 | - NiNj, a list of length 2 containing the dimensions of the 2D 140 | domain as [Ni,Nj] 141 | - theta, a real number specifying the rotation angle 142 | 143 | INPUT (optional): 144 | - diskMask = True, by default values outside a circular mask are set 145 | to zero. 146 | 147 | OUTPUT: 148 | - rotationMatrix, a np.array of dimensions [Ni*Nj,Ni*Nj] 149 | """ 150 | 151 | # Image size 152 | Ni = NiNj[0] 153 | Nj = NiNj[1] 154 | cij = m.floor(Ni / 2) # center 155 | 156 | # Fill the rotation operator matrix 157 | rotationMatrix = np.zeros([Ni * Nj, Ni * Nj]) 158 | for i in range(0, NiNj[0]): 159 | for j in range(0, NiNj[0]): 160 | # Apply a circular mask (disk matrix) if desired 161 | if not(diskMask) or ((i - cij) * (i - cij) + (j - cij) * (j - cij) <= (cij + 0.5) * (cij + 0.5)): 162 | # The row index of the operator matrix 163 | linij = ToLinearIndex([i, j], NiNj) 164 | # The interpolation points 165 | ijOld = CoordRotationInv([i, j], NiNj, theta) 166 | # The indices used for interpolation and their weights 167 | linIntIndicesAndWeights = LinIntIndicesAndWeights(ijOld, NiNj) 168 | # Fill the weights in the rotationMatrix 169 | for indexAndWeight in linIntIndicesAndWeights: 170 | indexOld = [indexAndWeight[0], indexAndWeight[1]] 171 | linIndexOld = ToLinearIndex(indexOld, NiNj) 172 | weight = indexAndWeight[2] 173 | rotationMatrix[linij, linIndexOld] = weight 174 | return rotationMatrix 175 | 176 | 177 | def MultiRotationOperatorMatrix(NiNj, Ntheta, periodicity=2 * np.pi, diskMask=True): 178 | """ Concatenates multiple operator matrices along the first dimension for a 179 | direct multi-orientation transformation. 180 | I.e., this function returns the matrix that rotates a square image over several angles via R.f, 181 | where f is the flattend image (a vector of length Ni*Nj). 182 | The dimensions of R are [Ntheta*Ni*Nj], with Ntheta the number of orientations 183 | sampled from 0 to "periodicity". 184 | The resulting vector needs to be repartitioned into a [Ntheta,Ni,Nj] stack of rotated images later. 185 | 186 | INPUT: 187 | - NiNj, a list of length 2 containing the dimensions of the 2D domain as [Ni,Nj] 188 | - nTheta, an integer specifying the number of rotations 189 | 190 | INPUT (optional): 191 | - periodicity = 2*np.pi, by default rotations from 0 to 2 pi are considered. 192 | - diskMask = True, by default values outside a circular mask are set to zero. 193 | 194 | OUTPUT: 195 | - rotationMatrix, a np.array of dimensions [Ntheta*Ni*Nj,Ni*Nj] 196 | """ 197 | matrices = [None] * Ntheta 198 | for r in range(Ntheta): 199 | matrices[r] = RotationOperatorMatrix( 200 | NiNj, 201 | periodicity * r / Ntheta, 202 | diskMask=diskMask) 203 | return np.concatenate(matrices, axis=0) 204 | 205 | 206 | def RotationOperatorMatrixSparse(NiNj, theta, diskMask=True, linIndOffset=0): 207 | """ Returns the idx and vals, where idx is a tuple of 2D indices (also as tuples) and vals the corresponding values. 208 | The indices and weights can be converted to a spare tensorflow matrix via 209 | R = tf.SparseTensor(idx,vals,[Ni*Nj,Ni*Nj]) 210 | The resulting matrix rotates a square image by R.f, where f is the flattend image (a vector of length Ni*Nj). 211 | The resulting vector needs to be repartitioned in to a [Ni,Nj] sized image later. 212 | 213 | INPUT: 214 | - NiNj, a list of length 2 containing the dimensions of the 2D 215 | domain as [Ni,Nj] 216 | - theta, a real number specifying the rotation angle 217 | 218 | INPUT (optional): 219 | - diskMask = True, by default values outside a circular mask are set 220 | to zero. 221 | 222 | OUTPUT: 223 | - idx, a tuple containing the non-zero indices (tuples of length 2) 224 | - vals, the corresponding values at these indices 225 | """ 226 | 227 | # Image size 228 | Ni = NiNj[0] 229 | Nj = NiNj[1] 230 | cij = m.floor(Ni / 2) # center 231 | 232 | # Fill the rotation operator matrix 233 | # rotationMatrix = np.zeros([Ni * Nj, Ni * Nj]) 234 | idx = [] # This will contain a list of index tuples 235 | vals = [] # This will contain the corresponding weights 236 | for i in range(0, NiNj[0]): 237 | for j in range(0, NiNj[0]): 238 | # Apply a circular mask (disk matrix) if desired 239 | if not(diskMask) or ((i - cij) * (i - cij) + (j - cij) * (j - cij) <= (cij + 0.5) * (cij + 0.5)): 240 | # The row index of the operator matrix 241 | linij = ToLinearIndex([i, j], NiNj) 242 | # The interpolation points 243 | ijOld = CoordRotationInv([i, j], NiNj, theta) 244 | # The indices used for interpolation and their weights 245 | linIntIndicesAndWeights = LinIntIndicesAndWeights(ijOld, NiNj) 246 | indicesAndWeights = linIntIndicesAndWeights 247 | # Fill the weights in the rotationMatrix 248 | for indexAndWeight in linIntIndicesAndWeights: 249 | indexOld = [indexAndWeight[0], indexAndWeight[1]] 250 | linIndexOld = ToLinearIndex(indexOld, NiNj) 251 | weight = indexAndWeight[2] 252 | idx = idx + [(linij + linIndOffset, linIndexOld)] 253 | vals = vals + [weight] 254 | 255 | # Return the indices and weights as tuples 256 | return tuple(idx), tuple(vals) 257 | 258 | 259 | def MultiRotationOperatorMatrixSparse(NiNj, Ntheta, periodicity=2 * np.pi, diskMask=True): 260 | """ Returns the idx and vals, where idx is a tuple of 2D indices (also as tuples) and vals the corresponding values. 261 | The indices and weights can be converted to a sparse tensorflow matrix via 262 | R = tf.SparseTensor(idx,vals,[Ntheta*Ni*Nj,Ni*Nj]). 263 | This matrix rotates a square image over several angles via R.f, 264 | where f is the flattend image (a vector of length Ni*Nj). 265 | The dimensions of R are [Ntheta*Ni*Nj], with Ntheta the number of orientations 266 | sampled from 0 to "periodicity". 267 | The resulting vector needs to be repartitioned into a [Ntheta,Ni,Nj] stack of rotated images later. 268 | 269 | INPUT: 270 | - NiNj, a list of length 2 containing the dimensions of the 2D 271 | domain as [Ni,Nj] 272 | - nTheta, an integer specifying the number of rotations 273 | 274 | INPUT (optional): 275 | - periodicity = 2*np.pi, by default rotations from 0 to 2 pi are 276 | considered. 277 | - diskMask = True, by default values outside a circular mask are set 278 | to zero. 279 | 280 | OUTPUT: 281 | - idx, a tuple containing the non-zero indices (tuples of length 2) 282 | - vals, the corresponding values at these indices 283 | """ 284 | idx = () 285 | vals = () 286 | for r in range(Ntheta): 287 | idxr, valsr = RotationOperatorMatrixSparse( 288 | NiNj, periodicity * r / Ntheta, 289 | linIndOffset=r * NiNj[0] * NiNj[1], 290 | diskMask=diskMask) 291 | idx = idx + idxr 292 | vals = vals + valsr 293 | return idx, vals 294 | -------------------------------------------------------------------------------- /src/opt/augs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Augmentation pipeline 3 | """ 4 | 5 | from tensorpack import imgaug 6 | import cv2 7 | from loader.custom_augs import ( 8 | BinarizeLabel, 9 | GaussianBlur, 10 | MedianBlur, 11 | GenInstanceContourMap, 12 | GenInstanceMarkerMap, 13 | ) 14 | 15 | # refer to https://tensorpack.readthedocs.io/modules/dataflow.imgaug.html for 16 | # information on how to modify the augmentation parameters 17 | 18 | 19 | def get_train_augmentors(self, input_shape, output_shape, view=False): 20 | print(input_shape, output_shape) 21 | if self.model_mode == "class_rotmnist": 22 | shape_augs = [ 23 | imgaug.Affine( 24 | rotate_max_deg=359, interp=cv2.INTER_NEAREST, border=cv2.BORDER_CONSTANT 25 | ), 26 | ] 27 | 28 | input_augs = [] 29 | 30 | else: 31 | shape_augs = [ 32 | imgaug.Affine( 33 | rotate_max_deg=359, 34 | translate_frac=(0.01, 0.01), 35 | interp=cv2.INTER_NEAREST, 36 | border=cv2.BORDER_REFLECT, 37 | ), 38 | imgaug.Flip(vert=True), 39 | imgaug.Flip(horiz=True), 40 | imgaug.CenterCrop(input_shape), 41 | ] 42 | 43 | input_augs = [ 44 | imgaug.RandomApplyAug( 45 | imgaug.RandomChooseAug( 46 | [GaussianBlur(), MedianBlur(), imgaug.GaussianNoise(),] 47 | ), 48 | 0.5, 49 | ), 50 | # Standard colour augmentation 51 | imgaug.RandomOrderAug( 52 | [ 53 | imgaug.Hue((-8, 8), rgb=True), 54 | imgaug.Saturation(0.2, rgb=True), 55 | imgaug.Brightness(26, clip=True), 56 | imgaug.Contrast((0.75, 1.25), clip=True), 57 | ] 58 | ), 59 | imgaug.ToUint8(), 60 | ] 61 | 62 | if self.model_mode == "seg_gland": 63 | label_augs = [] 64 | label_augs = [GenInstanceContourMap(mode=self.model_mode)] 65 | label_augs.append(BinarizeLabel()) 66 | 67 | if not view: 68 | label_augs.append(imgaug.CenterCrop(output_shape)) 69 | else: 70 | label_augs.append(imgaug.CenterCrop(input_shape)) 71 | 72 | return shape_augs, input_augs, label_augs 73 | elif self.model_mode == "seg_nuc": 74 | label_augs = [] 75 | label_augs = [GenInstanceMarkerMap()] 76 | label_augs.append(BinarizeLabel()) 77 | if not view: 78 | label_augs.append(imgaug.CenterCrop(output_shape)) 79 | else: 80 | label_augs.append(imgaug.CenterCrop(input_shape)) 81 | 82 | return shape_augs, input_augs, label_augs 83 | 84 | else: 85 | return shape_augs, input_augs 86 | 87 | 88 | def get_valid_augmentors(self, input_shape, output_shape, view=False): 89 | print(input_shape, output_shape) 90 | shape_augs = [ 91 | imgaug.CenterCrop(input_shape), 92 | ] 93 | input_augs = [] 94 | 95 | if self.model_mode == "seg_gland": 96 | label_augs = [] 97 | label_augs = [GenInstanceContourMap(mode=self.model_mode)] 98 | label_augs.append(BinarizeLabel()) 99 | 100 | if not view: 101 | label_augs.append(imgaug.CenterCrop(output_shape)) 102 | else: 103 | label_augs.append(imgaug.CenterCrop(input_shape)) 104 | 105 | return shape_augs, input_augs, label_augs 106 | elif self.model_mode == "seg_nuc": 107 | label_augs = [] 108 | label_augs = [GenInstanceMarkerMap()] 109 | label_augs.append(BinarizeLabel()) 110 | if not view: 111 | label_augs.append(imgaug.CenterCrop(output_shape)) 112 | else: 113 | label_augs.append(imgaug.CenterCrop(input_shape)) 114 | 115 | return shape_augs, input_augs, label_augs 116 | else: 117 | return shape_augs, input_augs 118 | -------------------------------------------------------------------------------- /src/opt/params.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model hyperparameters 3 | """ 4 | 5 | import tensorflow as tf 6 | 7 | class_pcam = { 8 | "train_input_shape": [96, 96], 9 | "train_output_shape": [1, 1], 10 | "infer_input_shape": [96, 96], 11 | "infer_output_shape": [1, 1], 12 | "input_chans": 3, 13 | "label_names": ["Non-Tumour", "Tumour"], 14 | "filter_sizes": [5, 7, 9], 15 | "input_norm": True, 16 | "training_phase": [ 17 | { 18 | "nr_epochs": 50, 19 | "manual_parameters": { 20 | # tuple(initial value, schedule) 21 | "learning_rate": ( 22 | 5.0e-5, 23 | [ 24 | ("15", 1.0e-5), 25 | ("25", 1.0e-5), 26 | ("35", 1.0e-5), 27 | ("40", 2.0e-5), 28 | ("45", 1.0e-5), 29 | ], 30 | ), 31 | }, 32 | "pretrained_path": None, # randomly initialise weights 33 | "train_batch_size": 32, 34 | "infer_batch_size": 64, 35 | "model_flags": {"freeze": False}, 36 | } 37 | ], 38 | "loss_term": {"bce": 1}, 39 | "optimizer": tf.train.AdamOptimizer, 40 | "inf_auto_metric": "valid_auc", 41 | "inf_auto_comparator": ">", 42 | "inf_batch_size": 64, 43 | } 44 | seg_gland = { 45 | "train_input_shape": [448, 448], 46 | "train_output_shape": [448, 448], 47 | "infer_input_shape": [448, 448], 48 | "infer_output_shape": [112, 112], 49 | "input_chans": 3, 50 | "filter_sizes": [5, 7, 11], 51 | "input_norm": True, 52 | "training_phase": [ 53 | { 54 | "nr_epochs": 70, 55 | "manual_parameters": { 56 | # tuple(initial value, schedule) 57 | "learning_rate": (1.0e-3, [("15", 1.0e-4), ("50", 5.0e-5)]), 58 | }, 59 | "pretrained_path": None, # randomly initialise weights 60 | "train_batch_size": 6, 61 | "infer_batch_size": 12, 62 | "model_flags": {"freeze": False}, 63 | } 64 | ], 65 | "loss_term": {"bce": 1}, 66 | "optimizer": tf.train.AdamOptimizer, 67 | "inf_auto_metric": "valid_dice", 68 | "inf_auto_comparator": ">", 69 | "inf_batch_size": 12, 70 | } 71 | seg_nuc = { 72 | "train_input_shape": [256, 256], 73 | "train_output_shape": [256, 256], 74 | "infer_input_shape": [256, 256], 75 | "infer_output_shape": [112, 112], 76 | "input_chans": 3, 77 | "filter_sizes": [5, 7, 11], 78 | "input_norm": True, 79 | "training_phase": [ 80 | { 81 | "nr_epochs": 70, 82 | "manual_parameters": { 83 | # tuple(initial value, schedule) 84 | "learning_rate": (1.0e-3, [("15", 1.0e-4), ("30", 5.0e-5)]), 85 | }, 86 | "pretrained_path": None, # randomly initialise weights 87 | "train_batch_size": 6, 88 | "infer_batch_size": 12, 89 | "model_flags": {"freeze": False}, 90 | } 91 | ], 92 | "loss_term": {"bce": 1, "dice": 1}, 93 | "optimizer": tf.train.AdamOptimizer, 94 | "inf_auto_metric": "valid_dice", 95 | "inf_auto_comparator": ">", 96 | "inf_batch_size": 12, 97 | } 98 | -------------------------------------------------------------------------------- /src/process.py: -------------------------------------------------------------------------------- 1 | """ 2 | Post-processing 3 | """ 4 | 5 | import glob 6 | import os 7 | import time 8 | import cv2 9 | import numpy as np 10 | from scipy.ndimage import measurements 11 | from scipy.ndimage.morphology import binary_fill_holes 12 | from skimage.morphology import remove_small_objects, watershed 13 | 14 | from config import Config 15 | 16 | from misc.viz_utils import visualize_instances 17 | from misc.utils import remap_label 18 | 19 | 20 | def process_utils(pred_map, mode): 21 | """ 22 | Performs post processing for a given image 23 | 24 | Args: 25 | pred_map: output of CNN 26 | mode: choose either 'seg_gland' or 'seg_nuc' 27 | """ 28 | 29 | if mode == "seg_gland": 30 | pred = np.squeeze(pred_map) 31 | 32 | blb = pred[..., 0] 33 | blb = np.squeeze(blb) 34 | cnt = pred[..., 1] 35 | cnt = np.squeeze(cnt) 36 | cnt[cnt > 0.5] = 1 37 | cnt[cnt <= 0.5] = 0 38 | 39 | pred = blb - cnt 40 | pred[pred > 0.55] = 1 41 | pred[pred <= 0.55] = 0 42 | k_disk1 = np.array( 43 | [ 44 | [0, 0, 1, 0, 0], 45 | [0, 1, 1, 1, 0], 46 | [1, 1, 1, 1, 1], 47 | [0, 1, 1, 1, 0], 48 | [0, 0, 1, 0, 0], 49 | ], 50 | np.uint8, 51 | ) 52 | # ! refactor these 53 | pred = binary_fill_holes(pred) 54 | pred = pred.astype("uint16") 55 | pred = cv2.morphologyEx(pred, cv2.MORPH_OPEN, k_disk1) 56 | pred = measurements.label(pred)[0] 57 | pred = remove_small_objects(pred, min_size=1500) 58 | 59 | k_disk2 = np.array( 60 | [ 61 | [0, 0, 0, 0, 1, 0, 0, 0, 0], 62 | [0, 0, 0, 1, 1, 1, 0, 0, 0], 63 | [0, 0, 1, 1, 1, 1, 1, 0, 0], 64 | [0, 1, 1, 1, 1, 1, 1, 1, 0], 65 | [1, 1, 1, 1, 1, 1, 1, 1, 1], 66 | [0, 1, 1, 1, 1, 1, 1, 1, 0], 67 | [0, 0, 1, 1, 1, 1, 1, 0, 0], 68 | [0, 0, 0, 1, 1, 1, 0, 0, 0], 69 | [0, 0, 0, 0, 1, 0, 0, 0, 0], 70 | ], 71 | np.uint8, 72 | ) 73 | 74 | pred = pred.astype("uint16") 75 | proced_pred = cv2.dilate(pred, k_disk2, iterations=1) 76 | elif mode == "seg_nuc": 77 | blb_raw = pred_map[..., 0] 78 | blb_raw = np.squeeze(blb_raw) 79 | blb = blb_raw.copy() 80 | blb[blb > 0.5] = 1 81 | blb[blb <= 0.5] = 0 82 | blb = measurements.label(blb)[0] 83 | blb = remove_small_objects(blb, min_size=10) 84 | blb[blb > 0] = 1 85 | 86 | mrk_raw = pred_map[..., 1] 87 | mrk_raw = np.squeeze(mrk_raw) 88 | cnt_raw = pred_map[..., 2] 89 | cnt_raw = np.squeeze(cnt_raw) 90 | cnt = cnt_raw.copy() 91 | cnt[cnt >= 0.4] = 1 92 | cnt[cnt < 0.4] = 0 93 | mrk = mrk_raw - cnt 94 | mrk = mrk * blb 95 | mrk[mrk > 0.75] = 1 96 | mrk[mrk <= 0.75] = 0 97 | 98 | marker = mrk.copy() 99 | marker = binary_fill_holes(marker) 100 | marker = measurements.label(marker)[0] 101 | marker = remove_small_objects(marker, min_size=10) 102 | proced_pred = watershed(-mrk_raw, marker, mask=blb) 103 | 104 | return proced_pred 105 | 106 | 107 | def process(): 108 | """ 109 | Performs post processing for a list of images 110 | 111 | """ 112 | 113 | cfg = Config() 114 | 115 | for data_dir in cfg.inf_data_list: 116 | 117 | proc_dir = cfg.inf_output_dir + "/processed/" 118 | pred_dir = cfg.inf_output_dir + "/raw/" 119 | file_list = glob.glob(pred_dir + "*.npy") 120 | file_list.sort() # ensure same order 121 | 122 | if not os.path.isdir(proc_dir): 123 | os.makedirs(proc_dir) 124 | for filename in file_list: 125 | start = time.time() 126 | filename = os.path.basename(filename) 127 | basename = filename.split(".")[0] 128 | 129 | test_set = basename.split("_")[0] 130 | test_set = test_set[-1] 131 | 132 | print(pred_dir, basename, end=" ", flush=True) 133 | 134 | ## 135 | img = cv2.imread(data_dir + basename + cfg.inf_imgs_ext) 136 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 137 | 138 | pred_map = np.load(pred_dir + "/%s.npy" % basename) 139 | 140 | # get the instance level prediction 141 | pred_inst = process_utils(pred_map, cfg.model_mode) 142 | 143 | # ! remap label is slow - check to see whether it is needed! 144 | pred_inst = remap_label(pred_inst, by_size=True) 145 | 146 | overlaid_output = visualize_instances(pred_inst, img) 147 | overlaid_output = cv2.cvtColor(overlaid_output, cv2.COLOR_BGR2RGB) 148 | cv2.imwrite("%s/%s.png" % (proc_dir, basename), overlaid_output) 149 | 150 | # save segmentation mask 151 | np.save("%s/%s" % (proc_dir, basename), pred_inst) 152 | 153 | end = time.time() 154 | diff = str(round(end - start, 2)) 155 | print("FINISH. TIME: %s" % diff) 156 | 157 | 158 | if __name__ == "__main__": 159 | process() 160 | 161 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | """train.py 2 | 3 | Main training script. 4 | 5 | Usage: 6 | train.py [--gpu=] [--view=] 7 | train.py (-h | --help) 8 | train.py --version 9 | 10 | Options: 11 | -h --help Show this string. 12 | --version Show version. 13 | --gpu= Comma separated GPU list. 14 | --view= View dataset- use either 'train' or 'valid'. 15 | """ 16 | 17 | from docopt import docopt 18 | import argparse 19 | import json 20 | import os 21 | 22 | import numpy as np 23 | import tensorflow as tf 24 | from tensorpack import Inferencer, logger 25 | from tensorpack.callbacks import ( 26 | DataParallelInferenceRunner, 27 | ModelSaver, 28 | MinSaver, 29 | MaxSaver, 30 | ScheduledHyperParamSetter, 31 | ) 32 | from tensorpack.tfutils import SaverRestore, get_model_loader 33 | from tensorpack.train import ( 34 | SyncMultiGPUTrainerParameterServer, 35 | TrainConfig, 36 | launch_train_with_config, 37 | ) 38 | 39 | import loader.loader as loader 40 | from config import Config 41 | from misc.utils import get_files 42 | 43 | from sklearn.metrics import roc_auc_score 44 | 45 | 46 | class StatCollector(Inferencer, Config): 47 | """ 48 | Accumulate output of inference during training. 49 | After the inference finishes, calculate the statistics 50 | """ 51 | 52 | def __init__(self, prefix="valid"): 53 | super(StatCollector, self).__init__() 54 | self.prefix = prefix 55 | 56 | def _get_fetches(self): 57 | return self.train_inf_output_tensor_names 58 | 59 | def _before_inference(self): 60 | self.true_list = [] 61 | self.pred_list = [] 62 | 63 | def _on_fetches(self, outputs): 64 | pred, true = outputs 65 | self.true_list.extend(true) 66 | self.pred_list.extend(pred) 67 | 68 | def _after_inference(self): 69 | # ! factor this out 70 | def _dice(true, pred, label): 71 | true = np.array(true[..., label], np.int32) 72 | pred = np.array(pred[..., label], np.int32) 73 | inter = (pred * true).sum() 74 | total = (pred + true).sum() 75 | return 2 * inter / (total + 1.0e-8) 76 | 77 | stat_dict = {} 78 | pred = np.array(self.pred_list) 79 | true = np.array(self.true_list) 80 | 81 | if self.model_mode == "seg_gland": 82 | # Get the segmentation stats 83 | 84 | pred = pred[..., :2] 85 | true = true[..., :2] 86 | 87 | # Binarize the prediction 88 | pred[pred > 0.5] = 1.0 89 | 90 | stat_dict[self.prefix + "_dice_obj"] = _dice(true, pred, 0) 91 | stat_dict[self.prefix + "_dice_cnt"] = _dice(true, pred, 1) 92 | 93 | elif self.model_mode == "seg_nuc": 94 | # Get the segmentation stats 95 | 96 | pred = pred[..., :3] 97 | true = true[..., :3] 98 | 99 | # Binarize the prediction 100 | pred[pred > 0.5] = 1.0 101 | 102 | stat_dict[self.prefix + "_dice_np"] = _dice(true, pred, 0) 103 | stat_dict[self.prefix + "_dice_mk_blb"] = _dice(true, pred, 1) 104 | stat_dict[self.prefix + "_dice_mk_cnt"] = _dice(true, pred, 2) 105 | 106 | else: 107 | # Get the classification stats 108 | 109 | # Convert vector to scalar prediction 110 | prob = np.squeeze(pred[..., 1]) 111 | pred = np.argmax(pred, -1) 112 | pred = np.squeeze(pred) 113 | true = np.squeeze(true) 114 | 115 | accuracy = (pred == true).sum() / np.size(true) 116 | error = (pred != true).sum() / np.size(true) 117 | 118 | stat_dict[self.prefix + "_acc"] = accuracy * 100 119 | stat_dict[self.prefix + "_error"] = error * 100 120 | 121 | if self.model_mode == "class_pcam": 122 | auc = roc_auc_score(true, prob) 123 | stat_dict[self.prefix + "_auc"] = auc 124 | 125 | return stat_dict 126 | 127 | 128 | ########################################### 129 | 130 | 131 | class Trainer(Config): 132 | def get_datagen(self, batch_size, mode="train", view=False): 133 | if mode == "train": 134 | augmentors = self.get_train_augmentors( 135 | self.train_input_shape, self.train_output_shape, view 136 | ) 137 | data_files = get_files(self.train_dir, self.data_ext) 138 | # Different data generators for segmentation and classification 139 | if self.model_mode == "seg_gland" or self.model_mode == "seg_nuc": 140 | data_generator = loader.train_generator_seg 141 | else: 142 | data_generator = loader.train_generator_class 143 | nr_procs = self.nr_procs_train 144 | else: 145 | augmentors = self.get_valid_augmentors( 146 | self.train_input_shape, self.train_output_shape, view 147 | ) 148 | # Different data generators for segmentation and classification 149 | data_files = get_files(self.valid_dir, self.data_ext) 150 | if self.model_mode == "seg_gland" or self.model_mode == "seg_nuc": 151 | data_generator = loader.valid_generator_seg 152 | else: 153 | data_generator = loader.valid_generator_class 154 | nr_procs = self.nr_procs_valid 155 | 156 | # set nr_proc=1 for viewing to ensure clean ctrl-z 157 | nr_procs = 1 if view else nr_procs 158 | dataset = loader.DatasetSerial(data_files) 159 | if self.model_mode == "seg_gland" or self.model_mode == "seg_nuc": 160 | datagen = data_generator( 161 | dataset, 162 | shape_aug=augmentors[0], 163 | input_aug=augmentors[1], 164 | label_aug=augmentors[2], 165 | batch_size=batch_size, 166 | nr_procs=nr_procs, 167 | ) 168 | else: 169 | datagen = data_generator( 170 | dataset, 171 | shape_aug=augmentors[0], 172 | input_aug=augmentors[1], 173 | batch_size=batch_size, 174 | nr_procs=nr_procs, 175 | ) 176 | 177 | return datagen 178 | 179 | def view_dataset(self, mode="train"): 180 | assert mode == "train" or mode == "valid", "Invalid view mode" 181 | if self.model_mode == "seg_gland" or self.model_mode == "seg_nuc": 182 | datagen = self.get_datagen(4, mode=mode, view=True) 183 | loader.visualize(datagen, 4) 184 | else: 185 | # visualise more for classification- don't need to show label 186 | datagen = self.get_datagen(8, mode=mode, view=True) 187 | loader.visualize(datagen, 8) 188 | return 189 | 190 | def run_once(self, opt, sess_init=None, save_dir=None): 191 | 192 | train_datagen = self.get_datagen(opt["train_batch_size"], mode="train") 193 | valid_datagen = self.get_datagen(opt["infer_batch_size"], mode="valid") 194 | 195 | ###### must be called before ModelSaver 196 | if save_dir is None: 197 | logger.set_logger_dir(self.save_dir) 198 | else: 199 | logger.set_logger_dir(save_dir) 200 | 201 | ###### 202 | model_flags = opt["model_flags"] 203 | model = self.get_model()(**model_flags) 204 | ###### 205 | callbacks = [ 206 | ModelSaver(max_to_keep=1, keep_checkpoint_every_n_hours=None), 207 | ] 208 | 209 | for param_name, param_info in opt["manual_parameters"].items(): 210 | model.add_manual_variable(param_name, param_info[0]) 211 | callbacks.append(ScheduledHyperParamSetter(param_name, param_info[1])) 212 | # multi-GPU inference (with mandatory queue prefetch) 213 | infs = [StatCollector()] 214 | callbacks.append( 215 | DataParallelInferenceRunner(valid_datagen, infs, list(range(nr_gpus))) 216 | ) 217 | if self.model_mode == "seg_gland": 218 | callbacks.append(MaxSaver("valid_dice_obj")) 219 | elif self.model_mode == "seg_nuc": 220 | callbacks.append(MaxSaver("valid_dice_np")) 221 | else: 222 | callbacks.append(MaxSaver("valid_auc")) 223 | 224 | steps_per_epoch = train_datagen.size() // nr_gpus 225 | 226 | config = TrainConfig( 227 | model=model, 228 | callbacks=callbacks, 229 | dataflow=train_datagen, 230 | steps_per_epoch=steps_per_epoch, 231 | max_epoch=opt["nr_epochs"], 232 | ) 233 | config.session_init = sess_init 234 | 235 | launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_gpus)) 236 | tf.reset_default_graph() # remove the entire graph in case of multiple runs 237 | return 238 | 239 | def run(self): 240 | def get_last_chkpt_path(prev_phase_dir): 241 | stat_file_path = prev_phase_dir + "/stats.json" 242 | with open(stat_file_path) as stat_file: 243 | info = json.load(stat_file) 244 | chkpt_list = [epoch_stat["global_step"] for epoch_stat in info] 245 | last_chkpts_path = "%smodel-%d.index" % (prev_phase_dir, max(chkpt_list)) 246 | return last_chkpts_path 247 | 248 | phase_opts = self.training_phase 249 | 250 | if len(phase_opts) > 1: 251 | for idx, opt in enumerate(phase_opts): 252 | 253 | log_dir = "%s/%02d" % (self.save_dir, idx) 254 | if opt["pretrained_path"] == -1: 255 | pretrained_path = get_last_chkpt_path(prev_log_dir) 256 | init_weights = SaverRestore( 257 | pretrained_path, ignore=["learning_rate"] 258 | ) 259 | elif opt["pretrained_path"] is not None: 260 | init_weights = get_model_loader(pretrained_path) 261 | self.run_once(opt, sess_init=init_weights, save_dir=log_dir + "/") 262 | prev_log_dir = log_dir 263 | else: 264 | 265 | opt = phase_opts[0] 266 | if "pretrained_path" in opt: 267 | if opt["pretrained_path"] == None: 268 | init_weights = None 269 | elif opt["pretrained_path"] == -1: 270 | log_dir_prev = "%s" % self.save_dir 271 | pretrained_path = get_last_chkpt_path(log_dir_prev) 272 | init_weights = SaverRestore( 273 | pretrained_path, ignore=["learning_rate"] 274 | ) 275 | else: 276 | init_weights = get_model_loader(opt["pretrained_path"]) 277 | self.run_once(opt, sess_init=init_weights, save_dir=self.save_dir) 278 | 279 | return 280 | 281 | 282 | ########################################################################### 283 | 284 | 285 | if __name__ == "__main__": 286 | 287 | args = docopt(__doc__) 288 | print(args) 289 | 290 | trainer = Trainer() 291 | 292 | if args["--view"] and args["--gpu"]: 293 | raise Exception("Supply only one of --view and --gpu.") 294 | 295 | if args["--view"]: 296 | if args["--view"] != "train" and args["--view"] != "valid": 297 | raise Exception('Use "train" or "valid" for --view.') 298 | trainer.view_dataset(args["--view"]) 299 | else: 300 | os.environ["CUDA_VISIBLE_DEVICES"] = args["--gpu"] 301 | nr_gpus = len(args["--gpu"].split(",")) 302 | trainer.run() 303 | -------------------------------------------------------------------------------- /src/viz_filters.py: -------------------------------------------------------------------------------- 1 | """viz_filters.py 2 | 3 | Visualise basis filters of the form (in polar coordinates): 4 | 5 | R_alpha(r)e^{i*alpha*phi} 6 | 7 | Here, R_alpha is a Gaussian centred at beta. 8 | 9 | Usage: 10 | viz_filters.py [--ksize=] 11 | viz_filters.py (-h | --help) 12 | viz_filters.py --version 13 | 14 | Options: 15 | -h --help Show this string. 16 | --version Show version. 17 | --ksize= Kernel size to display. [default: 7] 18 | """ 19 | 20 | 21 | from docopt import docopt 22 | import numpy as np 23 | import matplotlib.pyplot as plt 24 | import matplotlib.cm as cm 25 | import math 26 | 27 | 28 | def get_filter_info(k_size): 29 | """ 30 | Get the filter parameters for a given kernel size 31 | 32 | Args: 33 | k_size (int): input kernel size 34 | 35 | Returns: 36 | alpha_list: list of alpha values 37 | beta_list: list of beta values 38 | bl_list: used to bandlimit high frequency filters in get_basis_filters() 39 | """ 40 | 41 | if k_size == 5: 42 | alpha_list = [0, 1, 2] 43 | beta_list = [0, 1, 2] 44 | bl_list = [0, 2, 2] 45 | elif k_size == 7: 46 | alpha_list = [0, 1, 2, 3] 47 | beta_list = [0, 1, 2, 3] 48 | bl_list = [0, 2, 3, 2] 49 | elif k_size == 9: 50 | alpha_list = [0, 1, 2, 3, 4] 51 | beta_list = [0, 1, 2, 3, 4] 52 | bl_list = [0, 3, 4, 4, 3] 53 | elif k_size == 11: 54 | alpha_list = [0, 1, 2, 3, 4] 55 | beta_list = [1, 2, 3, 4] 56 | bl_list = [0, 3, 4, 4, 3] 57 | 58 | return alpha_list, beta_list, bl_list 59 | 60 | 61 | def get_basis_filters(alpha_list, beta_list, bl_list, k_size, eps=10 ** -8): 62 | """ 63 | Gets the atomic basis filters 64 | 65 | Args: 66 | alpha_list: list of alpha values for basis filters 67 | beta_list: list of beta values for the basis filters 68 | bl_list: bandlimit list to reduce aliasing of basis filters 69 | k_size (int): kernel size of basis filters 70 | eps=10**-8: epsilon used to prevent division by 0 71 | 72 | Returns: 73 | filter_list_bl: list of filters, with bandlimiting (bl) to reduce aliasing 74 | alpha_list_bl: corresponding list of alpha used in bandlimited filters 75 | beta_list_bl: corresponding list of beta used in bandlimited filters 76 | """ 77 | 78 | filter_list_bl = [] 79 | alpha_list_bl = [] 80 | beta_list_bl = [] 81 | for alpha in alpha_list: 82 | for beta in beta_list: 83 | if np.abs(alpha) <= bl_list[beta]: 84 | his = k_size // 2 # half image size 85 | y_index, x_index = np.mgrid[-his : (his + 1), -his : (his + 1)] 86 | y_index *= -1 87 | z_index = x_index + 1j * y_index 88 | 89 | # convert z to natural coordinates and add eps to avoid division by zero 90 | z = z_index + eps 91 | r = np.abs(z) 92 | 93 | if beta == beta_list[-1]: 94 | sigma = 0.6 95 | else: 96 | sigma = 0.6 97 | rad_prof = np.exp(-((r - beta) ** 2) / (2 * (sigma ** 2))) 98 | c_image = rad_prof * (z / r) ** alpha 99 | 100 | # add filter to list 101 | filter_list_bl.append(c_image) 102 | # add frequency of filter to list (needed for phase manipulation) 103 | alpha_list_bl.append(alpha) 104 | beta_list_bl.append(beta) 105 | 106 | return filter_list_bl, alpha_list_bl, beta_list_bl 107 | 108 | 109 | def plot_filters(filter_list, alpha_list, beta_list): 110 | """ 111 | Plot the real and imaginary parts of the basis filters. 112 | 113 | Args: 114 | filter_list: list of basis filters 115 | alpha_list: alpha of each basis filter 116 | beta_list: beta of each basis filter 117 | """ 118 | 119 | count = 1 120 | plt.figure(figsize=(25, 11)) 121 | nr_filts = 2 * len(filter_list) 122 | len_x = math.ceil(np.sqrt(nr_filts)) 123 | len_y = math.ceil(np.sqrt(nr_filts)) 124 | len_x = 5 125 | len_y = 10 126 | 127 | for i in range(len(filter_list)): 128 | filt_real = filter_list[i].real 129 | filt_imag = filter_list[i].imag 130 | plt.subplot(len_x, len_y, count) 131 | plt.imshow(filt_real, vmin=-1, vmax=1, cmap=cm.gist_gray) 132 | plt.axis("off") 133 | plt.title( 134 | "Real: $alpha$= %s, $beta$= %s" % (alpha_list[i], beta_list[i]), fontsize=6 135 | ) 136 | plt.subplot(len_x, len_y, count + 1) 137 | plt.imshow(filt_imag, vmin=-1, vmax=1, cmap=cm.gist_gray) 138 | plt.axis("off") 139 | plt.title( 140 | "Imaginary: $alpha$= %s, $beta$= %s" % (alpha_list[i], beta_list[i]), 141 | fontsize=6, 142 | ) 143 | count += 2 144 | plt.tight_layout() 145 | plt.show() 146 | 147 | 148 | ##### 149 | if __name__ == "__main__": 150 | args = docopt(__doc__) 151 | 152 | ksize = int(args["--ksize"]) 153 | 154 | if ksize not in [5, 7, 9, 11]: 155 | raise Exception("Select ksize to be either 5,7,9 or 11") 156 | 157 | info = get_filter_info(ksize) 158 | 159 | filter_list, alpha_list, beta_list = get_basis_filters( 160 | info[0], info[1], info[2], ksize 161 | ) 162 | 163 | plot_filters(filter_list, alpha_list, beta_list) 164 | --------------------------------------------------------------------------------