├── LICENSE ├── README.md ├── datafolder └── download_data.sh ├── dummydataset.nf ├── nextflow.config ├── realdataset.nf ├── src_DummyDataSet ├── DDSCreation.py ├── Records.py ├── UNet_Logit_Normalized.py ├── UNet_Normalized.py └── UNet_UNNormalized.py └── src_RealData ├── Data ├── CreateTFRecords.py ├── DataGen2.py ├── DataGenClass.py ├── DataGenRandomT.py ├── FIMM_histo │ ├── __init__.py │ ├── deconvolution.py │ ├── post_analysis.py │ ├── preparation.py │ └── segmentation_test.py ├── ImageTransform.py ├── __init__.py └── utils.py ├── Dist.py ├── FCN.py ├── FCN_Object.py ├── Nets ├── DataReadDecode.py ├── DataTF.py ├── ObjectOriented.py ├── UNetBatchNorm.py ├── UNetDistance.py ├── UNetObject.py └── __init__.py ├── TFRecords.py ├── UNet.py ├── __init__.py ├── postproc ├── __init__.py ├── plot.py ├── postprocessing.py └── regroup.py ├── preproc ├── BinToDistance.py ├── MeanCalculation.py ├── __init__.py └── changescale.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 PeterJackNaylor 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.1175282.svg)](https://doi.org/10.5281/zenodo.1175282) 2 | # Deep Regression For Nuclei Segmentation 3 | 4 | In this repository, we've implemented the code resulting in the submitted paper "Segmentation of Nuclei in Histopathology Images by deep regression of the distance map." in TMI, by P. Naylor, M. Laé, F. Reyal and T. Walter. The code in this repository is made out of nextflow for pipeline management. For more information about nextflow please refer to [Nextflow's documentation](https://www.nextflow.io/). Each process of this pipeline is made out of python and called via nextflow. Nextflow was very handy to take advantage of a cluster with multi-queues, in particular for a pipeline using CPU's and GPU's. 5 | 6 | ## Description 7 | We tackle the task of nuclei segmentation within histopathology tissue. Many methods have been proposed in the past but relatively few of them try to handle the wide hetereogenity that one can typically encounter with such data. As is the current trend, we apply state of the art algorithm technics based on CNN but also our novel nuclei segmentation framework. We compare each method with two metrics, the first pixel based and the second objected based. We elaborate a benchmark of fully convolutionnal algorithm applied to these datasets. In particular we compare ourselves to [\[Neeraj et al\]](https://nucleisegmentationbenchmark.weebly.com/) and show that our deep regression method outperforms previous state of the art method for separating cluttered objects. 8 | 9 | # Setup 10 | To setup, please [install](https://www.nextflow.io/docs/latest/getstarted.html) nextflow and configure it to your setup by configuring the nextflow.config file. 11 | This code works for python 2.7.11 and tensorflow 1.5. They may be other requirements in terms of python packages but nothing too specific. Install with ```conda install ``` or ```pip install```. 12 | In addition, if one wishes to reproduce the results achieved with FCN, please download and add to the PYTHONPATH this [directory](https://github.com/warmspringwinds/tf-image-segmentation). 13 | # Hardware 14 | This code was run on a K80 GPU. A K80 has a bi-heart and that is why ```maxForks 2``` in the nextflow files. Also do not hesitate to modify the processes environnement. For instance, to assure that jobs were running on seperate nodes, I used the options ```beforeScript``` and ```afterScript```. The scripts called at these moment are just locks. If one job launches first it will create a lock (for example for GPU number 0) to alert other jobs that he is using a GPU (number 0). Removes these lines if you do not need them. 15 | # Data 16 | The data made publicaly available by our institute can be found [here](https://zenodo.org/record/1175282/files/TNBC_NucleiSegmentation.zip). If you want to run the code by yourself please run the file download_metadata.sh, ```bash download_data.sh``` that can be found in the datafolder. This will download *DS1* and *DS2* as described in the paper, moreover it will download the pretrained weights for the FCN model. 17 | By running this script, the image file will be subdivided into groups as described in the paper. 18 | One could also find the data annouced publicaly available by checking out the website created by the authors [\[Neeraj et al\]](https://nucleisegmentationbenchmark.weebly.com/). 19 | # Running the pipeline with synthetic data 20 | To run this pipeline you will have to run the following command: ```nextflow run dummydataset.nf --epoch 10 -c nextflow.config -resume``` 21 | This command will call the nextflow pipeline script ```dummydataset.nf``` which is subdivided into 3 process: 22 | 1) Elaboration of the dummy data 23 | 2) Creating tensorflow records 24 | 3) Training the different designs with fixed hyper parameters. 25 | Would be nice to insert diagram picture of the pipeline. 26 | # Running the pipeline with the real data 27 | To run this pipeline you will have to run the following command: ```nextflow run realdataset.nf --epoch 80 -c nextflow.config -resume``` 28 | This command will call the nextflow pipeline script ```realdataset.nf``` which is subdivided into 4 steps: 29 | 1) Preparing the data for the experiments, processes: *ChangeInput*, *BinToDistance*, *Mean* and *CreateRecords*. 30 | 2) Training 31 | 3) Validating the different models, processes: *Testing*, *GetBestPerKey*, *Validation* and *plot*. 32 | Would be nice to insert diagram picture of the pipeline. 33 | # Running with your own data 34 | If you wish to run the pipeline with your own data please specify your data folder so they follow the same rules as those found in ```./datafolder/ ```. Further more, you can tweak how the model loads the data by modifying the data generator class that can be found in ```./DataGen/ ```. 35 | -------------------------------------------------------------------------------- /datafolder/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | wget https://zenodo.org/record/1175282/files/TNBC_NucleiSegmentation.zip && \ 4 | unzip TNBC_NucleiSegmentation.zip -d TNBC_NucleiSegmentation && \ 5 | rm TNBC_NucleiSegmentation.zip 6 | 7 | wget http://members.cbio.mines-paristech.fr/~pnaylor/Downloads/pretrained.zip && \ 8 | unzip pretrained.zip -d pretrained && \ 9 | rm pretrained.zip 10 | 11 | wget http://members.cbio.mines-paristech.fr/~pnaylor/Downloads/ForDataGenTrainTestVal.zip && \ 12 | unzip ForDataGenTrainTestVal.zip -d ForDataGenTrainTestVal && \ 13 | rm ForDataGenTrainTestVal.zip -------------------------------------------------------------------------------- /dummydataset.nf: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env nextflow 2 | 3 | params.toannotate = file('datafolder/TNBC_NucleiSegmentation') 4 | 5 | PYDS = file('src_DummyDataSet/DDSCreation.py') 6 | TFRecordPY = file('src_DummyDataSet/Records.py') 7 | 8 | UNET_NORMALIZED_LOGIT = file('src_DummyDataSet/UNet_Logit_Normalized.py') 9 | UNET_NORMALIZED = file('src_DummyDataSet/UNet_Normalized.py') 10 | UNET_UNNORMALIZED = file('src_DummyDataSet/UNet_UNNormalized.py') 11 | 12 | target = Channel.from( [0, UNET_NORMALIZED_LOGIT], [0, UNET_NORMALIZED], [1, UNET_UNNORMALIZED]) 13 | 14 | def getFileName( file ) { 15 | if ( file.name == "Normalized" ){ 16 | 0 17 | } else { 18 | 1 19 | } 20 | } 21 | 22 | VAL_NAME = [[0, "Normalized"], [1, "UNNormalized"]] 23 | 24 | process DummyDataSet { 25 | publishDir "./out_DDS/Data" 26 | input: 27 | file path from params.toannotate 28 | file py from PYDS 29 | each pair from VAL_NAME 30 | output: 31 | file "./${pair[1]}" into PATHS 32 | """ 33 | python $py --path $path --output ./${pair[1]} --test 10 --mu 127 --sigma 100 --sigma2 10 --normalized ${pair[0]} 34 | """ 35 | } 36 | 37 | process CreateTFRecords { 38 | publishDir "./out_DDS/Records" 39 | input: 40 | file py from TFRecordPY 41 | file path from PATHS 42 | output: 43 | set file("$path"), file("${path}.tfrecords") into PATH_RECORDS 44 | """ 45 | python $py --output ${path}.tfrecords --path $path --crop 4 --UNet --size 212 --seed 42 --epoch 20 --type JUST_READ --train 46 | """ 47 | } 48 | 49 | PATH_RECORDS .map { folder, rec -> tuple(getFileName(folder) ,folder, rec) } .cross(target) .set { PATH_RECORDS_FILE } 50 | // PATH_RECORDS_FILE .subscribe{first, second -> println("New parameter:\n") println(first[0]) println(first[1]) println(first[2]) println(second[0]) println(second[1]) println(second[1].baseName) println("\n")} 51 | process Training { 52 | publishDir "./out_DDS/${second[1].baseName}" 53 | maxForks 1 54 | input: 55 | set first, second from PATH_RECORDS_FILE // first[1] is the path, first[2] is the record, second[1] is the python file 56 | output: 57 | file "step_*" 58 | 59 | 60 | """ 61 | python -W ignore ${second[1]} --tf_record ${first[2]} --path ${first[1]} --log . --learning_rate 0.001 --batch_size 4 --epoch 100 --n_features 2 --weight_decay 0.005 --dropout 0.5 --n_threads 50 62 | """ 63 | } 64 | 65 | 66 | -------------------------------------------------------------------------------- /nextflow.config: -------------------------------------------------------------------------------- 1 | env.PYTHONPATH = "$PWD/src_RealData:$PYTHONPATH" 2 | -------------------------------------------------------------------------------- /realdataset.nf: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env nextflow 2 | 3 | // General parameters 4 | params.image_dir = './datafolder' 5 | params.epoch = 1 6 | IMAGE_FOLD = file(params.image_dir + "/ForDataGenTrainTestVal") 7 | 8 | /* 0) a) Resave all the images so that they have 1 for label instead of 255 9 | 0) b) Resave all the images so that they are distance map 10 | In outputs: 11 | newpath name 12 | */ 13 | 14 | CHANGESCALE = file('src_RealData/preproc/changescale.py') 15 | NAMES = ["FCN", "UNet"] 16 | 17 | process ChangeInput { 18 | input: 19 | file path from IMAGE_FOLD 20 | file changescale from CHANGESCALE 21 | each name from NAMES 22 | output: 23 | set val("$name"), file("ImageFolder") into IMAGE_FOLD2, IMAGE_FOLD3 24 | """ 25 | python $changescale --path $path 26 | 27 | """ 28 | } 29 | 30 | BinToDistanceFile = file('src_RealData/preproc/BinToDistance.py') 31 | 32 | process BinToDistance { 33 | input: 34 | file py from BinToDistanceFile 35 | file toannotate from IMAGE_FOLD 36 | output: 37 | set val("DIST"), file("ToAnnotateDistance") into DISTANCE_FOLD, DISTANCE_FOLD2 38 | 39 | """ 40 | python $py $toannotate 41 | """ 42 | } 43 | 44 | /* 1) We create all the needed records 45 | In outputs: 46 | a set with the name, the split and the record 47 | */ 48 | 49 | TFRECORDS = file('src_RealData/TFRecords.py') 50 | IMAGE_FOLD2 .concat(DISTANCE_FOLD) .into{FOLDS;FOLDS2} 51 | UNET_REC = ["UNet", "--UNet", 212] 52 | FCN_REC = ["FCN", "--no-UNet", 224] 53 | DIST_REC = ["DIST", "--UNet", 212] 54 | 55 | RECORDS_OPTIONS = Channel.from(UNET_REC, FCN_REC, DIST_REC) 56 | FOLDS.join(RECORDS_OPTIONS) .set{RECORDS_OPTIONS_v2} 57 | RECORDS_HP = [["train", "16", "0"], ["test", "1", 500], ["validation", "1", 996]] 58 | 59 | process CreateRecords { 60 | input: 61 | file py from TFRECORDS 62 | val epoch from params.epoch 63 | set name, file(path), unet, size_train from RECORDS_OPTIONS_v2 64 | each op from RECORDS_HP 65 | output: 66 | set val("${name}"), val("${op[0]}"), file("${op[0]}_${name}.tfrecords") into NSR0, NSR1, NSR2 67 | """ 68 | python $py --tf_record ${op[0]}_${name}.tfrecords --split ${op[0]} --path $path --crop ${op[1]} $unet --size_train $size_train --size_test ${op[2]} --seed 42 --epoch $epoch --type JUST_READ 69 | """ 70 | } 71 | 72 | NSR0.filter{ it -> it[1] == "train" }.set{TRAIN_REC} 73 | NSR1.filter{ it -> it[1] == "test" }.set{TEST_REC} 74 | NSR2.filter{ it -> it[1] == "validation" }.set{VAL_REC} 75 | 76 | /* 2) We create the mean 77 | In outputs: 78 | a set with the name, the split and the record 79 | */ 80 | 81 | MEANPY = file('src_RealData/preproc/MeanCalculation.py') 82 | 83 | process Mean { 84 | input: 85 | file py from MEANPY 86 | set val(name), file(toannotate) from FOLDS2 87 | output: 88 | set val("$name"), file("mean_file.npy"), file("$toannotate") into MeanFile, Meanfile2, Meanfile2VAL, Meanfile3, Meanfile3VAL 89 | """ 90 | python $py --path $toannotate --output . 91 | """ 92 | } 93 | 94 | /* 3) We train 95 | In inputs: Meanfile, name, split, rec 96 | In outputs: 97 | a set with the name, the parameters of the model 98 | */ 99 | 100 | ITERTEST = 50 101 | ITER8 = 108 // 10800 102 | LEARNING_RATE = [0.01]//, 0.001, 0.0001, 0.00001, 0.000001] 103 | FEATURES = [16]//, 32, 64] 104 | WEIGHT_DECAY = [0.00005]//, 0.0005] 105 | BS = 10 106 | 107 | Unet_file = file('src_RealData/UNet.py') 108 | Fcn_file = file('src_RealData/FCN.py') 109 | Dist_file = file('src_RealData/Dist.py') 110 | 111 | UNET_TRAINING = ["UNet", Unet_file, 212, 0] 112 | FCN_TRAINING = ["FCN", Fcn_file, 224, ITER8] 113 | DIST_TRAINING = ["DIST", Dist_file, 212, 0] 114 | 115 | Channel.from(UNET_TRAINING, FCN_TRAINING, DIST_TRAINING) .into{ TRAINING_CHANNEL; TRAINING_CHANNEL2; TRAINING_CHANNELVAL2} 116 | PRETRAINED_8 = file(params.image_dir + "/pretrained/checkpoint16/") 117 | TRAIN_REC.join(TRAINING_CHANNEL).join(MeanFile) .set {TRAINING_OPTIONS} 118 | 119 | process Training { 120 | maxForks 2 121 | beforeScript "source \$HOME/CUDA_LOCK/.whichNODE" 122 | afterScript "source \$HOME/CUDA_LOCK/.freeNODE" 123 | input: 124 | set name, split, file(rec), file(py), size, iters, file(mean), file(path) from TRAINING_OPTIONS 125 | val bs from BS 126 | each feat from FEATURES 127 | each lr from LEARNING_RATE 128 | each wd from WEIGHT_DECAY 129 | file __ from PRETRAINED_8 130 | val epoch from params.epoch 131 | output: 132 | set val("$name"), file("${name}__${feat}_${wd}_${lr}"), file("$py"), feat, wd, lr into RESULT_TRAIN, RESULT_TRAIN2, RESULT_TRAIN_VAL, RESULT_TRAIN_VAL2 133 | when: 134 | "$name" != "FCN" || ("$feat" == "${FEATURES[0]}" && "$wd" == "${WEIGHT_DECAY[0]}") 135 | script: 136 | """ 137 | python $py --tf_record $rec --path $path --log ${name}__${feat}_${wd}_${lr} --learning_rate $lr --batch_size $bs --epoch $epoch --n_features $feat --weight_decay $wd --mean_file ${mean} --n_threads 100 --restore $__ --size_train $size --split $split --iters $iters 138 | """ 139 | } 140 | 141 | /* 4) a) We choose the best hyperparamter with respect to the test data set 142 | 143 | In inputs: Meanfile, image_path resp., split, rec, model, python, feat 144 | In outputs: a set with the name and model or csv 145 | */ 146 | // a) 147 | P1 = [0, 1, 10, 11]//[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] 148 | P2 = [0.5, 1.0] //, 1.5, 2.0] 149 | TEST_REC.cross(RESULT_TRAIN).map{ first, second -> [first, second.drop(1)].flatten() } .set{ TEST_OPTIONS_pre } 150 | Meanfile2.cross(TEST_OPTIONS_pre).map { first, second -> [first, second.drop(1)].flatten() } .into{TEST_OPTIONS;TEST_OPTIONS2} 151 | 152 | process Testing { 153 | maxForks 2 154 | beforeScript "source \$HOME/CUDA_LOCK/.whichNODE" 155 | afterScript "source \$HOME/CUDA_LOCK/.freeNODE" 156 | input: 157 | set name, file(mean), file(path), split, file(rec), file(model), file(py), feat, wd, lr from TEST_OPTIONS 158 | each p1 from P1 159 | each p2 from P2 160 | val iters from ITERTEST 161 | output: 162 | set val("$name"), file("${name}__${feat}_${wd}_${lr}_${p1}_${p2}.csv") into RESULT_TEST 163 | set val("$name"), file("$model") into MODEL_TEST 164 | when: 165 | ("$name" =~ "DIST" && p1 < 6) || ( !("$name" =~ "DIST") && p2 == P2[0] && p1 > 5) 166 | script: 167 | """ 168 | python $py --tf_record $rec --path $path --log $model --batch_size 1 --n_features $feat --mean_file ${mean} --n_threads 100 --split $split --size_test 500 --p1 ${p1} --p2 ${p2} --restore $model --iters $iters --output ${name}__${feat}_${wd}_${lr}_${p1}_${p2}.csv 169 | """ 170 | 171 | } 172 | 173 | 174 | /* 5) We regroup a) the test on dataset 1 175 | In inputs: name, all result_test.csv per key 176 | In outputs: name, best_model, p1, p2 177 | */ 178 | // a) 179 | REGROUP = file('src_RealData/postprocessing/regroup.py') 180 | RESULT_TEST .groupTuple() 181 | .set { KEY_CSV } 182 | RESULT_TRAIN2.map{name, model, py, feat, wd, lr -> [name, model]} .groupTuple() . set {ALL_MODELS} 183 | KEY_CSV .join(ALL_MODELS) .set {KEY_CSV_MODEL} 184 | 185 | process GetBestPerKey { 186 | publishDir "./out_RDS/Test_tables/" , pattern: "*.csv" 187 | input: 188 | file py from REGROUP 189 | set name, file(csv), file(model) from KEY_CSV_MODEL 190 | 191 | output: 192 | set val("$name"), file("best_model") into BEST_MODEL_TEST 193 | file 'feat_val' into N_FEATS 194 | file 'p1_val' into P1_VAL 195 | file 'p2_val' into P2_VAL 196 | file "${name}_test.csv" 197 | """ 198 | python $py --store_best best_model --output ${name}_test.csv 199 | """ 200 | } 201 | 202 | /* 203 | Compute validation score on validation set 204 | a) Validation with hyper parameter choosen on different dataset 205 | 206 | */ 207 | // a) 208 | BEST_MODEL_TEST.join(TRAINING_CHANNEL2).join(Meanfile3) .set{ VALIDATION_OPTIONS} 209 | N_FEATS .map{ it.text } .set {FEATS_} 210 | P1_VAL .map{ it.text } .set {P1_} 211 | P2_VAL .map{ it.text } .set {P2_} 212 | 213 | process Validation { 214 | publishDir "./out_RDS/Validation/" 215 | input: 216 | set name, file(best_model), file(py), _, __, file(mean), file(path) from VALIDATION_OPTIONS 217 | val feat from FEATS_ 218 | val p1 from P1_ 219 | val p2 from P2_ 220 | output: 221 | file "./$name" 222 | file "${name}.csv" into CSV_VAL 223 | """ 224 | python $py --mean_file $mean --path $path --log $best_model --restore $best_model --batch_size 1 --n_features ${feat} --n_threads 100 --split validation --size_test 500 --p1 ${p1} --p2 ${p2} --output ${name}.csv --save_path $name 225 | """ 226 | } 227 | 228 | PLOT = file('src_RealData/postprocessing/plot.py') 229 | 230 | process Plot { 231 | publishDir "./out_RDS/Validation/" 232 | input: 233 | file _ from CSV_VAL .collect() 234 | file py from PLOT 235 | output: 236 | file "BarResult_train_test_val.png" 237 | """ 238 | python $py --output BarResult_train_test_val.png --output_csv Result_train_test_val.csv 239 | """ 240 | } 241 | -------------------------------------------------------------------------------- /src_DummyDataSet/DDSCreation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | from Data.DataGenClass import DataGenMulti 6 | 7 | from Data.ImageTransform import ListTransform 8 | import pdb 9 | from os.path import join 10 | from optparse import OptionParser 11 | import numpy as np 12 | import matplotlib.pylab as plt 13 | from scipy.ndimage.morphology import distance_transform_cdt 14 | from utils import CheckOrCreate 15 | from skimage.io import imsave 16 | from skimage.measure import label, regionprops 17 | from skimage.morphology import dilation, disk 18 | 19 | def AddNoiseBinary(GT, p): 20 | x, y = GT.shape 21 | vec = GT.flatten() 22 | noise = np.random.binomial(1, p, len(vec)) 23 | noisy_vec = 1 - ( vec * noise + (1 - vec) * (1 - noise) ) 24 | GT_noise = noisy_vec.reshape(x,y) 25 | return GT_noise 26 | 27 | def TruncatedNormal(loc, scale, size, min_v, max_v): 28 | vec = np.random.normal(loc=loc, scale=scale, size=size) 29 | def f(val): 30 | if min_v > val: 31 | return True 32 | elif val > max_v: 33 | return True 34 | else: 35 | return False 36 | res = np.array(map(f, vec)) 37 | if True in res: 38 | n_size = res.sum() 39 | vec[res] = TruncatedNormal(loc, scale, n_size, min_v, max_v) 40 | return vec 41 | 42 | def Noise(img, std): 43 | fl = img.flatten() 44 | vec = np.random.normal(loc=0, scale=std, size=len(fl)) 45 | def g(val): 46 | if 0 > val: 47 | return -val 48 | else: 49 | return val 50 | noise = np.array(map(g, vec)) 51 | res = fl + noise 52 | res = res.reshape(img.shape) 53 | res[res > 255.] = 255. 54 | return res.astype('uint8') 55 | 56 | def AddNoise(GT, mu, std, std_2, min_v, max_v): 57 | GT_lbl = label(GT) 58 | x, y = GT.shape 59 | GT_noise = np.zeros(shape=(x, y, 3)) 60 | for i in range(1, GT_lbl.max() + 1): 61 | 62 | col = TruncatedNormal(loc=mu, scale=std, size=3, min_v=min_v, max_v=max_v) 63 | #GT_noise[GT_lbl == i] = col.astype(int) 64 | nrow = GT_lbl[GT_lbl == i].shape[0] 65 | RANDOM = np.zeros(shape=(nrow, 3)) 66 | for j in range(3): 67 | RANDOM[:, j] = TruncatedNormal(loc=col[j], scale=std_2, size=nrow, min_v=min_v, max_v=max_v) 68 | GT_noise[GT_lbl == i] = RANDOM 69 | 70 | return GT_noise.astype('uint8') 71 | def DistanceWithoutNormalise(bin_image): 72 | res = np.zeros_like(bin_image) 73 | for j in range(1, bin_image.max() + 1): 74 | one_cell = np.zeros_like(bin_image) 75 | one_cell[bin_image == j] = 1 76 | one_cell = distance_transform_cdt(one_cell) 77 | res[bin_image == j] = one_cell[bin_image == j] 78 | res = res.astype('uint8') 79 | return res 80 | 81 | def DistanceBinNormalise(bin_image): 82 | bin_image = label(bin_image) 83 | result = np.zeros_like(bin_image, dtype="float") 84 | 85 | for k in range(1, bin_image.max() + 1): 86 | tmp = np.zeros_like(result, dtype="float") 87 | tmp[bin_image == k] = 1 88 | dist = distance_transform_cdt(tmp) 89 | MAX = dist.max() 90 | dist = dist.astype(float) / MAX 91 | result[bin_image == k] = dist[bin_image == k] 92 | result = result * 255 93 | result = result.astype('uint8') 94 | return result 95 | 96 | 97 | 98 | if __name__ == '__main__': 99 | 100 | parser = OptionParser() 101 | 102 | parser.add_option("--path", dest="path",type="string", 103 | help="path to annotated dataset") 104 | parser.add_option("--output", dest="output",type="string", 105 | help="out path") 106 | parser.add_option("-p", dest="p", type="float") 107 | parser.add_option("--test", dest="test", type="int") 108 | parser.add_option("--mu", dest="mu", type="int") 109 | parser.add_option("--sigma", dest="sigma", type="int") 110 | parser.add_option("--sigma2", dest="sigma2", type="int") 111 | parser.add_option("--normalized", dest='normalized', type='int') 112 | (options, args) = parser.parse_args() 113 | 114 | if (options.normalized != 0 and options.normalized != 1): 115 | raise AssertionError('normalized not define, give --normalized 0 or --normalized 1') 116 | 117 | 118 | path = "/data/users/pnaylor/Bureau/ToAnnotate" 119 | path = "/Users/naylorpeter/Documents/Histopathologie/ToAnnotate/ToAnnotate" 120 | 121 | 122 | path = options.path 123 | transf, transf_test = ListTransform() 124 | 125 | size = (512, 512) 126 | crop = 1 127 | DG = DataGenMulti(path, crop=crop, size=size, transforms=transf_test, 128 | split="train", num="", seed_=42) 129 | Slide_train = join(options.output, "Slide_train") 130 | CheckOrCreate(Slide_train) 131 | Gt_train = join(options.output, "GT_train") 132 | CheckOrCreate(Gt_train) 133 | 134 | Slide_test = join(options.output, "Slide_test") 135 | CheckOrCreate(Slide_test) 136 | Gt_test = join(options.output, "GT_test") 137 | CheckOrCreate(Gt_test) 138 | 139 | for i in range(DG.length): 140 | key = DG.NextKeyRandList(0) 141 | img, gt = DG[key] 142 | gt_lab = gt.copy() 143 | gt = label(gt) 144 | gt = dilation(gt, disk(3)) 145 | GT_noise = AddNoise(gt, options.mu, options.sigma, options.sigma2, 1, 255) 146 | GT_noise = Noise(GT_noise, 5) 147 | if options.test > i: 148 | imsave(join(Slide_test, "test_{}.png").format(i), GT_noise) 149 | if options.normalized == 0: 150 | imsave(join(Gt_test, "test_{}.png").format(i), DistanceBinNormalise(gt)) 151 | else: 152 | imsave(join(Gt_test, "test_{}.png").format(i), DistanceWithoutNormalise(gt)) 153 | else: 154 | imsave(join(Slide_train, "train_{}.png").format(i), GT_noise) 155 | if options.normalized == 0: 156 | imsave(join(Gt_train, "train_{}.png").format(i), DistanceBinNormalise(gt)) 157 | else: 158 | imsave(join(Gt_train, "train_{}.png").format(i), DistanceWithoutNormalise(gt)) 159 | # fig, ax = plt.subplots(nrows=2) 160 | # ax[0].imshow(gt) 161 | # ax[1].imshow(GT_noise) 162 | # plt.show() 163 | #np.save(join(options.out, "mean_file.npy"), mean) 164 | -------------------------------------------------------------------------------- /src_DummyDataSet/Records.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from optparse import OptionParser 5 | from Data.ImageTransform import ListTransform 6 | from Data.CreateTFRecords import CreateTFRecord 7 | 8 | def options_parser(): 9 | 10 | parser = OptionParser() 11 | 12 | parser.add_option('--output', dest="TFRecords", type="string", 13 | help="name for the output .tfrecords") 14 | parser.add_option('--path', dest="path", type="str", 15 | help="Where to find the annotations") 16 | parser.add_option('--crop', dest="crop", type="int", 17 | help="Number of crops to divide one image in") 18 | # parser.add_option('--UNet', dest="UNet", type="bool", 19 | # help="If image and annotations will have different shapes") 20 | parser.add_option('--size', dest="size", type="int", 21 | help='first dimension for size') 22 | parser.add_option('--seed', dest="seed", type="int", default=42, 23 | help='Seed to use, still not really implemented') 24 | parser.add_option('--epoch', dest="epoch", type ="int", 25 | help="Number of epochs to perform") 26 | parser.add_option('--type', dest="type", type ="str", 27 | help="Type for the datagen") 28 | parser.add_option('--UNet', dest='UNet', action='store_true') 29 | parser.add_option('--no-UNet', dest='UNet', action='store_false') 30 | 31 | parser.add_option('--train', dest='split', action='store_true') 32 | parser.add_option('--test', dest='split', action='store_false') 33 | parser.set_defaults(feature=True) 34 | 35 | (options, args) = parser.parse_args() 36 | options.SIZE = (options.size, options.size) 37 | return options 38 | 39 | if __name__ == '__main__': 40 | 41 | options = options_parser() 42 | 43 | OUTNAME = options.TFRecords 44 | PATH = options.path 45 | CROP = options.crop 46 | SIZE = options.SIZE 47 | SPLIT = "train" if options.split else "test" 48 | transform_list, transform_list_test = ListTransform() 49 | TRANSFORM_LIST = transform_list 50 | UNET = options.UNet 51 | SEED = options.seed 52 | TEST_PATIENT = ["test"] 53 | N_EPOCH = options.epoch 54 | TYPE = options.type 55 | 56 | 57 | CreateTFRecord(OUTNAME, PATH, CROP, SIZE, 58 | TRANSFORM_LIST, UNET, None, 59 | SEED, TEST_PATIENT, N_EPOCH, 60 | TYPE=TYPE, SPLIT=SPLIT) -------------------------------------------------------------------------------- /src_RealData/Data/CreateTFRecords.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import tensorflow as tf 5 | from DataGenRandomT import DataGenRandomT 6 | from DataGenClass import DataGen3, DataGenMulti, DataGen3reduce 7 | import numpy as np 8 | 9 | def _bytes_feature(value): 10 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 11 | 12 | def _int64_feature(value): 13 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 14 | 15 | 16 | def CreateTFRecord(OUTNAME, PATH, CROP, SIZE, 17 | TRANSFORM_LIST, UNET, MEAN_FILE, 18 | SEED, TEST_PATIENT, N_EPOCH, TYPE = "Normal", 19 | SPLIT="train"): 20 | """ 21 | Takes a DataGen object and creates an associated TFRecord file. 22 | We do not perform data augmentation on the fly but save the 23 | augmented images in the record. Most of the parameters here 24 | reference paramaters of the DataGen object. In particular, PATH, 25 | CROP, SIZE, TRANSFORM_LIST, UNET, SEED and TEST_PATIENT. 26 | OUTNAME is the name of the record. 27 | """ 28 | 29 | tfrecords_filename = OUTNAME 30 | writer = tf.python_io.TFRecordWriter(tfrecords_filename) 31 | 32 | 33 | if TYPE == "Normal": 34 | DG = DataGenRandomT(PATH, split=SPLIT, crop=CROP, size=SIZE, 35 | transforms=TRANSFORM_LIST, UNet=UNET, num=TEST_PATIENT, 36 | mean_file=MEAN_FILE, seed_=SEED) 37 | 38 | elif TYPE == "3class": 39 | DG = DataGen3(PATH, split=SPLIT, crop = CROP, size=SIZE, 40 | transforms=TRANSFORM_LIST, UNet=UNET, num=TEST_PATIENT, 41 | mean_file=MEAN_FILE, seed_=SEED) 42 | elif TYPE == "ReducedClass": 43 | DG = DataGen3reduce(PATH, split=SPLIT, crop = CROP, size=SIZE, 44 | transforms=TRANSFORM_LIST, UNet=UNET, num=TEST_PATIENT, 45 | mean_file=MEAN_FILE, seed_=SEED) 46 | elif TYPE == "JUST_READ": 47 | DG = DataGenMulti(PATH, split=SPLIT, crop = CROP, size=SIZE, 48 | transforms=TRANSFORM_LIST, UNet=UNET, num=TEST_PATIENT, 49 | mean_file=MEAN_FILE, seed_=SEED) 50 | 51 | DG.SetPatient(TEST_PATIENT) 52 | N_ITER_MAX = N_EPOCH * DG.length 53 | 54 | original_images = [] 55 | key = DG.RandomKey(False) 56 | if not UNET: 57 | for _ in range(N_ITER_MAX): 58 | key = DG.NextKeyRandList(0) 59 | img, annotation = DG[key] 60 | # img = img.astype(np.uint8) 61 | annotation = annotation.astype(np.uint8) 62 | height = img.shape[0] 63 | width = img.shape[1] 64 | 65 | original_images.append((img, annotation)) 66 | 67 | img_raw = img.tostring() 68 | annotation_raw = annotation.tostring() 69 | 70 | example = tf.train.Example(features=tf.train.Features(feature={ 71 | 'height': _int64_feature(height), 72 | 'width': _int64_feature(width), 73 | 'image_raw': _bytes_feature(img_raw), 74 | 'mask_raw': _bytes_feature(annotation_raw)})) 75 | 76 | writer.write(example.SerializeToString()) 77 | else: 78 | for _ in range(N_ITER_MAX): 79 | key = DG.NextKeyRandList(0) 80 | img, annotation = DG[key] 81 | # img = img.astype(np.uint8) 82 | annotation = annotation.astype(np.uint8) 83 | height_img = img.shape[0] 84 | width_img = img.shape[1] 85 | 86 | height_mask = annotation.shape[0] 87 | width_mask = annotation.shape[1] 88 | 89 | original_images.append((img, annotation)) 90 | 91 | img_raw = img.tostring() 92 | annotation_raw = annotation.tostring() 93 | 94 | example = tf.train.Example(features=tf.train.Features(feature={ 95 | 'height_img': _int64_feature(height_img), 96 | 'width_img': _int64_feature(width_img), 97 | 'height_mask': _int64_feature(height_mask), 98 | 'width_mask': _int64_feature(width_mask), 99 | 'image_raw': _bytes_feature(img_raw), 100 | 'mask_raw': _bytes_feature(annotation_raw)})) 101 | 102 | writer.write(example.SerializeToString()) 103 | 104 | 105 | writer.close() 106 | -------------------------------------------------------------------------------- /src_RealData/Data/DataGenClass.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from DataGenRandomT import DataGenRandomT 5 | import nibabel as ni 6 | from skimage import measure 7 | from scipy.ndimage.morphology import morphological_gradient 8 | from scipy import misc 9 | from utils import generate_wsl 10 | 11 | def Contours(bin_image, contour_size=3): 12 | """ 13 | Finds contours of binary images. 14 | """ 15 | grad = morphological_gradient(bin_image, size=(contour_size, contour_size)) 16 | return grad 17 | 18 | class DataGen3(DataGenRandomT): 19 | """ 20 | DG object of 3 class object. Background, inner cell and contour of cell. 21 | """ 22 | def LoadGT(self, path): 23 | image = ni.load(path) 24 | img = image.get_data() 25 | img = measure.label(img) 26 | wsl = generate_wsl(img[:,:,0]) 27 | img[ img > 0 ] = 1 28 | wsl[ wsl > 0 ] = 1 29 | img[:,:,0] = img[:,:,0] - wsl 30 | if len(img.shape) == 3: 31 | img = img[:, :, 0].transpose() 32 | else: 33 | img = img.transpose() 34 | cell_border = Contours(img, contour_size=3) 35 | img[cell_border > 0] = 2 36 | return img 37 | 38 | class DataGenMulti(DataGenRandomT): 39 | """ 40 | DG object that just reads png from files. 41 | """ 42 | def LoadGT(self, path): 43 | image = misc.imread(path.replace(".nii.gz", ".png")) 44 | return image 45 | 46 | class DataGen3reduce(DataGenRandomT): 47 | """ 48 | DG object that aggregates classes above 4 to be 5. 49 | """ 50 | def LoadGT(self, path): 51 | image = misc.imread(path.replace(".nii.gz", ".png")) 52 | image[image > 4] = 5 53 | return image -------------------------------------------------------------------------------- /src_RealData/Data/DataGenRandomT.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from DataGen2 import DataGen 5 | import glob 6 | import itertools 7 | from random import shuffle, sample, seed 8 | import numpy as np 9 | 10 | 11 | class DataGenRandomT(DataGen): 12 | """ 13 | Other DG object with better random key generators. 14 | """ 15 | def SetPatient(self, num): 16 | """ 17 | Set and sorts patients for train and test. 18 | """ 19 | if isinstance(num, list): 20 | test_patient = num 21 | else: 22 | test_patient = [num] 23 | 24 | train_patient = [el for el in self.patient_num if el not in test_patient] 25 | 26 | if self.split == "train": 27 | 28 | images_train = [len(glob.glob(self.path + "/Slide_{}".format(el) + "/*.png")) for el in train_patient] 29 | self.length = np.sum(images_train) * self.crop 30 | self.patients_iter = train_patient 31 | 32 | else: 33 | 34 | images_test = [len(glob.glob(self.path + "/Slide_{}".format(el) + "/*.png")) for el in test_patient] 35 | self.length = np.sum(images_test) * self.crop 36 | self.patients_iter = test_patient 37 | 38 | self.SetRandomList() 39 | 40 | 41 | def GeneratePossibleKeys(self): 42 | """ 43 | Generates all possible keys. 44 | """ 45 | len_key = 4 46 | 47 | AllPossibleKeys = [] 48 | i = 0 49 | 50 | for num in self.patients_iter: 51 | lists = ([i],) 52 | i += 1 53 | nber_per_patient = len(self.patient_img[num]) 54 | lists += (range(nber_per_patient),) 55 | lists += ([-1],) 56 | lists += (range(self.crop),) 57 | 58 | AllPossibleKeys += list(itertools.product(*lists)) 59 | 60 | return AllPossibleKeys 61 | 62 | 63 | def SetRandomList(self): 64 | """ 65 | Sets random list 66 | """ 67 | 68 | RandomList = self.GeneratePossibleKeys() 69 | shuffle(RandomList) 70 | self.RandomList = RandomList 71 | self.key_iter = 0 72 | 73 | 74 | def NextKeyRandList(self, key): 75 | """ 76 | Returns next key in random list. 77 | """ 78 | 79 | if not hasattr(self, "RandomList"): 80 | self.SetRandomList() 81 | self.key_iter = 0 82 | else: 83 | self.key_iter += 1 84 | if self.key_iter == self.length: 85 | self.key_iter = 0 86 | 87 | a, b, c, d = self.RandomList[self.key_iter] 88 | c = np.random.randint(0, self.n_trans) 89 | 90 | 91 | return (a, b, c, d) 92 | 93 | def SortPatients(self): 94 | """ 95 | Sorts patients for train and test set. 96 | """ 97 | if self.seed is not None: 98 | seed(self.seed) 99 | 100 | n = len(self.patient_num) 101 | test_patient = sample(self.patient_num, self.leave_out) 102 | train_patient = [el for el in self.patient_num if el not in test_patient] 103 | number_of_transforms = len(self.transforms) 104 | 105 | if self.split == "train": 106 | 107 | train_images = [len(glob.glob(self.path + "/Slide_{}".format(el) + "/*.png")) for el in train_patient] 108 | self.length = np.sum(train_images) * self.crop 109 | self.patients_iter = train_patient 110 | 111 | else: 112 | test_images = [len(glob.glob(self.path + "/Slide_{}".format(el) + "/*.png")) for el in test_patient] 113 | self.length = np.sum(test_images) * self.crop 114 | self.patients_iter = test_patient 115 | 116 | -------------------------------------------------------------------------------- /src_RealData/Data/FIMM_histo/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src_RealData/Data/FIMM_histo/deconvolution.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from numpy import linalg 3 | #import vigra 4 | 5 | from optparse import OptionParser 6 | import os 7 | import sys 8 | import numpy as np 9 | 10 | 11 | class Deconvolution(object): 12 | 13 | def __init__(self): 14 | self.params = { 15 | 'image_type': 'HEDab' 16 | } 17 | 18 | return 19 | 20 | def log_transform(self, colorin): 21 | res = - 255.0 / numpy.log(256.0) * numpy.log((colorin + 1) / 256.0) 22 | res[res < 0] = 0.0 23 | res[res > 255.0] = 255.0 24 | return res 25 | 26 | def exp_transform(self, colorin): 27 | res = numpy.exp((255 - colorin) * numpy.log(255) / 255) 28 | res[res < 0] = 0.0 29 | res[res > 255.0] = 255.0 30 | return res 31 | 32 | def colorDeconv(self, imin): 33 | M_h_e_dab_meas = numpy.array([[0.650, 0.072, 0.268], 34 | [0.704, 0.990, 0.570], 35 | [0.286, 0.105, 0.776]]) 36 | 37 | # [H,E] 38 | M_h_e_meas = numpy.array([[0.644211, 0.092789], 39 | [0.716556, 0.954111], 40 | [0.266844, 0.283111]]) 41 | 42 | if self.params['image_type'] == "HE": 43 | # print "HE stain" 44 | M = M_h_e_meas 45 | M_inv = numpy.dot(linalg.inv(numpy.dot(M.T, M)), M.T) 46 | 47 | elif self.params['image_type'] == "HEDab": 48 | # print "HEDab stain" 49 | M = M_h_e_dab_meas 50 | M_inv = linalg.inv(M) 51 | 52 | else: 53 | # print "Unrecognized image type !! image type set to \"HE\" " 54 | M = numpy.diag([1, 1, 1]) 55 | M_inv = numpy.diag([1, 1, 1]) 56 | 57 | imDecv = numpy.dot(self.log_transform(imin.astype('float')), M_inv.T) 58 | imout = self.exp_transform(imDecv) 59 | 60 | return imout 61 | 62 | def colorDeconvHE(self, imin): 63 | """ 64 | Does the opposite of colorDeconv 65 | """ 66 | M_h_e_dab_meas = numpy.array([[0.650, 0.072, 0.268], 67 | [0.704, 0.990, 0.570], 68 | [0.286, 0.105, 0.776]]) 69 | 70 | # [H,E] 71 | M_h_e_meas = numpy.array([[0.644211, 0.092789], 72 | [0.716556, 0.954111], 73 | [0.266844, 0.283111]]) 74 | 75 | if self.params['image_type'] == "HE": 76 | # print "HE stain" 77 | M = M_h_e_meas 78 | 79 | elif self.params['image_type'] == "HEDab": 80 | # print "HEDab stain" 81 | M = M_h_e_dab_meas 82 | 83 | else: 84 | # print "Unrecognized image type !! image type set to \"HE\" " 85 | M = numpy.diag([1, 1, 1]) 86 | M_inv = numpy.diag([1, 1, 1]) 87 | 88 | imDecv = numpy.dot(self.log_transform(imin.astype('float')), M.T) 89 | imout = self.exp_transform(imDecv) 90 | # imout = numpy.zeros(imDecv.shape, dtype = numpy.uint8) 91 | 92 | # Normalization 93 | # for i in range(imout.shape[-1]): 94 | # toto = imDecv[:,:,i] 95 | # vmax = toto.max() 96 | # vmin = toto.min() 97 | # if (vmax - vmin) < 0.0001: 98 | # continue 99 | # titi = (toto - vmin) / (vmax - vmin) * 255 100 | # titi = titi.astype(numpy.uint8) 101 | # imout[:,:,i] = titi 102 | 103 | return imout 104 | 105 | 106 | # DISABLED VIGRA AS IT IS NOT INSTALLED ON KEPLER 107 | 108 | 109 | # def __call__(self, filename, out_path): 110 | # if not os.path.exists(out_path): 111 | # os.makedirs(out_path) 112 | # print 'made %s' % out_path 113 | 114 | # colorin = vigra.readImage(filename) 115 | # filename_base, extension = os.path.splitext(os.path.basename(filename)) 116 | 117 | # col_dec = self.colorDeconv(colorin) 118 | # channels = ['h', 'e', 'dab'] 119 | # for i in range(3): 120 | # new_filename = os.path.join(out_path, 121 | # filename_base + '__%s' % channels[i] + extension) 122 | # vigra.impex.writeImage(col_dec[:,:,i], new_filename) 123 | # print 'written %s' % new_filename 124 | 125 | # return 126 | 127 | # def generate_dec_crops(self, filename, out_path, crop_size=1024, nb_positions=1): 128 | # if not os.path.exists(out_path): 129 | # os.makedirs(out_path) 130 | # print 'made %s' % out_path 131 | 132 | # colorin = vigra.readImage(filename) 133 | # width = colorin.shape[0] 134 | # height = colorin.shape[1] 135 | 136 | # frow = np.sqrt(nb_positions) 137 | # if np.abs(int(frow) - frow) > 1e-10: 138 | # raise ValueError('number of positions needs to be squared.') 139 | # frow = int(frow) 140 | 141 | # if frow*crop_size > width or frow*crop_size > height: 142 | # print 'crop_size is too large (exceeds image dimensions) ... skipping %s' % filename 143 | # return 144 | 145 | # offset_x = (width - frow*crop_size) / 2 146 | # offset_y = (height - frow*crop_size) / 2 147 | 148 | # filename_base, extension = os.path.splitext(os.path.basename(filename)) 149 | 150 | # col_dec = self.colorDeconv(colorin) 151 | # channels = ['h', 'e', 'dab'] 152 | 153 | # for i in range(3): 154 | 155 | # position = 1 156 | # for y in range(frow): 157 | # for x in range(frow): 158 | 159 | # ref_img = col_dec[offset_x+x*crop_size:offset_x+(x+1)*crop_size, 160 | # offset_y+y*crop_size:offset_y+(y+1)*crop_size, i] 161 | 162 | # new_filename = os.path.join(out_path, 163 | # '%s__P%05i__%s%s' % (filename_base, position, channels[i], extension)) 164 | # vigra.impex.writeImage(ref_img, new_filename) 165 | # print 'written %s' % new_filename 166 | # position += 1 167 | 168 | # return 169 | 170 | # def process_folder(self, input_folder, output_folder, crop_size, nb_positions): 171 | # image_names = filter(lambda x: os.path.splitext(x)[-1].lower() in ['.tiff', '.png', '.tif'], os.listdir(input_folder)) 172 | # for image_name in image_names: 173 | # full_filename = os.path.join(input_folder, image_name) 174 | # self.generate_dec_crops(full_filename, output_folder, crop_size, nb_positions) 175 | 176 | # return 177 | 178 | 179 | # if __name__ == "__main__": 180 | 181 | # description =\ 182 | # ''' 183 | # %prog - running segmentation tool . 184 | # ''' 185 | 186 | # parser = OptionParser(usage="usage: %prog [options]", 187 | # description=description) 188 | 189 | # parser.add_option("-i", "--input_folder", dest="input_folder", 190 | # help="Input folder") 191 | # parser.add_option("-o", "--output_folder", dest="output_folder", 192 | # help="Output folder") 193 | # parser.add_option("--crop", dest="crop", 194 | # help="Crop size") 195 | # parser.add_option("--fields", dest="fields", 196 | # help="number of fields to generate (there will be fields of 197 | # size , centered in the image") 198 | 199 | 200 | # (options, args) = parser.parse_args() 201 | 202 | # if (options.input_folder is None) or (options.output_folder is None): 203 | # parser.error("incorrect number of arguments!") 204 | 205 | # if options.crop is None: 206 | # crop_size = 1024 207 | # else: 208 | # crop_size = int(options.crop) 209 | 210 | # if options.fields is None: 211 | # fields = 1 212 | # else: 213 | # fields = int(options.fields) 214 | 215 | # dec = Deconvolution() 216 | # dec.process_folder(options.input_folder, options.output_folder, crop_size, fields) 217 | # print 'DONE' 218 | -------------------------------------------------------------------------------- /src_RealData/Data/FIMM_histo/post_analysis.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | import cellh5 4 | 5 | from optparse import OptionParser 6 | import numpy as np 7 | import operator 8 | 9 | import pdb 10 | 11 | class SimpleAnalyzer(object): 12 | def __init__(self, input_folder, output_folder): 13 | print 'SimpleAnalyzer' 14 | self.input_folder = input_folder 15 | self.output_folder = output_folder 16 | if not os.path.exists(self.output_folder): 17 | os.makedirs(self.output_folder) 18 | print 'made %s' % self.output_folder 19 | return 20 | 21 | def get_positions(self, in_folder): 22 | filenames = filter(lambda x: os.path.splitext(x)[-1].lower() == '.ch5' and x[0] != '_', 23 | os.listdir(in_folder)) 24 | if len(filenames) < 1: 25 | raise ValueError('no ch5 files found in %s' % in_folder) 26 | return None 27 | positions = {} 28 | for fn in filenames: 29 | well, pos = os.path.splitext(fn)[0].split('_') 30 | if not well in positions: 31 | positions[well] = [] 32 | positions[well].append(pos) 33 | return positions 34 | 35 | def get_first_pos(self, in_folder): 36 | filenames = filter(lambda x: os.path.splitext(x)[-1].lower() == '.ch5' and x[0] != '_', 37 | os.listdir(in_folder)) 38 | if len(filenames) > 0: 39 | return filenames[0] 40 | else: 41 | raise ValueError('no ch5 files found in %s' % in_folder) 42 | 43 | return None 44 | 45 | def get_counts(self, plate=None, channel='merged', region='histonucgrey-expanded-expanded', 46 | result_folder=None): 47 | 48 | if result_folder is None: 49 | result_folder = self.input_folder 50 | if plate is None: 51 | plate_folder = result_folder 52 | else: 53 | plate_folder = os.path.join(result_folder, plate) 54 | 55 | ch5_folder = os.path.join(plate_folder, 'hdf5') 56 | if not os.path.isdir(ch5_folder): 57 | print 'problem with folder settings: hdf5-folder for plate %s not found' % plate 58 | print 'given folder name: %s' % ch5_folder 59 | 60 | cm_general = cellh5.CH5MappedFile(os.path.join(ch5_folder, self.get_first_pos(ch5_folder))) 61 | positions = self.get_positions(ch5_folder) 62 | 63 | classes = cm_general.class_definition('%s__%s' % (channel, region)) 64 | class_colors = [x[-1] for x in classes] 65 | class_names = [x[1] for x in classes] 66 | class_labels = [x[0] for x in classes] 67 | 68 | res = {} 69 | 70 | #when _allpositions workes: for well, poslist in cm_general.positions.iteritems(): 71 | for well, poslist in positions.iteritems(): 72 | res[well] = dict(zip(class_names, [0 for x in classes])) 73 | 74 | for pos in poslist: 75 | cm = cellh5.CH5MappedFile(os.path.join(ch5_folder, '%05i_%02i.ch5' % (int(well), int(pos)) )) 76 | 77 | # dirty hack: 78 | #pdb.set_trace() 79 | lw = cm.positions.keys()[0] 80 | lpl = cm.positions[lw] 81 | #pos_obj = cm.get_position(well, pos) 82 | 83 | pos_obj = cm.get_position(lw, lpl[0]) 84 | 85 | if len(pos_obj['object']['%s__%s' % (channel, region)]) == 0: 86 | cm.close() 87 | continue 88 | 89 | # attention: the function get_class_prediction gives back an index, not a label. 90 | # the labels start with 1 and the index starts with 0. 91 | # but the labels can also be in completely different order. 92 | predictions = [class_names[x[0]] for x in pos_obj.get_class_prediction('%s__%s' % (channel, region))] 93 | 94 | for pred in predictions: 95 | res[well][pred] += 1 96 | 97 | cm.close() 98 | 99 | cm_general.close() 100 | 101 | return res 102 | 103 | def get_classes(self, plate=None): 104 | if plate is None: 105 | plate_folder = result_folder 106 | else: 107 | plate_folder = os.path.join(result_folder, plate) 108 | ch5_folder = os.path.join(plate_folder, 'hdf5') 109 | if not os.path.isdir(ch5_folder): 110 | print 'problem with folder settings: hdf5-folder for plate %s not found' % plate 111 | 112 | cm = cellh5.CH5MappedFile(os.path.join(ch5_folder, self.get_first_pos(ch5_folder))) 113 | 114 | classes = cm.class_definition('merged__histonucgrey-expanded-expanded') 115 | class_colors = [x[-1] for x in classes] 116 | class_names = [x[1] for x in classes] 117 | class_labels = [x[0] for x in classes] 118 | 119 | cm.close() 120 | 121 | return classes 122 | 123 | def export_predictions(self, predictions, filename): 124 | print self.output_folder 125 | print filename 126 | fp = open(os.path.join(self.output_folder, filename), 'w') 127 | exp_id = predictions.keys()[0] 128 | phenos = sorted(predictions[exp_id].keys()) 129 | title = '\t'.join([exp_id] + phenos) 130 | fp.write(title + '\n') 131 | for exp_id in predictions: 132 | temp_str = '\t'.join([exp_id] + ['%i' % predictions[exp_id][pheno] for pheno in phenos]) 133 | fp.write(temp_str + '\n') 134 | fp.close() 135 | return 136 | 137 | def mitocheck_analysis(resD=None): 138 | phenos = ['max_Prometaphase', 'max_Metaphase', 'max_MetaphaseAlignment'] 139 | 140 | if resD is None: 141 | filename = '/Users/twalter/data/mitocheck_results_primary_screen/meta_2007_11_06/id_result_file_essential.pickle' 142 | fp = open(filename, 'r') 143 | resD = pickle.load(fp) 144 | fp.close() 145 | 146 | gene_list = ['ATM','BARD1','BRCA1','BRCA2','BRIP1','CASP8','CDH1','CHEK2', 147 | 'CTLA4','CYP19A1','FGFR2','H19','LSP1','MAP3K1','MRE11A','NBN', 148 | 'PALB2','PTEN','RAD51','RAD51C','STK11','TERT','TOX3','TP53','XRCC2','XRCC3'] 149 | for gene in gene_list: 150 | print 151 | print ' **************************** ' 152 | print gene 153 | id_list = filter(lambda x: resD[x]['geneName'].lower()==gene.lower(), resD.keys()) 154 | for exp_id in id_list: 155 | tempStr = '%s\t%s\t%s: ' % (gene, resD[exp_id]['sirnaId'], exp_id) 156 | for pheno in phenos: 157 | tempStr += ' %s: %.4f' % (pheno, resD[exp_id][pheno]) 158 | print tempStr 159 | 160 | return 161 | 162 | def mitocheck_analysis2(resD, filename_hit_table): 163 | phenos = ['max_Metaphase', 'max_Prometaphase', 'max_Apoptosis', 164 | 'max_MetaphaseAlignment'] 165 | sirnas = {} 166 | for exp_id in resD.keys(): 167 | sirna = resD[exp_id]['sirnaId'] 168 | if not sirna in sirnas: 169 | sirnas[sirna] = {'gene': resD[exp_id]['geneName'], 170 | 'idL': []} 171 | sirnas[sirna]['idL'].append(exp_id) 172 | 173 | scores = {} 174 | for sirna in sirnas: 175 | scores[sirna] = {'gene': sirnas[sirna]['gene']} 176 | for pheno in phenos: 177 | scores[sirna][pheno] = np.median([resD[x][pheno] for x in sirnas[sirna]['idL']]) 178 | 179 | # hit lists: 180 | for pheno in phenos: 181 | print 182 | print 183 | print ' ******************************************************** ' 184 | print pheno 185 | score_list = [(sirna, scores[sirna]['gene'], scores[sirna][pheno]) for sirna in scores.keys()] 186 | score_list.sort(key=operator.itemgetter(-1), reverse=True) 187 | for i in range(10): 188 | sirna = score_list[i][0] 189 | print '%s\t%s\t%.5f' % (scores[sirna]['gene'], sirna, scores[sirna][pheno] ) 190 | 191 | threshD = {'max_Prometaphase': 0.06, 192 | 'max_Metaphase': 0.03, 193 | 'max_MetaphaseAlignment': 0.06} 194 | 195 | hit_table = {} 196 | for pheno in threshD.keys(): 197 | sirnas = filter(lambda x: scores[x][pheno] > threshD[pheno], scores.keys()) 198 | for sirna in sirnas: 199 | if sirna in hit_table: 200 | continue 201 | hit_table[sirna] = { 202 | 'gene': scores[sirna]['gene'] 203 | } 204 | for pheno in phenos: 205 | hit_table[sirna][pheno] = scores[sirna][pheno] 206 | 207 | hit_table2 = {} 208 | for sirna in hit_table: 209 | gene = hit_table[sirna]['gene'] 210 | if not gene in hit_table2: 211 | hit_table2[gene] = {} 212 | for pheno in phenos: 213 | hit_table2[gene][pheno] = hit_table[sirna][pheno] 214 | else: 215 | for pheno in phenos: 216 | hit_table2[gene][pheno] = max(hit_table2[gene][pheno], 217 | hit_table[sirna][pheno]) 218 | 219 | # export hit_table: 220 | fp = open(filename_hit_table, 'w') 221 | tempStr = '\t'.join(['gene'] + phenos) 222 | fp.write(tempStr + '\n') 223 | for gene in hit_table2: 224 | tempStr = '\t'.join([gene] + 225 | ['%.5f' % hit_table2[gene][pheno] for pheno in phenos]) 226 | fp.write(tempStr + '\n') 227 | fp.close() 228 | 229 | return 230 | 231 | 232 | if __name__ == "__main__": 233 | 234 | description =\ 235 | ''' 236 | %prog - running segmentation tool . 237 | ''' 238 | 239 | parser = OptionParser(usage="usage: %prog [options]", 240 | description=description) 241 | 242 | parser.add_option("-i", "--input_folder", dest="input_folder", 243 | help="Input folder (raw data)") 244 | parser.add_option("-o", "--output_folder", dest="output_folder", 245 | help="Output folder (properly adjusted images)") 246 | 247 | 248 | (options, args) = parser.parse_args() 249 | si = SimpleAnalyzer(options.input_folder, options.output_folder) 250 | predictions = si.get_counts() 251 | filename = options.input_folder.split('/')[-1] + '.txt' 252 | si.export_predictions(predictions, filename) 253 | 254 | -------------------------------------------------------------------------------- /src_RealData/Data/FIMM_histo/preparation.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | import numpy as np 4 | 5 | import vigra 6 | from optparse import OptionParser 7 | 8 | import pdb 9 | 10 | def get_extensions(in_folder): 11 | 12 | max_width = 0 13 | max_height = 0 14 | image_names = os.listdir(in_folder) 15 | image_names = sorted(filter(lambda x: os.path.splitext(x)[-1].lower() in ['.tif', '.tiff', '.png', '.jpg'], image_names)) 16 | 17 | for image_name in image_names: 18 | img = vigra.readImage(os.path.join(in_folder, image_name)) 19 | width = img.shape[0] 20 | height = img.shape[1] 21 | print '%s: %i, %i' % (image_name, width, height) 22 | 23 | max_width = max(width, max_width) 24 | max_height = max(height, max_height) 25 | 26 | print 'maximal extensions: %i, %i' % (max_width, max_height) 27 | return max_width, max_height 28 | 29 | def get_corner_color(colorin, width): 30 | #avg_color = np.array([ np.mean(colorin[0:width,0:width,i]) for i in range(3)]) 31 | avg_color = np.mean(np.mean(colorin[:width, :width], axis=1), axis=0) 32 | return avg_color 33 | 34 | def adjust_images(in_folder, out_folder, max_width, max_height): 35 | image_names = os.listdir(in_folder) 36 | image_names = sorted(filter(lambda x: os.path.splitext(x)[-1].lower() in ['.tif', '.tiff', '.png', '.jpg'], image_names)) 37 | 38 | ref_img = vigra.RGBImage((max_width, max_height)) 39 | 40 | if not os.path.exists(out_folder): 41 | os.makedirs(out_folder) 42 | print 'made %s' % out_folder 43 | for image_name in image_names: 44 | img = vigra.readImage(os.path.join(in_folder, image_name)) 45 | width = img.shape[0] 46 | height = img.shape[1] 47 | 48 | if width <= max_width and height <=max_height: 49 | avg_color = get_corner_color(img, 5) 50 | ref_img[:,:,:] = avg_color 51 | 52 | offset_x = (max_width - width) / 2 53 | offset_y = (max_height - height) / 2 54 | 55 | ref_img[offset_x:offset_x + width, offset_y:offset_y + height, :] = img 56 | 57 | elif width > max_width and height > max_height: 58 | # in this case, we have a crop situation 59 | offset_x = (width - max_width) / 2 60 | offset_y = (height - max_height) / 2 61 | 62 | ref_img = img[offset_x:offset_x + max_width, 63 | offset_y:offset_y + max_height, :] 64 | 65 | 66 | # export 67 | filename = os.path.join(out_folder, image_name) 68 | vigra.impex.writeImage(ref_img, filename) 69 | 70 | return 71 | 72 | 73 | 74 | if __name__ == "__main__": 75 | 76 | description =\ 77 | ''' 78 | %prog - running segmentation tool . 79 | ''' 80 | 81 | parser = OptionParser(usage="usage: %prog [options]", 82 | description=description) 83 | 84 | parser.add_option("-i", "--input_folder", dest="input_folder", 85 | help="Input folder (raw data)") 86 | parser.add_option("-o", "--output_folder", dest="output_folder", 87 | help="Output folder (properly adjusted images)") 88 | parser.add_option("--max_width", dest="max_width", 89 | help="Maximal width (if not given, it is the taken as the max width of the images in the input folder") 90 | parser.add_option("--max_height", dest="max_height", 91 | help="Maximal height (if not given, it is the taken as the max height of the images in the input folder") 92 | 93 | (options, args) = parser.parse_args() 94 | 95 | if (options.input_folder is None) or (options.output_folder is None): 96 | parser.error("incorrect number of arguments!") 97 | 98 | print 99 | print ' ******************* ' 100 | print 'getting the maximal width and maximal height' 101 | if (options.max_width is None) or (options.max_height is None): 102 | max_width, max_height = get_extensions(options.input_folder) 103 | 104 | if not options.max_width is None: 105 | max_width = int(options.max_width) 106 | 107 | if not options.max_height is None: 108 | max_height = int(options.max_height) 109 | 110 | print 111 | print ' ******************* ' 112 | print 'adjusting the images' 113 | adjust_images(options.input_folder, options.output_folder, 114 | max_width, max_height) 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /src_RealData/Data/FIMM_histo/segmentation_test.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import pdb 3 | 4 | import numpy as np 5 | 6 | import vigra 7 | from rdflib.plugins.parsers.pyRdfa.transform.prototype import pref 8 | from Finder.Type_Definitions import preferences 9 | 10 | test_in_folder = '/Users/twalter/data/FIMM_histopath/test_in' 11 | test_out_folder = '/Users/twalter/data/FIMM_histopath/test_out' 12 | 13 | import skimage 14 | import skimage.io 15 | 16 | from skimage.feature import blob_doh 17 | from skimage.morphology import disk 18 | from skimage import morphology 19 | from skimage.filters import rank 20 | from skimage import filters 21 | from skimage import color 22 | from skimage import restoration 23 | from skimage.measure import label 24 | from skimage import measure 25 | from skimage.morphology import watershed 26 | from skimage.feature import peak_local_max 27 | 28 | from scipy import ndimage as ndi 29 | 30 | 31 | from math import sqrt 32 | 33 | import matplotlib.pyplot as plt 34 | 35 | from cecog import ccore 36 | 37 | class Segmentation(object): 38 | def __init__(self): 39 | print 'SEGMENTATION' 40 | if not os.path.exists(test_out_folder): 41 | os.makedirs(test_out_folder) 42 | print 'made %s' % test_out_folder 43 | self.image_name = 'H1' 44 | return 45 | 46 | def read_H_image(self, image_name='H1'): 47 | filename = os.path.join(test_in_folder, '%s.png' % image_name) 48 | img = skimage.io.imread(filename) 49 | self.image_name = image_name 50 | return img 51 | 52 | def overlay(self, img, imbin, contour=False): 53 | colim = color.gray2rgb(img) 54 | colorvalue = (0, 100, 200) 55 | if contour: 56 | se = morphology.diamond(2) 57 | ero = morphology.erosion(imbin, se) 58 | grad = imbin - ero 59 | colim[grad > 0] = colorvalue 60 | else: 61 | colim[imbin>0] = colorvalue 62 | 63 | return colim 64 | 65 | def output_blob_detection(self, img, blobs): 66 | colim = color.gray2rgb(img) 67 | 68 | for blob in blobs: 69 | x, y, r = blob 70 | 71 | rr, cc = skimage.draw.circle(x,y,r) 72 | colorvalue = (255, 0, 0) 73 | 74 | if np.min(rr) < 0 or np.min(cc) < 0 or np.max(rr) >= img.shape[0] or np.max(cc) >= img.shape[1]: 75 | continue 76 | 77 | for i, col in enumerate(colorvalue): 78 | colim[rr,cc,i] = col 79 | 80 | return colim 81 | 82 | def blob_detection(self, img): 83 | blobs = blob_doh(img, max_sigma=80, threshold=.001) 84 | return blobs 85 | 86 | def difference_of_gaussian(self, imin, bigsize=30.0, smallsize=3.0): 87 | g1 = filters.gaussian_filter(imin, bigsize) 88 | g2 = filters.gaussian_filter(imin, smallsize) 89 | diff = 255*(g1 - g2) 90 | 91 | diff[diff < 0] = 0.0 92 | diff[diff > 255.0] = 255.0 93 | diff = diff.astype(np.uint8) 94 | 95 | return diff 96 | 97 | def test_dog(self, bigsize=30.0, smallsize=3.0, addon='none'): 98 | img = self.read_H_image() 99 | diff = self.difference_of_gaussian(img, bigsize, smallsize) 100 | filename = os.path.join(test_out_folder, 'dog_%s_%s.png' % (self.image_name, addon)) 101 | skimage.io.imsave(filename, diff) 102 | 103 | diff = self.difference_of_gaussian(-img, bigsize, smallsize) 104 | filename = os.path.join(test_out_folder, 'invdog_%s_%s.png' % (self.image_name, addon)) 105 | skimage.io.imsave(filename, diff) 106 | 107 | return 108 | 109 | def test_blob_detection(self): 110 | img = self.read_H_image() 111 | blobs = self.blob_detection(img) 112 | color_out = self.output_blob_detection(img, blobs) 113 | filename = os.path.join(test_out_folder, 'blobs_%s.png' % self.image_name) 114 | skimage.io.imsave(filename, color_out) 115 | 116 | return 117 | 118 | def morpho_rec(self, img, size=10): 119 | # internal gradient of the cells: 120 | se = morphology.diamond(size) 121 | ero = morphology.erosion(img, se) 122 | rec = morphology.reconstruction(ero, img, method='dilation').astype(np.dtype('uint8')) 123 | 124 | return rec 125 | 126 | def morpho_rec2(self, img, size=10): 127 | # internal gradient of the cells: 128 | se = morphology.diamond(size) 129 | dil = morphology.dilation(img, se) 130 | rec = morphology.reconstruction(dil, img, method='erosion').astype(np.dtype('uint8')) 131 | 132 | return rec 133 | 134 | def test_morpho(self): 135 | 136 | img = self.read_H_image() 137 | 138 | pref = self.morpho_rec(img, 10) 139 | filename = os.path.join(test_out_folder, 'rec_%s.png' % self.image_name) 140 | skimage.io.imsave(filename, pref) 141 | 142 | diff = self.difference_of_gaussian(img, 50.0, 2.0) 143 | filename = os.path.join(test_out_folder, 'recdog_%s_%s.png' % (self.image_name, '40_1')) 144 | skimage.io.imsave(filename, diff) 145 | 146 | return 147 | 148 | 149 | def test_morpho2(self, bigsize=20.0, smallsize=3.0, threshold=5.0): 150 | 151 | img = self.read_H_image() 152 | 153 | pref = self.morpho_rec(img, 10) 154 | filename = os.path.join(test_out_folder, 'morpho_00_rec_%s.png' % self.image_name) 155 | skimage.io.imsave(filename, pref) 156 | 157 | res = self.difference_of_gaussian(pref, bigsize, smallsize) 158 | filename = os.path.join(test_out_folder, 'morpho_01_diff_%s_%i_%i.png' % (self.image_name, int(bigsize), int(smallsize))) 159 | skimage.io.imsave(filename, res) 160 | 161 | #res = self.morpho_rec2(diff, 15) 162 | #filename = os.path.join(test_out_folder, 'morpho_02_rec_%s.png' % self.image_name) 163 | #skimage.io.imsave(filename, res) 164 | 165 | res[res>threshold] = 255 166 | filename = os.path.join(test_out_folder, 'morpho_03_res_%s_%i.png' % (self.image_name, threshold)) 167 | skimage.io.imsave(filename, res) 168 | 169 | se = morphology.diamond(3) 170 | ero = morphology.erosion(res, se) 171 | filename = os.path.join(test_out_folder, 'morpho_03_ero_%s_%i.png' % (self.image_name, threshold)) 172 | skimage.io.imsave(filename, ero) 173 | res[ero>0] = 0 174 | 175 | overlay_img = self.overlay(img, res) 176 | filename = os.path.join(test_out_folder, 'morpho_04_overlay_%s_%i.png' % (self.image_name, int(threshold))) 177 | skimage.io.imsave(filename, overlay_img) 178 | 179 | return 180 | 181 | def get_rough_detection(self, img, bigsize=40.0, smallsize=4.0, thresh = 0): 182 | diff = self.difference_of_gaussian(-img, bigsize, smallsize) 183 | diff[diff>thresh] = 1 184 | 185 | se = morphology.square(4) 186 | ero = morphology.erosion(diff, se) 187 | 188 | labimage = label(ero) 189 | #rec = morphology.reconstruction(ero, img, method='dilation').astype(np.dtype('uint8')) 190 | 191 | # connectivity=1 corresponds to 4-connectivity. 192 | morphology.remove_small_objects(labimage, min_size=600, connectivity=1, in_place=True) 193 | #res = np.zeros(img.shape) 194 | ero[labimage==0] = 0 195 | ero = 1 - ero 196 | labimage = label(ero) 197 | morphology.remove_small_objects(labimage, min_size=400, connectivity=1, in_place=True) 198 | ero[labimage==0] = 0 199 | res = 1 - ero 200 | res[res>0] = 255 201 | 202 | #temp = 255 - temp 203 | #temp = morphology.remove_small_objects(temp, min_size=400, connectivity=1, in_place=True) 204 | #res = 255 - temp 205 | 206 | return res 207 | 208 | def test_rough_detection(self, bigsize=40.0, smallsize=4.0, thresh = 0): 209 | 210 | img = self.read_H_image() 211 | rough = self.get_rough_detection(img, bigsize, smallsize, thresh) 212 | 213 | colorim = self.overlay(img, rough) 214 | filename = os.path.join(test_out_folder, 'roughdetection_%s.png' % (self.image_name)) 215 | skimage.io.imsave(filename, colorim) 216 | filename = os.path.join(test_out_folder, 'roughdetection_original_%s.png' % (self.image_name)) 217 | skimage.io.imsave(filename, img) 218 | 219 | return 220 | 221 | def prefilter(self, img, rec_size=20, se_size=3): 222 | 223 | se = morphology.disk(se_size) 224 | 225 | im1 = self.morpho_rec(img, rec_size) 226 | im2 = self.morpho_rec2(im1, int(rec_size / 2)) 227 | 228 | im3 = morphology.closing(im2, se) 229 | 230 | return im3 231 | 232 | def prefilter_new(self, img, rec_size=20, se_size=3): 233 | 234 | img_cc = ccore.numpy_to_image(img, copy=True) 235 | im1 = ccore.diameter_open(img_cc, rec_size, 8) 236 | im2 = ccore.diameter_close(im1, int(rec_size / 2), 8) 237 | 238 | #im1 = self.morpho_rec(img, rec_size) 239 | #im2 = self.morpho_rec2(im1, int(rec_size / 2)) 240 | 241 | se = morphology.disk(se_size) 242 | im3 = morphology.closing(im2.toArray(), se) 243 | 244 | return im3 245 | 246 | def h_minima(self, img, h): 247 | img_shift = img.copy() 248 | img_shift[img_shift >= 255 - h] = 255-h 249 | img_shift = img_shift + h 250 | rec = morphology.reconstruction(img_shift, img, method='erosion').astype(np.dtype('uint8')) 251 | diff = rec - img 252 | return diff 253 | 254 | def diameter_close(self, img, max_size): 255 | img_cc = ccore.numpy_to_image(img, copy=True) 256 | res_cc = ccore.diameter_close(img_cc, max_size, 8) 257 | res = res_cc.toArray() 258 | 259 | return res 260 | 261 | def diameter_tophat(self, img, max_size): 262 | img_cc = ccore.numpy_to_image(img, copy=True) 263 | diam_close_cc = ccore.diameter_close(img_cc, max_size, 8) 264 | diam_close = diam_close_cc.toArray() 265 | res = diam_close - img 266 | return res 267 | 268 | def split(self, img, imbin, alpha=0.5, dynval=2): 269 | pdb.set_trace() 270 | 271 | img_cc = ccore.numpy_to_image(img, copy=True) 272 | imbin_cc = ccore.numpy_to_image(imbin.astype(np.dtype('uint8')), copy=True) 273 | 274 | # inversion 275 | imbin_inv = ccore.linearRangeMapping(imbin_cc, 255, 0, 0, 255) 276 | 277 | # distance function of the inverted image 278 | imDist = ccore.distanceTransform(imbin_inv, 2) 279 | 280 | # gradient of the image 281 | imGrad = ccore.externalGradient(img_cc, 1, 8) 282 | 283 | im1 = imDist.toArray() 284 | im2 = imGrad.toArray() 285 | im1 = im1.astype(np.dtype('float32')) 286 | im2 = im2.astype(np.dtype('float32')) 287 | 288 | temp = im1 + alpha * im2 289 | minval = temp.min() 290 | maxval = temp.max() 291 | 292 | if maxval==minval: 293 | return 294 | 295 | temp = 254 / (maxval - minval) * (temp - minval) 296 | temp = temp.astype(np.dtype('uint8')) 297 | temp_cc = ccore.numpy_to_image(temp, copy=True) 298 | 299 | ws = ccore.watershed_dynamic_split(temp_cc, dynval) 300 | res = ccore.infimum(ws, imbin_cc) 301 | 302 | return res 303 | 304 | 305 | def test_current(self, threshold1=4, threshold2=10): 306 | 307 | img = self.read_H_image() 308 | 309 | pref = self.prefilter(img, 20, 5) 310 | 311 | filename = os.path.join(test_out_folder, 'current_01_prefilter_%s.png' % self.image_name) 312 | skimage.io.imsave(filename, pref) 313 | 314 | diff1 = self.h_minima(pref, h=15) 315 | filename = os.path.join(test_out_folder, 'current_02_h_tophat_%s.png' % self.image_name) 316 | skimage.io.imsave(filename, 4*diff1) 317 | 318 | diff2 = self.diameter_tophat(pref, 80) 319 | filename = os.path.join(test_out_folder, 'current_03_diam_tophat_%s.png' % self.image_name) 320 | skimage.io.imsave(filename, diff2) 321 | 322 | res1 = np.zeros(diff1.shape) 323 | res1[diff1>threshold1] = 255 324 | 325 | res2 = np.zeros(diff2.shape) 326 | res2[diff2>threshold2] = 255 327 | 328 | overlay_img = self.overlay(img, res1, contour=True) 329 | filename = os.path.join(test_out_folder, 'current_04_overlay_h_tophat_%s.png' % self.image_name) 330 | skimage.io.imsave(filename, overlay_img) 331 | 332 | overlay_img = self.overlay(img, res2, contour=True) 333 | filename = os.path.join(test_out_folder, 'current_05_overlay_diamthresh_%s.png' % self.image_name) 334 | skimage.io.imsave(filename, overlay_img) 335 | 336 | res = res1 337 | res[res2>0] = 255 338 | overlay_img = self.overlay(img, res, contour=True) 339 | filename = os.path.join(test_out_folder, 'current_06_overlay_all_%s.png' % self.image_name) 340 | skimage.io.imsave(filename, overlay_img) 341 | 342 | filename = os.path.join(test_out_folder, 'current_07_original_%s.png' % self.image_name) 343 | skimage.io.imsave(filename, img) 344 | 345 | res_final = self.split(pref, res) 346 | 347 | # prefiltering removing bright structures inside the cells (opening by reconstruction) 348 | #pref = self.morpho_rec(img, 10) 349 | #filename = os.path.join(test_out_folder, 'current_01_%s.png' % self.image_name) 350 | #skimage.io.imsave(filename, pref) 351 | 352 | #res = self.difference_of_gaussian(pref, bigsize, smallsize) 353 | #filename = os.path.join(test_out_folder, 'current_02_%s_%i_%i.png' % (self.image_name, int(bigsize), int(smallsize))) 354 | #skimage.io.imsave(filename, res) 355 | return 356 | 357 | 358 | 359 | if __name__ == "__main__": 360 | 361 | description =\ 362 | ''' 363 | %prog - running segmentation tool . 364 | ''' 365 | 366 | segmentation = Segmentation() 367 | segmentation.test_current() 368 | 369 | #segmentation.test_rough_detection() 370 | 371 | #segmentation.test_blob_detection() 372 | #segmentation.test_dog(30.0, 3.0, addon='30_3') 373 | #segmentation.test_dog(10.0, 1.0, addon='10_1') 374 | #segmentation.test_morpho2() 375 | #parser = OptionParser(usage="usage: %prog [options]", 376 | # description=description) 377 | 378 | #parser.add_option("-i", "--input_image", dest="input_image", 379 | # help="Input image") 380 | #parser.add_option("-o", "--output_folder", dest="output_folder", 381 | # help="Output folder") 382 | 383 | #(options, args) = parser.parse_args() 384 | 385 | # if (options.input_image is None) or (options.output_folder is None): 386 | # parser.error("incorrect number of arguments!") 387 | # 388 | # dec = Deconvolution() 389 | # dec(options.input_image, options.output_folder) 390 | 391 | print 'DONE' 392 | 393 | 394 | 395 | -------------------------------------------------------------------------------- /src_RealData/Data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeterJackNaylor/DRFNS/73fc5683db5e9f860846e22c8c0daf73b7103082/src_RealData/Data/__init__.py -------------------------------------------------------------------------------- /src_RealData/Data/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | from skimage.morphology import dilation, erosion, square 6 | 7 | 8 | def generate_wsl(ws): 9 | """ 10 | Generates watershed line. In particular, useful for seperating object 11 | in ground thruth as they are labeled by different intergers. 12 | """ 13 | se = square(3) 14 | ero = ws.copy() 15 | ero[ero == 0] = ero.max() + 1 16 | ero = erosion(ero, se) 17 | ero[ws == 0] = 0 18 | 19 | grad = dilation(ws, se) - ero 20 | grad[ws == 0] = 0 21 | grad[grad > 0] = 255 22 | grad = grad.astype(np.uint8) 23 | return grad -------------------------------------------------------------------------------- /src_RealData/Dist.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from utils import GetOptions, ComputeMetrics, CheckOrCreate 5 | import tensorflow as tf 6 | import numpy as np 7 | import os 8 | from glob import glob 9 | from Data.ImageTransform import ListTransform 10 | from Data.DataGenClass import DataGenMulti 11 | from Nets.DataReadDecode import read_and_decode 12 | from Nets.UNetDistance import UNetDistance 13 | 14 | 15 | 16 | class Model(UNetDistance): 17 | def test(self, p1, p2, steps): 18 | """ 19 | How the model tests 20 | """ 21 | loss, roc = 0., 0. 22 | acc, F1, recall = 0., 0., 0. 23 | precision, jac, AJI = 0., 0., 0. 24 | init_op = tf.group(tf.global_variables_initializer(), 25 | tf.local_variables_initializer()) 26 | self.sess.run(init_op) 27 | self.Saver() 28 | coord = tf.train.Coordinator() 29 | threads = tf.train.start_queue_runners(coord=coord) 30 | 31 | for step in range(steps): 32 | feed_dict = {self.is_training: False} 33 | l, prob, batch_labels = self.sess.run([self.loss, self.predictions, 34 | self.train_labels_node], feed_dict=feed_dict) 35 | prob[prob > 255] = 255 36 | prob[prob < 0] = 0 37 | prob = prob.astype(int) 38 | batch_labels[batch_labels > 0] = 255 39 | loss += l 40 | out = ComputeMetrics(prob[0], batch_labels[0], p1, p2) 41 | acc += out[0] 42 | roc += out[1] 43 | jac += out[2] 44 | recall += out[3] 45 | precision += out[4] 46 | F1 += out[5] 47 | AJI += out[6] 48 | coord.request_stop() 49 | coord.join(threads) 50 | loss, acc, F1 = np.array([loss, acc, F1]) / steps 51 | recall, precision, roc = np.array([recall, precision, roc]) / steps 52 | jac, AJI = np.array([jac, AJI]) / steps 53 | return loss, acc, F1, recall, precision, roc, jac, AJI 54 | 55 | def validation(self, DG_TEST, p1, p2, save_path): 56 | """ 57 | How the model validates 58 | """ 59 | n_test = DG_TEST.length 60 | n_batch = int(np.ceil(float(n_test) / self.BATCH_SIZE)) 61 | res = [] 62 | 63 | for i in range(n_batch): 64 | Xval, Yval = DG_TEST.Batch(0, self.BATCH_SIZE) 65 | feed_dict = {self.input_node: Xval, 66 | self.train_labels_node: Yval, 67 | self.is_training: False} 68 | l, pred = self.sess.run([self.loss, self.predictions], 69 | feed_dict=feed_dict) 70 | pred[pred > 255] = 255 71 | pred[pred < 0] = 0 72 | pred = pred.astype(int) 73 | rgb = (Xval[0,92:-92,92:-92] + np.load(self.MEAN_FILE)).astype(np.uint8) 74 | Yval[ Yval > 0 ] = 255 75 | out = ComputeMetrics(pred[0,:,:], Yval[0,:,:,0], p1, p2, rgb=rgb, save_path=save_path, ind=i) 76 | out = [l] + list(out) 77 | res.append(out) 78 | return res 79 | 80 | 81 | 82 | if __name__== "__main__": 83 | 84 | transform_list, transform_list_test = ListTransform(n_elastic=0) 85 | options = GetOptions() 86 | 87 | SPLIT = options.split 88 | 89 | ## Model parameters 90 | TFRecord = options.TFRecord 91 | LEARNING_RATE = options.lr 92 | BATCH_SIZE = options.bs 93 | SIZE = (options.size_train, options.size_train) 94 | if options.size_test is not None: 95 | SIZE = (options.size_test, options.size_test) 96 | N_ITER_MAX = 0 ## defined later 97 | LRSTEP = "10epoch" 98 | N_TRAIN_SAVE = 1000 99 | LOG = options.log 100 | WEIGHT_DECAY = options.weight_decay 101 | N_FEATURES = options.n_features 102 | N_EPOCH = options.epoch 103 | N_THREADS = options.THREADS 104 | MEAN_FILE = options.mean_file 105 | DROPOUT = options.dropout 106 | 107 | ## Datagen parameters 108 | PATH = options.path 109 | 110 | TEST_PATIENT = ["testbreast", "testliver", "testkidney", "testprostate", 111 | "bladder", "colorectal", "stomach", "test"] 112 | DG_TRAIN = DataGenMulti(PATH, split='train', crop = 16, size=SIZE, 113 | transforms=transform_list, UNet=True, mean_file=None, num=TEST_PATIENT) 114 | TEST_PATIENT = ["test"] 115 | DG_TEST = DataGenMulti(PATH, split="test", crop = 1, size=(500, 500), 116 | transforms=transform_list_test, UNet=True, mean_file=MEAN_FILE, num=TEST_PATIENT) 117 | if SPLIT == "train": 118 | N_ITER_MAX = N_EPOCH * DG_TRAIN.length // BATCH_SIZE 119 | elif SPLIT == "test": 120 | N_ITER_MAX = N_EPOCH * DG_TEST.length // BATCH_SIZE 121 | elif SPLIT == "validation": 122 | LOG = glob(os.path.join(LOG, '*'))[0] 123 | model = Model(TFRecord, LEARNING_RATE=LEARNING_RATE, 124 | BATCH_SIZE=BATCH_SIZE, 125 | IMAGE_SIZE=SIZE, 126 | NUM_CHANNELS=3, 127 | STEPS=N_ITER_MAX, 128 | LRSTEP=LRSTEP, 129 | N_PRINT=N_TRAIN_SAVE, 130 | LOG=LOG, 131 | SEED=42, 132 | WEIGHT_DECAY=WEIGHT_DECAY, 133 | N_FEATURES=N_FEATURES, 134 | N_EPOCH=N_EPOCH, 135 | N_THREADS=N_THREADS, 136 | MEAN_FILE=MEAN_FILE, 137 | DROPOUT=DROPOUT) 138 | if SPLIT == "train": 139 | model.train(DG_TEST) 140 | elif SPLIT == "test": 141 | p1 = options.p1 142 | p2 = options.p2 143 | file_name = options.output 144 | f = open(file_name, 'w') 145 | outs = model.test(p1, p2, N_ITER_MAX) 146 | outs = [LOG] + list(outs) + [p1, p2] 147 | NAMES = ["ID", "Loss", "Acc", "F1", "Recall", "Precision", "ROC", "Jaccard", "AJI", "p1", "p2"] 148 | f.write('{},{},{},{},{},{},{},{},{},{},{}\n'.format(*NAMES)) 149 | 150 | f.write('{},{},{},{},{},{},{},{},{},{},{}\n'.format(*outs)) 151 | 152 | elif SPLIT == "validation": 153 | 154 | TEST_PATIENT = ["testbreast", "testliver", "testkidney", "testprostate", 155 | "bladder", "colorectal", "stomach"] 156 | 157 | file_name = options.output 158 | f = open(file_name, 'w') 159 | NAMES = ["NUMBER", "ORGAN", "Loss", "Acc", "ROC", "Jaccard", "Recall", "Precision", "F1", "AJI", "p1", "p2"] 160 | f.write('{},{},{},{},{},{},{},{},{},{},{}\n'.format(*NAMES)) 161 | 162 | 163 | for organ in TEST_PATIENT: 164 | DG_TEST = DataGenMulti(PATH, split="test", crop = 1, size=(996, 996),num=[organ], 165 | transforms=transform_list_test, UNet=True, mean_file=MEAN_FILE) 166 | save_organ = os.path.join(options.save_path, organ) 167 | CheckOrCreate(save_organ) 168 | outs = model.validation(DG_TEST, options.p1, options.p2, save_organ) 169 | for i in range(len(outs)): 170 | small_o = outs[i] 171 | small_o = [i, organ] + small_o + [options.p1, options.p2] 172 | f.write('{},{},{},{},{},{},{},{},{},{},{}\n'.format(*small_o)) 173 | f.close() 174 | 175 | -------------------------------------------------------------------------------- /src_RealData/FCN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from FCN_Object import FCN8 5 | from utils import GetOptions, ComputeMetrics, CheckOrCreate 6 | import tensorflow as tf 7 | import numpy as np 8 | import os 9 | from Data.ImageTransform import ListTransform 10 | from Nets.DataReadDecode import read_and_decode 11 | from Data.DataGenClass import DataGenMulti 12 | 13 | if __name__== "__main__": 14 | 15 | options = GetOptions() 16 | 17 | SPLIT = options.split 18 | 19 | ## Model parameters 20 | TFRecord = options.TFRecord 21 | LEARNING_RATE = options.lr 22 | SIZE = (options.size_train, options.size_train) 23 | if options.size_test is not None: 24 | SIZE = (options.size_test, options.size_test) 25 | N_ITER_MAX = options.iters 26 | N_TRAIN_SAVE = 1000 27 | MEAN_FILE = options.mean_file 28 | save_dir = options.log 29 | checkpoint = options.restore 30 | 31 | model = FCN8(checkpoint, save_dir, TFRecord, SIZE[0], 32 | 2, 1000, options.split) 33 | 34 | if SPLIT == "train": 35 | model.train8(N_ITER_MAX, LEARNING_RATE) 36 | elif SPLIT == "test": 37 | p1 = options.p1 38 | LOG = options.log 39 | 40 | file_name = options.output 41 | f = open(file_name, 'w') 42 | 43 | checkpoint = os.path.join(checkpoint, checkpoint) 44 | outs = model.test8(N_ITER_MAX, checkpoint, p1) 45 | outs = [LOG] + list(outs) + [p1, 0.5] 46 | NAMES = ["ID", "Loss", "Acc", "F1", "Recall", "Precision", "ROC", "Jaccard", "AJI", "p1", "p2"] 47 | f.write('{},{},{},{},{},{},{},{},{},{},{}\n'.format(*NAMES)) 48 | 49 | f.write('{},{},{},{},{},{},{},{},{},{},{}\n'.format(*outs)) 50 | elif SPLIT == "validation": 51 | 52 | TEST_PATIENT = ["testbreast", "testliver", "testkidney", "testprostate", 53 | "bladder", "colorectal", "stomach"] 54 | 55 | file_name = options.output 56 | f = open(file_name, 'w') 57 | NAMES = ["NUMBER", "ORGAN", "Loss", "Acc", "ROC", "Jaccard", "Recall", "Precision", "F1", "AJI", "p1", "p2"] 58 | f.write('{},{},{},{},{},{},{},{},{},{},{}\n'.format(*NAMES)) 59 | transform_list, transform_list_test = ListTransform() 60 | PATH = options.path 61 | for organ in TEST_PATIENT: 62 | DG_TEST = DataGenMulti(PATH, split="test", crop = 1, size=(1000, 1000),num=[organ], 63 | transforms=transform_list_test, UNet=False, mean_file=None) 64 | save_organ = os.path.join(options.save_path, organ) 65 | CheckOrCreate(save_organ) 66 | outs = model.validation(DG_TEST, 2, options.p1, 0.5, save_organ) 67 | for i in range(len(outs)): 68 | small_o = outs[i] 69 | small_o = [i, organ] + small_o + [options.p1, 0.5] 70 | f.write('{},{},{},{},{},{},{},{},{},{},{}\n'.format(*small_o)) 71 | f.close() 72 | -------------------------------------------------------------------------------- /src_RealData/FCN_Object.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import tensorflow as tf 5 | import copy 6 | from tf_image_segmentation.models.fcn_8s import FCN_8s 7 | from tf_image_segmentation.utils.tf_records import read_tfrecord_and_decode_into_image_annotation_pair_tensors 8 | from tf_image_segmentation.utils.training import get_valid_logits_and_labels 9 | from tf_image_segmentation.utils.inference import adapt_network_for_any_size_input 10 | from tf_image_segmentation.utils.visualization import visualize_segmentation_adaptive 11 | from tf_image_segmentation.utils.augmentation import (distort_randomly_image_color, 12 | flip_randomly_left_right_image_with_annotation, 13 | scale_randomly_image_with_annotation_with_fixed_size_output) 14 | import os 15 | from glob import glob 16 | import numpy as np 17 | from scipy import misc 18 | slim = tf.contrib.slim 19 | from utils import ComputeMetrics 20 | 21 | class FCN8(): 22 | """ 23 | FCN8 object for performing training, testing and validating. 24 | """ 25 | def __init__(self, checkpoint, save_dir, record, 26 | size, num_labels, n_print, split='train'): 27 | 28 | self.checkpoint8 = checkpoint 29 | self.Setuped = False 30 | self.checkpointnew = save_dir 31 | self.record = record 32 | self.size = size 33 | self.num_labels = num_labels 34 | self.n_print = n_print 35 | self.setup_record() 36 | self.class_labels = range(num_labels) 37 | self.class_labels.append(255) 38 | if split != "validation": 39 | self.fcn_8s_checkpoint_path = glob(self.checkpoint8 + "/*.data*")[0].split(".data")[0] 40 | else: 41 | self.setup_val(record) 42 | def setup_record(self): 43 | """ 44 | Setup record reading. 45 | """ 46 | filename_queue = tf.train.string_input_producer( 47 | [self.record], num_epochs=10) 48 | 49 | self.image, self.annotation = read_tfrecord_and_decode_into_image_annotation_pair_tensors(filename_queue) 50 | self.resized_image, resized_annotation = scale_randomly_image_with_annotation_with_fixed_size_output(self.image, self.annotation, (self.size, self.size)) 51 | self.resized_annotation = tf.squeeze(resized_annotation) 52 | def setup_train8(self, lr): 53 | """ 54 | Setups queues and model evaluation. 55 | """ 56 | 57 | 58 | image_batch, annotation_batch = tf.train.shuffle_batch( [self.resized_image, self.resized_annotation], 59 | batch_size=1, 60 | capacity=3000, 61 | num_threads=2, 62 | min_after_dequeue=1000) 63 | 64 | 65 | upsampled_logits_batch, fcn_8s_variables_mapping = FCN_8s(image_batch_tensor=image_batch, 66 | number_of_classes=self.num_labels, 67 | is_training=True) 68 | 69 | valid_labels_batch_tensor, valid_logits_batch_tensor = get_valid_logits_and_labels(annotation_batch_tensor=annotation_batch, 70 | logits_batch_tensor=upsampled_logits_batch, 71 | class_labels=self.class_labels) 72 | 73 | 74 | # Count true positives, true negatives, false positives and false negatives. 75 | actual = tf.contrib.layers.flatten(tf.cast(annotation_batch, tf.int64)) 76 | 77 | self.predicted_img = tf.argmax(upsampled_logits_batch, axis=3) 78 | cross_entropies = tf.nn.softmax_cross_entropy_with_logits(logits=valid_logits_batch_tensor, 79 | labels=valid_labels_batch_tensor) 80 | self.cross_entropy_sum = tf.reduce_mean(cross_entropies) 81 | 82 | pred = tf.argmax(upsampled_logits_batch, dimension=3) 83 | 84 | probabilities = tf.nn.softmax(upsampled_logits_batch) 85 | 86 | 87 | with tf.variable_scope("adam_vars"): 88 | self.train_step = tf.train.AdamOptimizer(learning_rate=lr).minimize(self.cross_entropy_sum) 89 | 90 | 91 | # Variable's initialization functions 92 | 93 | self.init_fn = slim.assign_from_checkpoint_fn(model_path=self.fcn_8s_checkpoint_path, 94 | var_list=fcn_8s_variables_mapping) 95 | 96 | global_vars_init_op = tf.global_variables_initializer() 97 | 98 | self.merged_summary_op = tf.summary.merge_all() 99 | 100 | self.summary_string_writer = tf.summary.FileWriter("./log_8_{}".format(lr)) 101 | 102 | if not os.path.exists(self.checkpointnew): 103 | os.makedirs(self.checkpointnew) 104 | 105 | #The op for initializing the variables. 106 | local_vars_init_op = tf.local_variables_initializer() 107 | 108 | self.combined_op = tf.group(local_vars_init_op, global_vars_init_op) 109 | 110 | # We need this to save only model variables and omit 111 | # optimization-related and other variables. 112 | model_variables = slim.get_model_variables() 113 | self.saver = tf.train.Saver(model_variables) 114 | 115 | def train8(self, iters, lr): 116 | """ 117 | Trains the model. 118 | """ 119 | 120 | 121 | self.setup_train8(lr) 122 | model_save = os.path.join(self.checkpointnew, self.checkpointnew) 123 | with tf.Session() as sess: 124 | 125 | sess.run(self.combined_op) 126 | self.init_fn(sess) 127 | 128 | coord = tf.train.Coordinator() 129 | threads = tf.train.start_queue_runners(coord=coord) 130 | 131 | # 10 epochs 132 | for i in xrange(0, iters): 133 | cross_entropy, summary_string, _ = sess.run([self.cross_entropy_sum, 134 | self.merged_summary_op, 135 | self.train_step ]) 136 | 137 | if i % self.n_print == 0: 138 | self.summary_string_writer.add_summary(summary_string, i) 139 | save_path = self.saver.save(sess, model_save) 140 | print("Model saved in file: %s" % save_path) 141 | 142 | 143 | coord.request_stop() 144 | coord.join(threads) 145 | 146 | save_path = self.saver.save(sess, model_save) 147 | print("Model saved in file: %s" % save_path) 148 | 149 | 150 | def test8(self, steps, restore, p1): 151 | """ 152 | Tests the model. 153 | """ 154 | # Fake batch for image and annotation by adding 155 | # leading empty axis. 156 | image_batch_tensor = tf.expand_dims(self.image, axis=0) 157 | annotation_batch_tensor = tf.expand_dims(self.annotation, axis=0) 158 | # Be careful: after adaptation, network returns final labels 159 | # and not logits 160 | FCN_8s_bis = adapt_network_for_any_size_input(FCN_8s, 32) 161 | 162 | 163 | pred, fcn_16s_variables_mapping = FCN_8s_bis(image_batch_tensor=image_batch_tensor, 164 | number_of_classes=self.num_labels, 165 | is_training=False) 166 | prob = [h for h in [s for s in [t for t in pred.op.inputs][0].op.inputs][0].op.inputs][0] 167 | 168 | initializer = tf.local_variables_initializer() 169 | saver = tf.train.Saver() 170 | loss, roc = 0., 0. 171 | acc, F1, recall = 0., 0., 0. 172 | precision, jac, AJI = 0., 0., 0. 173 | with tf.Session() as sess: 174 | 175 | sess.run(initializer) 176 | saver.restore(sess, restore) 177 | 178 | coord = tf.train.Coordinator() 179 | threads = tf.train.start_queue_runners(coord=coord) 180 | 181 | 182 | for i in xrange(steps): 183 | 184 | image_np, annotation_np, pred_np, prob_np = sess.run([self.image, self.annotation, pred, prob]) 185 | prob_float = np.exp(-prob_np[0,:,:,0]) / (np.exp(-prob_np[0,:,:,0]) + np.exp(-prob_np[0,:,:,1])) 186 | prob_int8 = misc.imresize(prob_float, size=image_np[:,:,0].shape) 187 | 188 | prob_float = (prob_int8.copy().astype(float) / 255) 189 | out = ComputeMetrics(prob_float, annotation_np[:,:,0], p1, 0.5) 190 | acc += out[0] 191 | roc += out[1] 192 | jac += out[2] 193 | recall += out[3] 194 | precision += out[4] 195 | F1 += out[5] 196 | AJI += out[6] 197 | coord.request_stop() 198 | coord.join(threads) 199 | loss, acc, F1 = np.array([loss, acc, F1]) / steps 200 | recall, precision, roc = np.array([recall, precision, roc]) / steps 201 | jac, AJI = np.array([jac, AJI]) / steps 202 | return loss, acc, F1, recall, precision, roc, jac, AJI 203 | def setup_val(self, tfname): 204 | """ 205 | Setups the model in case we need to validate. 206 | """ 207 | self.restore = glob(os.path.join(self.checkpoint8, "FCN__*", "*.data*" ))[0].split(".data")[0] 208 | 209 | filename_queue = tf.train.string_input_producer( 210 | [tfname], num_epochs=10) 211 | self.image_queue, self.annotation_queue = read_tfrecord_and_decode_into_image_annotation_pair_tensors(filename_queue) 212 | self.image = tf.placeholder_with_default(self.image, shape=[None, 213 | None, 214 | 3]) 215 | self.annotation = tf.placeholder_with_default(self.annotation_queue, shape=[None, 216 | None, 217 | 1]) 218 | self.resized_image, resized_annotation = scale_randomly_image_with_annotation_with_fixed_size_output(self.image, self.annotation, (self.size, self.size)) 219 | self.resized_annotation = tf.squeeze(resized_annotation) 220 | image_batch_tensor = tf.expand_dims(self.image, axis=0) 221 | annotation_batch_tensor = tf.expand_dims(self.annotation, axis=0) 222 | # Be careful: after adaptation, network returns final labels 223 | # and not logits 224 | FCN_8s_bis = adapt_network_for_any_size_input(FCN_8s, 32) 225 | self.pred, fcn_16s_variables_mapping = FCN_8s_bis(image_batch_tensor=image_batch_tensor, 226 | number_of_classes=self.num_labels, 227 | is_training=False) 228 | self.prob = [h for h in [s for s in [t for t in self.pred.op.inputs][0].op.inputs][0].op.inputs][0] 229 | initializer = tf.local_variables_initializer() 230 | self.saver = tf.train.Saver() 231 | with tf.Session() as sess: 232 | sess.run(initializer) 233 | self.saver.restore(sess, self.restore) 234 | def validation(self, DG_TEST, steps, p1, p2, save_organ): 235 | """ 236 | Validates the model. 237 | """ 238 | tmp_name = os.path.basename(save_organ) + ".tfRecord" 239 | res = [] 240 | print "DOing this {}".format(save_organ) 241 | with tf.Session() as sess: 242 | sess.run(tf.global_variables_initializer()) 243 | 244 | self.saver.restore(sess, self.restore) 245 | #coord = tf.train.Coordinator() 246 | #threads = tf.train.start_queue_runners(coord=coord) 247 | for i in xrange(steps): 248 | Xval, Yval = DG_TEST.Batch(0, 1) 249 | 250 | feed_dict = {self.image: Xval[0], self.annotation: Yval[0]} 251 | image_np, annotation_np, pred_np, prob_np = sess.run([self.image, self.annotation, self.pred, self.prob], feed_dict=feed_dict) 252 | prob_float = np.exp(-prob_np[0,:,:,0]) / (np.exp(-prob_np[0,:,:,0]) + np.exp(-prob_np[0,:,:,1])) 253 | prob_int8 = misc.imresize(prob_float, size=image_np[:,:,0].shape) 254 | 255 | prob_float = (prob_int8.copy().astype(float) / 255) 256 | out = ComputeMetrics(prob_float, annotation_np[:,:,0], p1, 0.5, rgb=image_np, save_path=save_organ, ind=i) 257 | out = [0.] + list(out) 258 | res.append(out) 259 | return res 260 | def copy(self): 261 | """ 262 | Copies itself and returns a copied version of himself. 263 | """ 264 | return copy.deepcopy(self) 265 | -------------------------------------------------------------------------------- /src_RealData/Nets/DataReadDecode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import tensorflow as tf 5 | 6 | 7 | def read_and_decode(filename_queue, IMAGE_HEIGHT, IMAGE_WIDTH, 8 | BATCH_SIZE, N_THREADS, UNET, CHANNELS=3): 9 | """ 10 | Read and decode images from TFRecord file. 11 | """ 12 | reader = tf.TFRecordReader() 13 | 14 | _, serialized_example = reader.read(filename_queue) 15 | if not UNET: 16 | features = tf.parse_single_example( 17 | serialized_example, 18 | # Defaults are not specified since both keys are required. 19 | features={ 20 | 'height': tf.FixedLenFeature([], tf.int64), 21 | 'width': tf.FixedLenFeature([], tf.int64), 22 | 'image_raw': tf.FixedLenFeature([], tf.string), 23 | 'mask_raw': tf.FixedLenFeature([], tf.string) 24 | }) 25 | 26 | height_img = tf.cast(features['height'], tf.int32) 27 | width_img = tf.cast(features['width'], tf.int32) 28 | 29 | height_mask = height_img 30 | width_mask = width_img 31 | 32 | const_IMG_HEIGHT = IMAGE_HEIGHT 33 | const_IMG_WIDTH = IMAGE_WIDTH 34 | 35 | const_MASK_HEIGHT = IMAGE_HEIGHT 36 | const_MASK_WIDTH = IMAGE_WIDTH 37 | 38 | 39 | else: 40 | features = tf.parse_single_example( 41 | serialized_example, 42 | # Defaults are not specified since both keys are required. 43 | features={ 44 | 'height_img': tf.FixedLenFeature([], tf.int64), 45 | 'width_img': tf.FixedLenFeature([], tf.int64), 46 | 'height_mask': tf.FixedLenFeature([], tf.int64), 47 | 'width_mask': tf.FixedLenFeature([], tf.int64), 48 | 'image_raw': tf.FixedLenFeature([], tf.string), 49 | 'mask_raw': tf.FixedLenFeature([], tf.string) 50 | }) 51 | 52 | height_img = tf.cast(features['height_img'], tf.int32) 53 | width_img = tf.cast(features['width_img'], tf.int32) 54 | 55 | height_mask = tf.cast(features['height_mask'], tf.int32) 56 | width_mask = tf.cast(features['width_mask'], tf.int32) 57 | 58 | const_IMG_HEIGHT = IMAGE_HEIGHT + 184 59 | const_IMG_WIDTH = IMAGE_WIDTH + 184 60 | 61 | const_MASK_HEIGHT = IMAGE_HEIGHT 62 | const_MASK_WIDTH = IMAGE_WIDTH 63 | 64 | 65 | # Convert from a scalar string tensor (whose single string has 66 | # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape 67 | # [mnist.IMAGE_PIXELS]. 68 | image = tf.decode_raw(features['image_raw'], tf.uint8) 69 | annotation = tf.decode_raw(features['mask_raw'], tf.uint8) 70 | 71 | 72 | image_shape = tf.stack([height_img, width_img, CHANNELS]) 73 | annotation_shape = tf.stack([height_mask, width_mask, 1]) 74 | 75 | image = tf.reshape(image, image_shape) 76 | annotation = tf.reshape(annotation, annotation_shape) 77 | 78 | image_size_const = tf.constant((const_IMG_HEIGHT, const_IMG_WIDTH, CHANNELS), dtype=tf.int32) 79 | annotation_size_const = tf.constant((const_MASK_HEIGHT, const_MASK_WIDTH, 1), dtype=tf.int32) 80 | 81 | # Random transformations can be put here: right before you crop images 82 | # to predefined size. To get more information look at the stackoverflow 83 | # question linked above. 84 | image_f = tf.cast(image, tf.float32) 85 | annotation_f = tf.cast(annotation, tf.float32) 86 | 87 | resized_image = tf.image.resize_image_with_crop_or_pad(image=image_f, 88 | target_height=const_IMG_HEIGHT, 89 | target_width=const_IMG_WIDTH) 90 | 91 | resized_annotation = tf.image.resize_image_with_crop_or_pad(image=annotation_f, 92 | target_height=const_MASK_HEIGHT, 93 | target_width=const_MASK_WIDTH) 94 | 95 | images, annotations = tf.train.shuffle_batch( [resized_image, resized_annotation], 96 | batch_size=BATCH_SIZE, 97 | capacity=10 + 3 * BATCH_SIZE, 98 | num_threads=N_THREADS, 99 | min_after_dequeue=10) 100 | 101 | return images, annotations 102 | -------------------------------------------------------------------------------- /src_RealData/Nets/DataTF.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from ObjectOriented import ConvolutionalNeuralNetwork 5 | import os 6 | import tensorflow as tf 7 | from datetime import datetime 8 | import numpy as np 9 | from DataReadDecode import read_and_decode 10 | 11 | 12 | class DataReader(ConvolutionalNeuralNetwork): 13 | """ 14 | Class for TF modules with record read for inputing the images 15 | and labels to the DNN. 16 | """ 17 | def __init__( 18 | self, 19 | TF_RECORDS, 20 | LEARNING_RATE=0.01, 21 | K=0.96, 22 | BATCH_SIZE=10, 23 | IMAGE_SIZE=28, 24 | NUM_LABELS=10, 25 | NUM_CHANNELS=1, 26 | NUM_TEST=10000, 27 | STEPS=2000, 28 | LRSTEP=200, 29 | DECAY_EMA=0.9999, 30 | N_PRINT = 100, 31 | LOG="/tmp/net", 32 | SEED=42, 33 | DEBUG=True, 34 | WEIGHT_DECAY=0.00005, 35 | LOSS_FUNC=tf.nn.l2_loss, 36 | N_FEATURES=16, 37 | N_EPOCH=1, 38 | N_THREADS=1, 39 | MEAN_FILE=None, 40 | DROPOUT=0.5): 41 | 42 | self.LEARNING_RATE = LEARNING_RATE 43 | self.K = K 44 | self.BATCH_SIZE = BATCH_SIZE 45 | self.IMAGE_SIZE = IMAGE_SIZE 46 | self.NUM_LABELS = NUM_LABELS 47 | self.NUM_CHANNELS = NUM_CHANNELS 48 | self.N_FEATURES = N_FEATURES 49 | # self.NUM_TEST = NUM_TEST 50 | self.STEPS = STEPS 51 | self.N_PRINT = N_PRINT 52 | self.LRSTEP = LRSTEP 53 | self.DECAY_EMA = DECAY_EMA 54 | self.LOG = LOG 55 | self.SEED = SEED 56 | self.N_EPOCH = N_EPOCH 57 | self.N_THREADS = N_THREADS 58 | self.DROPOUT = DROPOUT 59 | self.MEAN_FILE = MEAN_FILE 60 | if MEAN_FILE is not None: 61 | MEAN_ARRAY = tf.constant(np.load(MEAN_FILE), dtype=tf.float32) # (3) 62 | self.MEAN_ARRAY = tf.reshape(MEAN_ARRAY, [1, 1, 3]) 63 | self.SUB_MEAN = True 64 | else: 65 | self.SUB_MEAN = False 66 | 67 | self.sess = tf.InteractiveSession() 68 | 69 | self.sess.as_default() 70 | 71 | self.var_to_reg = [] 72 | self.var_to_sum = [] 73 | self.TF_RECORDS = TF_RECORDS 74 | self.init_queue(TF_RECORDS) 75 | 76 | self.init_vars() 77 | self.init_model_architecture() 78 | self.init_training_graph() 79 | self.Saver() 80 | self.DEBUG = DEBUG 81 | self.loss_func = LOSS_FUNC 82 | self.weight_decay = WEIGHT_DECAY 83 | 84 | def init_queue(self, tfrecords_filename): 85 | """ 86 | New queues for coordinator 87 | """ 88 | self.filename_queue = tf.train.string_input_producer( 89 | [tfrecords_filename], num_epochs=10) 90 | with tf.device('/cpu:0'): 91 | self.image, self.annotation = read_and_decode(self.filename_queue, 92 | self.IMAGE_SIZE[0], 93 | self.IMAGE_SIZE[1], 94 | self.BATCH_SIZE, 95 | self.N_THREADS) 96 | print("Queue initialized") 97 | def input_node_f(self): 98 | """ 99 | The input node can now come from the record or can be inputed 100 | via a feed dict (for testing for example) 101 | """ 102 | if self.SUB_MEAN: 103 | self.images_queue = self.image - self.MEAN_ARRAY 104 | else: 105 | self.images_queue = self.image 106 | self.image_PH = tf.placeholder_with_default(self.images_queue, shape=[None, 107 | None, 108 | None, 109 | 3]) 110 | return self.image_PH 111 | def label_node_f(self): 112 | """ 113 | Same for input node f 114 | """ 115 | self.labels_queue = self.annotation 116 | self.labels_PH = tf.placeholder_with_default(self.labels_queue, shape=[None, 117 | None, 118 | None, 119 | 1]) 120 | 121 | return self.labels_PH 122 | def Validation(self, DG_TEST, step): 123 | """ 124 | How the models validates it self with respect to the test set 125 | """ 126 | if DG_TEST is None: 127 | print "no validation" 128 | else: 129 | n_test = DG_TEST.length 130 | n_batch = int(np.ceil(float(n_test) / self.BATCH_SIZE)) 131 | 132 | l, acc, F1, recall, precision, meanacc = 0., 0., 0., 0., 0., 0. 133 | 134 | for i in range(n_batch): 135 | Xval, Yval = DG_TEST.Batch(0, self.BATCH_SIZE) 136 | feed_dict = {self.input_node: Xval, 137 | self.train_labels_node: Yval} 138 | l_tmp, acc_tmp, F1_tmp, recall_tmp, precision_tmp, meanacc_tmp, pred, s = self.sess.run([self.loss, 139 | self.accuracy, self.F1, 140 | self.recall, self.precision, 141 | self.MeanAcc, self.predictions, 142 | self.merged_summary], feed_dict=feed_dict) 143 | l += l_tmp 144 | acc += acc_tmp 145 | F1 += F1_tmp 146 | recall += recall_tmp 147 | precision += precision_tmp 148 | meanacc += meanacc_tmp 149 | 150 | l, acc, F1, recall, precision, meanacc = np.array([l, acc, F1, recall, precision, meanacc]) / n_batch 151 | 152 | summary = tf.Summary() 153 | summary.value.add(tag="TestMan/Accuracy", simple_value=acc) 154 | summary.value.add(tag="TestMan/Loss", simple_value=l) 155 | summary.value.add(tag="TestMan/F1", simple_value=F1) 156 | summary.value.add(tag="TestMan/Recall", simple_value=recall) 157 | summary.value.add(tag="TestMan/Precision", simple_value=precision) 158 | summary.value.add(tag="TestMan/Performance", simple_value=meanacc) 159 | self.summary_test_writer.add_summary(summary, step) 160 | self.summary_test_writer.add_summary(s, step) 161 | print(' Validation loss: %.1f' % l) 162 | print(' Accuracy: %1.f%% \n acc1: %.1f%% \n recall: %1.f%% \n prec: %1.f%% \n f1 : %1.f%% \n' % (acc * 100, meanacc * 100, recall * 100, precision * 100, F1 * 100)) 163 | self.saver.save(self.sess, self.LOG + '/' + "model.ckpt", step) 164 | 165 | def train(self, DG_TEST=None): 166 | """ 167 | How the model trains 168 | """ 169 | 170 | epoch = self.STEPS * self.BATCH_SIZE // self.N_EPOCH 171 | 172 | self.LearningRateSchedule(self.LEARNING_RATE, self.K, epoch) 173 | 174 | trainable_var = tf.trainable_variables() 175 | 176 | self.regularize_model() 177 | self.optimization(trainable_var) 178 | self.ExponentialMovingAverage(trainable_var, self.DECAY_EMA) 179 | 180 | self.summary_test_writer = tf.summary.FileWriter(self.LOG + '/test', 181 | graph=self.sess.graph) 182 | 183 | self.summary_writer = tf.summary.FileWriter(self.LOG + '/train', graph=self.sess.graph) 184 | self.merged_summary = tf.summary.merge_all() 185 | steps = self.STEPS 186 | 187 | init_op = tf.group(tf.global_variables_initializer(), 188 | tf.local_variables_initializer()) 189 | self.sess.run(init_op) 190 | coord = tf.train.Coordinator() 191 | threads = tf.train.start_queue_runners(coord=coord) 192 | 193 | for step in range(steps): 194 | # self.optimizer is replaced by self.training_op for the exponential moving decay 195 | _, l, lr, predictions, batch_labels, s = self.sess.run( 196 | [self.training_op, self.loss, self.learning_rate, 197 | self.train_prediction, self.train_labels_node, 198 | self.merged_summary]) 199 | 200 | if step % self.N_PRINT == 0: 201 | i = datetime.now() 202 | print i.strftime('%Y/%m/%d %H:%M:%S: \n ') 203 | self.summary_writer.add_summary(s, step) 204 | error, acc, acc1, recall, prec, f1 = self.error_rate(predictions, batch_labels, step) 205 | print(' Step %d of %d' % (step, steps)) 206 | print(' Learning rate: %.5f \n') % lr 207 | print(' Mini-batch loss: %.5f \n Accuracy: %.1f%% \n acc1: %.1f%% \n recall: %1.f%% \n prec: %1.f%% \n f1 : %1.f%% \n' % 208 | (l, acc, acc1, recall, prec, f1)) 209 | self.Validation(DG_TEST, step) 210 | coord.request_stop() 211 | coord.join(threads) -------------------------------------------------------------------------------- /src_RealData/Nets/ObjectOriented.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | import os 7 | from sklearn.metrics import confusion_matrix 8 | from datetime import datetime 9 | 10 | class ConvolutionalNeuralNetwork: 11 | """ 12 | Generic object for create DNN models. 13 | This class instinciates all functions 14 | needed for DNN operations. 15 | """ 16 | def __init__( 17 | self, 18 | LEARNING_RATE=0.01, 19 | K=0.96, 20 | BATCH_SIZE=1, 21 | IMAGE_SIZE=28, 22 | NUM_LABELS=10, 23 | NUM_CHANNELS=1, 24 | NUM_TEST=10000, 25 | STEPS=2000, 26 | LRSTEP=200, 27 | DECAY_EMA=0.9999, 28 | N_PRINT = 100, 29 | LOG="/tmp/net", 30 | SEED=42, 31 | DEBUG=True, 32 | WEIGHT_DECAY=0.00005, 33 | LOSS_FUNC=tf.nn.l2_loss, 34 | N_FEATURES=16): 35 | 36 | self.LEARNING_RATE = LEARNING_RATE 37 | self.K = K 38 | self.BATCH_SIZE = BATCH_SIZE 39 | self.IMAGE_SIZE = IMAGE_SIZE 40 | self.NUM_LABELS = NUM_LABELS 41 | self.NUM_CHANNELS = NUM_CHANNELS 42 | self.N_FEATURES = N_FEATURES 43 | # self.NUM_TEST = NUM_TEST 44 | self.STEPS = STEPS 45 | self.N_PRINT = N_PRINT 46 | self.LRSTEP = LRSTEP 47 | self.DECAY_EMA = DECAY_EMA 48 | self.LOG = LOG 49 | self.SEED = SEED 50 | 51 | self.sess = tf.InteractiveSession() 52 | 53 | self.sess.as_default() 54 | 55 | self.var_to_reg = [] 56 | self.var_to_sum = [] 57 | 58 | self.init_vars() 59 | self.init_model_architecture() 60 | self.init_training_graph() 61 | self.Saver() 62 | self.DEBUG = DEBUG 63 | self.loss_func = LOSS_FUNC 64 | self.weight_decay = WEIGHT_DECAY 65 | 66 | def regularize_model(self): 67 | """ 68 | Adds regularization to parameters of the model given LOSS_FUNC 69 | """ 70 | if self.DEBUG: 71 | for var in self.var_to_sum + self.var_to_reg: 72 | self.add_to_summary(var) 73 | self.WritteSummaryImages() 74 | 75 | for var in self.var_to_reg: 76 | self.add_to_regularization(var) 77 | 78 | 79 | 80 | def add_to_summary(self, var): 81 | """ 82 | Adds histogram for each parameter in var 83 | """ 84 | if var is not None: 85 | tf.summary.histogram(var.op.name, var) 86 | 87 | def add_to_regularization(self, var): 88 | """ 89 | Combines loss with regularization loss 90 | """ 91 | if var is not None: 92 | self.loss = self.loss + self.weight_decay * self.loss_func(var) 93 | 94 | 95 | def add_activation_summary(self, var): 96 | """ 97 | Add activation summary with information about sparsity 98 | """ 99 | if var is not None: 100 | tf.summary.histogram(var.op.name + "/activation", var) 101 | tf.summary.scalar(var.op.name + "/sparsity", tf.nn.zero_fraction(var)) 102 | 103 | 104 | def add_gradient_summary(self, grad, var): 105 | """ 106 | Add gradiant summary to summary 107 | """ 108 | if grad is not None: 109 | tf.summary.histogram(var.op.name + "/gradient", grad) 110 | 111 | 112 | def input_node_f(self): 113 | """ 114 | Input node, called when initialising the network 115 | """ 116 | return tf.placeholder( 117 | tf.float32, 118 | shape=(self.BATCH_SIZE, self.IMAGE_SIZE, self.IMAGE_SIZE, self.NUM_CHANNELS)) 119 | 120 | def label_node_f(self): 121 | """ 122 | Label node, called when initialising the network 123 | """ 124 | return tf.placeholder( 125 | tf.float32, 126 | shape=(self.BATCH_SIZE, self.IMAGE_SIZE, self.IMAGE_SIZE, 1)) 127 | 128 | def conv_layer_f(self, i_layer, w_var, strides, scope_name, padding="SAME"): 129 | """ 130 | Defining convolution layer 131 | """ 132 | with tf.name_scope(scope_name): 133 | return tf.nn.conv2d(i_layer, w_var, strides=strides, padding=padding) 134 | 135 | def relu_layer_f(self, i_layer, biases, scope_name): 136 | """ 137 | Defining relu layer 138 | """ 139 | with tf.name_scope(scope_name): 140 | act = tf.nn.relu(tf.nn.bias_add(i_layer, biases)) 141 | self.var_to_sum.append(act) 142 | return act 143 | 144 | def weight_const_f(self, ks, inchannels, outchannels, stddev, scope_name, name="W", reg="True"): 145 | """ 146 | Defining parameter to give to a convolution layer 147 | """ 148 | with tf.name_scope(scope_name): 149 | K = tf.Variable(tf.truncated_normal([ks, ks, inchannels, outchannels], # 5x5 filter, depth 32. 150 | stddev=stddev, 151 | seed=self.SEED)) 152 | self.var_to_reg.append(K) 153 | self.var_to_sum.append(K) 154 | return K 155 | 156 | def weight_xavier(self, ks, inchannels, outchannels, scope_name, name="W"): 157 | """ 158 | Initialises a convolution kernel for a convolution layer with Xavier initialising 159 | """ 160 | xavier_std = np.sqrt( 1. / float(ks * ks * inchannels) ) 161 | return self.weight_const_f(ks, inchannels, outchannels, xavier_std, scope_name, name=name) 162 | 163 | def biases_const_f(self, const, shape, scope_name, name="B"): 164 | """ 165 | Initialises biais 166 | """ 167 | with tf.name_scope(scope_name): 168 | b = tf.Variable(tf.constant(const, shape=[shape]), name=name) 169 | self.var_to_sum.append(b) 170 | return b 171 | 172 | def max_pool(self, i_layer, ksize=[1,2,2,1], strides=[1,2,2,1], 173 | padding="SAME", name="MaxPool"): 174 | """ 175 | Performs max pool operation 176 | """ 177 | return tf.nn.max_pool(i_layer, ksize=ksize, strides=strides, 178 | padding=padding, name=name) 179 | 180 | def BatchNorm(self, Input, n_out, phase_train, scope='bn', decay=0.9, eps=1e-5): 181 | """ 182 | Performs batch normalisation. 183 | Code taken from http://stackoverflow.com/a/34634291/2267819 184 | """ 185 | with tf.name_scope(scope): 186 | init_beta = tf.constant(0.0, shape=[n_out]) 187 | beta = tf.Variable(init_beta, name="beta") 188 | init_gamma = tf.random_normal([n_out], 1.0, 0.02) 189 | gamma = tf.Variable(init_gamma) 190 | batch_mean, batch_var = tf.nn.moments(Input, [0, 1, 2], name='moments') 191 | ema = tf.train.ExponentialMovingAverage(decay=decay) 192 | 193 | def mean_var_with_update(): 194 | ema_apply_op = ema.apply([batch_mean, batch_var]) 195 | with tf.control_dependencies([ema_apply_op]): 196 | return tf.identity(batch_mean), tf.identity(batch_var) 197 | 198 | mean, var = tf.cond(phase_train, 199 | mean_var_with_update, 200 | lambda: (ema.average(batch_mean), ema.average(batch_var))) 201 | normed = tf.nn.batch_normalization(Input, mean, var, beta, gamma, eps) 202 | return normed 203 | 204 | def DropOutLayer(self, Input, scope="DropOut"): 205 | """ 206 | Performs drop out on the input layer 207 | """ 208 | with tf.name_scope(scope): 209 | return tf.nn.dropout(Input, self.keep_prob) ##keep prob has to be defined in init_var 210 | 211 | def init_vars(self): 212 | """ 213 | Initialises variables for the graph 214 | """ 215 | self.input_node = self.input_node_f() 216 | 217 | self.train_labels_node = self.label_node_f() 218 | 219 | self.conv1_weights = self.weight_xavier(5, self.NUM_CHANNELS, 8, "conv1/") 220 | self.conv1_biases = self.biases_const_f(0.1, 8, "conv1/") 221 | 222 | self.conv2_weights = self.weight_xavier(5, 8, 8, "conv2/") 223 | self.conv2_biases = self.biases_const_f(0.1, 8, "conv2/") 224 | 225 | self.conv3_weights = self.weight_xavier(5, 8, 8, "conv3/") 226 | self.conv3_biases = self.biases_const_f(0.1, 8, "conv3/") 227 | 228 | self.logits_weight = self.weight_xavier(1, 8, self.NUM_LABELS, "logits/") 229 | self.logits_biases = self.biases_const_f(0.1, self.NUM_LABELS, "logits/") 230 | 231 | self.keep_prob = tf.Variable(0.5, name="dropout_prob") 232 | 233 | print('Model variables initialised') 234 | 235 | def WritteSummaryImages(self): 236 | """ 237 | Image summary to add to the summary 238 | """ 239 | tf.summary.image("Input", self.input_node, max_outputs=4) 240 | tf.summary.image("Label", self.train_labels_node, max_outputs=4) 241 | tf.summary.image("Pred", tf.expand_dims(tf.cast(self.predictions, tf.float32), dim=3), max_outputs=4) 242 | 243 | def init_model_architecture(self): 244 | """ 245 | Graph structure for the model 246 | """ 247 | self.conv1 = self.conv_layer_f(self.input_node, self.conv1_weights, 248 | [1,1,1,1], "conv1/") 249 | self.relu1 = self.relu_layer_f(self.conv1, self.conv1_biases, "conv1/") 250 | 251 | 252 | self.conv2 = self.conv_layer_f(self.relu1, self.conv2_weights, 253 | [1,1,1,1], "conv2/") 254 | self.relu2 = self.relu_layer_f(self.conv2, self.conv2_biases, "conv2/") 255 | 256 | self.conv3 = self.conv_layer_f(self.relu2, self.conv3_weights, 257 | [1,1,1,1], "conv3/") 258 | self.relu3 = self.relu_layer_f(self.conv3, self.conv3_biases, "conv3/") 259 | 260 | self.last = self.relu3 261 | 262 | print('Model architecture initialised') 263 | 264 | def init_training_graph(self): 265 | """ 266 | Graph optimization part, here we define the loss and how the model is evaluated 267 | """ 268 | 269 | with tf.name_scope('Evaluation'): 270 | self.logits = self.conv_layer_f(self.last, self.logits_weight, strides=[1,1,1,1], scope_name="logits/") 271 | self.predictions = tf.argmax(self.logits, axis=3) 272 | 273 | with tf.name_scope('Loss'): 274 | self.loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, 275 | labels=tf.squeeze(tf.cast(self.train_labels_node, tf.int32), squeeze_dims=[3]), 276 | name="entropy"))) 277 | tf.summary.scalar("entropy", self.loss) 278 | 279 | with tf.name_scope('Accuracy'): 280 | 281 | LabelInt = tf.squeeze(tf.cast(self.train_labels_node, tf.int64), squeeze_dims=[3]) 282 | CorrectPrediction = tf.equal(self.predictions, LabelInt) 283 | self.accuracy = tf.reduce_mean(tf.cast(CorrectPrediction, tf.float32)) 284 | tf.summary.scalar("accuracy", self.accuracy) 285 | 286 | with tf.name_scope('Prediction'): 287 | 288 | self.TP = tf.count_nonzero(self.predictions * LabelInt) 289 | self.TN = tf.count_nonzero((self.predictions - 1) * (LabelInt - 1)) 290 | self.FP = tf.count_nonzero(self.predictions * (LabelInt - 1)) 291 | self.FN = tf.count_nonzero((self.predictions - 1) * LabelInt) 292 | 293 | with tf.name_scope('Precision'): 294 | 295 | self.precision = tf.divide(self.TP, tf.add(self.TP, self.FP)) 296 | tf.summary.scalar('Precision', self.precision) 297 | 298 | with tf.name_scope('Recall'): 299 | 300 | self.recall = tf.divide(self.TP, tf.add(self.TP, self.FN)) 301 | tf.summary.scalar('Recall', self.recall) 302 | 303 | with tf.name_scope('F1'): 304 | 305 | num = tf.multiply(self.precision, self.recall) 306 | dem = tf.add(self.precision, self.recall) 307 | self.F1 = tf.scalar_mul(2, tf.divide(num, dem)) 308 | tf.summary.scalar('F1', self.F1) 309 | 310 | with tf.name_scope('MeanAccuracy'): 311 | 312 | Nprecision = tf.divide(self.TN, tf.add(self.TN, self.FN)) 313 | self.MeanAcc = tf.divide(tf.add(self.precision, Nprecision) ,2) 314 | tf.summary.scalar('Performance', self.MeanAcc) 315 | #self.batch = tf.Variable(0, name = "batch_iterator") 316 | 317 | self.train_prediction = tf.nn.softmax(self.logits) 318 | 319 | self.test_prediction = tf.nn.softmax(self.logits) 320 | 321 | tf.global_variables_initializer().run() 322 | 323 | print('Computational graph initialised') 324 | 325 | 326 | def error_rate(self, predictions, labels, iter): 327 | """ 328 | Operations to perform on the training prediction every N_PRINT iterations. 329 | These values are printed to screen. 330 | """ 331 | predictions = np.argmax(predictions, 3) 332 | labels = labels[:,:,:,0] 333 | 334 | cm = confusion_matrix(labels.flatten(), predictions.flatten(), labels=[0, 1]).astype(np.float) 335 | b, x, y = predictions.shape 336 | total = b * x * y 337 | 338 | TP = cm[1, 1] 339 | TN = cm[0, 0] 340 | FN = cm[0, 1] 341 | FP = cm[1, 0] 342 | 343 | acc = (TP + TN) / (TP + TN + FN + FP) * 100 344 | precision = TP / (TP + FP) 345 | acc1 = np.mean([precision, TN / (TN + FN)]) * 100 346 | recall = TP / (TP + FN) 347 | 348 | F1 = 2 * precision * recall / (recall + precision) 349 | error = 100 - acc 350 | 351 | return error, acc, acc1, recall * 100, precision * 100, F1 * 100 352 | 353 | def optimization(self, var_list): 354 | """ 355 | Defining the optimization method to solve the task 356 | """ 357 | with tf.name_scope('optimizer'): 358 | optimizer = tf.train.AdamOptimizer(self.learning_rate) 359 | grads = optimizer.compute_gradients(self.loss, var_list=var_list) 360 | if self.DEBUG: 361 | for grad, var in grads: 362 | self.add_gradient_summary(grad, var) 363 | self.optimizer = optimizer.apply_gradients(grads, global_step=self.global_step) 364 | 365 | def LearningRateSchedule(self, lr, k, epoch): 366 | """ 367 | Defines the learning rate 368 | """ 369 | with tf.name_scope('LearningRateSchedule'): 370 | self.global_step = tf.Variable(0., trainable=False) 371 | tf.add_to_collection('global_step', self.global_step) 372 | if self.LRSTEP == "epoch/2": 373 | 374 | decay_step = float(epoch) / (2 * self.BATCH_SIZE) 375 | 376 | elif "epoch" in self.LRSTEP: 377 | num = int(self.LRSTEP[:-5]) 378 | decay_step = float(num) * float(epoch) / self.BATCH_SIZE 379 | else: 380 | decay_step = float(self.LRSTEP) 381 | 382 | self.learning_rate = tf.train.exponential_decay( 383 | lr, 384 | self.global_step, 385 | decay_step, 386 | k, 387 | staircase=True) 388 | tf.summary.scalar("learning_rate", self.learning_rate) 389 | 390 | def Validation(self, DG_TEST, step): 391 | """ 392 | How the models validates on the test set. 393 | """ 394 | n_test = DG_TEST.length 395 | n_batch = int(np.ceil(float(n_test) / self.BATCH_SIZE)) 396 | 397 | l, acc, F1, recall, precision, meanacc = 0., 0., 0., 0., 0., 0. 398 | 399 | for i in range(n_batch): 400 | Xval, Yval = DG_TEST.Batch(0, self.BATCH_SIZE) 401 | feed_dict = {self.input_node: Xval, 402 | self.train_labels_node: Yval} 403 | l_tmp, acc_tmp, F1_tmp, recall_tmp, precision_tmp, meanacc_tmp, pred = self.sess.run([self.loss, self.accuracy, self.F1, self.recall, self.precision, self.MeanAcc, self.predictions], feed_dict=feed_dict) 404 | l += l_tmp 405 | acc += acc_tmp 406 | F1 += F1_tmp 407 | recall += recall_tmp 408 | precision += precision_tmp 409 | meanacc += meanacc_tmp 410 | 411 | l, acc, F1, recall, precision, meanacc = np.array([l, acc, F1, recall, precision, meanacc]) / n_batch 412 | 413 | 414 | summary = tf.Summary() 415 | summary.value.add(tag="Test/Accuracy", simple_value=acc) 416 | summary.value.add(tag="Test/Loss", simple_value=l) 417 | summary.value.add(tag="Test/F1", simple_value=F1) 418 | summary.value.add(tag="Test/Recall", simple_value=recall) 419 | summary.value.add(tag="Test/Precision", simple_value=precision) 420 | summary.value.add(tag="Test/Performance", simple_value=meanacc) 421 | self.summary_test_writer.add_summary(summary, step) 422 | 423 | print(' Validation loss: %.1f' % l) 424 | print(' Accuracy: %1.f%% \n acc1: %.1f%% \n recall: %1.f%% \n prec: %1.f%% \n f1 : %1.f%% \n' % (acc * 100, meanacc * 100, recall * 100, precision * 100, F1 * 100)) 425 | self.saver.save(self.sess, self.LOG + '/' + "model.ckpt", global_step=self.global_step) 426 | 427 | def Saver(self): 428 | """ 429 | Defining the saver, it will load if possible. 430 | """ 431 | print("Setting up Saver...") 432 | self.saver = tf.train.Saver() 433 | ckpt = tf.train.get_checkpoint_state(self.LOG) 434 | if ckpt and ckpt.model_checkpoint_path: 435 | self.saver.restore(self.sess, ckpt.model_checkpoint_path) 436 | print("Model restored...") 437 | 438 | def ExponentialMovingAverage(self, var_list, decay=0.9999): 439 | """ 440 | Adding exponential moving average to increase performance. 441 | This aggregates parameters from different steps in order to have 442 | a more robust classifier. 443 | """ 444 | with tf.name_scope('ExponentialMovingAverage'): 445 | 446 | ema = tf.train.ExponentialMovingAverage(decay=decay) 447 | maintain_averages_op = ema.apply(var_list) 448 | 449 | # Create an op that will update the moving averages after each training 450 | # step. This is what we will use in place of the usual training op. 451 | with tf.control_dependencies([self.optimizer]): 452 | self.training_op = tf.group(maintain_averages_op) 453 | 454 | def train(self, DGTrain, DGTest, saver=True): 455 | """ 456 | How the model should train. 457 | """ 458 | 459 | epoch = DGTrain.length 460 | 461 | self.LearningRateSchedule(self.LEARNING_RATE, self.K, epoch) 462 | 463 | trainable_var = tf.trainable_variables() 464 | 465 | self.regularize_model() 466 | self.optimization(trainable_var) 467 | self.ExponentialMovingAverage(trainable_var, self.DECAY_EMA) 468 | 469 | tf.global_variables_initializer().run() 470 | tf.local_variables_initializer().run() 471 | 472 | self.summary_test_writer = tf.summary.FileWriter(self.LOG + '/test', 473 | graph=self.sess.graph) 474 | 475 | self.summary_writer = tf.summary.FileWriter(self.LOG + '/train', graph=self.sess.graph) 476 | merged_summary = tf.summary.merge_all() 477 | steps = self.STEPS 478 | 479 | 480 | for step in range(steps): 481 | batch_data, batch_labels = DGTrain.Batch(0, self.BATCH_SIZE) 482 | feed_dict = {self.input_node: batch_data, 483 | self.train_labels_node: batch_labels} 484 | 485 | # self.optimizer is replaced by self.training_op for the exponential moving decay 486 | _, l, lr, predictions, s = self.sess.run( 487 | [self.training_op, self.loss, self.learning_rate, 488 | self.train_prediction, merged_summary], 489 | feed_dict=feed_dict) 490 | 491 | if step % self.N_PRINT == 0: 492 | i = datetime.now() 493 | print i.strftime('%Y/%m/%d %H:%M:%S: \n ') 494 | self.summary_writer.add_summary(s, step) 495 | error, acc, acc1, recall, prec, f1 = self.error_rate(predictions, batch_labels, step) 496 | print(' Step %d of %d' % (step, steps)) 497 | print(' Learning rate: %.5f \n') % lr 498 | print(' Mini-batch loss: %.5f \n Accuracy: %.1f%% \n acc1: %.1f%% \n recall: %1.f%% \n prec: %1.f%% \n f1 : %1.f%% \n' % 499 | (l, acc, acc1, recall, prec, f1)) 500 | self.Validation(DGTest, step) -------------------------------------------------------------------------------- /src_RealData/Nets/UNetBatchNorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from UNetObject import UNet 3 | import tensorflow as tf 4 | import numpy as np 5 | from datetime import datetime 6 | import os 7 | 8 | class UNetBatchNorm(UNet): 9 | """ 10 | UNet object with batch normlisation after each convolution. 11 | """ 12 | def conv_layer_f(self, i_layer, w_var, scope_name, strides=[1,1,1,1], padding="VALID"): 13 | with tf.name_scope(scope_name): 14 | conv = tf.nn.conv2d(i_layer, w_var, strides=strides, padding=padding) 15 | n_out = w_var.shape[3].value 16 | BN = self.BatchNorm(conv, n_out, self.is_training) 17 | return BN 18 | 19 | 20 | def init_vars(self): 21 | """ 22 | Same architecture as unet except we added the self.is_training for BN. 23 | """ 24 | self.is_training = tf.placeholder_with_default(True, shape=[]) 25 | 26 | self.input_node = self.input_node_f() 27 | 28 | self.train_labels_node = self.label_node_f() 29 | n_features = self.N_FEATURES 30 | 31 | self.conv1_1weights = self.weight_xavier(3, self.NUM_CHANNELS, n_features, "conv1_1/") 32 | self.conv1_1biases = self.biases_const_f(0.1, n_features, "conv1_1/") 33 | 34 | self.conv1_2weights = self.weight_xavier(3, n_features, n_features, "conv1_2/") 35 | self.conv1_2biases = self.biases_const_f(0.1, n_features, "conv1_2/") 36 | 37 | self.conv1_3weights = self.weight_xavier(3, 2 * n_features, n_features, "conv1_3/") 38 | self.conv1_3biases = self.biases_const_f(0.1, n_features, "conv1_3/") 39 | 40 | self.conv1_4weights = self.weight_xavier(3, n_features, n_features, "conv1_4/") 41 | self.conv1_4biases = self.biases_const_f(0.1, n_features, "conv1_4/") 42 | 43 | 44 | 45 | self.conv2_1weights = self.weight_xavier(3, n_features, 2 * n_features, "conv2_1/") 46 | self.conv2_1biases = self.biases_const_f(0.1, 2 * n_features, "conv2_1/") 47 | 48 | self.conv2_2weights = self.weight_xavier(3, 2 * n_features, 2 * n_features, "conv2_2/") 49 | self.conv2_2biases = self.biases_const_f(0.1, 2 * n_features, "conv2_2/") 50 | 51 | self.conv2_3weights = self.weight_xavier(3, 4 * n_features, 2 * n_features, "conv2_3/") 52 | self.conv2_3biases = self.biases_const_f(0.1, 2 * n_features, "conv2_3/") 53 | 54 | self.conv2_4weights = self.weight_xavier(3, 2 * n_features, 2 * n_features, "conv2_4/") 55 | self.conv2_4biases = self.biases_const_f(0.1, 2 * n_features, "conv2_4/") 56 | 57 | 58 | 59 | self.conv3_1weights = self.weight_xavier(3, 2 * n_features, 4 * n_features, "conv3_1/") 60 | self.conv3_1biases = self.biases_const_f(0.1, 4 * n_features, "conv3_1/") 61 | 62 | self.conv3_2weights = self.weight_xavier(3, 4 * n_features, 4 * n_features, "conv3_2/") 63 | self.conv3_2biases = self.biases_const_f(0.1, 4 * n_features, "conv3_2/") 64 | 65 | self.conv3_3weights = self.weight_xavier(3, 8 * n_features, 4 * n_features, "conv3_3/") 66 | self.conv3_3biases = self.biases_const_f(0.1, 4 * n_features, "conv3_3/") 67 | 68 | self.conv3_4weights = self.weight_xavier(3, 4 * n_features, 4 * n_features, "conv3_4/") 69 | self.conv3_4biases = self.biases_const_f(0.1, 4 * n_features, "conv3_4/") 70 | 71 | 72 | 73 | self.conv4_1weights = self.weight_xavier(3, 4 * n_features, 8 * n_features, "conv4_1/") 74 | self.conv4_1biases = self.biases_const_f(0.1, 8 * n_features, "conv4_1/") 75 | 76 | self.conv4_2weights = self.weight_xavier(3, 8 * n_features, 8 * n_features, "conv4_2/") 77 | self.conv4_2biases = self.biases_const_f(0.1, 8 * n_features, "conv4_2/") 78 | 79 | self.conv4_3weights = self.weight_xavier(3, 16 * n_features, 8 * n_features, "conv4_3/") 80 | self.conv4_3biases = self.biases_const_f(0.1, 8 * n_features, "conv4_3/") 81 | 82 | self.conv4_4weights = self.weight_xavier(3, 8 * n_features, 8 * n_features, "conv4_4/") 83 | self.conv4_4biases = self.biases_const_f(0.1, 8 * n_features, "conv4_4/") 84 | 85 | 86 | 87 | self.conv5_1weights = self.weight_xavier(3, 8 * n_features, 16 * n_features, "conv5_1/") 88 | self.conv5_1biases = self.biases_const_f(0.1, 16 * n_features, "conv5_1/") 89 | 90 | self.conv5_2weights = self.weight_xavier(3, 16 * n_features, 16 * n_features, "conv5_2/") 91 | self.conv5_2biases = self.biases_const_f(0.1, 16 * n_features, "conv5_2/") 92 | 93 | 94 | 95 | 96 | self.tconv5_4weights = self.weight_xavier(2, 8 * n_features, 16 * n_features, "tconv5_4/") 97 | self.tconv5_4biases = self.biases_const_f(0.1, 8 * n_features, "tconv5_4/") 98 | 99 | self.tconv4_3weights = self.weight_xavier(2, 4 * n_features, 8 * n_features, "tconv4_3/") 100 | self.tconv4_3biases = self.biases_const_f(0.1, 4 * n_features, "tconv4_3/") 101 | 102 | self.tconv3_2weights = self.weight_xavier(2, 2 * n_features, 4 * n_features, "tconv3_2/") 103 | self.tconv3_2biases = self.biases_const_f(0.1, 2 * n_features, "tconv3_2/") 104 | 105 | self.tconv2_1weights = self.weight_xavier(2, n_features, 2 * n_features, "tconv2_1/") 106 | self.tconv2_1biases = self.biases_const_f(0.1, n_features, "tconv2_1/") 107 | 108 | 109 | 110 | self.logits_weight = self.weight_xavier(1, n_features, self.NUM_LABELS, "logits/") 111 | self.logits_biases = self.biases_const_f(0.1, self.NUM_LABELS, "logits/") 112 | 113 | self.keep_prob = tf.Variable(self.DROPOUT, name="dropout_prob") 114 | 115 | print('Model variables initialised') 116 | 117 | 118 | def Validation(self, DG_TEST, step): 119 | """ 120 | How to validate 121 | """ 122 | if DG_TEST is None: 123 | print "no validation" 124 | else: 125 | n_test = DG_TEST.length 126 | n_batch = int(np.ceil(float(n_test) / 1)) 127 | 128 | l, acc, F1, recall, precision, meanacc = 0., 0., 0., 0., 0., 0. 129 | 130 | for i in range(n_batch): 131 | Xval, Yval = DG_TEST.Batch(0, 1) 132 | feed_dict = {self.input_node: Xval, 133 | self.train_labels_node: Yval, 134 | self.is_training: False} 135 | l_tmp, acc_tmp, F1_tmp, recall_tmp, precision_tmp, meanacc_tmp, s = self.sess.run([self.loss, 136 | self.accuracy, self.F1, 137 | self.recall, self.precision, 138 | self.MeanAcc, 139 | self.merged_summary], feed_dict=feed_dict) 140 | l += l_tmp 141 | acc += acc_tmp 142 | F1 += F1_tmp 143 | recall += recall_tmp 144 | precision += precision_tmp 145 | meanacc += meanacc_tmp 146 | 147 | l, acc, F1, recall, precision, meanacc = np.array([l, acc, F1, recall, precision, meanacc]) / n_batch 148 | 149 | summary = tf.Summary() 150 | summary.value.add(tag="TestMan/Accuracy", simple_value=acc) 151 | summary.value.add(tag="TestMan/Loss", simple_value=l) 152 | summary.value.add(tag="TestMan/F1", simple_value=F1) 153 | summary.value.add(tag="TestMan/Recall", simple_value=recall) 154 | summary.value.add(tag="TestMan/Precision", simple_value=precision) 155 | summary.value.add(tag="TestMan/Performance", simple_value=meanacc) 156 | self.summary_test_writer.add_summary(summary, step) 157 | 158 | self.summary_test_writer.add_summary(s, step) 159 | print(' Validation loss: %.1f' % l) 160 | print(' Accuracy: %1.f%% \n acc1: %.1f%% \n recall: %1.f%% \n prec: %1.f%% \n f1 : %1.f%% \n' % (acc * 100, meanacc * 100, recall * 100, precision * 100, F1 * 100)) 161 | self.saver.save(self.sess, self.LOG + '/' + "model.ckpt", step) 162 | 163 | 164 | def train(self, DG_TEST): 165 | """ 166 | How to train the model 167 | """ 168 | 169 | epoch = self.STEPS * self.BATCH_SIZE // self.N_EPOCH 170 | 171 | self.LearningRateSchedule(self.LEARNING_RATE, self.K, epoch) 172 | 173 | trainable_var = tf.trainable_variables() 174 | 175 | self.regularize_model() 176 | self.optimization(trainable_var) 177 | self.ExponentialMovingAverage(trainable_var, self.DECAY_EMA) 178 | 179 | self.summary_test_writer = tf.summary.FileWriter(self.LOG + '/test', 180 | graph=self.sess.graph) 181 | 182 | self.summary_writer = tf.summary.FileWriter(self.LOG + '/train', graph=self.sess.graph) 183 | self.merged_summary = tf.summary.merge_all() 184 | steps = self.STEPS 185 | 186 | init_op = tf.group(tf.global_variables_initializer(), 187 | tf.local_variables_initializer()) 188 | self.sess.run(init_op) 189 | coord = tf.train.Coordinator() 190 | threads = tf.train.start_queue_runners(coord=coord) 191 | 192 | for step in range(steps): 193 | # self.optimizer is replaced by self.training_op for the exponential moving decay 194 | _, l, lr, predictions, batch_labels, s = self.sess.run( 195 | [self.training_op, self.loss, self.learning_rate, 196 | self.train_prediction, self.train_labels_node, 197 | self.merged_summary]) 198 | 199 | if step % self.N_PRINT == 0: 200 | i = datetime.now() 201 | print i.strftime('%Y/%m/%d %H:%M:%S: \n ') 202 | self.summary_writer.add_summary(s, step) 203 | error, acc, acc1, recall, prec, f1 = self.error_rate(predictions, batch_labels, step) 204 | print(' Step %d of %d' % (step, steps)) 205 | print(' Learning rate: %.5f \n') % lr 206 | print(' Mini-batch loss: %.5f \n Accuracy: %.1f%% \n acc1: %.1f%% \n recall: %1.f%% \n prec: %1.f%% \n f1 : %1.f%% \n' % 207 | (l, acc, acc1, recall, prec, f1)) 208 | self.Validation(DG_TEST, step) 209 | coord.request_stop() 210 | coord.join(threads) -------------------------------------------------------------------------------- /src_RealData/Nets/UNetDistance.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from UNetBatchNorm import UNetBatchNorm 5 | import tensorflow as tf 6 | import numpy as np 7 | from sklearn.metrics import mean_squared_error 8 | from datetime import datetime 9 | from DataReadDecode import read_and_decode 10 | 11 | class UNetDistance(UNetBatchNorm): 12 | def __init__( 13 | self, 14 | TF_RECORDS, 15 | LEARNING_RATE=0.01, 16 | K=0.96, 17 | BATCH_SIZE=10, 18 | IMAGE_SIZE=28, 19 | NUM_CHANNELS=1, 20 | NUM_TEST=10000, 21 | STEPS=2000, 22 | LRSTEP=200, 23 | DECAY_EMA=0.9999, 24 | N_PRINT = 100, 25 | LOG="/tmp/net", 26 | SEED=42, 27 | DEBUG=True, 28 | WEIGHT_DECAY=0.00005, 29 | LOSS_FUNC=tf.nn.l2_loss, 30 | N_FEATURES=16, 31 | N_EPOCH=1, 32 | N_THREADS=1, 33 | MEAN_FILE=None, 34 | DROPOUT=0.5): 35 | 36 | self.LEARNING_RATE = LEARNING_RATE 37 | self.K = K 38 | self.BATCH_SIZE = BATCH_SIZE 39 | self.IMAGE_SIZE = IMAGE_SIZE 40 | self.NUM_CHANNELS = NUM_CHANNELS 41 | self.N_FEATURES = N_FEATURES 42 | self.STEPS = STEPS 43 | self.N_PRINT = N_PRINT 44 | self.LRSTEP = LRSTEP 45 | self.DECAY_EMA = DECAY_EMA 46 | self.LOG = LOG 47 | self.SEED = SEED 48 | self.N_EPOCH = N_EPOCH 49 | self.N_THREADS = N_THREADS 50 | self.DROPOUT = DROPOUT 51 | self.MEAN_FILE = MEAN_FILE 52 | if MEAN_FILE is not None: 53 | MEAN_ARRAY = tf.constant(np.load(MEAN_FILE), dtype=tf.float32) # (3) 54 | self.MEAN_ARRAY = tf.reshape(MEAN_ARRAY, [1, 1, 3]) 55 | self.SUB_MEAN = True 56 | else: 57 | self.SUB_MEAN = False 58 | 59 | self.sess = tf.InteractiveSession() 60 | 61 | self.sess.as_default() 62 | 63 | self.var_to_reg = [] 64 | self.var_to_sum = [] 65 | self.TF_RECORDS = TF_RECORDS 66 | self.init_queue(TF_RECORDS) 67 | 68 | self.init_vars() 69 | self.init_model_architecture() 70 | self.init_training_graph() 71 | self.Saver() 72 | self.DEBUG = DEBUG 73 | self.loss_func = LOSS_FUNC 74 | self.weight_decay = WEIGHT_DECAY 75 | 76 | def init_queue(self, tfrecords_filename): 77 | """ 78 | Added the number of channels to extract to 79 | """ 80 | self.filename_queue = tf.train.string_input_producer( 81 | [tfrecords_filename], num_epochs=10) 82 | with tf.device('/cpu:0'): 83 | self.image, self.annotation = read_and_decode(self.filename_queue, 84 | self.IMAGE_SIZE[0], 85 | self.IMAGE_SIZE[1], 86 | self.BATCH_SIZE, 87 | self.N_THREADS, 88 | True, 89 | self.NUM_CHANNELS) 90 | #self.annotation = tf.divide(self.annotation, 255.) 91 | print("Queue initialized") 92 | 93 | def init_training_graph(self): 94 | """ 95 | Modified the evalution criteria so that it does the mean sqarred error 96 | instead of the cross-entropy loss on the DNN. 97 | """ 98 | with tf.name_scope('Evaluation'): 99 | with tf.name_scope("logits/"): 100 | self.logits2 = tf.nn.conv2d(self.last, self.logits_weight, strides=[1,1,1,1], padding="VALID") 101 | self.logits = tf.nn.bias_add(self.logits2, self.logits_biases) 102 | self.predictions = self.logits 103 | with tf.name_scope('Loss'): 104 | 105 | self.loss = tf.reduce_mean(tf.losses.mean_squared_error(self.logits, self.train_labels_node)) 106 | tf.summary.scalar("mean_squared_error", self.loss) 107 | self.predictions = tf.squeeze(self.predictions, [3]) 108 | self.train_prediction = self.predictions 109 | 110 | self.test_prediction = self.predictions 111 | 112 | tf.global_variables_initializer().run() 113 | 114 | print('Computational graph initialised') 115 | 116 | def init_vars(self): 117 | """ 118 | Different number of output channels. 119 | """ 120 | self.is_training = tf.placeholder_with_default(True, shape=[]) 121 | #### 122 | 123 | self.input_node = self.input_node_f() 124 | 125 | self.train_labels_node = self.label_node_f() 126 | n_features = self.N_FEATURES 127 | 128 | self.conv1_1weights = self.weight_xavier(3, self.NUM_CHANNELS, n_features, "conv1_1/") 129 | self.conv1_1biases = self.biases_const_f(0.1, n_features, "conv1_1/") 130 | 131 | self.conv1_2weights = self.weight_xavier(3, n_features, n_features, "conv1_2/") 132 | self.conv1_2biases = self.biases_const_f(0.1, n_features, "conv1_2/") 133 | 134 | self.conv1_3weights = self.weight_xavier(3, 2 * n_features, n_features, "conv1_3/") 135 | self.conv1_3biases = self.biases_const_f(0.1, n_features, "conv1_3/") 136 | 137 | self.conv1_4weights = self.weight_xavier(3, n_features, n_features, "conv1_4/") 138 | self.conv1_4biases = self.biases_const_f(0.1, n_features, "conv1_4/") 139 | 140 | 141 | 142 | self.conv2_1weights = self.weight_xavier(3, n_features, 2 * n_features, "conv2_1/") 143 | self.conv2_1biases = self.biases_const_f(0.1, 2 * n_features, "conv2_1/") 144 | 145 | self.conv2_2weights = self.weight_xavier(3, 2 * n_features, 2 * n_features, "conv2_2/") 146 | self.conv2_2biases = self.biases_const_f(0.1, 2 * n_features, "conv2_2/") 147 | 148 | self.conv2_3weights = self.weight_xavier(3, 4 * n_features, 2 * n_features, "conv2_3/") 149 | self.conv2_3biases = self.biases_const_f(0.1, 2 * n_features, "conv2_3/") 150 | 151 | self.conv2_4weights = self.weight_xavier(3, 2 * n_features, 2 * n_features, "conv2_4/") 152 | self.conv2_4biases = self.biases_const_f(0.1, 2 * n_features, "conv2_4/") 153 | 154 | 155 | 156 | self.conv3_1weights = self.weight_xavier(3, 2 * n_features, 4 * n_features, "conv3_1/") 157 | self.conv3_1biases = self.biases_const_f(0.1, 4 * n_features, "conv3_1/") 158 | 159 | self.conv3_2weights = self.weight_xavier(3, 4 * n_features, 4 * n_features, "conv3_2/") 160 | self.conv3_2biases = self.biases_const_f(0.1, 4 * n_features, "conv3_2/") 161 | 162 | self.conv3_3weights = self.weight_xavier(3, 8 * n_features, 4 * n_features, "conv3_3/") 163 | self.conv3_3biases = self.biases_const_f(0.1, 4 * n_features, "conv3_3/") 164 | 165 | self.conv3_4weights = self.weight_xavier(3, 4 * n_features, 4 * n_features, "conv3_4/") 166 | self.conv3_4biases = self.biases_const_f(0.1, 4 * n_features, "conv3_4/") 167 | 168 | 169 | 170 | self.conv4_1weights = self.weight_xavier(3, 4 * n_features, 8 * n_features, "conv4_1/") 171 | self.conv4_1biases = self.biases_const_f(0.1, 8 * n_features, "conv4_1/") 172 | 173 | self.conv4_2weights = self.weight_xavier(3, 8 * n_features, 8 * n_features, "conv4_2/") 174 | self.conv4_2biases = self.biases_const_f(0.1, 8 * n_features, "conv4_2/") 175 | 176 | self.conv4_3weights = self.weight_xavier(3, 16 * n_features, 8 * n_features, "conv4_3/") 177 | self.conv4_3biases = self.biases_const_f(0.1, 8 * n_features, "conv4_3/") 178 | 179 | self.conv4_4weights = self.weight_xavier(3, 8 * n_features, 8 * n_features, "conv4_4/") 180 | self.conv4_4biases = self.biases_const_f(0.1, 8 * n_features, "conv4_4/") 181 | 182 | 183 | 184 | self.conv5_1weights = self.weight_xavier(3, 8 * n_features, 16 * n_features, "conv5_1/") 185 | self.conv5_1biases = self.biases_const_f(0.1, 16 * n_features, "conv5_1/") 186 | 187 | self.conv5_2weights = self.weight_xavier(3, 16 * n_features, 16 * n_features, "conv5_2/") 188 | self.conv5_2biases = self.biases_const_f(0.1, 16 * n_features, "conv5_2/") 189 | 190 | 191 | 192 | 193 | self.tconv5_4weights = self.weight_xavier(2, 8 * n_features, 16 * n_features, "tconv5_4/") 194 | self.tconv5_4biases = self.biases_const_f(0.1, 8 * n_features, "tconv5_4/") 195 | 196 | self.tconv4_3weights = self.weight_xavier(2, 4 * n_features, 8 * n_features, "tconv4_3/") 197 | self.tconv4_3biases = self.biases_const_f(0.1, 4 * n_features, "tconv4_3/") 198 | 199 | self.tconv3_2weights = self.weight_xavier(2, 2 * n_features, 4 * n_features, "tconv3_2/") 200 | self.tconv3_2biases = self.biases_const_f(0.1, 2 * n_features, "tconv3_2/") 201 | 202 | self.tconv2_1weights = self.weight_xavier(2, n_features, 2 * n_features, "tconv2_1/") 203 | self.tconv2_1biases = self.biases_const_f(0.1, n_features, "tconv2_1/") 204 | 205 | 206 | 207 | self.logits_weight = self.weight_xavier(1, n_features, 1, "logits/") 208 | self.logits_biases = self.biases_const_f(0.1, 1, "logits/") 209 | 210 | self.keep_prob = tf.Variable(self.DROPOUT, name="dropout_prob") 211 | 212 | 213 | def init_model_architecture(self): 214 | """ 215 | Graph structure. 216 | """ 217 | self.conv1_1 = self.conv_layer_f(self.input_node, self.conv1_1weights, "conv1_1/") 218 | self.relu1_1 = self.relu_layer_f(self.conv1_1, self.conv1_1biases, "conv1_1/") 219 | 220 | self.conv1_2 = self.conv_layer_f(self.relu1_1, self.conv1_2weights, "conv1_2/") 221 | self.relu1_2 = self.relu_layer_f(self.conv1_2, self.conv1_2biases, "conv1_2/") 222 | 223 | 224 | self.pool1_2 = self.max_pool(self.relu1_2, name="pool1_2") 225 | 226 | 227 | self.conv2_1 = self.conv_layer_f(self.pool1_2, self.conv2_1weights, "conv2_1/") 228 | self.relu2_1 = self.relu_layer_f(self.conv2_1, self.conv2_1biases, "conv2_1/") 229 | 230 | self.conv2_2 = self.conv_layer_f(self.relu2_1, self.conv2_2weights, "conv2_2/") 231 | self.relu2_2 = self.relu_layer_f(self.conv2_2, self.conv2_2biases, "conv2_2/") 232 | 233 | 234 | self.pool2_3 = self.max_pool(self.relu2_2, name="pool2_3") 235 | 236 | 237 | self.conv3_1 = self.conv_layer_f(self.pool2_3, self.conv3_1weights, "conv3_1/") 238 | self.relu3_1 = self.relu_layer_f(self.conv3_1, self.conv3_1biases, "conv3_1/") 239 | 240 | self.conv3_2 = self.conv_layer_f(self.relu3_1, self.conv3_2weights, "conv3_2/") 241 | self.relu3_2 = self.relu_layer_f(self.conv3_2, self.conv3_2biases, "conv3_2/") 242 | 243 | 244 | self.pool3_4 = self.max_pool(self.relu3_2, name="pool3_4") 245 | 246 | 247 | self.conv4_1 = self.conv_layer_f(self.pool3_4, self.conv4_1weights, "conv4_1/") 248 | self.relu4_1 = self.relu_layer_f(self.conv4_1, self.conv4_1biases, "conv4_1/") 249 | 250 | self.conv4_2 = self.conv_layer_f(self.relu4_1, self.conv4_2weights, "conv4_2/") 251 | self.relu4_2 = self.relu_layer_f(self.conv4_2, self.conv4_2biases, "conv4_2/") 252 | 253 | 254 | self.pool4_5 = self.max_pool(self.relu4_2, name="pool4_5") 255 | 256 | 257 | self.conv5_1 = self.conv_layer_f(self.pool4_5, self.conv5_1weights, "conv5_1/") 258 | self.relu5_1 = self.relu_layer_f(self.conv5_1, self.conv5_1biases, "conv5_1/") 259 | 260 | self.conv5_2 = self.conv_layer_f(self.relu5_1, self.conv5_2weights, "conv5_2/") 261 | self.relu5_2 = self.relu_layer_f(self.conv5_2, self.conv5_2biases, "conv5_2/") 262 | 263 | 264 | 265 | self.tconv5_4 = self.transposeconv_layer_f(self.relu5_2, self.tconv5_4weights, "tconv5_4/") 266 | self.trelu5_4 = self.relu_layer_f(self.tconv5_4, self.tconv5_4biases, "tconv5_4/") 267 | self.bridge4 = self.CropAndMerge(self.relu4_2, self.trelu5_4, "bridge4") 268 | 269 | 270 | 271 | self.conv4_3 = self.conv_layer_f(self.bridge4, self.conv4_3weights, "conv4_3/") 272 | self.relu4_3 = self.relu_layer_f(self.conv4_3, self.conv4_3biases, "conv4_3/") 273 | 274 | self.conv4_4 = self.conv_layer_f(self.relu4_3, self.conv4_4weights, "conv4_4/") 275 | self.relu4_4 = self.relu_layer_f(self.conv4_4, self.conv4_4biases, "conv4_4/") 276 | 277 | 278 | 279 | self.tconv4_3 = self.transposeconv_layer_f(self.relu4_4, self.tconv4_3weights, "tconv4_3/") 280 | self.trelu4_3 = self.relu_layer_f(self.tconv4_3, self.tconv4_3biases, "tconv4_3/") 281 | self.bridge3 = self.CropAndMerge(self.relu3_2, self.trelu4_3, "bridge3") 282 | 283 | 284 | 285 | self.conv3_3 = self.conv_layer_f(self.bridge3, self.conv3_3weights, "conv3_3/") 286 | self.relu3_3 = self.relu_layer_f(self.conv3_3, self.conv3_3biases, "conv3_3/") 287 | 288 | self.conv3_4 = self.conv_layer_f(self.relu3_3, self.conv3_4weights, "conv3_4/") 289 | self.relu3_4 = self.relu_layer_f(self.conv3_4, self.conv3_4biases, "conv3_4/") 290 | 291 | 292 | 293 | self.tconv3_2 = self.transposeconv_layer_f(self.relu3_4, self.tconv3_2weights, "tconv3_2/") 294 | self.trelu3_2 = self.relu_layer_f(self.tconv3_2, self.tconv3_2biases, "tconv3_2/") 295 | self.bridge2 = self.CropAndMerge(self.relu2_2, self.trelu3_2, "bridge2") 296 | 297 | 298 | 299 | self.conv2_3 = self.conv_layer_f(self.bridge2, self.conv2_3weights, "conv2_3/") 300 | self.relu2_3 = self.relu_layer_f(self.conv2_3, self.conv2_3biases, "conv2_3/") 301 | 302 | self.conv2_4 = self.conv_layer_f(self.relu2_3, self.conv2_4weights, "conv2_4/") 303 | self.relu2_4 = self.relu_layer_f(self.conv2_4, self.conv2_4biases, "conv2_4/") 304 | 305 | 306 | 307 | self.tconv2_1 = self.transposeconv_layer_f(self.relu2_4, self.tconv2_1weights, "tconv2_1/") 308 | self.trelu2_1 = self.relu_layer_f(self.tconv2_1, self.tconv2_1biases, "tconv2_1/") 309 | self.bridge1 = self.CropAndMerge(self.relu1_2, self.trelu2_1, "bridge1") 310 | 311 | 312 | 313 | self.conv1_3 = self.conv_layer_f(self.bridge1, self.conv1_3weights, "conv1_3/") 314 | self.relu1_3 = self.relu_layer_f(self.conv1_3, self.conv1_3biases, "conv1_3/") 315 | 316 | self.conv1_4 = self.conv_layer_f(self.relu1_3, self.conv1_4weights, "conv1_4/") 317 | self.relu1_4 = self.relu_layer_f(self.conv1_4, self.conv1_4biases, "conv1_4/") 318 | 319 | self.last = self.relu1_4 320 | 321 | print('Model architecture initialised') 322 | 323 | 324 | 325 | def error_rate(self, predictions, labels, iter): 326 | """ 327 | Error rate on the training data. To be displayed. 328 | """ 329 | error = mean_squared_error(labels.flatten(), predictions.flatten()) 330 | 331 | return error 332 | 333 | def Validation(self, DG_TEST, step): 334 | """ 335 | How the model validates. 336 | """ 337 | if DG_TEST is None: 338 | print "no validation" 339 | else: 340 | n_test = DG_TEST.length 341 | n_batch = int(np.ceil(float(n_test) / 1)) 342 | 343 | l = 0. 344 | for i in range(n_batch): 345 | Xval, Yval = DG_TEST.Batch(0, 1) 346 | #Yval = Yval / 255. 347 | feed_dict = {self.input_node: Xval, 348 | self.train_labels_node: Yval, 349 | self.is_training: False} 350 | l_tmp, pred, s = self.sess.run([self.loss, 351 | self.predictions, 352 | self.merged_summary], 353 | feed_dict=feed_dict) 354 | l += l_tmp 355 | 356 | l = l / n_batch 357 | 358 | summary = tf.Summary() 359 | summary.value.add(tag="TestMan/Loss", simple_value=l) 360 | self.summary_test_writer.add_summary(summary, step) 361 | self.summary_test_writer.add_summary(s, step) 362 | print(' Validation loss: %.1f' % l) 363 | self.saver.save(self.sess, self.LOG + '/' + "model.ckpt", step) 364 | 365 | def train(self, DGTest): 366 | """ 367 | How the model trains. 368 | """ 369 | epoch = self.STEPS * self.BATCH_SIZE // self.N_EPOCH 370 | self.Saver() 371 | trainable_var = tf.trainable_variables() 372 | self.LearningRateSchedule(self.LEARNING_RATE, self.K, epoch) 373 | self.optimization(trainable_var) 374 | self.ExponentialMovingAverage(trainable_var, self.DECAY_EMA) 375 | init_op = tf.group(tf.global_variables_initializer(), 376 | tf.local_variables_initializer()) 377 | self.sess.run(init_op) 378 | self.regularize_model() 379 | 380 | self.Saver() 381 | 382 | 383 | 384 | self.summary_test_writer = tf.summary.FileWriter(self.LOG + '/test', 385 | graph=self.sess.graph) 386 | 387 | self.summary_writer = tf.summary.FileWriter(self.LOG + '/train', graph=self.sess.graph) 388 | self.merged_summary = tf.summary.merge_all() 389 | steps = self.STEPS 390 | 391 | print "self.global step", int(self.global_step.eval()) 392 | coord = tf.train.Coordinator() 393 | threads = tf.train.start_queue_runners(coord=coord) 394 | begin = int(self.global_step.eval()) 395 | print "begin", begin 396 | for step in range(begin, steps + begin): 397 | # self.optimizer is replaced by self.training_op for the exponential moving decay 398 | _, l, lr, predictions, batch_labels, s = self.sess.run( 399 | [self.training_op, self.loss, self.learning_rate, 400 | self.train_prediction, self.train_labels_node, 401 | self.merged_summary]) 402 | 403 | if step % self.N_PRINT == 0: 404 | i = datetime.now() 405 | print i.strftime('%Y/%m/%d %H:%M:%S: \n ') 406 | self.summary_writer.add_summary(s, step) 407 | print(' Step %d of %d' % (step, steps)) 408 | print(' Learning rate: %.5f \n') % lr 409 | print(' Mini-batch loss: %.5f \n ') % l 410 | print(' Max value: %.5f \n ') % np.max(predictions) 411 | self.Validation(DGTest, step) 412 | coord.request_stop() 413 | coord.join(threads) 414 | def predict(self, tensor): 415 | feed_dict = {self.input_node: tensor, 416 | self.is_training: False} 417 | pred = self.sess.run(self.predictions, 418 | feed_dict=feed_dict) 419 | return pred -------------------------------------------------------------------------------- /src_RealData/Nets/UNetObject.py: -------------------------------------------------------------------------------- 1 | from DataTF import DataReader 2 | import tensorflow as tf 3 | import os 4 | import numpy as np 5 | from DataReadDecode import read_and_decode 6 | 7 | def print_dim(text ,tensor): 8 | """ 9 | Prints useful tensor size for debugging. 10 | """ 11 | print text, tensor.get_shape() 12 | print 13 | 14 | class UNet(DataReader): 15 | """ 16 | UNet version for DNN, implies mostly having RGB size 17 | different to the GT size. 18 | """ 19 | def init_queue(self, tfrecords_filename): 20 | """ 21 | Different decoding for UNet 22 | """ 23 | self.filename_queue = tf.train.string_input_producer( 24 | [tfrecords_filename], num_epochs=10) 25 | with tf.device('/cpu:0'): 26 | self.image, self.annotation = read_and_decode(self.filename_queue, 27 | self.IMAGE_SIZE[0], 28 | self.IMAGE_SIZE[1], 29 | self.BATCH_SIZE, 30 | self.N_THREADS, 31 | True) 32 | print("Queue initialized") 33 | 34 | def WritteSummaryImages(self): 35 | """ 36 | Croping UNet image so that it matches with GT. 37 | """ 38 | Size1 = tf.shape(self.input_node)[1] 39 | Size_to_be = tf.cast(Size1, tf.int32) - 184 40 | crop_input_node = tf.slice(self.input_node, [0, 92, 92, 0], [-1, Size_to_be, Size_to_be, -1]) 41 | 42 | tf.summary.image("Input", crop_input_node, max_outputs=4) 43 | tf.summary.image("Label", self.train_labels_node, max_outputs=4) 44 | tf.summary.image("Pred", tf.expand_dims(tf.cast(self.predictions, tf.float32), dim=3), max_outputs=4) 45 | 46 | def input_node_f(self): 47 | """ 48 | Input node can be of any size, useful for testing on different size 49 | images. 50 | """ 51 | if self.SUB_MEAN: 52 | self.images_queue = self.image - self.MEAN_ARRAY 53 | else: 54 | self.images_queue = self.image 55 | self.image_PH = tf.placeholder_with_default(self.images_queue, shape=[None, 56 | None, 57 | None, 58 | 3]) 59 | return self.image_PH 60 | 61 | def conv_layer_f(self, i_layer, w_var, scope_name="conv", strides=[1,1,1,1], padding="VALID"): 62 | """ 63 | Convolution layer with more default parameters 64 | """ 65 | with tf.name_scope(scope_name): 66 | return tf.nn.conv2d(i_layer, w_var, strides=strides, padding=padding) 67 | 68 | 69 | def transposeconv_layer_f(self, i_layer, w_, scope_name="tconv", padding="VALID"): 70 | """ 71 | Transpose convolution layer 72 | """ 73 | i_shape = tf.shape(i_layer) 74 | o_shape = tf.stack([i_shape[0], i_shape[1]*2, i_shape[2]*2, i_shape[3]//2]) 75 | return tf.nn.conv2d_transpose(i_layer, w_, output_shape=o_shape, 76 | strides=[1,2,2,1], padding=padding) 77 | 78 | def max_pool(self, i_layer, ksize=[1,2,2,1], strides=[1,2,2,1], 79 | padding="VALID", name="MaxPool"): 80 | """ 81 | Max pooling with more default parameters. 82 | """ 83 | return tf.nn.max_pool(i_layer, ksize=ksize, strides=strides, 84 | padding=padding, name=name) 85 | 86 | def CropAndMerge(self, Input1, Input2, name="bridge"): 87 | """ 88 | Crop input1 so that it matches input2 and then 89 | return the concatenation of both channels. 90 | """ 91 | Size1_x = tf.shape(Input1)[1] 92 | Size2_x = tf.shape(Input2)[1] 93 | 94 | Size1_y = tf.shape(Input1)[2] 95 | Size2_y = tf.shape(Input2)[2] 96 | with tf.name_scope(name): 97 | diff_x = tf.divide(tf.subtract(Size1_x, Size2_x), 2) 98 | diff_y = tf.divide(tf.subtract(Size1_y, Size2_y), 2) 99 | diff_x = tf.cast(diff_x, tf.int32) 100 | Size2_x = tf.cast(Size2_x, tf.int32) 101 | diff_y = tf.cast(diff_y, tf.int32) 102 | Size2_y = tf.cast(Size2_y, tf.int32) 103 | crop = tf.slice(Input1, [0, diff_x, diff_y, 0], [-1, Size2_x, Size2_y, -1]) 104 | concat = tf.concat([crop, Input2], axis=3) 105 | 106 | return concat 107 | 108 | def init_vars(self): 109 | """ 110 | Parameter initialisation for the DNN. 111 | """ 112 | self.input_node = self.input_node_f() 113 | 114 | self.train_labels_node = self.label_node_f() 115 | n_features = self.N_FEATURES 116 | 117 | self.conv1_1weights = self.weight_xavier(3, self.NUM_CHANNELS, n_features, "conv1_1/") 118 | self.conv1_1biases = self.biases_const_f(0.1, n_features, "conv1_1/") 119 | 120 | self.conv1_2weights = self.weight_xavier(3, n_features, n_features, "conv1_2/") 121 | self.conv1_2biases = self.biases_const_f(0.1, n_features, "conv1_2/") 122 | 123 | self.conv1_3weights = self.weight_xavier(3, 2 * n_features, n_features, "conv1_3/") 124 | self.conv1_3biases = self.biases_const_f(0.1, n_features, "conv1_3/") 125 | 126 | self.conv1_4weights = self.weight_xavier(3, n_features, n_features, "conv1_4/") 127 | self.conv1_4biases = self.biases_const_f(0.1, n_features, "conv1_4/") 128 | 129 | 130 | 131 | self.conv2_1weights = self.weight_xavier(3, n_features, 2 * n_features, "conv2_1/") 132 | self.conv2_1biases = self.biases_const_f(0.1, 2 * n_features, "conv2_1/") 133 | 134 | self.conv2_2weights = self.weight_xavier(3, 2 * n_features, 2 * n_features, "conv2_2/") 135 | self.conv2_2biases = self.biases_const_f(0.1, 2 * n_features, "conv2_2/") 136 | 137 | self.conv2_3weights = self.weight_xavier(3, 4 * n_features, 2 * n_features, "conv2_3/") 138 | self.conv2_3biases = self.biases_const_f(0.1, 2 * n_features, "conv2_3/") 139 | 140 | self.conv2_4weights = self.weight_xavier(3, 2 * n_features, 2 * n_features, "conv2_4/") 141 | self.conv2_4biases = self.biases_const_f(0.1, 2 * n_features, "conv2_4/") 142 | 143 | 144 | 145 | self.conv3_1weights = self.weight_xavier(3, 2 * n_features, 4 * n_features, "conv3_1/") 146 | self.conv3_1biases = self.biases_const_f(0.1, 4 * n_features, "conv3_1/") 147 | 148 | self.conv3_2weights = self.weight_xavier(3, 4 * n_features, 4 * n_features, "conv3_2/") 149 | self.conv3_2biases = self.biases_const_f(0.1, 4 * n_features, "conv3_2/") 150 | 151 | self.conv3_3weights = self.weight_xavier(3, 8 * n_features, 4 * n_features, "conv3_3/") 152 | self.conv3_3biases = self.biases_const_f(0.1, 4 * n_features, "conv3_3/") 153 | 154 | self.conv3_4weights = self.weight_xavier(3, 4 * n_features, 4 * n_features, "conv3_4/") 155 | self.conv3_4biases = self.biases_const_f(0.1, 4 * n_features, "conv3_4/") 156 | 157 | 158 | 159 | self.conv4_1weights = self.weight_xavier(3, 4 * n_features, 8 * n_features, "conv4_1/") 160 | self.conv4_1biases = self.biases_const_f(0.1, 8 * n_features, "conv4_1/") 161 | 162 | self.conv4_2weights = self.weight_xavier(3, 8 * n_features, 8 * n_features, "conv4_2/") 163 | self.conv4_2biases = self.biases_const_f(0.1, 8 * n_features, "conv4_2/") 164 | 165 | self.conv4_3weights = self.weight_xavier(3, 16 * n_features, 8 * n_features, "conv4_3/") 166 | self.conv4_3biases = self.biases_const_f(0.1, 8 * n_features, "conv4_3/") 167 | 168 | self.conv4_4weights = self.weight_xavier(3, 8 * n_features, 8 * n_features, "conv4_4/") 169 | self.conv4_4biases = self.biases_const_f(0.1, 8 * n_features, "conv4_4/") 170 | 171 | 172 | 173 | self.conv5_1weights = self.weight_xavier(3, 8 * n_features, 16 * n_features, "conv5_1/") 174 | self.conv5_1biases = self.biases_const_f(0.1, 16 * n_features, "conv5_1/") 175 | 176 | self.conv5_2weights = self.weight_xavier(3, 16 * n_features, 16 * n_features, "conv5_2/") 177 | self.conv5_2biases = self.biases_const_f(0.1, 16 * n_features, "conv5_2/") 178 | 179 | 180 | 181 | 182 | self.tconv5_4weights = self.weight_xavier(2, 8 * n_features, 16 * n_features, "tconv5_4/") 183 | self.tconv5_4biases = self.biases_const_f(0.1, 8 * n_features, "tconv5_4/") 184 | 185 | self.tconv4_3weights = self.weight_xavier(2, 4 * n_features, 8 * n_features, "tconv4_3/") 186 | self.tconv4_3biases = self.biases_const_f(0.1, 4 * n_features, "tconv4_3/") 187 | 188 | self.tconv3_2weights = self.weight_xavier(2, 2 * n_features, 4 * n_features, "tconv3_2/") 189 | self.tconv3_2biases = self.biases_const_f(0.1, 2 * n_features, "tconv3_2/") 190 | 191 | self.tconv2_1weights = self.weight_xavier(2, n_features, 2 * n_features, "tconv2_1/") 192 | self.tconv2_1biases = self.biases_const_f(0.1, n_features, "tconv2_1/") 193 | 194 | 195 | 196 | self.logits_weight = self.weight_xavier(1, n_features, 2, "logits/") 197 | self.logits_biases = self.biases_const_f(0.1, 2, "logits/") 198 | 199 | self.keep_prob = tf.Variable(0.5, name="dropout_prob") 200 | 201 | print('Model variables initialised') 202 | 203 | 204 | 205 | def init_model_architecture(self): 206 | """ 207 | Graph structure 208 | """ 209 | 210 | self.conv1_1 = self.conv_layer_f(self.input_node, self.conv1_1weights, "conv1_1/") 211 | self.relu1_1 = self.relu_layer_f(self.conv1_1, self.conv1_1biases, "conv1_1/") 212 | 213 | self.conv1_2 = self.conv_layer_f(self.relu1_1, self.conv1_2weights, "conv1_2/") 214 | self.relu1_2 = self.relu_layer_f(self.conv1_2, self.conv1_2biases, "conv1_2/") 215 | 216 | 217 | self.pool1_2 = self.max_pool(self.relu1_2, name="pool1_2") 218 | 219 | 220 | self.conv2_1 = self.conv_layer_f(self.pool1_2, self.conv2_1weights, "conv2_1/") 221 | self.relu2_1 = self.relu_layer_f(self.conv2_1, self.conv2_1biases, "conv2_1/") 222 | 223 | self.conv2_2 = self.conv_layer_f(self.relu2_1, self.conv2_2weights, "conv2_2/") 224 | self.relu2_2 = self.relu_layer_f(self.conv2_2, self.conv2_2biases, "conv2_2/") 225 | 226 | 227 | self.pool2_3 = self.max_pool(self.relu2_2, name="pool2_3") 228 | 229 | 230 | self.conv3_1 = self.conv_layer_f(self.pool2_3, self.conv3_1weights, "conv3_1/") 231 | self.relu3_1 = self.relu_layer_f(self.conv3_1, self.conv3_1biases, "conv3_1/") 232 | 233 | self.conv3_2 = self.conv_layer_f(self.relu3_1, self.conv3_2weights, "conv3_2/") 234 | self.relu3_2 = self.relu_layer_f(self.conv3_2, self.conv3_2biases, "conv3_2/") 235 | 236 | 237 | self.pool3_4 = self.max_pool(self.relu3_2, name="pool3_4") 238 | 239 | 240 | self.conv4_1 = self.conv_layer_f(self.pool3_4, self.conv4_1weights, "conv4_1/") 241 | self.relu4_1 = self.relu_layer_f(self.conv4_1, self.conv4_1biases, "conv4_1/") 242 | 243 | self.conv4_2 = self.conv_layer_f(self.relu4_1, self.conv4_2weights, "conv4_2/") 244 | self.relu4_2 = self.relu_layer_f(self.conv4_2, self.conv4_2biases, "conv4_2/") 245 | 246 | 247 | self.pool4_5 = self.max_pool(self.relu4_2, name="pool4_5") 248 | 249 | 250 | self.conv5_1 = self.conv_layer_f(self.pool4_5, self.conv5_1weights, "conv5_1/") 251 | self.relu5_1 = self.relu_layer_f(self.conv5_1, self.conv5_1biases, "conv5_1/") 252 | 253 | self.conv5_2 = self.conv_layer_f(self.relu5_1, self.conv5_2weights, "conv5_2/") 254 | self.relu5_2 = self.relu_layer_f(self.conv5_2, self.conv5_2biases, "conv5_2/") 255 | 256 | 257 | 258 | self.tconv5_4 = self.transposeconv_layer_f(self.relu5_2, self.tconv5_4weights, "tconv5_4/") 259 | self.trelu5_4 = self.relu_layer_f(self.tconv5_4, self.tconv5_4biases, "tconv5_4/") 260 | self.bridge4 = self.CropAndMerge(self.relu4_2, self.trelu5_4, "bridge4") 261 | 262 | 263 | 264 | self.conv4_3 = self.conv_layer_f(self.bridge4, self.conv4_3weights, "conv4_3/") 265 | self.relu4_3 = self.relu_layer_f(self.conv4_3, self.conv4_3biases, "conv4_3/") 266 | 267 | self.conv4_4 = self.conv_layer_f(self.relu4_3, self.conv4_4weights, "conv4_4/") 268 | self.relu4_4 = self.relu_layer_f(self.conv4_4, self.conv4_4biases, "conv4_4/") 269 | 270 | 271 | 272 | self.tconv4_3 = self.transposeconv_layer_f(self.relu4_4, self.tconv4_3weights, "tconv4_3/") 273 | self.trelu4_3 = self.relu_layer_f(self.tconv4_3, self.tconv4_3biases, "tconv4_3/") 274 | self.bridge3 = self.CropAndMerge(self.relu3_2, self.trelu4_3, "bridge3") 275 | 276 | 277 | 278 | self.conv3_3 = self.conv_layer_f(self.bridge3, self.conv3_3weights, "conv3_3/") 279 | self.relu3_3 = self.relu_layer_f(self.conv3_3, self.conv3_3biases, "conv3_3/") 280 | 281 | self.conv3_4 = self.conv_layer_f(self.relu3_3, self.conv3_4weights, "conv3_4/") 282 | self.relu3_4 = self.relu_layer_f(self.conv3_4, self.conv3_4biases, "conv3_4/") 283 | 284 | 285 | 286 | self.tconv3_2 = self.transposeconv_layer_f(self.relu3_4, self.tconv3_2weights, "tconv3_2/") 287 | self.trelu3_2 = self.relu_layer_f(self.tconv3_2, self.tconv3_2biases, "tconv3_2/") 288 | self.bridge2 = self.CropAndMerge(self.relu2_2, self.trelu3_2, "bridge2") 289 | 290 | 291 | 292 | self.conv2_3 = self.conv_layer_f(self.bridge2, self.conv2_3weights, "conv2_3/") 293 | self.relu2_3 = self.relu_layer_f(self.conv2_3, self.conv2_3biases, "conv2_3/") 294 | 295 | self.conv2_4 = self.conv_layer_f(self.relu2_3, self.conv2_4weights, "conv2_4/") 296 | self.relu2_4 = self.relu_layer_f(self.conv2_4, self.conv2_4biases, "conv2_4/") 297 | 298 | 299 | 300 | self.tconv2_1 = self.transposeconv_layer_f(self.relu2_4, self.tconv2_1weights, "tconv2_1/") 301 | self.trelu2_1 = self.relu_layer_f(self.tconv2_1, self.tconv2_1biases, "tconv2_1/") 302 | self.bridge1 = self.CropAndMerge(self.relu1_2, self.trelu2_1, "bridge1") 303 | 304 | 305 | 306 | self.conv1_3 = self.conv_layer_f(self.bridge1, self.conv1_3weights, "conv1_3/") 307 | self.relu1_3 = self.relu_layer_f(self.conv1_3, self.conv1_3biases, "conv1_3/") 308 | 309 | self.conv1_4 = self.conv_layer_f(self.relu1_3, self.conv1_4weights, "conv1_4/") 310 | self.relu1_4 = self.relu_layer_f(self.conv1_4, self.conv1_4biases, "conv1_4/") 311 | self.last = self.relu1_4 312 | 313 | print('Model architecture initialised') -------------------------------------------------------------------------------- /src_RealData/Nets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeterJackNaylor/DRFNS/73fc5683db5e9f860846e22c8c0daf73b7103082/src_RealData/Nets/__init__.py -------------------------------------------------------------------------------- /src_RealData/TFRecords.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from Data.CreateTFRecords import CreateTFRecord 5 | from Data.ImageTransform import ListTransform 6 | import numpy as np 7 | from optparse import OptionParser 8 | from utils import GetOptions 9 | 10 | 11 | if __name__ == '__main__': 12 | 13 | options = GetOptions() 14 | 15 | OUTNAME = options.TFRecord 16 | PATH = options.path 17 | CROP = options.crop 18 | SIZE = options.size_train 19 | SPLIT = options.split 20 | var_elast = [1.3, 0.03, 0.15] 21 | var_he = [0.01, 0.2] 22 | var_hsv = [0.2, 0.15] 23 | UNET = options.UNet 24 | SEED = options.seed 25 | N_EPOCH = options.epoch 26 | TYPE = options.type 27 | 28 | 29 | transform_list, transform_list_test = ListTransform(var_elast=var_elast, 30 | var_hsv=var_hsv, 31 | var_he=var_he) 32 | if options.split == "train": 33 | TEST_PATIENT = ["testbreast", "testliver", "testkidney", "testprostate", 34 | "bladder", "colorectal", "stomach", "test"] 35 | TRANSFORM_LIST = transform_list 36 | elif options.split == "test": 37 | TEST_PATIENT = ["test"] 38 | TRANSFORM_LIST = transform_list_test 39 | SIZE = options.size_test 40 | 41 | elif options.split == "validation": 42 | options.split = "test" 43 | TEST_PATIENT = ["testbreast", "testliver", "testkidney", "testprostate", 44 | "bladder", "colorectal", "stomach"] 45 | TRANSFORM_LIST = transform_list_test 46 | SIZE = options.size_test 47 | 48 | 49 | SIZE = (SIZE, SIZE) 50 | CreateTFRecord(OUTNAME, PATH, CROP, SIZE, 51 | TRANSFORM_LIST, UNET, None, 52 | SEED, TEST_PATIENT, N_EPOCH, 53 | TYPE=TYPE, SPLIT=SPLIT) 54 | -------------------------------------------------------------------------------- /src_RealData/UNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from utils import GetOptions, ComputeMetrics 5 | from glob import glob 6 | import tensorflow as tf 7 | import numpy as np 8 | import os 9 | from utils import CheckOrCreate 10 | from Data.DataGenClass import DataGenMulti 11 | from Data.ImageTransform import ListTransform 12 | from Nets.UNetBatchNorm import UNetBatchNorm 13 | 14 | 15 | 16 | class Model(UNetBatchNorm): 17 | """ 18 | UNet model 19 | """ 20 | def test(self, p1, p2, steps): 21 | """ 22 | How you test this model 23 | """ 24 | loss, roc = 0., 0. 25 | acc, F1, recall = 0., 0., 0. 26 | precision, jac, AJI = 0., 0., 0. 27 | init_op = tf.group(tf.global_variables_initializer(), 28 | tf.local_variables_initializer()) 29 | self.sess.run(init_op) 30 | self.Saver() 31 | coord = tf.train.Coordinator() 32 | threads = tf.train.start_queue_runners(coord=coord) 33 | 34 | for step in range(steps): 35 | feed_dict = {self.is_training: False} 36 | l, prob, batch_labels = self.sess.run([self.loss, self.train_prediction, 37 | self.train_labels_node], feed_dict=feed_dict) 38 | loss += l 39 | out = ComputeMetrics(prob[0,:,:,1], batch_labels[0,:,:,0], p1, p2) 40 | acc += out[0] 41 | roc += out[1] 42 | jac += out[2] 43 | recall += out[3] 44 | precision += out[4] 45 | F1 += out[5] 46 | AJI += out[6] 47 | coord.request_stop() 48 | coord.join(threads) 49 | loss, acc, F1 = np.array([loss, acc, F1]) / steps 50 | recall, precision, roc = np.array([recall, precision, roc]) / steps 51 | jac, AJI = np.array([jac, AJI]) / steps 52 | return loss, acc, F1, recall, precision, roc, jac, AJI 53 | 54 | def validation(self, DG_TEST, p1, p2, save_path): 55 | """ 56 | How you perform validation 57 | """ 58 | n_test = DG_TEST.length 59 | n_batch = int(np.ceil(float(n_test) / self.BATCH_SIZE)) 60 | loss, roc = [], [] 61 | acc, F1, recall = [], [], [] 62 | precision, jac, AJI = [], [], [] 63 | res = [] 64 | 65 | for i in range(n_batch): 66 | Xval, Yval = DG_TEST.Batch(0, self.BATCH_SIZE) 67 | feed_dict = {self.input_node: Xval, 68 | self.train_labels_node: Yval, 69 | self.is_training: False} 70 | l, pred = self.sess.run([self.loss, self.train_prediction], 71 | feed_dict=feed_dict) 72 | rgb = (Xval[0,92:-92,92:-92] + np.load(self.MEAN_FILE)).astype(np.uint8) 73 | out = ComputeMetrics(pred[0,:,:,1], Yval[0,:,:,0], p1, p2, rgb=rgb, save_path=save_path, ind=i) 74 | out = [l] + list(out) 75 | res.append(out) 76 | return res 77 | 78 | 79 | if __name__== "__main__": 80 | 81 | transform_list, transform_list_test = ListTransform() 82 | options = GetOptions() 83 | 84 | SPLIT = options.split 85 | 86 | ## Model parameters 87 | TFRecord = options.TFRecord 88 | LEARNING_RATE = options.lr 89 | BATCH_SIZE = options.bs 90 | SIZE = (options.size_train, options.size_train) 91 | if options.size_test is not None: 92 | SIZE = (options.size_test, options.size_test) 93 | N_ITER_MAX = 0 ## defined later 94 | LRSTEP = "10epoch" 95 | N_TRAIN_SAVE = 100 96 | LOG = options.log 97 | WEIGHT_DECAY = options.weight_decay 98 | N_FEATURES = options.n_features 99 | N_EPOCH = options.epoch 100 | N_THREADS = options.THREADS 101 | MEAN_FILE = options.mean_file 102 | DROPOUT = options.dropout 103 | 104 | ## Datagen parameters 105 | PATH = options.path 106 | TEST_PATIENT = ["testbreast", "testliver", "testkidney", "testprostate", 107 | "bladder", "colorectal", "stomach", "test"] 108 | DG_TRAIN = DataGenMulti(PATH, split='train', crop = 16, size=SIZE, 109 | transforms=transform_list, UNet=True, mean_file=None, num=TEST_PATIENT) 110 | TEST_PATIENT = ["test"] 111 | DG_TEST = DataGenMulti(PATH, split="test", crop = 1, size=(500, 500), 112 | transforms=transform_list_test, UNet=True, mean_file=MEAN_FILE, num=TEST_PATIENT) 113 | if SPLIT == "train": 114 | N_ITER_MAX = N_EPOCH * DG_TRAIN.length // BATCH_SIZE 115 | elif SPLIT == "test": 116 | N_ITER_MAX = N_EPOCH * DG_TEST.length // BATCH_SIZE 117 | elif SPLIT == "validation": 118 | LOG = glob(os.path.join(LOG, '*'))[0] 119 | model = Model(TFRecord, LEARNING_RATE=LEARNING_RATE, 120 | BATCH_SIZE=BATCH_SIZE, 121 | IMAGE_SIZE=SIZE, 122 | NUM_LABELS=2, 123 | NUM_CHANNELS=3, 124 | STEPS=N_ITER_MAX, 125 | LRSTEP=LRSTEP, 126 | N_PRINT=N_TRAIN_SAVE, 127 | LOG=LOG, 128 | SEED=42, 129 | WEIGHT_DECAY=WEIGHT_DECAY, 130 | N_FEATURES=N_FEATURES, 131 | N_EPOCH=N_EPOCH, 132 | N_THREADS=N_THREADS, 133 | MEAN_FILE=MEAN_FILE, 134 | DROPOUT=DROPOUT) 135 | if SPLIT == "train": 136 | model.train(DG_TEST) 137 | elif SPLIT == "test": 138 | p1 = options.p1 139 | file_name = options.output 140 | f = open(file_name, 'w') 141 | outs = model.test(options.p1, 0.5, N_ITER_MAX) 142 | outs = [LOG] + list(outs) + [p1, 0.5] 143 | NAMES = ["ID", "Loss", "Acc", "F1", "Recall", "Precision", "ROC", "Jaccard", "AJI", "p1", "p2"] 144 | f.write('{},{},{},{},{},{},{},{},{},{},{}\n'.format(*NAMES)) 145 | 146 | f.write('{},{},{},{},{},{},{},{},{},{},{}\n'.format(*outs)) 147 | 148 | elif SPLIT == "validation": 149 | 150 | TEST_PATIENT = ["testbreast", "testliver", "testkidney", "testprostate", 151 | "bladder", "colorectal", "stomach"] 152 | 153 | file_name = options.output 154 | f = open(file_name, 'w') 155 | NAMES = ["NUMBER", "ORGAN", "Loss", "Acc", "ROC", "Jaccard", "Recall", "Precision", "F1", "AJI", "p1", "p2"] 156 | f.write('{},{},{},{},{},{},{},{},{},{},{}\n'.format(*NAMES)) 157 | 158 | 159 | for organ in TEST_PATIENT: 160 | DG_TEST = DataGenMulti(PATH, split="test", crop = 1, size=(996, 996),num=[organ], 161 | transforms=transform_list_test, UNet=True, mean_file=MEAN_FILE) 162 | save_organ = os.path.join(options.save_path, organ) 163 | CheckOrCreate(save_organ) 164 | outs = model.validation(DG_TEST, options.p1, 0.5, save_organ) 165 | for i in range(len(outs)): 166 | small_o = outs[i] 167 | small_o = [i, organ] + small_o + [options.p1, 0.5] 168 | f.write('{},{},{},{},{},{},{},{},{},{},{}\n'.format(*small_o)) 169 | f.close() 170 | 171 | -------------------------------------------------------------------------------- /src_RealData/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeterJackNaylor/DRFNS/73fc5683db5e9f860846e22c8c0daf73b7103082/src_RealData/__init__.py -------------------------------------------------------------------------------- /src_RealData/postproc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeterJackNaylor/DRFNS/73fc5683db5e9f860846e22c8c0daf73b7103082/src_RealData/postproc/__init__.py -------------------------------------------------------------------------------- /src_RealData/postproc/plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import pandas as pd 5 | from glob import glob 6 | import numpy as np 7 | import matplotlib 8 | matplotlib.use('agg') 9 | import matplotlib.pyplot as plt 10 | import matplotlib.patches as mpatches 11 | from matplotlib import lines 12 | from utils import GetOptions 13 | matplotlib.rc('text', usetex = True) 14 | options = GetOptions() 15 | CSV = glob('*.csv') 16 | 17 | df_s = [] 18 | for f in CSV: 19 | tmp = pd.read_csv(f) 20 | tmp["Model"] = f.split('.')[0] 21 | df_s.append(tmp) 22 | 23 | table = pd.concat(df_s) 24 | smaller = table[["AJI", "F1", "ORGAN", "Model", "NUMBER"]] 25 | smaller = smaller.set_index(['ORGAN', 'NUMBER', 'Model']).unstack() 26 | smaller.ix[('mean', 0),:] = smaller.mean() 27 | print smaller.round(4).to_latex(multicolumn=True, multirow=True) 28 | #smaller.to_csv(options.output_csv) 29 | 30 | grouped = table.groupby(['ORGAN', 'NUMBER']) 31 | 32 | fig, ax = plt.subplots(nrows=2, sharey=False, figsize=(15.0, 8.0), gridspec_kw = {'height_ratios':[1, 2]}) 33 | #Different color palets 34 | colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00"] 35 | colors = ["#1b9e77", "#d95f02", "#7570b3", "#e7298a", "#66a61e", "#e6ab02"] 36 | colors = ["#7fc97f", "#beaed4", "#7570b3", "#fdc086", "#ffff99", "#fed98e"] 37 | colors = ["#e41a1c", "#41b6c4", "#225ea8", "#238443", "#9e9ac8", "#6a51a3"] 38 | colors = ["#e41a1c", "#6baed6", "#2171b5", "#9e9ac8", "#6a51a3", "#74c476"] 39 | 40 | Patches = [] 41 | methods = np.unique(table["Model"]) 42 | organ_order = np.unique(table["ORGAN"]) 43 | 44 | mod = {el:[0.] * len(organ_order) * len(np.unique(table["NUMBER"])) for el in methods} 45 | mod2 = {el:[0.] * len(organ_order) * len(np.unique(table["NUMBER"])) for el in methods} 46 | mod3 = {el:[0.] * len(organ_order) * len(np.unique(table["NUMBER"])) for el in methods} 47 | dic = {"F1":mod2, "AJI":mod3} 48 | 49 | def fill_dic(row, ind): 50 | f1 = row["F1"] 51 | aji = row["AJI"] 52 | model = row["Model"] 53 | dic["F1"][model][ind] = f1 54 | dic["AJI"][model][ind] = aji 55 | 56 | for i, ((slide, num), tble) in enumerate(grouped): 57 | ind = 2 * np.where(organ_order == slide)[0][0] + int(num) 58 | tble.apply(lambda row: fill_dic(row, ind), axis=1) 59 | 60 | width = 0.15 61 | for j, model in enumerate(methods): 62 | ind = np.arange(len(dic["F1"][model])) 63 | rectsAJI = ax[0].bar(ind + j * width, dic["F1"][model], width, color=colors[j]) 64 | rectsF1 = ax[1].bar(ind + j * width, dic["AJI"][model], width, color=colors[j]) 65 | patch_to_add = mpatches.Patch(color=colors[j], label=model) 66 | Patches.append(patch_to_add) 67 | 68 | ax[0].set_ylabel('$F1$') 69 | ax[1].set_ylabel('$AJI$') 70 | 71 | Image_names = [] 72 | for el in organ_order: 73 | Image_names.append("1") 74 | Image_names.append("2") 75 | 76 | for j in range(2): 77 | if j == 1: 78 | ax[j].set_xticks(ind + 5 * width / 2) 79 | ax[j].set_xticklabels(Image_names, rotation=0) 80 | else: 81 | ax[j].set_xticks(ind + 5 * width / 2) 82 | ax[j].set_xticklabels(Image_names, rotation=0) 83 | ax[j].axhline(y=0.75,xmin=0,xmax=3,c="gray",ls="--",linewidth=0.5,zorder=0) 84 | ax[j].axhline(y=0.50,xmin=0,xmax=3,c="black",ls="-.",linewidth=0.5,zorder=0) 85 | ax[j].axhline(y=0.25,xmin=0,xmax=3,c="gray",ls="--",linewidth=0.5,zorder=0) 86 | box = ax[1].get_position() 87 | ax[1].set_position([box.x0, box.y0 + box.height * 0.1, 88 | box.width, box.height * 0.9]) 89 | ax[1].legend(handles=Patches, loc="center left", bbox_to_anchor=(1., 1.), 90 | fancybox=True, shadow=True, ncol=1) 91 | 92 | ax2 = plt.axes([0, 0, 1, 1]) 93 | ax2.set_position([box.x0 + 0.05, box.y0 + box.height * 0.1 - 0.15, 94 | box.width - 0.08, box.height * 0.9]) 95 | ax2.axis('off') 96 | for j, organ in enumerate(organ_order): 97 | P0 = [ind[j] * 2 -2, -2] 98 | P1 = [ind[j] * 2 + 1 -2, -2] 99 | line_x, line_y = np.array([P0, P1]) 100 | line_x, line_y = np.array([[0.05, 0.05], [0.05, 0.55]]) 101 | line1 = lines.Line2D(line_x, line_y, lw=2., color='k') 102 | #ax2.add_line(line1) 103 | alpha = 0.851 104 | xmin = (float(ind[j * 2]) / np.max(ind) + 0.01 + 0.16) * alpha + 0.006 105 | xmax = (float(ind[j * 2 + 1]) / np.max(ind) - 0.01 ) * alpha 106 | XY = np.mean([xmin, xmax]) + 0.003 107 | fs = 6. 108 | if 'test' in organ: 109 | organ = organ[4:] 110 | ax[1].annotate(organ.capitalize(), xy=(XY, 0.85), xytext=(XY, 0.95), xycoords='axes fraction', 111 | fontsize=fs*1.5, ha='center', va='top', 112 | bbox=dict(boxstyle='square', fc='white'), 113 | arrowprops=dict(arrowstyle='-[, widthB=5.95, lengthB=1.', lw=1.5)) 114 | ax[1].annotate(organ.capitalize(), xy=(XY, 1.05), xytext=(XY, 0.95), xycoords='axes fraction', 115 | fontsize=fs*1.5, ha='center', va='top', 116 | bbox=dict(boxstyle='square', fc='white'), 117 | arrowprops=dict(arrowstyle='-[, widthB=5.95, lengthB=1.', lw=1.5)) 118 | if j != 0: 119 | # dont add vertical lines on the first one.. 120 | x_line = ind[j] * 2 - 2 * width / 2 + 0.02 121 | ax[0].axvline(ymin=0, ymax=1, x=x_line, c="black",ls="--",linewidth=0.5,zorder=0) 122 | ax[1].axvline(ymin=0, ymax=1, x=x_line, c="black",ls="--",linewidth=0.5,zorder=0) 123 | ax[0].set_ylim([0.5,1]) 124 | ax[1].set_ylim([0,1]) 125 | ax[1].set_xlabel('$Image \ N^\circ$:') 126 | ax[1].xaxis.set_label_coords(0.03, -0.025) 127 | fig.savefig(options.output, bbox_inches='tight') -------------------------------------------------------------------------------- /src_RealData/postproc/postprocessing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from skimage.morphology import watershed 5 | import numpy as np 6 | from skimage.measure import label 7 | from skimage.morphology import reconstruction, dilation, erosion, disk, diamond, square 8 | from skimage import img_as_ubyte 9 | 10 | 11 | def PrepareProb(img, convertuint8=True, inverse=True): 12 | """ 13 | Prepares the prob image for post-processing, it can convert from 14 | float -> to uint8 and it can inverse it if needed. 15 | """ 16 | if convertuint8: 17 | img = img_as_ubyte(img) 18 | if inverse: 19 | img = 255 - img 20 | return img 21 | 22 | 23 | def HreconstructionErosion(prob_img, h): 24 | """ 25 | Performs a H minimma reconstruction via an erosion method. 26 | """ 27 | 28 | def making_top_mask(x, lamb=h): 29 | return min(255, x + lamb) 30 | 31 | f = np.vectorize(making_top_mask) 32 | shift_prob_img = f(prob_img) 33 | 34 | seed = shift_prob_img 35 | mask = prob_img 36 | recons = reconstruction( 37 | seed, mask, method='erosion').astype(np.dtype('ubyte')) 38 | return recons 39 | 40 | 41 | def find_maxima(img, convertuint8=False, inverse=False, mask=None): 42 | """ 43 | Finds all local maxima from 2D image. 44 | """ 45 | img = PrepareProb(img, convertuint8=convertuint8, inverse=inverse) 46 | recons = HreconstructionErosion(img, 1) 47 | if mask is None: 48 | return recons - img 49 | else: 50 | res = recons - img 51 | res[mask==0] = 0 52 | return res 53 | def GetContours(img): 54 | """ 55 | Returns only the contours of the image. 56 | The image has to be a binary image 57 | """ 58 | img[img > 0] = 1 59 | return dilation(img, disk(2)) - erosion(img, disk(2)) 60 | 61 | 62 | def generate_wsl(ws): 63 | """ 64 | Generates watershed line that correspond to areas of 65 | touching objects. 66 | """ 67 | se = square(3) 68 | ero = ws.copy() 69 | ero[ero == 0] = ero.max() + 1 70 | ero = erosion(ero, se) 71 | ero[ws == 0] = 0 72 | 73 | grad = dilation(ws, se) - ero 74 | grad[ws == 0] = 0 75 | grad[grad > 0] = 255 76 | grad = grad.astype(np.uint8) 77 | return grad 78 | 79 | def DynamicWatershedAlias(p_img, lamb, p_thresh = 0.5): 80 | """ 81 | Applies our dynamic watershed to 2D prob/dist image. 82 | """ 83 | b_img = (p_img > p_thresh) + 0 84 | Probs_inv = PrepareProb(p_img) 85 | 86 | 87 | Hrecons = HreconstructionErosion(Probs_inv, lamb) 88 | markers_Probs_inv = find_maxima(Hrecons, mask = b_img) 89 | markers_Probs_inv = label(markers_Probs_inv) 90 | ws_labels = watershed(Hrecons, markers_Probs_inv, mask=b_img) 91 | arrange_label = ArrangeLabel(ws_labels) 92 | wsl = generate_wsl(arrange_label) 93 | arrange_label[wsl > 0] = 0 94 | 95 | 96 | return arrange_label 97 | 98 | def ArrangeLabel(mat): 99 | """ 100 | Arrange label image as to effectively put background to 0. 101 | """ 102 | val, counts = np.unique(mat, return_counts=True) 103 | background_val = val[np.argmax(counts)] 104 | mat = label(mat, background = background_val) 105 | if np.min(mat) < 0: 106 | mat += np.min(mat) 107 | mat = ArrangeLabel(mat) 108 | return mat 109 | 110 | 111 | def PostProcess(prob_image, param=7, thresh = 0.5): 112 | """ 113 | Perform DynamicWatershedAlias with some default parameters. 114 | """ 115 | segmentation_mask = DynamicWatershedAlias(prob_image, param, thresh) 116 | return segmentation_mask 117 | -------------------------------------------------------------------------------- /src_RealData/postproc/regroup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import pandas as pd 6 | from glob import glob 7 | from optparse import OptionParser 8 | 9 | 10 | parser = OptionParser() 11 | parser.add_option('--store_best', dest='store_best',type='str') 12 | parser.add_option('--output', dest="output", type="str") 13 | 14 | (options, args) = parser.parse_args() 15 | 16 | CSV = glob('*.csv') 17 | df_list = [] 18 | for f in CSV: 19 | df = pd.read_csv(f, index_col=False) 20 | name = f.split('.cs')[0] 21 | df.index = [name] 22 | df_list.append(df) 23 | table = pd.concat(df_list) 24 | best_index = table['F1'].argmax() 25 | table.to_csv(options.output, header=True, index=True) 26 | tmove_name = "{}".format(best_index) 27 | model = "_".join(best_index.split('_')[:-2]) 28 | n_feat = model.split('__')[1].split('_')[0] 29 | name = options.store_best 30 | os.mkdir(name) 31 | os.rename(model, os.path.join(name, model)) 32 | 33 | p1 = table.ix[best_index, 'p1'] 34 | p2 = table.ix[best_index, 'p2'] 35 | f1 = open('p1_val', 'w') 36 | f1.write('{}'.format(p1)) 37 | f1.close() 38 | f2 = open('p2_val', 'w') 39 | f2.write('{}'.format(p2)) 40 | f2.close() 41 | f_feat = open('feat_val', 'w') 42 | f_feat.write('{}'.format(n_feat)) 43 | f_feat.close() 44 | -------------------------------------------------------------------------------- /src_RealData/preproc/BinToDistance.py: -------------------------------------------------------------------------------- 1 | from skimage.io import imread, imsave 2 | from glob import glob 3 | from os.path import dirname, join, basename 4 | from shutil import copy 5 | from utils import CheckOrCreate 6 | from scipy.ndimage.morphology import distance_transform_cdt 7 | import numpy as np 8 | 9 | def LoadGT(path): 10 | img = imread(path, dtype='uint8') 11 | return img 12 | 13 | def DistanceWithoutNormalise(bin_image): 14 | res = np.zeros_like(bin_image) 15 | for j in range(1, bin_image.max() + 1): 16 | one_cell = np.zeros_like(bin_image) 17 | one_cell[bin_image == j] = 1 18 | one_cell = distance_transform_cdt(one_cell) 19 | res[bin_image == j] = one_cell[bin_image == j] 20 | res = res.astype('uint8') 21 | return res 22 | 23 | NEW_FOLDER = 'ToAnnotateDistance' 24 | CheckOrCreate(NEW_FOLDER) 25 | for image in glob('ForDataGenTrainTestVal/Slide_*/*.png'): 26 | baseN = basename(image) 27 | Slide_name = dirname(image) 28 | GT_name = baseN.replace('Slide', 'GT') 29 | OLD_FOLDER = dirname(Slide_name) 30 | Slide_N = basename(dirname(image)) 31 | GT_N = Slide_N.replace('Slide_', 'GT_') 32 | 33 | CheckOrCreate(join(NEW_FOLDER, Slide_N)) 34 | CheckOrCreate(join(NEW_FOLDER, GT_N)) 35 | 36 | copy(image, join(NEW_FOLDER, Slide_N, baseN)) 37 | bin_image = LoadGT(join(OLD_FOLDER, GT_N, GT_name)) 38 | res = DistanceWithoutNormalise(bin_image) 39 | imsave(join(NEW_FOLDER, GT_N, GT_name), res) 40 | -------------------------------------------------------------------------------- /src_RealData/preproc/MeanCalculation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from Data.DataGenClass import DataGenMulti 3 | from Data.ImageTransform import ListTransform 4 | from os.path import join 5 | from optparse import OptionParser 6 | 7 | 8 | if __name__ == '__main__': 9 | 10 | parser = OptionParser() 11 | 12 | parser.add_option("--path", dest="path",type="string", 13 | help="path to annotated dataset") 14 | parser.add_option("--output", dest="out",type="string", 15 | help="out path") 16 | 17 | (options, args) = parser.parse_args() 18 | 19 | path = options.path 20 | transf, transf_test = ListTransform() 21 | 22 | size = (1000, 1000) 23 | size_test = (512, 512) 24 | crop = 1 25 | DG = DataGenMulti(path, crop=crop, size=size, transforms=transf_test, 26 | split="train", num="test") 27 | DG_test = DataGenMulti(path, crop=crop, size=size_test, transforms=transf_test, 28 | split="test", num="test") 29 | res = np.zeros(shape=3, dtype='float') 30 | count = 0 31 | for i in range(DG.length): 32 | key = DG.NextKeyRandList(0) 33 | res += np.mean(DG[key][0], axis=(0, 1)) 34 | count += 1 35 | for i in range(DG_test.length): 36 | key = DG_test.NextKeyRandList(0) 37 | res += np.mean(DG_test[key][0], axis=(0, 1)) 38 | count += 1 39 | mean = res / count 40 | np.save(join(options.out, "mean_file.npy"), mean) 41 | 42 | -------------------------------------------------------------------------------- /src_RealData/preproc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PeterJackNaylor/DRFNS/73fc5683db5e9f860846e22c8c0daf73b7103082/src_RealData/preproc/__init__.py -------------------------------------------------------------------------------- /src_RealData/preproc/changescale.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import shutil 3 | from optparse import OptionParser 4 | from utils import CheckOrCreate 5 | from os.path import join 6 | from skimage.io import imread, imsave 7 | 8 | parser = OptionParser() 9 | parser.add_option("--path", dest="path", type="string", 10 | help="Where to find the path") 11 | 12 | (options, args) = parser.parse_args() 13 | 14 | dst = 'ImageFolder' 15 | # CheckOrCreate(dst) 16 | shutil.copytree(options.path, dst, symlinks=False, ignore=None) 17 | 18 | FILES = glob(join(dst, "GT_*", "GT_*.png")) 19 | AND_TEST = glob(join(dst, "GT_test", "test_*.png")) 20 | for f in FILES + AND_TEST: 21 | img = imread(f) 22 | img[img > 0] = 1 23 | imsave(f, img.astype('int8')) 24 | -------------------------------------------------------------------------------- /src_RealData/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from optparse import OptionParser 5 | from skimage.measure import label 6 | from sklearn.metrics import accuracy_score, roc_auc_score 7 | from sklearn.metrics import jaccard_similarity_score, f1_score 8 | from sklearn.metrics import recall_score, precision_score, confusion_matrix 9 | from skimage.morphology import erosion, disk 10 | from os.path import join 11 | import os 12 | from skimage.io import imsave, imread 13 | import numpy as np 14 | import pdb 15 | import time 16 | from progressbar import ProgressBar 17 | from postproc.postprocessing import PostProcess, generate_wsl 18 | 19 | def GetOptions(): 20 | """ 21 | Defines most of the options needed 22 | """ 23 | parser = OptionParser() 24 | parser.add_option("--tf_record", dest="TFRecord", type="string", default="", 25 | help="Where to find the TFrecord file") 26 | parser.add_option("--path", dest="path", type="string", 27 | help="Where to collect the patches") 28 | parser.add_option("--size_train", dest="size_train", type="int", 29 | help="size of the input image to the network") 30 | parser.add_option("--log", dest="log", 31 | help="log dir") 32 | parser.add_option("--learning_rate", dest="lr", type="float", default=0.01, 33 | help="learning_rate") 34 | parser.add_option("--batch_size", dest="bs", type="int", default=1, 35 | help="batch size") 36 | parser.add_option("--epoch", dest="epoch", type="int", default=1, 37 | help="number of epochs") 38 | parser.add_option("--n_features", dest="n_features", type="int", 39 | help="number of channels on first layers") 40 | parser.add_option("--weight_decay", dest="weight_decay", type="float", default=0.00005, 41 | help="weight decay value") 42 | parser.add_option("--dropout", dest="dropout", type="float", 43 | default=0.5, help="dropout value to apply to the FC layers.") 44 | parser.add_option("--mean_file", dest="mean_file", type="str", 45 | help="where to find the mean file to substract to the original image.") 46 | parser.add_option('--n_threads', dest="THREADS", type=int, default=100, 47 | help="number of threads to use for the preprocessing.") 48 | parser.add_option('--crop', dest="crop", type=int, default=4, 49 | help="crop size depending on validation/test/train phase.") 50 | parser.add_option('--split', dest="split", type="str", 51 | help="validation/test/train phase.") 52 | parser.add_option('--p1', dest="p1", type="int", 53 | help="1st input for post processing.") 54 | parser.add_option('--p2', dest="p2", type="float", 55 | help="2nd input for post processing.") 56 | parser.add_option('--iters', dest="iters", type="int") 57 | parser.add_option('--seed', dest="seed", type="int") 58 | parser.add_option('--size_test', dest="size_test", type="int") 59 | parser.add_option('--restore', dest="restore", type="str") 60 | parser.add_option('--save_path', dest="save_path", type="str", default=".") 61 | parser.add_option('--type', dest="type", type ="str", 62 | help="Type for the datagen") 63 | parser.add_option('--UNet', dest='UNet', action='store_true') 64 | parser.add_option('--no-UNet', dest='UNet', action='store_false') 65 | parser.add_option('--output', dest="output", type="str") 66 | parser.add_option('--output_csv', dest="output_csv", type="str") 67 | 68 | (options, args) = parser.parse_args() 69 | 70 | return options 71 | 72 | def ComputeMetrics(prob, batch_labels, p1, p2, rgb=None, save_path=None, ind=0): 73 | """ 74 | Computes all metrics between probability map and corresponding label. 75 | If you give also an rgb image it will save many extra meta data image. 76 | """ 77 | GT = label(batch_labels.copy()) 78 | PRED = PostProcess(prob, p1, p2) 79 | # PRED = label((prob > 0.5).astype('uint8')) 80 | lbl = GT.copy() 81 | pred = PRED.copy() 82 | aji = AJI_fast(lbl, pred) 83 | lbl[lbl > 0] = 1 84 | pred[pred > 0] = 1 85 | l, p = lbl.flatten(), pred.flatten() 86 | acc = accuracy_score(l, p) 87 | roc = roc_auc_score(l, p) 88 | jac = jaccard_similarity_score(l, p) 89 | f1 = f1_score(l, p) 90 | recall = recall_score(l, p) 91 | precision = precision_score(l, p) 92 | if rgb is not None: 93 | xval_n = join(save_path, "xval_{}.png").format(ind) 94 | yval_n = join(save_path, "yval_{}.png").format(ind) 95 | prob_n = join(save_path, "prob_{}.png").format(ind) 96 | pred_n = join(save_path, "pred_{}.png").format(ind) 97 | c_gt_n = join(save_path, "C_gt_{}.png").format(ind) 98 | c_pr_n = join(save_path, "C_pr_{}.png").format(ind) 99 | 100 | imsave(xval_n, rgb) 101 | imsave(yval_n, color_bin(GT)) 102 | imsave(prob_n, prob) 103 | imsave(pred_n, color_bin(PRED)) 104 | imsave(c_gt_n, add_contours(rgb, GT)) 105 | imsave(c_pr_n, add_contours(rgb, PRED)) 106 | 107 | return acc, roc, jac, recall, precision, f1, aji 108 | 109 | def color_bin(bin_labl): 110 | """ 111 | Colors bin image so that nuclei come out nicer. 112 | """ 113 | dim = bin_labl.shape 114 | x, y = dim[0], dim[1] 115 | res = np.zeros(shape=(x, y, 3)) 116 | for i in range(1, bin_labl.max() + 1): 117 | rgb = np.random.normal(loc = 125, scale=100, size=3) 118 | rgb[rgb < 0 ] = 0 119 | rgb[rgb > 255] = 255 120 | rgb = rgb.astype(np.uint8) 121 | res[bin_labl == i] = rgb 122 | return res.astype(np.uint8) 123 | 124 | def add_contours(rgb_image, contour, ds = 2): 125 | """ 126 | Adds contours to images. 127 | The image has to be a binary image 128 | """ 129 | rgb = rgb_image.copy() 130 | contour[contour > 0] = 1 131 | boundery = contour - erosion(contour, disk(ds)) 132 | rgb[boundery > 0] = np.array([0, 0, 0]) 133 | return rgb 134 | 135 | def CheckOrCreate(path): 136 | """ 137 | If path exists, does nothing otherwise it creates it. 138 | """ 139 | if not os.path.isdir(path): 140 | os.makedirs(path) 141 | 142 | def Intersection(A, B): 143 | """ 144 | Returns the pixel count corresponding to the intersection 145 | between A and B. 146 | """ 147 | C = A + B 148 | C[C != 2] = 0 149 | C[C == 2] = 1 150 | return C 151 | 152 | def Union(A, B): 153 | """ 154 | Returns the pixel count corresponding to the union 155 | between A and B. 156 | """ 157 | C = A + B 158 | C[C > 0] = 1 159 | return C 160 | 161 | 162 | def AssociatedCell(G_i, S): 163 | """ 164 | Returns the indice of the associated prediction cell for a certain 165 | ground truth element. Maybe do something if no associated cell in the 166 | prediction mask touches the GT. 167 | """ 168 | def g(indice): 169 | S_indice = np.zeros_like(S) 170 | S_indice[ S == indice ] = 1 171 | NUM = float(Intersection(G_i, S_indice).sum()) 172 | DEN = float(Union(G_i, S_indice).sum()) 173 | return NUM / DEN 174 | res = map(g, range(1, S.max() + 1)) 175 | indice = np.array(res).argmax() + 1 176 | return indice 177 | 178 | pbar = ProgressBar() 179 | 180 | def AJI(G, S): 181 | """ 182 | AJI as described in the paper, AJI is more abstract implementation but 100times faster. 183 | """ 184 | G = label(G, background=0) 185 | S = label(S, background=0) 186 | 187 | C = 0 188 | U = 0 189 | USED = np.zeros(S.max()) 190 | 191 | for i in pbar(range(1, G.max() + 1)): 192 | only_ground_truth = np.zeros_like(G) 193 | only_ground_truth[ G == i ] = 1 194 | j = AssociatedCell(only_ground_truth, S) 195 | only_prediction = np.zeros_like(S) 196 | only_prediction[ S == j ] = 1 197 | C += Intersection(only_prediction, only_ground_truth).sum() 198 | U += Union(only_prediction, only_ground_truth).sum() 199 | USED[j - 1] = 1 200 | 201 | def h(indice): 202 | if USED[indice - 1] == 1: 203 | return 0 204 | else: 205 | only_prediction = np.zeros_like(S) 206 | only_prediction[ S == indice ] = 1 207 | return only_prediction.sum() 208 | U_sum = map(h, range(1, S.max() + 1)) 209 | U += np.sum(U_sum) 210 | return float(C) / float(U) 211 | 212 | 213 | 214 | def AJI_fast(G, S): 215 | """ 216 | AJI as described in the paper, but a much faster implementation. 217 | """ 218 | G = label(G, background=0) 219 | S = label(S, background=0) 220 | if S.sum() == 0: 221 | return 0. 222 | C = 0 223 | U = 0 224 | USED = np.zeros(S.max()) 225 | 226 | G_flat = G.flatten() 227 | S_flat = S.flatten() 228 | G_max = np.max(G_flat) 229 | S_max = np.max(S_flat) 230 | m_labels = max(G_max, S_max) + 1 231 | cm = confusion_matrix(G_flat, S_flat, labels=range(m_labels)).astype(np.float) 232 | LIGNE_J = np.zeros(S_max) 233 | for j in range(1, S_max + 1): 234 | LIGNE_J[j - 1] = cm[:, j].sum() 235 | 236 | for i in range(1, G_max + 1): 237 | LIGNE_I_sum = cm[i, :].sum() 238 | def h(indice): 239 | LIGNE_J_sum = LIGNE_J[indice - 1] 240 | inter = cm[i, indice] 241 | 242 | union = LIGNE_I_sum + LIGNE_J_sum - inter 243 | return inter / union 244 | 245 | JI_ligne = map(h, range(1, S_max + 1)) 246 | best_indice = np.argmax(JI_ligne) + 1 247 | C += cm[i, best_indice] 248 | U += LIGNE_J[best_indice - 1] + LIGNE_I_sum - cm[i, best_indice] 249 | USED[best_indice - 1] = 1 250 | 251 | U_sum = ((1 - USED) * LIGNE_J).sum() 252 | U += U_sum 253 | return float(C) / float(U) 254 | 255 | 256 | --------------------------------------------------------------------------------