├── python ├── tests │ ├── pharmbio │ │ ├── __init__.py │ │ ├── cp │ │ │ ├── __init__.py │ │ │ ├── metrics │ │ │ │ ├── __init__.py │ │ │ │ ├── reg_metrics_test.py │ │ │ │ └── clf_metrics_test.py │ │ │ ├── plotting │ │ │ │ ├── __init__.py │ │ │ │ ├── reg_plotting_test.py │ │ │ │ └── clf_plotting_test.py │ │ │ └── utils_test.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ └── load_data_test.py │ │ └── cpsign │ │ │ ├── __init__.py │ │ │ └── load_cpsign_test.py │ ├── resources │ │ ├── boston_labels.npy │ │ ├── boston_pred_out_3D_169.npy │ │ ├── boston_pred_out_3D_169_normalized.npy │ │ ├── cpsign_reg_stats_excl_sd.csv │ │ ├── cpsign_clf_stats_excl_sd.csv │ │ ├── cpsign_reg_stats_incl_sd.csv │ │ ├── cpsign_reg_predictions_10_incl_inf.csv │ │ ├── cpsign_clf_stats_incl_sd.csv │ │ ├── multiclass.csv │ │ ├── transporters.p-values.csv │ │ ├── cpsign_clf_predictions.csv │ │ └── er.p-values.csv │ ├── context.py │ ├── generate_test_files │ │ ├── generate_multilabel_predictions.py │ │ └── generate_regression_preds.py │ └── help_utils.py ├── src │ └── pharmbio │ │ ├── cp │ │ ├── __init__.py │ │ ├── metrics │ │ │ ├── __init__.py │ │ │ ├── _regression.py │ │ │ └── _classification.py │ │ ├── plotting │ │ │ ├── __init__.py │ │ │ ├── _utils.py │ │ │ ├── _regression.py │ │ │ └── _common.py │ │ └── utils.py │ │ ├── __init__.py │ │ ├── cpsign │ │ ├── __init__.py │ │ └── _load.py │ │ └── data │ │ ├── __init__.py │ │ └── _load.py ├── pyproject.toml ├── requirements.txt └── run_tests.sh ├── .gitignore └── README.md /python/tests/pharmbio/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/src/pharmbio/cp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/tests/pharmbio/cp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/tests/pharmbio/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/tests/pharmbio/cpsign/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/tests/pharmbio/cp/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/tests/pharmbio/cp/plotting/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/src/pharmbio/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.0' 2 | -------------------------------------------------------------------------------- /python/src/pharmbio/cpsign/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from ._load import * -------------------------------------------------------------------------------- /python/src/pharmbio/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from ._load import * -------------------------------------------------------------------------------- /python/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | addopts = [ 3 | "--import-mode=importlib", 4 | ] 5 | pythonpath = "src" -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=3.5 2 | numpy>=1.2 3 | pytest>=7.4 4 | pandas>=2.1 5 | scikit-learn>=1.2.0 6 | seaborn>=0.12 7 | -------------------------------------------------------------------------------- /python/tests/resources/boston_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pharmbio/plot_utils/HEAD/python/tests/resources/boston_labels.npy -------------------------------------------------------------------------------- /python/tests/resources/boston_pred_out_3D_169.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pharmbio/plot_utils/HEAD/python/tests/resources/boston_pred_out_3D_169.npy -------------------------------------------------------------------------------- /python/run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Create test output directories for plots 4 | mkdir -p test_output/clf test_output/reg 5 | 6 | python -m pytest 7 | -------------------------------------------------------------------------------- /python/tests/resources/boston_pred_out_3D_169_normalized.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pharmbio/plot_utils/HEAD/python/tests/resources/boston_pred_out_3D_169_normalized.npy -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore the 'compiled' python code 2 | *.pyc 3 | **.ipynb_checkpoints* 4 | # Mac meta files 5 | .DS_Store 6 | *.png 7 | test_output/ 8 | 9 | # Python venv directory and contents 10 | venv/ 11 | 12 | -------------------------------------------------------------------------------- /python/src/pharmbio/cp/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The `pharmbio.cp.metrics` module contains metrics that can be computed 3 | for conformal prediction output. 4 | """ 5 | 6 | # all 'public' classification metrics 7 | from ._classification import * 8 | 9 | # all 'public' regression metrics 10 | from ._regression import * 11 | -------------------------------------------------------------------------------- /python/tests/resources/cpsign_reg_stats_excl_sd.csv: -------------------------------------------------------------------------------- 1 | Confidence Accuracy Mean prediction interval width Efficiency (median prediction interval width) MAE RMSE R^2 2 | 0.7 0.07 4.9 4.49 31.9 38.6 -2.08 3 | 0.8 0.32 26.9 27.2 31.9 38.6 -2.08 4 | 0.9 1.0 Infinity Infinity 31.9 38.6 -2.08 5 | 0.95 1.0 Infinity Infinity 31.9 38.6 -2.08 6 | 1.0 1.0 Infinity Infinity 31.9 38.6 -2.08 -------------------------------------------------------------------------------- /python/tests/context.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 4 | 5 | 6 | # calculate absolute paths for loading test-files and writing output 7 | resource_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources") 8 | output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_output") 9 | -------------------------------------------------------------------------------- /python/src/pharmbio/cp/plotting/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The `pharmbio.cp.plotting` module contains plotting functions 3 | for conformal prediction output. 4 | """ 5 | 6 | # all 'public' classification plotting functions 7 | from ._classification import * 8 | 9 | # all 'public' regression functions - TODO 10 | from ._regression import * 11 | 12 | # From the common stuff 13 | from ._common import update_plot_settings,plot_calibration 14 | -------------------------------------------------------------------------------- /python/tests/pharmbio/data/load_data_test.py: -------------------------------------------------------------------------------- 1 | from pharmbio import data 2 | import pytest 3 | from ...help_utils import get_resource 4 | 5 | class Test_regression(): 6 | 7 | def test_load(self): 8 | (y,predictions,signs) = data.load_regression(get_resource('cpsign_reg_predictions.csv'),'y') 9 | assert len(y) == len(predictions) 10 | assert predictions.shape[2] ==len(signs) 11 | assert pytest.approx(1) == signs[0] 12 | 13 | -------------------------------------------------------------------------------- /python/tests/resources/cpsign_clf_stats_excl_sd.csv: -------------------------------------------------------------------------------- 1 | Confidence Accuracy Accuracy(mutagen) Accuracy(nonmutagen) Proportion empty-label prediction sets Proportion multi-label prediction sets Proportion single-label prediction sets AverageC Balanced Observed Fuzziness Observed Fuzziness Unobserved Confidence Unobserved Credibility Balanced Accuracy Classifier Accuracy F1Score_macro F1Score_micro F1Score_weighted NPV Precision ROC AUC Recall 2 | 0.5 0.516 0.484 0.548 0.444 0.0 0.556 1.34 0.148 0.148 0.916 0.577 0.785 0.786 0.785 0.786 0.786 0.776 0.797 0.871 0.758 3 | 0.7 0.746 0.766 0.726 0.0952 0.0 0.905 1.34 0.148 0.148 0.916 0.577 0.785 0.786 0.785 0.786 0.786 0.776 0.797 0.871 0.758 4 | 0.9 0.937 0.969 0.903 0.0 0.341 0.659 1.34 0.148 0.148 0.916 0.577 0.785 0.786 0.785 0.786 0.786 0.776 0.797 0.871 0.758 5 | -------------------------------------------------------------------------------- /python/tests/generate_test_files/generate_multilabel_predictions.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import load_iris 2 | import numpy as np 3 | from sklearn.svm import SVC 4 | 5 | import sys 6 | sys.path.append('/Users/staffan/git/peptid_studie/experiments/src') # Nonconformist 7 | 8 | from nonconformist.cp import TcpClassifier 9 | from nonconformist.nc import NcFactory 10 | 11 | 12 | iris = load_iris() 13 | 14 | idx = np.random.permutation(iris.target.size) 15 | 16 | # Divide the data into training set and test set 17 | idx_train, idx_test = idx[:100], idx[100:] 18 | 19 | model = SVC(probability=True) # Create the underlying model 20 | nc = NcFactory.create_nc(model) # Create a default nonconformity function 21 | tcp = TcpClassifier(nc) # Create a transductive conformal classifier 22 | 23 | # Fit the TCP using the proper training set 24 | tcp.fit(iris.data[idx_train, :], iris.target[idx_train]) 25 | 26 | # Produce predictions for the test set 27 | predictions = tcp.predict(iris.data[idx_test, :]) 28 | 29 | # 30 | targets = np.array(iris.target[idx_test], copy=True) 31 | targets.shape = (len(targets),1) 32 | output = np.hstack((targets, predictions)) 33 | 34 | np.savetxt('resources/multiclass.csv', output, delimiter=',') -------------------------------------------------------------------------------- /python/tests/resources/cpsign_reg_stats_incl_sd.csv: -------------------------------------------------------------------------------- 1 | Confidence Accuracy Accuracy_SD Mean prediction interval width Mean prediction interval width_SD Efficiency (median prediction interval width) Efficiency (median prediction interval width)_SD MAE MAE_SD RMSE RMSE_SD R^2 R^2_SD 2 | 0.05 0.0463 0.0214 0.126 0.072 0.0791 0.0212 0.769 0.0875 1.03 0.111 0.483 0.101 3 | 0.1 0.104 0.038 0.26 0.109 0.169 0.0441 0.769 0.0875 1.03 0.111 0.483 0.101 4 | 0.15 0.166 0.0543 0.421 0.167 0.273 0.0655 0.769 0.0875 1.03 0.111 0.483 0.101 5 | 0.2 0.231 0.0484 0.58 0.239 0.373 0.0768 0.769 0.0875 1.03 0.111 0.483 0.101 6 | 0.25 0.279 0.0331 0.741 0.368 0.467 0.0757 0.769 0.0875 1.03 0.111 0.483 0.101 7 | 0.3 0.339 0.0407 0.935 0.425 0.593 0.0888 0.769 0.0875 1.03 0.111 0.483 0.101 8 | 0.35 0.379 0.03 1.1 0.508 0.696 0.084 0.769 0.0875 1.03 0.111 0.483 0.101 9 | 0.4 0.422 0.0315 1.3 0.622 0.819 0.0949 0.769 0.0875 1.03 0.111 0.483 0.101 10 | 0.45 0.46 0.0342 1.48 0.686 0.933 0.0959 0.769 0.0875 1.03 0.111 0.483 0.101 11 | 0.5 0.499 0.0324 1.68 0.788 1.05 0.0912 0.769 0.0875 1.03 0.111 0.483 0.101 12 | 0.55 0.535 0.0409 1.9 0.873 1.2 0.0949 0.769 0.0875 1.03 0.111 0.483 0.101 13 | 0.6 0.59 0.0394 2.24 1.04 1.4 0.0903 0.769 0.0875 1.03 0.111 0.483 0.101 14 | 0.65 0.627 0.0341 2.49 1.17 1.56 0.121 0.769 0.0875 1.03 0.111 0.483 0.101 15 | 0.7 0.691 0.0398 3.05 1.51 1.9 0.154 0.769 0.0875 1.03 0.111 0.483 0.101 16 | 0.75 0.746 0.0342 3.62 1.87 2.24 0.172 0.769 0.0875 1.03 0.111 0.483 0.101 17 | 0.8 0.81 0.0453 4.32 2.11 2.69 0.212 0.769 0.0875 1.03 0.111 0.483 0.101 18 | 0.85 0.861 0.0393 5.3 2.59 3.29 0.362 0.769 0.0875 1.03 0.111 0.483 0.101 19 | 0.9 0.912 0.0276 7.29 4.75 4.37 0.72 0.769 0.0875 1.03 0.111 0.483 0.101 20 | 0.95 0.953 0.0221 11.0 7.35 6.59 1.63 0.769 0.0875 1.03 0.111 0.483 0.101 21 | 22 | -------------------------------------------------------------------------------- /python/tests/pharmbio/cp/utils_test.py: -------------------------------------------------------------------------------- 1 | from pharmbio.cp import utils 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | 7 | def test_validate_sign_single(): 8 | # Test single value 9 | # this should be fine 10 | utils.validate_sign(0) 11 | utils.validate_sign(.1) 12 | utils.validate_sign(1) 13 | 14 | 15 | # Too low value 16 | with pytest.raises(ValueError): 17 | utils.validate_sign(-.1) 18 | # Too high value 19 | with pytest.raises(ValueError): 20 | utils.validate_sign(1.001) 21 | 22 | def test_validate_sign_numpy(): 23 | # Should all be OK 24 | utils.validate_sign(np.asarray(0)) 25 | utils.validate_sign(np.asarray(.1)) 26 | utils.validate_sign(np.asarray(1)) 27 | utils.validate_sign(np.asarray((0,.1,.2,.3,.5,.9,.99,1))) 28 | 29 | # Too low value 30 | with pytest.raises(ValueError): 31 | utils.validate_sign(np.asarray((-.1,.2))) 32 | # Too high value 33 | with pytest.raises(ValueError): 34 | utils.validate_sign(np.asarray((0.1,.5,1.001))) 35 | 36 | # Invalid shape 37 | with pytest.raises(ValueError): 38 | utils.validate_sign(np.asarray(1.001).reshape((1,1))) 39 | 40 | def test_validate_sign_pd_series(): 41 | # Should all be OK 42 | utils.validate_sign(pd.Series(0)) 43 | utils.validate_sign(pd.Series(.1)) 44 | utils.validate_sign(pd.Series(1)) 45 | utils.validate_sign(pd.Series((0,.1,.2,.3,.5,.9,.99,1))) 46 | 47 | # Too low value 48 | with pytest.raises(ValueError): 49 | utils.validate_sign(pd.Series((-.1,.2))) 50 | # Too high value 51 | with pytest.raises(ValueError): 52 | utils.validate_sign(pd.Series((0.1,.5,1.001))) 53 | 54 | # Invalid shape 55 | with pytest.raises(TypeError): 56 | utils.validate_sign(pd.DataFrame(data=[[0.1]])) 57 | -------------------------------------------------------------------------------- /python/tests/resources/cpsign_reg_predictions_10_incl_inf.csv: -------------------------------------------------------------------------------- 1 | canonical_smiles solubility Predicted value (ŷ) Prediction interval lower bound (confidence=0.1) Prediction interval upper bound (confidence=0.1) Capped prediction interval lower bound (confidence=0.1) Capped prediction interval upper bound (confidence=0.1) Prediction interval lower bound (confidence=0.4) Prediction interval upper bound (confidence=0.4) Capped prediction interval lower bound (confidence=0.4) Capped prediction interval upper bound (confidence=0.4) Prediction interval lower bound (confidence=0.9) Prediction interval upper bound (confidence=0.9) Capped prediction interval lower bound (confidence=0.9) Capped prediction interval upper bound (confidence=0.9) Prediction interval lower bound (confidence=1.0) Prediction interval upper bound (confidence=1.0) Capped prediction interval lower bound (confidence=1.0) Capped prediction interval upper bound (confidence=1.0) 2 | CC(C1=CC=CC=C1)NCCC(C2=CC=CC=C2)C3=CC=CC=C3.Cl 52.8 3.38 3.22 3.53 3.22 3.53 2.66 4.1 2.66 4.1 -Infinity Infinity 0.05 31.0 -Infinity Infinity 0.05 31.0 3 | CCN(CC)CCOC(=O)C(C1CCCCC1)C2=CC=CC=C2.Cl 53.1 4.18 4.02 4.34 4.02 4.34 3.39 4.97 3.39 4.97 -Infinity Infinity 0.05 31.0 -Infinity Infinity 0.05 31.0 4 | CN(C)CCOC(C1=CC=CC=C1)C2=CC=CC=C2.Cl 43.8 3.53 3.37 3.7 3.37 3.7 2.73 4.34 2.73 4.34 -Infinity Infinity 0.05 31.0 -Infinity Infinity 0.05 31.0 5 | C1=CC=C(C=C1)N=NC2=C(N=C(C=C2)N)N.Cl 30.4 4.79 4.6 4.98 4.6 4.98 3.96 5.63 3.96 5.63 -Infinity Infinity 0.05 31.0 -Infinity Infinity 0.05 31.0 6 | CC1=C(C(=NN1)C)NC(=O)CN2CCN(CC2)S(=O)(=O)C3=CC(=C(C=C3)OC)OC.C(=O)(C(=O)O)O 63.6 3.64 3.47 3.8 3.47 3.8 2.91 4.36 2.91 4.36 -Infinity Infinity 0.05 31.0 -Infinity Infinity 0.05 31.0 7 | CN(CC(=O)NC1=CC(=CC(=C1)C(=O)OC)C(=O)OC)C2CCS(=O)(=O)C2.C(=O)(C(=O)O)O 73.3 2.45 2.33 2.58 2.33 2.58 1.84 3.07 1.84 3.07 -Infinity Infinity 0.05 31.0 -Infinity Infinity 0.05 31.0 8 | CC(C(=O)NC1=CC=C(C=C1)C(F)(F)F)N(C)CC(=O)N2CCN(CC2)CC3=CC=CC=C3.C(=O)(C(=O)O)O 52.1 1.0 0.9 1.1 0.9 1.1 0.563 1.44 0.563 1.44 -Infinity Infinity 0.05 31.0 -Infinity Infinity 0.05 31.0 9 | CC1=CC(=CC=C1)S(=O)(=O)NC2=C(C=C(C=C2)N3CCN(CC3)C)C.Cl 52 5.16 4.89 5.42 4.89 5.42 4.14 6.18 4.14 6.18 -Infinity Infinity 0.05 31.0 -Infinity Infinity 0.05 31.0 10 | CC(C(=O)NC1=CC=CC=C1OC)N(C)CC(=O)NC2=CC=CC3=CC=CC=C32.Cl 64.2 4.12 3.93 4.31 3.93 4.31 3.31 4.93 3.31 4.93 -Infinity Infinity 0.05 31.0 -Infinity Infinity 0.05 31.0 11 | C1COCCN1CCNC(=O)C2=CC=CC=C2OCC3=CC=CC=C3.Cl 56.5 3.31 3.16 3.46 3.16 3.46 2.56 4.06 2.56 4.06 -Infinity Infinity 0.05 31.0 -Infinity Infinity 0.05 31.0 12 | -------------------------------------------------------------------------------- /python/src/pharmbio/cp/metrics/_regression.py: -------------------------------------------------------------------------------- 1 | """"CP Regression metrics 2 | 3 | """ 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from ..utils import * 8 | 9 | def pred_width(predictions, median = True): 10 | """**Regression** - Calculates the median or mean interval width of the Confidence Intervals 11 | 12 | Parameters 13 | ---------- 14 | predictions : 2D or 3D numpy array 15 | The input can either be a 2D array from a single significance level with shape (n_samples,2), or a 3D array with predictions for several significance levels with shape (n_samples,2,n_significance_levels). In the second dimension of the array, the first index should contain the lower/min value and the second the upper/max value of the prediction interval 16 | 17 | median : bool 18 | True - if the median interval should be calculated 19 | False - if mean should be calculated 20 | 21 | Returns 22 | ------- 23 | widths : float or 1D numpy array 24 | A scalar value if the `predictions` input is 2D, or a 1D array if the `predictions` is 3D (one median/mean value for each significance level) 25 | """ 26 | n_sign, pred_matrix = validate_regression_preds(predictions) 27 | 28 | if n_sign > 1: 29 | # 3D matrix 30 | widths = pred_matrix[:,1,:] - pred_matrix[:,0,:] 31 | else: 32 | # 2D matrix 33 | widths = pred_matrix[:,1] - pred_matrix[:,0] 34 | 35 | if (np.any(widths < 0)): 36 | raise ValueError('Invalid input, prediction intervals cannot be negative') 37 | 38 | return (np.median(widths, axis=0) if median else np.mean(widths, axis=0)) 39 | 40 | def frac_error_reg(y_true, predictions): 41 | """**Regression** - Calculate the fraction of errors 42 | 43 | Parameters 44 | ---------- 45 | y_true : 1d array like 46 | List or array with the true labels, must be convertable to numpy ndarray 47 | 48 | predictions : 2D or 3D ndarray 49 | A matrix with either shape (n_samples, 2, n_sign_levels) or (n_sampes,2). The shape of the preidctions will decide the output dimension of the error_rates 50 | 51 | Returns 52 | ------- 53 | error_rates : float or 1D ndarray 54 | The either a single float in case input is 2D, or an array of error rates (one for each significance level) 55 | 56 | """ 57 | # Validation and potential 58 | n_sign, pred_matrix = validate_regression_preds(predictions) 59 | ys = to_numpy1D_reg_y_true(y_true, pred_matrix.shape[0]) 60 | 61 | if n_sign > 1: 62 | # 3D matrix 63 | ys.shape = (ys.shape[0],1) # turn to matrix in order to broadcast 64 | truth_vals = (np.greater_equal(ys,pred_matrix[:,0])) & (np.greater_equal(pred_matrix[:,1],ys)) 65 | else: 66 | # 2D matrix 67 | truth_vals = (pred_matrix[:,0] <= ys) & (ys <= pred_matrix[:,1]) 68 | # True = 1, False = 0, mean will be fraction of true values 69 | return 1 - truth_vals.mean(0) 70 | -------------------------------------------------------------------------------- /python/tests/pharmbio/cp/metrics/reg_metrics_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | 5 | from ....help_utils import get_resource 6 | from pharmbio.cp.metrics import (pred_width,frac_error_reg) 7 | from statistics import mean 8 | 9 | class Test_pred_width(): 10 | 11 | def test_single_pred(self): 12 | pred = np.array([ 13 | [0,1] 14 | ]) 15 | m = pred_width(pred) 16 | assert 1 == m 17 | 18 | pred_2 = np.array([ 19 | [-1.5,5.2] 20 | ]) 21 | assert pytest.approx(1.5+5.2) == pred_width(pred_2) 22 | 23 | def test_two_pred(self): 24 | pred = np.array([ 25 | [0.5,1], 26 | [3.5,9] 27 | ]) 28 | # median 29 | assert pytest.approx(np.median([.5, 9-3.5])) == pred_width(pred) 30 | # mean 31 | assert pytest.approx(np.mean([.5, 9-3.5])) == pred_width(pred,median=False) 32 | 33 | def test_3d_pred(self): 34 | # Create a (3,2,2) matrix (for two different significance levels) 35 | 36 | # First "significance" level 37 | pred_sig1 = np.array([ 38 | [0.5,1], 39 | [3.5,9], 40 | [5.1,8.7]]) 41 | 42 | # Second "significance" level 43 | pred_sig2= np.array([ 44 | [0.25,1.5], 45 | [3,9.5], 46 | [5,9]]) 47 | 48 | # Stack them to 3D 49 | pred_3d = np.stack((pred_sig1,pred_sig2), axis=-1) 50 | assert (3,2,2) == pred_3d.shape 51 | # first level: np.mean([.5,9-3.5,8.7-5.1]) = 3.2, median = 3.6 52 | # second level: np.mean([1.5-.25,9.5-3,9-5]) = 3.9166667, median = 4 53 | median = pred_width(pred_3d, median=True) 54 | mean = pred_width(pred_3d, median=False) 55 | assert equal_np_arrays([3.2, 3.9166667], mean) 56 | assert equal_np_arrays([3.6, 4], median) 57 | 58 | def test_3d_boston(self): 59 | boston_preds = np.load(get_resource('boston_pred_out_3D_169.npy')) 60 | assert (169,2,99) == boston_preds.shape 61 | withs = pred_width(boston_preds) 62 | #print(withs) 63 | assert len(withs) == 99 64 | 65 | class Test_frac_error_reg(): 66 | 67 | def test_boston(self): 68 | boston_preds = np.load(get_resource('boston_pred_out_3D_169.npy')) 69 | boston_labels = np.load(get_resource('boston_labels.npy')) 70 | # Try 3D 71 | errs3d = frac_error_reg(boston_labels,boston_preds) 72 | #print(errs3d) 73 | # As 2D (a single significance level) 74 | for i in range(boston_preds.shape[2]): 75 | errs2d = frac_error_reg(boston_labels,boston_preds[:,:,i]) 76 | #print(errs2d) 77 | assert pytest.approx(errs2d) == errs3d[i] 78 | 79 | 80 | 81 | def equal_np_arrays(arr1, arr2): 82 | if len(arr1) != len(arr2): 83 | return False 84 | return True if np.all((arr1 - arr2) < 0.000001) else False 85 | -------------------------------------------------------------------------------- /python/tests/resources/cpsign_clf_stats_incl_sd.csv: -------------------------------------------------------------------------------- 1 | Confidence Accuracy(nonmutagen) Accuracy(nonmutagen)_SD Accuracy Accuracy_SD Accuracy(mutagen) Accuracy(mutagen)_SD Proportion empty-label prediction sets Proportion empty-label prediction sets_SD Proportion multi-label prediction sets Proportion multi-label prediction sets_SD Proportion single-label prediction sets Proportion single-label prediction sets_SD AverageC AverageC_SD Balanced Observed Fuzziness Balanced Observed Fuzziness_SD Observed Fuzziness Observed Fuzziness_SD Unobserved Confidence Unobserved Confidence_SD Unobserved Credibility Unobserved Credibility_SD Balanced Accuracy Balanced Accuracy_SD Classifier Accuracy Classifier Accuracy_SD F1Score_weighted F1Score_weighted_SD F1Score_macro F1Score_macro_SD F1Score_micro F1Score_micro_SD NPV NPV_SD Precision Precision_SD ROC AUC ROC AUC_SD Recall Recall_SD 2 | 0.1 0.124 0.0219 0.104 0.012 0.0885 0.0164 0.884 0.0149 0.0 0.0 0.116 0.0149 0.944 0.0343 0.122 0.0123 0.123 0.0127 0.946 0.00585 0.553 0.0176 0.805 0.0242 0.805 0.0252 0.805 0.0249 0.803 0.0251 0.805 0.0252 0.843 0.0211 0.762 0.0358 0.866 0.0203 0.813 0.028 3 | 0.2 0.216 0.0262 0.195 0.0176 0.178 0.0283 0.784 0.0208 0.0 0.0 0.216 0.0208 0.944 0.0343 0.122 0.0123 0.123 0.0127 0.946 0.00585 0.553 0.0176 0.805 0.0242 0.805 0.0252 0.805 0.0249 0.803 0.0251 0.805 0.0252 0.843 0.0211 0.762 0.0358 0.866 0.0203 0.813 0.028 4 | 0.3 0.321 0.0308 0.291 0.0244 0.267 0.0286 0.678 0.0299 0.0 0.0 0.322 0.0299 0.944 0.0343 0.122 0.0123 0.123 0.0127 0.946 0.00585 0.553 0.0176 0.805 0.0242 0.805 0.0252 0.805 0.0249 0.803 0.0251 0.805 0.0252 0.843 0.0211 0.762 0.0358 0.866 0.0203 0.813 0.028 5 | 0.4 0.42 0.0365 0.382 0.0256 0.352 0.0287 0.573 0.0333 0.0 0.0 0.427 0.0333 0.944 0.0343 0.122 0.0123 0.123 0.0127 0.946 0.00585 0.553 0.0176 0.805 0.0242 0.805 0.0252 0.805 0.0249 0.803 0.0251 0.805 0.0252 0.843 0.0211 0.762 0.0358 0.866 0.0203 0.813 0.028 6 | 0.5 0.503 0.0357 0.474 0.024 0.451 0.036 0.463 0.0312 0.0 0.0 0.537 0.0312 0.944 0.0343 0.122 0.0123 0.123 0.0127 0.946 0.00585 0.553 0.0176 0.805 0.0242 0.805 0.0252 0.805 0.0249 0.803 0.0251 0.805 0.0252 0.843 0.0211 0.762 0.0358 0.866 0.0203 0.813 0.028 7 | 0.6 0.596 0.0337 0.569 0.0202 0.548 0.0283 0.345 0.0246 0.0 0.0 0.655 0.0246 0.944 0.0343 0.122 0.0123 0.123 0.0127 0.946 0.00585 0.553 0.0176 0.805 0.0242 0.805 0.0252 0.805 0.0249 0.803 0.0251 0.805 0.0252 0.843 0.0211 0.762 0.0358 0.866 0.0203 0.813 0.028 8 | 0.7 0.705 0.0414 0.686 0.0287 0.671 0.0397 0.196 0.0346 0.0 0.0 0.804 0.0346 0.944 0.0343 0.122 0.0123 0.123 0.0127 0.946 0.00585 0.553 0.0176 0.805 0.0242 0.805 0.0252 0.805 0.0249 0.803 0.0251 0.805 0.0252 0.843 0.0211 0.762 0.0358 0.866 0.0203 0.813 0.028 9 | 0.8 0.789 0.0306 0.777 0.0348 0.767 0.0457 0.0559 0.0343 0.0 0.0 0.944 0.0343 0.944 0.0343 0.122 0.0123 0.123 0.0127 0.946 0.00585 0.553 0.0176 0.805 0.0242 0.805 0.0252 0.805 0.0249 0.803 0.0251 0.805 0.0252 0.843 0.0211 0.762 0.0358 0.866 0.0203 0.813 0.028 10 | 0.9 0.882 0.0319 0.879 0.0224 0.877 0.0299 0.0 0.0 0.185 0.0405 0.815 0.0405 0.944 0.0343 0.122 0.0123 0.123 0.0127 0.946 0.00585 0.553 0.0176 0.805 0.0242 0.805 0.0252 0.805 0.0249 0.803 0.0251 0.805 0.0252 0.843 0.0211 0.762 0.0358 0.866 0.0203 0.813 0.028 11 | 12 | -------------------------------------------------------------------------------- /python/tests/help_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | from matplotlib.figure import Figure 5 | import matplotlib.pyplot as plt 6 | from pharmbio.cp import metrics 7 | import os 8 | 9 | import pharmbio.cp.plotting._utils as plt_utils 10 | 11 | from .context import output_dir, resource_dir 12 | # Utility functions 13 | 14 | def get_resource(file_name:str)->str: 15 | return os.path.join(resource_dir,file_name) 16 | 17 | def _save_reg(fig, test_func, close_after=True): 18 | test_func = test_func if test_func is not None and isinstance(test_func,str) else str(test_func) 19 | ax1 = fig.axes[0] 20 | ax1_orig_title = ax1.get_title() 21 | ax1.set_title(ax1_orig_title+":"+test_func) 22 | fig.savefig(output_dir+'/reg/'+test_func+".pdf", bbox_inches='tight') 23 | ax1.set_title(ax1_orig_title) 24 | if close_after: 25 | plt.close(fig) 26 | 27 | def _save_clf(fig, test_func, close_after=True): 28 | test_func = test_func if test_func is not None and isinstance(test_func,str) else str(test_func) 29 | ax1 = fig.axes[0] 30 | ax1_orig_title = ax1.get_title() 31 | ax1.set_title(ax1_orig_title+":"+test_func) 32 | fig.savefig(output_dir+'/clf/'+test_func+".pdf", bbox_inches='tight') 33 | ax1.set_title(ax1_orig_title) 34 | if close_after: 35 | plt.close(fig) 36 | 37 | def assert_fig_wh(fig : Figure, w : int, h : int, ax): 38 | """ 39 | fig: matplotlib.Figure 40 | w: float, expected width of fig 41 | h: float, expected height of fig 42 | ax: Figure.Axes that should be part of the `fig` object 43 | 44 | """ 45 | 46 | assert w == fig.get_figwidth() 47 | assert h == fig.get_figheight() 48 | 49 | ax_found = False 50 | assert ax is not None, "Axes should not be None" 51 | for a in fig.axes: 52 | if a == ax: 53 | ax_found = True 54 | assert ax_found, "Should found axes" 55 | 56 | class TestMetricUtils(): 57 | 58 | def test_get_onehot_bool(self): 59 | y_true = [1, 0, 1, 2, 0] 60 | one_hot, cats = metrics.to_numpy1D_onehot(y_true,'test') 61 | assert (5,3) == one_hot.shape 62 | assert np.all(np.ones(5) == one_hot.sum(axis=1)) 63 | assert np.all(cats == np.unique(y_true)) 64 | 65 | class TestPlottingUtils(): 66 | 67 | def test_generate_figure(self): 68 | fig, ax = plt_utils.get_fig_and_axis() 69 | ax.set_title('Default fig') 70 | assert_fig_wh(fig,10,8,ax) 71 | 72 | fig, ax = plt_utils.get_fig_and_axis(figsize=10) 73 | assert_fig_wh(fig,10,10,ax) 74 | 75 | fig, ax = plt_utils.get_fig_and_axis(figsize=(10,5)) 76 | assert_fig_wh(fig,10,5,ax) 77 | 78 | def test_existing_ax_object(self): 79 | fig, (ax1, ax2) = plt.subplots(2,1, figsize=(2,3)) 80 | assert_fig_wh(fig,2,3,ax1) 81 | assert_fig_wh(fig,2,3,ax2) 82 | 83 | def test_faulty_figsize_param(self): 84 | with pytest.raises(TypeError): 85 | plt_utils.get_fig_and_axis(figsize="") 86 | 87 | def test_faulty_ax_param(self): 88 | with pytest.raises(TypeError): 89 | plt_utils.get_fig_and_axis(ax=1) 90 | 91 | def test_utility_save_func(self): 92 | fig, _ = plt_utils.get_fig_and_axis(figsize=(10,5)) 93 | _save_reg(fig,'test_utility_save_func') 94 | _save_clf(fig,'test_utility_save_func_clf') 95 | 96 | 97 | -------------------------------------------------------------------------------- /python/tests/generate_test_files/generate_regression_preds.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Example: inductive conformal regression using DecisionTreeRegressor 5 | """ 6 | 7 | # Authors: Henrik Linusson 8 | 9 | import numpy as np 10 | import pandas as pd 11 | 12 | from sklearn.tree import DecisionTreeRegressor 13 | from sklearn.neighbors import KNeighborsRegressor 14 | from sklearn.datasets import load_boston 15 | 16 | from nonconformist.base import RegressorAdapter 17 | from nonconformist.icp import IcpRegressor 18 | from nonconformist.nc import RegressorNc, AbsErrorErrFunc, RegressorNormalizer 19 | 20 | # ----------------------------------------------------------------------------- 21 | # Setup training, calibration and test indices 22 | # ----------------------------------------------------------------------------- 23 | data = load_boston() 24 | 25 | idx = np.random.permutation(data.target.size) 26 | train = idx[:int(idx.size / 3)] 27 | calibrate = idx[int(idx.size / 3):int(2 * idx.size / 3)] 28 | test = idx[int(2 * idx.size / 3):] 29 | 30 | # ----------------------------------------------------------------------------- 31 | # Without normalization 32 | # ----------------------------------------------------------------------------- 33 | # Train and calibrate 34 | # ----------------------------------------------------------------------------- 35 | underlying_model = RegressorAdapter(DecisionTreeRegressor(min_samples_leaf=5)) 36 | nc = RegressorNc(underlying_model, AbsErrorErrFunc()) 37 | icp = IcpRegressor(nc) 38 | icp.fit(data.data[train, :], data.target[train]) 39 | icp.calibrate(data.data[calibrate, :], data.target[calibrate]) 40 | 41 | # ----------------------------------------------------------------------------- 42 | # Predict 43 | # ----------------------------------------------------------------------------- 44 | prediction3D = icp.predict(data.data[test, :], significance=None) 45 | 46 | np.save('resources/boston_pred_out_3D_169.npy',prediction3D) 47 | np.save('resources/boston_labels.npy',data.target[test]) 48 | # print(prediction3D.shape) 49 | # print(data.target[test]) 50 | # exit() 51 | # prediction = icp.predict(data.data[test, :], significance=0.1) 52 | # header = ['min','max','truth','size'] 53 | # size = prediction[:, 1] - prediction[:, 0] 54 | # table = np.vstack([prediction.T, data.target[test], size.T]).T 55 | # df = pd.DataFrame(table, columns=header) 56 | # print(df) 57 | 58 | # ----------------------------------------------------------------------------- 59 | # With normalization 60 | # ----------------------------------------------------------------------------- 61 | # Train and calibrate 62 | # ----------------------------------------------------------------------------- 63 | underlying_model = RegressorAdapter(DecisionTreeRegressor(min_samples_leaf=5)) 64 | normalizing_model = RegressorAdapter(KNeighborsRegressor(n_neighbors=1)) 65 | normalizer = RegressorNormalizer(underlying_model, normalizing_model, AbsErrorErrFunc()) 66 | nc = RegressorNc(underlying_model, AbsErrorErrFunc(), normalizer) 67 | icp = IcpRegressor(nc) 68 | icp.fit(data.data[train, :], data.target[train]) 69 | icp.calibrate(data.data[calibrate, :], data.target[calibrate]) 70 | 71 | # ----------------------------------------------------------------------------- 72 | # Predict 73 | # ----------------------------------------------------------------------------- 74 | prediction = icp.predict(data.data[test, :], significance=None) #0.1) 75 | np.save('resources/boston_pred_out_3D_169_normalized.npy',prediction) 76 | 77 | exit() 78 | header = ['min','max','truth','size'] 79 | size = prediction[:, 1] - prediction[:, 0] 80 | table = np.vstack([prediction.T, data.target[test], size.T]).T 81 | df = pd.DataFrame(table, columns=header) 82 | print(df) 83 | -------------------------------------------------------------------------------- /python/tests/pharmbio/cp/plotting/reg_plotting_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import matplotlib.pyplot as plt 4 | 5 | from pharmbio.cp import metrics,plotting 6 | from ....help_utils import _save_reg, get_resource 7 | from ....context import resource_dir 8 | 9 | boston_preds = np.load(get_resource('boston_pred_out_3D_169.npy')) 10 | boston_preds_norm = np.load(get_resource('boston_pred_out_3D_169_normalized.npy')) 11 | boston_labels = np.load(get_resource('boston_labels.npy')) 12 | # These are the ones nonconformist calculates 13 | significance_lvls = np.arange(0.01,1,0.01) 14 | 15 | class Test_calib_plot(): 16 | 17 | def test_boston(self): 18 | error_rates = metrics.frac_error_reg(boston_labels,boston_preds) 19 | fig = plotting.plot_calibration(significance_lvls,error_rates) 20 | _save_reg(fig,"Test_calib_plot.test_boston") 21 | 22 | def test_boston_flipping(self): 23 | error_rates = metrics.frac_error_reg(boston_labels,boston_preds) 24 | # Plot all in one image 25 | fig, axes = plt.subplots(2,2,figsize=(10,10)) 26 | # std 27 | plotting.plot_calibration(significance_lvls,error_rates, ax=axes[0,0], title='std') 28 | # flip x 29 | plotting.plot_calibration(significance_lvls,error_rates, ax=axes[0,1], flip_x=True, title='flip x') 30 | # flip y 31 | plotting.plot_calibration(significance_lvls,error_rates, ax=axes[1,0], flip_y=True, title='flip y') 32 | # flip both 33 | plotting.plot_calibration(significance_lvls,error_rates,ax=axes[1,1], flip_x=True, flip_y=True, title='both') 34 | 35 | _save_reg(fig,"Test_calib_plot.test_boston_flippin") 36 | 37 | 38 | class Test_pred_width(): 39 | 40 | def test_boston(self): 41 | pred_widths = metrics.pred_width(boston_preds) 42 | fig = plotting.plot_pred_widths(significance_lvls,pred_widths) 43 | _save_reg(fig,"Test_pred_width.test_boston") 44 | 45 | def test_boston_subset(self): 46 | pred_widths = metrics.pred_width(boston_preds) 47 | fig = plotting.plot_pred_widths(significance_lvls[10:],pred_widths[10:]) 48 | ax = fig.axes[0] 49 | _save_reg(fig,"Test_pred_width.test_boston_subset") 50 | 51 | def test_boston_non_std(self): 52 | pred_widths = metrics.pred_width(boston_preds) 53 | fig, axes = plt.subplots(1,2,figsize=(10,10)) 54 | # No flip 55 | plotting.plot_pred_widths(significance_lvls,pred_widths,ax=axes[0],title="standard") 56 | # Flip 57 | plotting.plot_pred_widths(significance_lvls,pred_widths,ax=axes[1],flip_x=True,title="flip_x") 58 | # Save 59 | _save_reg(fig,"Test_pred_width.test_boston_flippin'") 60 | 61 | 62 | class Test_pred_intervals(): 63 | 64 | def test_boston(self): 65 | fig = plotting.plot_pred_intervals(boston_labels, 66 | boston_preds[:,:,70] 67 | # ,incorrect_ci='red' 68 | # , line_cap = 2 69 | ) 70 | fig.get_axes()[0].legend(loc='upper left') 71 | _save_reg(fig,"Test_pred_intervals.test_boston") 72 | 73 | def test_boston_norm(self): 74 | fig = plotting.plot_pred_intervals(boston_labels, 75 | boston_preds_norm[:,:,70] 76 | # ,incorrect_ci='red' 77 | , line_cap = True 78 | , incorrect_ci= 'k' 79 | ) 80 | fig.get_axes()[0].legend(loc='upper left') 81 | _save_reg(fig,"Test_pred_width:test_boston_norm") 82 | 83 | def test_boston_gray(self): 84 | fig = plotting.plot_pred_intervals(boston_labels, 85 | boston_preds_norm[:,:,70], 86 | correct_color='gray', 87 | correct_marker='o', 88 | incorrect_color='k', 89 | incorrect_marker='X' 90 | # ,incorrect_ci='red' 91 | , line_cap = True 92 | , incorrect_ci= 'k' 93 | ) 94 | fig.get_axes()[0].legend(loc='upper left') 95 | _save_reg(fig,"Test_pred_width:test_boston_gray") 96 | 97 | -------------------------------------------------------------------------------- /python/tests/resources/multiclass.csv: -------------------------------------------------------------------------------- 1 | 1.000000000000000000e+00,-2.471993677768674104e-03,3.342445095027559571e-02,3.484596123322193650e-02 2 | 2.000000000000000000e+00,1.556395788347502557e-02,1.211005214303505738e-02,2.355212318453251963e-01 3 | 0.000000000000000000e+00,3.745365420569936599e-01,-4.765865232552582431e-03,8.838414066113570514e-03 4 | 0.000000000000000000e+00,5.720112566031760881e-01,3.483023822294546712e-03,-4.751797856114580582e-03 5 | 0.000000000000000000e+00,9.252899078705828506e-01,7.424839358164513591e-03,-5.801593008546311082e-03 6 | 1.000000000000000000e+00,1.118654241131473819e-02,2.554505473865385512e-01,-5.039412562607241235e-03 7 | 0.000000000000000000e+00,7.502572635671252987e-01,6.760726725763965307e-03,-2.315200765770199752e-03 8 | 1.000000000000000000e+00,-3.613450140674955013e-03,1.209502876193996462e-01,1.130923719758454397e-02 9 | 1.000000000000000000e+00,5.509020755024003277e-03,9.510823959044310827e-03,1.270433050416252263e-01 10 | 0.000000000000000000e+00,4.469313512680007094e-01,-1.910398435348542806e-03,9.761274306024953981e-03 11 | 2.000000000000000000e+00,-4.834290912428928079e-04,9.994808315212969432e-03,7.950628716162053689e-01 12 | 1.000000000000000000e+00,3.019798413930200199e-03,1.805155315461476062e-01,8.281148928887323291e-03 13 | 0.000000000000000000e+00,6.880007393388273274e-01,-6.853648899008451178e-04,2.117356584667600284e-03 14 | 0.000000000000000000e+00,3.346417126138364062e-01,8.767088522338964754e-03,6.645578243473292818e-03 15 | 2.000000000000000000e+00,7.609737807878711076e-03,1.135180008324206298e-02,1.889110523699844280e-01 16 | 1.000000000000000000e+00,-4.713981466599250457e-03,8.532573207881267718e-01,1.657967488414324383e-02 17 | 0.000000000000000000e+00,8.075704129282984534e-01,-5.089553831733425104e-05,-4.407336534641186770e-03 18 | 2.000000000000000000e+00,-3.771348873885878045e-03,1.455238207844535871e-02,1.777285394295365639e-01 19 | 1.000000000000000000e+00,4.075145022443840692e-03,1.698302646639709979e-01,-4.804522157191564759e-03 20 | 0.000000000000000000e+00,5.727087349639652381e-01,2.602912263795892876e-03,9.487246116507528845e-03 21 | 0.000000000000000000e+00,5.284652042499206059e-01,5.861591985020897511e-03,5.333661169113309251e-03 22 | 1.000000000000000000e+00,6.693968712869571326e-04,3.750394799447906991e-01,1.056609345890428353e-04 23 | 1.000000000000000000e+00,9.981314573283711980e-03,7.935531384197062055e-01,-6.923769763538601026e-03 24 | 0.000000000000000000e+00,6.662355833353265089e-01,7.390669888855553931e-04,3.831899358935787606e-04 25 | 2.000000000000000000e+00,1.481316697742679142e-02,5.702374059757067421e-04,9.835020121317296082e-01 26 | 0.000000000000000000e+00,8.036794400387450299e-01,8.162695369305861259e-03,3.140783682409668126e-03 27 | 2.000000000000000000e+00,1.199346831897858678e-02,-7.973626059003013564e-03,2.715472613369635524e-01 28 | 1.000000000000000000e+00,9.027228761799550513e-03,1.302300471717894870e-01,6.393993992578914619e-03 29 | 0.000000000000000000e+00,6.552402208825344188e-01,7.798667285393446301e-03,8.065186919302705371e-03 30 | 1.000000000000000000e+00,8.779792276928906694e-03,3.897762196532796874e-01,-3.831293394496527397e-03 31 | 1.000000000000000000e+00,3.899505532873354975e-03,6.441484959966428114e-01,2.152220785037516158e-03 32 | 1.000000000000000000e+00,5.765768700346124084e-03,1.376100692590369670e-02,1.059143149125150035e-01 33 | 2.000000000000000000e+00,1.360173800456625573e-02,1.337571081344019440e-02,3.751691095243766538e-01 34 | 0.000000000000000000e+00,5.017787623104140726e-01,-3.331001548780091349e-03,1.482468665797487782e-02 35 | 1.000000000000000000e+00,2.659029973942042286e-03,2.908336433022385425e-01,1.252984687939190898e-02 36 | 1.000000000000000000e+00,9.873060216096721589e-03,8.491782149503738220e-01,8.449299396097627313e-04 37 | 2.000000000000000000e+00,3.921371408179670504e-03,1.329127041265017878e-02,1.673234882352435948e-01 38 | 0.000000000000000000e+00,2.555151186788799933e-01,-7.548933228795473745e-03,1.085908420892274344e-02 39 | 2.000000000000000000e+00,5.504141645294919651e-03,7.858537751163758439e-03,9.959976371033908427e-01 40 | 2.000000000000000000e+00,1.715697165944922269e-03,6.771827916344049479e-03,8.145896778420456119e-01 41 | 2.000000000000000000e+00,4.235598731275280922e-03,4.569731385919316956e-03,8.480405032914033470e-01 42 | 0.000000000000000000e+00,4.470161022439992360e-01,1.405681337592660368e-02,8.990446873413583673e-03 43 | 2.000000000000000000e+00,8.445957569601331844e-03,-3.903910688608832735e-03,8.872091582724088710e-01 44 | 0.000000000000000000e+00,7.211998129395873480e-01,1.892283749308038172e-03,-1.867973883951963478e-03 45 | 1.000000000000000000e+00,4.245732303056424031e-03,2.813617164852835195e-01,2.434727045814426283e-03 46 | 2.000000000000000000e+00,1.430777080493685592e-02,8.326352371799681726e-03,8.765112037142801515e-01 47 | 2.000000000000000000e+00,4.437157723732447837e-03,5.604327276156647854e-05,9.902454936301050559e-01 48 | 0.000000000000000000e+00,4.707246476225954157e-01,9.336090489739731421e-03,6.662155679284908170e-03 49 | 2.000000000000000000e+00,-7.062030648820293022e-04,1.188596085571762967e-02,2.337961846231434437e-01 50 | 0.000000000000000000e+00,7.765991494020043628e-01,9.204967912989803214e-03,-2.666203382402914691e-03 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # plot_utils 2 | Plotting library for conformal prediction metrics, intended to facilitate fast testing in e.g. notebooks. 3 | 4 | ## Examples 5 | Example usage can be found in the [User Guide Classification notebook](python/examples/User_guide_classification.ipynb), [User Guide Regression notebook](python/examples/User_guide_regression.ipynb) and [Nonconformist+plot_utils notebook](python/examples/Nonconformist_and_plot_utils.ipynb). 6 | 7 | ## Package dependencies 8 | See [requirements.txt](python/requirements.txt) for package dependencies used in our development. Here are links to the libraries: 9 | 10 | - [matplotlib](https://matplotlib.org/) 11 | - [Numpy](https://numpy.org/) 12 | - [Pandas](https://pandas.pydata.org/) 13 | - [Seaborn](https://seaborn.pydata.org/) - **Optionally!** 14 | 15 | ## API 16 | 17 | ### Data format 18 | The code internally use numpy ndarrays for matrices and vectors, but tries to be agnostic about input being either list, arrays or Pandas equivalents. But for performance reasons it is recommended that conversion to numpy format is done when using several of the methods in this library, as a new conversion would be performed for each function call. 19 | 20 | ### Rendering backends 21 | Internally this library requires [matplotlib](https://matplotlib.org/) and (optionally) [Seaborn](https://seaborn.pydata.org/). Only the `plot_confusion_matrix_heatmap` has a hard requirement for seaborn to be available, otherwise this library only interacts with the matplotlib classes/functions and use seaborn-settings for generating somewhat nicer plots (in our opinion). Styling and colors can always be changed through the matplotlib API. 22 | 23 | ### Data loading 24 | To simplify loading and conversions of data the `plot_utils` library now has some utility functions for loading CSV files with predictions or validation metrics (typically generated using [CPSign](https://github.com/arosbio/cpsign). For regression we require predictions to be the same as used in [nonconformist](https://github.com/donlnz/nonconformist), using 2D or 3D tensors in numpy ndarrays of shape `(num_examples,2)` or `(num_examples,2,num_significance_levels)`, where the second dimension contains the lower and upper limits of the prediction intervals. 25 | 26 | 27 | ## Supported plots 28 | ### Classification 29 | * Calibration plot 30 | * Label ratio plot, showing ratio of single/multi/empty predictions for each significance level 31 | * p-value distribution plot: plot p-values as a scatter plot 32 | * "Bubble plot" confusion matrix 33 | * Heatmap confusion matrix 34 | 35 | ### Regression 36 | * Calibration plot 37 | * Efficiency plot (mean or median prediction interval width vs significance) 38 | * Prediction intervals (for a given significance level) 39 | 40 | ## Set up 41 | To use this package you clone this repo and add the `/python/src/` directory to your `$PYTHONPATH`. 42 | 43 | ## Developer notes 44 | We should aim at supplying proper docstrings, following the [numpy docstring guide](https://numpydoc.readthedocs.io/en/latest/format.html). 45 | 46 | ### Testing 47 | All python-tests are located in the [tests folder](python/tests) and are meant to be run using [pytest](https://docs.pytest.org). Test should be started from standing in the `python` folder and can be run "all at once" (`python -m pytest`), "per file" (`python -m pytest tests/pharmbio/cp/metrics/clf_metrics_test.py`), or a single test function (`python -m pytest tests/pharmbio/cp/metrics/clf_metrics_test.py::TestConfusionMatrix::test_with_custom_labels`). 48 | - **Note1:** The invocation `python -m pytest [opt args]` is preferred here as the current directory is added to the python path and resolves the application code automatically. Simply running `pytest` requires manual setup of the `PYTHONPATH` instead. 49 | - **Note2:** The plotting tests generate images that are saved in the [test_output](python/tests/test_output) directory and these should be checked manually (no good way of automating plotting-tests). 50 | 51 | ### TODOs: 52 | 53 | Add/finish the following plots: 54 | - [x] calibration plot - Staffan 55 | - [x] 'area plot' with label-distributions - Staffan 56 | - [x] bubbel-plot - Jonathan 57 | - [x] heatmap - Staffan 58 | - [x] p0-p1 plot - Staffan 59 | - [x] Add regression metrics 60 | - [x] Add plots regression 61 | 62 | 63 | ### Change log: 64 | - **0.1.0**: 65 | * Added `pharmbio.cpsign` package with loading functionality for CPSign generated files, loading calibration statistics, efficiency statistics and predictions. 66 | * Updated plotting functions in order to use pre-computed metrics where applicable (e.g. when computed by CPSign). 67 | * Added possibility to add a shading for +/- standard deviation where applicable, e.g. calibration curve 68 | * Updated calibration curve plotting to have a general `plotting.plot_calibration` acting on pre-computed values or for classification using `plotting.plot_calibration_clf` where true labels and p-values can be given. 69 | * Update parameter order to make it consistent across plotting functions, e.g. ordered as x, y (significance vs error rate) in the plots. 70 | * Added a utility function for setting the seaborn theme and context using `plotting.update_plot_settings` which updates the matplotlib global settings. *Note* this will have effect on all other plots generated in the same python session if those rely on matplotlib. 71 | -------------------------------------------------------------------------------- /python/src/pharmbio/data/_load.py: -------------------------------------------------------------------------------- 1 | """Utility functions for loading and converting datasets 2 | """ 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import re 7 | 8 | def load_regression(f, 9 | y_true_col, 10 | sep = ',', 11 | lower_regex=r'^prediction.*interval.*lower.*\d+', 12 | upper_regex=r'^prediction.*interval.*upper.*\d+', 13 | specifies_significance=None): 14 | """Loads a CSV file with predictions and converts to the format used by Plot_utils 15 | 16 | The required format is that the csv has; 17 | - A header 18 | - Specifies significance or confidence in the header names of the 'lower' and 'upper' columns 19 | - Those headers must only contain a single number 20 | 21 | Note that there is no requirement for a true label to exist, the `y_true_col` can be set to None and no y-labels will be returned 22 | 23 | Parameters 24 | ---------- 25 | f : str or buffer 26 | File path or buffer that `Pandas.read_csv` can read 27 | 28 | y_true_col : str or None 29 | The (case insensitive) column header of the true labels, or None if it should not be loaded 30 | 31 | sep : str, default ',' 32 | Delimiter that is used in the CSV between columns 33 | 34 | lower_regex, upper_regex : str or re.Pattern 35 | Regex used for finding columns of the lower and upper interval limits. Must match the column headers 36 | 37 | specifies_significance : bool or None, default None 38 | If the numbers in the headers are significance level (True) or confidence (False). If None, the first column-header found by `lower_regex` will be used to check for occurrences of 'significance' or 'conf' to try to infer what is used 39 | 40 | Returns 41 | ------- 42 | (y, pred_matrix, sign_values) 43 | """ 44 | 45 | if not isinstance(lower_regex, re.Pattern): 46 | lower_regex = re.compile(lower_regex,re.IGNORECASE) 47 | if not isinstance(upper_regex,re.Pattern): 48 | upper_regex = re.compile(upper_regex,re.IGNORECASE) 49 | num_pattern = re.compile('\d*\.\d*') 50 | y_col_lc = None if y_true_col is None else y_true_col.lower() 51 | y_true_ind = None 52 | 53 | df = pd.read_csv(f,sep=sep) 54 | low_ind, upp_ind, sign_low, sign_upp = [], [], [], [] 55 | for i, c in enumerate(df.columns): 56 | if lower_regex.match(c) is not None: 57 | low_ind.append(i) 58 | sign_low.append(float(num_pattern.findall(c)[0])) 59 | elif upper_regex.match(c) is not None: 60 | upp_ind.append(i) 61 | sign_upp.append(float(num_pattern.findall(c)[0])) 62 | elif y_col_lc is not None and c.lower() == y_col_lc: 63 | y_true_ind = i 64 | 65 | # Some validation 66 | assert sign_low == sign_upp 67 | assert len(low_ind) == len(upp_ind) 68 | if not isinstance(specifies_significance,bool): 69 | col_lc = df.columns[low_ind[0]].lower() 70 | contains_sign =col_lc.__contains__('significance') 71 | contains_conf = col_lc.__contains__('confidence') 72 | 73 | if (contains_sign and contains_conf) or (not contains_sign and not contains_conf): 74 | raise ValueError('Parameter \'specifies_significance\' not set, could not deduce if significance or confidence is used. Explicitly set this parameter and try again') 75 | 76 | specifies_significance = True if contains_sign else False 77 | 78 | sign_vals = np.array(sign_low) if specifies_significance else 1 - np.array(sign_low) 79 | 80 | y, p = convert_regression(df,y_true_ind,low_ind,upp_ind) 81 | return y, p, sign_vals 82 | 83 | 84 | def convert_regression(data, 85 | y_true_index, 86 | min_index, 87 | max_index): 88 | """ 89 | Converts a 2D input matrix to a 3D ndarray that 90 | is required by the metrics and plotting functions 91 | 92 | Parameters 93 | ---------- 94 | data : 2d array like 95 | Data matrix that must be convertible to 2D ndarray 96 | 97 | y_true_index : int or None 98 | Column index that the ground truth values are, or None if no 99 | y values should be generated. Output `y` will then be None 100 | 101 | min_index, max_index : list or array of int 102 | Column indices for min and max values for prediction intervals 103 | 104 | Returns 105 | ------- 106 | (y, predictions) 107 | y : 1D ndarray 108 | The y_true values or None if `y_true_index` is None 109 | 110 | predictions : 3D ndarray 111 | matrix of shape (n_examples, 2, n_significance_levels), where the second 112 | dimension is [min, max] of the prediction intervals 113 | """ 114 | if not isinstance(data,np.ndarray): 115 | data = np.asarray(data) 116 | if data.ndim != 2: 117 | raise ValueError('Input must be a 2D array type') 118 | 119 | ys = data[:,y_true_index].astype(np.float64) if y_true_index is not None else None 120 | 121 | if len(min_index) != len(max_index): 122 | raise ValueError('min_index and max_index must be of same length') 123 | 124 | # Allocate matrix 125 | preds = np.zeros((len(data),2,len(min_index)),dtype=np.float64) 126 | 127 | for s, (min,max) in enumerate(zip(min_index,max_index)): 128 | preds[:,0,s] = data[:,min] 129 | preds[:,1,s] = data[:,max] 130 | 131 | # Return tuple 132 | return (ys, preds) -------------------------------------------------------------------------------- /python/src/pharmbio/cp/plotting/_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | import matplotlib.pyplot as plt 3 | import logging 4 | from warnings import warn 5 | import math 6 | import numpy as np 7 | from sklearn.utils import check_consistent_length 8 | import pandas as pd 9 | 10 | 11 | _using_seaborn = False 12 | # Try to import sns as they create somewhat nicer plots 13 | try: 14 | import seaborn as sns 15 | sns.set_theme() 16 | logging.debug('Using Seaborn plotting defaults') 17 | _using_seaborn = True 18 | except ImportError as e: 19 | logging.debug('Seaborn not available - using Matplotlib defaults') 20 | pass 21 | 22 | __DEFAULT_FIG_SIZE = (10,8) 23 | 24 | def get_fig_and_axis(ax=None, figsize = __DEFAULT_FIG_SIZE): 25 | '''Function for instantiating a Figure / axes object 26 | 27 | Validates the input parameters, instantiates a Figure and axes if not 28 | sent as a parameter to the plotting function. 29 | 30 | Parameters 31 | ---------- 32 | ax : matplotlib axes or None 33 | An existing axes object that the user may send, 34 | to write the plot in 35 | 36 | figsize : float or (float, float) 37 | A figure size to generate. If a single number is given, the figure will 38 | be a square with each side of that size. 39 | 40 | Returns 41 | ------- 42 | fig : matplotlib Figure 43 | 44 | ax : matplotlib Axes 45 | 46 | Raises 47 | ------ 48 | TypeError 49 | If `figsize` or `ax` is of invalid type 50 | ''' 51 | # Override if figsize is None 52 | if figsize is None: 53 | figsize = __DEFAULT_FIG_SIZE 54 | 55 | if ax is None: 56 | # No current axes, create a new Figure 57 | if isinstance(figsize, (int, float)): 58 | fig = plt.figure(figsize = (figsize, figsize)) 59 | elif isinstance(figsize, tuple): 60 | fig = plt.figure(figsize = figsize) 61 | else: 62 | raise TypeError('parameter figsize must either be float or (float, float), was: {}'.format(type(figsize))) 63 | # Add an axes spanning the entire Figure 64 | ax = fig.add_subplot(111) 65 | elif not isinstance(ax, mpl.axes.Axes): 66 | raise TypeError('parameter ax must be either None or a matplotlib.axes object, was: {}'.format(type(ax))) 67 | else: 68 | fig = ax.get_figure() 69 | 70 | return fig, ax 71 | 72 | def cm_as_list(cm, default_cm): 73 | if cm is None: 74 | return default_cm 75 | elif isinstance(cm, mpl.colors.ListedColormap): 76 | return list(cm.colors) 77 | elif isinstance(cm, list): 78 | return cm 79 | else: 80 | return [cm] 81 | 82 | def _set_label_if_not_set(ax, text, x_axis=True): 83 | """Sets the x or y axis labels if not set before 84 | 85 | Checks if the axis label has been set previously, 86 | does not overwrite any existing label 87 | 88 | Parameters 89 | ---------- 90 | ax : matplotlib Axes 91 | The Axes to write to, must not be None 92 | text : str or None 93 | Optional text to write to the label, function simply 94 | returns if None is given 95 | x_axis : bool 96 | If it the x axis (True) or y axis (False) that should be written 97 | """ 98 | if text is None or not isinstance(text,str): 99 | return 100 | if x_axis: 101 | if len(ax.xaxis.get_label().get_text()) == 0: 102 | # Label not set before 103 | ax.set_xlabel(text) 104 | else: 105 | if len(ax.yaxis.get_label().get_text()) == 0: 106 | ax.set_ylabel(text) 107 | 108 | def _set_title(ax, title=None): 109 | """Sets the title if given and not previously set 110 | 111 | """ 112 | if title is None: 113 | return 114 | if len(ax.get_title()) == 0: 115 | if not isinstance(title,str): 116 | title = str(title) 117 | ax.set_title(title,fontdict={'fontsize':'x-large'}) 118 | 119 | def _set_chart_size(ax, 120 | x_vals, 121 | y_vals, 122 | padding = 0.025, 123 | flip_x=False, 124 | flip_y=False): 125 | """Sets the chart drawing limits 126 | 127 | Handles padding and finds the max and min values 128 | 129 | Parameters 130 | ---------- 131 | ax : matplotlib Axes 132 | 133 | x_vals, y_vals : array_like 134 | The values to find limits of, can optionally be calculated prior to this function and sent as a list of e.g. [min,max] to save computation time 135 | 136 | padding : float or (float,float), default = 0.025 137 | Padding as percentage of the value range, if a single value is given the same padding is applied to both axes. For two values, the first is applied to x-axes and the second to the y-axes. 138 | 139 | flip_x : bool, default False 140 | If the x-axes should display significance level (`False`) or confidence (`True`) 141 | 142 | flip_y : bool, default False 143 | If the y-axes should display error-rate (`False`) or accuracy (`True`) 144 | """ 145 | x_min,x_max = np.min(x_vals), np.max(x_vals) 146 | y_min,y_max = np.min(y_vals), np.max(y_vals) 147 | if flip_x: 148 | x_min,x_max = 1-x_max, 1-x_min 149 | if flip_y: 150 | y_min,y_max = 1-y_max, 1-y_min 151 | x_w, y_w = x_max - x_min, y_max - y_min 152 | 153 | if padding is None: 154 | x_padd = 0 155 | y_padd = 0 156 | elif isinstance(padding,float): 157 | x_padd = x_w*padding 158 | y_padd = y_w*padding 159 | elif isinstance(padding,tuple) or isinstance(padding,list): 160 | if len(padding) == 1: 161 | x_padd = x_w*padding[0] 162 | y_padd = y_w*padding[0] 163 | elif len(padding) > 1: 164 | x_padd = x_w*padding[0] 165 | y_padd = y_w*padding[1] 166 | else: 167 | raise TypeError('padding should be a float or list/tuple of 2 floats') 168 | else: 169 | raise TypeError('padding should be a float or list/tuple of 2 floats, got {}'.format(type(padding))) 170 | 171 | ax.axis([x_min-x_padd, 172 | x_max+x_padd, 173 | y_min-y_padd, 174 | y_max+y_padd]) 175 | 176 | def _plot_vline(x,y_min,y_max,ax,color='gray',alpha=0.7,line_cap=None): 177 | # The vertical line itself 178 | ax.vlines(x = x, 179 | ymin = y_min, 180 | ymax = y_max, 181 | color = color, 182 | alpha = alpha) 183 | if (isinstance(line_cap,bool) and not line_cap) or line_cap is None: 184 | return 185 | elif (isinstance(line_cap,bool) and line_cap) or line_cap == 1: 186 | m_u = m_l = '_' 187 | elif line_cap == 2: 188 | m_u,m_l = 6,7 189 | elif isinstance(line_cap,str): 190 | m_u, m_l = line_cap,line_cap 191 | else: 192 | warn('Invalid argument for line_cap {}, falling back to not printing any'.format(line_cap)) 193 | # Upper 'cap' 194 | ax.plot(x, y_max, 195 | marker = m_u, 196 | lw = 0, 197 | alpha = alpha, 198 | color = color) 199 | # Lower 'cap' 200 | ax.plot(x, y_min, 201 | marker = m_l, 202 | lw = 0, 203 | alpha = alpha, 204 | color = color) -------------------------------------------------------------------------------- /python/src/pharmbio/cp/utils.py: -------------------------------------------------------------------------------- 1 | from sklearn.preprocessing import OneHotEncoder 2 | import numpy as np 3 | import pandas as pd 4 | 5 | __sklearn_1_2_0_or_later = False 6 | 7 | try: 8 | from packaging.version import Version, parse as parse_version 9 | import sklearn 10 | __sklearn_1_2_0_or_later = parse_version(sklearn.__version__)>= Version('1.2.0') 11 | except ImportError as e: 12 | pass 13 | 14 | 15 | def get_n_classes(y_true, p_vals): 16 | """Helper method for finding the maximum number of classes 17 | 18 | The number could either be the number of columns in the p-value matrix. 19 | Or the user could have only sent a slice of the p-values/added more labels 20 | in the `y_true` due to wanting to plot them in a different color. The value 21 | of `n_class` is the maximum number of these, so trying to access the `n_class'th - 1` 22 | column the p-value matrix might be out of range! 23 | 24 | """ 25 | if y_true is None and p_vals is None: 26 | raise ValueError('Neither y_true nor p_values were given') 27 | elif y_true is None: 28 | if isinstance(p_vals, np.ndarray): 29 | return p_vals.shape[1] 30 | return to_numpy2D(p_vals, None).shape[1] 31 | elif p_vals is None: 32 | if isinstance(y_true, (list,np.ndarray,pd.Series)): 33 | return int(np.max(y_true)) 34 | return max(int(np.max(y_true)+1), p_vals.shape[1]) # +1 on max(y_true) as labels start at 0 35 | 36 | def get_str_labels(labels, n_class): 37 | """Helper method for turning numerical labels to str labels 38 | 39 | Parameters 40 | ---------- 41 | labels : list of str or None 42 | Labels given as parameter, or None if not given 43 | 44 | n_class : int 45 | The number of classes 46 | 47 | """ 48 | if labels is not None: 49 | if not isinstance(labels, (np.ndarray, list, pd.Series)): 50 | raise TypeError('parameter labels must be either a list or 1D numpy array') 51 | if len(labels) < n_class: 52 | raise TypeError('parameter labels and number of classes does not match') 53 | return np.array(labels).astype(str) 54 | else: 55 | # No labels, generate n_classes labels 56 | return ['Label {}'.format(i) for i in range(0,n_class)] 57 | 58 | def validate_sign(sign): 59 | """Validate that `sign` is within [0,1] or raise error 60 | 61 | Checks both the type and content are OK. If numpy.ndarray the array must be 1dim 62 | 63 | Parameters 64 | ---------- 65 | sign : int, float, numpy.ndarray, pandas.Series 66 | The significance level to check 67 | """ 68 | if isinstance(sign, np.ndarray) and sign.ndim == 0: 69 | # This is a single element, convert to float 70 | sign = float(sign) 71 | 72 | if isinstance(sign, (np.ndarray,pd.Series, list, tuple)): 73 | # Check that ndarray is 1dim 74 | if isinstance(sign, np.ndarray): 75 | # must be dim == 1 76 | if sign.ndim != 1 : 77 | raise ValueError('Significance levels must be given as a single value or an array / 1dim ndarray') 78 | # validate each value 79 | for s in sign: 80 | if s < 0 or s >1: 81 | raise ValueError('All significance levels must be in the range [0..1], got: {}'.format(s)) 82 | elif isinstance(sign, (int,float)): 83 | # I.e. a single value 84 | if sign < 0 or sign >1: 85 | # Single value but which is outside 86 | raise ValueError('parameter sign must be in the range [0,1]') 87 | else: 88 | raise TypeError('parameter sign must be a number or sequence of numbers') 89 | 90 | def to_numpy2D(input, param_name, min_num_cols=2, return_copy=True, unravel=False): 91 | """ Converts python list-based matrices and Pandas DataFrames into numpy 2D arrays 92 | 93 | If input is already a numpy array, it will be copied in case `return_copy` is True 94 | """ 95 | 96 | if input is None: 97 | raise ValueError('Input {} cannot be None'.format(param_name)) 98 | elif isinstance(input, list): 99 | # This should be a python list-matrix, convert to numpy matrix 100 | matrix = np.array(input) 101 | elif isinstance(input, pd.DataFrame): 102 | matrix = input.to_numpy() 103 | elif isinstance(input, np.ndarray): 104 | if input.ndim != 2: 105 | if input.ndim == 1 and unravel: 106 | # if we are allowed to unravel (i.e. add an additional dim to the ndarray) - we create one 107 | return input.reshape((len(input),1)) 108 | raise ValueError('parameter {} must be a 2D matrix, was a ndarray of shape {}'.format(param_name,input.shape)) 109 | matrix = input.copy() if return_copy else input 110 | else: 111 | raise TypeError('parameter {} in unsupported format: {}'.format(param_name,type(input))) 112 | # Validate at least min num columns present 113 | if len(matrix.shape) < 2 or matrix.shape[1] < min_num_cols: 114 | raise ValueError('parameter {} must be a matrix with at least {} columns'.format(param_name, min_num_cols)) 115 | return matrix 116 | 117 | def to_numpy1D(input,param_name,return_copy=True): 118 | """Convert lists and Panda Series to 1D numpy array. 119 | 120 | If input is already a numpy array, it is copied if `return_copy` is True 121 | """ 122 | if isinstance(input, (list, pd.Series)): 123 | return np.array(input) 124 | elif isinstance(input, np.ndarray): 125 | if len(input.shape) == 1: 126 | return input.copy() if return_copy else input 127 | elif input.shape[1]>1: 128 | raise ValueError('parameter {} must be a list, 1D numpy array or pandas Series'.format(param_name)) 129 | else: 130 | if return_copy: 131 | cpy = input.copy() 132 | cpy.shape = (len(cpy), ) 133 | return cpy 134 | else: 135 | input.shape = (len(input),) 136 | return input 137 | else: 138 | raise ValueError('parameter {} must be a list, 1D numpy array or pandas Series'.format(param_name)) 139 | 140 | def to_numpy1D_int(input, param_name): 141 | return to_numpy1D(input,param_name).astype(np.int16) 142 | 143 | def to_numpy1D_onehot(input, param_name, return_encoder=False, dtype=bool, labels=None): 144 | """ 145 | Returns 146 | ------- 147 | (matrix, array): 148 | matrix : numpy 2D of bool 149 | The one-hot-encoded version of y_true 150 | array : numpy 1D 151 | The categories, corresponding to the indices of `matrix` 152 | 153 | (matrix, array, sklearn.preprocessing.OneHotEncoder) 154 | When `return_encoder` is set to True. 155 | """ 156 | one_dim = to_numpy1D(input,param_name,return_copy=False).reshape(-1,1) 157 | if labels is None: 158 | labels = np.unique(one_dim) 159 | 160 | if __sklearn_1_2_0_or_later: 161 | enc = OneHotEncoder(sparse_output=False,dtype=dtype,categories=[labels]) 162 | else: 163 | enc = OneHotEncoder(sparse=False,dtype=dtype,categories=[labels]) 164 | one_hot = enc.fit_transform(one_dim) 165 | 166 | if return_encoder: 167 | return one_hot, enc.categories_[0], enc 168 | else: 169 | return one_hot, enc.categories_[0] 170 | 171 | 172 | def validate_regression_preds(input_matrix): 173 | """Checks if input is either 2D (one significance level) or 3D (multiple significance levels) 174 | 175 | The second dimension must always be 2 [lower, upper] of the prediction interval 176 | 177 | Returns 178 | ------- 179 | (n_significance_lvls, 2D/3D ndarray) 180 | """ 181 | if not isinstance(input_matrix, np.ndarray): 182 | raise ValueError('Regression predictions only supports numpy 2D or 3D arrays') 183 | 184 | if input_matrix.ndim == 2: 185 | # 2D matrix should be (N,2) shape 186 | if input_matrix.shape[1] != 2: 187 | raise ValueError('Regression predictions should be of the shape (N,2) or (N,2,S), where N is the number of predictions and S is the number of significance levels') 188 | return 1, input_matrix 189 | elif input_matrix.ndim == 3: 190 | # 3D matrix should be (N,2,S) shape 191 | if (input_matrix.shape[1]!= 2): 192 | raise ValueError('Regression predictions should be of the shape (N,2) or (N,2,S), where N is the number of predictions and S is the number of significance levels') 193 | if input_matrix.shape[2]==1: 194 | # "Fake 3D matrix" 195 | return 1, input_matrix[:,:,0] 196 | return input_matrix.shape[2], input_matrix 197 | else: 198 | raise ValueError('Regression predictions only supported as numpy 2D or 3D arrays') 199 | 200 | def to_numpy1D_reg_y_true(y_true,expected_len): 201 | arr = to_numpy1D(y_true,'y_true') 202 | if len(arr)!=expected_len: 203 | raise ValueError('Input predictions and true labels not of the same length: {} != {}'.format(len(arr),expected_len)) 204 | return arr -------------------------------------------------------------------------------- /python/tests/pharmbio/cpsign/load_cpsign_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import pytest 4 | 5 | from pharmbio.cp import plotting 6 | from pharmbio import cpsign 7 | from statistics import mean 8 | from ...help_utils import _save_clf, _save_reg, get_resource 9 | 10 | # file names 11 | clf_stats_incl_sd_file = 'cpsign_clf_stats_incl_sd.csv' 12 | clf_stats_excl_sd_file = 'cpsign_clf_stats_excl_sd.csv' 13 | clf_predictions_file = 'cpsign_clf_predictions.csv' 14 | reg_stats_incl_sd_file = 'cpsign_reg_stats_incl_sd.csv' 15 | reg_stats_excl_sd_file = 'cpsign_reg_stats_excl_sd.csv' 16 | reg_pred_incl_infs_file = 'cpsign_reg_predictions_10_incl_inf.csv' 17 | 18 | 19 | 20 | 21 | class TestClassification(): 22 | 23 | # update plot-settings 24 | plotting.update_plot_settings() 25 | 26 | @pytest.fixture 27 | def load_data(self): 28 | (self.signs, self.errs, self.errs_sd, self.labels) = cpsign.load_calib_stats(get_resource(clf_stats_incl_sd_file), sep='\t') 29 | 30 | 31 | def test_load_stats_file_calib(self,load_data): 32 | 33 | # plot and save it 34 | fig = plotting.plot_calibration(sign_vals=self.signs, error_rates=self.errs, error_rates_sd=self.errs_sd, labels=self.labels) 35 | fig.axes[0].set_title('from calculated values CPSign') 36 | _save_clf(fig,"CPSign_computed_clf.test_load_stats_file") 37 | 38 | # No labels and no error-SD 39 | fig_no_label = plotting.plot_calibration(sign_vals=self.signs, error_rates=self.errs) 40 | fig_no_label.axes[0].set_title('from CPSign - no labels') 41 | _save_clf(fig_no_label,"CPSign_computed_clf.test_load_stats_file_2") 42 | 43 | # Flip axes 44 | fig_flipped = plotting.plot_calibration(sign_vals=self.signs, error_rates=self.errs, 45 | error_rates_sd=self.errs_sd, labels=self.labels, 46 | flip_x=True,flip_y=True,title='precomputed from CPSign - flipped') 47 | _save_clf(fig_flipped,"CPSign_computed_clf.test_load_stats_file_3") 48 | 49 | 50 | def test_load_stats_file_calib_conf_acc(self, load_data): 51 | # convert to confidence and accuracy instead 52 | confs = 1- self.signs 53 | accs = 1 - self.errs 54 | accs_sd = self.errs_sd 55 | # plot and save it 56 | fig = plotting.plot_calibration(conf_vals=confs,accuracy_vals=accs, accuracy_sd=accs_sd, labels=self.labels) 57 | fig.axes[0].set_title('confs + accs from CPSign') 58 | _save_clf(fig,"CPSign_computed_clf.test_load_stats_file_conf_acc") 59 | 60 | fig_no_label = plotting.plot_calibration(conf_vals=confs,accuracy_vals=accs, accuracy_sd=accs_sd, sd_alpha=.1) 61 | fig_no_label.axes[0].set_title('from CPSign - no labels') 62 | _save_clf(fig_no_label,"CPSign_computed_clf.test_load_stats_file_conf_acc_2") 63 | 64 | 65 | def test_plot_single_calib_line(self, load_data): 66 | assert self.labels[0].lower() == 'overall' 67 | fig = plotting.plot_calibration(sign_vals=self.signs,error_rates=self.errs[:,0], error_rates_sd=self.errs_sd[:,0], labels=self.labels[0], title='cpsign only overall calib') 68 | _save_clf(fig,"TestCLF_CPSign.test_plot_single_calib_line") 69 | 70 | 71 | def test_load_stats_label_eff(self, load_data): 72 | (signs,single,multi,empty, prop_single_sd, prop_multi_sd, prop_empty_sd) = cpsign.load_clf_efficiency_stats(get_resource(clf_stats_incl_sd_file), sep='\t') 73 | assert len(signs) == len(single) 74 | assert len(signs) == len(multi) 75 | assert len(signs) == len(empty) 76 | assert len(signs) == len(prop_single_sd) 77 | assert len(signs) == len(prop_multi_sd) 78 | assert len(signs) == len(prop_empty_sd) 79 | # Explicitly turn of reading of SD values 80 | (signs2,single2,multi2,empty2) = cpsign.load_clf_efficiency_stats(get_resource(clf_stats_incl_sd_file), sep='\t', 81 | prop_e_sd_regex=None, prop_m_sd_regex=None, prop_s_sd_regex=None) 82 | # Output should be identical for both function calls 83 | assert np.array_equal(signs, signs2) 84 | assert np.array_equal(single, single2) 85 | assert np.array_equal(multi, multi2) 86 | assert np.array_equal(empty, empty2) 87 | fig = plotting.plot_label_distribution(prop_single=single, sign_vals=signs,prop_multi=multi, prop_empty=empty) 88 | _save_clf(fig, "TestCLF_CPSign.label_distr") 89 | 90 | def test_load_stats_eff_2(self): 91 | (signs,single,multi,empty) = cpsign.load_clf_efficiency_stats(get_resource(clf_stats_excl_sd_file), sep='\t') 92 | assert np.allclose(np.sort(signs), [0.1,0.3,0.5]) 93 | assert len(signs) == len(single) 94 | assert len(signs) == len(multi) 95 | assert len(signs) == len(empty) 96 | 97 | 98 | def test_load_preds(self): 99 | (ys, pvals, labels) = cpsign.load_clf_predictions(get_resource(clf_predictions_file),'target',';') 100 | fig = plotting.plot_label_distribution(y_true=ys,p_values= pvals) 101 | _save_clf(fig, "TestCLF_CPSign.load_clf_pred") 102 | 103 | class TestRegression(): 104 | 105 | def assert_label_output(self,labels): 106 | assert len(labels) == 1 107 | assert labels[0].lower() == 'overall' 108 | 109 | def test_load_reg_calib(self): 110 | (signs,errs,errs_sd,labels) = cpsign.load_calib_stats(get_resource(reg_stats_incl_sd_file), sep='\t') 111 | self.assert_label_output(labels) 112 | assert errs_sd is not None 113 | assert len(signs) == len(errs) 114 | assert len(signs) == len(errs_sd) 115 | fig = plotting.plot_calibration(sign_vals=signs,error_rates=errs, error_rates_sd=errs_sd, labels='Error rate', title='cpsign only overall calib') 116 | _save_reg(fig,"TestREG_CPSign.test_plot_calib") 117 | 118 | def test_load_reg_calib_2(self): 119 | (signs,errs,errs_sd,labels) = cpsign.load_calib_stats(get_resource(reg_stats_excl_sd_file), sep='\t') 120 | self.assert_label_output(labels) 121 | assert errs_sd is None 122 | assert len(signs) == len(errs) 123 | # print("cpsign-reg2: ",signs,errs,errs_sd,labels) 124 | fig = plotting.plot_calibration(sign_vals=signs,error_rates=errs, error_rates_sd=errs_sd, labels='Error rate', title='cpsign only overall calib') 125 | _save_reg(fig,"TestREG_CPSign.test_plot_calib_2") 126 | 127 | def test_load_reg_eff(self): 128 | (sign_vals, median_widths, mean_widths, median_widths_sd, mean_widths_sd) = cpsign.load_reg_efficiency_stats(get_resource(reg_stats_incl_sd_file), sep='\t') 129 | fig = plotting.plot_pred_widths(sign_vals,median_widths) 130 | _save_reg(fig, "TestREG_CPSign.test_plot_widths") 131 | # With std 132 | fig_std = plotting.plot_pred_widths(sign_vals,median_widths, median_widths_sd) 133 | _save_reg(fig_std, "TestREG_CPSign.test_plot_widths_std") 134 | assert median_widths_sd is not None 135 | assert mean_widths_sd is not None 136 | 137 | # check the skip_inf 138 | (sign_vals, median_widths, mean_widths) = cpsign.load_reg_efficiency_stats(get_resource(reg_stats_excl_sd_file), sep='\t', skip_inf=True) 139 | assert np.allclose(np.sort(sign_vals), [.2, 0.3]) 140 | assert len(median_widths) == 2 141 | assert len(mean_widths) == 2 142 | # Check the values (relies on order of sign_vals) 143 | assert np.isclose(4.9, mean_widths[0]) 144 | assert np.isclose(26.9, mean_widths[1]) 145 | 146 | (sign_vals_incl, median_widths_incl, mean_widths_incl) = cpsign.load_reg_efficiency_stats(get_resource(reg_stats_excl_sd_file), sep='\t', skip_inf=False) 147 | assert np.allclose(np.sort(sign_vals_incl), [0, 0.05, 0.1, 0.2, 0.3]) 148 | assert len(median_widths_incl) == 5 and np.any(np.isposinf(median_widths_incl)) 149 | assert len(mean_widths_incl) == 5 and np.any(np.isposinf(mean_widths_incl)) 150 | # Check the values (relies on order of sign_vals_incl) 151 | assert np.isclose(4.49, median_widths_incl[0]) 152 | assert np.isclose(27.2, median_widths_incl[1]) 153 | 154 | 155 | def test_load_cpsign_preds(self): 156 | # Load predictions and exclude inf values 157 | (y, pred_matrix, sign_values) = cpsign.load_reg_predictions(get_resource(reg_pred_incl_infs_file),y_true_col='solubility', sep='\t', skip_inf=True) 158 | # Load predictions and include inf values 159 | (y_incl, pred_matrix_incl, sign_values_incl) = cpsign.load_reg_predictions(get_resource(reg_pred_incl_infs_file),y_true_col='solubility', sep='\t', skip_inf=False) 160 | 161 | assert np.all(y == y_incl) 162 | # Should be as many predictions (i.e. two first dimensions) 163 | assert len(pred_matrix) == len(pred_matrix_incl) 164 | assert pred_matrix.shape[1] == 2 and pred_matrix_incl.shape[1] == 2 165 | # The one with including inf values should have some inf values 166 | assert np.any(np.isinf(pred_matrix_incl)) 167 | assert not np.any(np.isinf(pred_matrix)) 168 | assert len(sign_values) == pred_matrix.shape[2] 169 | assert len(sign_values_incl) == pred_matrix_incl.shape[2] 170 | -------------------------------------------------------------------------------- /python/src/pharmbio/cp/plotting/_regression.py: -------------------------------------------------------------------------------- 1 | """CP Regression plots 2 | """ 3 | import matplotlib as mpl 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from sklearn.utils import check_consistent_length 7 | 8 | from pharmbio.cp.utils import to_numpy1D, to_numpy2D, validate_sign 9 | 10 | # The following import sets seaborn etc if available 11 | from ._utils import get_fig_and_axis, _set_chart_size, _plot_vline,_set_label_if_not_set, _set_title 12 | from ._common import add_calib_curve 13 | 14 | 15 | def plot_pred_widths(sign_vals, 16 | pred_widths, 17 | pred_widths_sd = None, 18 | color = 'blue', 19 | sd_alpha = 0.3, 20 | flip_x = False, 21 | ax = None, 22 | figsize = (10,8), 23 | chart_padding = 0.025, 24 | title = None, 25 | y_label = 'Median Prediction interval width', 26 | tight_layout = True, 27 | **kwargs): 28 | """**Regression** - Plot prediction widths at different significance levels 29 | 30 | Parameters 31 | ---------- 32 | sign_vals : array like 33 | List of significance levels for each of the `pred_widths` 34 | 35 | pred_widths : array like 36 | List or 1D array of prediction widths, typically generated from `metrics.pred_width` 37 | 38 | pred_widths_sd : array like, optional 39 | List or 1D array of standard deviations for the prediction widths, to depict the width+/-std area 40 | 41 | color : str or matplotlib recognized color-input 42 | Color of the plotted curve 43 | 44 | flip_x : bool, default False 45 | If the x-axes should display significance level (`False`) or confidence (`True`) 46 | 47 | ax : matplotlib Axes 48 | Axes to plot in 49 | 50 | figsize : float or (float, float), optional 51 | Figure size to generate, ignored if `ax` is given 52 | 53 | chart_padding : float, (float,float) or None 54 | padding added to the chart-area outside of the min and max values found in data. If two values the first value will be used as x-padding and second y-padding. E.g. 0.025 means 2.5% on both sides. None means no padding at all 55 | 56 | title : str, optional 57 | Optional title that will be printed in 'x-large' font size (default None) 58 | 59 | y_label : str or None 60 | Label for the y-axis, default is 'Median Prediction interval width' 61 | 62 | tight_layout : bool, optional 63 | Set `tight_layout` on the matplotlib Figure object 64 | 65 | **kwargs : dict, optional 66 | Keyword arguments, passed to matplotlib 67 | 68 | Returns 69 | ------- 70 | Figure 71 | matplotlib.figure.Figure object 72 | 73 | See Also 74 | -------- 75 | metrics.pred_width 76 | Calculates median or mean prediction interval widths 77 | """ 78 | 79 | check_consistent_length(sign_vals, pred_widths) 80 | if pred_widths_sd is not None: 81 | check_consistent_length(sign_vals, pred_widths_sd) 82 | 83 | validate_sign(sign_vals) 84 | fig, ax = get_fig_and_axis(ax, figsize) 85 | 86 | if flip_x: 87 | xs = 1 - sign_vals if isinstance(sign_vals,np.ndarray) else 1 - np.array(sign_vals) 88 | x_label = 'Confidence' 89 | else: 90 | xs = sign_vals 91 | x_label = 'Significance' 92 | 93 | # Set chart range 94 | _set_chart_size(ax, 95 | xs, 96 | pred_widths, 97 | chart_padding) 98 | 99 | if pred_widths_sd is not None: 100 | ax.fill_between(xs, pred_widths-pred_widths_sd, pred_widths+pred_widths_sd, interpolate=True, alpha = sd_alpha) 101 | ax.plot(xs, pred_widths, color=color,label=y_label, **kwargs) 102 | 103 | # Print some labels and title if appropriate 104 | _set_label_if_not_set(ax,x_label,True) 105 | _set_label_if_not_set(ax,y_label,False) 106 | _set_title(ax,title) 107 | 108 | if tight_layout: 109 | fig.tight_layout() 110 | 111 | return fig 112 | 113 | 114 | 115 | def plot_pred_intervals(y_true, 116 | predictions, 117 | ax = None, 118 | figsize = (10,8), 119 | chart_padding = 0.025, 120 | 121 | correct_color = 'blue', 122 | correct_marker = 'o', 123 | correct_alpha = 0.75, 124 | correct_ci = 'gray', 125 | correct_ci_alpha = 0.7, 126 | correct_label = 'Correct', 127 | 128 | incorrect_color = 'red', 129 | incorrect_marker = 'o', 130 | incorrect_alpha = 0.75, 131 | incorrect_ci ='gray', 132 | incorrect_ci_alpha = 0.7, 133 | incorrect_label = 'Incorrect', 134 | 135 | line_cap = 1, 136 | 137 | title = None, 138 | x_label = 'Predicted examples', 139 | y_label = None, 140 | 141 | x_start_index = 0, 142 | tight_layout = True, 143 | **kwargs): 144 | """**Regression** - Plot predictions and their confidence intervals 145 | 146 | Sorts the predictions after the size of the `y_true` values and plots both the true labels and the prediction/confidence intervals (CI) for each prediction. Erronious and correctly predicted examples can be discerned by using different colors and markers for the CI and/or the true value-points. 147 | 148 | Parameters 149 | ---------- 150 | y_true : 1D numpy array, list or pandas Series 151 | True labels 152 | 153 | predictions : 2D ndarray 154 | 2D array of shape (n_samples, 2) where the second dimension should have min interval limit at index 0 and max interval limit at index 1 155 | 156 | ax : matplotlib Axes, optional 157 | An existing matplotlib Axes to plot in (default None) 158 | 159 | figsize : float or (float, float), optional 160 | Figure size to generate, ignored if `ax` is given 161 | 162 | chart_padding : float, (float,float) or None, default 0.025 163 | padding added to the chart-area outside of the min and max values found in data. If two values the first value will be used as x-padding and second y-padding. E.g. 0.025 means 2.5% on both sides 164 | 165 | correct_color,incorrect_color : str of matplotlib recognized color-input 166 | Color of the points for the true values 167 | 168 | correct_marker, incorrect_marker : str of matplotlib recognized marker 169 | The shape of the true values 170 | 171 | correct_alpha, incorrect_alpha : float, default 0.75 172 | The alpha (transparency) of the true values 173 | 174 | correct_ci, incorrect_ci : str of matplotlib recognized color-input 175 | Color of the confidence/prediction intervals 176 | 177 | correct_ci_alpha, incorrect_ci_alpha : float 178 | The alpha (transparency) of the confidence/prediction intervals 179 | 180 | correct_label,incorrect_label : str 181 | The label, if any, that should be added to the correct/incorrect true examples, which will end up in the plot if adding a legend in the figure 182 | 183 | line_cap : {None, 1, 2 or str}, default 1 184 | The end of the confidence/prediction intervals, 1:'_', 2:6 / 7 out of the accepted markers list: https://matplotlib.org/stable/api/markers_api.html 185 | 186 | title : str, default None 187 | A title to add to the figure (default None) 188 | 189 | x_label, y_label : str, optional 190 | label for the x/y-axis, default is None for the y-axis and 'Predicted examples' on the x-axis 191 | 192 | x_start_ind : int, default 0 193 | The starting index on the x-axis 194 | 195 | Returns 196 | ------- 197 | Figure 198 | matplotlib.figure.Figure object 199 | 200 | See Also 201 | -------- 202 | """ 203 | 204 | check_consistent_length((y_true, predictions)) 205 | ys = to_numpy1D(y_true, "y_true",return_copy=True) 206 | preds = to_numpy2D(predictions,"predictions") 207 | 208 | fig, ax = get_fig_and_axis(ax, figsize) 209 | 210 | # sorted by the true labels 211 | sort_order = ys.argsort() 212 | ys = ys[sort_order] 213 | preds = preds[sort_order] 214 | xs = np.arange(x_start_index,x_start_index+len(ys),1) 215 | 216 | # find the correct and incorrect predictions 217 | corr_ind = (preds[:,0] <= ys) & (ys<= preds[:,1]) 218 | incorr_ind = ~corr_ind 219 | 220 | # Set the chart size 221 | _set_chart_size(ax, 222 | [x_start_index,len(y_true)+x_start_index], 223 | [np.max(predictions), np.max(y_true), np.min(predictions), np.min(y_true)], 224 | chart_padding) 225 | 226 | # VERTICAL INTERVALS 227 | # plot corrects 228 | _plot_vline(x = xs[corr_ind], 229 | y_min = preds[corr_ind,0], 230 | y_max = preds[corr_ind,1], 231 | ax = ax, 232 | color = correct_ci, 233 | alpha = correct_ci_alpha, 234 | line_cap=line_cap) 235 | 236 | # plot incorrect intervals 237 | _plot_vline(x = xs[incorr_ind], 238 | y_min = preds[incorr_ind,0], 239 | y_max = preds[incorr_ind,1], 240 | ax = ax, 241 | color = incorrect_ci, 242 | alpha = incorrect_ci_alpha, 243 | line_cap=line_cap) 244 | 245 | # plot the true values 246 | # corrects 247 | ax.scatter(xs[corr_ind], 248 | ys[corr_ind], 249 | label = correct_label, 250 | marker = correct_marker, 251 | alpha = correct_alpha, 252 | color = correct_color, 253 | zorder = 2.5) 254 | # incorrects 255 | ax.scatter(xs[incorr_ind], 256 | ys[incorr_ind], 257 | label = incorrect_label, 258 | marker = incorrect_marker, 259 | alpha = incorrect_alpha, 260 | color = incorrect_color, 261 | zorder = 2.5) 262 | 263 | # Print some labels and title if appropriate 264 | _set_label_if_not_set(ax, y_label, x_axis=False) 265 | _set_label_if_not_set(ax, x_label, x_axis=True) 266 | _set_title(ax,title) 267 | 268 | if tight_layout: 269 | fig.tight_layout() 270 | 271 | return fig 272 | -------------------------------------------------------------------------------- /python/tests/pharmbio/cp/plotting/clf_plotting_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import matplotlib.pyplot as plt 4 | 5 | from pharmbio.cp import metrics,plotting 6 | from ....help_utils import _save_clf, get_resource 7 | 8 | # Some testing data - 2 class 9 | my_data = np.genfromtxt(get_resource('transporters.p-values.csv'), delimiter=';', skip_header=1) 10 | true_labels_2_class = (my_data[:,1] == 1).astype(np.int16) 11 | p_vals_2_class = my_data[:,[2,3]] 12 | cm_2_class_015 = metrics.confusion_matrix( true_labels_2_class, p_vals_2_class, sign=0.15 ) 13 | cm_2_class_015_normalized = metrics.confusion_matrix( true_labels_2_class, p_vals_2_class, sign=0.15, normalize_per_class=True) 14 | cm_2_class_075 = metrics.confusion_matrix( true_labels_2_class, p_vals_2_class, sign=0.75 ) 15 | 16 | # 3 class 17 | data3class = np.genfromtxt(get_resource('multiclass.csv'), delimiter=',', skip_header=0) 18 | true_labels_3_class = data3class[:,0].astype(np.int16) 19 | p_vals_3_class = data3class[:,1:] 20 | cm_3_class_015 = metrics.confusion_matrix( true_labels_3_class, p_vals_3_class, sign=0.15 ) 21 | 22 | # hER predictions 23 | er_data = np.genfromtxt(get_resource('er.p-values.csv'), delimiter=',', skip_header=1) 24 | er_labels = er_data[:,0].astype(np.int16) 25 | er_pvals = er_data[:,1:] 26 | 27 | 28 | class TestPValuesPlot(): 29 | def test_2_class(self): 30 | fig = plotting.plot_pvalues(true_labels_2_class, p_values=p_vals_2_class) 31 | fig.axes[0].set_title('p0/p1 2-class') 32 | _save_clf(fig,"TestPValuesPlot.test_2_class") 33 | 34 | def test_3_class_01(self): 35 | fig = plotting.plot_pvalues(true_labels_3_class, p_values=p_vals_3_class,split_chart=False) 36 | fig.axes[0].set_title('p0/p1 3-class') 37 | _save_clf(fig,"TestPValuesPlot.test_3_class01") 38 | 39 | def test_3_class_21(self): 40 | fig = plotting.plot_pvalues(true_labels_3_class, p_values=p_vals_3_class, cols=[2,1]) 41 | fig.axes[0].set_title('p2/p1 3-class') 42 | _save_clf(fig,"TestPValuesPlot.test_3_class21") 43 | 44 | def test_3_class_only_send21(self): 45 | excl_filter = true_labels_3_class == 0 46 | 47 | fig = plotting.plot_pvalues(true_labels_3_class[~excl_filter], p_values=p_vals_3_class[~excl_filter], cols=[2,1]) 48 | fig.axes[0].set_title('p2/p1 3-class only single class 1 and 2') 49 | _save_clf(fig,"TestPValuesPlot.test_3_class21_rm0") 50 | 51 | def test_3_class_only_send_2pvals(self): 52 | fig = plotting.plot_pvalues(true_labels_3_class, p_values=p_vals_3_class[:,[0,1]],cm=['r','b','k'],title='p0/p1 3-class (2-vals sent)') 53 | fig.axes[0].set_title('p0/p1 3-class (2-vals sent)') 54 | _save_clf(fig,"TestPValuesPlot.test_3_class_only_send_2pvals") 55 | 56 | def test_cols_outside_range(self): 57 | with pytest.raises(ValueError): 58 | plotting.plot_pvalues(true_labels_2_class, p_values=p_vals_2_class, cols=[2,1]) 59 | 60 | def test_her(self): 61 | import matplotlib as mpl 62 | from matplotlib.markers import MarkerStyle 63 | import seaborn as sns 64 | sns.set_style('ticks') 65 | 66 | non_filled_o = MarkerStyle(marker='o', fillstyle='none') 67 | 68 | markers_ls = [None,'*',['*','o'],[non_filled_o,'*']] 69 | colors = [None, ['r','b'],'y'] 70 | 71 | # FREQ order 72 | # print("FREQ ORDER") 73 | freq_fig, axes = plt.subplots(3,4,figsize = (5*4,5*3)) 74 | for row, c in enumerate(colors): 75 | for col, m in enumerate(markers_ls): 76 | plotting.plot_pvalues(er_labels, er_pvals, 77 | ax=axes[row,col], 78 | order=None, 79 | cm = c, 80 | alphas= [.9,.5], 81 | markers=m, 82 | linewidths=.5, 83 | fontargs={'fontsize':'medium'}) 84 | freq_fig.tight_layout() 85 | _save_clf(freq_fig,"TestPValuesPlot.hER_all_freq") 86 | 87 | # print("CLS ORDER") 88 | label_fig, axes = plt.subplots(3,4,figsize = (5*4,5*3)) 89 | # class order 90 | for row, c in enumerate(colors): 91 | for col, m in enumerate(markers_ls): 92 | plotting.plot_pvalues(er_labels, er_pvals, 93 | ax=axes[row,col], 94 | order='class', 95 | alpha=0.8, # Same alpha for both 96 | # alphas = [.8,.75], # default 97 | cm = c, 98 | markers=m, 99 | sizes = mpl.rcParams['lines.markersize']**2, 100 | lw=1.5, 101 | fontargs={'fontsize':'large'}) 102 | label_fig.tight_layout() 103 | _save_clf(label_fig,"TestPValuesPlot.hER_all_class") 104 | 105 | # print("REV ORDER") 106 | # reverse class order 107 | rev_label_fig, axes = plt.subplots(3,4,figsize = (5*4,5*3)) 108 | for row, c in enumerate(colors): 109 | for col, m in enumerate(markers_ls): 110 | plotting.plot_pvalues(er_labels, er_pvals, 111 | ax=axes[row,col], 112 | order='rev class', 113 | cm = c, 114 | alphas= [.3,.75], 115 | markers=m,fontargs={'fontsize':'x-large'}) 116 | rev_label_fig.tight_layout() 117 | _save_clf(rev_label_fig,"TestPValuesPlot.hER_all_rev_class") 118 | 119 | 120 | class TestLabelDistributionPlot(): 121 | 122 | def test_2_class(self): 123 | fig1 = plotting.plot_label_distribution(y_true = true_labels_2_class, p_values=p_vals_2_class) 124 | fig1.axes[0].set_title('LabelDistribution 2-class') 125 | _save_clf(fig1,"TestLabelDistPlot.test_2_class") 126 | 127 | def test_3_class(self): 128 | fig = plotting.plot_label_distribution(y_true = true_labels_3_class, p_values=p_vals_3_class) 129 | fig.axes[0].set_title('LabelDistribution 3-class') 130 | _save_clf(fig,"TestLabelDistPlot.test_3_class") 131 | 132 | class TestCalibrationPlot(): 133 | def test_2_class(self): 134 | fig1 = plotting.plot_calibration_clf(true_labels_2_class, p_vals_2_class) 135 | fig1.axes[0].set_title('Calib plot 2-class') 136 | fig2 = plotting.plot_calibration_clf(true_labels_2_class, p_vals_2_class, labels = ['class 0', 'class 1']) 137 | fig2.axes[0].set_title('Calib plot 2-class with labels') 138 | _save_clf(fig2,"TestCalibPlot.test_2_class") 139 | 140 | def test_3_class(self): 141 | fig = plotting.plot_calibration_clf(true_labels_3_class, p_vals_3_class, labels = ['A', 'B', 'C']) 142 | fig.axes[0].set_title('Calib plot 3-class, labels={A,B,C}') 143 | _save_clf(fig,"TestCalibPlot.test_3_class") 144 | 145 | def test_3_class_conf_acc(self): 146 | # Plot all in one image 147 | fig, axes = plt.subplots(2,2,figsize=(10,10)) 148 | # std 149 | plotting.plot_calibration_clf(true_labels_3_class, p_values=p_vals_3_class, labels = ['A', 'B', 'C'], ax=axes[0,0], title='std') 150 | # flip x 151 | plotting.plot_calibration_clf(true_labels_3_class, p_values=p_vals_3_class, labels = ['A', 'B', 'C'], ax=axes[0,1], flip_x=True, title='flip x') 152 | # flip y 153 | plotting.plot_calibration_clf(true_labels_3_class, p_values=p_vals_3_class, labels = ['A', 'B', 'C'], ax=axes[1,0], flip_y=True, title='flip y') 154 | # flip both 155 | plotting.plot_calibration_clf(true_labels_3_class, p_values=p_vals_3_class, labels = ['A', 'B', 'C'], ax=axes[1,1], flip_x=True, flip_y=True, title='both') 156 | _save_clf(fig,"TestCalibPlot.test_3_class_flip") 157 | 158 | 159 | class TestBubblePlot(): 160 | 161 | def test_3_class(self): 162 | fig1 = plotting.plot_confusion_matrix_bubbles(cm_3_class_015,color_scheme=None) 163 | fig1.axes[0].set_title('Bubbles 3-class 0.15') 164 | _save_clf(fig1,"TestBubblebPlot.test_3_class") 165 | 166 | def test_2_class_percentage(self): 167 | fig2 = plotting.plot_confusion_matrix_bubbles(cm_2_class_015_normalized, annotate=True, annotate_as_percentage=True, figsize=(6,7)) 168 | fig2.axes[0].set_title('Bubbles 2-class 0.15 - percentage - scale 5.5') 169 | _save_clf(fig2,"TestBubblebPlot.test_2_class_1_percentage") 170 | 171 | # Test without normalized CM 172 | with pytest.raises(ValueError): 173 | _ = plotting.plot_confusion_matrix_bubbles(cm_2_class_015, annotate=True, annotate_as_percentage=True, figsize=(6,7)) 174 | 175 | def test_2_class(self): 176 | fig2 = plotting.plot_confusion_matrix_bubbles(cm_2_class_015, annotate=False, scale_factor=5.5, figsize=(6,7)) 177 | fig2.axes[0].set_title('Bubbles 2-class 0.15 - no annotation - scale 5.5') 178 | _save_clf(fig2,"TestBubblebPlot.test_2_class_1") 179 | 180 | fig3 = plotting.plot_confusion_matrix_bubbles(cm_2_class_075) 181 | fig3.axes[0].set_title('Bubbles 2-class 0.75') 182 | _save_clf(fig3,"TestBubblebPlot.test_2_class_2") 183 | 184 | def test_illegal_color_scheme(self): 185 | with pytest.warns(UserWarning): 186 | fig_ = plotting.plot_confusion_matrix_bubbles(cm_2_class_015, color_scheme='bad_arg', annotate=False, scale_factor=5.5, figsize=(6,7)) 187 | 188 | class TestConfusionMatrixHeatmap(): 189 | 190 | def test_3_class(self): 191 | fig1 = plotting.plot_confusion_matrix_heatmap(cm_3_class_015) 192 | fig1.axes[0].set_title('Heatmap 3-class 0.15') 193 | _save_clf(fig1,"TestConfMatrixHeatMap.test_3_class") 194 | 195 | def test_2_class(self): 196 | fig2 = plotting.plot_confusion_matrix_heatmap(cm_2_class_015, cmap="YlGnBu") 197 | fig2.axes[0].set_title('Heatmap 2-class 0.15 (YllGnBu colormap)') 198 | _save_clf(fig2,"TestConfMatrixHeatMap.test_2_class_1") 199 | fig3 = plotting.plot_confusion_matrix_heatmap(cm_2_class_075) 200 | fig3.axes[0].set_title('Heatmap 2-class 0.75') 201 | _save_clf(fig3,"TestConfMatrixHeatMap.test_2_class_2") 202 | 203 | class FinalTest(): 204 | 205 | def display_plots(self): 206 | plt.show() 207 | 208 | -------------------------------------------------------------------------------- /python/tests/resources/transporters.p-values.csv: -------------------------------------------------------------------------------- 1 | name;target;p-value (class -1);p-value (class 1) s=Cramer_2007__n=1b;-1;0.932813525;0.001769804 s=Feng_2009__n=103612219;1;0.003274166;0.770759509 s=Marighetti_2013__n=4;1;0.013958888;0.679011672 s=Matsson_2009__n=Tetracycline;-1;0.924027199;0.002966592 s=Patel_2011__n=5;1;0.201022127;0.262879965 cmpd_86;1;0.001956786;0.767231702 s=Boumendjel_2005__n=103513516;-1;0.919921501;0.00113258 s=Ivnitski-Steele_2008__n=14741197;-1;0.80137057;0.024913439 s=Ivnitski-Steele_2010__n=92764892;1;0.007062843;0.966611259 s=Matsson_2007__n=Ketoconazole;1;0.230752181;0.227916077 s=Ivnitski-Steele_2010__n=96022052;1;0.004578412;0.968710764 s=acridones_Boumendjel_2007__n=1d;-1;0.044183517;0.595940981 cmpd_92;-1;0.143142093;0.288384415 s=Loevezijn_2001__n=B1;-1;0.384835685;0.191695442 cmpd_121;1;0.073748827;0.472165758 s=Patel_2011__n=17;-1;0.801885869;0.026671154 cmpd_184;1;0.006685016;0.840032875 s=Matsson_2007__n=Sulfinpyrazone;-1;0.564692637;0.085891033 s=Ahmed-Belkacem_2007__n=103510952;1;0.183255546;0.274586019 cmpd_29;1;0.04892914;0.536122222 s=Saito_2006__n=acetylsalicylic_acid;-1;0.948662571;0.004042035 s=Ivnitski-Steele_2008__n=22403643;-1;0.473267846;0.137467786 s=Ivnitski-Steele_2008__n=22406563;-1;0.417097958;0.183401522 s=Marighetti_2013__n=22;1;0.14803234;0.284981491 s=Pick_2010__n=103720970;1;0.007135356;0.769870906 cmpd_81;-1;0.119375479;0.393580204 s=Marighetti_2013__n=5;-1;0.041378474;0.609280235 s=Loevezijn_2001__n=B2;-1;0.299202878;0.226029135 cmpd_144;-1;0.483513885;0.122007782 s=Matsson_2009__n=Chloroquine;-1;0.49395026;0.114129711 cmpd_34;1;0.005886064;0.810672124 s=Feng_2008__n=103646945;1;0.006732286;0.787776563 cmpd_109;1;0.001195506;0.833281963 cmpd_141;-1;0.047641278;0.534795994 s=Ivnitski-Steele_2008__n=844597;-1;0.657618059;0.051218606 s=Weiss_2007__n=Atazanavir;-1;0.149902388;0.312446901 cmpd_181;-1;0.011387457;0.666118775 s=Ivnitski-Steele_2008__n=22403400;-1;0.359985818;0.205065334 s=Ivnitski-Steele_2008__n=24817134;1;0.74776768;0.031178886 s=Ivnitski-Steele_2008__n=22402816;-1;0.253416596;0.225831351 s=Colabufo_2008__n=103578071;-1;0.20909241;0.230101003 s=Matsson_2007__n=Dipyridamole;1;0.53865369;0.105530757 s=Cramer_2007__n=2c;-1;0.14598154;0.291794999 cmpd_113;1;0.000409625;0.969941797 s=Saito_2006__n=Glycine;-1;0.976897905;0.002441175 s=Ivnitski-Steele_2008__n=17401132;-1;0.065973131;0.472574675 s=phenylquinazolines_Juvale_2012__n=2;1;0.006914999;0.706425155 s=Ivnitski-Steele_2008__n=24797597;-1;0.548175722;0.102448981 s=Weiss_2007__n=Tenofovir;-1;0.931829055;0.00586298 cmpd_99;1;0.003889876;0.836457619 s=Matsson_2009__n=Taurolithocholic_acid;-1;0.213377706;0.228525223 s=Jin_2006__n=Ginsenoside_Rh2;1;0.185948651;0.28077806 s=Ivnitski-Steele_2010__n=93619256;1;0.005318768;0.96029725 s=Ivnitski-Steele_2010__n=96022047;1;0.002327457;0.769598158 s=Matsson_2007__n=Folic_acid;-1;0.420302633;0.167428871 cmpd_42;1;0.00173474;0.726949803 s=Sugimoto_2003__n=TAG-3;-1;0.181894216;0.262375001 s=Matsson_2007__n=Chenodeoxycholic_acid;-1;0.463814855;0.152097892 s=Ivnitski-Steele_2008__n=22407563;-1;0.332899617;0.211780505 s=Matsson_2007__n=Salicylic_acid;-1;0.963611216;0.001891898 cmpd_82;1;0.040877303;0.615887326 s=Saito_2006__n=hematoporphyrin;1;0.051830554;0.536852535 s=Ivnitski-Steele_2008__n=16953419;-1;0.766622057;0.027581134 s=Juvale_2012__n=103566963;1;0.002737976;0.845463629 s=Matsson_2007__n=Zidovudine;-1;0.866520476;0.012728905 cmpd_2;1;0.570848116;0.095781726 s=Ivnitski-Steele_2008__n=17412701;-1;0.519219188;0.115053187 cmpd_31;-1;0.585628703;0.071204787 s=Ivnitski-Steele_2008__n=7977793;-1;0.679134808;0.051217426 s=Ivnitski-Steele_2008__n=22403593;-1;0.669138879;0.052118405 s=Matsson_2007__n=Neomycin_sulfate;-1;0.745936181;0.032489536 s=Ivnitski-Steele_2008__n=24837996;-1;0.768021418;0.020180326 s=Matsson_2007__n=Methotrexate;-1;0.496012931;0.122519985 s=Juvale_2012__n=103323399;-1;0.743794822;0.033495725 s=Juvale_2012__n=103567050;1;0.14790034;0.282777564 s=Matsson_2007__n=Tinidazole;-1;0.916997893;0.002704392 s=Ivnitski-Steele_2008__n=24826880;-1;0.47658495;0.116760994 s=Loevezijn_2001__n=C1;-1;0.15180502;0.276857446 s=Saito_2006__n=cortisone;-1;0.833350518;0.02110262 s=Ivnitski-Steele_2010__n=93619262;1;0.096394943;0.403408323 cmpd_105;-1;0.001328766;0.76193346 s=Juvale_2012__n=103206494;-1;0.643926628;0.059435498 s=flavonoids_Zhang_2005__n=6,2',3'-7-Hydroxyflavanone;1;0.293255238;0.226199887 s=Boumendjel_2005__n=103513508;-1;0.667180584;0.049436423 cmpd_10;1;0.014222697;0.654097376 s=Kuhnle_2009__n=103591272;1;0.027792779;0.637360101 s=Juvale_2012__n=136929453;-1;0.408486331;0.185488756 s=Ivnitski-Steele_2008__n=852968;-1;0.560437393;0.091096121 s=Ivnitski-Steele_2008__n=17512364;-1;0.285014128;0.225456174 s=Marighetti_2013__n=19;1;0.040175248;0.610524843 s=Pick_2008__n=103569129;1;0.180872213;0.271645975 s=Ivnitski-Steele_2010__n=93619254;1;0.003422729;0.807595632 cmpd_50;-1;0.073377633;0.442244087 s=Cramer_2007__n=4;-1;0.641954325;0.058926528 s=Ivnitski-Steele_2008__n=14743034;-1;0.830819598;0.020460879 cmpd_156;-1;0.419357437;0.187710939 s=Wang_2008__n=Olanzapine;-1;0.72104682;0.044061217 s=Ivnitski-Steele_2008__n=17504108;-1;0.13552658;0.347442771 s=Matsson_2007__n=Phenytoin;-1;0.600661495;0.077093697 s=Ochoa-Puentes_2011__n=131287273;1;0.05220358;0.533382602 s=Sugimoto_2003__n=TAG-11;1;0.056886073;0.5353999 s=Ivnitski-Steele_2010__n=87550714;1;0.000962545;0.972253702 s=Saito_2006__n=melatonin;-1;0.943934607;0.00628152 s=Imai_2004__n=Kaempferide;1;0.003601839;0.699825831 s=Ivnitski-Steele_2008__n=17504141;-1;0.214662592;0.231835087 s=Matsson_2007__n=Maprotiline;-1;0.15011401;0.289677259 s=Ivnitski-Steele_2010__n=88095709;1;0.000640263;0.967372063 s=Jin_2006__n=Ginsenoside_Rg3;-1;0.644566822;0.064307676 s=Ivnitski-Steele_2010__n=93619268;1;3.99E-05;0.79966521 s=Juvale_2012__n=136939285;1;0.00616336;0.858734883 s=Matsson_2007__n=Digoxin;-1;0.523545727;0.113885099 s=Matsson_2009__n=Bromosulfalein;-1;0.336851277;0.209172237 s=Feng_2009__n=103453032;1;0.000333607;0.758069409 cmpd_157;-1;0.475926911;0.127284992 cmpd_62;-1;0.795732202;0.020479168 s=Matsson_2007__n=Carbamazepine;-1;0.3592486;0.189069823 s=Imai_2004__n=Luteolin-4'-beta-D-glucoside;1;0.829937726;0.020374387 s=Feng_2009__n=103220117;1;0.116431978;0.377932408 s=Matsson_2007__n=Hydralazine;-1;0.933409102;0.001120833 cmpd_191;-1;0.082530994;0.426719758 s=Versiani_2011__n=124968332;1;0.129736243;0.366192851 s=Pan_2013__n=Fosinopril;1;0.181439142;0.264104966 s=Juvale_2012__n=103449034;1;0.148517482;0.316633732 cmpd_133;1;0.041322963;0.606339268 cmpd_20;-1;0.529126076;0.112510597 cmpd_117;1;0.00324315;0.76834121 cmpd_85;1;0.043494584;0.593973792 s=Ivnitski-Steele_2010__n=90944694;1;0.006699916;0.769166789 cmpd_56;1;0.001300806;0.830839697 s=Ahmed-Belkacem_2007__n=103510963;-1;0.077076491;0.420495625 s=Ivnitski-Steele_2010__n=99361158;1;0.002721518;0.971101179 s=Matsson_2007__n=Carisoprodol;-1;0.905882236;0.000994242 s=Loevezijn_2001__n=F2;-1;0.652929745;0.052153513 cmpd_185;1;0.005106766;0.744737035 s=flavonoids_Zhang_2004__n=Silymarin;1;0.466046059;0.123099436 s=Marighetti_2013__n=7;1;0.071719116;0.436964602 cmpd_91;1;0.046905816;0.570666603 s=Ivnitski-Steele_2008__n=4263775;-1;0.758692384;0.030644283 cmpd_12;1;0.038372866;0.60520206 cmpd_5;1;0.047818479;0.542044143 s=Matsson_2007__n=Meclizine;-1;0.388884806;0.191143136 s=Saito_2006__n=naproxen;-1;0.635671972;0.072760386 s=Pan_2013__n=Nicergoline;-1;0.317822333;0.221183308 s=Njus_2010__n=87350361;1;0.001528985;0.88737961 s=Ivnitski-Steele_2010__n=99376136;1;0.005311175;0.762783025 cmpd_58;-1;0.493474235;0.114726484 s=Matsson_2009__n=Indinavir;-1;0.097336572;0.403771324 s=Feng_2008__n=103646946;1;0.00595446;0.928400455 s=Ivnitski-Steele_2008__n=22407547;-1;0.532384586;0.113329923 s=Versiani_2011__n=124965660;1;0.062292426;0.493781693 s=Ahmed-Belkacem_2007__n=103510962;-1;0.719232903;0.051909753 cmpd_118;1;0.039656241;0.591777466 s=Matsson_2007__n=Hoechst_33342;1;0.162227333;0.276519564 s=Pick_2008__n=103569328;-1;0.499281962;0.116913916 s=Ivnitski-Steele_2008__n=17509535;-1;0.306858548;0.227947139 s=Matsson_2007__n=Diazepam;-1;0.653451582;0.063673122 s=Matsson_2007__n=Levothyroxine;-1;0.235538669;0.225533283 s=Imai_2004__n=Diosmin;-1;0.691754488;0.051347811 s=acridones_Boumendjel_2007__n=4d;1;0.011421807;0.703674991 s=Feng_2009__n=103612047;1;0.047817571;0.564051493 cmpd_52;1;0.071777767;0.445883202 cmpd_46;1;0.000261912;0.913932971 s=Zembruski_2011__n=103181784;1;0.416487486;0.165114888 s=Wang_2008__n=Paliperidone;-1;0.331173207;0.213282565 s=Ahmed-Belkacem_2005__n=7-hydroxyflavone;1;0.079099819;0.418927123 s=Ivnitski-Steele_2008__n=17433121;-1;0.469304444;0.135846025 cmpd_161;-1;0.814698375;0.021012226 s=Ivnitski-Steele_2008__n=24832853;-1;0.043543793;0.587924001 s=Ivnitski-Steele_2008__n=22407014;-1;0.732776017;0.036797902 s=Loevezijn_2001__n=C7;1;0.005574495;0.70575201 cmpd_8;1;0.023117277;0.643808177 s=Ivnitski-Steele_2008__n=17432057;-1;0.539258383;0.108474817 s=Matsson_2007__n=Chlorpromazine;1;0.114125679;0.39336815 s=Ivnitski-Steele_2008__n=24815249;-1;0.325797904;0.215267359 s=Jin_2006__n=34080-08-5;1;0.124712544;0.373409996 s=Boumendjel_2005__n=103513524;-1;0.690599986;0.053087858 s=Matsson_2009__n=Probenecid;-1;0.984202901;0.005250061 s=Boumendjel_2005__n=103513530;1;0.147097911;0.282508718 s=Juvale_2012__n=136929454;1;0.020350505;0.645920578 s=flavonoids_Zhang_2005__n=7,8-Benzoflavone;1;0.005745439;0.842326491 s=Ivnitski-Steele_2010__n=99361143;1;0.00140969;0.768252283 s=Ivnitski-Steele_2008__n=4254626;-1;0.189766194;0.267626181 s=Matsson_2007__n=Warfarin;-1;0.152390086;0.280673602 s=Pick_2008__n=103587881;-1;0.757748658;0.030733362 s=Marighetti_2013__n=12;1;0.038100143;0.575106497 s=Colabufo_2008_ext__n=103578561;1;0.116238964;0.390072741 s=Matsson_2007__n=Captopril;-1;0.909542298;0.000724404 s=Imai_2004__n=Diosmetin;1;0.001076233;0.808618984 s=Ivnitski-Steele_2008__n=24804288;-1;0.906527262;0.00165525 s=Ivnitski-Steele_2008__n=24806339;-1;0.467917054;0.131328372 s=Juvale_2012__n=136926279;-1;0.068408237;0.441660134 s=Feng_2009__n=103612430;1;0.001646375;0.747613628 s=Juvale_2012__n=136923018;1;0.480798287;0.122288295 s=Ivnitski-Steele_2010__n=93619259;1;0.080821188;0.426910829 cmpd_23;-1;0.129036979;0.379865885 s=Juvale_2012__n=136926280;-1;0.413604697;0.156954884 s=Marighetti_2013__n=13;1;0.041347527;0.575919817 s=Ivnitski-Steele_2010__n=92123917;1;0.003873648;0.697697871 s=Ivnitski-Steele_2008__n=17514180;-1;0.039029862;0.609431522 s=Ivnitski-Steele_2008__n=3713915;-1;0.240214149;0.233932003 s=Pick_2011__n=Nobiletin;1;0.038868012;0.586388673 s=phenylquinazolines_Juvale_2012__n=8;1;2.36E-05;0.9716344 s=Feng_2009__n=103612045;1;0.068635467;0.430730508 s=Juvale_2012__n=136939284;1;0.009041616;0.653182059 s=Ivnitski-Steele_2008__n=14744220;-1;0.420743931;0.16249045 s=Ivnitski-Steele_2008__n=17517128;-1;0.652399784;0.057710533 s=Ivnitski-Steele_2008__n=3711455;-1;0.904119952;0.002086851 s=Ivnitski-Steele_2008__n=22407122;-1;0.335377732;0.214029022 s=Katayama_2007__n=3',4',7-trimethoxyflavone;1;0.001959491;0.91880804 s=Ivnitski-Steele_2008__n=22407329;-1;0.335985599;0.210446616 cmpd_172;1;0.077003198;0.429763957 cmpd_79;1;0.055266658;0.536132507 s=Matsson_2007__n=Erlotinib;1;0.15720837;0.283925846 cmpd_110;1;0.001082325;0.843910014 s=Ivnitski-Steele_2008__n=17509379;-1;0.276423245;0.221672802 s=Juvale_2012__n=136945677;-1;0.141693795;0.366155713 s=Loevezijn_2001__n=C5;1;0.069083831;0.440977171 s=Ivnitski-Steele_2008__n=17407808;-1;0.332997059;0.210592765 s=Ivnitski-Steele_2008__n=22404620;-1;0.461682969;0.126603045 s=Ivnitski-Steele_2008__n=3715631;-1;0.564840477;0.083920483 s=Juvale_2012__n=136929456;1;0.526433192;0.094308157 s=Bokesch_2010__n=103766244;1;0.255777806;0.226072958 s=Njus_2010__n=87577984;1;0.005373387;0.808530143 s=Matsson_2009__n=Benzbromarone;1;0.321874889;0.220257383 s=Boumendjel_2005__n=103261745;-1;0.907670855;0.012656731 s=Cramer_2007__n=5c;1;0.187789949;0.275758806 s=Matsson_2007__n=Sulindac;-1;0.637527306;0.048594502 s=Ivnitski-Steele_2008__n=857021;-1;0.647914708;0.050270328 s=Holland_2007__n=tetrahydrocannabinol;1;0.199978569;0.254577722 s=Matsson_2009__n=Rifampicin;-1;0.33244987;0.208066361 cmpd_84;1;0.003481579;0.78090516 s=Ivnitski-Steele_2008__n=846390;-1;0.510554812;0.117497069 s=Pan_2013__n=Trifluoperazine;1;0.598185478;0.067640628 -------------------------------------------------------------------------------- /python/tests/resources/cpsign_clf_predictions.csv: -------------------------------------------------------------------------------- 1 | name;target;p-value [label=-1];p-value [label=1] 2 | s=Cramer_2007__n=1b;-1;0.932813525;0.001769804 3 | s=Feng_2009__n=103612219;1;0.003274166;0.770759509 4 | s=Marighetti_2013__n=4;1;0.013958888;0.679011672 5 | s=Matsson_2009__n=Tetracycline;-1;0.924027199;0.002966592 6 | s=Patel_2011__n=5;1;0.201022127;0.262879965 7 | cmpd_86;1;0.001956786;0.767231702 8 | s=Boumendjel_2005__n=103513516;-1;0.919921501;0.00113258 9 | s=Ivnitski-Steele_2008__n=14741197;-1;0.80137057;0.024913439 10 | s=Ivnitski-Steele_2010__n=92764892;1;0.007062843;0.966611259 11 | s=Matsson_2007__n=Ketoconazole;1;0.230752181;0.227916077 12 | s=Ivnitski-Steele_2010__n=96022052;1;0.004578412;0.968710764 13 | s=acridones_Boumendjel_2007__n=1d;-1;0.044183517;0.595940981 14 | cmpd_92;-1;0.143142093;0.288384415 15 | s=Loevezijn_2001__n=B1;-1;0.384835685;0.191695442 16 | cmpd_121;1;0.073748827;0.472165758 17 | s=Patel_2011__n=17;-1;0.801885869;0.026671154 18 | cmpd_184;1;0.006685016;0.840032875 19 | s=Matsson_2007__n=Sulfinpyrazone;-1;0.564692637;0.085891033 20 | s=Ahmed-Belkacem_2007__n=103510952;1;0.183255546;0.274586019 21 | cmpd_29;1;0.04892914;0.536122222 22 | s=Saito_2006__n=acetylsalicylic_acid;-1;0.948662571;0.004042035 23 | s=Ivnitski-Steele_2008__n=22403643;-1;0.473267846;0.137467786 24 | s=Ivnitski-Steele_2008__n=22406563;-1;0.417097958;0.183401522 25 | s=Marighetti_2013__n=22;1;0.14803234;0.284981491 26 | s=Pick_2010__n=103720970;1;0.007135356;0.769870906 27 | cmpd_81;-1;0.119375479;0.393580204 28 | s=Marighetti_2013__n=5;-1;0.041378474;0.609280235 29 | s=Loevezijn_2001__n=B2;-1;0.299202878;0.226029135 30 | cmpd_144;-1;0.483513885;0.122007782 31 | s=Matsson_2009__n=Chloroquine;-1;0.49395026;0.114129711 32 | cmpd_34;1;0.005886064;0.810672124 33 | s=Feng_2008__n=103646945;1;0.006732286;0.787776563 34 | cmpd_109;1;0.001195506;0.833281963 35 | cmpd_141;-1;0.047641278;0.534795994 36 | s=Ivnitski-Steele_2008__n=844597;-1;0.657618059;0.051218606 37 | s=Weiss_2007__n=Atazanavir;-1;0.149902388;0.312446901 38 | cmpd_181;-1;0.011387457;0.666118775 39 | s=Ivnitski-Steele_2008__n=22403400;-1;0.359985818;0.205065334 40 | s=Ivnitski-Steele_2008__n=24817134;1;0.74776768;0.031178886 41 | s=Ivnitski-Steele_2008__n=22402816;-1;0.253416596;0.225831351 42 | s=Colabufo_2008__n=103578071;-1;0.20909241;0.230101003 43 | s=Matsson_2007__n=Dipyridamole;1;0.53865369;0.105530757 44 | s=Cramer_2007__n=2c;-1;0.14598154;0.291794999 45 | cmpd_113;1;0.000409625;0.969941797 46 | s=Saito_2006__n=Glycine;-1;0.976897905;0.002441175 47 | s=Ivnitski-Steele_2008__n=17401132;-1;0.065973131;0.472574675 48 | s=phenylquinazolines_Juvale_2012__n=2;1;0.006914999;0.706425155 49 | s=Ivnitski-Steele_2008__n=24797597;-1;0.548175722;0.102448981 50 | s=Weiss_2007__n=Tenofovir;-1;0.931829055;0.00586298 51 | cmpd_99;1;0.003889876;0.836457619 52 | s=Matsson_2009__n=Taurolithocholic_acid;-1;0.213377706;0.228525223 53 | s=Jin_2006__n=Ginsenoside_Rh2;1;0.185948651;0.28077806 54 | s=Ivnitski-Steele_2010__n=93619256;1;0.005318768;0.96029725 55 | s=Ivnitski-Steele_2010__n=96022047;1;0.002327457;0.769598158 56 | s=Matsson_2007__n=Folic_acid;-1;0.420302633;0.167428871 57 | cmpd_42;1;0.00173474;0.726949803 58 | s=Sugimoto_2003__n=TAG-3;-1;0.181894216;0.262375001 59 | s=Matsson_2007__n=Chenodeoxycholic_acid;-1;0.463814855;0.152097892 60 | s=Ivnitski-Steele_2008__n=22407563;-1;0.332899617;0.211780505 61 | s=Matsson_2007__n=Salicylic_acid;-1;0.963611216;0.001891898 62 | cmpd_82;1;0.040877303;0.615887326 63 | s=Saito_2006__n=hematoporphyrin;1;0.051830554;0.536852535 64 | s=Ivnitski-Steele_2008__n=16953419;-1;0.766622057;0.027581134 65 | s=Juvale_2012__n=103566963;1;0.002737976;0.845463629 66 | s=Matsson_2007__n=Zidovudine;-1;0.866520476;0.012728905 67 | cmpd_2;1;0.570848116;0.095781726 68 | s=Ivnitski-Steele_2008__n=17412701;-1;0.519219188;0.115053187 69 | cmpd_31;-1;0.585628703;0.071204787 70 | s=Ivnitski-Steele_2008__n=7977793;-1;0.679134808;0.051217426 71 | s=Ivnitski-Steele_2008__n=22403593;-1;0.669138879;0.052118405 72 | s=Matsson_2007__n=Neomycin_sulfate;-1;0.745936181;0.032489536 73 | s=Ivnitski-Steele_2008__n=24837996;-1;0.768021418;0.020180326 74 | s=Matsson_2007__n=Methotrexate;-1;0.496012931;0.122519985 75 | s=Juvale_2012__n=103323399;-1;0.743794822;0.033495725 76 | s=Juvale_2012__n=103567050;1;0.14790034;0.282777564 77 | s=Matsson_2007__n=Tinidazole;-1;0.916997893;0.002704392 78 | s=Ivnitski-Steele_2008__n=24826880;-1;0.47658495;0.116760994 79 | s=Loevezijn_2001__n=C1;-1;0.15180502;0.276857446 80 | s=Saito_2006__n=cortisone;-1;0.833350518;0.02110262 81 | s=Ivnitski-Steele_2010__n=93619262;1;0.096394943;0.403408323 82 | cmpd_105;-1;0.001328766;0.76193346 83 | s=Juvale_2012__n=103206494;-1;0.643926628;0.059435498 84 | s=flavonoids_Zhang_2005__n=6,2',3'-7-Hydroxyflavanone;1;0.293255238;0.226199887 85 | s=Boumendjel_2005__n=103513508;-1;0.667180584;0.049436423 86 | cmpd_10;1;0.014222697;0.654097376 87 | s=Kuhnle_2009__n=103591272;1;0.027792779;0.637360101 88 | s=Juvale_2012__n=136929453;-1;0.408486331;0.185488756 89 | s=Ivnitski-Steele_2008__n=852968;-1;0.560437393;0.091096121 90 | s=Ivnitski-Steele_2008__n=17512364;-1;0.285014128;0.225456174 91 | s=Marighetti_2013__n=19;1;0.040175248;0.610524843 92 | s=Pick_2008__n=103569129;1;0.180872213;0.271645975 93 | s=Ivnitski-Steele_2010__n=93619254;1;0.003422729;0.807595632 94 | cmpd_50;-1;0.073377633;0.442244087 95 | s=Cramer_2007__n=4;-1;0.641954325;0.058926528 96 | s=Ivnitski-Steele_2008__n=14743034;-1;0.830819598;0.020460879 97 | cmpd_156;-1;0.419357437;0.187710939 98 | s=Wang_2008__n=Olanzapine;-1;0.72104682;0.044061217 99 | s=Ivnitski-Steele_2008__n=17504108;-1;0.13552658;0.347442771 100 | s=Matsson_2007__n=Phenytoin;-1;0.600661495;0.077093697 101 | s=Ochoa-Puentes_2011__n=131287273;1;0.05220358;0.533382602 102 | s=Sugimoto_2003__n=TAG-11;1;0.056886073;0.5353999 103 | s=Ivnitski-Steele_2010__n=87550714;1;0.000962545;0.972253702 104 | s=Saito_2006__n=melatonin;-1;0.943934607;0.00628152 105 | s=Imai_2004__n=Kaempferide;1;0.003601839;0.699825831 106 | s=Ivnitski-Steele_2008__n=17504141;-1;0.214662592;0.231835087 107 | s=Matsson_2007__n=Maprotiline;-1;0.15011401;0.289677259 108 | s=Ivnitski-Steele_2010__n=88095709;1;0.000640263;0.967372063 109 | s=Jin_2006__n=Ginsenoside_Rg3;-1;0.644566822;0.064307676 110 | s=Ivnitski-Steele_2010__n=93619268;1;3.99E-05;0.79966521 111 | s=Juvale_2012__n=136939285;1;0.00616336;0.858734883 112 | s=Matsson_2007__n=Digoxin;-1;0.523545727;0.113885099 113 | s=Matsson_2009__n=Bromosulfalein;-1;0.336851277;0.209172237 114 | s=Feng_2009__n=103453032;1;0.000333607;0.758069409 115 | cmpd_157;-1;0.475926911;0.127284992 116 | cmpd_62;-1;0.795732202;0.020479168 117 | s=Matsson_2007__n=Carbamazepine;-1;0.3592486;0.189069823 118 | s=Imai_2004__n=Luteolin-4'-beta-D-glucoside;1;0.829937726;0.020374387 119 | s=Feng_2009__n=103220117;1;0.116431978;0.377932408 120 | s=Matsson_2007__n=Hydralazine;-1;0.933409102;0.001120833 121 | cmpd_191;-1;0.082530994;0.426719758 122 | s=Versiani_2011__n=124968332;1;0.129736243;0.366192851 123 | s=Pan_2013__n=Fosinopril;1;0.181439142;0.264104966 124 | s=Juvale_2012__n=103449034;1;0.148517482;0.316633732 125 | cmpd_133;1;0.041322963;0.606339268 126 | cmpd_20;-1;0.529126076;0.112510597 127 | cmpd_117;1;0.00324315;0.76834121 128 | cmpd_85;1;0.043494584;0.593973792 129 | s=Ivnitski-Steele_2010__n=90944694;1;0.006699916;0.769166789 130 | cmpd_56;1;0.001300806;0.830839697 131 | s=Ahmed-Belkacem_2007__n=103510963;-1;0.077076491;0.420495625 132 | s=Ivnitski-Steele_2010__n=99361158;1;0.002721518;0.971101179 133 | s=Matsson_2007__n=Carisoprodol;-1;0.905882236;0.000994242 134 | s=Loevezijn_2001__n=F2;-1;0.652929745;0.052153513 135 | cmpd_185;1;0.005106766;0.744737035 136 | s=flavonoids_Zhang_2004__n=Silymarin;1;0.466046059;0.123099436 137 | s=Marighetti_2013__n=7;1;0.071719116;0.436964602 138 | cmpd_91;1;0.046905816;0.570666603 139 | s=Ivnitski-Steele_2008__n=4263775;-1;0.758692384;0.030644283 140 | cmpd_12;1;0.038372866;0.60520206 141 | cmpd_5;1;0.047818479;0.542044143 142 | s=Matsson_2007__n=Meclizine;-1;0.388884806;0.191143136 143 | s=Saito_2006__n=naproxen;-1;0.635671972;0.072760386 144 | s=Pan_2013__n=Nicergoline;-1;0.317822333;0.221183308 145 | s=Njus_2010__n=87350361;1;0.001528985;0.88737961 146 | s=Ivnitski-Steele_2010__n=99376136;1;0.005311175;0.762783025 147 | cmpd_58;-1;0.493474235;0.114726484 148 | s=Matsson_2009__n=Indinavir;-1;0.097336572;0.403771324 149 | s=Feng_2008__n=103646946;1;0.00595446;0.928400455 150 | s=Ivnitski-Steele_2008__n=22407547;-1;0.532384586;0.113329923 151 | s=Versiani_2011__n=124965660;1;0.062292426;0.493781693 152 | s=Ahmed-Belkacem_2007__n=103510962;-1;0.719232903;0.051909753 153 | cmpd_118;1;0.039656241;0.591777466 154 | s=Matsson_2007__n=Hoechst_33342;1;0.162227333;0.276519564 155 | s=Pick_2008__n=103569328;-1;0.499281962;0.116913916 156 | s=Ivnitski-Steele_2008__n=17509535;-1;0.306858548;0.227947139 157 | s=Matsson_2007__n=Diazepam;-1;0.653451582;0.063673122 158 | s=Matsson_2007__n=Levothyroxine;-1;0.235538669;0.225533283 159 | s=Imai_2004__n=Diosmin;-1;0.691754488;0.051347811 160 | s=acridones_Boumendjel_2007__n=4d;1;0.011421807;0.703674991 161 | s=Feng_2009__n=103612047;1;0.047817571;0.564051493 162 | cmpd_52;1;0.071777767;0.445883202 163 | cmpd_46;1;0.000261912;0.913932971 164 | s=Zembruski_2011__n=103181784;1;0.416487486;0.165114888 165 | s=Wang_2008__n=Paliperidone;-1;0.331173207;0.213282565 166 | s=Ahmed-Belkacem_2005__n=7-hydroxyflavone;1;0.079099819;0.418927123 167 | s=Ivnitski-Steele_2008__n=17433121;-1;0.469304444;0.135846025 168 | cmpd_161;-1;0.814698375;0.021012226 169 | s=Ivnitski-Steele_2008__n=24832853;-1;0.043543793;0.587924001 170 | s=Ivnitski-Steele_2008__n=22407014;-1;0.732776017;0.036797902 171 | s=Loevezijn_2001__n=C7;1;0.005574495;0.70575201 172 | cmpd_8;1;0.023117277;0.643808177 173 | s=Ivnitski-Steele_2008__n=17432057;-1;0.539258383;0.108474817 174 | s=Matsson_2007__n=Chlorpromazine;1;0.114125679;0.39336815 175 | s=Ivnitski-Steele_2008__n=24815249;-1;0.325797904;0.215267359 176 | s=Jin_2006__n=34080-08-5;1;0.124712544;0.373409996 177 | s=Boumendjel_2005__n=103513524;-1;0.690599986;0.053087858 178 | s=Matsson_2009__n=Probenecid;-1;0.984202901;0.005250061 179 | s=Boumendjel_2005__n=103513530;1;0.147097911;0.282508718 180 | s=Juvale_2012__n=136929454;1;0.020350505;0.645920578 181 | s=flavonoids_Zhang_2005__n=7,8-Benzoflavone;1;0.005745439;0.842326491 182 | s=Ivnitski-Steele_2010__n=99361143;1;0.00140969;0.768252283 183 | s=Ivnitski-Steele_2008__n=4254626;-1;0.189766194;0.267626181 184 | s=Matsson_2007__n=Warfarin;-1;0.152390086;0.280673602 185 | s=Pick_2008__n=103587881;-1;0.757748658;0.030733362 186 | s=Marighetti_2013__n=12;1;0.038100143;0.575106497 187 | s=Colabufo_2008_ext__n=103578561;1;0.116238964;0.390072741 188 | s=Matsson_2007__n=Captopril;-1;0.909542298;0.000724404 189 | s=Imai_2004__n=Diosmetin;1;0.001076233;0.808618984 190 | s=Ivnitski-Steele_2008__n=24804288;-1;0.906527262;0.00165525 191 | s=Ivnitski-Steele_2008__n=24806339;-1;0.467917054;0.131328372 192 | s=Juvale_2012__n=136926279;-1;0.068408237;0.441660134 193 | s=Feng_2009__n=103612430;1;0.001646375;0.747613628 194 | s=Juvale_2012__n=136923018;1;0.480798287;0.122288295 195 | s=Ivnitski-Steele_2010__n=93619259;1;0.080821188;0.426910829 196 | cmpd_23;-1;0.129036979;0.379865885 197 | s=Juvale_2012__n=136926280;-1;0.413604697;0.156954884 198 | s=Marighetti_2013__n=13;1;0.041347527;0.575919817 199 | s=Ivnitski-Steele_2010__n=92123917;1;0.003873648;0.697697871 200 | s=Ivnitski-Steele_2008__n=17514180;-1;0.039029862;0.609431522 201 | s=Ivnitski-Steele_2008__n=3713915;-1;0.240214149;0.233932003 202 | s=Pick_2011__n=Nobiletin;1;0.038868012;0.586388673 203 | s=phenylquinazolines_Juvale_2012__n=8;1;2.36E-05;0.9716344 204 | s=Feng_2009__n=103612045;1;0.068635467;0.430730508 205 | s=Juvale_2012__n=136939284;1;0.009041616;0.653182059 206 | s=Ivnitski-Steele_2008__n=14744220;-1;0.420743931;0.16249045 207 | s=Ivnitski-Steele_2008__n=17517128;-1;0.652399784;0.057710533 208 | s=Ivnitski-Steele_2008__n=3711455;-1;0.904119952;0.002086851 209 | s=Ivnitski-Steele_2008__n=22407122;-1;0.335377732;0.214029022 210 | s=Katayama_2007__n=3',4',7-trimethoxyflavone;1;0.001959491;0.91880804 211 | s=Ivnitski-Steele_2008__n=22407329;-1;0.335985599;0.210446616 212 | cmpd_172;1;0.077003198;0.429763957 213 | cmpd_79;1;0.055266658;0.536132507 214 | s=Matsson_2007__n=Erlotinib;1;0.15720837;0.283925846 215 | cmpd_110;1;0.001082325;0.843910014 216 | s=Ivnitski-Steele_2008__n=17509379;-1;0.276423245;0.221672802 217 | s=Juvale_2012__n=136945677;-1;0.141693795;0.366155713 218 | s=Loevezijn_2001__n=C5;1;0.069083831;0.440977171 219 | s=Ivnitski-Steele_2008__n=17407808;-1;0.332997059;0.210592765 220 | s=Ivnitski-Steele_2008__n=22404620;-1;0.461682969;0.126603045 221 | s=Ivnitski-Steele_2008__n=3715631;-1;0.564840477;0.083920483 222 | s=Juvale_2012__n=136929456;1;0.526433192;0.094308157 223 | s=Bokesch_2010__n=103766244;1;0.255777806;0.226072958 224 | s=Njus_2010__n=87577984;1;0.005373387;0.808530143 225 | s=Matsson_2009__n=Benzbromarone;1;0.321874889;0.220257383 226 | s=Boumendjel_2005__n=103261745;-1;0.907670855;0.012656731 227 | s=Cramer_2007__n=5c;1;0.187789949;0.275758806 228 | s=Matsson_2007__n=Sulindac;-1;0.637527306;0.048594502 229 | s=Ivnitski-Steele_2008__n=857021;-1;0.647914708;0.050270328 230 | s=Holland_2007__n=tetrahydrocannabinol;1;0.199978569;0.254577722 231 | s=Matsson_2009__n=Rifampicin;-1;0.33244987;0.208066361 232 | cmpd_84;1;0.003481579;0.78090516 233 | s=Ivnitski-Steele_2008__n=846390;-1;0.510554812;0.117497069 234 | s=Pan_2013__n=Trifluoperazine;1;0.598185478;0.067640628 -------------------------------------------------------------------------------- /python/src/pharmbio/cp/plotting/_common.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | import matplotlib as mpl 3 | import numpy as np 4 | from sklearn.utils import check_consistent_length 5 | from numpy.core.fromnumeric import sort 6 | from ._utils import get_fig_and_axis, cm_as_list, _set_title, _set_label_if_not_set,_set_chart_size 7 | from ..utils import to_numpy2D, validate_sign, to_numpy1D 8 | 9 | _default_color_map = list(mpl.rcParams['axes.prop_cycle'].by_key()['color']) 10 | 11 | def update_plot_settings(theme_style = 'ticks', context = 'notebook', font_scale = 1): 12 | ''' 13 | Update the global plot-settings, requires having seaborn available. 14 | 15 | This is simply a convenience wrapper of the functions `seaborn.set_context` and `seaborn.set_theme`. If seaborn is not available, this function 16 | will not make any alterations 17 | ''' 18 | try: 19 | import seaborn as sns 20 | sns.set_context(context=context) 21 | custom_params = {"axes.spines.right": False, "axes.spines.top": False} 22 | sns.set_theme(style=theme_style, rc=custom_params, font_scale=font_scale) 23 | except ImportError as e: 24 | pass 25 | 26 | def add_calib_curve(ax, 27 | sign_vals, 28 | error_rates, 29 | label = None, 30 | color = 'k', 31 | flip_x = False, 32 | flip_y = False, 33 | chart_padding = 0.025, 34 | set_chart_size = False, 35 | plot_expected = True, 36 | zorder = None, 37 | **kwargs): 38 | """Utility function for adding a single line to an Axes 39 | 40 | Solves setting chart size, error rate vs significance or Accuracy vs confidence 41 | Parameters 42 | ---------- 43 | ax : matplotlib Axes 44 | Axes to plot in 45 | 46 | sign_vals : 1d ndarray 47 | 48 | error_rates : 1d ndarray 49 | 50 | label : str or None 51 | An optional label to add to the plotted values 52 | 53 | color : str or matplotlib recognized color-input 54 | 55 | flip_x : bool, default False 56 | If the x-axes should display significance level (`False`) or confidence (`True`) 57 | 58 | flip_y : bool, default False 59 | If the y-axes should display error-rate (`False`) or accuracy (`True`) 60 | 61 | chart_padding : float, (float,float) or None 62 | padding added to the chart-area outside of the min and max values found in data. If two values the first value will be used as x-padding and second y-padding. E.g. 0.025 means 2.5% on both sides 63 | 64 | set_chart_size : bool, default False 65 | If the chart size should be set 66 | 67 | plot_expected : bool, default True 68 | If the dashed 'expected' error/accuracy line should be plotted 69 | 70 | zorder : float or None 71 | An zorder (controlling the order of plotted artists), a higher zorder means plotting on-top of other objects plotted 72 | 73 | Returns 74 | ------- 75 | (x_label, y_label) : the str labels for what is plotted 76 | """ 77 | 78 | # Handle x-axis 79 | if flip_x: 80 | x_label, xs = 'Confidence', 1 - np.array(sign_vals) 81 | else: 82 | x_label, xs = 'Significance', sign_vals 83 | # Handle y-axis 84 | if flip_y: 85 | y_label, ys = 'Accuracy', 1 - np.array(error_rates) 86 | else: 87 | y_label, ys = 'Error rate', error_rates 88 | 89 | if set_chart_size: 90 | min = np.min([np.min(ys), np.min(xs)]) 91 | max = np.max([np.max(ys), np.max(xs)]) 92 | _set_chart_size(ax, 93 | [min,max], 94 | [min,max], 95 | chart_padding) 96 | 97 | if plot_expected: 98 | if flip_x == flip_y: 99 | # If both flipped or both normal 100 | ax.plot(xs, xs, '--', color='gray', linewidth=1) 101 | else: 102 | ax.plot(xs, 1-np.array(xs), '--', color='gray', linewidth=1) 103 | 104 | # If there's an explicit zorder - add it to a new dict 105 | if zorder is not None: 106 | kwargs = dict(kwargs, zorder=zorder) 107 | 108 | # Plot the values 109 | if color is not None: 110 | ax.plot(xs, ys, 111 | label=(label if label is not None else y_label), 112 | color=color, 113 | **kwargs) 114 | 115 | return (x_label,y_label) 116 | 117 | def plot_calibration(sign_vals = None, 118 | error_rates = None, 119 | error_rates_sd = None, 120 | conf_vals = None, 121 | accuracy_vals = None, 122 | accuracy_sd = None, 123 | labels = None, 124 | ax = None, 125 | figsize = (10,8), 126 | chart_padding=0.025, 127 | cm = None, 128 | flip_x = False, 129 | flip_y = False, 130 | title=None, 131 | tight_layout=True, 132 | plot_expected = True, 133 | sd_alpha = .3, 134 | **kwargs): 135 | ''' 136 | **Classification and regression ** - Create a calibration plot from computed values 137 | 138 | This function creates a plot of calibration curves, given precomputed values for either 139 | accuracy or error rates given significance or confidence. Note that either accuracy _or_ error_rate 140 | can be given (not both) and significance _or_ confidence (not both) must be given. Additionally 141 | a standard-deviation "_sd" parameters can be given which will be displayed with the same color (according to `cm` argument) 142 | behind the error/accuracy values. 143 | 144 | Parameters 145 | ---------- 146 | sign_vals : a 1D Iterable, default None 147 | Significance values for the corresponding accuracy/error rates 148 | 149 | error_rates : 1D or 2D list like, default None 150 | Error rates, either a single (e.g. overall value) or multiple (i.e. one for class) 151 | 152 | error_rates_sd : 1D or 2D list like, default None 153 | Standard deviations for the `error_rates`, used for plotting `error_rate +/- SD` areas 154 | 155 | conf_vals : a 1D Iterable, default None 156 | Confidence values for the corresponding accuracy/error rates 157 | 158 | accuracy_vals : 1D or 2D list like, default None 159 | Accuracy values, either a single (e.g. overall value) or multiple (i.e. one for class) 160 | 161 | accuracy_sd : 1D or 2D list like, default None 162 | Standard deviations for the `accuracy_vals`, used for plotting `accuracy +/- SD` areas 163 | 164 | labels : list of str, optional 165 | Descriptive labels for the input, for regression input it can be a single str, for classification 166 | the number of columns in `error_rates` or `accuracy_vals` should match the number of labels 167 | 168 | ax : matplotlib Axes, optional 169 | An existing matplotlib Axes to plot in (default None) 170 | 171 | figsize : float or (float, float), optional 172 | Figure size to generate, ignored if `ax` is given 173 | 174 | chart_padding : float, (float,float) or None, default 0.025 175 | padding added to the chart-area outside of the min and max values found in data. 176 | If two values the first value will be used as x-padding and second y-padding. E.g. 0.025 means 2.5% on both sides 177 | 178 | cm : color, list of colors or ListedColorMap, optional 179 | The colors to use. First color will be for class 0, second for class 1, .. 180 | 181 | flip_x : bool, default False 182 | If the x-axes should be 'flipped', i.e. if `sign_vals` was given the default is to display "Significance" on the x-axis, 183 | this will flip it and display "Confidence instead", or vise versa in case conf_vals was given 184 | 185 | flip_y : bool, default False 186 | If the y-axes should be 'flipped', i.e. if accuracy_vals was given, the default is to display "Accuracy" on the y-axis, 187 | this will flip it and instead display "Error rate". And vise versa if `error_rates` is given. 188 | 189 | title : str, optional 190 | Optional title that will be printed in 'x-large' font size (default None) 191 | 192 | tight_layout : bool, optional 193 | Set `tight_layout` on the matplotlib Figure object 194 | 195 | plot_expected : bool, optional 196 | Plot the diagonal, representing the expected error/accuracy (default `True`) 197 | 198 | **kwargs : dict, optional 199 | Keyword arguments, passed to matplotlib 200 | 201 | Returns 202 | ------- 203 | fig : Figure 204 | matplotlib.figure.Figure object 205 | 206 | See Also 207 | -------- 208 | matplotlib.colors.ListedColormap 209 | ''' 210 | 211 | colors = cm_as_list(cm, _default_color_map) 212 | 213 | # Validate either significance or confidence values were given 214 | if (sign_vals is None and conf_vals is None) or (sign_vals is not None and conf_vals is not None): 215 | raise ValueError('Either sign_vals or conf_vals must be given (not both)') 216 | # Validate either error_rates or accuracy are given 217 | if (error_rates is None and accuracy_vals is None) or (error_rates is not None and accuracy_vals is not None): 218 | raise ValueError('Either error_rates or accuracy_vals must be given (not both)') 219 | 220 | # ====================================================== 221 | # Deduce the x-values + label 222 | if sign_vals is not None: 223 | # Using sign input 224 | validate_sign(sign_vals) 225 | if len(sign_vals) < 2: 226 | raise ValueError('Must have at least 2 significance values to plot a calibration curve') 227 | x_lab = 'Confidence' if flip_x else 'Significance' 228 | x_vals = 1 - to_numpy1D(sign_vals,'sign_vals') if flip_x else to_numpy1D(sign_vals,'sign_vals') 229 | else: 230 | # Using conf input 231 | validate_sign(conf_vals) 232 | if len(conf_vals) < 2: 233 | raise ValueError('Must have at least 2 confidence values to plot a calibration curve') 234 | x_lab = 'Significance' if flip_x else 'Confidence' 235 | x_vals = 1 - to_numpy1D(conf_vals,'conf_vals') if flip_x else to_numpy1D(conf_vals,'conf_vals') 236 | 237 | # ====================================================== 238 | # Deduce the y-values + label 239 | if error_rates is not None: 240 | # Using error rate input 241 | y_lab = 'Accuracy' if flip_y else 'Error rate' 242 | y_vals = 1 - to_numpy2D(error_rates,'error_rates', unravel=True, min_num_cols=1) if flip_y else to_numpy2D(error_rates,'error_rates', unravel=True, min_num_cols=1) 243 | y_SD = None if error_rates_sd is None else to_numpy2D(error_rates_sd,'error_rates_sd', unravel=True, min_num_cols=1) 244 | else: 245 | # Using accuracy input 246 | y_lab = 'Error rate' if flip_y else 'Accuracy' 247 | y_vals = 1 - to_numpy2D(accuracy_vals,'accuracy_vals', unravel=True, min_num_cols=1) if flip_y else to_numpy2D(accuracy_vals,'accuracy_vals', unravel=True, min_num_cols=1) 248 | y_SD = None if accuracy_sd is None else to_numpy2D(accuracy_sd,'accuracy_sd', unravel=True, min_num_cols=1) 249 | 250 | # Create the figure and axis to plot in 251 | error_fig, ax = get_fig_and_axis(ax, figsize) 252 | 253 | # Create labels if not set 254 | if labels is None: 255 | labels = ['Overall'] 256 | if y_vals.shape[1]>1: 257 | labels += ['Label {}'.format(i-1) for i in range(1,y_vals.shape[1])] 258 | elif isinstance(labels,Iterable): 259 | if len(labels) < y_vals.shape[1]: 260 | raise ValueError('Invalid number of labels given, should be {}'.format(y_vals.shape[1])) 261 | if isinstance(labels,str): 262 | # str is iterable, which forces us to special case this 263 | if y_vals.shape[1]==1: 264 | labels = [labels] 265 | else: 266 | raise ValueError("Invalid 'labels' argument, should be a list of labels of length {}".format(y_vals.shape[1])) 267 | 268 | elif y_vals.shape[1]==1: 269 | # Single line to be plotted, wrap in a list 270 | labels = [labels] 271 | else: 272 | raise ValueError("Invalid 'labels' argument, should be a list of labels") 273 | 274 | 275 | # Check consistent length of x and y points 276 | check_consistent_length(y_vals,x_vals) 277 | 278 | # Set the chart size, flipping handled before, set to False 279 | _set_chart_size(ax,x_vals,y_vals, 280 | padding=chart_padding, 281 | flip_x=False, 282 | flip_y=False) 283 | 284 | # Plot the expected 285 | if plot_expected: 286 | if flip_x == flip_y: 287 | # If both flipped or both normal 288 | ax.plot(x_vals, x_vals, '--', color='gray', linewidth=1) 289 | else: 290 | ax.plot(x_vals, 1-np.array(x_vals), '--', color='gray', linewidth=1) 291 | 292 | z_offset = 20 293 | # Plot all curves 294 | for col in range(0,y_vals.shape[1]): 295 | # Plot SD area 296 | if y_SD is not None and y_SD.shape[1]>= col: 297 | ax.fill_between(x_vals, y_vals[:,col]-y_SD[:,col], y_vals[:,col]+y_SD[:,col], interpolate=True, zorder = col+z_offset, color = colors[col], alpha = sd_alpha) 298 | # Plot the mean line 299 | ax.plot(x_vals, y_vals[:,col], color=colors[col],zorder=col+1, label=labels[col],**kwargs) 300 | 301 | if flip_x != flip_y: 302 | ax.legend(loc='lower left') 303 | else: 304 | ax.legend(loc='lower right') 305 | 306 | _set_label_if_not_set(ax,x_lab, True) 307 | _set_label_if_not_set(ax,y_lab, False) 308 | _set_title(ax,title) 309 | 310 | if tight_layout: 311 | error_fig.tight_layout() 312 | 313 | return error_fig -------------------------------------------------------------------------------- /python/tests/pharmbio/cp/metrics/clf_metrics_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | 5 | 6 | from pharmbio.cp.metrics import (confusion_matrix,frac_errors,frac_multi_label_preds, 7 | frac_single_label_preds,n_criterion,u_criterion, 8 | f_criteria,s_criterion,cp_credibility,cp_confidence,frac_error) 9 | from statistics import mean 10 | import time 11 | from ....help_utils import get_resource 12 | 13 | class TestConfusionMatrix(): 14 | 15 | def test_only_0_class(self): 16 | p_vals_m = np.array([ 17 | [0.05, 0.85], 18 | [0.23, 0.1], 19 | [.1, .1], 20 | [.21, .21], 21 | [.3, 0.15] 22 | ]) 23 | tr = np.array([0,0,0,0,0]) 24 | expected_CM = np.array([[2, 0], [1,0], [1,0], [1,0]]) 25 | 26 | cm = confusion_matrix(tr,p_vals_m, sign=0.2) 27 | assert np.array_equal(cm.to_numpy(), expected_CM) 28 | # test input as normal arrays instead 29 | cm2 = confusion_matrix(tr.tolist(), p_vals_m.tolist(), sign=0.2) 30 | assert cm.equals( cm2 ) 31 | # assert cm == cm2 32 | # Test with Pandas as input 33 | pvals_pd = pd.DataFrame(p_vals_m) 34 | tr_pd = pd.Series(tr) 35 | cm3 = confusion_matrix(tr_pd, pvals_pd, sign=0.2) 36 | assert cm.equals( cm3 ) 37 | # assert cm == cm3 38 | 39 | def test_small_binary_ex(self): 40 | p_vals_m = np.array([[0.05, 0.85], [0.23, 0.1], [.1, .1], [.21, .21], [.3, 0.15]]) 41 | tr = np.array([0,1,0,1,0]) 42 | expected_CM = np.array([ 43 | [1,1], 44 | [1,0], 45 | [1,0], 46 | [0,1] 47 | ]) 48 | 49 | cm = confusion_matrix(tr,p_vals_m, sign=0.2) 50 | assert np.array_equal(cm.to_numpy(), expected_CM) 51 | 52 | def test_with_custom_labels(self): 53 | p_vals_m = np.array([[0.05, 0.85], [0.23, 0.1], [.1, .1], [.21, .21], [.3, 0.15]]) 54 | tr = np.array([0,1,0,1,0]) 55 | custom_labels = ['Mutagen', 'Nonmutagen'] 56 | expected_CM = np.array([[1, 1], [1,0], [1,0], [0,1]]) 57 | 58 | cm = confusion_matrix(tr,p_vals_m, sign=0.2, labels=custom_labels) 59 | assert cm.shape == expected_CM.shape 60 | assert np.array_equal(cm.to_numpy(), expected_CM) 61 | #assert(cm.shape, expected_CM.shape) 62 | assert np.array_equal(cm.to_numpy(), expected_CM) 63 | 64 | cm_names = list(cm.columns) 65 | assert custom_labels == cm_names 66 | assert custom_labels == list(cm.index.values)[:2] 67 | assert 4 == len(list(cm.index.values)) 68 | assert len(tr) == cm.to_numpy().sum() 69 | 70 | def test_3_class(self): 71 | p_vals_m = np.array( 72 | [ 73 | [0.05, 0.1, 0.5], 74 | [0.05, 0.1, 0.5], 75 | [0.05, 0.3, 0.5], 76 | [0.05, 0.3, 0.1], 77 | ] 78 | ) 79 | true_l = np.array([0, 1, 2, 0]) 80 | custom_labels = [4, 5, 6] 81 | cm = confusion_matrix(true_l,p_vals_m, sign=0.2, labels=custom_labels) 82 | 83 | expected_CM = np.array([ 84 | [0,0,0], 85 | [1,0,0], 86 | [1,1,0], 87 | [0,0,0], 88 | [0,0,1], 89 | [0,0,0] 90 | ]) 91 | assert np.array_equal(cm.to_numpy(), expected_CM) 92 | 93 | ## Ebba TODO write test 94 | 95 | def test_normalize_3(self): 96 | p_vals_m = np.array( 97 | [ 98 | [0.05, 0.1, 0.5], 99 | [0.05, 0.1, 0.5], 100 | [0.05, 0.3, 0.5], 101 | [0.05, 0.3, 0.1], 102 | ] 103 | ) 104 | true_l = np.array([0, 1, 2, 0]) 105 | custom_labels = [4, 5, 6] 106 | cm = confusion_matrix(true_l,p_vals_m, sign=0.2, labels=custom_labels, normalize_per_class=True) 107 | #print(cm) 108 | expected_CM = np.array([ 109 | [0,0,0], 110 | [.5,0,0], 111 | [.5,1,0], 112 | [0,0,0], 113 | [0,0,1], 114 | [0,0,0] 115 | ]) 116 | assert np.array_equal(cm.to_numpy(), expected_CM) 117 | 118 | class TestObservedMetrics(): 119 | 120 | @pytest.fixture 121 | def set_up(self): 122 | raw_data = np.genfromtxt(get_resource('transporters.p-values.csv'), delimiter=';', skip_header=1) 123 | self.true_labels = np.array([1 if x == 1.0 else 0 for x in raw_data[:,1]]) 124 | self.p_values = raw_data[:,[2,3]] 125 | 126 | multiclass_data = np.genfromtxt(get_resource('multiclass.csv'), delimiter=',') 127 | self.m_p_values = multiclass_data[:,1:] 128 | self.m_true_labels = multiclass_data[:,:1].astype(int) 129 | 130 | 131 | def test_3D_frac_err(self, set_up): 132 | sign_vals = [0.7,0.8,0.9] 133 | (overall,class_wise) = frac_errors(self.true_labels,self.p_values,sign_vals) 134 | assert 3 == overall.shape[0] # One for each sign-value 135 | assert len(overall.shape)==1 # 1D array 136 | assert (len(sign_vals),self.p_values.shape[1]) == class_wise.shape 137 | 138 | # Test using a single significance level 139 | (overall,class_wise) = frac_errors(self.true_labels,self.p_values,sign_vals=0.25) 140 | assert 1 == len(overall) 141 | assert (1, self.p_values.shape[1]) == class_wise.shape 142 | 143 | 144 | def test_3D_frac_err_multiclass(self, set_up): 145 | sign_vals = [0.7,0.8,0.9] 146 | (overall,class_wise) = frac_errors(self.m_true_labels,self.m_p_values,sign_vals) 147 | assert 3 == overall.shape[0] # One for each sign-value 148 | assert len(overall.shape) == 1 # 1D array 149 | assert (len(sign_vals),self.m_p_values.shape[1]) == class_wise.shape 150 | 151 | 152 | def test_3D_frac_err_multiclass_only_one_cls(self, set_up): 153 | sign_vals = [0.7,0.8,0.9] 154 | 155 | only_1_index = self.m_true_labels == 1 156 | ys = self.m_true_labels[only_1_index] 157 | pvals = self.m_p_values[only_1_index.reshape(-1),:] 158 | 159 | (overall,class_wise) = frac_errors(ys,pvals,sign_vals) 160 | # The 0 and 2 classes should have error-rate of 0 for all significance levels! 161 | assert np.all(np.zeros(3)==class_wise[:,0]) 162 | assert np.all(np.zeros(3)==class_wise[:,2]) 163 | assert 3 == overall.shape[0] # One for each sign-value 164 | assert len(overall.shape)==1 # 1D array 165 | assert (len(sign_vals),self.m_p_values.shape[1]) == class_wise.shape 166 | 167 | 168 | def test_3D_vs_2D(self, set_up): 169 | sign_vals = np.arange(0,1,0.01) 170 | # Here check consistent results 171 | (overall,cls_wise) = frac_errors(self.true_labels,self.p_values,sign_vals) 172 | joined_overall = [] 173 | joined_cls_wise = np.zeros((len(sign_vals),self.p_values.shape[1])) 174 | for i,s in enumerate(sign_vals): 175 | err, cls_ = frac_error(self.true_labels, self.p_values, s) 176 | joined_overall.append(err) 177 | joined_cls_wise[i,:] = cls_ 178 | assert np.allclose(overall,np.array(joined_overall)) 179 | assert np.allclose(joined_cls_wise, cls_wise) 180 | 181 | # Small benchmark of the two versions, not even considering the 182 | num_iter = 0 183 | tic = time.perf_counter() 184 | for _ in range(num_iter): 185 | for s in sign_vals: 186 | _ = frac_error(self.true_labels, self.p_values, s) 187 | toc = time.perf_counter() 188 | if num_iter >10: 189 | print(f"For loop in {toc - tic:0.4f} seconds") 190 | 191 | tic = time.perf_counter() 192 | for _ in range(num_iter): 193 | _ = frac_errors(self.true_labels,self.p_values,sign_vals) 194 | toc = time.perf_counter() 195 | if num_iter >10: 196 | print(f"All in one in {toc - tic:0.4f} seconds") 197 | 198 | 199 | def test_fraction_errors(self, set_up): 200 | # Taken from the values Ulf gave for this dataset 201 | sign = .25 202 | overall, (e0, e1) = frac_error(self.true_labels, self.p_values, sign) 203 | assert pytest.approx(0.1845) == round(overall, 4) 204 | assert pytest.approx(.24) == e0 205 | assert pytest.approx(.12,abs=1e-3) == e1 206 | 207 | sign=.2 208 | overall, (e0, e1) = frac_error(self.true_labels, self.p_values, sign) 209 | assert pytest.approx(.1459,abs=1e-4) == overall 210 | assert pytest.approx(.2) == e0 211 | assert pytest.approx(0.083, abs=1e-3) == e1 212 | 213 | sign=.15 214 | overall, (e0, e1) = frac_error(self.true_labels, self.p_values, sign) 215 | assert pytest.approx(.12,abs=1e-3) == overall 216 | assert pytest.approx(.16) == e0 217 | assert pytest.approx(0.074, abs=1e-3) == e1 218 | 219 | 220 | def test_single_label_ext(self, set_up): 221 | for s in np.arange(0,1,0.1): 222 | overall, correct_s, incorrect_s = frac_single_label_preds(self.true_labels, self.p_values, s) 223 | all_single, = frac_single_label_preds(None, self.p_values, s) 224 | assert pytest.approx(all_single) == correct_s+incorrect_s 225 | 226 | 227 | def test_multilabel_ext(self, set_up): 228 | for s in np.arange(0,1,0.1): 229 | overall, correct_m, incorrect_m = frac_multi_label_preds(self.true_labels, self.p_values, s) 230 | all_m, = frac_multi_label_preds(None, self.p_values, s) 231 | assert pytest.approx(all_m) == correct_m+incorrect_m 232 | assert 0 == incorrect_m # For binary - all multi-label are always correct! 233 | 234 | 235 | def test_multilabel_ext_3class(self, set_up): 236 | for s in np.arange(0,.1,0.01): 237 | overall, correct_m, incorrect_m = frac_multi_label_preds(self.m_true_labels, self.m_p_values, s) 238 | all_m, = frac_multi_label_preds(None, self.m_p_values, s) 239 | assert pytest.approx(correct_m+incorrect_m) == all_m 240 | 241 | 242 | def test_multilabel_ext_synthetic(self): 243 | p = np.array([ 244 | [0.1,0.2,0.3,0.5], 245 | [0.1,0.2,0.3,0.5], 246 | [0.1,0.2,0.3,0.5], 247 | [0.1,0.2,0.3,0.5], 248 | ]) 249 | s = 0.09 250 | all_m, correct_m, incorrect_m = frac_multi_label_preds([3,3,3,3], p, s) 251 | assert 1 == correct_m # All predicted and all correct 252 | assert 0 == incorrect_m 253 | assert all_m == correct_m + incorrect_m 254 | 255 | s = 0.11 256 | all_m, correct_m, incorrect_m = frac_multi_label_preds([0,0,0,3], p, s) 257 | assert all_m == correct_m + incorrect_m 258 | all_m, = frac_multi_label_preds(None, p, s) 259 | assert .25 == correct_m 260 | assert .75 == incorrect_m 261 | assert all_m == incorrect_m+correct_m 262 | 263 | class TestUnobMetrics(): 264 | 265 | @pytest.fixture 266 | def set_up(self): 267 | self.pvals_2 = np.array([ 268 | [0.05, 0.25], 269 | [0.45, 0.03], 270 | [0.65, 0.15], 271 | [0.03, 0.92] 272 | ]) 273 | 274 | self.pvals_3 = np.array([ 275 | [0.05, 0.25, 0.67], 276 | [0.45, 0.03, 0.8], 277 | [0.65, 0.15, 0.05], 278 | [0.03, 0.92, 0.76], 279 | [0.23, 0.02, 0.5] 280 | ]) 281 | 282 | self.pvals_4 = np.array([ 283 | [0.05, 0.25, 0.67, 0.2], 284 | [0.45, 0.03, 0.8, 0.1], 285 | [0.65, 0.15, 0.05, 0.5] 286 | ]) 287 | # Ulfs real-life data 288 | raw_data = np.genfromtxt(get_resource('transporters.p-values.csv'), delimiter=';', skip_header=1) 289 | self.p_values = raw_data[:,[2,3]] 290 | 291 | 292 | def test_single_label_preds(self, set_up): 293 | sign = 0.25 294 | assert pytest.approx(.974) == round(frac_single_label_preds(None, self.p_values, sign)[0],3) 295 | sign = 0.2 296 | assert pytest.approx(.893) == round(frac_single_label_preds(None,self.p_values, sign)[0],3) 297 | sign = 0.15 298 | assert pytest.approx(.79) == round(frac_single_label_preds(None,self.p_values, sign)[0],3) 299 | 300 | 301 | def test_f_criteria(self, set_up): 302 | assert pytest.approx(mean([.05, .03, .15,.03])) == f_criteria(self.pvals_2) 303 | assert pytest.approx(mean([.3, .48, .2, .79, .25])) == f_criteria(self.pvals_3) 304 | assert pytest.approx(mean([.5, .58, .7])) == f_criteria(self.pvals_4) 305 | 306 | 307 | def test_u_criterion(self, set_up): 308 | assert pytest.approx(mean([.05, .03, .15,.03])) == u_criterion(self.pvals_2) 309 | assert pytest.approx(mean([.25, .45, .15, .76, .23])) == u_criterion(self.pvals_3) 310 | assert pytest.approx(mean([.25, .45, .5])) == u_criterion(self.pvals_4) 311 | 312 | 313 | def test_n_criterion(self, set_up): 314 | # sig = 0.01 > all labels predicted! 315 | sig = 0.01 316 | assert 2 == n_criterion(self.pvals_2, sig) 317 | assert 3 == n_criterion(self.pvals_3, sig) 318 | assert 4 == n_criterion(self.pvals_4, sig) 319 | # sig = 0.1 - most labels predicted 320 | sig = 0.1 321 | assert pytest.approx(mean([1,1,2,1])) == n_criterion(self.pvals_2, sig) 322 | assert 2 == n_criterion(self.pvals_3, sig) 323 | assert mean([3,2,3]) == n_criterion(self.pvals_4, sig) 324 | # sig = 0.5 - few labels 325 | sig = 0.5 326 | assert pytest.approx(.5) == n_criterion(self.pvals_2, sig) 327 | assert pytest.approx(mean([1,1,1,2,0])) == n_criterion(self.pvals_3, sig) 328 | assert pytest.approx(1) == n_criterion(self.pvals_4, sig) 329 | # sig = 1.0 - no labels predicted 330 | sig = 1.0 331 | assert 0 == n_criterion(self.pvals_2, sig) 332 | assert 0 == n_criterion(self.pvals_3, sig) 333 | assert 0 == n_criterion(self.pvals_4, sig) 334 | 335 | 336 | def test_s_criterion(self, set_up): 337 | assert pytest.approx(mean([.3, .48, .8, .95])) == s_criterion(self.pvals_2) 338 | assert pytest.approx(mean([.97, 1.28, .85, 1.71, .75])) == s_criterion(self.pvals_3) 339 | assert pytest.approx(mean([1.17, 1.38, 1.35])) == s_criterion(self.pvals_4) 340 | 341 | 342 | def test_confidence(self, set_up): 343 | assert pytest.approx(mean([.95, .97, .85, .97])) == cp_confidence(self.pvals_2) 344 | assert pytest.approx(mean([.75, .55, .85, .24, .77])) == cp_confidence(self.pvals_3) 345 | assert pytest.approx(mean([.75, .55, .5])) == cp_confidence(self.pvals_4) 346 | 347 | 348 | def test_credibility(self, set_up): 349 | assert pytest.approx(mean([.25, .45,.65, .92])) == cp_credibility(self.pvals_2) 350 | assert pytest.approx(mean([.67, .8, .65, .92, .5])) == cp_credibility(self.pvals_3) 351 | assert pytest.approx(mean([.67, .8, .65])) == cp_credibility(self.pvals_4) 352 | 353 | 354 | 355 | -------------------------------------------------------------------------------- /python/src/pharmbio/cpsign/_load.py: -------------------------------------------------------------------------------- 1 | """Utility functions for loading and converting datasets 2 | """ 3 | 4 | import numpy as np 5 | from sklearn.utils import Bunch 6 | import pandas as pd 7 | import re 8 | from sklearn.utils import check_consistent_length 9 | 10 | 11 | def load_calib_stats(f, 12 | sep=',', 13 | overall_accuracy_regex = r'^accuracy$', 14 | overall_accuracy_sd_regex = r'^accuracy_SD$', 15 | accuracy_regex = r'^accuracy\((.*?)\)$', 16 | accuracy_sd_regex = r'^accuracy\((.*?)\)_SD$' 17 | ): 18 | """ 19 | Read a CSV formatted file with calibration statistics from CPSign 20 | 21 | Requires that the first column contains confidence levels and the accuracy values are picked out 22 | given the regex parameters. Depending on if there are standard-deviations given the output tuple will 23 | contain 3 or 5 values 24 | 25 | Parameters 26 | ---------- 27 | f : file or file path 28 | 29 | sep : column-separator, str, default ',' 30 | Character that separate columns in the CSV 31 | 32 | overall_accuracy_regex, overall_accuracy_sd_regex, accuracy_regex, accuracy_sd_regex : str or re.Pattern, optional 33 | Regex patterns for picking out accuracy values for each confidence level 34 | 35 | Returns 36 | ------- 37 | (sign_vals, error_rates, error_rates_SD, labels) 38 | significance levels, error-rates, standard-deviations of error-rate (may be `None` if not available in the file), labels 39 | for those. For regression the "labels" will be a single label 'Overall' as there are no error rates per class. 40 | The error-rates and error-rate-SD are 2d ndarray, where each column correspond to the same index in the labels-list 41 | 42 | """ 43 | # Convert to regex if not already given 44 | overall_accuracy_regex = __get_regex_or_None(overall_accuracy_regex) 45 | overall_accuracy_sd_regex = __get_regex_or_None(overall_accuracy_sd_regex) 46 | accuracy_regex = __get_regex_or_None(accuracy_regex) 47 | accuracy_sd_regex = __get_regex_or_None(accuracy_sd_regex) 48 | 49 | df = pd.read_csv(f,sep=sep) 50 | sign_vals = 1 - df.iloc[:,0].to_numpy() 51 | n_rows = len(sign_vals) 52 | 53 | # Create two lists, to make sure they are in the same order 54 | labels = [] 55 | sd_labels = [] 56 | accuracies = np.empty((n_rows,0)) 57 | accuracies_sd = np.empty((n_rows,0)) 58 | overall_sd = None # fallback if no SD column is given 59 | 60 | for i, c in enumerate(df.columns): 61 | if __matches(overall_accuracy_regex,c): 62 | overall = df.iloc[:,i].to_numpy().reshape((n_rows,1)) 63 | elif __matches(overall_accuracy_sd_regex,c): 64 | overall_sd = df.iloc[:,i].to_numpy().reshape((n_rows,1)) 65 | elif __matches(accuracy_regex,c): 66 | accuracies = np.hstack((accuracies, df.iloc[:,i].to_numpy().reshape(n_rows,1))) 67 | # Add label for this value 68 | labels.append(accuracy_regex.match(c).group(1)) 69 | elif __matches(accuracy_sd_regex, c): 70 | accuracies_sd = np.hstack((accuracies_sd, df.iloc[:,i].to_numpy().reshape(n_rows,1))) 71 | # Add label for this value 72 | sd_labels.append(accuracy_sd_regex.match(c).group(1)) 73 | # Overall stuff 74 | check_consistent_length(sign_vals,overall) 75 | 76 | # Only in case several accuracies (i.e. one per label) 77 | if accuracies.shape[1]>0: 78 | check_consistent_length(sign_vals,accuracies) 79 | if len(sd_labels)>0: 80 | check_consistent_length(sign_vals, accuracies_sd, overall_sd) 81 | if sd_labels is not None and sd_labels != labels: 82 | raise ValueError('Inconsistent input file, different labels for accuracies and SD versions') 83 | 84 | accuracies = np.hstack((overall, accuracies)) 85 | accuracies_sd = np.hstack((overall_sd,accuracies_sd)) 86 | labels = ['Overall'] + labels # pre-pend the overall label 87 | 88 | return sign_vals, 1 - accuracies, accuracies_sd, labels 89 | else: 90 | # Only overall accuracy given 91 | return sign_vals, 1-overall, overall_sd, ['Overall'] 92 | 93 | 94 | def load_reg_efficiency_stats(f, 95 | sep: str =',', 96 | skip_inf: bool = True, 97 | median_regex = r'.*median.*prediction.*interval.*width.*(?[^\]]+)\]'): 400 | """ 401 | Utility method for loading predictions from a file generated by CPSign 402 | 403 | Loads the true labels, the predicted p-values and the corresponding labels for those predictions 404 | 405 | Parameters 406 | ---------- 407 | f : str or buffer 408 | File path or buffer that `Pandas.read_csv` can read 409 | 410 | y_true_col : str or None 411 | The (case insensitive) column header of the true labels, or None if it should not be loaded 412 | 413 | sep : str, default ',' 414 | Delimiter that is used in the CSV between columns 415 | 416 | pvalue_regex : str, re.Pattern 417 | A regular expression for getting column headers matching those that should contain p-values, and that 418 | retrieves the textual label for the p-value 419 | 420 | Returns 421 | ------- 422 | (y_true, p_values, labels) 423 | The true labels (`y_true`), the predicted `p_values` and the corresponding `labels`. The labels are sorted 424 | in the same way as p-values matrix. Column 0 in `p_values` was the predictions for `labels[0]` etc. 425 | """ 426 | df = pd.read_csv(f, sep=sep) 427 | n_rows = len(df) 428 | 429 | pvalue_regex = __get_regex_or_None(pvalue_regex) 430 | y_label_lc = None if y_true_col is None else y_true_col.lower() 431 | 432 | y = None 433 | pvals = np.empty((n_rows,0)) 434 | labels = [] 435 | 436 | for i, c in enumerate(df.columns): 437 | if y_label_lc is not None and c.lower() == y_label_lc: 438 | y = df.iloc[:,i] 439 | elif __matches(pvalue_regex,c): 440 | pvals = np.hstack((pvals,df.iloc[:,i].to_numpy().reshape((n_rows,1)))) 441 | labels.append(pvalue_regex.match(c).group('label')) 442 | 443 | return y, pvals, labels 444 | 445 | 446 | def __get_regex_or_None(input): 447 | if input is None: 448 | return None 449 | if isinstance(input, re.Pattern): 450 | # Correct type already 451 | return input 452 | # Try to convert into re.Pattern 453 | return re.compile(input,re.IGNORECASE) 454 | 455 | 456 | def __matches(regex, txt): 457 | if regex is None: 458 | return False 459 | return regex.match(txt) is not None -------------------------------------------------------------------------------- /python/tests/resources/er.p-values.csv: -------------------------------------------------------------------------------- 1 | label,p[0],p[1] 2 | 0,0.952,0.0271 3 | 0,0.648,0.174 4 | 0,0.515,0.244 5 | 0,0.522,0.257 6 | 0,0.52,0.26 7 | 0,0.736,0.122 8 | 0,0.925,0.0515 9 | 0,0.885,0.0786 10 | 0,0.582,0.211 11 | 0,0.365,0.33 12 | 0,0.934,0.0645 13 | 0,0.765,0.127 14 | 1,0.532,0.24 15 | 0,0.867,0.0862 16 | 0,0.394,0.327 17 | 0,0.585,0.195 18 | 1,0.721,0.148 19 | 0,0.489,0.26 20 | 0,0.0174,0.796 21 | 0,0.354,0.33 22 | 1,0.713,0.174 23 | 0,0.944,0.0421 24 | 0,0.451,0.313 25 | 1,0.0289,0.733 26 | 0,0.538,0.237 27 | 1,0.499,0.228 28 | 0,0.93,0.0558 29 | 0,0.606,0.204 30 | 0,0.0239,0.793 31 | 0,0.244,0.393 32 | 0,0.0442,0.7 33 | 1,0.0171,0.793 34 | 0,0.233,0.404 35 | 0,0.158,0.488 36 | 1,0.00658,1.0 37 | 0,0.326,0.403 38 | 0,0.0356,0.749 39 | 1,0.02,0.768 40 | 1,0.328,0.385 41 | 1,0.00658,0.95 42 | 0,0.0689,0.602 43 | 0,0.866,0.0814 44 | 0,0.181,0.453 45 | 0,0.702,0.171 46 | 1,0.0464,0.679 47 | 1,0.0219,0.802 48 | 1,0.529,0.237 49 | 0,0.202,0.447 50 | 0,0.909,0.0806 51 | 0,0.185,0.478 52 | 0,0.89,0.0583 53 | 0,0.926,0.0617 54 | 0,0.942,0.0418 55 | 0,0.149,0.484 56 | 0,0.26,0.391 57 | 0,0.884,0.083 58 | 0,0.768,0.119 59 | 1,0.099,0.572 60 | 1,0.101,0.558 61 | 0,0.675,0.162 62 | 0,0.504,0.296 63 | 0,0.633,0.169 64 | 0,0.125,0.493 65 | 0,0.842,0.0754 66 | 0,0.704,0.161 67 | 0,0.0602,0.597 68 | 0,0.902,0.0595 69 | 0,0.951,0.0424 70 | 0,0.292,0.397 71 | 0,0.22,0.419 72 | 0,0.0445,0.722 73 | 0,0.436,0.292 74 | 0,0.445,0.296 75 | 0,0.671,0.188 76 | 1,0.532,0.236 77 | 0,0.173,0.449 78 | 0,0.371,0.302 79 | 1,0.101,0.578 80 | 0,0.28,0.415 81 | 0,0.208,0.468 82 | 1,0.00942,0.836 83 | 0,0.786,0.115 84 | 0,0.537,0.255 85 | 0,0.198,0.458 86 | 0,0.739,0.134 87 | 0,0.269,0.409 88 | 1,0.0206,0.774 89 | 0,0.943,0.0339 90 | 0,0.915,0.0793 91 | 0,0.508,0.248 92 | 0,0.794,0.117 93 | 0,0.953,0.0496 94 | 0,0.834,0.0907 95 | 0,0.721,0.172 96 | 1,0.672,0.163 97 | 0,0.887,0.08 98 | 0,0.26,0.396 99 | 0,0.127,0.492 100 | 0,0.863,0.0659 101 | 1,0.068,0.656 102 | 0,0.104,0.56 103 | 0,0.862,0.0748 104 | 0,0.232,0.432 105 | 0,0.327,0.377 106 | 0,0.366,0.319 107 | 0,0.0899,0.555 108 | 0,0.464,0.221 109 | 0,0.249,0.381 110 | 0,0.162,0.451 111 | 0,0.573,0.128 112 | 0,0.683,0.136 113 | 0,0.0596,0.634 114 | 0,0.733,0.11 115 | 0,0.193,0.399 116 | 0,0.301,0.329 117 | 0,0.952,0.0435 118 | 0,0.929,0.0393 119 | 0,0.947,0.0349 120 | 0,0.195,0.398 121 | 1,0.13,0.481 122 | 0,0.712,0.106 123 | 0,0.709,0.121 124 | 0,0.681,0.123 125 | 0,0.202,0.434 126 | 0,0.307,0.311 127 | 0,0.383,0.278 128 | 0,0.517,0.217 129 | 0,0.179,0.429 130 | 0,0.0837,0.575 131 | 1,0.467,0.23 132 | 1,0.501,0.211 133 | 0,0.515,0.227 134 | 0,0.0264,0.715 135 | 0,0.831,0.0762 136 | 1,0.127,0.5 137 | 0,0.543,0.202 138 | 0,0.639,0.152 139 | 0,0.347,0.296 140 | 1,0.362,0.308 141 | 0,0.365,0.307 142 | 0,0.488,0.199 143 | 0,0.73,0.107 144 | 0,0.309,0.312 145 | 0,0.728,0.108 146 | 0,0.549,0.203 147 | 0,0.65,0.127 148 | 1,0.502,0.234 149 | 1,0.0548,0.62 150 | 1,0.0125,0.787 151 | 0,0.947,0.0309 152 | 0,0.169,0.448 153 | 0,0.504,0.22 154 | 1,0.00881,0.812 155 | 0,0.0176,0.754 156 | 0,0.248,0.377 157 | 0,0.199,0.452 158 | 0,0.0787,0.552 159 | 0,0.511,0.185 160 | 1,0.246,0.347 161 | 0,0.902,0.0503 162 | 1,0.0516,0.635 163 | 0,0.879,0.0661 164 | 0,0.89,0.0646 165 | 0,0.314,0.305 166 | 0,0.615,0.149 167 | 0,0.867,0.0653 168 | 1,0.997,0.0256 169 | 0,0.519,0.201 170 | 0,0.596,0.144 171 | 1,0.874,0.0576 172 | 0,0.891,0.0537 173 | 0,0.531,0.187 174 | 0,0.7,0.118 175 | 0,0.846,0.0749 176 | 0,0.715,0.12 177 | 0,0.0594,0.633 178 | 0,0.786,0.0942 179 | 0,0.0801,0.554 180 | 1,0.392,0.299 181 | 0,0.803,0.0821 182 | 0,0.726,0.129 183 | 0,0.888,0.0627 184 | 0,0.509,0.224 185 | 0,0.861,0.0704 186 | 0,0.738,0.124 187 | 0,0.46,0.248 188 | 1,0.00872,0.823 189 | 0,0.00659,0.818 190 | 0,0.435,0.247 191 | 0,0.48,0.192 192 | 0,0.371,0.253 193 | 1,0.109,0.536 194 | 0,0.462,0.235 195 | 0,0.38,0.248 196 | 0,0.517,0.202 197 | 0,0.399,0.277 198 | 0,0.125,0.506 199 | 0,0.668,0.144 200 | 1,0.00658,0.893 201 | 0,0.131,0.475 202 | 0,0.0744,0.585 203 | 0,0.0263,0.701 204 | 1,0.454,0.247 205 | 1,0.0237,0.712 206 | 1,0.0171,0.774 207 | 1,0.0123,0.781 208 | 0,0.724,0.103 209 | 0,0.489,0.205 210 | 0,0.542,0.201 211 | 0,0.808,0.091 212 | 0,0.447,0.298 213 | 0,0.668,0.142 214 | 0,0.953,0.0397 215 | 0,0.422,0.311 216 | 0,0.563,0.226 217 | 0,0.59,0.181 218 | 1,0.0748,0.61 219 | 0,0.766,0.0865 220 | 0,0.771,0.101 221 | 1,0.0232,0.751 222 | 1,0.284,0.383 223 | 0,0.979,0.0256 224 | 0,0.262,0.384 225 | 0,0.0349,0.694 226 | 0,0.102,0.561 227 | 1,0.187,0.458 228 | 0,0.528,0.248 229 | 0,0.498,0.29 230 | 0,0.648,0.182 231 | 0,0.609,0.218 232 | 0,0.274,0.387 233 | 0,0.949,0.0367 234 | 0,0.617,0.215 235 | 1,0.00658,0.944 236 | 0,0.493,0.29 237 | 0,0.475,0.297 238 | 0,0.0432,0.669 239 | 0,0.817,0.089 240 | 1,0.412,0.349 241 | 0,0.99,0.0256 242 | 0,0.415,0.308 243 | 0,0.245,0.409 244 | 1,0.965,0.0256 245 | 0,0.178,0.46 246 | 1,0.00809,0.862 247 | 0,0.586,0.221 248 | 0,0.989,0.0256 249 | 1,0.00658,0.93 250 | 1,0.332,0.375 251 | 0,0.561,0.214 252 | 1,0.00658,0.904 253 | 1,0.00658,0.939 254 | 1,0.00658,0.941 255 | 0,0.251,0.423 256 | 0,0.128,0.537 257 | 0,0.208,0.441 258 | 0,0.096,0.527 259 | 0,0.307,0.367 260 | 0,0.971,0.0281 261 | 0,0.108,0.546 262 | 0,0.825,0.0868 263 | 0,0.406,0.305 264 | 1,0.154,0.491 265 | 1,0.00658,0.959 266 | 0,0.432,0.333 267 | 0,0.532,0.249 268 | 0,0.615,0.199 269 | 0,0.8,0.0926 270 | 1,0.32,0.376 271 | 0,0.942,0.0414 272 | 0,0.872,0.0641 273 | 0,0.861,0.0657 274 | 0,0.906,0.0576 275 | 0,0.489,0.256 276 | 0,0.374,0.341 277 | 0,0.887,0.0677 278 | 0,0.916,0.0461 279 | 0,0.462,0.281 280 | 0,0.0731,0.585 281 | 0,0.54,0.227 282 | 0,0.739,0.12 283 | 0,0.461,0.3 284 | 0,0.815,0.0848 285 | 0,0.426,0.332 286 | 0,0.322,0.37 287 | 0,0.54,0.286 288 | 0,0.456,0.337 289 | 1,0.0212,0.751 290 | 0,0.838,0.0698 291 | 0,0.918,0.0435 292 | 0,0.618,0.196 293 | 0,0.606,0.19 294 | 0,0.594,0.211 295 | 1,0.0464,0.664 296 | 0,0.293,0.368 297 | 0,0.594,0.194 298 | 0,0.755,0.142 299 | 0,0.632,0.177 300 | 0,0.614,0.224 301 | 1,0.0169,0.764 302 | 0,0.803,0.101 303 | 1,0.00658,0.949 304 | 0,0.308,0.367 305 | 0,0.747,0.112 306 | 0,0.261,0.371 307 | 0,0.457,0.287 308 | 0,0.957,0.0422 309 | 0,0.152,0.451 310 | 0,1.0,0.0256 311 | 0,0.493,0.277 312 | 0,0.636,0.181 313 | 0,0.522,0.263 314 | 0,0.669,0.168 315 | 0,0.63,0.207 316 | 1,0.382,0.341 317 | 0,0.233,0.455 318 | 0,0.529,0.257 319 | 0,0.891,0.0797 320 | 0,0.744,0.136 321 | 0,0.387,0.325 322 | 0,0.78,0.113 323 | 1,0.00658,0.956 324 | 0,0.935,0.0538 325 | 0,0.401,0.329 326 | 1,0.0519,0.681 327 | 0,0.817,0.0809 328 | 0,0.152,0.553 329 | 0,0.639,0.201 330 | 0,0.697,0.16 331 | 1,0.2,0.498 332 | 1,0.44,0.302 333 | 0,0.322,0.33 334 | 0,0.287,0.404 335 | 0,0.454,0.279 336 | 0,0.631,0.188 337 | 1,0.00835,0.88 338 | 0,0.628,0.237 339 | 1,0.066,0.673 340 | 0,0.399,0.332 341 | 1,0.00658,0.946 342 | 0,0.243,0.447 343 | 1,0.0109,0.886 344 | 1,0.012,0.874 345 | 0,0.0187,0.853 346 | 0,0.851,0.0938 347 | 0,0.159,0.552 348 | 0,0.835,0.0913 349 | 0,0.116,0.584 350 | 0,0.0874,0.62 351 | 0,0.0398,0.766 352 | 0,0.771,0.122 353 | 0,0.136,0.559 354 | 1,0.451,0.287 355 | 0,0.442,0.287 356 | 1,0.00658,0.966 357 | 0,0.378,0.314 358 | 0,0.706,0.147 359 | 0,0.226,0.466 360 | 1,0.24,0.418 361 | 0,0.765,0.142 362 | 0,0.349,0.394 363 | 0,0.531,0.251 364 | 0,0.821,0.0989 365 | 0,0.0478,0.705 366 | 0,0.923,0.0677 367 | 0,0.893,0.0724 368 | 0,0.603,0.23 369 | 1,0.48,0.275 370 | 0,0.308,0.39 371 | 0,1.0,0.0263 372 | 0,0.347,0.351 373 | 1,0.155,0.52 374 | 1,0.00658,0.957 375 | 0,0.737,0.18 376 | 0,0.634,0.177 377 | 0,0.0268,0.774 378 | 0,0.171,0.546 379 | 0,0.769,0.129 380 | 0,0.676,0.151 381 | 0,0.799,0.102 382 | 0,0.623,0.232 383 | 1,0.506,0.273 384 | 0,0.494,0.244 385 | 0,0.439,0.283 386 | 0,0.618,0.192 387 | 0,0.479,0.278 388 | 0,0.685,0.14 389 | 0,0.215,0.456 390 | 0,0.168,0.536 391 | 0,0.888,0.0905 392 | 0,0.111,0.642 393 | 0,0.213,0.511 394 | 0,0.893,0.0858 395 | 0,0.814,0.0987 396 | 0,0.587,0.202 397 | 0,0.193,0.497 398 | 1,0.0114,0.869 399 | 0,0.546,0.215 400 | 1,0.515,0.224 401 | 0,0.377,0.319 402 | 0,0.542,0.223 403 | 0,0.965,0.0541 404 | 0,0.864,0.101 405 | 1,0.421,0.286 406 | 0,0.312,0.387 407 | 1,0.426,0.347 408 | 0,0.438,0.298 409 | 0,0.421,0.285 410 | 0,0.766,0.119 411 | 0,0.338,0.38 412 | 0,0.258,0.379 413 | 0,0.987,0.0341 414 | 0,0.987,0.0339 415 | 0,0.0696,0.669 416 | 0,0.557,0.238 417 | 0,0.465,0.296 418 | 1,0.607,0.255 419 | 0,0.506,0.255 420 | 0,0.501,0.258 421 | 0,0.936,0.0662 422 | 0,0.592,0.194 423 | 0,0.126,0.517 424 | 0,0.755,0.104 425 | 0,0.58,0.224 426 | 0,0.82,0.065 427 | 0,0.106,0.534 428 | 1,0.304,0.373 429 | 1,0.941,0.0416 430 | 0,0.118,0.541 431 | 0,0.304,0.366 432 | 0,0.494,0.229 433 | 0,0.581,0.158 434 | 0,0.183,0.442 435 | 1,0.0194,0.788 436 | 1,0.00658,0.961 437 | 0,0.6,0.175 438 | 0,0.403,0.273 439 | 1,0.864,0.0597 440 | 0,0.909,0.0495 441 | 0,0.464,0.267 442 | 0,0.0892,0.572 443 | 1,0.456,0.264 444 | 0,0.464,0.248 445 | 0,0.651,0.127 446 | 1,0.135,0.535 447 | 0,0.738,0.0912 448 | 0,0.135,0.494 449 | 0,0.294,0.394 450 | 0,0.672,0.117 451 | 0,0.488,0.237 452 | 0,0.831,0.0781 453 | 0,0.934,0.0425 454 | 0,0.0284,0.734 455 | 0,0.447,0.246 456 | 1,0.00658,1.0 457 | 0,0.492,0.257 458 | 1,0.495,0.214 459 | 1,0.0596,0.614 460 | 0,0.325,0.346 461 | 1,0.298,0.462 462 | 0,0.311,0.344 463 | 1,0.00658,0.935 464 | 1,0.00658,0.963 465 | 0,0.0501,0.656 466 | 0,0.109,0.539 467 | 0,0.625,0.15 468 | 0,0.11,0.532 469 | 0,0.331,0.392 470 | 0,0.799,0.0729 471 | 0,0.269,0.405 472 | 0,0.432,0.316 473 | 0,0.522,0.202 474 | 0,0.16,0.501 475 | 0,0.0489,0.674 476 | 1,0.545,0.198 477 | 0,0.429,0.28 478 | 0,0.149,0.508 479 | 0,0.0178,0.79 480 | 0,0.377,0.31 481 | 0,0.847,0.0654 482 | 0,0.717,0.115 483 | 0,0.634,0.174 484 | 1,0.342,0.353 485 | 0,0.976,0.0267 486 | 0,0.831,0.0765 487 | 0,0.403,0.259 488 | 0,0.552,0.18 489 | 0,0.308,0.35 490 | 0,0.261,0.395 491 | 0,0.257,0.396 492 | 0,0.79,0.0764 493 | 0,0.945,0.0464 494 | 0,0.626,0.153 495 | 0,0.183,0.486 496 | 0,0.402,0.32 497 | 0,0.811,0.0963 498 | 1,0.671,0.124 499 | 0,0.302,0.379 500 | 0,0.364,0.334 501 | 0,0.251,0.416 502 | 0,0.303,0.402 503 | 0,0.559,0.211 504 | 0,0.878,0.0508 505 | 0,0.886,0.0522 506 | 0,0.547,0.197 507 | 0,0.435,0.235 508 | 1,0.254,0.405 509 | 0,0.545,0.228 510 | 0,0.557,0.196 511 | 0,0.644,0.173 512 | 0,0.565,0.171 513 | 0,0.615,0.194 514 | 1,0.444,0.249 515 | 1,0.129,0.519 516 | 0,0.714,0.146 517 | 0,0.516,0.215 518 | 0,0.987,0.0256 519 | 0,0.749,0.129 520 | 0,0.0136,0.797 521 | 0,0.197,0.449 522 | 1,0.00658,0.908 523 | 0,0.256,0.401 524 | 0,0.198,0.442 525 | 0,0.129,0.513 526 | 1,0.283,0.341 527 | 0,0.97,0.0256 528 | 0,0.667,0.17 529 | 0,0.968,0.0282 530 | 0,0.344,0.339 531 | 0,0.0613,0.626 532 | 0,0.804,0.0935 533 | 0,0.741,0.119 534 | 0,0.505,0.224 535 | 0,0.717,0.127 536 | 0,0.775,0.103 537 | 0,0.935,0.0416 538 | 0,0.756,0.112 539 | 0,0.382,0.267 540 | 0,0.591,0.176 541 | 0,0.522,0.245 542 | 0,1.0,0.0256 543 | 0,0.973,0.0256 544 | 0,0.869,0.0738 545 | 0,0.339,0.328 546 | 0,0.212,0.445 547 | 1,0.0137,0.807 548 | 1,0.0136,0.792 549 | 0,0.0329,0.691 550 | 0,0.796,0.113 551 | 0,0.582,0.204 552 | 0,0.334,0.405 553 | 0,0.835,0.0695 554 | 0,0.549,0.19 555 | 1,0.0583,0.592 556 | 1,0.116,0.505 557 | 1,0.0413,0.68 558 | 0,0.386,0.278 559 | 0,0.831,0.0723 560 | 0,0.0457,0.656 561 | 1,0.709,0.136 562 | 0,0.515,0.221 563 | 0,0.712,0.138 564 | 0,0.448,0.281 565 | 1,0.704,0.147 566 | 0,0.879,0.0601 567 | 0,0.101,0.55 568 | 0,0.769,0.0966 569 | 1,0.0932,0.556 570 | 0,0.263,0.337 571 | 0,0.393,0.275 572 | 0,0.335,0.299 573 | 1,0.0529,0.624 574 | 0,0.559,0.194 575 | 0,0.61,0.161 576 | 0,0.0576,0.625 577 | 0,0.294,0.359 578 | 0,0.46,0.257 579 | 0,0.45,0.241 580 | 0,0.323,0.337 581 | 0,0.19,0.43 582 | 0,0.275,0.362 583 | 1,0.0256,0.717 584 | 1,0.55,0.201 585 | 0,0.947,0.0264 586 | 0,0.912,0.062 587 | 0,0.145,0.507 588 | 0,0.744,0.12 589 | 0,0.905,0.0626 590 | 0,0.737,0.116 591 | 0,0.0894,0.55 592 | 0,0.921,0.0511 593 | 0,0.904,0.0502 594 | 0,0.891,0.0501 595 | 0,0.743,0.119 596 | 1,0.199,0.436 597 | 0,0.428,0.274 598 | 1,0.582,0.182 599 | 0,0.354,0.323 600 | 1,0.437,0.216 601 | 0,0.132,0.447 602 | 1,0.309,0.312 603 | 0,0.483,0.255 604 | 0,0.743,0.102 605 | 0,0.344,0.263 606 | 0,0.0146,0.77 607 | 0,0.66,0.157 608 | 0,0.717,0.129 609 | 0,0.45,0.267 610 | 0,0.79,0.112 611 | 0,0.677,0.139 612 | 0,0.584,0.16 613 | 0,0.434,0.256 614 | 1,0.661,0.151 615 | 0,0.284,0.337 616 | 1,0.116,0.521 617 | 1,0.385,0.254 618 | 0,0.586,0.192 619 | 0,0.687,0.12 620 | 0,0.553,0.187 621 | 0,0.444,0.235 622 | 0,0.333,0.293 623 | 0,0.358,0.305 624 | 1,0.0588,0.65 625 | 1,0.0832,0.541 626 | 0,0.465,0.225 627 | 0,0.914,0.0521 628 | 1,0.494,0.198 629 | 0,0.52,0.223 630 | 0,0.868,0.0769 631 | 0,0.416,0.241 632 | 0,0.525,0.184 633 | 1,0.0175,0.735 634 | 0,0.453,0.189 635 | 0,0.701,0.115 636 | 0,0.154,0.461 637 | 0,0.956,0.0334 638 | 0,0.286,0.315 639 | 0,0.467,0.226 640 | 0,0.524,0.182 641 | 0,0.406,0.262 642 | 1,0.143,0.445 643 | 0,0.762,0.115 644 | 0,0.521,0.156 645 | 1,0.389,0.267 646 | 0,0.417,0.244 647 | 0,0.0127,0.747 648 | 0,0.479,0.234 649 | 0,0.858,0.0734 650 | 1,0.748,0.131 651 | 0,0.443,0.226 652 | 0,0.551,0.185 653 | 0,0.271,0.362 654 | 0,0.883,0.0488 655 | 1,0.0196,0.711 656 | 0,0.979,0.0256 657 | 0,0.424,0.263 658 | 0,0.255,0.344 659 | 1,0.00764,0.844 660 | 0,0.812,0.0995 661 | 0,0.47,0.212 662 | 0,0.962,0.0256 663 | 1,0.0172,0.723 664 | 0,0.447,0.216 665 | 0,0.204,0.414 666 | 0,0.635,0.146 667 | 1,0.268,0.328 668 | 0,0.538,0.174 669 | 1,0.0074,0.856 670 | 0,0.324,0.323 671 | 0,0.411,0.215 672 | 1,0.0102,0.79 673 | 0,0.201,0.423 674 | 0,0.563,0.166 675 | 0,0.0987,0.502 676 | 1,0.00658,0.937 677 | 1,0.0984,0.521 678 | 0,0.522,0.187 679 | 1,0.731,0.108 680 | 0,0.0067,0.845 681 | 0,0.602,0.168 682 | 0,0.0579,0.605 683 | 0,0.11,0.541 684 | 0,0.628,0.164 685 | 0,0.00658,0.86 686 | 1,0.0199,0.751 687 | 0,0.148,0.466 688 | 0,0.99,0.0256 689 | 0,0.148,0.452 690 | 0,0.648,0.155 691 | 0,0.156,0.464 692 | 0,0.152,0.437 693 | 0,0.0981,0.487 694 | 0,0.139,0.473 695 | 0,0.577,0.166 696 | 0,0.27,0.316 697 | 0,0.48,0.182 698 | 0,0.801,0.107 699 | 1,0.629,0.161 700 | 0,0.405,0.227 701 | 0,0.735,0.123 702 | 0,0.741,0.118 703 | 1,0.0169,0.757 704 | 0,0.938,0.0435 705 | 0,0.625,0.167 706 | 0,0.721,0.152 707 | 0,0.882,0.0581 708 | 0,0.691,0.136 709 | 0,0.48,0.195 710 | 0,0.17,0.4 711 | 0,0.555,0.134 712 | 1,0.665,0.154 713 | 0,0.714,0.12 714 | 0,0.436,0.223 715 | 1,0.89,0.07 716 | 0,0.00658,0.885 717 | 0,0.861,0.0688 718 | 1,0.194,0.419 719 | 1,0.0232,0.706 720 | 0,0.547,0.156 721 | 0,0.0187,0.695 722 | 0,0.401,0.266 723 | 0,0.107,0.525 724 | 0,0.196,0.411 725 | 1,0.979,0.0256 726 | 0,0.798,0.106 727 | 0,0.384,0.266 728 | 0,0.589,0.17 729 | 0,0.376,0.266 730 | 0,0.418,0.259 731 | 0,0.0435,0.601 732 | 1,0.173,0.417 733 | 0,0.194,0.394 734 | 0,0.154,0.422 735 | 0,0.833,0.102 736 | 0,0.553,0.189 737 | 0,0.616,0.163 738 | 0,0.687,0.114 739 | 1,0.203,0.403 740 | 1,0.141,0.508 741 | 0,0.0296,0.654 742 | 1,0.0108,0.802 743 | 0,0.622,0.132 744 | 0,0.849,0.04 745 | 1,0.0922,0.558 746 | 0,0.314,0.372 747 | 0,0.426,0.291 748 | 0,0.577,0.184 749 | 1,0.406,0.291 750 | 0,0.204,0.454 751 | 0,0.473,0.266 752 | 0,0.861,0.0496 753 | 0,0.584,0.148 754 | 1,0.0373,0.648 755 | 0,0.584,0.209 756 | 1,0.352,0.35 757 | 0,0.66,0.132 758 | 0,0.91,0.0432 759 | 0,0.708,0.121 760 | 0,0.387,0.331 761 | 0,0.198,0.439 762 | 0,0.235,0.425 763 | 0,0.281,0.349 764 | 1,0.043,0.66 765 | 0,0.478,0.271 766 | 0,0.862,0.0588 767 | 0,0.385,0.344 768 | 0,0.0349,0.678 769 | 1,0.575,0.193 770 | 0,0.375,0.281 771 | 0,0.133,0.531 772 | 1,0.829,0.0584 773 | 1,0.00658,0.875 774 | 1,0.0301,0.664 775 | 1,0.0522,0.629 776 | 0,0.279,0.399 777 | 1,0.504,0.238 778 | 0,0.0545,0.633 779 | 0,0.481,0.29 780 | 0,0.367,0.343 781 | 0,0.348,0.348 782 | 0,0.0902,0.567 783 | 0,0.161,0.526 784 | 0,0.424,0.326 785 | 0,0.406,0.31 786 | 0,0.0566,0.654 787 | 0,0.184,0.5 788 | 0,0.0973,0.568 789 | 0,0.0804,0.593 790 | 0,0.841,0.067 791 | 0,0.844,0.0718 792 | 0,0.0525,0.662 793 | 1,0.0272,0.695 794 | 0,0.372,0.348 795 | 0,0.784,0.0742 796 | 1,0.00658,1.0 797 | 0,0.57,0.209 798 | 0,0.796,0.0834 799 | 0,0.823,0.066 800 | 0,0.0294,0.718 801 | 0,0.409,0.296 802 | 0,0.248,0.395 803 | 0,0.709,0.13 804 | 0,0.598,0.16 805 | 0,0.896,0.0451 806 | 0,0.821,0.0816 807 | 0,0.739,0.111 808 | 0,0.911,0.0434 809 | 0,0.453,0.276 810 | 0,0.415,0.323 811 | 0,0.0244,0.701 812 | 0,0.822,0.0627 813 | 0,0.464,0.284 814 | 0,0.772,0.0899 815 | 1,0.467,0.275 816 | 0,0.202,0.46 817 | 0,0.305,0.346 818 | 0,0.894,0.0608 819 | 0,0.605,0.142 820 | 0,0.909,0.0364 821 | 0,0.467,0.245 822 | 0,0.942,0.0395 823 | 0,0.441,0.304 824 | 0,0.415,0.293 825 | 0,0.119,0.521 826 | 0,0.445,0.273 827 | 0,0.0918,0.591 828 | 0,0.511,0.266 829 | 0,0.231,0.441 830 | 0,0.88,0.0468 831 | 1,0.00658,0.929 832 | 0,0.914,0.044 833 | 0,0.954,0.0256 834 | 0,0.362,0.323 835 | 1,0.632,0.128 836 | 1,0.0507,0.644 837 | 1,0.155,0.467 838 | 0,0.521,0.248 839 | 0,0.811,0.0827 840 | 0,0.825,0.0687 841 | 0,0.703,0.104 842 | 0,0.406,0.294 843 | 0,0.654,0.151 844 | 0,0.473,0.239 845 | 1,0.333,0.295 846 | 0,0.228,0.414 847 | 0,0.265,0.36 848 | 0,0.666,0.147 849 | 0,0.59,0.199 850 | 1,0.438,0.243 851 | 0,0.88,0.0707 852 | 0,0.808,0.0985 853 | 0,0.528,0.224 854 | 0,0.901,0.0619 855 | 0,0.435,0.272 856 | 0,0.89,0.0665 857 | 0,0.567,0.206 858 | 0,0.682,0.123 859 | 0,0.721,0.1 860 | 0,0.112,0.534 861 | 0,0.143,0.464 862 | 0,0.216,0.386 863 | 0,0.937,0.0464 864 | 0,0.532,0.189 865 | 0,0.379,0.26 866 | 0,0.172,0.425 867 | 1,0.00658,0.892 868 | 0,0.744,0.108 869 | 0,0.553,0.193 870 | 0,0.398,0.267 871 | 0,0.198,0.448 872 | 1,0.00658,0.924 873 | 0,0.0283,0.748 874 | 0,0.195,0.437 875 | 0,0.904,0.0696 876 | 0,0.118,0.512 877 | 0,0.869,0.0624 878 | 1,0.00658,0.887 879 | 0,0.314,0.315 880 | 1,0.827,0.103 881 | 1,0.934,0.048 882 | 0,0.372,0.269 883 | 0,0.407,0.259 884 | 1,0.396,0.258 885 | 1,0.0261,0.759 886 | 0,0.378,0.276 887 | 0,0.624,0.159 888 | 0,0.184,0.416 889 | 1,0.265,0.376 890 | 1,0.0899,0.561 891 | 0,0.341,0.313 892 | 1,0.72,0.111 893 | 0,0.282,0.344 894 | 0,0.193,0.392 895 | 1,0.832,0.103 896 | 0,0.672,0.16 897 | 0,0.38,0.281 898 | 0,0.14,0.504 899 | 0,0.13,0.509 900 | 0,0.763,0.0997 901 | 1,0.0255,0.755 902 | 0,0.777,0.0962 903 | 0,0.784,0.0854 904 | 1,0.763,0.0954 905 | 0,0.749,0.0763 906 | 0,0.648,0.16 907 | 0,0.771,0.0976 908 | 0,0.914,0.0679 909 | 0,0.402,0.269 910 | 0,0.518,0.21 911 | 1,0.245,0.357 912 | 0,0.593,0.171 913 | 0,0.591,0.194 914 | 0,0.754,0.102 915 | 0,0.572,0.185 916 | 0,0.475,0.271 917 | 0,0.474,0.265 918 | 0,0.501,0.231 919 | 1,0.427,0.251 920 | 0,0.283,0.374 921 | 0,0.513,0.206 922 | 0,0.25,0.398 923 | 0,0.279,0.373 924 | 0,0.0906,0.554 925 | 0,0.288,0.305 926 | 0,0.325,0.31 927 | 0,0.625,0.168 928 | 0,0.577,0.225 929 | 0,0.501,0.235 930 | 1,0.443,0.241 931 | 0,0.134,0.537 932 | 0,0.798,0.0943 933 | 0,0.0516,0.657 934 | 0,0.747,0.115 935 | 0,0.221,0.41 936 | 1,0.814,0.0926 937 | 0,0.661,0.134 938 | 1,0.627,0.184 939 | 0,0.4,0.286 940 | 0,0.318,0.292 941 | 0,0.993,0.0256 942 | 0,0.506,0.212 943 | 0,0.0513,0.671 944 | 0,0.235,0.392 945 | 0,0.125,0.56 946 | 1,0.565,0.206 947 | 0,0.552,0.209 948 | 0,0.865,0.0724 949 | 1,0.208,0.451 950 | 0,0.201,0.44 951 | 1,0.971,0.0256 952 | 0,0.637,0.189 953 | 0,0.506,0.268 954 | 0,0.59,0.215 955 | 0,0.76,0.121 956 | 1,0.112,0.512 957 | 0,0.131,0.495 958 | 0,0.397,0.321 959 | 1,0.175,0.484 960 | 0,0.0651,0.576 961 | 0,0.798,0.0989 962 | 0,0.899,0.0438 963 | 1,0.0121,0.826 964 | 0,0.857,0.0589 965 | 0,0.213,0.443 966 | 0,0.292,0.365 967 | 0,0.395,0.298 968 | 1,0.0218,0.746 969 | 1,0.497,0.3 970 | 0,0.404,0.301 971 | 1,0.0629,0.633 972 | 1,0.0279,0.702 973 | 0,0.68,0.199 974 | 0,0.305,0.372 975 | 0,0.167,0.475 976 | 0,0.269,0.39 977 | 0,0.839,0.07 978 | 0,0.113,0.53 979 | 0,0.0225,0.766 980 | 0,0.109,0.554 981 | 0,0.987,0.0256 982 | 0,0.1,0.52 983 | 1,0.0884,0.536 984 | 0,0.282,0.425 985 | 1,0.00658,0.919 986 | 0,0.623,0.201 987 | 1,0.0211,0.729 988 | 0,0.46,0.268 989 | 0,0.988,0.0256 990 | 1,0.112,0.501 991 | 0,0.593,0.196 992 | 0,0.489,0.28 993 | 0,0.792,0.105 994 | 0,0.792,0.0861 995 | 0,0.561,0.253 996 | 0,0.278,0.373 997 | 0,0.412,0.287 998 | 0,0.807,0.087 999 | 0,0.759,0.107 1000 | 0,0.262,0.369 1001 | 0,0.516,0.265 1002 | 0,0.249,0.371 1003 | 0,0.216,0.431 1004 | 0,0.0817,0.572 1005 | 0,0.075,0.556 1006 | 1,0.77,0.115 1007 | 0,0.614,0.177 1008 | 0,0.838,0.0711 1009 | 0,0.396,0.308 1010 | 0,0.801,0.0806 1011 | 0,0.408,0.3 1012 | 0,0.143,0.506 1013 | 0,0.75,0.125 1014 | 0,0.0915,0.537 1015 | 0,0.264,0.381 1016 | 0,0.6,0.161 1017 | 0,0.623,0.173 1018 | 0,0.667,0.186 1019 | 0,0.224,0.401 1020 | 0,0.52,0.272 1021 | 0,0.035,0.698 1022 | 0,0.191,0.472 1023 | 0,0.0534,0.672 1024 | 0,0.11,0.521 1025 | 1,0.0472,0.658 1026 | 0,0.633,0.176 1027 | 0,0.569,0.217 1028 | 0,0.616,0.182 1029 | 0,0.73,0.13 1030 | 0,0.828,0.0757 1031 | 0,0.751,0.123 1032 | 0,0.438,0.281 1033 | 0,0.528,0.22 1034 | 0,0.863,0.0799 1035 | 0,0.0942,0.546 1036 | 0,0.798,0.0745 1037 | 0,0.803,0.0839 1038 | 0,0.446,0.276 1039 | 0,0.119,0.522 1040 | 1,0.54,0.257 1041 | 0,0.346,0.352 1042 | 0,0.717,0.143 1043 | 0,0.716,0.123 1044 | 1,0.0202,0.76 1045 | 0,0.258,0.373 1046 | 1,0.168,0.467 1047 | 1,0.56,0.208 1048 | 1,0.468,0.274 1049 | 0,0.596,0.166 1050 | 0,0.00658,0.937 1051 | 1,0.11,0.499 1052 | 0,0.543,0.242 1053 | 0,0.491,0.247 1054 | -------------------------------------------------------------------------------- /python/src/pharmbio/cp/metrics/_classification.py: -------------------------------------------------------------------------------- 1 | """CP Classification metrics 2 | 3 | Module with classification metrics for CP. See https://arxiv.org/abs/1603.04416 4 | for references. Note that some metrics are 'unobserved' - i.e. a metric 5 | that can be calculated without knowing the ground truth (correct) labels 6 | for all predictions. 7 | 8 | """ 9 | 10 | import numpy as np 11 | import pandas as pd 12 | from collections import Counter 13 | 14 | from ..utils import * 15 | from sklearn.utils import check_consistent_length 16 | 17 | _default_significance = 0.8 18 | 19 | 20 | ###################################### 21 | ### OBSERVED METRICS 22 | ###################################### 23 | 24 | def frac_error(y_true, p_values, sign): 25 | """**Classification** - Calculate the fraction of errors 26 | 27 | Calculate the fraction of erroneous predictions at a given significance level `sign` 28 | 29 | Parameters 30 | ---------- 31 | y_true : 1D numpy array, list or pandas Series 32 | True labels 33 | 34 | p_values : 2D numpy array or DataFrame 35 | The predicted p-values, first column for the class 0, second for class 1, .. 36 | 37 | sign : float in [0,1] 38 | Significance the metric should be calculated for 39 | 40 | Returns 41 | ------- 42 | frac_error : float 43 | Overall fraction of errors 44 | 45 | label_wise_fraction_error : array, shape = (n_classes,) 46 | Fraction of errors for each true label, first index for class 0, ... 47 | 48 | See Also 49 | -------- 50 | frac_errors : caculate error rates for a list of significance levels at the same time - much faster! 51 | 52 | .. deprecated:: 53 | Use `frac_errors` instead as it uses vector functions and is roughly 30 times faster to compute 54 | """ 55 | validate_sign(sign) 56 | p_values = to_numpy2D(p_values,'p_values') 57 | y_true = to_numpy1D_int(y_true, 'y_true') 58 | 59 | check_consistent_length(y_true, p_values) 60 | 61 | total_errors = 0 62 | # lists containing errors/counts for each class label 63 | label_wise_errors = np.zeros(p_values.shape[1], dtype=float) # get an exception if not having float 64 | label_wise_counts = np.zeros(p_values.shape[1], dtype=int) 65 | 66 | for test_ex in range(0,p_values.shape[0]): 67 | ex_value = y_true[test_ex] 68 | #print(ex_value) 69 | if p_values[test_ex, ex_value] < sign: 70 | total_errors += 1 71 | label_wise_errors[ex_value] += 1 72 | label_wise_counts[ex_value] += 1 73 | 74 | label_wise_errors = np.divide(label_wise_errors, label_wise_counts, 75 | out=np.full_like(label_wise_errors,np.nan), 76 | where=np.array(label_wise_counts)!=0) 77 | 78 | return float(total_errors) / y_true.shape[0], label_wise_errors 79 | 80 | def frac_errors(y_true,p_values,sign_vals): 81 | """**Classification:** Calculate the fraction of errors for each significance level 82 | Parameters 83 | ---------- 84 | y_true : 1D numpy array, list or pandas Series 85 | True labels 86 | 87 | p_values : 2D numpy array or DataFrame 88 | The predicted p-values, first column for the class 0, second for class 1, .. 89 | 90 | sign : float in [0,1] 91 | Significance the metric should be calculated for 92 | 93 | Returns 94 | ------- 95 | (array, matrix) 96 | array : numpy 1D array 97 | Overall error rates, will have the same length as `sign_vals` 98 | matrix : numpy 2D array 99 | Class-wise error-rates, will have the shape (num_sign_vals, n_classes) 100 | 101 | """ 102 | validate_sign(sign_vals) 103 | if isinstance(sign_vals,float): 104 | sign_vals = [sign_vals] 105 | pval2D = to_numpy2D(p_values,'p_values',return_copy=False) 106 | predicted = _get_predicted(pval2D,sign_vals) 107 | (y_onehot, _) = to_numpy1D_onehot(y_true,'y_true',labels=np.arange(pval2D.shape[1])) 108 | overall_err = __calc_frac_errors(predicted[y_onehot]) 109 | 110 | cls_err = np.zeros((predicted.shape[2],y_onehot.shape[1]),dtype=np.float32) 111 | for c in range(y_onehot.shape[1]): 112 | if y_onehot[:,c].sum() < 1: 113 | continue 114 | cls_err[:,c] = __calc_frac_errors(predicted[y_onehot[:,c]][:,c,:]) 115 | 116 | return overall_err, cls_err 117 | 118 | def __calc_frac_errors(predictions): 119 | return (1 - (predictions.sum(axis=0) / predictions.shape[0])).reshape(-1) 120 | 121 | def _get_predicted(p_vals,sign_vals): 122 | preds = np.empty((p_vals.shape[0],p_vals.shape[1],len(sign_vals)),dtype=np.bool_) 123 | for i, s in enumerate(sign_vals): 124 | preds[:,:,i] = p_vals > s 125 | return preds 126 | 127 | def _unobs_frac_single_label_preds(p_values, sign): 128 | """**Classification** - Calculate the fraction of single label predictions 129 | 130 | Parameters 131 | ---------- 132 | p_values : array, 2D numpy array or DataFrame 133 | The predicted p-values, first column for the class 0, second for class 1, .. 134 | 135 | sign : float in [0,1] 136 | Significance the metric should be calculated for 137 | 138 | Returns 139 | ------- 140 | score : float 141 | """ 142 | validate_sign(sign) 143 | p_values = to_numpy2D(p_values,'p_values') 144 | 145 | predictions = p_values > sign 146 | return np.mean(np.sum(predictions, axis=1) == 1) 147 | 148 | def frac_single_label_preds(y_true, p_values, sign): 149 | """**Classification** - Calculate the fraction of single label predictions 150 | 151 | It is possible to both calculate this as an observed and un-observed metric, 152 | the `y_true` is given the function returns three values - if no true values 153 | are known - only the fraction of multi-label predictions is returned. 154 | 155 | Parameters 156 | ---------- 157 | y_true : 1D numpy array, list, pandas Series or None 158 | True labels or None. If given, the fraction of correct and incorrect 159 | single label predictions can be calculated as well. Otherwise this will 160 | be calculated in an unobserved fashion. 161 | 162 | p_values : 2D numpy array or DataFrame 163 | The predicted p-values, first column for the class 0, second for class 1, .. 164 | 165 | sign : float in [0,1] 166 | Significance the metric should be calculated for 167 | 168 | Returns 169 | ------- 170 | frac_single : float 171 | Overall fraction of single-labelpredictio 172 | 173 | frac_correct_single : float, optional 174 | Fraction of correct single label predictions, not returned if no `y_true` was given 175 | 176 | frac_incorrect_single : float, optional 177 | Fraction of incorrect single label predictions, not returned if no `y_true` was given 178 | """ 179 | # If no y_true - calculate in an un-observed fashion 180 | if y_true is None: 181 | return _unobs_frac_single_label_preds(p_values, sign), 182 | 183 | validate_sign(sign) 184 | p_values = to_numpy2D(p_values,'p_values') 185 | y_true = to_numpy1D_int(y_true, 'y_true') 186 | check_consistent_length(y_true, p_values) 187 | 188 | n_total = len(y_true) 189 | 190 | predictions = p_values > sign 191 | s_label_filter = np.sum(predictions, axis=1) == 1 192 | s_preds = predictions[s_label_filter] 193 | s_trues = y_true[s_label_filter] 194 | 195 | n_corr = 0 196 | n_incorr = 0 197 | for i in range(0, s_trues.shape[0]): 198 | if s_preds[i, s_trues[i]]: 199 | n_corr +=1 200 | else: 201 | n_incorr += 1 202 | 203 | return (n_corr+n_incorr)/n_total, n_corr/n_total, n_incorr/n_total 204 | 205 | def _unobs_frac_multi_label_preds(p_values, sign): 206 | """**Classification** - Calculate the fraction of multi-label predictions 207 | 208 | Calculates the fraction of multi-label predictions in an un-observed fashion - 209 | i.e. disregarding the true labels 210 | 211 | Parameters 212 | ---------- 213 | p_values : array, 2D numpy array or DataFrame 214 | The predicted p-values, first column for the class 0, second for class 1, .. 215 | 216 | sign : float in [0,1] 217 | Significance the metric should be calculated for 218 | 219 | Returns 220 | ------- 221 | float 222 | 223 | See Also 224 | -------- 225 | frac_multi_label_preds 226 | 227 | """ 228 | p_values = to_numpy2D(p_values,'p_values') 229 | validate_sign(sign) 230 | 231 | predictions = p_values > sign 232 | return np.mean(np.sum(predictions, axis=1) > 1) 233 | 234 | def frac_multi_label_preds(y_true, p_values, sign): 235 | """**Classification** - Calculate the fraction of multi-label predictions 236 | 237 | It is possible to both calculate this as an observed and un-observed metric, 238 | if the `y_true` is given the function returns three values - if no true values 239 | are known - only the fraction of multi-label predictions is returned. 240 | 241 | Parameters 242 | ---------- 243 | y_true : 1D numpy array, list, pandas Series, optional 244 | True labels or None. If given, the fraction of correct and incorrect 245 | multi-label predictions can be calculated as well. Otherwise this will 246 | be calculated in an unobserved fashion 247 | 248 | p_values : 2D numpy array or DataFrame 249 | The predicted p-values, first column for the class 0, second for class 1, .. 250 | 251 | sign : float in [0,1] 252 | Significance the metric should be calculated for 253 | 254 | Returns 255 | ------- 256 | frac_multi_label : float 257 | Fraction of multi-label predictions 258 | 259 | frac_correct : float or None 260 | Fraction of correct multi-label predictions (i.e. the true label is part of the set of predictions) 261 | Not returned if no `y_true` was given 262 | 263 | frac_incorrect : float or None 264 | Fraction of incorrect multi-label predictions. Not returned if no `y_true` was given 265 | """ 266 | # If no y_true - calculate in an un-observed fashion 267 | if y_true is None: 268 | return _unobs_frac_multi_label_preds(p_values, sign), 269 | 270 | validate_sign(sign) 271 | p_values = to_numpy2D(p_values,'p_values') 272 | y_true = to_numpy1D_int(y_true, 'y_true') 273 | check_consistent_length(y_true, p_values) 274 | 275 | n_total = len(y_true) 276 | 277 | predictions = p_values > sign 278 | m_label_filter = np.sum(predictions, axis=1) > 1 279 | m_preds = predictions[m_label_filter] 280 | m_trues = y_true[m_label_filter] 281 | 282 | n_corr = 0 283 | n_incorr = 0 284 | for i in range(0, m_trues.shape[0]): 285 | if m_preds[i, m_trues[i]]: 286 | n_corr +=1 287 | else: 288 | n_incorr +=1 289 | 290 | return (n_corr+n_incorr)/n_total, n_corr/n_total, n_incorr/n_total 291 | 292 | def obs_fuzziness(y_true, p_values): 293 | """**Classification** - Calculate the Observed Fuzziness (OF) 294 | 295 | Significance independent metric, smaller is better 296 | 297 | Parameters 298 | ---------- 299 | y_true : 1D numpy array, list or pandas Series 300 | True labels 301 | 302 | p_values : 2D numpy array or DataFrame 303 | The predicted p-values, first column for the class 0, second for class 1, .. 304 | 305 | Returns 306 | ------- 307 | obs_fuzz : float 308 | Observed fuzziness 309 | """ 310 | p_values = to_numpy2D(p_values,'p_values') 311 | y_true = to_numpy1D_int(y_true, 'y_true') 312 | check_consistent_length(y_true, p_values) 313 | 314 | of_sum = 0 315 | for i in range(0,p_values.shape[0]): 316 | # Mask the p-value of the true label 317 | p_vals_masked = np.ma.array(p_values[i,:], mask=False) 318 | p_vals_masked.mask[y_true[i]] = True 319 | # Sum the remaining p-values 320 | of_sum += p_vals_masked.sum() 321 | 322 | return of_sum / len(y_true) 323 | 324 | def confusion_matrix(y_true, 325 | p_values, 326 | sign, 327 | labels=None, 328 | normalize_per_class = False): 329 | """**Classification** - Calculate a conformal confusion matrix 330 | 331 | A conformal confusion matrix includes the number of predictions for each class, empty predition sets and 332 | multi-prediction sets. 333 | 334 | Parameters 335 | ---------- 336 | y_true : 1D numpy array, list or pandas Series 337 | True labels 338 | 339 | p_values : 2D numpy array or DataFrame 340 | The predicted p-values, first column for the class 0, second for class 1, .. 341 | 342 | sign : float in [0,1] 343 | Significance the confusion matrix should be calculated for 344 | 345 | labels : list of str, optional 346 | Descriptive labels for the classes 347 | 348 | normalize_per_class : bool, optional 349 | Normalizes the count so that each column sums to 1, good when visualizing imbalanced datasets (default False) 350 | 351 | Returns 352 | ------- 353 | cm : pandas DataFrame 354 | The confusion matrix 355 | """ 356 | validate_sign(sign) 357 | p_values = to_numpy2D(p_values,'p_values') 358 | y_true = to_numpy1D_int(y_true, 'y_true') 359 | check_consistent_length(y_true, p_values) 360 | 361 | predictions = p_values > sign 362 | 363 | n_class = p_values.shape[1] 364 | 365 | # We create two different 'multi-label' predictions, either including or excluding the correct label 366 | if n_class == 2: 367 | result_matrix = np.zeros((n_class+2, n_class)) 368 | else: 369 | result_matrix = np.zeros((n_class+3, n_class)) 370 | 371 | labels = get_str_labels(labels, get_n_classes(y_true,p_values)) 372 | 373 | # For every observed class - t 374 | for t in range(n_class): 375 | 376 | # Get the predictions for this class 377 | t_filter = y_true == t 378 | t_preds = predictions[t_filter] 379 | 380 | # For every (single) predicted label - p 381 | for p in range(n_class): 382 | predicted_p = [False]*n_class 383 | predicted_p[p] = True 384 | result_matrix[p,t] = (t_preds == predicted_p).all(axis=1).sum() 385 | 386 | # Empty predictions for class t 387 | result_matrix[n_class,t] = ( t_preds.sum(axis=1) == 0 ).sum() 388 | 389 | # multi-label predictions for class t 390 | t_multi_preds = t_preds.sum(axis=1) > 1 391 | t_num_all_multi = t_multi_preds.sum() 392 | if n_class == 2: 393 | result_matrix[n_class+1,t] = t_num_all_multi 394 | else: 395 | # For multi-class we have two different multi-sets - correct or incorrect! 396 | # first do a filter of rows that are multi-labeled then check t was predicted 397 | t_num_correct_multi = (t_preds[t_multi_preds][:,t] == True).sum() 398 | t_num_incorrect_multi = t_num_all_multi - t_num_correct_multi 399 | 400 | result_matrix[n_class+1,t] = t_num_correct_multi 401 | result_matrix[n_class+2,t] = t_num_incorrect_multi 402 | 403 | row_labels = list(labels) 404 | row_labels.append('Empty') 405 | if n_class == 2: 406 | row_labels.append('Both') 407 | else: 408 | row_labels.append('Correct Multi-set') 409 | row_labels.append('Incorrect Multi-set') 410 | 411 | if normalize_per_class: 412 | result_matrix = result_matrix / result_matrix.sum(axis=0) 413 | else: 414 | # Convert to int values! 415 | result_matrix = result_matrix.astype(int) 416 | 417 | return pd.DataFrame(result_matrix, columns=labels, index = row_labels) 418 | 419 | def confusion_matrices_multiclass(y_true, 420 | p_values, 421 | sign, 422 | labels=None, 423 | normalize_per_class = False): 424 | """**Classification** - Calculate two confusion matrix only for the incorrect multiclasses and only for the correct multiclass 425 | 426 | The correct multiclass confusion matrix will demonstrate the distribution of classes available in the multiclass cases where 427 | the true class is part of the multiclass. The incorrect multiclass confusion matrix will demonstrate the distribution of classes available in the multiclass cases where 428 | the true class is not part of the multiclass. 429 | 430 | For the binary case the same matrix will be returned twice. 431 | 432 | Parameters 433 | ---------- 434 | y_true : 1D numpy array, list or pandas Series 435 | True labels 436 | 437 | p_values : 2D numpy array or DataFrame 438 | The predicted p-values, first column for the class 0, second for class 1, .. 439 | 440 | sign : float in [0,1] 441 | Significance the confusion matrix should be calculated for 442 | 443 | labels : list of str, optional 444 | Descriptive labels for the classes 445 | 446 | normalize_per_class : bool, optional 447 | Normalizes the count so that each column sums to 1, good when visualizing imbalanced datasets (default False) 448 | 449 | Returns 450 | ------- 451 | cm : pandas DataFrame 452 | The confusion matrix 453 | """ 454 | validate_sign(sign) 455 | p_values = to_numpy2D(p_values,'p_values') 456 | y_true = to_numpy1D_int(y_true, 'y_true') 457 | check_consistent_length(y_true, p_values) 458 | 459 | predictions = p_values > sign 460 | n_class = p_values.shape[1] 461 | 462 | # We create two different 'multi-label' predictions, either including or excluding the correct label 463 | if n_class == 2: 464 | correct_result_matrix = np.zeros((n_class, n_class)) 465 | else: 466 | correct_result_matrix = np.zeros((n_class, n_class)) 467 | incorrect_result_matrix = np.zeros((n_class, n_class)) 468 | 469 | labels = get_str_labels(labels, get_n_classes(y_true,p_values)) 470 | 471 | # For every observed class - t 472 | for t in range(n_class): 473 | 474 | # Get the predictions for this class 475 | t_filter = y_true == t 476 | t_predictions = predictions[t_filter] 477 | 478 | # multi-label predictions for class t 479 | t_multi_preds = t_predictions.sum(axis=1) > 1 480 | if ~t_multi_preds.any(): 481 | correct_result_matrix[t] = np.zeros((1, n_class)) 482 | incorrect_result_matrix[t] = np.zeros((1, n_class)) 483 | else: 484 | t_multi_predictions = t_predictions[t_multi_preds] 485 | if n_class == 2: 486 | correct_result_matrix[t] =t_multi_predictions 487 | incorrect_result_matrix[t] = t_multi_predictions 488 | else: 489 | # For multi-class we have two different multi-sets - correct or incorrect! 490 | # first do a filter of rows that are multi-labeled then check t was predicted 491 | t_multi_predictions 492 | correct_predictions = t_multi_predictions[:,t] == True 493 | correct_t_multi_predictions = t_multi_predictions[correct_predictions] #TODO sum use np.add 494 | correct_t_multi_predictions = correct_t_multi_predictions.astype(int) 495 | incorrect_t_multi_predictions = t_multi_predictions[~correct_predictions]#TODO sum use np.add 496 | incorrect_t_multi_predictions = incorrect_t_multi_predictions.astype(int) 497 | 498 | summed_correct_t_multi_predictions = np.sum(correct_t_multi_predictions, axis=0, dtype=int) 499 | summed_incorrect_t_multi_predictions = np.sum(incorrect_t_multi_predictions, axis=0, dtype=int) 500 | 501 | correct_result_matrix[t] = summed_correct_t_multi_predictions 502 | incorrect_result_matrix[t] = summed_incorrect_t_multi_predictions 503 | 504 | row_labels = list(labels) 505 | 506 | if normalize_per_class: 507 | correct_result_matrix = correct_result_matrix / correct_result_matrix.sum(axis=0) 508 | incorrect_result_matrix = incorrect_result_matrix / incorrect_result_matrix.sum(axis=0) 509 | else: 510 | correct_result_matrix = correct_result_matrix.astype(int) 511 | incorrect_result_matrix = incorrect_result_matrix.astype(int) 512 | 513 | return pd.DataFrame(correct_result_matrix, columns=labels, index = row_labels), pd.DataFrame(incorrect_result_matrix, columns=labels, index = row_labels) 514 | 515 | ######################################## 516 | ### UNOBSERVED METRICS 517 | ######################################## 518 | 519 | 520 | def cp_credibility(p_values): 521 | """**Classification** - CP Credibility 522 | 523 | Mean of the largest p-values 524 | 525 | Parameters 526 | ---------- 527 | p_values : array, 2D numpy array or DataFrame 528 | The predicted p-values, first column for the class 0, second for class 1, .. 529 | 530 | Returns 531 | ------- 532 | credibility : float 533 | """ 534 | p_values = to_numpy2D(p_values,'p_values') 535 | sorted_matrix = np.sort(p_values, axis=1) 536 | return np.mean(sorted_matrix[:,-1]) # last index is the largest 537 | 538 | def cp_confidence(p_values): 539 | """**Classification** - CP Confidence 540 | 541 | Mean of 1-'second largest p-value' 542 | 543 | Parameters 544 | ---------- 545 | p_values : array, 2D numpy array or DataFrame 546 | The predicted p-values, first column for the class 0, second for class 1, .. 547 | 548 | Returns 549 | ------- 550 | confidence : float 551 | """ 552 | p_values = to_numpy2D(p_values,'p_values') 553 | sorted_matrix = np.sort(p_values, axis=1) 554 | return np.mean(1-sorted_matrix[:,-2]) 555 | 556 | def s_criterion(p_values): 557 | """**Classification** - S criterion 558 | 559 | Mean of the sum of all p-values 560 | 561 | Parameters 562 | ---------- 563 | p_values : array, 2D numpy array or DataFrame 564 | The predicted p-values, first column for the class 0, second for class 1, .. 565 | 566 | Returns 567 | ------- 568 | s_score : float 569 | """ 570 | p_values = to_numpy2D(p_values,'p_values') 571 | return np.mean(np.sum(p_values, axis=1)) 572 | 573 | def n_criterion(p_values, sign=_default_significance): 574 | """**Classification** - N criterion 575 | 576 | "Number" criterion - the average number of predicted labels. Significance dependent metric 577 | 578 | Parameters 579 | ---------- 580 | p_values : array, 2D numpy array or DataFrame 581 | The predicted p-values, first column for the class 0, second for class 1, .. 582 | 583 | sign : float, in range [0..1], default 0.8 584 | The significance level 585 | 586 | Returns 587 | ------- 588 | n_score : float 589 | """ 590 | p_values = to_numpy2D(p_values,'p_values') 591 | validate_sign(sign) 592 | return np.mean(np.sum(p_values > sign, axis=1)) 593 | 594 | def u_criterion(p_values): 595 | """**Classification** - U criterion - "Unconfidence" 596 | 597 | Smaller values are preferable 598 | 599 | Parameters 600 | ---------- 601 | p_values : array, 2D numpy array or DataFrame 602 | The predicted p-values, first column for the class 0, second for class 1, .. 603 | 604 | Returns 605 | ------- 606 | u_score : float 607 | """ 608 | p_values = to_numpy2D(p_values,'p_values') 609 | sorted_matrix = np.sort(p_values, axis=1) 610 | return np.mean(sorted_matrix[:,-2]) 611 | 612 | def f_criteria(p_values): 613 | """**Classification** - F criterion 614 | 615 | Average fuzziness. Average of the sum of all p-values apart from the largest one. 616 | Smaller values are preferable. 617 | 618 | Parameters 619 | ---------- 620 | p_values : array, 2D numpy array or DataFrame 621 | The predicted p-values, first column for the class 0, second for class 1, .. 622 | 623 | Returns 624 | ------- 625 | f_score : float 626 | """ 627 | p_values = to_numpy2D(p_values,'p_values') 628 | sorted_matrix = np.sort(p_values, axis=1) 629 | if sorted_matrix.shape[1] == 2: 630 | # Mean of only the smallest p-value 631 | return np.mean(sorted_matrix[:,0]) 632 | else: 633 | # Here we must take the sum of the values appart from the first column 634 | return np.mean(np.sum(sorted_matrix[:,:-1], axis=1)) 635 | --------------------------------------------------------------------------------