├── images ├── CapsNet.jpg ├── chart.jpg ├── vallina.jpg ├── Optimized CNN.jpg └── 00000017_001_small.png ├── data sample └── images │ ├── 00000013_005.png │ ├── 00000017_001.png │ ├── 00000030_001.png │ ├── 00000032_001.png │ ├── 00000040_003.png │ ├── 00000042_002.png │ ├── 00000057_001.png │ ├── 00000061_002.png │ ├── 00000061_019.png │ └── 00000079_000.png ├── SampleDataset Log ├── log_bCNN_rgb.csv ├── log_bCNN_gray.csv ├── log_pretrained_extradata_CNN.csv ├── log_pretrained_extradata_stn_CNN.csv ├── CapsNet_log.csv └── log_pretrained_CNN.csv ├── .github └── FUNDING.yml ├── FullDataset Log ├── log_bCNN_rgb.csv ├── log_pretrained_extradata_stn_CNN.csv ├── log_pretrained_CNN.csv ├── CapsNetBasic_log.csv └── CapsNet_log.csv ├── utils.py ├── README.md ├── spatial_transformer.py ├── Data preprocessing - SampleDataset.ipynb ├── capsulelayers.py ├── vanilla CNN - SampleDataset.ipynb └── vanilla CNN - FullDataset.ipynb /images/CapsNet.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dawn0123/classify_lung_diseases/HEAD/images/CapsNet.jpg -------------------------------------------------------------------------------- /images/chart.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dawn0123/classify_lung_diseases/HEAD/images/chart.jpg -------------------------------------------------------------------------------- /images/vallina.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dawn0123/classify_lung_diseases/HEAD/images/vallina.jpg -------------------------------------------------------------------------------- /images/Optimized CNN.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dawn0123/classify_lung_diseases/HEAD/images/Optimized CNN.jpg -------------------------------------------------------------------------------- /images/00000017_001_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dawn0123/classify_lung_diseases/HEAD/images/00000017_001_small.png -------------------------------------------------------------------------------- /data sample/images/00000013_005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dawn0123/classify_lung_diseases/HEAD/data sample/images/00000013_005.png -------------------------------------------------------------------------------- /data sample/images/00000017_001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dawn0123/classify_lung_diseases/HEAD/data sample/images/00000017_001.png -------------------------------------------------------------------------------- /data sample/images/00000030_001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dawn0123/classify_lung_diseases/HEAD/data sample/images/00000030_001.png -------------------------------------------------------------------------------- /data sample/images/00000032_001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dawn0123/classify_lung_diseases/HEAD/data sample/images/00000032_001.png -------------------------------------------------------------------------------- /data sample/images/00000040_003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dawn0123/classify_lung_diseases/HEAD/data sample/images/00000040_003.png -------------------------------------------------------------------------------- /data sample/images/00000042_002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dawn0123/classify_lung_diseases/HEAD/data sample/images/00000042_002.png -------------------------------------------------------------------------------- /data sample/images/00000057_001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dawn0123/classify_lung_diseases/HEAD/data sample/images/00000057_001.png -------------------------------------------------------------------------------- /data sample/images/00000061_002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dawn0123/classify_lung_diseases/HEAD/data sample/images/00000061_002.png -------------------------------------------------------------------------------- /data sample/images/00000061_019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dawn0123/classify_lung_diseases/HEAD/data sample/images/00000061_019.png -------------------------------------------------------------------------------- /data sample/images/00000079_000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dawn0123/classify_lung_diseases/HEAD/data sample/images/00000079_000.png -------------------------------------------------------------------------------- /SampleDataset Log/log_bCNN_rgb.csv: -------------------------------------------------------------------------------- 1 | epoch,acc,fbeta_score,loss,precision,recall,val_acc,val_fbeta_score,val_loss,val_precision,val_recall 2 | 0,0.5452941176,0,0.6889533999,0,0,0.5354545452,0,0.6906476062,0,0 3 | 1,0.5452941176,0,0.6889169875,0,0,0.5354545452,0,0.6907717941,0,0 4 | 2,0.5452941176,0,0.6889572777,0,0,0.5354545452,0,0.6907827871,0,0 5 | 3,0.5452941176,0,0.6889632463,0,0,0.5354545452,0,0.6906308712,0,0 6 | 4,0.5452941176,0,0.6889291063,0,0,0.5354545452,0,0.6907129791,0,0 7 | 5,0.5452941176,0,0.6889408926,0,0,0.5354545452,0,0.6908171818,0,0 8 | 6,0.5452941176,0,0.6889159724,0,0,0.5354545452,0,0.690645408,0,0 9 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: doduy 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: https://www.buymeacoffee.com/duydo 13 | -------------------------------------------------------------------------------- /SampleDataset Log/log_bCNN_gray.csv: -------------------------------------------------------------------------------- 1 | epoch,acc,fbeta_score,loss,precision,recall,val_acc,val_fbeta_score,val_loss,val_precision,val_recall 2 | 0,0.54205882352941182,0.017420550795162425,0.69158509086160103,0.018934715775882495,0.025474158034605138,0.53545454523780134,0.0,0.69111661715940997,0.0,0.0 3 | 1,0.54529411764705882,0.0,0.689841628074646,0.0,0.0,0.53545454523780134,0.0,0.69066418214277792,0.0,0.0 4 | 2,0.54529411764705882,0.0,0.68933707868351657,0.0,0.0,0.53545454523780134,0.0,0.69063535408540211,0.0,0.0 5 | 3,0.54529411764705882,0.0,0.68917345664080454,0.0,0.0,0.53545454523780134,0.0,0.69071199742230505,0.0,0.0 6 | 4,0.54529411764705882,0.0,0.68913845146403596,0.0,0.0,0.53545454523780134,0.0,0.69077103679830376,0.0,0.0 7 | 5,0.54529411764705882,0.0,0.68906468601787796,0.0,0.0,0.53545454523780134,0.0,0.6907148636471141,0.0,0.0 8 | -------------------------------------------------------------------------------- /FullDataset Log/log_bCNN_rgb.csv: -------------------------------------------------------------------------------- 1 | epoch,acc,fbeta_score,loss,precision,recall,val_acc,val_fbeta_score,val_loss,val_precision,val_recall 2 | 0,0.6177790179,0.5386506745,0.6561404834,0.5974430825,0.4632201908,0.6148214286,0.5805270764,0.6567168893,0.5573584282,0.7223831506 3 | 1,0.6412165179,0.6006557989,0.6415947197,0.6352549547,0.5353033635,0.6459821429,0.6024931234,0.6383138497,0.6060842178,0.6121549356 4 | 2,0.6486830357,0.611487572,0.6354371689,0.6416548027,0.5549154656,0.655625,0.6105580592,0.630429735,0.6347088113,0.553317985 5 | 3,0.6538169643,0.6173962522,0.6302867381,0.6454706586,0.5639977905,0.6525,0.6032367876,0.6303255286,0.6626741399,0.4710843952 6 | 4,0.6602566964,0.6259580519,0.6267321785,0.6505693664,0.5799647066,0.6641964286,0.6212870885,0.6257025086,0.6349920879,0.5960478524 7 | 5,0.6626897321,0.6286710487,0.6232196966,0.6519224093,0.5838990266,0.6627678571,0.6196908268,0.6225213589,0.6655496187,0.5105087916 8 | 6,0.6668861607,0.6333232983,0.6204124281,0.6558316377,0.5893091821,0.6624107143,0.6217035891,0.6236729975,0.6175520038,0.6638687555 9 | 7,0.6692075893,0.6363973958,0.6179867112,0.6574121728,0.5978735098,0.6617857143,0.6205778218,0.6242765649,0.6156787112,0.6677082909 10 | 8,0.6701450893,0.6362396245,0.6151568225,0.6563695704,0.5988331204,0.6701785714,0.6273954472,0.6170851012,0.6748169265,0.5153670008 11 | 9,0.67390625,0.6416461812,0.6124990746,0.6617951331,0.6043337277,0.6746428571,0.6353322488,0.616449727,0.658534082,0.5816819168 12 | 10,0.6762388393,0.6443485806,0.6102004521,0.6636122449,0.6092205987,0.6685714286,0.628634784,0.6159811424,0.6263339555,0.6632383823 13 | 11,0.6782477679,0.6460882632,0.6074418454,0.6644314631,0.6139909088,0.6752678571,0.6363492554,0.6127620067,0.6399746943,0.6476058971 14 | 12,0.681171875,0.6498036205,0.6052222207,0.6677827419,0.6174215108,0.6748214286,0.63479978,0.6094208223,0.6545126119,0.5914182316 15 | 13,0.6835714286,0.6520364265,0.6021596814,0.6698982288,0.6209990523,0.6644642857,0.6248786549,0.618718201,0.6155199111,0.6913215696 16 | 14,0.685234375,0.6544460262,0.5999377874,0.6724498862,0.6239075368,0.6701785714,0.6307483191,0.6146799781,0.6373010629,0.6320903125 17 | 15,0.6882924107,0.6586153701,0.5970020791,0.675440178,0.6311963792,0.6778571429,0.6389423406,0.6095196782,0.6441667758,0.6437159697 18 | 16,0.6915848214,0.662176338,0.5937102495,0.6792800181,0.6334463308,0.6790178571,0.6402863067,0.6132769273,0.6844196773,0.5337847916 19 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | import csv 4 | import math 5 | 6 | def plot_log(filename, show=True): 7 | # load data 8 | keys = [] 9 | values = [] 10 | with open(filename, 'r') as f: 11 | reader = csv.DictReader(f) 12 | for row in reader: 13 | if keys == []: 14 | for key, value in row.items(): 15 | keys.append(key) 16 | values.append(float(value)) 17 | continue 18 | 19 | for _, value in row.items(): 20 | values.append(float(value)) 21 | 22 | values = np.reshape(values, newshape=(-1, len(keys))) 23 | values[:,0] += 1 24 | 25 | fig = plt.figure(figsize=(4,6)) 26 | fig.subplots_adjust(top=0.95, bottom=0.05, right=0.95) 27 | fig.add_subplot(211) 28 | for i, key in enumerate(keys): 29 | if key.find('loss') >= 0 and not key.find('val') >= 0: # training loss 30 | plt.plot(values[:, 0], values[:, i], label=key) 31 | plt.legend() 32 | plt.title('Training loss') 33 | 34 | fig.add_subplot(212) 35 | for i, key in enumerate(keys): 36 | if key.find('acc') >= 0: # acc 37 | plt.plot(values[:, 0], values[:, i], label=key) 38 | plt.legend() 39 | plt.title('Training and validation accuracy') 40 | 41 | # fig.savefig('result/log.png') 42 | if show: 43 | plt.show() 44 | 45 | 46 | def combine_images(generated_images, height=None, width=None): 47 | num = generated_images.shape[0] 48 | if width is None and height is None: 49 | width = int(math.sqrt(num)) 50 | height = int(math.ceil(float(num)/width)) 51 | elif width is not None and height is None: # height not given 52 | height = int(math.ceil(float(num)/width)) 53 | elif height is not None and width is None: # width not given 54 | width = int(math.ceil(float(num)/height)) 55 | 56 | shape = generated_images.shape[1:3] 57 | image = np.zeros((height*shape[0], width*shape[1]), 58 | dtype=generated_images.dtype) 59 | for index, img in enumerate(generated_images): 60 | i = int(index/width) 61 | j = index % width 62 | image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \ 63 | img[:, :, 0] 64 | return image 65 | 66 | if __name__=="__main__": 67 | plot_log('result/log.csv') 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /FullDataset Log/log_pretrained_extradata_stn_CNN.csv: -------------------------------------------------------------------------------- 1 | epoch,acc,fbeta_score,loss,precision,recall,val_acc,val_fbeta_score,val_loss,val_precision,val_recall 2 | 0,0.6433705357,0.606827141,0.644004979,0.6296350707,0.568511455,0.6750892857,0.6417777568,0.6117657816,0.6260079176,0.7378581279 3 | 1,0.6759598214,0.6441278785,0.6159628053,0.6647343584,0.6082652122,0.691875,0.6623803228,0.5952583789,0.6655518265,0.6749131134 4 | 2,0.6818526786,0.6511089651,0.6098953197,0.6703847418,0.6165134504,0.6945535714,0.6629181119,0.59410755,0.657492414,0.7096231996 5 | 3,0.6858035714,0.6564224436,0.6061968212,0.6746499242,0.6245757858,0.695,0.6660019566,0.5909016198,0.672393455,0.6646228709 6 | 4,0.6898549107,0.6601768706,0.6013707069,0.6768638748,0.6320583098,0.69875,0.6738470113,0.5892657177,0.6999090261,0.6102929111 7 | 5,0.6922209821,0.6632680464,0.5980793631,0.6793944952,0.6358640046,0.6996428571,0.6723688069,0.5889285081,0.6850050129,0.6498667787 8 | 6,0.6949888393,0.6660706334,0.5969122961,0.6819024986,0.6390870021,0.7001785714,0.671727505,0.5868884751,0.6818388769,0.658200434 9 | 7,0.6950223214,0.6664602499,0.5954480429,0.6829884909,0.6373863858,0.7000892857,0.670837555,0.5867267475,0.6825731966,0.6517059623 10 | 8,0.6968303571,0.6684101428,0.5946692328,0.6849095886,0.6392685586,0.7045535714,0.6771999932,0.585861764,0.6858409586,0.6694560301 11 | 9,0.6982254464,0.6698569038,0.592244897,0.6856790254,0.6414952115,0.7028571429,0.6737729474,0.5874182601,0.682295861,0.6655389947 12 | 10,0.6997433036,0.6713167114,0.5896426111,0.6863063719,0.6466843469,0.6957142857,0.6634636697,0.5857509109,0.6545690194,0.7253536847 13 | 11,0.7018638393,0.6740568228,0.5882245679,0.6880789199,0.6518941659,0.7025,0.6712242673,0.5906474084,0.6678851537,0.7083117699 14 | 12,0.7020200893,0.6741936766,0.5877912395,0.6884021113,0.6511598381,0.7040178571,0.6772869911,0.5815721386,0.6924503132,0.6463362403 15 | 13,0.7030691964,0.6747234924,0.5853738394,0.6897755574,0.6499794544,0.7045535714,0.6812548486,0.5822989676,0.7109024567,0.6088890095 16 | 14,0.7037946429,0.6758338093,0.5852408169,0.6912716901,0.6496556995,0.7091964286,0.6866951036,0.5818049651,0.7162198149,0.6137113285 17 | 15,0.7034709821,0.6767402032,0.5851002679,0.6928432949,0.6482835422,0.7066964286,0.6781579984,0.580472147,0.6860266338,0.6715104289 18 | 16,0.7064397321,0.6793232491,0.5830834405,0.6946879706,0.652332694,0.7035714286,0.674013966,0.5844901007,0.6777562135,0.683859978 19 | 17,0.7063950893,0.6788849064,0.5825721574,0.6929895401,0.6550349756,0.7041964286,0.673136076,0.5805599592,0.6711237983,0.7055387334 20 | 18,0.70796875,0.6807134869,0.5809952305,0.6942501476,0.65794537,0.7099107143,0.6881338269,0.5794417344,0.7188703191,0.6119631298 21 | 19,0.7084821429,0.6812101476,0.577897298,0.6955909147,0.656644689,0.7108035714,0.6860205713,0.5767416284,0.7037970119,0.6468023708 22 | -------------------------------------------------------------------------------- /FullDataset Log/log_pretrained_CNN.csv: -------------------------------------------------------------------------------- 1 | epoch,acc,fbeta_score,loss,precision,recall,val_acc,val_fbeta_score,val_loss,val_precision,val_recall 2 | 0,0.635970982143,0.591053643434,0.643349900448,0.626120905661,0.530210411268,0.679375,0.651867729127,0.608484711817,0.696576484357,0.546841056475 3 | 1,0.669285714286,0.63638479377,0.620337603433,0.66201374989,0.589934698214,0.686428571429,0.66091887555,0.599252312013,0.702254353208,0.562250821676 4 | 2,0.675334821429,0.643939219814,0.614072720462,0.665799178657,0.604705330868,0.694464285714,0.668845196622,0.596943780695,0.697186935501,0.599895583136 5 | 3,0.678973214286,0.648025740774,0.61063301245,0.666844693039,0.613625235398,0.687053571429,0.661002421847,0.598165822625,0.70959956382,0.545265223895 6 | 4,0.682399553571,0.651977766438,0.608081463182,0.670761127935,0.617918852131,0.696071428571,0.66672830488,0.595203837156,0.672570065643,0.667413902198 7 | 5,0.68359375,0.653542766363,0.606103459969,0.672625494988,0.618472392979,0.695892857143,0.666515616945,0.592701400774,0.671527543451,0.671013373988 8 | 6,0.683705357143,0.653485037014,0.603725311788,0.673514930909,0.616564621925,0.695089285714,0.670247367876,0.591533994249,0.700838973586,0.592610327389 9 | 7,0.685558035714,0.655448664368,0.602686744939,0.673808791776,0.620683654765,0.696696428571,0.670639496488,0.592265090091,0.692403849576,0.618565468448 10 | 8,0.688046875,0.657961240863,0.601731579815,0.675256170056,0.627409563873,0.697767857143,0.672142378858,0.591167051707,0.696025858841,0.614748275493 11 | 9,0.68828125,0.658227533863,0.600291013239,0.675622779722,0.625940713159,0.696785714286,0.666368269069,0.591178885528,0.668177728781,0.682145910348 12 | 10,0.690111607143,0.66057028209,0.599465306593,0.677595569891,0.629089863452,0.6975,0.667152820059,0.592249507393,0.669076864507,0.682002081956 13 | 11,0.691752232143,0.661520785583,0.597646372116,0.679080401185,0.631218052091,0.685625,0.661457589056,0.601983797721,0.740479374443,0.489270496624 14 | 12,0.692232142857,0.662995156353,0.597939613206,0.680829776062,0.628886399296,0.70125,0.67483607356,0.58741309166,0.691254036469,0.639681392823 15 | 13,0.693113839286,0.664533910807,0.596982922618,0.682875497338,0.630863658302,0.697857142857,0.671169082395,0.586260054197,0.688830576198,0.632390896337 16 | 14,0.692399553571,0.663682696186,0.596263061336,0.680327967608,0.633326174803,0.700446428571,0.671680691242,0.586133022564,0.684447795791,0.648081343004 17 | 15,0.694955357143,0.66651288396,0.595092779472,0.681906055792,0.639971819414,0.70125,0.67625013956,0.585961831042,0.70247002029,0.611895676894 18 | 16,0.69390625,0.664665958072,0.59506952621,0.680674381754,0.638062689786,0.698125,0.674852831364,0.588498389806,0.716525242094,0.571123302877 19 | 17,0.69546875,0.666268449846,0.593827833118,0.683033510329,0.635367047486,0.701071428571,0.670759862874,0.590240892427,0.669798341862,0.697540981599 20 | 18,0.695993303571,0.667212238285,0.592807079341,0.683769631386,0.63691160082,0.702142857143,0.675679355413,0.588691526992,0.693341882144,0.637373978581 21 | 19,0.696696428571,0.668135398829,0.592458417001,0.684236841734,0.639203407003,0.702410714286,0.675308369803,0.583563267589,0.690152602877,0.646714306048 22 | -------------------------------------------------------------------------------- /FullDataset Log/CapsNetBasic_log.csv: -------------------------------------------------------------------------------- 1 | epoch,capsnet_binary_accuracy,capsnet_fbeta_score,capsnet_loss,capsnet_precision,capsnet_recall,decoder_binary_accuracy,decoder_fbeta_score,decoder_loss,decoder_precision,decoder_recall,loss,val_capsnet_binary_accuracy,val_capsnet_fbeta_score,val_capsnet_loss,val_capsnet_precision,val_capsnet_recall,val_decoder_binary_accuracy,val_decoder_fbeta_score,val_decoder_loss,val_decoder_precision,val_decoder_recall,val_loss 2 | 0,0.5812834821,0.419877565,0.6850839681,0.52311746,0.3072689789,0.0177143206,0.9702252487,0.6931162065,0.9669968281,0.9836562477,1.3782001755,0.6035714286,0.5101954912,0.6649704496,0.6418044885,0.3063546592,0.0170326451,0.9632478694,0.6930965991,0.9582627508,0.9841570101,1.3580670476 3 | 1,0.6022544643,0.5415249526,0.6667785529,0.5960183373,0.4459431742,0.0172436523,0.9698514304,0.6930901111,0.966470077,0.9839353929,1.3598686629,0.6151785714,0.5557879618,0.6590087768,0.6100974996,0.4362498447,0.0171470424,0.9632545558,0.6930723281,0.9582579803,0.9842117972,1.352081104 4 | 2,0.6067522321,0.5497387045,0.6634429719,0.5984147459,0.4619553043,0.0173515974,0.970224656,0.693070751,0.9668298934,0.9843654609,1.3565137224,0.616875,0.5597558112,0.6563696914,0.6114250647,0.4446691353,0.0171651786,0.9631511896,0.6930524385,0.9582241055,0.9838154714,1.3494221272 5 | 3,0.6091964286,0.5556359762,0.6619076519,0.6001930588,0.4694988412,0.0180025809,0.9696710953,0.693054223,0.9662221106,0.9840464248,1.3549618743,0.6199107143,0.5650182249,0.6548419014,0.614022995,0.4546268646,0.0183147321,0.9635769129,0.6930340297,0.9582834962,0.9857930223,1.3478759319 6 | 4,0.6101674107,0.556990365,0.6606000541,0.6009333461,0.4724650374,0.0189737374,0.9698048716,0.6930400444,0.9662108165,0.9847816283,1.3536400995,0.6229464286,0.5705767977,0.6536211245,0.6155474319,0.4666798495,0.0180189732,0.9632413207,0.6930212654,0.9582481188,0.9841863227,1.3466423879 7 | 5,0.6138392857,0.5630162288,0.6586354627,0.6046364338,0.4797259121,0.0195197405,0.9704371522,0.6930280836,0.9668055596,0.9855458923,1.351663547,0.6245535714,0.5714056115,0.6524844984,0.6262454118,0.4492402432,0.0189760045,0.9635639397,0.6930068627,0.9583067545,0.9856253488,1.3454913647 8 | 6,0.617109375,0.5678135592,0.6571139084,0.607176159,0.4885518316,0.0196025739,0.9701177343,0.6930170621,0.9665771363,0.9848620684,1.3501309694,0.6163392857,0.5407781426,0.6552665218,0.6519447841,0.3487137047,0.0195800781,0.9636108175,0.6929927353,0.9583522941,0.9856743264,1.3482592576 9 | 7,0.6180357143,0.57004543,0.6563217701,0.6102186664,0.486848869,0.0202082171,0.9700672913,0.6930074705,0.9663872608,0.9853916376,1.3493292409,0.626875,0.5785514483,0.6510692288,0.6194930891,0.4814248909,0.0199553571,0.9638080369,0.6929850207,0.9583580128,0.98668196,1.3440542521 10 | 8,0.61875,0.5711229064,0.6563567093,0.6091405722,0.4901548124,0.0205339704,0.9704733762,0.6929989923,0.9667744193,0.9858680261,1.3493557013,0.6142857143,0.5314444663,0.6588688529,0.6594279123,0.3260890907,0.0206752232,0.9638753448,0.6929744092,0.9583673329,0.9869966034,1.3518432644 11 | 9,0.6214620536,0.5744765517,0.6553598474,0.6128447165,0.4939647606,0.0208986119,0.9704094166,0.6929908309,0.9666654189,0.9860069546,1.3483506793,0.6225,0.561233149,0.6518824186,0.641436476,0.4013976838,0.0208412388,0.9637223135,0.6929656834,0.9583760732,0.9861590772,1.3448480981 12 | -------------------------------------------------------------------------------- /SampleDataset Log/log_pretrained_extradata_CNN.csv: -------------------------------------------------------------------------------- 1 | epoch,binary_accuracy,fbeta_score_1,fbeta_score_2,fbeta_score_3,loss,precision_1,precision_2,precision_3,recall_1,recall_2,recall_3,val_binary_accuracy,val_fbeta_score_1,val_fbeta_score_2,val_fbeta_score_3,val_loss,val_precision_1,val_precision_2,val_precision_3,val_recall_1,val_recall_2,val_recall_3 2 | 0,0.54529411764705882,0.50010541649425733,0.39698555932325474,0.2098102502261891,0.70157888047835404,0.47846905021106495,0.49032567585215847,0.33373907173381134,0.7320749528969035,0.3659721560337964,0.13149450586122624,0.47090909090909089,0.51533396352421157,0.51736689112403178,0.56924184929240829,0.72926406773653896,0.46181818181818179,0.46495432506908069,0.53158905224366626,1.0,0.97917780442671343,0.83131492983211175 3 | 1,0.59529411764705886,0.55018575065276198,0.51414689358542942,0.35037625551223756,0.67276278018951419,0.53324669557459214,0.58974518579595225,0.52950439803740557,0.75904924434774057,0.45986414313316343,0.21467758539844961,0.64363636363636367,0.57831396558068016,0.61481496529145674,0.62099796988747336,0.64394684076309205,0.53858718091791324,0.59895660053599964,0.64806363279169255,0.85989459189501671,0.72534265279769894,0.56026108904318372 4 | 2,0.62852941176470589,0.58691396993749279,0.57435101943857525,0.46647512253592999,0.65536156345816221,0.57553752225988053,0.63220113936592548,0.64789107168422022,0.74426292251138126,0.5395278962219463,0.30378244613899902,0.55909090930765326,0.52115117398175326,0.55771685535257509,0.62425520701841875,0.68075691548260775,0.46827418847517532,0.51469714858315208,0.60333628914572979,0.98591326800259682,0.87749630732969808,0.76181666742671617 5 | 3,0.64147058823529413,0.58321033730226401,0.59736734810997461,0.53731772282544299,0.64400507618399228,0.57448839075425095,0.64546548001906456,0.67596062435823323,0.70688462523853079,0.54544416427612308,0.36742763564867131,0.68090909069234673,0.62534199042753735,0.64993242350491609,0.63235385981473058,0.61386448166587138,0.602754035429521,0.66741150769320401,0.70383355053988372,0.77028300176967279,0.62361531203443354,0.47753337372433058 6 | 4,0.64558823529411768,0.59860244049745448,0.60339565052705657,0.56133037216523118,0.63594650731367219,0.5894838958628037,0.64866261005401615,0.71175529788522163,0.72305797198239496,0.55724643167327437,0.38817568663288565,0.67636363614689221,0.6488530572977933,0.6453172701055353,0.62968355005437682,0.6075954218344255,0.64724479328502305,0.6777971050956032,0.6972817958484997,0.6912844616716558,0.56775431752204897,0.4784028498692946 7 | 5,0.65558823529411769,0.614511262178421,0.61664122013484735,0.57814602697596829,0.62896492467207066,0.60682613134384156,0.65275812745094297,0.70734179693109844,0.73107504606246954,0.58184882991454179,0.41753459390471964,0.67181818203492594,0.64494439710270279,0.63578874089501125,0.60886771288785069,0.61376092889092182,0.64189544677734378,0.6890040961178866,0.73203306458213113,0.68567124388434675,0.50866155320947826,0.38937777508388866 8 | 6,0.67176470588235293,0.6238640359569998,0.63557584937881018,0.59992028096142935,0.61844616903978233,0.61649218895856073,0.67407886280732998,0.70889153340283562,0.73047929932089417,0.59911609285018019,0.44595210874781888,0.66909090887416489,0.64674885121258818,0.63820697524330827,0.56348717364397916,0.61602819356051364,0.6480017176541415,0.70107081239873714,0.75488390228965063,0.674365419474515,0.49017075441100383,0.30475706696510313 9 | 7,0.68294117647058827,0.62488969564437868,0.64319171554902022,0.60990407537011537,0.60827918866101438,0.61781573646208821,0.68122672024895159,0.72828136107500863,0.71937771348392265,0.60390872913248395,0.45483989238739014,0.58181818192655388,0.48091441609642721,0.35346689365126871,0.21279955863952638,0.78733144326643512,0.75986147186972874,0.73160171118649564,0.47175755587491119,0.2156346056136218,0.13109766784039412,0.075006225434216589 10 | -------------------------------------------------------------------------------- /SampleDataset Log/log_pretrained_extradata_stn_CNN.csv: -------------------------------------------------------------------------------- 1 | epoch,binary_accuracy,fbeta_score_1,fbeta_score_2,fbeta_score_3,loss,precision_1,precision_2,precision_3,recall_1,recall_2,recall_3,val_binary_accuracy,val_fbeta_score_1,val_fbeta_score_2,val_fbeta_score_3,val_loss,val_precision_1,val_precision_2,val_precision_3,val_recall_1,val_recall_2,val_recall_3 2 | 0,0.5291176471,0.4948684085,0.4734873144,0.4293279282,0.7860639564,0.4770355913,0.4906423529,0.5074351194,0.6240776376,0.4579474817,0.30534815,0.560909091,0.5840411178,0.3973683961,0,0.6622885489,0.5510671993,0.557090914,0,0.8045604597,0.2082604222,0 3 | 1,0.5591176471,0.5262217562,0.4977622193,0.4389819417,0.7122719232,0.5052910017,0.5235498664,0.5552654221,0.6610600428,0.4474099576,0.2655760634,0.6236363634,0.611125521,0.5624346421,0.1295643256,0.6453593395,0.6013877791,0.6529943666,0.4072726926,0.6869489117,0.380995596,0.0369938008 4 | 2,0.5882352941,0.5526220309,0.5403339899,0.4611366558,0.6884518703,0.5267964062,0.5577130825,0.5789883842,0.7261991952,0.5182860456,0.2898107196,0.6336363634,0.6297961918,0.5835532951,0.2490095794,0.6355488645,0.623007992,0.652951661,0.5308744452,0.6869811867,0.4306025643,0.0895825776 5 | 3,0.6105882353,0.5727684636,0.5668266104,0.4963725947,0.6636590144,0.5471778342,0.5840948974,0.6040591905,0.737354099,0.5411044401,0.3223898519,0.6499999999,0.6431130763,0.6100552297,0.4606089898,0.631596919,0.6452689752,0.6726128873,0.6750476222,0.6611636025,0.4655458634,0.2228111841 6 | 4,0.6235294118,0.5708756799,0.5840678148,0.5061266693,0.6629603737,0.5421245372,0.5947905857,0.6415891361,0.7558109373,0.5815788906,0.3109303993,0.6309090909,0.6355006285,0.5618793351,0.3379328629,0.6447034196,0.6502440288,0.6554949778,0.6736969545,0.6021461029,0.3792736354,0.1302810828 7 | 5,0.6361764706,0.5913460419,0.6023645219,0.5293285818,0.6420808557,0.565684132,0.6216840872,0.6706275786,0.7566254253,0.5720829783,0.3199714131,0.6763636366,0.6196758207,0.64334103,0.478556008,0.6277111606,0.6029419439,0.6770145512,0.6398961111,0.7267144413,0.5602550576,0.2599851919 8 | 6,0.64,0.5882683248,0.6023104722,0.5640312335,0.636434495,0.5578925837,0.614709073,0.6737144223,0.7850761977,0.589807812,0.3723302717,0.6636363639,0.6273880575,0.628712207,0.4102360982,0.628507328,0.6152214007,0.6825639863,0.6350476209,0.7107735948,0.5005488019,0.1884386918 9 | 7,0.6441176471,0.5909201121,0.6017604348,0.5334679264,0.6388929319,0.5644448322,0.622110494,0.6871935477,0.7614843949,0.5637710834,0.3161400353,0.6763636366,0.6157423754,0.6452073698,0.6043251352,0.6143034157,0.5882904105,0.6632563331,0.7138260607,0.7873941192,0.6081931747,0.4010576748 10 | 8,0.6423529412,0.5988272741,0.6060973636,0.556261388,0.6346050704,0.569961235,0.6150978767,0.6760325787,0.7887131761,0.6062545476,0.3600848272,0.6663636366,0.6545777952,0.6325311663,0.5142697934,0.6247146756,0.6649139647,0.6999317559,0.6883788607,0.645168009,0.483284825,0.2756088151 11 | 9,0.6585294118,0.5965161963,0.6187274397,0.5771112589,0.6291027651,0.5713034565,0.6366625374,0.6970924052,0.7645239334,0.5997196918,0.3779517821,0.6854545457,0.6251281147,0.658441047,0.5940724089,0.6082561942,0.5987732514,0.6770479722,0.6994682139,0.7863554974,0.6195663348,0.4005364192 12 | 10,0.6647058824,0.6102660887,0.6310433456,0.58647103,0.6163415677,0.5846500057,0.6440189841,0.6958706665,0.7761099609,0.6119595465,0.3936816204,0.6754545455,0.6148748196,0.6409533039,0.6603116079,0.6083095787,0.5809394906,0.6333904093,0.7005206585,0.838183921,0.7015222057,0.5652364101 13 | 11,0.6652941176,0.6046179318,0.6297607587,0.5728990813,0.6201219235,0.5776400349,0.6412247393,0.674313804,0.7809407411,0.6287425251,0.4003526451,0.6736363634,0.6403363252,0.6489446683,0.5548460351,0.6115122023,0.6249228764,0.6877390523,0.7400750438,0.7430146551,0.5593629814,0.3003518676 14 | 12,0.6738235294,0.616810721,0.6398559415,0.5650821569,0.6199422207,0.589750173,0.6590079602,0.7099026493,0.7994630631,0.6332062399,0.3587324006,0.6772727271,0.6429986119,0.652995641,0.5852010842,0.6087274248,0.6340589844,0.6913573343,0.7191940238,0.7091327669,0.560175316,0.357291924 15 | -------------------------------------------------------------------------------- /SampleDataset Log/CapsNet_log.csv: -------------------------------------------------------------------------------- 1 | epoch,capsnet_binary_accuracy,capsnet_fbeta_score,capsnet_loss,capsnet_precision,capsnet_recall,decoder_binary_accuracy,decoder_fbeta_score,decoder_loss,decoder_precision,decoder_recall,loss,val_capsnet_binary_accuracy,val_capsnet_fbeta_score,val_capsnet_loss,val_capsnet_precision,val_capsnet_recall,val_decoder_binary_accuracy,val_decoder_fbeta_score,val_decoder_loss,val_decoder_precision,val_decoder_recall,val_loss 2 | 0,0.5403891509,0.0604318035,0.8377897329,0.078733481,0.0569308767,0.01178324,0.9658051482,0.6931599671,0.9650408946,0.969268566,1.5309496983,0.5400000001,0,0.7040503032,0,0,0.0097017045,0.9615046987,0.6931650056,0.9595970214,0.96969906,1.3972152966 3 | 1,0.5403800475,0,0.706588868,0,0,0.0118022565,0.9658012274,0.693162281,0.9649487203,0.969684887,1.3997511456,0.5400000001,0,0.7018701139,0,0,0.0097443182,0.9614613706,0.6931640122,0.9595903717,0.9695076541,1.3950341199 4 | 2,0.5374109264,0.0098420483,0.7048673522,0.0285035595,0.0027366669,0.0116909145,0.9655348668,0.6931614714,0.9646237613,0.9695544884,1.3980288254,0.5400000001,0,0.7000222208,0,0,0.0097159091,0.9614373467,0.6931638026,0.9595857906,0.9694035112,1.3931860321 5 | 3,0.541567696,0,0.7014485272,0,0,0.0119321556,0.9665974935,0.6931613743,0.9659529808,0.9694732396,1.3946099015,0.5400000001,0.0072727251,0.6982513183,0.0290909056,0.0018181818,0.0096732955,0.9614204008,0.6931629092,0.9595828256,0.9693294402,1.3914142171 6 | 4,0.5400831354,0.028479713,0.7006629871,0.1045130517,0.007322597,0.0120342191,0.9660563865,0.6931604954,0.9653199594,0.9694005163,1.3938234821,0.5390909092,0.014928226,0.6968195566,0.0581818112,0.0037575759,0.0096164773,0.9613966612,0.6931623459,0.9595786693,0.9692257324,1.3899818932 7 | 5,0.5439429929,0.0130706548,0.6985187026,0.0443388714,0.0037326095,0.0119646303,0.9652344872,0.6931605193,0.9643036072,0.9693265135,1.3916792179,0.540909091,0.0478710266,0.6953984358,0.1454545316,0.0132239886,0.0096022727,0.9613800413,0.6931617685,0.9595761897,0.969152218,1.388560198 8 | 6,0.5430522565,0.0709600477,0.6976494772,0.2399049618,0.0191322498,0.011811535,0.966184668,0.6931592763,0.9655504677,0.9690863773,1.3908087488,0.5400000001,0.0296179117,0.694051492,0.0872727203,0.0083415981,0.0095880682,0.9613734575,0.6931615346,0.9595858877,0.9690769436,1.3872130273 9 | 7,0.5406769596,0.0415988284,0.6957445567,0.1330166138,0.011613291,0.0112409071,0.9659845181,0.6931582828,0.9652662981,0.9692187907,1.3889028412,0.5472727274,0.1098674531,0.6929822365,0.3345454233,0.0319955833,0.0095738636,0.961363189,0.6931609208,0.9595839769,0.9690323619,1.3861431434 10 | 8,0.5483966746,0.1015663861,0.6948592479,0.3357085966,0.0277395309,0.0117697818,0.9654959437,0.6931590588,0.9646773685,0.969115506,1.3880183034,0.5436363637,0.0734061436,0.6920612129,0.2181817974,0.0221632889,0.0095454545,0.9613492359,0.6931608831,0.9595810673,0.9689725388,1.3852220787 11 | 9,0.5451306413,0.1316083372,0.6948565919,0.4115597434,0.0373521547,0.0120017444,0.9637256589,0.6931581616,0.9624959164,0.9690295802,1.3880147506,0.5472727274,0.1098674531,0.6911536514,0.3345454233,0.0319955833,0.0095454545,0.9613492359,0.6931603848,0.9595810673,0.9689725388,1.3843140437 12 | 10,0.5472090261,0.1692559134,0.6938915602,0.470308758,0.052533176,0.0118208135,0.965931682,0.6931581719,0.9652486384,0.9690027558,1.387049726,0.5518181819,0.1468754209,0.6903963776,0.392727245,0.0455959226,0.00953125,0.9613492359,0.6931600395,0.9595810673,0.9689725388,1.3835564119 13 | 11,0.5489904988,0.1238760209,0.6913897176,0.3304829579,0.0379785143,0.0115192622,0.965805539,0.6931577316,0.9651046949,0.968950873,1.384547451,0.5554545456,0.1649574614,0.6896505928,0.426666636,0.0530154089,0.0095454545,0.9613415978,0.6931597762,0.9595786676,0.9689415843,1.3828103742 14 | 12,0.5576009501,0.2141060373,0.6907890653,0.5187648269,0.0690326717,0.0116769967,0.9656950129,0.6931576083,0.9649408276,0.9689723746,1.3839466832,0.5472727274,0.1098674531,0.6891993919,0.3345454233,0.0319955833,0.0095454545,0.9613347886,0.6931597981,0.9595774469,0.9689119354,1.3823591813 15 | 13,0.5507719715,0.2081532316,0.6922442161,0.5371337789,0.0683104553,0.0115378192,0.9656368587,0.6931570024,0.9649121271,0.9689110945,1.3854012172,0.5563636365,0.1749141684,0.6885329255,0.446060576,0.0571508814,0.0095454545,0.9613347886,0.69315944,0.9595774469,0.9689119354,1.3816923644 16 | 14,0.5593824228,0.2338715528,0.6906064062,0.5546318041,0.0791792001,0.0118022565,0.9664538906,0.6931571819,0.9659074308,0.9688993115,1.3837635857,0.5563636365,0.1718838657,0.688126056,0.4557575417,0.0547266389,0.0095454545,0.9613347886,0.6931593364,0.9595774469,0.9689119354,1.38128539 17 | 15,0.5510688836,0.2329366584,0.6902837671,0.5178938886,0.0822814077,0.0113754454,0.9655846247,0.6931568059,0.9648508422,0.9688696911,1.3834405735,0.5563636365,0.1749141684,0.6876910112,0.446060576,0.0571508814,0.00953125,0.9613347886,0.6931591496,0.9595774469,0.9689119354,1.3808501573 18 | 16,0.5611638955,0.2415512721,0.6887494871,0.5735550006,0.0827743244,0.0118393705,0.9655340838,0.6931568854,0.9647346663,0.9689833907,1.3819063791,0.5563636365,0.1749141684,0.687322485,0.446060576,0.0571508814,0.00953125,0.9613313883,0.6931590106,0.9595768348,0.9688971084,1.3804815102 19 | 17,0.5507719715,0.2602351972,0.6909028259,0.5562417357,0.0944792301,0.0111156473,0.9667201116,0.6931565799,0.9662566004,0.9688788791,1.3840594014,0.5581818183,0.2275056046,0.6869036551,0.5333333033,0.0769413135,0.00953125,0.9613313883,0.6931587115,0.9595768348,0.9688971084,1.3800623729 20 | 18,0.567695962,0.3108273653,0.6892577919,0.6206537596,0.1167721621,0.0113708061,0.9650901975,0.6931565209,0.9642599134,0.9688046592,1.382414307,0.5563636365,0.2034829634,0.6866317411,0.5042423907,0.0671231318,0.0095170455,0.9613277158,0.6931586734,0.9595758447,0.968881886,1.3797904127 21 | 19,0.5632422803,0.2876404992,0.6875848078,0.5709421915,0.1098303381,0.0115470977,0.9656580752,0.6931561019,0.9649334401,0.9688740326,1.3807409093,0.5563636365,0.1880428531,0.6863606574,0.4557575486,0.0626209671,0.0095170455,0.9613277158,0.693158557,0.9595758447,0.968881886,1.379519219 22 | -------------------------------------------------------------------------------- /FullDataset Log/CapsNet_log.csv: -------------------------------------------------------------------------------- 1 | epoch,capsnet_binary_accuracy,capsnet_fbeta_score,capsnet_loss,capsnet_precision,capsnet_recall,decoder_binary_accuracy,decoder_fbeta_score,decoder_loss,decoder_precision,decoder_recall,loss,val_capsnet_binary_accuracy,val_capsnet_fbeta_score,val_capsnet_loss,val_capsnet_precision,val_capsnet_recall,val_decoder_binary_accuracy,val_decoder_fbeta_score,val_decoder_loss,val_decoder_precision,val_decoder_recall,val_loss 2 | 0,0.5758147321,0.3792350322,0.6901096179,0.4934636126,0.2591311921,0.0173819406,0.9715220758,0.693157311,0.9667303369,0.9915030845,1.3832669293,0.6001785714,0.5231256321,0.6679280068,0.6245713578,0.3468264713,0.0164020647,0.9653517638,0.6931430793,0.9591393445,0.9914397187,1.361071087 3 | 1,0.60234375,0.5392347566,0.6683946115,0.6013095137,0.4277395425,0.017209647,0.9711848351,0.6931384424,0.9665413958,0.9905472273,1.3615330546,0.6023214286,0.5211550929,0.6630945984,0.6344391949,0.3321683579,0.0178306362,0.9652242005,0.6931241359,0.9591408798,0.9907621191,1.3562187369 4 | 2,0.6077120536,0.5506250451,0.6643018668,0.6021545008,0.4493390042,0.0182929339,0.9712360599,0.6931223041,0.966619684,0.9904655133,1.3574241703,0.6078571429,0.5373162717,0.6599591563,0.6305740686,0.3650544037,0.0185658482,0.9651442036,0.6931065091,0.9594158357,0.9891714605,1.3530656672 5 | 3,0.608515625,0.5533487427,0.6627090156,0.6014107552,0.4573045804,0.0188874163,0.9711502013,0.6931073953,0.9668960993,0.9888872922,1.355816411,0.6184821429,0.5684140084,0.6570801279,0.617394966,0.4616279766,0.0188992746,0.965225116,0.6930907798,0.9593369739,0.9899326505,1.3501709076 6 | 4,0.6105133929,0.5585282862,0.6619549086,0.6031027647,0.4666727006,0.0190952846,0.9713612807,0.6930948028,0.967165207,0.9888230462,1.3550497114,0.6082142857,0.5360415837,0.6589123435,0.6369084955,0.3564060503,0.0192522321,0.9649958534,0.6930776053,0.9594505097,0.9882435765,1.3519899464 7 | 5,0.6111049107,0.5592149998,0.6609195369,0.6022689779,0.4713428681,0.0195778111,0.9710703028,0.693082662,0.9670304159,0.9878906374,1.3540021985,0.6200892857,0.5677337058,0.6559482735,0.6237367885,0.4470510539,0.0189090402,0.965066843,0.6930670241,0.9595427035,0.9882281268,1.3490152955 8 | 6,0.6118973214,0.5617322192,0.6599511596,0.6020159375,0.4762518434,0.0197858538,0.97113774,0.6930721241,0.9668583152,0.9889612782,1.3530232831,0.6220535714,0.5761084723,0.65567954,0.6137983135,0.4915759329,0.0196065848,0.965496778,0.693055281,0.9595164306,0.9905934235,1.3487348233 9 | 7,0.6148772321,0.5647736093,0.6587705943,0.6040131324,0.4810965797,0.0202094378,0.9716602003,0.6930618906,0.9670621344,0.9908104239,1.351832486,0.6190178571,0.5640063307,0.6551002676,0.6264988695,0.4312130506,0.0202566964,0.9656825285,0.6930435021,0.9594728114,0.9917597497,1.3481437693 10 | 8,0.6144866071,0.5647041997,0.6592954256,0.6034562958,0.4811278726,0.0204225377,0.9717656483,0.6930527877,0.9669852187,0.9917050311,1.3523482132,0.6208035714,0.5719605477,0.6542359061,0.6209972551,0.4629330194,0.0210365513,0.9656569816,0.6930350179,0.9594652011,0.9916560653,1.3472709237 11 | 9,0.6150334821,0.5662143931,0.6585881304,0.6044013564,0.4831428053,0.0208482143,0.9720832006,0.6930447945,0.9672179537,0.9923641636,1.3516329266,0.6227678571,0.576746479,0.6539186421,0.6166060566,0.4883319126,0.0214578683,0.9658880833,0.6930267297,0.9594311505,0.9930232966,1.3469453733 12 | 10,0.6153459821,0.5657373727,0.657989501,0.6026548185,0.4840274082,0.0210313198,0.971783132,0.6930382044,0.9666276067,0.9933035693,1.3510277047,0.6186607143,0.5596324362,0.6550096319,0.636310265,0.4055970725,0.0222921317,0.9659300196,0.6930195756,0.9594734769,0.9930616134,1.3480292082 13 | 11,0.615625,0.5673893831,0.6581966331,0.6048445246,0.4851007374,0.0211251395,0.9720749748,0.6930319224,0.9667870317,0.9941409812,1.3512285569,0.6183928571,0.5625120934,0.654006716,0.6281487175,0.4258920732,0.0221707589,0.9660318072,0.6930136268,0.9594673978,0.9936234966,1.3470203413 14 | 12,0.6167857143,0.5687148751,0.6577280569,0.6063288942,0.4866817805,0.0214901297,0.972172127,0.6930264342,0.9669164691,0.9940889786,1.3507544921,0.6216964286,0.5709786533,0.6534766788,0.6277011839,0.4491567526,0.0223632813,0.9660182357,0.6930083632,0.9594417812,0.9936637553,1.3464850426 15 | 13,0.6155691964,0.5669862596,0.6578721742,0.6036282994,0.4852845614,0.0216003418,0.9721375357,0.6930218499,0.9667878232,0.9944588583,1.3508940252,0.6175892857,0.555193677,0.6557565655,0.6471007961,0.3818566174,0.0230287388,0.9662067262,0.6930024905,0.9594400791,0.9946673007,1.3487590541 16 | 14,0.6160825893,0.56779103,0.656636801,0.6056179665,0.4854416268,0.0219304548,0.9722371336,0.6930166933,0.9666392028,0.9956284269,1.3496534944,0.6219642857,0.5727078312,0.6529427842,0.6253200151,0.4580057097,0.0222516741,0.9664500972,0.6929971395,0.9593843433,0.9961998406,1.3459399247 17 | 15,0.6166852679,0.5692440124,0.6564575649,0.6059141374,0.4848656416,0.0219321987,0.9725399895,0.6930127422,0.9668634952,0.9962388092,1.3494703059,0.6229464286,0.5743146912,0.6527860647,0.6253869681,0.4621744092,0.0229087612,0.966521359,0.6929924398,0.9593855933,0.9965729218,1.3457785027 18 | 16,0.6179129464,0.5710363887,0.6566438014,0.6075242505,0.4892737625,0.0220392718,0.9725867802,0.6930091974,0.9668868095,0.996383066,1.3496529993,0.61875,0.5604942875,0.6541440252,0.639071218,0.4042046348,0.0234333147,0.9666451129,0.692988972,0.9594224725,0.9970714288,1.3471329955 19 | 17,0.6187165179,0.5725177876,0.656182669,0.6088398008,0.4913023395,0.0222438267,0.9728325558,0.693005865,0.9671412216,0.9965952638,1.3491885355,0.6228571429,0.5741376383,0.6526756799,0.6246632154,0.4630239294,0.0225167411,0.9666062122,0.692986794,0.9594202352,0.9968766715,1.3456624753 20 | 18,0.6185602679,0.5708953117,0.6565121366,0.6073653003,0.4895949723,0.0222521973,0.9725900414,0.6930034113,0.9668838104,0.9964303029,1.3495155483,0.6225,0.5705339127,0.6527074163,0.6315155782,0.440859914,0.0237695313,0.9665889171,0.6929826403,0.9594175206,0.9967928304,1.3456900549 21 | 19,0.6185044643,0.570672852,0.6561067066,0.6060979802,0.4901266841,0.0224698312,0.972695397,0.6930005677,0.9669122003,0.9968606166,1.3491072739,0.6216964286,0.5666606352,0.6530412974,0.6349740566,0.4251009309,0.0235909598,0.9666384528,0.6929801539,0.9594195167,0.9970475752,1.3460214506 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diseases Detection from Chest X-ray data 2 | Machine Learning Capstone Project - Udacity MLND 3 | 4 | ## Project Overview 5 | With so many lung diseases people can get, here is just one example of diseases we can save if we find them out earlier. 6 | With the technology machine and computer power, the earlier identification of diseases, particularly lung disease, we can be helped to detect earlier and more accurately, which can save many many people as well as reduce the pressure on the system. The health system has not developed in time with the development of the population. 7 | 8 | ## Analysis 9 | ### Data Exploration 10 | ![](./images/00000017_001_small.png) 11 | ### Exploratory Visualization 12 | ![](./images/chart.jpg) 13 | 14 | ## Datasets 15 | ### [Sample dataset](https://www.kaggle.com/nih-chest-xrays/sample) 16 | * File contents: this is a random sample (5%) of the full dataset: 17 | sample.zip: Contains 5,606 images with size 1024 x 1024 18 | sample_labels.csv: Class labels and patient data for the entire dataset 19 | * Class descriptions: there are 15 classes (14 diseases, and one for "No findings") in the full dataset, but since this is drastically reduced version of the full dataset, some of the classes are sparse with the labeled as "No findings": Hernia - 13 images, Pneumonia - 62 images, Fibrosis - 84 images, Edema - 118 images, Emphysema - 127 images, Cardiomegaly - 141 images, Pleural_Thickening - 176 images, Consolidation - 226 images, Pneumothorax - 271 images, Mass - 284 images, Nodule - 313 images, Atelectasis - 508 images, Effusion - 644 images, Infiltration - 967 images, No Finding - 3044 images. 20 | ### [Full dataset](https://www.kaggle.com/nih-chest-xrays/data) 21 | * File contents: 22 | images_00x.zip: 12 files with 112,120 total images with size 1024 x 1024 23 | README_ChestXray.pdf: Original README file 24 | BBox_list_2017.csv: Bounding box coordinates. Note: Start at x,y, extend horizontally w pixels, and vertically h pixels 25 | Data_entry_2017.csv: Class labels and patient data for the entire dataset 26 | * Class descriptions: there are 15 classes (14 diseases, and one for "No findings"). Images can be classified as "No findings" or one or more disease classes: Atelectasis, Consolidation, Infiltration, Pneumothorax, Edema, Emphysema, Fibrosis, Effusion, Pneumonia, Pleural_thickening, Cardiomegaly, Nodule Mass, Hernia. 27 | 28 | ## Algorithms and Techniques 29 | * CNN 30 | * Spacial Transformer 31 | * VGG finetuning 32 | * Capsule Network 33 | ### Architecture 34 | 35 | #### Vanilla CNN 36 | ![](./images/vallina.jpg) 37 | 38 | #### Optimized CNN 39 | ![](./images/Optimized%20CNN.jpg) 40 | 41 | #### CapsNet 42 | ![](./images/CapsNet.jpg) 43 | 44 | ## Metrics & Result 45 | F-beta score with β = 0.5 to represent precision will be more important than recall in this case. 46 | 47 | Result: 48 | In sample dataset: 49 | 50 | | Model | Precision | Recall | F 0.5 score | Accuracy | Training time/ epoch | no. parameters | 51 | | ------ | ------ | ------ | ------ | ------ | ------ | ------ | 52 | | Vanilla rgb | 0.617 | 0.589 | 0.611 | 0.503 | 2 s | 322793 | 53 | | Vanilla gray | 0.577 | 0.48 | 0.555 | 0.517 | 2 s | 321225 | 54 | | CNN + VGG | 0.645 | 0.555 | 0.624 | 0.667 | 16 s | 15252133 | 55 | | CNN + VGG + data | 0.647 | 0.588 | 0.634 | 0.675 | 16 s | 15240769 | 56 | | CNN + VGG + data + STN | 0.642 | 0.614 | 0.636 | 0.677 | 19 s | 15488051 | 57 | | CapsNet basic | 0.614 | 0.599 | 0.611 | 0.581 | 75 s | 14788864 | 58 | | CapsNet changed | 0.735 | 0.073 | 0.261 | 0.575 | 37 s | 12167424 | 59 | 60 | In full dataset: 61 | 62 | | Model | Precision | Recall | F 0.5 score | Accuracy | Training time/ epoch | no. parameters | 63 | | ------ | ------ | ------ | ------ | ------ | ------ | ------ | 64 | | Vanilla rgb | 0.672 | 0.594 | 0.655 | 0.672 | 53 s | 322793 | 65 | | Vanilla gray | 0.672 | 0.572 | 0.649 | 0.667 | 51 s | 321225 | 66 | | CNN + VGG | 0.675 | 0.619 | 0.663 | 0.688 | 384 s | 15252133 | 67 | | CNN + VGG + data + STN | 0.684 | 0.621 | 0.67 | 0.693 | 431 s | 15488051 | 68 | | CapsNet basic | 0.64 | 0.498 | 0.605 | 0.635 | 1815 s | 14788864 | 69 | | CapsNet changed | 0.625 | 0.474 | 0.588 | 0.625 | 856 s | 12167424 | 70 | 71 | ## Installation 72 | ### [Jupyter Notebook](http://jupyter.readthedocs.io/en/latest/install.html) with [python3](http://docs.python-guide.org/en/latest/starting/install3/linux/) 73 | ```sh 74 | $ sudo apt-get update 75 | $ sudo apt-get install python3-pip python3-dev 76 | $ pip3 install --upgrade pip 77 | $ pip3 install jupyter 78 | ``` 79 | ### [Tensorflow](https://www.tensorflow.org/install/install_linux) for GPU 80 | ```sh 81 | $ pip3 install tensorflow==1.8.0 # Python 3.n; CPU support (no GPU support) 82 | $ pip3 install tensorflow-gpu==1.8.0 # Python 3.n; GPU support 83 | ``` 84 | ### [Keras](https://keras.io/#installation) for GPU 85 | ```sh 86 | $ pip3 install keras 87 | ``` 88 | ### Others 89 | * numpy 90 | * pandas 91 | * seaborn 92 | * matplotlib 93 | * opencv 94 | * glob 95 | * tqdm 96 | * sklearn 97 | * pickle 98 | 99 | ## Note 100 | 1. Run [Data preprocessing](./Data%20preprocessing%20-%20SampleDataset.ipynb) first to create preprocessing file in Sample dataset before run other notebook for Sample dataset. 101 | 102 | 2. Following are the file descriptions and URL’s from which the data can be obtained: 103 | * data sample/sample_labels.csv: Class labels and patient data for the sample dataset 104 | * data sample/Data_entry_2017.csv: Class labels and patient data for the full dataset 105 | * data sample/images/*: 10 chest X-ray images 106 | 107 | 3. Following are the notebooks descriptions and python files descriptions, files log: 108 | Notebooks: 109 | * Capsule Network - FullDataset.ipynb: Capsule Network with my architecture in full dataset 110 | * Capsule Network - SampleDataset.ipynb: Capsule Network with my architecture in sample dataset 111 | * Capsule Network basic - FullDataset.ipynb: Capsule Network with Hinton's architecture in full dataset 112 | * Capsule Network basic - SampleDataset.ipynb: Capsule Network with Hinton's architecture in sample dataset 113 | * Data analysis - FullDataset.ipynb: Data analysis in full dataset 114 | * Data analysis - SampleDataset.ipynb: data analysis in sample dataset 115 | * Data preprocessing - SampleDataset.ipynb: Data preprocessing 116 | * Demo.ipynb: Demo prediction 20 samples 117 | * optimized CNN - FullDataset.ipynb: My optimized CNN architecture in full dataset 118 | * optimized CNN - SampleDataset.ipynb: My optimized CNN architecture in sample dataset 119 | * vanilla CNN - FullDataset.ipynb: Vanilla CNN in full dataset 120 | * vanilla CNN - SampleDataset.ipynb: Vanilla CNN in sample dataset 121 | 122 | Python files 123 | * capsulelayers.py: capsule layer from [XifengGuo](https://github.com/XifengGuo/CapsNet-Keras) 124 | * spatial_transformer.py: spatial transformer layser from [hello2all](https://github.com/hello2all/GTSRB_Keras_STN) 125 | So thank you guys for support me with capsule layer and spatial transformer layer in Keras-gpu 126 | 127 | Log: 128 | * FullDataset Log: all log file in full dataset 129 | * SampleDataset Log: all log file in sample dataset 130 | -------------------------------------------------------------------------------- /spatial_transformer.py: -------------------------------------------------------------------------------- 1 | from keras.layers.core import Layer 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | class SpatialTransformer(Layer): 6 | """Spatial Transformer Layer 7 | Implements a spatial transformer layer as described in [1]_. 8 | Borrowed from [2]_: 9 | downsample_fator : float 10 | A value of 1 will keep the orignal size of the image. 11 | Values larger than 1 will down sample the image. Values below 1 will 12 | upsample the image. 13 | example image: height= 100, width = 200 14 | downsample_factor = 2 15 | output image will then be 50, 100 16 | References 17 | ---------- 18 | .. [1] Spatial Transformer Networks 19 | Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu 20 | Submitted on 5 Jun 2015 21 | .. [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py 22 | 23 | .. [3] https://github.com/EderSantana/seya/blob/keras1/seya/layers/attention.py 24 | """ 25 | 26 | def __init__(self, 27 | localization_net, 28 | output_size, 29 | **kwargs): 30 | self.locnet = localization_net 31 | self.output_size = output_size 32 | super(SpatialTransformer, self).__init__(**kwargs) 33 | 34 | def build(self, input_shape): 35 | self.locnet.build(input_shape) 36 | self.trainable_weights = self.locnet.trainable_weights 37 | # self.constraints = self.locnet.constraints 38 | 39 | def compute_output_shape(self, input_shape): 40 | output_size = self.output_size 41 | return (None, 42 | int(output_size[0]), 43 | int(output_size[1]), 44 | int(input_shape[-1])) 45 | 46 | def call(self, X, mask=None): 47 | affine_transformation = self.locnet.call(X) 48 | output = self._transform(affine_transformation, X, self.output_size) 49 | return output 50 | 51 | def _repeat(self, x, num_repeats): 52 | ones = tf.ones((1, num_repeats), dtype='int32') 53 | x = tf.reshape(x, shape=(-1,1)) 54 | x = tf.matmul(x, ones) 55 | return tf.reshape(x, [-1]) 56 | 57 | def _interpolate(self, image, x, y, output_size): 58 | batch_size = tf.shape(image)[0] 59 | height = tf.shape(image)[1] 60 | width = tf.shape(image)[2] 61 | num_channels = tf.shape(image)[3] 62 | 63 | x = tf.cast(x , dtype='float32') 64 | y = tf.cast(y , dtype='float32') 65 | 66 | height_float = tf.cast(height, dtype='float32') 67 | width_float = tf.cast(width, dtype='float32') 68 | 69 | output_height = output_size[0] 70 | output_width = output_size[1] 71 | 72 | x = .5*(x + 1.0)*(width_float) 73 | y = .5*(y + 1.0)*(height_float) 74 | 75 | x0 = tf.cast(tf.floor(x), 'int32') 76 | x1 = x0 + 1 77 | y0 = tf.cast(tf.floor(y), 'int32') 78 | y1 = y0 + 1 79 | 80 | max_y = tf.cast(height - 1, dtype='int32') 81 | max_x = tf.cast(width - 1, dtype='int32') 82 | zero = tf.zeros([], dtype='int32') 83 | 84 | x0 = tf.clip_by_value(x0, zero, max_x) 85 | x1 = tf.clip_by_value(x1, zero, max_x) 86 | y0 = tf.clip_by_value(y0, zero, max_y) 87 | y1 = tf.clip_by_value(y1, zero, max_y) 88 | 89 | flat_image_dimensions = width*height 90 | pixels_batch = tf.range(batch_size)*flat_image_dimensions 91 | flat_output_dimensions = output_height*output_width 92 | base = self._repeat(pixels_batch, flat_output_dimensions) 93 | base_y0 = base + y0*width 94 | base_y1 = base + y1*width 95 | indices_a = base_y0 + x0 96 | indices_b = base_y1 + x0 97 | indices_c = base_y0 + x1 98 | indices_d = base_y1 + x1 99 | 100 | flat_image = tf.reshape(image, shape=(-1, num_channels)) 101 | flat_image = tf.cast(flat_image, dtype='float32') 102 | pixel_values_a = tf.gather(flat_image, indices_a) 103 | pixel_values_b = tf.gather(flat_image, indices_b) 104 | pixel_values_c = tf.gather(flat_image, indices_c) 105 | pixel_values_d = tf.gather(flat_image, indices_d) 106 | 107 | x0 = tf.cast(x0, 'float32') 108 | x1 = tf.cast(x1, 'float32') 109 | y0 = tf.cast(y0, 'float32') 110 | y1 = tf.cast(y1, 'float32') 111 | 112 | area_a = tf.expand_dims(((x1 - x) * (y1 - y)), 1) 113 | area_b = tf.expand_dims(((x1 - x) * (y - y0)), 1) 114 | area_c = tf.expand_dims(((x - x0) * (y1 - y)), 1) 115 | area_d = tf.expand_dims(((x - x0) * (y - y0)), 1) 116 | output = tf.add_n([area_a*pixel_values_a, 117 | area_b*pixel_values_b, 118 | area_c*pixel_values_c, 119 | area_d*pixel_values_d]) 120 | return output 121 | 122 | def _meshgrid(self, height, width): 123 | x_linspace = tf.linspace(-1., 1., width) 124 | y_linspace = tf.linspace(-1., 1., height) 125 | x_coordinates, y_coordinates = tf.meshgrid(x_linspace, y_linspace) 126 | x_coordinates = tf.reshape(x_coordinates, shape=(1, -1)) 127 | y_coordinates = tf.reshape(y_coordinates, shape=(1, -1)) 128 | ones = tf.ones_like(x_coordinates) 129 | indices_grid = tf.concat([x_coordinates, y_coordinates, ones], 0) 130 | return indices_grid 131 | 132 | def _transform(self, affine_transformation, input_shape, output_size): 133 | batch_size = tf.shape(input_shape)[0] 134 | height = tf.shape(input_shape)[1] 135 | width = tf.shape(input_shape)[2] 136 | num_channels = tf.shape(input_shape)[3] 137 | 138 | affine_transformation = tf.reshape(affine_transformation, shape=(batch_size,2,3)) 139 | 140 | affine_transformation = tf.reshape(affine_transformation, (-1, 2, 3)) 141 | affine_transformation = tf.cast(affine_transformation, 'float32') 142 | 143 | width = tf.cast(width, dtype='float32') 144 | height = tf.cast(height, dtype='float32') 145 | output_height = output_size[0] 146 | output_width = output_size[1] 147 | indices_grid = self._meshgrid(output_height, output_width) 148 | indices_grid = tf.expand_dims(indices_grid, 0) 149 | indices_grid = tf.reshape(indices_grid, [-1]) # flatten? 150 | indices_grid = tf.tile(indices_grid, tf.stack([batch_size])) 151 | indices_grid = tf.reshape(indices_grid, tf.stack([batch_size, 3, -1])) 152 | 153 | # transformed_grid = tf.batch_matmul(affine_transformation, indices_grid) 154 | transformed_grid = tf.matmul(affine_transformation, indices_grid) 155 | x_s = tf.slice(transformed_grid, [0, 0, 0], [-1, 1, -1]) 156 | y_s = tf.slice(transformed_grid, [0, 1, 0], [-1, 1, -1]) 157 | x_s_flatten = tf.reshape(x_s, [-1]) 158 | y_s_flatten = tf.reshape(y_s, [-1]) 159 | 160 | transformed_image = self._interpolate(input_shape, 161 | x_s_flatten, 162 | y_s_flatten, 163 | output_size) 164 | 165 | transformed_image = tf.reshape(transformed_image, shape=(batch_size, 166 | output_height, 167 | output_width, 168 | num_channels)) 169 | return transformed_image 170 | 171 | -------------------------------------------------------------------------------- /SampleDataset Log/log_pretrained_CNN.csv: -------------------------------------------------------------------------------- 1 | epoch,binary_accuracy,fbeta_score_1,fbeta_score_2,fbeta_score_3,loss,precision_1,precision_2,precision_3,recall_1,recall_2,recall_3,val_binary_accuracy,val_fbeta_score_1,val_fbeta_score_2,val_fbeta_score_3,val_loss,val_precision_1,val_precision_2,val_precision_3,val_recall_1,val_recall_2,val_recall_3 2 | 0,0.49395161290322581,0.53720117095978026,0.50560964788160012,0.38699235454682379,0.74134328096143665,0.50242005817351798,0.50752923757799207,0.47299940739908525,0.76874662599255961,0.52573722985482985,0.23675813958529504,0.50600000000000001,0.49397757291793826,0.49726141309738159,0.017777769088745118,0.69598266410827636,0.44,0.4599198532104492,0.032000000000000001,1.0,0.76322349309921267,0.0064000000953674313 3 | 1,0.5175619834710744,0.54769208953400284,0.53233478138269474,0.47588054552551146,0.71442696082690527,0.50596581510275851,0.52418360631327987,0.5705436950872752,0.83954150420575102,0.60009212336264361,0.31049355766004766,0.5659999990463257,0.49397757291793826,0.40315093231201171,0.0,0.68024671459197994,0.44,0.5383965396881103,0.0,1.0,0.21731317949295043,0.0 4 | 2,0.52376033057851235,0.55032241590752085,0.51164758254673859,0.38250933501346052,0.70456456448421001,0.51766746989951651,0.53125669739463111,0.5449822983465904,0.75364684073393007,0.47308463744880741,0.19092910225726356,0.61600000047683712,0.49397757291793826,0.55838899278640752,0.029151126861572264,0.68217520236968998,0.44,0.56514202594757079,0.12799998474121094,1.0,0.54524399375915522,0.0071331269741058351 5 | 3,0.49586776859504134,0.54721591443069706,0.50577884222850322,0.41757567017531594,0.71897703852535277,0.50626543484443476,0.49937426107974092,0.51724346363840024,0.84993210508803696,0.56673639766440909,0.25366915011208907,0.60199999952316285,0.49397757291793826,0.55451092052459716,0.061214609622955322,0.68276256084442144,0.44,0.5356033861637115,0.23199997234344483,1.0,0.66140743350982667,0.015704555809497832 6 | 4,0.55061983471074383,0.55255316998347759,0.55527278854827256,0.45327203456035331,0.69484167010330955,0.50864232128316711,0.55187221154693733,0.60977306075332582,0.87260602968783418,0.59754219158621857,0.24392163507209336,0.60799999999999998,0.49494351148605348,0.54356690835952759,0.1493375506401062,0.67695041894912722,0.44096774101257324,0.55419616889953616,0.46666662502288819,1.0,0.51668846344947816,0.041192014038562777 7 | 5,0.54442148760330578,0.55749023404003173,0.55365916618630906,0.44177543870673691,0.69316851353842368,0.5130982960551238,0.55357990383116662,0.57523285455940187,0.8871740654480359,0.58647387776493043,0.24804671896883279,0.61799999904632563,0.49605483579635617,0.54079250288009639,0.079315610408782961,0.67478464984893793,0.44212903022766115,0.57431920003890991,0.24266664791107179,1.0,0.44727409434318544,0.022815666973590852 8 | 6,0.54545454545454541,0.54877914131180316,0.54015308298355291,0.37359242596902137,0.69921979923878819,0.50934111653280656,0.55072251651897908,0.56828020919453015,0.83594096987700661,0.52964774192857345,0.17992015577052251,0.61800000047683712,0.49397757291793826,0.55305094289779666,0.16447113084793091,0.67661386394500733,0.44,0.56115925788879395,0.4666666326522827,1.0,0.53610664176940914,0.047592014133930209 9 | 7,0.53925619834710747,0.5423160195350647,0.53396599745947471,0.3767378014966476,0.69447678573860605,0.49723164281569238,0.53199072416163673,0.56966358078412771,0.8665746337126109,0.5528253491752404,0.18057682371336567,0.63000000095367437,0.49934831953048708,0.55191696023941039,0.16896236085891725,0.67086642742156988,0.44562580299377441,0.60327850198745725,0.49866662502288817,0.99663158035278321,0.4197926604747772,0.047592014133930209 10 | 8,0.51446280991735538,0.56933571730763455,0.52277973269628097,0.38929898581228967,0.70190296685400089,0.52956144647164782,0.53156804521221757,0.57705628354687333,0.84411857738967766,0.51139959495914866,0.17822590864394322,0.62799999904632564,0.50093313360214231,0.53649974966049196,0.15013883638381958,0.66944918441772461,0.44721720314025881,0.61801897001266481,0.43466663265228273,0.9954285697937012,0.36210220980644225,0.042668937027454376 11 | 9,0.56301652892561982,0.55932851480551005,0.54334156375286002,0.39620274009783407,0.68301203920821518,0.52255389217502812,0.56807865596507201,0.65123967149040918,0.81022279124614616,0.49304729158228094,0.17327465424852923,0.63799999904632565,0.50585260152816769,0.5547002415657043,0.19001742982864381,0.66679926013946533,0.45258791160583495,0.6368587946891785,0.50933330059051518,0.98672681808471685,0.37625652766227724,0.057676863014698029 12 | 10,0.54958677685950408,0.57356923128947734,0.55915420951922079,0.40796302074243213,0.69055659169993122,0.53291707590591808,0.56567247987778724,0.60548341225001434,0.86515560623042842,0.56320452887164663,0.20110818238790371,0.63800000095367426,0.50030369329452518,0.58447565078735353,0.41195426845550537,0.66971020030975337,0.44657634162902832,0.59174238014221192,0.61919999694824224,0.99663158035278321,0.57421087074279786,0.19049104785919188 13 | 11,0.55991735537190079,0.56217146855740507,0.56239737360930642,0.47146018871591111,0.68075222614382913,0.51734886583217909,0.56008481930110077,0.64035753868828138,0.88913768185071707,0.59408721992792179,0.25460668720982294,0.62800000047683713,0.50281999158859247,0.55170554161071772,0.3774713411331177,0.66463589000701906,0.45002343368530273,0.5813404450416565,0.65119998931884771,0.97753707504272458,0.46952063679695127,0.15458669805526734 14 | 12,0.56818181818181823,0.5661139882300511,0.56277453308263103,0.43837457402678559,0.68253791824845245,0.52470086527264803,0.58607696304636558,0.6742489825595509,0.86517241375505427,0.52344600622318993,0.20140908720079531,0.64800000095367427,0.49708198308944701,0.58662760543823245,0.42590164089202881,0.66824434614181516,0.44360860061645507,0.60495448303222654,0.62560000610351563,0.99206015014648441,0.53530382347106931,0.19591608476638794 15 | 13,0.55578512396694213,0.56819997690925916,0.55631851903663199,0.49278282232520992,0.68171666673392306,0.52748538713809867,0.56665799735991418,0.65101675947835624,0.8522430944048669,0.54955758092817197,0.27507319750864645,0.63200000095367437,0.50844989538192753,0.55194022035598755,0.41445411586761477,0.6635497694015503,0.45589747238159178,0.59587435770034791,0.65744762420654301,0.97327040863037106,0.44461979842185972,0.1759829592704773 16 | 14,0.57541322314049592,0.57687143254871209,0.57260145678007901,0.51372163601158083,0.67439483709571779,0.53587944300706725,0.5776920456531619,0.66961826864352891,0.86027889458601137,0.57868845797767321,0.29334368636785457,0.64000000095367426,0.50456268835067752,0.57054553461074831,0.45983920860290528,0.66024498081207272,0.45398939514160158,0.60168081426620479,0.65592381668090816,0.93585857772827152,0.48648181080818176,0.21993034601211547 17 | 15,0.55888429752066116,0.57079703305378438,0.55343190894639194,0.46205269452954129,0.67633081566203723,0.53022239080145339,0.56111675796429972,0.62813376641470542,0.84270595519010683,0.54474245516721864,0.24840255758979105,0.63999999904632565,0.5004555430412293,0.57332537317276,0.46251095008850096,0.65974418354034425,0.45321076011657713,0.60942027235031127,0.65912381744384763,0.88193282699584963,0.48288552403450014,0.21934000253677369 18 | 16,0.59400826446280997,0.57204657004884452,0.58170411015345047,0.45000788249260137,0.67079960413215578,0.53603101860393176,0.61263899960793744,0.70156106722256373,0.8118663701144132,0.50187606703151355,0.20892135663466019,0.63599999904632565,0.49826028394699096,0.55737909555435183,0.41445411586761477,0.65906746816635131,0.45092140007019044,0.61142644834518434,0.65744762420654301,0.8848728866577148,0.42537214732170103,0.1759829592704773 19 | 17,0.57438016528925617,0.57272755951920817,0.57338549307555209,0.50526916340362926,0.67721039007517914,0.5326380756768313,0.58975890968456746,0.66830589948606889,0.8469455453975141,0.5560549798090596,0.28873877793796793,0.62599999904632564,0.51047858953475955,0.53582693433761597,0.39782307338714601,0.6571477479934692,0.46548043370246889,0.60396191358566287,0.65424761581420898,0.84872106981277462,0.37869071459770204,0.16714486312866211 20 | 18,0.60537190082644632,0.58465119777632157,0.62343196090587905,0.52681741783441594,0.66964504935524682,0.54624344494717181,0.63991112945493589,0.70053721016103576,0.84189683346708943,0.60045767027484487,0.29390942533154135,0.61800000095367436,0.50639952588081361,0.51825738477706906,0.39444982624053954,0.65773095893859868,0.46180869483947756,0.59293911647796627,0.6458666648864746,0.84822192335128788,0.35879621028900144,0.1649348120689392 21 | 19,0.60433884297520657,0.58107777962014695,0.59755370981437117,0.55636262056256125,0.66160529999693563,0.54147753045578628,0.61528228284898867,0.77133078318982085,0.841574892032245,0.55759252644767443,0.29751545301646243,0.62800000095367436,0.50065768384933473,0.57092650985717774,0.52323693037033081,0.664879566192627,0.44848987388610839,0.5700584402084351,0.59138586378097535,0.96367040634155277,0.58610981082916258,0.3693716781139374 22 | -------------------------------------------------------------------------------- /Data preprocessing - SampleDataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import pandas as pd\n", 13 | "from glob import glob\n", 14 | "from tqdm import tqdm\n", 15 | "from sklearn.utils import shuffle\n", 16 | "\n", 17 | "df = pd.read_csv('sample/sample_labels.csv')\n", 18 | "\n", 19 | "diseases = ['Cardiomegaly','Emphysema','Effusion','Hernia','Nodule','Pneumothorax','Atelectasis','Pleural_Thickening','Mass','Edema','Consolidation','Infiltration','Fibrosis','Pneumonia']\n", 20 | "#Number diseases\n", 21 | "for disease in diseases :\n", 22 | " df[disease] = df['Finding Labels'].apply(lambda x: 1 if disease in x else 0)\n", 23 | "\n", 24 | "# #test to perfect\n", 25 | "# df = df.drop(df[df['Emphysema']==0][:-127].index.values)\n", 26 | " \n", 27 | "#remove Y after age\n", 28 | "df['Age']=df['Patient Age'].apply(lambda x: x[:-1]).astype(int)\n", 29 | "df['Age Type']=df['Patient Age'].apply(lambda x: x[-1:])\n", 30 | "df.loc[df['Age Type']=='M',['Age']] = df[df['Age Type']=='M']['Age'].apply(lambda x: round(x/12.)).astype(int)\n", 31 | "df.loc[df['Age Type']=='D',['Age']] = df[df['Age Type']=='D']['Age'].apply(lambda x: round(x/365.)).astype(int)\n", 32 | "# remove outliers\n", 33 | "df = df.drop(df['Age'].sort_values(ascending=False).head(1).index)\n", 34 | "df['Age'] = df['Age']/df['Age'].max()\n", 35 | "\n", 36 | "#one hot data\n", 37 | "# df = df.drop(df.index[4242])\n", 38 | "df = df.join(pd.get_dummies(df['Patient Gender']))\n", 39 | "df = df.join(pd.get_dummies(df['View Position']))\n", 40 | "\n", 41 | "#random samples\n", 42 | "df = shuffle(df)\n", 43 | "\n", 44 | "#get other data\n", 45 | "data = df[['Age', 'F', 'M', 'AP', 'PA']]\n", 46 | "data = np.array(data)\n", 47 | "\n", 48 | "labels = df[diseases].as_matrix()\n", 49 | "files_list = ('sample/images/' + df['Image Index']).tolist()\n", 50 | "\n", 51 | "# #test to perfect\n", 52 | "# labelB = df['Emphysema'].tolist()\n", 53 | "\n", 54 | "labelB = (df[diseases].sum(axis=1)>0).tolist()\n", 55 | "labelB = np.array(labelB, dtype=int)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "# RGB images" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 2, 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "name": "stderr", 72 | "output_type": "stream", 73 | "text": [ 74 | "100%|██████████| 3400/3400 [00:53<00:00, 64.04it/s]\n", 75 | "100%|██████████| 1100/1100 [00:17<00:00, 61.92it/s]\n", 76 | "100%|██████████| 1105/1105 [00:17<00:00, 62.04it/s]\n" 77 | ] 78 | } 79 | ], 80 | "source": [ 81 | "from keras.preprocessing import image \n", 82 | "from tqdm import tqdm\n", 83 | "\n", 84 | "def path_to_tensor(img_path, shape):\n", 85 | " # loads RGB image as PIL.Image.Image type\n", 86 | " img = image.load_img(img_path, target_size=shape)\n", 87 | " # convert PIL.Image.Image type to 3D tensor with shape (224, 224, 3)\n", 88 | " x = image.img_to_array(img)/255\n", 89 | " # convert 3D tensor to 4D tensor with shape (1, 224, 224, 3) and return 4D tensor\n", 90 | " return np.expand_dims(x, axis=0)\n", 91 | "\n", 92 | "def paths_to_tensor(img_paths, shape):\n", 93 | " list_of_tensors = [path_to_tensor(img_path, shape) for img_path in tqdm(img_paths)]\n", 94 | " return np.vstack(list_of_tensors)\n", 95 | "\n", 96 | "train_labels = labelB[:3400][:, np.newaxis]\n", 97 | "valid_labels = labelB[3400:4500][:, np.newaxis]\n", 98 | "test_labels = labelB[4500:][:, np.newaxis]\n", 99 | "\n", 100 | "train_data = data[:3400]\n", 101 | "valid_data = data[3400:4500]\n", 102 | "test_data = data[4500:]\n", 103 | "\n", 104 | "img_shape = (64, 64)\n", 105 | "train_tensors = paths_to_tensor(files_list[:3400], shape = img_shape)\n", 106 | "valid_tensors = paths_to_tensor(files_list[3400:4500], shape = img_shape)\n", 107 | "test_tensors = paths_to_tensor(files_list[4500:], shape = img_shape)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 3, 113 | "metadata": { 114 | "collapsed": true 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "import pickle\n", 119 | "\n", 120 | "train_filename = \"data_preprocessed/train_data_sample_rgb.p\"\n", 121 | "pickle.dump((train_labels, train_data, train_tensors), open(train_filename, 'wb'))\n", 122 | "\n", 123 | "valid_filename = \"data_preprocessed/valid_data_sample_rgb.p\"\n", 124 | "pickle.dump((valid_labels, valid_data, valid_tensors), open(valid_filename, 'wb'))\n", 125 | "\n", 126 | "test_filename = \"data_preprocessed/test_data_sample_rgb.p\"\n", 127 | "pickle.dump((test_labels, test_data, test_tensors), open(test_filename, 'wb'))" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "# Gray images" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 4, 140 | "metadata": {}, 141 | "outputs": [ 142 | { 143 | "name": "stderr", 144 | "output_type": "stream", 145 | "text": [ 146 | "100%|██████████| 3400/3400 [00:43<00:00, 78.34it/s]\n", 147 | "100%|██████████| 1100/1100 [00:12<00:00, 84.98it/s]\n", 148 | "100%|██████████| 1105/1105 [00:13<00:00, 84.73it/s]\n" 149 | ] 150 | } 151 | ], 152 | "source": [ 153 | "from keras.preprocessing import image \n", 154 | "from tqdm import tqdm\n", 155 | "\n", 156 | "def path_to_tensor(img_path, shape):\n", 157 | " # loads RGB image as PIL.Image.Image type\n", 158 | " img = image.load_img(img_path, grayscale=True, target_size=shape)\n", 159 | " # convert PIL.Image.Image type to 3D tensor with shape (224, 224, 1)\n", 160 | " x = image.img_to_array(img)/255\n", 161 | " # convert 3D tensor to 4D tensor with shape (1, 224, 224, 1) and return 4D tensor\n", 162 | " return np.expand_dims(x, axis=0)\n", 163 | "\n", 164 | "def paths_to_tensor(img_paths, shape):\n", 165 | " list_of_tensors = [path_to_tensor(img_path, shape) for img_path in tqdm(img_paths)]\n", 166 | " return np.vstack(list_of_tensors)\n", 167 | "\n", 168 | "train_labels = labelB[:3400][:, np.newaxis]\n", 169 | "valid_labels = labelB[3400:4500][:, np.newaxis]\n", 170 | "test_labels = labelB[4500:][:, np.newaxis]\n", 171 | "\n", 172 | "train_data = data[:3400]\n", 173 | "valid_data = data[3400:4500]\n", 174 | "test_data = data[4500:]\n", 175 | "\n", 176 | "img_shape = (64, 64)\n", 177 | "train_tensors = paths_to_tensor(files_list[:3400], shape = img_shape)\n", 178 | "valid_tensors = paths_to_tensor(files_list[3400:4500], shape = img_shape)\n", 179 | "test_tensors = paths_to_tensor(files_list[4500:], shape = img_shape)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 5, 185 | "metadata": { 186 | "collapsed": true 187 | }, 188 | "outputs": [], 189 | "source": [ 190 | "import pickle\n", 191 | "\n", 192 | "train_filename = \"data_preprocessed/train_data_sample_gray.p\"\n", 193 | "pickle.dump((train_labels, train_data, train_tensors), open(train_filename, 'wb'))\n", 194 | "\n", 195 | "valid_filename = \"data_preprocessed/valid_data_sample_gray.p\"\n", 196 | "pickle.dump((valid_labels, valid_data, valid_tensors), open(valid_filename, 'wb'))\n", 197 | "\n", 198 | "test_filename = \"data_preprocessed/test_data_sample_gray.p\"\n", 199 | "pickle.dump((test_labels, test_data, test_tensors), open(test_filename, 'wb'))" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": { 206 | "collapsed": true 207 | }, 208 | "outputs": [], 209 | "source": [] 210 | } 211 | ], 212 | "metadata": { 213 | "kernelspec": { 214 | "display_name": "Python 3", 215 | "language": "python", 216 | "name": "python3" 217 | }, 218 | "language_info": { 219 | "codemirror_mode": { 220 | "name": "ipython", 221 | "version": 3 222 | }, 223 | "file_extension": ".py", 224 | "mimetype": "text/x-python", 225 | "name": "python", 226 | "nbconvert_exporter": "python", 227 | "pygments_lexer": "ipython3", 228 | "version": "3.5.2" 229 | } 230 | }, 231 | "nbformat": 4, 232 | "nbformat_minor": 2 233 | } 234 | -------------------------------------------------------------------------------- /capsulelayers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some key layers used for constructing a Capsule Network. These layers can used to construct CapsNet on other dataset, 3 | not just on MNIST. 4 | *NOTE*: some functions can be implemented in multiple ways, I keep all of them. You can try them for yourself just by 5 | uncommenting them and commenting their counterparts. 6 | 7 | Author: Xifeng Guo, E-mail: `guoxifeng1990@163.com`, Github: `https://github.com/XifengGuo/CapsNet-Keras` 8 | """ 9 | 10 | import keras.backend as K 11 | import tensorflow as tf 12 | from keras import initializers, layers 13 | 14 | 15 | class Length(layers.Layer): 16 | """ 17 | Compute the length of vectors. This is used to compute a Tensor that has the same shape with y_true in margin_loss. 18 | Using this layer as model's output can directly predict labels by using `y_pred = np.argmax(model.predict(x), 1)` 19 | inputs: shape=[None, num_vectors, dim_vector] 20 | output: shape=[None, num_vectors] 21 | """ 22 | def call(self, inputs, **kwargs): 23 | return K.sqrt(K.sum(K.square(inputs), -1)) 24 | 25 | def compute_output_shape(self, input_shape): 26 | return input_shape[:-1] 27 | 28 | 29 | class Mask(layers.Layer): 30 | """ 31 | Mask a Tensor with shape=[None, num_capsule, dim_vector] either by the capsule with max length or by an additional 32 | input mask. Except the max-length capsule (or specified capsule), all vectors are masked to zeros. Then flatten the 33 | masked Tensor. 34 | For example: 35 | ``` 36 | x = keras.layers.Input(shape=[8, 3, 2]) # batch_size=8, each sample contains 3 capsules with dim_vector=2 37 | y = keras.layers.Input(shape=[8, 3]) # True labels. 8 samples, 3 classes, one-hot coding. 38 | out = Mask()(x) # out.shape=[8, 6] 39 | # or 40 | out2 = Mask()([x, y]) # out2.shape=[8,6]. Masked with true labels y. Of course y can also be manipulated. 41 | ``` 42 | """ 43 | def call(self, inputs, **kwargs): 44 | if type(inputs) is list: # true label is provided with shape = [None, n_classes], i.e. one-hot code. 45 | assert len(inputs) == 2 46 | inputs, mask = inputs 47 | else: # if no true label, mask by the max length of capsules. Mainly used for prediction 48 | # compute lengths of capsules 49 | x = K.sqrt(K.sum(K.square(inputs), -1)) 50 | # generate the mask which is a one-hot code. 51 | # mask.shape=[None, n_classes]=[None, num_capsule] 52 | mask = K.one_hot(indices=K.argmax(x, 1), num_classes=x.get_shape().as_list()[1]) 53 | 54 | # inputs.shape=[None, num_capsule, dim_capsule] 55 | # mask.shape=[None, num_capsule] 56 | # masked.shape=[None, num_capsule * dim_capsule] 57 | masked = K.batch_flatten(inputs * K.expand_dims(mask, -1)) 58 | return masked 59 | 60 | def compute_output_shape(self, input_shape): 61 | if type(input_shape[0]) is tuple: # true label provided 62 | return tuple([None, input_shape[0][1] * input_shape[0][2]]) 63 | else: # no true label provided 64 | return tuple([None, input_shape[1] * input_shape[2]]) 65 | 66 | 67 | def squash(vectors, axis=-1): 68 | """ 69 | The non-linear activation used in Capsule. It drives the length of a large vector to near 1 and small vector to 0 70 | :param vectors: some vectors to be squashed, N-dim tensor 71 | :param axis: the axis to squash 72 | :return: a Tensor with same shape as input vectors 73 | """ 74 | s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True) 75 | scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm + K.epsilon()) 76 | return scale * vectors 77 | 78 | 79 | class CapsuleLayer(layers.Layer): 80 | """ 81 | The capsule layer. It is similar to Dense layer. Dense layer has `in_num` inputs, each is a scalar, the output of the 82 | neuron from the former layer, and it has `out_num` output neurons. CapsuleLayer just expand the output of the neuron 83 | from scalar to vector. So its input shape = [None, input_num_capsule, input_dim_capsule] and output shape = \ 84 | [None, num_capsule, dim_capsule]. For Dense Layer, input_dim_capsule = dim_capsule = 1. 85 | 86 | :param num_capsule: number of capsules in this layer 87 | :param dim_capsule: dimension of the output vectors of the capsules in this layer 88 | :param routings: number of iterations for the routing algorithm 89 | """ 90 | def __init__(self, num_capsule, dim_capsule, routings=3, 91 | kernel_initializer='glorot_uniform', 92 | **kwargs): 93 | super(CapsuleLayer, self).__init__(**kwargs) 94 | self.num_capsule = num_capsule 95 | self.dim_capsule = dim_capsule 96 | self.routings = routings 97 | self.kernel_initializer = initializers.get(kernel_initializer) 98 | 99 | def build(self, input_shape): 100 | assert len(input_shape) >= 3, "The input Tensor should have shape=[None, input_num_capsule, input_dim_capsule]" 101 | self.input_num_capsule = input_shape[1] 102 | self.input_dim_capsule = input_shape[2] 103 | 104 | # Transform matrix 105 | self.W = self.add_weight(shape=[self.num_capsule, self.input_num_capsule, 106 | self.dim_capsule, self.input_dim_capsule], 107 | initializer=self.kernel_initializer, 108 | name='W') 109 | 110 | self.built = True 111 | 112 | def call(self, inputs, training=None): 113 | # inputs.shape=[None, input_num_capsule, input_dim_capsule] 114 | # inputs_expand.shape=[None, 1, input_num_capsule, input_dim_capsule] 115 | inputs_expand = K.expand_dims(inputs, 1) 116 | 117 | # Replicate num_capsule dimension to prepare being multiplied by W 118 | # inputs_tiled.shape=[None, num_capsule, input_num_capsule, input_dim_capsule] 119 | inputs_tiled = K.tile(inputs_expand, [1, self.num_capsule, 1, 1]) 120 | 121 | # Compute `inputs * W` by scanning inputs_tiled on dimension 0. 122 | # x.shape=[num_capsule, input_num_capsule, input_dim_capsule] 123 | # W.shape=[num_capsule, input_num_capsule, dim_capsule, input_dim_capsule] 124 | # Regard the first two dimensions as `batch` dimension, 125 | # then matmul: [input_dim_capsule] x [dim_capsule, input_dim_capsule]^T -> [dim_capsule]. 126 | # inputs_hat.shape = [None, num_capsule, input_num_capsule, dim_capsule] 127 | inputs_hat = K.map_fn(lambda x: K.batch_dot(x, self.W, [2, 3]), elems=inputs_tiled) 128 | 129 | # Begin: Routing algorithm ---------------------------------------------------------------------# 130 | # The prior for coupling coefficient, initialized as zeros. 131 | # b.shape = [None, self.num_capsule, self.input_num_capsule]. 132 | b = tf.zeros(shape=[K.shape(inputs_hat)[0], self.num_capsule, self.input_num_capsule]) 133 | 134 | assert self.routings > 0, 'The routings should be > 0.' 135 | for i in range(self.routings): 136 | # c.shape=[batch_size, num_capsule, input_num_capsule] 137 | c = tf.nn.softmax(b, dim=1) 138 | 139 | # c.shape = [batch_size, num_capsule, input_num_capsule] 140 | # inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule] 141 | # The first two dimensions as `batch` dimension, 142 | # then matmal: [input_num_capsule] x [input_num_capsule, dim_capsule] -> [dim_capsule]. 143 | # outputs.shape=[None, num_capsule, dim_capsule] 144 | outputs = squash(K.batch_dot(c, inputs_hat, [2, 2])) # [None, 10, 16] 145 | 146 | if i < self.routings - 1: 147 | # outputs.shape = [None, num_capsule, dim_capsule] 148 | # inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule] 149 | # The first two dimensions as `batch` dimension, 150 | # then matmal: [dim_capsule] x [input_num_capsule, dim_capsule]^T -> [input_num_capsule]. 151 | # b.shape=[batch_size, num_capsule, input_num_capsule] 152 | b += K.batch_dot(outputs, inputs_hat, [2, 3]) 153 | # End: Routing algorithm -----------------------------------------------------------------------# 154 | 155 | return outputs 156 | 157 | def compute_output_shape(self, input_shape): 158 | return tuple([None, self.num_capsule, self.dim_capsule]) 159 | 160 | 161 | def PrimaryCap(inputs, dim_capsule, n_channels, kernel_size, strides, padding): 162 | """ 163 | Apply Conv2D `n_channels` times and concatenate all capsules 164 | :param inputs: 4D tensor, shape=[None, width, height, channels] 165 | :param dim_capsule: the dim of the output vector of capsule 166 | :param n_channels: the number of types of capsules 167 | :return: output tensor, shape=[None, num_capsule, dim_capsule] 168 | """ 169 | output = layers.Conv2D(filters=dim_capsule*n_channels, kernel_size=kernel_size, strides=strides, padding=padding, 170 | name='primarycap_conv2d')(inputs) 171 | outputs = layers.Reshape(target_shape=[-1, dim_capsule], name='primarycap_reshape')(output) 172 | return layers.Lambda(squash, name='primarycap_squash')(outputs) 173 | 174 | 175 | """ 176 | # The following is another way to implement primary capsule layer. This is much slower. 177 | # Apply Conv2D `n_channels` times and concatenate all capsules 178 | def PrimaryCap(inputs, dim_capsule, n_channels, kernel_size, strides, padding): 179 | outputs = [] 180 | for _ in range(n_channels): 181 | output = layers.Conv2D(filters=dim_capsule, kernel_size=kernel_size, strides=strides, padding=padding)(inputs) 182 | outputs.append(layers.Reshape([output.get_shape().as_list()[1] ** 2, dim_capsule])(output)) 183 | outputs = layers.Concatenate(axis=1)(outputs) 184 | return layers.Lambda(squash)(outputs) 185 | """ 186 | -------------------------------------------------------------------------------- /vanilla CNN - SampleDataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# With rgb images" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "### Load data" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": { 21 | "collapsed": true 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "import pickle\n", 26 | "\n", 27 | "train_filename = \"data_preprocessed/train_data_sample_rgb.p\"\n", 28 | "(train_labels, train_data, train_tensors) = pickle.load(open(train_filename, mode='rb'))\n", 29 | "\n", 30 | "valid_filename = \"data_preprocessed/valid_data_sample_rgb.p\"\n", 31 | "(valid_labels, valid_data, valid_tensors) = pickle.load(open(valid_filename, mode='rb'))\n", 32 | "\n", 33 | "test_filename = \"data_preprocessed/test_data_sample_rgb.p\"\n", 34 | "(test_labels, test_data, test_tensors) = pickle.load(open(test_filename, mode='rb'))" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "### CNN model" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": { 48 | "scrolled": true 49 | }, 50 | "outputs": [ 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "_________________________________________________________________\n", 56 | "Layer (type) Output Shape Param # \n", 57 | "=================================================================\n", 58 | "conv2d_1 (Conv2D) (None, 64, 64, 16) 2368 \n", 59 | "_________________________________________________________________\n", 60 | "max_pooling2d_1 (MaxPooling2 (None, 32, 32, 16) 0 \n", 61 | "_________________________________________________________________\n", 62 | "conv2d_2 (Conv2D) (None, 32, 32, 32) 12832 \n", 63 | "_________________________________________________________________\n", 64 | "max_pooling2d_2 (MaxPooling2 (None, 16, 16, 32) 0 \n", 65 | "_________________________________________________________________\n", 66 | "conv2d_3 (Conv2D) (None, 16, 16, 64) 51264 \n", 67 | "_________________________________________________________________\n", 68 | "max_pooling2d_3 (MaxPooling2 (None, 8, 8, 64) 0 \n", 69 | "_________________________________________________________________\n", 70 | "conv2d_4 (Conv2D) (None, 4, 4, 128) 204928 \n", 71 | "_________________________________________________________________\n", 72 | "max_pooling2d_4 (MaxPooling2 (None, 2, 2, 128) 0 \n", 73 | "_________________________________________________________________\n", 74 | "flatten_1 (Flatten) (None, 512) 0 \n", 75 | "_________________________________________________________________\n", 76 | "dense_1 (Dense) (None, 100) 51300 \n", 77 | "_________________________________________________________________\n", 78 | "dense_2 (Dense) (None, 1) 101 \n", 79 | "=================================================================\n", 80 | "Total params: 322,793\n", 81 | "Trainable params: 322,793\n", 82 | "Non-trainable params: 0\n", 83 | "_________________________________________________________________\n" 84 | ] 85 | } 86 | ], 87 | "source": [ 88 | "import time\n", 89 | "\n", 90 | "from keras.layers import Conv2D, MaxPooling2D, GlobalAveragePooling2D, Dropout, Flatten, Dense\n", 91 | "from keras.models import Sequential\n", 92 | "from keras.layers.normalization import BatchNormalization\n", 93 | "from keras import regularizers, initializers, optimizers\n", 94 | "\n", 95 | "model = Sequential()\n", 96 | "\n", 97 | "model.add(Conv2D(filters=16, \n", 98 | " kernel_size=7,\n", 99 | " padding='same', \n", 100 | " activation='relu', \n", 101 | " input_shape=train_tensors.shape[1:]))\n", 102 | "model.add(MaxPooling2D(pool_size=2))\n", 103 | "\n", 104 | "model.add(Conv2D(filters=32, \n", 105 | " kernel_size=5,\n", 106 | " padding='same', \n", 107 | " activation='relu'))\n", 108 | "model.add(MaxPooling2D(pool_size=2))\n", 109 | "\n", 110 | "model.add(Conv2D(filters=64, \n", 111 | " kernel_size=5,\n", 112 | " padding='same', \n", 113 | " activation='relu'))\n", 114 | "model.add(MaxPooling2D(pool_size=2))\n", 115 | "\n", 116 | "model.add(Conv2D(filters=128, \n", 117 | " kernel_size=5,\n", 118 | " strides=2,\n", 119 | " padding='same', \n", 120 | " activation='relu'))\n", 121 | "model.add(MaxPooling2D(pool_size=2))\n", 122 | "\n", 123 | "model.add(Flatten())\n", 124 | "model.add(Dense(100, activation='relu'))\n", 125 | "model.add(Dense(1, activation='sigmoid'))\n", 126 | "\n", 127 | "model.summary()" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 4, 133 | "metadata": { 134 | "collapsed": true 135 | }, 136 | "outputs": [], 137 | "source": [ 138 | "from keras import backend as K\n", 139 | "\n", 140 | "def binary_accuracy(y_true, y_pred):\n", 141 | " return K.mean(K.equal(y_true, K.round(y_pred)))\n", 142 | "\n", 143 | "def precision_threshold(threshold = 0.5):\n", 144 | " def precision(y_true, y_pred):\n", 145 | " threshold_value = threshold\n", 146 | " y_pred = K.cast(K.greater(K.clip(y_pred, 0, 1), threshold_value), K.floatx())\n", 147 | " true_positives = K.round(K.sum(K.clip(y_true * y_pred, 0, 1)))\n", 148 | " predicted_positives = K.sum(y_pred)\n", 149 | " precision_ratio = true_positives / (predicted_positives + K.epsilon())\n", 150 | " return precision_ratio\n", 151 | " return precision\n", 152 | "\n", 153 | "def recall_threshold(threshold = 0.5):\n", 154 | " def recall(y_true, y_pred):\n", 155 | " threshold_value = threshold\n", 156 | " y_pred = K.cast(K.greater(K.clip(y_pred, 0, 1), threshold_value), K.floatx())\n", 157 | " true_positives = K.round(K.sum(K.clip(y_true * y_pred, 0, 1)))\n", 158 | " possible_positives = K.sum(K.clip(y_true, 0, 1))\n", 159 | " recall_ratio = true_positives / (possible_positives + K.epsilon())\n", 160 | " return recall_ratio\n", 161 | " return recall\n", 162 | "\n", 163 | "def fbeta_score_threshold(beta = 1, threshold = 0.5):\n", 164 | " def fbeta_score(y_true, y_pred):\n", 165 | " threshold_value = threshold\n", 166 | " beta_value = beta\n", 167 | " p = precision_threshold(threshold_value)(y_true, y_pred)\n", 168 | " r = recall_threshold(threshold_value)(y_true, y_pred)\n", 169 | " bb = beta_value ** 2\n", 170 | " fbeta_score = (1 + bb) * (p * r) / (bb * p + r + K.epsilon())\n", 171 | " return fbeta_score\n", 172 | " return fbeta_score" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 5, 178 | "metadata": { 179 | "collapsed": true 180 | }, 181 | "outputs": [], 182 | "source": [ 183 | "model.compile(optimizer='sgd', loss='binary_crossentropy', \n", 184 | " metrics=[precision_threshold(threshold = 0.5), \n", 185 | " recall_threshold(threshold = 0.5), \n", 186 | " fbeta_score_threshold(beta=0.5, threshold = 0.5),\n", 187 | " 'accuracy'])" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 6, 193 | "metadata": { 194 | "scrolled": true 195 | }, 196 | "outputs": [ 197 | { 198 | "name": "stdout", 199 | "output_type": "stream", 200 | "text": [ 201 | "Train on 3400 samples, validate on 1100 samples\n", 202 | "Epoch 1/20\n", 203 | "3392/3400 [============================>.] - ETA: 0s - loss: 0.6880 - precision: 0.0000e+00 - recall: 0.0000e+00 - fbeta_score: 0.0000e+00 - acc: 0.5457Epoch 00001: val_loss improved from inf to 0.68814, saving model to saved_models/bCNN.best.from_scratch.hdf5\n", 204 | "3400/3400 [==============================] - 6s 2ms/step - loss: 0.6878 - precision: 0.0000e+00 - recall: 0.0000e+00 - fbeta_score: 0.0000e+00 - acc: 0.5465 - val_loss: 0.6881 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_fbeta_score: 0.0000e+00 - val_acc: 0.5482\n", 205 | "Epoch 2/20\n", 206 | "3392/3400 [============================>.] - ETA: 0s - loss: 0.6853 - precision: 0.0000e+00 - recall: 0.0000e+00 - fbeta_score: 0.0000e+00 - acc: 0.5469Epoch 00002: val_loss improved from 0.68814 to 0.68352, saving model to saved_models/bCNN.best.from_scratch.hdf5\n", 207 | "3400/3400 [==============================] - 2s 637us/step - loss: 0.6853 - precision: 0.0000e+00 - recall: 0.0000e+00 - fbeta_score: 0.0000e+00 - acc: 0.5465 - val_loss: 0.6835 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_fbeta_score: 0.0000e+00 - val_acc: 0.5482\n", 208 | "Epoch 3/20\n", 209 | "3392/3400 [============================>.] - ETA: 0s - loss: 0.6822 - precision: 0.0660 - recall: 0.0078 - fbeta_score: 0.0256 - acc: 0.5478Epoch 00003: val_loss improved from 0.68352 to 0.68176, saving model to saved_models/bCNN.best.from_scratch.hdf5\n", 210 | "3400/3400 [==============================] - 2s 638us/step - loss: 0.6821 - precision: 0.0659 - recall: 0.0078 - fbeta_score: 0.0256 - acc: 0.5479 - val_loss: 0.6818 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_fbeta_score: 0.0000e+00 - val_acc: 0.5482\n", 211 | "Epoch 4/20\n", 212 | "3392/3400 [============================>.] - ETA: 0s - loss: 0.6778 - precision: 0.2974 - recall: 0.0929 - fbeta_score: 0.1770 - acc: 0.5643Epoch 00004: val_loss improved from 0.68176 to 0.67342, saving model to saved_models/bCNN.best.from_scratch.hdf5\n", 213 | "3400/3400 [==============================] - 2s 636us/step - loss: 0.6778 - precision: 0.2979 - recall: 0.0932 - fbeta_score: 0.1776 - acc: 0.5641 - val_loss: 0.6734 - val_precision: 0.6478 - val_recall: 0.3187 - val_fbeta_score: 0.5230 - val_acc: 0.6082\n", 214 | "Epoch 5/20\n", 215 | "3392/3400 [============================>.] - ETA: 0s - loss: 0.6695 - precision: 0.5334 - recall: 0.2741 - fbeta_score: 0.3894 - acc: 0.5843Epoch 00005: val_loss improved from 0.67342 to 0.66195, saving model to saved_models/bCNN.best.from_scratch.hdf5\n", 216 | "3400/3400 [==============================] - 2s 634us/step - loss: 0.6696 - precision: 0.5334 - recall: 0.2740 - fbeta_score: 0.3894 - acc: 0.5841 - val_loss: 0.6619 - val_precision: 0.6358 - val_recall: 0.3035 - val_fbeta_score: 0.5063 - val_acc: 0.6018\n", 217 | "Epoch 6/20\n", 218 | "3392/3400 [============================>.] - ETA: 0s - loss: 0.6619 - precision: 0.6025 - recall: 0.3949 - fbeta_score: 0.4957 - acc: 0.5970Epoch 00006: val_loss did not improve\n", 219 | "3400/3400 [==============================] - 2s 629us/step - loss: 0.6614 - precision: 0.6011 - recall: 0.3940 - fbeta_score: 0.4945 - acc: 0.5976 - val_loss: 0.8364 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_fbeta_score: 0.0000e+00 - val_acc: 0.5482\n", 220 | "Epoch 7/20\n", 221 | "3392/3400 [============================>.] - ETA: 0s - loss: 0.6571 - precision: 0.6038 - recall: 0.4296 - fbeta_score: 0.5116 - acc: 0.6067Epoch 00007: val_loss did not improve\n", 222 | "3400/3400 [==============================] - 2s 630us/step - loss: 0.6567 - precision: 0.6047 - recall: 0.4302 - fbeta_score: 0.5126 - acc: 0.6074 - val_loss: 0.6649 - val_precision: 0.6424 - val_recall: 0.2262 - val_fbeta_score: 0.4548 - val_acc: 0.5936\n", 223 | "Epoch 8/20\n", 224 | "3392/3400 [============================>.] - ETA: 0s - loss: 0.6537 - precision: 0.6049 - recall: 0.4763 - fbeta_score: 0.5451 - acc: 0.6223Epoch 00008: val_loss improved from 0.66195 to 0.65951, saving model to saved_models/bCNN.best.from_scratch.hdf5\n", 225 | "3400/3400 [==============================] - 2s 636us/step - loss: 0.6539 - precision: 0.6049 - recall: 0.4769 - fbeta_score: 0.5453 - acc: 0.6224 - val_loss: 0.6595 - val_precision: 0.5727 - val_recall: 0.6779 - val_fbeta_score: 0.5868 - val_acc: 0.6209\n", 226 | "Epoch 9/20\n", 227 | "3392/3400 [============================>.] - ETA: 0s - loss: 0.6503 - precision: 0.6183 - recall: 0.4933 - fbeta_score: 0.5667 - acc: 0.6312Epoch 00009: val_loss improved from 0.65951 to 0.65006, saving model to saved_models/bCNN.best.from_scratch.hdf5\n", 228 | "3400/3400 [==============================] - 2s 638us/step - loss: 0.6502 - precision: 0.6186 - recall: 0.4939 - fbeta_score: 0.5672 - acc: 0.6315 - val_loss: 0.6501 - val_precision: 0.6236 - val_recall: 0.5744 - val_fbeta_score: 0.6056 - val_acc: 0.6445\n", 229 | "Epoch 10/20\n", 230 | "3392/3400 [============================>.] - ETA: 0s - loss: 0.6486 - precision: 0.6144 - recall: 0.4919 - fbeta_score: 0.5607 - acc: 0.6197Epoch 00010: val_loss did not improve\n", 231 | "3400/3400 [==============================] - 2s 633us/step - loss: 0.6491 - precision: 0.6153 - recall: 0.4912 - fbeta_score: 0.5607 - acc: 0.6194 - val_loss: 0.6955 - val_precision: 0.4951 - val_recall: 0.8959 - val_fbeta_score: 0.5416 - val_acc: 0.5382\n", 232 | "Epoch 11/20\n", 233 | "3392/3400 [============================>.] - ETA: 0s - loss: 0.6462 - precision: 0.6221 - recall: 0.5064 - fbeta_score: 0.5710 - acc: 0.6277Epoch 00011: val_loss did not improve\n", 234 | "3400/3400 [==============================] - 2s 633us/step - loss: 0.6461 - precision: 0.6218 - recall: 0.5076 - fbeta_score: 0.5710 - acc: 0.6276 - val_loss: 0.6587 - val_precision: 0.6374 - val_recall: 0.2935 - val_fbeta_score: 0.5025 - val_acc: 0.6027\n", 235 | "Epoch 12/20\n", 236 | "3392/3400 [============================>.] - ETA: 0s - loss: 0.6435 - precision: 0.6175 - recall: 0.4894 - fbeta_score: 0.5614 - acc: 0.6262Epoch 00012: val_loss did not improve\n", 237 | "3400/3400 [==============================] - 2s 632us/step - loss: 0.6434 - precision: 0.6172 - recall: 0.4890 - fbeta_score: 0.5612 - acc: 0.6262 - val_loss: 0.6522 - val_precision: 0.6512 - val_recall: 0.3691 - val_fbeta_score: 0.5529 - val_acc: 0.6209\n", 238 | "Epoch 00012: early stopping\n", 239 | "training time: 0.50 minutes\n" 240 | ] 241 | } 242 | ], 243 | "source": [ 244 | "from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping\n", 245 | "import numpy as np\n", 246 | "\n", 247 | "epochs = 20\n", 248 | "batch_size = 32\n", 249 | "\n", 250 | "earlystop = EarlyStopping(monitor='val_loss', min_delta=0, patience=3, verbose=1, mode='auto')\n", 251 | "log = CSVLogger('saved_models/log_bCNN_rgb.csv')\n", 252 | "checkpointer = ModelCheckpoint(filepath='saved_models/bCNN.best.from_scratch.hdf5', \n", 253 | " verbose=1, save_best_only=True)\n", 254 | "\n", 255 | "start = time.time()\n", 256 | "\n", 257 | "model.fit(train_tensors, train_labels, \n", 258 | " validation_data=(valid_tensors, valid_labels),\n", 259 | " epochs=epochs, batch_size=batch_size, callbacks=[checkpointer, log, earlystop], verbose=1)\n", 260 | "\n", 261 | "# Show total training time\n", 262 | "print(\"training time: %.2f minutes\"%((time.time()-start)/60))" 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "metadata": {}, 268 | "source": [ 269 | "### Metric" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 7, 275 | "metadata": { 276 | "collapsed": true 277 | }, 278 | "outputs": [], 279 | "source": [ 280 | "model.load_weights('saved_models/bCNN.best.from_scratch.hdf5')\n", 281 | "prediction = model.predict(test_tensors)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 8, 287 | "metadata": {}, 288 | "outputs": [ 289 | { 290 | "name": "stdout", 291 | "output_type": "stream", 292 | "text": [ 293 | "Precision: 0.617234 %\n", 294 | "Recall: 0.588910 %\n", 295 | "Fscore: 0.611354 %\n" 296 | ] 297 | } 298 | ], 299 | "source": [ 300 | "threshold = 0.5\n", 301 | "beta = 0.5\n", 302 | "\n", 303 | "pre = K.eval(precision_threshold(threshold = threshold)(K.variable(value=test_labels),\n", 304 | " K.variable(value=prediction)))\n", 305 | "rec = K.eval(recall_threshold(threshold = threshold)(K.variable(value=test_labels),\n", 306 | " K.variable(value=prediction)))\n", 307 | "fsc = K.eval(fbeta_score_threshold(beta = beta, threshold = threshold)(K.variable(value=test_labels),\n", 308 | " K.variable(value=prediction)))\n", 309 | "\n", 310 | "print (\"Precision: %f %%\\nRecall: %f %%\\nFscore: %f %%\"% (pre, rec, fsc))" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 9, 316 | "metadata": {}, 317 | "outputs": [ 318 | { 319 | "data": { 320 | "text/plain": [ 321 | "0.50258511" 322 | ] 323 | }, 324 | "execution_count": 9, 325 | "metadata": {}, 326 | "output_type": "execute_result" 327 | } 328 | ], 329 | "source": [ 330 | "K.eval(binary_accuracy(K.variable(value=test_labels),\n", 331 | " K.variable(value=prediction)))" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 10, 337 | "metadata": { 338 | "scrolled": true 339 | }, 340 | "outputs": [ 341 | { 342 | "data": { 343 | "text/plain": [ 344 | "array([[ 0.46597424],\n", 345 | " [ 0.57600302],\n", 346 | " [ 0.30108091],\n", 347 | " [ 0.50593966],\n", 348 | " [ 0.61561286],\n", 349 | " [ 0.6416322 ],\n", 350 | " [ 0.29955843],\n", 351 | " [ 0.30611175],\n", 352 | " [ 0.42266327],\n", 353 | " [ 0.40697429],\n", 354 | " [ 0.48799837],\n", 355 | " [ 0.34801716],\n", 356 | " [ 0.55648535],\n", 357 | " [ 0.52279401],\n", 358 | " [ 0.62823325],\n", 359 | " [ 0.35642451],\n", 360 | " [ 0.50304908],\n", 361 | " [ 0.42197177],\n", 362 | " [ 0.72991049],\n", 363 | " [ 0.50801474],\n", 364 | " [ 0.31001693],\n", 365 | " [ 0.49956188],\n", 366 | " [ 0.50922167],\n", 367 | " [ 0.47676209],\n", 368 | " [ 0.36952221],\n", 369 | " [ 0.3691574 ],\n", 370 | " [ 0.63229394],\n", 371 | " [ 0.49165967],\n", 372 | " [ 0.53164726],\n", 373 | " [ 0.54903966]], dtype=float32)" 374 | ] 375 | }, 376 | "execution_count": 10, 377 | "metadata": {}, 378 | "output_type": "execute_result" 379 | } 380 | ], 381 | "source": [ 382 | "prediction[:30]" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": 11, 388 | "metadata": {}, 389 | "outputs": [ 390 | { 391 | "name": "stdout", 392 | "output_type": "stream", 393 | "text": [ 394 | "Precision: 0.543257 %\n", 395 | "Recall: 0.816444 %\n", 396 | "Fscore: 0.582220 %\n" 397 | ] 398 | } 399 | ], 400 | "source": [ 401 | "threshold = 0.4\n", 402 | "beta = 0.5\n", 403 | "\n", 404 | "pre = K.eval(precision_threshold(threshold = threshold)(K.variable(value=test_labels),\n", 405 | " K.variable(value=prediction)))\n", 406 | "rec = K.eval(recall_threshold(threshold = threshold)(K.variable(value=test_labels),\n", 407 | " K.variable(value=prediction)))\n", 408 | "fsc = K.eval(fbeta_score_threshold(beta = beta, threshold = threshold)(K.variable(value=test_labels),\n", 409 | " K.variable(value=prediction)))\n", 410 | "\n", 411 | "print (\"Precision: %f %%\\nRecall: %f %%\\nFscore: %f %%\"% (pre, rec, fsc))" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 12, 417 | "metadata": {}, 418 | "outputs": [ 419 | { 420 | "name": "stdout", 421 | "output_type": "stream", 422 | "text": [ 423 | "Precision: 0.672000 %\n", 424 | "Recall: 0.321224 %\n", 425 | "Fscore: 0.551543 %\n" 426 | ] 427 | } 428 | ], 429 | "source": [ 430 | "threshold = 0.6\n", 431 | "beta = 0.5\n", 432 | "\n", 433 | "pre = K.eval(precision_threshold(threshold = threshold)(K.variable(value=test_labels),\n", 434 | " K.variable(value=prediction)))\n", 435 | "rec = K.eval(recall_threshold(threshold = threshold)(K.variable(value=test_labels),\n", 436 | " K.variable(value=prediction)))\n", 437 | "fsc = K.eval(fbeta_score_threshold(beta = beta, threshold = threshold)(K.variable(value=test_labels),\n", 438 | " K.variable(value=prediction)))\n", 439 | "\n", 440 | "print (\"Precision: %f %%\\nRecall: %f %%\\nFscore: %f %%\"% (pre, rec, fsc))" 441 | ] 442 | }, 443 | { 444 | "cell_type": "markdown", 445 | "metadata": {}, 446 | "source": [ 447 | "# With gray images" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": 13, 453 | "metadata": { 454 | "collapsed": true 455 | }, 456 | "outputs": [], 457 | "source": [ 458 | "import pickle\n", 459 | "\n", 460 | "train_filename = \"data_preprocessed/train_data_sample_gray.p\"\n", 461 | "(train_labels, train_data, train_tensors) = pickle.load(open(train_filename, mode='rb'))\n", 462 | "\n", 463 | "valid_filename = \"data_preprocessed/valid_data_sample_gray.p\"\n", 464 | "(valid_labels, valid_data, valid_tensors) = pickle.load(open(valid_filename, mode='rb'))\n", 465 | "\n", 466 | "test_filename = \"data_preprocessed/test_data_sample_gray.p\"\n", 467 | "(test_labels, test_data, test_tensors) = pickle.load(open(test_filename, mode='rb'))" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": 15, 473 | "metadata": {}, 474 | "outputs": [ 475 | { 476 | "name": "stdout", 477 | "output_type": "stream", 478 | "text": [ 479 | "_________________________________________________________________\n", 480 | "Layer (type) Output Shape Param # \n", 481 | "=================================================================\n", 482 | "conv2d_5 (Conv2D) (None, 64, 64, 16) 800 \n", 483 | "_________________________________________________________________\n", 484 | "max_pooling2d_5 (MaxPooling2 (None, 32, 32, 16) 0 \n", 485 | "_________________________________________________________________\n", 486 | "conv2d_6 (Conv2D) (None, 32, 32, 32) 12832 \n", 487 | "_________________________________________________________________\n", 488 | "max_pooling2d_6 (MaxPooling2 (None, 16, 16, 32) 0 \n", 489 | "_________________________________________________________________\n", 490 | "conv2d_7 (Conv2D) (None, 16, 16, 64) 51264 \n", 491 | "_________________________________________________________________\n", 492 | "max_pooling2d_7 (MaxPooling2 (None, 8, 8, 64) 0 \n", 493 | "_________________________________________________________________\n", 494 | "conv2d_8 (Conv2D) (None, 4, 4, 128) 204928 \n", 495 | "_________________________________________________________________\n", 496 | "max_pooling2d_8 (MaxPooling2 (None, 2, 2, 128) 0 \n", 497 | "_________________________________________________________________\n", 498 | "flatten_2 (Flatten) (None, 512) 0 \n", 499 | "_________________________________________________________________\n", 500 | "dense_3 (Dense) (None, 100) 51300 \n", 501 | "_________________________________________________________________\n", 502 | "dense_4 (Dense) (None, 1) 101 \n", 503 | "=================================================================\n", 504 | "Total params: 321,225\n", 505 | "Trainable params: 321,225\n", 506 | "Non-trainable params: 0\n", 507 | "_________________________________________________________________\n" 508 | ] 509 | } 510 | ], 511 | "source": [ 512 | "import time\n", 513 | "\n", 514 | "from keras.layers import Conv2D, MaxPooling2D, GlobalAveragePooling2D, Dropout, Flatten, Dense\n", 515 | "from keras.models import Sequential\n", 516 | "from keras.layers.normalization import BatchNormalization\n", 517 | "from keras import regularizers, initializers, optimizers\n", 518 | "\n", 519 | "model = Sequential()\n", 520 | "\n", 521 | "model.add(Conv2D(filters=16, \n", 522 | " kernel_size=7,\n", 523 | " padding='same', \n", 524 | " activation='relu', \n", 525 | " input_shape=train_tensors.shape[1:]))\n", 526 | "model.add(MaxPooling2D(pool_size=2))\n", 527 | "\n", 528 | "model.add(Conv2D(filters=32, \n", 529 | " kernel_size=5,\n", 530 | " padding='same', \n", 531 | " activation='relu'))\n", 532 | "model.add(MaxPooling2D(pool_size=2))\n", 533 | "\n", 534 | "model.add(Conv2D(filters=64, \n", 535 | " kernel_size=5,\n", 536 | " padding='same', \n", 537 | " activation='relu'))\n", 538 | "model.add(MaxPooling2D(pool_size=2))\n", 539 | "\n", 540 | "model.add(Conv2D(filters=128, \n", 541 | " kernel_size=5,\n", 542 | " strides=2,\n", 543 | " padding='same', \n", 544 | " activation='relu'))\n", 545 | "model.add(MaxPooling2D(pool_size=2))\n", 546 | "\n", 547 | "model.add(Flatten())\n", 548 | "model.add(Dense(100, activation='relu'))\n", 549 | "model.add(Dense(1, activation='sigmoid'))\n", 550 | "\n", 551 | "model.summary()" 552 | ] 553 | }, 554 | { 555 | "cell_type": "code", 556 | "execution_count": 16, 557 | "metadata": { 558 | "collapsed": true 559 | }, 560 | "outputs": [], 561 | "source": [ 562 | "from keras import backend as K\n", 563 | "\n", 564 | "def binary_accuracy(y_true, y_pred):\n", 565 | " return K.mean(K.equal(y_true, K.round(y_pred)))\n", 566 | "\n", 567 | "def precision_threshold(threshold = 0.5):\n", 568 | " def precision(y_true, y_pred):\n", 569 | " threshold_value = threshold\n", 570 | " y_pred = K.cast(K.greater(K.clip(y_pred, 0, 1), threshold_value), K.floatx())\n", 571 | " true_positives = K.round(K.sum(K.clip(y_true * y_pred, 0, 1)))\n", 572 | " predicted_positives = K.sum(y_pred)\n", 573 | " precision_ratio = true_positives / (predicted_positives + K.epsilon())\n", 574 | " return precision_ratio\n", 575 | " return precision\n", 576 | "\n", 577 | "def recall_threshold(threshold = 0.5):\n", 578 | " def recall(y_true, y_pred):\n", 579 | " threshold_value = threshold\n", 580 | " y_pred = K.cast(K.greater(K.clip(y_pred, 0, 1), threshold_value), K.floatx())\n", 581 | " true_positives = K.round(K.sum(K.clip(y_true * y_pred, 0, 1)))\n", 582 | " possible_positives = K.sum(K.clip(y_true, 0, 1))\n", 583 | " recall_ratio = true_positives / (possible_positives + K.epsilon())\n", 584 | " return recall_ratio\n", 585 | " return recall\n", 586 | "\n", 587 | "def fbeta_score_threshold(beta = 1, threshold = 0.5):\n", 588 | " def fbeta_score(y_true, y_pred):\n", 589 | " threshold_value = threshold\n", 590 | " beta_value = beta\n", 591 | " p = precision_threshold(threshold_value)(y_true, y_pred)\n", 592 | " r = recall_threshold(threshold_value)(y_true, y_pred)\n", 593 | " bb = beta_value ** 2\n", 594 | " fbeta_score = (1 + bb) * (p * r) / (bb * p + r + K.epsilon())\n", 595 | " return fbeta_score\n", 596 | " return fbeta_score" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": 17, 602 | "metadata": { 603 | "collapsed": true 604 | }, 605 | "outputs": [], 606 | "source": [ 607 | "model.compile(optimizer='sgd', loss='binary_crossentropy', \n", 608 | " metrics=[precision_threshold(threshold = 0.5), \n", 609 | " recall_threshold(threshold = 0.5), \n", 610 | " fbeta_score_threshold(beta=0.5, threshold = 0.5),\n", 611 | " 'accuracy'])" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": 18, 617 | "metadata": {}, 618 | "outputs": [ 619 | { 620 | "name": "stdout", 621 | "output_type": "stream", 622 | "text": [ 623 | "Train on 3400 samples, validate on 1100 samples\n", 624 | "Epoch 1/20\n", 625 | "3328/3400 [============================>.] - ETA: 0s - loss: 0.6900 - precision: 0.1182 - recall: 0.2178 - fbeta_score: 0.1266 - acc: 0.5373Epoch 00001: val_loss improved from inf to 0.69002, saving model to saved_models/bCNN_gray.best.from_scratch.hdf5\n", 626 | "3400/3400 [==============================] - 2s 682us/step - loss: 0.6901 - precision: 0.1157 - recall: 0.2131 - fbeta_score: 0.1239 - acc: 0.5359 - val_loss: 0.6900 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_fbeta_score: 0.0000e+00 - val_acc: 0.5291\n", 627 | "Epoch 2/20\n", 628 | "3360/3400 [============================>.] - ETA: 0s - loss: 0.6877 - precision: 0.0000e+00 - recall: 0.0000e+00 - fbeta_score: 0.0000e+00 - acc: 0.5405Epoch 00002: val_loss improved from 0.69002 to 0.68985, saving model to saved_models/bCNN_gray.best.from_scratch.hdf5\n", 629 | "3400/3400 [==============================] - 2s 569us/step - loss: 0.6877 - precision: 0.0000e+00 - recall: 0.0000e+00 - fbeta_score: 0.0000e+00 - acc: 0.5409 - val_loss: 0.6899 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_fbeta_score: 0.0000e+00 - val_acc: 0.5291\n", 630 | "Epoch 3/20\n", 631 | "3360/3400 [============================>.] - ETA: 0s - loss: 0.6861 - precision: 0.0000e+00 - recall: 0.0000e+00 - fbeta_score: 0.0000e+00 - acc: 0.5399Epoch 00003: val_loss did not improve\n", 632 | "3400/3400 [==============================] - 2s 563us/step - loss: 0.6859 - precision: 0.0000e+00 - recall: 0.0000e+00 - fbeta_score: 0.0000e+00 - acc: 0.5409 - val_loss: 0.6921 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_fbeta_score: 0.0000e+00 - val_acc: 0.5291\n", 633 | "Epoch 4/20\n", 634 | "3328/3400 [============================>.] - ETA: 0s - loss: 0.6837 - precision: 0.0096 - recall: 7.3964e-04 - fbeta_score: 0.0028 - acc: 0.5406 Epoch 00004: val_loss improved from 0.68985 to 0.68468, saving model to saved_models/bCNN_gray.best.from_scratch.hdf5\n", 635 | "3400/3400 [==============================] - 2s 574us/step - loss: 0.6837 - precision: 0.0141 - recall: 0.0014 - fbeta_score: 0.0049 - acc: 0.5403 - val_loss: 0.6847 - val_precision: 0.0727 - val_recall: 0.0063 - val_fbeta_score: 0.0230 - val_acc: 0.5309\n", 636 | "Epoch 5/20\n", 637 | "3360/3400 [============================>.] - ETA: 0s - loss: 0.6802 - precision: 0.2212 - recall: 0.0330 - fbeta_score: 0.0913 - acc: 0.5473Epoch 00005: val_loss improved from 0.68468 to 0.68110, saving model to saved_models/bCNN_gray.best.from_scratch.hdf5\n", 638 | "3400/3400 [==============================] - 2s 570us/step - loss: 0.6803 - precision: 0.2186 - recall: 0.0326 - fbeta_score: 0.0902 - acc: 0.5468 - val_loss: 0.6811 - val_precision: 0.2812 - val_recall: 0.0339 - val_fbeta_score: 0.1079 - val_acc: 0.5364\n", 639 | "Epoch 6/20\n", 640 | "3328/3400 [============================>.] - ETA: 0s - loss: 0.6745 - precision: 0.5036 - recall: 0.1669 - fbeta_score: 0.3177 - acc: 0.5817Epoch 00006: val_loss improved from 0.68110 to 0.67981, saving model to saved_models/bCNN_gray.best.from_scratch.hdf5\n", 641 | "3400/3400 [==============================] - 2s 569us/step - loss: 0.6746 - precision: 0.5105 - recall: 0.1685 - fbeta_score: 0.3214 - acc: 0.5809 - val_loss: 0.6798 - val_precision: 0.5527 - val_recall: 0.7384 - val_fbeta_score: 0.5775 - val_acc: 0.5891\n", 642 | "Epoch 7/20\n", 643 | "3296/3400 [============================>.] - ETA: 0s - loss: 0.6672 - precision: 0.6067 - recall: 0.3669 - fbeta_score: 0.4821 - acc: 0.6062Epoch 00007: val_loss did not improve\n", 644 | "3400/3400 [==============================] - 2s 563us/step - loss: 0.6667 - precision: 0.6079 - recall: 0.3663 - fbeta_score: 0.4841 - acc: 0.6076 - val_loss: 0.6825 - val_precision: 0.5411 - val_recall: 0.0988 - val_fbeta_score: 0.2625 - val_acc: 0.5491\n", 645 | "Epoch 8/20\n", 646 | "3328/3400 [============================>.] - ETA: 0s - loss: 0.6620 - precision: 0.5974 - recall: 0.4096 - fbeta_score: 0.5063 - acc: 0.5983Epoch 00008: val_loss did not improve\n", 647 | "3400/3400 [==============================] - 2s 564us/step - loss: 0.6616 - precision: 0.5987 - recall: 0.4100 - fbeta_score: 0.5080 - acc: 0.5988 - val_loss: 0.7088 - val_precision: 0.4828 - val_recall: 0.9658 - val_fbeta_score: 0.5347 - val_acc: 0.4964\n", 648 | "Epoch 9/20\n", 649 | "3360/3400 [============================>.] - ETA: 0s - loss: 0.6549 - precision: 0.6199 - recall: 0.4667 - fbeta_score: 0.5570 - acc: 0.6214Epoch 00009: val_loss improved from 0.67981 to 0.65802, saving model to saved_models/bCNN_gray.best.from_scratch.hdf5\n", 650 | "3400/3400 [==============================] - 2s 569us/step - loss: 0.6568 - precision: 0.6173 - recall: 0.4635 - fbeta_score: 0.5541 - acc: 0.6197 - val_loss: 0.6580 - val_precision: 0.6221 - val_recall: 0.4836 - val_fbeta_score: 0.5805 - val_acc: 0.6209\n", 651 | "Epoch 10/20\n", 652 | "3296/3400 [============================>.] - ETA: 0s - loss: 0.6548 - precision: 0.6294 - recall: 0.4779 - fbeta_score: 0.5588 - acc: 0.6204Epoch 00010: val_loss did not improve\n", 653 | "3400/3400 [==============================] - 2s 564us/step - loss: 0.6544 - precision: 0.6312 - recall: 0.4814 - fbeta_score: 0.5618 - acc: 0.6209 - val_loss: 0.6714 - val_precision: 0.5552 - val_recall: 0.7640 - val_fbeta_score: 0.5833 - val_acc: 0.5964\n", 654 | "Epoch 11/20\n", 655 | "3328/3400 [============================>.] - ETA: 0s - loss: 0.6514 - precision: 0.6366 - recall: 0.5072 - fbeta_score: 0.5816 - acc: 0.6232Epoch 00011: val_loss did not improve\n", 656 | "3400/3400 [==============================] - 2s 564us/step - loss: 0.6521 - precision: 0.6321 - recall: 0.5051 - fbeta_score: 0.5780 - acc: 0.6215 - val_loss: 0.6751 - val_precision: 0.6973 - val_recall: 0.2211 - val_fbeta_score: 0.4559 - val_acc: 0.5773\n", 657 | "Epoch 12/20\n", 658 | "3296/3400 [============================>.] - ETA: 0s - loss: 0.6508 - precision: 0.6292 - recall: 0.4952 - fbeta_score: 0.5799 - acc: 0.6323Epoch 00012: val_loss did not improve\n", 659 | "3400/3400 [==============================] - 2s 565us/step - loss: 0.6498 - precision: 0.6315 - recall: 0.5001 - fbeta_score: 0.5829 - acc: 0.6335 - val_loss: 0.6675 - val_precision: 0.6483 - val_recall: 0.3281 - val_fbeta_score: 0.5318 - val_acc: 0.5991\n", 660 | "Epoch 00012: early stopping\n", 661 | "training time: 0.40 minutes\n" 662 | ] 663 | } 664 | ], 665 | "source": [ 666 | "from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping\n", 667 | "import numpy as np\n", 668 | "\n", 669 | "epochs = 20\n", 670 | "batch_size = 32\n", 671 | "\n", 672 | "earlystop = EarlyStopping(monitor='val_loss', min_delta=0, patience=3, verbose=1, mode='auto')\n", 673 | "log = CSVLogger('saved_models/log_bCNN_gray.csv')\n", 674 | "checkpointer = ModelCheckpoint(filepath='saved_models/bCNN_gray.best.from_scratch.hdf5', \n", 675 | " verbose=1, save_best_only=True)\n", 676 | "\n", 677 | "start = time.time()\n", 678 | "\n", 679 | "model.fit(train_tensors, train_labels, \n", 680 | " validation_data=(valid_tensors, valid_labels),\n", 681 | " epochs=epochs, batch_size=batch_size, callbacks=[checkpointer, log, earlystop], verbose=1)\n", 682 | "\n", 683 | "# Show total training time\n", 684 | "print(\"training time: %.2f minutes\"%((time.time()-start)/60))" 685 | ] 686 | }, 687 | { 688 | "cell_type": "code", 689 | "execution_count": 19, 690 | "metadata": { 691 | "collapsed": true 692 | }, 693 | "outputs": [], 694 | "source": [ 695 | "model.load_weights('saved_models/bCNN_gray.best.from_scratch.hdf5')\n", 696 | "prediction = model.predict(test_tensors)" 697 | ] 698 | }, 699 | { 700 | "cell_type": "code", 701 | "execution_count": 20, 702 | "metadata": {}, 703 | "outputs": [ 704 | { 705 | "name": "stdout", 706 | "output_type": "stream", 707 | "text": [ 708 | "Precision: 0.577114 %\n", 709 | "Recall: 0.480331 %\n", 710 | "Fscore: 0.554758 %\n" 711 | ] 712 | } 713 | ], 714 | "source": [ 715 | "threshold = 0.5\n", 716 | "beta = 0.5\n", 717 | "\n", 718 | "pre = K.eval(precision_threshold(threshold = threshold)(K.variable(value=test_labels),\n", 719 | " K.variable(value=prediction)))\n", 720 | "rec = K.eval(recall_threshold(threshold = threshold)(K.variable(value=test_labels),\n", 721 | " K.variable(value=prediction)))\n", 722 | "fsc = K.eval(fbeta_score_threshold(beta = beta, threshold = threshold)(K.variable(value=test_labels),\n", 723 | " K.variable(value=prediction)))\n", 724 | "\n", 725 | "print (\"Precision: %f %%\\nRecall: %f %%\\nFscore: %f %%\"% (pre, rec, fsc))" 726 | ] 727 | }, 728 | { 729 | "cell_type": "code", 730 | "execution_count": 21, 731 | "metadata": {}, 732 | "outputs": [ 733 | { 734 | "data": { 735 | "text/plain": [ 736 | "0.51713276" 737 | ] 738 | }, 739 | "execution_count": 21, 740 | "metadata": {}, 741 | "output_type": "execute_result" 742 | } 743 | ], 744 | "source": [ 745 | "K.eval(binary_accuracy(K.variable(value=test_labels),\n", 746 | " K.variable(value=prediction)))" 747 | ] 748 | }, 749 | { 750 | "cell_type": "code", 751 | "execution_count": 22, 752 | "metadata": {}, 753 | "outputs": [ 754 | { 755 | "data": { 756 | "text/plain": [ 757 | "array([[ 0.3081094 ],\n", 758 | " [ 0.24159142],\n", 759 | " [ 0.52262437],\n", 760 | " [ 0.59462857],\n", 761 | " [ 0.3100515 ],\n", 762 | " [ 0.62393486],\n", 763 | " [ 0.47555083],\n", 764 | " [ 0.48481095],\n", 765 | " [ 0.47963724],\n", 766 | " [ 0.46049529],\n", 767 | " [ 0.52123272],\n", 768 | " [ 0.38751996],\n", 769 | " [ 0.35624275],\n", 770 | " [ 0.53882909],\n", 771 | " [ 0.63341409],\n", 772 | " [ 0.47135681],\n", 773 | " [ 0.61958778],\n", 774 | " [ 0.42561847],\n", 775 | " [ 0.51211774],\n", 776 | " [ 0.29424879],\n", 777 | " [ 0.38310093],\n", 778 | " [ 0.28851342],\n", 779 | " [ 0.35126474],\n", 780 | " [ 0.65281165],\n", 781 | " [ 0.48659411],\n", 782 | " [ 0.43335259],\n", 783 | " [ 0.32977027],\n", 784 | " [ 0.65944982],\n", 785 | " [ 0.6016748 ],\n", 786 | " [ 0.62601507]], dtype=float32)" 787 | ] 788 | }, 789 | "execution_count": 22, 790 | "metadata": {}, 791 | "output_type": "execute_result" 792 | } 793 | ], 794 | "source": [ 795 | "prediction[:30]" 796 | ] 797 | }, 798 | { 799 | "cell_type": "code", 800 | "execution_count": 23, 801 | "metadata": {}, 802 | "outputs": [ 803 | { 804 | "name": "stdout", 805 | "output_type": "stream", 806 | "text": [ 807 | "Precision: 0.523416 %\n", 808 | "Recall: 0.786749 %\n", 809 | "Fscore: 0.560968 %\n" 810 | ] 811 | } 812 | ], 813 | "source": [ 814 | "threshold = 0.4\n", 815 | "beta = 0.5\n", 816 | "\n", 817 | "pre = K.eval(precision_threshold(threshold = threshold)(K.variable(value=test_labels),\n", 818 | " K.variable(value=prediction)))\n", 819 | "rec = K.eval(recall_threshold(threshold = threshold)(K.variable(value=test_labels),\n", 820 | " K.variable(value=prediction)))\n", 821 | "fsc = K.eval(fbeta_score_threshold(beta = beta, threshold = threshold)(K.variable(value=test_labels),\n", 822 | " K.variable(value=prediction)))\n", 823 | "\n", 824 | "print (\"Precision: %f %%\\nRecall: %f %%\\nFscore: %f %%\"% (pre, rec, fsc))" 825 | ] 826 | }, 827 | { 828 | "cell_type": "code", 829 | "execution_count": null, 830 | "metadata": { 831 | "collapsed": true 832 | }, 833 | "outputs": [], 834 | "source": [] 835 | } 836 | ], 837 | "metadata": { 838 | "kernelspec": { 839 | "display_name": "Python 3", 840 | "language": "python", 841 | "name": "python3" 842 | }, 843 | "language_info": { 844 | "codemirror_mode": { 845 | "name": "ipython", 846 | "version": 3 847 | }, 848 | "file_extension": ".py", 849 | "mimetype": "text/x-python", 850 | "name": "python", 851 | "nbconvert_exporter": "python", 852 | "pygments_lexer": "ipython3", 853 | "version": "3.5.2" 854 | } 855 | }, 856 | "nbformat": 4, 857 | "nbformat_minor": 2 858 | } 859 | -------------------------------------------------------------------------------- /vanilla CNN - FullDataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# With rgb images" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "### Load data" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": { 21 | "collapsed": true 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "import numpy as np\n", 26 | "import pandas as pd\n", 27 | "from glob import glob\n", 28 | "from tqdm import tqdm\n", 29 | "from sklearn.utils import shuffle\n", 30 | "\n", 31 | "df = pd.read_csv('sample/Data_Entry_2017.csv')\n", 32 | "\n", 33 | "diseases = ['Cardiomegaly','Emphysema','Effusion','Hernia','Nodule','Pneumothorax','Atelectasis','Pleural_Thickening','Mass','Edema','Consolidation','Infiltration','Fibrosis','Pneumonia']\n", 34 | "#Number diseases\n", 35 | "for disease in diseases :\n", 36 | " df[disease] = df['Finding Labels'].apply(lambda x: 1 if disease in x else 0)\n", 37 | "\n", 38 | "# #test to perfect\n", 39 | "# df = df.drop(df[df['Emphysema']==0][:-127].index.values)\n", 40 | " \n", 41 | "#remove Y after age\n", 42 | "df['Age']=df['Patient Age'].apply(lambda x: x[:-1]).astype(int)\n", 43 | "df['Age Type']=df['Patient Age'].apply(lambda x: x[-1:])\n", 44 | "df.loc[df['Age Type']=='M',['Age']] = df[df['Age Type']=='M']['Age'].apply(lambda x: round(x/12.)).astype(int)\n", 45 | "df.loc[df['Age Type']=='D',['Age']] = df[df['Age Type']=='D']['Age'].apply(lambda x: round(x/365.)).astype(int)\n", 46 | "# remove outliers\n", 47 | "df = df.drop(df['Age'].sort_values(ascending=False).head(16).index)\n", 48 | "df['Age'] = df['Age']/df['Age'].max()\n", 49 | "\n", 50 | "#one hot data\n", 51 | "# df = df.drop(df.index[4242])\n", 52 | "df = df.join(pd.get_dummies(df['Patient Gender']))\n", 53 | "df = df.join(pd.get_dummies(df['View Position']))\n", 54 | "\n", 55 | "#random samples\n", 56 | "df = shuffle(df)\n", 57 | "\n", 58 | "#get other data\n", 59 | "data = df[['Age', 'F', 'M', 'AP', 'PA']]\n", 60 | "data = np.array(data)\n", 61 | "\n", 62 | "labels = df[diseases].as_matrix()\n", 63 | "files_list = ('sample/images/' + df['Image Index']).tolist()\n", 64 | "\n", 65 | "# #test to perfect\n", 66 | "# labelB = df['Emphysema'].tolist()\n", 67 | "\n", 68 | "labelB = (df[diseases].sum(axis=1)>0).tolist()\n", 69 | "labelB = np.array(labelB, dtype=int)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 3, 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "name": "stderr", 79 | "output_type": "stream", 80 | "text": [ 81 | "Using TensorFlow backend.\n", 82 | "/home/aind2/anaconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6\n", 83 | " return f(*args, **kwds)\n", 84 | "100%|██████████| 89600/89600 [20:26<00:00, 73.03it/s]\n", 85 | "100%|██████████| 11200/11200 [02:33<00:00, 73.10it/s]\n", 86 | "100%|██████████| 11319/11319 [02:38<00:00, 71.35it/s]\n" 87 | ] 88 | } 89 | ], 90 | "source": [ 91 | "from keras.preprocessing import image \n", 92 | "from tqdm import tqdm\n", 93 | "\n", 94 | "def path_to_tensor(img_path, shape):\n", 95 | " # loads RGB image as PIL.Image.Image type\n", 96 | " img = image.load_img(img_path, target_size=shape)\n", 97 | " # convert PIL.Image.Image type to 3D tensor with shape (224, 224, 3)\n", 98 | " x = image.img_to_array(img)/255\n", 99 | " # convert 3D tensor to 4D tensor with shape (1, 224, 224, 3) and return 4D tensor\n", 100 | " return np.expand_dims(x, axis=0)\n", 101 | "\n", 102 | "def paths_to_tensor(img_paths, shape):\n", 103 | " list_of_tensors = [path_to_tensor(img_path, shape) for img_path in tqdm(img_paths)]\n", 104 | " return np.vstack(list_of_tensors)\n", 105 | "\n", 106 | "train_labels = labelB[:89600][:, np.newaxis]\n", 107 | "valid_labels = labelB[89600:100800][:, np.newaxis]\n", 108 | "test_labels = labelB[100800:][:, np.newaxis]\n", 109 | "\n", 110 | "train_data = data[:89600]\n", 111 | "valid_data = data[89600:100800]\n", 112 | "test_data = data[100800:]\n", 113 | "\n", 114 | "img_shape = (64, 64)\n", 115 | "train_tensors = paths_to_tensor(files_list[:89600], shape = img_shape)\n", 116 | "valid_tensors = paths_to_tensor(files_list[89600:100800], shape = img_shape)\n", 117 | "test_tensors = paths_to_tensor(files_list[100800:], shape = img_shape)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "### CNN model" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 11, 130 | "metadata": { 131 | "scrolled": true 132 | }, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "_________________________________________________________________\n", 139 | "Layer (type) Output Shape Param # \n", 140 | "=================================================================\n", 141 | "conv2d_13 (Conv2D) (None, 64, 64, 16) 2368 \n", 142 | "_________________________________________________________________\n", 143 | "max_pooling2d_13 (MaxPooling (None, 32, 32, 16) 0 \n", 144 | "_________________________________________________________________\n", 145 | "conv2d_14 (Conv2D) (None, 32, 32, 32) 12832 \n", 146 | "_________________________________________________________________\n", 147 | "max_pooling2d_14 (MaxPooling (None, 16, 16, 32) 0 \n", 148 | "_________________________________________________________________\n", 149 | "conv2d_15 (Conv2D) (None, 16, 16, 64) 51264 \n", 150 | "_________________________________________________________________\n", 151 | "max_pooling2d_15 (MaxPooling (None, 8, 8, 64) 0 \n", 152 | "_________________________________________________________________\n", 153 | "conv2d_16 (Conv2D) (None, 4, 4, 128) 204928 \n", 154 | "_________________________________________________________________\n", 155 | "max_pooling2d_16 (MaxPooling (None, 2, 2, 128) 0 \n", 156 | "_________________________________________________________________\n", 157 | "flatten_3 (Flatten) (None, 512) 0 \n", 158 | "_________________________________________________________________\n", 159 | "dense_5 (Dense) (None, 100) 51300 \n", 160 | "_________________________________________________________________\n", 161 | "dense_6 (Dense) (None, 1) 101 \n", 162 | "=================================================================\n", 163 | "Total params: 322,793\n", 164 | "Trainable params: 322,793\n", 165 | "Non-trainable params: 0\n", 166 | "_________________________________________________________________\n" 167 | ] 168 | } 169 | ], 170 | "source": [ 171 | "import time\n", 172 | "\n", 173 | "from keras.layers import Conv2D, MaxPooling2D, GlobalAveragePooling2D, Dropout, Flatten, Dense\n", 174 | "from keras.models import Sequential\n", 175 | "from keras.layers.normalization import BatchNormalization\n", 176 | "from keras import regularizers, initializers, optimizers\n", 177 | "\n", 178 | "model = Sequential()\n", 179 | "\n", 180 | "model.add(Conv2D(filters=16, \n", 181 | " kernel_size=7,\n", 182 | " padding='same', \n", 183 | " activation='relu', \n", 184 | " input_shape=train_tensors.shape[1:]))\n", 185 | "model.add(MaxPooling2D(pool_size=2))\n", 186 | "\n", 187 | "model.add(Conv2D(filters=32, \n", 188 | " kernel_size=5,\n", 189 | " padding='same', \n", 190 | " activation='relu'))\n", 191 | "model.add(MaxPooling2D(pool_size=2))\n", 192 | "\n", 193 | "model.add(Conv2D(filters=64, \n", 194 | " kernel_size=5,\n", 195 | " padding='same', \n", 196 | " activation='relu'))\n", 197 | "model.add(MaxPooling2D(pool_size=2))\n", 198 | "\n", 199 | "model.add(Conv2D(filters=128, \n", 200 | " kernel_size=5,\n", 201 | " strides=2,\n", 202 | " padding='same', \n", 203 | " activation='relu'))\n", 204 | "model.add(MaxPooling2D(pool_size=2))\n", 205 | "\n", 206 | "model.add(Flatten())\n", 207 | "model.add(Dense(100, activation='relu'))\n", 208 | "model.add(Dense(1, activation='sigmoid'))\n", 209 | "\n", 210 | "model.summary()" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 12, 216 | "metadata": { 217 | "collapsed": true 218 | }, 219 | "outputs": [], 220 | "source": [ 221 | "from keras import backend as K\n", 222 | "\n", 223 | "def binary_accuracy(y_true, y_pred):\n", 224 | " return K.mean(K.equal(y_true, K.round(y_pred)))\n", 225 | "\n", 226 | "def precision_threshold(threshold = 0.5):\n", 227 | " def precision(y_true, y_pred):\n", 228 | " threshold_value = threshold\n", 229 | " y_pred = K.cast(K.greater(K.clip(y_pred, 0, 1), threshold_value), K.floatx())\n", 230 | " true_positives = K.round(K.sum(K.clip(y_true * y_pred, 0, 1)))\n", 231 | " predicted_positives = K.sum(y_pred)\n", 232 | " precision_ratio = true_positives / (predicted_positives + K.epsilon())\n", 233 | " return precision_ratio\n", 234 | " return precision\n", 235 | "\n", 236 | "def recall_threshold(threshold = 0.5):\n", 237 | " def recall(y_true, y_pred):\n", 238 | " threshold_value = threshold\n", 239 | " y_pred = K.cast(K.greater(K.clip(y_pred, 0, 1), threshold_value), K.floatx())\n", 240 | " true_positives = K.round(K.sum(K.clip(y_true * y_pred, 0, 1)))\n", 241 | " possible_positives = K.sum(K.clip(y_true, 0, 1))\n", 242 | " recall_ratio = true_positives / (possible_positives + K.epsilon())\n", 243 | " return recall_ratio\n", 244 | " return recall\n", 245 | "\n", 246 | "def fbeta_score_threshold(beta = 1, threshold = 0.5):\n", 247 | " def fbeta_score(y_true, y_pred):\n", 248 | " threshold_value = threshold\n", 249 | " beta_value = beta\n", 250 | " p = precision_threshold(threshold_value)(y_true, y_pred)\n", 251 | " r = recall_threshold(threshold_value)(y_true, y_pred)\n", 252 | " bb = beta_value ** 2\n", 253 | " fbeta_score = (1 + bb) * (p * r) / (bb * p + r + K.epsilon())\n", 254 | " return fbeta_score\n", 255 | " return fbeta_score" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 13, 261 | "metadata": { 262 | "collapsed": true 263 | }, 264 | "outputs": [], 265 | "source": [ 266 | "model.compile(optimizer='sgd', loss='binary_crossentropy', \n", 267 | " metrics=[precision_threshold(threshold = 0.5), \n", 268 | " recall_threshold(threshold = 0.5), \n", 269 | " fbeta_score_threshold(beta=0.5, threshold = 0.5),\n", 270 | " 'accuracy'])" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 14, 276 | "metadata": { 277 | "scrolled": true 278 | }, 279 | "outputs": [ 280 | { 281 | "name": "stdout", 282 | "output_type": "stream", 283 | "text": [ 284 | "Train on 89600 samples, validate on 11200 samples\n", 285 | "Epoch 1/20\n", 286 | "89536/89600 [============================>.] - ETA: 0s - loss: 0.6561 - precision: 0.5974 - recall: 0.4632 - fbeta_score: 0.5386 - acc: 0.6178Epoch 00000: val_loss improved from inf to 0.65672, saving model to saved_models/bCNN.best.from_scratch.hdf5\n", 287 | "89600/89600 [==============================] - 52s - loss: 0.6561 - precision: 0.5974 - recall: 0.4632 - fbeta_score: 0.5387 - acc: 0.6178 - val_loss: 0.6567 - val_precision: 0.5574 - val_recall: 0.7224 - val_fbeta_score: 0.5805 - val_acc: 0.6148\n", 288 | "Epoch 2/20\n", 289 | "89504/89600 [============================>.] - ETA: 0s - loss: 0.6416 - precision: 0.6351 - recall: 0.5353 - fbeta_score: 0.6006 - acc: 0.6412Epoch 00001: val_loss improved from 0.65672 to 0.63831, saving model to saved_models/bCNN.best.from_scratch.hdf5\n", 290 | "89600/89600 [==============================] - 51s - loss: 0.6416 - precision: 0.6353 - recall: 0.5353 - fbeta_score: 0.6007 - acc: 0.6412 - val_loss: 0.6383 - val_precision: 0.6061 - val_recall: 0.6122 - val_fbeta_score: 0.6025 - val_acc: 0.6460\n", 291 | "Epoch 3/20\n", 292 | "89504/89600 [============================>.] - ETA: 0s - loss: 0.6354 - precision: 0.6417 - recall: 0.5549 - fbeta_score: 0.6115 - acc: 0.6487Epoch 00002: val_loss improved from 0.63831 to 0.63043, saving model to saved_models/bCNN.best.from_scratch.hdf5\n", 293 | "89600/89600 [==============================] - 51s - loss: 0.6354 - precision: 0.6417 - recall: 0.5549 - fbeta_score: 0.6115 - acc: 0.6487 - val_loss: 0.6304 - val_precision: 0.6347 - val_recall: 0.5533 - val_fbeta_score: 0.6106 - val_acc: 0.6556\n", 294 | "Epoch 4/20\n", 295 | "89504/89600 [============================>.] - ETA: 0s - loss: 0.6303 - precision: 0.6453 - recall: 0.5640 - fbeta_score: 0.6173 - acc: 0.6537Epoch 00003: val_loss improved from 0.63043 to 0.63033, saving model to saved_models/bCNN.best.from_scratch.hdf5\n", 296 | "89600/89600 [==============================] - 51s - loss: 0.6303 - precision: 0.6455 - recall: 0.5640 - fbeta_score: 0.6174 - acc: 0.6538 - val_loss: 0.6303 - val_precision: 0.6627 - val_recall: 0.4711 - val_fbeta_score: 0.6032 - val_acc: 0.6525\n", 297 | "Epoch 5/20\n", 298 | "89504/89600 [============================>.] - ETA: 0s - loss: 0.6267 - precision: 0.6506 - recall: 0.5799 - fbeta_score: 0.6259 - acc: 0.6603Epoch 00004: val_loss improved from 0.63033 to 0.62570, saving model to saved_models/bCNN.best.from_scratch.hdf5\n", 299 | "89600/89600 [==============================] - 52s - loss: 0.6267 - precision: 0.6506 - recall: 0.5800 - fbeta_score: 0.6260 - acc: 0.6603 - val_loss: 0.6257 - val_precision: 0.6350 - val_recall: 0.5960 - val_fbeta_score: 0.6213 - val_acc: 0.6642\n", 300 | "Epoch 6/20\n", 301 | "89504/89600 [============================>.] - ETA: 0s - loss: 0.6232 - precision: 0.6519 - recall: 0.5841 - fbeta_score: 0.6287 - acc: 0.6627Epoch 00005: val_loss improved from 0.62570 to 0.62252, saving model to saved_models/bCNN.best.from_scratch.hdf5\n", 302 | "89600/89600 [==============================] - 52s - loss: 0.6232 - precision: 0.6519 - recall: 0.5839 - fbeta_score: 0.6287 - acc: 0.6627 - val_loss: 0.6225 - val_precision: 0.6655 - val_recall: 0.5105 - val_fbeta_score: 0.6197 - val_acc: 0.6628\n", 303 | "Epoch 7/20\n", 304 | "89504/89600 [============================>.] - ETA: 0s - loss: 0.6204 - precision: 0.6558 - recall: 0.5892 - fbeta_score: 0.6332 - acc: 0.6668Epoch 00006: val_loss did not improve\n", 305 | "89600/89600 [==============================] - 51s - loss: 0.6204 - precision: 0.6558 - recall: 0.5893 - fbeta_score: 0.6333 - acc: 0.6669 - val_loss: 0.6237 - val_precision: 0.6176 - val_recall: 0.6639 - val_fbeta_score: 0.6217 - val_acc: 0.6624\n", 306 | "Epoch 8/20\n", 307 | "89504/89600 [============================>.] - ETA: 0s - loss: 0.6179 - precision: 0.6574 - recall: 0.5978 - fbeta_score: 0.6364 - acc: 0.6693Epoch 00007: val_loss did not improve\n", 308 | "89600/89600 [==============================] - 51s - loss: 0.6180 - precision: 0.6574 - recall: 0.5979 - fbeta_score: 0.6364 - acc: 0.6692 - val_loss: 0.6243 - val_precision: 0.6157 - val_recall: 0.6677 - val_fbeta_score: 0.6206 - val_acc: 0.6618\n", 309 | "Epoch 9/20\n", 310 | "89504/89600 [============================>.] - ETA: 0s - loss: 0.6151 - precision: 0.6565 - recall: 0.5990 - fbeta_score: 0.6364 - acc: 0.6702Epoch 00008: val_loss improved from 0.62252 to 0.61709, saving model to saved_models/bCNN.best.from_scratch.hdf5\n", 311 | "89600/89600 [==============================] - 51s - loss: 0.6152 - precision: 0.6564 - recall: 0.5988 - fbeta_score: 0.6362 - acc: 0.6701 - val_loss: 0.6171 - val_precision: 0.6748 - val_recall: 0.5154 - val_fbeta_score: 0.6274 - val_acc: 0.6702\n", 312 | "Epoch 10/20\n", 313 | "89504/89600 [============================>.] - ETA: 0s - loss: 0.6125 - precision: 0.6618 - recall: 0.6044 - fbeta_score: 0.6417 - acc: 0.6739Epoch 00009: val_loss improved from 0.61709 to 0.61645, saving model to saved_models/bCNN.best.from_scratch.hdf5\n", 314 | "89600/89600 [==============================] - 51s - loss: 0.6125 - precision: 0.6618 - recall: 0.6043 - fbeta_score: 0.6416 - acc: 0.6739 - val_loss: 0.6164 - val_precision: 0.6585 - val_recall: 0.5817 - val_fbeta_score: 0.6353 - val_acc: 0.6746\n", 315 | "Epoch 11/20\n", 316 | "89504/89600 [============================>.] - ETA: 0s - loss: 0.6102 - precision: 0.6636 - recall: 0.6092 - fbeta_score: 0.6443 - acc: 0.6762Epoch 00010: val_loss improved from 0.61645 to 0.61598, saving model to saved_models/bCNN.best.from_scratch.hdf5\n", 317 | "89600/89600 [==============================] - 52s - loss: 0.6102 - precision: 0.6636 - recall: 0.6092 - fbeta_score: 0.6443 - acc: 0.6762 - val_loss: 0.6160 - val_precision: 0.6263 - val_recall: 0.6632 - val_fbeta_score: 0.6286 - val_acc: 0.6686\n", 318 | "Epoch 12/20\n", 319 | "89504/89600 [============================>.] - ETA: 0s - loss: 0.6075 - precision: 0.6644 - recall: 0.6140 - fbeta_score: 0.6461 - acc: 0.6782Epoch 00011: val_loss improved from 0.61598 to 0.61276, saving model to saved_models/bCNN.best.from_scratch.hdf5\n", 320 | "89600/89600 [==============================] - 53s - loss: 0.6074 - precision: 0.6644 - recall: 0.6140 - fbeta_score: 0.6461 - acc: 0.6782 - val_loss: 0.6128 - val_precision: 0.6400 - val_recall: 0.6476 - val_fbeta_score: 0.6363 - val_acc: 0.6753\n", 321 | "Epoch 13/20\n", 322 | "89504/89600 [============================>.] - ETA: 0s - loss: 0.6052 - precision: 0.6678 - recall: 0.6175 - fbeta_score: 0.6498 - acc: 0.6812Epoch 00012: val_loss improved from 0.61276 to 0.60942, saving model to saved_models/bCNN.best.from_scratch.hdf5\n", 323 | "89600/89600 [==============================] - 53s - loss: 0.6052 - precision: 0.6678 - recall: 0.6174 - fbeta_score: 0.6498 - acc: 0.6812 - val_loss: 0.6094 - val_precision: 0.6545 - val_recall: 0.5914 - val_fbeta_score: 0.6348 - val_acc: 0.6748\n", 324 | "Epoch 14/20\n", 325 | "89504/89600 [============================>.] - ETA: 0s - loss: 0.6021 - precision: 0.6698 - recall: 0.6210 - fbeta_score: 0.6520 - acc: 0.6836Epoch 00013: val_loss did not improve\n", 326 | "89600/89600 [==============================] - 53s - loss: 0.6022 - precision: 0.6699 - recall: 0.6210 - fbeta_score: 0.6520 - acc: 0.6836 - val_loss: 0.6187 - val_precision: 0.6155 - val_recall: 0.6913 - val_fbeta_score: 0.6249 - val_acc: 0.6645\n", 327 | "Epoch 15/20\n", 328 | "89504/89600 [============================>.] - ETA: 0s - loss: 0.5997 - precision: 0.6727 - recall: 0.6242 - fbeta_score: 0.6547 - acc: 0.6855Epoch 00014: val_loss did not improve\n", 329 | "89600/89600 [==============================] - 53s - loss: 0.5999 - precision: 0.6724 - recall: 0.6239 - fbeta_score: 0.6544 - acc: 0.6852 - val_loss: 0.6147 - val_precision: 0.6373 - val_recall: 0.6321 - val_fbeta_score: 0.6307 - val_acc: 0.6702\n", 330 | "Epoch 16/20\n", 331 | "89504/89600 [============================>.] - ETA: 0s - loss: 0.5969 - precision: 0.6754 - recall: 0.6313 - fbeta_score: 0.6586 - acc: 0.6883Epoch 00015: val_loss did not improve\n", 332 | "89600/89600 [==============================] - 53s - loss: 0.5970 - precision: 0.6754 - recall: 0.6312 - fbeta_score: 0.6586 - acc: 0.6883 - val_loss: 0.6095 - val_precision: 0.6442 - val_recall: 0.6437 - val_fbeta_score: 0.6389 - val_acc: 0.6779\n", 333 | "Epoch 17/20\n", 334 | "89504/89600 [============================>.] - ETA: 0s - loss: 0.5937 - precision: 0.6794 - recall: 0.6334 - fbeta_score: 0.6622 - acc: 0.6916Epoch 00016: val_loss did not improve\n", 335 | "89600/89600 [==============================] - 53s - loss: 0.5937 - precision: 0.6793 - recall: 0.6334 - fbeta_score: 0.6622 - acc: 0.6916 - val_loss: 0.6133 - val_precision: 0.6844 - val_recall: 0.5338 - val_fbeta_score: 0.6403 - val_acc: 0.6790\n", 336 | "Epoch 00016: early stopping\n", 337 | "training time: 14.89 minutes\n" 338 | ] 339 | } 340 | ], 341 | "source": [ 342 | "from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping\n", 343 | "import numpy as np\n", 344 | "\n", 345 | "epochs = 20\n", 346 | "batch_size = 32\n", 347 | "\n", 348 | "earlystop = EarlyStopping(monitor='val_loss', min_delta=0, patience=3, verbose=1, mode='auto')\n", 349 | "log = CSVLogger('saved_models/log_bCNN_rgb.csv')\n", 350 | "checkpointer = ModelCheckpoint(filepath='saved_models/bCNN.best.from_scratch.hdf5', \n", 351 | " verbose=1, save_best_only=True)\n", 352 | "\n", 353 | "start = time.time()\n", 354 | "\n", 355 | "model.fit(train_tensors, train_labels, \n", 356 | " validation_data=(valid_tensors, valid_labels),\n", 357 | " epochs=epochs, batch_size=batch_size, callbacks=[checkpointer, log, earlystop], verbose=1)\n", 358 | "\n", 359 | "# Show total training time\n", 360 | "print(\"training time: %.2f minutes\"%((time.time()-start)/60))" 361 | ] 362 | }, 363 | { 364 | "cell_type": "markdown", 365 | "metadata": {}, 366 | "source": [ 367 | "### Metric" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": 15, 373 | "metadata": { 374 | "collapsed": true 375 | }, 376 | "outputs": [], 377 | "source": [ 378 | "model.load_weights('saved_models/bCNN.best.from_scratch.hdf5')\n", 379 | "prediction = model.predict(test_tensors)" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": 16, 385 | "metadata": {}, 386 | "outputs": [ 387 | { 388 | "name": "stdout", 389 | "output_type": "stream", 390 | "text": [ 391 | "Precision: 0.672176 %\n", 392 | "Recall: 0.594230 %\n", 393 | "Fscore: 0.654993 %\n" 394 | ] 395 | } 396 | ], 397 | "source": [ 398 | "threshold = 0.5\n", 399 | "beta = 0.5\n", 400 | "\n", 401 | "pre = K.eval(precision_threshold(threshold = threshold)(K.variable(value=test_labels),\n", 402 | " K.variable(value=prediction)))\n", 403 | "rec = K.eval(recall_threshold(threshold = threshold)(K.variable(value=test_labels),\n", 404 | " K.variable(value=prediction)))\n", 405 | "fsc = K.eval(fbeta_score_threshold(beta = beta, threshold = threshold)(K.variable(value=test_labels),\n", 406 | " K.variable(value=prediction)))\n", 407 | "\n", 408 | "print (\"Precision: %f %%\\nRecall: %f %%\\nFscore: %f %%\"% (pre, rec, fsc))" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": 17, 414 | "metadata": {}, 415 | "outputs": [ 416 | { 417 | "data": { 418 | "text/plain": [ 419 | "0.67196751" 420 | ] 421 | }, 422 | "execution_count": 17, 423 | "metadata": {}, 424 | "output_type": "execute_result" 425 | } 426 | ], 427 | "source": [ 428 | "K.eval(binary_accuracy(K.variable(value=test_labels),\n", 429 | " K.variable(value=prediction)))" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": 18, 435 | "metadata": { 436 | "scrolled": true 437 | }, 438 | "outputs": [ 439 | { 440 | "data": { 441 | "text/plain": [ 442 | "array([[ 0.58734381],\n", 443 | " [ 0.57867444],\n", 444 | " [ 0.25462931],\n", 445 | " [ 0.77830732],\n", 446 | " [ 0.25104493],\n", 447 | " [ 0.29603228],\n", 448 | " [ 0.32766595],\n", 449 | " [ 0.3778576 ],\n", 450 | " [ 0.70840192],\n", 451 | " [ 0.18522891],\n", 452 | " [ 0.55603856],\n", 453 | " [ 0.60000223],\n", 454 | " [ 0.73821551],\n", 455 | " [ 0.2869918 ],\n", 456 | " [ 0.22979702],\n", 457 | " [ 0.4054445 ],\n", 458 | " [ 0.32552701],\n", 459 | " [ 0.56464356],\n", 460 | " [ 0.55663085],\n", 461 | " [ 0.58321428],\n", 462 | " [ 0.49937385],\n", 463 | " [ 0.61920291],\n", 464 | " [ 0.76322109],\n", 465 | " [ 0.48952124],\n", 466 | " [ 0.51417869],\n", 467 | " [ 0.26597881],\n", 468 | " [ 0.33098736],\n", 469 | " [ 0.5749808 ],\n", 470 | " [ 0.56771249],\n", 471 | " [ 0.27707309]], dtype=float32)" 472 | ] 473 | }, 474 | "execution_count": 18, 475 | "metadata": {}, 476 | "output_type": "execute_result" 477 | } 478 | ], 479 | "source": [ 480 | "prediction[:30]" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": 19, 486 | "metadata": {}, 487 | "outputs": [ 488 | { 489 | "name": "stdout", 490 | "output_type": "stream", 491 | "text": [ 492 | "Precision: 0.617976 %\n", 493 | "Recall: 0.735481 %\n", 494 | "Fscore: 0.638374 %\n" 495 | ] 496 | } 497 | ], 498 | "source": [ 499 | "threshold = 0.4\n", 500 | "beta = 0.5\n", 501 | "\n", 502 | "pre = K.eval(precision_threshold(threshold = threshold)(K.variable(value=test_labels),\n", 503 | " K.variable(value=prediction)))\n", 504 | "rec = K.eval(recall_threshold(threshold = threshold)(K.variable(value=test_labels),\n", 505 | " K.variable(value=prediction)))\n", 506 | "fsc = K.eval(fbeta_score_threshold(beta = beta, threshold = threshold)(K.variable(value=test_labels),\n", 507 | " K.variable(value=prediction)))\n", 508 | "\n", 509 | "print (\"Precision: %f %%\\nRecall: %f %%\\nFscore: %f %%\"% (pre, rec, fsc))" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": 20, 515 | "metadata": {}, 516 | "outputs": [ 517 | { 518 | "name": "stdout", 519 | "output_type": "stream", 520 | "text": [ 521 | "Precision: 0.712731 %\n", 522 | "Recall: 0.404833 %\n", 523 | "Fscore: 0.618630 %\n" 524 | ] 525 | } 526 | ], 527 | "source": [ 528 | "threshold = 0.6\n", 529 | "beta = 0.5\n", 530 | "\n", 531 | "pre = K.eval(precision_threshold(threshold = threshold)(K.variable(value=test_labels),\n", 532 | " K.variable(value=prediction)))\n", 533 | "rec = K.eval(recall_threshold(threshold = threshold)(K.variable(value=test_labels),\n", 534 | " K.variable(value=prediction)))\n", 535 | "fsc = K.eval(fbeta_score_threshold(beta = beta, threshold = threshold)(K.variable(value=test_labels),\n", 536 | " K.variable(value=prediction)))\n", 537 | "\n", 538 | "print (\"Precision: %f %%\\nRecall: %f %%\\nFscore: %f %%\"% (pre, rec, fsc))" 539 | ] 540 | }, 541 | { 542 | "cell_type": "markdown", 543 | "metadata": {}, 544 | "source": [ 545 | "# With gray images" 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": 1, 551 | "metadata": { 552 | "collapsed": true 553 | }, 554 | "outputs": [], 555 | "source": [ 556 | "import numpy as np\n", 557 | "import pandas as pd\n", 558 | "from glob import glob\n", 559 | "from tqdm import tqdm\n", 560 | "from sklearn.utils import shuffle\n", 561 | "\n", 562 | "df = pd.read_csv('sample/Data_Entry_2017.csv')\n", 563 | "\n", 564 | "diseases = ['Cardiomegaly','Emphysema','Effusion','Hernia','Nodule','Pneumothorax','Atelectasis','Pleural_Thickening','Mass','Edema','Consolidation','Infiltration','Fibrosis','Pneumonia']\n", 565 | "#Number diseases\n", 566 | "for disease in diseases :\n", 567 | " df[disease] = df['Finding Labels'].apply(lambda x: 1 if disease in x else 0)\n", 568 | "\n", 569 | "# #test to perfect\n", 570 | "# df = df.drop(df[df['Emphysema']==0][:-127].index.values)\n", 571 | " \n", 572 | "#remove Y after age\n", 573 | "df['Age']=df['Patient Age'].apply(lambda x: x[:-1]).astype(int)\n", 574 | "df['Age Type']=df['Patient Age'].apply(lambda x: x[-1:])\n", 575 | "df.loc[df['Age Type']=='M',['Age']] = df[df['Age Type']=='M']['Age'].apply(lambda x: round(x/12.)).astype(int)\n", 576 | "df.loc[df['Age Type']=='D',['Age']] = df[df['Age Type']=='D']['Age'].apply(lambda x: round(x/365.)).astype(int)\n", 577 | "# remove outliers\n", 578 | "df = df.drop(df['Age'].sort_values(ascending=False).head(16).index)\n", 579 | "df['Age'] = df['Age']/df['Age'].max()\n", 580 | "\n", 581 | "#one hot data\n", 582 | "# df = df.drop(df.index[4242])\n", 583 | "df = df.join(pd.get_dummies(df['Patient Gender']))\n", 584 | "df = df.join(pd.get_dummies(df['View Position']))\n", 585 | "\n", 586 | "#random samples\n", 587 | "df = shuffle(df)\n", 588 | "\n", 589 | "#get other data\n", 590 | "data = df[['Age', 'F', 'M', 'AP', 'PA']]\n", 591 | "data = np.array(data)\n", 592 | "\n", 593 | "labels = df[diseases].as_matrix()\n", 594 | "files_list = ('sample/images/' + df['Image Index']).tolist()\n", 595 | "\n", 596 | "# #test to perfect\n", 597 | "# labelB = df['Emphysema'].tolist()\n", 598 | "\n", 599 | "labelB = (df[diseases].sum(axis=1)>0).tolist()\n", 600 | "labelB = np.array(labelB, dtype=int)" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": 2, 606 | "metadata": {}, 607 | "outputs": [ 608 | { 609 | "name": "stderr", 610 | "output_type": "stream", 611 | "text": [ 612 | "Using TensorFlow backend.\n", 613 | "/home/aind2/anaconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6\n", 614 | " return f(*args, **kwds)\n", 615 | "100%|██████████| 89600/89600 [19:43<00:00, 75.69it/s]\n", 616 | "100%|██████████| 11200/11200 [02:32<00:00, 73.55it/s]\n", 617 | "100%|██████████| 11319/11319 [02:33<00:00, 73.74it/s]\n" 618 | ] 619 | } 620 | ], 621 | "source": [ 622 | "from keras.preprocessing import image \n", 623 | "from tqdm import tqdm\n", 624 | "\n", 625 | "def path_to_tensor(img_path, shape):\n", 626 | " # loads RGB image as PIL.Image.Image type\n", 627 | " img = image.load_img(img_path, grayscale=True, target_size=shape)\n", 628 | " # convert PIL.Image.Image type to 3D tensor with shape (224, 224, 1)\n", 629 | " x = image.img_to_array(img)/255\n", 630 | " # convert 3D tensor to 4D tensor with shape (1, 224, 224, 1) and return 4D tensor\n", 631 | " return np.expand_dims(x, axis=0)\n", 632 | "\n", 633 | "def paths_to_tensor(img_paths, shape):\n", 634 | " list_of_tensors = [path_to_tensor(img_path, shape) for img_path in tqdm(img_paths)]\n", 635 | " return np.vstack(list_of_tensors)\n", 636 | "\n", 637 | "train_labels = labelB[:89600][:, np.newaxis]\n", 638 | "valid_labels = labelB[89600:100800][:, np.newaxis]\n", 639 | "test_labels = labelB[100800:][:, np.newaxis]\n", 640 | "\n", 641 | "train_data = data[:89600]\n", 642 | "valid_data = data[89600:100800]\n", 643 | "test_data = data[100800:]\n", 644 | "\n", 645 | "img_shape = (64, 64)\n", 646 | "train_tensors = paths_to_tensor(files_list[:89600], shape = img_shape)\n", 647 | "valid_tensors = paths_to_tensor(files_list[89600:100800], shape = img_shape)\n", 648 | "test_tensors = paths_to_tensor(files_list[100800:], shape = img_shape)" 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": 3, 654 | "metadata": {}, 655 | "outputs": [ 656 | { 657 | "name": "stdout", 658 | "output_type": "stream", 659 | "text": [ 660 | "_________________________________________________________________\n", 661 | "Layer (type) Output Shape Param # \n", 662 | "=================================================================\n", 663 | "conv2d_1 (Conv2D) (None, 64, 64, 16) 800 \n", 664 | "_________________________________________________________________\n", 665 | "max_pooling2d_1 (MaxPooling2 (None, 32, 32, 16) 0 \n", 666 | "_________________________________________________________________\n", 667 | "conv2d_2 (Conv2D) (None, 32, 32, 32) 12832 \n", 668 | "_________________________________________________________________\n", 669 | "max_pooling2d_2 (MaxPooling2 (None, 16, 16, 32) 0 \n", 670 | "_________________________________________________________________\n", 671 | "conv2d_3 (Conv2D) (None, 16, 16, 64) 51264 \n", 672 | "_________________________________________________________________\n", 673 | "max_pooling2d_3 (MaxPooling2 (None, 8, 8, 64) 0 \n", 674 | "_________________________________________________________________\n", 675 | "conv2d_4 (Conv2D) (None, 4, 4, 128) 204928 \n", 676 | "_________________________________________________________________\n", 677 | "max_pooling2d_4 (MaxPooling2 (None, 2, 2, 128) 0 \n", 678 | "_________________________________________________________________\n", 679 | "flatten_1 (Flatten) (None, 512) 0 \n", 680 | "_________________________________________________________________\n", 681 | "dense_1 (Dense) (None, 100) 51300 \n", 682 | "_________________________________________________________________\n", 683 | "dense_2 (Dense) (None, 1) 101 \n", 684 | "=================================================================\n", 685 | "Total params: 321,225\n", 686 | "Trainable params: 321,225\n", 687 | "Non-trainable params: 0\n", 688 | "_________________________________________________________________\n" 689 | ] 690 | } 691 | ], 692 | "source": [ 693 | "import time\n", 694 | "\n", 695 | "from keras.layers import Conv2D, MaxPooling2D, GlobalAveragePooling2D, Dropout, Flatten, Dense\n", 696 | "from keras.models import Sequential\n", 697 | "from keras.layers.normalization import BatchNormalization\n", 698 | "from keras import regularizers, initializers, optimizers\n", 699 | "\n", 700 | "model = Sequential()\n", 701 | "\n", 702 | "model.add(Conv2D(filters=16, \n", 703 | " kernel_size=7,\n", 704 | " padding='same', \n", 705 | " activation='relu', \n", 706 | " input_shape=train_tensors.shape[1:]))\n", 707 | "model.add(MaxPooling2D(pool_size=2))\n", 708 | "\n", 709 | "model.add(Conv2D(filters=32, \n", 710 | " kernel_size=5,\n", 711 | " padding='same', \n", 712 | " activation='relu'))\n", 713 | "model.add(MaxPooling2D(pool_size=2))\n", 714 | "\n", 715 | "model.add(Conv2D(filters=64, \n", 716 | " kernel_size=5,\n", 717 | " padding='same', \n", 718 | " activation='relu'))\n", 719 | "model.add(MaxPooling2D(pool_size=2))\n", 720 | "\n", 721 | "model.add(Conv2D(filters=128, \n", 722 | " kernel_size=5,\n", 723 | " strides=2,\n", 724 | " padding='same', \n", 725 | " activation='relu'))\n", 726 | "model.add(MaxPooling2D(pool_size=2))\n", 727 | "\n", 728 | "model.add(Flatten())\n", 729 | "model.add(Dense(100, activation='relu'))\n", 730 | "model.add(Dense(1, activation='sigmoid'))\n", 731 | "\n", 732 | "model.summary()" 733 | ] 734 | }, 735 | { 736 | "cell_type": "code", 737 | "execution_count": 4, 738 | "metadata": { 739 | "collapsed": true 740 | }, 741 | "outputs": [], 742 | "source": [ 743 | "from keras import backend as K\n", 744 | "\n", 745 | "def binary_accuracy(y_true, y_pred):\n", 746 | " return K.mean(K.equal(y_true, K.round(y_pred)))\n", 747 | "\n", 748 | "def precision_threshold(threshold = 0.5):\n", 749 | " def precision(y_true, y_pred):\n", 750 | " threshold_value = threshold\n", 751 | " y_pred = K.cast(K.greater(K.clip(y_pred, 0, 1), threshold_value), K.floatx())\n", 752 | " true_positives = K.round(K.sum(K.clip(y_true * y_pred, 0, 1)))\n", 753 | " predicted_positives = K.sum(y_pred)\n", 754 | " precision_ratio = true_positives / (predicted_positives + K.epsilon())\n", 755 | " return precision_ratio\n", 756 | " return precision\n", 757 | "\n", 758 | "def recall_threshold(threshold = 0.5):\n", 759 | " def recall(y_true, y_pred):\n", 760 | " threshold_value = threshold\n", 761 | " y_pred = K.cast(K.greater(K.clip(y_pred, 0, 1), threshold_value), K.floatx())\n", 762 | " true_positives = K.round(K.sum(K.clip(y_true * y_pred, 0, 1)))\n", 763 | " possible_positives = K.sum(K.clip(y_true, 0, 1))\n", 764 | " recall_ratio = true_positives / (possible_positives + K.epsilon())\n", 765 | " return recall_ratio\n", 766 | " return recall\n", 767 | "\n", 768 | "def fbeta_score_threshold(beta = 1, threshold = 0.5):\n", 769 | " def fbeta_score(y_true, y_pred):\n", 770 | " threshold_value = threshold\n", 771 | " beta_value = beta\n", 772 | " p = precision_threshold(threshold_value)(y_true, y_pred)\n", 773 | " r = recall_threshold(threshold_value)(y_true, y_pred)\n", 774 | " bb = beta_value ** 2\n", 775 | " fbeta_score = (1 + bb) * (p * r) / (bb * p + r + K.epsilon())\n", 776 | " return fbeta_score\n", 777 | " return fbeta_score" 778 | ] 779 | }, 780 | { 781 | "cell_type": "code", 782 | "execution_count": 5, 783 | "metadata": { 784 | "collapsed": true 785 | }, 786 | "outputs": [], 787 | "source": [ 788 | "model.compile(optimizer='sgd', loss='binary_crossentropy', \n", 789 | " metrics=[precision_threshold(threshold = 0.5), \n", 790 | " recall_threshold(threshold = 0.5), \n", 791 | " fbeta_score_threshold(beta=0.5, threshold = 0.5),\n", 792 | " 'accuracy'])" 793 | ] 794 | }, 795 | { 796 | "cell_type": "code", 797 | "execution_count": 6, 798 | "metadata": {}, 799 | "outputs": [ 800 | { 801 | "name": "stdout", 802 | "output_type": "stream", 803 | "text": [ 804 | "Train on 89600 samples, validate on 11200 samples\n", 805 | "Epoch 1/20\n", 806 | "89472/89600 [============================>.] - ETA: 0s - loss: 0.6632 - precision: 0.5035 - recall: 0.3776 - fbeta_score: 0.4438 - acc: 0.6012Epoch 00000: val_loss improved from inf to 0.64742, saving model to saved_models/bCNN_gray.best.from_scratch.hdf5\n", 807 | "89600/89600 [==============================] - 51s - loss: 0.6632 - precision: 0.5035 - recall: 0.3778 - fbeta_score: 0.4439 - acc: 0.6012 - val_loss: 0.6474 - val_precision: 0.6516 - val_recall: 0.3998 - val_fbeta_score: 0.5670 - val_acc: 0.6283\n", 808 | "Epoch 2/20\n", 809 | "89568/89600 [============================>.] - ETA: 0s - loss: 0.6456 - precision: 0.6291 - recall: 0.5252 - fbeta_score: 0.5922 - acc: 0.6347Epoch 00001: val_loss improved from 0.64742 to 0.63779, saving model to saved_models/bCNN_gray.best.from_scratch.hdf5\n", 810 | "89600/89600 [==============================] - 48s - loss: 0.6456 - precision: 0.6291 - recall: 0.5252 - fbeta_score: 0.5922 - acc: 0.6346 - val_loss: 0.6378 - val_precision: 0.6263 - val_recall: 0.5591 - val_fbeta_score: 0.6054 - val_acc: 0.6468\n", 811 | "Epoch 3/20\n", 812 | "89568/89600 [============================>.] - ETA: 0s - loss: 0.6391 - precision: 0.6361 - recall: 0.5440 - fbeta_score: 0.6047 - acc: 0.6444Epoch 00002: val_loss did not improve\n", 813 | "89600/89600 [==============================] - 48s - loss: 0.6391 - precision: 0.6361 - recall: 0.5440 - fbeta_score: 0.6048 - acc: 0.6444 - val_loss: 0.6387 - val_precision: 0.6762 - val_recall: 0.4284 - val_fbeta_score: 0.5951 - val_acc: 0.6449\n", 814 | "Epoch 4/20\n", 815 | "89504/89600 [============================>.] - ETA: 0s - loss: 0.6341 - precision: 0.6433 - recall: 0.5593 - fbeta_score: 0.6146 - acc: 0.6520Epoch 00003: val_loss improved from 0.63779 to 0.62709, saving model to saved_models/bCNN_gray.best.from_scratch.hdf5\n", 816 | "89600/89600 [==============================] - 48s - loss: 0.6341 - precision: 0.6434 - recall: 0.5594 - fbeta_score: 0.6147 - acc: 0.6520 - val_loss: 0.6271 - val_precision: 0.6313 - val_recall: 0.6086 - val_fbeta_score: 0.6215 - val_acc: 0.6601\n", 817 | "Epoch 5/20\n", 818 | "89472/89600 [============================>.] - ETA: 0s - loss: 0.6300 - precision: 0.6482 - recall: 0.5680 - fbeta_score: 0.6209 - acc: 0.6565Epoch 00004: val_loss improved from 0.62709 to 0.62225, saving model to saved_models/bCNN_gray.best.from_scratch.hdf5\n", 819 | "89600/89600 [==============================] - 47s - loss: 0.6299 - precision: 0.6484 - recall: 0.5681 - fbeta_score: 0.6211 - acc: 0.6566 - val_loss: 0.6222 - val_precision: 0.6434 - val_recall: 0.5803 - val_fbeta_score: 0.6238 - val_acc: 0.6628\n", 820 | "Epoch 6/20\n", 821 | "89472/89600 [============================>.] - ETA: 0s - loss: 0.6266 - precision: 0.6516 - recall: 0.5729 - fbeta_score: 0.6244 - acc: 0.6599Epoch 00005: val_loss improved from 0.62225 to 0.61933, saving model to saved_models/bCNN_gray.best.from_scratch.hdf5\n", 822 | "89600/89600 [==============================] - 47s - loss: 0.6266 - precision: 0.6516 - recall: 0.5730 - fbeta_score: 0.6244 - acc: 0.6599 - val_loss: 0.6193 - val_precision: 0.6418 - val_recall: 0.5943 - val_fbeta_score: 0.6267 - val_acc: 0.6646\n", 823 | "Epoch 7/20\n", 824 | "89536/89600 [============================>.] - ETA: 0s - loss: 0.6234 - precision: 0.6541 - recall: 0.5796 - fbeta_score: 0.6289 - acc: 0.6633Epoch 00006: val_loss did not improve\n", 825 | "89600/89600 [==============================] - 47s - loss: 0.6234 - precision: 0.6539 - recall: 0.5796 - fbeta_score: 0.6287 - acc: 0.6633 - val_loss: 0.6254 - val_precision: 0.6851 - val_recall: 0.4468 - val_fbeta_score: 0.6094 - val_acc: 0.6534\n", 826 | "Epoch 8/20\n", 827 | "89568/89600 [============================>.] - ETA: 0s - loss: 0.6206 - precision: 0.6558 - recall: 0.5843 - fbeta_score: 0.6316 - acc: 0.6660Epoch 00007: val_loss did not improve\n", 828 | "89600/89600 [==============================] - 47s - loss: 0.6206 - precision: 0.6558 - recall: 0.5843 - fbeta_score: 0.6316 - acc: 0.6660 - val_loss: 0.6205 - val_precision: 0.6817 - val_recall: 0.4982 - val_fbeta_score: 0.6261 - val_acc: 0.6633\n", 829 | "Epoch 9/20\n", 830 | "89568/89600 [============================>.] - ETA: 0s - loss: 0.6182 - precision: 0.6580 - recall: 0.5909 - fbeta_score: 0.6348 - acc: 0.6679Epoch 00008: val_loss improved from 0.61933 to 0.61392, saving model to saved_models/bCNN_gray.best.from_scratch.hdf5\n", 831 | "89600/89600 [==============================] - 47s - loss: 0.6182 - precision: 0.6580 - recall: 0.5909 - fbeta_score: 0.6348 - acc: 0.6679 - val_loss: 0.6139 - val_precision: 0.6392 - val_recall: 0.6666 - val_fbeta_score: 0.6401 - val_acc: 0.6762\n", 832 | "Epoch 10/20\n", 833 | "89568/89600 [============================>.] - ETA: 0s - loss: 0.6158 - precision: 0.6609 - recall: 0.5961 - fbeta_score: 0.6388 - acc: 0.6712Epoch 00009: val_loss improved from 0.61392 to 0.61250, saving model to saved_models/bCNN_gray.best.from_scratch.hdf5\n", 834 | "89600/89600 [==============================] - 47s - loss: 0.6158 - precision: 0.6608 - recall: 0.5961 - fbeta_score: 0.6387 - acc: 0.6711 - val_loss: 0.6125 - val_precision: 0.6469 - val_recall: 0.6398 - val_fbeta_score: 0.6406 - val_acc: 0.6764\n", 835 | "Epoch 11/20\n", 836 | "89568/89600 [============================>.] - ETA: 0s - loss: 0.6134 - precision: 0.6610 - recall: 0.5980 - fbeta_score: 0.6392 - acc: 0.6727Epoch 00010: val_loss did not improve\n", 837 | "89600/89600 [==============================] - 47s - loss: 0.6134 - precision: 0.6610 - recall: 0.5980 - fbeta_score: 0.6392 - acc: 0.6727 - val_loss: 0.6131 - val_precision: 0.6434 - val_recall: 0.6591 - val_fbeta_score: 0.6417 - val_acc: 0.6771\n", 838 | "Epoch 12/20\n", 839 | "89568/89600 [============================>.] - ETA: 0s - loss: 0.6111 - precision: 0.6655 - recall: 0.6038 - fbeta_score: 0.6440 - acc: 0.6755Epoch 00011: val_loss improved from 0.61250 to 0.60702, saving model to saved_models/bCNN_gray.best.from_scratch.hdf5\n", 840 | "89600/89600 [==============================] - 47s - loss: 0.6111 - precision: 0.6655 - recall: 0.6038 - fbeta_score: 0.6440 - acc: 0.6755 - val_loss: 0.6070 - val_precision: 0.6684 - val_recall: 0.5950 - val_fbeta_score: 0.6462 - val_acc: 0.6793\n", 841 | "Epoch 13/20\n", 842 | "89536/89600 [============================>.] - ETA: 0s - loss: 0.6091 - precision: 0.6655 - recall: 0.6088 - fbeta_score: 0.6460 - acc: 0.6770Epoch 00012: val_loss improved from 0.60702 to 0.60598, saving model to saved_models/bCNN_gray.best.from_scratch.hdf5\n", 843 | "89600/89600 [==============================] - 47s - loss: 0.6091 - precision: 0.6655 - recall: 0.6087 - fbeta_score: 0.6459 - acc: 0.6769 - val_loss: 0.6060 - val_precision: 0.6706 - val_recall: 0.5808 - val_fbeta_score: 0.6443 - val_acc: 0.6776\n", 844 | "Epoch 14/20\n", 845 | "89504/89600 [============================>.] - ETA: 0s - loss: 0.6066 - precision: 0.6661 - recall: 0.6114 - fbeta_score: 0.6467 - acc: 0.6792Epoch 00013: val_loss did not improve\n", 846 | "89600/89600 [==============================] - 47s - loss: 0.6067 - precision: 0.6661 - recall: 0.6114 - fbeta_score: 0.6467 - acc: 0.6792 - val_loss: 0.6065 - val_precision: 0.6767 - val_recall: 0.5648 - val_fbeta_score: 0.6445 - val_acc: 0.6774\n", 847 | "Epoch 15/20\n", 848 | "89568/89600 [============================>.] - ETA: 0s - loss: 0.6043 - precision: 0.6677 - recall: 0.6155 - fbeta_score: 0.6490 - acc: 0.6804Epoch 00014: val_loss did not improve\n", 849 | "89600/89600 [==============================] - 47s - loss: 0.6043 - precision: 0.6677 - recall: 0.6155 - fbeta_score: 0.6490 - acc: 0.6803 - val_loss: 0.6122 - val_precision: 0.6375 - val_recall: 0.6991 - val_fbeta_score: 0.6449 - val_acc: 0.6806\n", 850 | "Epoch 16/20\n", 851 | "89472/89600 [============================>.] - ETA: 0s - loss: 0.6018 - precision: 0.6698 - recall: 0.6196 - fbeta_score: 0.6518 - acc: 0.6833Epoch 00015: val_loss improved from 0.60598 to 0.60483, saving model to saved_models/bCNN_gray.best.from_scratch.hdf5\n", 852 | "89600/89600 [==============================] - 47s - loss: 0.6017 - precision: 0.6698 - recall: 0.6196 - fbeta_score: 0.6518 - acc: 0.6833 - val_loss: 0.6048 - val_precision: 0.6760 - val_recall: 0.5804 - val_fbeta_score: 0.6481 - val_acc: 0.6796\n", 853 | "Epoch 17/20\n", 854 | "89568/89600 [============================>.] - ETA: 0s - loss: 0.5996 - precision: 0.6730 - recall: 0.6234 - fbeta_score: 0.6555 - acc: 0.6861Epoch 00016: val_loss did not improve\n", 855 | "89600/89600 [==============================] - 48s - loss: 0.5997 - precision: 0.6730 - recall: 0.6234 - fbeta_score: 0.6555 - acc: 0.6861 - val_loss: 0.6108 - val_precision: 0.6958 - val_recall: 0.5157 - val_fbeta_score: 0.6425 - val_acc: 0.6755\n", 856 | "Epoch 18/20\n", 857 | "89536/89600 [============================>.] - ETA: 0s - loss: 0.5973 - precision: 0.6737 - recall: 0.6247 - fbeta_score: 0.6562 - acc: 0.6867Epoch 00017: val_loss did not improve\n", 858 | "89600/89600 [==============================] - 50s - loss: 0.5972 - precision: 0.6737 - recall: 0.6248 - fbeta_score: 0.6563 - acc: 0.6868 - val_loss: 0.6143 - val_precision: 0.6934 - val_recall: 0.5248 - val_fbeta_score: 0.6437 - val_acc: 0.6756\n", 859 | "Epoch 19/20\n", 860 | "89568/89600 [============================>.] - ETA: 0s - loss: 0.5948 - precision: 0.6778 - recall: 0.6327 - fbeta_score: 0.6612 - acc: 0.6910Epoch 00018: val_loss did not improve\n", 861 | "89600/89600 [==============================] - 49s - loss: 0.5948 - precision: 0.6779 - recall: 0.6327 - fbeta_score: 0.6613 - acc: 0.6910 - val_loss: 0.6085 - val_precision: 0.6664 - val_recall: 0.6087 - val_fbeta_score: 0.6480 - val_acc: 0.6796\n", 862 | "Epoch 20/20\n", 863 | "89568/89600 [============================>.] - ETA: 0s - loss: 0.5929 - precision: 0.6769 - recall: 0.6336 - fbeta_score: 0.6603 - acc: 0.6906Epoch 00019: val_loss did not improve\n", 864 | "89600/89600 [==============================] - 49s - loss: 0.5929 - precision: 0.6769 - recall: 0.6336 - fbeta_score: 0.6603 - acc: 0.6906 - val_loss: 0.6067 - val_precision: 0.6810 - val_recall: 0.5719 - val_fbeta_score: 0.6497 - val_acc: 0.6806\n", 865 | "Epoch 00019: early stopping\n", 866 | "training time: 16.09 minutes\n" 867 | ] 868 | } 869 | ], 870 | "source": [ 871 | "from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping\n", 872 | "import numpy as np\n", 873 | "\n", 874 | "epochs = 20\n", 875 | "batch_size = 32\n", 876 | "\n", 877 | "earlystop = EarlyStopping(monitor='val_loss', min_delta=0, patience=3, verbose=1, mode='auto')\n", 878 | "log = CSVLogger('saved_models/log_bCNN_gray.csv')\n", 879 | "checkpointer = ModelCheckpoint(filepath='saved_models/bCNN_gray.best.from_scratch.hdf5', \n", 880 | " verbose=1, save_best_only=True)\n", 881 | "\n", 882 | "start = time.time()\n", 883 | "\n", 884 | "model.fit(train_tensors, train_labels, \n", 885 | " validation_data=(valid_tensors, valid_labels),\n", 886 | " epochs=epochs, batch_size=batch_size, callbacks=[checkpointer, log, earlystop], verbose=1)\n", 887 | "\n", 888 | "# Show total training time\n", 889 | "print(\"training time: %.2f minutes\"%((time.time()-start)/60))" 890 | ] 891 | }, 892 | { 893 | "cell_type": "code", 894 | "execution_count": 7, 895 | "metadata": { 896 | "collapsed": true 897 | }, 898 | "outputs": [], 899 | "source": [ 900 | "model.load_weights('saved_models/bCNN_gray.best.from_scratch.hdf5')\n", 901 | "prediction = model.predict(test_tensors)" 902 | ] 903 | }, 904 | { 905 | "cell_type": "code", 906 | "execution_count": 8, 907 | "metadata": {}, 908 | "outputs": [ 909 | { 910 | "name": "stdout", 911 | "output_type": "stream", 912 | "text": [ 913 | "Precision: 0.671851 %\n", 914 | "Recall: 0.572077 %\n", 915 | "Fscore: 0.649206 %\n" 916 | ] 917 | } 918 | ], 919 | "source": [ 920 | "threshold = 0.5\n", 921 | "beta = 0.5\n", 922 | "\n", 923 | "pre = K.eval(precision_threshold(threshold = threshold)(K.variable(value=test_labels),\n", 924 | " K.variable(value=prediction)))\n", 925 | "rec = K.eval(recall_threshold(threshold = threshold)(K.variable(value=test_labels),\n", 926 | " K.variable(value=prediction)))\n", 927 | "fsc = K.eval(fbeta_score_threshold(beta = beta, threshold = threshold)(K.variable(value=test_labels),\n", 928 | " K.variable(value=prediction)))\n", 929 | "\n", 930 | "print (\"Precision: %f %%\\nRecall: %f %%\\nFscore: %f %%\"% (pre, rec, fsc))" 931 | ] 932 | }, 933 | { 934 | "cell_type": "code", 935 | "execution_count": 10, 936 | "metadata": {}, 937 | "outputs": [ 938 | { 939 | "data": { 940 | "text/plain": [ 941 | "0.666622" 942 | ] 943 | }, 944 | "execution_count": 10, 945 | "metadata": {}, 946 | "output_type": "execute_result" 947 | } 948 | ], 949 | "source": [ 950 | "K.eval(binary_accuracy(K.variable(value=test_labels),\n", 951 | " K.variable(value=prediction)))" 952 | ] 953 | }, 954 | { 955 | "cell_type": "code", 956 | "execution_count": 11, 957 | "metadata": {}, 958 | "outputs": [ 959 | { 960 | "name": "stdout", 961 | "output_type": "stream", 962 | "text": [ 963 | "Precision: 0.627903 %\n", 964 | "Recall: 0.710935 %\n", 965 | "Fscore: 0.642921 %\n" 966 | ] 967 | } 968 | ], 969 | "source": [ 970 | "threshold = 0.4\n", 971 | "beta = 0.5\n", 972 | "\n", 973 | "pre = K.eval(precision_threshold(threshold = threshold)(K.variable(value=test_labels),\n", 974 | " K.variable(value=prediction)))\n", 975 | "rec = K.eval(recall_threshold(threshold = threshold)(K.variable(value=test_labels),\n", 976 | " K.variable(value=prediction)))\n", 977 | "fsc = K.eval(fbeta_score_threshold(beta = beta, threshold = threshold)(K.variable(value=test_labels),\n", 978 | " K.variable(value=prediction)))\n", 979 | "\n", 980 | "print (\"Precision: %f %%\\nRecall: %f %%\\nFscore: %f %%\"% (pre, rec, fsc))" 981 | ] 982 | }, 983 | { 984 | "cell_type": "code", 985 | "execution_count": null, 986 | "metadata": { 987 | "collapsed": true 988 | }, 989 | "outputs": [], 990 | "source": [] 991 | } 992 | ], 993 | "metadata": { 994 | "kernelspec": { 995 | "display_name": "Python 3", 996 | "language": "python", 997 | "name": "python3" 998 | }, 999 | "language_info": { 1000 | "codemirror_mode": { 1001 | "name": "ipython", 1002 | "version": 3 1003 | }, 1004 | "file_extension": ".py", 1005 | "mimetype": "text/x-python", 1006 | "name": "python", 1007 | "nbconvert_exporter": "python", 1008 | "pygments_lexer": "ipython3", 1009 | "version": "3.5.2" 1010 | } 1011 | }, 1012 | "nbformat": 4, 1013 | "nbformat_minor": 2 1014 | } 1015 | --------------------------------------------------------------------------------