├── .gitignore ├── EnsembleBench ├── __init__.py ├── diversityMetrics.py ├── frameworks │ ├── pytorchUtility.py │ └── sklearnUtility.py ├── groupMetrics.py └── teamSelection.py ├── README.md ├── demo ├── BaselineDiversityBasedEnsembleSelection.ipynb └── FocalDiversityBasedEnsembleSelection.ipynb ├── env.sh ├── setup.cfg ├── setup.py └── setup.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | # PyCharm 131 | .idea/* 132 | 133 | # Data folder 134 | data/* 135 | imagenetData 136 | *prediction 137 | # Byte-compiled / optimized / DLL files 138 | __pycache__/ 139 | *.py[cod] 140 | *$py.class 141 | 142 | # C extensions 143 | *.so 144 | 145 | # Distribution / packaging 146 | .Python 147 | build/ 148 | develop-eggs/ 149 | dist/ 150 | downloads/ 151 | eggs/ 152 | .eggs/ 153 | lib/ 154 | lib64/ 155 | parts/ 156 | sdist/ 157 | var/ 158 | wheels/ 159 | *.egg-info/ 160 | .installed.cfg 161 | *.egg 162 | MANIFEST 163 | 164 | # PyInstaller 165 | # Usually these files are written by a python script from a template 166 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 167 | *.manifest 168 | *.spec 169 | 170 | # Installer logs 171 | pip-log.txt 172 | pip-delete-this-directory.txt 173 | 174 | # Unit test / coverage reports 175 | htmlcov/ 176 | .tox/ 177 | .coverage 178 | .coverage.* 179 | .cache 180 | nosetests.xml 181 | coverage.xml 182 | *.cover 183 | .hypothesis/ 184 | .pytest_cache/ 185 | 186 | # Translations 187 | *.mo 188 | *.pot 189 | 190 | # Django stuff: 191 | *.log 192 | local_settings.py 193 | db.sqlite3 194 | 195 | # Flask stuff: 196 | instance/ 197 | .webassets-cache 198 | 199 | # Scrapy stuff: 200 | .scrapy 201 | 202 | # Sphinx documentation 203 | docs/_build/ 204 | 205 | # PyBuilder 206 | target/ 207 | 208 | # Jupyter Notebook 209 | .ipynb_checkpoints 210 | 211 | # pyenv 212 | .python-version 213 | 214 | # celery beat schedule file 215 | celerybeat-schedule 216 | 217 | # SageMath parsed files 218 | *.sage.py 219 | 220 | # Environments 221 | .env 222 | .venv 223 | env/ 224 | venv/ 225 | ENV/ 226 | env.bak/ 227 | venv.bak/ 228 | 229 | # Spyder project settings 230 | .spyderproject 231 | .spyproject 232 | 233 | # Rope project settings 234 | .ropeproject 235 | 236 | # mkdocs documentation 237 | /site 238 | 239 | # mypy 240 | .mypy_cache/ 241 | 242 | # pytorch pt 243 | *.pt 244 | 245 | # numpy data 246 | *.np 247 | *.npz 248 | 249 | # images 250 | *.pdf 251 | *.png 252 | 253 | # binary object 254 | *.obj 255 | *.joblib -------------------------------------------------------------------------------- /EnsembleBench/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/git-disl/EnsembleBench/cb202c11172466373706075fdc390a38b81e482b/EnsembleBench/__init__.py -------------------------------------------------------------------------------- /EnsembleBench/diversityMetrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def cohen_kappa_statistic(M): 5 | """ 6 | M: the multiple inputs, each input represents a model output. 7 | """ 8 | from sklearn.metrics import cohen_kappa_score 9 | Qs = [] 10 | for i in range(M.shape[0]): 11 | for j in range(i + 1, M.shape[0]): 12 | Qs.append(cohen_kappa_score(M[i, :], M[j, :])) 13 | return np.mean(Qs) 14 | 15 | def correlation_coefficient(M, y_true): 16 | Qs = [] 17 | for i in range(M.shape[0]): 18 | for j in range(i + 1, M.shape[0]): 19 | N_1_1 = np.sum(np.logical_and(y_true == M[i, :], y_true == M[j, :])) # number of both correct 20 | N_0_0 = np.sum(np.logical_and(y_true != M[i, :], y_true != M[j, :])) # number of both incorrect 21 | N_0_1 = np.sum(np.logical_and(y_true != M[i, :], y_true == M[j, :])) # number of j correct but not i 22 | N_1_0 = np.sum(np.logical_and(y_true == M[i, :], y_true != M[j, :])) # number of i correct but not j 23 | Qs.append((N_1_1*N_0_0-N_0_1*N_1_0) * 1. / (np.sqrt((N_1_1+N_1_0)*(N_0_1+N_0_0)*(N_1_1+N_0_1)*(N_1_0+N_0_0))+np.finfo(float).eps)+np.finfo(float).eps) 24 | return np.mean(Qs) 25 | 26 | 27 | def Q_statistic(M, y_true): 28 | Qs = [] 29 | for i in range(M.shape[0]): 30 | for j in range(i+1, M.shape[0]): 31 | N_1_1 = np.sum(np.logical_and(y_true == M[i, :], y_true == M[j, :])) # number of both correct 32 | N_0_0 = np.sum(np.logical_and(y_true != M[i, :], y_true != M[j, :])) # number of both incorrect 33 | N_0_1 = np.sum(np.logical_and(y_true != M[i, :], y_true == M[j, :])) # number of j correct but not i 34 | N_1_0 = np.sum(np.logical_and(y_true == M[i, :], y_true != M[j, :])) # number of i correct but not j 35 | Qs.append((N_1_1*N_0_0 - N_0_1*N_1_0)*1./(N_1_1*N_0_0+N_0_1*N_1_0+np.finfo(float).eps)) 36 | return np.mean(Qs) 37 | 38 | def binary_disagreement(M, y_true): 39 | Qs = [] 40 | for i in range(M.shape[0]): 41 | for j in range(i + 1, M.shape[0]): 42 | N_1_1 = np.sum(np.logical_and(y_true == M[i, :], y_true == M[j, :])) # number of both correct 43 | N_0_0 = np.sum(np.logical_and(y_true != M[i, :], y_true != M[j, :])) # number of both incorrect 44 | N_0_1 = np.sum(np.logical_and(y_true != M[i, :], y_true == M[j, :])) # number of j correct but not i 45 | N_1_0 = np.sum(np.logical_and(y_true == M[i, :], y_true != M[j, :])) # number of i correct but not j 46 | Qs.append((N_0_1+N_1_0)*1./(N_1_1+N_1_0+N_0_1+N_0_0)) 47 | return np.mean(Qs) 48 | 49 | 50 | def fleiss_kappa_statistic(M, y_true, n_classes=10): 51 | M_ = np.zeros(shape=(M.shape[1], n_classes)) 52 | for row in M: 53 | for sid in range(len(row)): 54 | M_[sid, row[sid]] += 1 55 | 56 | N, k = M_.shape # N is # of items, k is # of categories 57 | n_annotators = float(np.sum(M_[0, :])) # # of annotators 58 | 59 | p = np.sum(M_, axis=0) / (N * n_annotators) 60 | P = (np.sum(M_ * M_, axis=1) - n_annotators) / (n_annotators * (n_annotators - 1)) 61 | Pbar = np.sum(P) / N 62 | PbarE = np.sum(p * p) 63 | 64 | return (Pbar - PbarE) / (1 - PbarE) 65 | 66 | 67 | def entropy(M, y_true): 68 | N = M.shape[1] 69 | L = M.shape[0] 70 | E = 0. 71 | for j in range(M.shape[1]): 72 | l_zj = list(M[:, j]).count(y_true[j]) 73 | E += min(l_zj, L-l_zj) 74 | E = E * 1. / (N * (L - np.ceil(L / 2.))) 75 | return E 76 | 77 | def kohavi_wolpert_variance(M, y_true): 78 | N = M.shape[1] 79 | L = M.shape[0] 80 | kw = 0. 81 | for j in range(N): 82 | l_zj = list(M[:, j]).count(y_true[j]) 83 | kw += l_zj * (L-l_zj) 84 | kw = kw * 1. / (N * L * L) 85 | return kw 86 | 87 | def generalized_diversity(M, y_true): 88 | N = M.shape[1] 89 | L = M.shape[0] 90 | #print('num samples:', N, 'num models:', L) 91 | pi = np.zeros(N) 92 | for i in range(N): 93 | pIdx = 0 94 | for j in range(L): 95 | if M[j][i] != y_true[i]: 96 | pIdx += 1 97 | pi[pIdx] += 1 98 | 99 | pi = [x*1.0/N for x in pi] 100 | 101 | P1 = 0 102 | P2 = 0 103 | for i in range(N): 104 | P1 += i * 1.0 * pi[i] / L 105 | P2 += i * (i-1) * 1.0 * pi[i] / (L * (L-1)) 106 | return 1.0-P2/P1 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /EnsembleBench/frameworks/pytorchUtility.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | import numpy as np 8 | 9 | from collections import Counter 10 | 11 | 12 | def calAccuracy(output, target, topk=(1,)): 13 | """Computes the precision@k for the specified values of k""" 14 | maxk = max(topk) 15 | batch_size = target.size(0) 16 | 17 | _, pred = output.topk(maxk, 1, True, True) 18 | pred = pred.t() 19 | #print(pred.type(), pred.size()) 20 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 21 | #print(target.type(), target.size()) 22 | res = [] 23 | for k in topk: 24 | correct_k = correct[:k].view(-1).float().sum(0) 25 | res.append(correct_k.mul_(100.0 / batch_size)) 26 | return res 27 | 28 | 29 | def calAveragePredictionVectorAccuracy(predictionVectorsList, target, modelsList=None, topk=(1,)): 30 | predictionVectorsStack = torch.stack(predictionVectorsList) 31 | if len(modelsList) > 0: 32 | predictionVectorsStack = predictionVectorsStack[modelsList,...] 33 | averagePrediction = torch.mean(predictionVectorsStack, dim=0) 34 | return calAccuracy(averagePrediction, target, topk) 35 | 36 | 37 | def calNegativeSamplesSet(predictionVectorsList, target): 38 | """filter the disagreed samples, return an array of sets""" 39 | batchSize = target.size(0) 40 | predictionList = list() 41 | negativeSamplesSet = list() 42 | 43 | for pVL in predictionVectorsList: 44 | _, pred = pVL.max(dim=1) 45 | predictionList.append(pred) 46 | negativeSamplesSet.append(set()) 47 | 48 | for i in range(batchSize): 49 | for j,_ in enumerate(predictionList): 50 | if predictionList[j][i] != target[i]: 51 | negativeSamplesSet[j].add(i) 52 | return negativeSamplesSet 53 | 54 | 55 | def calDisagreementSamplesNoGroundTruth(predictionVectorsList, target): 56 | """filter the disagreed samples without ground truth""" 57 | batchSize = target.size(0) 58 | predictionList = list() 59 | 60 | for pVL in predictionVectorsList: 61 | _, pred = pVL.max(dim=1) 62 | predictionList.append(pred) 63 | 64 | sampleID = list() 65 | sampleTarget = list() 66 | predictions = list() 67 | predVectors = list() 68 | 69 | for i in range(batchSize): 70 | pred = [] 71 | predVect = [] 72 | allAgreed = True 73 | previousPrediction = -1 74 | for j, p in enumerate(predictionList): 75 | pred.append(p[i].item()) 76 | predVect.append(predictionVectorsList[j][i]) 77 | if previousPrediction == -1: 78 | previousPrediction = p[i] 79 | continue 80 | if p[i] != previousPrediction: 81 | allAgreed = False 82 | if not allAgreed: 83 | sampleID.append(i) 84 | sampleTarget.append(target[i].item()) 85 | predictions.append(pred) 86 | predVectors.append(predVect) 87 | return sampleID, sampleTarget, predictions, predVectors 88 | 89 | 90 | def calDisagreementSamplesOneTargetNegative(predictionVectorsList, target, oneTargetIdx): 91 | """filter the disagreed samples""" 92 | batchSize = target.size(0) 93 | predictionList = list() 94 | 95 | for pVL in predictionVectorsList: 96 | _, pred = pVL.max(dim=1) 97 | predictionList.append(pred) 98 | 99 | # return sampleID, sampleTarget, predictions, predVectors 100 | sampleID = list() 101 | sampleTarget = list() 102 | predictions = list() 103 | predVectors = list() 104 | 105 | for i in range(batchSize): 106 | pred = [] 107 | predVect = [] 108 | for j, p in enumerate(predictionList): 109 | pred.append(p[i].item()) 110 | predVect.append(predictionVectorsList[j][i]) 111 | if predictionList[oneTargetIdx][i] != target[i]: 112 | sampleID.append(i) 113 | sampleTarget.append(target[i].item()) 114 | predictions.append(pred) 115 | predVectors.append(predVect) 116 | return sampleID, sampleTarget, predictions, predVectors 117 | 118 | 119 | def filterModelsFixed(sampleID, sampleTarget, predictions, predVectors, selectModels): 120 | filteredPredictions = predictions[:, selectModels] 121 | #print(filteredPredictions.shape) 122 | filteredPredVectors = predVectors[:, selectModels] 123 | return sampleID, sampleTarget, filteredPredictions, filteredPredVectors 124 | 125 | 126 | from ..groupMetrics import calAllDiversityMetrics 127 | 128 | def calFocalDiversityScoresPyTorch( 129 | oneFocalModel, 130 | teamModelList, 131 | negativeSamplesList, 132 | diversityMetricsList, 133 | # save time 134 | crossValidation = True, 135 | nRandomSamples = 100, 136 | crossValidationTimes = 3 137 | ): 138 | sampleID, sampleTarget, predictions, predVectors = negativeSamplesList[oneFocalModel] 139 | teamSampleID, teamSampleTarget, teamPredictions, teamPredVectors = \ 140 | filterModelsFixed(sampleID, sampleTarget, predictions, predVectors, teamModelList) 141 | 142 | if crossValidation: 143 | tmpMetrics = list() 144 | for _ in range(crossValidationTimes): 145 | randomIdx = np.random.choice(np.arange(teamPredictions.shape[0]), nRandomSamples) 146 | tmpMetrics.append(calAllDiversityMetrics(teamPredictions[randomIdx], 147 | teamSampleTarget[randomIdx], 148 | diversityMetricsList)) 149 | tmpMetrics = np.mean(np.array(tmpMetrics), axis=0) 150 | else: 151 | tmpMetrics.append(calAllDiversityMetrics(teamPredictions[randomIdx], 152 | teamSampleTarget[randomIdx], 153 | diversityMetricsList)) 154 | return {diversityMetricsList[i]:tmpMetrics[i].item() for i in range(len(tmpMetrics))} 155 | 156 | -------------------------------------------------------------------------------- /EnsembleBench/frameworks/sklearnUtility.py: -------------------------------------------------------------------------------- 1 | # for scikit-learn 2 | 3 | import copy 4 | import numpy as np 5 | from sklearn.metrics import accuracy_score 6 | 7 | # Test on scikit-learn GradientBoostingClassifier 8 | def getEnsModelPred( 9 | X, 10 | ens_wrapper, 11 | model_ids, 12 | ): 13 | if isinstance(ens_wrapper, EnsWrapper): 14 | return ens_wrapper.predict(X, model_ids) 15 | else: 16 | org_estimators = ens_wrapper.estimators_ 17 | ens_wrapper = copy.deepcopy(ens_wrapper) 18 | ens_wrapper.estimators_ = org_estimators[model_ids, ...] 19 | y_pred = ens_wrapper.predict(X) 20 | return y_pred 21 | 22 | 23 | class EnsWrapper: # generally follow sklearn API 24 | def __init__(self, classifiers, names=None, voting='plurality', weights=None): 25 | self.classifiers = np.array(classifiers, dtype=object) 26 | self.names = names 27 | self.voting = voting 28 | self.weights = weights 29 | 30 | def fit(self, X, y): 31 | for name, clf in zip(names, classifiers): 32 | clf.fit(X_train, y_train) 33 | print("finish training of ", name) 34 | 35 | def predict(self, X, model_ids=None): # different from sklearn API 36 | if model_ids: 37 | classifiers = self.classifiers[model_ids, ...] 38 | else: 39 | classifiers = self.classifiers 40 | 41 | if self.voting == 'plurality': 42 | predictions = np.asarray([clf.predict(X) for clf in classifiers]).T.astype(np.int64) 43 | plural = np.apply_along_axis(lambda x: 44 | np.argmax(np.bincount( 45 | x, weights=self.weights)), 46 | axis=1, 47 | arr=predictions) 48 | return plural 49 | else: 50 | raise NotImplementedError("Voting methods not implemented: ", self.voting) 51 | 52 | def score(self, X, y, model_ids=None): 53 | return accuracy_score(y, self.predict(X, model_ids)) 54 | 55 | 56 | def calNegativeSamplesFocalModel(predictionList, target, oneTargetIdx): 57 | """Obtain the negative samples for the focal model oneTargetIdx""" 58 | sampleID = list() 59 | for i in range(len(target)): 60 | if predictionList[oneTargetIdx][i] != target[i]: 61 | sampleID.append(i) 62 | return sampleID -------------------------------------------------------------------------------- /EnsembleBench/groupMetrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | _allDiversityMetrics = set(['CK', 'QS', 'BD', 'FK', 'GD', 'KW']) 4 | 5 | # pair-wise 6 | from .diversityMetrics import correlation_coefficient 7 | def group_correlation_coefficient(predictions, label): 8 | pred = np.transpose(predictions, (1, 0)) 9 | return correlation_coefficient(pred, label) 10 | from .diversityMetrics import Q_statistic 11 | def group_Q_statistic(predictions, label): 12 | pred = np.transpose(predictions, (1, 0)) 13 | return Q_statistic(pred, label) 14 | from .diversityMetrics import cohen_kappa_statistic 15 | def group_kappa_score(predictions): 16 | pred = np.transpose(predictions, (1, 0)) 17 | return cohen_kappa_statistic(pred) 18 | from .diversityMetrics import binary_disagreement 19 | def group_binary_disagreement(predictions, label): 20 | pred = np.transpose(predictions, (1, 0)) 21 | return binary_disagreement(pred, label) 22 | 23 | # non-pair-wise 24 | from .diversityMetrics import entropy 25 | def group_entropy(predictions, label): 26 | pred = np.transpose(predictions, (1, 0)) 27 | return entropy(pred, label) 28 | from .diversityMetrics import kohavi_wolpert_variance 29 | def group_KW_variance(predictions, label): 30 | pred = np.transpose(predictions, (1, 0)) 31 | return kohavi_wolpert_variance(pred, label) 32 | from .diversityMetrics import generalized_diversity 33 | def group_generalized_diversity(predictions, label): 34 | pred = np.transpose(predictions, (1, 0)) 35 | return generalized_diversity(pred, label) 36 | import statsmodels.stats.inter_rater 37 | def fleiss_kappa_score(predictions): 38 | pred, _ = statsmodels.stats.inter_rater.aggregate_raters(predictions) 39 | return statsmodels.stats.inter_rater.fleiss_kappa(pred) 40 | 41 | def calDiversityMetric(prediction, target=None, metric='CK'): 42 | if metric not in _allDiversityMetrics: 43 | raise Exception("Diversity Metric Not Found!") 44 | if metric == 'CK': 45 | return group_kappa_score(prediction) 46 | if metric == 'QS' and len(target) > 0: 47 | return group_Q_statistic(prediction, target) 48 | if metric == 'BD' and len(target) > 0: 49 | return 1.0 - group_binary_disagreement(prediction, target) 50 | if metric == 'FK': 51 | return fleiss_kappa_score(prediction) 52 | if metric == 'GD' and len(target) > 0: 53 | return 1.0 - group_generalized_diversity(prediction, target) 54 | if metric == 'KW' and len(target) > 0: 55 | return 1.0 - group_KW_variance(prediction, target) 56 | raise Exception("Diversity Metric Error!") 57 | 58 | def calAllDiversityMetrics(prediction, target=None, metrics=None): 59 | if metrics == None: 60 | return 61 | results = list() 62 | for m in metrics: 63 | #print(m) 64 | results.append(calDiversityMetric(prediction, target, m)) 65 | return results 66 | 67 | # calFocalDiversityScores 68 | def calFocalDiversityScores( 69 | oneFocalModel, 70 | teamModelList, 71 | y_pred_list, 72 | gt_label, 73 | negative_sample_list, 74 | diversityMetricsList, 75 | # save time 76 | crossValidation = True, 77 | nRandomSamples = 100, 78 | crossValidationTimes = 3 79 | ): 80 | # num samples x num member models 81 | teamPredictions = np.transpose(y_pred_list[teamModelList, ...], (1, 0)) 82 | np.random.seed(2021) # fix random state 83 | if crossValidation: 84 | tmpMetrics = list() 85 | for _ in range(crossValidationTimes): 86 | randomIdx = np.random.choice(np.arange(len(negative_sample_list[oneFocalModel])), nRandomSamples) 87 | tmpMetrics.append(np.array(calAllDiversityMetrics( 88 | teamPredictions[negative_sample_list[oneFocalModel]][randomIdx], 89 | gt_label[negative_sample_list[oneFocalModel]][randomIdx], 90 | diversityMetricsList))) 91 | tmpMetrics = np.mean(np.array(tmpMetrics), axis=0) 92 | else: 93 | tmpMetrics = np.array(calAllDiversityMetrics( 94 | teamPredictions[negative_sample_list[oneFocalModel]], 95 | gt_label[negative_sample_list[oneFocalModel]], 96 | diversityMetricsList)) 97 | 98 | return {diversityMetricsList[i]:tmpMetrics[i].item() for i in range(len(tmpMetrics))} -------------------------------------------------------------------------------- /EnsembleBench/teamSelection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | from operator import itemgetter 4 | 5 | 6 | def getThreshold(target, metric, k=1.0): 7 | avg = np.mean(target) 8 | std = np.std(target) 9 | for i, (m, t) in enumerate(sorted(zip(metric, target))): 10 | if t < avg - k * std: 11 | return m 12 | return np.max(metric) 13 | 14 | from sklearn.cluster import KMeans 15 | def getThresholdFromClusteringKMeans(target, metric, kmeansInit='random'): 16 | data = [[t, m] for t, m in zip(target, metric)] 17 | kmeans = KMeans(n_clusters=2, init=kmeansInit, random_state=0).fit(data) 18 | c0 = metric[np.ma.make_mask(kmeans.labels_)] 19 | c0min, c0max = min(c0), max(c0) 20 | c1 = metric[np.logical_not(kmeans.labels_)] 21 | c1min, c1max = min(c1), max(c1) 22 | if c0min > c1max or c1min > c0max: 23 | return max(c0min, c1min) 24 | return min(c0max, c1max) 25 | 26 | def getThresholdFromKMeans(target, metric, kmeansInit='random'): 27 | data = [[t, m] for t, m in zip(target, metric)] 28 | kmeans = KMeans(n_clusters=2, init=kmeansInit, random_state=0).fit(data) 29 | c0 = metric[np.ma.make_mask(kmeans.labels_)] 30 | c0min, c0max = min(c0), max(c0) 31 | c1 = metric[np.logical_not(kmeans.labels_)] 32 | c1min, c1max = min(c1), max(c1) 33 | if c0min > c1max or c1min > c0max: 34 | return max(c0min, c1min) 35 | return min(c0max, c1max), kmeans 36 | 37 | def getThresholdClusteringKMeans(target, metric, kmeansInit='random'): 38 | data = [[t, m] for t, m in zip(target, metric)] 39 | if kmeansInit == 'strategic': 40 | kmeansInit=np.array([[np.max(target), np.min(metric)], 41 | [np.min(target), np.max(metric)]]) 42 | 43 | kmeans = KMeans(n_clusters=2, init=kmeansInit, random_state=0).fit(data) 44 | c0 = metric[np.logical_not(kmeans.labels_)] 45 | c0min, c0max = min(c0), max(c0) 46 | c1 = metric[np.ma.make_mask(kmeans.labels_)] 47 | c1min, c1max = min(c1), max(c1) 48 | centers = kmeans.cluster_centers_ 49 | if centers[0][0] > centers[1][0]: 50 | return c1min, kmeans 51 | return c0min, kmeans 52 | 53 | def getThresholdClusteringKMeansCenter(target, metric, kmeansInit='random'): 54 | data = [[t, m] for t, m in zip(target, metric)] 55 | if kmeansInit == 'strategic': 56 | kmeansInit=np.array([[np.max(target), np.min(metric)], 57 | [np.min(target), np.max(metric)]]) 58 | 59 | kmeans = KMeans(n_clusters=2, init=kmeansInit, random_state=0).fit(data) 60 | c0 = metric[np.logical_not(kmeans.labels_)] 61 | c0min, c0max = min(c0), max(c0) 62 | c1 = metric[np.ma.make_mask(kmeans.labels_)] 63 | c1min, c1max = min(c1), max(c1) 64 | centers = kmeans.cluster_centers_ 65 | if centers[0][0] > centers[1][0]: 66 | return centers[1][1], kmeans 67 | return centers[0][1], kmeans 68 | 69 | def oneThirdThreshold(metric): 70 | metricSort = copy.deepcopy(metric) 71 | metricSort.sort() 72 | for i in range(len(metric)): 73 | if i >= len(metric)/3.0: 74 | return metric[i] 75 | 76 | def normalize01(array): 77 | if max(array) == min(array): #TODO: to consider more cases 78 | return array 79 | return (array-min(array))/(max(array)-min(array)) 80 | 81 | def isTeamContainsAny(tA, tBs): 82 | setA = set(tA) 83 | for tB in tBs: 84 | assert len(tA) >= len(tB), "len(tA) >= len(tB)" 85 | if set(tB).issubset(setA): 86 | return True 87 | return False 88 | 89 | def centeredMean(nums): 90 | if len(nums) <= 2: 91 | return np.mean(nums) 92 | else: 93 | return (np.sum(nums) - np.max(nums) - np.min(nums)) / (len(nums) - 2) 94 | 95 | def getNTeamStatistics(teamNameList, accuracyDict, minAcc, avgAcc, maxAcc, tmpAccList): 96 | nAboveMin = 0 97 | nAboveAvg = 0 98 | nAboveMax = 0 99 | nHigherMember = 0 100 | allAcc = [] 101 | for teamName in teamNameList: 102 | acc = accuracyDict[teamName] 103 | allAcc.append(acc) 104 | if acc >= round(minAcc, 2): 105 | nAboveMin += 1 106 | if acc >= round(avgAcc, 2): 107 | nAboveAvg += 1 108 | if acc >= round(maxAcc, 2): 109 | nAboveMax += 1 110 | #print(teamName) 111 | # count whether an ensemble is higher than all its member model 112 | nHigherMember += 1 113 | if ',' in teamName: 114 | teamName = teamName.split(',') 115 | for modelName in teamName: 116 | if len(tmpAccList) > 1 and isinstance(tmpAccList[0], list): 117 | modelAcc = tmpAccList[int(modelName)][0].item() 118 | else: 119 | modelAcc = tmpAccList[int(modelName)].item() 120 | if acc < modelAcc: 121 | nHigherMember -= 1 122 | break 123 | return len(teamNameList), np.min(allAcc), np.max(allAcc), np.mean(allAcc), np.std(allAcc), nHigherMember, nAboveMax, nAboveAvg, nAboveMin 124 | 125 | # random selection 126 | def randomSelection(teamNameList, nRandomSamples = 1, nRepeat = 1, verbose = False): 127 | selectedTeamLists = [] 128 | for i in range(nRepeat): 129 | randomIdx = np.random.choice(np.arange(len(teamNameList)), nRandomSamples) 130 | for idx in randomIdx: 131 | selectedTeamLists.append(teamNameList[idx]) 132 | if verbose: 133 | print(selectedTeamLists) 134 | return selectedTeamLists 135 | 136 | def printTopNTeamStatistics(teamNameList, accuracyDict, minAcc, avgAcc, maxAcc, tmpAccList, divScores, dm, topN=5, divFormat="teamName-dm", verbose=False): 137 | tmpFQTeamNameAccList = [] 138 | for teamName in teamNameList: 139 | if divFormat == "dm-teamName": 140 | tmpFQTeamNameAccList.append([divScores[dm][teamName], 141 | teamName, accuracyDict[teamName]]) 142 | else: 143 | tmpFQTeamNameAccList.append([divScores[teamName][dm], 144 | teamName, accuracyDict[teamName]]) 145 | 146 | #tmpFQTeamNameAccList.sort() 147 | 148 | tmpFQTeamNameAccList = sorted(tmpFQTeamNameAccList, key=itemgetter(2), reverse=True) 149 | tmpFQTeamNameAccList = sorted(tmpFQTeamNameAccList, key=itemgetter(0)) 150 | 151 | tmpFQTeamNameAccList = tmpFQTeamNameAccList[:topN] 152 | if verbose: 153 | tmpTeamNameList = [] 154 | for i in range(min(topN, len(tmpFQTeamNameAccList))): 155 | print(tmpFQTeamNameAccList[i]) 156 | tmpTeamNameList.append(tmpFQTeamNameAccList[i][1]) 157 | print(tmpTeamNameList) 158 | print(dm, getNTeamStatistics([tmpFTA[1] for tmpFTA in tmpFQTeamNameAccList], 159 | accuracyDict, minAcc, avgAcc, maxAcc, tmpAccList)) 160 | 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # EnsembleBench 3 | 4 | ----------------- 5 | [![GitHub license](https://img.shields.io/badge/license-apache-green.svg?style=flat)](https://www.apache.org/licenses/LICENSE-2.0) 6 | [![Version](https://img.shields.io/badge/version-0.0.1-red.svg?style=flat)]() 7 | 12 | ## Introduction 13 | 14 | A set of tools for building high diversity ensembles. 15 | 16 | * a set of quantitative metrics for assessing the quality of ensembles; 17 | * a suite of baseline diversity metrics and optimized diversity metrics for identifying and selecting ensembles with high diversity and high quality; 18 | * representative ensemble consensus methods: soft voting (model averaging), majority voting, plurality voting and boosting voting. 19 | 20 | CogMI 2020 Presentation Video: https://youtu.be/ErZj_OxyYxc 21 | 22 | If you find this work useful in your research, please cite the following papers: 23 | 24 | **Bibtex**: 25 | ```bibtex 26 | @INPROCEEDINGS{ensemblebench, 27 | author={Y. {Wu} and L. {Liu} and Z. {Xie} and J. {Bae} and K. -H. {Chow} and W. {Wei}}, 28 | booktitle={2020 IEEE Second International Conference on Cognitive Machine Intelligence (CogMI)}, 29 | title={Promoting High Diversity Ensemble Learning with EnsembleBench}, 30 | year={2020}, 31 | volume={}, 32 | number={}, 33 | pages={208-217}, 34 | doi={10.1109/CogMI50398.2020.00034} 35 | } 36 | @INPROCEEDINGS{dp-ensemble, 37 | author={Wu, Yanzhao and Liu, Ling and Xie, Zhongwei and Chow, Ka-Ho and Wei, Wenqi}, 38 | booktitle={2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 39 | title={Boosting Ensemble Accuracy by Revisiting Ensemble Diversity Metrics}, 40 | year={2021}, 41 | volume={}, 42 | number={}, 43 | pages={16464-16472}, 44 | doi={10.1109/CVPR46437.2021.01620} 45 | } 46 | @INPROCEEDINGS{hq-ensemble, 47 | author={Wu, Yanzhao and Liu, Ling}, 48 | booktitle={2021 IEEE International Conference on Data Mining (ICDM)}, 49 | title={Boosting Deep Ensemble Performance with Hierarchical Pruning}, 50 | year={2021}, 51 | volume={}, 52 | number={}, 53 | pages={1433-1438}, 54 | doi={10.1109/ICDM51629.2021.00184} 55 | } 56 | ``` 57 | 58 | ## Instructions 59 | 60 | 61 | ### Installation 62 | 63 | 1. It is recommended to clone this git repo and refer to the demo folder for building your own projects using EnsembleBench. 64 | 65 | git clone https://github.com/git-disl/EnsembleBench.git 66 | 67 | 2. Initialize the environmental variables: 68 | 69 | source env.sh 70 | 71 | 3. Install the Python dependencies. 72 | 73 | 4. Run the demos under the demo folder. 74 | 75 | 76 | If you would like to simply use some functions provided by EnsembleBench, you may install it using the following command. 77 | 78 | pip install EnsembleBench 79 | 80 | 81 | 82 | ## Supported Platforms 83 | 84 | The source codes have been tested on Ubuntu 16.04 and Ubuntu 20.04. 85 | 86 | 87 | 88 | ## Development / Contributing 89 | 90 | 91 | ## Issues 92 | 93 | 94 | ## Status 95 | 96 | 97 | ## Contributors 98 | 99 | See the [people page](https://github.com/git-disl/EnsembleBench/graphs/contributors) for the full listing of contributors. 100 | 101 | ## License 102 | 103 | Copyright (c) 20XX-20XX [Georgia Tech DiSL](https://github.com/git-disl) 104 | Licensed under the [Apache License](LICENSE). 105 | 106 | -------------------------------------------------------------------------------- /demo/BaselineDiversityBasedEnsembleSelection.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "c22b5df4", 6 | "metadata": {}, 7 | "source": [ 8 | "# Baseline Diversity-based Ensemble Selection\n", 9 | "\n", 10 | "This demo provides the baseline diversity-based ensemble selection examples on CIFAR-10 and ImageNet." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "f2d4f7a8", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import os\n", 21 | "import time\n", 22 | "import timeit\n", 23 | "import numpy as np\n", 24 | "\n", 25 | "import torch\n", 26 | "from itertools import combinations\n", 27 | "\n", 28 | "from EnsembleBench.frameworks.pytorchUtility import (\n", 29 | " calAccuracy,\n", 30 | " calAveragePredictionVectorAccuracy,\n", 31 | " calNegativeSamplesSet,\n", 32 | " calDisagreementSamplesNoGroundTruth,\n", 33 | " filterModelsFixed,\n", 34 | ")\n", 35 | "\n", 36 | "from EnsembleBench.groupMetrics import (\n", 37 | " calAllDiversityMetrics,\n", 38 | ")\n", 39 | "from EnsembleBench.teamSelection import (\n", 40 | " getNTeamStatistics,\n", 41 | ")\n", 42 | "\n", 43 | "%load_ext autoreload\n", 44 | "%autoreload 2" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "id": "17bf108a", 50 | "metadata": {}, 51 | "source": [ 52 | "## Dataset Configurations\n", 53 | "\n", 54 | "You can download the extracted predictions for CIFAR-10 and ImageNet from the following Google Drive folder.\n", 55 | "https://drive.google.com/drive/folders/18rEcjSpMSy-XN2bUQ3PfsBppwb874B8q?usp=sharing" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "cfabb251", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "# simply use the extracted prediction results to calculate the diversity scores and perform ensemble selection\n", 66 | "\n", 67 | "dataset = 'cifar10'\n", 68 | "diversityMetricsList = ['CK', 'QS', 'BD', 'FK', 'KW', 'GD']\n", 69 | "\n", 70 | "if dataset == 'cifar10':\n", 71 | " predictionDir = './cifar10/prediction'\n", 72 | " models = ['densenet-L190-k40', 'densenetbc-100-12', 'resnext8x64d', 'wrn-28-10-drop', 'vgg19_bn', \n", 73 | " 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110']\n", 74 | "elif dataset == 'imagenet':\n", 75 | " predictionDir = './imagenet/prediction'\n", 76 | " models = np.array(['AlexNet', 'DenseNet', 'EfficientNetb0', 'ResNeXt50', 'Inception3', 'ResNet152', 'ResNet18', 'SqueezeNet', 'VGG16', 'VGG19bn'])\n", 77 | "else:\n", 78 | " raise Exception(\"Dataset not support!\")\n", 79 | "\n", 80 | "suffix = '.pt'\n" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "id": "67b1f777", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "# load prediction vectors\n", 91 | "labelVectorsList = list()\n", 92 | "predictionVectorsList = list()\n", 93 | "tmpAccList = list()\n", 94 | "for m in models:\n", 95 | " predictionPath = os.path.join(predictionDir, m+suffix)\n", 96 | " prediction = torch.load(predictionPath)\n", 97 | " predictionVectors = prediction['predictionVectors']\n", 98 | " predictionVectorsList.append(torch.nn.functional.softmax(predictionVectors, dim=-1).cpu())\n", 99 | " labelVectors = prediction['labelVectors']\n", 100 | " labelVectorsList.append(labelVectors.cpu())\n", 101 | " tmpAccList.append(calAccuracy(predictionVectors, labelVectors)[0].cpu())\n", 102 | " print(tmpAccList[-1])\n", 103 | "\n", 104 | "minAcc = np.min(tmpAccList)\n", 105 | "avgAcc = np.mean(tmpAccList)\n", 106 | "maxAcc = np.max(tmpAccList)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "id": "3a2469e9", 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "# Calculate the team accuracy (optional)\n", 117 | "# team -> accuracy map\n", 118 | "# model -> team\n", 119 | "import timeit\n", 120 | "teamAccuracyDict = dict()\n", 121 | "startTime = timeit.default_timer()\n", 122 | "for n in range(2, len(models)+1):\n", 123 | " comb = combinations(list(range(len(models))), n)\n", 124 | " for selectedModels in list(comb):\n", 125 | " tmpAccuracy = calAveragePredictionVectorAccuracy(predictionVectorsList, labelVectorsList[0], modelsList=selectedModels)[0].cpu().item()\n", 126 | " teamName = \"\".join(map(str, selectedModels))\n", 127 | " teamAccuracyDict[teamName] = tmpAccuracy\n", 128 | "endTime = timeit.default_timer()\n", 129 | "print(\"Time: \", endTime-startTime)\n" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "id": "630b1d53", 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "# obtain negative samples for any base models\n", 140 | "sampleID, sampleTarget, predictions, predVectors = calDisagreementSamplesNoGroundTruth(\n", 141 | " predictionVectorsList, labelVectorsList[0]\n", 142 | ")\n", 143 | "\n", 144 | "sampleID = np.array(sampleID)\n", 145 | "sampleTarget = np.array(sampleTarget)\n", 146 | "predictions = np.array(predictions)\n", 147 | "predVectors = np.array([np.array([np.array(pp) for pp in p]) for p in predVectors])\n", 148 | "\n", 149 | "# settings for the diversity score calculation\n", 150 | "nModels = len(predictions[0])\n", 151 | "modelIdx = list(range(nModels))" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "id": "07373f1e", 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "# statistics for different diversity metrics\n", 162 | "np.random.seed(0)\n", 163 | "crossValidation = True\n", 164 | "crossValidationTimes = 3\n", 165 | "nRandomSamples = 100\n", 166 | "\n", 167 | "teamSizeList = list()\n", 168 | "teamList = list()\n", 169 | "diversityScoresList = list()\n", 170 | "\n", 171 | "startTime = timeit.default_timer()\n", 172 | "for n in range(2, nModels+1):\n", 173 | " comb = combinations(modelIdx, n)\n", 174 | " for selectedModels in list(comb):\n", 175 | " teamSampleID, teamSampleTarget, teamPredictions, teamPredVectors = filterModelsFixed(sampleID, sampleTarget, predictions, predVectors, selectedModels) \n", 176 | " \n", 177 | " if len(teamPredictions) == 0:\n", 178 | " print(\"negative sample not found\")\n", 179 | " continue\n", 180 | " \n", 181 | " if crossValidation:\n", 182 | " tmpMetrics = list() \n", 183 | " for _ in range(crossValidationTimes):\n", 184 | " randomIdx = np.random.choice(np.arange(teamPredictions.shape[0]), nRandomSamples)\n", 185 | " tmpMetrics.append(calAllDiversityMetrics(teamPredictions[randomIdx], teamSampleTarget[randomIdx], diversityMetricsList))\n", 186 | " tmpMetrics = np.mean(np.array(tmpMetrics), axis=0)\n", 187 | " else:\n", 188 | " tmpMetrics = np.array(calAllDiversityMetrics(teamPredictions, teamSampleTarget, diversityMetricsList))\n", 189 | " \n", 190 | " diversityScoresList.append(tmpMetrics) \n", 191 | " teamSizeList.append(n)\n", 192 | " teamList.append(selectedModels)\n", 193 | "endTime = timeit.default_timer()\n", 194 | "print(\"Time: \", endTime-startTime)\n", 195 | "\n", 196 | "diversityScoresList = np.array(diversityScoresList)\n", 197 | "teamSizeList = np.array(teamSizeList)\n", 198 | "teamList = np.array(teamList, dtype=object)\n" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "id": "56b18918", 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "# perform mean-threshold based ensemble selection (baseline approach)\n", 209 | "QMetrics = {}\n", 210 | "QMetricsThreshold = {}\n", 211 | "teamSelectedQAllDict = {}\n", 212 | "\n", 213 | "for j, dm in enumerate(diversityMetricsList):\n", 214 | " QMetricsThreshold[dm] = np.mean(diversityScoresList[..., j])\n", 215 | "\n", 216 | "print(QMetricsThreshold)\n", 217 | "\n", 218 | "for i, t in enumerate(teamList):\n", 219 | " teamName = \"\".join(map(str, t))\n", 220 | " for j, dm in enumerate(diversityMetricsList):\n", 221 | " QMetricsDM = QMetrics.get(dm, {})\n", 222 | " QMetricsDM[teamName] = diversityScoresList[i][j]\n", 223 | " QMetrics[dm] = QMetricsDM\n", 224 | " if QMetricsDM[teamName] < round(QMetricsThreshold[dm], 3):\n", 225 | " teamSelectedQAllSet = teamSelectedQAllDict.get(dm, set())\n", 226 | " teamSelectedQAllSet.add(teamName)\n", 227 | " teamSelectedQAllDict[dm] = teamSelectedQAllSet\n", 228 | "\n", 229 | "for dm in diversityMetricsList:\n", 230 | " print(dm, getNTeamStatistics(list(teamSelectedQAllDict[dm]), teamAccuracyDict,\n", 231 | " minAcc, avgAcc, maxAcc, tmpAccList))\n" 232 | ] 233 | } 234 | ], 235 | "metadata": { 236 | "kernelspec": { 237 | "display_name": "Python 3 (ipykernel)", 238 | "language": "python", 239 | "name": "python3" 240 | }, 241 | "language_info": { 242 | "codemirror_mode": { 243 | "name": "ipython", 244 | "version": 3 245 | }, 246 | "file_extension": ".py", 247 | "mimetype": "text/x-python", 248 | "name": "python", 249 | "nbconvert_exporter": "python", 250 | "pygments_lexer": "ipython3", 251 | "version": "3.7.11" 252 | } 253 | }, 254 | "nbformat": 4, 255 | "nbformat_minor": 5 256 | } 257 | -------------------------------------------------------------------------------- /demo/FocalDiversityBasedEnsembleSelection.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "031d9572", 6 | "metadata": {}, 7 | "source": [ 8 | "# Focal Diversity-based Ensemble Selection\n", 9 | "\n", 10 | "This demo provides the focal diversity-based ensemble selection examples on CIFAR-10 and ImageNet." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "f433d980", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import os\n", 21 | "import time\n", 22 | "import torch\n", 23 | "import numpy as np\n", 24 | "from itertools import combinations\n", 25 | "import timeit\n", 26 | "\n", 27 | "# EnsembleBench modules\n", 28 | "from EnsembleBench.frameworks.pytorchUtility import (\n", 29 | " calAccuracy,\n", 30 | " calAveragePredictionVectorAccuracy,\n", 31 | " calNegativeSamplesSet,\n", 32 | " calDisagreementSamplesOneTargetNegative,\n", 33 | " filterModelsFixed,\n", 34 | ")\n", 35 | "\n", 36 | "%load_ext autoreload\n", 37 | "%autoreload 2" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "3a9e3966", 43 | "metadata": {}, 44 | "source": [ 45 | "## Dataset Configurations\n", 46 | "\n", 47 | "You can download the extracted predictions for CIFAR-10 and ImageNet from the following Google Drive folder.\n", 48 | "https://drive.google.com/drive/folders/18rEcjSpMSy-XN2bUQ3PfsBppwb874B8q?usp=sharing" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "id": "75188ecb", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "# simply use the extracted prediction results to calculate the diversity scores and perform ensemble selection\n", 59 | "\n", 60 | "dataset = 'cifar10'\n", 61 | "diversityMetricsList = ['CK', 'QS', 'BD', 'FK', 'KW', 'GD']\n", 62 | "\n", 63 | "if dataset == 'cifar10':\n", 64 | " predictionDir = './cifar10/prediction'\n", 65 | " models = ['densenet-L190-k40', 'densenetbc-100-12', 'resnext8x64d', 'wrn-28-10-drop', 'vgg19_bn', \n", 66 | " 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110']\n", 67 | " maxModel = 0\n", 68 | " thresholdAcc = 96.68\n", 69 | "elif dataset == 'imagenet':\n", 70 | " predictionDir = './imagenet/prediction'\n", 71 | " models = np.array(['AlexNet', 'DenseNet', 'EfficientNetb0', 'ResNeXt50', 'Inception3', 'ResNet152', 'ResNet18', 'SqueezeNet', 'VGG16', 'VGG19bn'])\n", 72 | " maxModel = 5\n", 73 | " thresholdAcc = 78.25\n", 74 | "\n", 75 | "else:\n", 76 | " raise Exception(\"Dataset not support!\")\n", 77 | "\n", 78 | "suffix = '.pt'" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "id": "07d3ebc4", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "labelVectorsList = list()\n", 89 | "predictionVectorsList = list()\n", 90 | "tmpAccList = list()\n", 91 | "for m in models:\n", 92 | " predictionPath = os.path.join(predictionDir, m+suffix)\n", 93 | " prediction = torch.load(predictionPath)\n", 94 | " predictionVectors = prediction['predictionVectors']\n", 95 | " predictionVectorsList.append(torch.nn.functional.softmax(predictionVectors, dim=-1).cpu())\n", 96 | " labelVectors = prediction['labelVectors']\n", 97 | " labelVectorsList.append(labelVectors.cpu())\n", 98 | " tmpAccList.append(calAccuracy(predictionVectors, labelVectors)[0].cpu())\n", 99 | " print(tmpAccList[-1])\n", 100 | "\n", 101 | "minAcc = np.min(tmpAccList)\n", 102 | "avgAcc = np.mean(tmpAccList)\n", 103 | "maxAcc = np.max(tmpAccList)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "id": "9f776923", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "# preprocessing\n", 114 | "# team -> accuracy map\n", 115 | "# model -> team\n", 116 | "teamAccuracyDict = dict()\n", 117 | "modelTeamDict = dict()\n", 118 | "teamNameDict = dict()\n", 119 | "startTime = timeit.default_timer()\n", 120 | "for n in range(2, len(models)+1):\n", 121 | " comb = combinations(list(range(len(models))), n)\n", 122 | " for selectedModels in list(comb):\n", 123 | " tmpAccuracy = calAveragePredictionVectorAccuracy(predictionVectorsList, labelVectorsList[0], modelsList=selectedModels)[0].cpu().item()\n", 124 | " teamName = \"\".join(map(str, selectedModels))\n", 125 | " teamNameDict[teamName] = selectedModels\n", 126 | " teamAccuracyDict[teamName] = tmpAccuracy\n", 127 | " for m in teamName:\n", 128 | " if m in modelTeamDict:\n", 129 | " modelTeamDict[m].add(teamName)\n", 130 | " else:\n", 131 | " modelTeamDict[m] = set([teamName,])\n", 132 | "endTime = timeit.default_timer()\n", 133 | "print(\"Time: \", endTime-startTime)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "id": "d577abdd", 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "# calculate the diversity measures for all configurations\n", 144 | "import numpy as np\n", 145 | "from EnsembleBench.groupMetrics import *\n", 146 | "np.random.seed(0)\n", 147 | "nRandomSamples = 100\n", 148 | "crossValidation = True\n", 149 | "crossValidationTimes = 3\n", 150 | "\n", 151 | "teamDiversityMetricMap = dict()\n", 152 | "negAccuracyDict = dict()\n", 153 | "startTime = timeit.default_timer()\n", 154 | "for oneTargetModel in range(len(models)):\n", 155 | " sampleID, sampleTarget, predictions, predVectors = calDisagreementSamplesOneTargetNegative(predictionVectorsList, labelVectorsList[0], oneTargetModel)\n", 156 | " if len(predictions) == 0:\n", 157 | " print(\"negative sample not found\")\n", 158 | " continue\n", 159 | " sampleID = np.array(sampleID)\n", 160 | " sampleTarget = np.array(sampleTarget)\n", 161 | " predictions = np.array(predictions)\n", 162 | " predVectors = np.array([np.array([np.array(pp) for pp in p]) for p in predVectors])\n", 163 | " for teamName in modelTeamDict[str(oneTargetModel)]:\n", 164 | " selectedModels = teamNameDict[teamName]\n", 165 | " teamSampleID, teamSampleTarget, teamPredictions, teamPredVectors = filterModelsFixed(sampleID, sampleTarget, predictions, predVectors, selectedModels) \n", 166 | " if crossValidation:\n", 167 | " tmpMetrics = list()\n", 168 | " for _ in range(crossValidationTimes):\n", 169 | " randomIdx = np.random.choice(np.arange(teamPredictions.shape[0]), nRandomSamples) \n", 170 | " tmpMetrics.append(calAllDiversityMetrics(teamPredictions[randomIdx], teamSampleTarget[randomIdx], diversityMetricsList))\n", 171 | " tmpMetrics = np.mean(np.array(tmpMetrics), axis=0)\n", 172 | " else:\n", 173 | " tmpMetrics = np.array(calAllDiversityMetrics(teamPredictions, teamSampleTarget, diversityMetricsList))\n", 174 | " diversityMetricDict = {diversityMetricsList[i]:tmpMetrics[i].item() for i in range(len(tmpMetrics))}\n", 175 | " targetDiversity = teamDiversityMetricMap.get(teamName, dict())\n", 176 | " targetDiversity[str(oneTargetModel)] = diversityMetricDict\n", 177 | " teamDiversityMetricMap[teamName] = targetDiversity\n", 178 | " \n", 179 | " tmpNegAccuracy = calAccuracy(torch.tensor(np.mean(np.transpose(teamPredVectors, (1, 0, 2)), axis=0)), torch.tensor(teamSampleTarget))[0].cpu().item()\n", 180 | " targetNegAccuracy = negAccuracyDict.get(teamName, dict())\n", 181 | " targetNegAccuracy[str(oneTargetModel)] = tmpNegAccuracy\n", 182 | " negAccuracyDict[teamName] = targetNegAccuracy\n", 183 | "\n", 184 | "endTime = timeit.default_timer()\n", 185 | "print(\"Time: \", endTime-startTime)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "id": "fc991e2d", 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "# calculate the targetTeamSizeDict\n", 196 | "startTime = timeit.default_timer()\n", 197 | "targetTeamSizeDict = dict()\n", 198 | "for oneTargetModel in range(len(models)):\n", 199 | " for teamName in modelTeamDict[str(oneTargetModel)]:\n", 200 | " teamSize = len(teamName)\n", 201 | " teamSizeDict = targetTeamSizeDict.get(str(oneTargetModel), dict())\n", 202 | " fixedTeamDict = teamSizeDict.get(str(teamSize), dict())\n", 203 | " \n", 204 | " teamList = fixedTeamDict.get('TeamList', list())\n", 205 | " teamList.append(teamName)\n", 206 | " fixedTeamDict['TeamList'] = teamList\n", 207 | " \n", 208 | " # diversity measures\n", 209 | " diversityVector = np.expand_dims(np.array([teamDiversityMetricMap[teamName][str(oneTargetModel)][dm]\n", 210 | " for dm in diversityMetricsList]), axis=0)\n", 211 | " \n", 212 | " diversityMatrix = fixedTeamDict.get('DiversityMatrix', None)\n", 213 | " if diversityMatrix is None:\n", 214 | " diversityMatrix = diversityVector\n", 215 | " else:\n", 216 | " diversityMatrix = np.append(diversityMatrix, diversityVector, axis=0)\n", 217 | " fixedTeamDict['DiversityMatrix'] = diversityMatrix\n", 218 | " \n", 219 | " teamSizeDict[str(teamSize)] = fixedTeamDict\n", 220 | " targetTeamSizeDict[str(oneTargetModel)] = teamSizeDict \n", 221 | "endTime = timeit.default_timer()\n", 222 | "print(\"Time: \", endTime-startTime)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "id": "0e907336", 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "teamSelectedFQDict = dict()\n", 233 | "teamSelectedFQOutDict = dict()\n", 234 | "from EnsembleBench.teamSelection import *\n", 235 | "for oneTargetModel in range(len(models)):\n", 236 | " targetFQDict = teamSelectedFQDict.get(str(oneTargetModel), dict())\n", 237 | " targetFQOutDict = teamSelectedFQOutDict.get(str(oneTargetModel), dict())\n", 238 | " for teamSize in range(2, len(models)):\n", 239 | " targetTeamSizeFQDict = targetFQDict.get(str(teamSize), dict())\n", 240 | " targetTeamSizeFQOutDict = targetFQOutDict.get(str(teamSize), dict())\n", 241 | " fixedTeamDict = targetTeamSizeDict[str(oneTargetModel)][str(teamSize)]\n", 242 | " thresholds = list()\n", 243 | " kmeans = list()\n", 244 | " teamList = fixedTeamDict['TeamList']\n", 245 | " accuracyList = [teamAccuracyDict[teamName] for teamName in teamList]\n", 246 | " diversityMatrix = fixedTeamDict['DiversityMatrix']\n", 247 | " for i in range(len(diversityMetricsList)):\n", 248 | " tmpThreshold, tmpKMeans = getThresholdClusteringKMeans(accuracyList, diversityMatrix[:, i], kmeansInit='strategic')\n", 249 | " tmpThreshold = min(np.mean(diversityMatrix[:, i]), tmpThreshold)\n", 250 | " thresholds.append(tmpThreshold)\n", 251 | " kmeans.append(tmpKMeans)\n", 252 | " fixedTeamDict['Threshold'] = thresholds\n", 253 | " fixedTeamDict['KMeans'] = kmeans\n", 254 | " \n", 255 | " # calculate scaled diversity scores\n", 256 | " scaledDiversityMeasures = list()\n", 257 | " for i in range(len(diversityMetricsList)):\n", 258 | " scaledDiversityMeasures.append(normalize01(diversityMatrix[:, i]))\n", 259 | " scaledDiversityMatrix = np.stack(scaledDiversityMeasures, axis=1)\n", 260 | " fixedTeamDict['ScaledDiversityMatrix'] = scaledDiversityMatrix\n", 261 | " targetTeamSizeDict[str(oneTargetModel)][str(teamSize)] = fixedTeamDict\n", 262 | " \n", 263 | " for i, teamName in enumerate(fixedTeamDict['TeamList']):\n", 264 | " for j in range(len(diversityMetricsList)):\n", 265 | " targetTeamSizeFQDiversitySet = targetTeamSizeFQDict.get(diversityMetricsList[j], set())\n", 266 | " targetTeamSizeFQOutDiversitySet = targetTeamSizeFQOutDict.get(diversityMetricsList[j], set())\n", 267 | " if diversityMatrix[i, j] < round(thresholds[j], 3):\n", 268 | " targetTeamSizeFQDiversitySet.add(teamName)\n", 269 | " else:\n", 270 | " targetTeamSizeFQOutDiversitySet.add(teamName)\n", 271 | " targetTeamSizeFQDict[diversityMetricsList[j]] = targetTeamSizeFQDiversitySet\n", 272 | " targetTeamSizeFQOutDict[diversityMetricsList[j]] = targetTeamSizeFQOutDiversitySet\n", 273 | "\n", 274 | " targetFQDict[str(teamSize)] = targetTeamSizeFQDict\n", 275 | " targetFQOutDict[str(teamSize)] = targetTeamSizeFQOutDict\n", 276 | "\n", 277 | " \n", 278 | " teamSelectedFQDict[str(oneTargetModel)] = targetFQDict\n", 279 | " teamSelectedFQOutDict[str(oneTargetModel)] = targetFQOutDict" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": null, 285 | "id": "3083f71a", 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "teamSelectedFQAllDict = dict()\n", 290 | "for j, dm in enumerate(diversityMetricsList):\n", 291 | " teamSelectedFQAllDiversitySet = teamSelectedFQAllDict.get(dm, set())\n", 292 | " for teamSize in range(2, len(models)):\n", 293 | " teamSizeSelectedTeamsSet = set()\n", 294 | " tmpTeamDict = dict() # teamName & Metric\n", 295 | " for oneTargetModel in range(len(models)):\n", 296 | " for teamName in teamSelectedFQDict[str(oneTargetModel)][str(teamSize)][dm]:\n", 297 | " if teamName in tmpTeamDict:\n", 298 | " continue\n", 299 | " tmpMetricList = list()\n", 300 | " teamModelIdx = map(int, [modelName for modelName in teamName])\n", 301 | " teamModelAcc = [tmpAccList[modelIdx].item() for modelIdx in teamModelIdx]\n", 302 | " teamModelWeights = np.argsort(teamModelAcc)\n", 303 | " tmpModelWeights = list()\n", 304 | " for (k, modelName) in enumerate(teamName):\n", 305 | " fixedTeamDict = targetTeamSizeDict[modelName][str(teamSize)]\n", 306 | " for i, tmpTeamName in enumerate(fixedTeamDict['TeamList']):\n", 307 | " if tmpTeamName == teamName:\n", 308 | " tmpMetricList.append(fixedTeamDict['ScaledDiversityMatrix'][i, j])\n", 309 | " tmpModelWeights.append(teamModelWeights[k])\n", 310 | " tmpTeamDict[teamName] = np.average(tmpMetricList, weights=tmpModelWeights)\n", 311 | " if len(tmpTeamDict) > 0:\n", 312 | " accuracyList = np.array([teamAccuracyDict[teamName] for teamName in tmpTeamDict])\n", 313 | " metricList = np.array([tmpTeamDict[teamName] for teamName in tmpTeamDict])\n", 314 | " tmpThreshold, _ = getThresholdClusteringKMeansCenter(accuracyList, metricList, kmeansInit='strategic')\n", 315 | " for teamName in tmpTeamDict:\n", 316 | " if tmpTeamDict[teamName] < tmpThreshold:\n", 317 | " teamSizeSelectedTeamsSet.add(teamName)\n", 318 | " teamSelectedFQAllDiversitySet.update(teamSizeSelectedTeamsSet)\n", 319 | " teamSelectedFQAllDict[dm] = teamSelectedFQAllDiversitySet\n", 320 | "\n", 321 | "\n", 322 | "# print the ensemble selection results\n", 323 | "for dm in diversityMetricsList:\n", 324 | " print(dm, getNTeamStatistics(list(teamSelectedFQAllDict[dm]), teamAccuracyDict, minAcc, avgAcc, maxAcc, tmpAccList))\n", 325 | " \n", 326 | " " 327 | ] 328 | } 329 | ], 330 | "metadata": { 331 | "kernelspec": { 332 | "display_name": "Python 3 (ipykernel)", 333 | "language": "python", 334 | "name": "python3" 335 | }, 336 | "language_info": { 337 | "codemirror_mode": { 338 | "name": "ipython", 339 | "version": 3 340 | }, 341 | "file_extension": ".py", 342 | "mimetype": "text/x-python", 343 | "name": "python", 344 | "nbconvert_exporter": "python", 345 | "pygments_lexer": "ipython3", 346 | "version": "3.7.11" 347 | } 348 | }, 349 | "nbformat": 4, 350 | "nbformat_minor": 5 351 | } 352 | -------------------------------------------------------------------------------- /env.sh: -------------------------------------------------------------------------------- 1 | #conda activate ens 2 | export PYTHONPATH=$PYTHONPATH:`pwd` 3 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Inside of setup.cfg 2 | [metadata] 3 | description-file = README.md 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | import os 4 | import sys 5 | 6 | with open('README.md') as f: 7 | long_description = f.read() 8 | 9 | 10 | setup( 11 | name = 'EnsembleBench', 12 | packages = ['EnsembleBench',], 13 | version = '0.0.0.1', 14 | description = 'A set of tools for building good ensemble model teams in machine learning.', 15 | long_description=long_description, 16 | long_description_content_type='text/markdown', 17 | author = 'Yanzhao Wu', 18 | author_email = 'yanzhaowumail@gmail.com', 19 | url = 'https://github.com/git-disl/EnsembleBench', 20 | download_url = 'https://github.com/git-disl/EnsembleBench/archive/master.zip', 21 | keywords = ['ENSEMBLE', 'INFERENCE', 'MACHINE LEARNING'], 22 | install_requires=[ 23 | 'numpy', 24 | 'matplotlib', 25 | 'scikit-learn', 26 | 'statsmodels', 27 | ], 28 | classifiers=[ 29 | 'Development Status :: 3 - Alpha', 30 | 'Intended Audience :: Developers', 31 | 'Intended Audience :: Science/Research', 32 | 'Programming Language :: Python :: 2.7', 33 | 'Programming Language :: Python :: 3', 34 | 'Programming Language :: Python :: 3.4', 35 | 'Programming Language :: Python :: 3.5', 36 | 'Programming Language :: Python :: 3.6', 37 | ], 38 | ) 39 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if ! [ -x "$(command -v conda)" ]; then 3 | echo "Error: Please intall conda" 4 | exit 1 5 | fi 6 | source $(conda info --base)/etc/profile.d/conda.sh 7 | conda create -n ens python=2.7 # or python=3.6 8 | conda activate ens 9 | # install jupyter notebook 10 | conda install jupyter 11 | --------------------------------------------------------------------------------