├── .gitignore ├── README.md ├── simforest ├── __init__.py └── _simforest.py ├── setup.py ├── example.py └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache/ 2 | .idea/ 3 | 4 | *.pyc 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # similarity-forest 2 | Basic (not especially-optimised) implementation of the Similarity Forest algorithm, as outlined 3 | [here](http://www.kdd.org/kdd2017/papers/view/similarity-forests) 4 | -------------------------------------------------------------------------------- /simforest/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Similarity Forest implementation as in 3 | 'Similarity Forests', S. Sathe and C. C. Aggarwal, KDD 2017' 4 | """ 5 | 6 | from ._simforest import SimilarityForest 7 | 8 | __all__ = ( 9 | 'SimilarityForest' 10 | ) 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | 4 | def readme(): 5 | with open('README.rst') as f: 6 | return f.read() 7 | 8 | setup(name='Similarity Forest', 9 | version='0.1', 10 | description='Similarity Forest', 11 | url='https://github.com/rrricharrrd/similarity-forest', 12 | author='Richard Harris', 13 | author_email='rrricharrrd@gmail.com', 14 | license='MIT', 15 | packages=['simforest'], 16 | install_requires=['numpy'], 17 | scripts=[], 18 | zip_safe=False) 19 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | from simforest import SimilarityForest 2 | 3 | from sklearn.datasets import make_blobs 4 | from sklearn.ensemble import RandomForestClassifier 5 | from sklearn.metrics import accuracy_score 6 | from sklearn.model_selection import train_test_split 7 | 8 | 9 | if __name__ == '__main__': 10 | X, y = make_blobs(n_samples=1000, centers=[(0, 0), (1, 1)]) 11 | X_train, X_test, y_train, y_test = train_test_split( 12 | X, y, test_size=0.2, random_state=1234) 13 | 14 | sf = SimilarityForest(n_estimators=20, n_axes=1) 15 | sf.fit(X_train, y_train) 16 | 17 | sf_pred = sf.predict(X_test) 18 | sf_prob = sf.predict_proba(X_test) 19 | 20 | print('Similarity Forest') 21 | print(sf_prob[:, 1]) 22 | print(y_test) 23 | print(accuracy_score(y_test, sf_pred)) 24 | 25 | rf = RandomForestClassifier() 26 | rf.fit(X_train, y_train) 27 | 28 | rf_pred = rf.predict(X_test) 29 | rf_prob = rf.predict_proba(X_test) 30 | 31 | print('Random Forest') 32 | print(rf_prob[:, 1]) 33 | print(y_test) 34 | print(accuracy_score(y_test, rf_pred)) 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Richard Harris 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /simforest/_simforest.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | 3 | import numpy as np 4 | 5 | 6 | def _sample_axes(labels, rand, n_samples=1): 7 | pos = np.where(labels == 1)[0] 8 | neg = np.where(labels == 0)[0] 9 | for _ in range(n_samples): 10 | yield rand.choice(pos), rand.choice(neg) 11 | 12 | 13 | def _split_metric(total_left, total_right, true_left, true_right): 14 | left_pred = true_left / total_left 15 | right_pred = true_right / total_right 16 | 17 | left_gini = 1 - left_pred**2 - (1 - left_pred)**2 18 | right_gini = 1 - right_pred ** 2 - (1 - right_pred)**2 19 | 20 | left_prop = total_left / (total_left + total_right) 21 | return left_prop * left_gini + (1 - left_prop) * right_gini 22 | 23 | 24 | class Node: 25 | def __init__(self, depth, similarity_function=np.dot, n_axes=1, 26 | max_depth=None, rand=None): 27 | self.depth = depth 28 | self.max_depth = max_depth 29 | self._sim = similarity_function 30 | self.n_axes = n_axes 31 | self._left = None 32 | self._right = None 33 | self._p = None 34 | self._q = None 35 | self.criterion = None 36 | self.prediction = None 37 | self._rand = np.random.RandomState() if rand is None else rand 38 | 39 | def _find_split(self, X, y, p, q): 40 | sims = [self._sim(x, q) - self._sim(x, p) for x in X] 41 | indices = sorted([i for i in range(len(y)) if not np.isnan(sims[i])], 42 | key=lambda x: sims[x]) 43 | 44 | best_metric = 1 45 | best_p = None 46 | best_q = None 47 | best_criterion = 0 48 | 49 | n = len(indices) 50 | total_true = sum([y[j] for j in indices]) 51 | left_true = 0 52 | for i in range(n - 1): 53 | left_true += y[indices[i]] 54 | right_true = total_true - left_true 55 | split_metric = _split_metric(i + 1, n - i - 1, left_true, right_true) 56 | if split_metric < best_metric: 57 | best_metric = split_metric 58 | best_p = p 59 | best_q = q 60 | best_criterion = (sims[indices[i]] + sims[indices[i + 1]]) / 2 61 | return best_metric, best_p, best_q, best_criterion 62 | 63 | def fit(self, X, y): 64 | self.prediction = sum(y) / len(y) 65 | if self.prediction in [0, 1]: 66 | return self 67 | 68 | if self.max_depth is not None and self.depth >= self.max_depth: 69 | return self 70 | 71 | best_metric = 1 72 | best_p = None 73 | best_q = None 74 | best_criterion = 0 75 | for i, j in _sample_axes(y, self._rand, self.n_axes): 76 | metric, p, q, criterion = self._find_split(X, y, X[i], X[j]) 77 | if metric < best_metric: 78 | best_metric = metric 79 | best_p = p 80 | best_q = q 81 | best_criterion = criterion 82 | 83 | # Split found 84 | if best_metric < 1: 85 | self._p = best_p 86 | self._q = best_q 87 | self.criterion = best_criterion 88 | 89 | sims = [self._sim(x, self._q) - self._sim(x, self._p) for x in X] 90 | X_left = X[sims <= self.criterion, :] 91 | X_right = X[sims > self.criterion, :] 92 | y_left = y[sims <= self.criterion] 93 | y_right = y[sims > self.criterion] 94 | 95 | if len(y_left) > 0 and len(y_right) > 0: 96 | self._left = Node(self.depth + 1, 97 | self._sim, 98 | self.n_axes, 99 | self.max_depth, 100 | self._rand).fit(X_left, y_left) 101 | self._right = Node(self.depth + 1, 102 | self._sim, 103 | self.n_axes, 104 | self.max_depth, 105 | self._rand).fit(X_right, y_right) 106 | 107 | return self 108 | 109 | def _predict_proba_once(self, x): 110 | if self._left is None: 111 | return self.prediction 112 | elif self._sim(x, self._q) - self._sim(x, self._p) <= self.criterion: 113 | return self._left._predict_proba_once(x) 114 | elif self._sim(x, self._q) - self._sim(x, self._p) > self.criterion: 115 | return self._right._predict_proba_once(x) 116 | else: 117 | return self.prediction 118 | 119 | def predict_proba(self, X): 120 | return [self._predict_proba_once(x) for x in X] 121 | 122 | 123 | class SimilarityForest: 124 | """ 125 | Basic implementation of SimForest, as outlined in 126 | 'Similarity Forests', S. Sathe and C. C. Aggarwal, KDD 2017'. 127 | 128 | :param n_estimators: number of trees in the forest (default=10) 129 | :param similarity_function: similarity function (default is dot product) - 130 | should return np.nan if similarity unknown 131 | :param n_axes: number of 'axes' per split 132 | :param max_depth: maximum depth to grow trees to (default=None) 133 | """ 134 | def __init__(self, n_estimators=10, similarity_function=np.dot, n_axes=1, 135 | max_depth=None, random_state=None): 136 | self.n_estimators = n_estimators 137 | self.n_axes = n_axes 138 | self.max_depth = max_depth 139 | self._sim = similarity_function 140 | self._trees = None 141 | self._rand = np.random.RandomState(random_state) 142 | 143 | def _bag(self, X, y): 144 | selection = np.array(list(set(self._rand.choice(len(y), size=len(y))))) 145 | return X[selection, :], y[selection] 146 | 147 | def fit(self, X, y): 148 | """ 149 | Build a forest of trees from the training set (X, y). 150 | 151 | :param X: training set 152 | :param y: training set labels 153 | :return: self 154 | """ 155 | if len(X) != len(y): # @@@ More checks 156 | print('Bad sizes: {}, {}'.format(X.shape, y.shape)) 157 | else: 158 | self._trees = [Node(1, 159 | self._sim, 160 | self.n_axes, 161 | self.max_depth, 162 | self._rand).fit(*self._bag(X, y)) 163 | for _ in range(self.n_estimators)] 164 | return self 165 | 166 | def predict_proba(self, X): 167 | """ 168 | Predict class probabilities of X. 169 | 170 | :param X: samples to make prediction probabilities for 171 | :return: array of prediction probabilities for each class 172 | """ 173 | probs = np.mean([t.predict_proba(X) for t in self._trees], axis=0) 174 | return np.c_[1 - probs, probs] 175 | 176 | def predict(self, X): 177 | """ 178 | Predict class of X. 179 | 180 | :param X: samples to make predictions for 181 | :return: array of class predictions 182 | """ 183 | return (self.predict_proba(X)[:, 1] > 0.5).astype(np.int) 184 | --------------------------------------------------------------------------------