├── ADNI_subject_id ├── AD.csv ├── MCI.csv └── Normal.csv ├── README.md ├── __init__.py ├── conv3d2d.py ├── convnet_3d.py ├── dlt_utils.py ├── main.py └── maxpool3d.py /ADNI_subject_id/AD.csv: -------------------------------------------------------------------------------- 1 | Subject,Group,Sex,Age,Visit,Modality,Description,Type,Acq Date,Format 2 | 002_S_0619,Patient,M,78,1,MRI,MPR-R; GradWarp; N3,Processed,6/01/2006,NiFTI 3 | 002_S_0816,Patient,M,71,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,8/30/2006,NiFTI 4 | 002_S_0938,Patient,F,83,3,MRI,MPR; GradWarp; B1 Correction; N3,Processed,4/12/2007,NiFTI 5 | 002_S_0955,Patient,F,78,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,10/11/2006,NiFTI 6 | 002_S_1018,Patient,F,71,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,11/29/2006,NiFTI 7 | 002_S_5018,Patient,M,73,22,MRI,MT1; N3m,Processed,11/08/2012,NiFTI 8 | 003_S_1059,Patient,F,85,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,11/09/2006,NiFTI 9 | 003_S_1257,Patient,M,85,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,3/01/2007,NiFTI 10 | 003_S_4136,Patient,M,67,22,MRI,MT1; GradWarp; N3m,Processed,8/10/2011,NiFTI 11 | 003_S_4142,Patient,F,90,22,MRI,MT1; GradWarp; N3m,Processed,8/31/2011,NiFTI 12 | 003_S_4152,Patient,M,61,22,MRI,MT1; GradWarp; N3m,Processed,8/30/2011,NiFTI 13 | 003_S_4373,Patient,F,71,22,MRI,MT1; GradWarp; N3m,Processed,12/15/2011,NiFTI 14 | 003_S_4892,Patient,F,75,22,MRI,MT1; GradWarp; N3m,Processed,8/23/2012,NiFTI 15 | 003_S_5165,Patient,M,79,22,MRI,MT1; GradWarp; N3m,Processed,5/16/2013,NiFTI 16 | 003_S_5187,Patient,F,62,22,MRI,MT1; GradWarp; N3m,Processed,6/07/2013,NiFTI 17 | 005_S_0221,Patient,M,68,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,2/22/2006,NiFTI 18 | 005_S_0814,Patient,F,71,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,8/30/2006,NiFTI 19 | 005_S_0929,Patient,M,82,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,10/02/2006,NiFTI 20 | 005_S_1341,Patient,F,72,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,3/07/2007,NiFTI 21 | 005_S_4707,Patient,M,68,22,MRI,MT1; GradWarp; N3m,Processed,5/15/2012,NiFTI 22 | 005_S_4910,Patient,F,82,22,MRI,MT1; GradWarp; N3m,Processed,9/21/2012,NiFTI 23 | 005_S_5038,Patient,M,82,22,MRI,MT1; GradWarp; N3m,Processed,12/13/2012,NiFTI 24 | 005_S_5119,Patient,F,77,22,MRI,MT1; N3m,Processed,3/28/2013,NiFTI 25 | 006_S_0547,Patient,M,76,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,6/29/2006,NiFTI 26 | 006_S_0653,Patient,F,74,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,7/05/2006,NiFTI 27 | 006_S_4153,Patient,M,79,22,MRI,MT1; N3m,Processed,8/03/2011,NiFTI 28 | 006_S_4192,Patient,M,82,22,MRI,MT1; N3m,Processed,9/27/2011,NiFTI 29 | 006_S_4546,Patient,M,71,22,MRI,MT1; N3m,Processed,3/05/2012,NiFTI 30 | 006_S_4867,Patient,M,75,22,MRI,MT1; N3m,Processed,8/07/2012,NiFTI 31 | 007_S_0316,Patient,M,81,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,3/29/2006,NiFTI 32 | 007_S_1248,Patient,F,80,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,1/25/2007,NiFTI 33 | 007_S_1304,Patient,F,75,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,2/13/2007,NiFTI 34 | 007_S_1339,Patient,F,80,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,2/22/2007,NiFTI 35 | 007_S_4568,Patient,F,71,22,MRI,MT1; GradWarp; N3m,Processed,3/01/2012,NiFTI 36 | 007_S_4911,Patient,M,75,22,MRI,MT1; GradWarp; N3m,Processed,8/30/2012,NiFTI 37 | 007_S_5196,Patient,F,73,22,MRI,MT1; GradWarp; N3m,Processed,6/03/2013,NiFTI 38 | 009_S_1334,Patient,M,64,1,MRI,MPR-R; GradWarp; N3,Processed,3/01/2007,NiFTI 39 | 009_S_1354,Patient,F,59,1,MRI,MPR; GradWarp; N3,Processed,3/06/2007,NiFTI 40 | 009_S_5027,Patient,M,76,22,MRI,MT1; GradWarp; N3m,Processed,11/29/2012,NiFTI 41 | 009_S_5037,Patient,M,67,22,MRI,MT1; GradWarp; N3m,Processed,1/08/2013,NiFTI 42 | 009_S_5224,Patient,M,78,22,MRI,MT1; GradWarp; N3m,Processed,7/11/2013,NiFTI 43 | 009_S_5252,Patient,M,57,22,MRI,MT1; GradWarp; N3m,Processed,7/18/2013,NiFTI 44 | 010_S_0786,Patient,M,75,1,MRI,MPR-R; ; N3,Processed,9/27/2006,NiFTI 45 | 010_S_0829,Patient,F,65,1,MRI,MPR; ; N3,Processed,2/08/2007,NiFTI 46 | 010_S_5163,Patient,M,67,22,MRI,MT1; N3m,Processed,5/29/2013,NiFTI 47 | 011_S_0003,Patient,M,81,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,9/01/2005,NiFTI 48 | 011_S_0010,Patient,F,74,3,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,5/09/2006,NiFTI 49 | 011_S_0053,Patient,M,80,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,11/14/2005,NiFTI 50 | 011_S_0183,Patient,F,72,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,3/03/2006,NiFTI 51 | 011_S_4827,Patient,M,71,22,MRI,MT1; GradWarp; N3m,Processed,7/02/2012,NiFTI 52 | 011_S_4845,Patient,F,68,22,MRI,MT1; GradWarp; N3m,Processed,7/12/2012,NiFTI 53 | 011_S_4906,Patient,F,76,22,MRI,MT1; GradWarp; N3m,Processed,8/13/2012,NiFTI 54 | 011_S_4912,Patient,F,69,22,MRI,MT1; GradWarp; N3m,Processed,8/22/2012,NiFTI 55 | 011_S_4949,Patient,F,78,22,MRI,MT1; GradWarp; N3m,Processed,9/19/2012,NiFTI 56 | 012_S_0689,Patient,M,64,1,MRI,MPR-R; ; N3,Processed,7/05/2006,NiFTI 57 | 012_S_0712,Patient,M,77,1,MRI,MPR; ; N3,Processed,6/29/2006,NiFTI 58 | 012_S_0720,Patient,F,78,1,MRI,MPR; ; N3,Processed,8/09/2006,NiFTI 59 | 012_S_0803,Patient,F,85,1,MRI,MPR; ; N3,Processed,9/22/2006,NiFTI 60 | 013_S_0592,Patient,M,78,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,8/17/2006,NiFTI 61 | 013_S_0699,Patient,M,82,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,8/28/2006,NiFTI 62 | 013_S_0996,Patient,F,91,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,11/06/2006,NiFTI 63 | 013_S_1161,Patient,M,80,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,12/20/2006,NiFTI 64 | 013_S_1205,Patient,M,83,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,1/11/2007,NiFTI 65 | 013_S_5071,Patient,M,76,22,MRI,MT1; N3m,Processed,2/12/2013,NiFTI 66 | 014_S_0328,Patient,M,77,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,3/22/2006,NiFTI 67 | 014_S_0356,Patient,M,80,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,4/05/2006,NiFTI 68 | 014_S_0357,Patient,F,71,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,4/04/2006,NiFTI 69 | 014_S_1095,Patient,F,80,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,12/03/2006,NiFTI 70 | 014_S_4039,Patient,M,56,22,MRI,MT1; GradWarp; N3m,Processed,6/02/2011,NiFTI 71 | 014_S_4615,Patient,M,87,22,MRI,MT1; GradWarp; N3m,Processed,3/29/2012,NiFTI -------------------------------------------------------------------------------- /ADNI_subject_id/MCI.csv: -------------------------------------------------------------------------------- 1 | Subject,Group,Sex,Age,Visit,Modality,Description,Type,Acq Date,Format 2 | 002_S_0782,Patient,M,82,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,8/14/2006,NiFTI 3 | 002_S_0954,Patient,F,69,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,10/10/2006,NiFTI 4 | 002_S_1070,Patient,M,74,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,11/28/2006,NiFTI 5 | 005_S_0222,Patient,M,86,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,2/21/2006,NiFTI 6 | 005_S_0324,Patient,F,75,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,3/30/2006,NiFTI 7 | 005_S_1224,Patient,M,81,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,1/23/2007,NiFTI 8 | 006_S_0322,Patient,M,66,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,6/20/2006,NiFTI 9 | 006_S_0521,Patient,F,71,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,6/27/2006,NiFTI 10 | 006_S_0675,Patient,F,79,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,8/31/2006,NiFTI 11 | 007_S_0041,Patient,F,71,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,10/21/2005,NiFTI 12 | 007_S_0128,Patient,F,64,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,1/16/2006,NiFTI 13 | 007_S_0249,Patient,F,72,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,3/02/2006,NiFTI 14 | 007_S_0293,Patient,M,88,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,3/14/2006,NiFTI 15 | 007_S_0344,Patient,M,79,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,3/31/2006,NiFTI 16 | 007_S_0414,Patient,F,80,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,5/22/2006,NiFTI 17 | 009_S_1199,Patient,M,58,1,MRI,MPR; GradWarp; N3,Processed,1/16/2007,NiFTI 18 | 010_S_0422,Patient,M,62,1,MRI,MPR; ; N3,Processed,6/15/2006,NiFTI 19 | 010_S_0662,Patient,M,79,1,MRI,MPR; ; N3,Processed,7/12/2006,NiFTI 20 | 010_S_0788,Patient,F,62,1,MRI,MPR; ; N3,Processed,8/28/2006,NiFTI 21 | 010_S_0904,Patient,M,84,1,MRI,MPR; ; N3,Processed,12/07/2006,NiFTI 22 | 011_S_0168,Patient,M,89,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,2/10/2006,NiFTI 23 | 011_S_0241,Patient,M,82,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,3/10/2006,NiFTI 24 | 011_S_0326,Patient,M,77,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,3/20/2006,NiFTI 25 | 011_S_0362,Patient,F,71,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,3/28/2006,NiFTI 26 | 011_S_0856,Patient,M,60,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,9/15/2006,NiFTI 27 | 011_S_0861,Patient,M,87,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,9/27/2006,NiFTI 28 | 011_S_1282,Patient,F,77,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,2/09/2007,NiFTI 29 | 012_S_0634,Patient,M,82,1,MRI,MPR; ; N3,Processed,6/16/2006,NiFTI 30 | 012_S_0917,Patient,M,70,1,MRI,MPR; ; N3,Processed,12/01/2006,NiFTI 31 | 012_S_0932,Patient,M,75,1,MRI,MPR; ; N3,Processed,9/20/2006,NiFTI 32 | 012_S_1033,Patient,F,73,1,MRI,MPR; ; N3,Processed,11/16/2006,NiFTI 33 | 012_S_1165,Patient,M,82,1,MRI,MPR; ; N3,Processed,12/28/2006,NiFTI 34 | 012_S_1175,Patient,M,73,1,MRI,MPR; ; N3,Processed,1/05/2007,NiFTI 35 | 012_S_1292,Patient,M,76,1,MRI,MPR; ; N3,Processed,3/01/2007,NiFTI 36 | 012_S_1321,Patient,M,83,1,MRI,MPR; ; N3,Processed,2/22/2007,NiFTI 37 | 013_S_0240,Patient,M,88,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,3/20/2006,NiFTI 38 | 013_S_0325,Patient,F,71,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,4/19/2006,NiFTI 39 | 013_S_0860,Patient,M,85,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,9/21/2006,NiFTI 40 | 013_S_1120,Patient,F,78,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,11/22/2006,NiFTI 41 | 013_S_1275,Patient,F,79,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,2/22/2007,NiFTI 42 | 016_S_0354,Patient,M,76,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,5/05/2006,NiFTI 43 | 016_S_0590,Patient,F,78,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,6/06/2006,NiFTI 44 | 016_S_0769,Patient,M,62,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,8/02/2006,NiFTI 45 | 016_S_1028,Patient,F,77,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,11/02/2006,NiFTI 46 | 016_S_1092,Patient,M,74,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,12/11/2006,NiFTI 47 | 016_S_1121,Patient,F,56,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,12/06/2006,NiFTI 48 | 016_S_1138,Patient,M,67,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,12/28/2006,NiFTI 49 | 016_S_1149,Patient,M,84,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,1/04/2007,NiFTI 50 | 018_S_0057,Patient,M,77,1,MRI,MPR; ; N3,Processed,11/17/2005,NiFTI 51 | 018_S_0080,Patient,M,85,1,MRI,MPR; ; N3,Processed,12/29/2005,NiFTI 52 | 018_S_0087,Patient,M,75,1,MRI,MPR; ; N3,Processed,12/22/2005,NiFTI 53 | 018_S_0103,Patient,M,87,1,MRI,MPR; ; N3,Processed,1/05/2006,NiFTI 54 | 018_S_0155,Patient,M,81,1,MRI,MPR; ; N3,Processed,2/23/2006,NiFTI 55 | 018_S_0406,Patient,M,78,1,MRI,MPR; ; N3,Processed,4/20/2006,NiFTI 56 | 018_S_0450,Patient,M,69,1,MRI,MPR; ; N3,Processed,5/04/2006,NiFTI 57 | 021_S_0141,Patient,M,81,1,MRI,MPR; GradWarp; N3,Processed,1/23/2006,NiFTI 58 | 021_S_0231,Patient,M,60,1,MRI,MPR; GradWarp; N3,Processed,2/28/2006,NiFTI 59 | 021_S_0273,Patient,M,63,1,MRI,MPR; GradWarp; N3,Processed,3/14/2006,NiFTI 60 | 021_S_0332,Patient,M,70,1,MRI,MPR; GradWarp; N3,Processed,4/04/2006,NiFTI 61 | 021_S_0424,Patient,M,81,1,MRI,MPR; GradWarp; N3,Processed,4/20/2006,NiFTI 62 | 022_S_0004,Patient,M,68,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,9/22/2005,NiFTI 63 | 022_S_0044,Patient,F,86,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,11/03/2005,NiFTI 64 | 022_S_0544,Patient,F,77,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,5/17/2006,NiFTI 65 | 022_S_0750,Patient,F,76,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,8/07/2006,NiFTI 66 | 022_S_0924,Patient,M,70,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,9/27/2006,NiFTI 67 | 022_S_0961,Patient,M,73,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,10/20/2006,NiFTI 68 | 022_S_1366,Patient,M,74,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,3/28/2007,NiFTI 69 | 023_S_0030,Patient,F,80,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,10/10/2005,NiFTI 70 | 023_S_0078,Patient,F,76,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,12/16/2005,NiFTI 71 | 023_S_0604,Patient,M,87,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,6/02/2006,NiFTI -------------------------------------------------------------------------------- /ADNI_subject_id/Normal.csv: -------------------------------------------------------------------------------- 1 | Subject,Group,Sex,Age,Visit,Modality,Description,Type,Acq Date,Format 2 | 002_S_0295,Patient,M,85,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,4/18/2006,NiFTI 3 | 002_S_0413,Patient,F,76,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,5/02/2006,NiFTI 4 | 002_S_0559,Patient,M,79,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,5/23/2006,NiFTI 5 | 002_S_0685,Patient,F,90,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,7/06/2006,NiFTI 6 | 002_S_1261,Patient,F,71,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,2/15/2007,NiFTI 7 | 002_S_1280,Patient,F,71,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,2/13/2007,NiFTI 8 | 002_S_4213,Patient,F,78,22,MRI,MT1; N3m,Processed,9/02/2011,NiFTI 9 | 002_S_4225,Patient,M,70,22,MRI,MT1; N3m,Processed,9/21/2011,NiFTI 10 | 002_S_4262,Patient,F,73,22,MRI,MT1; N3m,Processed,10/05/2011,NiFTI 11 | 002_S_4264,Patient,F,74,22,MRI,MT1; N3m,Processed,10/05/2011,NiFTI 12 | 002_S_4270,Patient,F,75,22,MRI,MT1; N3m,Processed,10/11/2011,NiFTI 13 | 003_S_0907,Patient,F,89,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,9/11/2006,NiFTI 14 | 003_S_0931,Patient,F,86,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,10/11/2006,NiFTI 15 | 003_S_0981,Patient,F,84,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,10/19/2006,NiFTI 16 | 003_S_1021,Patient,F,87,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,11/01/2006,NiFTI 17 | 003_S_4081,Patient,F,73,22,MRI,MT1; GradWarp; N3m,Processed,7/05/2011,NiFTI 18 | 003_S_4119,Patient,M,79,22,MRI,MT1; GradWarp; N3m,Processed,8/09/2011,NiFTI 19 | 003_S_4288,Patient,F,73,22,MRI,MT1; GradWarp; N3m,Processed,10/20/2011,NiFTI 20 | 003_S_4350,Patient,M,73,22,MRI,MT1; GradWarp; N3m,Processed,11/09/2011,NiFTI 21 | 003_S_4441,Patient,F,69,22,MRI,MT1; GradWarp; N3m,Processed,1/03/2012,NiFTI 22 | 003_S_4555,Patient,F,66,22,MRI,MT1; GradWarp; N3m,Processed,3/05/2012,NiFTI 23 | 003_S_4644,Patient,F,68,22,MRI,MT1; GradWarp; N3m,Processed,4/19/2012,NiFTI 24 | 003_S_4839,Patient,M,66,22,MRI,MT1; GradWarp; N3m,Processed,7/24/2012,NiFTI 25 | 003_S_4840,Patient,F,62,22,MRI,MT1; GradWarp; N3m,Processed,7/24/2012,NiFTI 26 | 003_S_4872,Patient,F,69,22,MRI,MT1; GradWarp; N3m,Processed,8/02/2012,NiFTI 27 | 003_S_4900,Patient,F,60,22,MRI,MT1; GradWarp; N3m,Processed,8/15/2012,NiFTI 28 | 005_S_0223,Patient,F,78,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,2/28/2006,NiFTI 29 | 005_S_0553,Patient,M,85,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,6/14/2006,NiFTI 30 | 005_S_0602,Patient,M,71,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,6/29/2006,NiFTI 31 | 005_S_0610,Patient,M,79,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,6/21/2006,NiFTI 32 | 006_S_0484,Patient,M,71,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,9/13/2006,NiFTI 33 | 006_S_0498,Patient,M,71,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,6/26/2006,NiFTI 34 | 006_S_0681,Patient,F,77,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,8/31/2006,NiFTI 35 | 006_S_0731,Patient,M,72,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,8/28/2006,NiFTI 36 | 006_S_4150,Patient,M,74,22,MRI,MT1; N3m,Processed,8/08/2011,NiFTI 37 | 006_S_4357,Patient,F,74,22,MRI,MT1; N3m,Processed,11/28/2011,NiFTI 38 | 006_S_4449,Patient,F,67,22,MRI,MT1; N3m,Processed,1/19/2012,NiFTI 39 | 006_S_4485,Patient,M,73,22,MRI,MT1; N3m,Processed,2/01/2012,NiFTI 40 | 007_S_0068,Patient,F,75,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,11/30/2005,NiFTI 41 | 007_S_0070,Patient,M,74,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,1/18/2006,NiFTI 42 | 007_S_1206,Patient,M,73,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,1/17/2007,NiFTI 43 | 007_S_1222,Patient,F,73,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,1/19/2007,NiFTI 44 | 007_S_4387,Patient,F,76,22,MRI,MT1; GradWarp; N3m,Processed,12/05/2011,NiFTI 45 | 007_S_4488,Patient,M,73,22,MRI,MT1; GradWarp; N3m,Processed,1/31/2012,NiFTI 46 | 007_S_4516,Patient,M,71,22,MRI,MT1; GradWarp; N3m,Processed,2/13/2012,NiFTI 47 | 007_S_4620,Patient,M,77,22,MRI,MT1; GradWarp; N3m,Processed,3/28/2012,NiFTI 48 | 007_S_4637,Patient,F,71,22,MRI,MT1; GradWarp; N3m,Processed,4/05/2012,NiFTI 49 | 009_S_0751,Patient,M,71,1,MRI,MPR; GradWarp; N3,Processed,7/25/2006,NiFTI 50 | 009_S_0842,Patient,M,74,1,MRI,MPR; GradWarp; N3,Processed,9/07/2006,NiFTI 51 | 009_S_0862,Patient,F,74,1,MRI,MPR; GradWarp; N3,Processed,9/19/2006,NiFTI 52 | 009_S_4337,Patient,M,72,22,MRI,MT1; GradWarp; N3m,Processed,11/07/2011,NiFTI 53 | 009_S_4388,Patient,M,67,22,MRI,MT1; GradWarp; N3m,Processed,12/13/2011,NiFTI 54 | 009_S_4612,Patient,F,69,22,MRI,MT1; GradWarp; N3m,Processed,3/29/2012,NiFTI 55 | 010_S_0067,Patient,M,75,1,MRI,MPR; ; N3,Processed,11/30/2005,NiFTI 56 | 010_S_0419,Patient,M,70,1,MRI,MPR; ; N3,Processed,5/16/2006,NiFTI 57 | 010_S_0420,Patient,M,74,1,MRI,MPR; ; N3,Processed,6/15/2006,NiFTI 58 | 010_S_0472,Patient,M,72,1,MRI,MPR; ; N3,Processed,7/05/2006,NiFTI 59 | 010_S_4345,Patient,M,70,22,MRI,MT1; N3m,Processed,1/24/2012,NiFTI 60 | 010_S_4442,Patient,F,74,22,MRI,MT1; N3m,Processed,2/07/2012,NiFTI 61 | 011_S_0002,Patient,M,74,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,8/26/2005,NiFTI 62 | 011_S_0005,Patient,M,74,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,9/02/2005,NiFTI 63 | 011_S_0008,Patient,F,85,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,9/13/2005,NiFTI 64 | 011_S_0016,Patient,M,66,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,9/27/2005,NiFTI 65 | 011_S_0021,Patient,F,73,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,10/10/2005,NiFTI 66 | 011_S_0022,Patient,M,63,1,MRI,MPR; GradWarp; B1 Correction; N3,Processed,10/10/2005,NiFTI 67 | 011_S_0023,Patient,M,72,1,MRI,MPR-R; GradWarp; B1 Correction; N3,Processed,10/31/2005,NiFTI 68 | 011_S_4075,Patient,M,73,22,MRI,MT1; GradWarp; N3m,Processed,6/10/2011,NiFTI 69 | 011_S_4105,Patient,F,71,22,MRI,MT1; GradWarp; N3m,Processed,7/08/2011,NiFTI 70 | 011_S_4120,Patient,F,82,22,MRI,MT1; GradWarp; N3m,Processed,7/15/2011,NiFTI 71 | 011_S_4222,Patient,F,82,22,MRI,MT1; GradWarp; N3m,Processed,9/14/2011,NiFTI -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Alzheimer Disease Diagnosis by Deeply Supervised 3D Convolutional Network 2 | Diagnosing Alzheimer disease from 3D MRI T1 scans from ADNI dataset. The initial results using 3D Convolutional Network is published in ICIP 2016 [[1]](https://arxiv.org/abs/1607.00455). The second model used deeply supervision to boost the performance on all binary and three-way classification of AD/MCI/Normal classes. The results are published on arxiv [[2]](https://arxiv.org/abs/1607.00556) 3 | 4 | ### Using Transfer Learning 5 | * Pretraining 3D CNN with 3D Convolutional Autoencoder on source domain 6 | * Finetuning uper fully-connected layers of 3D CNN using supervised fine-tuning on target domain 7 | * Using deeply supervision in supervised fine-tuning of upper fully-connected layers 8 | 9 | ### DATA 10 | List of all subject ids are in ADNI_subject_id directory 11 | 12 | 13 | ###Papers 14 | * [1] E. Hosseini-Asl, R. Keynton and A. El-Baz, "Alzheimer's disease diagnostics by adaptation of 3D convolutional network," 2016 IEEE International Conference on Image Processing (ICIP), Phoenix, AZ, USA, 2016, pp. 126-130. 15 | * [2] E. Hosseini-Asl, G. Gimel'farb, and A. El-Baz, “Alzheimer's Disease Diagnostics by a Deeply Supervised Adaptable 3D Convolutional Network”, arXiv:1607.00556 [cs.LG, q-bio.NC, stat.ML], 2016. 16 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ehosseiniasl/3d-convolutional-network/ff6149f02081996dfa006605b93924468efadba8/__init__.py -------------------------------------------------------------------------------- /conv3d2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3D convolutions using GPU accelereration for Theano (using conv2d) 3 | https://github.com/jaberg/TheanoConv3d2d 4 | 5 | """ 6 | 7 | 8 | import theano 9 | from theano.gradient import DisconnectedType 10 | from theano.gof import Op, Apply, TopoOptimizer 11 | from theano import tensor 12 | import theano.sandbox.cuda as cuda 13 | 14 | 15 | def get_diagonal_subtensor_view(x, i0, i1): 16 | """Helper function for DiagonalSubtensor and 17 | IncDiagonalSubtensor 18 | 19 | :note: it return a partial view of x, not a partial copy. 20 | """ 21 | 22 | if x.shape[i0] < x.shape[i1]: 23 | raise NotImplementedError('is this allowed?') 24 | idx = [slice(None)] * x.ndim 25 | # idx[i0] = slice(x.shape[i1] - 1, None, None) 26 | xview = x.__getitem__(tuple(idx)) 27 | strides = list(xview.strides) 28 | # strides[i1] -= strides[i0] 29 | xview.strides = strides 30 | return xview 31 | 32 | 33 | class DiagonalSubtensor(Op): 34 | """Return a form a nd diagonal subtensor. 35 | 36 | :param x: n-d tensor 37 | :param i0: axis index in x 38 | :param i1: axis index in x 39 | :note: Work on the GPU. 40 | 41 | ``x`` is some n-dimensional tensor, but this Op only deals with a 42 | matrix-shaped slice, using axes i0 and i1. Without loss of 43 | generality, suppose that ``i0`` picks out our ``row`` dimension, 44 | and i1 the ``column`` dimension. 45 | 46 | So the relevant part of ``x`` is some matrix ``u``. Suppose it has 7 rows 47 | and 4 columns:: 48 | 49 | [ 0 0 0 0 ] 50 | [ 0 0 0 0 ] 51 | [ 0 0 0 0 ] 52 | [ 0 0 0 0 ] 53 | [ 0 0 0 0 ] 54 | [ 0 0 0 0 ] 55 | 56 | The view returned by this function is also a matrix. It's a thick, 57 | diagonal ``stripe`` across u that discards the lower left triangle 58 | and the upper right triangle: 59 | 60 | [ x 0 0 0 ] 61 | [ x x 0 0 ] 62 | [ x x x 0 ] 63 | [ 0 x x x ] 64 | [ 0 0 x x ] 65 | [ 0 0 0 x ] 66 | 67 | In this case the return value would be this view of shape 3x4. The 68 | returned view has the same number of dimensions as the input 69 | ``x``, and the only difference is that the shape along dimension 70 | ``i0`` has been reduced by ``shape[i1] - 1`` because of the 71 | triangles that got chopped out. 72 | 73 | The NotImplementedError is meant to catch the case where shape[i0] 74 | is too small for the stripe to reach across the matrix, in which 75 | case it's not clear what this function should do. Maybe always 76 | raise an error. I'd look back to the call site in the Conv3D to 77 | see what's necessary at that point. 78 | 79 | """ 80 | def __str__(self): 81 | if self.inplace: 82 | return "%s{inplace}" % self.__class__.__name__ 83 | return "%s" % self.__class__.__name__ 84 | 85 | def __init__(self, inplace=False): 86 | self.inplace = inplace 87 | if inplace: 88 | self.view_map = {0: [0]} 89 | 90 | def __eq__(self, other): 91 | return type(self) == type(other) and self.inplace == other.inplace 92 | 93 | def __hash__(self): 94 | return hash((type(self), self.inplace)) 95 | 96 | def make_node(self, x, i0, i1): 97 | _i0 = tensor.as_tensor_variable(i0) 98 | _i1 = tensor.as_tensor_variable(i1) 99 | return Apply(self, [x, _i0, _i1], [x.type()]) 100 | 101 | def perform(self, node, inputs, output_storage): 102 | xview = get_diagonal_subtensor_view(*inputs) 103 | if self.inplace: 104 | output_storage[0][0] = xview 105 | else: 106 | output_storage[0][0] = xview.copy() 107 | 108 | def grad(self, inputs, g_outputs): 109 | z = tensor.zeros_like(inputs[0]) 110 | gx = inc_diagonal_subtensor(z, inputs[1], inputs[2], g_outputs[0]) 111 | return [gx, DisconnectedType()(), DisconnectedType()()] 112 | 113 | def connection_pattern(self, node): 114 | rval = [[True], [False], [False]] 115 | return rval 116 | 117 | diagonal_subtensor = DiagonalSubtensor(False) 118 | 119 | 120 | class IncDiagonalSubtensor(Op): 121 | """ 122 | The gradient of DiagonalSubtensor 123 | """ 124 | def __str__(self): 125 | if self.inplace: 126 | return "%s{inplace}" % self.__class__.__name__ 127 | return "%s" % self.__class__.__name__ 128 | 129 | def __init__(self, inplace=False): 130 | self.inplace = inplace 131 | if inplace: 132 | self.destroy_map = {0: [0]} 133 | 134 | def __eq__(self, other): 135 | return type(self) == type(other) and self.inplace == other.inplace 136 | 137 | def __hash__(self): 138 | return hash((type(self), self.inplace)) 139 | 140 | def make_node(self, x, i0, i1, amt): 141 | _i0 = tensor.as_tensor_variable(i0) 142 | _i1 = tensor.as_tensor_variable(i1) 143 | return Apply(self, [x, _i0, _i1, amt], [x.type()]) 144 | 145 | def perform(self, node, inputs, output_storage): 146 | x, i0, i1, amt = inputs 147 | if not self.inplace: 148 | x = x.copy() 149 | xview = get_diagonal_subtensor_view(x, i0, i1) 150 | xview += amt 151 | output_storage[0][0] = x 152 | 153 | def grad(self, inputs, g_outputs): 154 | x, i0, i1, amt = inputs 155 | gy = g_outputs[0] 156 | return [gy, DisconnectedType()(), DisconnectedType()(), 157 | diagonal_subtensor(gy, i0, i1)] 158 | 159 | def connection_pattern(self, node): 160 | rval = [[True], [False], [False], [True]] 161 | return rval 162 | inc_diagonal_subtensor = IncDiagonalSubtensor(False) 163 | 164 | 165 | def conv3d(signals, filters, 166 | signals_shape=None, filters_shape=None, 167 | border_mode='valid'): 168 | """Convolve spatio-temporal filters with a movie. 169 | 170 | :param signals: timeseries of images whose pixels have color channels. 171 | shape: [Ns, Ts, C, Hs, Ws] 172 | :param filters: spatio-temporal filters 173 | shape: [Nf, Tf, C, Hf, Wf] 174 | :param signals_shape: None or a tuple/list with the shape of signals 175 | :param filters_shape: None or a tuple/list with the shape of filters 176 | :param border_mode: The only one tested is 'valid'. 177 | 178 | :note: Work on the GPU. 179 | """ 180 | 181 | if isinstance(border_mode, str): 182 | border_mode = (border_mode, border_mode, border_mode) 183 | 184 | _signals_shape_5d = signals.shape if signals_shape is None else signals_shape 185 | _filters_shape_5d = filters.shape if filters_shape is None else filters_shape 186 | 187 | _signals_shape_4d = ( 188 | _signals_shape_5d[0] * _signals_shape_5d[1], 189 | _signals_shape_5d[2], 190 | _signals_shape_5d[3], 191 | _signals_shape_5d[4], 192 | ) 193 | _filters_shape_4d = ( 194 | _filters_shape_5d[0] * _filters_shape_5d[1], 195 | _filters_shape_5d[2], 196 | _filters_shape_5d[3], 197 | _filters_shape_5d[4], 198 | ) 199 | 200 | if border_mode[1] != border_mode[2]: 201 | raise NotImplementedError('height and width bordermodes must match') 202 | conv2d_signal_shape = _signals_shape_4d 203 | conv2d_filter_shape = _filters_shape_4d 204 | if signals_shape is None: 205 | conv2d_signal_shape = None 206 | if filters_shape is None: 207 | conv2d_filter_shape = None 208 | 209 | out_4d = tensor.nnet.conv2d( 210 | signals.reshape(_signals_shape_4d), 211 | filters.reshape(_filters_shape_4d), 212 | image_shape=conv2d_signal_shape, 213 | filter_shape=conv2d_filter_shape, 214 | border_mode = border_mode[1]) # ignoring border_mode[2] 215 | 216 | # reshape the output to restore its original size 217 | # shape = Ns, Ts, Nf, Tf, W-Wf+1, H-Hf+1 218 | if border_mode[1] == 'valid': 219 | out_tmp = out_4d.reshape(( 220 | _signals_shape_5d[0], # Ns 221 | _signals_shape_5d[1], # Ts 222 | _filters_shape_5d[0], # Nf 223 | _filters_shape_5d[1], # Tf 224 | _signals_shape_5d[3] - _filters_shape_5d[3] + 1, 225 | _signals_shape_5d[4] - _filters_shape_5d[4] + 1, 226 | )) 227 | elif border_mode[1] == 'full': 228 | out_tmp = out_4d.reshape(( 229 | _signals_shape_5d[0], # Ns 230 | _signals_shape_5d[1], # Ts 231 | _filters_shape_5d[0], # Nf 232 | _filters_shape_5d[1], # Tf 233 | _signals_shape_5d[3] + _filters_shape_5d[3] - 1, 234 | _signals_shape_5d[4] + _filters_shape_5d[4] - 1, 235 | )) 236 | elif border_mode[1] == 'same': 237 | raise NotImplementedError() 238 | else: 239 | raise ValueError('invalid border mode', border_mode[1]) 240 | 241 | # now sum out along the Tf to get the output 242 | # but we have to sum on a diagonal through the Tf and Ts submatrix. 243 | if border_mode[0] == 'valid': 244 | out_5d = diagonal_subtensor(out_tmp, 1, 3).sum(axis=3) 245 | elif border_mode[0] in ('full', 'same'): 246 | out_5d = diagonal_subtensor(out_tmp, 1, 3).sum(axis=3) 247 | # out_5d = out_4d.reshape((_signals_shape_5d)) 248 | # raise NotImplementedError('sequence border mode', border_mode[0]) 249 | else: 250 | raise ValueError('invalid border mode', border_mode[1]) 251 | return out_5d 252 | 253 | 254 | def make_gpu_optimizer(op, to_gpu): 255 | """This function create optimizer that move some inputs to the GPU 256 | for op that work on both CPU and GPU. 257 | 258 | The op object is created by calling op(), so good default value 259 | are needed. 260 | 261 | We suppose the same op work with CPU and GPU inputs. 262 | 263 | :param op: the op that support GPU inputs 264 | :param to_gpu: a list of op inputs that are moved to the GPU. 265 | 266 | """ 267 | @theano.gof.local_optimizer([op, cuda.gpu_from_host]) 268 | def local_to_gpu(node): 269 | """ 270 | op(host_from_gpu()) -> host_from_gpu(op) 271 | gpu_from_host(op) -> op(gpu_from_host) 272 | """ 273 | if isinstance(node.op, op): 274 | #op(host_from_gpu()) -> host_from_gpu(op) 275 | #If any of the input that go on the GPU are on the GPU, 276 | #move the op to the gpu. 277 | if any(node.inputs[idx].owner and 278 | isinstance(node.inputs[idx].owner.op, cuda.HostFromGpu) 279 | for idx in to_gpu): 280 | new_inp = list(node.inputs) 281 | for idx in to_gpu: 282 | new_inp[idx] = cuda.gpu_from_host(new_inp[idx]) 283 | return [cuda.host_from_gpu(op()(*new_inp))] 284 | if node.op == cuda.gpu_from_host: 285 | #gpu_from_host(op) -> op(gpu_from_host) 286 | host_input = node.inputs[0] 287 | if host_input.owner and isinstance(host_input.owner.op, 288 | op): 289 | op_node = host_input.owner 290 | new_inp = list(op_node.inputs) 291 | for idx in to_gpu: 292 | new_inp[idx] = cuda.gpu_from_host(new_inp[idx]) 293 | return [op()(*new_inp)] 294 | return False 295 | local_to_gpu.__name__ = "local_to_gpu_" + op.__name__ 296 | cuda.opt.register_opt()(local_to_gpu) 297 | 298 | if cuda.cuda_available: 299 | make_gpu_optimizer(DiagonalSubtensor, [0]) 300 | make_gpu_optimizer(IncDiagonalSubtensor, [0, 3]) 301 | 302 | 303 | @theano.gof.local_optimizer([DiagonalSubtensor, IncDiagonalSubtensor]) 304 | def local_inplace_DiagonalSubtensor(node): 305 | """ also work for IncDiagonalSubtensor """ 306 | if (isinstance(node.op, (DiagonalSubtensor, IncDiagonalSubtensor)) and 307 | not node.op.inplace): 308 | new_op = node.op.__class__(inplace=True) 309 | new_node = new_op(*node.inputs) 310 | return [new_node] 311 | return False 312 | theano.compile.optdb.register( 313 | 'local_inplace_DiagonalSubtensor', 314 | TopoOptimizer( 315 | local_inplace_DiagonalSubtensor, 316 | failure_callback=TopoOptimizer.warn_inplace), 317 | 60, 'fast_run', 'inplace') 318 | -------------------------------------------------------------------------------- /convnet_3d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | """ 3 | 3D-CAE with max-pooling 4 | 5 | Stacked 3D-CAE for Alzheimer 6 | 7 | 11-11-15 Ehsan Hosseini-Asl 8 | 9 | """ 10 | __author__ = 'ehsanh' 11 | 12 | import numpy as np 13 | import pickle 14 | import maxpool3d 15 | import theano 16 | import theano.tensor as T 17 | from theano.tensor import nnet 18 | from theano.tensor.signal import downsample 19 | import conv3d2d 20 | from itertools import izip 21 | 22 | FLOAT_PRECISION = np.float32 23 | 24 | def adadelta_updates(parameters, gradients, rho, eps): 25 | 26 | # create variables to store intermediate updates 27 | # ipdb.set_trace() 28 | gradients_sq = [ theano.shared(np.zeros(p.get_value().shape, dtype=FLOAT_PRECISION),) for p in parameters ] 29 | deltas_sq = [ theano.shared(np.zeros(p.get_value().shape, dtype=FLOAT_PRECISION)) for p in parameters ] 30 | 31 | # calculates the new "average" delta for the next iteration 32 | gradients_sq_new = [ rho*g_sq + (1-rho)*(g**2) for g_sq,g in izip(gradients_sq,gradients) ] 33 | 34 | # calculates the step in direction. The square root is an approximation to getting the RMS for the average value 35 | deltas = [ (T.sqrt(d_sq+eps)/T.sqrt(g_sq+eps))*grad for d_sq,g_sq,grad in izip(deltas_sq,gradients_sq_new,gradients) ] 36 | 37 | # calculates the new "average" deltas for the next step. 38 | deltas_sq_new = [ rho*d_sq + (1-rho)*(d**2) for d_sq,d in izip(deltas_sq,deltas) ] 39 | 40 | # Prepare it as a list f 41 | gradient_sq_updates = zip(gradients_sq,gradients_sq_new) 42 | deltas_sq_updates = zip(deltas_sq,deltas_sq_new) 43 | parameters_updates = [ (p,p - d) for p,d in izip(parameters,deltas) ] 44 | # ipdb.set_trace() 45 | return gradient_sq_updates + deltas_sq_updates + parameters_updates 46 | # return parameters_updates 47 | 48 | 49 | class ConvolutionLayer3D(object): 50 | 51 | def __init__(self, rng, input, signal_shape, filter_shape, poolsize=(2, 2, 2), stride=None, if_pool=False, if_hidden_pool=False, 52 | act=None, 53 | share_with=None, 54 | tied=None, 55 | border_mode='valid'): 56 | self.input = input 57 | 58 | if share_with: 59 | self.W = share_with.W 60 | self.b = share_with.b 61 | 62 | self.W_delta = share_with.W_delta 63 | self.b_delta = share_with.b_delta 64 | 65 | elif tied: 66 | self.W = tied.W.dimshuffle(1,0,2,3) 67 | self.b = tied.b 68 | 69 | self.W_delta = tied.W_delta.dimshuffle(1,0,2,3) 70 | self.b_delta = tied.b_delta 71 | 72 | else: 73 | fan_in = np.prod(filter_shape[1:]) 74 | poolsize_size = np.prod(poolsize) if poolsize else 1 75 | fan_out = (filter_shape[0] * np.prod(filter_shape[2:]) / poolsize_size) 76 | W_bound = np.sqrt(6. / (fan_in + fan_out)) 77 | self.W = theano.shared( 78 | np.asarray( 79 | rng.uniform(low=-W_bound, high=W_bound, size=filter_shape), 80 | dtype=theano.config.floatX 81 | ), 82 | borrow=True 83 | ) 84 | b_values = np.zeros((filter_shape[0],), dtype=theano.config.floatX) 85 | self.b = theano.shared(value=b_values, borrow=True) 86 | 87 | self.W_delta = theano.shared( 88 | np.zeros(filter_shape, dtype=theano.config.floatX), 89 | borrow=True 90 | ) 91 | 92 | self.b_delta = theano.shared(value=b_values, borrow=True) 93 | 94 | # convolution 95 | conv_out = conv3d2d.conv3d( 96 | signals=input, 97 | filters=self.W, 98 | signals_shape=signal_shape, 99 | filters_shape=filter_shape, 100 | border_mode=border_mode) 101 | 102 | #if poolsize: 103 | if if_pool: 104 | conv_out = conv_out.dimshuffle(0,2,1,3,4) #maxpool3d works on last 3 dimesnions 105 | pooled_out = maxpool3d.max_pool_3d( 106 | input=conv_out, 107 | ds=poolsize, 108 | ignore_border=True) 109 | tmp_out = pooled_out.dimshuffle(0,2,1,3,4) 110 | tmp = tmp_out + self.b.dimshuffle('x', 'x', 0, 'x', 'x') 111 | elif if_hidden_pool: 112 | pooled_out = downsample.max_pool_2d( 113 | input=conv_out, 114 | ds=poolsize[:2], 115 | st=stride, 116 | ignore_border=True) 117 | tmp = pooled_out + self.b.dimshuffle('x', 'x', 0, 'x', 'x') 118 | else: 119 | tmp = conv_out + self.b.dimshuffle('x', 'x', 0, 'x', 'x') 120 | 121 | if act == 'tanh': 122 | self.output = T.tanh(tmp) 123 | elif act == 'sigmoid': 124 | self.output = nnet.sigmoid(tmp) 125 | elif act == 'relu': 126 | # self.output = tmp * (tmp>0) 127 | self.output = 0.5 * (tmp + abs(tmp)) + 1e-9 128 | elif act == 'softplus': 129 | # self.output = T.log2(1+T.exp(tmp)) 130 | self.output = nnet.softplus(tmp) 131 | else: 132 | self.output = tmp 133 | 134 | self.get_activation = theano.function( 135 | [self.input], 136 | self.output, 137 | updates=None, 138 | name='get hidden activation') 139 | 140 | # store parameters of this layer 141 | self.params = [self.W, self.b] 142 | self.deltas = [self.W_delta, self.b_delta] 143 | 144 | def get_state(self): 145 | return self.W.get_value(), self.b.get_value() 146 | 147 | def set_state(self, state): 148 | self.W.set_value(state[0], borrow=True) 149 | self.b.set_value(state[1], borrow=True) 150 | 151 | 152 | class HiddenLayer(object): 153 | 154 | def __init__(self, rng, input, n_in, n_out, share_with=None, activation=None): 155 | 156 | self.input = input 157 | self.n_in = n_in 158 | self.n_out = n_out 159 | self.activation = activation 160 | if share_with: 161 | self.W = share_with.W 162 | self.b = share_with.b 163 | 164 | self.W_delta = share_with.W_delta 165 | self.b_delta = share_with.b_delta 166 | else: 167 | W_values = np.asarray( 168 | rng.uniform( 169 | low=-np.sqrt(6. / (n_in + n_out)), 170 | high=np.sqrt(6. / (n_in + n_out)), 171 | size=(n_in, n_out) 172 | ), 173 | dtype=theano.config.floatX 174 | ) 175 | if activation == nnet.sigmoid: 176 | W_values *= 4 177 | 178 | self.W = theano.shared(value=W_values, name='W', borrow=True) 179 | 180 | b_values = np.zeros((n_out,), dtype=theano.config.floatX) 181 | self.b = theano.shared(value=b_values, name='b', borrow=True) 182 | 183 | self.W_delta = theano.shared( 184 | np.zeros((n_in, n_out), dtype=theano.config.floatX), 185 | borrow=True 186 | ) 187 | 188 | self.b_delta = theano.shared(value=b_values, borrow=True) 189 | 190 | self.params = [self.W, self.b] 191 | 192 | self.deltas = [self.W_delta, self.b_delta] 193 | 194 | lin_output = T.dot(self.input, self.W) + self.b 195 | 196 | if activation == 'tanh': 197 | self.output = T.tanh(lin_output) 198 | elif activation == 'sigmoid': 199 | self.output = nnet.sigmoid(lin_output) 200 | elif activation == 'relu': 201 | self.output = T.maximum(lin_output, 0) 202 | else: 203 | self.output = lin_output 204 | 205 | def get_state(self): 206 | return self.W.get_value(), self.b.get_value() 207 | 208 | def set_state(self, state): 209 | self.W.set_value(state[0], borrow=True) 210 | self.b.set_value(state[1], borrow=True) 211 | 212 | def initialize_layer(self): 213 | rng = np.random.RandomState(None) 214 | W_values = np.asarray( 215 | rng.uniform( 216 | low=-np.sqrt(6. / (self.n_in + self.n_out)), 217 | high=np.sqrt(6. / (self.n_in + self.n_out)), 218 | size=(self.n_in, self.n_out)), 219 | dtype=theano.config.floatX) 220 | 221 | if self.activation == nnet.sigmoid: 222 | W_values *= 4 223 | 224 | b_values = np.zeros((self.n_out,), dtype=theano.config.floatX) 225 | self.W.set_value(W_values, borrow=True) 226 | self.b.set_value(b_values, borrow=True) 227 | 228 | 229 | class softmaxLayer(object): 230 | def __init__(self, input, n_in, n_out): 231 | self.n_in = n_in 232 | self.n_out = n_out 233 | self.W = theano.shared( 234 | value=np.zeros( 235 | (n_in,n_out), 236 | dtype=theano.config.floatX 237 | ), 238 | name='W', 239 | borrow=True 240 | ) 241 | 242 | self.b = theano.shared( 243 | value=np.zeros( 244 | (n_out,), 245 | dtype=theano.config.floatX 246 | ), 247 | name='b', 248 | borrow=True 249 | ) 250 | 251 | self.W_delta = theano.shared( 252 | np.zeros((n_in,n_out), dtype=theano.config.floatX), 253 | borrow=True 254 | ) 255 | 256 | self.b_delta = theano.shared( 257 | value=np.zeros( 258 | (n_out,), 259 | dtype=theano.config.floatX), 260 | name='b', 261 | borrow=True) 262 | 263 | self.p_y_given_x = T.nnet.softmax(T.dot(input, self.W) + self.b) 264 | self.y_pred = T.argmax(self.p_y_given_x, axis=1) 265 | self.params = [self.W, self.b] 266 | 267 | self.deltas = [self.W_delta, self.b_delta] 268 | 269 | def initialize_layer(self): 270 | 271 | W_value=np.zeros( 272 | (self.n_in, self.n_out), 273 | dtype=theano.config.floatX 274 | ) 275 | 276 | b_value=np.zeros( 277 | (self.n_out,), 278 | dtype=theano.config.floatX 279 | ) 280 | 281 | self.W.set_value(W_value, borrow=True) 282 | self.b.set_value(b_value, borrow=True) 283 | 284 | def negative_log_likelihood(self, y): 285 | return -T.mean(T.log(self.p_y_given_x)[T.arange(y.shape[0]), y]) 286 | 287 | def errors(self, y): 288 | if y.ndim != self.y_pred.ndim: 289 | raise TypeError( 290 | 'y should have the same shape as self.y_pred', 291 | ('y', y.type, 'y_pred', self.y_pred.type) 292 | ) 293 | if y.dtype.startswith('int'): 294 | return T.mean(T.neq(self.y_pred, y)) 295 | else: 296 | raise NotImplementedError() 297 | 298 | def get_state(self): 299 | return self.W.get_value(), self.b.get_value() 300 | 301 | def set_state(self, state): 302 | self.W.set_value(state[0], borrow=True) 303 | self.b.set_value(state[1], borrow=True) 304 | 305 | 306 | class CAE3d(object): 307 | def __init__(self, signal_shape, filter_shape, poolsize, activation=None): 308 | rng = np.random.RandomState(None) 309 | dtensor5 = T.TensorType('float32', (False,)*5) 310 | self.inputs = dtensor5(name='inputs') 311 | self.image_shape = signal_shape 312 | self.batchsize = signal_shape[0] 313 | self.in_channels = signal_shape[2] 314 | self.in_depth = signal_shape[1] 315 | self.in_width = signal_shape[4] 316 | self.in_height = signal_shape[3] 317 | self.flt_channels = filter_shape[0] 318 | self.flt_time = filter_shape[1] 319 | self.flt_width = filter_shape[4] 320 | self.flt_height = filter_shape[3] 321 | self.activation = activation 322 | 323 | self.hidden_layer=ConvolutionLayer3D(rng, 324 | input=self.inputs, 325 | signal_shape=signal_shape, 326 | filter_shape=filter_shape, 327 | act=activation, 328 | border_mode='full', 329 | if_hidden_pool=False) 330 | 331 | self.hidden_image_shape = (self.batchsize, 332 | self.in_depth, 333 | self.flt_channels, 334 | self.in_height+self.flt_height-1, 335 | self.in_width+self.flt_width-1) 336 | 337 | self.hidden_pooled_image_shape = (self.batchsize, 338 | self.in_depth/2, 339 | self.flt_channels, 340 | (self.in_height+self.flt_height-1)/2, 341 | (self.in_width+self.flt_width-1)/2) 342 | 343 | self.hidden_filter_shape = (self.in_channels, self.flt_time, self.flt_channels, self.flt_height, 344 | self.flt_width) 345 | 346 | self.recon_layer=ConvolutionLayer3D(rng, 347 | input=self.hidden_layer.output, 348 | signal_shape=self.hidden_image_shape, 349 | filter_shape=self.hidden_filter_shape, 350 | act=activation, 351 | border_mode='valid') 352 | 353 | self.layers = [self.hidden_layer, self.recon_layer] 354 | self.params = sum([layer.params for layer in self.layers], []) 355 | L=T.sum(T.pow(T.sub(self.recon_layer.output, self.inputs), 2), axis=(1,2,3,4)) 356 | self.cost = 0.5*T.mean(L) 357 | self.grads = T.grad(self.cost, self.params) 358 | self.updates = adadelta_updates(self.params, self.grads, rho=0.95, eps=1e-6) 359 | 360 | self.train = theano.function( 361 | [self.inputs], 362 | self.cost, 363 | updates=self.updates, 364 | name = "train cae model" 365 | ) 366 | 367 | self.activation = maxpool3d.max_pool_3d( 368 | input=self.hidden_layer.output.dimshuffle(0,2,1,3,4), 369 | ds=poolsize, 370 | ignore_border=True) 371 | self.activation = self.activation.dimshuffle(0,2,1,3,4) 372 | self.get_activation = theano.function( 373 | [self.inputs], 374 | self.activation, 375 | updates=None, 376 | name='get hidden activation') 377 | 378 | def save(self, filename): 379 | f = open(filename, 'w') 380 | for layer in self.layers: 381 | pickle.dump(layer.get_state(), f, -1) 382 | f.close() 383 | 384 | 385 | def load(self, filename): 386 | f = open(filename) 387 | for layer in self.layers: 388 | layer.set_state(pickle.load(f)) 389 | f.close() 390 | print 'cae model loaded from', filename 391 | 392 | 393 | class stacked_CAE3d(object): 394 | def __init__(self, image_shape, filter_shapes, poolsize, activation_cae=None, activation_final=None, hidden_size=(2000, 500, 200, 20, 3)): 395 | rng = np.random.RandomState(None) 396 | dtensor5 = T.TensorType('float32', (False,)*5) 397 | images = dtensor5(name='images') 398 | labels = T.lvector('labels') 399 | 400 | self.image_shape = image_shape 401 | self.batchsize = image_shape[0] 402 | self.in_channels = image_shape[2] 403 | self.in_depth = image_shape[1] 404 | self.in_width = image_shape[4] 405 | self.in_height = image_shape[3] 406 | self.flt_channels1 = filter_shapes[0][0] 407 | self.flt_channels2 = filter_shapes[1][0] 408 | self.flt_channels3 = filter_shapes[2][0] 409 | self.flt_time = filter_shapes[0][1] 410 | self.flt_width = filter_shapes[0][4] 411 | self.flt_height = filter_shapes[0][3] 412 | 413 | conv1 = ConvolutionLayer3D(rng, 414 | input=images, 415 | signal_shape=image_shape, 416 | filter_shape=filter_shapes[0], 417 | act=activation_cae, 418 | poolsize=poolsize, 419 | if_pool=True, 420 | border_mode='valid') 421 | 422 | self.conv1_output_shape = (self.batchsize, 423 | self.in_depth/2, 424 | self.flt_channels1, 425 | (self.in_height-self.flt_height+1)/2, 426 | (self.in_width-self.flt_width+1)/2) 427 | 428 | #conv2_input=conv1.output.flatten(2) 429 | conv2 = ConvolutionLayer3D(rng, 430 | input=conv1.output, 431 | signal_shape=self.conv1_output_shape, 432 | filter_shape=filter_shapes[1], 433 | act=activation_cae, 434 | poolsize=poolsize, 435 | if_pool=True, 436 | border_mode='valid') 437 | 438 | self.conv2_output_shape = (self.batchsize, 439 | self.conv1_output_shape[1]/2, 440 | self.flt_channels2, 441 | (self.conv1_output_shape[3]-self.flt_height+1)/2, 442 | (self.conv1_output_shape[4]-self.flt_width+1)/2) 443 | 444 | conv3_input=conv2.output.flatten(2) 445 | conv3 = ConvolutionLayer3D(rng, 446 | input=conv2.output, 447 | signal_shape=self.conv2_output_shape, 448 | filter_shape=filter_shapes[2], 449 | act=activation_cae, 450 | poolsize=poolsize, 451 | if_pool=True, 452 | border_mode='valid') 453 | 454 | self.conv3_output_shape = (self.batchsize, 455 | self.conv2_output_shape[1]/2, 456 | self.flt_channels3, 457 | (self.conv2_output_shape[3]-self.flt_height+1)/2, 458 | (self.conv2_output_shape[4]-self.flt_width+1)/2) 459 | 460 | # 4 layers in hidden_size: 461 | ip1_input=conv3.output.flatten(2) 462 | ip1 = HiddenLayer(rng, 463 | input=ip1_input, 464 | n_in=np.prod(self.conv3_output_shape[1:]), 465 | n_out=hidden_size[0], 466 | activation=activation_final) 467 | 468 | ip2 = HiddenLayer(rng, 469 | input=ip1.output, 470 | n_in=hidden_size[0], 471 | n_out=hidden_size[1], 472 | activation=activation_final) 473 | 474 | ip3 = HiddenLayer(rng, 475 | input=ip2.output, 476 | n_in=hidden_size[1], 477 | n_out=hidden_size[2], 478 | activation=activation_final) 479 | 480 | ip4 = HiddenLayer(rng, 481 | input=ip3.output, 482 | n_in=hidden_size[2], 483 | n_out=hidden_size[3], 484 | activation=activation_final) 485 | 486 | output_layer = softmaxLayer(input=ip4.output, 487 | n_in=hidden_size[1], 488 | n_out=hidden_size[4]) 489 | 490 | self.layers = [conv1, 491 | conv2, 492 | conv3, 493 | ip1, 494 | ip2, 495 | output_layer] 496 | 497 | # freeze first 3 conv layers 498 | self.params = sum([l.params for l in self.layers[3:]], []) 499 | self.cost = output_layer.negative_log_likelihood(labels) 500 | self.grads = T.grad(self.cost, self.params) 501 | self.grads_input = T.grad(self.cost, images) 502 | 503 | self.updates = adadelta_updates(parameters=self.params, 504 | gradients=self.grads, 505 | rho=0.95, 506 | eps=1e-6) 507 | 508 | self.error = output_layer.errors(labels) 509 | self.y_pred = output_layer.y_pred 510 | self.prob = output_layer.p_y_given_x.max(axis=1) 511 | self.true_prob = output_layer.p_y_given_x[T.arange(labels.shape[0]), labels] 512 | self.p_y_given_x = output_layer.p_y_given_x 513 | self.train = theano.function( 514 | inputs=[images, labels], 515 | outputs=(self.error, self.cost, self.y_pred, self.prob), 516 | updates=self.updates 517 | ) 518 | 519 | self.forward = theano.function( 520 | inputs=[images, labels], 521 | outputs=(self.error, self.y_pred, self.prob, self.true_prob, self.p_y_given_x, 522 | conv3_input, ip1_input, self.layers[-2].output, self.layers[-3].output, 523 | self.grads_input) 524 | ) 525 | 526 | def load_cae(self, filename, cae_layer): 527 | f = open(filename) 528 | self.layers[cae_layer].set_state(pickle.load(f)) 529 | print 'cae %d loaded from %s' % (cae_layer, filename) 530 | 531 | def save(self, filename): 532 | f = open(filename, 'w') 533 | for l in self.layers: 534 | pickle.dump(l.get_state(), f, -1) 535 | f.close() 536 | 537 | def load(self, filename): 538 | f = open(filename) 539 | for l in self.layers: 540 | l.set_state(pickle.load(f)) 541 | f.close() 542 | print 'model loaded from', filename 543 | 544 | def load_binary(self, filename): 545 | f = open(filename) 546 | for l in self.layers[:-1]: 547 | l.set_state(pickle.load(f)) 548 | f.close() 549 | print 'model loaded from', filename 550 | 551 | def load_conv(self, filename): 552 | f = open(filename) 553 | for l in self.layers[:3]: 554 | l.set_state(pickle.load(f)) 555 | f.close() 556 | print 'model conv layers loaded from', filename 557 | 558 | def load_fc(self, filename): 559 | f = open(filename) 560 | for l in self.layers[-3:]: 561 | l.set_state(pickle.load(f)) 562 | f.close() 563 | print 'model fc layers loaded from', filename 564 | 565 | 566 | -------------------------------------------------------------------------------- /dlt_utils.py: -------------------------------------------------------------------------------- 1 | """ This file contains different utility functions that are not connected 2 | in anyway to the networks presented in the tutorials, but rather help in 3 | processing the outputs into a more understandable way. 4 | 5 | For example ``tile_raster_images`` helps in generating a easy to grasp 6 | image from a set of samples or weights. 7 | """ 8 | 9 | 10 | import numpy 11 | 12 | 13 | def scale_to_unit_interval(ndar, eps=1e-8): 14 | """ Scales all values in the ndarray ndar to be between 0 and 1 """ 15 | ndar = ndar.copy() 16 | ndar -= ndar.min() 17 | ndar *= 1.0 / (ndar.max() + eps) 18 | return ndar 19 | 20 | 21 | def tile_raster_images(X, img_shape, tile_shape, tile_spacing=(0, 0), 22 | scale_rows_to_unit_interval=True, 23 | output_pixel_vals=True): 24 | """ 25 | Transform an array with one flattened image per row, into an array in 26 | which images are reshaped and layed out like tiles on a floor. 27 | 28 | This function is useful for visualizing datasets whose rows are images, 29 | and also columns of matrices for transforming those rows 30 | (such as the first layer of a neural net). 31 | 32 | :type X: a 2-D ndarray or a tuple of 4 channels, elements of which can 33 | be 2-D ndarrays or None; 34 | :param X: a 2-D array in which every row is a flattened image. 35 | 36 | :type img_shape: tuple; (height, width) 37 | :param img_shape: the original shape of each image 38 | 39 | :type tile_shape: tuple; (rows, cols) 40 | :param tile_shape: the number of images to tile (rows, cols) 41 | 42 | :param output_pixel_vals: if output should be pixel values (i.e. int8 43 | values) or floats 44 | 45 | :param scale_rows_to_unit_interval: if the values need to be scaled before 46 | being plotted to [0,1] or not 47 | 48 | 49 | :returns: array suitable for viewing as an image. 50 | (See:`Image.fromarray`.) 51 | :rtype: a 2-d array with same dtype as X. 52 | 53 | """ 54 | 55 | assert len(img_shape) == 2 56 | assert len(tile_shape) == 2 57 | assert len(tile_spacing) == 2 58 | 59 | # The expression below can be re-written in a more C style as 60 | # follows : 61 | # 62 | # out_shape = [0,0] 63 | # out_shape[0] = (img_shape[0]+tile_spacing[0])*tile_shape[0] - 64 | # tile_spacing[0] 65 | # out_shape[1] = (img_shape[1]+tile_spacing[1])*tile_shape[1] - 66 | # tile_spacing[1] 67 | out_shape = [ 68 | (ishp + tsp) * tshp - tsp 69 | for ishp, tshp, tsp in zip(img_shape, tile_shape, tile_spacing) 70 | ] 71 | 72 | if isinstance(X, tuple): 73 | assert len(X) == 4 74 | # Create an output numpy ndarray to store the image 75 | if output_pixel_vals: 76 | out_array = numpy.zeros((out_shape[0], out_shape[1], 4), 77 | dtype='uint8') 78 | else: 79 | out_array = numpy.zeros((out_shape[0], out_shape[1], 4), 80 | dtype=X.dtype) 81 | 82 | #colors default to 0, alpha defaults to 1 (opaque) 83 | if output_pixel_vals: 84 | channel_defaults = [0, 0, 0, 255] 85 | else: 86 | channel_defaults = [0., 0., 0., 1.] 87 | 88 | for i in xrange(4): 89 | if X[i] is None: 90 | # if channel is None, fill it with zeros of the correct 91 | # dtype 92 | dt = out_array.dtype 93 | if output_pixel_vals: 94 | dt = 'uint8' 95 | out_array[:, :, i] = numpy.zeros( 96 | out_shape, 97 | dtype=dt 98 | ) + channel_defaults[i] 99 | else: 100 | # use a recurrent call to compute the channel and store it 101 | # in the output 102 | out_array[:, :, i] = tile_raster_images( 103 | X[i], img_shape, tile_shape, tile_spacing, 104 | scale_rows_to_unit_interval, output_pixel_vals) 105 | return out_array 106 | 107 | else: 108 | # if we are dealing with only one channel 109 | H, W = img_shape 110 | Hs, Ws = tile_spacing 111 | 112 | # generate a matrix to store the output 113 | dt = X.dtype 114 | if output_pixel_vals: 115 | dt = 'uint8' 116 | out_array = numpy.zeros(out_shape, dtype=dt) 117 | 118 | for tile_row in xrange(tile_shape[0]): 119 | for tile_col in xrange(tile_shape[1]): 120 | if tile_row * tile_shape[1] + tile_col < X.shape[0]: 121 | this_x = X[tile_row * tile_shape[1] + tile_col] 122 | if scale_rows_to_unit_interval: 123 | # if we should scale values to be between 0 and 1 124 | # do this by calling the `scale_to_unit_interval` 125 | # function 126 | this_img = scale_to_unit_interval( 127 | this_x.reshape(img_shape)) 128 | else: 129 | this_img = this_x.reshape(img_shape) 130 | # add the slice to the corresponding position in the 131 | # output array 132 | c = 1 133 | if output_pixel_vals: 134 | c = 255 135 | out_array[ 136 | tile_row * (H + Hs): tile_row * (H + Hs) + H, 137 | tile_col * (W + Ws): tile_col * (W + Ws) + W 138 | ] = this_img * c 139 | return out_array -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | """ 3 | 3D-CAE with max-pooling 4 | 5 | Stacked 3D-CAE for Alzheimer 6 | 7 | 11-11-15 Ehsan Hosseini-Asl 8 | 9 | """ 10 | __author__ = 'ehsanh' 11 | 12 | import numpy as np 13 | import argparse 14 | import os 15 | import pickle 16 | import random 17 | import sys 18 | import time 19 | import scipy.io as sio 20 | from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report, \ 21 | roc_curve, auc, roc_auc_score 22 | from convnet_3d import CAE3d, stacked_CAE3d 23 | FLOAT_PRECISION = np.float32 24 | 25 | 26 | 27 | def load_batch(batch_idx, num_batches, image_shape, data_dir, data_list): 28 | batch_size, d, c, h, w = image_shape 29 | if batch_idx= save_interval: 206 | filename = 'cae'+str(cae_layer)\ 207 | +'_[act=%s,fn=%d,fs=%d].pkl'%\ 208 | (models[cae_layer-1].activation, models[cae_layer-1].flt_channels, models[cae_layer-1].flt_width) 209 | models[cae_layer-1].save(filename) 210 | print 'model saved to', filename 211 | sys.stdout.flush() 212 | last_save = time.time() 213 | if epoch >= max_epoch-1: 214 | filename = 'cae'+str(cae_layer)\ 215 | +'_[act=%s,fn=%d,fs=%d].pkl'%\ 216 | (models[cae_layer-1].activation, models[cae_layer-1].flt_channels, models[cae_layer-1].flt_width) 217 | models[cae_layer-1].save(filename) 218 | print 'max epoch reached. model saved to', filename 219 | sys.stdout.flush() 220 | return filename 221 | 222 | except KeyboardInterrupt: 223 | # filename = 'cae_'+time.strftime('%Y%m%d-%H%M%S') + ('-%06d.pkl' % epoch) 224 | filename = 'cae'+str(cae_layer)\ 225 | +'_[act=%s,fn=%d,fs=%d].pkl'%\ 226 | (models[cae_layer-1].activation, models[cae_layer-1].flt_channels, models[cae_layer-1].flt_width) 227 | models[cae_layer-1].save(filename) 228 | print 'model saved to', filename 229 | sys.stdout.flush() 230 | return filename 231 | 232 | 233 | def finetune_scae(data_dir, model, binary_classification=(False, False, False, False), max_epoch=1): 234 | data_list = os.listdir(data_dir) 235 | random.shuffle(data_list) 236 | batch_size, d, c, h, w = model.image_shape 237 | progress_report = 1 238 | save_interval = 1800 239 | num_subjects = len(data_list) 240 | num_batches = num_subjects/batch_size 241 | if num_subjects%batch_size !=0: 242 | num_batches +=1 243 | last_save = time.time() 244 | epoch = 0 245 | AD_Normal, AD_MCI, MCI_Normal, AM_N = binary_classification 246 | 247 | if True not in binary_classification: 248 | filename = 'scae.pkl' 249 | print 'training scae for AD_MCI_Normal' 250 | elif not MCI_Normal and not AM_N: 251 | filename = 'scae_%s.pkl'%('AD_Normal' if AD_Normal else 'AD_MCI') 252 | print 'training scae for %s'%('AD_Normal' if AD_Normal else 'AD_MCI') 253 | else: 254 | filename = 'scae_%s.pkl'%('MCI_Normal' if MCI_Normal else 'AM_N') 255 | print 'training scae for %s'%('MCI_Normal' if MCI_Normal else 'AM_N') 256 | 257 | while True: 258 | try: 259 | loss_hist = np.empty((num_batches,), dtype=FLOAT_PRECISION) 260 | error_hist = np.empty((num_batches,), dtype=FLOAT_PRECISION) 261 | start_time = time.time() 262 | for batch in xrange(num_batches): 263 | if True not in binary_classification: 264 | batch_data, batch_labels, batch_names = load_batch(batch, num_batches, model.image_shape, data_dir, 265 | data_list=data_list) 266 | elif AD_Normal: 267 | batch_data, batch_labels, batch_names = load_batch_AD_Normal(batch, num_batches, model.image_shape, data_dir, 268 | data_list=data_list) 269 | elif AD_MCI: 270 | batch_data, batch_labels, batch_names = load_batch_AD_MCI(batch, num_batches, model.image_shape, data_dir, 271 | data_list=data_list) 272 | elif MCI_Normal: 273 | batch_data, batch_labels, batch_names = load_batch_MCI_Normal(batch, num_batches, model.image_shape, data_dir, 274 | data_list=data_list) 275 | elif AM_N: 276 | batch_data, batch_labels, batch_names = load_batch_AM_N(batch, num_batches, model.image_shape, data_dir, 277 | data_list=data_list) 278 | start = time.time() 279 | 280 | batch_error, cost, pred, prob = model.train(batch_data, batch_labels) 281 | loss_hist[batch] = cost 282 | train_time = time.time()-start 283 | print 284 | error_hist[batch] = batch_error 285 | print 'batch:%02d\terror:%.2f\tcost:%.2f\ttime:%.2f' % (batch, batch_error, cost, train_time/60.) 286 | print 'subjects:\t', 287 | for name in batch_names: 288 | print name[:-4], 289 | print 290 | print 'labels:\t', 291 | for l in batch_labels: 292 | print l, 293 | print 294 | print 'pred:\t', 295 | for p in pred: 296 | print p, 297 | print 298 | print 'prob:', 299 | for p in prob: 300 | print '%.2f'%p, 301 | print 302 | sys.stdout.flush() 303 | epoch += 1 304 | if epoch % progress_report == 0: 305 | print 'epoch:%02d\terror:%g\tloss:%g\ttime:%f' % (epoch, error_hist.mean(), loss_hist.mean(), 306 | (time.time()-start_time)/60.) 307 | sys.stdout.flush() 308 | if time.time() - last_save >= save_interval: 309 | model.save(filename) 310 | print 'scae model saved to', filename 311 | sys.stdout.flush() 312 | last_save = time.time() 313 | if epoch >= max_epoch-1: 314 | model.save(filename) 315 | print 'max epoch reached. scae model saved to', filename 316 | sys.stdout.flush() 317 | return filename 318 | except KeyboardInterrupt: 319 | model.save(filename) 320 | print 'scae model saved to', filename 321 | sys.stdout.flush() 322 | return filename 323 | 324 | 325 | def finetune_scae_crossvalidate(data_dir, model, 326 | binary_classification=(False, False, False, False), 327 | max_epoch=1): 328 | data_list = os.listdir(data_dir) 329 | random.shuffle(data_list) 330 | batch_size, d, c, h, w = model.image_shape 331 | progress_report = 1 332 | save_interval = 1800 333 | num_subjects = int(4./5*len(data_list)) 334 | num_batches = num_subjects/batch_size 335 | if num_subjects%batch_size !=0: 336 | num_batches +=1 337 | last_save = time.time() 338 | AD_Normal, AD_MCI, MCI_Normal, AM_N = binary_classification 339 | 340 | for fold in xrange(2,5): 341 | epoch = 0 342 | model.layers[3].initialize_layer() 343 | model.layers[4].initialize_layer() 344 | model.layers[5].initialize_layer() 345 | 346 | data_list_fold = [data for data in data_list if int(data[-6:-4])%5!=fold] 347 | if True not in binary_classification: 348 | filename = 'scae_fold%d.pkl'%(fold) 349 | print 'training scae for AD_MCI_Normal' 350 | elif not MCI_Normal and not AM_N: 351 | filename = 'scae_%s_%d.pkl'%('AD_Normal' if AD_Normal else 'AD_MCI', fold) 352 | print 'training scae for %s, fold %d'%('AD_Normal' if AD_Normal else 'AD_MCI', fold) 353 | else: 354 | filename = 'scae_%s_fold%d.pkl'%('MCI_Normal' if MCI_Normal else 'AM_N', fold) 355 | print 'training scae for %s, fold %d'%('MCI_Normal' if MCI_Normal else 'AM_N', fold) 356 | 357 | error = 1 358 | while error>0.04: 359 | try: 360 | loss_hist = np.empty((num_batches,), dtype=FLOAT_PRECISION) 361 | error_hist = np.empty((num_batches,), dtype=FLOAT_PRECISION) 362 | start_time = time.time() 363 | for batch in xrange(num_batches): 364 | if True not in binary_classification: 365 | batch_data, batch_labels, batch_names = load_batch(batch, num_batches, model.image_shape, data_dir, 366 | data_list=data_list_fold) 367 | elif AD_Normal: 368 | batch_data, batch_labels, batch_names = load_batch_AD_Normal(batch, num_batches, model.image_shape, data_dir, 369 | data_list=data_list_fold) 370 | elif AD_MCI: 371 | batch_data, batch_labels, batch_names = load_batch_AD_MCI(batch, num_batches, model.image_shape, data_dir, 372 | data_list=data_list_fold) 373 | elif MCI_Normal: 374 | batch_data, batch_labels, batch_names = load_batch_MCI_Normal(batch, num_batches, model.image_shape, data_dir, 375 | data_list=data_list_fold) 376 | elif AM_N: 377 | batch_data, batch_labels, batch_names = load_batch_AM_N(batch, num_batches, model.image_shape, data_dir, 378 | data_list=data_list_fold) 379 | start = time.time() 380 | 381 | batch_error, cost, pred, prob = model.train(batch_data, batch_labels) 382 | loss_hist[batch] = cost 383 | train_time = time.time()-start 384 | print 385 | error_hist[batch] = batch_error 386 | print 'batch:%02d\terror:%.2f\tcost:%.2f\ttime:%.2f' % (batch, batch_error, cost, train_time/60.) 387 | print 'subjects:\t', 388 | for name in batch_names: 389 | print name[:-4], 390 | print 391 | print 'labels:\t', 392 | for l in batch_labels: 393 | print l, 394 | print 395 | print 'pred:\t', 396 | for p in pred: 397 | print p, 398 | print 399 | print 'prob:', 400 | for p in prob: 401 | print '%.2f'%p, 402 | print 403 | sys.stdout.flush() 404 | epoch += 1 405 | error = error_hist.mean() 406 | if epoch % progress_report == 0: 407 | print 'epoch:%02d\terror:%.2f\tloss:%.2f\ttime:%02d min' % (epoch, error_hist.mean(), 408 | loss_hist.mean(), 409 | (time.time()-start_time)/60.) 410 | sys.stdout.flush() 411 | if time.time() - last_save >= save_interval: 412 | model.save(filename) 413 | print 'scae model fold %d saved to %s'% (fold, filename) 414 | sys.stdout.flush() 415 | last_save = time.time() 416 | except KeyboardInterrupt: 417 | model.save(filename) 418 | print 'scae model fold %d saved to %s'% (fold, filename) 419 | sys.stdout.flush() 420 | continue 421 | model.save(filename) 422 | print 'error threshold reached. scae model fold %d saved to %s' % (fold, filename) 423 | sys.stdout.flush() 424 | continue 425 | 426 | 427 | def get_hidden_data(dir, image_shape, models, layer): 428 | ''' print 'get hidden activation ''' 429 | hidden_data = {} 430 | for type in ['AD', 'MCI', 'Normal']: 431 | print 'get %s hidden activation' % type 432 | data_dir = dir+type+'/' 433 | data_list = os.listdir(data_dir) 434 | data_list.sort() 435 | num_subjects = len(data_list) 436 | batch_size = image_shape[0] 437 | num_batches = num_subjects/batch_size 438 | 439 | sample_batch, _, _ = load_batch(0, num_batches, image_shape, data_dir, data_list=data_list) 440 | if layer == 1: 441 | sample_hidden = models[layer-1].get_activation(sample_batch) 442 | elif layer == 2: 443 | hidden_batch = models[layer-2].get_activation(sample_batch) 444 | sample_hidden = models[layer-1].get_activation(hidden_batch) 445 | else: 446 | hidden1_batch = models[layer-3].get_activation(sample_batch) 447 | hidden2_batch = models[layer-2].get_activation(hidden1_batch) 448 | sample_hidden = models[layer-1].get_activation(hidden2_batch) 449 | sample_shape = sample_hidden.shape 450 | _, depth, channel, height, width = sample_shape 451 | hidden_shape = (len(data_list), depth, channel, height, width) 452 | 453 | hidden_data[type] = np.empty(hidden_shape, dtype=FLOAT_PRECISION) 454 | for batch in xrange(num_batches): 455 | print batch 456 | batch_data, _, _ = load_batch(batch, num_batches, image_shape, data_dir, data_list=data_list) 457 | if layer == 1: 458 | batch_hidden = models[layer-1].get_activation(batch_data) 459 | elif layer == 2: 460 | batch_hidden1 = models[layer-2].get_activation(batch_data) 461 | batch_hidden = models[layer-1].get_activation(batch_hidden1) 462 | else: 463 | batch_hidden1 = models[layer-3].get_activation(batch_data) 464 | batch_hidden2 = models[layer-2].get_activation(batch_hidden1) 465 | batch_hidden = models[layer-1].get_activation(batch_hidden2) 466 | hidden_data[type][batch*batch_size:(batch+1)*batch_size] = batch_hidden 467 | for i in xrange(10): 468 | filename = '%s_hidden_layer%d_%d.mat' % (type, layer, i) 469 | sio.savemat(filename, {'hidden_data':hidden_data[type][i*10:(i+1)*10]}) 470 | return hidden_data 471 | 472 | 473 | def get_hidden_finetuned(dir, model, layer): 474 | ''' print 'get hidden activation ''' 475 | hidden_data = {} 476 | for type in ['AD', 'MCI', 'Normal']: 477 | print 'get %s hidden activation' % type 478 | sys.stdout.flush() 479 | data_dir = dir+type+'/' 480 | data_list = os.listdir(data_dir) 481 | data_list.sort() 482 | num_subjects = len(data_list) 483 | batch_size = model.image_shape[0] 484 | num_batches = num_subjects/batch_size 485 | 486 | sample_batch, _, _ = load_batch(0, num_batches, model.image_shape, data_dir, data_list=data_list) 487 | if layer == 1: 488 | sample_hidden = model.layers[layer-1].get_activation(sample_batch) 489 | elif layer == 2: 490 | hidden_batch = model.layers[layer-2].get_activation(sample_batch) 491 | sample_hidden = model.layers[layer-1].get_activation(hidden_batch) 492 | else: 493 | hidden1_batch = model.layers[layer-3].get_activation(sample_batch) 494 | hidden2_batch = model.layers[layer-2].get_activation(hidden1_batch) 495 | sample_hidden = model.layers[layer-1].get_activation(hidden2_batch) 496 | sample_shape = sample_hidden.shape 497 | _, depth, channel, height, width = sample_shape 498 | hidden_shape = (len(data_list), depth, channel, height, width) 499 | 500 | hidden_data[type] = np.empty(hidden_shape, dtype=FLOAT_PRECISION) 501 | for batch in xrange(num_batches): 502 | batch_data, _, _ = load_batch(batch, num_batches, model.image_shape, data_dir, data_list=data_list) 503 | start_time = time.time() 504 | if layer == 1: 505 | batch_hidden = model.layers[layer-1].get_activation(batch_data) 506 | elif layer == 2: 507 | batch_hidden1 = model.layers[layer-2].get_activation(batch_data) 508 | batch_hidden = model.layers[layer-1].get_activation(batch_hidden1) 509 | else: 510 | batch_hidden1 = model.layers[layer-3].get_activation(batch_data) 511 | batch_hidden2 = model.layers[layer-2].get_activation(batch_hidden1) 512 | batch_hidden = model.layers[layer-1].get_activation(batch_hidden2) 513 | forward_time = time.time()-start_time 514 | print 'batch:%d\ttime: %.2f min' % (batch, forward_time/60.) 515 | hidden_data[type][batch*batch_size:(batch+1)*batch_size] = batch_hidden 516 | sys.stdout.flush() 517 | for i in xrange(10): 518 | filename = 'hidden_layer%d/%s_hidden_layer%d_%d.mat' % (layer, type, layer, i) 519 | sio.savemat(filename, {'hidden_data':hidden_data[type][i*10:(i+1)*10]}) 520 | sys.stdout.flush() 521 | return hidden_data 522 | 523 | 524 | def ProcessCommandLine(): 525 | parser = argparse.ArgumentParser(description='train scae on alzheimer') 526 | default_image_dir = 'ADNI_original/data/' 527 | parser.add_argument('-I', '--data_dir', default=default_image_dir, 528 | help='location of image files; default=%s' % default_image_dir) 529 | parser.add_argument('-m', '--scae_model', 530 | help='start with this scae model') 531 | parser.add_argument('-cae1', '--cae1_model', 532 | help='Initialize cae1 model') 533 | parser.add_argument('-cae2', '--cae2_model', 534 | help='Initialize cae2 model') 535 | parser.add_argument('-cae3', '--cae3_model', 536 | help='Initialize cae3 model') 537 | parser.add_argument('-ac', '--activation_cae', type=str, default='relu', 538 | help='cae activation function') 539 | parser.add_argument('-af', '--activation_final', type=str, default='relu', 540 | help='final layer activation function') 541 | parser.add_argument('-fn', '--filter_channel', type=int, default=[8,8,8], nargs='+', 542 | help='filter channel list') 543 | parser.add_argument('-fs', '--filter_size', type=int, default=3, 544 | help='filter size') 545 | parser.add_argument('-p', '--pretrain_layer', type=int, default=0, 546 | help='pretrain cae layer') 547 | parser.add_argument('-gh', '--get_hidden', type=int, default=0, 548 | help='get hidden layer') 549 | parser.add_argument('-t', '--test', action='store_true', 550 | help='do testing') 551 | parser.add_argument('-ft', '--finetune', action='store_true', 552 | help='do fine tuning') 553 | parser.add_argument('-AN', '--AD_Normal', action='store_true', 554 | help='AD-Normal classification') 555 | parser.add_argument('-AM', '--AD_MCI', action='store_true', 556 | help='AD-MCI classification') 557 | parser.add_argument('-MN', '--MCI_Normal', action='store_true', 558 | help='MCI-Normal classification') 559 | parser.add_argument('-AMN', '--AM_N', action='store_true', 560 | help='AM-Normal classification') 561 | parser.add_argument('-lcn', '--load_conv', action='store_true', 562 | help='load only conv layers') 563 | parser.add_argument('-batch', '--batchsize', type=int, default=1, 564 | help='batch size') 565 | args = parser.parse_args() 566 | return args.data_dir, args.scae_model, args.cae1_model, args.cae2_model, args.cae3_model, args.activation_cae, \ 567 | args.activation_final, \ 568 | args.filter_channel, args.filter_size, args.pretrain_layer, args.get_hidden, args.test, \ 569 | args.finetune, args.AD_Normal, args.AD_MCI, args.MCI_Normal, args.AM_N, args.load_conv, args.batchsize 570 | 571 | 572 | def test_scae(data_dir, model, binary_classification=(False, False, False, False)): 573 | data_list = os.listdir(data_dir) 574 | batch_size, d, c, h, w = model.image_shape 575 | num_subjects = len(data_list) 576 | num_batches = num_subjects/batch_size 577 | if num_subjects%batch_size !=0: 578 | num_batches +=1 579 | AD_Normal, AD_MCI, MCI_Normal, AM_N = binary_classification 580 | if True not in binary_classification: 581 | print 'testing scae for AD_MCI_Normal' 582 | elif not MCI_Normal and not AM_N: 583 | print 'testing scae for %s'%('AD_Normal' if AD_Normal else 'AD_MCI') 584 | filename = 'test_%s.pkl'%('AD_Normal' if AD_Normal else 'AD_MCI') 585 | else: 586 | filename = 'test_%s.pkl'%('MCI_Normal' if MCI_Normal else 'AM_N') 587 | print 'testing scae for %s'%('MCI_Normal' if MCI_Normal else 'AM_N') 588 | sys.stdout.flush() 589 | 590 | test_labels, test_names, test_pred, test_prob, test_label_prob= [], [], [], [], [] 591 | num_labels = 2 if True in binary_classification else 3 592 | p_y_given_x = np.empty((num_subjects, num_labels), dtype=FLOAT_PRECISION) 593 | conv2_feat = np.empty((num_subjects, np.prod(model.conv2_output_shape[1:])), dtype=FLOAT_PRECISION) 594 | conv3_feat = np.empty((num_subjects, np.prod(model.conv3_output_shape[1:])), dtype=FLOAT_PRECISION) 595 | ip2_feat = np.empty((num_subjects, 500), dtype=FLOAT_PRECISION) 596 | ip1_feat = np.empty((num_subjects, 2000), dtype=FLOAT_PRECISION) 597 | image_gradient = np.empty((num_subjects, d, c, h, w), dtype=FLOAT_PRECISION) 598 | for batch in xrange(num_batches): 599 | if True not in binary_classification: 600 | batch_data, batch_labels, batch_names = load_batch(batch, num_batches, model.image_shape, data_dir, 601 | data_list=data_list) 602 | elif AD_Normal: 603 | batch_data, batch_labels, batch_names = load_batch_AD_Normal(batch, num_batches, model.image_shape, data_dir, 604 | data_list=data_list) 605 | elif AD_MCI: 606 | batch_data, batch_labels, batch_names = load_batch_AD_MCI(batch, num_batches, model.image_shape, data_dir, 607 | data_list=data_list) 608 | elif MCI_Normal: 609 | batch_data, batch_labels, batch_names = load_batch_MCI_Normal(batch, num_batches, model.image_shape, data_dir, 610 | data_list=data_list) 611 | elif AM_N: 612 | batch_data, batch_labels, batch_names = load_batch_AM_N(batch, num_batches, model.image_shape, data_dir, 613 | data_list=data_list) 614 | batch_error, pred, prob, truth_prob, batch_p_y_given_x, batch_conv2_feat, \ 615 | batch_conv3_feat, batch_ip2_feat, batch_ip1_feat, batch_gradient\ 616 | = model.forward(batch_data, batch_labels) 617 | test_labels.extend(batch_labels) 618 | test_names.extend(batch_names) 619 | test_pred.extend(pred) 620 | test_prob.extend(prob) 621 | test_label_prob.extend(truth_prob) 622 | p_y_given_x[batch*batch_size:(batch+1)*batch_size, :] = batch_p_y_given_x 623 | conv2_feat[batch*batch_size:(batch+1)*batch_size, :] = batch_conv2_feat 624 | conv3_feat[batch*batch_size:(batch+1)*batch_size, :] = batch_conv3_feat 625 | ip2_feat[batch*batch_size:(batch+1)*batch_size, :] = batch_ip2_feat 626 | ip1_feat[batch*batch_size:(batch+1)*batch_size, :] = batch_ip1_feat 627 | image_gradient[batch*batch_size:(batch+1)*batch_size, :] = batch_gradient 628 | for i, subject in enumerate(batch_names): 629 | sio.savemat('{0}_gradient.mat'.format(subject[:-4]), {'gradient':batch_gradient[i]}) 630 | 631 | print '\n\nbatch:%02d\terror:%.2f' % (batch, batch_error) 632 | print 'subjects:\t', 633 | for name in batch_names: 634 | print name[:-4], 635 | print 636 | print 'labels:\t', 637 | for l in batch_labels: 638 | print l, 639 | print 640 | print 'pred:\t', 641 | for p in pred: 642 | print p, 643 | print 644 | print 'prob:', 645 | for p in prob: 646 | print '%.2f'%p, 647 | print 648 | sys.stdout.flush() 649 | 650 | accuracy = accuracy_score(test_labels, test_pred) 651 | f_score = f1_score(np.asarray(test_labels), np.asarray(test_pred)) 652 | confusion = confusion_matrix(np.asarray(test_labels), np.asarray(test_pred)) 653 | 654 | print '\n\nAccuracy:%.4f\tF1_Score:%.4f' % (accuracy, f_score) 655 | print '\nconfusion:' 656 | print confusion 657 | 658 | if True not in binary_classification: 659 | class_names = ['AD', 'MCI', 'Normal'] 660 | filename = 'test_AMN.pkl' 661 | elif AD_Normal: 662 | class_names = ['AD', 'Normal'] 663 | filename = 'test_AN.pkl' 664 | elif AD_MCI: 665 | class_names = ['AD', 'MCI'] 666 | filename = 'test_AM.pkl' 667 | elif MCI_Normal: 668 | class_names = ['MCI', 'Normal'] 669 | filename = 'test_MN.pkl' 670 | elif AM_N: 671 | class_names = ['AD_MCI', 'Normal'] 672 | filename = 'test_AM_N.pkl' 673 | 674 | results_report = classification_report(test_labels, test_pred, target_names=class_names) 675 | print '\nclassification report:' 676 | print results_report 677 | 678 | results = (test_names, test_labels, test_label_prob, test_pred, test_prob, 679 | p_y_given_x, results_report, class_names) 680 | 681 | f = open(filename, 'wb') 682 | pickle.dump(results, f, -1) 683 | f.close() 684 | f=open('image_gradient.pkl', 'wb') 685 | pickle.dump(image_gradient, f, -1) 686 | f.close() 687 | 688 | 689 | def test_scae_crossvalidate(data_dir, model, binary_classification=(False, False, False, False)): 690 | data_list = os.listdir(data_dir) 691 | batch_size, d, c, h, w = model.image_shape 692 | num_subjects = int(1./5*len(data_list)) 693 | num_batches = num_subjects/batch_size 694 | if num_subjects%batch_size != 0: 695 | num_batches += 1 696 | AD_Normal, AD_MCI, MCI_Normal, AM_N = binary_classification 697 | 698 | for fold in xrange(5): 699 | data_list_fold = [data for data in data_list if int(data[-6:-4])%5==fold] 700 | if True not in binary_classification: 701 | print 'testing scae for fold %d AD_MCI_Normal' % (fold) 702 | elif not MCI_Normal and not AM_N: 703 | print 'testing scae for %s for fold %d'%('AD_Normal' if AD_Normal else 'AD_MCI', fold) 704 | filename = 'scae_%s_fold%d.pkl'%('AD_Normal' if AD_Normal else 'AD_MCI', fold) 705 | else: 706 | filename = 'scae_%s_fold%d.pkl'%('MCI_Normal' if MCI_Normal else 'AM_N', fold) 707 | print 'testing scae for %s for fold %d'%('MCI_Normal' if MCI_Normal else 'AM_N', fold) 708 | model.load(filename) 709 | sys.stdout.flush() 710 | 711 | test_labels, test_names, test_pred, test_prob, test_label_prob= [], [], [], [], [] 712 | num_labels = 2 if True in binary_classification else 3 713 | p_y_given_x = np.empty((num_subjects, num_labels), dtype=FLOAT_PRECISION) 714 | conv2_feat = np.empty((num_subjects, np.prod(model.conv2_output_shape[1:])), dtype=FLOAT_PRECISION) 715 | conv3_feat = np.empty((num_subjects, np.prod(model.conv3_output_shape[1:])), dtype=FLOAT_PRECISION) 716 | ip2_feat = np.empty((num_subjects, 500), dtype=FLOAT_PRECISION) 717 | ip1_feat = np.empty((num_subjects, 2000), dtype=FLOAT_PRECISION) 718 | for batch in xrange(num_batches): 719 | if True not in binary_classification: 720 | batch_data, batch_labels, batch_names = load_batch(batch, num_batches, model.image_shape, data_dir, 721 | data_list=data_list_fold) 722 | elif AD_Normal: 723 | batch_data, batch_labels, batch_names = load_batch_AD_Normal(batch, num_batches, model.image_shape, data_dir, 724 | data_list=data_list_fold) 725 | elif AD_MCI: 726 | batch_data, batch_labels, batch_names = load_batch_AD_MCI(batch, num_batches, model.image_shape, data_dir, 727 | data_list=data_list_fold) 728 | elif MCI_Normal: 729 | batch_data, batch_labels, batch_names = load_batch_MCI_Normal(batch, num_batches, model.image_shape, data_dir, 730 | data_list=data_list_fold) 731 | elif AM_N: 732 | batch_data, batch_labels, batch_names = load_batch_AM_N(batch, num_batches, model.image_shape, data_dir, 733 | data_list=data_list_fold) 734 | batch_error, pred, prob, truth_prob, batch_p_y_given_x, batch_conv2_feat, \ 735 | batch_conv3_feat, batch_ip2_feat, batch_ip1_feat, batch_gradient\ 736 | = model.forward(batch_data, batch_labels) 737 | test_labels.extend(batch_labels) 738 | test_names.extend(batch_names) 739 | test_pred.extend(pred) 740 | test_prob.extend(prob) 741 | test_label_prob.extend(truth_prob) 742 | p_y_given_x[batch*batch_size:(batch+1)*batch_size, :] = batch_p_y_given_x 743 | conv2_feat[batch*batch_size:(batch+1)*batch_size, :] = batch_conv2_feat 744 | conv3_feat[batch*batch_size:(batch+1)*batch_size, :] = batch_conv3_feat 745 | ip2_feat[batch*batch_size:(batch+1)*batch_size, :] = batch_ip2_feat 746 | ip1_feat[batch*batch_size:(batch+1)*batch_size, :] = batch_ip1_feat 747 | 748 | print '\n\nbatch:%02d\terror:%.2f' % (batch, batch_error) 749 | print 'subjects:\t', 750 | for name in batch_names: 751 | print name[:-4], 752 | print 753 | print 'labels:\t', 754 | for l in batch_labels: 755 | print l, 756 | print 757 | print 'pred:\t', 758 | for p in pred: 759 | print p, 760 | print 761 | print 'prob:', 762 | for p in prob: 763 | print '%.2f'%p, 764 | print 765 | sys.stdout.flush() 766 | 767 | accuracy = accuracy_score(test_labels, test_pred) 768 | f_score = f1_score(np.asarray(test_labels), np.asarray(test_pred)) 769 | confusion = confusion_matrix(np.asarray(test_labels), np.asarray(test_pred)) 770 | computed_auc = roc_auc_score(test_labels, test_pred) 771 | print '\n\nAccuracy:%.4f\tF1_Score:%.4f\tAUC:%.4f' % (accuracy, f_score, computed_auc) 772 | print '\nconfusion:' 773 | print confusion 774 | 775 | if True not in binary_classification: 776 | class_names = ['AD', 'MCI', 'Normal'] 777 | filename = 'test_AMN_fold{0}.pkl'.format(fold) 778 | elif AD_Normal: 779 | class_names = ['AD', 'Normal'] 780 | filename = 'test_AN_fold{0}.pkl'.format(fold) 781 | elif AD_MCI: 782 | class_names = ['AD', 'MCI'] 783 | filename = 'test_AM_fold{0}.pkl'.format(fold) 784 | elif MCI_Normal: 785 | class_names = ['MCI', 'Normal'] 786 | filename = 'test_MN_fold{0}.pkl'.format(fold) 787 | elif AM_N: 788 | class_names = ['AD_MCI', 'Normal'] 789 | filename = 'test_AM_N_fold{0}.pkl'.format(fold) 790 | 791 | results_report = classification_report(test_labels, test_pred, target_names=class_names) 792 | print '\nclassification report:' 793 | print results_report 794 | 795 | results = (test_names, test_labels, test_label_prob, test_pred, test_prob, 796 | p_y_given_x, results_report, class_names) 797 | 798 | f = open(filename, 'wb') 799 | pickle.dump(results, f, -1) 800 | f.close() 801 | 802 | 803 | def main(): 804 | data_dir, scae_model, cae1_model, cae2_model, cae3_model, activation_cae, activation_final, \ 805 | flt_channels, flt_size, pretrain_layer, get_hidden, test, finetune, AD_Normal, AD_MCI, MCI_Normal, AM_N, load_conv, batchsize = \ 806 | ProcessCommandLine() 807 | binary = (AD_Normal, AD_MCI, MCI_Normal, AM_N) 808 | print 'cae activation:', activation_cae 809 | print 'final layers activation:', activation_final 810 | print 'filter channels:', flt_channels 811 | print 'filter size:', flt_size 812 | sys.stdout.flush() 813 | data_list = os.listdir(data_dir) 814 | sample = sio.loadmat(data_dir+data_list[0]) 815 | depth, height, width = sample['original'].shape 816 | in_channels = 1 817 | in_time = depth 818 | in_width = width 819 | in_height = height 820 | flt_depth = flt_size 821 | flt_width = flt_size 822 | flt_height = flt_size 823 | 824 | image_shp = (batchsize, in_time, in_channels, in_height, in_width) 825 | filter_shp_1 = (flt_channels[0], flt_depth, in_channels, flt_height, flt_width) 826 | filter_shp_2 = (flt_channels[1], flt_depth, filter_shp_1[0], flt_height, flt_width) 827 | filter_shp_3 = (flt_channels[2], flt_depth, filter_shp_2[0], flt_height, flt_width) 828 | 829 | if not finetune and not test and get_hidden==0: 830 | cae1 = CAE3d(signal_shape=image_shp, 831 | filter_shape=filter_shp_1, 832 | poolsize=(2, 2, 2), 833 | activation=activation_cae) 834 | print 'CAE1 built' 835 | if cae1_model: 836 | cae1.load(cae1_model) 837 | sys.stdout.flush() 838 | 839 | cae2 = CAE3d(signal_shape=cae1.hidden_pooled_image_shape, 840 | filter_shape=filter_shp_2, 841 | poolsize=(2, 2, 2), 842 | activation=activation_cae) 843 | print 'CAE2 built' 844 | if cae2_model: 845 | cae2.load(cae2_model) 846 | sys.stdout.flush() 847 | 848 | cae3 = CAE3d(signal_shape=cae2.hidden_pooled_image_shape, 849 | filter_shape=filter_shp_3, 850 | poolsize=(2, 2, 2), 851 | activation=activation_cae) 852 | print 'CAE3 built' 853 | if cae3_model: 854 | cae3.load(cae3_model) 855 | sys.stdout.flush() 856 | 857 | if pretrain_layer != 0: 858 | cae_models = [cae1, cae2, cae3] 859 | do_pretraining_cae(data_dir=data_dir, 860 | models=cae_models, 861 | cae_layer=pretrain_layer, 862 | max_epoch=100) 863 | 864 | 865 | elif finetune or test or get_hidden: 866 | print 'creating scae...' 867 | sys.stdout.flush() 868 | if True not in binary: 869 | scae = stacked_CAE3d(image_shape=image_shp, 870 | filter_shapes=(filter_shp_1, filter_shp_2, filter_shp_3), 871 | poolsize=(2, 2, 2), 872 | activation_cae=activation_cae, 873 | activation_final=activation_final, 874 | hidden_size=(2000, 500, 200, 20, 3)) 875 | else: 876 | scae = stacked_CAE3d(image_shape=image_shp, 877 | filter_shapes=(filter_shp_1, filter_shp_2, filter_shp_3), 878 | poolsize=(2, 2, 2), 879 | activation_cae=activation_cae, 880 | activation_final=activation_final, 881 | hidden_size=(2000, 500, 200, 20, 2)) 882 | 883 | print 'scae model built' 884 | sys.stdout.flush() 885 | if cae1_model: 886 | scae.load_cae(cae1_model, cae_layer=0) 887 | pass 888 | if cae2_model: 889 | scae.load_cae(cae2_model, cae_layer=1) 890 | pass 891 | if cae3_model: 892 | scae.load_cae(cae3_model, cae_layer=2) 893 | pass 894 | sys.stdout.flush() 895 | 896 | if scae_model: 897 | if True in binary and scae_model[:-25] == 'scae': 898 | if load_conv: 899 | scae.load_conv(scae_model) 900 | else: 901 | scae.load_binary(scae_model) 902 | else: 903 | if load_conv: 904 | scae.load_conv(scae_model) 905 | else: 906 | scae.load(scae_model) 907 | pass 908 | sys.stdout.flush() 909 | 910 | if finetune: 911 | finetune_scae_crossvalidate(data_dir=data_dir, 912 | model=scae, 913 | binary_classification=binary, 914 | max_epoch=100) 915 | elif test: 916 | test_scae_crossvalidate(data_dir=data_dir, 917 | model=scae, 918 | binary_classification=binary) 919 | elif get_hidden!=0: 920 | get_hidden_finetuned(dir=data_dir, 921 | model=scae, 922 | layer=get_hidden) 923 | 924 | 925 | if __name__ == '__main__': 926 | sys.exit(main()) 927 | 928 | -------------------------------------------------------------------------------- /maxpool3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Max pooling spatio-temporal inputs for Theano 3 | """ 4 | 5 | 6 | from theano import tensor 7 | from theano.tensor.signal.downsample import DownsampleFactorMax 8 | 9 | 10 | def max_pool_3d(input, ds, ignore_border=False): 11 | """ 12 | Takes as input a N-D tensor, where N >= 3. It downscales the input video by 13 | the specified factor, by keeping only the maximum value of non-overlapping 14 | patches of size (ds[0],ds[1],ds[2]) (time, height, width) 15 | :type input: N-D theano tensor of input images. 16 | :param input: input images. Max pooling will be done over the 3 last dimensions. 17 | :type ds: tuple of length 3 18 | :param ds: factor by which to downscale. (2,2,2) will halve the video in each dimension. 19 | :param ignore_border: boolean value. When True, (5,5,5) input with ds=(2,2,2) will generate a 20 | (2,2,2) output. (3,3,3) otherwise. 21 | """ 22 | 23 | if input.ndim < 3: 24 | raise NotImplementedError('max_pool_3d requires a dimension >= 3') 25 | 26 | # extract nr dimensions 27 | vid_dim = input.ndim 28 | # max pool in two different steps, so we can use the 2d implementation of 29 | # downsamplefactormax. First maxpool frames as usual. 30 | # Then maxpool the time dimension. Shift the time dimension to the third 31 | # position, so rows and cols are in the back 32 | 33 | # extract dimensions 34 | frame_shape = input.shape[-2:] 35 | 36 | # count the number of "leading" dimensions, store as dmatrix 37 | batch_size = tensor.prod(input.shape[:-2]) 38 | batch_size = tensor.shape_padright(batch_size,1) 39 | 40 | # store as 4D tensor with shape: (batch_size,1,height,width) 41 | new_shape = tensor.cast(tensor.join(0, batch_size, 42 | tensor.as_tensor([1,]), 43 | frame_shape), 'int32') 44 | input_4D = tensor.reshape(input, new_shape, ndim=4) 45 | 46 | # downsample mini-batch of videos in rows and cols 47 | op = DownsampleFactorMax((ds[1],ds[2]), ignore_border) 48 | output = op(input_4D) 49 | # restore to original shape 50 | outshape = tensor.join(0, input.shape[:-2], output.shape[-2:]) 51 | out = tensor.reshape(output, outshape, ndim=input.ndim) 52 | 53 | # now maxpool time 54 | 55 | # output (time, rows, cols), reshape so that time is in the back 56 | shufl = (list(range(vid_dim-3)) + [vid_dim-2]+[vid_dim-1]+[vid_dim-3]) 57 | input_time = out.dimshuffle(shufl) 58 | # reset dimensions 59 | vid_shape = input_time.shape[-2:] 60 | 61 | # count the number of "leading" dimensions, store as dmatrix 62 | batch_size = tensor.prod(input_time.shape[:-2]) 63 | batch_size = tensor.shape_padright(batch_size,1) 64 | 65 | # store as 4D tensor with shape: (batch_size,1,width,time) 66 | new_shape = tensor.cast(tensor.join(0, batch_size, 67 | tensor.as_tensor([1,]), 68 | vid_shape), 'int32') 69 | input_4D_time = tensor.reshape(input_time, new_shape, ndim=4) 70 | # downsample mini-batch of videos in time 71 | op = DownsampleFactorMax((1,ds[0]), ignore_border) 72 | outtime = op(input_4D_time) 73 | # output 74 | # restore to original shape (xxx, rows, cols, time) 75 | outshape = tensor.join(0, input_time.shape[:-2], outtime.shape[-2:]) 76 | shufl = (list(range(vid_dim-3)) + [vid_dim-1]+[vid_dim-3]+[vid_dim-2]) 77 | return tensor.reshape(outtime, outshape, ndim=input.ndim).dimshuffle(shufl) --------------------------------------------------------------------------------