├── .gitignore ├── Makefile ├── README.md ├── cpm.opam ├── data └── pr_curve.py ├── dune-project ├── src ├── MakeROC.ml ├── RegrStats.ml ├── TopKeeper.ml ├── TopKeeper.mli ├── dune ├── myList.ml ├── test.ml └── utls.ml └── test.scored-label /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UnixJunkie/cpmlib/d5570252e2f0d19e1b8f0b980fe0e30fab5d1ccb/.gitignore -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: build config clean edit install uninstall reinstall test 2 | 3 | build: 4 | dune build @install 5 | dune build _build/default/src/test.exe 6 | 7 | clean: 8 | dune clean 9 | 10 | edit: 11 | emacs src/*.ml & 12 | 13 | install: build 14 | dune uninstall 15 | dune install 16 | 17 | uninstall: 18 | dune uninstall 19 | 20 | # unit tests 21 | test: 22 | dune build _build/default/src/test.exe 23 | _build/default/src/test.exe 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # cpmlib 2 | Classification and Regression Performance Metrics library 3 | -------------------------------------------------------------------------------- /cpm.opam: -------------------------------------------------------------------------------- 1 | opam-version: "2.0" 2 | authors: "Francois Berenger" 3 | maintainer: "unixjunkie@sdf.org" 4 | homepage: "https://github.com/UnixJunkie/cpmlib" 5 | bug-reports: "https://github.com/UnixJunkie/cpmlib/issues" 6 | dev-repo: "git+https://github.com/UnixJunkie/cpmlib.git" 7 | license: "LGPL-2.0-or-later" 8 | build: ["dune" "build" "-p" name "-j" jobs] 9 | depends: [ 10 | "dune" {>= "2.8"} 11 | "batteries" {>= "2.6"} 12 | "ocaml" 13 | ] 14 | depopts: [ 15 | "conf-gnuplot" 16 | ] 17 | synopsis: "The Classification and Regression Performance Metrics library" 18 | description: """ 19 | For classification/ranking: ROC AUC, BEDROC AUC, Enrichment Factor, 20 | Robust Initial Enhancement, Power Metric, Matthews' Correlation Coefficient, 21 | Platt scaling. 22 | 23 | For regression: Root Mean Squared Error, Mean Absolute Error, 24 | r^2 coefficient of determination, Raw Regression Error Characteristic Curve. 25 | 26 | Also features a TopKeeper module; to keep in memory the top 'k' 27 | scored items when dealing with very large datasets. 28 | """ 29 | #url { 30 | # src: "https://github.com/UnixJunkie/cpmlib/archive/vXXX.tar.gz" 31 | # checksum: "md5=YYY" 32 | #} 33 | -------------------------------------------------------------------------------- /data/pr_curve.py: -------------------------------------------------------------------------------- 1 | # date: 24/07/2019 author: Jason Brownlee 2 | # source: https://machinelearningmastery.com/ 3 | # roc-curves-and-precision-recall-curves-for-classification-in-python/ 4 | 5 | # precision-recall curve and f1 6 | from sklearn.datasets import make_classification 7 | from sklearn.neighbors import KNeighborsClassifier 8 | from sklearn.model_selection import train_test_split 9 | from sklearn.metrics import precision_recall_curve 10 | from sklearn.metrics import f1_score 11 | from sklearn.metrics import auc 12 | from sklearn.metrics import average_precision_score 13 | from matplotlib import pyplot 14 | # generate 2 class dataset 15 | X, y = make_classification(n_samples=1000, n_classes=2, weights=[1,1], random_state=1) 16 | # split into train/test sets 17 | trainX, testX, trainy, testy = train_test_split(X, y, test_size=0.5, random_state=2) 18 | # fit a model 19 | model = KNeighborsClassifier(n_neighbors=3) 20 | model.fit(trainX, trainy) 21 | # predict probabilities 22 | probs = model.predict_proba(testX) 23 | # keep probabilities for the positive outcome only 24 | probs = probs[:, 1] 25 | # predict class values 26 | yhat = model.predict(testX) 27 | # calculate precision-recall curve 28 | precision, recall, thresholds = precision_recall_curve(testy, probs) 29 | # calculate F1 score 30 | f1 = f1_score(testy, yhat) 31 | # calculate precision-recall AUC 32 | auc = auc(recall, precision) 33 | # calculate average precision score 34 | ap = average_precision_score(testy, probs) 35 | print('f1=%.3f auc=%.3f ap=%.3f' % (f1, auc, ap)) 36 | # plot no skill 37 | pyplot.plot([0, 1], [0.5, 0.5], linestyle='--') 38 | # plot the precision-recall curve for the model 39 | pyplot.plot(recall, precision, marker='.') 40 | # show the plot 41 | pyplot.show() 42 | -------------------------------------------------------------------------------- /dune-project: -------------------------------------------------------------------------------- 1 | (lang dune 1.11) 2 | (name cpm) 3 | -------------------------------------------------------------------------------- /src/MakeROC.ml: -------------------------------------------------------------------------------- 1 | 2 | module A = BatArray 3 | module L = MyList 4 | 5 | module type SCORE_LABEL = sig 6 | type t 7 | val get_score: t -> float 8 | val get_label: t -> bool 9 | end 10 | 11 | module type ROC_FUNCTOR = functor (SL: SCORE_LABEL) -> 12 | sig 13 | (** sort score labels putting high scores first *) 14 | val rank_order_by_score: SL.t list -> SL.t list 15 | 16 | (** in-place sort of score labels; putting high scores first *) 17 | val rank_order_by_score_a: SL.t array -> unit 18 | 19 | (** cumulated actives curve given an already sorted list of score labels *) 20 | val cumulated_actives_curve: SL.t list -> int list 21 | 22 | (** ROC curve (list of (FPR,TPR) values) corresponding to 23 | those score labels *) 24 | val roc_curve: SL.t list -> (float * float) list 25 | 26 | (** logAUC of ROC curve [log_AUC lambda roc_curve]. 27 | [lambda] must be small but > 0.0 (e.g. 0.001) *) 28 | val log_AUC: float -> (float * float) list -> float 29 | 30 | (** same as [roc_curve] but for an already sorted array of score-labels *) 31 | val fast_roc_curve_a: SL.t array -> (float * float) array 32 | 33 | (** Precision Recall curve (list of (recall,precision) pairs) 34 | corresponding to given score labels *) 35 | val pr_curve: SL.t list -> (float * float) list 36 | 37 | (** compute Area Under the ROC curve given an already sorted list of 38 | score labels *) 39 | val fast_auc: SL.t list -> float 40 | 41 | (** ROC AUC: Area Under the ROC curve given an unsorted list 42 | of score labels *) 43 | val auc: SL.t list -> float 44 | 45 | (** PR AUC: Area Under the Precision-Recall curve given an unsorted list 46 | of score labels *) 47 | val pr_auc: SL.t list -> float 48 | 49 | (** Area Under the ROC curve given an unsorted array 50 | of score labels; WARNING: array will be modified (sorted) *) 51 | val auc_a: SL.t array -> float 52 | 53 | (** Area Under the ROC curve given an already sorted array 54 | of score labels *) 55 | val fast_auc_a: SL.t array -> float 56 | 57 | (** (early) enrichment factor at given threshold (database percentage) 58 | given an unsorted list of score labels *) 59 | val enrichment_factor: float -> SL.t list -> float 60 | 61 | (** (early) enrichment factor at given threshold (database percentage) 62 | given an already sorted array of score labels *) 63 | val fast_enrichment_factor: float -> SL.t array -> float 64 | 65 | (** [initial_enhancement a score_labels] will compute 66 | S = sum_over_i (exp (-rank(active_i) / a)) 67 | given an unsorted list of score labels. 68 | Robust Initial Enhancement (RIE) = S/ where is 69 | the average S for randomly ordered score labels. 70 | RIE = 1.0 <=> random performance. Cf. DOI:10.1021/ci0100144 71 | for details. *) 72 | val initial_enhancement: float -> SL.t list -> float 73 | 74 | (** same as [initial_enhancement] but does not reorder the list 75 | of score_labels *) 76 | val fast_initial_enhancement: float -> SL.t list -> float 77 | 78 | (** power metric at given threshold given an unsorted list of score labels *) 79 | val power_metric: float -> SL.t list -> float 80 | 81 | (** power metric at given threshold for an already decr. sorted list of score labels *) 82 | val fast_power_metric_a: float -> SL.t array -> float 83 | 84 | (** bedroc_auc at given alpha. Default alpha = 20.0. *) 85 | val bedroc_auc: ?alpha:float -> SL.t list -> float 86 | 87 | (** bedroc_auc at given alpha for an array of score-labels. 88 | Default alpha = 20.0. 89 | WARNING: the array will be modified (sorted by decrasing scores) 90 | if [sorted = false] which is the default. *) 91 | val bedroc_auc_a: ?alpha:float -> ?sorted:bool -> SL.t array -> float 92 | 93 | (** equivalent to [bedroc_auc_a ~alpha ~sorted:true arr]. *) 94 | val fast_bedroc_auc_a: ?alpha:float -> SL.t array -> float 95 | 96 | (** Matthews' Correlation Coefficient (MCC) 97 | use: [mcc classif_threshold score_labels]. 98 | scores >= threshold are predicted as targets; 99 | scores < threshold are predicted as non targets. *) 100 | val mcc: float -> SL.t list -> float 101 | 102 | (** [a, b = platt_scaling ~debug score_labels] 103 | Fit a logistic curve (1 / (1 + exp (ax + b))) to [score_labels] 104 | and return its [(a, b)] parameters. 105 | Gnuplot it used underneath to do the fitting. 106 | Biblio: 107 | Platt, J. (1999). Probabilistic outputs for support vector machines 108 | and comparisons to regularized likelihood methods. 109 | Advances in large margin classifiers, 10(3), 61-74. *) 110 | val platt_scaling: ?debug:bool -> SL.t list -> (float * float) 111 | 112 | (** [platt_probability a b score] transform [score] into 113 | a probability, given logistic function parameters [a] and [b] 114 | obtained from a prior call to [platt_scaling]. *) 115 | val platt_probability: float -> float -> float -> float 116 | 117 | end 118 | 119 | (* functions for ROC analysis *) 120 | module Make: ROC_FUNCTOR = functor (SL: SCORE_LABEL) -> 121 | struct 122 | 123 | let trapezoid_surface x1 x2 y1 y2 = 124 | let base = abs_float (x1 -. x2) in 125 | let height = 0.5 *. (y1 +. y2) in 126 | base *. height 127 | 128 | let rev_compare_scores x y = 129 | BatFloat.compare (SL.get_score y) (SL.get_score x) 130 | 131 | (* put molecules with the highest scores at the top of the list *) 132 | let rank_order_by_score (score_labels: SL.t list) = 133 | L.stable_sort rev_compare_scores score_labels 134 | 135 | (* put molecules with the highest scores at the top of the array *) 136 | let rank_order_by_score_a (score_labels: SL.t array) = 137 | A.stable_sort rev_compare_scores score_labels 138 | 139 | (* compute the cumulated number of actives curve, 140 | given an already sorted list of score labels *) 141 | let cumulated_actives_curve (high_scores_first: SL.t list) = 142 | let sum = ref 0 in 143 | L.map (fun sl -> 144 | if SL.get_label sl then incr sum; 145 | !sum 146 | ) high_scores_first 147 | 148 | let roc_curve (score_labels: SL.t list) = 149 | let high_scores_first = rank_order_by_score score_labels in 150 | let nacts = ref 0 in 151 | let ndecs = ref 0 in 152 | let nb_act_decs = 153 | L.fold_left (fun acc x -> 154 | if SL.get_label x then incr nacts 155 | else incr ndecs; 156 | (!nacts, !ndecs) :: acc 157 | ) [(0, 0)] high_scores_first in 158 | let nb_actives = float !nacts in 159 | let nb_decoys = float !ndecs in 160 | L.rev_map (fun (na, nd) -> 161 | let tpr = float na /. nb_actives in 162 | let fpr = float nd /. nb_decoys in 163 | (fpr, tpr) 164 | ) nb_act_decs 165 | 166 | (* $LogAUC_\lambda=\frac{\sum_{i}^{where~x_i\ge\lambda} (\log_{10} x_{i+1} - \log_{10} x_i) 167 | (\frac{y_{i+1}+y_i}{2})}{\log_{10}\frac{1}{\lambda}}$ *) 168 | let log_AUC (lambda: float) (roc: (float * float) list): float = 169 | if lambda <= 0.0 then 170 | failwith "MakeROC.logAUC: lambda <= 0.0" 171 | else 172 | let x_ge_lambda = L.drop_while (fun (x, _y) -> x < lambda) roc in 173 | let res = ref 0.0 in 174 | let rec loop = function 175 | | [] | [_] -> () 176 | | (x_i, y_i) :: (x_j, y_j) :: xs -> 177 | begin 178 | res := !res +. (((log10 x_j) -. (log10 x_i)) *. (0.5 *. (y_i +. y_j))); 179 | loop ((x_j, y_j) :: xs) 180 | end in 181 | loop x_ge_lambda; 182 | !res /. (log10 (1. /. lambda)) 183 | 184 | let fast_roc_curve_a (score_labels: SL.t array) = 185 | let nacts = ref 0 in 186 | let ndecs = ref 0 in 187 | let nb_act_decs = 188 | A.map (fun x -> 189 | if SL.get_label x then incr nacts 190 | else incr ndecs; 191 | (!nacts, !ndecs) 192 | ) score_labels in 193 | let nb_actives = float !nacts in 194 | let nb_decoys = float !ndecs in 195 | A.map (fun (na, nd) -> 196 | let tpr = float na /. nb_actives in 197 | let fpr = float nd /. nb_decoys in 198 | (fpr, tpr) 199 | ) nb_act_decs 200 | 201 | (* Saito, T., & Rehmsmeier, M. (2015). 202 | The precision-recall plot is more informative than the ROC plot when 203 | evaluating binary classifiers on imbalanced datasets. 204 | PloS one, 10(3), e0118432. *) 205 | (* Davis, J., & Goadrich, M. (2006, June). 206 | The relationship between Precision-Recall and ROC curves. 207 | In Proceedings of the 23rd international conference on Machine learning 208 | (pp. 233-240). ACM. *) 209 | let pr_curve (score_labels: SL.t list) = 210 | let precision tp fp = 211 | tp /. (tp +. fp) in 212 | let recall tp fn = 213 | tp /. (tp +. fn) in 214 | let negate p x = 215 | not (p x) in 216 | let high_scores_first_uniq = 217 | let all_scores = L.map SL.get_score score_labels in 218 | L.sort_uniq (fun x y -> BatFloat.compare y x) all_scores in 219 | (* L.iter (Printf.printf "threshold: %f\n") high_scores_first_uniq; *) 220 | let high_scores_first = rank_order_by_score score_labels in 221 | let before = ref [] in 222 | let after = ref high_scores_first in 223 | let res = 224 | L.map (fun threshold -> 225 | let higher, lower = 226 | L.partition_while (fun x -> (SL.get_score x) >= threshold) !after in 227 | before := L.rev_append higher !before; 228 | after := lower; 229 | (* TP <=> (score >= t) && label *) 230 | let tp = float (L.filter_count (SL.get_label) !before) in 231 | (* FN <=> (score < t) && label *) 232 | let fn = float (L.filter_count (SL.get_label) !after) in 233 | (* FP <=> (score >= t) && (not label) *) 234 | let fp = float (L.filter_count (negate SL.get_label) !before) in 235 | let r = recall tp fn in 236 | let p = precision tp fp in 237 | (r, p) 238 | ) high_scores_first_uniq in 239 | (0.0, 1.0) :: res (* add missing first point *) 240 | 241 | let fast_auc_common fold_fun high_scores_first = 242 | let fp, tp, fp_prev, tp_prev, a, _p_prev = 243 | fold_fun (fun (fp, tp, fp_prev, tp_prev, a, p_prev) sl -> 244 | let si = SL.get_score sl in 245 | let li = SL.get_label sl in 246 | let new_a, new_p_prev, new_fp_prev, new_tp_prev = 247 | if si <> p_prev then 248 | a +. trapezoid_surface fp fp_prev tp tp_prev, 249 | si, 250 | fp, 251 | tp 252 | else 253 | a, 254 | p_prev, 255 | fp_prev, 256 | tp_prev 257 | in 258 | let new_tp, new_fp = 259 | if li then 260 | tp +. 1., fp 261 | else 262 | tp, fp +. 1. 263 | in 264 | (new_fp, new_tp, new_fp_prev, new_tp_prev, new_a, new_p_prev) 265 | ) (0., 0., 0., 0., 0., neg_infinity) 266 | high_scores_first 267 | in 268 | (a +. trapezoid_surface fp fp_prev tp tp_prev) /. (fp *. tp) 269 | 270 | (* area under the ROC curve given an already sorted list of score-labels *) 271 | let fast_auc high_scores_first = 272 | fast_auc_common L.fold_left high_scores_first 273 | 274 | let fast_auc_a high_scores_first = 275 | fast_auc_common A.fold_left high_scores_first 276 | 277 | (* area under the ROC curve given an unsorted list of score-labels 278 | TP cases have the label set to true 279 | TN cases have the label unset *) 280 | let auc (score_labels: SL.t list) = 281 | let high_scores_first = rank_order_by_score score_labels in 282 | fast_auc high_scores_first 283 | 284 | let pr_auc (score_labels: SL.t list) = 285 | let curve = pr_curve score_labels in 286 | let rec loop acc = function 287 | | [] -> acc 288 | | [_] -> acc 289 | | (x1, y1) :: (x2, y2) :: xs -> 290 | let area = trapezoid_surface x1 x2 y1 y2 in 291 | loop (area +. acc) ((x2, y2) :: xs) in 292 | loop 0.0 curve 293 | 294 | let auc_a (score_labels: SL.t array) = 295 | rank_order_by_score_a score_labels; 296 | fast_auc_a score_labels 297 | 298 | (* proportion of actives given an unsorted list of score-labels 299 | TP cases have the label set to true 300 | TN cases have the label unset 301 | returns: (nb_molecules, actives_rate) *) 302 | let actives_rate (score_labels: SL.t list) = 303 | let tp_count, fp_count = 304 | L.fold_left 305 | (fun (tp_c, fp_c) sl -> 306 | if SL.get_label sl then 307 | (tp_c + 1, fp_c) 308 | else 309 | (tp_c, fp_c + 1) 310 | ) (0, 0) score_labels 311 | in 312 | let nb_molecules = tp_count + fp_count in 313 | (nb_molecules, (float tp_count) /. (float nb_molecules)) 314 | 315 | (* enrichment rate at x (e.g. x = 0.01 --> ER @ 1%) given a list 316 | of unsorted score-labels *) 317 | let enrichment_factor (p: float) (score_labels: SL.t list) = 318 | let nb_molecules, rand_actives_rate = actives_rate score_labels in 319 | let top_n = BatFloat.round_to_int (p *. (float nb_molecules)) in 320 | let top_p_percent_molecules = 321 | L.take top_n (rank_order_by_score score_labels) in 322 | let _, top_actives_rate = actives_rate top_p_percent_molecules in 323 | let enr_rate = top_actives_rate /. rand_actives_rate in 324 | enr_rate 325 | 326 | (* this should land in batteries, not here... *) 327 | let array_filter_count p a = 328 | let count = ref 0 in 329 | A.iter (fun x -> 330 | if p x then incr count 331 | ) a; 332 | !count 333 | 334 | let array_actives_rate a = 335 | let nb_actives = array_filter_count SL.get_label a in 336 | let n = A.length a in 337 | (float nb_actives) /. (float n) 338 | 339 | let fast_enrichment_factor p score_labels = 340 | let nb_molecules = A.length score_labels in 341 | let rand_actives_rate = array_actives_rate score_labels in 342 | let top_n = BatFloat.round_to_int (p *. (float nb_molecules)) in 343 | let top_p_percent_molecules = A.sub score_labels 0 top_n in 344 | let top_actives_rate = array_actives_rate top_p_percent_molecules in 345 | (top_actives_rate /. rand_actives_rate) 346 | 347 | let fast_initial_enhancement (a: float) (l: SL.t list) = 348 | L.fold_lefti (fun acc i x -> 349 | if SL.get_label x then 350 | let rank = float i in 351 | acc +. exp (-. rank /. a) 352 | else 353 | acc 354 | ) 0.0 l 355 | 356 | let initial_enhancement (a: float) (l: SL.t list) = 357 | fast_initial_enhancement a (rank_order_by_score l) 358 | 359 | let nb_actives l = 360 | float 361 | (L.fold_left (fun acc x -> 362 | if SL.get_label x then acc + 1 363 | else acc 364 | ) 0 l) 365 | 366 | let nb_actives_a a = 367 | let res = ref 0 in 368 | A.iter (fun x -> if SL.get_label x then incr res) a; 369 | !res 370 | 371 | (* Cf. http://jcheminf.springeropen.com/articles/10.1186/s13321-016-0189-4 372 | for formulas: 373 | The power metric: a new statistically robust enrichment-type metric for 374 | virtual screening applications with early recovery capability 375 | Lopes et. al. Journal of Cheminformatics 2017 *) 376 | let power_metric (cutoff: float) (scores_tot: SL.t list): float = 377 | assert(cutoff > 0.0 && cutoff <= 1.0); 378 | let size_tot = float (L.length scores_tot) in 379 | let x = BatFloat.round (cutoff *. size_tot) in 380 | let size_x = int_of_float x in 381 | assert(size_x >= 1); 382 | let sorted = rank_order_by_score scores_tot in 383 | let scores_x = L.take size_x sorted in 384 | let actives_x = nb_actives scores_x in 385 | let actives_tot = nb_actives scores_tot in 386 | let tpr_x = actives_x /. actives_tot in 387 | let fpr_x = (x -. actives_x) /. (size_tot -. actives_tot) in 388 | tpr_x /. (tpr_x +. fpr_x) 389 | 390 | (* Same as [power_metric], but for an already sorted array of score-labels. *) 391 | let fast_power_metric_a (cutoff: float) (scores_tot: SL.t array): float = 392 | assert(cutoff > 0.0 && cutoff <= 1.0); 393 | let size_tot = float (A.length scores_tot) in 394 | let x = BatFloat.round (cutoff *. size_tot) in 395 | let size_x = int_of_float x in 396 | assert(size_x >= 1); 397 | let scores_x = A.sub scores_tot 0 size_x in 398 | let actives_x = float (nb_actives_a scores_x) in 399 | let actives_tot = float (nb_actives_a scores_tot) in 400 | let tpr_x = actives_x /. actives_tot in 401 | let fpr_x = (x -. actives_x) /. (size_tot -. actives_tot) in 402 | tpr_x /. (tpr_x +. fpr_x) 403 | 404 | (* formula comes from 405 | "Evaluating Virtual Screening Methods: 406 | Good and Bad Metrics for the “Early Recognition” Problem" 407 | Jean-François Truchon * and Christopher I. Bayly. DOI: 10.1021/ci600426e 408 | Reference implementation in Python: 409 | --- 410 | def calculateBEDROC(self, alpha = 20.0 ): 411 | if alpha < 0.00001: 412 | os.stderr.write( "In method calculatBEDROC, 413 | the alpha parameter argument must be 414 | greater than zero." ) 415 | sys.exit(1) 416 | N = float( self.getNbrTotal() ) 417 | n = float( self.getNbrActives() ) 418 | sum = 0.0 419 | for rank in self.ranks: 420 | sum += math.exp( -alpha * rank / N ) 421 | ra = n/N 422 | factor1 = 423 | ra * math.sinh( alpha/2.0 )/( math.cosh(alpha/2.0) - 424 | math.cosh(alpha/2.0 - ra*alpha ) ) 425 | factor2 = 426 | 1.0 / ra * (math.exp(alpha/N) - 1.0)/( 1.0 - math.exp(-alpha)) 427 | constant = 1.0 / ( 1.0 - math.exp( alpha * ( 1.0 - ra ) ) ) 428 | bedroc = sum * factor1 * factor2 + constant 429 | return bedroc 430 | --- *) 431 | let bedroc_auc_a ?alpha:(alpha = 20.0) ?sorted:(sorted = false) 432 | (score_labels: SL.t array): float = 433 | let half_alpha = 0.5 *. alpha in 434 | let n_tot = float (A.length score_labels) in 435 | let n_act = float (nb_actives_a score_labels) in 436 | (if not sorted then rank_order_by_score_a score_labels); 437 | let sum = 438 | A.fold_lefti (fun acc rank x -> 439 | if SL.get_label x then 440 | (* ranks must start at 1 *) 441 | acc +. exp (-.alpha *. (1.0 +. float rank) /. n_tot) 442 | else 443 | acc 444 | ) 0.0 score_labels in 445 | let r_a = n_act /. n_tot in 446 | let factor1 = r_a *. sinh half_alpha /. 447 | (cosh half_alpha -. cosh (half_alpha -. r_a *. alpha)) in 448 | let factor2 = 1.0 /. r_a *. 449 | (exp (alpha /. n_tot) -. 1.0) /. (1.0 -. exp (-.alpha)) in 450 | let constant = 1.0 /. (1.0 -. exp (alpha *. ( 1.0 -. r_a))) in 451 | sum *. factor1 *. factor2 +. constant 452 | 453 | let fast_bedroc_auc_a ?alpha:(alpha = 20.0) (score_labels: SL.t array) = 454 | bedroc_auc_a ~alpha ~sorted:true score_labels 455 | 456 | let bedroc_auc ?alpha:(alpha = 20.0) (score_labels: SL.t list): float = 457 | bedroc_auc_a ~alpha (A.of_list score_labels) 458 | 459 | (* accumulator type for the mcc metric *) 460 | type mcc_accum = { tp: int ; 461 | tn: int ; 462 | fp: int ; 463 | fn: int } 464 | 465 | (* Matthews' Correlation Coefficient (MCC) 466 | Biblio: Matthews, B. W. (1975). 467 | "Comparison of the predicted and observed secondary structure of T4 phage 468 | lysozyme". Biochimica et Biophysica Acta (BBA)-Protein Structure, 469 | 405(2), 442-451. *) 470 | let mcc classif_threshold score_labels = 471 | (* formula: 472 | mcc = (tp * tn - fp * fn) / 473 | sqrt ((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) *) 474 | let acc = 475 | L.fold_left (fun acc sl -> 476 | let truth = SL.get_label sl in 477 | let score = SL.get_score sl in 478 | let prediction = score >= classif_threshold in 479 | match (truth, prediction) with 480 | | (true, true) -> { acc with tp = acc.tp + 1 } (* TP *) 481 | | (false, false) -> { acc with tn = acc.tn + 1 } (* TN *) 482 | | (true, false) -> { acc with fn = acc.fn + 1 } (* FN *) 483 | | (false, true) -> { acc with fp = acc.fp + 1 } (* FP *) 484 | ) { tp = 0; tn = 0; fp = 0; fn = 0 } score_labels in 485 | let tp = float acc.tp in 486 | let tn = float acc.tn in 487 | let fp = float acc.fp in 488 | let fn = float acc.fn in 489 | let denum' = (tp +. fp) *. (tp +. fn) *. (tn +. fp) *. (tn +. fn) in 490 | if denum' = 0.0 then 0.0 (* div by 0 protection *) 491 | else 492 | let num = (tp *. tn) -. (fp *. fn) in 493 | let denum = sqrt denum' in 494 | num /. denum 495 | 496 | let platt_scaling ?(debug = false) score_labels = 497 | let scores_fn = Filename.temp_file "scores_" ".txt" in 498 | Utls.with_out_file scores_fn (fun out -> 499 | L.iter (fun sl -> 500 | let score = SL.get_score sl in 501 | let label = SL.get_label sl in 502 | Printf.fprintf out "%f %d\n" score (if label then 1 else 0) 503 | ) score_labels 504 | ); 505 | let gnuplot_script_fn = Filename.temp_file "gnuplot_" ".gpl" in 506 | Utls.with_out_file gnuplot_script_fn (fun out -> 507 | Printf.fprintf out 508 | "g(x) = 1 / (1 + exp(a * x + b))\n\ 509 | fit g(x) '%s' using 1:2 via a, b\n\ 510 | print a, b\n" 511 | scores_fn 512 | ); 513 | let a_b_str = 514 | Utls.get_command_output 515 | (Printf.sprintf "gnuplot %s 2>&1 | tail -1" gnuplot_script_fn) in 516 | if not debug then 517 | L.iter Sys.remove [scores_fn; gnuplot_script_fn]; 518 | Scanf.sscanf a_b_str "%f %f" (fun a b -> (a, b)) 519 | 520 | let platt_probability a b x = 521 | 1.0 /. (1.0 +. exp (a *. x +. b)) 522 | 523 | end 524 | -------------------------------------------------------------------------------- /src/RegrStats.ml: -------------------------------------------------------------------------------- 1 | (* Performance measures for regression models 2 | cf. chapter 12 "regression models" in book 3 | Varnek, A. ed., 2017. Tutorials in chemoinformatics. John Wiley & Sons. *) 4 | 5 | module A = BatArray 6 | module L = BatList 7 | 8 | let square x = 9 | x *. x 10 | 11 | (** Root Mean Squared Error 12 | [rmse exp pred] *) 13 | let rmse (l1: float list) (l2: float list): float = 14 | let a1 = A.of_list l1 in 15 | let a2 = A.of_list l2 in 16 | let m = A.length a1 in 17 | let n = A.length a2 in 18 | assert(m = n); 19 | let sum_squared_diffs = 20 | A.fold_lefti (fun acc i x -> 21 | let y = a2.(i) in 22 | acc +. square (x -. y) 23 | ) 0.0 a1 in 24 | sqrt (sum_squared_diffs /. (float n)) 25 | 26 | (** Mean Absolute Error 27 | [mae exp pred] *) 28 | let mae (l1: float list) (l2: float list): float = 29 | let a1 = A.of_list l1 in 30 | let a2 = A.of_list l2 in 31 | let m = A.length a1 in 32 | let n = A.length a2 in 33 | assert(m = n); 34 | let sum_abs_diffs = 35 | A.fold_lefti (fun acc i x -> 36 | let y = a2.(i) in 37 | acc +. abs_float (x -. y) 38 | ) 0.0 a1 in 39 | sum_abs_diffs /. (float n) 40 | 41 | (** standard deviation of residuals 42 | [std_dev_res exp pred] *) 43 | let std_dev_res (l1: float list) (l2: float list): float = 44 | let a1 = A.of_list l1 in 45 | let a2 = A.of_list l2 in 46 | let m = A.length a1 in 47 | let n = A.length a2 in 48 | assert(m = n); 49 | let sum_squared_diffs = 50 | A.fold_lefti (fun acc i x -> 51 | let y = a2.(i) in 52 | acc +. square (x -. y) 53 | ) 0.0 a1 in 54 | sqrt (sum_squared_diffs /. (float (n - 2))) 55 | 56 | (** coefficient of determination (for arrays) 57 | [r2_a exp pred] *) 58 | let r2_a a1 a2 = 59 | let m = A.length a1 in 60 | let n = A.length a2 in 61 | assert(m = n); 62 | let sum_squared_diffs = 63 | A.fold_lefti (fun acc i x -> 64 | let y = a2.(i) in 65 | acc +. square (x -. y) 66 | ) 0.0 a1 in 67 | let sum_squared_exp_diffs = 68 | let avg_exp = A.favg a1 in 69 | A.fold_left (fun acc x -> 70 | acc +. square (x -. avg_exp) 71 | ) 0.0 a1 in 72 | 1.0 -. (sum_squared_diffs /. sum_squared_exp_diffs) 73 | 74 | (** coefficient of determination 75 | [r2 exp pred] *) 76 | let r2 (l1: float list) (l2: float list): float = 77 | r2_a (A.of_list l1) (A.of_list l2) 78 | 79 | (** raw Regression Error Characteristic Curve 80 | (raw means not scaled by a null model) 81 | [raw_REC_curve exp pred] 82 | Cf. Bi, J. and Bennett, K.P., 2003. 83 | Regression error characteristic curves. 84 | In Proceedings of the 20th international conference on machine learning 85 | (ICML-03) (pp. 43-50). *) 86 | let raw_REC_curve (l1: float list) (l2: float list): (float * float) list = 87 | let array_filter_count p a = 88 | float 89 | (A.fold_left (fun acc x -> 90 | if p x then acc + 1 else acc 91 | ) 0 a) in 92 | let a1 = A.of_list l1 in 93 | let a2 = A.of_list l2 in 94 | let n = A.length a1 in 95 | let errors = 96 | A.map2 (fun x y -> 97 | abs_float (x -. y) 98 | ) a1 a2 in 99 | A.sort BatFloat.compare errors; 100 | let max_err = errors.(n - 1) in 101 | (* 100 steps on the X axis *) 102 | let xs = L.frange 0.0 `To max_err 100 in 103 | (* WARNING: not very efficient algorithm *) 104 | let m = float n in 105 | L.map (fun err_tol -> 106 | let percent_ok = 107 | let ok_count = array_filter_count (fun err -> err <= err_tol) errors in 108 | (ok_count /. m) in 109 | (err_tol, percent_ok) 110 | ) xs 111 | -------------------------------------------------------------------------------- /src/TopKeeper.ml: -------------------------------------------------------------------------------- 1 | 2 | (* Keep only the N top scoring elements in memory. 3 | WARNING: we will have several elements with equal scores when screening 4 | a huge database *) 5 | 6 | module L = List 7 | 8 | type 'a t = { max_size: int; (* max number of (top scoring) elements *) 9 | mutable curr_size: int; (* how many elts currently *) 10 | mutable min_score: float; 11 | (* For a given score, elements are in LIFO order *) 12 | mutable elements: (float * 'a list) list } 13 | 14 | (* this does not update the count, on purpose because drop_lowest_score 15 | * is called when there is one score too much *) 16 | let drop_lowest_score t = 17 | match t.elements with 18 | | [] -> assert(false) 19 | | (score, elts) :: rest -> 20 | match elts with 21 | | [] -> assert(false) 22 | | [_x] -> 23 | (* this whole score class is dropped, since it has no more members *) 24 | t.elements <- rest 25 | | (_x :: y :: zs) -> 26 | (* just drop the last element that came in with that score *) 27 | t.elements <- (score, y :: zs) :: rest 28 | 29 | (* peek at the currently known min score *) 30 | let peek_score t = match t.elements with 31 | | [] -> assert(false) 32 | | (score, _elts) :: _rest -> score 33 | 34 | let insert t score x = 35 | let rec loop acc = function 36 | | [] -> L.rev_append acc [(score, [x])] 37 | | (score', elts) :: rest -> 38 | if score' < score then 39 | loop ((score', elts) :: acc) rest 40 | else if score' = score then 41 | L.rev_append acc ((score', x :: elts) :: rest) 42 | else (* score' > score *) 43 | L.rev_append acc ((score, [x]) :: (score', elts) :: rest) 44 | in 45 | t.elements <- loop [] t.elements 46 | 47 | let get_min_score t = 48 | t.min_score 49 | 50 | let get_curr_size t = 51 | t.curr_size 52 | 53 | let get_max_size t = 54 | t.max_size 55 | 56 | (* when we insert an element *) 57 | let update_bound t score = 58 | if score < t.min_score then 59 | t.min_score <- score 60 | 61 | (* after we drop one *) 62 | let recompute_bound t = 63 | t.min_score <- peek_score t 64 | 65 | let create (max_size: int): 'a t = 66 | assert(max_size > 0); 67 | let curr_size = 0 in 68 | let min_score = max_float in 69 | let elements = [] in 70 | { max_size; curr_size; elements; min_score } 71 | 72 | let add (t: 'a t) (score: float) (x: 'a): unit = 73 | if t.curr_size < t.max_size then 74 | begin 75 | (* don't filter, as long as there are not enough elements *) 76 | insert t score x; 77 | t.curr_size <- t.curr_size + 1; 78 | update_bound t score 79 | end 80 | else 81 | begin 82 | (* enforce data structure invariant *) 83 | assert(t.curr_size = t.max_size); 84 | if score > t.min_score then 85 | begin 86 | insert t score x; 87 | drop_lowest_score t; 88 | recompute_bound t 89 | end 90 | end 91 | 92 | let high_scores_first (t: 'a t): (float * 'a) list = 93 | (* put scores in decreasing order *) 94 | L.fold_left (fun acc1 (score, elts) -> 95 | (* put back elements in FIFO order *) 96 | L.fold_left (fun acc2 x -> 97 | (score, x) :: acc2 98 | ) acc1 elts 99 | ) [] t.elements 100 | -------------------------------------------------------------------------------- /src/TopKeeper.mli: -------------------------------------------------------------------------------- 1 | 2 | type 'a t 3 | 4 | (** [create n] creates a ['a TopKeeper.t] that will keep up to [n] 5 | best scoring elements. [n] must be greater than 0. *) 6 | val create: int -> 'a t 7 | 8 | (** [add t score elt] add [elt] with [score] to the top_keeper [t] 9 | (if the score is good enough or if the top_keeper 10 | doesn't hold enough elements yet) *) 11 | val add: 'a t -> float -> 'a -> unit 12 | 13 | (** [high_scores_first t] retrieve the [n] best scores from [t] 14 | (with associated elements); scores are in decreasing order *) 15 | val high_scores_first: 'a t -> (float * 'a) list 16 | 17 | (** [get_min_score t] return the current minimum score in [t] *) 18 | val get_min_score: 'a t -> float 19 | 20 | (** [get_curr_size t] return the current number of elements in [t]. 21 | Note that [get_curr_size t <= get_max_size t] is always true 22 | (it is the data-structure invariant). *) 23 | val get_curr_size: 'a t -> int 24 | 25 | (** [get_max_size t] return the number of elements chosen at creation time *) 26 | val get_max_size: 'a t -> int 27 | -------------------------------------------------------------------------------- /src/dune: -------------------------------------------------------------------------------- 1 | (library 2 | (name cpm) 3 | (public_name cpm) 4 | (modules MakeROC TopKeeper RegrStats Utls MyList) 5 | (private_modules MyList) 6 | (libraries batteries)) 7 | 8 | (executable 9 | (name test) 10 | (modules test) 11 | (libraries cpm)) 12 | -------------------------------------------------------------------------------- /src/myList.ml: -------------------------------------------------------------------------------- 1 | 2 | include BatList 3 | 4 | (* count elements of [l] satisfying [p] *) 5 | let filter_count p l = 6 | fold_left (fun acc x -> 7 | if p x then acc + 1 8 | else acc 9 | ) 0 l 10 | 11 | (* split [l] in two parts, 12 | the first one being the longest prefix of [l] where [p] holds *) 13 | let partition_while p l = 14 | let rec loop acc = function 15 | | [] -> (rev acc, []) 16 | | x :: xs -> 17 | if p x then loop (x :: acc) xs 18 | else (rev acc, x :: xs) in 19 | loop [] l 20 | -------------------------------------------------------------------------------- /src/test.ml: -------------------------------------------------------------------------------- 1 | 2 | open Printf 3 | 4 | module L = BatList 5 | 6 | (* Example usage *) 7 | 8 | (* first, define your score_label module *) 9 | module SL = struct 10 | type t = string * float * int * bool 11 | let get_score (_, s, _, _) = s 12 | let get_label (_, _, _, l) = l 13 | end 14 | 15 | (* second, instantiate the ROC functor for your score_label module *) 16 | module ROC = Cpm.MakeROC.Make (SL) 17 | module Top = Cpm.TopKeeper 18 | 19 | let almost_equal epsilon x_curr x_ref = 20 | (x_ref -. epsilon <= x_curr) && (x_curr <= x_ref +. epsilon) 21 | 22 | (* third, call any classification performance metric you need *) 23 | let main () = 24 | let create name score index label = 25 | (name, score, index, label) in 26 | let scores = 27 | [ create "" 14.0 0 true; 28 | create "" 13.0 0 true; 29 | create "" 12.0 0 false; 30 | create "" 11.0 0 true; 31 | create "" 10.0 0 false; 32 | create "" 9.0 0 false; 33 | create "" 8.0 0 false; 34 | create "" 7.0 0 false; 35 | create "" 6.0 0 true; 36 | create "" 5.0 0 false; 37 | create "" 4.0 0 false; 38 | create "" 3.0 0 false; 39 | create "" 2.0 0 false; 40 | create "" 1.0 0 false ] in 41 | let tpr_x = 3. /. 4. in 42 | let fpr_x = (5. -. 3.) /. (14. -. 4.) in 43 | assert(ROC.power_metric 0.35 scores = tpr_x /. (tpr_x +. fpr_x)); 44 | assert(ROC.auc [("", 1.0, 0, true) ; (("", 0.9, 1, false))] = 1.0); 45 | assert(ROC.auc [("", 1.0, 0, false); (("", 0.9, 1, true))] = 0.0); 46 | assert(ROC.auc_a [|("", 1.0, 0, true) ; (("", 0.9, 1, false))|] = 1.0); 47 | assert(ROC.auc_a [|("", 1.0, 0, false); (("", 0.9, 1, true))|] = 0.0); 48 | let scores_02 = 49 | [ create "" 0.1 0 false; 50 | create "" 0.2 0 false; 51 | create "" 0.3 0 true; 52 | create "" 0.4 0 false; 53 | create "" 0.5 0 false; 54 | create "" 0.6 0 true; 55 | create "" 0.7 0 true; 56 | create "" 0.8 0 true; 57 | create "" 0.9 0 false; 58 | create "" 1.0 0 true ] in 59 | (* cross validated with 'croc-curve < test.scored-label > /dev/null' *) 60 | assert(ROC.auc scores_02 = 0.76); 61 | assert(ROC.auc_a (Array.of_list scores_02) = 0.76); 62 | (* cross validated with 'croc-bedroc < test.scored-label > /dev/null' *) 63 | assert(almost_equal 0.0001 (ROC.bedroc_auc scores_02) 0.88297); 64 | (* wikipedia example: 65 | https://en.wikipedia.org/wiki/Matthews_correlation_coefficient *) 66 | let tp, fp, tn, fn = 90., 4., 1., 5. in 67 | let scores_03 = 68 | (List.init 90 (fun i -> create "" 1.0 i true)) @ 69 | (List.init 4 (fun i -> create "" 1.0 i false)) @ 70 | [create "" 0.0 0 false] @ 71 | (List.init 5 (fun i -> create "" 0.0 i true)) in 72 | let mcc = ROC.mcc 0.5 scores_03 in 73 | assert(mcc = ((tp *. tn -. fp *. fn) /. 74 | sqrt ((tp +. fp) *. (tp +. fn) *. (tn +. fp) *. (tn +. fn)))); 75 | assert(almost_equal 0.0000001 mcc 0.135242); 76 | let top3 = Top.create 3 in 77 | Top.add top3 1. 'e'; 78 | Top.add top3 4. 'd'; 79 | Top.add top3 3. 'c'; 80 | Top.add top3 4. 'b'; 81 | Top.add top3 5. 'a'; 82 | assert(Top.high_scores_first top3 = [(5., 'a'); (4., 'd'); (4., 'b')]); 83 | Random.self_init (); 84 | let top50 = Top.create 50 in 85 | let rands = L.init 1000 (fun _ -> Random.float 1.0) in 86 | L.iter (fun rand -> 87 | Top.add top50 rand rand 88 | ) rands; 89 | let top_scores = Top.high_scores_first top50 in 90 | let reference = 91 | let l = L.take 50 (L.stable_sort (fun x y -> compare y x) rands) in 92 | L.combine l l in 93 | assert(top_scores = reference); 94 | (* score labels from the Python script *) 95 | let score_labs_py = 96 | [ create "" 1. 0 true; 97 | create "" 0.66666667 0 true; 98 | create "" 0.66666667 0 true; 99 | create "" 1. 0 true; 100 | create "" 0.66666667 0 false; 101 | create "" 0.66666667 0 true; 102 | create "" 1. 0 true; 103 | create "" 0. 0 false; 104 | create "" 1. 0 true; 105 | create "" 0.33333333 0 true; 106 | create "" 1. 0 true; 107 | create "" 1. 0 true; 108 | create "" 1. 0 true; 109 | create "" 1. 0 true; 110 | create "" 0.66666667 0 true; 111 | create "" 0.66666667 0 true; 112 | create "" 0.66666667 0 false; 113 | create "" 0.66666667 0 false; 114 | create "" 0.66666667 0 true; 115 | create "" 0. 0 false; 116 | create "" 0. 0 false; 117 | create "" 0.66666667 0 false; 118 | create "" 0.33333333 0 true; 119 | create "" 1. 0 true; 120 | create "" 0.33333333 0 true; 121 | create "" 0. 0 false; 122 | create "" 0. 0 false; 123 | create "" 0.66666667 0 true; 124 | create "" 0.33333333 0 true; 125 | create "" 0.33333333 0 false; 126 | create "" 0.66666667 0 false; 127 | create "" 0.33333333 0 false; 128 | create "" 1. 0 true; 129 | create "" 0. 0 false; 130 | create "" 1. 0 true; 131 | create "" 1. 0 true; 132 | create "" 0.66666667 0 true; 133 | create "" 1. 0 true; 134 | create "" 0. 0 false; 135 | create "" 0.33333333 0 false; 136 | create "" 0.66666667 0 true; 137 | create "" 0.33333333 0 false; 138 | create "" 1. 0 true; 139 | create "" 0.33333333 0 true; 140 | create "" 0.33333333 0 true; 141 | create "" 0. 0 false; 142 | create "" 0.66666667 0 false; 143 | create "" 0.66666667 0 true; 144 | create "" 0. 0 false; 145 | create "" 0. 0 false; 146 | create "" 0.66666667 0 true; 147 | create "" 0.33333333 0 true; 148 | create "" 0.66666667 0 false; 149 | create "" 0.66666667 0 true; 150 | create "" 0. 0 false; 151 | create "" 0.33333333 0 false; 152 | create "" 0. 0 false; 153 | create "" 1. 0 true; 154 | create "" 0. 0 false; 155 | create "" 0.33333333 0 true; 156 | create "" 0. 0 false; 157 | create "" 1. 0 true; 158 | create "" 0.66666667 0 true; 159 | create "" 1. 0 true; 160 | create "" 0. 0 false; 161 | create "" 0.66666667 0 false; 162 | create "" 0.33333333 0 true; 163 | create "" 0. 0 false; 164 | create "" 0.66666667 0 false; 165 | create "" 1. 0 true; 166 | create "" 0. 0 false; 167 | create "" 0.66666667 0 false; 168 | create "" 1. 0 true; 169 | create "" 1. 0 true; 170 | create "" 0.33333333 0 false; 171 | create "" 0.66666667 0 true; 172 | create "" 1. 0 true; 173 | create "" 1. 0 true; 174 | create "" 0.66666667 0 true; 175 | create "" 0. 0 false; 176 | create "" 0.66666667 0 true; 177 | create "" 0. 0 false; 178 | create "" 0.33333333 0 false; 179 | create "" 0. 0 false; 180 | create "" 0.66666667 0 true; 181 | create "" 1. 0 true; 182 | create "" 0.33333333 0 false; 183 | create "" 1. 0 false; 184 | create "" 0. 0 false; 185 | create "" 1. 0 true; 186 | create "" 0.66666667 0 false; 187 | create "" 1. 0 false; 188 | create "" 0. 0 false; 189 | create "" 0. 0 false; 190 | create "" 0. 0 false; 191 | create "" 1. 0 true; 192 | create "" 0. 0 false; 193 | create "" 0. 0 false; 194 | create "" 1. 0 true; 195 | create "" 0.66666667 0 false; 196 | create "" 0.66666667 0 false; 197 | create "" 0. 0 false; 198 | create "" 1. 0 false; 199 | create "" 0. 0 false; 200 | create "" 0. 0 false; 201 | create "" 1. 0 true; 202 | create "" 1. 0 true; 203 | create "" 1. 0 true; 204 | create "" 0.66666667 0 true; 205 | create "" 1. 0 true; 206 | create "" 0. 0 false; 207 | create "" 0. 0 false; 208 | create "" 0. 0 false; 209 | create "" 0.66666667 0 false; 210 | create "" 1. 0 true; 211 | create "" 1. 0 true; 212 | create "" 1. 0 false; 213 | create "" 0.33333333 0 true; 214 | create "" 0. 0 false; 215 | create "" 1. 0 false; 216 | create "" 1. 0 true; 217 | create "" 1. 0 true; 218 | create "" 0.33333333 0 true; 219 | create "" 0. 0 false; 220 | create "" 0.66666667 0 true; 221 | create "" 0.33333333 0 false; 222 | create "" 0.66666667 0 false; 223 | create "" 0.66666667 0 false; 224 | create "" 1. 0 true; 225 | create "" 0. 0 false; 226 | create "" 1. 0 true; 227 | create "" 0.33333333 0 false; 228 | create "" 0. 0 false; 229 | create "" 0.33333333 0 false; 230 | create "" 0.33333333 0 false; 231 | create "" 0. 0 false; 232 | create "" 0.33333333 0 false; 233 | create "" 0. 0 false; 234 | create "" 1. 0 true; 235 | create "" 0.33333333 0 false; 236 | create "" 0.33333333 0 false; 237 | create "" 1. 0 true; 238 | create "" 0.33333333 0 false; 239 | create "" 0.66666667 0 true; 240 | create "" 0. 0 false; 241 | create "" 0. 0 false; 242 | create "" 0. 0 false; 243 | create "" 0.33333333 0 false; 244 | create "" 0.66666667 0 true; 245 | create "" 0. 0 false; 246 | create "" 0.33333333 0 false; 247 | create "" 1. 0 true; 248 | create "" 0.66666667 0 true; 249 | create "" 0.66666667 0 true; 250 | create "" 1. 0 true; 251 | create "" 0. 0 false; 252 | create "" 1. 0 true; 253 | create "" 0. 0 false; 254 | create "" 0.33333333 0 false; 255 | create "" 0. 0 false; 256 | create "" 0.33333333 0 false; 257 | create "" 0. 0 false; 258 | create "" 0. 0 false; 259 | create "" 0. 0 false; 260 | create "" 0.66666667 0 false; 261 | create "" 0.66666667 0 true; 262 | create "" 0.66666667 0 true; 263 | create "" 1. 0 true; 264 | create "" 0. 0 false; 265 | create "" 0. 0 false; 266 | create "" 0. 0 false; 267 | create "" 0.66666667 0 false; 268 | create "" 0. 0 false; 269 | create "" 0. 0 false; 270 | create "" 0. 0 false; 271 | create "" 0. 0 false; 272 | create "" 0. 0 false; 273 | create "" 1. 0 true; 274 | create "" 1. 0 true; 275 | create "" 0.33333333 0 false; 276 | create "" 0. 0 false; 277 | create "" 0.66666667 0 true; 278 | create "" 0. 0 false; 279 | create "" 0.66666667 0 false; 280 | create "" 0. 0 false; 281 | create "" 0. 0 false; 282 | create "" 0.33333333 0 false; 283 | create "" 0. 0 false; 284 | create "" 0. 0 false; 285 | create "" 0.66666667 0 true; 286 | create "" 0. 0 false; 287 | create "" 1. 0 true; 288 | create "" 0. 0 false; 289 | create "" 1. 0 true; 290 | create "" 0.66666667 0 true; 291 | create "" 0.66666667 0 true; 292 | create "" 1. 0 true; 293 | create "" 1. 0 true; 294 | create "" 0.66666667 0 true; 295 | create "" 1. 0 true; 296 | create "" 0.66666667 0 true; 297 | create "" 1. 0 true; 298 | create "" 1. 0 true; 299 | create "" 0. 0 false; 300 | create "" 0. 0 false; 301 | create "" 0. 0 false; 302 | create "" 0. 0 false; 303 | create "" 1. 0 true; 304 | create "" 0. 0 false; 305 | create "" 1. 0 false; 306 | create "" 0. 0 false; 307 | create "" 0.66666667 0 false; 308 | create "" 1. 0 true; 309 | create "" 0. 0 false; 310 | create "" 1. 0 true; 311 | create "" 0.66666667 0 true; 312 | create "" 0.33333333 0 true; 313 | create "" 0. 0 false; 314 | create "" 0.66666667 0 true; 315 | create "" 1. 0 false; 316 | create "" 1. 0 true; 317 | create "" 0. 0 false; 318 | create "" 0. 0 false; 319 | create "" 1. 0 true; 320 | create "" 1. 0 true; 321 | create "" 0.66666667 0 true; 322 | create "" 1. 0 true; 323 | create "" 1. 0 true; 324 | create "" 0. 0 false; 325 | create "" 0.33333333 0 false; 326 | create "" 0. 0 false; 327 | create "" 0.66666667 0 true; 328 | create "" 1. 0 true; 329 | create "" 0.33333333 0 false; 330 | create "" 1. 0 true; 331 | create "" 1. 0 true; 332 | create "" 0. 0 false; 333 | create "" 0.33333333 0 false; 334 | create "" 0.66666667 0 false; 335 | create "" 0.33333333 0 true; 336 | create "" 1. 0 true; 337 | create "" 1. 0 true; 338 | create "" 0.66666667 0 true; 339 | create "" 0. 0 false; 340 | create "" 0. 0 false; 341 | create "" 0. 0 true; 342 | create "" 1. 0 true; 343 | create "" 0.33333333 0 false; 344 | create "" 0.66666667 0 true; 345 | create "" 1. 0 true; 346 | create "" 0. 0 false; 347 | create "" 1. 0 true; 348 | create "" 0. 0 false; 349 | create "" 1. 0 true; 350 | create "" 1. 0 true; 351 | create "" 1. 0 true; 352 | create "" 0. 0 false; 353 | create "" 1. 0 true; 354 | create "" 0. 0 false; 355 | create "" 0. 0 false; 356 | create "" 0. 0 false; 357 | create "" 0. 0 false; 358 | create "" 0. 0 false; 359 | create "" 0.66666667 0 true; 360 | create "" 0.66666667 0 false; 361 | create "" 0.33333333 0 true; 362 | create "" 0.33333333 0 true; 363 | create "" 1. 0 true; 364 | create "" 0.66666667 0 true; 365 | create "" 0. 0 false; 366 | create "" 0.33333333 0 false; 367 | create "" 0. 0 false; 368 | create "" 1. 0 true; 369 | create "" 0. 0 false; 370 | create "" 0.33333333 0 false; 371 | create "" 0.66666667 0 true; 372 | create "" 0.66666667 0 false; 373 | create "" 0. 0 false; 374 | create "" 0. 0 false; 375 | create "" 1. 0 true; 376 | create "" 1. 0 true; 377 | create "" 0. 0 false; 378 | create "" 1. 0 true; 379 | create "" 1. 0 true; 380 | create "" 1. 0 true; 381 | create "" 0.66666667 0 false; 382 | create "" 1. 0 true; 383 | create "" 0.66666667 0 false; 384 | create "" 1. 0 true; 385 | create "" 0. 0 false; 386 | create "" 0.33333333 0 false; 387 | create "" 0. 0 false; 388 | create "" 0.66666667 0 true; 389 | create "" 0.66666667 0 true; 390 | create "" 0. 0 false; 391 | create "" 1. 0 true; 392 | create "" 0.33333333 0 false; 393 | create "" 0.66666667 0 true; 394 | create "" 1. 0 true; 395 | create "" 0. 0 false; 396 | create "" 0. 0 false; 397 | create "" 0. 0 false; 398 | create "" 1. 0 true; 399 | create "" 0.33333333 0 true; 400 | create "" 1. 0 true; 401 | create "" 0.66666667 0 true; 402 | create "" 0. 0 false; 403 | create "" 0.66666667 0 true; 404 | create "" 1. 0 true; 405 | create "" 1. 0 true; 406 | create "" 0.66666667 0 true; 407 | create "" 0.66666667 0 false; 408 | create "" 0.66666667 0 false; 409 | create "" 0.66666667 0 true; 410 | create "" 0.33333333 0 true; 411 | create "" 1. 0 true; 412 | create "" 0. 0 false; 413 | create "" 0. 0 false; 414 | create "" 0.33333333 0 true; 415 | create "" 1. 0 false; 416 | create "" 0.66666667 0 false; 417 | create "" 1. 0 false; 418 | create "" 1. 0 true; 419 | create "" 0. 0 false; 420 | create "" 0. 0 false; 421 | create "" 0.66666667 0 true; 422 | create "" 0. 0 false; 423 | create "" 1. 0 true; 424 | create "" 0.33333333 0 true; 425 | create "" 0.66666667 0 true; 426 | create "" 1. 0 true; 427 | create "" 0.33333333 0 true; 428 | create "" 1. 0 true; 429 | create "" 1. 0 true; 430 | create "" 0. 0 false; 431 | create "" 1. 0 true; 432 | create "" 0.66666667 0 true; 433 | create "" 0. 0 false; 434 | create "" 0.66666667 0 false; 435 | create "" 0.33333333 0 false; 436 | create "" 1. 0 true; 437 | create "" 0.66666667 0 true; 438 | create "" 0.66666667 0 true; 439 | create "" 0. 0 false; 440 | create "" 1. 0 true; 441 | create "" 1. 0 true; 442 | create "" 0. 0 false; 443 | create "" 0.66666667 0 true; 444 | create "" 1. 0 true; 445 | create "" 0.33333333 0 false; 446 | create "" 1. 0 true; 447 | create "" 1. 0 true; 448 | create "" 0. 0 false; 449 | create "" 0. 0 false; 450 | create "" 0.33333333 0 false; 451 | create "" 0. 0 false; 452 | create "" 0. 0 false; 453 | create "" 1. 0 true; 454 | create "" 1. 0 true; 455 | create "" 0.66666667 0 true; 456 | create "" 0. 0 false; 457 | create "" 0. 0 false; 458 | create "" 1. 0 true; 459 | create "" 1. 0 false; 460 | create "" 1. 0 true; 461 | create "" 0. 0 false; 462 | create "" 0.33333333 0 false; 463 | create "" 0.33333333 0 false; 464 | create "" 1. 0 true; 465 | create "" 0.66666667 0 true; 466 | create "" 0. 0 false; 467 | create "" 0.66666667 0 true; 468 | create "" 1. 0 true; 469 | create "" 0.33333333 0 false; 470 | create "" 1. 0 true; 471 | create "" 0.66666667 0 false; 472 | create "" 0.66666667 0 false; 473 | create "" 1. 0 true; 474 | create "" 0. 0 false; 475 | create "" 1. 0 true; 476 | create "" 0.66666667 0 true; 477 | create "" 0. 0 false; 478 | create "" 0.66666667 0 false; 479 | create "" 1. 0 false; 480 | create "" 0.33333333 0 false; 481 | create "" 0. 0 false; 482 | create "" 0.66666667 0 true; 483 | create "" 0. 0 false; 484 | create "" 1. 0 true; 485 | create "" 1. 0 false; 486 | create "" 0.66666667 0 true; 487 | create "" 1. 0 true; 488 | create "" 0. 0 false; 489 | create "" 0.33333333 0 true; 490 | create "" 0.66666667 0 true; 491 | create "" 0. 0 true; 492 | create "" 1. 0 true; 493 | create "" 1. 0 true; 494 | create "" 0. 0 true; 495 | create "" 0.66666667 0 false; 496 | create "" 1. 0 true; 497 | create "" 1. 0 false; 498 | create "" 0.66666667 0 true; 499 | create "" 0. 0 false; 500 | create "" 0. 0 false; 501 | create "" 0.33333333 0 false; 502 | create "" 1. 0 true; 503 | create "" 0. 0 false; 504 | create "" 0.66666667 0 true; 505 | create "" 1. 0 true; 506 | create "" 0. 0 false; 507 | create "" 1. 0 true; 508 | create "" 0.33333333 0 false; 509 | create "" 0.66666667 0 true; 510 | create "" 0.66666667 0 false; 511 | create "" 0. 0 false; 512 | create "" 0. 0 false; 513 | create "" 0. 0 false; 514 | create "" 0.66666667 0 false; 515 | create "" 0. 0 false; 516 | create "" 0.66666667 0 false; 517 | create "" 0.66666667 0 true; 518 | create "" 0. 0 false; 519 | create "" 1. 0 true; 520 | create "" 0.33333333 0 false; 521 | create "" 1. 0 true; 522 | create "" 0. 0 false; 523 | create "" 0.66666667 0 false; 524 | create "" 0.66666667 0 false; 525 | create "" 0.33333333 0 true; 526 | create "" 0.66666667 0 true; 527 | create "" 1. 0 true; 528 | create "" 0. 0 false; 529 | create "" 1. 0 true; 530 | create "" 0. 0 false; 531 | create "" 0. 0 false; 532 | create "" 0. 0 false; 533 | create "" 0. 0 false; 534 | create "" 1. 0 true; 535 | create "" 0. 0 false; 536 | create "" 0.66666667 0 true; 537 | create "" 0. 0 false; 538 | create "" 1. 0 true; 539 | create "" 0. 0 false; 540 | create "" 0.66666667 0 false; 541 | create "" 1. 0 true; 542 | create "" 1. 0 true; 543 | create "" 0.33333333 0 false; 544 | create "" 0. 0 false; 545 | create "" 0. 0 false; 546 | create "" 1. 0 true; 547 | create "" 0.66666667 0 true; 548 | create "" 0.66666667 0 true; 549 | create "" 0. 0 false; 550 | create "" 0.66666667 0 false; 551 | create "" 0.33333333 0 false; 552 | create "" 0.66666667 0 true; 553 | create "" 0. 0 false; 554 | create "" 1. 0 true; 555 | create "" 0.33333333 0 false; 556 | create "" 0.66666667 0 true; 557 | create "" 1. 0 true; 558 | create "" 0. 0 false; 559 | create "" 0.66666667 0 true; 560 | create "" 1. 0 true; 561 | create "" 1. 0 true; 562 | create "" 1. 0 true; 563 | create "" 0.33333333 0 false; 564 | create "" 0.33333333 0 false; 565 | create "" 1. 0 false; 566 | create "" 0. 0 false; 567 | create "" 0. 0 false; 568 | create "" 0.66666667 0 true; 569 | create "" 0.66666667 0 true; 570 | create "" 0.33333333 0 false; 571 | create "" 0. 0 false; 572 | create "" 0. 0 false; 573 | create "" 0. 0 false; 574 | create "" 0. 0 false; 575 | create "" 0. 0 true; 576 | create "" 0.33333333 0 false; 577 | create "" 1. 0 true; 578 | create "" 0. 0 false; 579 | create "" 0.33333333 0 false; 580 | create "" 1. 0 true; 581 | create "" 1. 0 true; 582 | create "" 0.66666667 0 true; 583 | create "" 0. 0 true; 584 | create "" 0.66666667 0 true; 585 | create "" 0.66666667 0 true; 586 | create "" 1. 0 true; 587 | create "" 0.66666667 0 true; 588 | create "" 0.33333333 0 false; 589 | create "" 0. 0 true; 590 | create "" 0.66666667 0 true; 591 | create "" 1. 0 true; 592 | create "" 0. 0 false; 593 | create "" 1. 0 true; 594 | create "" 0.66666667 0 false; 595 | create "" 0.66666667 0 false] in 596 | let pr_auc = ROC.pr_auc score_labs_py in 597 | assert(almost_equal 0.000001 pr_auc 0.891669); 598 | printf "all OK\n" 599 | 600 | let () = main () 601 | -------------------------------------------------------------------------------- /src/utls.ml: -------------------------------------------------------------------------------- 1 | 2 | module L = BatList 3 | 4 | let with_in_file fn f = 5 | let input = open_in_bin fn in 6 | let res = f input in 7 | close_in input; 8 | res 9 | 10 | let with_out_file fn f = 11 | let output = open_out_bin fn in 12 | let res = f output in 13 | close_out output; 14 | res 15 | 16 | (* get the first line output by given command *) 17 | let get_command_output ?(debug = false) (cmd: string): string = 18 | if debug then 19 | Printf.printf "get_command_output: %s" cmd; 20 | let _stat, output = BatUnix.run_and_read cmd in 21 | match BatString.split_on_char '\n' output with 22 | | first_line :: _others -> first_line 23 | | [] -> 24 | begin 25 | Printf.eprintf "get_command_output: no output for: %s" cmd; 26 | exit 1 27 | end 28 | 29 | (* filename to string list *) 30 | let lines_of_file fn = 31 | with_in_file fn (fun input -> 32 | let res, exn = L.unfold_exc (fun () -> input_line input) in 33 | if exn <> End_of_file then 34 | raise exn 35 | else res 36 | ) 37 | 38 | (* use fraction [p] as training set and fraction 1-p as test set *) 39 | let train_test_split p lines = 40 | assert(p >= 0.0 && p <= 1.0); 41 | let n = float (L.length lines) in 42 | let for_training = BatFloat.round_to_int (p *. n) in 43 | let train, test = L.takedrop for_training lines in 44 | assert(L.length train = for_training); 45 | (train, test) 46 | 47 | (* shuffle then train_test_split *) 48 | let shuffle_then_cut seed p = function 49 | | [] -> failwith "Utls.shuffle_then_cut: no lines" 50 | | lines -> 51 | let rng = BatRandom.State.make [|seed|] in 52 | let rand_lines = L.shuffle ~state:rng lines in 53 | train_test_split p rand_lines 54 | 55 | (* split a list into n parts (the last part might have 56 | a different number of elements) *) 57 | let list_nparts n l = 58 | let len = L.length l in 59 | let res = ref [] in 60 | let curr = ref l in 61 | let m = int_of_float (BatFloat.ceil (float len /. float n)) in 62 | for _ = 1 to n - 1 do 63 | let xs, ys = L.takedrop m !curr in 64 | curr := ys; 65 | res := xs :: !res 66 | done; 67 | L.rev (!curr :: !res) 68 | 69 | (* create folds of cross validation; each fold consists in (train, test) *) 70 | let cv_folds n l = 71 | let test_sets = list_nparts n l in 72 | let rec loop acc prev = function 73 | | [] -> acc 74 | | x :: xs -> 75 | let before_after = L.flatten (L.rev_append prev xs) in 76 | let prev' = x :: prev in 77 | let train_test = (before_after, x) in 78 | let acc' = train_test :: acc in 79 | loop acc' prev' xs in 80 | loop [] [] test_sets 81 | 82 | (* shuffle then [n] train-test folds *) 83 | let shuffle_then_nfolds seed n = function 84 | | [] -> failwith "Utls.shuffle_then_nfolds: no lines" 85 | | lines -> 86 | let rng = BatRandom.State.make [|seed|] in 87 | let rand_lines = L.shuffle ~state:rng lines in 88 | cv_folds n rand_lines 89 | -------------------------------------------------------------------------------- /test.scored-label: -------------------------------------------------------------------------------- 1 | 0.1 0 2 | 0.2 0 3 | 0.3 1 4 | 0.4 0 5 | 0.5 0 6 | 0.6 1 7 | 0.7 1 8 | 0.8 1 9 | 0.9 0 10 | 1.0 1 11 | --------------------------------------------------------------------------------