├── assets
├── intel.png
├── oatml.png
├── oxcs.png
└── turing.png
├── leaderboard
├── diabetic_retinopathy_diagnosis
│ ├── auc
│ │ ├── mfvi.csv
│ │ ├── random.csv
│ │ ├── deep_ensemble.csv
│ │ ├── deterministic.csv
│ │ ├── mc_dropout.csv
│ │ └── ensemble_mc_dropout.csv
│ ├── accuracy
│ │ ├── mfvi.csv
│ │ ├── random.csv
│ │ ├── deep_ensemble.csv
│ │ ├── deterministic.csv
│ │ ├── mc_dropout.csv
│ │ └── ensemble_mc_dropout.csv
│ └── README.md
├── aptos2019
│ ├── auc
│ │ ├── random.csv
│ │ ├── ensemble_mc_dropout.csv
│ │ ├── deterministic.csv
│ │ ├── mfvi.csv
│ │ ├── deep_ensemble.csv
│ │ └── mc_dropout.csv
│ └── accuracy
│ │ ├── deterministic.csv
│ │ ├── ensemble_mc_dropout.csv
│ │ ├── deep_ensemble.csv
│ │ ├── mc_dropout.csv
│ │ ├── random.csv
│ │ └── mfvi.csv
└── pretraining
│ ├── auc
│ ├── mc_dropout.csv
│ ├── random.csv
│ └── deterministic.csv
│ └── accuracy
│ ├── deterministic.csv
│ ├── mc_dropout.csv
│ └── random.csv
├── baselines
├── diabetic_retinopathy_diagnosis
│ ├── mfvi
│ │ ├── configs
│ │ │ ├── medium.cfg
│ │ │ └── realworld.cfg
│ │ ├── __init__.py
│ │ ├── main.py
│ │ └── model.py
│ ├── mc_dropout
│ │ ├── configs
│ │ │ ├── medium.cfg
│ │ │ └── realworld.cfg
│ │ ├── __init__.py
│ │ ├── main.py
│ │ └── model.py
│ ├── deep_ensembles
│ │ ├── __init__.py
│ │ ├── model.py
│ │ └── main.py
│ ├── deterministic
│ │ ├── __init__.py
│ │ ├── model.py
│ │ └── main.py
│ ├── ensemble_mc_dropout
│ │ ├── __init__.py
│ │ ├── model.py
│ │ └── main.py
│ └── README.md
└── __init__.py
├── bdlb
├── core
│ ├── __init__.py
│ ├── constants.py
│ ├── benchmark.py
│ ├── levels.py
│ ├── registered.py
│ ├── transforms.py
│ └── plotting.py
├── diabetic_retinopathy_diagnosis
│ ├── __init__.py
│ ├── tfds_adapter.py
│ └── benchmark.py
└── __init__.py
├── setup.py
├── .gitignore
├── README.md
└── LICENSE
/assets/intel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OATML/bdl-benchmarks/HEAD/assets/intel.png
--------------------------------------------------------------------------------
/assets/oatml.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OATML/bdl-benchmarks/HEAD/assets/oatml.png
--------------------------------------------------------------------------------
/assets/oxcs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OATML/bdl-benchmarks/HEAD/assets/oxcs.png
--------------------------------------------------------------------------------
/assets/turing.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OATML/bdl-benchmarks/HEAD/assets/turing.png
--------------------------------------------------------------------------------
/leaderboard/diabetic_retinopathy_diagnosis/auc/mfvi.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,86.6,0.9
3 | 0.6,85.4,1.2
4 | 0.7,84.0,1.0
5 | 0.8,83.0,0.9
6 | 0.9,81.8,0.8
7 | 1.0,82.1,1.3
--------------------------------------------------------------------------------
/leaderboard/diabetic_retinopathy_diagnosis/auc/random.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,81.8,1.2
3 | 0.6,82.1,1.1
4 | 0.7,82.0,1.3
5 | 0.8,81.8,1.1
6 | 0.9,82.1,1.0
7 | 1.0,82.0,0.9
--------------------------------------------------------------------------------
/leaderboard/diabetic_retinopathy_diagnosis/accuracy/mfvi.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,88.1,1.1
3 | 0.6,86.9,0.9
4 | 0.7,85.0,1.0
5 | 0.8,84.5,0.7
6 | 0.9,84.4,0.6
7 | 1.0,84.3,0.7
--------------------------------------------------------------------------------
/leaderboard/diabetic_retinopathy_diagnosis/accuracy/random.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,84.8,0.9
3 | 0.6,84.6,0.8
4 | 0.7,84.3,0.7
5 | 0.8,84.6,0.7
6 | 0.9,84.3,0.6
7 | 1.0,84.2,0.5
--------------------------------------------------------------------------------
/leaderboard/diabetic_retinopathy_diagnosis/auc/deep_ensemble.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,87.2,0.9
3 | 0.6,86.1,1.1
4 | 0.7,84.9,0.8
5 | 0.8,83.8,0.9
6 | 0.9,82.7,0.8
7 | 1.0,81.8,1.1
--------------------------------------------------------------------------------
/leaderboard/diabetic_retinopathy_diagnosis/auc/deterministic.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,84.9,1.1
3 | 0.6,83.4,1.3
4 | 0.7,82.3,1.2
5 | 0.8,81.8,1.2
6 | 0.9,81.7,1.4
7 | 1.0,82.0,1.0
--------------------------------------------------------------------------------
/leaderboard/diabetic_retinopathy_diagnosis/auc/mc_dropout.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,87.8,1.1
3 | 0.6,86.2,0.8
4 | 0.7,85.2,0.8
5 | 0.8,84.1,0.7
6 | 0.9,83.0,0.8
7 | 1.0,82.1,1.2
--------------------------------------------------------------------------------
/leaderboard/diabetic_retinopathy_diagnosis/accuracy/deep_ensemble.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,89.9,0.9
3 | 0.6,88.3,1.1
4 | 0.7,86.1,1.0
5 | 0.8,85.1,0.9
6 | 0.9,84.7,0.8
7 | 1.0,84.6,0.7
--------------------------------------------------------------------------------
/leaderboard/diabetic_retinopathy_diagnosis/accuracy/deterministic.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,86.1,0.6
3 | 0.6,85.2,0.5
4 | 0.7,84.9,0.5
5 | 0.8,84.4,0.5
6 | 0.9,84.3,0.5
7 | 1.0,84.2,0.6
--------------------------------------------------------------------------------
/leaderboard/diabetic_retinopathy_diagnosis/accuracy/mc_dropout.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,91.3,0.7
3 | 0.6,88.9,0.9
4 | 0.7,87.1,0.9
5 | 0.8,86.4,0.7
6 | 0.9,85.2,0.8
7 | 1.0,84.5,0.9
--------------------------------------------------------------------------------
/leaderboard/diabetic_retinopathy_diagnosis/auc/ensemble_mc_dropout.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,88.1,1.2
3 | 0.6,86.7,1.0
4 | 0.7,85.4,1.0
5 | 0.8,84.4,0.6
6 | 0.9,83.2,0.8
7 | 1.0,82.5,1.1
--------------------------------------------------------------------------------
/leaderboard/diabetic_retinopathy_diagnosis/accuracy/ensemble_mc_dropout.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,92.4,0.9
3 | 0.6,90.2,1.0
4 | 0.7,88.1,1.0
5 | 0.8,87.1,0.9
6 | 0.9,86.7,0.8
7 | 1.0,85.3,1.0
--------------------------------------------------------------------------------
/baselines/diabetic_retinopathy_diagnosis/mfvi/configs/medium.cfg:
--------------------------------------------------------------------------------
1 | --level=medium
2 | --output_dir=tmp/medium.mfvi
3 | --batch_size=64
4 | --num_epochs=50
5 | --num_base_filters=42
6 | --learning_rate=4e-4
7 |
--------------------------------------------------------------------------------
/baselines/diabetic_retinopathy_diagnosis/mfvi/configs/realworld.cfg:
--------------------------------------------------------------------------------
1 | --level=realworld
2 | --output_dir=tmp/realworld.mfvi
3 | --batch_size=64
4 | --num_epochs=50
5 | --num_base_filters=42
6 | --learning_rate=4e-4
7 |
--------------------------------------------------------------------------------
/leaderboard/aptos2019/auc/random.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,74.1115282530884,1.2
3 | 0.6,73.98926352594685,1.1
4 | 0.7,74.31505347450982,1.3
5 | 0.8,73.63684036945445,1.1
6 | 0.9,73.70576881316478,1.0
7 | 1.0,74.44604600876292,0.9
8 |
--------------------------------------------------------------------------------
/leaderboard/pretraining/auc/mc_dropout.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,90.50087157323037,1.1
3 | 0.6,89.27465772369939,0.8
4 | 0.7,87.98128411710026,1.0
5 | 0.8,86.90757174461723,0.8
6 | 0.9,85.8656455334597,1.0
7 | 1.0,86.05092622000835,1.2
8 |
--------------------------------------------------------------------------------
/leaderboard/pretraining/auc/random.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,84.69111482898622,1.2
3 | 0.6,84.80000240671389,1.1
4 | 0.7,84.74939574019838,1.3
5 | 0.8,84.64244653558231,1.1
6 | 0.9,84.86491640995733,1.0
7 | 1.0,84.75539927127676,0.9
8 |
--------------------------------------------------------------------------------
/leaderboard/aptos2019/accuracy/deterministic.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,77.5813901979666,1.2
3 | 0.6,76.49447682269734,1.1
4 | 0.7,77.70305205105299,1.1
5 | 0.8,76.14151490098507,1.1
6 | 0.9,76.72487320864245,1.1
7 | 1.0,76.5230841411558,1.2
8 |
--------------------------------------------------------------------------------
/leaderboard/pretraining/auc/deterministic.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,87.7653697027395,1.1
3 | 0.6,86.29581740991412,1.3
4 | 0.7,85.37533171580017,1.2
5 | 0.8,84.60957109993197,1.2
6 | 0.9,84.51952536111276,1.4
7 | 1.0,84.98894869973235,1.0
8 |
--------------------------------------------------------------------------------
/leaderboard/aptos2019/auc/ensemble_mc_dropout.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,82.37751333962768,1.2
3 | 0.6,81.29736671796819,1.0
4 | 0.7,79.82695138171496,1.0
5 | 0.8,79.00445952639875,0.6
6 | 0.9,77.87282433308359,0.8
7 | 1.0,77.21169988706856,1.1
8 |
--------------------------------------------------------------------------------
/leaderboard/pretraining/accuracy/deterministic.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,89.81101313374596,0.6
3 | 0.6,88.20880358604539,0.5
4 | 0.7,88.67156209245978,0.5
5 | 0.8,88.3554739574559,0.5
6 | 0.9,88.26438144807636,0.5
7 | 1.0,87.48948402164251,0.6
8 |
--------------------------------------------------------------------------------
/leaderboard/pretraining/accuracy/mc_dropout.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,95.37552651460484,0.7
3 | 0.6,92.84161334475907,0.9
4 | 0.7,91.12201822951694,0.9
5 | 0.8,89.73503652410145,0.7
6 | 0.9,88.79844947088918,0.8
7 | 1.0,87.85851249188208,0.9
8 |
--------------------------------------------------------------------------------
/baselines/diabetic_retinopathy_diagnosis/mc_dropout/configs/medium.cfg:
--------------------------------------------------------------------------------
1 | --level=medium
2 | --output_dir=tmp/medium.mc_dropout
3 | --batch_size=64
4 | --num_epochs=50
5 | --num_base_filters=64
6 | --learning_rate=4e-4
7 | --dropout_rate=0.1
8 | --l2_reg=5e-5
9 |
--------------------------------------------------------------------------------
/leaderboard/aptos2019/accuracy/ensemble_mc_dropout.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,87.1431711534417,1.3
3 | 0.6,84.93029656023879,1.4
4 | 0.7,82.88382528204772,1.4
5 | 0.8,81.23390412356164,1.3
6 | 0.9,80.87747459924013,1.2
7 | 1.0,79.40399028140253,1.4
8 |
--------------------------------------------------------------------------------
/baselines/diabetic_retinopathy_diagnosis/mc_dropout/configs/realworld.cfg:
--------------------------------------------------------------------------------
1 | --level=realworld
2 | --output_dir=tmp/realworld.mc_dropout
3 | --batch_size=64
4 | --num_epochs=50
5 | --num_base_filters=64
6 | --learning_rate=4e-4
7 | --dropout_rate=0.1
8 | --l2_reg=5e-5
9 |
--------------------------------------------------------------------------------
/leaderboard/aptos2019/accuracy/deep_ensemble.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,84.33166109450175,0.7000000000000001
3 | 0.6,82.4987348870786,0.9
4 | 0.7,80.4561312734997,0.8
5 | 0.8,79.53716354380433,0.7000000000000001
6 | 0.9,79.30688328092553,0.6000000000000001
7 | 1.0,78.8957482417067,0.5
8 |
--------------------------------------------------------------------------------
/leaderboard/aptos2019/accuracy/mc_dropout.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,84.10168287156655,1.5
3 | 0.6,82.37704127549415,1.7000000000000002
4 | 0.7,79.93894164742707,1.7000000000000002
5 | 0.8,80.56345114826215,1.5
6 | 0.9,78.32015795734496,1.6
7 | 1.0,78.6132934251819,1.7000000000000002
8 |
--------------------------------------------------------------------------------
/leaderboard/aptos2019/auc/deterministic.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,76.35435199554918,0.5
3 | 0.6,74.31748839691352,0.7
4 | 0.7,73.55963606499797,0.5999999999999999
5 | 0.8,71.63022802104602,0.5999999999999999
6 | 0.9,73.83468349205513,0.8
7 | 1.0,73.85692043415344,0.39999999999999997
8 |
--------------------------------------------------------------------------------
/leaderboard/pretraining/accuracy/random.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,87.84742483344976,0.6000000000000001
3 | 0.6,88.01581754189371,0.5
4 | 0.7,87.20753593919196,0.39999999999999997
5 | 0.8,87.61709745818298,0.39999999999999997
6 | 0.9,86.91129532409715,0.3
7 | 1.0,87.23000053924596,0.2
8 |
--------------------------------------------------------------------------------
/leaderboard/aptos2019/auc/mfvi.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,77.85178917697054,0.09999999999999998
3 | 0.6,77.86896735637504,0.30000000000000004
4 | 0.7,75.55112560598127,0.0
5 | 0.8,75.2693398565609,0.09999999999999998
6 | 0.9,73.87362012532374,0.0
7 | 1.0,74.5137868570405,0.30000000000000004
8 |
--------------------------------------------------------------------------------
/leaderboard/aptos2019/accuracy/random.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,76.91767480701789,0.3000000000000001
3 | 0.6,76.73097026546617,0.2
4 | 0.7,76.41935840750058,0.10000000000000003
5 | 0.8,75.93500461156489,0.10000000000000003
6 | 0.9,76.26363128215164,0.0
7 | 1.0,76.93718103862291,-0.09999999999999998
8 |
--------------------------------------------------------------------------------
/leaderboard/aptos2019/auc/deep_ensemble.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,80.82331632570184,0.7000000000000001
3 | 0.6,80.61568433971902,0.9
4 | 0.7,78.56652659837609,0.6000000000000001
5 | 0.8,77.46110698253457,0.7000000000000001
6 | 0.9,77.00546441449661,0.6000000000000001
7 | 1.0,76.21950318328321,0.9
8 |
--------------------------------------------------------------------------------
/leaderboard/aptos2019/auc/mc_dropout.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,79.92363409730275,0.30000000000000004
3 | 0.6,78.50683527663989,0.0
4 | 0.7,76.86254965867806,0.19999999999999996
5 | 0.8,74.95801354030414,0.0
6 | 0.9,75.30579744841387,0.19999999999999996
7 | 1.0,74.0241101850024,0.3999999999999999
8 |
--------------------------------------------------------------------------------
/leaderboard/aptos2019/accuracy/mfvi.csv:
--------------------------------------------------------------------------------
1 | retained_data,mean,std
2 | 0.5,81.15193369214985,0.30000000000000004
3 | 0.6,80.18774373202086,0.09999999999999998
4 | 0.7,78.61797765336982,0.19999999999999996
5 | 0.8,76.99668588324471,-0.10000000000000009
6 | 0.9,77.59646926912359,-0.2
7 | 1.0,77.24944531248897,-0.10000000000000009
8 |
--------------------------------------------------------------------------------
/baselines/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
--------------------------------------------------------------------------------
/bdlb/core/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
--------------------------------------------------------------------------------
/bdlb/diabetic_retinopathy_diagnosis/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
--------------------------------------------------------------------------------
/baselines/diabetic_retinopathy_diagnosis/mfvi/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
--------------------------------------------------------------------------------
/baselines/diabetic_retinopathy_diagnosis/mc_dropout/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
--------------------------------------------------------------------------------
/baselines/diabetic_retinopathy_diagnosis/deep_ensembles/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
--------------------------------------------------------------------------------
/baselines/diabetic_retinopathy_diagnosis/deterministic/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
--------------------------------------------------------------------------------
/baselines/diabetic_retinopathy_diagnosis/ensemble_mc_dropout/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
--------------------------------------------------------------------------------
/bdlb/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Global API."""
16 |
17 | from .core.registered import load
18 |
--------------------------------------------------------------------------------
/bdlb/core/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Default values for some parameters of the API when no values are passed."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 |
23 | # Root directory of the BDL Benchmarks module.
24 | BDLB_ROOT_DIR: str = os.path.abspath(
25 | os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))
26 |
27 | # Directory where to store processed datasets.
28 | DATA_DIR: str = os.path.join(BDLB_ROOT_DIR, "data")
29 |
30 | # URL to hosted datasets
31 | DIABETIC_RETINOPATHY_DIAGNOSIS_URL_MEDIUM = "https://drive.google.com/uc?id=1WAvS-pQsVLxUJiClmKLnVNQkoKmRt2I_"
32 |
33 | # URL to hosted assets
34 | LEADERBOARD_DIR_URL: str = "https://drive.google.com/uc?id=1LQeAfqMQa4lot09qAuzWa3t6attmCeG-"
35 |
--------------------------------------------------------------------------------
/bdlb/core/benchmark.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Data structures and API for general benchmarks."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 |
23 | from .levels import Level
24 |
25 | # Abstract class for benchmark information.
26 | BenchmarkInfo = collections.namedtuple("BenchmarkInfo", [
27 | "description",
28 | "urls",
29 | "setup",
30 | "citation",
31 | ])
32 |
33 | # Container for train, validation and test sets.
34 | DataSplits = collections.namedtuple("DataSplits", [
35 | "train",
36 | "validation",
37 | "test",
38 | ])
39 |
40 |
41 | class Benchmark(object):
42 | """Abstract class for benchmark objects, specifying the core API."""
43 |
44 | def download_and_prepare(self):
45 | """Downloads and prepares necessary datasets for benchmark."""
46 | raise NotImplementedError()
47 |
48 | @property
49 | def info(self) -> BenchmarkInfo:
50 | """Text description of the benchmark."""
51 | raise NotImplementedError()
52 |
53 | @property
54 | def level(self) -> Level:
55 | """The downstream task level."""
56 | raise NotImplementedError()
57 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | import os
17 |
18 | from setuptools import find_packages
19 | from setuptools import setup
20 |
21 | here = os.path.abspath(os.path.dirname(__file__))
22 |
23 | # Get the long description from the README file
24 | with open(os.path.join(here, "README.md"), encoding="utf-8") as f:
25 | long_description = f.read()
26 |
27 | setup(
28 | name="bdlb",
29 | version="0.0.2",
30 | description="BDL Benchmarks",
31 | long_description=long_description,
32 | long_description_content_type="text/markdown",
33 | url="https://github.com/oatml/bdl-benchmarks",
34 | author="Oxford Applied and Theoretical Machine Learning Group",
35 | author_email="oatml@googlegroups.com",
36 | license="Apache-2.0",
37 | packages=find_packages(),
38 | install_requires=[
39 | "numpy==1.18.5",
40 | "scipy==1.4.1",
41 | "pandas==1.0.4",
42 | "matplotlib==3.2.1",
43 | "seaborn==0.10.1",
44 | "scikit-learn==0.21.3",
45 | "kaggle==1.5.6",
46 | "opencv-python==4.2.0.34",
47 | "tensorflow-gpu==2.0.0-beta0",
48 | "tensorflow-probability==0.7.0",
49 | "tensorflow-datasets==1.1.0",
50 | ],
51 | classifiers=[
52 | "Development Status :: 3 - Alpha",
53 | "Intended Audience :: Researchers",
54 | "Topic :: Software Development :: Build Tools",
55 | "License :: OSI Approved :: Apache 2.0 License",
56 | "Programming Language :: Python :: 3.6",
57 | "Programming Language :: Python :: 3.7",
58 | ],
59 | )
60 |
--------------------------------------------------------------------------------
/bdlb/core/levels.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Downstream tasks levels."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import enum
22 | from typing import Text
23 |
24 |
25 | class Level(enum.IntEnum):
26 | """Downstream task levels.
27 |
28 | TOY: Fewer examples and drastically reduced input dimensionality.
29 | This version is intended for sanity checks and debugging only.
30 | Training on a modern CPU should take five to ten minutes.
31 | MEDIUM: Fewer examples and reduced input dimensionality.
32 | This version is intended for prototyping before moving on to the
33 | real-world scale data. Training on a single modern GPU should take
34 | five or six hours.
35 | REALWORLD: The full dataset and input dimensionality. This version is
36 | intended for the evaluation of proposed techniques at a scale applicable
37 | to the real world. There are no guidelines for train time for the real-world
38 | version of the task, reflecting the fact that any improvement will translate to
39 | safer, more robust and reliable systems.
40 | """
41 |
42 | TOY = 0
43 | MEDIUM = 1
44 | REALWORLD = 2
45 |
46 | @classmethod
47 | def from_str(cls, strvalue: Text) -> "Level":
48 | """Parses a string value to ``Level``.
49 |
50 | Args:
51 | strvalue: `str`, the level in string format.
52 |
53 | Returns:
54 | The `IntEnum` ``Level`` object.
55 | """
56 | strvalue = strvalue.lower()
57 | if strvalue == "toy":
58 | return cls.TOY
59 | elif strvalue == "medium":
60 | return cls.MEDIUM
61 | elif strvalue == "realworld":
62 | return cls.REALWORLD
63 | else:
64 | raise ValueError(
65 | "Unrecognized level value '{}' provided.".format(strvalue))
66 |
--------------------------------------------------------------------------------
/leaderboard/diabetic_retinopathy_diagnosis/README.md:
--------------------------------------------------------------------------------
1 | # Diabetic Retinopathy Diagnosis
2 | The baseline results we evaluated on this benchmark are ranked below by AUC@50% data retained:
3 |
4 | | Method | AUC
(50% data retained) | Accuracy
(50% data retained) | AUC
(100% data retained) | Accuracy
(100% data retained) |
5 | | ------------------- | :-------------------------: | :-----------------------------: | :-------------------------: | :-------------------------------: |
6 | | Ensemble MC Dropout | 88.1 ± 1.2 | 92.4 ± 0.9 | 82.5 ± 1.1 | 85.3 ± 1.0 |
7 | | MC Dropout | 87.8 ± 1.1 | 91.3 ± 0.7 | 82.1 ± 0.9 | 84.5 ± 0.9 |
8 | | Deep Ensembles | 87.2 ± 0.9 | 89.9 ± 0.9 | 81.8 ± 1.1 | 84.6 ± 0.7 |
9 | | Mean-field VI | 86.6 ± 1.1 | 88.1 ± 1.1 | 82.1 ± 1.2 | 84.3 ± 0.7 |
10 | | Deterministic | 84.9 ± 1.1 | 86.1 ± 0.6 | 82.0 ± 1.0 | 84.2 ± 0.6 |
11 | | Random | 81.8 ± 1.2 | 84.8 ± 0.9 | 82.0 ± 0.9 | 84.2 ± 0.5 |
12 |
13 | These are also plotted below in an area under the receiver-operating characteristic curve (AUC) and binary accuracy for the different baselines. The methods that capture uncertainty score better when less data is retained, referring the least certain patients to expert doctors. The best scoring methods, _MC Dropout_, _mean-field variational inference_ and _Deep Ensembles_, estimate and use the predictive uncertainty. The deterministic deep model regularized by _standard dropout_ uses only aleatoric uncertainty and performs worse. Shading shows the standard error.
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # General
2 | .DS_Store
3 | .AppleDouble
4 | .LSOverride
5 | .vscode
6 |
7 | # Icon must end with two \r
8 | Icon
9 |
10 | # Thumbnails
11 | ._*
12 |
13 | # Files that might appear in the root of a volume
14 | .DocumentRevisions-V100
15 | .fseventsd
16 | .Spotlight-V100
17 | .TemporaryItems
18 | .Trashes
19 | .VolumeIcon.icns
20 | .com.apple.timemachine.donotpresent
21 |
22 | # Directories potentially created on remote AFP share
23 | .AppleDB
24 | .AppleDesktop
25 | Network Trash Folder
26 | Temporary Items
27 | .apdisk
28 |
29 | # Byte-compiled / optimized / DLL files
30 | __pycache__/
31 | *.py[cod]
32 | *$py.class
33 |
34 | # C extensions
35 | *.so
36 |
37 | # Distribution / packaging
38 | .Python
39 | build/
40 | develop-eggs/
41 | dist/
42 | downloads/
43 | eggs/
44 | .eggs/
45 | lib/
46 | lib64/
47 | parts/
48 | sdist/
49 | var/
50 | wheels/
51 | *.egg-info/
52 | .installed.cfg
53 | *.egg
54 | MANIFEST
55 |
56 | # PyInstaller
57 | # Usually these files are written by a python script from a template
58 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
59 | *.manifest
60 | *.spec
61 |
62 | # Installer logs
63 | pip-log.txt
64 | pip-delete-this-directory.txt
65 |
66 | # Unit test / coverage reports
67 | htmlcov/
68 | .tox/
69 | .coverage
70 | .coverage.*
71 | .cache
72 | nosetests.xml
73 | coverage.xml
74 | *.cover
75 | .hypothesis/
76 |
77 | # Translations
78 | *.mo
79 | *.pot
80 |
81 | # Django stuff:
82 | *.log
83 | .static_storage/
84 | .media/
85 | local_settings.py
86 |
87 | # Flask stuff:
88 | instance/
89 | .webassets-cache
90 |
91 | # Scrapy stuff:
92 | .scrapy
93 |
94 | # Sphinx documentation
95 | docs/_build/
96 |
97 | # PyBuilder
98 | target/
99 |
100 | # Jupyter Notebook
101 | .ipynb_checkpoints
102 |
103 | # pyenv
104 | .python-version
105 |
106 | # celery beat schedule file
107 | celerybeat-schedule
108 |
109 | # SageMath parsed files
110 | *.sage.py
111 |
112 | # Environments
113 | .env
114 | .venv
115 | env/
116 | venv/
117 | ENV/
118 | env.bak/
119 | venv.bak/
120 |
121 | # Spyder project settings
122 | .spyderproject
123 | .spyproject
124 |
125 | # Rope project settings
126 | .ropeproject
127 |
128 | # mkdocs documentation
129 | /site
130 |
131 | # mypy
132 | .mypy_cache/
133 |
134 | # temporary files
135 | tmp
136 |
137 | # direnv
138 | .envrc
139 |
140 | # Alpha release
141 | assets/*.graffle
142 | data
143 |
144 | # workstation files
145 | .style.yapf
146 | Makefile
147 | matplotlibrc
148 |
149 | # ckeckpoints
150 | baselines/*/*/checkpoints
--------------------------------------------------------------------------------
/baselines/diabetic_retinopathy_diagnosis/deterministic/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Uncertainty estimator for the deterministic deep model baseline."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 |
22 | def predict(x, model, num_samples, type="entropy"):
23 | """Simple sigmoid uncertainty estimator.
24 |
25 | Args:
26 | x: `numpy.ndarray`, datapoints from input space,
27 | with shape [B, H, W, 3], where B the batch size and
28 | H, W the input images height and width accordingly.
29 | model: `tensorflow.keras.Model`, a probabilistic model,
30 | which accepts input with shape [B, H, W, 3] and
31 | outputs sigmoid probability [0.0, 1.0], and also
32 | accepts boolean arguments `training=False` for
33 | disabling dropout at test time.
34 | type: (optional) `str`, type of uncertainty returns,
35 | one of {"entropy", "stddev"}.
36 |
37 | Returns:
38 | mean: `numpy.ndarray`, predictive mean, with shape [B].
39 | uncertainty: `numpy.ndarray`, ncertainty in prediction,
40 | with shape [B].
41 | """
42 | import numpy as np
43 | import scipy.stats
44 |
45 | # Get shapes of data
46 | B, _, _, _ = x.shape
47 |
48 | # Single forward pass from the deterministic model
49 | p = model(x, training=False)
50 |
51 | # Bernoulli output distribution
52 | dist = scipy.stats.bernoulli(p)
53 |
54 | # Predictive mean calculation
55 | mean = dist.mean()
56 |
57 | # Use predictive entropy for uncertainty
58 | if type == "entropy":
59 | uncertainty = dist.entropy()
60 | # Use predictive standard deviation for uncertainty
61 | elif type == "stddev":
62 | uncertainty = dist.std()
63 | else:
64 | raise ValueError(
65 | "Unrecognized type={} provided, use one of {'entropy', 'stddev'}".
66 | format(type))
67 |
68 | return mean, uncertainty
69 |
--------------------------------------------------------------------------------
/baselines/diabetic_retinopathy_diagnosis/deep_ensembles/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Uncertainty estimator for the Deep Ensemble baseline."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 |
22 | def predict(x, models, type="entropy"):
23 | """Deep Ensembles uncertainty estimator.
24 |
25 | Args:
26 | x: `numpy.ndarray`, datapoints from input space,
27 | with shape [B, H, W, 3], where B the batch size and
28 | H, W the input images height and width accordingly.
29 | models: `iterable` of `tensorflow.keras.Model`,
30 | a probabilistic model, which accepts input with
31 | shape [B, H, W, 3] and outputs sigmoid probability
32 | [0.0, 1.0], and also accepts boolean arguments
33 | `training=False` for disabling dropout at test time.
34 | type: (optional) `str`, type of uncertainty returns,
35 | one of {"entropy", "stddev"}.
36 |
37 | Returns:
38 | mean: `numpy.ndarray`, predictive mean, with shape [B].
39 | uncertainty: `numpy.ndarray`, ncertainty in prediction,
40 | with shape [B].
41 | """
42 | import numpy as np
43 | import scipy.stats
44 |
45 | # Get shapes of data
46 | B, _, _, _ = x.shape
47 |
48 | # Monte Carlo samples from different deterministic models
49 | mc_samples = np.asarray([model(x, training=False) for model in models
50 | ]).reshape(-1, B)
51 |
52 | # Bernoulli output distribution
53 | dist = scipy.stats.bernoulli(mc_samples.mean(axis=0))
54 |
55 | # Predictive mean calculation
56 | mean = dist.mean()
57 |
58 | # Use predictive entropy for uncertainty
59 | if type == "entropy":
60 | uncertainty = dist.entropy()
61 | # Use predictive standard deviation for uncertainty
62 | elif type == "stddev":
63 | uncertainty = dist.std()
64 | else:
65 | raise ValueError(
66 | "Unrecognized type={} provided, use one of {'entropy', 'stddev'}".
67 | format(type))
68 |
69 | return mean, uncertainty
70 |
--------------------------------------------------------------------------------
/baselines/diabetic_retinopathy_diagnosis/ensemble_mc_dropout/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Uncertainty estimator for the Ensemble Monte Carlo Dropout baseline."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 |
22 | def predict(x, models, num_samples, type="entropy"):
23 | """Deep Ensembles uncertainty estimator.
24 |
25 | Args:
26 | x: `numpy.ndarray`, datapoints from input space,
27 | with shape [B, H, W, 3], where B the batch size and
28 | H, W the input images height and width accordingly.
29 | num_samples: `int`, number of Monte Carlo samples
30 | (i.e. forward passes from dropout) used for
31 | the calculation of predictive mean and uncertainty.
32 | type: (optional) `str`, type of uncertainty returns,
33 | one of {"entropy", "stddev"}.
34 | models: `iterable` of `tensorflow.keras.Model`,
35 | a probabilistic model, which accepts input with shape
36 | [B, H, W, 3] and outputs sigmoid probability [0.0, 1.0],
37 | and also accepts boolean arguments `training=True` for
38 | enabling dropout at test time.
39 |
40 | Returns:
41 | mean: `numpy.ndarray`, predictive mean, with shape [B].
42 | uncertainty: `numpy.ndarray`, ncertainty in prediction,
43 | with shape [B].
44 | """
45 | import numpy as np
46 | import scipy.stats
47 |
48 | # Get shapes of data
49 | B, _, _, _ = x.shape
50 |
51 | # Monte Carlo samples from different dropout mask at test time from different models
52 | mc_samples = np.asarray([
53 | model(x, training=True) for _ in range(num_samples) for model in models
54 | ]).reshape(-1, B)
55 |
56 | # Bernoulli output distribution
57 | dist = scipy.stats.bernoulli(mc_samples.mean(axis=0))
58 |
59 | # Predictive mean calculation
60 | mean = dist.mean()
61 |
62 | # Use predictive entropy for uncertainty
63 | if type == "entropy":
64 | uncertainty = dist.entropy()
65 | # Use predictive standard deviation for uncertainty
66 | elif type == "stddev":
67 | uncertainty = dist.std()
68 | else:
69 | raise ValueError(
70 | "Unrecognized type={} provided, use one of {'entropy', 'stddev'}".
71 | format(type))
72 |
73 | return mean, uncertainty
74 |
--------------------------------------------------------------------------------
/bdlb/core/registered.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Benchmarks registry handlers and definitions."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from typing import Dict
22 | from typing import Optional
23 | from typing import Text
24 | from typing import Union
25 |
26 | from ..core.benchmark import Benchmark
27 | from ..core.levels import Level
28 | from ..diabetic_retinopathy_diagnosis.benchmark import \
29 | DiabeticRetinopathyDiagnosisBecnhmark
30 |
31 | # Internal registry containing
32 | _BENCHMARK_REGISTRY: Dict[Text, Benchmark] = {
33 | "diabetic_retinopathy_diagnosis": DiabeticRetinopathyDiagnosisBecnhmark
34 | }
35 |
36 |
37 | def load(
38 | benchmark: Text,
39 | level: Union[Text, Level] = "realworld",
40 | data_dir: Optional[Text] = None,
41 | download_and_prepare: bool = True,
42 | **dtask_kwargs,
43 | ) -> Benchmark:
44 | """Loads the named benchmark into a `bdlb.Benchmark`.
45 |
46 | Args:
47 | benchmark: `str`, the registerd name of `bdlb.Benchmark`.
48 | level: `bdlb.Level` or `str`, which level of the benchmark to load.
49 | data_dir: `str` (optional), directory to read/write data.
50 | Defaults to "~/.bdlb/data".
51 | download_and_prepare: (optional) `bool`, if the data is not available
52 | it downloads and preprocesses it.
53 | dtask_kwargs: key arguments for the benchmark contructor.
54 |
55 | Returns:
56 | A registered `bdlb.Benchmark` with `level` at `data_dir`.
57 |
58 | Raises:
59 | BenchmarkNotFoundError: if `name` is unrecognised.
60 | """
61 | if not benchmark in _BENCHMARK_REGISTRY:
62 | raise BenchmarkNotFoundError(benchmark)
63 | # Fetch benchmark object
64 | return _BENCHMARK_REGISTRY.get(benchmark)(
65 | level=level,
66 | data_dir=data_dir,
67 | download_and_prepare=download_and_prepare,
68 | **dtask_kwargs,
69 | )
70 |
71 |
72 | class BenchmarkNotFoundError(ValueError):
73 | """The requested `bdlb.Benchmark` was not found."""
74 |
75 | def __init__(self, name: Text):
76 | all_denchmarks_str = "\n\t- ".join([""] + list(_BENCHMARK_REGISTRY.keys()))
77 | error_str = (
78 | "Benchmark {name} not found. Available denchmarks: {benchmarks}\n",
79 | format(name=name, benchmarks=all_denchmarks_str))
80 | super(BenchmarkNotFoundError, self).__init__(error_str)
81 |
--------------------------------------------------------------------------------
/bdlb/core/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Data augmentation and transformations."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from typing import Optional
22 | from typing import Sequence
23 | from typing import Tuple
24 |
25 | import numpy as np
26 | import tensorflow as tf
27 |
28 |
29 | class Transform(object):
30 | """Abstract transformation class."""
31 |
32 | def __call__(
33 | self,
34 | x: tf.Tensor,
35 | y: Optional[tf.Tensor] = None,
36 | ) -> Tuple[tf.Tensor, tf.Tensor]:
37 | raise NotImplementedError()
38 |
39 |
40 | class Compose(Transform):
41 | """Uber transformation, composing a list of transformations."""
42 |
43 | def __init__(self, transforms: Sequence[Transform]):
44 | """Constructs a composition of transformations.
45 |
46 | Args:
47 | transforms: `iterable`, sequence of transformations to be composed.
48 | """
49 | self.transforms = transforms
50 |
51 | def __call__(
52 | self,
53 | x: tf.Tensor,
54 | y: Optional[tf.Tensor] = None,
55 | ) -> Tuple[tf.Tensor, tf.Tensor]:
56 | """Returns a composite function of transformations.
57 |
58 | Args:
59 | x: `any`, raw data format.
60 | y: `optional`, raw data format.
61 |
62 | Returns:
63 | A composite function to be used with `tf.data.Dataset.map()`.
64 | """
65 | for f in self.transforms:
66 | x, y = f(x, y)
67 | return x, y
68 |
69 |
70 | class RandomAugment(Transform):
71 |
72 | def __init__(self, **config):
73 | """Constructs a tranformer from `config`.
74 |
75 | Args:
76 | **config: keyword arguments for
77 | `tensorflow.keras.preprocessing.image.ImageDataGenerator`
78 | """
79 | self.idg = tf.keras.preprocessing.image.ImageDataGenerator(**config)
80 |
81 | def __call__(self, x: tf.Tensor, y: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
82 | """Returns a randomly augmented image and its label.
83 |
84 | Args:
85 | x: `tensorflow.Tensor`, an image, with shape [height, width, channels].
86 | y: `tensorflow.Tensor`, a target, with shape [].
87 |
88 |
89 | Returns:
90 | The processed tuple:
91 | * `x`: `tensorflow.Tensor`, the randomly augmented image,
92 | with shape [height, width, channels].
93 | * `y`: `tensorflow.Tensor`, the unchanged target, with shape [].
94 | """
95 | return tf.py_function(self._transform, inp=[x], Tout=tf.float32), y
96 |
97 | def _transform(self, x: tf.Tensor) -> tf.Tensor:
98 | """Helper function for `tensorflow.py_function`, will be removed when
99 | TensorFlow 2.0 is released."""
100 | return tf.cast(self.idg.random_transform(x.numpy()), tf.float32)
101 |
102 |
103 | class Resize(Transform):
104 |
105 | def __init__(self, target_height: int, target_width: int):
106 | """Constructs an image resizer.
107 |
108 | Args:
109 | target_height: `int`, number of pixels in height.
110 | target_width: `int`, number of pixels in width.
111 | """
112 | self.target_height = target_height
113 | self.target_width = target_width
114 |
115 | def __call__(self, x: tf.Tensor, y: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
116 | """Returns a resized image."""
117 | return tf.image.resize(x, size=[self.target_height, self.target_width]), y
118 |
119 |
120 | class Normalize(Transform):
121 |
122 | def __init__(self, loc: np.ndarray, scale: np.ndarray):
123 | self.loc = loc
124 | self.scale = scale
125 |
126 | def __call__(self, x: tf.Tensor, y: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
127 | return (x - self.loc) / self.scale, y
128 |
--------------------------------------------------------------------------------
/baselines/diabetic_retinopathy_diagnosis/deep_ensembles/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Script for training and evaluating Deep Ensemble baseline for Diabetic
16 | Retinopathy Diagnosis benchmark."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import functools
23 | import os
24 |
25 | import tensorflow as tf
26 | from absl import app
27 | from absl import flags
28 | from absl import logging
29 |
30 | import bdlb
31 | from baselines.diabetic_retinopathy_diagnosis.deep_ensembles.model import \
32 | predict
33 | from baselines.diabetic_retinopathy_diagnosis.mc_dropout.model import VGGDrop
34 | from bdlb.core import plotting
35 |
36 | tfk = tf.keras
37 |
38 | ##########################
39 | # Command line arguments #
40 | ##########################
41 | FLAGS = flags.FLAGS
42 | flags.DEFINE_spaceseplist(
43 | name="model_checkpoints",
44 | default=None,
45 | help="Paths to checkpoints of the models.",
46 | )
47 | flags.DEFINE_string(
48 | name="output_dir",
49 | default="/tmp",
50 | help="Path to store model, tensorboard and report outputs.",
51 | )
52 | flags.DEFINE_enum(
53 | name="level",
54 | default="medium",
55 | enum_values=["realworld", "medium"],
56 | help="Downstream task level, one of {'medium', 'realworld'}.",
57 | )
58 | flags.DEFINE_integer(
59 | name="batch_size",
60 | default=128,
61 | help="Batch size used for training.",
62 | )
63 | flags.DEFINE_integer(
64 | name="num_epochs",
65 | default=50,
66 | help="Number of epochs of training over the whole training set.",
67 | )
68 | flags.DEFINE_enum(
69 | name="uncertainty",
70 | default="entropy",
71 | enum_values=["stddev", "entropy"],
72 | help="Uncertainty type, one of those defined "
73 | "with `estimator` function.",
74 | )
75 | flags.DEFINE_integer(
76 | name="num_base_filters",
77 | default=32,
78 | help="Number of base filters in convolutional layers.",
79 | )
80 | flags.DEFINE_float(
81 | name="learning_rate",
82 | default=4e-4,
83 | help="ADAM optimizer learning rate.",
84 | )
85 | flags.DEFINE_float(
86 | name="dropout_rate",
87 | default=0.1,
88 | help="The rate of dropout, between [0.0, 1.0).",
89 | )
90 | flags.DEFINE_float(
91 | name="l2_reg",
92 | default=5e-5,
93 | help="The L2-regularization coefficient.",
94 | )
95 |
96 |
97 | def main(argv):
98 |
99 | print(argv)
100 | print(FLAGS)
101 |
102 | ##########################
103 | # Hyperparmeters & Model #
104 | ##########################
105 | input_shape = dict(medium=(256, 256, 3), realworld=(512, 512, 3))[FLAGS.level]
106 |
107 | hparams = dict(dropout_rate=FLAGS.dropout_rate,
108 | num_base_filters=FLAGS.num_base_filters,
109 | learning_rate=FLAGS.learning_rate,
110 | l2_reg=FLAGS.l2_reg,
111 | input_shape=input_shape)
112 | classifiers = list()
113 | for checkpoint in FLAGS.model_checkpoints:
114 | classifier = VGGDrop(**hparams)
115 | classifier.load_weights(checkpoint)
116 | classifier.summary()
117 | classifiers.append(classifier)
118 |
119 | #############
120 | # Load Task #
121 | #############
122 | dtask = bdlb.load(
123 | benchmark="diabetic_retinopathy_diagnosis",
124 | level=FLAGS.level,
125 | batch_size=FLAGS.batch_size,
126 | download_and_prepare=False, # do not download data from this script
127 | )
128 | _, _, ds_test = dtask.datasets
129 |
130 | ##############
131 | # Evaluation #
132 | ##############
133 | dtask.evaluate(functools.partial(predict,
134 | models=classifiers,
135 | type=FLAGS.uncertainty),
136 | dataset=ds_test,
137 | output_dir=FLAGS.output_dir)
138 |
139 |
140 | if __name__ == "__main__":
141 | app.run(main)
142 |
--------------------------------------------------------------------------------
/baselines/diabetic_retinopathy_diagnosis/ensemble_mc_dropout/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Script for training and evaluating Ensemble Monte Carlo Dropout baseline for
16 | Diabetic Retinopathy Diagnosis benchmark."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import functools
23 | import os
24 |
25 | import tensorflow as tf
26 | from absl import app
27 | from absl import flags
28 | from absl import logging
29 |
30 | import bdlb
31 | from baselines.diabetic_retinopathy_diagnosis.ensemble_mc_dropout.model import \
32 | predict
33 | from baselines.diabetic_retinopathy_diagnosis.mc_dropout.model import VGGDrop
34 | from bdlb.core import plotting
35 |
36 | tfk = tf.keras
37 |
38 | ##########################
39 | # Command line arguments #
40 | ##########################
41 | FLAGS = flags.FLAGS
42 | flags.DEFINE_spaceseplist(
43 | name="model_checkpoints",
44 | default=None,
45 | help="Paths to checkpoints of the models.",
46 | )
47 | flags.DEFINE_string(
48 | name="output_dir",
49 | default="/tmp",
50 | help="Path to store model, tensorboard and report outputs.",
51 | )
52 | flags.DEFINE_enum(
53 | name="level",
54 | default="medium",
55 | enum_values=["realworld", "medium"],
56 | help="Downstream task level, one of {'medium', 'realworld'}.",
57 | )
58 | flags.DEFINE_integer(
59 | name="batch_size",
60 | default=128,
61 | help="Batch size used for training.",
62 | )
63 | flags.DEFINE_integer(
64 | name="num_epochs",
65 | default=50,
66 | help="Number of epochs of training over the whole training set.",
67 | )
68 | flags.DEFINE_integer(
69 | name="num_mc_samples",
70 | default=10,
71 | help="Number of Monte Carlo samples used for uncertainty estimation.",
72 | )
73 | flags.DEFINE_enum(
74 | name="uncertainty",
75 | default="entropy",
76 | enum_values=["stddev", "entropy"],
77 | help="Uncertainty type, one of those defined "
78 | "with `estimator` function.",
79 | )
80 | flags.DEFINE_integer(
81 | name="num_base_filters",
82 | default=32,
83 | help="Number of base filters in convolutional layers.",
84 | )
85 | flags.DEFINE_float(
86 | name="learning_rate",
87 | default=4e-4,
88 | help="ADAM optimizer learning rate.",
89 | )
90 | flags.DEFINE_float(
91 | name="dropout_rate",
92 | default=0.1,
93 | help="The rate of dropout, between [0.0, 1.0).",
94 | )
95 | flags.DEFINE_float(
96 | name="l2_reg",
97 | default=5e-5,
98 | help="The L2-regularization coefficient.",
99 | )
100 |
101 |
102 | def main(argv):
103 |
104 | print(argv)
105 | print(FLAGS)
106 |
107 | ##########################
108 | # Hyperparmeters & Model #
109 | ##########################
110 | input_shape = dict(medium=(256, 256, 3), realworld=(512, 512, 3))[FLAGS.level]
111 |
112 | hparams = dict(dropout_rate=FLAGS.dropout_rate,
113 | num_base_filters=FLAGS.num_base_filters,
114 | learning_rate=FLAGS.learning_rate,
115 | l2_reg=FLAGS.l2_reg,
116 | input_shape=input_shape)
117 | classifiers = list()
118 | for checkpoint in FLAGS.model_checkpoints:
119 | classifier = VGGDrop(**hparams)
120 | classifier.load_weights(checkpoint)
121 | classifier.summary()
122 | classifiers.append(classifier)
123 |
124 | #############
125 | # Load Task #
126 | #############
127 | dtask = bdlb.load(
128 | benchmark="diabetic_retinopathy_diagnosis",
129 | level=FLAGS.level,
130 | batch_size=FLAGS.batch_size,
131 | download_and_prepare=False, # do not download data from this script
132 | )
133 | _, _, ds_test = dtask.datasets
134 |
135 | ##############
136 | # Evaluation #
137 | ##############
138 | dtask.evaluate(functools.partial(predict,
139 | models=classifiers,
140 | num_samples=FLAGS.num_mc_samples,
141 | type=FLAGS.uncertainty),
142 | dataset=ds_test,
143 | output_dir=FLAGS.output_dir)
144 |
145 |
146 | if __name__ == "__main__":
147 | app.run(main)
148 |
--------------------------------------------------------------------------------
/baselines/diabetic_retinopathy_diagnosis/deterministic/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Script for training and evaluating a deterministic baseline for Diabetic
16 | Retinopathy Diagnosis benchmark."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import functools
23 | import os
24 |
25 | import tensorflow as tf
26 | from absl import app
27 | from absl import flags
28 | from absl import logging
29 |
30 | import bdlb
31 | from baselines.diabetic_retinopathy_diagnosis.deterministic.model import \
32 | predict
33 | from baselines.diabetic_retinopathy_diagnosis.mc_dropout.model import VGGDrop
34 | from bdlb.core import plotting
35 |
36 | tfk = tf.keras
37 |
38 | ##########################
39 | # Command line arguments #
40 | ##########################
41 | FLAGS = flags.FLAGS
42 | flags.DEFINE_string(
43 | name="output_dir",
44 | default="/tmp",
45 | help="Path to store model, tensorboard and report outputs.",
46 | )
47 | flags.DEFINE_enum(
48 | name="level",
49 | default="medium",
50 | enum_values=["realworld", "medium"],
51 | help="Downstream task level, one of {'medium', 'realworld'}.",
52 | )
53 | flags.DEFINE_integer(
54 | name="batch_size",
55 | default=128,
56 | help="Batch size used for training.",
57 | )
58 | flags.DEFINE_integer(
59 | name="num_epochs",
60 | default=50,
61 | help="Number of epochs of training over the whole training set.",
62 | )
63 | flags.DEFINE_enum(
64 | name="uncertainty",
65 | default="entropy",
66 | enum_values=["stddev", "entropy"],
67 | help="Uncertainty type, one of those defined "
68 | "with `estimator` function.",
69 | )
70 | flags.DEFINE_integer(
71 | name="num_base_filters",
72 | default=32,
73 | help="Number of base filters in convolutional layers.",
74 | )
75 | flags.DEFINE_float(
76 | name="learning_rate",
77 | default=4e-4,
78 | help="ADAM optimizer learning rate.",
79 | )
80 | flags.DEFINE_float(
81 | name="dropout_rate",
82 | default=0.1,
83 | help="The rate of dropout, between [0.0, 1.0).",
84 | )
85 | flags.DEFINE_float(
86 | name="l2_reg",
87 | default=5e-5,
88 | help="The L2-regularization coefficient.",
89 | )
90 |
91 |
92 | def main(argv):
93 |
94 | print(argv)
95 | print(FLAGS)
96 |
97 | ##########################
98 | # Hyperparmeters & Model #
99 | ##########################
100 | input_shape = dict(medium=(256, 256, 3), realworld=(512, 512, 3))[FLAGS.level]
101 |
102 | hparams = dict(dropout_rate=FLAGS.dropout_rate,
103 | num_base_filters=FLAGS.num_base_filters,
104 | learning_rate=FLAGS.learning_rate,
105 | l2_reg=FLAGS.l2_reg,
106 | input_shape=input_shape)
107 | classifier = VGGDrop(**hparams)
108 | classifier.summary()
109 |
110 | #############
111 | # Load Task #
112 | #############
113 | dtask = bdlb.load(
114 | benchmark="diabetic_retinopathy_diagnosis",
115 | level=FLAGS.level,
116 | batch_size=FLAGS.batch_size,
117 | download_and_prepare=False, # do not download data from this script
118 | )
119 | ds_train, ds_validation, ds_test = dtask.datasets
120 |
121 | #################
122 | # Training Loop #
123 | #################
124 | history = classifier.fit(
125 | ds_train,
126 | epochs=FLAGS.num_epochs,
127 | validation_data=ds_validation,
128 | class_weight=dtask.class_weight(),
129 | callbacks=[
130 | tfk.callbacks.TensorBoard(
131 | log_dir=os.path.join(FLAGS.output_dir, "tensorboard"),
132 | update_freq="epoch",
133 | write_graph=True,
134 | histogram_freq=1,
135 | ),
136 | tfk.callbacks.ModelCheckpoint(
137 | filepath=os.path.join(
138 | FLAGS.output_dir,
139 | "checkpoints",
140 | "weights-{epoch}.ckpt",
141 | ),
142 | verbose=1,
143 | save_weights_only=True,
144 | )
145 | ],
146 | )
147 | plotting.tfk_history(history,
148 | output_dir=os.path.join(FLAGS.output_dir, "history"))
149 |
150 | ##############
151 | # Evaluation #
152 | ##############
153 | dtask.evaluate(functools.partial(predict,
154 | model=classifier,
155 | type=FLAGS.uncertainty),
156 | dataset=ds_test,
157 | output_dir=FLAGS.output_dir)
158 |
159 |
160 | if __name__ == "__main__":
161 | app.run(main)
162 |
--------------------------------------------------------------------------------
/baselines/diabetic_retinopathy_diagnosis/mfvi/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Script for training and evaluating Mean-Field Variational Inference baseline
16 | for Diabetic Retinopathy Diagnosis benchmark."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import functools
23 | import os
24 |
25 | import tensorflow as tf
26 | from absl import app
27 | from absl import flags
28 | from absl import logging
29 |
30 | import bdlb
31 | from bdlb.core import plotting
32 | from model import VGGFlipout
33 | from model import predict
34 |
35 | tfk = tf.keras
36 |
37 | ##########################
38 | # Command line arguments #
39 | ##########################
40 | FLAGS = flags.FLAGS
41 | flags.DEFINE_string(
42 | name="output_dir",
43 | default="/tmp",
44 | help="Path to store model, tensorboard and report outputs.",
45 | )
46 | flags.DEFINE_enum(
47 | name="level",
48 | default="medium",
49 | enum_values=["realworld", "medium"],
50 | help="Downstream task level, one of {'medium', 'realworld'}.",
51 | )
52 | flags.DEFINE_integer(
53 | name="batch_size",
54 | default=128,
55 | help="Batch size used for training.",
56 | )
57 | flags.DEFINE_integer(
58 | name="num_epochs",
59 | default=50,
60 | help="Number of epochs of training over the whole training set.",
61 | )
62 | flags.DEFINE_integer(
63 | name="num_mc_samples",
64 | default=10,
65 | help="Number of Monte Carlo samples used for uncertainty estimation.",
66 | )
67 | flags.DEFINE_enum(
68 | name="uncertainty",
69 | default="entropy",
70 | enum_values=["stddev", "entropy"],
71 | help="Uncertainty type, one of those defined "
72 | "with `estimator` function.",
73 | )
74 | flags.DEFINE_integer(
75 | name="num_base_filters",
76 | default=32,
77 | help="Number of base filters in convolutional layers.",
78 | )
79 | flags.DEFINE_float(
80 | name="learning_rate",
81 | default=4e-4,
82 | help="ADAM optimizer learning rate.",
83 | )
84 | flags.DEFINE_float(
85 | name="dropout_rate",
86 | default=0.1,
87 | help="The rate of dropout, between [0.0, 1.0).",
88 | )
89 | flags.DEFINE_float(
90 | name="l2_reg",
91 | default=5e-5,
92 | help="The L2-regularization coefficient.",
93 | )
94 |
95 |
96 | def main(argv):
97 |
98 | print(argv)
99 | print(FLAGS)
100 |
101 | ##########################
102 | # Hyperparmeters & Model #
103 | ##########################
104 | input_shape = dict(medium=(256, 256, 3), realworld=(512, 512, 3))[FLAGS.level]
105 |
106 | hparams = dict(num_base_filters=FLAGS.num_base_filters,
107 | learning_rate=FLAGS.learning_rate,
108 | input_shape=input_shape)
109 | classifier = VGGFlipout(**hparams)
110 | classifier.summary()
111 |
112 | #############
113 | # Load Task #
114 | #############
115 | dtask = bdlb.load(
116 | benchmark="diabetic_retinopathy_diagnosis",
117 | level=FLAGS.level,
118 | batch_size=FLAGS.batch_size,
119 | download_and_prepare=False, # do not download data from this script
120 | )
121 | ds_train, ds_validation, ds_test = dtask.datasets
122 |
123 | #################
124 | # Training Loop #
125 | #################
126 | history = classifier.fit(
127 | ds_train,
128 | epochs=FLAGS.num_epochs,
129 | validation_data=ds_validation,
130 | class_weight=dtask.class_weight(),
131 | callbacks=[
132 | tfk.callbacks.TensorBoard(
133 | log_dir=os.path.join(FLAGS.output_dir, "tensorboard"),
134 | update_freq="epoch",
135 | write_graph=True,
136 | histogram_freq=1,
137 | ),
138 | tfk.callbacks.ModelCheckpoint(
139 | filepath=os.path.join(
140 | FLAGS.output_dir,
141 | "checkpoints",
142 | "weights-{epoch}.ckpt",
143 | ),
144 | verbose=1,
145 | save_weights_only=True,
146 | )
147 | ],
148 | )
149 | plotting.tfk_history(history,
150 | output_dir=os.path.join(FLAGS.output_dir, "history"))
151 |
152 | ##############
153 | # Evaluation #
154 | ##############
155 | dtask.evaluate(functools.partial(predict,
156 | model=classifier,
157 | num_samples=FLAGS.num_mc_samples,
158 | type=FLAGS.uncertainty),
159 | dataset=ds_test,
160 | output_dir=FLAGS.output_dir)
161 |
162 |
163 | if __name__ == "__main__":
164 | app.run(main)
165 |
--------------------------------------------------------------------------------
/baselines/diabetic_retinopathy_diagnosis/mc_dropout/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Script for training and evaluating Monte Carlo baseline for Diabetic
16 | Retinopathy Diagnosis benchmark."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import functools
23 | import os
24 |
25 | import tensorflow as tf
26 | from absl import app
27 | from absl import flags
28 | from absl import logging
29 |
30 | import bdlb
31 | from baselines.diabetic_retinopathy_diagnosis.mc_dropout.model import VGGDrop
32 | from baselines.diabetic_retinopathy_diagnosis.mc_dropout.model import predict
33 | from bdlb.core import plotting
34 |
35 | tfk = tf.keras
36 |
37 | ##########################
38 | # Command line arguments #
39 | ##########################
40 | FLAGS = flags.FLAGS
41 | flags.DEFINE_string(
42 | name="output_dir",
43 | default="/tmp",
44 | help="Path to store model, tensorboard and report outputs.",
45 | )
46 | flags.DEFINE_enum(
47 | name="level",
48 | default="medium",
49 | enum_values=["realworld", "medium"],
50 | help="Downstream task level, one of {'medium', 'realworld'}.",
51 | )
52 | flags.DEFINE_integer(
53 | name="batch_size",
54 | default=128,
55 | help="Batch size used for training.",
56 | )
57 | flags.DEFINE_integer(
58 | name="num_epochs",
59 | default=50,
60 | help="Number of epochs of training over the whole training set.",
61 | )
62 | flags.DEFINE_integer(
63 | name="num_mc_samples",
64 | default=10,
65 | help="Number of Monte Carlo samples used for uncertainty estimation.",
66 | )
67 | flags.DEFINE_enum(
68 | name="uncertainty",
69 | default="entropy",
70 | enum_values=["stddev", "entropy"],
71 | help="Uncertainty type, one of those defined "
72 | "with `estimator` function.",
73 | )
74 | flags.DEFINE_integer(
75 | name="num_base_filters",
76 | default=32,
77 | help="Number of base filters in convolutional layers.",
78 | )
79 | flags.DEFINE_float(
80 | name="learning_rate",
81 | default=4e-4,
82 | help="ADAM optimizer learning rate.",
83 | )
84 | flags.DEFINE_float(
85 | name="dropout_rate",
86 | default=0.1,
87 | help="The rate of dropout, between [0.0, 1.0).",
88 | )
89 | flags.DEFINE_float(
90 | name="l2_reg",
91 | default=5e-5,
92 | help="The L2-regularization coefficient.",
93 | )
94 |
95 |
96 | def main(argv):
97 |
98 | print(argv)
99 | print(FLAGS)
100 |
101 | ##########################
102 | # Hyperparmeters & Model #
103 | ##########################
104 | input_shape = dict(medium=(256, 256, 3), realworld=(512, 512, 3))[FLAGS.level]
105 |
106 | hparams = dict(dropout_rate=FLAGS.dropout_rate,
107 | num_base_filters=FLAGS.num_base_filters,
108 | learning_rate=FLAGS.learning_rate,
109 | l2_reg=FLAGS.l2_reg,
110 | input_shape=input_shape)
111 | classifier = VGGDrop(**hparams)
112 | classifier.summary()
113 |
114 | #############
115 | # Load Task #
116 | #############
117 | dtask = bdlb.load(
118 | benchmark="diabetic_retinopathy_diagnosis",
119 | level=FLAGS.level,
120 | batch_size=FLAGS.batch_size,
121 | download_and_prepare=False, # do not download data from this script
122 | )
123 | ds_train, ds_validation, ds_test = dtask.datasets
124 |
125 | #################
126 | # Training Loop #
127 | #################
128 | history = classifier.fit(
129 | ds_train,
130 | epochs=FLAGS.num_epochs,
131 | validation_data=ds_validation,
132 | class_weight=dtask.class_weight(),
133 | callbacks=[
134 | tfk.callbacks.TensorBoard(
135 | log_dir=os.path.join(FLAGS.output_dir, "tensorboard"),
136 | update_freq="epoch",
137 | write_graph=True,
138 | histogram_freq=1,
139 | ),
140 | tfk.callbacks.ModelCheckpoint(
141 | filepath=os.path.join(
142 | FLAGS.output_dir,
143 | "checkpoints",
144 | "weights-{epoch}.ckpt",
145 | ),
146 | verbose=1,
147 | save_weights_only=True,
148 | )
149 | ],
150 | )
151 | plotting.tfk_history(history,
152 | output_dir=os.path.join(FLAGS.output_dir, "history"))
153 |
154 | ##############
155 | # Evaluation #
156 | ##############
157 | dtask.evaluate(functools.partial(predict,
158 | model=classifier,
159 | num_samples=FLAGS.num_mc_samples,
160 | type=FLAGS.uncertainty),
161 | dataset=ds_test,
162 | output_dir=FLAGS.output_dir)
163 |
164 |
165 | if __name__ == "__main__":
166 | app.run(main)
167 |
--------------------------------------------------------------------------------
/bdlb/core/plotting.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Helper functions for visualization."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 | from typing import Dict
23 | from typing import Optional
24 | from typing import Text
25 |
26 | import matplotlib.pyplot as plt
27 | import pandas as pd
28 | import tensorflow as tf
29 |
30 | tfk = tf.keras
31 |
32 |
33 | def tfk_history(
34 | history: tfk.callbacks.History,
35 | output_dir: Optional[Text] = None,
36 | **ax_set_kwargs,
37 | ):
38 | """Visualization of `tensorflow.keras.callbacks.History`, similar to
39 | `TensorBoard`, in train and validation.
40 |
41 | Args:
42 | history: `tensorflow.keras.callbacks.History`, the logs of
43 | training a `tensorflow.keras.Model`.
44 | output_dir: (optional) `str`, the directory name to
45 | store the figures.
46 | """
47 | if not isinstance(history, tfk.callbacks.History):
48 | raise TypeError("`history` was expected to be of type "
49 | "`tensorflow.keras.callbacks.History`, "
50 | "but {} was provided.".format(type(history)))
51 | for metric in [k for k in history.history.keys() if not "val_" in k]:
52 | fig, ax = plt.subplots()
53 | ax.plot(history.history.get(metric), label="train")
54 | ax.plot(history.history.get("val_{}".format(metric)), label="validation")
55 | ax.set(title=metric, xlabel="epochs", **ax_set_kwargs)
56 | ax.legend()
57 | fig.tight_layout()
58 | if isinstance(output_dir, str):
59 | os.makedirs(output_dir, exist_ok=True)
60 | fig.savefig(
61 | os.path.join(output_dir, "{}.pdf".format(metric)),
62 | trasparent=True,
63 | dpi=300,
64 | format="pdf",
65 | )
66 | fig.show()
67 |
68 |
69 | def leaderboard(
70 | benchmark: Text,
71 | results: Optional[Dict[Text, pd.DataFrame]] = None,
72 | output_dir: Optional[Text] = None,
73 | leaderboard_dir: Optional[Text] = None,
74 | **ax_set_kwargs,
75 | ):
76 | """Generates a leaderboard for all metrics in `benchmark`, by appending the
77 | (optional) `results`.
78 |
79 | Args:
80 | benchmark: `str`, the registerd name of `bdlb.Benchmark`.
81 | results: (optional) `dict`, dictionary of `pandas.DataFrames`
82 | with the results from a new method to be plotted against
83 | the leaderboard.
84 | leaderboard_dir: (optional) `str`, the path to the parent
85 | directory with all the leaderboard results.
86 | output_dir: (optional) `str`, the directory name to
87 | store the figures.
88 | """
89 | from .constants import BDLB_ROOT_DIR
90 | COLORS = plt.rcParams['axes.prop_cycle'].by_key()['color']
91 | MARKERS = ["o", "D", "s", "8", "^", "*"]
92 |
93 | # The assumes path for stored baselines records
94 | leaderboard_dir = leaderboard_dir or os.path.join(BDLB_ROOT_DIR,
95 | "leaderboard")
96 | benchmark_dir = os.path.join(leaderboard_dir, benchmark)
97 | if not os.path.exists(benchmark_dir):
98 | ValueError("No leaderboard data found at {}".format(benchmark_dir))
99 |
100 | # Metrics for which values are stored
101 | metrics = [
102 | x for x in os.listdir(benchmark_dir)
103 | if os.path.isdir(os.path.join(benchmark_dir, x))
104 | ]
105 | for metric in metrics:
106 | fig, ax = plt.subplots()
107 | # Iterate over baselines
108 | baselines = [
109 | x for x in os.listdir(os.path.join(benchmark_dir, metric))
110 | if ".csv" in x
111 | ]
112 | for b, baseline in enumerate(baselines):
113 | baseline = baseline.replace(".csv", "")
114 | # Fetch results
115 | df = pd.read_csv(
116 | os.path.join(benchmark_dir, metric, "{}.csv".format(baseline)))
117 | # Parse columns
118 | retained_data = df["retained_data"]
119 | mean = df["mean"]
120 | std = df["std"]
121 | # Visualize mean with standard error
122 | ax.plot(
123 | retained_data,
124 | mean,
125 | label=baseline,
126 | color=COLORS[b % len(COLORS)],
127 | marker=MARKERS[b % len(MARKERS)],
128 | )
129 | ax.fill_between(
130 | retained_data,
131 | mean - std,
132 | mean + std,
133 | color=COLORS[b % len(COLORS)],
134 | alpha=0.25,
135 | )
136 | if results is not None:
137 | # Plot results from dictionary
138 | if metric in results:
139 | df = results[metric]
140 | baseline = df.name if hasattr(df, "name") else "new_method"
141 | # Parse columns
142 | retained_data = df["retained_data"]
143 | mean = df["mean"]
144 | std = df["std"]
145 | # Visualize mean with standard error
146 | ax.plot(
147 | retained_data,
148 | mean,
149 | label=baseline,
150 | color=COLORS[(b + 1) % len(COLORS)],
151 | marker=MARKERS[(b + 1) % len(MARKERS)],
152 | )
153 | ax.fill_between(
154 | retained_data,
155 | mean - std,
156 | mean + std,
157 | color=COLORS[(b + 1) % len(COLORS)],
158 | alpha=0.25,
159 | )
160 | ax.set(xlabel="retained data", ylabel=metric)
161 | ax.legend()
162 | fig.tight_layout()
163 | if isinstance(output_dir, str):
164 | os.makedirs(output_dir, exist_ok=True)
165 | fig.savefig(
166 | os.path.join(output_dir, "{}.pdf".format(metric)),
167 | trasparent=True,
168 | dpi=300,
169 | format="pdf",
170 | )
171 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Bayesian Deep Learning Benchmarks
2 |
3 | **This repository is no longer being updated.**
4 |
5 | Please refer to the [Diabetic Retinopathy Detection implementation in Google's 'uncertainty-baselines' repo](https://github.com/google/uncertainty-baselines/tree/master/baselines/diabetic_retinopathy_detection) for up-to-date baseline implementations.
6 |
7 | ## Overview
8 | In order to make real-world difference with **Bayesian Deep Learning** (BDL) tools, the tools must scale to real-world settings. And for that we, the research community, must be able to evaluate our inference tools (and iterate quickly) with real-world benchmark tasks. We should be able to do this without necessarily worrying about application-specific domain knowledge, like the expertise often required in medical applications for example. We require benchmarks to test for inference robustness, performance, and accuracy, in addition to cost and effort of development. These benchmarks should be at a variety of scales, ranging from toy `MNIST`-scale benchmarks for fast development cycles, to large data benchmarks which are truthful to real-world applications, capturing their constraints.
9 |
10 | Our BDL benchmarks should
11 | * provide a transparent, modular and consistent interface for the evaluation of deep probabilistic models on a variety of _downstream tasks_;
12 | * rely on expert-driven metrics of uncertainty quality (actual applications making use of BDL uncertainty in the real-world), but abstract away the expert-knowledge and eliminate the boilerplate steps necessary for running experiments on real-world datasets;
13 | * make it easy to compare the performance of new models against _well tuned baselines_, models that have been well-adopted by the machine learning community, under a fair and realistic setting (e.g., computational resources, model sizes, datasets);
14 | * provide reference implementations of baseline models (e.g., Monte Carlo Dropout Inference, Mean Field Variational Inference, Deep Ensembles), enabling rapid prototyping and easy development of new tools;
15 | * be independent of specific deep learning frameworks (e.g., not depend on `TensorFlow`, `PyTorch`, etc.), and integrate with the SciPy ecosystem (i.e., `NumPy`, `Pandas`, `Matplotlib`). Benchmarks are framework-agnostic, while baselines are framework-dependent.
16 |
17 | In this repo we strive to provide such well-needed benchmarks for the BDL community, and collect and maintain new baselines and benchmarks contributed by the community. **A colab notebook demonstrating the MNIST-like workflow of our benchmarks is [available here](notebooks/diabetic_retinopathy_diagnosis.ipynb)**.
18 |
19 | **We highly encourage you to contribute your models as new *baselines* for others to compete against, as well as contribute new *benchmarks* for others to evaluate their models on!**
20 |
21 | ## List of Benchmarks
22 |
23 | **Bayesian Deep Learning Benchmarks** (BDL Benchmarks or `bdlb` for short), is an open-source framework that aims to bridge the gap between the design of deep probabilistic machine learning models and their application to real-world problems. Our currently supported benchmarks are:
24 |
25 | - [x] [Diabetic Retinopathy Diagnosis](baselines/diabetic_retinopathy_diagnosis) (in [`alpha`](https://github.com/OATML/bdl-benchmarks/tree/alpha/), following [Leibig et al.](https://www.nature.com/articles/s41598-017-17876-z))
26 | - [x] [Deterministic](baselines/diabetic_retinopathy_diagnosis/deterministic)
27 | - [x] [Monte Carlo Dropout](baselines/diabetic_retinopathy_diagnosis/mc_dropout) (following [Gal and Ghahramani, 2015](https://arxiv.org/abs/1506.02142))
28 | - [x] [Mean-Field Variational Inference](baselines/diabetic_retinopathy_diagnosis/mfvi) (following [Peterson and Anderson, 1987](https://pdfs.semanticscholar.org/37fa/18c66b8130b9f9748d9c94472c5671fb5622.pdf), [Wen et al., 2018](https://arxiv.org/abs/1803.04386))
29 | - [x] [Deep Ensembles](baselines/diabetic_retinopathy_diagnosis/deep_ensembles) (following [Lakshminarayanan et al., 2016](https://arxiv.org/abs/1612.01474))
30 | - [x] [Ensemble MC Dropout](baselines/diabetic_retinopathy_diagnosis/deep_ensembles) (following [Smith and Gal, 2018](https://arxiv.org/abs/1803.08533))
31 |
32 | - [ ] Autonomous Vehicle's Scene Segmentation (in `pre-alpha`, following [Mukhoti et al.](https://arxiv.org/abs/1811.12709))
33 | - [ ] Galaxy Zoo (in `pre-alpha`, following [Walmsley et al.](https://arxiv.org/abs/1905.07424))
34 | - [ ] Fishyscapes (in `pre-alpha`, following [Blum et al.](https://arxiv.org/abs/1904.03215))
35 |
36 |
37 | ## Installation
38 |
39 | *BDL Benchmarks* is shipped as a PyPI package (Python3 compatible) installable as:
40 |
41 | ```
42 | pip3 install git+https://github.com/OATML/bdl-benchmarks.git
43 | ```
44 |
45 | The data downloading and preparation is benchmark-specific, and you can follow the relevant guides at `baselines//README.md` (e.g. [`baselines/diabetic_retinopathy_diagnosis/README.md`](baselines/diabetic_retinopathy_diagnosis/README.md)).
46 |
47 |
48 | ## Examples
49 |
50 | For example, the [Diabetic Retinopathy Diagnosis](baselines/diabetic_retinopathy_diagnosis) benchmark comes with several baselines, including MC Dropout, MFVI, Deep Ensembles, and more. These models are trained with images of blood vessels in the eye:
51 |
52 |
53 |
54 |
55 |
56 | The models try to predict diabetic retinopathy, and use their uncertainty for *prescreening* (sending patients the model is uncertain about to an expert for further examination). When you implement a new model, you can easily benchmark your model against existing baseline results provided in the repo, and generate plots using expert metrics (such as the AUC of retained data when referring 50% most uncertain patients to an expert):
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 | **You can even play with a [colab notebook](notebooks/diabetic_retinopathy_diagnosis.ipynb) to see the workflow of the benchmark**, and contribute your model for others to benchmark against.
65 |
66 |
67 | ## Cite as
68 |
69 | Please cite individual benchmarks when you use these, as well as the baselines you compare against. For the [Diabetic Retinopathy Diagnosis](baselines/diabetic_retinopathy_diagnosis) benchmark please see [here](baselines/diabetic_retinopathy_diagnosis#cite-as).
70 |
71 | ## Acknowledgements
72 |
73 | The repository is developed and maintained by the [Oxford Applied and Theoretical Machine Learning](http://oatml.cs.ox.ac.uk/) group, with sponsorship from:
74 |
75 |
76 |
77 |  |
78 |  |
79 |  |
80 |  |
81 |
82 |
83 |
84 | ## Contact Us
85 |
86 | Email us for questions at oatml@cs.ox.ac.uk, or submit any issues to improve the framework.
87 |
--------------------------------------------------------------------------------
/baselines/diabetic_retinopathy_diagnosis/mfvi/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Model definition of the VGGish network for Mean-Field Variational Inference
16 | baseline."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import collections
23 | import functools
24 |
25 |
26 | def VGGFlipout(num_base_filters, learning_rate, input_shape):
27 | """VGG-like model with Flipout for diabetic retinopathy diagnosis.
28 |
29 | Args:
30 | num_base_filters: `int`, number of convolution filters in the
31 | first layer.
32 | learning_rate: `float`, ADAM optimizer learning rate.
33 | input_shape: `iterable`, the shape of the images in the input layer.
34 |
35 | Returns:
36 | A tensorflow.keras.Sequential VGG-like model with flipout.
37 | """
38 | import tensorflow as tf
39 | tfk = tf.keras
40 | tfkl = tfk.layers
41 | import tensorflow_probability as tfp
42 | tfpl = tfp.layers
43 | from bdlb.diabetic_retinopathy_diagnosis.benchmark import DiabeticRetinopathyDiagnosisBecnhmark
44 |
45 | # Feedforward neural network
46 | model = tfk.Sequential([
47 | tfkl.InputLayer(input_shape),
48 | # Block 1
49 | tfpl.Convolution2DFlipout(filters=num_base_filters,
50 | kernel_size=3,
51 | strides=(2, 2),
52 | padding="same"),
53 | tfkl.Activation("relu"),
54 | tfkl.MaxPooling2D(pool_size=3, strides=(2, 2), padding="same"),
55 | # Block 2
56 | tfpl.Convolution2DFlipout(filters=num_base_filters,
57 | kernel_size=3,
58 | strides=(1, 1),
59 | padding="same"),
60 | tfkl.Activation("relu"),
61 | tfpl.Convolution2DFlipout(filters=num_base_filters,
62 | kernel_size=3,
63 | strides=(1, 1),
64 | padding="same"),
65 | tfkl.Activation("relu"),
66 | tfkl.MaxPooling2D(pool_size=3, strides=(2, 2), padding="same"),
67 | # Block 3
68 | tfpl.Convolution2DFlipout(filters=num_base_filters * 2,
69 | kernel_size=3,
70 | strides=(1, 1),
71 | padding="same"),
72 | tfkl.Activation("relu"),
73 | tfpl.Convolution2DFlipout(filters=num_base_filters * 2,
74 | kernel_size=3,
75 | strides=(1, 1),
76 | padding="same"),
77 | tfkl.Activation("relu"),
78 | tfkl.MaxPooling2D(pool_size=3, strides=(2, 2), padding="same"),
79 | # Block 4
80 | tfpl.Convolution2DFlipout(filters=num_base_filters * 4,
81 | kernel_size=3,
82 | strides=(1, 1),
83 | padding="same"),
84 | tfkl.Activation("relu"),
85 | tfpl.Convolution2DFlipout(filters=num_base_filters * 4,
86 | kernel_size=3,
87 | strides=(1, 1),
88 | padding="same"),
89 | tfkl.Activation("relu"),
90 | tfpl.Convolution2DFlipout(filters=num_base_filters * 4,
91 | kernel_size=3,
92 | strides=(1, 1),
93 | padding="same"),
94 | tfkl.Activation("relu"),
95 | tfpl.Convolution2DFlipout(filters=num_base_filters * 4,
96 | kernel_size=3,
97 | strides=(1, 1),
98 | padding="same"),
99 | tfkl.Activation("relu"),
100 | tfkl.MaxPooling2D(pool_size=3, strides=(2, 2), padding="same"),
101 | # Block 5
102 | tfpl.Convolution2DFlipout(filters=num_base_filters * 8,
103 | kernel_size=3,
104 | strides=(1, 1),
105 | padding="same"),
106 | tfkl.Activation("relu"),
107 | tfpl.Convolution2DFlipout(filters=num_base_filters * 8,
108 | kernel_size=3,
109 | strides=(1, 1),
110 | padding="same"),
111 | tfkl.Activation("relu"),
112 | tfpl.Convolution2DFlipout(filters=num_base_filters * 8,
113 | kernel_size=3,
114 | strides=(1, 1),
115 | padding="same"),
116 | tfkl.Activation("relu"),
117 | tfpl.Convolution2DFlipout(filters=num_base_filters * 8,
118 | kernel_size=3,
119 | strides=(1, 1),
120 | padding="same"),
121 | tfkl.Activation("relu"),
122 | # Global poolings
123 | tfkl.Lambda(lambda x: tfk.backend.concatenate(
124 | [tfkl.GlobalAvgPool2D()
125 | (x), tfkl.GlobalMaxPool2D()(x)], axis=1)),
126 | # Fully-connected
127 | tfpl.DenseFlipout(1,),
128 | tfkl.Activation("sigmoid")
129 | ])
130 |
131 | model.compile(loss=DiabeticRetinopathyDiagnosisBecnhmark.loss(),
132 | optimizer=tfk.optimizers.Adam(learning_rate),
133 | metrics=DiabeticRetinopathyDiagnosisBecnhmark.metrics())
134 |
135 | return model
136 |
137 |
138 | def predict(x, model, num_samples, type="entropy"):
139 | """Monte Carlo dropout uncertainty estimator.
140 |
141 | Args:
142 | x: `numpy.ndarray`, datapoints from input space,
143 | with shape [B, H, W, 3], where B the batch size and
144 | H, W the input images height and width accordingly.
145 | model: `tensorflow.keras.Model`, a probabilistic model,
146 | which accepts input with shape [B, H, W, 3] and
147 | outputs sigmoid probability [0.0, 1.0].
148 | num_samples: `int`, number of Monte Carlo samples
149 | (i.e. forward passes from dropout) used for
150 | the calculation of predictive mean and uncertainty.
151 | type: (optional) `str`, type of uncertainty returns,
152 | one of {"entropy", "stddev"}.
153 |
154 | Returns:
155 | mean: `numpy.ndarray`, predictive mean, with shape [B].
156 | uncertainty: `numpy.ndarray`, ncertainty in prediction,
157 | with shape [B].
158 | """
159 | import numpy as np
160 | import scipy.stats
161 |
162 | # Get shapes of data
163 | B, _, _, _ = x.shape
164 |
165 | # Monte Carlo samples from different dropout mask at test time
166 | mc_samples = np.asarray([model(x) for _ in range(num_samples)]).reshape(-1, B)
167 |
168 | # Bernoulli output distribution
169 | dist = scipy.stats.bernoulli(mc_samples.mean(axis=0))
170 |
171 | # Predictive mean calculation
172 | mean = dist.mean()
173 |
174 | # Use predictive entropy for uncertainty
175 | if type == "entropy":
176 | uncertainty = dist.entropy()
177 | # Use predictive standard deviation for uncertainty
178 | elif type == "stddev":
179 | uncertainty = dist.std()
180 | else:
181 | raise ValueError(
182 | "Unrecognized type={} provided, use one of {'entropy', 'stddev'}".
183 | format(type))
184 |
185 | return mean, uncertainty
186 |
--------------------------------------------------------------------------------
/bdlb/diabetic_retinopathy_diagnosis/tfds_adapter.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | import csv
17 | import io
18 | import os
19 | from typing import Sequence
20 |
21 | import numpy as np
22 | import tensorflow as tf
23 | import tensorflow_datasets as tfds
24 |
25 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
26 |
27 | cv2 = tfds.core.lazy_imports.cv2
28 |
29 |
30 | class DiabeticRetinopathyDiagnosisConfig(tfds.core.BuilderConfig):
31 | """BuilderConfig for DiabeticRetinopathyDiagnosis."""
32 |
33 | def __init__(
34 | self,
35 | target_height: int,
36 | target_width: int,
37 | crop: bool = False,
38 | scale: int = 500,
39 | **kwargs,
40 | ):
41 | """BuilderConfig for DiabeticRetinopathyDiagnosis.
42 |
43 | Args:
44 | target_height: `int`, number of pixels in height.
45 | target_width: `int`, number of pixels in width.
46 | scale: (optional) `int`, the radius of the neighborhood to apply
47 | Gaussian blur filtering.
48 | **kwargs: keyword arguments forward to super.
49 | """
50 | super(DiabeticRetinopathyDiagnosisConfig, self).__init__(**kwargs)
51 | self._target_height = target_height
52 | self._target_width = target_width
53 | self._scale = scale
54 |
55 | @property
56 | def target_height(self) -> int:
57 | return self._target_height
58 |
59 | @property
60 | def target_width(self) -> int:
61 | return self._target_width
62 |
63 | @property
64 | def scale(self) -> int:
65 | return self._scale
66 |
67 |
68 | class DiabeticRetinopathyDiagnosis(tfds.image.DiabeticRetinopathyDetection):
69 |
70 | BUILDER_CONFIGS: Sequence[DiabeticRetinopathyDiagnosisConfig] = [
71 | DiabeticRetinopathyDiagnosisConfig(
72 | name="medium",
73 | version="0.0.1",
74 | description="Images for Medium level.",
75 | target_height=256,
76 | target_width=256,
77 | ),
78 | DiabeticRetinopathyDiagnosisConfig(
79 | name="realworld",
80 | version="0.0.1",
81 | description="Images for RealWorld level.",
82 | target_height=512,
83 | target_width=512,
84 | ),
85 | ]
86 |
87 | def _info(self) -> tfds.core.DatasetInfo:
88 | return tfds.core.DatasetInfo(
89 | builder=self,
90 | description="A large set of high-resolution retina images taken under "
91 | "a variety of imaging conditions. "
92 | "Ehanced contrast and resized to {}x{}.".format(
93 | self.builder_config.target_height,
94 | self.builder_config.target_width),
95 | features=tfds.features.FeaturesDict({
96 | "name":
97 | tfds.features.Text(), # patient ID + eye. eg: "4_left".
98 | "image":
99 | tfds.features.Image(shape=(
100 | self.builder_config.target_height,
101 | self.builder_config.target_width,
102 | 3,
103 | )),
104 | # 0: (no DR)
105 | # 1: (with DR)
106 | "label":
107 | tfds.features.ClassLabel(num_classes=2),
108 | }),
109 | urls=["https://www.kaggle.com/c/diabetic-retinopathy-detection/data"],
110 | citation=tfds.image.diabetic_retinopathy_detection._CITATION,
111 | )
112 |
113 | def _generate_examples(self, images_dir_path, csv_path=None, csv_usage=None):
114 | """Yields Example instances from given CSV. Applies contrast enhancement as
115 | in https://github.com/btgraham/SparseConvNet/tree/kaggle_Diabetic_Retinopat
116 | hy_competition. Turns the multiclass (i.e. 5 classes) problem to binary
117 | classification according to
118 | https://www.nature.com/articles/s41598-017-17876-z.pdf.
119 |
120 | Args:
121 | images_dir_path: path to dir in which images are stored.
122 | csv_path: optional, path to csv file with two columns: name of image and
123 | label. If not provided, just scan image directory, don't set labels.
124 | csv_usage: optional, subset of examples from the csv file to use based on
125 | the "Usage" column from the csv.
126 | """
127 | if csv_path:
128 | with tf.io.gfile.GFile(csv_path) as csv_f:
129 | reader = csv.DictReader(csv_f)
130 | data = [(row["image"], int(row["level"]))
131 | for row in reader
132 | if csv_usage is None or row["Usage"] == csv_usage]
133 | else:
134 | data = [(fname[:-5], -1)
135 | for fname in tf.io.gfile.listdir(images_dir_path)
136 | if fname.endswith(".jpeg")]
137 | for name, label in data:
138 | record = {
139 | "name":
140 | name,
141 | "image":
142 | self._preprocess(
143 | tf.io.gfile.GFile("%s/%s.jpeg" % (images_dir_path, name),
144 | mode="rb"),
145 | target_height=self.builder_config.target_height,
146 | target_width=self.builder_config.target_width,
147 | ),
148 | "label":
149 | int(label > 1),
150 | }
151 |
152 | yield record
153 |
154 | @classmethod
155 | def _preprocess(
156 | cls,
157 | image_fobj,
158 | target_height: int,
159 | target_width: int,
160 | crop: bool = False,
161 | scale: int = 500,
162 | ) -> io.BytesIO:
163 | """Resize an image to have (roughly) the given number of target pixels.
164 |
165 | Args:
166 | image_fobj: File object containing the original image.
167 | target_height: `int`, number of pixels in height.
168 | target_width: `int`, number of pixels in width.
169 | crops: (optional) `bool`, if True crops the centre of the original
170 | image t the target size.
171 | scale: (optional) `int`, the radius of the neighborhood to apply
172 | Gaussian blur filtering.
173 |
174 | Returns:
175 | A file object.
176 | """
177 | # Decode image using OpenCV2.
178 | image = cv2.imdecode(np.fromstring(image_fobj.read(), dtype=np.uint8),
179 | flags=3)
180 | try:
181 | a = cls._get_radius(image, scale)
182 | b = np.zeros(a.shape)
183 | cv2.circle(img=b,
184 | center=(a.shape[1] // 2, a.shape[0] // 2),
185 | radius=int(scale * 0.9),
186 | color=(1, 1, 1),
187 | thickness=-1,
188 | lineType=8,
189 | shift=0)
190 | image = cv2.addWeighted(src1=a,
191 | alpha=4,
192 | src2=cv2.GaussianBlur(
193 | src=a, ksize=(0, 0), sigmaX=scale // 30),
194 | beta=-4,
195 | gamma=128) * b + 128 * (1 - b)
196 | except cv2.error:
197 | pass
198 | # Reshape image to target size
199 | image = cv2.resize(image, (target_height, target_width))
200 | # Encode the image with quality=72 and store it in a BytesIO object.
201 | _, buff = cv2.imencode(".jpg", image, [int(cv2.IMWRITE_JPEG_QUALITY), 72])
202 | return io.BytesIO(buff.tostring())
203 |
204 | @staticmethod
205 | def _get_radius(img: np.ndarray, scale: int) -> np.ndarray:
206 | """Returns radius of the circle to use.
207 |
208 | Args:
209 | img: `numpy.ndarray`, an image, with shape [height, width, 3].
210 | scale: `int`, the radius of the neighborhood.
211 |
212 | Returns:
213 | A resized image.
214 | """
215 | x = img[img.shape[0] // 2, ...].sum(axis=1)
216 | r = 0.5 * (x > x.mean() // 10).sum()
217 | s = scale * 1.0 / r
218 | return cv2.resize(img, (0, 0), fx=s, fy=s)
219 |
--------------------------------------------------------------------------------
/baselines/diabetic_retinopathy_diagnosis/mc_dropout/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Model definition of the VGGish network for Monte Carlo Dropout baseline."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 |
22 | def VGGDrop(dropout_rate, num_base_filters, learning_rate, l2_reg, input_shape):
23 | """VGG-like model with dropout for diabetic retinopathy diagnosis.
24 |
25 | Args:
26 | dropout_rate: `float`, the rate of dropout, between [0.0, 1.0).
27 | num_base_filters: `int`, number of convolution filters in the
28 | first layer.
29 | learning_rate: `float`, ADAM optimizer learning rate.
30 | l2_reg: `float`, the L2-regularization coefficient.
31 | input_shape: `iterable`, the shape of the images in the input layer.
32 |
33 | Returns:
34 | A tensorflow.keras.Sequential VGG-like model with dropout.
35 | """
36 | import tensorflow as tf
37 | tfk = tf.keras
38 | tfkl = tfk.layers
39 | from bdlb.diabetic_retinopathy_diagnosis.benchmark import DiabeticRetinopathyDiagnosisBecnhmark
40 |
41 | # Feedforward neural network
42 | model = tfk.Sequential([
43 | tfkl.InputLayer(input_shape),
44 | # Block 1
45 | tfkl.Conv2D(filters=num_base_filters,
46 | kernel_size=3,
47 | strides=(2, 2),
48 | padding="same",
49 | kernel_regularizer=tfk.regularizers.l2(l2_reg)),
50 | tfkl.Activation("relu"),
51 | tfkl.Dropout(dropout_rate),
52 | tfkl.MaxPooling2D(pool_size=3, strides=(2, 2), padding="same"),
53 | # Block 2
54 | tfkl.Conv2D(filters=num_base_filters,
55 | kernel_size=3,
56 | strides=(1, 1),
57 | padding="same",
58 | kernel_regularizer=tfk.regularizers.l2(l2_reg)),
59 | tfkl.Activation("relu"),
60 | tfkl.Dropout(dropout_rate),
61 | tfkl.Conv2D(filters=num_base_filters,
62 | kernel_size=3,
63 | strides=(1, 1),
64 | padding="same",
65 | kernel_regularizer=tfk.regularizers.l2(l2_reg)),
66 | tfkl.Activation("relu"),
67 | tfkl.Dropout(dropout_rate),
68 | tfkl.MaxPooling2D(pool_size=3, strides=(2, 2), padding="same"),
69 | # Block 3
70 | tfkl.Conv2D(filters=num_base_filters * 2,
71 | kernel_size=3,
72 | strides=(1, 1),
73 | padding="same",
74 | kernel_regularizer=tfk.regularizers.l2(l2_reg)),
75 | tfkl.Activation("relu"),
76 | tfkl.Dropout(dropout_rate),
77 | tfkl.Conv2D(filters=num_base_filters * 2,
78 | kernel_size=3,
79 | strides=(1, 1),
80 | padding="same",
81 | kernel_regularizer=tfk.regularizers.l2(l2_reg)),
82 | tfkl.Activation("relu"),
83 | tfkl.Dropout(dropout_rate),
84 | tfkl.MaxPooling2D(pool_size=3, strides=(2, 2), padding="same"),
85 | # Block 4
86 | tfkl.Conv2D(filters=num_base_filters * 4,
87 | kernel_size=3,
88 | strides=(1, 1),
89 | padding="same",
90 | kernel_regularizer=tfk.regularizers.l2(l2_reg)),
91 | tfkl.Activation("relu"),
92 | tfkl.Dropout(dropout_rate),
93 | tfkl.Conv2D(filters=num_base_filters * 4,
94 | kernel_size=3,
95 | strides=(1, 1),
96 | padding="same",
97 | kernel_regularizer=tfk.regularizers.l2(l2_reg)),
98 | tfkl.Activation("relu"),
99 | tfkl.Dropout(dropout_rate),
100 | tfkl.Conv2D(filters=num_base_filters * 4,
101 | kernel_size=3,
102 | strides=(1, 1),
103 | padding="same",
104 | kernel_regularizer=tfk.regularizers.l2(l2_reg)),
105 | tfkl.Activation("relu"),
106 | tfkl.Dropout(dropout_rate),
107 | tfkl.Conv2D(filters=num_base_filters * 4,
108 | kernel_size=3,
109 | strides=(1, 1),
110 | padding="same",
111 | kernel_regularizer=tfk.regularizers.l2(l2_reg)),
112 | tfkl.Activation("relu"),
113 | tfkl.Dropout(dropout_rate),
114 | tfkl.MaxPooling2D(pool_size=3, strides=(2, 2), padding="same"),
115 | # Block 5
116 | tfkl.Conv2D(filters=num_base_filters * 8,
117 | kernel_size=3,
118 | strides=(1, 1),
119 | padding="same",
120 | kernel_regularizer=tfk.regularizers.l2(l2_reg)),
121 | tfkl.Activation("relu"),
122 | tfkl.Dropout(dropout_rate),
123 | tfkl.Conv2D(filters=num_base_filters * 8,
124 | kernel_size=3,
125 | strides=(1, 1),
126 | padding="same",
127 | kernel_regularizer=tfk.regularizers.l2(l2_reg)),
128 | tfkl.Activation("relu"),
129 | tfkl.Dropout(dropout_rate),
130 | tfkl.Conv2D(filters=num_base_filters * 8,
131 | kernel_size=3,
132 | strides=(1, 1),
133 | padding="same",
134 | kernel_regularizer=tfk.regularizers.l2(l2_reg)),
135 | tfkl.Activation("relu"),
136 | tfkl.Dropout(dropout_rate),
137 | tfkl.Conv2D(filters=num_base_filters * 8,
138 | kernel_size=3,
139 | strides=(1, 1),
140 | padding="same",
141 | kernel_regularizer=tfk.regularizers.l2(l2_reg)),
142 | tfkl.Activation("relu"),
143 | # Global poolings
144 | tfkl.Lambda(lambda x: tfk.backend.concatenate(
145 | [tfkl.GlobalAvgPool2D()
146 | (x), tfkl.GlobalMaxPool2D()(x)], axis=1)),
147 | # Fully-connected
148 | tfkl.Dense(1, kernel_regularizer=tfk.regularizers.l2(l2_reg)),
149 | tfkl.Activation("sigmoid")
150 | ])
151 |
152 | model.compile(loss=DiabeticRetinopathyDiagnosisBecnhmark.loss(),
153 | optimizer=tfk.optimizers.Adam(learning_rate),
154 | metrics=DiabeticRetinopathyDiagnosisBecnhmark.metrics())
155 |
156 | return model
157 |
158 |
159 | def predict(x, model, num_samples, type="entropy"):
160 | """Monte Carlo dropout uncertainty estimator.
161 |
162 | Args:
163 | x: `numpy.ndarray`, datapoints from input space,
164 | with shape [B, H, W, 3], where B the batch size and
165 | H, W the input images height and width accordingly.
166 | model: `tensorflow.keras.Model`, a probabilistic model,
167 | which accepts input with shape [B, H, W, 3] and
168 | outputs sigmoid probability [0.0, 1.0], and also
169 | accepts boolean arguments `training=True` for enabling
170 | dropout at test time.
171 | num_samples: `int`, number of Monte Carlo samples
172 | (i.e. forward passes from dropout) used for
173 | the calculation of predictive mean and uncertainty.
174 | type: (optional) `str`, type of uncertainty returns,
175 | one of {"entropy", "stddev"}.
176 |
177 | Returns:
178 | mean: `numpy.ndarray`, predictive mean, with shape [B].
179 | uncertainty: `numpy.ndarray`, ncertainty in prediction,
180 | with shape [B].
181 | """
182 | import numpy as np
183 | import scipy.stats
184 |
185 | # Get shapes of data
186 | B, _, _, _ = x.shape
187 |
188 | # Monte Carlo samples from different dropout mask at test time
189 | mc_samples = np.asarray([model(x, training=True) for _ in range(num_samples)
190 | ]).reshape(-1, B)
191 |
192 | # Bernoulli output distribution
193 | dist = scipy.stats.bernoulli(mc_samples.mean(axis=0))
194 |
195 | # Predictive mean calculation
196 | mean = dist.mean()
197 |
198 | # Use predictive entropy for uncertainty
199 | if type == "entropy":
200 | uncertainty = dist.entropy()
201 | # Use predictive standard deviation for uncertainty
202 | elif type == "stddev":
203 | uncertainty = dist.std()
204 | else:
205 | raise ValueError(
206 | "Unrecognized type={} provided, use one of {'entropy', 'stddev'}".
207 | format(type))
208 |
209 | return mean, uncertainty
210 |
--------------------------------------------------------------------------------
/baselines/diabetic_retinopathy_diagnosis/README.md:
--------------------------------------------------------------------------------
1 | # Diabetic Retinopathy Diagnosis
2 |
3 | Machine learning researchers often evaluate their predictions directly on the whole test set.
4 | But, in fact, in real-world settings we have additional choices available, like asking for more information when we are uncertain.
5 | Because of the importance of accurate diagnosis, it would be unreasonable _not_ to ask for further scans of the most ambiguous cases.
6 | Moreover, in this dataset, many images feature camera artefacts that distort results.
7 | In these cases, it is critically important that a model is able to tell when the information provided to it is not sufficiently reliable to classify the patient.
8 | Just like real medical professionals, any diagnostic algorithm should be able to flag cases that require more investigation by medical experts.
9 |
10 |
11 |
12 |
13 |
14 | This task is illustrated in the figure above, where a threshold, `τ`, is used to flag cases as certain and uncertain, with uncertain cases referred to an expert. Alternatively, the uncertainty estimates could be used to come up with a priority list, which could be matched to the available resources of a hospital, rather than waste diagnostic resources on patients for whom the diagnosis is clear cut.
15 |
16 | In order to simulate this process of referring the uncertain cases to experts and relying on the model's predictions for cases it is certain of, we assess the techniques by their diagnostic accuracy and area under receiver-operating-characteristic (ROC) curve, as a function of the
17 | referral rate. We expect the models with well-calibrated uncertainty to refer their least confident predictions to experts,
18 | improving their performance as the number of referrals increases.
19 |
20 | The accuracy of the binary classifier is defined as the ratio of correctly classified data-points over the size of the population.
21 | The receiver-operating-characteristic (ROC) curve illustrates
22 | the diagnostic ability of a binary classifier system as its discrimination threshold is varied.
23 | It is created by plotting the true positive rate (a.k.a. sensitivity) against the false positive rate (a.k.a. 1 - sensitivity).
24 | The quality of such a ROC curve can be summarized by its area under the curve (AUC), which varies between 0.5 (chance level) and 1.0 (best possible value).
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 | To get a better insight into the mechanics of these plots, below we show the relation between predictive uncertainty, e.g. entropy `H_{pred}` of MC Dropout (on y-axis), and maximum-likelihood, i.e. sigmoid probabilities `p(disease| image)` of a deterministic dropout model (on x-axis). In red are images classified incorrectly, and in green are images classified correctly. You can see that the model has higher *uncertainty* for the miss-classified images, whereas the softmax probabilities cannot distinguish red from green for low p (i.e. the plot is separable along the y-axis, but not the x-axis). Hence the uncertainty can be used as an indicator to drive referral.
33 |
34 |
35 |
36 |
37 |
38 |
39 | ## Download and Prepare
40 |
41 | The raw data is licensed and hosted by [Kaggle](https://www.kaggle.com/c/diabetic-retinopathy-detection),
42 | hence you will need a Kaggle account to fetch it. The Kaggle Credentials can be found at
43 |
44 | ```
45 | https://www.kaggle.com//account -> "Create New API Key"
46 | ```
47 |
48 | After creating an API key you will need to accept the dataset license.
49 | Go to [the dateset page on Kaggle](https://www.kaggle.com/c/diabetic-retinopathy-detection/overview) and look
50 | for the button `I Understand and Accept` (make sure when reloading the page that the button does not pop up again).
51 |
52 | The [Kaggle command line interface](https://github.com/Kaggle/kaggle-api) is used for downloading the data, which
53 | assumes that the API token is stored at `~/.kaggle/kaggle.json`. Run the following commands to populate it:
54 |
55 | ```
56 | mkdir -p ~/.kaggle
57 | echo '{"username":"${KAGGLE_USERNAME}","key":"${KAGGLE_KEY}"}' > ~/.kaggle/kaggle.json
58 | chmod 600 ~/.kaggle/kaggle.json
59 | ```
60 |
61 | Download and prepare the data by running:
62 |
63 | ```
64 | python3 -u -c "from bdlb.diabetic_retinopathy_diagnosis.benchmark import DiabeticRetinopathyDiagnosisBecnhmark; DiabeticRetinopathyDiagnosisBecnhmark.download_and_prepare()"
65 | ```
66 |
67 | ## Run a Baseline
68 |
69 | Baseline we currently have implemented include:
70 | * [Deterministic](deterministic)
71 | * [Monte Carlo Dropout](mc_dropout) (following [Gal and Ghahramani](https://arxiv.org/abs/1506.02142))
72 | * [Mean-Field Variational Inference](mfvi) (following [Peterson and Anderson](https://pdfs.semanticscholar.org/37fa/18c66b8130b9f9748d9c94472c5671fb5622.pdf), [Wen et al., 2018](https://arxiv.org/abs/1803.04386))
73 | * [Deep Ensembles](deep_ensembles) (following [Lakshminarayanan et al.](https://arxiv.org/abs/1612.01474))
74 | * [Ensemble MC Dropout](deep_ensembles) (following [Smith and Gal](https://arxiv.org/abs/1803.08533))
75 |
76 |
77 | One executable script per baseline, `main.py`, is provided and can be used by running:
78 |
79 | ```
80 | python3 baselines/diabetic_retinopathy_diagnosis/mc_dropout/main.py \
81 | --level=medium \
82 | --dropout_rate=0.2 \
83 | --output_dir=tmp/medium.mc_dropout
84 | ```
85 |
86 | Or alternatively, use the `baselines/*/configs` for tuned hyperparameters per baseline:
87 |
88 | ```
89 | python3 baselines/diabetic_retinopathy_diagnosis/mc_dropout/main.py --flagfile=baselines/diabetic_retinopathy_diagnosis/mc_dropout/configs/medium.cfg
90 | ```
91 |
92 | ## Leaderboard
93 |
94 | The baseline results we evaluated on this benchmark are ranked below by AUC@50% data retained:
95 |
96 | | Method | AUC
(50% data retained) | Accuracy
(50% data retained) | AUC
(100% data retained) | Accuracy
(100% data retained) |
97 | | ------------------- | :-------------------------: | :-----------------------------: | :-------------------------: | :-------------------------------: |
98 | | Ensemble MC Dropout | 88.1 ± 1.2 | 92.4 ± 0.9 | 82.5 ± 1.1 | 85.3 ± 1.0 |
99 | | MC Dropout | 87.8 ± 1.1 | 91.3 ± 0.7 | 82.1 ± 0.9 | 84.5 ± 0.9 |
100 | | Deep Ensembles | 87.2 ± 0.9 | 89.9 ± 0.9 | 81.8 ± 1.1 | 84.6 ± 0.7 |
101 | | Mean-field VI | 86.6 ± 1.1 | 88.1 ± 1.1 | 82.1 ± 1.2 | 84.3 ± 0.7 |
102 | | Deterministic | 84.9 ± 1.1 | 86.1 ± 0.6 | 82.0 ± 1.0 | 84.2 ± 0.6 |
103 | | Random | 81.8 ± 1.2 | 84.8 ± 0.9 | 82.0 ± 0.9 | 84.2 ± 0.5 |
104 |
105 | ## Cite as
106 |
107 | > [**A Systematic Comparison of Bayesian Deep Learning Robustness in Diabetic Retinopathy Tasks**](https://arxiv.org/abs/1912.10481)
108 | > Angelos Filos, Sebastian Farquhar, Aidan N. Gomez, Tim G. J. Rudner, Zachary Kenton, Lewis Smith, Milad Alizadeh, Arnoud de Kroon & Yarin Gal
109 | > [Bayesian Deep Learning Workshop @ NeurIPS 2019](http://bayesiandeeplearning.org/) (BDL2019)
110 | > _arXiv 1912.10481_
111 |
112 | ```
113 | @article{filos2019systematic,
114 | title={A Systematic Comparison of Bayesian Deep Learning Robustness in Diabetic Retinopathy Tasks},
115 | author={Filos, Angelos and Farquhar, Sebastian and Gomez, Aidan N and Rudner, Tim GJ and Kenton, Zachary and Smith, Lewis and Alizadeh, Milad and de Kroon, Arnoud and Gal, Yarin},
116 | journal={arXiv preprint arXiv:1912.10481},
117 | year={2019}
118 | }
119 | ```
120 | Please cite individual baselines you compare to as well:
121 | - [Monte Carlo Dropout](mc_dropout) [[Gal and Ghahramani, 2015](https://arxiv.org/abs/1506.02142)]
122 | - [Mean-Field Variational Inference](mfvi) [[Peterson and Anderson, 1987](https://pdfs.semanticscholar.org/37fa/18c66b8130b9f9748d9c94472c5671fb5622.pdf); [Wen et al., 2018](https://arxiv.org/abs/1803.04386)]
123 | - [Deep Ensembles](deep_ensembles) [[Lakshminarayanan et al., 2016](https://arxiv.org/abs/1612.01474)]
124 | - [Ensemble MC Dropout](deep_ensembles) [[Smith and Gal, 2018](https://arxiv.org/abs/1803.08533)]
125 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2018, BDL Benchmark Authors. All rights reserved.
2 |
3 | Apache License
4 | Version 2.0, January 2004
5 | http://www.apache.org/licenses/
6 |
7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8 |
9 | 1. Definitions.
10 |
11 | "License" shall mean the terms and conditions for use, reproduction,
12 | and distribution as defined by Sections 1 through 9 of this document.
13 |
14 | "Licensor" shall mean the copyright owner or entity authorized by
15 | the copyright owner that is granting the License.
16 |
17 | "Legal Entity" shall mean the union of the acting entity and all
18 | other entities that control, are controlled by, or are under common
19 | control with that entity. For the purposes of this definition,
20 | "control" means (i) the power, direct or indirect, to cause the
21 | direction or management of such entity, whether by contract or
22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
23 | outstanding shares, or (iii) beneficial ownership of such entity.
24 |
25 | "You" (or "Your") shall mean an individual or Legal Entity
26 | exercising permissions granted by this License.
27 |
28 | "Source" form shall mean the preferred form for making modifications,
29 | including but not limited to software source code, documentation
30 | source, and configuration files.
31 |
32 | "Object" form shall mean any form resulting from mechanical
33 | transformation or translation of a Source form, including but
34 | not limited to compiled object code, generated documentation,
35 | and conversions to other media types.
36 |
37 | "Work" shall mean the work of authorship, whether in Source or
38 | Object form, made available under the License, as indicated by a
39 | copyright notice that is included in or attached to the work
40 | (an example is provided in the Appendix below).
41 |
42 | "Derivative Works" shall mean any work, whether in Source or Object
43 | form, that is based on (or derived from) the Work and for which the
44 | editorial revisions, annotations, elaborations, or other modifications
45 | represent, as a whole, an original work of authorship. For the purposes
46 | of this License, Derivative Works shall not include works that remain
47 | separable from, or merely link (or bind by name) to the interfaces of,
48 | the Work and Derivative Works thereof.
49 |
50 | "Contribution" shall mean any work of authorship, including
51 | the original version of the Work and any modifications or additions
52 | to that Work or Derivative Works thereof, that is intentionally
53 | submitted to Licensor for inclusion in the Work by the copyright owner
54 | or by an individual or Legal Entity authorized to submit on behalf of
55 | the copyright owner. For the purposes of this definition, "submitted"
56 | means any form of electronic, verbal, or written communication sent
57 | to the Licensor or its representatives, including but not limited to
58 | communication on electronic mailing lists, source code control systems,
59 | and issue tracking systems that are managed by, or on behalf of, the
60 | Licensor for the purpose of discussing and improving the Work, but
61 | excluding communication that is conspicuously marked or otherwise
62 | designated in writing by the copyright owner as "Not a Contribution."
63 |
64 | "Contributor" shall mean Licensor and any individual or Legal Entity
65 | on behalf of whom a Contribution has been received by Licensor and
66 | subsequently incorporated within the Work.
67 |
68 | 2. Grant of Copyright License. Subject to the terms and conditions of
69 | this License, each Contributor hereby grants to You a perpetual,
70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71 | copyright license to reproduce, prepare Derivative Works of,
72 | publicly display, publicly perform, sublicense, and distribute the
73 | Work and such Derivative Works in Source or Object form.
74 |
75 | 3. Grant of Patent License. Subject to the terms and conditions of
76 | this License, each Contributor hereby grants to You a perpetual,
77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78 | (except as stated in this section) patent license to make, have made,
79 | use, offer to sell, sell, import, and otherwise transfer the Work,
80 | where such license applies only to those patent claims licensable
81 | by such Contributor that are necessarily infringed by their
82 | Contribution(s) alone or by combination of their Contribution(s)
83 | with the Work to which such Contribution(s) was submitted. If You
84 | institute patent litigation against any entity (including a
85 | cross-claim or counterclaim in a lawsuit) alleging that the Work
86 | or a Contribution incorporated within the Work constitutes direct
87 | or contributory patent infringement, then any patent licenses
88 | granted to You under this License for that Work shall terminate
89 | as of the date such litigation is filed.
90 |
91 | 4. Redistribution. You may reproduce and distribute copies of the
92 | Work or Derivative Works thereof in any medium, with or without
93 | modifications, and in Source or Object form, provided that You
94 | meet the following conditions:
95 |
96 | (a) You must give any other recipients of the Work or
97 | Derivative Works a copy of this License; and
98 |
99 | (b) You must cause any modified files to carry prominent notices
100 | stating that You changed the files; and
101 |
102 | (c) You must retain, in the Source form of any Derivative Works
103 | that You distribute, all copyright, patent, trademark, and
104 | attribution notices from the Source form of the Work,
105 | excluding those notices that do not pertain to any part of
106 | the Derivative Works; and
107 |
108 | (d) If the Work includes a "NOTICE" text file as part of its
109 | distribution, then any Derivative Works that You distribute must
110 | include a readable copy of the attribution notices contained
111 | within such NOTICE file, excluding those notices that do not
112 | pertain to any part of the Derivative Works, in at least one
113 | of the following places: within a NOTICE text file distributed
114 | as part of the Derivative Works; within the Source form or
115 | documentation, if provided along with the Derivative Works; or,
116 | within a display generated by the Derivative Works, if and
117 | wherever such third-party notices normally appear. The contents
118 | of the NOTICE file are for informational purposes only and
119 | do not modify the License. You may add Your own attribution
120 | notices within Derivative Works that You distribute, alongside
121 | or as an addendum to the NOTICE text from the Work, provided
122 | that such additional attribution notices cannot be construed
123 | as modifying the License.
124 |
125 | You may add Your own copyright statement to Your modifications and
126 | may provide additional or different license terms and conditions
127 | for use, reproduction, or distribution of Your modifications, or
128 | for any such Derivative Works as a whole, provided Your use,
129 | reproduction, and distribution of the Work otherwise complies with
130 | the conditions stated in this License.
131 |
132 | 5. Submission of Contributions. Unless You explicitly state otherwise,
133 | any Contribution intentionally submitted for inclusion in the Work
134 | by You to the Licensor shall be under the terms and conditions of
135 | this License, without any additional terms or conditions.
136 | Notwithstanding the above, nothing herein shall supersede or modify
137 | the terms of any separate license agreement you may have executed
138 | with Licensor regarding such Contributions.
139 |
140 | 6. Trademarks. This License does not grant permission to use the trade
141 | names, trademarks, service marks, or product names of the Licensor,
142 | except as required for reasonable and customary use in describing the
143 | origin of the Work and reproducing the content of the NOTICE file.
144 |
145 | 7. Disclaimer of Warranty. Unless required by applicable law or
146 | agreed to in writing, Licensor provides the Work (and each
147 | Contributor provides its Contributions) on an "AS IS" BASIS,
148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149 | implied, including, without limitation, any warranties or conditions
150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151 | PARTICULAR PURPOSE. You are solely responsible for determining the
152 | appropriateness of using or redistributing the Work and assume any
153 | risks associated with Your exercise of permissions under this License.
154 |
155 | 8. Limitation of Liability. In no event and under no legal theory,
156 | whether in tort (including negligence), contract, or otherwise,
157 | unless required by applicable law (such as deliberate and grossly
158 | negligent acts) or agreed to in writing, shall any Contributor be
159 | liable to You for damages, including any direct, indirect, special,
160 | incidental, or consequential damages of any character arising as a
161 | result of this License or out of the use or inability to use the
162 | Work (including but not limited to damages for loss of goodwill,
163 | work stoppage, computer failure or malfunction, or any and all
164 | other commercial damages or losses), even if such Contributor
165 | has been advised of the possibility of such damages.
166 |
167 | 9. Accepting Warranty or Additional Liability. While redistributing
168 | the Work or Derivative Works thereof, You may choose to offer,
169 | and charge a fee for, acceptance of support, warranty, indemnity,
170 | or other liability obligations and/or rights consistent with this
171 | License. However, in accepting such obligations, You may act only
172 | on Your own behalf and on Your sole responsibility, not on behalf
173 | of any other Contributor, and only if You agree to indemnify,
174 | defend, and hold each Contributor harmless for any liability
175 | incurred by, or claims asserted against, such Contributor by reason
176 | of your accepting any such warranty or additional liability.
177 |
178 | END OF TERMS AND CONDITIONS
179 |
180 | APPENDIX: How to apply the Apache License to your work.
181 |
182 | To apply the Apache License to your work, attach the following
183 | boilerplate notice, with the fields enclosed by brackets "[]"
184 | replaced with your own identifying information. (Don't include
185 | the brackets!) The text should be enclosed in the appropriate
186 | comment syntax for the file format. We also recommend that a
187 | file or class name and description of purpose be included on the
188 | same "printed page" as the copyright notice for easier
189 | identification within third-party archives.
190 |
191 | Copyright 2018, Zac Kenton, Angelos Filos.
192 |
193 | Licensed under the Apache License, Version 2.0 (the "License");
194 | you may not use this file except in compliance with the License.
195 | You may obtain a copy of the License at
196 |
197 | http://www.apache.org/licenses/LICENSE-2.0
198 |
199 | Unless required by applicable law or agreed to in writing, software
200 | distributed under the License is distributed on an "AS IS" BASIS,
201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202 | See the License for the specific language governing permissions and
203 | limitations under the License.
--------------------------------------------------------------------------------
/bdlb/diabetic_retinopathy_diagnosis/benchmark.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 BDL Benchmarks Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Diabetic retinopathy diagnosis BDL Benchmark."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import os
23 | from typing import Callable
24 | from typing import Dict
25 | from typing import Optional
26 | from typing import Sequence
27 | from typing import Text
28 | from typing import Tuple
29 | from typing import Union
30 |
31 | import numpy as np
32 | import pandas as pd
33 | import tensorflow as tf
34 | from absl import logging
35 |
36 | from ..core import transforms
37 | from ..core.benchmark import Benchmark
38 | from ..core.benchmark import BenchmarkInfo
39 | from ..core.benchmark import DataSplits
40 | from ..core.constants import DATA_DIR
41 | from ..core.levels import Level
42 |
43 | tfk = tf.keras
44 |
45 | _DIABETIC_RETINOPATHY_DIAGNOSIS_DATA_DIR = os.path.join(
46 | DATA_DIR, "downloads", "manual", "diabetic_retinopathy_diagnosis")
47 |
48 |
49 | class DiabeticRetinopathyDiagnosisBecnhmark(Benchmark):
50 | """Diabetic retinopathy diagnosis benchmark class."""
51 |
52 | def __init__(
53 | self,
54 | level: Union[Text, Level],
55 | batch_size: int = 64,
56 | data_dir: Optional[Text] = None,
57 | download_and_prepare: bool = False,
58 | ):
59 | """Constructs a benchmark object.
60 |
61 | Args:
62 | level: `Level` or `str, downstream task level.
63 | batch_size: (optional) `int`, number of datapoints
64 | per mini-batch.
65 | data_dir: (optional) `str`, path to parent data directory.
66 | download_and_prepare: (optional) `bool`, if the data is not available
67 | it downloads and preprocesses it.
68 | """
69 | self.__level = level if isinstance(level, Level) else Level.from_str(level)
70 | try:
71 | self.__ds = self.load(level=level,
72 | batch_size=batch_size,
73 | data_dir=data_dir or DATA_DIR)
74 | except AssertionError:
75 | if not download_and_prepare:
76 | raise
77 | else:
78 | logging.info(
79 | "Data not found, `DiabeticRetinopathyDiagnosisBecnhmark.download_and_prepare()`"
80 | " is now running...")
81 | self.download_and_prepare()
82 |
83 | @classmethod
84 | def evaluate(
85 | cls,
86 | estimator: Callable[[np.ndarray], Tuple[np.ndarray, np.ndarray]],
87 | dataset: tf.data.Dataset,
88 | output_dir: Optional[Text] = None,
89 | name: Optional[Text] = None,
90 | ) -> Dict[Text, float]:
91 | """Evaluates an `estimator` on the `mode` benchmark dataset.
92 |
93 | Args:
94 | estimator: `lambda x: mu_x, uncertainty_x`, an uncertainty estimation
95 | function, which returns `mean_x` and predictive `uncertainty_x`.
96 | dataset: `tf.data.Dataset`, on which dataset to performance evaluation.
97 | output_dir: (optional) `str`, directory to save figures.
98 | name: (optional) `str`, the name of the method.
99 | """
100 | import inspect
101 | import tqdm
102 | import tensorflow_datasets as tfds
103 | from sklearn.metrics import roc_auc_score
104 | from sklearn.metrics import accuracy_score
105 | import matplotlib.pyplot as plt
106 |
107 | # Containers used for caching performance evaluation
108 | y_true = list()
109 | y_pred = list()
110 | y_uncertainty = list()
111 |
112 | # Convert to NumPy iterator if necessary
113 | ds = dataset if inspect.isgenerator(dataset) else tfds.as_numpy(dataset)
114 |
115 | for x, y in tqdm.tqdm(ds):
116 | # Sample from probabilistic model
117 | mean, uncertainty = estimator(x)
118 | # Cache predictions
119 | y_true.append(y)
120 | y_pred.append(mean)
121 | y_uncertainty.append(uncertainty)
122 |
123 | # Use vectorized NumPy containers
124 | y_true = np.concatenate(y_true).flatten()
125 | y_pred = np.concatenate(y_pred).flatten()
126 | y_uncertainty = np.concatenate(y_uncertainty).flatten()
127 | fractions = np.asarray([0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
128 |
129 | # Metrics for evaluation
130 | metrics = zip(["accuracy", "auc"], cls.metrics())
131 |
132 | return {
133 | metric: cls._evaluate_metric(
134 | y_true,
135 | y_pred,
136 | y_uncertainty,
137 | fractions,
138 | lambda y_true, y_pred: metric_fn(y_true, y_pred).numpy(),
139 | name,
140 | ) for (metric, metric_fn) in metrics
141 | }
142 |
143 | @staticmethod
144 | def _evaluate_metric(
145 | y_true: np.ndarray,
146 | y_pred: np.ndarray,
147 | y_uncertainty: np.ndarray,
148 | fractions: Sequence[float],
149 | metric_fn: Callable[[np.ndarray, np.ndarray], float],
150 | name=None,
151 | ) -> pd.DataFrame:
152 | """Evaluate model predictive distribution on `metric_fn` at data retain
153 | `fractions`.
154 |
155 | Args:
156 | y_true: `numpy.ndarray`, the ground truth labels, with shape [N].
157 | y_pred: `numpy.ndarray`, the model predictions, with shape [N].
158 | y_uncertainty: `numpy.ndarray`, the model uncertainties,
159 | with shape [N].
160 | fractions: `iterable`, the percentages of data to retain for
161 | calculating `metric_fn`.
162 | metric_fn: `lambda(y_true, y_pred) -> float`, a metric
163 | function that provides a score given ground truths
164 | and predictions.
165 | name: (optional) `str`, the name of the method.
166 |
167 | Returns:
168 | A `pandas.DataFrame` with columns ["retained_data", "mean", "std"],
169 | that summarizes the scores at different data retained fractions.
170 | """
171 |
172 | N = y_true.shape[0]
173 |
174 | # Sorts indexes by ascending uncertainty
175 | I_uncertainties = np.argsort(y_uncertainty)
176 |
177 | # Score containers
178 | mean = np.empty_like(fractions)
179 | # TODO(filangel): do bootstrap sampling and estimate standard error
180 | std = np.zeros_like(fractions)
181 |
182 | for i, frac in enumerate(fractions):
183 | # Keep only the %-frac of lowest uncertainties
184 | I = np.zeros(N, dtype=bool)
185 | I[I_uncertainties[:int(N * frac)]] = True
186 | mean[i] = metric_fn(y_true[I], y_pred[I])
187 |
188 | # Store
189 | df = pd.DataFrame(dict(retained_data=fractions, mean=mean, std=std))
190 | df.name = name
191 |
192 | return df
193 |
194 | @property
195 | def datasets(self) -> tf.data.Dataset:
196 | """Pointer to the processed datasets."""
197 | return self.__ds
198 |
199 | @property
200 | def info(self) -> BenchmarkInfo:
201 | """Text description of the benchmark."""
202 | return BenchmarkInfo(description="", urls="", setup="", citation="")
203 |
204 | @property
205 | def level(self) -> Level:
206 | """The downstream task level."""
207 | return self.__level
208 |
209 | @staticmethod
210 | def loss() -> tfk.losses.Loss:
211 | """Loss used for training binary classifiers."""
212 | return tfk.losses.BinaryCrossentropy()
213 |
214 | @staticmethod
215 | def metrics() -> tfk.metrics.Metric:
216 | """Evaluation metrics used for monitoring training."""
217 | return [tfk.metrics.BinaryAccuracy(), tfk.metrics.AUC()]
218 |
219 | @staticmethod
220 | def class_weight() -> Sequence[float]:
221 | """Class weights used for rebalancing the dataset, by skewing the `loss`
222 | accordingly."""
223 | return [1.0, 4.0]
224 |
225 | @classmethod
226 | def load(
227 | cls,
228 | level: Union[Text, Level] = "realworld",
229 | batch_size: int = 64,
230 | data_dir: Optional[Text] = None,
231 | as_numpy: bool = False,
232 | ) -> DataSplits:
233 | """Loads the datasets for the benchmark.
234 |
235 | Args:
236 | level: `Level` or `str, downstream task level.
237 | batch_size: (optional) `int`, number of datapoints
238 | per mini-batch.
239 | data_dir: (optional) `str`, path to parent data directory.
240 | as_numpy: (optional) `bool`, if True returns python generators
241 | with `numpy.ndarray` outputs.
242 |
243 | Returns:
244 | A namedtuple with properties:
245 | * train: `tf.data.Dataset`, train dataset.
246 | * validation: `tf.data.Dataset`, validation dataset.
247 | * test: `tf.data.Dataset`, test dataset.
248 | """
249 | import tensorflow_datasets as tfds
250 | from .tfds_adapter import DiabeticRetinopathyDiagnosis
251 |
252 | # Fetch datasets
253 | try:
254 | ds_train, ds_validation, ds_test = DiabeticRetinopathyDiagnosis(
255 | data_dir=data_dir or DATA_DIR,
256 | config=level).as_dataset(split=["train", "validation", "test"],
257 | shuffle_files=True,
258 | batch_size=batch_size)
259 | except AssertionError as ae:
260 | raise AssertionError(
261 | str(ae) +
262 | " Run DiabeticRetinopathyDiagnosisBecnhmark.download_and_prepare()"
263 | " first and then retry.")
264 |
265 | # Parse task level
266 | level = level if isinstance(level, Level) else Level.from_str(level)
267 | # Dataset tranformations
268 | transforms_train, transforms_eval = cls._preprocessors()
269 | # Apply transformations
270 | ds_train = ds_train.map(transforms_train,
271 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
272 | ds_validation = ds_validation.map(
273 | transforms_eval, num_parallel_calls=tf.data.experimental.AUTOTUNE)
274 | ds_test = ds_test.map(transforms_eval,
275 | num_parallel_calls=tf.data.experimental.AUTOTUNE)
276 |
277 | # Prefetches datasets to memory
278 | ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)
279 | ds_validation = ds_validation.prefetch(tf.data.experimental.AUTOTUNE)
280 | ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)
281 |
282 | if as_numpy:
283 | # Convert to NumPy iterators
284 | ds_train = tfds.as_numpy(ds_train)
285 | ds_validation = tfds.as_numpy(ds_validation)
286 | ds_test = tfds.as_numpy(ds_test)
287 |
288 | return DataSplits(ds_train, ds_validation, ds_test)
289 |
290 | @classmethod
291 | def download_and_prepare(cls, levels=None) -> None:
292 | """Downloads dataset from Kaggle, extracts zip files and processes it using
293 | `tensorflow_datasets`.
294 |
295 | Args:
296 | levels: (optional) `iterable` of `str`, specifies which
297 | levels from {'medium', 'realworld'} to prepare,
298 | if None it prepares all the levels.
299 |
300 | Raises:
301 | OSError: if `~/.kaggle/kaggle.json` is not set up.
302 | """
303 | # Disable GPU for data download, extraction and preparation
304 | import os
305 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
306 | cls._download()
307 | # cls._extract()
308 | #cls._prepare(levels)
309 |
310 | @staticmethod
311 | def _download() -> None:
312 | """Downloads data from Kaggle using `tensorflow_datasets`.
313 |
314 | Raises:
315 | OSError: if `~/.kaggle/kaggle.json` is not set up.
316 | """
317 | import subprocess as sp
318 | import tensorflow_datasets as tfds
319 |
320 | # Append `/home/$USER/.local/bin` to path
321 | os.environ["PATH"] += ":/home/{}/.local/bin/".format(os.environ["USER"])
322 |
323 | # Download all files from Kaggle
324 | drd = tfds.download.kaggle.KaggleCompetitionDownloader(
325 | "diabetic-retinopathy-detection")
326 | try:
327 | for dfile in drd.competition_files:
328 | drd.download_file(dfile,
329 | output_dir=_DIABETIC_RETINOPATHY_DIAGNOSIS_DATA_DIR)
330 | except sp.CalledProcessError as cpe:
331 | raise OSError(
332 | str(cpe) + "." +
333 | " Make sure you have ~/.kaggle/kaggle.json setup, fetched from the Kaggle website"
334 | " https://www.kaggle.com//account -> 'Create New API Key'."
335 | " Also accept the dataset license by going to"
336 | " https://www.kaggle.com/c/diabetic-retinopathy-detection/rules"
337 | " and look for the button 'I Understand and Accept' (make sure when reloading the"
338 | " page that the button does not pop up again).")
339 |
340 | @staticmethod
341 | def _extract() -> None:
342 | """Extracts zip files downloaded from Kaggle."""
343 | import glob
344 | import tqdm
345 | import zipfile
346 | import tempfile
347 |
348 | # Extract train and test original images
349 | for split in ["train", "test"]:
350 | # Extract ".zip.00*"" files to ""
351 | with tempfile.NamedTemporaryFile() as tmp:
352 | # Concatenate ".zip.00*" to ".zip"
353 | for fname in tqdm.tqdm(
354 | sorted(
355 | glob.glob(
356 | os.path.join(_DIABETIC_RETINOPATHY_DIAGNOSIS_DATA_DIR,
357 | "{split}.zip.00*".format(split=split))))):
358 | # Unzip ".zip" to ""
359 | with open(fname, "rb") as ztmp:
360 | tmp.write(ztmp.read())
361 | with zipfile.ZipFile(tmp) as zfile:
362 | for image in tqdm.tqdm(iterable=zfile.namelist(),
363 | total=len(zfile.namelist())):
364 | zfile.extract(member=image,
365 | path=_DIABETIC_RETINOPATHY_DIAGNOSIS_DATA_DIR)
366 | # Delete ".zip.00*" files
367 | for splitzip in os.listdir(_DIABETIC_RETINOPATHY_DIAGNOSIS_DATA_DIR):
368 | if "{split}.zip.00".format(split=split) in splitzip:
369 | os.remove(
370 | os.path.join(_DIABETIC_RETINOPATHY_DIAGNOSIS_DATA_DIR, splitzip))
371 |
372 | # Extract "sample.zip", "trainLabels.csv.zip"
373 | for fname in ["sample", "trainLabels.csv"]:
374 | zfname = os.path.join(_DIABETIC_RETINOPATHY_DIAGNOSIS_DATA_DIR,
375 | "{fname}.zip".format(fname=fname))
376 | with zipfile.ZipFile(zfname) as zfile:
377 | zfile.extractall(_DIABETIC_RETINOPATHY_DIAGNOSIS_DATA_DIR)
378 | os.remove(zfname)
379 |
380 | @staticmethod
381 | def _prepare(levels=None) -> None:
382 | """Generates the TFRecord objects for medium and realworld experiments."""
383 | import multiprocessing
384 | from absl import logging
385 | from .tfds_adapter import DiabeticRetinopathyDiagnosis
386 | # Hangle each level individually
387 | for level in levels or ["medium", "realworld"]:
388 | dtask = DiabeticRetinopathyDiagnosis(data_dir=DATA_DIR, config=level)
389 | logging.debug("=== Preparing TFRecords for {} ===".format(level))
390 | dtask.download_and_prepare()
391 |
392 | @classmethod
393 | def _preprocessors(cls) -> Tuple[transforms.Transform, transforms.Transform]:
394 | """Applies transformations to the raw data."""
395 | import tensorflow_datasets as tfds
396 |
397 | # Transformation hyperparameters
398 | mean = np.asarray([0.42606387, 0.29752496, 0.21309826])
399 | stddev = np.asarray([0.27662534, 0.20280295, 0.1687619])
400 |
401 | class Parse(transforms.Transform):
402 | """Parses datapoints from raw `tf.data.Dataset`."""
403 |
404 | def __call__(self, x, y=None):
405 | """Returns `as_supervised` tuple."""
406 | return x["image"], x["label"]
407 |
408 | class CastX(transforms.Transform):
409 | """Casts image to `dtype`."""
410 |
411 | def __init__(self, dtype):
412 | """Constructs a type caster."""
413 | self.dtype = dtype
414 |
415 | def __call__(self, x, y):
416 | """Returns casted image (to `dtype`) and its (unchanged) label as
417 | tuple."""
418 | return tf.cast(x, self.dtype), y
419 |
420 | class To01X(transforms.Transform):
421 | """Rescales image to [min, max]=[0, 1]."""
422 |
423 | def __call__(self, x, y):
424 | """Returns rescaled image and its (unchanged) label as tuple."""
425 | return x / 255.0, y
426 |
427 | # Get augmentation schemes
428 | [augmentation_config,
429 | no_augmentation_config] = cls._ImageDataGenerator_config()
430 |
431 | # Transformations for train dataset
432 | transforms_train = transforms.Compose([
433 | Parse(),
434 | CastX(tf.float32),
435 | To01X(),
436 | transforms.Normalize(mean, stddev),
437 | # TODO(filangel): hangle batch with ImageDataGenerator
438 | # transforms.RandomAugment(**augmentation_config),
439 | ])
440 |
441 | # Transformations for validation/test dataset
442 | transforms_eval = transforms.Compose([
443 | Parse(),
444 | CastX(tf.float32),
445 | To01X(),
446 | transforms.Normalize(mean, stddev),
447 | # TODO(filangel): hangle batch with ImageDataGenerator
448 | # transforms.RandomAugment(**no_augmentation_config),
449 | ])
450 |
451 | return transforms_train, transforms_eval
452 |
453 | @staticmethod
454 | def _ImageDataGenerator_config():
455 | """Returns the configs for the
456 | `tensorflow.keras.preprocessing.image.ImageDataGenerator`, used for the
457 | random augmentation of the dataset, following the implementation of
458 | https://github.com/chleibig/disease-detection/blob/f3401b26aa9b832ff77afe93
459 | e3faa342f7d088e5/scripts/inspect_data_augmentation.py."""
460 | augmentation_config = dict(
461 | featurewise_center=False,
462 | samplewise_center=False,
463 | featurewise_std_normalization=False,
464 | samplewise_std_normalization=False,
465 | zca_whitening=False,
466 | rotation_range=180.0,
467 | width_shift_range=0.05,
468 | height_shift_range=0.05,
469 | shear_range=0.,
470 | zoom_range=0.10,
471 | channel_shift_range=0.,
472 | fill_mode="constant",
473 | cval=0.,
474 | horizontal_flip=True,
475 | vertical_flip=True,
476 | data_format="channels_last",
477 | )
478 | no_augmentation_config = dict(
479 | featurewise_center=False,
480 | samplewise_center=False,
481 | featurewise_std_normalization=False,
482 | samplewise_std_normalization=False,
483 | zca_whitening=False,
484 | rotation_range=0.0,
485 | width_shift_range=0.0,
486 | height_shift_range=0.0,
487 | shear_range=0.,
488 | zoom_range=0.0,
489 | channel_shift_range=0.,
490 | fill_mode="nearest",
491 | cval=0.,
492 | horizontal_flip=False,
493 | vertical_flip=False,
494 | data_format="channels_last",
495 | )
496 | return augmentation_config, no_augmentation_config
497 |
--------------------------------------------------------------------------------