├── .gitignore
├── LICENSE
├── README.md
├── feature_maps.gif
├── requirements.txt
└── src
├── README.md
├── config.py
├── extract_patches.py
├── infer.py
├── loader
├── custom_augs.py
└── loader.py
├── misc
├── patch_extractor.py
├── utils.py
└── viz_utils.py
├── model
├── class_pcam
│ └── graph.py
├── seg_gland
│ └── graph.py
├── seg_nuc
│ └── graph.py
└── utils
│ ├── gconv_utils.py
│ ├── model_utils.py
│ ├── norm_utils.py
│ └── rotation_utils.py
├── opt
├── augs.py
└── params.py
├── process.py
├── train.py
└── viz_filters.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled/cached Python
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Temporaries
7 | .ipynb_checkpoints/
8 | *.swp
9 |
10 | # Outputs
11 | *.png
12 | *.svg
13 | *.eps
14 | *.out
15 |
16 | # Model and other data
17 | *.pth
18 | *.npy
19 | .idea
20 | *.DS_Store
21 | .pytest_cache/
22 | *.pyc
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Simon Graham
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dense Steerable Filter CNNs for Expoiting Rotational Symmetry in Histology Images
2 |
3 | A densely connected rotation-equivariant CNN for histology image analysis.
4 |
5 | [Link](https://arxiv.org/abs/2004.03037) to the pre-print.
6 |
7 | **NEWS**: Our paper has now been published in IEEE Transactions on Medical Imaging. Find the published article [here](https://ieeexplore.ieee.org/document/9153847).
8 |
9 | ## Getting Started
10 |
11 | Environment instructions:
12 |
13 | ```
14 | conda create --name dsf-cnn python=3.6
15 | conda activate dsf-cnn
16 | pip install -r requirements.txt
17 | ```
18 |
19 | ## Repository Structure
20 |
21 | - `src/` contains executable files used to run the model. Further information on running the code can be found in the corresponding directory.
22 | - `loader/`contains scripts for data loading and self implemented augmentation functions.
23 | - `misc/`contains util scripts.
24 | - `model/class_pcam/` model architecture for dsf-cnn on PCam dataset
25 | - `model/seg_nuc/` model architecture for dsf-cnn on Kumar dataset
26 | - `model/seg_gland/` model architecture for dsf-cnn on CRAG dataset
27 | - `model/utils/` contains util scripts for the models.
28 | - `opt/` contains scripts that define the model hyperparameters and augmentation pipeline.
29 | - `config.py` is the configuration file. Paths need to be changed accordingly.
30 | - `train.py` and `infer.py` are the training and inference scripts respectively.
31 | - `process.py` is the post processing script for obtaining the final instances for segmentation.
32 |
33 |
34 |
35 |
36 |
37 | ## Citation
38 |
39 | If any part of this code is used, please give appropriate citation to our paper.
40 |
41 | ```
42 | @article{graham2020dense,
43 | title={Dense Steerable Filter CNNs for Exploiting Rotational Symmetry in Histology Images},
44 | author={Graham, Simon and Epstein, David and Rajpoot, Nasir},
45 | journal={arXiv preprint arXiv:2004.03037},
46 | year={2020}
47 | }
48 | ```
49 |
50 | ## Authors
51 |
52 | See the list of [contributors](https://github.com/simongraham/dsf-cnn/graphs/contributors) who participated in this project.
53 |
54 | ## License
55 |
56 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details
57 |
--------------------------------------------------------------------------------
/feature_maps.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/simongraham/dsf-cnn/69ba9fb4834e86c238ce866d3ddebfc3fe6c5c57/feature_maps.gif
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | docopt==0.6.2
2 | tensorflow-gpu==1.12.0
3 | tensorpack==0.9.0.1
4 | scikit-image==0.14.2
5 | matplotlib===3.0.2
6 | numpy==1.15.4
7 | opencv-python==4.1.2.30
8 |
--------------------------------------------------------------------------------
/src/README.md:
--------------------------------------------------------------------------------
1 | # Training and Inference Instructions
2 |
3 | ## Choosing the network
4 |
5 | The model to use and the selection of other hyperparameters is selected in `config.py`. The models available are:
6 | - Classification PCam DSF-CNN: `model/class_pcam/graph.py`
7 | - Segmentation CRAG DSF-CNN: `model/seg_gland/graph.py`
8 | - Segmentation Kumar DSF-CNN: `model/seg_nuc/graph.py`
9 |
10 | ## Modifying Hyperparameters
11 |
12 | To modify hyperparameters, refer to `opt/params.py`
13 |
14 | ## Augmentation
15 |
16 | To modify the augmentation pipeline, refer to `get_train_augmentors()` in `opt/augs.py`. Refer to [this webpage](https://tensorpack.readthedocs.io/modules/dataflow.imgaug.html) for information on how to modify the augmentation parameters.
17 |
18 | ## Data Format
19 |
20 | For segmentation, store patches in a 4 dimensional numpy array with channels [RGB, inst]. Here, inst is the instance segmentation ground truth. I.e pixels range from 0 to N, where 0 is background and N is the number of nuclear instances for that particular image.
For classification, save each image patch as a 1D array of size [(H * W * C)+1], where H, W and C refer to height, width and number of channels. The final value is the label of the image (starting from 0).
21 |
22 | ## Training
23 |
24 | To train the network, the command is:
25 |
26 | `python train.py --gpu=`
27 | where gpu_id is a comma separated list which denotes which GPU(s) will be used for training. For example, if we are using GPU number 0 and 1, the command is:
28 | `python train.py --gpu='0,1'`
29 |
30 | Before training, set in `config.py`:
31 | - path to the data directories
32 | - path where checkpoints will be saved
33 | - path to where the output will be saved
34 |
35 | ## Inference
36 |
37 | To generate the network predictions, the command is:
38 | `python infer.py --gpu= --mode=`
39 | Currently, the inference code only supports 1 GPU. For ``, use `'seg'` or `'class'`. Use `'class'` when processing PCam and `'seg'` when processing Kumar and CRAG.
40 |
41 | Before running inference, set in `config.py`:
42 | - path where the output will be saved
43 | - path to data root directories
44 | - path to model checkpoint.
45 |
46 | ## Post Processing
47 |
48 | To obtain the final segmentation, use the command:
49 | `python process.py`
50 | for post-processing the network predictions. Note, this is only for segmentation networks.
51 |
52 |
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 | """
2 | Configuration file
3 | """
4 |
5 | import importlib
6 | import numpy as np
7 | import tensorflow as tf
8 |
9 | import opt.augs as augs
10 |
11 | from model.utils.gconv_utils import get_basis_filters, get_rot_info, get_basis_params
12 |
13 |
14 | class Config(object):
15 | def __init__(self,):
16 |
17 | self.seed = 10
18 |
19 | self.model_mode = "seg_nuc" # choose seg_gland, seg_nuc, class_pcam
20 | self.filter_type = "steerable" # choose steerable or standard
21 |
22 | self.nr_orients = 8 # number of orientations for the filters
23 |
24 | #### Dynamically setting the hyper-param and aug files into variable
25 | param_file = importlib.import_module("opt.params")
26 | param_dict = param_file.__getattribute__(self.model_mode)
27 |
28 | for variable, value in param_dict.items():
29 | self.__setattr__(variable, value)
30 |
31 | self.data_ext = ".npy"
32 | # list of directories containing training and validation files.
33 | # Each directory contains one numpy file per image.
34 | # if self.model_mode = 'seg_nuc' or 'seg_gland',
35 | # data is of size [H,W,4] (RGB + instance label)
36 | # if self.mode == 'class_pcam', data is of size [(H*W*C)+1],
37 | # where the final value is the class label
38 | self.train_dir = ["/media/simon/Storage 1/Data/Nuclei/patches/kumar/train/"]
39 | self.valid_dir = ["/media/simon/Storage 1/Data/Nuclei/patches/kumar/valid/"]
40 |
41 | # nr of processes for parallel processing input
42 | self.nr_procs_train = 8
43 | self.nr_procs_valid = 4
44 |
45 | exp_id = "v1.0"
46 | # loading chkpts in tensorflow, the path must not contain extra '/'
47 | self.log_path = "/media/simon/Storage 1/dsf-cnn/checkpoints/" # log root path
48 | self.save_dir = "%s/%s/%s_%s_%s" % (
49 | self.log_path,
50 | self.model_mode,
51 | self.filter_type,
52 | self.nr_orients,
53 | exp_id,
54 | ) # log file destination
55 |
56 | #### Info for running inference
57 | self.inf_auto_find_chkpt = False
58 | # path to checkpoints will be used for inference, replace accordingly
59 | self.inf_model_path = self.save_dir + "model-xxxx.index"
60 |
61 | # paths to files for inference. Note, for PCam we use the original .h5 file.
62 | self.inf_imgs_ext = ".tif"
63 | self.inf_data_list = ["/media/simon/Storage 1/Data/Nuclei/kumar/test/Images/"]
64 |
65 | output_root = "/media/simon/Storage 1/output/"
66 | # log file destination
67 | self.inf_output_dir = "%s/%s_%s/" % (
68 | output_root,
69 | self.filter_type,
70 | self.nr_orients,
71 | )
72 |
73 | if self.filter_type == "steerable":
74 | # Generate the basis filters- only need to do this once before training
75 | self.basis_filter_list = []
76 | self.rot_matrix_list = []
77 | for ksize in self.filter_sizes:
78 | alpha_list, beta_list, bl_list = get_basis_params(ksize)
79 | b_filters, freq_filters = get_basis_filters(
80 | alpha_list, beta_list, bl_list, ksize
81 | )
82 | self.basis_filter_list.append(b_filters)
83 | self.rot_matrix_list.append(get_rot_info(self.nr_orients, freq_filters))
84 |
85 | # for inference during evalutaion mode i.e run by inferer.py
86 | self.eval_inf_input_tensor_names = ["images"]
87 | self.eval_inf_output_tensor_names = ["predmap-coded"]
88 | # for inference during training mode i.e run by trainer.py
89 | self.train_inf_output_tensor_names = ["predmap-coded", "truemap-coded"]
90 |
91 | def get_model(self):
92 | model_constructor = importlib.import_module("model.%s.graph" % self.model_mode)
93 | model_constructor = model_constructor.Graph
94 | return model_constructor
95 |
96 | # refer to https://tensorpack.readthedocs.io/modules/dataflow.imgaug.html
97 | # for information on how to modify the augmentation parameters.
98 | # Pipeline can be modified in opt/augs.py
99 |
100 | def get_train_augmentors(self, input_shape, output_shape, view=False):
101 | return augs.get_train_augmentors(self, input_shape, output_shape, view)
102 |
103 | def get_valid_augmentors(self, input_shape, output_shape, view=False):
104 | return augs.get_valid_augmentors(self, input_shape, output_shape, view)
105 |
--------------------------------------------------------------------------------
/src/extract_patches.py:
--------------------------------------------------------------------------------
1 | """extract_patches.py
2 |
3 | Script for extracting patches from image tiles. The script will read
4 | and RGB image and a corresponding label and form image patches to be
5 | used by the network.
6 | """
7 |
8 |
9 | import glob
10 | import os
11 |
12 | import cv2
13 | import numpy as np
14 |
15 | from misc.patch_extractor import PatchExtractor
16 | from misc.utils import rm_n_mkdir
17 |
18 | from config import Config
19 |
20 | ###########################################################################
21 | if __name__ == "__main__":
22 |
23 | cfg = Config()
24 |
25 | extract_type = "mirror" # 'valid' or 'mirror'
26 | # 'mirror' reflects at the borders; 'valid' doesn't.
27 | # check the patch_extractor.py 'main' to see the difference
28 |
29 | # original size (win size) - input size - output size (step size)
30 | step_size = [112, 112]
31 | # set to size of network input: 448 for glands, 256 for nuclei
32 | win_size = [448, 448]
33 |
34 | xtractor = PatchExtractor(win_size, step_size)
35 |
36 | ### Paths to data - these need to be modified according to where the original data is stored
37 | img_ext = ".png"
38 | # img_dir should contain RGB image tiles from where to extract patches.
39 | img_dir = "path/to/images/"
40 | # ann_dir should contain 2D npy image tiles, with values ranging from 0 to N.
41 | # 0 is background and then each nucleus is uniquely labelled from 1-N.
42 | ann_dir = "path/to/labels/"
43 | ####
44 | out_dir = "output_path/%dx%d_%dx%d" % (
45 | win_size[0],
46 | win_size[1],
47 | step_size[0],
48 | step_size[1],
49 | )
50 |
51 | file_list = glob.glob("%s/*%s" % (img_dir, img_ext))
52 | file_list.sort()
53 |
54 | rm_n_mkdir(out_dir)
55 | for filename in file_list:
56 | filename = os.path.basename(filename)
57 | basename = filename.split(".")[0]
58 | print(filename)
59 |
60 | img = cv2.imread(img_dir + basename + img_ext)
61 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
62 |
63 | # assumes that ann is HxW
64 | ann_inst = np.load(ann_dir + basename + ".npy")
65 | ann_inst = ann_inst.astype("int32")
66 | ann = np.expand_dims(ann_inst, -1)
67 |
68 | img = np.concatenate([img, ann], axis=-1)
69 | sub_patches = xtractor.extract(img, extract_type)
70 | for idx, patch in enumerate(sub_patches):
71 | np.save("{0}/{1}_{2:03d}.npy".format(out_dir, basename, idx), patch)
72 |
--------------------------------------------------------------------------------
/src/infer.py:
--------------------------------------------------------------------------------
1 | """infer.py
2 |
3 | Main inference script.
4 |
5 | Usage:
6 | infer.py [--gpu=] [--mode=]
7 | infer.py (-h | --help)
8 | infer.py --version
9 |
10 | Options:
11 | -h --help Show this string.
12 | --version Show version.
13 | --gpu= Comma separated GPU list. [default: 0]
14 | --mode= Inference mode- use either 'seg' or 'class'.
15 | """
16 |
17 | from docopt import docopt
18 | import argparse
19 | import glob
20 | import math
21 | import time
22 | import os
23 | import sys
24 | from collections import deque
25 | from keras.utils import HDF5Matrix
26 |
27 | import cv2
28 | import numpy as np
29 |
30 | from tensorpack.predict import OfflinePredictor, PredictConfig
31 | from tensorpack.tfutils.sessinit import get_model_loader
32 |
33 | from config import Config
34 | from misc.utils import rm_n_mkdir, cropping_center
35 | from model.utils.model_utils import crop_op
36 |
37 | import json
38 | import operator
39 |
40 | from sklearn.metrics import roc_auc_score
41 |
42 |
43 | def get_best_chkpts(path, metric_name, comparator=">"):
44 | """
45 | Return the best checkpoint according to some criteria.
46 | Note that it will only return valid path, so any checkpoint that has been
47 | removed wont be returned (i.e moving to next one that satisfies the criteria
48 | such as second best etc.)
49 |
50 | Args:
51 | path: directory contains all checkpoints, including the "stats.json" file
52 | """
53 | stat_file = path + "/stats.json"
54 | ops = {
55 | ">": operator.gt,
56 | "<": operator.lt,
57 | }
58 |
59 | op_func = ops[comparator]
60 | with open(stat_file) as f:
61 | info = json.load(f)
62 |
63 | if comparator == ">":
64 | best_value = -float("inf")
65 | else:
66 | best_value = +float("inf")
67 |
68 | best_chkpt = None
69 | for epoch_stat in info:
70 | epoch_value = epoch_stat[metric_name]
71 | if op_func(epoch_value, best_value):
72 | chkpt_path = "%s/model-%d.index" % (path, epoch_stat["global_step"])
73 | if os.path.isfile(chkpt_path):
74 | selected_stat = epoch_stat
75 | best_value = epoch_value
76 | best_chkpt = chkpt_path
77 | return best_chkpt, selected_stat
78 |
79 |
80 | class InferClass(Config):
81 | def __gen_prediction(self, x, predictor):
82 | """
83 | Using 'predictor' to generate the prediction of image 'x'
84 |
85 | Args:
86 | x : input image to be segmented. It will be split into patches
87 | to run the prediction upon before being assembled back
88 | tissue: tissue mask created via otsu -> only process tissue regions
89 | if it is provided
90 | """
91 |
92 | prob = predictor(x)[0]
93 | pred = np.argmax(prob, -1)
94 | pred = np.squeeze(pred)
95 | prob = np.squeeze(prob[..., 1])
96 |
97 | return prob, pred
98 |
99 | ####
100 | def run(self):
101 | if self.inf_auto_find_chkpt:
102 | print(
103 | '-----Auto Selecting Checkpoint Basing On "%s" Through "%s" Comparison'
104 | % (self.inf_auto_metric, self.inf_auto_comparator)
105 | )
106 | model_path, stat = get_best_chkpts(
107 | self.save_dir, self.inf_auto_metric, self.inf_auto_comparator
108 | )
109 | print("Selecting: %s" % model_path)
110 | print("Having Following Statistics:")
111 | for key, value in stat.items():
112 | print("\t%s: %s" % (key, value))
113 | else:
114 | model_path = self.inf_model_path
115 | model_constructor = self.get_model()
116 | pred_config = PredictConfig(
117 | model=model_constructor(),
118 | session_init=get_model_loader(model_path),
119 | input_names=self.eval_inf_input_tensor_names,
120 | output_names=self.eval_inf_output_tensor_names,
121 | create_graph=False)
122 | predictor = OfflinePredictor(pred_config)
123 |
124 |
125 | ####
126 | save_dir = self.inf_output_dir
127 | predict_list = [["case", "prediction"]]
128 |
129 | file_load_img = HDF5Matrix(
130 | self.inf_data_list[0] + "camelyonpatch_level_2_split_test_x.h5", "x"
131 | )
132 | file_load_lab = HDF5Matrix(
133 | self.inf_data_list[0] + "camelyonpatch_level_2_split_test_y.h5", "y"
134 | )
135 |
136 | true_list = []
137 | prob_list = []
138 | pred_list = []
139 |
140 | num_ims = file_load_img.shape[0]
141 | last_step = math.floor(num_ims / self.inf_batch_size)
142 | last_step = self.inf_batch_size * last_step
143 | last_batch = num_ims - last_step
144 | count = 0
145 | for start_batch in range(0, last_step + 1, self.inf_batch_size):
146 | sys.stdout.write("\rProcessed (%d/%d)" % (start_batch, num_ims))
147 | sys.stdout.flush()
148 | if start_batch != last_step:
149 | img = file_load_img[start_batch : start_batch + self.inf_batch_size]
150 | img = img.astype("uint8")
151 | lab = np.squeeze(
152 | file_load_lab[start_batch : start_batch + self.inf_batch_size]
153 | )
154 | else:
155 | img = file_load_img[start_batch : start_batch + last_batch]
156 | img = img.astype("uint8")
157 | lab = np.squeeze(file_load_lab[start_batch : start_batch + last_batch])
158 |
159 | prob, pred = self.__gen_prediction(img, predictor)
160 |
161 | for j in range(prob.shape[0]):
162 | predict_list.append([str(count), str(prob[j])])
163 | count += 1
164 |
165 | prob_list.extend(prob)
166 | pred_list.extend(pred)
167 | true_list.extend(lab)
168 |
169 | prob_list = np.array(prob_list)
170 | pred_list = np.array(pred_list)
171 | true_list = np.array(true_list)
172 | accuracy = (pred_list == true_list).sum() / np.size(true_list)
173 | error = (pred_list != true_list).sum() / np.size(true_list)
174 |
175 | print("Accurcy (%): ", 100 * accuracy)
176 | print("Error (%): ", 100 * error)
177 | if self.model_mode == "class_pcam":
178 | auc = roc_auc_score(true_list, prob_list)
179 | print("AUC: ", auc)
180 |
181 | # Save predictions to csv
182 | rm_n_mkdir(save_dir)
183 | for result in predict_list:
184 | predict_file = open("%s/predict.csv" % save_dir, "a")
185 | predict_file.write(result[0])
186 | predict_file.write(",")
187 | predict_file.write(result[1])
188 | predict_file.write("\n")
189 | predict_file.close()
190 |
191 |
192 | class InferSeg(Config):
193 | def __gen_prediction(self, x, predictor):
194 | """
195 | Using 'predictor' to generate the prediction of image 'x'
196 |
197 | Args:
198 | x : input image to be segmented. It will be split into patches
199 | to run the prediction upon before being assembled back
200 | tissue: tissue mask created via otsu -> only process tissue regions
201 | if it is provided
202 | """
203 | step_size = self.infer_output_shape
204 | msk_size = self.infer_output_shape
205 | win_size = self.infer_input_shape
206 |
207 | def get_last_steps(length, msk_size, step_size):
208 | nr_step = math.ceil((length - msk_size) / step_size)
209 | last_step = (nr_step + 1) * step_size
210 | return int(last_step), int(nr_step + 1)
211 |
212 | im_h = x.shape[0]
213 | im_w = x.shape[1]
214 |
215 | last_h, nr_step_h = get_last_steps(im_h, msk_size[0], step_size[0])
216 | last_w, nr_step_w = get_last_steps(im_w, msk_size[1], step_size[1])
217 |
218 | diff_h = win_size[0] - step_size[0]
219 | padt = diff_h // 2
220 | padb = last_h + win_size[0] - im_h
221 |
222 | diff_w = win_size[1] - step_size[1]
223 | padl = diff_w // 2
224 | padr = last_w + win_size[1] - im_w
225 |
226 | x = np.lib.pad(x, ((padt, padb), (padl, padr), (0, 0)), "reflect")
227 |
228 | #### TODO: optimize this
229 | sub_patches = []
230 | skipped_idx = []
231 | # generating subpatches from orginal
232 | idx = 0
233 | for row in range(0, last_h, step_size[0]):
234 | for col in range(0, last_w, step_size[1]):
235 | win = x[row : row + win_size[0], col : col + win_size[1]]
236 | sub_patches.append(win)
237 | idx += 1
238 |
239 | pred_map = deque()
240 | while len(sub_patches) > self.inf_batch_size:
241 | mini_batch = sub_patches[: self.inf_batch_size]
242 | sub_patches = sub_patches[self.inf_batch_size :]
243 | mini_output = predictor(mini_batch)[0]
244 | if win_size[0] > msk_size[0]:
245 | mini_output = cropping_center(mini_output, (diff_h, diff_w))
246 | mini_output = np.split(mini_output, self.inf_batch_size, axis=0)
247 | pred_map.extend(mini_output)
248 | if len(sub_patches) != 0:
249 | mini_output = predictor(sub_patches)[0]
250 | if win_size[0] > msk_size[0]:
251 | mini_output = cropping_center(mini_output, (diff_h, diff_w))
252 | mini_output = np.split(mini_output, len(sub_patches), axis=0)
253 | pred_map.extend(mini_output)
254 |
255 | #### Assemble back into full image
256 | output_patch_shape = np.squeeze(pred_map[0]).shape
257 |
258 | ch = 1 if len(output_patch_shape) == 2 else output_patch_shape[-1]
259 |
260 | #### Assemble back into full image
261 | pred_map = np.squeeze(np.array(pred_map))
262 | pred_map = np.reshape(pred_map, (nr_step_h, nr_step_w) + pred_map.shape[1:])
263 | pred_map = (
264 | np.transpose(pred_map, [0, 2, 1, 3, 4])
265 | if ch != 1
266 | else np.transpose(pred_map, [0, 2, 1, 3])
267 | )
268 | pred_map = np.reshape(
269 | pred_map,
270 | (
271 | pred_map.shape[0] * pred_map.shape[1],
272 | pred_map.shape[2] * pred_map.shape[3],
273 | ch,
274 | ),
275 | )
276 | # just crop back to original size
277 | pred_map = np.squeeze(pred_map[:im_h, :im_w])
278 |
279 | return pred_map
280 |
281 | ####
282 | def run(self):
283 |
284 | if self.inf_auto_find_chkpt:
285 | print(
286 | '-----Auto Selecting Checkpoint Basing On "%s" Through "%s" Comparison'
287 | % (self.inf_auto_metric, self.inf_auto_comparator)
288 | )
289 | model_path, stat = get_best_chkpts(
290 | self.save_dir, self.inf_auto_metric, self.inf_auto_comparator
291 | )
292 | print("Selecting: %s" % model_path)
293 | print("Having Following Statistics:")
294 | for key, value in stat.items():
295 | print("\t%s: %s" % (key, value))
296 | else:
297 | model_path = self.inf_model_path
298 |
299 | model_constructor = self.get_model()
300 | pred_config = PredictConfig(
301 | model=model_constructor(),
302 | session_init=get_model_loader(model_path),
303 | input_names=self.eval_inf_input_tensor_names,
304 | output_names=self.eval_inf_output_tensor_names,
305 | create_graph=False,
306 | )
307 | predictor = OfflinePredictor(pred_config)
308 |
309 | for data_dir in self.inf_data_list:
310 | save_dir = self.inf_output_dir + "/raw/"
311 | file_list = glob.glob("%s/*%s" % (data_dir, self.inf_imgs_ext))
312 | file_list.sort() # ensure same order
313 |
314 | rm_n_mkdir(save_dir)
315 | for filename in file_list:
316 | start = time.time()
317 | filename = os.path.basename(filename)
318 | basename = filename.split(".")[0]
319 | print(data_dir, basename, end=" ", flush=True)
320 |
321 | ##
322 | img = cv2.imread(data_dir + filename)
323 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
324 |
325 | pred_map = self.__gen_prediction(img, predictor)
326 |
327 | np.save("%s/%s.npy" % (save_dir, basename), [pred_map])
328 | end = time.time()
329 | diff = str(round(end - start, 2))
330 | print("FINISH. TIME: %s" % diff)
331 |
332 |
333 | ####
334 | if __name__ == "__main__":
335 | args = docopt(__doc__)
336 | print(args)
337 |
338 | if args["--gpu"]:
339 | os.environ["CUDA_VISIBLE_DEVICES"] = args["--gpu"]
340 | nr_gpus = len(args["--gpu"].split(","))
341 |
342 | if args["--mode"] is None:
343 | raise Exception('Mode cannot be empty. Use either "class" or "seg".')
344 |
345 | if args["--mode"] == "class":
346 | infer = InferClass()
347 | elif args["--mode"] == "seg":
348 | infer = InferSeg()
349 | else:
350 | raise Exception('Mode not recognised. Use either "class" or "seg".')
351 |
352 | infer.run()
353 |
--------------------------------------------------------------------------------
/src/loader/custom_augs.py:
--------------------------------------------------------------------------------
1 | """
2 | Custom augmentations
3 | """
4 |
5 | import math
6 | import cv2
7 | import matplotlib.cm as cm
8 | import numpy as np
9 | import matplotlib.pyplot as plt
10 |
11 | from scipy import ndimage
12 | from scipy.ndimage import measurements
13 | from scipy.ndimage.filters import gaussian_filter
14 | from scipy.ndimage.interpolation import affine_transform, map_coordinates
15 | from scipy.ndimage.morphology import distance_transform_cdt, distance_transform_edt
16 | from skimage import morphology as morph
17 |
18 | from tensorpack.dataflow.imgaug import ImageAugmentor
19 | from tensorpack.utils.utils import get_rng
20 |
21 | from misc.utils import cropping_center, bounding_box
22 |
23 |
24 | class GenInstance(ImageAugmentor):
25 | def __init__(self, crop_shape=None):
26 | super(GenInstance, self).__init__()
27 | self.crop_shape = crop_shape
28 |
29 | def reset_state(self):
30 | self.rng = get_rng(self)
31 |
32 | def _fix_mirror_padding(self, ann):
33 | """
34 | Deal with duplicated instances due to mirroring in interpolation
35 | during shape augmentation (scale, rotation etc.)
36 | """
37 | current_max_id = np.amax(ann)
38 | inst_list = list(np.unique(ann))
39 | try:
40 | inst_list.remove(0) # remove background
41 | except ValueError:
42 | pass
43 | for inst_id in inst_list:
44 | inst_map = np.array(ann == inst_id, np.uint8)
45 | remapped_ids = measurements.label(inst_map)[0]
46 | remapped_ids[remapped_ids > 1] += current_max_id
47 | ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1]
48 | current_max_id = np.amax(ann)
49 | return ann
50 |
51 |
52 | class GenInstanceContourMap(GenInstance):
53 | """
54 | Input annotation must be of original shape.
55 | """
56 |
57 | def __init__(self, mode, crop_shape=None):
58 | super(GenInstanceContourMap, self).__init__()
59 | self.crop_shape = crop_shape
60 | self.mode = mode
61 |
62 | def _augment(self, img, _):
63 | img = np.copy(img)
64 | orig_ann = img[..., 0] # instance ID map
65 | fixed_ann = self._fix_mirror_padding(orig_ann)
66 | fixed_ann = orig_ann
67 | # re-cropping with fixed instance id map
68 | # crop_ann = cropping_center(fixed_ann, self.crop_shape)
69 |
70 | # setting 1 boundary pix of each instance to background
71 | inner_map = np.zeros(fixed_ann.shape[:2], np.uint8)
72 | contour_map = np.zeros(fixed_ann.shape[:2], np.uint8)
73 |
74 | inst_list = list(np.unique(fixed_ann))
75 |
76 | try: # remove background
77 | inst_list.remove(0) # 0 is background
78 | except ValueError:
79 | pass
80 |
81 | if self.mode == "seg_gland":
82 | k_disk = np.array(
83 | [
84 | [0, 0, 0, 0, 1, 0, 0, 0, 0],
85 | [0, 0, 0, 1, 1, 1, 0, 0, 0],
86 | [0, 0, 1, 1, 1, 1, 1, 0, 0],
87 | [0, 1, 1, 1, 1, 1, 1, 1, 0],
88 | [1, 1, 1, 1, 1, 1, 1, 1, 1],
89 | [0, 1, 1, 1, 1, 1, 1, 1, 0],
90 | [0, 0, 1, 1, 1, 1, 1, 0, 0],
91 | [0, 0, 0, 1, 1, 1, 0, 0, 0],
92 | [0, 0, 0, 0, 1, 0, 0, 0, 0],
93 | ],
94 | np.uint8,
95 | )
96 | else:
97 | k_disk = np.array(
98 | [
99 | [0, 0, 1, 0, 0],
100 | [0, 1, 1, 1, 0],
101 | [1, 1, 1, 1, 1],
102 | [0, 1, 1, 1, 0],
103 | [0, 0, 1, 0, 0],
104 | ],
105 | np.uint8,
106 | )
107 |
108 | for inst_id in inst_list:
109 | inst_map = np.array(fixed_ann == inst_id, np.uint8)
110 | inner = cv2.erode(inst_map, k_disk, iterations=1)
111 | outer = cv2.dilate(inst_map, k_disk, iterations=1)
112 | inner_map += inner
113 | contour_map += outer - inner
114 | inner_map[inner_map > 0] = 1 # binarize
115 | contour_map[contour_map > 0] = 1 # binarize
116 | bg_map = 1 - (inner_map + contour_map)
117 | img = np.dstack([inner_map, contour_map, bg_map, img[..., 1:]])
118 | return img
119 |
120 |
121 | class GenInstanceMarkerMap(GenInstance):
122 | """
123 | Input annotation must be of original shape.
124 | Perform following operation:
125 | 1) Remove the 1px of boundary of each instance
126 | to create separation between touching instances
127 | 2) Generate the weight map from the result of 1)
128 | according to the unet paper equation.
129 | Args:
130 | wc (dict) : Dictionary of weight classes.
131 | w0 (int/float) : Border weight parameter.
132 | sigma (int/float): Border width parameter.
133 | """
134 |
135 | def __init__(self, wc=None, w0=10.0, sigma=4.0, crop_shape=None):
136 | super(GenInstanceMarkerMap, self).__init__()
137 | self.crop_shape = crop_shape
138 | self.wc = wc
139 | self.w0 = w0
140 | self.sigma = sigma
141 |
142 | def _erode_obj(self, ann):
143 | new_ann = np.zeros(ann.shape[:2], np.int32)
144 | inst_list = list(np.unique(ann))
145 | inst_list.remove(0) # 0 is background
146 |
147 | inner_map = np.zeros(ann.shape[:2], np.uint8)
148 | contour_map = np.zeros(ann.shape[:2], np.uint8)
149 |
150 | k = np.array(
151 | [
152 | [0, 0, 0, 1, 0, 0, 0],
153 | [0, 0, 1, 1, 1, 0, 0],
154 | [0, 1, 1, 1, 1, 1, 0],
155 | [1, 1, 1, 1, 1, 1, 1],
156 | [0, 1, 1, 1, 1, 1, 0],
157 | [0, 0, 1, 1, 1, 0, 0],
158 | [0, 0, 0, 1, 0, 0, 0],
159 | ],
160 | np.uint8,
161 | )
162 |
163 | for inst_id in inst_list:
164 | inst_map = np.array(ann == inst_id, np.uint8)
165 | inner = cv2.erode(inst_map, k, iterations=1)
166 | outer = cv2.dilate(inst_map, k, iterations=1)
167 | inner_map += inner
168 | contour_map += outer - inner
169 | inner_map[inner_map > 0] = 1 # binarize
170 | contour_map[contour_map > 0] = 1 # binarize
171 | bg_map = 1 - (inner_map + contour_map)
172 |
173 | return inner_map, contour_map, bg_map
174 |
175 | def _get_weight_map(self, ann, inst_list):
176 | if len(inst_list) <= 1: # 1 instance only
177 | return np.zeros(ann.shape[:2])
178 | stacked_inst_bgd_dst = np.zeros(ann.shape[:2] + (len(inst_list),))
179 |
180 | for idx, inst_id in enumerate(inst_list):
181 | inst_bgd_map = np.array(ann != inst_id, np.uint8)
182 | inst_bgd_dst = distance_transform_edt(inst_bgd_map)
183 | stacked_inst_bgd_dst[..., idx] = inst_bgd_dst
184 |
185 | near1_dst = np.amin(stacked_inst_bgd_dst, axis=2)
186 | near2_dst = np.expand_dims(near1_dst, axis=2)
187 | near2_dst = stacked_inst_bgd_dst - near2_dst
188 | near2_dst[near2_dst == 0] = np.PINF # very large
189 | near2_dst = np.amin(near2_dst, axis=2)
190 | near2_dst[ann > 0] = 0 # the instances
191 | near2_dst = near2_dst + near1_dst
192 | # to fix pixel where near1 == near2
193 | near2_eve = np.expand_dims(near1_dst, axis=2)
194 | # to avoide the warning of a / 0
195 | near2_eve = (1.0 + stacked_inst_bgd_dst) / (1.0 + near2_eve)
196 | near2_eve[near2_eve != 1] = 0
197 | near2_eve = np.sum(near2_eve, axis=2)
198 | near2_dst[near2_eve > 1] = near1_dst[near2_eve > 1]
199 | #
200 | pix_dst = near1_dst + near2_dst
201 | pen_map = pix_dst / self.sigma
202 | pen_map = self.w0 * np.exp(-(pen_map ** 2) / 2)
203 | pen_map[ann > 0] = 0 # inner instances zero
204 | return pen_map
205 |
206 | def _augment(self, img, _):
207 | img = np.copy(img)
208 | orig_ann = img[..., 0] # instance ID map
209 | orig_ann_copy = orig_ann.copy()
210 | fixed_ann = self._fix_mirror_padding(orig_ann)
211 | # setting 1 boundary pix of each instance to background
212 | inner_map, contour_map, bg_map = self._erode_obj(fixed_ann)
213 |
214 | # cant do the shortcut because near2 also needs instances
215 | # outside of cropped portion
216 | inst_list = list(np.unique(fixed_ann))
217 | inst_list.remove(0) # 0 is background
218 | wmap = self._get_weight_map(fixed_ann, inst_list)
219 |
220 | if self.wc is None:
221 | wmap += 1 # uniform weight for all classes
222 | else:
223 | class_weights = np.zeros_like(fixed_ann.shape[:2])
224 | for class_id, class_w in self.wc.items():
225 | class_weights[fixed_ann == class_id] = class_w
226 | wmap += class_weights
227 |
228 | # fix other maps to align
229 | img[fixed_ann == 0] = 0
230 | orig_ann[orig_ann > 0] = 1
231 | img = np.dstack([orig_ann_copy, inner_map, contour_map, bg_map, wmap])
232 |
233 | return img
234 |
235 |
236 | class GaussianBlur(ImageAugmentor):
237 | """
238 | Gaussian blur the image with random window size
239 | """
240 |
241 | def __init__(self, max_size=3):
242 | """
243 | Args:
244 | max_size (int): max possible Gaussian window size would be 2 * max_size + 1
245 | """
246 | super(GaussianBlur, self).__init__()
247 | self.max_size = max_size
248 |
249 | def _get_augment_params(self, img):
250 | sx, sy = self.rng.randint(1, self.max_size, size=(2,))
251 | sx = sx * 2 + 1
252 | sy = sy * 2 + 1
253 | return sx, sy
254 |
255 | def _augment(self, img, s):
256 | return np.reshape(
257 | cv2.GaussianBlur(
258 | img, s, sigmaX=0, sigmaY=0, borderType=cv2.BORDER_REPLICATE
259 | ),
260 | img.shape,
261 | )
262 |
263 |
264 | class BinarizeLabel(ImageAugmentor):
265 | """
266 | Convert labels to binary maps
267 | """
268 |
269 | def __init__(self):
270 | super(BinarizeLabel, self).__init__()
271 |
272 | def _get_augment_params(self, img):
273 | return None
274 |
275 | def _augment(self, img, s):
276 | img = np.copy(img)
277 | arr = img[..., 0]
278 | arr[arr > 0] = 1
279 | return img
280 |
281 |
282 | class MedianBlur(ImageAugmentor):
283 | """
284 | Median blur the image with random window size
285 | """
286 |
287 | def __init__(self, max_size=3):
288 | """
289 | Args:
290 | max_size (int): max possible window size
291 | would be 2 * max_size + 1
292 | """
293 | super(MedianBlur, self).__init__()
294 | self.max_size = max_size
295 |
296 | def _get_augment_params(self, img):
297 | s = self.rng.randint(1, self.max_size)
298 | s = s * 2 + 1
299 | return s
300 |
301 | def _augment(self, img, ksize):
302 | return cv2.medianBlur(img, ksize)
303 |
--------------------------------------------------------------------------------
/src/loader/loader.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataset loader
3 | """
4 |
5 | import random
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 |
9 | from tensorpack.dataflow import (
10 | AugmentImageComponent,
11 | AugmentImageComponents,
12 | BatchData,
13 | BatchDataByShape,
14 | CacheData,
15 | PrefetchDataZMQ,
16 | RNGDataFlow,
17 | RepeatedData,
18 | )
19 |
20 | from config import Config
21 |
22 |
23 | class DatasetSerial(RNGDataFlow, Config):
24 | """
25 | Produce ``(image, label)`` pair, where
26 | ``image`` has shape HWC and is RGB, has values in range [0-255].
27 |
28 | ``label`` is a float image of shape (H, W, C). Number of C depends
29 | on `self.model_mode` within `config.py`
30 | """
31 |
32 | def __init__(self, path_list):
33 | super(DatasetSerial, self).__init__()
34 | self.path_list = path_list
35 |
36 | ##
37 |
38 | def size(self):
39 | return len(self.path_list)
40 |
41 | ##
42 |
43 | def get_data(self):
44 | idx_list = list(range(0, len(self.path_list)))
45 | random.shuffle(idx_list)
46 | for idx in idx_list:
47 |
48 | if self.model_mode == "seg_gland" or self.model_mode == "seg_nuc":
49 | data = np.load(self.path_list[idx])
50 | # split stacked channel into image and label
51 | img = data[..., :3] # RGB image
52 | img = img.astype("uint8")
53 | lab = data[..., 3:] # instance ID map
54 | yield [img, lab]
55 |
56 | else:
57 | data = np.load(self.path_list[idx])
58 | # split stacked channel into image and label
59 | img = data[:-1] # RGB image
60 | # reshape vector to HxWxC
61 | img = np.reshape(img, self.train_input_shape + [self.input_chans])
62 | if self.model_mode != "class_rotmnist":
63 | img = img.astype("uint8")
64 |
65 | lab = data[-1] # label
66 | lab = np.reshape(lab, [1, 1, 1])
67 |
68 | yield [img, lab]
69 |
70 |
71 | def valid_generator_seg(
72 | ds, shape_aug=None, input_aug=None, label_aug=None, batch_size=16, nr_procs=1
73 | ):
74 | ### augment both the input and label
75 | ds = (
76 | ds
77 | if shape_aug is None
78 | else AugmentImageComponents(ds, shape_aug, (0, 1), copy=True)
79 | )
80 | ### augment just the input
81 | ds = (
82 | ds
83 | if input_aug is None
84 | else AugmentImageComponent(ds, input_aug, index=0, copy=False)
85 | )
86 | ### augment just the output
87 | ds = (
88 | ds
89 | if label_aug is None
90 | else AugmentImageComponent(ds, label_aug, index=1, copy=True)
91 | )
92 | #
93 | ds = BatchData(ds, batch_size, remainder=True)
94 | ds = CacheData(ds) # cache all inference images
95 | return ds
96 |
97 |
98 | def valid_generator_class(
99 | ds, shape_aug=None, input_aug=None, batch_size=16, nr_procs=1
100 | ):
101 | ### augment the input
102 | ds = (
103 | ds
104 | if shape_aug is None
105 | else AugmentImageComponent(ds, shape_aug, index=0, copy=True)
106 | )
107 | ### augment the input
108 | ds = (
109 | ds
110 | if input_aug is None
111 | else AugmentImageComponent(ds, input_aug, index=0, copy=False)
112 | )
113 | #
114 | ds = BatchData(ds, batch_size, remainder=True)
115 | ds = CacheData(ds) # cache all inference images
116 | return ds
117 |
118 |
119 | def train_generator_seg(
120 | ds, shape_aug=None, input_aug=None, label_aug=None, batch_size=16, nr_procs=8
121 | ):
122 | ### augment both the input and label
123 | ds = (
124 | ds
125 | if shape_aug is None
126 | else AugmentImageComponents(ds, shape_aug, (0, 1), copy=True)
127 | )
128 | ### augment just the input i.e index 0 within each yield of DatasetSerial
129 | ds = (
130 | ds
131 | if input_aug is None
132 | else AugmentImageComponent(ds, input_aug, index=0, copy=False)
133 | )
134 | ### augment just the output i.e index 1 within each yield of DatasetSerial
135 | ds = (
136 | ds
137 | if label_aug is None
138 | else AugmentImageComponent(ds, label_aug, index=1, copy=True)
139 | )
140 | #
141 | ds = BatchDataByShape(ds, batch_size, idx=0)
142 | ds = PrefetchDataZMQ(ds, nr_procs)
143 | return ds
144 |
145 |
146 | def train_generator_class(
147 | ds, shape_aug=None, input_aug=None, batch_size=16, nr_procs=8
148 | ):
149 | ### augment the input
150 | ds = (
151 | ds
152 | if shape_aug is None
153 | else AugmentImageComponent(ds, shape_aug, index=0, copy=True)
154 | )
155 | ### augment the input i.e index 0 within each yield of DatasetSerial
156 | ds = (
157 | ds
158 | if input_aug is None
159 | else AugmentImageComponent(ds, input_aug, index=0, copy=False)
160 | )
161 | #
162 | ds = BatchDataByShape(ds, batch_size, idx=0)
163 | ds = PrefetchDataZMQ(ds, nr_procs)
164 | return ds
165 |
166 |
167 | def visualize(datagen, batch_size):
168 | """
169 | Read the batch from 'datagen' and display 'view_size' number of
170 | of images and their corresponding Ground Truth
171 | """
172 | cfg = Config()
173 |
174 | def prep_imgs(img, lab):
175 |
176 | # Deal with HxWx1 case
177 | img = np.squeeze(img)
178 |
179 | if cfg.model_mode == "seg_gland" or cfg.model_mode == "seg_nuc":
180 | cmap = plt.get_cmap("jet")
181 | # cmap may randomly fails if of other types
182 | lab = lab.astype("float32")
183 | lab_chs = np.dsplit(lab, lab.shape[-1])
184 | for i, ch in enumerate(lab_chs):
185 | ch = np.squeeze(ch)
186 | # cmap may behave stupidly
187 | ch = ch / (np.max(ch) - np.min(ch) + 1.0e-16)
188 | # take RGB from RGBA heat map
189 | lab_chs[i] = cmap(ch)[..., :3]
190 | img = img.astype("float32") / 255.0
191 | prepped_img = np.concatenate([img] + lab_chs, axis=1)
192 | else:
193 | prepped_img = img
194 | return prepped_img
195 |
196 | ds = RepeatedData(datagen, -1)
197 | ds.reset_state()
198 | for imgs, labs in ds.get_data():
199 | if cfg.model_mode == "seg_gland" or cfg.model_mode == "seg_nuc":
200 | for idx in range(0, 4):
201 | displayed_img = prep_imgs(imgs[idx], labs[idx])
202 | # plot the image and the label
203 | plt.subplot(4, 1, idx + 1)
204 | plt.imshow(displayed_img, vmin=-1, vmax=1)
205 | plt.axis("off")
206 | plt.show()
207 | else:
208 | for idx in range(0, 8):
209 | displayed_img = prep_imgs(imgs[idx], labs[idx])
210 | # plot the image and the label
211 | plt.subplot(2, 4, idx + 1)
212 | plt.imshow(displayed_img)
213 | if len(cfg.label_names) > 0:
214 | lab_title = cfg.label_names[int(labs[idx])]
215 | else:
216 | lab_tite = int(labs[idx])
217 | plt.title(lab_title)
218 | plt.axis("off")
219 | plt.show()
220 | return
221 |
--------------------------------------------------------------------------------
/src/misc/patch_extractor.py:
--------------------------------------------------------------------------------
1 | """
2 | Patch extraction script
3 | """
4 |
5 | import math
6 | import cv2
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 |
10 | from .utils import cropping_center
11 |
12 |
13 | class PatchExtractor(object):
14 | """
15 | Extractor to generate patches with or without padding.
16 | Turn on debug mode to see how it is done.
17 |
18 | Args:
19 | x : input image, should be of shape HWC
20 | win_size : a tuple of (h, w)
21 | step_size : a tuple of (h, w)
22 | debug : flag to see how it is done
23 |
24 | Return:
25 | a list of sub patches, each patch has dtype same as x
26 |
27 | Examples:
28 | >>> xtractor = PatchExtractor((450, 450), (120, 120))
29 | >>> img = np.full([1200, 1200, 3], 255, np.uint8)
30 | >>> patches = xtractor.extract(img, 'mirror')
31 | """
32 |
33 | def __init__(self, win_size, step_size, debug=False):
34 |
35 | self.patch_type = 'mirror'
36 | self.win_size = win_size
37 | self.step_size = step_size
38 | self.debug = debug
39 | self.counter = 0
40 |
41 | def __get_patch(self, x, ptx):
42 | pty = (ptx[0]+self.win_size[0],
43 | ptx[1]+self.win_size[1])
44 | win = x[ptx[0]:pty[0],
45 | ptx[1]:pty[1]]
46 | assert win.shape[0] == self.win_size[0] and \
47 | win.shape[1] == self.win_size[1], \
48 | '[BUG] Incorrect Patch Size {0}'.format(win.shape)
49 | if self.debug:
50 | if self.patch_type == 'mirror':
51 | cen = cropping_center(win, self.step_size)
52 | cen = cen[..., self.counter % 3]
53 | cen.fill(150)
54 | cv2.rectangle(x, ptx, pty, (255, 0, 0), 2)
55 | plt.imshow(x)
56 | plt.show(block=False)
57 | plt.pause(1)
58 | plt.close()
59 | self.counter += 1
60 | return win
61 |
62 | def __extract_valid(self, x):
63 | """
64 | Extracted patches without padding, only work in case win_size > step_size
65 |
66 | Note: to deal with the remaining portions which are at the boundary a.k.a
67 | those which do not fit when slide left->right, top->bottom), we flip
68 | the sliding direction then extract 1 patch starting from right / bottom edge.
69 | There will be 1 additional patch extracted at the bottom-right corner
70 |
71 | Args:
72 | x : input image, should be of shape HWC
73 | win_size : a tuple of (h, w)
74 | step_size : a tuple of (h, w)
75 |
76 | Return:
77 | a list of sub patches, each patch is same dtype as x
78 | """
79 |
80 | im_h = x.shape[0]
81 | im_w = x.shape[1]
82 |
83 | def extract_infos(length, win_size, step_size):
84 | flag = (length - win_size) % step_size != 0
85 | last_step = math.floor((length - win_size) / step_size)
86 | last_step = (last_step + 1) * step_size
87 | return flag, last_step
88 |
89 | h_flag, h_last = extract_infos(
90 | im_h, self.win_size[0], self.step_size[0])
91 | w_flag, w_last = extract_infos(
92 | im_w, self.win_size[1], self.step_size[1])
93 |
94 | sub_patches = []
95 | #### Deal with valid block
96 | for row in range(0, h_last, self.step_size[0]):
97 | for col in range(0, w_last, self.step_size[1]):
98 | win = self.__get_patch(x, (row, col))
99 | sub_patches.append(win)
100 | #### Deal with edge case
101 | if h_flag:
102 | row = im_h - self.win_size[0]
103 | for col in range(0, w_last, self.step_size[1]):
104 | win = self.__get_patch(x, (row, col))
105 | sub_patches.append(win)
106 | if w_flag:
107 | col = im_w - self.win_size[1]
108 | for row in range(0, h_last, self.step_size[0]):
109 | win = self.__get_patch(x, (row, col))
110 | sub_patches.append(win)
111 | if h_flag and w_flag:
112 | ptx = (im_h - self.win_size[0], im_w - self.win_size[1])
113 | win = self.__get_patch(x, ptx)
114 | sub_patches.append(win)
115 | return sub_patches
116 |
117 | def __extract_mirror(self, x):
118 | """
119 | Extracted patches with mirror padding the boundary such that the
120 | central region of each patch is always within the orginal (non-padded)
121 | image while all patches' central region cover the whole orginal image
122 |
123 | Args:
124 | x : input image, should be of shape HWC
125 | win_size : a tuple of (h, w)
126 | step_size : a tuple of (h, w)
127 |
128 | Return:
129 | a list of sub patches, each patch is same dtype as x
130 | """
131 |
132 | diff_h = self.win_size[0] - self.step_size[0]
133 | padt = diff_h // 2
134 | padb = diff_h - padt
135 |
136 | diff_w = self.win_size[1] - self.step_size[1]
137 | padl = diff_w // 2
138 | padr = diff_w - padl
139 |
140 | pad_type = 'constant' if self.debug else 'reflect'
141 | x = np.lib.pad(x, ((padt, padb), (padl, padr), (0, 0)), pad_type)
142 | sub_patches = self.__extract_valid(x)
143 | return sub_patches
144 |
145 | def extract(self, x, patch_type):
146 | """
147 | Extract patches
148 | """
149 | patch_type = patch_type.lower()
150 | self.patch_type = patch_type
151 | if patch_type == 'valid':
152 | return self.__extract_valid(x)
153 | elif patch_type == 'mirror':
154 | return self.__extract_mirror(x)
155 | else:
156 | assert False, 'Unknown Patch Type [%s]' % patch_type
157 | return
158 |
159 | ###########################################################################
160 |
161 |
162 | if __name__ == '__main__':
163 | # debugging
164 | xtractor = PatchExtractor((450, 450), (120, 120), debug=True)
165 | a = np.full([1200, 1200, 3], 255, np.uint8)
166 | xtractor.extract(a, 'mirror')
167 | xtractor.extract(a, 'valid')
168 |
--------------------------------------------------------------------------------
/src/misc/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utils
3 | """
4 |
5 | import glob
6 | import os
7 | import shutil
8 | import cv2
9 | import numpy as np
10 |
11 |
12 | def bounding_box(img):
13 | """
14 | Get the bounding box of a binary region
15 |
16 | Args:
17 | img: input array- should contain one
18 | binary object.
19 | """
20 | rows = np.any(img, axis=1)
21 | cols = np.any(img, axis=0)
22 | rmin, rmax = np.where(rows)[0][[0, -1]]
23 | cmin, cmax = np.where(cols)[0][[0, -1]]
24 | # due to python indexing, need to add 1 to max
25 | # else accessing will be 1px in the box, not out
26 | rmax += 1
27 | cmax += 1
28 | return [rmin, rmax, cmin, cmax]
29 |
30 |
31 | def cropping_center(img, crop_shape, batch=False):
32 | """
33 | Crop an array at the centre
34 |
35 | Args:
36 | img: input array
37 | crop_shape: new spatial dimensions (h,w)
38 | """
39 |
40 | orig_shape = img.shape
41 | if not batch:
42 | h_0 = int((orig_shape[0] - crop_shape[0]) * 0.5)
43 | w_0 = int((orig_shape[1] - crop_shape[1]) * 0.5)
44 | img = img[h_0 : h_0 + crop_shape[0], w_0 : w_0 + crop_shape[1]]
45 | else:
46 | h_0 = int((orig_shape[1] - crop_shape[0]) * 0.5)
47 | w_0 = int((orig_shape[2] - crop_shape[1]) * 0.5)
48 | img = img[:, h_0 : h_0 + crop_shape[0], w_0 : w_0 + crop_shape[1]]
49 | return img
50 |
51 |
52 | def rm_n_mkdir(dir_path):
53 | """
54 | Remove, then create a new directory
55 | """
56 |
57 | if os.path.isdir(dir_path):
58 | shutil.rmtree(dir_path)
59 | os.makedirs(dir_path)
60 |
61 |
62 | def get_files(data_dir_list, data_ext):
63 | """
64 | Given a list of directories containing data with extention 'date_ext',
65 | generate a list of paths for all files within these directories
66 | """
67 |
68 | data_files = []
69 | for sub_dir in data_dir_list:
70 | files = glob.glob(sub_dir + "/*" + data_ext)
71 | data_files.extend(files)
72 |
73 | return data_files
74 |
75 |
76 | def remap_label(pred, by_size=False):
77 | """Rename all instance id so that the id is contiguous i.e [0, 1, 2, 3]
78 | not [0, 2, 4, 6]. The ordering of instances (which one comes first)
79 | is preserved unless by_size=True, then the instances will be reordered
80 | so that bigger object has smaller ID.
81 | Args:
82 | pred : the 2d array contain instances where each instances is marked
83 | by non-zero integer
84 | by_size : renaming with larger object has smaller id (on-top)
85 | """
86 | pred_id = list(np.unique(pred))
87 | pred_id.remove(0)
88 | if len(pred_id) == 0:
89 | return pred # no label
90 | if by_size:
91 | pred_size = []
92 | for inst_id in pred_id:
93 | size = (pred == inst_id).sum()
94 | pred_size.append(size)
95 | # sort the id by size in descending order
96 | pair_list = zip(pred_id, pred_size)
97 | pair_list = sorted(pair_list, key=lambda x: x[1], reverse=True)
98 | pred_id, pred_size = zip(*pair_list)
99 |
100 | new_pred = np.zeros(pred.shape, np.int32)
101 | for idx, inst_id in enumerate(pred_id):
102 | new_pred[pred == inst_id] = idx + 1
103 | return new_pred
104 |
--------------------------------------------------------------------------------
/src/misc/viz_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Visualisation utils
3 | """
4 |
5 | import math
6 | import random
7 | import colorsys
8 | import cv2
9 | import numpy as np
10 | import matplotlib.pyplot as plt
11 |
12 | from .utils import bounding_box
13 |
14 |
15 | def random_colors(N, bright=True):
16 | """
17 | Generate random colors.
18 | To get visually distinct colors, generate them in HSV space then
19 | convert to RGB.
20 | """
21 | brightness = 1.0 if bright else 0.7
22 | hsv = [(i / N, 1, brightness) for i in range(N)]
23 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
24 | random.shuffle(colors)
25 | return colors
26 |
27 |
28 | def visualize_instances(mask, canvas=None):
29 | """
30 | Args:
31 | mask: array of NW
32 | Return:
33 | Image with the instance overlaid
34 | """
35 |
36 | colour = [255, 255, 0] # yellow
37 |
38 | canvas = (
39 | np.full((mask.shape[0], mask.shape[1]) + (3,), 200, dtype=np.uint8)
40 | if canvas is None
41 | else np.copy(canvas)
42 | )
43 |
44 | insts_list = list(np.unique(mask))
45 | insts_list.remove(0) # remove background
46 |
47 | for idx, inst_id in enumerate(insts_list):
48 | inst_map = np.array(mask == inst_id, np.uint8)
49 | y1, y2, x1, x2 = bounding_box(inst_map)
50 | y1 = y1 - 2 if y1 - 2 >= 0 else y1
51 | x1 = x1 - 2 if x1 - 2 >= 0 else x1
52 | x2 = x2 + 2 if x2 + 2 <= mask.shape[1] - 1 else x2
53 | y2 = y2 + 2 if y2 + 2 <= mask.shape[0] - 1 else y2
54 | inst_map_crop = inst_map[y1:y2, x1:x2]
55 | inst_canvas_crop = canvas[y1:y2, x1:x2]
56 | contours = cv2.findContours(
57 | inst_map_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
58 | )
59 | cv2.drawContours(inst_canvas_crop, contours[0], -1, colour, 3)
60 | canvas[y1:y2, x1:x2] = inst_canvas_crop
61 |
62 | return canvas
63 |
--------------------------------------------------------------------------------
/src/model/class_pcam/graph.py:
--------------------------------------------------------------------------------
1 | """
2 | DSF-CNN for tumour classification
3 | """
4 |
5 | import tensorflow as tf
6 |
7 | from tensorpack import *
8 | from tensorpack.models import BNReLU, Conv2D, MaxPooling
9 | from tensorpack.tfutils.summary import add_moving_summary, add_param_summary
10 |
11 | from model.utils.model_utils import *
12 | from model.utils.gconv_utils import *
13 |
14 | import sys
15 |
16 | sys.path.append("..") # adds higher directory to python modules path.
17 | try: # HACK: import beyond current level, may need to restructure
18 | from config import Config
19 | except ImportError:
20 | assert False, "Fail to import config.py"
21 |
22 |
23 | def group_concat(x, y, nr_orients):
24 | shape1 = x.get_shape().as_list()
25 | chans1 = shape1[3]
26 | c1 = int(chans1 / nr_orients)
27 | x = tf.reshape(x, [-1, shape1[1], shape1[2], nr_orients, c1])
28 |
29 | shape2 = y.get_shape().as_list()
30 | chans2 = shape2[3]
31 | c2 = int(chans2 / nr_orients)
32 | y = tf.reshape(y, [-1, shape2[1], shape2[2], nr_orients, c2])
33 |
34 | z = tf.concat([x, y], axis=-1)
35 |
36 | return tf.reshape(z, [-1, shape1[1], shape1[2], nr_orients * (c1 + c2)])
37 |
38 |
39 | def g_dense_blk(
40 | name,
41 | l,
42 | ch,
43 | ksize,
44 | count,
45 | nr_orients,
46 | filter_type,
47 | basis_filter_list,
48 | rot_matrix_list,
49 | padding="same",
50 | bn_init=True,
51 | ):
52 | with tf.variable_scope(name):
53 | for i in range(0, count):
54 | with tf.variable_scope("blk/" + str(i)):
55 | if bn_init:
56 | x = GBNReLU("preact_bna", l, nr_orients)
57 | else:
58 | x = l
59 | x = GConv2D(
60 | "conv1",
61 | x,
62 | ch[0],
63 | ksize[0],
64 | nr_orients,
65 | filter_type,
66 | basis_filter_list[0],
67 | rot_matrix_list[0],
68 | )
69 | x = GConv2D(
70 | "conv2",
71 | x,
72 | ch[1],
73 | ksize[1],
74 | nr_orients,
75 | filter_type,
76 | basis_filter_list[1],
77 | rot_matrix_list[1],
78 | activation="identity",
79 | )
80 | ##
81 | if padding == "valid":
82 | x_shape = x.get_shape().as_list()
83 | l_shape = l.get_shape().as_list()
84 | l = crop_op(
85 | l,
86 | (l_shape[1] - x_shape[2], l_shape[1] - x_shape[2]),
87 | "channels_last",
88 | )
89 |
90 | l = group_concat(l, x, nr_orients)
91 | l = GBNReLU("blk_bna", l, nr_orients)
92 | return l
93 |
94 |
95 | def net(
96 | name, i, basis_filter_list, rot_matrix_list, nr_orients, filter_type, is_training
97 | ):
98 | """
99 | Dense Steerable Filter CNN
100 | """
101 |
102 | dense_basis_list = [basis_filter_list[0], basis_filter_list[1]]
103 | dense_rot_list = [rot_matrix_list[0], rot_matrix_list[1]]
104 |
105 | with tf.variable_scope(name):
106 |
107 | c1 = GConv2D(
108 | "ds_conv1",
109 | i,
110 | 8,
111 | 7,
112 | nr_orients,
113 | filter_type,
114 | basis_filter_list[1],
115 | rot_matrix_list[1],
116 | input_layer=True,
117 | )
118 | c2 = GConv2D(
119 | "ds_conv2",
120 | c1,
121 | 8,
122 | 7,
123 | nr_orients,
124 | filter_type,
125 | basis_filter_list[1],
126 | rot_matrix_list[1],
127 | activation="identity",
128 | )
129 | p1 = MaxPooling("max_pool1", c2, 2)
130 | ####
131 |
132 | d1 = g_dense_blk(
133 | "dense1",
134 | p1,
135 | [32, 8],
136 | [5, 7],
137 | 2,
138 | nr_orients,
139 | filter_type,
140 | dense_basis_list,
141 | dense_rot_list,
142 | bn_init=False,
143 | )
144 | c3 = GConv2D(
145 | "ds_conv3",
146 | d1,
147 | 32,
148 | 5,
149 | nr_orients,
150 | filter_type,
151 | basis_filter_list[0],
152 | rot_matrix_list[0],
153 | activation="identity",
154 | )
155 | p2 = MaxPooling("max_pool2", c3, 2, padding="valid")
156 | ####
157 |
158 | d2 = g_dense_blk(
159 | "dense2",
160 | p2,
161 | [32, 8],
162 | [5, 7],
163 | 2,
164 | nr_orients,
165 | filter_type,
166 | dense_basis_list,
167 | dense_rot_list,
168 | bn_init=False,
169 | )
170 | c4 = GConv2D(
171 | "ds_conv4",
172 | d2,
173 | 32,
174 | 5,
175 | nr_orients,
176 | filter_type,
177 | basis_filter_list[0],
178 | rot_matrix_list[0],
179 | activation="identity",
180 | )
181 | p3 = MaxPooling("max_pool3", c4, 2, padding="valid")
182 | ####
183 |
184 | d3 = g_dense_blk(
185 | "dense3",
186 | p3,
187 | [32, 8],
188 | [5, 7],
189 | 3,
190 | nr_orients,
191 | filter_type,
192 | dense_basis_list,
193 | dense_rot_list,
194 | bn_init=False,
195 | )
196 | c5 = GConv2D(
197 | "ds_conv5",
198 | d3,
199 | 32,
200 | 5,
201 | nr_orients,
202 | filter_type,
203 | basis_filter_list[0],
204 | rot_matrix_list[0],
205 | activation="identity",
206 | )
207 | p4 = MaxPooling("max_pool4", c5, 2, padding="valid")
208 | ####
209 |
210 | d4 = g_dense_blk(
211 | "dense4",
212 | p4,
213 | [32, 8],
214 | [5, 7],
215 | 3,
216 | nr_orients,
217 | filter_type,
218 | dense_basis_list,
219 | dense_rot_list,
220 | bn_init=False,
221 | )
222 | c6 = GConv2D(
223 | "ds_conv6",
224 | d4,
225 | 32,
226 | 5,
227 | nr_orients,
228 | filter_type,
229 | basis_filter_list[0],
230 | rot_matrix_list[0],
231 | )
232 | p5 = AvgPooling("glb_avg_pool", c6, 6, padding="valid")
233 | p6 = GroupPool("orient_pool", p5, nr_orients, pool_type="max")
234 | ####
235 |
236 | c7 = Conv2D("conv3", p6, 96, 1, use_bias=True, nl=BNReLU)
237 | c7 = tf.layers.dropout(c7, rate=0.3, seed=5, training=is_training)
238 | c8 = Conv2D("conv4", c7, 96, 1, use_bias=True, nl=BNReLU)
239 | c8 = tf.layers.dropout(c8, rate=0.3, seed=5, training=is_training)
240 |
241 | return c8
242 |
243 |
244 | class Model(ModelDesc, Config):
245 | def __init__(self, freeze=False):
246 | super(Model, self).__init__()
247 | # assert tf.test.is_gpu_available()
248 | self.freeze = freeze
249 | self.data_format = "NHWC"
250 |
251 | def _get_inputs(self):
252 | return [
253 | InputDesc(tf.float32, [None] + self.train_input_shape + [3], "images"),
254 | InputDesc(
255 | tf.float32, [None] + self.train_output_shape + [None], "truemap-coded"
256 | ),
257 | ]
258 |
259 | # for node to receive manual info such as learning rate.
260 | def add_manual_variable(self, name, init_value, summary=True):
261 | var = tf.get_variable(name, initializer=init_value, trainable=False)
262 | if summary:
263 | tf.summary.scalar(name + "-summary", var)
264 | return
265 |
266 | def _get_optimizer(self):
267 | with tf.variable_scope("", reuse=True):
268 | lr = tf.get_variable("learning_rate")
269 | opt = self.optimizer(learning_rate=lr)
270 | return opt
271 |
272 |
273 | class Graph(Model):
274 | def _build_graph(self, inputs):
275 |
276 | is_training = get_current_tower_context().is_training
277 |
278 | images, truemap_coded = inputs
279 | orig_imgs = images
280 | true = truemap_coded[..., 0]
281 | true = tf.cast(true, tf.int32)
282 | true = tf.identity(true, name="truemap")
283 | one_hot = tf.one_hot(true, 2, axis=-1)
284 | true = tf.expand_dims(true, axis=-1)
285 |
286 | ####
287 | with argscope(
288 | Conv2D,
289 | activation=tf.identity,
290 | use_bias=False, # K.he initializer
291 | W_init=tf.variance_scaling_initializer(scale=2.0, mode="fan_out"),
292 | ), argscope([Conv2D], data_format=self.data_format):
293 |
294 | i = images if not self.input_norm else images / 255.0
295 |
296 | ####
297 | feat = net(
298 | "net",
299 | i,
300 | self.basis_filter_list,
301 | self.rot_matrix_list,
302 | self.nr_orients,
303 | self.filter_type,
304 | is_training,
305 | )
306 |
307 | #### Prediction
308 | o_logi = Conv2D("output", feat, 2, 1, use_bias=True, nl=tf.identity)
309 | soft = tf.nn.softmax(o_logi, axis=-1)
310 |
311 | prob = tf.identity(soft, name="predmap-prob")
312 |
313 | # encoded so that inference can extract all output at once
314 | predmap_coded = tf.concat(prob, axis=-1, name="predmap-coded")
315 |
316 | ####
317 | if get_current_tower_context().is_training:
318 | # ---- LOSS ----#
319 | loss = 0
320 | for term, weight in self.loss_term.items():
321 | if term == "bce":
322 | term_loss = categorical_crossentropy(soft, one_hot)
323 | term_loss = tf.reduce_mean(term_loss, name="loss-bce")
324 | else:
325 | assert False, "Not support loss term: %s" % term
326 | add_moving_summary(term_loss)
327 | loss += term_loss * weight
328 |
329 | ### combine the loss into single cost function
330 | wd_loss = regularize_cost(".*/W", l2_regularizer(1.0e-7), name="l2_wd_loss")
331 | add_moving_summary(wd_loss)
332 | self.cost = tf.identity(loss + wd_loss, name="overall-loss")
333 | add_moving_summary(self.cost)
334 | ####
335 |
336 | add_param_summary((".*/W", ["histogram"])) # monitor W
337 |
338 | ### logging visual sthg
339 | orig_imgs = tf.cast(orig_imgs, tf.uint8)
340 | tf.summary.image("input", orig_imgs, max_outputs=1)
341 |
342 | return
343 |
--------------------------------------------------------------------------------
/src/model/seg_gland/graph.py:
--------------------------------------------------------------------------------
1 | """
2 | DSF-CNN for gland segmentation
3 | """
4 |
5 | import tensorflow as tf
6 |
7 | from tensorpack import *
8 | from tensorpack.models import BNReLU, Conv2D, MaxPooling
9 | from tensorpack.tfutils.summary import add_moving_summary, add_param_summary
10 |
11 | from model.utils.model_utils import *
12 | from model.utils.gconv_utils import *
13 |
14 | import sys
15 |
16 | sys.path.append("..") # adds higher directory to python modules path.
17 | try: # HACK: import beyond current level, may need to restructure
18 | from config import Config
19 | except ImportError:
20 | assert False, "Fail to import config.py"
21 |
22 |
23 | def upsample(name, x, size):
24 | return tf.image.resize_images(x, [size, size])
25 |
26 |
27 | def group_concat(x, y, nr_orients):
28 | shape1 = x.get_shape().as_list()
29 | chans1 = shape1[3]
30 | c1 = int(chans1 / nr_orients)
31 | x = tf.reshape(x, [-1, shape1[1], shape1[2], nr_orients, c1])
32 |
33 | shape2 = y.get_shape().as_list()
34 | chans2 = shape2[3]
35 | c2 = int(chans2 / nr_orients)
36 | y = tf.reshape(y, [-1, shape2[1], shape2[2], nr_orients, c2])
37 |
38 | z = tf.concat([x, y], axis=-1)
39 |
40 | return tf.reshape(z, [-1, shape1[1], shape1[2], nr_orients * (c1 + c2)])
41 |
42 |
43 | def g_dense_blk(
44 | name,
45 | l,
46 | ch,
47 | ksize,
48 | count,
49 | nr_orients,
50 | filter_type,
51 | basis_filter_list,
52 | rot_matrix_list,
53 | padding="same",
54 | ):
55 | with tf.variable_scope(name):
56 | for i in range(0, count):
57 | with tf.variable_scope("blk/" + str(i)):
58 | x = GBNReLU("preact_bna", l, nr_orients)
59 | x = GConv2D(
60 | "conv1",
61 | x,
62 | ch[0],
63 | ksize[0],
64 | nr_orients,
65 | filter_type,
66 | basis_filter_list[0],
67 | rot_matrix_list[0],
68 | )
69 | x = GConv2D(
70 | "conv2",
71 | x,
72 | ch[1],
73 | ksize[1],
74 | nr_orients,
75 | filter_type,
76 | basis_filter_list[1],
77 | rot_matrix_list[1],
78 | activation="identity",
79 | )
80 | ##
81 | if padding == "valid":
82 | x_shape = x.get_shape().as_list()
83 | l_shape = l.get_shape().as_list()
84 | l = crop_op(l, (l_shape[2] - x_shape[2], l_shape[3] - x_shape[3]))
85 |
86 | l = group_concat(l, x, nr_orients)
87 | l = GBNReLU("blk_bna", l, nr_orients)
88 | return l
89 |
90 |
91 | def encoder(
92 | name, i, basis_filter_list, rot_matrix_list, nr_orients, filter_type, is_training
93 | ):
94 | """
95 | Dense Steerable Filter Encoder
96 | """
97 |
98 | dense_basis_list = [basis_filter_list[1], basis_filter_list[0]]
99 | dense_rot_list = [rot_matrix_list[1], rot_matrix_list[0]]
100 |
101 | with tf.variable_scope(name):
102 |
103 | c1 = GConv2D(
104 | "ds_conv1",
105 | i,
106 | 10,
107 | 7,
108 | nr_orients,
109 | filter_type,
110 | basis_filter_list[1],
111 | rot_matrix_list[1],
112 | input_layer=True,
113 | )
114 | c2 = GConv2D(
115 | "ds_conv2",
116 | c1,
117 | 10,
118 | 7,
119 | nr_orients,
120 | filter_type,
121 | basis_filter_list[1],
122 | rot_matrix_list[1],
123 | activation="identity",
124 | )
125 | p1 = MaxPooling("max_pool1", c2, 2)
126 | ####
127 |
128 | d1 = g_dense_blk(
129 | "dense1",
130 | p1,
131 | [14, 6],
132 | [7, 5],
133 | 3,
134 | nr_orients,
135 | filter_type,
136 | dense_basis_list,
137 | dense_rot_list,
138 | )
139 | c3 = GConv2D(
140 | "ds_conv3",
141 | d1,
142 | 16,
143 | 5,
144 | nr_orients,
145 | filter_type,
146 | basis_filter_list[0],
147 | rot_matrix_list[0],
148 | activation="identity",
149 | )
150 | p2 = MaxPooling("max_pool2", c3, 2, padding="valid")
151 | ####
152 |
153 | d2 = g_dense_blk(
154 | "dense2",
155 | p2,
156 | [14, 6],
157 | [7, 5],
158 | 4,
159 | nr_orients,
160 | filter_type,
161 | dense_basis_list,
162 | dense_rot_list,
163 | )
164 | c4 = GConv2D(
165 | "ds_conv4",
166 | d2,
167 | 32,
168 | 5,
169 | nr_orients,
170 | filter_type,
171 | basis_filter_list[0],
172 | rot_matrix_list[0],
173 | activation="identity",
174 | )
175 | p3 = MaxPooling("max_pool3", c4, 2, padding="valid")
176 | ####
177 |
178 | d3 = g_dense_blk(
179 | "dense3",
180 | p3,
181 | [14, 6],
182 | [7, 5],
183 | 5,
184 | nr_orients,
185 | filter_type,
186 | dense_basis_list,
187 | dense_rot_list,
188 | )
189 | c5 = GConv2D(
190 | "ds_conv5",
191 | d3,
192 | 32,
193 | 5,
194 | nr_orients,
195 | filter_type,
196 | basis_filter_list[0],
197 | rot_matrix_list[0],
198 | activation="identity",
199 | )
200 | p4 = MaxPooling("max_pool4", c5, 2, padding="valid")
201 | ####
202 |
203 | d4 = g_dense_blk(
204 | "dense4",
205 | p4,
206 | [14, 6],
207 | [7, 5],
208 | 6,
209 | nr_orients,
210 | filter_type,
211 | dense_basis_list,
212 | dense_rot_list,
213 | )
214 | c6 = GConv2D(
215 | "ds_conv6",
216 | d4,
217 | 32,
218 | 5,
219 | nr_orients,
220 | filter_type,
221 | basis_filter_list[0],
222 | rot_matrix_list[0],
223 | activation="identity",
224 | )
225 |
226 | return [c2, c3, c4, c5, c6]
227 |
228 |
229 | def decoder(
230 | name, i, basis_filter_list, rot_matrix_list, nr_orients, filter_type, is_training
231 | ):
232 | """
233 | Dense Steerable Filter Decoder
234 | """
235 |
236 | dense_basis_list = [basis_filter_list[1], basis_filter_list[0]]
237 | dense_rot_list = [rot_matrix_list[1], rot_matrix_list[0]]
238 |
239 | with tf.variable_scope(name):
240 | with tf.variable_scope("us1"):
241 | us1 = upsample("us1", i[-1], 56)
242 | us1 = g_dense_blk(
243 | "dense_us1",
244 | us1,
245 | [14, 6],
246 | [7, 5],
247 | 4,
248 | nr_orients,
249 | filter_type,
250 | dense_basis_list,
251 | dense_rot_list,
252 | )
253 | us1 = GConv2D(
254 | "us_conv1",
255 | us1,
256 | 32,
257 | 5,
258 | nr_orients,
259 | filter_type,
260 | basis_filter_list[0],
261 | rot_matrix_list[0],
262 | activation="identity",
263 | )
264 | ####
265 |
266 | with tf.variable_scope("us2"):
267 | us2 = upsample("us2", us1, 112)
268 | us2_sum = tf.add_n([us2, i[-3]])
269 | us2 = g_dense_blk(
270 | "dense_us2",
271 | us2_sum,
272 | [14, 6],
273 | [7, 5],
274 | 3,
275 | nr_orients,
276 | filter_type,
277 | dense_basis_list,
278 | dense_rot_list,
279 | )
280 | us2 = GConv2D(
281 | "us_conv2",
282 | us2,
283 | 16,
284 | 5,
285 | nr_orients,
286 | filter_type,
287 | basis_filter_list[0],
288 | rot_matrix_list[0],
289 | activation="identity",
290 | )
291 | ####
292 |
293 | with tf.variable_scope("us3"):
294 | us3 = upsample("us3", us2, 224)
295 | us3_sum = tf.add_n([us3, i[-4]])
296 | us3 = g_dense_blk(
297 | "dense_us3",
298 | us3_sum,
299 | [14, 6],
300 | [7, 5],
301 | 3,
302 | nr_orients,
303 | filter_type,
304 | dense_basis_list,
305 | dense_rot_list,
306 | )
307 | us3 = GConv2D(
308 | "us_conv3",
309 | us3,
310 | 10,
311 | 5,
312 | nr_orients,
313 | filter_type,
314 | basis_filter_list[0],
315 | rot_matrix_list[0],
316 | activation="identity",
317 | )
318 | ####
319 |
320 | with tf.variable_scope("us4"):
321 | us4 = upsample("us4", us3, 448)
322 | us4_sum = tf.add_n([us4, i[-5]])
323 | us4 = GConv2D(
324 | "us_conv4",
325 | us4_sum,
326 | 10,
327 | 7,
328 | nr_orients,
329 | filter_type,
330 | basis_filter_list[1],
331 | rot_matrix_list[1],
332 | )
333 | feat = GroupPool("us4", us4, nr_orients, pool_type="max")
334 |
335 | return feat
336 |
337 |
338 | class Model(ModelDesc, Config):
339 | def __init__(self, freeze=False):
340 | super(Model, self).__init__()
341 | # assert tf.test.is_gpu_available()
342 | self.freeze = freeze
343 | self.data_format = "NHWC"
344 |
345 | def _get_inputs(self):
346 | return [
347 | InputDesc(tf.float32, [None] + self.train_input_shape + [3], "images"),
348 | InputDesc(
349 | tf.float32, [None] + self.train_output_shape + [None], "truemap-coded"
350 | ),
351 | ]
352 |
353 | # for node to receive manual info such as learning rate.
354 | def add_manual_variable(self, name, init_value, summary=True):
355 | var = tf.get_variable(name, initializer=init_value, trainable=False)
356 | if summary:
357 | tf.summary.scalar(name + "-summary", var)
358 | return
359 |
360 | def _get_optimizer(self):
361 | with tf.variable_scope("", reuse=True):
362 | lr = tf.get_variable("learning_rate")
363 | opt = self.optimizer(learning_rate=lr)
364 | return opt
365 |
366 |
367 | class Graph(Model):
368 | def _build_graph(self, inputs):
369 |
370 | is_training = get_current_tower_context().is_training
371 |
372 | images, truemap_coded = inputs
373 | orig_imgs = images
374 |
375 | true = truemap_coded[..., :3]
376 | true = tf.cast(true, tf.int32)
377 | true = tf.identity(true, name="truemap")
378 | one_hot = tf.cast(true, tf.float32)
379 |
380 | ####
381 | with argscope(
382 | Conv2D,
383 | activation=tf.identity,
384 | use_bias=False, # K.he initializer
385 | W_init=tf.variance_scaling_initializer(scale=2.0, mode="fan_out"),
386 | ), argscope([Conv2D], data_format=self.data_format):
387 |
388 | i = images if not self.input_norm else images / 255.0
389 |
390 | ####
391 | d = encoder(
392 | "encoder",
393 | i,
394 | self.basis_filter_list,
395 | self.rot_matrix_list,
396 | self.nr_orients,
397 | self.filter_type,
398 | is_training,
399 | )
400 |
401 | ####
402 | feat = decoder(
403 | "decoder",
404 | d,
405 | self.basis_filter_list,
406 | self.rot_matrix_list,
407 | self.nr_orients,
408 | self.filter_type,
409 | is_training,
410 | )
411 |
412 | feat1 = Conv2D("feat", feat, 96, 1, use_bias=True, nl=BNReLU)
413 | o_logi = Conv2D("output", feat, 3, 1, use_bias=True, nl=tf.identity)
414 | soft = tf.nn.softmax(o_logi, axis=-1)
415 |
416 | prob = tf.identity(soft[..., :2], name="predmap-prob")
417 |
418 | # encoded so that inference can extract all output at once
419 | predmap_coded = tf.concat(prob, axis=-1, name="predmap-coded")
420 |
421 | ####
422 | if get_current_tower_context().is_training:
423 | # ---- LOSS ----#
424 | loss = 0
425 | for term, weight in self.loss_term.items():
426 | if term == "bce":
427 | term_loss = categorical_crossentropy(soft, one_hot)
428 | term_loss = tf.reduce_mean(term_loss, name="loss-bce")
429 | elif "dice" in self.loss_term:
430 | # branch 1
431 | term_loss = dice_loss(soft[..., 0], one_hot[..., 0]) + dice_loss(
432 | soft[..., 1], one_hot[..., 1]
433 | )
434 | term_loss = tf.identity(term_loss, name="loss-dice")
435 | else:
436 | assert False, "Not support loss term: %s" % term
437 | add_moving_summary(term_loss)
438 | loss += term_loss
439 |
440 | ### combine the loss into single cost function
441 | wd_loss = regularize_cost(".*/W", l2_regularizer(1.0e-7), name="l2_wd_loss")
442 | add_moving_summary(wd_loss)
443 | self.cost = tf.identity(loss + wd_loss, name="overall-loss")
444 | add_moving_summary(self.cost)
445 | ####
446 |
447 | add_param_summary((".*/W", ["histogram"])) # monitor W
448 |
449 | ### logging visual sthg
450 | orig_imgs = tf.cast(orig_imgs, tf.uint8)
451 | tf.summary.image("input", orig_imgs, max_outputs=1)
452 |
453 | pred_blb = colorize(prob[..., 0], cmap="jet")
454 | true_blb = colorize(true[..., 0], cmap="jet")
455 |
456 | pred_cnt = colorize(prob[..., 1], cmap="jet")
457 | true_cnt = colorize(true[..., 1], cmap="jet")
458 |
459 | viz = tf.concat([orig_imgs, pred_blb, pred_cnt, true_blb, true_cnt], 2)
460 |
461 | viz = tf.concat([viz[0], viz[-1]], axis=0)
462 | viz = tf.expand_dims(viz, axis=0)
463 | tf.summary.image("output", viz, max_outputs=1)
464 |
465 | return
466 |
467 |
--------------------------------------------------------------------------------
/src/model/seg_nuc/graph.py:
--------------------------------------------------------------------------------
1 | """
2 | DSF-CNN for nuclear segmentation
3 | """
4 |
5 | import tensorflow as tf
6 |
7 | from tensorpack import *
8 | from tensorpack.models import BNReLU, Conv2D, MaxPooling
9 | from tensorpack.tfutils.summary import add_moving_summary, add_param_summary
10 |
11 | from model.utils.model_utils import *
12 | from model.utils.gconv_utils import *
13 |
14 | import sys
15 |
16 | sys.path.append("..") # adds higher directory to python modules path.
17 | try: # HACK: import beyond current level, may need to restructure
18 | from config import Config
19 | except ImportError:
20 | assert False, "Fail to import config.py"
21 |
22 |
23 | def upsample(name, x, size):
24 | return tf.image.resize_images(x, [size, size])
25 |
26 |
27 | def group_concat(x, y, nr_orients):
28 | shape1 = x.get_shape().as_list()
29 | chans1 = shape1[3]
30 | c1 = int(chans1 / nr_orients)
31 | x = tf.reshape(x, [-1, shape1[1], shape1[2], nr_orients, c1])
32 |
33 | shape2 = y.get_shape().as_list()
34 | chans2 = shape2[3]
35 | c2 = int(chans2 / nr_orients)
36 | y = tf.reshape(y, [-1, shape2[1], shape2[2], nr_orients, c2])
37 |
38 | z = tf.concat([x, y], axis=-1)
39 |
40 | return tf.reshape(z, [-1, shape1[1], shape1[2], nr_orients * (c1 + c2)])
41 |
42 |
43 | def g_dense_blk(
44 | name,
45 | l,
46 | ch,
47 | ksize,
48 | count,
49 | nr_orients,
50 | filter_type,
51 | basis_filter_list,
52 | rot_matrix_list,
53 | padding="same",
54 | bn_init=True,
55 | ):
56 | with tf.variable_scope(name):
57 | for i in range(0, count):
58 | with tf.variable_scope("blk/" + str(i)):
59 | if bn_init:
60 | x = GBNReLU("preact_bna", l, nr_orients)
61 | else:
62 | x = l
63 | x = GConv2D(
64 | "conv1",
65 | x,
66 | ch[0],
67 | ksize[0],
68 | nr_orients,
69 | filter_type,
70 | basis_filter_list[0],
71 | rot_matrix_list[0],
72 | )
73 | x = GConv2D(
74 | "conv2",
75 | x,
76 | ch[1],
77 | ksize[1],
78 | nr_orients,
79 | filter_type,
80 | basis_filter_list[1],
81 | rot_matrix_list[1],
82 | activation="identity",
83 | )
84 | ##
85 | if padding == "valid":
86 | x_shape = x.get_shape().as_list()
87 | l_shape = l.get_shape().as_list()
88 | l = crop_op(
89 | l,
90 | (l_shape[1] - x_shape[2], l_shape[1] - x_shape[2]),
91 | "channels_last",
92 | )
93 |
94 | l = group_concat(l, x, nr_orients)
95 | l = GBNReLU("blk_bna", l, nr_orients)
96 | return l
97 |
98 |
99 | def encoder(
100 | name, i, basis_filter_list, rot_matrix_list, nr_orients, filter_type, is_training
101 | ):
102 | """
103 | Dense Steerable Filter Encoder
104 | """
105 |
106 | dense_basis_list = [basis_filter_list[1], basis_filter_list[0]]
107 | dense_rot_list = [rot_matrix_list[1], rot_matrix_list[0]]
108 |
109 | with tf.variable_scope(name):
110 |
111 | c1 = GConv2D(
112 | "ds_conv1",
113 | i,
114 | 10,
115 | 7,
116 | nr_orients,
117 | filter_type,
118 | basis_filter_list[1],
119 | rot_matrix_list[1],
120 | input_layer=True,
121 | )
122 | c2 = GConv2D(
123 | "ds_conv2",
124 | c1,
125 | 10,
126 | 7,
127 | nr_orients,
128 | filter_type,
129 | basis_filter_list[1],
130 | rot_matrix_list[1],
131 | activation="identity",
132 | )
133 | p1 = MaxPooling("max_pool1", c2, 2)
134 | ####
135 |
136 | d1 = g_dense_blk(
137 | "dense1",
138 | p1,
139 | [14, 6],
140 | [7, 5],
141 | 3,
142 | nr_orients,
143 | filter_type,
144 | dense_basis_list,
145 | dense_rot_list,
146 | )
147 | c3 = GConv2D(
148 | "ds_conv3",
149 | d1,
150 | 16,
151 | 5,
152 | nr_orients,
153 | filter_type,
154 | basis_filter_list[0],
155 | rot_matrix_list[0],
156 | activation="identity",
157 | )
158 | p2 = MaxPooling("max_pool2", c3, 2, padding="valid")
159 | ####
160 |
161 | d2 = g_dense_blk(
162 | "dense2",
163 | p2,
164 | [14, 6],
165 | [7, 5],
166 | 4,
167 | nr_orients,
168 | filter_type,
169 | dense_basis_list,
170 | dense_rot_list,
171 | )
172 | c4 = GConv2D(
173 | "ds_conv4",
174 | d2,
175 | 32,
176 | 5,
177 | nr_orients,
178 | filter_type,
179 | basis_filter_list[0],
180 | rot_matrix_list[0],
181 | activation="identity",
182 | )
183 | p3 = MaxPooling("max_pool3", c4, 2, padding="valid")
184 | ####
185 |
186 | d3 = g_dense_blk(
187 | "dense3",
188 | p3,
189 | [14, 6],
190 | [7, 5],
191 | 5,
192 | nr_orients,
193 | filter_type,
194 | dense_basis_list,
195 | dense_rot_list,
196 | )
197 | c5 = GConv2D(
198 | "ds_conv5",
199 | d3,
200 | 32,
201 | 5,
202 | nr_orients,
203 | filter_type,
204 | basis_filter_list[0],
205 | rot_matrix_list[0],
206 | activation="identity",
207 | )
208 | p4 = MaxPooling("max_pool4", c5, 2, padding="valid")
209 | ####
210 |
211 | d4 = g_dense_blk(
212 | "dense4",
213 | p4,
214 | [14, 6],
215 | [7, 5],
216 | 6,
217 | nr_orients,
218 | filter_type,
219 | dense_basis_list,
220 | dense_rot_list,
221 | )
222 | c6 = GConv2D(
223 | "ds_conv6",
224 | d4,
225 | 32,
226 | 5,
227 | nr_orients,
228 | filter_type,
229 | basis_filter_list[0],
230 | rot_matrix_list[0],
231 | activation="identity",
232 | )
233 |
234 | return [c2, c3, c4, c5, c6]
235 |
236 |
237 | def decoder(
238 | name, i, basis_filter_list, rot_matrix_list, nr_orients, filter_type, is_training
239 | ):
240 | """
241 | Dense Steerable Filter Decoder
242 | """
243 |
244 | dense_basis_list = [basis_filter_list[1], basis_filter_list[0]]
245 | dense_rot_list = [rot_matrix_list[1], rot_matrix_list[0]]
246 |
247 | with tf.variable_scope(name):
248 | with tf.variable_scope("us1"):
249 | us1 = upsample("us1", i[-1], 32)
250 | us1_sum = tf.add_n([us1, i[-2]])
251 | us1 = g_dense_blk(
252 | "dense_us1",
253 | us1_sum,
254 | [14, 6],
255 | [7, 5],
256 | 4,
257 | nr_orients,
258 | filter_type,
259 | dense_basis_list,
260 | dense_rot_list,
261 | )
262 | us1 = GConv2D(
263 | "us_conv1",
264 | us1,
265 | 32,
266 | 5,
267 | nr_orients,
268 | filter_type,
269 | basis_filter_list[0],
270 | rot_matrix_list[0],
271 | activation="identity",
272 | )
273 | ####
274 |
275 | with tf.variable_scope("us2"):
276 | us2 = upsample("us2", us1, 64)
277 | us2_sum = tf.add_n([us2, i[-3]])
278 | us2 = g_dense_blk(
279 | "dense_us2",
280 | us2_sum,
281 | [14, 6],
282 | [7, 5],
283 | 3,
284 | nr_orients,
285 | filter_type,
286 | dense_basis_list,
287 | dense_rot_list,
288 | )
289 | us2 = GConv2D(
290 | "us_conv2",
291 | us2,
292 | 16,
293 | 5,
294 | nr_orients,
295 | filter_type,
296 | basis_filter_list[0],
297 | rot_matrix_list[0],
298 | activation="identity",
299 | )
300 | ####
301 |
302 | with tf.variable_scope("us3"):
303 | us3 = upsample("us3", us2, 128)
304 | us3_sum = tf.add_n([us3, i[-4]])
305 | us3 = g_dense_blk(
306 | "dense_us3",
307 | us3_sum,
308 | [14, 6],
309 | [7, 5],
310 | 2,
311 | nr_orients,
312 | filter_type,
313 | dense_basis_list,
314 | dense_rot_list,
315 | )
316 | us3 = GConv2D(
317 | "us_conv3",
318 | us3,
319 | 10,
320 | 5,
321 | nr_orients,
322 | filter_type,
323 | basis_filter_list[0],
324 | rot_matrix_list[0],
325 | activation="identity",
326 | )
327 | ####
328 |
329 | with tf.variable_scope("us4"):
330 | us4 = upsample("us4", us3, 256)
331 | us4_sum = tf.add_n([us4, i[-5]])
332 | us4 = GConv2D(
333 | "us_conv4",
334 | us4_sum,
335 | 10,
336 | 7,
337 | nr_orients,
338 | filter_type,
339 | basis_filter_list[1],
340 | rot_matrix_list[1],
341 | )
342 | feat = GroupPool("us4", us4, nr_orients, pool_type="max")
343 |
344 | return feat
345 |
346 |
347 | class Model(ModelDesc, Config):
348 | def __init__(self, freeze=False):
349 | super(Model, self).__init__()
350 | # assert tf.test.is_gpu_available()
351 | self.freeze = freeze
352 | self.data_format = "NHWC"
353 |
354 | def _get_inputs(self):
355 | return [
356 | InputDesc(tf.float32, [None] + self.train_input_shape + [3], "images"),
357 | InputDesc(
358 | tf.float32, [None] + self.train_output_shape + [None], "truemap-coded"
359 | ),
360 | ]
361 |
362 | # for node to receive manual info such as learning rate.
363 | def add_manual_variable(self, name, init_value, summary=True):
364 | var = tf.get_variable(name, initializer=init_value, trainable=False)
365 | if summary:
366 | tf.summary.scalar(name + "-summary", var)
367 | return
368 |
369 | def _get_optimizer(self):
370 | with tf.variable_scope("", reuse=True):
371 | lr = tf.get_variable("learning_rate")
372 | opt = self.optimizer(learning_rate=lr)
373 | return opt
374 |
375 |
376 | class Graph(Model):
377 | def _build_graph(self, inputs):
378 |
379 | is_training = get_current_tower_context().is_training
380 |
381 | images, truemap_coded = inputs
382 | orig_imgs = images
383 |
384 | pen_map = truemap_coded[..., -1]
385 |
386 | true_np = truemap_coded[..., 0]
387 | true_np = tf.cast(true_np, tf.int32)
388 | true_np = tf.identity(true_np, name="truemap-np")
389 | one_np = tf.one_hot(true_np, 2, axis=-1)
390 | true_np = tf.expand_dims(true_np, axis=-1)
391 |
392 | true_mk = truemap_coded[..., 1:4]
393 | true_mk = tf.cast(true_mk, tf.int32)
394 | true_mk = tf.identity(true_mk, name="truemap-mk")
395 | one_mk = tf.cast(true_mk, tf.float32)
396 |
397 | ####
398 | with argscope(
399 | Conv2D,
400 | activation=tf.identity,
401 | use_bias=False, # K.he initializer
402 | W_init=tf.variance_scaling_initializer(scale=2.0, mode="fan_out"),
403 | ), argscope([Conv2D], data_format=self.data_format):
404 |
405 | i = images if not self.input_norm else images / 255.0
406 |
407 | ####
408 | d = encoder(
409 | "encoder",
410 | i,
411 | self.basis_filter_list,
412 | self.rot_matrix_list,
413 | self.nr_orients,
414 | self.filter_type,
415 | is_training,
416 | )
417 | ####
418 | feat = decoder(
419 | "decoder",
420 | d,
421 | self.basis_filter_list,
422 | self.rot_matrix_list,
423 | self.nr_orients,
424 | self.filter_type,
425 | is_training,
426 | )
427 |
428 | feat_np = Conv2D("feat_np", feat, 96, 1, use_bias=True, nl=BNReLU)
429 | o_logi_np = Conv2D(
430 | "output_np", feat_np, 2, 1, use_bias=True, nl=tf.identity
431 | )
432 | soft_np = tf.nn.softmax(o_logi_np, axis=-1)
433 | prob_np = tf.identity(soft_np[..., 1], name="predmap-prob")
434 | prob_np = tf.expand_dims(prob_np, -1)
435 |
436 | feat_mk = Conv2D("feat_mk", feat, 96, 1, use_bias=True, nl=BNReLU)
437 | o_logi_mk = Conv2D(
438 | "output_mk", feat_mk, 3, 1, use_bias=True, nl=tf.identity
439 | )
440 | soft_mk = tf.nn.softmax(o_logi_mk, axis=-1)
441 | prob_mk = tf.identity(soft_mk[..., :2], name="predmap-prob")
442 |
443 | # encoded so that inference can extract all output at once
444 | predmap_coded = tf.concat([prob_np, prob_mk], axis=-1, name="predmap-coded")
445 |
446 | ####
447 | if get_current_tower_context().is_training:
448 | # ---- LOSS ----#
449 | loss = 0
450 | for term, weight in self.loss_term.items():
451 | if term == "bce":
452 | term_loss_np = categorical_crossentropy(soft_np, one_np)
453 | term_loss_np = tf.reduce_mean(term_loss_np, name="loss-bce-np")
454 |
455 | term_loss_mk = categorical_crossentropy(soft_mk, one_mk)
456 | term_loss_mk = tf.reduce_mean(
457 | term_loss_mk * pen_map, name="loss-bce-mk"
458 | )
459 | elif "dice" in self.loss_term:
460 | # branch 1
461 | term_loss_np = dice_loss(
462 | soft_np[..., 0], one_np[..., 0]
463 | ) + dice_loss(soft_np[..., 1], one_np[..., 1])
464 | term_loss_np = tf.identity(term_loss_np, name="loss-dice-np")
465 |
466 | term_loss_mk = dice_loss(
467 | soft_mk[..., 0], one_mk[..., 0]
468 | ) + dice_loss(soft_mk[..., 1], one_mk[..., 1])
469 | term_loss_mk = tf.identity(term_loss_mk, name="loss-dice-mk")
470 | else:
471 | assert False, "Not support loss term: %s" % term
472 | add_moving_summary(term_loss_np)
473 | add_moving_summary(term_loss_mk)
474 | loss += term_loss_np + term_loss_mk
475 |
476 | ### combine the loss into single cost function
477 | wd_loss = regularize_cost(".*/W", l2_regularizer(1.0e-7), name="l2_wd_loss")
478 | add_moving_summary(wd_loss)
479 | self.cost = tf.identity(loss + wd_loss, name="overall-loss")
480 | add_moving_summary(self.cost)
481 | ####
482 |
483 | add_param_summary((".*/W", ["histogram"])) # monitor W
484 |
485 | ### logging visual sthg
486 | orig_imgs = tf.cast(orig_imgs, tf.uint8)
487 | tf.summary.image("input", orig_imgs, max_outputs=1)
488 |
489 | pred_np = colorize(prob_np[..., 0], cmap="jet")
490 | true_np = colorize(true_np[..., 0], cmap="jet")
491 |
492 | pred_mk_blb = colorize(prob_mk[..., 0], cmap="jet")
493 | true_mk_blb = colorize(true_mk[..., 0], cmap="jet")
494 | pred_mk_cnt = colorize(prob_mk[..., 1], cmap="jet")
495 | true_mk_cnt = colorize(true_mk[..., 1], cmap="jet")
496 |
497 | viz = tf.concat(
498 | [
499 | orig_imgs,
500 | pred_np,
501 | pred_mk_blb,
502 | pred_mk_cnt,
503 | true_np,
504 | true_mk_blb,
505 | true_mk_cnt,
506 | ],
507 | 2,
508 | )
509 |
510 | viz = tf.concat([viz[0], viz[-1]], axis=0)
511 | viz = tf.expand_dims(viz, axis=0)
512 | tf.summary.image("output", viz, max_outputs=1)
513 |
514 | return
515 |
--------------------------------------------------------------------------------
/src/model/utils/gconv_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Group equivariant convolution utils
3 | """
4 |
5 | import math
6 | import numpy as np
7 |
8 | import tensorflow as tf
9 | from tensorflow.python.framework import dtypes
10 | from tensorflow.python.ops import random_ops
11 |
12 | from tensorpack import *
13 | from tensorpack.tfutils.symbolic_functions import *
14 | from tensorpack.tfutils.summary import *
15 | from tensorpack.tfutils.common import get_tf_version_tuple
16 |
17 | from matplotlib import cm
18 |
19 | from model.utils.norm_utils import *
20 | from model.utils.rotation_utils import *
21 |
22 |
23 | def GBNReLU(name, x, nr_orients):
24 | """
25 | A shorthand of Group Equivariant BatchNormalization + ReLU.
26 |
27 | Args:
28 | name: variable scope name
29 | x: input tensor
30 | nr_orients: number of filter orientations
31 |
32 | Returns:
33 | out: normalised tensor with ReLU activation
34 | """
35 |
36 | shape = x.get_shape().as_list()
37 | chans = shape[3]
38 |
39 | c = int(chans / nr_orients)
40 |
41 | x = tf.reshape(x, [-1, shape[1], shape[2], nr_orients, c])
42 | bn = BatchNorm3d(name + "_bn", x)
43 | act = tf.nn.relu(bn, name="relu")
44 | out = tf.reshape(act, [-1, shape[1], shape[2], chans])
45 | return out
46 |
47 |
48 | def GBatchNorm(name, x, nr_orients):
49 | """
50 | Group Equivariant BatchNormalization.
51 |
52 | Args:
53 | name: variable scope name
54 | x: input tensor
55 | nr_orients: number of filter orientations
56 |
57 | Returns:
58 | out: normalised tensor
59 | """
60 |
61 | shape = x.get_shape().as_list()
62 | chans = shape[3]
63 |
64 | c = int(chans / nr_orients)
65 |
66 | x = tf.reshape(x, [-1, shape[1], shape[2], nr_orients, c])
67 | bn = BatchNorm3d(name + "_bn", x)
68 | out = tf.reshape(act, [-1, shape[1], shape[2], chans])
69 | return out
70 |
71 |
72 | def get_basis_params(k_size):
73 | """
74 | Get the filter parameters for a given kernel size
75 |
76 | Args:
77 | k_size (int): input kernel size
78 |
79 | Returns:
80 | alpha_list: list of alpha values
81 | beta_list: list of beta values
82 | bl_list: used to bandlimit high frequency filters in get_basis_filters()
83 | """
84 |
85 | if k_size == 5:
86 | alpha_list = [0, 1, 2]
87 | beta_list = [0, 1, 2]
88 | bl_list = [0, 2, 2]
89 | if k_size == 7:
90 | alpha_list = [0, 1, 2, 3]
91 | beta_list = [0, 1, 2, 3]
92 | bl_list = [0, 2, 3, 2]
93 | if k_size == 9:
94 | alpha_list = [0, 1, 2, 3, 4]
95 | beta_list = [0, 1, 2, 3, 4]
96 | bl_list = [0, 3, 4, 4, 3]
97 | if k_size == 11:
98 | alpha_list = [0, 1, 2, 3, 4]
99 | beta_list = [1, 2, 3, 4]
100 | bl_list = [0, 3, 4, 4, 3]
101 |
102 | return alpha_list, beta_list, bl_list
103 |
104 |
105 | def get_basis_filters(alpha_list, beta_list, bl_list, k_size, eps=10 ** -8):
106 | """
107 | Gets the atomic basis filters
108 |
109 | Args:
110 | alpha_list: list of alpha values for basis filters
111 | beta_list: list of beta values for the basis filters
112 | bl_list: bandlimit list to reduce aliasing of basis filters
113 | k_size (int): kernel size of basis filters
114 | eps=10**-8: epsilon used to prevent division by 0
115 |
116 | Returns:
117 | filter_list_bl: list of filters, with bandlimiting (bl) to reduce aliasing
118 | alpha_list_bl: corresponding list of alpha used in bandlimited filters
119 | beta_list_bl: corresponding list of beta used in bandlimited filters
120 | """
121 |
122 | filter_list = []
123 | freq_list = []
124 | for beta in beta_list:
125 | for alpha in alpha_list:
126 | if alpha <= bl_list[beta]:
127 | his = k_size // 2 # half image size
128 | y_index, x_index = np.mgrid[-his : (his + 1), -his : (his + 1)]
129 | y_index *= -1
130 | z_index = x_index + 1j * y_index
131 |
132 | # convert z to natural coordinates and add eps to avoid division by zero
133 | z = z_index + eps
134 | r = np.abs(z)
135 |
136 | if beta == beta_list[-1]:
137 | sigma = 0.4
138 | else:
139 | sigma = 0.6
140 | rad_prof = np.exp(-((r - beta) ** 2) / (2 * (sigma ** 2)))
141 | c_image = rad_prof * (z / r) ** alpha
142 | c_image_norm = (math.sqrt(2) * c_image) / np.linalg.norm(c_image)
143 |
144 | # add basis filter to list
145 | filter_list.append(c_image)
146 | # add corresponding frequency of filter to list (info needed for phase manipulation)
147 | freq_list.append(alpha)
148 |
149 | filter_array = np.array(filter_list)
150 |
151 | filter_array = np.reshape(
152 | filter_array,
153 | [filter_array.shape[0], filter_array.shape[1], filter_array.shape[2], 1, 1, 1],
154 | )
155 | return tf.convert_to_tensor(filter_array, dtype=tf.complex64), freq_list
156 |
157 |
158 | def get_rot_info(nr_orients, alpha_list):
159 | """
160 | Generate rotation info for phase manipulation of steerable filters.
161 | Rotation is dependent onthe frequency of the filter (alpha)
162 |
163 | Args:
164 | nr_orients: number of filter rotations
165 | alpha_list: list of alpha values that detemine the frequency
166 |
167 | Returns:
168 | rot_info used to rotate steerable filters
169 | """
170 |
171 | # Generate rotation matrix for phase manipulation of steerable function
172 | rot_list = []
173 | for i in range(len(alpha_list)):
174 | list_tmp = []
175 | for j in range(nr_orients):
176 | # Rotation is dependent on the frequency of the basis filter
177 | angle = (2 * np.math.pi / nr_orients) * j
178 | list_tmp.append(np.exp(-1j * alpha_list[i] * angle))
179 | rot_list.append(list_tmp)
180 | rot_info = np.array(rot_list)
181 |
182 | # Reshape to enable matrix multiplication
183 | rot_info = np.reshape(rot_info, [rot_info.shape[0], 1, 1, 1, 1, nr_orients])
184 | rot_info = tf.convert_to_tensor(rot_info, dtype=tf.complex64)
185 | return rot_info
186 |
187 |
188 | def GroupPool(name, x, nr_orients, pool_type="max"):
189 | """
190 | Perform pooling along the orientation axis.
191 |
192 | Args:
193 | name: variable scope name
194 | x: input tensor
195 | nr_orients: number of filter orientations
196 | pool_type: choose either 'max' or 'mean'
197 |
198 | Returns:
199 | pool: pooled tensor
200 | """
201 | shape = x.get_shape().as_list()
202 | new_shape = [-1, shape[1], shape[2], nr_orients, shape[3] // nr_orients]
203 | x_reshape = tf.reshape(x, new_shape)
204 | if pool_type == "max":
205 | pool = tf.reduce_max(x_reshape, 3)
206 | elif pool_type == "mean":
207 | pool = tf.reduce_mean(x_reshape, 3)
208 | else:
209 | raise ValueError("Pool type not recognised")
210 | return pool
211 |
212 |
213 | def steerable_initializer(
214 | nr_orients, factor=2.0, mode="FAN_IN", seed=None, dtype=dtypes.float32
215 | ):
216 | """
217 | Initialise complex coefficients in accordance with Weiler et al. (https://arxiv.org/pdf/1711.07289.pdf)
218 | Note, here we use the truncated normal dist, whereas Weiler et al. uses the regular normal dist.
219 |
220 | Args:
221 | input_layer:
222 | nr_orients: number of filter orientations
223 | factor: factor used for weight init
224 | mode: 'FAN_IN' or 'FAN_OUT
225 | seed: seed for weight init
226 | dtype: data type
227 |
228 | Returns:
229 | _initializer:
230 | """
231 |
232 | def _initializer(shape, dtype=dtype, partition_info=None):
233 |
234 | # total number of basis filters
235 | Q = shape[0] * shape[1]
236 | if mode == "FAN_IN":
237 | fan_in = shape[-2]
238 | C = fan_in
239 | # count number of input connections.
240 | elif mode == "FAN_OUT":
241 | fan_out = shape[-2]
242 | # count number of output connections.
243 | C = fan_out
244 | n = C * Q
245 | # to get stddev = math.sqrt(factor / n) need to adjust for truncated.
246 | trunc_stddev = math.sqrt(factor / n) / 0.87962566103423978
247 | return random_ops.truncated_normal(shape, 0.0, trunc_stddev, dtype, seed=seed)
248 |
249 | return _initializer
250 |
251 |
252 | def cycle_channels(filters, shape_list):
253 | """
254 | Perform cyclic permutation of the orientation channels for kernels on the group G.
255 |
256 | Args:
257 | filters: input filters
258 | shape_list: [nr_orients_out, ksize, ksize,
259 | nr_orients_in, filters_in, filters_out]
260 |
261 | Returns:
262 | tensor of filters with channels permuted
263 | """
264 |
265 | nr_orients_out = shape_list[0]
266 | rotated_filters = [None] * nr_orients_out
267 | for orientation in range(nr_orients_out):
268 | # [K, K, nr_orients_in, filters_in, filters_out]
269 | filters_temp = filters[orientation]
270 | # [K, K, filters_in, filters_out, nr_orients]
271 | filters_temp = tf.transpose(filters_temp, [0, 1, 3, 4, 2])
272 | # [K * K * filters_in * filters_out, nr_orients_in]
273 | filters_temp = tf.reshape(
274 | filters_temp,
275 | [
276 | shape_list[1] * shape_list[2] * shape_list[4] * shape_list[5],
277 | shape_list[3],
278 | ],
279 | )
280 | # Cycle along the orientation axis
281 | roll_matrix = tf.constant(
282 | np.roll(np.identity(shape_list[3]), orientation, axis=1), dtype=tf.float32
283 | )
284 | filters_temp = tf.matmul(filters_temp, roll_matrix)
285 | filters_temp = tf.reshape(
286 | filters_temp,
287 | [shape_list[1], shape_list[2], shape_list[4], shape_list[5], shape_list[3]],
288 | )
289 | filters_temp = tf.transpose(filters_temp, [0, 1, 4, 2, 3])
290 | rotated_filters[orientation] = filters_temp
291 |
292 | return tf.stack(rotated_filters)
293 |
294 |
295 | def gen_rotated_filters(
296 | w, filter_type, input_layer, nr_orients_out, basis_filters=None, rot_info=None
297 | ):
298 | """
299 | Generate the rotated filters either by phase manipulation or direct rotation of planar filter.
300 | Cyclic permutation of channels is performed for kernels on the group G.
301 |
302 | Args:
303 | w: coefficients used to perform a linear combination of basis filters
304 | filter_type: either 'steerable' or 'standard'
305 | input_layer (bool): whether 1st layer convolution or not
306 | nr_orients_out: number of output filter orientations
307 | basis_filters: atomic basis filters
308 | rot_info: array to determine how to rotate filters
309 |
310 | Returns:
311 | rot_filters: rotated steerable basis filters, with
312 | cyclic permutation if not the first layer
313 | """
314 |
315 | if filter_type == "steerable":
316 | # if using steerable filters, then rotate by phase manipulation
317 |
318 | rot_filters = [None] * nr_orients_out
319 | for orientation in range(nr_orients_out):
320 | rot_info_tmp = tf.expand_dims(rot_info[..., orientation], -1)
321 | filter_tmp = w * rot_info_tmp * basis_filters # phase manipulation
322 | rot_filters[orientation] = filter_tmp
323 | # [nr_orients_out, J, K, K, nr_orients_in, filters_in, filters_out] (M: nr frequencies, R: nr radial profile params)
324 | rot_filters = tf.stack(rot_filters)
325 |
326 | # Linear combination of basis filters
327 | # [nr_orients_out, K, K, nr_orients_in, filters_in, filters_out]
328 | rot_filters = tf.reduce_sum(rot_filters, axis=1)
329 | # Get real part of filters
330 | # [nr_orients_out, K, K, nr_orients_in, filters_in, filters_out]
331 | rot_filters = tf.math.real(rot_filters, name="filters")
332 |
333 | else:
334 | # if using regular kernels, rotate by sparse matrix multiplication
335 |
336 | # [K, K, nr_orients_in, filters_in, filters_out]
337 | filter_shape = w.get_shape().as_list()
338 |
339 | # Flatten the filter
340 | filter_flat = tf.reshape(
341 | w,
342 | [
343 | filter_shape[0] * filter_shape[1],
344 | filter_shape[2] * filter_shape[3] * filter_shape[4],
345 | ],
346 | )
347 |
348 | # Generate a set of rotated kernels via rotation matrix multiplication
349 | idx, vals = MultiRotationOperatorMatrixSparse(
350 | [filter_shape[0], filter_shape[1]],
351 | nr_orients_out,
352 | periodicity=2 * np.pi,
353 | diskMask=True,
354 | )
355 |
356 | # Sparse rotation matrix
357 | rotOp_matrix = tf.SparseTensor(
358 | idx,
359 | vals,
360 | [
361 | nr_orients_out * filter_shape[0] * filter_shape[1],
362 | filter_shape[0] * filter_shape[1],
363 | ],
364 | )
365 |
366 | # Matrix multiplication
367 | rot_filters = tf.sparse_tensor_dense_matmul(rotOp_matrix, filter_flat)
368 | # [nr_orients_out * K * K, filters_in * filters_out]
369 |
370 | # Reshape the filters to [nr_orients_out, K, K, nr_orients_in, filters_in, filters_out]
371 | rot_filters = tf.reshape(
372 | rot_filters,
373 | [
374 | nr_orients_out,
375 | filter_shape[0],
376 | filter_shape[1],
377 | filter_shape[2],
378 | filter_shape[3],
379 | filter_shape[4],
380 | ],
381 | )
382 |
383 | # Do not cycle filter for input convolution f: Z2 -> G
384 | if input_layer is False:
385 | shape_list = rot_filters.get_shape().as_list()
386 | # cycle channels - [nr_orients_out, K, K, nr_orients_in, filters_in, filters_out]
387 | rot_filters = cycle_channels(rot_filters, shape_list)
388 |
389 | return rot_filters
390 |
391 |
392 | def GConv2D(
393 | name,
394 | inputs,
395 | filters_out,
396 | kernel_size,
397 | nr_orients,
398 | filter_type,
399 | basis_filters=None,
400 | rot_info=None,
401 | input_layer=False,
402 | strides=[1, 1, 1, 1],
403 | padding="SAME",
404 | data_format="NHWC",
405 | activation="bnrelu",
406 | use_bias=False,
407 | bias_initializer=tf.zeros_initializer(),
408 | ):
409 | """
410 | Rotation equivatiant group convolution layer
411 |
412 | Args:
413 | name: variable scope name
414 | inputs: input tensor
415 | filters_out: number of filters out (per orientation)
416 | kernel_size: size of kernel
417 | basis_filters: atomic basis filters
418 | rot_info: array to determine how to rotate filters
419 | input_layer: whether the operation is the input layer (1st conv)
420 | strides: stride of kernel for convolution
421 | padding: choose either 'SAME' or 'VALID'
422 | data_format: either 'NHWC' or 'NCHW'
423 | activation: activation function to apply
424 | use_bias: whether to use bias
425 | bias_initializer: bias initialiser method
426 |
427 | Returns:
428 | conv: group equivariant convolution of input with
429 | steerable filters and optional activation.
430 | """
431 |
432 | if filter_type == "steerable":
433 | assert (
434 | basis_filters != None and rot_info != None
435 | ), "Must provide basis filters and rotation matrix"
436 |
437 | in_shape = inputs.get_shape().as_list()
438 | channel_axis = 3 if data_format == "NHWC" else 1
439 |
440 | if input_layer == False:
441 | nr_orients_in = nr_orients
442 | else:
443 | nr_orients_in = 1
444 | nr_orients_out = nr_orients
445 |
446 | filters_in = int(in_shape[channel_axis] / nr_orients_in)
447 |
448 | if filter_type == "steerable":
449 | # shape for the filter coefficients
450 | nr_b_filts = basis_filters.shape[0]
451 | w_shape = [nr_b_filts, 1, 1, nr_orients_in, filters_in, filters_out]
452 |
453 | # init complex valued weights with the adapted He init (Weiler et al.)
454 | w1 = tf.get_variable(
455 | name + "_W_real", w_shape, initializer=steerable_initializer(nr_orients_out)
456 | )
457 | w2 = tf.get_variable(
458 | name + "_W_imag", w_shape, initializer=steerable_initializer(nr_orients_out)
459 | )
460 | w = tf.complex(w1, w2)
461 |
462 | # Generate filters at different orientations- also perform cyclic permutation of channels if f: G -> G
463 | # Cyclic permutation of filters happenens for all rotation equivariant layers except for the input layer
464 | # [nr_orients_out, K, K, nr_orients_in, filters_in, filters_out]
465 | filters = gen_rotated_filters(
466 | w, filter_type, input_layer, nr_orients_out, basis_filters, rot_info
467 | )
468 |
469 | else:
470 | w_shape = [kernel_size, kernel_size, nr_orients_in, filters_in, filters_out]
471 | w = tf.get_variable(
472 | name + "_W",
473 | w_shape,
474 | initializer=tf.variance_scaling_initializer(scale=2.0, mode="fan_out"),
475 | )
476 |
477 | # Generate filters at different orientations- also perform cyclic permutation of channels if f: G -> G
478 | # Cyclic permutation of filters happenens for all rotation equivariant layers except for the input layer
479 | # [nr_orients_out, K, K, nr_orients_in, filters_in, filters_out]
480 | filters = gen_rotated_filters(w, filter_type, input_layer, nr_orients_out)
481 |
482 | # reshape filters for 2D convolution
483 | # [K, K, nr_orients_in, filters_in, nr_orients_out, filters_out]
484 | filters = tf.transpose(filters, [1, 2, 3, 4, 0, 5])
485 | filters = tf.reshape(
486 | filters,
487 | [
488 | kernel_size,
489 | kernel_size,
490 | nr_orients_in * filters_in,
491 | nr_orients_out * filters_out,
492 | ],
493 | )
494 |
495 | # perform conv with rotated filters (rehshaped so we can perform 2D convolution)
496 | kwargs = dict(data_format=data_format)
497 | conv = tf.nn.conv2d(inputs, filters, strides, padding.upper(), **kwargs)
498 | if use_bias:
499 | # Use same bias for all orientations
500 | b = tf.get_variable(
501 | name + "_bias", [filters_out], initializer=tf.zeros_initializer()
502 | )
503 | b = tf.stack([b] * nr_orients_out)
504 | b = tf.reshape(b, [nr_orients_out * filters_out])
505 | conv = tf.nn.bias_add(conv, b)
506 |
507 | if activation == "bnrelu":
508 | # Rotation equivariant batch normalisation
509 | conv = GBNReLU(name, conv, nr_orients_out)
510 |
511 | if activation == "bn":
512 | # Rotation equivariant batch normalisation
513 | conv = GBatchNorm(name, conv, nr_orients_out)
514 |
515 | if activation == "relu":
516 | # Rotation equivariant batch normalisation
517 | conv = tf.nn.relu(conv)
518 |
519 | return conv
520 |
--------------------------------------------------------------------------------
/src/model/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Model utils
3 | """
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 | from tensorflow.python.framework import dtypes
8 | from tensorflow.python.ops import random_ops
9 |
10 | from tensorpack import *
11 | from tensorpack.tfutils.symbolic_functions import *
12 | from tensorpack.tfutils.summary import *
13 | from tensorpack.tfutils.common import get_tf_version_tuple
14 |
15 | from matplotlib import cm
16 |
17 | from .norm_utils import *
18 |
19 |
20 | def resize_op(
21 | x,
22 | height_factor=None,
23 | width_factor=None,
24 | size=None,
25 | interp="bicubic",
26 | data_format="channels_last",
27 | ):
28 | """
29 | Resize by a factor if `size=None` else resize to `size`
30 | """
31 | original_shape = x.get_shape().as_list()
32 | if size is not None:
33 | if data_format == "channels_first":
34 | x = tf.transpose(x, [0, 2, 3, 1])
35 | if interp == "bicubic":
36 | x = tf.image.resize_bicubic(x, size)
37 | elif interp == "bilinear":
38 | x = tf.image.resize_bilinear(x, size)
39 | else:
40 | x = tf.image.resize_nearest_neighbor(x, size)
41 | x = tf.transpose(x, [0, 3, 1, 2])
42 | x.set_shape(
43 | (
44 | None,
45 | original_shape[1] if original_shape[1] is not None else None,
46 | size[0],
47 | size[1],
48 | )
49 | )
50 | else:
51 | if interp == "bicubic":
52 | x = tf.image.resize_bicubic(x, size)
53 | elif interp == "bilinear":
54 | x = tf.image.resize_bilinear(x, size)
55 | else:
56 | x = tf.image.resize_nearest_neighbor(x, size)
57 | x.set_shape(
58 | (
59 | None,
60 | size[0],
61 | size[1],
62 | original_shape[3] if original_shape[3] is not None else None,
63 | )
64 | )
65 | else:
66 | if data_format == "channels_first":
67 | new_shape = tf.cast(tf.shape(x)[2:], tf.float32)
68 | new_shape *= tf.constant(
69 | np.array([height_factor, width_factor]).astype("float32")
70 | )
71 | new_shape = tf.cast(new_shape, tf.int32)
72 | x = tf.transpose(x, [0, 2, 3, 1])
73 | if interp == "bicubic":
74 | x = tf.image.resize_bicubic(x, new_shape)
75 | elif interp == "bilinear":
76 | x = tf.image.resize_bilinear(x, new_shape)
77 | else:
78 | x = tf.image.resize_nearest_neighbor(x, new_shape)
79 | x = tf.transpose(x, [0, 3, 1, 2])
80 | x.set_shape(
81 | (
82 | None,
83 | original_shape[1] if original_shape[1] is not None else None,
84 | int(original_shape[2] * height_factor)
85 | if original_shape[2] is not None
86 | else None,
87 | int(original_shape[3] * width_factor)
88 | if original_shape[3] is not None
89 | else None,
90 | )
91 | )
92 | else:
93 | original_shape = x.get_shape().as_list()
94 | new_shape = tf.cast(tf.shape(x)[1:3], tf.float32)
95 | new_shape *= tf.constant(
96 | np.array([height_factor, width_factor]).astype("float32")
97 | )
98 | new_shape = tf.cast(new_shape, tf.int32)
99 | if interp == "bicubic":
100 | x = tf.image.resize_bicubic(x, new_shape)
101 | elif interp == "bilinear":
102 | x = tf.image.resize_bilinear(x, new_shape)
103 | else:
104 | x = tf.image.resize_nearest_neighbor(x, new_shape)
105 | x.set_shape(
106 | (
107 | None,
108 | int(original_shape[1] * height_factor)
109 | if original_shape[1] is not None
110 | else None,
111 | int(original_shape[2] * width_factor)
112 | if original_shape[2] is not None
113 | else None,
114 | original_shape[3] if original_shape[3] is not None else None,
115 | )
116 | )
117 | return x
118 |
119 |
120 | ####
121 | def crop_op(x, cropping, data_format="channels_last"):
122 | """
123 | Center crop image
124 |
125 | Args:
126 | cropping: the substracted portion
127 | """
128 |
129 | crop_t = cropping[0] // 2
130 | crop_b = cropping[0] - crop_t
131 | crop_l = cropping[1] // 2
132 | crop_r = cropping[1] - crop_l
133 | if data_format == "channels_first":
134 | x = x[:, :, crop_t:-crop_b, crop_l:-crop_r]
135 | else:
136 | x = x[:, crop_t:-crop_b, crop_l:-crop_r]
137 | return x
138 |
139 |
140 | def categorical_crossentropy(output, target):
141 | """
142 | categorical cross-entropy, accept probabilities not logit
143 |
144 | Args:
145 | output:
146 | target:
147 | """
148 |
149 | # scale preds so that the class probs of each sample sum to 1
150 | output /= tf.reduce_sum(
151 | output, reduction_indices=len(output.get_shape()) - 1, keepdims=True
152 | )
153 | # manual computation of crossentropy
154 | epsilon = tf.convert_to_tensor(10e-8, output.dtype.base_dtype)
155 | output = tf.clip_by_value(output, epsilon, 1.0 - epsilon)
156 | return -tf.reduce_sum(
157 | target * tf.log(output), reduction_indices=len(output.get_shape()) - 1
158 | )
159 |
160 |
161 | def dice_loss(output, target, loss_type="sorensen", axis=None, smooth=1e-3):
162 | """Soft dice (Sørensen or Jaccard) coefficient for comparing the similarity
163 | of two batch of data, usually be used for binary image segmentation
164 | i.e. labels are binary. The coefficient between 0 to 1, 1 means totally match.
165 | Parameters
166 | -----------
167 | output : Tensor
168 | A distribution with shape: [batch_size, ....], (any dimensions).
169 | target : Tensor
170 | The target distribution, format the same with `output`.
171 | loss_type : str
172 | ``jaccard`` or ``sorensen``, default is ``jaccard``.
173 | axis : tuple of int
174 | All dimensions are reduced, default ``[1,2,3]``.
175 | smooth : float
176 | This small value will be added to the numerator and denominator.
177 | - If both output and target are empty, it makes sure dice is 1.
178 | - If either output or target are empty (all pixels are background),
179 | dice = ```smooth/(small_value + smooth)``, then if smooth is very small,
180 | dice close to 0 (even the image values lower than the threshold),
181 | so in this case, higher smooth can have a higher dice.
182 | Examples
183 | ---------
184 | >>> dice_loss = dice_coe(outputs, y_)
185 | """
186 |
187 | target = tf.squeeze(tf.cast(target, tf.float32))
188 | output = tf.squeeze(tf.cast(output, tf.float32))
189 |
190 | inse = tf.reduce_sum(output * target, axis=axis)
191 | if loss_type == "jaccard":
192 | l = tf.reduce_sum(output * output, axis=axis)
193 | r = tf.reduce_sum(target * target, axis=axis)
194 | elif loss_type == "sorensen":
195 | l = tf.reduce_sum(output, axis=axis)
196 | r = tf.reduce_sum(target, axis=axis)
197 | else:
198 | raise Exception("Unknown loss_type")
199 | # already flatten
200 | dice = 1.0 - (2.0 * inse + smooth) / (l + r + smooth)
201 | ##
202 | return dice
203 |
204 |
205 | def colorize(value, vmin=None, vmax=None, cmap=None):
206 | """
207 | Args:
208 | - value: input tensor, NHWC ('channels_last')
209 | - vmin: the minimum value of the range used for normalization.
210 | (Default: value minimum)
211 | - vmax: the maximum value of the range used for normalization.
212 | (Default: value maximum)
213 | - cmap: a valid cmap named for use with matplotlib's `get_cmap`.
214 | (Default: 'gray')
215 |
216 | Example usage:
217 | ```
218 | output = tf.random_uniform(shape=[256, 256, 1])
219 | output_color = colorize(output, vmin=0.0, vmax=1.0, cmap='viridis')
220 | tf.summary.image('output', output_color)
221 | ```
222 |
223 | Returns:
224 | 3D tensor of shape [height, width, 3], uint8.
225 | """
226 |
227 | # normalize
228 | if vmin is None:
229 | vmin = tf.reduce_min(value, axis=[1, 2])
230 | vmin = tf.reshape(vmin, [-1, 1, 1])
231 | if vmax is None:
232 | vmax = tf.reduce_max(value, axis=[1, 2])
233 | vmax = tf.reshape(vmax, [-1, 1, 1])
234 | value = (value - vmin) / (vmax - vmin) # vmin..vmax
235 |
236 | # squeeze last dim if it exists
237 | # NOTE: will throw error if use get_shape()
238 | # value = tf.squeeze(value)
239 |
240 | # quantize
241 | value = tf.round(value * 255)
242 | indices = tf.cast(value, np.int32)
243 |
244 | # gather
245 | colormap = cm.get_cmap(cmap if cmap is not None else "gray")
246 | colors = colormap(np.arange(256))[:, :3]
247 | colors = tf.constant(colors, dtype=tf.float32)
248 | value = tf.gather(colors, indices)
249 | value = tf.cast(value * 255, tf.uint8)
250 | return value
251 |
252 |
253 | @layer_register(use_scope=None)
254 | def BNELU(x, name=None):
255 | """
256 | A shorthand of BatchNormalization + ELU.
257 |
258 | Args:
259 | x (tf.Tensor): the input
260 | name: deprecated, don't use.
261 | """
262 | if name is not None:
263 | log_deprecated("BNReLU(name=...)", "The output tensor will be named `output`.")
264 |
265 | x = BatchNorm("bn", x)
266 | x = tf.nn.elu(x, name=name)
267 | return x
268 |
269 |
--------------------------------------------------------------------------------
/src/model/utils/norm_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File: custom_ops.py
3 | ###
4 | # https://github.com/tensorpack/tensorpack/blob/master/tensorpack/models/batch_norm.py
5 | ###
6 |
7 | import tensorflow as tf
8 | from tensorflow.contrib.framework import add_model_variable
9 | from tensorflow.python.training import moving_averages
10 | import re
11 | import six
12 | import functools
13 |
14 | from tensorpack.utils import logger
15 | from tensorpack.utils.argtools import get_data_format
16 | from tensorpack.tfutils.tower import get_current_tower_context
17 | from tensorpack.tfutils.common import get_tf_version_tuple
18 | from tensorpack.tfutils.collection import backup_collection, restore_collection
19 | from tensorpack import layer_register, VariableHolder
20 | from tensorpack.tfutils.varreplace import custom_getter_scope
21 |
22 |
23 | def rename_get_variable(mapping):
24 | """
25 | Args:
26 |
27 | mapping(dict): an old -> new mapping for variable basename. e.g. {'kernel': 'W'}
28 |
29 | Returns:
30 | A context where the variables are renamed.
31 | """
32 | def custom_getter(getter, name, *args, **kwargs):
33 | splits = name.split('/')
34 | basename = splits[-1]
35 | if basename in mapping:
36 | basename = mapping[basename]
37 | splits[-1] = basename
38 | name = '/'.join(splits)
39 | return getter(name, *args, **kwargs)
40 | return custom_getter_scope(custom_getter)
41 |
42 |
43 | def map_common_tfargs(kwargs):
44 | df = kwargs.pop('data_format', None)
45 | if df is not None:
46 | df = get_data_format(df, tfmode=True)
47 | kwargs['data_format'] = df
48 |
49 | old_nl = kwargs.pop('nl', None)
50 | if old_nl is not None:
51 | kwargs['activation'] = lambda x, name=None: old_nl(x, name=name)
52 |
53 | if 'W_init' in kwargs:
54 | kwargs['kernel_initializer'] = kwargs.pop('W_init')
55 |
56 | if 'b_init' in kwargs:
57 | kwargs['bias_initializer'] = kwargs.pop('b_init')
58 | return kwargs
59 |
60 |
61 | def convert_to_tflayer_args(args_names, name_mapping):
62 | """
63 | After applying this decorator:
64 | 1. data_format becomes tf.layers style
65 | 2. nl becomes activation
66 | 3. initializers are renamed
67 | 4. positional args are transformed to correspoding kwargs, according to args_names
68 | 5. kwargs are mapped to tf.layers names if needed, by name_mapping
69 | """
70 |
71 | def decorator(func):
72 | @functools.wraps(func)
73 | def decorated_func(inputs, *args, **kwargs):
74 | kwargs = map_common_tfargs(kwargs)
75 |
76 | posarg_dic = {}
77 | assert len(args) <= len(args_names), \
78 | "Please use kwargs instead of positional args to call this model, " \
79 | "except for the following arguments: {}".format(
80 | ', '.join(args_names))
81 | for pos_arg, name in zip(args, args_names):
82 | posarg_dic[name] = pos_arg
83 |
84 | ret = {}
85 | for name, arg in six.iteritems(kwargs):
86 | newname = name_mapping.get(name, None)
87 | if newname is not None:
88 | assert newname not in kwargs, \
89 | "Argument {} and {} conflicts!".format(name, newname)
90 | else:
91 | newname = name
92 | ret[newname] = arg
93 | # Let pos arg overwrite kw arg, for argscope to work
94 | ret.update(posarg_dic)
95 |
96 | return func(inputs, **ret)
97 |
98 | return decorated_func
99 |
100 | return decorator
101 |
102 |
103 | __all__ = ['BatchNorm3d']
104 |
105 |
106 | def get_bn_variables(n_out, use_scale, use_bias, beta_init, gamma_init):
107 | if use_bias:
108 | beta = tf.get_variable('beta', [n_out], initializer=beta_init)
109 | else:
110 | beta = tf.zeros([n_out], name='beta')
111 | if use_scale:
112 | gamma = tf.get_variable('gamma', [n_out], initializer=gamma_init)
113 | else:
114 | gamma = tf.ones([n_out], name='gamma')
115 | # x * gamma + beta
116 |
117 | moving_mean = tf.get_variable('mean/EMA', [n_out],
118 | initializer=tf.constant_initializer(), trainable=False)
119 | moving_var = tf.get_variable('variance/EMA', [n_out],
120 | initializer=tf.constant_initializer(1.0), trainable=False)
121 |
122 | if get_current_tower_context().is_main_training_tower:
123 | for v in [moving_mean, moving_var]:
124 | tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
125 | return beta, gamma, moving_mean, moving_var
126 |
127 |
128 | def internal_update_bn_ema(xn, batch_mean, batch_var,
129 | moving_mean, moving_var, decay):
130 | update_op1 = moving_averages.assign_moving_average(
131 | moving_mean, batch_mean, decay, zero_debias=False,
132 | name='mean_ema_op')
133 | update_op2 = moving_averages.assign_moving_average(
134 | moving_var, batch_var, decay, zero_debias=False,
135 | name='var_ema_op')
136 |
137 | # When sync_statistics is True, always enable internal_update.
138 | # Otherwise the update ops (only executed on main tower)
139 | # will hang when some BatchNorm layers are unused (https://github.com/tensorpack/tensorpack/issues/1078)
140 | with tf.control_dependencies([update_op1, update_op2]):
141 | return tf.identity(xn, name='output')
142 |
143 |
144 | @layer_register()
145 | @convert_to_tflayer_args(
146 | args_names=[],
147 | name_mapping={
148 | 'use_bias': 'center',
149 | 'use_scale': 'scale',
150 | 'gamma_init': 'gamma_initializer',
151 | 'decay': 'momentum',
152 | 'use_local_stat': 'training'
153 | })
154 | def BatchNorm3d(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
155 | center=True, scale=True,
156 | beta_initializer=tf.zeros_initializer(),
157 | gamma_initializer=tf.ones_initializer(),
158 | virtual_batch_size=None,
159 | data_format='channels_last',
160 | internal_update=False,
161 | sync_statistics=None):
162 | """
163 | Almost equivalent to `tf.layers.batch_normalization`, but different (and more powerful)
164 | in the following:
165 | 1. Accepts an alternative `data_format` option when `axis` is None. For 2D input, this argument will be ignored.
166 | 2. Default value for `momentum` and `epsilon` is different.
167 | 3. Default value for `training` is automatically obtained from tensorpack's `TowerContext`, but can be overwritten.
168 | 4. Support the `internal_update` option, which enables the use of BatchNorm layer inside conditionals.
169 | 5. Support the `sync_statistics` option, which is very useful in small-batch models.
170 | Args:
171 | internal_update (bool): if False, add EMA update ops to
172 | `tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer by control dependencies.
173 | They are very similar in speed, but `internal_update=True` can be used
174 | when you have conditionals in your model, or when you have multiple networks to train.
175 | Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/14699
176 | sync_statistics: either None or "nccl". By default (None), it uses statistics of the input tensor to normalize.
177 | When set to "nccl", this layer must be used under tensorpack multi-gpu trainers,
178 | and it then uses per-machine (multiple GPU) statistics to normalize.
179 | Note that this implementation averages the per-tower E[x] and E[x^2] among towers to compute
180 | global mean&variance. The result is the global mean&variance only if each tower has the same batch size.
181 | This option has no effect when not training.
182 | This option is also known as "Cross-GPU BatchNorm" as mentioned in https://arxiv.org/abs/1711.07240.
183 | Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/18222
184 | Variable Names:
185 | * ``beta``: the bias term. Will be zero-inited by default.
186 | * ``gamma``: the scale term. Will be one-inited by default.
187 | * ``mean/EMA``: the moving average of mean.
188 | * ``variance/EMA``: the moving average of variance.
189 | Note:
190 | Combinations of ``training`` and ``ctx.is_training``:
191 | * ``training == ctx.is_training``: standard BN, EMA are maintained during training
192 | and used during inference. This is the default.
193 | * ``training and not ctx.is_training``: still use batch statistics in inference.
194 | * ``not training and ctx.is_training``: use EMA to normalize in
195 | training. This is useful when you load a pre-trained BN and
196 | don't want to fine tune the EMA. EMA will not be updated in
197 | this case.
198 | """
199 | # parse shapes
200 | data_format = get_data_format(data_format, tfmode=False)
201 | shape = inputs.get_shape().as_list()
202 | ndims = len(shape)
203 | # in 3d conv, we have 5d dim [batch, h, w, t, c]
204 | if sync_statistics is not None:
205 | sync_statistics = sync_statistics.lower()
206 | assert sync_statistics in [None, 'nccl', 'horovod'], sync_statistics
207 |
208 | if axis is None:
209 | assert ndims == 5, 'Number of input dims must be 5 for 3d Batch Norm'
210 | axis = 1 if data_format == 'NCHW' else 4
211 |
212 | num_chan = shape[axis]
213 |
214 | # parse training/ctx
215 | ctx = get_current_tower_context()
216 | if training is None:
217 | training = ctx.is_training
218 | training = bool(training)
219 | TF_version = get_tf_version_tuple()
220 | if not training and ctx.is_training:
221 | assert TF_version >= (1, 4), \
222 | "Fine tuning a BatchNorm model with fixed statistics is only " \
223 | "supported after https://github.com/tensorflow/tensorflow/pull/12580 "
224 | if ctx.is_main_training_tower: # only warn in first tower
225 | logger.warn(
226 | "[BatchNorm] Using moving_mean/moving_variance in training.")
227 | # Using moving_mean/moving_variance in training, which means we
228 | # loaded a pre-trained BN and only fine-tuning the affine part.
229 |
230 | if sync_statistics is None or not (training and ctx.is_training):
231 | coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
232 | with rename_get_variable(
233 | {'moving_mean': 'mean/EMA',
234 | 'moving_variance': 'variance/EMA'}):
235 | tf_args = dict(
236 | axis=axis,
237 | momentum=momentum, epsilon=epsilon,
238 | center=center, scale=scale,
239 | beta_initializer=beta_initializer,
240 | gamma_initializer=gamma_initializer,
241 | fused=True,
242 | _reuse=tf.get_variable_scope().reuse)
243 | if TF_version >= (1, 5):
244 | tf_args['virtual_batch_size'] = virtual_batch_size
245 | else:
246 | assert virtual_batch_size is None, "Feature not supported in this version of TF!"
247 | layer = tf.layers.BatchNormalization(**tf_args)
248 | xn = layer.apply(inputs, training=training,
249 | scope=tf.get_variable_scope())
250 |
251 | # maintain EMA only on one GPU is OK, even in replicated mode.
252 | # because during training, EMA isn't used
253 | if ctx.is_main_training_tower:
254 | for v in layer.non_trainable_variables:
255 | add_model_variable(v)
256 | if not ctx.is_main_training_tower or internal_update:
257 | restore_collection(coll_bk)
258 |
259 | if training and internal_update:
260 | assert layer.updates
261 | with tf.control_dependencies(layer.updates):
262 | ret = tf.identity(xn, name='output')
263 | else:
264 | ret = tf.identity(xn, name='output')
265 |
266 | vh = ret.variables = VariableHolder(
267 | moving_mean=layer.moving_mean,
268 | mean=layer.moving_mean, # for backward-compatibility
269 | moving_variance=layer.moving_variance,
270 | variance=layer.moving_variance) # for backward-compatibility
271 | if scale:
272 | vh.gamma = layer.gamma
273 | if center:
274 | vh.beta = layer.beta
275 | else:
276 | red_axis = [0] if ndims == 2 else (
277 | [0, 2, 3] if axis == 1 else [0, 1, 2])
278 | if ndims == 5:
279 | red_axis = [0, 2, 3, 4] if axis == 1 else [0, 1, 2, 3]
280 | new_shape = None # don't need to reshape unless ...
281 | if ndims == 4 and axis == 1:
282 | new_shape = [1, num_chan, 1, 1]
283 | if ndims == 5 and axis == 1:
284 | new_shape = [1, num_chan, 1, 1, 1]
285 |
286 | batch_mean = tf.reduce_mean(inputs, axis=red_axis)
287 | batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis)
288 |
289 | if sync_statistics == 'nccl':
290 | if six.PY3 and TF_version <= (1, 8) and ctx.is_main_training_tower:
291 | logger.warn("A TensorFlow bug will cause cross-GPU BatchNorm to fail. "
292 | "Apply this patch: https://github.com/tensorflow/tensorflow/pull/20360")
293 |
294 | from tensorflow.contrib.nccl.ops import gen_nccl_ops
295 | shared_name = re.sub(
296 | 'tower[0-9]+/', '', tf.get_variable_scope().name)
297 | num_dev = ctx.total
298 | batch_mean = gen_nccl_ops.nccl_all_reduce(
299 | input=batch_mean,
300 | reduction='sum',
301 | num_devices=num_dev,
302 | shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
303 | batch_mean_square = gen_nccl_ops.nccl_all_reduce(
304 | input=batch_mean_square,
305 | reduction='sum',
306 | num_devices=num_dev,
307 | shared_name=shared_name + '_NCCL_mean_square') * (1.0 / num_dev)
308 | elif sync_statistics == 'horovod':
309 | # Require https://github.com/uber/horovod/pull/331
310 | # Proof-of-concept, not ready yet.
311 | import horovod.tensorflow as hvd
312 | batch_mean = hvd.allreduce(batch_mean, average=True)
313 | batch_mean_square = hvd.allreduce(batch_mean_square, average=True)
314 | batch_var = batch_mean_square - tf.square(batch_mean)
315 | batch_mean_vec = batch_mean
316 | batch_var_vec = batch_var
317 |
318 | beta, gamma, moving_mean, moving_var = get_bn_variables(
319 | num_chan, scale, center, beta_initializer, gamma_initializer)
320 | if new_shape is not None:
321 | batch_mean = tf.reshape(batch_mean, new_shape)
322 | batch_var = tf.reshape(batch_var, new_shape)
323 | # Using fused_batch_norm(is_training=False) is actually slightly faster,
324 | # but hopefully this call will be JITed in the future.
325 | xn = tf.nn.batch_normalization(
326 | inputs, batch_mean, batch_var,
327 | tf.reshape(beta, new_shape),
328 | tf.reshape(gamma, new_shape), epsilon)
329 | else:
330 | xn = tf.nn.batch_normalization(
331 | inputs, batch_mean, batch_var,
332 | beta, gamma, epsilon)
333 |
334 | if ctx.is_main_training_tower:
335 | ret = update_bn_ema(
336 | xn, batch_mean_vec, batch_var_vec, moving_mean, moving_var,
337 | momentum, internal_update)
338 | else:
339 | ret = tf.identity(xn, name='output')
340 |
341 | vh = ret.variables = VariableHolder(
342 | moving_mean=moving_mean,
343 | mean=moving_mean, # for backward-compatibility
344 | moving_variance=moving_var,
345 | variance=moving_var) # for backward-compatibility
346 | if scale:
347 | vh.gamma = gamma
348 | if center:
349 | vh.beta = beta
350 | return ret
351 |
--------------------------------------------------------------------------------
/src/model/utils/rotation_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | The below functions are adapted from www.github.com/tueimage/SE2CNN
4 |
5 | Released in June 2018
6 | @author: EJ Bekkers, Eindhoven University of Technology, The Netherlands
7 | @author: MW Lafarge, Eindhoven University of Technology, The Netherlands
8 | ________________________________________________________________________
9 |
10 | Copyright 2018 Erik J Bekkers and Maxime W Lafarge, Eindhoven University
11 | of Technology, the Netherlands
12 |
13 | Licensed under the Apache License, Version 2.0 (the "License");
14 | you may not use this file except in compliance with the License.
15 | You may obtain a copy of the License at
16 |
17 | http://www.apache.org/licenses/LICENSE-2.0
18 |
19 | Unless required by applicable law or agreed to in writing, software
20 | distributed under the License is distributed on an "AS IS" BASIS,
21 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22 | See the License for the specific language governing permissions and
23 | limitations under the License.
24 | ________________________________________________________________________
25 | """
26 |
27 | import numpy as np
28 | import math as m
29 |
30 |
31 | def CoordRotationInv(ij, NiNj, theta):
32 | """ Appplies the inverse rotation transformation on input coordinates (i,j).
33 | The rotation is around the center of the image with dimensions Ni, Nj
34 | (resp. # of rows and colums). Input theta is the applied rotation.
35 |
36 | INPUT:
37 | - ij, a list of length 2 containing the i and j coordinate as [i,j]
38 | - NiNj, a list of length 2 containing the dimensions of the 2D domain
39 | as [Ni, Nj]
40 | - theta, a real number specifying the angle of rotations
41 |
42 | OUTPUT:
43 | - ijOld, a list of length 2 containing the new coordinate of the
44 | inverse rotation, i.e., the old coordinate which was mapped to
45 | the new one via (forward) rotation over theta.
46 | """
47 |
48 | # Define the center of rotation
49 | centeri = m.floor(NiNj[0] / 2)
50 | centerj = m.floor(NiNj[1] / 2)
51 |
52 | # Compute the output of the inverse rotation transformation
53 | ijOld = np.zeros([2])
54 | ijOld[0] = m.cos(theta) * (ij[0] - centeri) + \
55 | m.sin(theta) * (ij[1] - centerj) + centeri
56 | ijOld[1] = -1 * m.sin(theta) * (ij[0] - centeri) + \
57 | m.cos(theta) * (ij[1] - centerj) + centerj
58 |
59 | # Return the "old" indices
60 | return ijOld
61 |
62 |
63 | def LinIntIndicesAndWeights(ij, NiNj):
64 | """ Returns, given a target index (i,j), the 4 neighbouring indices and
65 | their corresponding weights used for linear interpolation.
66 |
67 | INPUT:
68 | - ij, a list of length 2 containing the i and j coordinate
69 | as [i,j]
70 | - NiNj, a list of length 2 containing the dimensions of the 2D
71 | domain as [Ni,Nj]
72 |
73 | OUTPUT:
74 | - indicesAndWeights, a list index-weight pairs as [[i0,j0,w00],
75 | [i0,j1,w01],...]
76 | """
77 |
78 | # The index where want to obtain the value
79 | i = ij[0]
80 | j = ij[1]
81 | # Image size
82 | Ni = NiNj[0]
83 | Nj = NiNj[1]
84 |
85 | # The neighbouring indices
86 | i1 = int(m.floor(i)) # -- to integer format
87 | i2 = i1 + 1
88 | j1 = int(m.floor(j)) # -- to integer format
89 | j2 = j1 + 1
90 |
91 | # The 1D weights
92 | ti = i - i1
93 | tj = j - j1
94 |
95 | # The 2D weights
96 | w11 = (1 - ti) * (1 - tj)
97 | w12 = (1 - ti) * tj
98 | w21 = ti * (1 - tj)
99 | w22 = ti * tj
100 |
101 | # Only add indices and weights if they fall in the range of the image with
102 | # dimensions NiNj
103 | indicesAndWeights = []
104 | if (0 <= i1 < Ni) and (0 <= j1 < Nj):
105 | indicesAndWeights.append([i1, j1, w11])
106 | if (0 <= i1 < Ni) and (0 <= j2 < Nj):
107 | indicesAndWeights.append([i1, j2, w12])
108 | if (0 <= i2 < Ni) and (0 <= j1 < Nj):
109 | indicesAndWeights.append([i2, j1, w21])
110 | if (0 <= i2 < Ni) and (0 <= j2 < Nj):
111 | indicesAndWeights.append([i2, j2, w22])
112 |
113 | return indicesAndWeights
114 |
115 |
116 | def ToLinearIndex(ij, NiNj):
117 | """ Returns the linear index of a flattened 2D image that has dimensions
118 | [Ni,Nj] before flattening.
119 |
120 | INPUT:
121 | - ij, a list of length 2 containing the i and j coordinate
122 | as [i,j]
123 | - NiNj, a list of length 2 containing the dimensions of the 2D
124 | domain as [Ni,Nj]
125 |
126 | OUTPUT:
127 | - ijFlat = ij[0] * NiNj[0] + ij[1]
128 | """
129 |
130 | return ij[0] * NiNj[0] + ij[1]
131 |
132 |
133 | def RotationOperatorMatrix(NiNj, theta, diskMask=True):
134 | """ Returns the matrix that rotates a square image by R.f, where f is the
135 | flattend image (a vector of length Ni*Nj).
136 | The resulting vector needs to be repartitioned in to a [Ni,Nj] sized image later.
137 |
138 | INPUT:
139 | - NiNj, a list of length 2 containing the dimensions of the 2D
140 | domain as [Ni,Nj]
141 | - theta, a real number specifying the rotation angle
142 |
143 | INPUT (optional):
144 | - diskMask = True, by default values outside a circular mask are set
145 | to zero.
146 |
147 | OUTPUT:
148 | - rotationMatrix, a np.array of dimensions [Ni*Nj,Ni*Nj]
149 | """
150 |
151 | # Image size
152 | Ni = NiNj[0]
153 | Nj = NiNj[1]
154 | cij = m.floor(Ni / 2) # center
155 |
156 | # Fill the rotation operator matrix
157 | rotationMatrix = np.zeros([Ni * Nj, Ni * Nj])
158 | for i in range(0, NiNj[0]):
159 | for j in range(0, NiNj[0]):
160 | # Apply a circular mask (disk matrix) if desired
161 | if not(diskMask) or ((i - cij) * (i - cij) + (j - cij) * (j - cij) <= (cij + 0.5) * (cij + 0.5)):
162 | # The row index of the operator matrix
163 | linij = ToLinearIndex([i, j], NiNj)
164 | # The interpolation points
165 | ijOld = CoordRotationInv([i, j], NiNj, theta)
166 | # The indices used for interpolation and their weights
167 | linIntIndicesAndWeights = LinIntIndicesAndWeights(ijOld, NiNj)
168 | # Fill the weights in the rotationMatrix
169 | for indexAndWeight in linIntIndicesAndWeights:
170 | indexOld = [indexAndWeight[0], indexAndWeight[1]]
171 | linIndexOld = ToLinearIndex(indexOld, NiNj)
172 | weight = indexAndWeight[2]
173 | rotationMatrix[linij, linIndexOld] = weight
174 | return rotationMatrix
175 |
176 |
177 | def MultiRotationOperatorMatrix(NiNj, Ntheta, periodicity=2 * np.pi, diskMask=True):
178 | """ Concatenates multiple operator matrices along the first dimension for a
179 | direct multi-orientation transformation.
180 | I.e., this function returns the matrix that rotates a square image over several angles via R.f,
181 | where f is the flattend image (a vector of length Ni*Nj).
182 | The dimensions of R are [Ntheta*Ni*Nj], with Ntheta the number of orientations
183 | sampled from 0 to "periodicity".
184 | The resulting vector needs to be repartitioned into a [Ntheta,Ni,Nj] stack of rotated images later.
185 |
186 | INPUT:
187 | - NiNj, a list of length 2 containing the dimensions of the 2D domain as [Ni,Nj]
188 | - nTheta, an integer specifying the number of rotations
189 |
190 | INPUT (optional):
191 | - periodicity = 2*np.pi, by default rotations from 0 to 2 pi are considered.
192 | - diskMask = True, by default values outside a circular mask are set to zero.
193 |
194 | OUTPUT:
195 | - rotationMatrix, a np.array of dimensions [Ntheta*Ni*Nj,Ni*Nj]
196 | """
197 | matrices = [None] * Ntheta
198 | for r in range(Ntheta):
199 | matrices[r] = RotationOperatorMatrix(
200 | NiNj,
201 | periodicity * r / Ntheta,
202 | diskMask=diskMask)
203 | return np.concatenate(matrices, axis=0)
204 |
205 |
206 | def RotationOperatorMatrixSparse(NiNj, theta, diskMask=True, linIndOffset=0):
207 | """ Returns the idx and vals, where idx is a tuple of 2D indices (also as tuples) and vals the corresponding values.
208 | The indices and weights can be converted to a spare tensorflow matrix via
209 | R = tf.SparseTensor(idx,vals,[Ni*Nj,Ni*Nj])
210 | The resulting matrix rotates a square image by R.f, where f is the flattend image (a vector of length Ni*Nj).
211 | The resulting vector needs to be repartitioned in to a [Ni,Nj] sized image later.
212 |
213 | INPUT:
214 | - NiNj, a list of length 2 containing the dimensions of the 2D
215 | domain as [Ni,Nj]
216 | - theta, a real number specifying the rotation angle
217 |
218 | INPUT (optional):
219 | - diskMask = True, by default values outside a circular mask are set
220 | to zero.
221 |
222 | OUTPUT:
223 | - idx, a tuple containing the non-zero indices (tuples of length 2)
224 | - vals, the corresponding values at these indices
225 | """
226 |
227 | # Image size
228 | Ni = NiNj[0]
229 | Nj = NiNj[1]
230 | cij = m.floor(Ni / 2) # center
231 |
232 | # Fill the rotation operator matrix
233 | # rotationMatrix = np.zeros([Ni * Nj, Ni * Nj])
234 | idx = [] # This will contain a list of index tuples
235 | vals = [] # This will contain the corresponding weights
236 | for i in range(0, NiNj[0]):
237 | for j in range(0, NiNj[0]):
238 | # Apply a circular mask (disk matrix) if desired
239 | if not(diskMask) or ((i - cij) * (i - cij) + (j - cij) * (j - cij) <= (cij + 0.5) * (cij + 0.5)):
240 | # The row index of the operator matrix
241 | linij = ToLinearIndex([i, j], NiNj)
242 | # The interpolation points
243 | ijOld = CoordRotationInv([i, j], NiNj, theta)
244 | # The indices used for interpolation and their weights
245 | linIntIndicesAndWeights = LinIntIndicesAndWeights(ijOld, NiNj)
246 | indicesAndWeights = linIntIndicesAndWeights
247 | # Fill the weights in the rotationMatrix
248 | for indexAndWeight in linIntIndicesAndWeights:
249 | indexOld = [indexAndWeight[0], indexAndWeight[1]]
250 | linIndexOld = ToLinearIndex(indexOld, NiNj)
251 | weight = indexAndWeight[2]
252 | idx = idx + [(linij + linIndOffset, linIndexOld)]
253 | vals = vals + [weight]
254 |
255 | # Return the indices and weights as tuples
256 | return tuple(idx), tuple(vals)
257 |
258 |
259 | def MultiRotationOperatorMatrixSparse(NiNj, Ntheta, periodicity=2 * np.pi, diskMask=True):
260 | """ Returns the idx and vals, where idx is a tuple of 2D indices (also as tuples) and vals the corresponding values.
261 | The indices and weights can be converted to a sparse tensorflow matrix via
262 | R = tf.SparseTensor(idx,vals,[Ntheta*Ni*Nj,Ni*Nj]).
263 | This matrix rotates a square image over several angles via R.f,
264 | where f is the flattend image (a vector of length Ni*Nj).
265 | The dimensions of R are [Ntheta*Ni*Nj], with Ntheta the number of orientations
266 | sampled from 0 to "periodicity".
267 | The resulting vector needs to be repartitioned into a [Ntheta,Ni,Nj] stack of rotated images later.
268 |
269 | INPUT:
270 | - NiNj, a list of length 2 containing the dimensions of the 2D
271 | domain as [Ni,Nj]
272 | - nTheta, an integer specifying the number of rotations
273 |
274 | INPUT (optional):
275 | - periodicity = 2*np.pi, by default rotations from 0 to 2 pi are
276 | considered.
277 | - diskMask = True, by default values outside a circular mask are set
278 | to zero.
279 |
280 | OUTPUT:
281 | - idx, a tuple containing the non-zero indices (tuples of length 2)
282 | - vals, the corresponding values at these indices
283 | """
284 | idx = ()
285 | vals = ()
286 | for r in range(Ntheta):
287 | idxr, valsr = RotationOperatorMatrixSparse(
288 | NiNj, periodicity * r / Ntheta,
289 | linIndOffset=r * NiNj[0] * NiNj[1],
290 | diskMask=diskMask)
291 | idx = idx + idxr
292 | vals = vals + valsr
293 | return idx, vals
294 |
--------------------------------------------------------------------------------
/src/opt/augs.py:
--------------------------------------------------------------------------------
1 | """
2 | Augmentation pipeline
3 | """
4 |
5 | from tensorpack import imgaug
6 | import cv2
7 | from loader.custom_augs import (
8 | BinarizeLabel,
9 | GaussianBlur,
10 | MedianBlur,
11 | GenInstanceContourMap,
12 | GenInstanceMarkerMap,
13 | )
14 |
15 | # refer to https://tensorpack.readthedocs.io/modules/dataflow.imgaug.html for
16 | # information on how to modify the augmentation parameters
17 |
18 |
19 | def get_train_augmentors(self, input_shape, output_shape, view=False):
20 | print(input_shape, output_shape)
21 | if self.model_mode == "class_rotmnist":
22 | shape_augs = [
23 | imgaug.Affine(
24 | rotate_max_deg=359, interp=cv2.INTER_NEAREST, border=cv2.BORDER_CONSTANT
25 | ),
26 | ]
27 |
28 | input_augs = []
29 |
30 | else:
31 | shape_augs = [
32 | imgaug.Affine(
33 | rotate_max_deg=359,
34 | translate_frac=(0.01, 0.01),
35 | interp=cv2.INTER_NEAREST,
36 | border=cv2.BORDER_REFLECT,
37 | ),
38 | imgaug.Flip(vert=True),
39 | imgaug.Flip(horiz=True),
40 | imgaug.CenterCrop(input_shape),
41 | ]
42 |
43 | input_augs = [
44 | imgaug.RandomApplyAug(
45 | imgaug.RandomChooseAug(
46 | [GaussianBlur(), MedianBlur(), imgaug.GaussianNoise(),]
47 | ),
48 | 0.5,
49 | ),
50 | # Standard colour augmentation
51 | imgaug.RandomOrderAug(
52 | [
53 | imgaug.Hue((-8, 8), rgb=True),
54 | imgaug.Saturation(0.2, rgb=True),
55 | imgaug.Brightness(26, clip=True),
56 | imgaug.Contrast((0.75, 1.25), clip=True),
57 | ]
58 | ),
59 | imgaug.ToUint8(),
60 | ]
61 |
62 | if self.model_mode == "seg_gland":
63 | label_augs = []
64 | label_augs = [GenInstanceContourMap(mode=self.model_mode)]
65 | label_augs.append(BinarizeLabel())
66 |
67 | if not view:
68 | label_augs.append(imgaug.CenterCrop(output_shape))
69 | else:
70 | label_augs.append(imgaug.CenterCrop(input_shape))
71 |
72 | return shape_augs, input_augs, label_augs
73 | elif self.model_mode == "seg_nuc":
74 | label_augs = []
75 | label_augs = [GenInstanceMarkerMap()]
76 | label_augs.append(BinarizeLabel())
77 | if not view:
78 | label_augs.append(imgaug.CenterCrop(output_shape))
79 | else:
80 | label_augs.append(imgaug.CenterCrop(input_shape))
81 |
82 | return shape_augs, input_augs, label_augs
83 |
84 | else:
85 | return shape_augs, input_augs
86 |
87 |
88 | def get_valid_augmentors(self, input_shape, output_shape, view=False):
89 | print(input_shape, output_shape)
90 | shape_augs = [
91 | imgaug.CenterCrop(input_shape),
92 | ]
93 | input_augs = []
94 |
95 | if self.model_mode == "seg_gland":
96 | label_augs = []
97 | label_augs = [GenInstanceContourMap(mode=self.model_mode)]
98 | label_augs.append(BinarizeLabel())
99 |
100 | if not view:
101 | label_augs.append(imgaug.CenterCrop(output_shape))
102 | else:
103 | label_augs.append(imgaug.CenterCrop(input_shape))
104 |
105 | return shape_augs, input_augs, label_augs
106 | elif self.model_mode == "seg_nuc":
107 | label_augs = []
108 | label_augs = [GenInstanceMarkerMap()]
109 | label_augs.append(BinarizeLabel())
110 | if not view:
111 | label_augs.append(imgaug.CenterCrop(output_shape))
112 | else:
113 | label_augs.append(imgaug.CenterCrop(input_shape))
114 |
115 | return shape_augs, input_augs, label_augs
116 | else:
117 | return shape_augs, input_augs
118 |
--------------------------------------------------------------------------------
/src/opt/params.py:
--------------------------------------------------------------------------------
1 | """
2 | Model hyperparameters
3 | """
4 |
5 | import tensorflow as tf
6 |
7 | class_pcam = {
8 | "train_input_shape": [96, 96],
9 | "train_output_shape": [1, 1],
10 | "infer_input_shape": [96, 96],
11 | "infer_output_shape": [1, 1],
12 | "input_chans": 3,
13 | "label_names": ["Non-Tumour", "Tumour"],
14 | "filter_sizes": [5, 7, 9],
15 | "input_norm": True,
16 | "training_phase": [
17 | {
18 | "nr_epochs": 50,
19 | "manual_parameters": {
20 | # tuple(initial value, schedule)
21 | "learning_rate": (
22 | 5.0e-5,
23 | [
24 | ("15", 1.0e-5),
25 | ("25", 1.0e-5),
26 | ("35", 1.0e-5),
27 | ("40", 2.0e-5),
28 | ("45", 1.0e-5),
29 | ],
30 | ),
31 | },
32 | "pretrained_path": None, # randomly initialise weights
33 | "train_batch_size": 32,
34 | "infer_batch_size": 64,
35 | "model_flags": {"freeze": False},
36 | }
37 | ],
38 | "loss_term": {"bce": 1},
39 | "optimizer": tf.train.AdamOptimizer,
40 | "inf_auto_metric": "valid_auc",
41 | "inf_auto_comparator": ">",
42 | "inf_batch_size": 64,
43 | }
44 | seg_gland = {
45 | "train_input_shape": [448, 448],
46 | "train_output_shape": [448, 448],
47 | "infer_input_shape": [448, 448],
48 | "infer_output_shape": [112, 112],
49 | "input_chans": 3,
50 | "filter_sizes": [5, 7, 11],
51 | "input_norm": True,
52 | "training_phase": [
53 | {
54 | "nr_epochs": 70,
55 | "manual_parameters": {
56 | # tuple(initial value, schedule)
57 | "learning_rate": (1.0e-3, [("15", 1.0e-4), ("50", 5.0e-5)]),
58 | },
59 | "pretrained_path": None, # randomly initialise weights
60 | "train_batch_size": 6,
61 | "infer_batch_size": 12,
62 | "model_flags": {"freeze": False},
63 | }
64 | ],
65 | "loss_term": {"bce": 1},
66 | "optimizer": tf.train.AdamOptimizer,
67 | "inf_auto_metric": "valid_dice",
68 | "inf_auto_comparator": ">",
69 | "inf_batch_size": 12,
70 | }
71 | seg_nuc = {
72 | "train_input_shape": [256, 256],
73 | "train_output_shape": [256, 256],
74 | "infer_input_shape": [256, 256],
75 | "infer_output_shape": [112, 112],
76 | "input_chans": 3,
77 | "filter_sizes": [5, 7, 11],
78 | "input_norm": True,
79 | "training_phase": [
80 | {
81 | "nr_epochs": 70,
82 | "manual_parameters": {
83 | # tuple(initial value, schedule)
84 | "learning_rate": (1.0e-3, [("15", 1.0e-4), ("30", 5.0e-5)]),
85 | },
86 | "pretrained_path": None, # randomly initialise weights
87 | "train_batch_size": 6,
88 | "infer_batch_size": 12,
89 | "model_flags": {"freeze": False},
90 | }
91 | ],
92 | "loss_term": {"bce": 1, "dice": 1},
93 | "optimizer": tf.train.AdamOptimizer,
94 | "inf_auto_metric": "valid_dice",
95 | "inf_auto_comparator": ">",
96 | "inf_batch_size": 12,
97 | }
98 |
--------------------------------------------------------------------------------
/src/process.py:
--------------------------------------------------------------------------------
1 | """
2 | Post-processing
3 | """
4 |
5 | import glob
6 | import os
7 | import time
8 | import cv2
9 | import numpy as np
10 | from scipy.ndimage import measurements
11 | from scipy.ndimage.morphology import binary_fill_holes
12 | from skimage.morphology import remove_small_objects, watershed
13 |
14 | from config import Config
15 |
16 | from misc.viz_utils import visualize_instances
17 | from misc.utils import remap_label
18 |
19 |
20 | def process_utils(pred_map, mode):
21 | """
22 | Performs post processing for a given image
23 |
24 | Args:
25 | pred_map: output of CNN
26 | mode: choose either 'seg_gland' or 'seg_nuc'
27 | """
28 |
29 | if mode == "seg_gland":
30 | pred = np.squeeze(pred_map)
31 |
32 | blb = pred[..., 0]
33 | blb = np.squeeze(blb)
34 | cnt = pred[..., 1]
35 | cnt = np.squeeze(cnt)
36 | cnt[cnt > 0.5] = 1
37 | cnt[cnt <= 0.5] = 0
38 |
39 | pred = blb - cnt
40 | pred[pred > 0.55] = 1
41 | pred[pred <= 0.55] = 0
42 | k_disk1 = np.array(
43 | [
44 | [0, 0, 1, 0, 0],
45 | [0, 1, 1, 1, 0],
46 | [1, 1, 1, 1, 1],
47 | [0, 1, 1, 1, 0],
48 | [0, 0, 1, 0, 0],
49 | ],
50 | np.uint8,
51 | )
52 | # ! refactor these
53 | pred = binary_fill_holes(pred)
54 | pred = pred.astype("uint16")
55 | pred = cv2.morphologyEx(pred, cv2.MORPH_OPEN, k_disk1)
56 | pred = measurements.label(pred)[0]
57 | pred = remove_small_objects(pred, min_size=1500)
58 |
59 | k_disk2 = np.array(
60 | [
61 | [0, 0, 0, 0, 1, 0, 0, 0, 0],
62 | [0, 0, 0, 1, 1, 1, 0, 0, 0],
63 | [0, 0, 1, 1, 1, 1, 1, 0, 0],
64 | [0, 1, 1, 1, 1, 1, 1, 1, 0],
65 | [1, 1, 1, 1, 1, 1, 1, 1, 1],
66 | [0, 1, 1, 1, 1, 1, 1, 1, 0],
67 | [0, 0, 1, 1, 1, 1, 1, 0, 0],
68 | [0, 0, 0, 1, 1, 1, 0, 0, 0],
69 | [0, 0, 0, 0, 1, 0, 0, 0, 0],
70 | ],
71 | np.uint8,
72 | )
73 |
74 | pred = pred.astype("uint16")
75 | proced_pred = cv2.dilate(pred, k_disk2, iterations=1)
76 | elif mode == "seg_nuc":
77 | blb_raw = pred_map[..., 0]
78 | blb_raw = np.squeeze(blb_raw)
79 | blb = blb_raw.copy()
80 | blb[blb > 0.5] = 1
81 | blb[blb <= 0.5] = 0
82 | blb = measurements.label(blb)[0]
83 | blb = remove_small_objects(blb, min_size=10)
84 | blb[blb > 0] = 1
85 |
86 | mrk_raw = pred_map[..., 1]
87 | mrk_raw = np.squeeze(mrk_raw)
88 | cnt_raw = pred_map[..., 2]
89 | cnt_raw = np.squeeze(cnt_raw)
90 | cnt = cnt_raw.copy()
91 | cnt[cnt >= 0.4] = 1
92 | cnt[cnt < 0.4] = 0
93 | mrk = mrk_raw - cnt
94 | mrk = mrk * blb
95 | mrk[mrk > 0.75] = 1
96 | mrk[mrk <= 0.75] = 0
97 |
98 | marker = mrk.copy()
99 | marker = binary_fill_holes(marker)
100 | marker = measurements.label(marker)[0]
101 | marker = remove_small_objects(marker, min_size=10)
102 | proced_pred = watershed(-mrk_raw, marker, mask=blb)
103 |
104 | return proced_pred
105 |
106 |
107 | def process():
108 | """
109 | Performs post processing for a list of images
110 |
111 | """
112 |
113 | cfg = Config()
114 |
115 | for data_dir in cfg.inf_data_list:
116 |
117 | proc_dir = cfg.inf_output_dir + "/processed/"
118 | pred_dir = cfg.inf_output_dir + "/raw/"
119 | file_list = glob.glob(pred_dir + "*.npy")
120 | file_list.sort() # ensure same order
121 |
122 | if not os.path.isdir(proc_dir):
123 | os.makedirs(proc_dir)
124 | for filename in file_list:
125 | start = time.time()
126 | filename = os.path.basename(filename)
127 | basename = filename.split(".")[0]
128 |
129 | test_set = basename.split("_")[0]
130 | test_set = test_set[-1]
131 |
132 | print(pred_dir, basename, end=" ", flush=True)
133 |
134 | ##
135 | img = cv2.imread(data_dir + basename + cfg.inf_imgs_ext)
136 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
137 |
138 | pred_map = np.load(pred_dir + "/%s.npy" % basename)
139 |
140 | # get the instance level prediction
141 | pred_inst = process_utils(pred_map, cfg.model_mode)
142 |
143 | # ! remap label is slow - check to see whether it is needed!
144 | pred_inst = remap_label(pred_inst, by_size=True)
145 |
146 | overlaid_output = visualize_instances(pred_inst, img)
147 | overlaid_output = cv2.cvtColor(overlaid_output, cv2.COLOR_BGR2RGB)
148 | cv2.imwrite("%s/%s.png" % (proc_dir, basename), overlaid_output)
149 |
150 | # save segmentation mask
151 | np.save("%s/%s" % (proc_dir, basename), pred_inst)
152 |
153 | end = time.time()
154 | diff = str(round(end - start, 2))
155 | print("FINISH. TIME: %s" % diff)
156 |
157 |
158 | if __name__ == "__main__":
159 | process()
160 |
161 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | """train.py
2 |
3 | Main training script.
4 |
5 | Usage:
6 | train.py [--gpu=] [--view=]
7 | train.py (-h | --help)
8 | train.py --version
9 |
10 | Options:
11 | -h --help Show this string.
12 | --version Show version.
13 | --gpu= Comma separated GPU list.
14 | --view= View dataset- use either 'train' or 'valid'.
15 | """
16 |
17 | from docopt import docopt
18 | import argparse
19 | import json
20 | import os
21 |
22 | import numpy as np
23 | import tensorflow as tf
24 | from tensorpack import Inferencer, logger
25 | from tensorpack.callbacks import (
26 | DataParallelInferenceRunner,
27 | ModelSaver,
28 | MinSaver,
29 | MaxSaver,
30 | ScheduledHyperParamSetter,
31 | )
32 | from tensorpack.tfutils import SaverRestore, get_model_loader
33 | from tensorpack.train import (
34 | SyncMultiGPUTrainerParameterServer,
35 | TrainConfig,
36 | launch_train_with_config,
37 | )
38 |
39 | import loader.loader as loader
40 | from config import Config
41 | from misc.utils import get_files
42 |
43 | from sklearn.metrics import roc_auc_score
44 |
45 |
46 | class StatCollector(Inferencer, Config):
47 | """
48 | Accumulate output of inference during training.
49 | After the inference finishes, calculate the statistics
50 | """
51 |
52 | def __init__(self, prefix="valid"):
53 | super(StatCollector, self).__init__()
54 | self.prefix = prefix
55 |
56 | def _get_fetches(self):
57 | return self.train_inf_output_tensor_names
58 |
59 | def _before_inference(self):
60 | self.true_list = []
61 | self.pred_list = []
62 |
63 | def _on_fetches(self, outputs):
64 | pred, true = outputs
65 | self.true_list.extend(true)
66 | self.pred_list.extend(pred)
67 |
68 | def _after_inference(self):
69 | # ! factor this out
70 | def _dice(true, pred, label):
71 | true = np.array(true[..., label], np.int32)
72 | pred = np.array(pred[..., label], np.int32)
73 | inter = (pred * true).sum()
74 | total = (pred + true).sum()
75 | return 2 * inter / (total + 1.0e-8)
76 |
77 | stat_dict = {}
78 | pred = np.array(self.pred_list)
79 | true = np.array(self.true_list)
80 |
81 | if self.model_mode == "seg_gland":
82 | # Get the segmentation stats
83 |
84 | pred = pred[..., :2]
85 | true = true[..., :2]
86 |
87 | # Binarize the prediction
88 | pred[pred > 0.5] = 1.0
89 |
90 | stat_dict[self.prefix + "_dice_obj"] = _dice(true, pred, 0)
91 | stat_dict[self.prefix + "_dice_cnt"] = _dice(true, pred, 1)
92 |
93 | elif self.model_mode == "seg_nuc":
94 | # Get the segmentation stats
95 |
96 | pred = pred[..., :3]
97 | true = true[..., :3]
98 |
99 | # Binarize the prediction
100 | pred[pred > 0.5] = 1.0
101 |
102 | stat_dict[self.prefix + "_dice_np"] = _dice(true, pred, 0)
103 | stat_dict[self.prefix + "_dice_mk_blb"] = _dice(true, pred, 1)
104 | stat_dict[self.prefix + "_dice_mk_cnt"] = _dice(true, pred, 2)
105 |
106 | else:
107 | # Get the classification stats
108 |
109 | # Convert vector to scalar prediction
110 | prob = np.squeeze(pred[..., 1])
111 | pred = np.argmax(pred, -1)
112 | pred = np.squeeze(pred)
113 | true = np.squeeze(true)
114 |
115 | accuracy = (pred == true).sum() / np.size(true)
116 | error = (pred != true).sum() / np.size(true)
117 |
118 | stat_dict[self.prefix + "_acc"] = accuracy * 100
119 | stat_dict[self.prefix + "_error"] = error * 100
120 |
121 | if self.model_mode == "class_pcam":
122 | auc = roc_auc_score(true, prob)
123 | stat_dict[self.prefix + "_auc"] = auc
124 |
125 | return stat_dict
126 |
127 |
128 | ###########################################
129 |
130 |
131 | class Trainer(Config):
132 | def get_datagen(self, batch_size, mode="train", view=False):
133 | if mode == "train":
134 | augmentors = self.get_train_augmentors(
135 | self.train_input_shape, self.train_output_shape, view
136 | )
137 | data_files = get_files(self.train_dir, self.data_ext)
138 | # Different data generators for segmentation and classification
139 | if self.model_mode == "seg_gland" or self.model_mode == "seg_nuc":
140 | data_generator = loader.train_generator_seg
141 | else:
142 | data_generator = loader.train_generator_class
143 | nr_procs = self.nr_procs_train
144 | else:
145 | augmentors = self.get_valid_augmentors(
146 | self.train_input_shape, self.train_output_shape, view
147 | )
148 | # Different data generators for segmentation and classification
149 | data_files = get_files(self.valid_dir, self.data_ext)
150 | if self.model_mode == "seg_gland" or self.model_mode == "seg_nuc":
151 | data_generator = loader.valid_generator_seg
152 | else:
153 | data_generator = loader.valid_generator_class
154 | nr_procs = self.nr_procs_valid
155 |
156 | # set nr_proc=1 for viewing to ensure clean ctrl-z
157 | nr_procs = 1 if view else nr_procs
158 | dataset = loader.DatasetSerial(data_files)
159 | if self.model_mode == "seg_gland" or self.model_mode == "seg_nuc":
160 | datagen = data_generator(
161 | dataset,
162 | shape_aug=augmentors[0],
163 | input_aug=augmentors[1],
164 | label_aug=augmentors[2],
165 | batch_size=batch_size,
166 | nr_procs=nr_procs,
167 | )
168 | else:
169 | datagen = data_generator(
170 | dataset,
171 | shape_aug=augmentors[0],
172 | input_aug=augmentors[1],
173 | batch_size=batch_size,
174 | nr_procs=nr_procs,
175 | )
176 |
177 | return datagen
178 |
179 | def view_dataset(self, mode="train"):
180 | assert mode == "train" or mode == "valid", "Invalid view mode"
181 | if self.model_mode == "seg_gland" or self.model_mode == "seg_nuc":
182 | datagen = self.get_datagen(4, mode=mode, view=True)
183 | loader.visualize(datagen, 4)
184 | else:
185 | # visualise more for classification- don't need to show label
186 | datagen = self.get_datagen(8, mode=mode, view=True)
187 | loader.visualize(datagen, 8)
188 | return
189 |
190 | def run_once(self, opt, sess_init=None, save_dir=None):
191 |
192 | train_datagen = self.get_datagen(opt["train_batch_size"], mode="train")
193 | valid_datagen = self.get_datagen(opt["infer_batch_size"], mode="valid")
194 |
195 | ###### must be called before ModelSaver
196 | if save_dir is None:
197 | logger.set_logger_dir(self.save_dir)
198 | else:
199 | logger.set_logger_dir(save_dir)
200 |
201 | ######
202 | model_flags = opt["model_flags"]
203 | model = self.get_model()(**model_flags)
204 | ######
205 | callbacks = [
206 | ModelSaver(max_to_keep=1, keep_checkpoint_every_n_hours=None),
207 | ]
208 |
209 | for param_name, param_info in opt["manual_parameters"].items():
210 | model.add_manual_variable(param_name, param_info[0])
211 | callbacks.append(ScheduledHyperParamSetter(param_name, param_info[1]))
212 | # multi-GPU inference (with mandatory queue prefetch)
213 | infs = [StatCollector()]
214 | callbacks.append(
215 | DataParallelInferenceRunner(valid_datagen, infs, list(range(nr_gpus)))
216 | )
217 | if self.model_mode == "seg_gland":
218 | callbacks.append(MaxSaver("valid_dice_obj"))
219 | elif self.model_mode == "seg_nuc":
220 | callbacks.append(MaxSaver("valid_dice_np"))
221 | else:
222 | callbacks.append(MaxSaver("valid_auc"))
223 |
224 | steps_per_epoch = train_datagen.size() // nr_gpus
225 |
226 | config = TrainConfig(
227 | model=model,
228 | callbacks=callbacks,
229 | dataflow=train_datagen,
230 | steps_per_epoch=steps_per_epoch,
231 | max_epoch=opt["nr_epochs"],
232 | )
233 | config.session_init = sess_init
234 |
235 | launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_gpus))
236 | tf.reset_default_graph() # remove the entire graph in case of multiple runs
237 | return
238 |
239 | def run(self):
240 | def get_last_chkpt_path(prev_phase_dir):
241 | stat_file_path = prev_phase_dir + "/stats.json"
242 | with open(stat_file_path) as stat_file:
243 | info = json.load(stat_file)
244 | chkpt_list = [epoch_stat["global_step"] for epoch_stat in info]
245 | last_chkpts_path = "%smodel-%d.index" % (prev_phase_dir, max(chkpt_list))
246 | return last_chkpts_path
247 |
248 | phase_opts = self.training_phase
249 |
250 | if len(phase_opts) > 1:
251 | for idx, opt in enumerate(phase_opts):
252 |
253 | log_dir = "%s/%02d" % (self.save_dir, idx)
254 | if opt["pretrained_path"] == -1:
255 | pretrained_path = get_last_chkpt_path(prev_log_dir)
256 | init_weights = SaverRestore(
257 | pretrained_path, ignore=["learning_rate"]
258 | )
259 | elif opt["pretrained_path"] is not None:
260 | init_weights = get_model_loader(pretrained_path)
261 | self.run_once(opt, sess_init=init_weights, save_dir=log_dir + "/")
262 | prev_log_dir = log_dir
263 | else:
264 |
265 | opt = phase_opts[0]
266 | if "pretrained_path" in opt:
267 | if opt["pretrained_path"] == None:
268 | init_weights = None
269 | elif opt["pretrained_path"] == -1:
270 | log_dir_prev = "%s" % self.save_dir
271 | pretrained_path = get_last_chkpt_path(log_dir_prev)
272 | init_weights = SaverRestore(
273 | pretrained_path, ignore=["learning_rate"]
274 | )
275 | else:
276 | init_weights = get_model_loader(opt["pretrained_path"])
277 | self.run_once(opt, sess_init=init_weights, save_dir=self.save_dir)
278 |
279 | return
280 |
281 |
282 | ###########################################################################
283 |
284 |
285 | if __name__ == "__main__":
286 |
287 | args = docopt(__doc__)
288 | print(args)
289 |
290 | trainer = Trainer()
291 |
292 | if args["--view"] and args["--gpu"]:
293 | raise Exception("Supply only one of --view and --gpu.")
294 |
295 | if args["--view"]:
296 | if args["--view"] != "train" and args["--view"] != "valid":
297 | raise Exception('Use "train" or "valid" for --view.')
298 | trainer.view_dataset(args["--view"])
299 | else:
300 | os.environ["CUDA_VISIBLE_DEVICES"] = args["--gpu"]
301 | nr_gpus = len(args["--gpu"].split(","))
302 | trainer.run()
303 |
--------------------------------------------------------------------------------
/src/viz_filters.py:
--------------------------------------------------------------------------------
1 | """viz_filters.py
2 |
3 | Visualise basis filters of the form (in polar coordinates):
4 |
5 | R_alpha(r)e^{i*alpha*phi}
6 |
7 | Here, R_alpha is a Gaussian centred at beta.
8 |
9 | Usage:
10 | viz_filters.py [--ksize=]
11 | viz_filters.py (-h | --help)
12 | viz_filters.py --version
13 |
14 | Options:
15 | -h --help Show this string.
16 | --version Show version.
17 | --ksize= Kernel size to display. [default: 7]
18 | """
19 |
20 |
21 | from docopt import docopt
22 | import numpy as np
23 | import matplotlib.pyplot as plt
24 | import matplotlib.cm as cm
25 | import math
26 |
27 |
28 | def get_filter_info(k_size):
29 | """
30 | Get the filter parameters for a given kernel size
31 |
32 | Args:
33 | k_size (int): input kernel size
34 |
35 | Returns:
36 | alpha_list: list of alpha values
37 | beta_list: list of beta values
38 | bl_list: used to bandlimit high frequency filters in get_basis_filters()
39 | """
40 |
41 | if k_size == 5:
42 | alpha_list = [0, 1, 2]
43 | beta_list = [0, 1, 2]
44 | bl_list = [0, 2, 2]
45 | elif k_size == 7:
46 | alpha_list = [0, 1, 2, 3]
47 | beta_list = [0, 1, 2, 3]
48 | bl_list = [0, 2, 3, 2]
49 | elif k_size == 9:
50 | alpha_list = [0, 1, 2, 3, 4]
51 | beta_list = [0, 1, 2, 3, 4]
52 | bl_list = [0, 3, 4, 4, 3]
53 | elif k_size == 11:
54 | alpha_list = [0, 1, 2, 3, 4]
55 | beta_list = [1, 2, 3, 4]
56 | bl_list = [0, 3, 4, 4, 3]
57 |
58 | return alpha_list, beta_list, bl_list
59 |
60 |
61 | def get_basis_filters(alpha_list, beta_list, bl_list, k_size, eps=10 ** -8):
62 | """
63 | Gets the atomic basis filters
64 |
65 | Args:
66 | alpha_list: list of alpha values for basis filters
67 | beta_list: list of beta values for the basis filters
68 | bl_list: bandlimit list to reduce aliasing of basis filters
69 | k_size (int): kernel size of basis filters
70 | eps=10**-8: epsilon used to prevent division by 0
71 |
72 | Returns:
73 | filter_list_bl: list of filters, with bandlimiting (bl) to reduce aliasing
74 | alpha_list_bl: corresponding list of alpha used in bandlimited filters
75 | beta_list_bl: corresponding list of beta used in bandlimited filters
76 | """
77 |
78 | filter_list_bl = []
79 | alpha_list_bl = []
80 | beta_list_bl = []
81 | for alpha in alpha_list:
82 | for beta in beta_list:
83 | if np.abs(alpha) <= bl_list[beta]:
84 | his = k_size // 2 # half image size
85 | y_index, x_index = np.mgrid[-his : (his + 1), -his : (his + 1)]
86 | y_index *= -1
87 | z_index = x_index + 1j * y_index
88 |
89 | # convert z to natural coordinates and add eps to avoid division by zero
90 | z = z_index + eps
91 | r = np.abs(z)
92 |
93 | if beta == beta_list[-1]:
94 | sigma = 0.6
95 | else:
96 | sigma = 0.6
97 | rad_prof = np.exp(-((r - beta) ** 2) / (2 * (sigma ** 2)))
98 | c_image = rad_prof * (z / r) ** alpha
99 |
100 | # add filter to list
101 | filter_list_bl.append(c_image)
102 | # add frequency of filter to list (needed for phase manipulation)
103 | alpha_list_bl.append(alpha)
104 | beta_list_bl.append(beta)
105 |
106 | return filter_list_bl, alpha_list_bl, beta_list_bl
107 |
108 |
109 | def plot_filters(filter_list, alpha_list, beta_list):
110 | """
111 | Plot the real and imaginary parts of the basis filters.
112 |
113 | Args:
114 | filter_list: list of basis filters
115 | alpha_list: alpha of each basis filter
116 | beta_list: beta of each basis filter
117 | """
118 |
119 | count = 1
120 | plt.figure(figsize=(25, 11))
121 | nr_filts = 2 * len(filter_list)
122 | len_x = math.ceil(np.sqrt(nr_filts))
123 | len_y = math.ceil(np.sqrt(nr_filts))
124 | len_x = 5
125 | len_y = 10
126 |
127 | for i in range(len(filter_list)):
128 | filt_real = filter_list[i].real
129 | filt_imag = filter_list[i].imag
130 | plt.subplot(len_x, len_y, count)
131 | plt.imshow(filt_real, vmin=-1, vmax=1, cmap=cm.gist_gray)
132 | plt.axis("off")
133 | plt.title(
134 | "Real: $alpha$= %s, $beta$= %s" % (alpha_list[i], beta_list[i]), fontsize=6
135 | )
136 | plt.subplot(len_x, len_y, count + 1)
137 | plt.imshow(filt_imag, vmin=-1, vmax=1, cmap=cm.gist_gray)
138 | plt.axis("off")
139 | plt.title(
140 | "Imaginary: $alpha$= %s, $beta$= %s" % (alpha_list[i], beta_list[i]),
141 | fontsize=6,
142 | )
143 | count += 2
144 | plt.tight_layout()
145 | plt.show()
146 |
147 |
148 | #####
149 | if __name__ == "__main__":
150 | args = docopt(__doc__)
151 |
152 | ksize = int(args["--ksize"])
153 |
154 | if ksize not in [5, 7, 9, 11]:
155 | raise Exception("Select ksize to be either 5,7,9 or 11")
156 |
157 | info = get_filter_info(ksize)
158 |
159 | filter_list, alpha_list, beta_list = get_basis_filters(
160 | info[0], info[1], info[2], ksize
161 | )
162 |
163 | plot_filters(filter_list, alpha_list, beta_list)
164 |
--------------------------------------------------------------------------------