├── .gitignore ├── README.rst ├── doc ├── eights.jpeg ├── featuture_gen_tut.rst ├── new_features.md ├── readme.rst ├── sample_report.csv ├── sample_report.pdf ├── sweeping_architecture.odg ├── sweeping_architecture.pdf └── tutorial.py ├── eights ├── __init__.py ├── communicate │ ├── __init__.py │ ├── communicate.py │ └── communicate_helper.py ├── decontaminate │ ├── __init__.py │ └── decontaminate.py ├── generate │ ├── __init__.py │ └── generate.py ├── investigate │ ├── __init__.py │ ├── investigate.py │ └── investigate_helper.py ├── operate │ ├── __init__.py │ └── operate.py ├── perambulate │ ├── __init__.py │ ├── perambulate.py │ └── perambulate_helper.py ├── truncate │ ├── __init__.py │ ├── truncate.py │ └── truncate_helper.py └── utils.py ├── setup.py ├── test_sklearn_iris.py ├── test_wine.py └── tests ├── .DS_Store ├── data ├── full_test.csv ├── mixed.csv ├── small.db ├── test_communicate_ref.pdf ├── test_operate_std.pkl ├── test_perambulate │ ├── make_csv.csv │ ├── run_experiment.pkl │ ├── slice_by_best_score.pkl │ ├── slice_on_dimension_clf.pkl │ ├── slice_on_dimension_subset_params.pkl │ ├── sliding_windows.csv │ └── test_subsetting.pkl └── test_perambulate_ref.pdf ├── test_all.py ├── test_communicate.py ├── test_decontaminate.py ├── test_generate.py ├── test_investigate.py ├── test_operate.py ├── test_perambulate.py ├── test_sklearn_parity.py ├── test_truncate.py ├── test_utils.py ├── test_utils_for_tests.py └── utils_for_tests.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | something.txt 5 | .DS_Store 6 | # C extensions 7 | *.so 8 | Tweets-DataFixed.csv 9 | report.csv 10 | report.pdf 11 | eights_temp/ 12 | my_* 13 | 14 | # Distribution / packaging 15 | .Python 16 | env/ 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *,cover 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | 59 | # Sphinx documentation 60 | docs/_build/ 61 | 62 | # PyBuilder 63 | target/ 64 | 65 | # vim 66 | *.swp 67 | *-e 68 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | :: 2 | 3 | ============================================ 4 | _ _ _ 5 | (_) | | | | 6 | ___ _ __ _| |__ | |_ ___ 7 | / _ \ |/ _` | '_ \| __/ __| 8 | | __/ | (_| | | | | |_\__ \ 9 | \___|_|\__, |_| |_|\__|___/ 10 | __/ | 11 | |___/ 12 | 13 | ============================================ 14 | 15 | 16 | ------------ 17 | Introduction 18 | ------------ 19 | 20 | Eights is a a Python library and workflow templet for machine learning. 21 | Principally it wraps sklearn providing enhanced functionality and simplified 22 | interface of often used workflows. 23 | 24 | ------------ 25 | Installation 26 | ------------ 27 | 28 | `pip install git+git://github.com/dssg/eights.git` 29 | 30 | Required 31 | ======== 32 | 33 | Python packages 34 | --------------- 35 | - `Python 2.7 `_ 36 | - `Numpy `_ 37 | - `scikit-learn `_ 38 | - `pdfkit `_ 39 | 40 | Other packages 41 | -------------- 42 | 43 | - `wkhtmltopdf `_ 44 | 45 | Optional 46 | ======== 47 | 48 | Python packages 49 | --------------- 50 | - `plotlib `_ 51 | 52 | 53 | 54 | Other packages 55 | -------------- 56 | 57 | 58 | ------- 59 | Example 60 | ------- 61 | 62 | 63 | ---------- 64 | Next Steps 65 | ---------- 66 | 67 | my_* are included in the .gitignore. We recommend a standard such as my_exeperiment, my_storage for local folders. 68 | 69 | 70 | -------------------------------------------------------------------------------- /doc/eights.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dssg/eights/9f12f9fb60984b8da2270e0df809fa09027336e5/doc/eights.jpeg -------------------------------------------------------------------------------- /doc/featuture_gen_tut.rst: -------------------------------------------------------------------------------- 1 | **************************************** 2 | Using Eights for Feature Generation 3 | **************************************** 4 | Quick test:: 5 | 6 | import numpy as np 7 | 8 | import eights.investigate 9 | import eights.generate 10 | 11 | M = [[1,2,3], [2,3,4], [3,4,5]] 12 | col_names = ['heigh','weight', 'age'] 13 | lables= [0,0,1] 14 | 15 | # Eights uses Structured arrays, which allow for different data types in different columns 16 | M = eights.investigate.convert_list_of_list_to_sa(np.array(M), c_name=col_names) 17 | #By convention M is the our matrix on which our ML algo will run 18 | 19 | #This is a sample lambada statment, to show how easy it is to craft your own. 20 | #the signitutre(M, col_name, boundary) is standardized. 21 | def test_equality(M, col_name, boundary): 22 | return M[col_name] == boundary 23 | 24 | #This generates a new frow where the values are all true 25 | M_new = eights.generate.where_all_are_true( 26 | M, 27 | [test_equality, test_equality, test_equality], 28 | ['height','weight', 'age'], 29 | [1,2,3], 30 | ('new_column_name',) 31 | ) 32 | # Read top to bottom: 33 | # If test_equality in column 'height' == 1 AND 34 | # If test_equality in column 'weight' == 2 AND 35 | # If test_equality in column 'age' == 3 36 | # return true 37 | 38 | 39 | import numpy as np 40 | from sklearn.ensemble import RandomForestClassifier 41 | import sklearn.datasets 42 | import eights as e 43 | 44 | 45 | diab = sklearn.datasets.load_diabetes() 46 | 47 | data = diab.data 48 | target = diab.target 49 | 50 | M= e.inv.convert_list_of_list_to_sa(data) 51 | 52 | #quick sanity check 53 | #e.com.plot_simple_histogram(target) 54 | # i want to bin all values above 210 55 | y = (target >=205) 56 | 57 | # twist this into a classification problem 58 | 59 | e.inv.simple_CV(data,target, RandomForestClassifier) 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /doc/new_features.md: -------------------------------------------------------------------------------- 1 | # Feature Requests From Group Meetings, Week of August 10, 2015 2 | 3 | ## Australia 4 | 5 | * Incorporate SQL queries into perambulation. This would mean that we can make queries and populate table with the perambulate. 6 | 7 | * Descriptive statistics on SQL tables like col_names, num of categories, num of nulls. 8 | 9 | * Plot of feature importance. We want to know how importance scores trail off. And distribution. 10 | 11 | * This group spent the majority of their time figuring out what database schema meant. We probably can't help, but it's good to know that that was the majority of their workload 12 | 13 | ## Babies 14 | * Train where x is < than all values in col B. test where col B > x. x is an element of col A 15 | 16 | ## Cincinnati 17 | * Caching intermediary results. As with Drake, only regenerate files as needed. Have some mechanism to keep track of when we've run what. 18 | 19 | * Features like mean value of homes in an area, change of value of homes in this area. Aggregates over a partial set of data in table. average value of home price over last 3 years based on date inspected. 20 | 21 | * Choose columns to sum across based on value in a column. 22 | 23 | * Distance from this entry to nearest X. For example, distance from this home to nearest abandoned home. This is apparently much easier to do in PostGIS than in Python 24 | 25 | ## Feeding 26 | 27 | * Simple multiprocessing. Thin wrapper around joblib just so people know it's there. 28 | 29 | * Deduplicate rows. If rows are identical, remove them. 30 | 31 | * ROC curves w/ more than one series. If we have more than one class in our label, we will have more than one series to ROC. We do this in each case by making one class the baseline and comparing it against all other classes 32 | 33 | ## High School 34 | 35 | * The biggest thing is the really confusing cross validation rules that involve leaving out columns. First we pick a max grade, then we have to leave out columns, then the max grade determines how far apart our train and test sets need to be. According to Robin: 36 | 37 | ``` 38 | min_year = train_start = min(available_cohorts) 39 | max_year = test_end = max(available_cohorts) 40 | for max_grade in seq(9, 11){ 41 | for train_end in seq(min_year, max_year - (12 - max_grade)){ 42 | test_start = train_end + (12 - max_grade) 43 | #...train on cohorts between train_start and train_end 44 | #...test on cohorts between test_start and test_end 45 | } 46 | } 47 | # remember, seq is inclusive 48 | ``` 49 | 50 | * pandas.get_dummies 51 | 52 | * Transposing 1-to-many relations. 53 | 54 | This circumstance arises when we have a table in a "log" format, where 55 | multiple rows are associated with one identity. For example, say we are 56 | predicting likelihood that a given student will drop out of school, and we 57 | have a GPA per student per year. So one of the tables we have is formatted like: 58 | 59 | | ID | Grade | GPA | 60 | | ---:| -----:| ---:| 61 | | 1 | 9 | 3.2 | 62 | | 1 | 10 | 3.4 | 63 | | 1 | 11 | 4.0 | 64 | | 2 | 9 | 2.1 | 65 | | 2 | 10 | 2.0 | 66 | | 2 | 11 | 2.3 | 67 | | 2 | 12 | 2.5 | 68 | 69 | Then we have another table of features that don't vary by year. For example: 70 | 71 | | ID | Date of Birth | Graduated | 72 | | ---:| -------------:| ---------:| 73 | | 1 | 1988-09-22 | 1 | 74 | | 2 | 1989-08-08 | 0 | 75 | 76 | The table that we actually analyze needs one row per student, so we need to 77 | attach a transpose of the former table to the later table, like 78 | 79 | | ID | Date of Birth | Graduated | GPA_9 | GPA_10 | GPA_11 | GPA_12 | 80 | | ---:| -------------:| ---------:| -----:| ------:| ------:| ------:| 81 | | 1 | 1988-09-22 | 1 | 3.2 | 3.4 | 4.0 | | 82 | | 2 | 1989-08-08 | 0 | 2.1 | 2.0 | 2.3 | 2.5 | 83 | 84 | * Removing columns by regex. In the above example, there are circumstances in which we won't to exclude some of the columns (for example, GPA > grade 11). We should have a subsetter that does this. For example, we take a subset that removes anything > grade 11: 85 | 86 | | ID | Date of Birth | Graduated | GPA_9 | GPA_10 | GPA_11 | 87 | | ---:| -------------:| ---------:| -----:| ------:| ------:| 88 | | 1 | 1988-09-22 | 1 | 3.2 | 3.4 | 4.0 | 89 | | 2 | 1989-08-08 | 0 | 2.1 | 2.0 | 2.3 | 90 | 91 | And then another subset that removes anything > grade 10: 92 | 93 | | ID | Date of Birth | Graduated | GPA_9 | GPA_10 | 94 | | ---:| -------------:| ---------:| -----:| ------:| 95 | | 1 | 1988-09-22 | 1 | 3.2 | 3.4 | 96 | | 2 | 1989-08-08 | 0 | 2.1 | 2.0 | 97 | 98 | ## Infonavit 99 | 100 | * In perambulate: fit train set to a Gaussian, then apply Gaussian to test set. Normalize to train set, then apply to test set. Normalizing across everything would be cheating 101 | 102 | * managing shape files. Use them for thresholding to bin GPS data 103 | 104 | * We'd rather have a static map than a web page. Just a quick sanity check 105 | 106 | * Sanity check for csv. E.G. Are there the correct amount of delims per row. Except we're not importing any more, so we're not doing this for now. 107 | 108 | * Impute based on previous calculations. Treat col w/ missing values as label, train on everything else with RF 109 | 110 | * Sanity check to make sure that train/test data includes distribution of labels corresponding to real distribution. E.g. if bimodal 111 | 112 | ## Labor 113 | 114 | * pie charts. 115 | 116 | ## Police 117 | 118 | * Pay attention to feature processors. The different sorts of post-processing one might do on a column. 119 | 120 | ## Sunlight 121 | 122 | ### Standardized pdf scraping 123 | 124 | * After some discussion, we've decided that importing data to a structured array in a thorough manner is probably beyond our scope. It would be best to let our clients do it. (probably w/ Pandas). PDF scraping is irrelevant 125 | 126 | ### Smith-Waterman 127 | 128 | * Sunlight developed a faster implementation than is publicly available. This won't go in eights, but we should help make sure it gets released. 129 | 130 | ### JIT 131 | 132 | * look at numba. Does this help make things faster if there's no effort involved. Compare to Pypy. 133 | 134 | ### TF-IDF Score 135 | 136 | * A standard tool for topic modeling, clustering, developing feature vectors. Account for its existence. Maybe implement it. 137 | 138 | ### TIKA 139 | 140 | * Import documents into search engine. Not directly relevant, but consider. 141 | 142 | 143 | ## World Bank 144 | 145 | * Entity resolution. Different people call different entities the same thing. For example, a series of different nicknames for the same college. 146 | 147 | * Figuring out what to name automatically generated columns. 148 | 149 | * Feature aggregation into percent “what percent of a suppliers contracts were in Africa before a given contract”. Powers of sets. Things to aggregate over are cross product of a set. See Elissa’s slide 150 | 151 | * SQL-esque joins 152 | 153 | ## Kirsten 154 | 155 | * Detect and remove colinear columns 156 | 157 | * col renaming. e.g. year to grade level 158 | 159 | ## Feature generation 160 | 161 | * When we generate a feature, keep track of the column name and what it means in metadata that is attached to the structured array. We'll keep metadata in structured arrays by doing a thin subclass like this: 162 | 163 | ``` 164 | >>> class BetterSA(np.ndarray): 165 | ... def __init__(self, *args, **kwargs): 166 | ... super(BetterSA, self).__init__(*args, **kwargs) 167 | ... self.meta = 'metadata' 168 | ``` 169 | -------------------------------------------------------------------------------- /doc/readme.rst: -------------------------------------------------------------------------------- 1 | omg there are so many 2 | 3 | http://patorjk.com/software/taag/#p=display&v=3&f=3D-ASCII&t=eights 4 | 5 | 6 | __.....__ .--. . 7 | .-'' '. |__| .--./) .'| 8 | / .-''"'-. `. .--. /.''\\ < | .| 9 | / /________\ \| || | | | | | .' |_ 10 | | || | \`-' / | | .'''-. .' | _ 11 | \ .-------------'| | /("'` | |/.'''. \'--. .-' .' | 12 | \ '-.____...---.| | \ '---. | / | | | | . | / 13 | `. .' |__| /'""'.\ | | | | | | .'.'| |// 14 | `''-...... -' || ||| | | | | '.'.'.'.-' / 15 | \'. __// | '. | '. | / .' \_.' 16 | `'---' '---' '---' `'-'' 17 | 18 | _ _ _ 19 | (_) | | | | 20 | ___ _ __ _| |__ | |_ ___ 21 | / _ \ |/ _` | '_ \| __/ __| 22 | | __/ | (_| | | | | |_\__ \ 23 | \___|_|\__, |_| |_|\__|___/ 24 | __/ | 25 | |___/ 26 | 27 | _______ _________ _______ _________ _______ 28 | ( ____ \\__ __/( ____ \|\ /|\__ __/( ____ \ 29 | | ( \/ ) ( | ( \/| ) ( | ) ( | ( \/ 30 | | (__ | | | | | (___) | | | | (_____ 31 | | __) | | | | ____ | ___ | | | (_____ ) 32 | | ( | | | | \_ )| ( ) | | | ) | 33 | | (____/\___) (___| (___) || ) ( | | | /\____) | 34 | (_______/\_______/(_______)|/ \| )_( \_______) 35 | 36 | 37 | 38 | By default we use pass structured arrays from within functions. However to accommodate both list of list and list of nd.arrays, we cast structured arrays as: 39 | name, M = sa_to_nd(M) 40 | 41 | For our the rapid analysis of different CLF's in SKLEARN we use dictionaries of dictionaries of Lists. Where lists are the slice indices, the outermost dictionary is the test, the inner dictionary is the run. For instance, Test["sweepParamtersize"]['one'] == nd.array([1,2,3]) 42 | 43 | M is a structured array 44 | Y is the target in SKLEARN parlance. It is the known labels. 45 | 46 | 47 | 48 | supported plots: (ROC, PER_RECALL, ACC, N_TOP_FEAT, AUC) 49 | supported clfs: (RF, SVM, DESC_TREE, ADA_BOOST) 50 | supported subsetting: (LEAVE_ONE_COL_OUT, SWEEP_TRAINING_SIZE, ) 51 | supported cv: (K_FOLD, STRAT_ACTUAL_K_FOLD, STRAT_EVEN_K_FOLD) 52 | 53 | 54 | plots = ['roc', 'acc'] 55 | clfs = [('random forest', ['RF PARMS'] ), 56 | ('svm', ['SVM PARMS'] )] 57 | 58 | subsets = [('leave one out col', ['PARMS'] ), 59 | ('sweep training size', ['PARMS'] )] 60 | 61 | cv = [('cv', ['parms']), 62 | ('stratified cv', ['parms'])] 63 | 64 | runOne = Experiment(plots, clfs, subsets, cv) 65 | 66 | exp = Experiment( 67 | [RF: {'depth': [10, 100], 68 | 'n_trees': [40, 50]}, 69 | SVM: {'param_1': [1, 2], 70 | 'param_2': ['a', 'b']}], 71 | [LEAVE_ONE_COL_OUT: {'col_names': ['f0', 'f1', 'f2', 'f3']}, 72 | SWEEP_TRAINING_SIZE: {'sizes': (10, 20, 40)} 73 | ], 74 | [STRAT_ACTUAL_K_FOLD : {'y': y}]) 75 | 76 | 77 | M is our matrix to train our ML algo on, its always a structured array (or an nd.array, or a list of one dimension arrays). 78 | Labels are the supervised learnings gold standard labels. It is the TRUTH. Aways a one dim numpy.array 79 | If we use a collection of columns, that are not the full matrix it is cols. 80 | col is the one dimeion equiv of cols, it has the same type as labels. 81 | When we pass in a classifier, its an sklearn base estimator class as opposed to instance 82 | When we are passing a cross validation type, its a sklearn partition iterator. 83 | When we pass a subset method it is an iterator for which each iteration returns a set of indices. 84 | Parameters for sklearn etc is always pass dictionaries of string to something. 85 | There will be other stuff, but the above carved in rice paper. 86 | 87 | 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /doc/sample_report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dssg/eights/9f12f9fb60984b8da2270e0df809fa09027336e5/doc/sample_report.pdf -------------------------------------------------------------------------------- /doc/sweeping_architecture.odg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dssg/eights/9f12f9fb60984b8da2270e0df809fa09027336e5/doc/sweeping_architecture.odg -------------------------------------------------------------------------------- /doc/sweeping_architecture.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dssg/eights/9f12f9fb60984b8da2270e0df809fa09027336e5/doc/sweeping_architecture.pdf -------------------------------------------------------------------------------- /doc/tutorial.py: -------------------------------------------------------------------------------- 1 | #Required Import 2 | import numpy as np 3 | 4 | import eights.investigate as inv 5 | 6 | 7 | 8 | #Investigate 9 | M,labels = inv.open_cvs(file_loc) 10 | 11 | #choose to Numpy Structures Arrays 12 | 13 | #Descriptive statistics 14 | inv.describe_cols(data) 15 | inv.cross_tabs 16 | 17 | 18 | inv.plot_correlation_matrix 19 | inv.plot_correlation_scatter_plot 20 | inv.plot_box_plot 21 | 22 | 23 | # Decontaminate Data 24 | import eights.investigate as dec 25 | replace_with_n_bins 26 | replace_missing_vals 27 | 28 | #generate features 29 | def is_this_word_in(a_text, word): 30 | return word in a_text 31 | 32 | M = where_all_are_true( 33 | M, 34 | [(val_eq, 'open_col', NULL), 35 | (val_eq, 'click', NULL ]) 36 | "n_click,n_open" 37 | ) 38 | 39 | M = where_all_are_true( 40 | M, 41 | [(is_this_word_in, 'email_text', 'unsubscribe'), 42 | (val_eq, 'click', NULL ]) 43 | "1_click,n_open" 44 | ) 45 | 46 | #Trucate 47 | M1 = remove_rows_= 48 | M, 49 | [(is_this_word_in, 'email_text', 'unsubscribe'), 50 | (val_eq, 'click', NULL ]) 51 | "1_click,n_open" 52 | ) 53 | 54 | #Permabulate/operate 55 | 56 | experiment = run_std_classifiers(M1, labesl) 57 | 58 | #exp identicle to experiment 59 | 60 | exp = Experiment( 61 | M1, 62 | labels, 63 | clfs = {AdaBoostClassifier: {'n_estimators': [20,50,100]}, 64 | RandomForestClassifier: {'n_estimators': [10,30,50],'max_depth': [None,4,7,15],'n_jobs':[1]}, 65 | LogisticRegression:{'C': [1.0,2.0,0.5,0.25],'penalty': ['l1','l2']}, 66 | DecisionTreeClassifier: {'max_depth': [None,4,7,15,25]}, 67 | SVC:{'kernel': ['linear','rbf']}, 68 | DummyClassifier:{'strategy': ['stratified','most_frequent','uniform']} 69 | } 70 | cvs = {StratifiedKFold:{}} 71 | ) 72 | 73 | #communicate 74 | exp.report() 75 | 76 | 77 | M = add_these( 78 | [[(val_eq, 'open_col', NULL),(val_eq, 'click', NULL ]),"n_click,n_open"], 79 | [[(val_eq, 'open_col', NULL),(val_eq, 'click', NULL ]),"n_click,n_open"] 80 | ) 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /eights/__init__.py: -------------------------------------------------------------------------------- 1 | import investigate as inv 2 | import communicate as com 3 | import decontaminate as dec 4 | import generate as gen 5 | import operate as op 6 | import perambulate as per 7 | import truncate as tr 8 | -------------------------------------------------------------------------------- /eights/communicate/__init__.py: -------------------------------------------------------------------------------- 1 | from communicate import * -------------------------------------------------------------------------------- /eights/communicate/communicate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import StringIO 4 | import cgi 5 | import uuid 6 | import abc 7 | from datetime import datetime 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import matplotlib.dates 11 | from matplotlib.pylab import boxplot 12 | 13 | 14 | from sklearn.grid_search import GridSearchCV 15 | from sklearn.neighbors.kde import KernelDensity 16 | import pdfkit 17 | 18 | from sklearn.metrics import roc_curve 19 | from sklearn.metrics import roc_auc_score 20 | from sklearn.metrics import precision_recall_curve 21 | from ..perambulate import Experiment 22 | from ..utils import is_sa, is_nd, cast_np_sa_to_nd, convert_to_sa 23 | from ..utils import cast_list_of_list_to_sa 24 | from communicate_helper import * 25 | from communicate_helper import _feature_pair_report 26 | 27 | 28 | def print_matrix_row_col(M, row_labels=None, col_labels=None): 29 | M = convert_to_sa(M, col_names=col_labels) 30 | if row_labels is None: 31 | row_labels = xrange(M.shape[0]) 32 | col_labels = M.dtype.names 33 | # From http://stackoverflow.com/questions/9535954/python-printing-lists-as-tabular-data 34 | row_format =' '.join(['{:>15}' for _ in xrange(len(col_labels) + 1)]) 35 | print row_format.format("", *col_labels) 36 | for row_name, row in zip(row_labels, M): 37 | print row_format.format(row_name, *row) 38 | 39 | def plot_simple_histogram(col, verbose=True): 40 | hist, bins = np.histogram(col, bins=50) 41 | width = 0.7 * (bins[1] - bins[0]) 42 | center = (bins[:-1] + bins[1:]) / 2 43 | f = plt.figure() 44 | plt.bar(center, hist, align='center', width=width) 45 | if verbose: 46 | plt.show() 47 | return f 48 | 49 | # all of the below take output from any func in perambulate or operate 50 | 51 | 52 | def plot_prec_recall(labels, score, title='Prec/Recall', verbose=True): 53 | # adapted from Rayid's prec/recall code 54 | y_true = labels 55 | y_score = score 56 | precision_curve, recall_curve, pr_thresholds = precision_recall_curve( 57 | y_true, 58 | y_score) 59 | precision_curve = precision_curve[:-1] 60 | recall_curve = recall_curve[:-1] 61 | pct_above_per_thresh = [] 62 | number_scored = len(y_score) 63 | for value in pr_thresholds: 64 | num_above_thresh = len(y_score[y_score>=value]) 65 | pct_above_thresh = num_above_thresh / float(number_scored) 66 | pct_above_per_thresh.append(pct_above_thresh) 67 | pct_above_per_thresh = np.array(pct_above_per_thresh) 68 | fig = plt.figure() 69 | ax1 = plt.gca() 70 | ax1.plot(pct_above_per_thresh, precision_curve, 'b') 71 | ax1.set_xlabel('percent of population') 72 | ax1.set_ylabel('precision', color='b') 73 | ax2 = ax1.twinx() 74 | ax2.plot(pct_above_per_thresh, recall_curve, 'r') 75 | ax2.set_ylabel('recall', color='r') 76 | plt.title(title) 77 | if verbose: 78 | fig.show() 79 | return fig 80 | 81 | def plot_roc(labels, score, title='ROC', verbose=True): 82 | # adapted from Rayid's prec/recall code 83 | fpr, tpr, thresholds = roc_curve(labels, score) 84 | fpr = fpr 85 | tpr = tpr 86 | pct_above_per_thresh = [] 87 | number_scored = len(score) 88 | for value in thresholds: 89 | num_above_thresh = len(score[score>=value]) 90 | pct_above_thresh = num_above_thresh / float(number_scored) 91 | pct_above_per_thresh.append(pct_above_thresh) 92 | pct_above_per_thresh = np.array(pct_above_per_thresh) 93 | 94 | fig = plt.figure() 95 | ax1 = plt.gca() 96 | ax1.plot(pct_above_per_thresh, fpr, 'b') 97 | ax1.set_xlabel('percent of population') 98 | ax1.set_ylabel('fpr', color='b') 99 | ax2 = ax1.twinx() 100 | ax2.plot(pct_above_per_thresh, tpr, 'r') 101 | ax2.set_ylabel('tpr', color='r') 102 | plt.title(title) 103 | if verbose: 104 | fig.show() 105 | return fig 106 | 107 | def plot_box_plot(col, col_name=None, verbose=True): 108 | """Makes a box plot for a feature 109 | comment 110 | 111 | Parameters 112 | ---------- 113 | col : np.array 114 | 115 | Returns 116 | ------- 117 | matplotlib.figure.Figure 118 | 119 | """ 120 | 121 | fig = plt.figure() 122 | boxplot(col) 123 | if col_name: 124 | plt.title(col_name) 125 | #add col_name to graphn 126 | if verbose: 127 | plt.show() 128 | return fig 129 | 130 | def get_top_features(clf, M=None, col_names=None, n=10, verbose=True): 131 | scores = clf.feature_importances_ 132 | if col_names is None: 133 | if is_sa(M): 134 | col_names = M.dtype.names 135 | else: 136 | col_names = ['f{}'.format(i) for i in xrange(len(scores))] 137 | ranked_name_and_score = [(col_names[x], scores[x]) for x in 138 | scores.argsort()[::-1]] 139 | ranked_name_and_score = convert_to_sa( 140 | ranked_name_and_score[:n], 141 | col_names=('feat_name', 'score')) 142 | if verbose: 143 | print_matrix_row_col(ranked_name_and_score) 144 | return ranked_name_and_score 145 | 146 | # TODO features form top % of clfs 147 | 148 | def get_roc_auc(labels, score, verbose=True): 149 | auc_score = roc_auc_score(labels, score) 150 | if verbose: 151 | print 'ROC AUC: {}'.format(auc_score) 152 | return auc_score 153 | 154 | def plot_correlation_matrix(M, verbose=True): 155 | """Plot correlation between variables in M 156 | 157 | Parameters 158 | ---------- 159 | M : numpy structured array 160 | 161 | Returns 162 | ------- 163 | matplotlib.figure.Figure 164 | 165 | """ 166 | # http://glowingpython.blogspot.com/2012/10/visualizing-correlation-matrices.html 167 | # TODO work on structured arrays or not 168 | # TODO ticks are col names 169 | if is_sa(M): 170 | names = M.dtype.names 171 | M = cast_np_sa_to_nd(M) 172 | n_cols = M.shape[1] 173 | else: 174 | if is_nd(M): 175 | n_cols = M.shape[1] 176 | else: # list of arrays 177 | n_cols = len(M[0]) 178 | names = ['f{}'.format(i) for i in xrange(n_cols)] 179 | 180 | #set rowvar =0 for rows are items, cols are features 181 | cc = np.corrcoef(M, rowvar=0) 182 | 183 | fig = plt.figure() 184 | plt.pcolor(cc) 185 | plt.colorbar() 186 | plt.yticks(np.arange(0.5, M.shape[1] + 0.5), range(0, M.shape[1])) 187 | plt.xticks(np.arange(0.5, M.shape[1] + 0.5), range(0, M.shape[1])) 188 | if verbose: 189 | plt.show() 190 | return fig 191 | 192 | def plot_correlation_scatter_plot(M, verbose=True): 193 | """Makes a grid of scatter plots representing relationship between variables 194 | 195 | Each scatter plot is one variable plotted against another variable 196 | 197 | Parameters 198 | ---------- 199 | M : numpy structured array 200 | 201 | Returns 202 | ------- 203 | matplotlib.figure.Figure 204 | 205 | """ 206 | # TODO work for all three types that M might be 207 | # TODO ignore classification variables 208 | # adapted from the excellent 209 | # http://stackoverflow.com/questions/7941207/is-there-a-function-to-make-scatterplot-matrices-in-matplotlib 210 | 211 | M = convert_to_sa(M) 212 | 213 | numdata = M.shape[0] 214 | numvars = len(M.dtype) 215 | names = M.dtype.names 216 | fig, axes = plt.subplots(numvars, numvars) 217 | fig.subplots_adjust(hspace=0.05, wspace=0.05) 218 | 219 | for ax in axes.flat: 220 | # Hide all ticks and labels 221 | ax.xaxis.set_visible(False) 222 | ax.yaxis.set_visible(False) 223 | 224 | # Set up ticks only on one side for the "edge" subplots... 225 | if ax.is_first_col(): 226 | ax.yaxis.set_ticks_position('left') 227 | if ax.is_last_col(): 228 | ax.yaxis.set_ticks_position('right') 229 | if ax.is_first_row(): 230 | ax.xaxis.set_ticks_position('top') 231 | if ax.is_last_row(): 232 | ax.xaxis.set_ticks_position('bottom') 233 | 234 | # Plot the M. 235 | for i, j in zip(*np.triu_indices_from(axes, k=1)): 236 | for x, y in [(i,j), (j,i)]: 237 | axes[x,y].plot(M[M.dtype.names[x]], M[M.dtype.names[y]], '.') 238 | 239 | # Label the diagonal subplots... 240 | for i, label in enumerate(names): 241 | axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction', 242 | ha='center', va='center') 243 | 244 | # Turn on the proper x or y axes ticks. 245 | for i, j in zip(range(numvars), it.cycle((-1, 0))): 246 | axes[j,i].xaxis.set_visible(True) 247 | axes[i,j].yaxis.set_visible(True) 248 | if verbose: 249 | plt.show() 250 | return fig 251 | 252 | def plot_kernel_density(col, n=None, missing_val=np.nan, verbose=True): 253 | #address pass entire matrix 254 | # TODO respect missing_val 255 | # TODO what does n do? 256 | x_grid = np.linspace(min(col), max(col), 1000) 257 | 258 | grid = GridSearchCV(KernelDensity(), {'bandwidth': np.linspace(0.1,1.0,30)}, cv=20) # 20-fold cross-validation 259 | grid.fit(col[:, None]) 260 | 261 | kde = grid.best_estimator_ 262 | pdf = np.exp(kde.score_samples(x_grid[:, None])) 263 | 264 | fig, ax = plt.subplots() 265 | #fig = plt.figure() 266 | ax.plot(x_grid, pdf, linewidth=3, alpha=0.5, label='bw=%.2f' % kde.bandwidth) 267 | ax.hist(col, 30, fc='gray', histtype='stepfilled', alpha=0.3, normed=True) 268 | ax.legend(loc='upper left') 269 | ax.set_xlim(min(col), max(col)) 270 | if verbose: 271 | plt.show() 272 | return fig 273 | 274 | def plot_on_timeline(col, verbose=True): 275 | """Plots points on a timeline 276 | 277 | Parameters 278 | ---------- 279 | col : np.array 280 | 281 | Returns 282 | ------- 283 | matplotlib.figure.Figure 284 | """ 285 | # http://stackoverflow.com/questions/1574088/plotting-time-in-python-with-matplotlib 286 | if is_nd(col): 287 | col = col.astype(datetime) 288 | dates = matplotlib.dates.date2num(col) 289 | fig = plt.figure() 290 | plt.plot_date(dates, [0] * len(dates)) 291 | if verbose: 292 | plt.show() 293 | return fig 294 | 295 | def feature_pairs_in_rf(rf, weight_by_depth=None, verbose=True, n=10): 296 | """Describes the frequency of features appearing subsequently in each tree 297 | in a random forest""" 298 | # weight be depth is a vector. The 0th entry is the weight of being at 299 | # depth 0; the 1st entry is the weight of being at depth 1, etc. 300 | # If not provided, weights are linear with negative depth. If 301 | # the provided vector is not as long as the number of depths, then 302 | # remaining depths are weighted with 0 303 | # If verbose, will only print the first n results 304 | 305 | 306 | pairs_by_est = [feature_pairs_in_tree(est) for est in rf.estimators_] 307 | pairs_by_depth = [list(it.chain(*pair_list)) for pair_list in 308 | list(it.izip_longest(*pairs_by_est, fillvalue=[]))] 309 | pairs_flat = list(it.chain(*pairs_by_depth)) 310 | depths_by_pair = {} 311 | for depth, pairs in enumerate(pairs_by_depth): 312 | for pair in pairs: 313 | try: 314 | depths_by_pair[pair] += [depth] 315 | except KeyError: 316 | depths_by_pair[pair] = [depth] 317 | counts_by_pair=Counter(pairs_flat) 318 | count_pairs_by_depth = [Counter(pairs) for pairs in pairs_by_depth] 319 | 320 | depth_len = len(pairs_by_depth) 321 | if weight_by_depth is None: 322 | weight_by_depth = [(depth_len - float(depth)) / depth_len for depth in 323 | xrange(depth_len)] 324 | weight_filler = it.repeat(0.0, depth_len - len(weight_by_depth)) 325 | weights = list(it.chain(weight_by_depth, weight_filler)) 326 | 327 | average_depth_by_pair = {pair: float(sum(depths)) / len(depths) for 328 | pair, depths in depths_by_pair.iteritems()} 329 | 330 | weighted = {pair: sum([weights[depth] for depth in depths]) 331 | for pair, depths in depths_by_pair.iteritems()} 332 | 333 | if verbose: 334 | print '=' * 80 335 | print 'RF Subsequent Pair Analysis' 336 | print '=' * 80 337 | print 338 | _feature_pair_report( 339 | counts_by_pair.most_common(), 340 | 'Overall Occurrences', 341 | 'occurrences', 342 | n=n) 343 | _feature_pair_report( 344 | sorted([item for item in average_depth_by_pair.iteritems()], 345 | key=lambda item: item[1]), 346 | 'Average depth', 347 | 'average depth', 348 | 'Max depth was {}'.format(depth_len - 1), 349 | n=n) 350 | _feature_pair_report( 351 | sorted([item for item in weighted.iteritems()], 352 | key=lambda item: item[1]), 353 | 'Occurrences weighted by depth', 354 | 'sum weight', 355 | 'Weights for depth 0, 1, 2, ... were: {}'.format(weights), 356 | n=n) 357 | 358 | for depth, pairs in enumerate(count_pairs_by_depth): 359 | _feature_pair_report( 360 | pairs.most_common(), 361 | 'Occurrences at depth {}'.format(depth), 362 | 'occurrences', 363 | n=n) 364 | 365 | 366 | return (counts_by_pair, count_pairs_by_depth, average_depth_by_pair, 367 | weighted) 368 | 369 | def html_escape(s): 370 | """Returns a string with all its html-averse characters html escaped""" 371 | return cgi.escape(s).encode('ascii', 'xmlcharrefreplace') 372 | 373 | def html_format(fmt, *args, **kwargs): 374 | clean_args = [html_escape(str(arg)) for arg in args] 375 | clean_kwargs = {key: html_escape(str(kwargs[key])) for 376 | key in kwargs} 377 | return fmt.format(*clean_args, **clean_kwargs) 378 | 379 | def np_to_html_table(sa, fout, show_shape=False): 380 | if show_shape: 381 | fout.write('

table of shape: ({},{})

'.format( 382 | len(sa), 383 | len(sa.dtype))) 384 | fout.write('

\n') 385 | header = '{}\n'.format( 386 | ''.join( 387 | [html_format( 388 | '', 389 | name) for 390 | name in sa.dtype.names])) 391 | fout.write(header) 392 | data = '\n'.join( 393 | ['{}'.format( 394 | ''.join( 395 | [html_format( 396 | '', 397 | cell) for 398 | cell in row])) for 399 | row in sa]) 400 | fout.write(data) 401 | fout.write('\n') 402 | fout.write('
{}
{}

') 403 | 404 | class ReportError(Exception): 405 | pass 406 | 407 | class Report(object): 408 | 409 | def __init__(self, exp=None, report_path='report.pdf'): 410 | self.__exp = exp 411 | if exp is not None: 412 | self.__back_indices = {trial: i for i, trial in enumerate(exp.trials)} 413 | self.__objects = [] 414 | self.__tmp_folder = 'eights_temp' 415 | if not os.path.exists(self.__tmp_folder): 416 | os.mkdir(self.__tmp_folder) 417 | self.__html_src_path = os.path.join(self.__tmp_folder, 418 | '{}.html'.format(uuid.uuid4())) 419 | self.__report_path = report_path 420 | 421 | def to_pdf(self, options={}, verbose=True): 422 | # Options are pdfkit.from_url options. See 423 | # https://pypi.python.org/pypi/pdfkit 424 | if verbose: 425 | print 'Generating report...' 426 | with open(self.__html_src_path, 'w') as html_out: 427 | html_out.write(self.__get_header()) 428 | html_out.write('\n'.join(self.__objects)) 429 | html_out.write(self.__get_footer()) 430 | if not verbose: 431 | options['quiet'] = '' 432 | pdfkit.from_url(self.__html_src_path, self.__report_path, 433 | options=options) 434 | report_path = self.get_report_path() 435 | if verbose: 436 | print 'Report written to {}'.format(report_path) 437 | return report_path 438 | 439 | def get_report_path(self): 440 | return os.path.abspath(self.__report_path) 441 | 442 | def __get_header(self): 443 | # Thanks to http://stackoverflow.com/questions/13516534/how-to-avoid-page-break-inside-table-row-for-wkhtmltopdf 444 | # For not page breaking in the middle of tables 445 | return ('\n' 446 | '\n' 447 | '\n' 448 | '\n' 465 | '\n' 466 | '\n') 467 | 468 | def add_subreport(self, subreport): 469 | self.__objects += subreport.__objects 470 | 471 | def __get_footer(self): 472 | return '\n\n\n' 473 | 474 | def add_heading(self, heading, level=2): 475 | self.__objects.append(html_format( 476 | '{}', 477 | level, 478 | heading, 479 | level)) 480 | 481 | def add_text(self, text): 482 | self.__objects.append(html_format( 483 | '

{}

', 484 | text)) 485 | 486 | def add_table(self, M): 487 | sio = StringIO.StringIO() 488 | np_to_html_table(M, sio) 489 | self.__objects.append(sio.getvalue()) 490 | 491 | def add_fig(self, fig): 492 | # So we don't get pages with nothing but one figure on them 493 | fig.set_figheight(5.0) 494 | filename = 'fig_{}.png'.format(str(uuid.uuid4())) 495 | path = os.path.join(self.__tmp_folder, filename) 496 | fig.savefig(path) 497 | self.__objects.append(''.format(filename)) 498 | 499 | def add_summary_graph(self, measure): 500 | if self.__exp is None: 501 | raise ReportError('No experiment provided for this report. ' 502 | 'Cannot add summary graphs.') 503 | results = [(trial, score, self.__back_indices[trial]) for 504 | trial, score in getattr(self.__exp, measure)().iteritems()] 505 | results_sorted = sorted( 506 | results, 507 | key=lambda result: result[1], 508 | reverse=True) 509 | y = [result[1] for result in results_sorted] 510 | x = xrange(len(results)) 511 | fig = plt.figure() 512 | plt.bar(x, y) 513 | maxy = max(y) 514 | for rank, result in enumerate(results_sorted): 515 | plt.text(rank, result[1], '{}'.format(result[2])) 516 | plt.ylabel(measure) 517 | self.add_fig(fig) 518 | plt.close() 519 | 520 | def add_summary_graph_roc_auc(self): 521 | self.add_summary_graph('roc_auc') 522 | 523 | def add_summary_graph_average_score(self): 524 | self.add_summary_graph('average_score') 525 | 526 | def add_graph_for_best(self, func_name): 527 | if self.__exp is None: 528 | raise ReportError('No experiment provided for this report. ' 529 | 'Cannot add graph for best trial.') 530 | best_trial = max( 531 | self.__exp.trials, 532 | key=lambda trial: trial.average_score()) 533 | fig = getattr(best_trial, func_name)() 534 | self.add_fig(fig) 535 | self.add_text('Best trial is trial {} ({})]'.format( 536 | self.__back_indices[best_trial], 537 | best_trial)) 538 | plt.close() 539 | 540 | def add_graph_for_best_roc(self): 541 | self.add_graph_for_best('roc_curve') 542 | 543 | def add_graph_for_best_prec_recall(self): 544 | self.add_graph_for_best('prec_recall_curve') 545 | 546 | def add_legend(self): 547 | if self.__exp is None: 548 | raise ReportError('No experiment provided for this report. ' 549 | 'Cannot add legend.') 550 | list_of_tuple = [(str(i), str(trial)) for i, trial in 551 | enumerate(self.__exp.trials)] 552 | table = cast_list_of_list_to_sa(list_of_tuple, col_names=('Id', 'Trial')) 553 | # display 10 at a time to give pdfkit an easier time with page breaks 554 | start_row = 0 555 | n_trials = len(list_of_tuple) 556 | while start_row < n_trials: 557 | self.add_table(table[start_row:start_row+9]) 558 | start_row += 9 559 | 560 | 561 | -------------------------------------------------------------------------------- /eights/communicate/communicate_helper.py: -------------------------------------------------------------------------------- 1 | from sklearn.tree._tree import TREE_LEAF 2 | from collections import Counter 3 | import itertools as it 4 | 5 | def _feature_pair_report(pair_and_values, 6 | description='pairs', 7 | measurement='value', 8 | note=None, 9 | n=10): 10 | print '-' * 80 11 | print description 12 | print '-' * 80 13 | print 'feature pair : {}'.format(measurement) 14 | for pair, value in it.islice(pair_and_values, n): 15 | print '{} : {}'.format(pair, value) 16 | if note is not None: 17 | print '* {}'.format(note) 18 | print 19 | 20 | 21 | def feature_pairs_in_tree(dt): 22 | """Lists subsequent features sorted by importance 23 | 24 | Parameters 25 | ---------- 26 | dt : sklearn.tree.DecisionTreeClassifer 27 | 28 | Returns 29 | ------- 30 | list of list of tuple of int : 31 | Going from inside to out: 32 | 33 | 1. Each int is a feature that a node split on 34 | 35 | 2. If two ints appear in the same tuple, then there was a node 36 | that split on the second feature immediately below a node 37 | that split on the first feature 38 | 39 | 3. Tuples appearing in the same inner list appear at the same 40 | depth in the tree 41 | 42 | 4. The outer list describes the entire tree 43 | 44 | """ 45 | t = dt.tree_ 46 | feature = t.feature 47 | children_left = t.children_left 48 | children_right = t.children_right 49 | result = [] 50 | if t.children_left[0] == TREE_LEAF: 51 | return result 52 | next_queue = [0] 53 | while next_queue: 54 | this_queue = next_queue 55 | next_queue = [] 56 | results_this_depth = [] 57 | while this_queue: 58 | node = this_queue.pop() 59 | left_child = children_left[node] 60 | right_child = children_right[node] 61 | if children_left[left_child] != TREE_LEAF: 62 | results_this_depth.append(tuple(sorted( 63 | (feature[node], 64 | feature[left_child])))) 65 | next_queue.append(left_child) 66 | if children_left[right_child] != TREE_LEAF: 67 | results_this_depth.append(tuple(sorted( 68 | (feature[node], 69 | feature[right_child])))) 70 | next_queue.append(right_child) 71 | result.append(results_this_depth) 72 | result.pop() # The last results are always empty 73 | return result 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /eights/decontaminate/__init__.py: -------------------------------------------------------------------------------- 1 | from decontaminate import * 2 | 3 | -------------------------------------------------------------------------------- /eights/decontaminate/decontaminate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import preprocessing 3 | from sklearn.preprocessing import Imputer 4 | from eights.utils import convert_to_sa 5 | 6 | def label_encode(M): 7 | """ 8 | Changes string cols to integers so that there is a 1-1 mapping between 9 | strings and ints 10 | """ 11 | 12 | M = convert_to_sa(M) 13 | le = preprocessing.LabelEncoder() 14 | new_dtype = [] 15 | result_arrays = [] 16 | for (col_name, fmt) in M.dtype.descr: 17 | if 'S' in fmt: 18 | result_arrays.append(le.fit_transform(M[col_name])) 19 | new_dtype.append((col_name, int)) 20 | else: 21 | result_arrays.append(M[col_name]) 22 | new_dtype.append((col_name, fmt)) 23 | return np.array(zip(*result_arrays), dtype=new_dtype) 24 | 25 | def replace_missing_vals(M, strategy, missing_val=np.nan, constant=0): 26 | # TODO support times, strings 27 | M = convert_to_sa(M) 28 | 29 | if strategy not in ['mean', 'median', 'most_frequent', 'constant']: 30 | raise ValueError('Invalid strategy') 31 | 32 | M_cp = M.copy() 33 | 34 | if strategy == 'constant': 35 | 36 | try: 37 | missing_is_nan = np.isnan(missing_val) 38 | except TypeError: 39 | # missing_val is not a float 40 | missing_is_nan = False 41 | 42 | if missing_is_nan: # we need to be careful about handling nan 43 | for col_name, col_type in M_cp.dtype.descr: 44 | if 'f' in col_type: 45 | col = M_cp[col_name] 46 | col[np.isnan(col)] = constant 47 | return M_cp 48 | 49 | for col_name, col_type in M_cp.dtype.descr: 50 | if 'i' in col_type or 'f' in col_type: 51 | col = M_cp[col_name] 52 | col[col == missing_val] = constant 53 | return M_cp 54 | 55 | # we're doing one of the sklearn imputer strategies 56 | imp = Imputer(missing_values=missing_val, strategy=strategy, axis=1) 57 | for col_name, col_type in M_cp.dtype.descr: 58 | if 'f' in col_type or 'i' in col_type: 59 | # The Imputer only works on float and int columns 60 | col = M_cp[col_name] 61 | col[:] = imp.fit_transform(col) 62 | return M_cp 63 | 64 | 65 | -------------------------------------------------------------------------------- /eights/generate/__init__.py: -------------------------------------------------------------------------------- 1 | from generate import * -------------------------------------------------------------------------------- /eights/generate/generate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import cross_validation 3 | from ..utils import append_cols, distance 4 | from uuid import uuid4 5 | 6 | 7 | def where_all_are_true(M, arguments, generated_name=None): 8 | if generated_name is None: 9 | generated_name = str(uuid4()) 10 | to_select = np.ones(M.size, dtype=bool) 11 | for arg_set in arguments: 12 | lambd, col_name, vals = (arg_set['func'], arg_set['col_name'], 13 | arg_set['vals']) 14 | to_select = np.logical_and(to_select, lambd(M, col_name, vals)) 15 | return append_cols(M, to_select, generated_name) 16 | 17 | # where_all_are_true( 18 | # M, 19 | # [{'func': val_eq, 'col_name': 'f1', 'vals': 4}, 20 | # {'func': val_between, 'col_name': 'f7', 'vals': (1.2, 2.5)}] 21 | 22 | def is_outlier(M, col_name, boundary): 23 | std = np.std(M[col_name]) 24 | mean = np.mean(M[col_name]) 25 | return (np.logical_or( (mean-3*std)>M[col_name], (mean+3*std) boundary 39 | 40 | def val_between(M, col_name, boundary): 41 | return np.logical_and(boundary[0] <= M[col_name], M[col_name] <= boundary[1]) 42 | 43 | 44 | 45 | def generate_bin(col, num_bins): 46 | """Generates a column of categories, where each category is a bin. 47 | 48 | Parameters 49 | ---------- 50 | col : np.array 51 | 52 | Returns 53 | ------- 54 | np.array 55 | 56 | Examples 57 | -------- 58 | >>> M = np.array([0.1, 3.0, 0.0, 1.2, 2.5, 1.7, 2]) 59 | >>> generate_bin(M, 3) 60 | [0 3 0 1 2 1 2] 61 | 62 | """ 63 | 64 | minimum = float(min(col)) 65 | maximum = float(max(col)) 66 | distance = float(maximum - minimum) 67 | return [int((x - minimum) / distance * num_bins) for x in col] 68 | 69 | def normalize(col, mean=None, stddev=None, return_fit=False): 70 | """ 71 | 72 | Generate a normalized column. 73 | 74 | Normalize both mean and std dev. 75 | 76 | Parameters 77 | ---------- 78 | col : np.array 79 | mean : float or None 80 | Mean to use for fit. If none, will use 0 81 | stddev : float or None 82 | return_fit : boolean 83 | If True, returns tuple of fitted col, mean, and standard dev of fit. 84 | If False, only returns fitted col 85 | Returns 86 | ------- 87 | np.array or (np.array, float, float) 88 | 89 | """ 90 | # see infonavit for applying to different set than we fit on 91 | # https://github.com/dssg/infonavit-public/blob/master/pipeline_src/preprocessing.py#L99 92 | # Logic is from sklearn StandardScaler, but I didn't use sklearn because 93 | # I want to pass in mean and stddev rather than a fitted StandardScaler 94 | # https://github.com/scikit-learn/scikit-learn/blob/a95203b/sklearn/preprocessing/data.py#L276 95 | if mean is None: 96 | mean = np.mean(col) 97 | if stddev is None: 98 | stddev = np.std(col) 99 | res = (col - mean) / stddev 100 | if return_fit: 101 | return (res, mean, stddev) 102 | else: 103 | return res 104 | 105 | def distance_from_point(lat_origin, lng_origin, lat_col, lng_col): 106 | """ Generates a column of how far each record is from the origin""" 107 | return distance(lat_origin, lng_origin, lat_col, lng_col) 108 | 109 | @np.vectorize 110 | def combine_sum(*args): 111 | return sum(args) 112 | 113 | @np.vectorize 114 | def combine_mean(*args): 115 | return np.mean(args) 116 | 117 | def combine_cols(M, lambd, col_names, generated_name): 118 | new_col = lambd(*[M[name] for name in col_names]) 119 | return append_cols(M, new_col, generated_name) 120 | 121 | -------------------------------------------------------------------------------- /eights/investigate/__init__.py: -------------------------------------------------------------------------------- 1 | from investigate import * 2 | 3 | -------------------------------------------------------------------------------- /eights/investigate/investigate.py: -------------------------------------------------------------------------------- 1 | import itertools as it 2 | import numpy as np 3 | 4 | import sklearn 5 | 6 | from collections import Counter 7 | import matplotlib.pyplot as plt 8 | 9 | from sklearn import cross_validation 10 | from sklearn.ensemble import RandomForestClassifier 11 | from sklearn.neighbors import KernelDensity 12 | from sklearn.grid_search import GridSearchCV 13 | 14 | from .investigate_helper import * 15 | from ..communicate import * 16 | from ..utils import is_sa 17 | 18 | 19 | __describe_cols_metrics = [('Count', len), 20 | ('Mean', np.mean), 21 | ('Standard Dev', np.std), 22 | ('Minimum', min), 23 | ('Maximum', max)] 24 | 25 | __describe_cols_fill = [np.nan] * len(__describe_cols_metrics) 26 | 27 | def describe_cols(M): 28 | """takes a SA or list of Np.rayas and returns the summary statistcs 29 | Parameters 30 | ---------- 31 | M import numpy as np 32 | : Structured Array or list of Numpy ND arays. 33 | Description 34 | 35 | Returns 36 | ------- 37 | temp : type 38 | Description 39 | 40 | """ 41 | M = convert_to_sa(M) 42 | descr_rows = [] 43 | for col_name, col_type in M.dtype.descr: 44 | if 'f' in col_type or 'i' in col_type: 45 | col = M[col_name] 46 | row = [col_name] + [func(col) for _, func in 47 | __describe_cols_metrics] 48 | else: 49 | row = [col_name] + __describe_cols_fill 50 | descr_rows.append(row) 51 | col_names = ['Column Name'] + [col_name for col_name, _ in 52 | __describe_cols_metrics] 53 | return convert_to_sa(descr_rows, col_names=col_names) 54 | 55 | 56 | def crosstab(col1, col2): 57 | """ 58 | Makes a crosstab of col1 and col2. This is represented as a 59 | structured array with the following properties: 60 | 61 | 1. The first column is the value of col1 being crossed 62 | 2. The name of every column except the first is the value of col2 being 63 | crossed 64 | 3. To find the number of cooccurences of x from col1 and y in col2, 65 | find the row that has 'x' in col1 and the column named 'y'. The 66 | corresponding cell is the number of cooccurrences of x and y 67 | """ 68 | col1 = np.array(col1) 69 | col2 = np.array(col2) 70 | col1_unique = np.unique(col1) 71 | col2_unique = np.unique(col2) 72 | crosstab_rows = [] 73 | for col1_val in col1_unique: 74 | loc_col1_val = np.where(col1==col1_val)[0] 75 | col2_vals = col2[loc_col1_val] 76 | cnt = Counter(col2_vals) 77 | counts = [cnt[col2_val] if cnt.has_key(col2_val) else 0 for col2_val 78 | in col2_unique] 79 | crosstab_rows.append(['{}'.format(col1_val)] + counts) 80 | col_names = ['col1_value'] + ['{}'.format(col2_val) for col2_val in 81 | col2_unique] 82 | return convert_to_sa(crosstab_rows, col_names=col_names) 83 | 84 | def connect_sql(con_str, allow_caching=False, cache_dir='.'): 85 | return SQLConnection(con_str, allow_caching, cache_dir) 86 | 87 | 88 | 89 | #Plots of desrcptive statsitics 90 | from ..communicate.communicate import plot_correlation_matrix 91 | from ..communicate.communicate import plot_correlation_scatter_plot 92 | from ..communicate.communicate import plot_kernel_density 93 | from ..communicate.communicate import plot_on_timeline 94 | from ..communicate.communicate import plot_box_plot 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /eights/investigate/investigate_helper.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import urllib2 3 | import os 4 | import cPickle 5 | from collections import Counter 6 | import numpy as np 7 | import sqlalchemy as sqla 8 | from ..utils import * 9 | import itertools as it 10 | from datetime import datetime 11 | 12 | 13 | __special_csv_strings = {'': None, 14 | 'True': True, 15 | 'False': False} 16 | 17 | def __correct_csv_cell_type(cell): 18 | # Change strings in CSV to appropriate Python objects 19 | try: 20 | return __special_csv_strings[cell] 21 | except KeyError: 22 | pass 23 | try: 24 | return int(cell) 25 | except ValueError: 26 | pass 27 | try: 28 | return float(cell) 29 | except ValueError: 30 | pass 31 | try: 32 | return parse(cell) 33 | except (TypeError, ValueError): 34 | pass 35 | return cell 36 | 37 | def open_csv_url_as_list(url_loc, delimiter=','): 38 | response = urllib2.urlopen(url_loc) 39 | cr = csv.reader(response, delimiter=delimiter) 40 | return list(cr) 41 | 42 | def open_csv_as_list(file_loc, delimiter=',', return_col_names=False): 43 | # infers types 44 | with open(file_loc, 'rU') as f: 45 | reader = csv.reader(f, delimiter=delimiter) 46 | names = reader.next() # skip header 47 | data = [[__correct_csv_cell_type(cell) for cell in row] for 48 | row in reader] 49 | if return_col_names: 50 | return data, names 51 | return data 52 | 53 | def open_csv_as_structured_array(file_loc, delimiter=','): 54 | python_list, names = open_csv_as_list(file_loc, delimiter, True) 55 | return cast_list_of_list_to_sa(python_list, names) 56 | 57 | def convert_fixed_width_list_to_CSV_list(data, list_of_widths): 58 | #assumes you loaded a fixed with thing into a list of list csv. 59 | #not clear what this does with the 's's... 60 | s = "s".join([str(s) for s in list_of_widths]) 61 | s = s + 's' 62 | out = [] 63 | for x in data: 64 | out.append(struct.unpack(s, x[0])) 65 | return out 66 | 67 | # let's not use this any more 68 | #def set_structured_array_datetime_as_day(first_pass,file_loc, delimiter=','): 69 | # date_cols = [] 70 | # int_cols = [] 71 | # new_dtype = [] 72 | # for i, (col_name, col_dtype) in enumerate(first_pass.dtype.descr): 73 | # if 'S' in col_dtype: 74 | # col = first_pass[col_name] 75 | # if np.any(validate_time(col)): 76 | # date_cols.append(i) 77 | # # TODO better inference 78 | # # col_dtype = 'M8[D]' 79 | # col_dtype = np.datetime64(col[0]).dtype 80 | # elif 'i' in col_dtype: 81 | # int_cols.append(i) 82 | # new_dtype.append((col_name, col_dtype)) 83 | # 84 | # converter = {i: str_to_time for i in date_cols} 85 | # missing_values = {i: '' for i in int_cols} 86 | # filling_values = {i: -999 for i in int_cols} 87 | # return np.genfromtxt(file_loc, dtype=new_dtype, names=True, delimiter=delimiter, 88 | # converters=converter, missing_values=missing_values, 89 | # filling_values=filling_values) 90 | 91 | 92 | 93 | def describe_column(col): 94 | if col.dtype.kind not in ['f','i']: 95 | return {} 96 | cnt = len(col) 97 | mean = np.mean(np.array(col)) 98 | std = np.std(np.array(col)) 99 | mi = min(col) 100 | mx = max(col) 101 | return {'Count:' : cnt,'Mean:': mean, 'Standard Dev:': std, 'Minimal ': mi,'Maximal:': mx} 102 | 103 | 104 | 105 | class SQLConnection(object): 106 | # Intended to vaguely implement DBAPI 2 107 | # If allow_caching is True, will pickle results in cache_dir and reuse 108 | # them if it encounters an identical query twice. 109 | def __init__(self, con_str, allow_caching=False, cache_dir='.'): 110 | self.__engine = sqla.create_engine(con_str) 111 | self.__cache_dir = cache_dir 112 | if allow_caching: 113 | self.execute = self.__execute_with_cache 114 | 115 | def __sql_to_sa(self, exec_str): 116 | raw_python = self.__engine.execute(exec_str) 117 | return cast_list_of_list_to_sa( 118 | raw_python.fetchall(), 119 | [str(key) for key in raw_python.keys()]) 120 | 121 | def __execute_with_cache(self, exec_str, invalidate_cache=False): 122 | pkl_file_name = os.path.join( 123 | self.__cache_dir, 124 | 'eights_cache_{}.pkl'.format(hash(exec_str))) 125 | if os.path.exists(pkl_file_name) and not invalidate_cache: 126 | with open(pkl_file_name) as fin: 127 | return cPickle.load(fin) 128 | ret = self.__sql_to_sa(exec_str) 129 | with open(pkl_file_name, 'w') as fout: 130 | cPickle.dump(ret, fout) 131 | return ret 132 | 133 | def execute(self, exec_str, invalidate_cache=False): 134 | return self.__sql_to_sa(exec_str) 135 | -------------------------------------------------------------------------------- /eights/operate/__init__.py: -------------------------------------------------------------------------------- 1 | from operate import * -------------------------------------------------------------------------------- /eights/operate/operate.py: -------------------------------------------------------------------------------- 1 | 2 | from sklearn.ensemble import (AdaBoostClassifier, 3 | RandomForestClassifier, 4 | ExtraTreesClassifier, 5 | GradientBoostingClassifier) 6 | from sklearn.linear_model import (LogisticRegression, 7 | RidgeClassifier, 8 | SGDClassifier, 9 | Perceptron, 10 | PassiveAggressiveClassifier) 11 | from sklearn.cross_validation import (StratifiedKFold, 12 | KFold) 13 | from sklearn.naive_bayes import (BernoulliNB, 14 | MultinomialNB, 15 | GaussianNB) 16 | from sklearn.neighbors import(KNeighborsClassifier, 17 | NearestCentroid) 18 | from sklearn.tree import DecisionTreeClassifier 19 | from sklearn.svm import SVC 20 | from sklearn.dummy import DummyClassifier 21 | 22 | from ..utils import remove_cols 23 | from ..perambulate import Experiment 24 | 25 | std_clfs = [{'clf': AdaBoostClassifier, 'n_estimators': [20,50,100]}, 26 | {'clf': RandomForestClassifier, 27 | 'n_estimators': [10,30,50], 28 | 'max_features': ['sqrt','log2'], 29 | 'max_depth': [None,4,7,15], 30 | 'n_jobs':[1]}, 31 | {'clf': LogisticRegression, 32 | 'C': [1.0,2.0,0.5,0.25], 33 | 'penalty': ['l1','l2']}, 34 | {'clf': DecisionTreeClassifier, 35 | 'max_depth': [None,4,7,15,25]}, 36 | {'clf': SVC, 'kernel': ['linear','rbf'], 37 | 'probability': [True]}, 38 | {'clf': DummyClassifier, 39 | 'strategy': ['stratified','most_frequent','uniform']}] 40 | 41 | DBG_std_clfs = [{'clf': AdaBoostClassifier, 'n_estimators': [20]}, 42 | {'clf': RandomForestClassifier, 43 | 'n_estimators': [10], 44 | 'max_features': ['sqrt'], 45 | 'max_depth': [None], 46 | 'n_jobs':[1]}, 47 | {'clf': LogisticRegression, 48 | 'C': [1.0], 49 | 'penalty': ['l1']}, 50 | {'clf': DecisionTreeClassifier, 51 | 'max_depth': [None]}, 52 | {'clf': DummyClassifier, 53 | 'strategy': ['stratified','most_frequent']}] 54 | 55 | 56 | rg_clfs= [{'clf': RandomForestClassifier, 57 | 'n_estimators': [1,10,100,1000,10000], 58 | 'max_depth': [1,5,10,20,50,100], 59 | 'max_features': ['sqrt','log2'], 60 | 'min_samples_split': [2,5,10], 61 | 'n_jobs': [1]}, 62 | {'clf': LogisticRegression, 63 | 'penalty': ['l1','l2'], 64 | 'C': [0.00001,0.0001,0.001,0.01,0.1,1,10]}, 65 | {'clf': SGDClassifier, 66 | 'loss':['hinge','log','perceptron'], 67 | 'penalty':['l2','l1','elasticnet']}, 68 | {'clf': ExtraTreesClassifier, 69 | 'n_estimators': [1,10,100,1000,10000], 70 | 'criterion' : ['gini', 'entropy'], 71 | 'max_depth': [1,5,10,20,50,100], 72 | 'max_features': ['sqrt','log2'], 73 | 'min_samples_split': [2,5,10], 74 | 'n_jobs': [1]}, 75 | {'clf': AdaBoostClassifier, 76 | 'algorithm' :['SAMME', 'SAMME.R'], 77 | 'n_estimators': [1,10,100,1000,10000], 78 | 'base_estimator': [DecisionTreeClassifier(max_depth=1)]}, 79 | {'clf': GradientBoostingClassifier, 80 | 'n_estimators': [1,10,100,1000,10000], 81 | 'learning_rate' : [0.001,0.01,0.05,0.1,0.5], 82 | 'subsample' : [0.1,0.5,1.0], 83 | 'max_depth': [1,3,5,10,20,50,100]}, 84 | {'clf': GaussianNB }, 85 | {'clf': DecisionTreeClassifier, 86 | 'criterion': ['gini', 'entropy'], 87 | 'max_depth': [1,5,10,20,50,100], 88 | 'max_features': ['sqrt','log2'], 89 | 'min_samples_split': [2,5,10]}, 90 | {'clf':SVC, 91 | 'C': [0.00001,0.0001,0.001,0.01,0.1,1,10], 92 | 'kernel': ['linear'], 93 | 'probability': [True]}, 94 | {'clf': KNeighborsClassifier, 95 | 'n_neighbors':[1,5,10,25,50,100], 96 | 'weights': ['uniform','distance'], 97 | 'algorithm':['auto','ball_tree','kd_tree']}] 98 | 99 | DBG_rg_clfs= [{'clf': RandomForestClassifier, 100 | 'n_estimators': [1], 101 | 'max_depth': [1], 102 | 'max_features': ['sqrt'], 103 | 'min_samples_split': [2], 104 | 'n_jobs': [1]}, 105 | {'clf': LogisticRegression, 106 | 'penalty': ['l1'], 107 | 'C': [0.00001]}, 108 | {'clf': SGDClassifier, 109 | 'loss':['log'], # hinge doesn't have predict_proba 110 | 'penalty':['l2']}, 111 | {'clf': ExtraTreesClassifier, 112 | 'n_estimators': [1], 113 | 'criterion' : ['gini'], 114 | 'max_depth': [1], 115 | 'max_features': ['sqrt'], 116 | 'min_samples_split': [2], 117 | 'n_jobs': [1]}, 118 | {'clf': AdaBoostClassifier, 119 | 'algorithm' :['SAMME'], 120 | 'n_estimators': [1], 121 | 'base_estimator': [DecisionTreeClassifier(max_depth=1)]}, 122 | {'clf': GradientBoostingClassifier, 123 | 'n_estimators': [1], 124 | 'learning_rate' : [0.001], 125 | 'subsample' : [0.1], 126 | 'max_depth': [1]}, 127 | {'clf': GaussianNB }, 128 | {'clf': DecisionTreeClassifier, 129 | 'criterion': ['gini'], 130 | 'max_depth': [1], 131 | 'max_features': ['sqrt'], 132 | 'min_samples_split': [2]}, 133 | {'clf':SVC, 134 | 'C': [0.00001], 135 | 'kernel': ['linear'], 136 | 'probability': [True]}, 137 | {'clf': KNeighborsClassifier, 138 | 'n_neighbors':[1], 139 | 'weights': ['uniform'], 140 | 'algorithm':['auto']}] 141 | 142 | alt_clfs = [{'clf': RidgeClassifier, 'tol':[1e-2], 'solver':['lsqr']}, 143 | {'clf': SGDClassifier, 'alpha':[.0001], 'n_iter':[50],'penalty':['l1', 'l2', 'elasticnet']}, 144 | {'clf': Perceptron, 'n_iter':[50]}, 145 | {'clf': PassiveAggressiveClassifier, 'n_iter':[50]}, 146 | {'clf': BernoulliNB, 'alpha':[.01]}, 147 | {'clf': MultinomialNB, 'alpha':[.01]}, 148 | {'clf': KNeighborsClassifier, 'n_neighbors':[10]}, 149 | {'clf': NearestCentroid}] 150 | 151 | -------------------------------------------------------------------------------- /eights/perambulate/__init__.py: -------------------------------------------------------------------------------- 1 | from perambulate import * -------------------------------------------------------------------------------- /eights/perambulate/perambulate.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import json 3 | import copy 4 | import abc 5 | import datetime 6 | import itertools as it 7 | import numpy as np 8 | import csv 9 | import os 10 | 11 | from collections import Counter 12 | 13 | from sklearn.ensemble import RandomForestClassifier 14 | from sklearn.tree import DecisionTreeClassifier 15 | from sklearn.ensemble import AdaBoostClassifier 16 | from sklearn.cross_validation import KFold, StratifiedKFold 17 | from sklearn.cross_validation import _PartitionIterator 18 | 19 | from joblib import Parallel, delayed 20 | from multiprocessing import cpu_count 21 | 22 | from .perambulate_helper import * 23 | import eights.utils as utils 24 | 25 | def _run_trial(trial): 26 | return trial.run() 27 | 28 | class Experiment(object): 29 | def __init__( 30 | self, 31 | M, 32 | y, 33 | clfs=[{'clf': RandomForestClassifier}], 34 | subsets=[{'subset': SubsetNoSubset}], 35 | cvs=[{'cv': NoCV}], 36 | trials=None): 37 | if utils.is_sa(M): 38 | self.col_names = M.dtype.names 39 | self.M = utils.cast_np_sa_to_nd(M) 40 | else: # assuming an nd_array 41 | self.M = M 42 | self.col_names = ['f{}'.format(i) for i in xrange(M.shape[1])] 43 | self.y = y 44 | self.clfs = clfs 45 | self.subsets = subsets 46 | self.cvs = cvs 47 | self.trials = trials 48 | 49 | def __repr__(self): 50 | return 'Experiment(clfs={}, subsets={}, cvs={})'.format( 51 | self.clfs, 52 | self.subsets, 53 | self.cvs) 54 | 55 | 56 | def __run_all_trials(self, trials): 57 | # TODO parallelize on Runs too 58 | #return Parallel(n_jobs=cpu_count())(delayed(_run_trial)(t) 59 | # for t in trials) 60 | return [_run_trial(t) for t in trials] 61 | 62 | def __copy(self, trials): 63 | return Experiment( 64 | self.M, 65 | self.y, 66 | self.clfs, 67 | self.subsets, 68 | self.cvs, 69 | trials) 70 | 71 | def __transpose_dict_of_lists(self, dol): 72 | # http://stackoverflow.com/questions/5228158/cartesian-product-of-a-dictionary-of-lists 73 | return (dict(it.izip(dol, x)) for 74 | x in it.product(*dol.itervalues())) 75 | 76 | def slice_on_dimension(self, dimension, value, trials=None): 77 | self.run() 78 | return self.__copy([trial for trial in self.trials if 79 | trial[dimension] == value]) 80 | 81 | def iterate_over_dimension(self, dimension): 82 | by_dim = {} 83 | for trial in self.trials: 84 | val_of_dim = trial[dimension] 85 | try: 86 | by_dim[val_of_dim].append(trial) 87 | except KeyError: 88 | by_dim[val_of_dim] = [trial] 89 | for val_of_dim, trials_this_dim in by_dim.iteritems(): 90 | yield (val_of_dim, self.__copy(trials_this_dim)) 91 | 92 | 93 | def slice_by_best_score(self, dimension): 94 | self.run() 95 | categories = {} 96 | other_dims = list(dimensions) 97 | other_dims.remove(dimension) 98 | for trial in self.trials: 99 | # http://stackoverflow.com/questions/5884066/hashing-a-python-dictionary 100 | key = repr([trial[dim] for dim in other_dims]) 101 | try: 102 | categories[key].append(trial) 103 | except KeyError: 104 | categories[key] = [trial] 105 | result = [] 106 | for key in categories: 107 | result.append(max( 108 | categories[key], 109 | key=lambda trial: trial.average_score())) 110 | return self.__copy(result) 111 | 112 | def has_run(self): 113 | return self.trials is not None 114 | 115 | def run(self): 116 | if self.has_run(): 117 | return self.trials 118 | trials = [] 119 | for clf_args in self.clfs: 120 | clf = clf_args['clf'] 121 | all_clf_ps = clf_args.copy() 122 | del all_clf_ps['clf'] 123 | for clf_params in self.__transpose_dict_of_lists(all_clf_ps): 124 | for subset_args in self.subsets: 125 | subset = subset_args['subset'] 126 | all_sub_ps = subset_args.copy() 127 | del all_sub_ps['subset'] 128 | for subset_params in self.__transpose_dict_of_lists(all_sub_ps): 129 | for cv_args in self.cvs: 130 | cv = cv_args['cv'] 131 | all_cv_ps = cv_args.copy() 132 | del all_cv_ps['cv'] 133 | for cv_params in self.__transpose_dict_of_lists(all_cv_ps): 134 | trial = Trial( 135 | M=self.M, 136 | y=self.y, 137 | col_names=self.col_names, 138 | clf=clf, 139 | clf_params=clf_params, 140 | subset=subset, 141 | subset_params=subset_params, 142 | cv=cv, 143 | cv_params=cv_params) 144 | trials.append(trial) 145 | trials = self.__run_all_trials(trials) 146 | self.trials = trials 147 | return trials 148 | 149 | def average_score(self): 150 | self.run() 151 | return {trial: trial.average_score() for trial in self.trials} 152 | 153 | def roc_auc(self): 154 | self.run() 155 | return {trial: trial.roc_auc() for trial in self.trials} 156 | 157 | @staticmethod 158 | def csv_header(): 159 | return Trial.csv_header() 160 | 161 | def make_report( 162 | self, 163 | report_file_name='report.pdf', 164 | dimension=None, 165 | return_report_object=False, 166 | verbose=True): 167 | # TODO make this more flexible 168 | from ..communicate import Report 169 | self.run() 170 | if dimension is None: 171 | dim_iter = [(None, self)] 172 | else: 173 | dim_iter = self.iterate_over_dimension(dimension) 174 | rep = Report(self, report_file_name) 175 | rep.add_heading('Eights Report {}'.format(datetime.datetime.now()), 1) 176 | for val_of_dim, sub_exp in dim_iter: 177 | sub_rep = Report(sub_exp) 178 | if val_of_dim is not None: 179 | sub_rep.add_heading('Subreport for {} = {}'.format( 180 | dimension_descr[dimension], 181 | val_of_dim), 1) 182 | sub_rep.add_heading('Roc AUCs', 3) 183 | sub_rep.add_summary_graph_roc_auc() 184 | sub_rep.add_heading('Average Scores', 3) 185 | sub_rep.add_summary_graph_average_score() 186 | sub_rep.add_heading('ROC for best trial', 3) 187 | sub_rep.add_graph_for_best_roc() 188 | sub_rep.add_heading('Prec recall for best trial', 3) 189 | sub_rep.add_graph_for_best_prec_recall() 190 | sub_rep.add_heading('Legend', 3) 191 | sub_rep.add_legend() 192 | rep.add_subreport(sub_rep) 193 | returned_report_file_name = rep.to_pdf(verbose=verbose) 194 | if return_report_object: 195 | return (returned_report_file_name, rep) 196 | return returned_report_file_name 197 | 198 | def make_csv(self, file_name='report.csv'): 199 | self.run() 200 | with open(file_name, 'w') as fout: 201 | writer = csv.writer(fout) 202 | writer.writerow(self.csv_header()) 203 | for trial in self.trials: 204 | writer.writerows(trial.csv_rows()) 205 | return os.path.abspath(file_name) 206 | 207 | 208 | def random_subset_of_columns(M, number_to_select): 209 | num_col = len(M.dtypes.names) 210 | remove_these_columns = np.random.choice(num_col, number_to_select, replace=False) 211 | names = [col_names[i] for i in remove_these_columns] 212 | return names 213 | 214 | 215 | 216 | 217 | 218 | 219 | -------------------------------------------------------------------------------- /eights/truncate/__init__.py: -------------------------------------------------------------------------------- 1 | from truncate import * -------------------------------------------------------------------------------- /eights/truncate/truncate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import Counter 3 | from truncate_helper import * 4 | from ..utils import remove_cols 5 | 6 | def remove_col_where(M, arguments): 7 | to_remove = np.ones(len(M.dtype), dtype=bool) 8 | for arg_set in arguments: 9 | lambd, vals = (arg_set['func'], arg_set['vals']) 10 | to_remove = np.logical_and(to_remove, lambd(M, vals)) 11 | remove_col_names = [col_name for col_name,included in zip(M.dtype.names, to_remove) if included] 12 | return remove_cols(M, remove_col_names) 13 | 14 | def all_equal_to(M, boundary): 15 | return [np.all(M[col_name] == boundary) for col_name in M.dtype.names] 16 | 17 | def all_same_value(M, boundary=None): 18 | return [np.all(M[col_name]==M[col_name][0]) for col_name in M.dtype.names] 19 | 20 | def fewer_then_n_nonzero_in_col(M, boundary): 21 | return [len(np.where(M[col_name]!=0)[0])<2 for col_name in M.dtype.names] 22 | 23 | def remove_rows_where(M, lamd, col_name, vals): 24 | to_remove = lamd(M, col_name, vals) 25 | to_keep = np.logical_not(to_remove) 26 | return M[to_keep] 27 | 28 | 29 | 30 | from ..generate.generate import val_eq 31 | from ..generate.generate import val_lt 32 | from ..generate.generate import val_gt 33 | from ..generate.generate import val_between 34 | from ..generate.generate import is_outlier 35 | 36 | 37 | 38 | 39 | 40 | #def fewer_then_n_nonzero_in_col(M, boundary): 41 | # col_names = M.dtype.names 42 | # num_rows =M.shape[0] 43 | # l = [sum(M[n]==0) for n in col_names] 44 | # remove_these_columns = np.where(np.array(l)>=(num_rows-boundary))[0] 45 | # names = [col_names[i] for i in remove_these_columns] 46 | # return remove_cols(M, names) 47 | 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /eights/truncate/truncate_helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import Counter 3 | 4 | 5 | #checks 6 | 7 | def is_within_region(L, point): 8 | import matplotlib.path as mplPath 9 | bbPath = mplPath.Path(np.array(L)) 10 | return bbPath.contains_point(point) 11 | 12 | #remove 13 | def remove_these_columns(M, list_of_col_to_remove): 14 | return M[[col for col in M.dtype.names if col not in list_of_col_to_remove]] 15 | 16 | 17 | 18 | 19 | def col_has_all_same_val(col): 20 | return np.all(col==col[0]) 21 | 22 | def col_has_one_unique_val(col): 23 | d = Counter(col) 24 | if len(d) == 2: #ignores case for -999, null 25 | return (1 in d.values()) 26 | return False 27 | 28 | def col_has_lt_threshold_unique_values(col, threshold): 29 | d = Counter(col) 30 | vals = sort(d.values()) 31 | return ( sum(vals[:-1]) < threshold) 32 | 33 | -------------------------------------------------------------------------------- /eights/utils.py: -------------------------------------------------------------------------------- 1 | #files included here are those that you SHOULD be able to do in python syntax but can not. 2 | import numpy as np 3 | import numpy.lib.recfunctions as nprf 4 | import matplotlib.mlab 5 | import itertools as it 6 | from datetime import datetime 7 | from dateutil.parser import parse 8 | 9 | NOT_A_TIME = np.datetime64('NaT') 10 | 11 | def utf_to_ascii(s): 12 | # http://stackoverflow.com/questions/4299675/python-script-to-convert-from-utf-8-to-ascii 13 | if isinstance(s, unicode): 14 | return s.encode('ascii', 'replace') 15 | return s 16 | 17 | @np.vectorize 18 | def validate_time(date_text): 19 | return __str_to_datetime(date_text) != NOT_A_TIME 20 | 21 | def str_to_time(date_text): 22 | return __str_to_datetime(date_text) 23 | 24 | def invert_dictionary(aDict): 25 | return {v: k for k, v in aDict.items()} 26 | 27 | 28 | TYPE_PRECEDENCE = {type(None): 0, 29 | bool: 100, 30 | np.bool_: 101, 31 | int: 200, 32 | long: 300, 33 | np.int64: 301, 34 | float: 400, 35 | np.float64: 401, 36 | str: 500, 37 | np.string_: 501, 38 | unicode: 600, 39 | np.unicode_: 601, 40 | datetime: 700, 41 | np.datetime64: 701} 42 | 43 | def __primitive_clean(cell, expected_type, alt): 44 | if cell == None: 45 | return alt 46 | try: 47 | return expected_type(cell) 48 | except (TypeError, ValueError): 49 | return alt 50 | 51 | def __datetime_clean(cell): 52 | # Because, unlike primitives, we can't cast random objects to datetimes 53 | if isinstance(cell, datetime): 54 | return cell 55 | if isinstance(cell, basestring): 56 | return str_to_time(cell) 57 | return NOT_A_TIME 58 | 59 | def __datetime64_clean(cell): 60 | try: 61 | # Not dealing with resolution. Everythin is us 62 | return np.datetime64(cell).astype('M8[us]') 63 | except (TypeError, ValueError): 64 | return NOT_A_TIME 65 | 66 | 67 | CLEAN_FUNCTIONS = {type(None): lambda cell: '', 68 | bool: lambda cell: __primitive_clean(cell, bool, False), 69 | np.bool_: lambda cell: __primitive_clean(cell, np.bool_, 70 | np.bool_(False)), 71 | int: lambda cell: __primitive_clean(cell, int, -999), 72 | long: lambda cell: __primitive_clean(cell, long, -999L), 73 | np.int64: lambda cell: __primitive_clean(cell, np.int64, 74 | np.int64(-999L)), 75 | float: lambda cell: __primitive_clean(cell, float, np.nan), 76 | np.float64: lambda cell: __primitive_clean(cell, np.float64, 77 | np.nan), 78 | str: lambda cell: __primitive_clean(cell, str, ''), 79 | np.string_: lambda cell: __primitive_clean(cell, np.string_, 80 | np.string_('')), 81 | unicode: lambda cell: __primitive_clean(cell, unicode, u''), 82 | np.unicode_: lambda cell: __primitive_clean( 83 | cell, 84 | np.unicode_, 85 | np.unicode_('')), 86 | datetime: __datetime_clean, 87 | np.datetime64: __datetime64_clean} 88 | 89 | STR_TYPE_LETTERS = {str: 'S', 90 | np.string_: 'S', 91 | unicode: 'U', 92 | np.unicode_: 'U'} 93 | 94 | 95 | def __str_to_datetime(s): 96 | # Invalid time if the string is too short 97 | # This prevents empty strings from being times 98 | # as well as odd short strings like 'a' 99 | if len(s) < 6: 100 | return NOT_A_TIME 101 | # Invalid time if the string is just a number 102 | try: 103 | float(s) 104 | return NOT_A_TIME 105 | except ValueError: 106 | pass 107 | # Invalid time if dateutil.parser.parse can't parse it 108 | try: 109 | return parse(s) 110 | except (TypeError, ValueError): 111 | return NOT_A_TIME 112 | 113 | def __str_col_to_datetime(col): 114 | col_dtimes = [__str_to_datetime(s) for s in col] 115 | valid_dtimes = [dt for dt in col_dtimes if dt != NOT_A_TIME] 116 | # If there is even one valid datetime, we're calling this a datetime col 117 | return (bool(valid_dtimes), col_dtimes) 118 | 119 | def cast_list_of_list_to_sa(L, col_names=None, dtype=None): 120 | # TODO utils.cast_list_of_list_to_sa is redundant J: NOT agreed 121 | n_cols = len(L[0]) 122 | if col_names is None: 123 | col_names = ['f{}'.format(i) for i in xrange(n_cols)] 124 | dtypes = [] 125 | cleaned_cols = [] 126 | if dtype is None: 127 | for idx, col in enumerate(it.izip(*L)): 128 | dom_type = type(max( 129 | col, 130 | key=lambda cell: TYPE_PRECEDENCE[type(cell)])) 131 | if dom_type in (bool, np.bool_, int, long, np.int64, float, 132 | np.float64): 133 | dtypes.append(dom_type) 134 | cleaned_cols.append(map(CLEAN_FUNCTIONS[dom_type], col)) 135 | elif dom_type == datetime: 136 | dtypes.append('M8[us]') 137 | cleaned_cols.append(map(CLEAN_FUNCTIONS[dom_type], col)) 138 | elif dom_type == np.datetime64: 139 | dtypes.append('M8[us]') 140 | cleaned_cols.append(map(CLEAN_FUNCTIONS[dom_type], col)) 141 | elif dom_type in (str, unicode, np.string_, np.unicode_): 142 | cleaned_col = map(CLEAN_FUNCTIONS[dom_type], col) 143 | is_datetime, dt_col = __str_col_to_datetime(cleaned_col) 144 | if is_datetime: 145 | dtypes.append('M8[us]') 146 | cleaned_cols.append(dt_col) 147 | else: 148 | max_len = max( 149 | len(max(cleaned_col, 150 | key=lambda cell: len(dom_type(cell)))), 151 | 1) 152 | dtypes.append('|{}{}'.format( 153 | STR_TYPE_LETTERS[dom_type], 154 | max_len)) 155 | cleaned_cols.append(cleaned_col) 156 | elif dom_type == type(None): 157 | # column full of None make it a column of empty strings 158 | dtypes.append('|S1') 159 | cleaned_cols.append([''] * len(col)) 160 | else: 161 | raise ValueError( 162 | 'Type of col: {} could not be determined'.format( 163 | col_names[idx])) 164 | 165 | return np.fromiter(it.izip(*cleaned_cols), 166 | dtype={'names': col_names, 167 | 'formats': dtypes}) 168 | 169 | def convert_to_sa(M, col_names=None): 170 | """Converts an list of lists or a np ndarray to a Structured Arrray 171 | Parameters 172 | ---------- 173 | M : List of List or np.ndarray 174 | This is the Matrix M, that it is assumed is the basis for the ML algorithm 175 | col_names : list of str or None 176 | Column names for new sa. If M is already a structured array, col_names 177 | will be ignored. If M is not a structured array and col_names is None, 178 | names will be generated 179 | 180 | Returns 181 | ------- 182 | temp : Numpy Structured array 183 | This is the matrix of an appropriate type that eights expects. 184 | 185 | """ 186 | if is_sa(M): 187 | return M 188 | 189 | if is_nd(M): 190 | return cast_np_nd_to_sa(M, names=col_names) 191 | 192 | if isinstance(M, list): 193 | return cast_list_of_list_to_sa(M, col_names=col_names) 194 | # TODO make sure this function ^ ensures list of /lists/ 195 | 196 | raise ValueError('Can\'t cast to sa') 197 | 198 | __type_permissiveness_ranks = {'b': 0, 'M': 100, 'm': 100, 'i': 200, 'f': 300, 'S': 400} 199 | def __type_permissiveness(dtype): 200 | # TODO handle other types 201 | return __type_permissiveness_ranks[dtype.kind] + dtype.itemsize 202 | 203 | def np_dtype_is_homogeneous(A): 204 | """True iff dtype is nonstructured or every sub dtype is the same""" 205 | # http://stackoverflow.com/questions/3787908/python-determine-if-all-items-of-a-list-are-the-same-item 206 | if not is_sa(A): 207 | return True 208 | dtype = A.dtype 209 | first_dtype = dtype[0] 210 | return all(dtype[i] == first_dtype for i in xrange(len(dtype))) 211 | 212 | def cast_np_nd_to_sa(nd, dtype=None, names=None): 213 | """ 214 | Returns a view of a numpy, single-type, 0, 1 or 2-dimensional array as a 215 | structured array 216 | Parameters 217 | ---------- 218 | nd : numpy.ndarray 219 | The array to view 220 | dtype : numpy.dtype or None (optional) 221 | The type of the structured array. If not provided, or None, nd.dtype is 222 | used for all columns. 223 | If the dtype requested is not homogeneous and the datatype of each 224 | column is not identical nd.dtype, this operation may involve copying 225 | and conversion. Consequently, this operation should be avoided with 226 | heterogeneous or different datatypes. 227 | Returns 228 | ------- 229 | A structured numpy.ndarray 230 | """ 231 | if nd.ndim not in (0, 1, 2): 232 | raise TypeError('np_nd_to_sa only takes 0, 1 or 2-dimensional arrays') 233 | nd_dtype = nd.dtype 234 | if nd.ndim <= 1: 235 | nd = nd.reshape(nd.size, 1) 236 | if dtype is None: 237 | n_cols = nd.shape[1] 238 | if names is None: 239 | names = map('f{}'.format, xrange(n_cols)) 240 | dtype = np.dtype({'names': names,'formats': [nd_dtype for i in xrange(n_cols)]}) 241 | return nd.reshape(nd.size).view(dtype) 242 | type_len = nd_dtype.itemsize 243 | #import pdb; pdb.set_trace() 244 | if all(dtype[i] == nd_dtype for i in xrange(len(dtype))): 245 | return nd.reshape(nd.size).view(dtype) 246 | #import pdb; pdb.set_trace() 247 | # if the user requests an incompatible type, we have to convert 248 | cols = (nd[:,i].astype(dtype[i]) for i in xrange(len(dtype))) 249 | return np.array(it.izip(*cols), dtype=dtype) 250 | 251 | def cast_np_sa_to_nd(sa): 252 | """ 253 | 254 | Returns a view of a numpy structured array as a single-type 1 or 255 | 2-dimensional array. If the resulting nd array would be a column vector, 256 | returns a 1-d array instead. If the resulting array would have a single 257 | entry, returns a 0-d array instead 258 | All elements are converted to the most permissive type. permissiveness 259 | is determined first by finding the most permissive type in the ordering: 260 | datetime64 < int < float < string 261 | then by selecting the longest typelength among all columns with with that 262 | type. 263 | If the sa does not have a homogeneous datatype already, this may require 264 | copying and type conversion rather than just casting. Consequently, this 265 | operation should be avoided for heterogeneous arrays 266 | Based on http://wiki.scipy.org/Cookbook/Recarray. 267 | Parameters 268 | ---------- 269 | sa : numpy.ndarray 270 | The structured array to view 271 | Returns 272 | ------- 273 | np.ndarray 274 | """ 275 | if not is_sa(sa): 276 | return sa 277 | dtype = sa.dtype 278 | if len(dtype) == 1: 279 | if sa.size == 1: 280 | return sa.view(dtype=dtype[0]).reshape(()) 281 | return sa.view(dtype=dtype[0]).reshape(len(sa)) 282 | if np_dtype_is_homogeneous(sa): 283 | return sa.view(dtype=dtype[0]).reshape(len(sa), -1) 284 | # If type isn't homogeneous, we have to convert 285 | dtype_it = (dtype[i] for i in xrange(len(dtype))) 286 | most_permissive = max(dtype_it, key=__type_permissiveness) 287 | col_names = dtype.names 288 | cols = (sa[col_name].astype(most_permissive) for col_name in col_names) 289 | nd = np.column_stack(cols) 290 | return nd 291 | 292 | def distance(lat_1, lon_1, lat_2, lon_2): 293 | """ 294 | Calculate the great circle distance between two points 295 | on the earth (specified in decimal degrees) 296 | from: 297 | http://stackoverflow.com/questions/4913349/haversine-formula-in-python-bearing-and-distance-between-two-gps-points 298 | """ 299 | # convert decimal degrees to radians 300 | 301 | lon_1, lat_1, lon_2, lat_2 = map(np.radians, [lon_1, lat_1, lon_2, lat_2]) 302 | 303 | # haversine formula 304 | dlon = lon_2 - lon_1 305 | dlat = lat_2 - lat_1 306 | a = np.sin(dlat/2)**2 + np.cos(lat_1) * np.cos(lat_2) * np.sin(dlon/2)**2 307 | c = 2 * np.arcsin(np.sqrt(a)) 308 | r = 6371 # 6371 Radius of earth in kilometers. Use 3956 for miles 309 | return c * r 310 | 311 | def dist_less_than(lat_1, lon_1, lat_2, lon_2, threshold): 312 | """single line description 313 | Parameters 314 | ---------- 315 | val : float 316 | miles 317 | Returns 318 | ------- 319 | boolean 320 | 321 | """ 322 | return (distance(lat_1, lon_1, lat_2, lon_2) < threshold) 323 | 324 | def is_sa(M): 325 | return is_nd(M) and M.dtype.names is not None 326 | 327 | def is_nd(M): 328 | return isinstance(M, np.ndarray) 329 | 330 | def stack_rows(*args): 331 | return nprf.stack_arrays(args, usemask=False) 332 | 333 | def sa_from_cols(cols): 334 | # TODO take col names 335 | return nprf.merge_arrays(cols, usemask=False) 336 | 337 | def append_cols(M, cols, names): 338 | return nprf.append_fields(M, names, data=cols, usemask=False) 339 | 340 | def remove_cols(M, col_names): 341 | return nprf.drop_fields(M, col_names, usemask=False) 342 | 343 | def __fill_by_descr(s): 344 | if 'b' in s: 345 | return False 346 | if 'i' in s: 347 | return -999 348 | if 'f' in s: 349 | return np.nan 350 | if 'S' in s: 351 | return '' 352 | if 'U' in s: 353 | return u'' 354 | if 'M' in s or 'm' in s: 355 | return np.datetime64('NaT') 356 | raise ValueError('Unrecognized description {}'.format(s)) 357 | 358 | def join(left, right, how, left_on, right_on, suffixes=('_x', '_y')): 359 | """ 360 | approximates Pandas DataFrame.merge 361 | http://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.merge.html 362 | implements a hash join 363 | http://blogs.msdn.com/b/craigfr/archive/2006/08/10/687630.aspx 364 | """ 365 | 366 | # left_on and right_on can both be strings or lists 367 | if isinstance(left_on, basestring): 368 | left_on = [left_on] 369 | if isinstance(right_on, basestring): 370 | right_on = [right_on] 371 | 372 | # assemble dtype for the merged array 373 | # Rules for naming columns in the new table, as inferred from Pandas: 374 | # 1. If a joined on column has the same name in both tables, it appears 375 | # in the joined table once under that name (no suffix) 376 | # 2. Otherwise, every column from each table will appear in the joined 377 | # table, whether they are joined on or not. If both tables share a 378 | # column name, the name will appear twice with suffixes. If a column 379 | # name appears only in one table, it will appear without a suffix. 380 | frozenset_left_on = frozenset(left_on) 381 | frozenset_right_on = frozenset(right_on) 382 | frozenset_shared_on = frozenset_left_on.intersection(frozenset_right_on) 383 | shared_on = list(frozenset_shared_on) 384 | # get arrays without shared join columns 385 | left_names = left.dtype.names 386 | right_names = right.dtype.names 387 | frozenset_left_names = frozenset(left.dtype.names).difference( 388 | frozenset_shared_on) 389 | left_names = list(frozenset_left_names) 390 | frozenset_right_names = frozenset(right.dtype.names).difference( 391 | frozenset_shared_on) 392 | right_names = list(frozenset_right_names) 393 | left_no_idx = left[left_names] 394 | right_no_idx = right[right_names] 395 | left_names_w_suffix = [col_name + suffixes[0] if 396 | col_name in frozenset_right_names else 397 | col_name for 398 | col_name in left_names] 399 | right_names_w_suffix = [col_name + suffixes[1] if 400 | col_name in frozenset_left_names else 401 | col_name for 402 | col_name in right_names] 403 | col_names = (left_names_w_suffix + shared_on + right_names_w_suffix) 404 | col_dtypes = ([left[left_col].dtype for left_col in left_names] + 405 | [left[shared_on_col].dtype for shared_on_col in shared_on] + 406 | [right[right_col].dtype for right_col in right_names]) 407 | take_all_right_rows = how in ('outer', 'right') 408 | take_all_left_rows = how in ('outer', 'left') 409 | # data to fill in if we're doing an outer join and one of the sides is 410 | # missing 411 | left_fill = tuple([__fill_by_descr(dtype) for _, dtype in 412 | left_no_idx.dtype.descr]) 413 | right_fill = tuple([__fill_by_descr(dtype) for _, dtype in 414 | right_no_idx.dtype.descr]) 415 | 416 | # Make a hash of the first join column in the left table 417 | left_col = left[left_on[0]] 418 | hashed_col = {} 419 | for left_idx, left_cell in enumerate(left_col): 420 | try: 421 | rows = hashed_col[left_cell] 422 | except KeyError: 423 | rows = [] 424 | hashed_col[left_cell] = rows 425 | rows.append(left_idx) 426 | 427 | # Pick out columns that we will be joining on beyond the 0th 428 | extra_left_cols = [left[left_on_name] for left_on_name in left_on[1:]] 429 | extra_right_cols = [right[right_on_name] for right_on_name in right_on[1:]] 430 | extra_contraint_cols = zip(extra_left_cols, extra_right_cols) 431 | 432 | rows_new_table = [] 433 | right_col = right[right_on[0]] 434 | # keep track of used left rows so we can include all the rows if we're 435 | # doing a left or outer join 436 | left_rows_used = set() 437 | # Iterate through every row in the right table 438 | for right_idx, right_cell in enumerate(right_col): 439 | has_match = False 440 | # See if we have matches from the hashed col of the left table 441 | try: 442 | left_matches = hashed_col[right_cell] 443 | 444 | for left_idx in left_matches: 445 | # If all the constraints are met, we have a match 446 | if all([extra_left_col[left_idx] == extra_right_col[right_idx] 447 | for extra_left_col, extra_right_col in 448 | extra_contraint_cols]): 449 | has_match = True 450 | rows_new_table.append( 451 | tuple(left_no_idx[left_idx]) + 452 | tuple([left[shared_on_col][left_idx] 453 | for shared_on_col in shared_on]) + 454 | tuple(right_no_idx[right_idx])) 455 | left_rows_used.add(left_idx) 456 | # No match found for this right row 457 | except KeyError: 458 | pass 459 | # If we're doing a right or outer join and we didn't find a match, add 460 | # this row from the right table, filled with type-appropriate versions 461 | # of NULL from the left table 462 | if (not has_match) and take_all_right_rows: 463 | rows_new_table.append(left_fill + 464 | tuple([right[shared_on_col][right_idx] for shared_on_col in 465 | shared_on]) + 466 | tuple(right_no_idx[right_idx])) 467 | 468 | # if we're doing a left or outer join, we have to add all rows from the 469 | # left table, using type-appropriate versions of NULL for the right table 470 | if take_all_left_rows: 471 | left_rows_unused = [i for i in xrange(len(left)) if i not in 472 | left_rows_used] 473 | for unused_left_idx in left_rows_unused: 474 | rows_new_table.append( 475 | tuple(left_no_idx[unused_left_idx]) + 476 | tuple([left[shared_on_col][unused_left_idx] 477 | for shared_on_col in shared_on]) + 478 | right_fill) 479 | 480 | return np.array(rows_new_table, dtype={'names': col_names, 481 | 'formats': col_dtypes}) 482 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='eights', 5 | version='0.0.1', 6 | url='https://github.com/dssg/eights', 7 | author='Center for Data Science and Public Policy', 8 | description='A library and workflow template for machine learning', 9 | packages=find_packages(), 10 | install_requires=('numpy', 11 | 'scikit-learn', 12 | 'matplotlib', 13 | 'SQLAlchemy', 14 | 'joblib', 15 | 'pdfkit'), 16 | zip_safe=False) 17 | -------------------------------------------------------------------------------- /test_sklearn_iris.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sklearn.datasets 3 | 4 | from eights.investigate import (cast_np_nd_to_sa, describe_cols,) 5 | from eights.communicate import (plot_correlation_scatter_plot, 6 | plot_correlation_matrix, 7 | plot_kernel_density, 8 | plot_box_plot) 9 | 10 | #import numpy array 11 | M = sklearn.datasets.load_iris().data 12 | labels = sklearn.datasets.load_iris().target 13 | 14 | M = cast_np_nd_to_sa(M) 15 | 16 | 17 | #M is multi class, we want to remove those rows. 18 | keep_index = np.where(labels!=2) 19 | 20 | labels = labels[keep_index] 21 | M = M[keep_index] 22 | 23 | 24 | 25 | 26 | if False: 27 | for x in describe_cols(M): 28 | print x 29 | 30 | if False: 31 | plot_correlation_scatter_plot(M) 32 | plot_correlation_matrix(M) 33 | plot_kernel_density(M['f0']) #no designation of col name 34 | plot_box_plot(M['f0']) #no designation of col name 35 | 36 | 37 | if False: 38 | from eights.generate import val_between, where_all_are_true, append_cols #val_btwn, where 39 | #generate a composite rule 40 | M = where_all_are_true(M, 41 | [{'func': val_between, 42 | 'col_name': 'f0', 43 | 'vals': (3.5, 5.0)}, 44 | {'func': val_between, 45 | 'col_name': 'f1', 46 | 'vals': (2.7, 3.1)} 47 | ], 48 | 'a new col_name') 49 | 50 | #new eval function 51 | def rounds_to_val(M, col_name, boundary): 52 | return (np.round(M[col_name]) == boundary) 53 | 54 | M = where_all_are_true(M, 55 | [{'func': rounds_to_val, 56 | 'col_name': 'f0', 57 | 'vals': 5}], 58 | 'new_col') 59 | 60 | from eights.truncate import (fewer_then_n_nonzero_in_col, 61 | remove_rows_where, 62 | remove_cols, 63 | val_eq) 64 | #remove Useless row 65 | M = fewer_then_n_nonzero_in_col(M,1) 66 | M = append_cols(M, labels, 'labels') 67 | M = remove_rows_where(M, val_eq, 'labels', 2) 68 | labels=M['labels'] 69 | M = remove_cols(M, 'labels') 70 | 71 | 72 | from eights.operate import run_std_classifiers, run_alt_classifiers #run_alt_classifiers not working yet 73 | exp = run_std_classifiers(M,labels) 74 | exp.make_csv() 75 | import pdb; pdb.set_trace() 76 | 77 | 78 | ####################Communicate####################### 79 | 80 | 81 | 82 | #Pretend .1 is wrong so set all values of .1 in M[3] as .2 83 | # make a new column where its a test if col,val, (3,.2), (2,1.4) is true. 84 | 85 | 86 | import pdb; pdb.set_trace() 87 | 88 | #from decontaminate import remove_null, remove_999, case_fix, truncate 89 | #from generate import donut 90 | #from aggregate import append_on_right, append_on_bottom 91 | #from truncate import remove 92 | #from operate import run_list, fiveFunctions 93 | #from communicate import graph_all, results_invtestiage 94 | 95 | #investiage 96 | #M_orginal = csv_open(file_loc, file_descpiption) # this is our original files 97 | #results = eights.investigate.describe_all(M_orginal) 98 | #results_invtestiage(results) 99 | 100 | #decontaminate 101 | #aggregate 102 | #generate 103 | #M = np.array([]) #this is the master Matrix we train on. 104 | #labels = np.array([]) # this is tells us 105 | 106 | #truncate 107 | #models = [] #list of functions 108 | 109 | #operate 110 | 111 | #communicate 112 | 113 | 114 | #func_list = [sklearn.randomforest,sklearn.gaussian, ] 115 | 116 | 117 | #If main: 118 | #run on single csv 119 | -------------------------------------------------------------------------------- /test_wine.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | import sklearn.datasets 5 | 6 | from eights.investigate import (cast_np_nd_to_sa, describe_cols,open_csv_url_as_list) 7 | from eights.communicate import (plot_correlation_scatter_plot, 8 | plot_correlation_matrix, 9 | plot_kernel_density, 10 | plot_box_plot) 11 | 12 | from eights.operate import run_std_classifiers 13 | 14 | 15 | data = open_csv_url_as_list( 16 | 'http://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv', 17 | delimiter=';') 18 | 19 | col_names = data[0][:-1] 20 | labels = np.array([int(x[-1]) for x in data[1:]]) 21 | #make this problem binary 22 | labels = np.array([0 if x < np.average(labels) else 1 for x in labels]) 23 | dtype = np.dtype({'names': col_names,'formats': [float] * (len(col_names)+1)}) 24 | M = cast_np_nd_to_sa(np.array([x[:-1] for x in data[1:]],dtype='float'), dtype) 25 | 26 | 27 | 28 | import pdb; pdb.set_trace() 29 | if False: 30 | for x in describe_cols(M): 31 | print x 32 | 33 | if False: 34 | plot_correlation_scatter_plot(M) 35 | plot_correlation_matrix(M) 36 | plot_kernel_density(M['f0']) #no designation of col name 37 | plot_box_plot(M['f0']) #no designation of col name 38 | 39 | 40 | 41 | 42 | exp = run_std_classifiers(M,labels) 43 | exp.make_csv() 44 | 45 | import pdb; pdb.set_trace() 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | import pdb; pdb.set_trace() -------------------------------------------------------------------------------- /tests/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dssg/eights/9f12f9fb60984b8da2270e0df809fa09027336e5/tests/.DS_Store -------------------------------------------------------------------------------- /tests/data/full_test.csv: -------------------------------------------------------------------------------- 1 | id,bool_1,bool_2,dt_1,dt_2,num_1,str,num_2 2 | 100,,1,2012-06-06,2004-03-18,613,str_163,197 3 | 101,1,0,2002-12-28,2014-06-23,644,str_150,38 4 | 102,0,1,2006-11-25,2005-02-24,126,str_35,181 5 | 103,0,1,2008-10-12,2015-09-26,674,str_44,71 6 | 104,1,0,2001-02-14,2003-07-03,1,str_99,4 7 | 105,0,0,2000-05-22,2005-12-26,860,str_24,32 8 | 106,1,0,2008-10-18,2011-03-25,447,str_129,147 9 | 107,1,1,2001-03-06,2015-08-07,896,str_28,185 10 | 108,,0,2014-02-09,2004-07-03,381,str_93,97 11 | 109,0,1,2005-05-27,2015-08-06,374,str_35,133 12 | 110,1,0,2008-01-18,2001-08-03,373,str_89,44 13 | 111,1,0,2009-03-10,2011-11-21,293,str_99,18 14 | 112,0,0,2011-09-25,2003-08-27,593,str_5,125 15 | 113,0,0,2004-10-09,2002-10-06,900,str_163,189 16 | 114,1,0,2011-10-15,2004-03-31,509,str_7,20 17 | 115,0,0,2010-07-12,2012-10-02,637,str_66,90 18 | 116,1,0,2010-07-17,2000-12-03,611,str_131,84 19 | 117,0,1,2002-09-01,2005-10-09,772,str_86,120 20 | 118,1,1,2012-11-04,2005-11-15,689,str_169,46 21 | 119,0,0,2006-03-12,2004-01-31,23,str_162,152 22 | 120,1,1,2010-07-28,2000-05-10,568,str_182,95 23 | 121,,0,2011-10-26,2010-08-03,96,str_54,26 24 | 122,1,0,2008-07-25,2013-10-12,932,str_191,11 25 | 123,1,0,2006-06-10,2010-05-24,533,str_147,185 26 | 124,0,0,2004-06-29,2000-08-23,591,str_14,183 27 | 125,1,1,2002-12-04,2005-08-27,18,str_188,134 28 | -------------------------------------------------------------------------------- /tests/data/mixed.csv: -------------------------------------------------------------------------------- 1 | id,name,height 2 | 0,Jim,5.6 3 | 1,Jill,5.5 4 | -------------------------------------------------------------------------------- /tests/data/small.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dssg/eights/9f12f9fb60984b8da2270e0df809fa09027336e5/tests/data/small.db -------------------------------------------------------------------------------- /tests/data/test_communicate_ref.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dssg/eights/9f12f9fb60984b8da2270e0df809fa09027336e5/tests/data/test_communicate_ref.pdf -------------------------------------------------------------------------------- /tests/data/test_operate_std.pkl: -------------------------------------------------------------------------------- 1 | (dp1 2 | S"Trial(clf=, clf_params={'max_depth': None}, subset=, subset_params={}, cv=, cv_params={})" 3 | p2 4 | cnumpy.core.multiarray 5 | scalar 6 | p3 7 | (cnumpy 8 | dtype 9 | p4 10 | (S'f8' 11 | I0 12 | I1 13 | tRp5 14 | (I3 15 | S'<' 16 | NNNI-1 17 | I-1 18 | I0 19 | tbS'\x87\xbd\x96\xfb\x8e\xdc\xe2?' 20 | tRp6 21 | sS"Trial(clf=, clf_params={'strategy': 'stratified'}, subset=, subset_params={}, cv=, cv_params={})" 22 | p7 23 | g3 24 | (g5 25 | S"\xd3\x08\xe2F\xda'\xde?" 26 | tRp8 27 | sS"Trial(clf=, clf_params={'n_estimators': 10, 'max_features': 'sqrt', 'n_jobs': 1, 'max_depth': None}, subset=, subset_params={}, cv=, cv_params={})" 28 | p9 29 | g3 30 | (g5 31 | S'\xa5\xdb\xb4\x19\xad\xfa\xe0?' 32 | tRp10 33 | sS"Trial(clf=, clf_params={'n_estimators': 20}, subset=, subset_params={}, cv=, cv_params={})" 34 | p11 35 | g3 36 | (g5 37 | S'\x98*x\xae\x87\xec\xdf?' 38 | tRp12 39 | sS"Trial(clf=, clf_params={'strategy': 'most_frequent'}, subset=, subset_params={}, cv=, cv_params={})" 40 | p13 41 | g3 42 | (g5 43 | S'H1\xd4_\x8eH\xe1?' 44 | tRp14 45 | sS"Trial(clf=, clf_params={'penalty': 'l1', 'C': 1.0}, subset=, subset_params={}, cv=, cv_params={})" 46 | p15 47 | g3 48 | (g5 49 | S'M\x9b\xd1\xaa\x0f\xa3\xe0?' 50 | tRp16 51 | s. -------------------------------------------------------------------------------- /tests/data/test_perambulate/make_csv.csv: -------------------------------------------------------------------------------- 1 | clf,clf_C,clf_algorithm,clf_base_estimator,clf_bootstrap,clf_cache_size,clf_class_weight,clf_coef0,clf_compute_importances,clf_constant,clf_criterion,clf_degree,clf_dual,clf_fit_intercept,clf_gamma,clf_intercept_scaling,clf_kernel,clf_learning_rate,clf_max_depth,clf_max_features,clf_max_iter,clf_max_leaf_nodes,clf_min_density,clf_min_samples_leaf,clf_min_samples_split,clf_n_estimators,clf_n_jobs,clf_oob_score,clf_penalty,clf_probability,clf_random_state,clf_shrinking,clf_splitter,clf_strategy,clf_tol,clf_verbose,subset,subset_cols_to_exclude,subset_max_grades,subset_n_subsets,subset_num_rows,subset_proportions_positive,subset_random_state,subset_subset_size,cv,cv_col_name,cv_col_name,cv_expanding_train,cv_inc_value,cv_indices,cv_n_folds,cv_random_state,cv_shuffle,cv_test_start,cv_test_window_size,cv_train_start,cv_train_window_size,subset_note_excluded_col,subset_note_max_grade,subset_note_prop_positive,subset_note_rows,subset_note_sample_num,cv_note_fold,cv_note_test_end,cv_note_test_start,cv_note_train_end,cv_note_train_start,f1_score,roc_auc,prec@1%,prec@2%,prec@5%,prec@10%,prec@20%,feature_ranked_0,feature_ranked_1,feature_ranked_2,feature_ranked_3,feature_ranked_4,feature_ranked_5,feature_ranked_6,feature_ranked_7,feature_ranked_8,feature_ranked_9,feature_score_0,feature_score_1,feature_score_2,feature_score_3,feature_score_4,feature_score_5,feature_score_6,feature_score_7,feature_score_8,feature_score_9 2 | ,,,,,,,,,,,,,,,,,,5,,,,,,,10,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,100,,0,,,,,0.53061224489795922,0.58253205128205132,0.9444444444444444,0.8888888888888888,0.7222222222222221,0.6,0.46794871794871795,f1,f2,f0,f3,f4,,,,,,0.32975670539712815,0.21783633757608575,0.20167114777814693,0.14726882290197957,0.1034669863466596 3 | ,,,,,,,,,,,,,,,,,,5,,,,,,,10,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,100,,1,,,,,0.6071428571428571,0.66666666666666663,1.0,1.0,1.0,0.7916666666666667,0.9,f0,f1,f2,f3,f4,,,,,,0.32431843360321888,0.26233144232559025,0.19900421285846664,0.13025341594729301,0.084092495265431236 4 | ,,,,,,,,,,,,,,,,,,5,,,,,,,10,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,200,,0,,,,,0.53571428571428581,0.50642312324367733,0.0,0.5,0.8,0.6,0.65,f2,f1,f0,f3,f4,,,,,,0.27586232488785944,0.2172871784002634,0.17639356802119194,0.17068629484537451,0.15977063384531065 5 | ,,,,,,,,,,,,,,,,,,5,,,,,,,10,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,200,,1,,,,,0.46846846846846846,0.42513046969088719,0.0,0.5,0.4,0.5,0.45,f0,f3,f2,f1,f4,,,,,,0.24753246382660662,0.21293260116956142,0.21229504301200014,0.21002277778160533,0.11721711421022651 6 | ,,,,,,,,,,,,,,,,,,5,,,,,,,10,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,100,,0,,,,,0.57142857142857151,0.5625,1.0,1.0,1.0,0.9,0.7047619047619048,f1,f0,f2,f3,f4,,,,,,0.30478408991866651,0.27044951036495857,0.18325750177689373,0.13632059259106671,0.10518830534841442 7 | ,,,,,,,,,,,,,,,,,,5,,,,,,,10,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,100,,1,,,,,0.43749999999999994,0.43566176470588236,1.0,1.0,0.675,0.38333333333333336,0.4228571428571428,f3,f1,f2,f0,f4,,,,,,0.25898956576865445,0.23822260159270395,0.20478252617004356,0.16538833602247799,0.13261697044612003 8 | ,,,,,,,,,,,,,,,,,,5,,,,,,,10,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,100,,2,,,,,0.47058823529411764,0.4797794117647059,0.6699999999999999,0.33999999999999997,0.325,0.6916666666666667,0.6095238095238095,f0,f2,f1,f3,f4,,,,,,0.31968998098113521,0.20909498898330217,0.20873721998661438,0.14459446604790366,0.11788334400104464 9 | ,,,,,,,,,,,,,,,,,,5,,,,,,,10,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,200,,0,,,,,0.44117647058823528,0.39583333333333326,1.0,1.0,0.9,0.5904761904761904,0.4417582417582417,f2,f4,f1,f3,f0,,,,,,0.23484891193747606,0.22987804031880693,0.21584627745110163,0.16689666075974247,0.15253010953287297 10 | ,,,,,,,,,,,,,,,,,,5,,,,,,,10,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,200,,1,,,,,0.59259259259259267,0.44285714285714284,0.33999999999999997,0.0,0.325,0.6095238095238095,0.39340659340659345,f1,f0,f3,f2,f4,,,,,,0.23320808213120975,0.22120339466173547,0.20802019574093614,0.19516808963832621,0.14240023782779246 11 | ,,,,,,,,,,,,,,,,,,5,,,,,,,10,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,200,,2,,,,,0.5,0.40000000000000002,1.0,0.84,0.38333333333333336,0.45714285714285713,0.39340659340659345,f0,f1,f2,f3,f4,,,,,,0.25040188372773187,0.2218556743990574,0.18579449846558521,0.17478755476804689,0.16716038863957866 12 | ,,,,,,,,,,,,,,,,,,25,,,,,,,10,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,100,,0,,,,,0.53333333333333333,0.61458333333333326,0.9444444444444444,0.8888888888888888,0.7222222222222221,0.6,0.5408163265306123,f1,f2,f0,f3,f4,,,,,,0.30133021121919307,0.23112491512134198,0.20744625487560162,0.14369264178350444,0.11640597700035889 13 | ,,,,,,,,,,,,,,,,,,25,,,,,,,10,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,100,,1,,,,,0.5490196078431373,0.58733974358974361,1.0,1.0,1.0,1.0,0.8396825396825396,f0,f1,f2,f3,f4,,,,,,0.31538687588603648,0.25814614625471521,0.1766143486752913,0.13122155649553124,0.11863107268842576 14 | ,,,,,,,,,,,,,,,,,,25,,,,,,,10,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,200,,0,,,,,0.45544554455445557,0.41589723002810119,0.92,0.84,0.6,0.4970588235294118,0.3891402714932127,f2,f1,f0,f4,f3,,,,,,0.25990245434566417,0.24271990248826655,0.18731568413402738,0.16303430332206129,0.14702765570998061 15 | ,,,,,,,,,,,,,,,,,,25,,,,,,,10,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,200,,1,,,,,0.47058823529411764,0.45885186672019262,1.0,0.9363636363636364,0.7454545454545455,0.42727272727272725,0.44960474308300397,f1,f0,f2,f3,f4,,,,,,0.229641156726999,0.21869648166708436,0.20353255434063353,0.18399347294541454,0.16413633431986849 16 | ,,,,,,,,,,,,,,,,,,25,,,,,,,10,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,100,,0,,,,,0.56250000000000011,0.58333333333333337,1.0,1.0,1.0,0.9066666666666667,0.764,f0,f1,f2,f4,f3,,,,,,0.28308477324015008,0.27822858418872959,0.16607330792423597,0.13645981957739037,0.13615351506949408 17 | ,,,,,,,,,,,,,,,,,,25,,,,,,,10,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,100,,1,,,,,0.29629629629629634,0.36948529411764708,0.835,0.6699999999999999,0.17499999999999993,0.06500000000000002,0.23000000000000004,f3,f1,f2,f0,f4,,,,,,0.26096680352639806,0.21665124295938062,0.18861882152820236,0.16921312942108349,0.1645500025649354 18 | ,,,,,,,,,,,,,,,,,,25,,,,,,,10,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,100,,2,,,,,0.55172413793103448,0.60110294117647056,0.9175,0.835,0.5874999999999999,0.5,0.5866666666666667,f0,f1,f2,f4,f3,,,,,,0.32921793517928005,0.25233764257953051,0.20785385053419528,0.11360915075813882,0.096981420948855285 19 | ,,,,,,,,,,,,,,,,,,25,,,,,,,10,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,200,,0,,,,,0.38709677419354838,0.45442708333333337,0.31999999999999995,0.12,0.6888888888888889,0.7996336996336997,0.5349112426035503,f4,f1,f2,f0,f3,,,,,,0.25116173678552117,0.20841230124757323,0.19998562496746489,0.19230999334839732,0.14813034365104333 20 | ,,,,,,,,,,,,,,,,,,25,,,,,,,10,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,200,,1,,,,,0.53333333333333333,0.46820276497695851,0.8533333333333333,0.7066666666666666,0.3378787878787879,0.3878787878787879,0.4588932806324111,f0,f3,f1,f2,f4,,,,,,0.24723024802943988,0.22726973422666061,0.18732245320429688,0.18038186346475077,0.1577957010748518 21 | ,,,,,,,,,,,,,,,,,,25,,,,,,,10,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,200,,2,,,,,0.37931034482758619,0.4824884792626728,0.835,0.6699999999999999,0.6499999999999999,0.4666666666666666,0.37681159420289856,f1,f0,f2,f3,f4,,,,,,0.23334845282323982,0.2210159330303286,0.18738559797935758,0.18506931294859097,0.17318070321848297 22 | ,,,,,,,,,,,,,,,,,,5,,,,,,,100,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,100,,0,,,,,0.59649122807017552,0.56891025641025639,1.0,1.0,0.5833333333333334,0.6,0.6,f1,f0,f2,f4,f3,,,,,,0.25440571902128428,0.22728139336663561,0.18420264475619511,0.17374432238714416,0.16036592046874079 23 | ,,,,,,,,,,,,,,,,,,5,,,,,,,100,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,100,,1,,,,,0.56603773584905659,0.58653846153846156,1.0,1.0,1.0,0.6,0.5,f0,f2,f1,f3,f4,,,,,,0.28393626364839969,0.19274521663412744,0.19191663812597715,0.17437211724597249,0.15702976434552329 24 | ,,,,,,,,,,,,,,,,,,5,,,,,,,100,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,200,,0,,,,,0.4464285714285714,0.39863508631071859,0.0,0.5,0.4,0.4,0.45,f2,f3,f1,f0,f4,,,,,,0.23382865091380045,0.22045839420515917,0.1978707672693146,0.18135584093151727,0.1664863466802084 25 | ,,,,,,,,,,,,,,,,,,5,,,,,,,100,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,200,,1,,,,,0.5043478260869565,0.40064231232436776,1.0,1.0,0.4,0.4,0.4,f3,f1,f2,f0,f4,,,,,,0.23519844406512624,0.2144100484693974,0.20186183495579818,0.18351556865838281,0.16501410385129542 26 | ,,,,,,,,,,,,,,,,,,5,,,,,,,100,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,100,,0,,,,,0.43750000000000006,0.55555555555555558,1.0,1.0,1.0,0.6,0.5571428571428572,f1,f0,f2,f3,f4,,,,,,0.2424796057733653,0.23230239528747806,0.21614440755657705,0.15911408835861571,0.14995950302396399 27 | ,,,,,,,,,,,,,,,,,,5,,,,,,,100,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,100,,1,,,,,0.34285714285714286,0.31617647058823534,0.6699999999999999,0.33999999999999997,0.0,0.07500000000000002,0.30476190476190473,f3,f1,f2,f0,f4,,,,,,0.2369027252009365,0.22425214168868976,0.20525560745085134,0.17511395881232616,0.15847556684719635 28 | ,,,,,,,,,,,,,,,,,,5,,,,,,,100,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,100,,2,,,,,0.40000000000000008,0.4779411764705882,1.0,1.0,1.0,0.6166666666666666,0.5428571428571428,f0,f3,f2,f1,f4,,,,,,0.29793584094879361,0.19145300917599337,0.18590234214306417,0.1681054688073802,0.15660333892476866 29 | ,,,,,,,,,,,,,,,,,,5,,,,,,,100,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,200,,0,,,,,0.45945945945945943,0.39756944444444442,0.31999999999999995,0.0,0.3,0.40952380952380957,0.48461538461538467,f1,f2,f4,f3,f0,,,,,,0.23140663928463631,0.22326947335714642,0.19224051774910195,0.18112544249448853,0.17195792711462687 30 | ,,,,,,,,,,,,,,,,,,5,,,,,,,100,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,200,,1,,,,,0.57142857142857151,0.44516129032258067,0.33999999999999997,0.16,0.3083333333333333,0.5428571428571428,0.45494505494505494,f3,f1,f2,f4,f0,,,,,,0.25541938189397301,0.20721034938285371,0.18854126685371919,0.18222104826434771,0.16660795360510655 31 | ,,,,,,,,,,,,,,,,,,5,,,,,,,100,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,200,,2,,,,,0.50666666666666671,0.40829493087557606,0.33999999999999997,0.16,0.3083333333333333,0.30476190476190473,0.45494505494505494,f3,f1,f2,f0,f4,,,,,,0.2403061180426124,0.22060609132390133,0.19100354816638088,0.18552459823309697,0.16255964423400862 32 | ,,,,,,,,,,,,,,,,,,25,,,,,,,100,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,100,,0,,,,,0.6428571428571429,0.58173076923076927,1.0,1.0,0.8333333333333333,0.7083333333333333,0.5505050505050505,f1,f0,f2,f3,f4,,,,,,0.2419002265279718,0.23078744518391048,0.19297183420958958,0.17129289430580866,0.16304759977271932 33 | ,,,,,,,,,,,,,,,,,,25,,,,,,,100,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,100,,1,,,,,0.53846153846153844,0.60016025641025639,1.0,1.0,0.8333333333333333,0.6,0.5,f0,f1,f2,f4,f3,,,,,,0.26863509180838524,0.19333044271245448,0.19081127008117474,0.17593514035749741,0.17128805504048816 34 | ,,,,,,,,,,,,,,,,,,25,,,,,,,100,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,200,,0,,,,,0.4504504504504504,0.38960256924929743,1.0,0.5,0.42857142857142855,0.4,0.474937343358396,f3,f2,f1,f4,f0,,,,,,0.22350090210911383,0.21427697694156828,0.20326124269908244,0.18167124277240199,0.17728963547783352 35 | ,,,,,,,,,,,,,,,,,,25,,,,,,,100,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,200,,1,,,,,0.5565217391304349,0.4225210758731433,0.75,0.5,0.4,0.4,0.3,f1,f2,f3,f0,f4,,,,,,0.21782429285079583,0.21539683143790089,0.21192161041140306,0.18158368655797716,0.17327357874192301 36 | ,,,,,,,,,,,,,,,,,,25,,,,,,,100,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,100,,0,,,,,0.41379310344827586,0.57291666666666663,1.0,1.0,0.6499999999999999,0.7,0.5904761904761904,f0,f1,f2,f3,f4,,,,,,0.23613845551949267,0.22881939308298854,0.20445356856332519,0.1660393383069875,0.16454924452720629 37 | ,,,,,,,,,,,,,,,,,,25,,,,,,,100,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,100,,1,,,,,0.37500000000000006,0.34375,0.6699999999999999,0.33999999999999997,0.0,0.07500000000000002,0.3904761904761905,f3,f1,f2,f0,f4,,,,,,0.23362103650846688,0.22018174526372533,0.21246634133663858,0.17025532441102048,0.16347555248014856 38 | ,,,,,,,,,,,,,,,,,,25,,,,,,,100,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,100,,2,,,,,0.42857142857142855,0.50551470588235292,1.0,1.0,0.675,0.6166666666666666,0.6095238095238095,f0,f2,f3,f1,f4,,,,,,0.28781595254211167,0.19175822985045801,0.19014553023837716,0.17294080239743859,0.1573394849716146 39 | ,,,,,,,,,,,,,,,,,,25,,,,,,,100,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,200,,0,,,,,0.42666666666666664,0.41753472222222221,0.31999999999999995,0.06,0.4,0.42,0.4822222222222222,f2,f1,f3,f4,f0,,,,,,0.21706156046502501,0.21099040600646657,0.19816746113588174,0.18973376872587752,0.18404680366674925 40 | ,,,,,,,,,,,,,,,,,,25,,,,,,,100,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,200,,1,,,,,0.5609756097560975,0.41797235023041479,0.33999999999999997,0.0,0.3733333333333333,0.5771428571428571,0.45494505494505494,f3,f1,f4,f2,f0,,,,,,0.24143701379631813,0.19531390439039281,0.19191494902373649,0.18846364339222813,0.18287048939732442 41 | ,,,,,,,,,,,,,,,,,,25,,,,,,,100,,,,,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,200,,2,,,,,0.40000000000000008,0.41152073732718886,1.0,0.84,0.38333333333333336,0.3083333333333333,0.3861538461538462,f3,f1,f2,f0,f4,,,,,,0.22537570460810358,0.21089279462369967,0.19856125916151954,0.19101735773761735,0.17415288386906005 42 | ,,,,,,,,,,,,,,,,linear,,,,,,,,,,,,,True,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,100,,0,,,,,0.48000000000000004,0.52083333333333326,1.0,1.0,0.5833333333333334,0.8,0.7,,,,,,,,,, 43 | ,,,,,,,,,,,,,,,,linear,,,,,,,,,,,,,True,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,100,,1,,,,,0.54838709677419362,0.55769230769230771,0.5,0.0,0.41666666666666663,0.4,0.6,,,,,,,,,, 44 | ,,,,,,,,,,,,,,,,linear,,,,,,,,,,,,,True,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,200,,0,,,,,0.57377049180327866,0.52308309915696505,1.0,1.0,0.8,0.7,0.55,,,,,,,,,, 45 | ,,,,,,,,,,,,,,,,linear,,,,,,,,,,,,,True,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,200,,1,,,,,0.69281045751633985,0.55961461260537937,0.0,0.0,0.4,0.6,0.55,,,,,,,,,, 46 | ,,,,,,,,,,,,,,,,linear,,,,,,,,,,,,,True,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,100,,0,,,,,0.43750000000000006,0.61805555555555558,1.0,1.0,1.0,0.9,0.8523809523809524,,,,,,,,,, 47 | ,,,,,,,,,,,,,,,,linear,,,,,,,,,,,,,True,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,100,,1,,,,,0.35294117647058826,0.1875,0.6699999999999999,0.33999999999999997,0.0,0.0,0.0,,,,,,,,,, 48 | ,,,,,,,,,,,,,,,,linear,,,,,,,,,,,,,True,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,100,,2,,,,,0.42105263157894735,0.72426470588235281,1.0,1.0,1.0,0.9249999999999999,0.6952380952380952,,,,,,,,,, 49 | ,,,,,,,,,,,,,,,,linear,,,,,,,,,,,,,True,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,200,,0,,,,,0.59999999999999998,0.56510416666666663,1.0,1.0,0.9,0.7047619047619048,0.8098901098901098,,,,,,,,,, 50 | ,,,,,,,,,,,,,,,,linear,,,,,,,,,,,,,True,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,200,,1,,,,,0.69306930693069302,0.51152073732718895,0.33999999999999997,0.16,0.6916666666666667,0.45714285714285713,0.5450549450549451,,,,,,,,,, 51 | ,,,,,,,,,,,,,,,,linear,,,,,,,,,,,,,True,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,200,,2,,,,,0.5609756097560975,0.57880184331797224,0.33999999999999997,0.16,0.38333333333333336,0.45714285714285713,0.6208791208791209,,,,,,,,,, 52 | ,,,,,,,,,,,,,,,,rbf,,,,,,,,,,,,,True,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,100,,0,,,,,0.68421052631578949,0.52403846153846156,1.0,1.0,0.8333333333333333,0.8,0.7,,,,,,,,,, 53 | ,,,,,,,,,,,,,,,,rbf,,,,,,,,,,,,,True,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,100,,1,,,,,0.68421052631578949,0.5689102564102565,0.5,0.0,0.41666666666666663,0.4,0.6,,,,,,,,,, 54 | ,,,,,,,,,,,,,,,,rbf,,,,,,,,,,,,,True,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,200,,0,,,,,0.69281045751633985,0.52709755118426338,1.0,1.0,0.8,0.7,0.55,,,,,,,,,, 55 | ,,,,,,,,,,,,,,,,rbf,,,,,,,,,,,,,True,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,2,,,,,,,,,,200,,1,,,,,0.69281045751633985,0.54235246888799682,0.0,0.0,0.4,0.5306818181818183,0.65,,,,,,,,,, 56 | ,,,,,,,,,,,,,,,,rbf,,,,,,,,,,,,,True,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,100,,0,,,,,0.66666666666666663,0.63888888888888895,1.0,1.0,1.0,1.0,0.738095238095238,,,,,,,,,, 57 | ,,,,,,,,,,,,,,,,rbf,,,,,,,,,,,,,True,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,100,,1,,,,,0.33333333333333337,0.20955882352941174,0.6699999999999999,0.33999999999999997,0.0,0.0,0.15238095238095237,,,,,,,,,, 58 | ,,,,,,,,,,,,,,,,rbf,,,,,,,,,,,,,True,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,100,,2,,,,,0.67999999999999994,0.74632352941176472,1.0,1.0,1.0,1.0,0.6952380952380952,,,,,,,,,, 59 | ,,,,,,,,,,,,,,,,rbf,,,,,,,,,,,,,True,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,200,,0,,,,,0.6923076923076924,0.56727430555555558,1.0,0.82,0.7,0.8523809523809524,0.7054945054945055,,,,,,,,,, 60 | ,,,,,,,,,,,,,,,,rbf,,,,,,,,,,,,,True,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,200,,1,,,,,0.69306930693069302,0.48571428571428577,0.33999999999999997,0.16,0.6166666666666666,0.6952380952380952,0.5450549450549451,,,,,,,,,, 61 | ,,,,,,,,,,,,,,,,rbf,,,,,,,,,,,,,True,0,,,,,,,,,,"[100, 200]",,0,,,,,,,,3,,,,,,,,,,200,,2,,,,,0.69306930693069302,0.60092165898617522,0.33999999999999997,0.0,0.38333333333333336,0.6095238095238095,0.7582417582417582,,,,,,,,,, 62 | -------------------------------------------------------------------------------- /tests/data/test_perambulate/run_experiment.pkl: -------------------------------------------------------------------------------- 1 | (dp1 2 | S"Trial(clf=, clf_params={'random_state': 0}, subset=, subset_params={'subset_size': 80, 'random_state': 0}, cv=, cv_params={})" 3 | p2 4 | cnumpy.core.multiarray 5 | scalar 6 | p3 7 | (cnumpy 8 | dtype 9 | p4 10 | (S'f8' 11 | I0 12 | I1 13 | tRp5 14 | (I3 15 | S'<' 16 | NNNI-1 17 | I-1 18 | I0 19 | tbS'`\x07cV\x1a=\xee?' 20 | tRp6 21 | sS"Trial(clf=, clf_params={'random_state': 0}, subset=, subset_params={'subset_size': 20, 'random_state': 0}, cv=, cv_params={})" 22 | p7 23 | g3 24 | (g5 25 | S'\xa2\xbd\x84\xf6\x12\xda\xeb?' 26 | tRp8 27 | sS"Trial(clf=, clf_params={'random_state': 0}, subset=, subset_params={'subset_size': 40, 'random_state': 0}, cv=, cv_params={})" 28 | p9 29 | g3 30 | (g5 31 | S'\xca:W\x1e\x90\xac\xeb?' 32 | tRp10 33 | sS"Trial(clf=, clf_params={'random_state': 0}, subset=, subset_params={'subset_size': 60, 'random_state': 0}, cv=, cv_params={})" 34 | p11 35 | g3 36 | (g5 37 | S'\x8e\xe38\x8e\xe38\xee?' 38 | tRp12 39 | sS"Trial(clf=, clf_params={'random_state': 0}, subset=, subset_params={'subset_size': 100, 'random_state': 0}, cv=, cv_params={})" 40 | p13 41 | g3 42 | (g5 43 | S'\x0b\xbb\xcfwm\x99\xee?' 44 | tRp14 45 | s. -------------------------------------------------------------------------------- /tests/data/test_perambulate/slice_by_best_score.pkl: -------------------------------------------------------------------------------- 1 | (dp1 2 | S"Trial(clf=, clf_params={'kernel': 'linear', 'random_state': 0}, subset=, subset_params={'subset_size': 20, 'random_state': 0}, cv=, cv_params={})" 3 | p2 4 | cnumpy.core.multiarray 5 | scalar 6 | p3 7 | (cnumpy 8 | dtype 9 | p4 10 | (S'f8' 11 | I0 12 | I1 13 | tRp5 14 | (I3 15 | S'<' 16 | NNNI-1 17 | I-1 18 | I0 19 | tbS'\x8e\xe38\x8e\xe38\xee?' 20 | tRp6 21 | sS"Trial(clf=, clf_params={'n_estimators': 100, 'random_state': 0, 'max_depth': 1}, subset=, subset_params={'subset_size': 40, 'random_state': 0}, cv=, cv_params={})" 22 | p7 23 | g3 24 | (g5 25 | S'\x8bQ\xc3\xdf\xa6\x18\xed?' 26 | tRp8 27 | sS"Trial(clf=, clf_params={'kernel': 'linear', 'random_state': 0}, subset=, subset_params={'subset_size': 40, 'random_state': 0}, cv=, cv_params={})" 28 | p9 29 | g3 30 | (g5 31 | S'\xde\xdd\xdd\xdd\xdd\xdd\xed?' 32 | tRp10 33 | sS"Trial(clf=, clf_params={'n_estimators': 100, 'random_state': 0, 'max_depth': 1}, subset=, subset_params={'subset_size': 20, 'random_state': 0}, cv=, cv_params={})" 34 | p11 35 | g3 36 | (g5 37 | S'\x0b\xed%\xb4\x97\xd0\xee?' 38 | tRp12 39 | s. -------------------------------------------------------------------------------- /tests/data/test_perambulate/slice_on_dimension_clf.pkl: -------------------------------------------------------------------------------- 1 | (lp1 2 | S"Trial(clf=, clf_params={'n_estimators': 10, 'random_state': 0, 'max_depth': 1}, subset=, subset_params={'subset_size': 20, 'random_state': 0}, cv=, cv_params={})" 3 | p2 4 | aS"Trial(clf=, clf_params={'n_estimators': 10, 'random_state': 0, 'max_depth': 1}, subset=, subset_params={'subset_size': 40, 'random_state': 0}, cv=, cv_params={})" 5 | p3 6 | aS"Trial(clf=, clf_params={'n_estimators': 10, 'random_state': 0, 'max_depth': 1}, subset=, subset_params={'subset_size': 60, 'random_state': 0}, cv=, cv_params={})" 7 | p4 8 | aS"Trial(clf=, clf_params={'n_estimators': 10, 'random_state': 0, 'max_depth': 1}, subset=, subset_params={'subset_size': 80, 'random_state': 0}, cv=, cv_params={})" 9 | p5 10 | aS"Trial(clf=, clf_params={'n_estimators': 10, 'random_state': 0, 'max_depth': 1}, subset=, subset_params={'subset_size': 100, 'random_state': 0}, cv=, cv_params={})" 11 | p6 12 | aS"Trial(clf=, clf_params={'n_estimators': 10, 'random_state': 0, 'max_depth': 10}, subset=, subset_params={'subset_size': 20, 'random_state': 0}, cv=, cv_params={})" 13 | p7 14 | aS"Trial(clf=, clf_params={'n_estimators': 10, 'random_state': 0, 'max_depth': 10}, subset=, subset_params={'subset_size': 40, 'random_state': 0}, cv=, cv_params={})" 15 | p8 16 | aS"Trial(clf=, clf_params={'n_estimators': 10, 'random_state': 0, 'max_depth': 10}, subset=, subset_params={'subset_size': 60, 'random_state': 0}, cv=, cv_params={})" 17 | p9 18 | aS"Trial(clf=, clf_params={'n_estimators': 10, 'random_state': 0, 'max_depth': 10}, subset=, subset_params={'subset_size': 80, 'random_state': 0}, cv=, cv_params={})" 19 | p10 20 | aS"Trial(clf=, clf_params={'n_estimators': 10, 'random_state': 0, 'max_depth': 10}, subset=, subset_params={'subset_size': 100, 'random_state': 0}, cv=, cv_params={})" 21 | p11 22 | aS"Trial(clf=, clf_params={'n_estimators': 100, 'random_state': 0, 'max_depth': 1}, subset=, subset_params={'subset_size': 20, 'random_state': 0}, cv=, cv_params={})" 23 | p12 24 | aS"Trial(clf=, clf_params={'n_estimators': 100, 'random_state': 0, 'max_depth': 1}, subset=, subset_params={'subset_size': 40, 'random_state': 0}, cv=, cv_params={})" 25 | p13 26 | aS"Trial(clf=, clf_params={'n_estimators': 100, 'random_state': 0, 'max_depth': 1}, subset=, subset_params={'subset_size': 60, 'random_state': 0}, cv=, cv_params={})" 27 | p14 28 | aS"Trial(clf=, clf_params={'n_estimators': 100, 'random_state': 0, 'max_depth': 1}, subset=, subset_params={'subset_size': 80, 'random_state': 0}, cv=, cv_params={})" 29 | p15 30 | aS"Trial(clf=, clf_params={'n_estimators': 100, 'random_state': 0, 'max_depth': 1}, subset=, subset_params={'subset_size': 100, 'random_state': 0}, cv=, cv_params={})" 31 | p16 32 | aS"Trial(clf=, clf_params={'n_estimators': 100, 'random_state': 0, 'max_depth': 10}, subset=, subset_params={'subset_size': 20, 'random_state': 0}, cv=, cv_params={})" 33 | p17 34 | aS"Trial(clf=, clf_params={'n_estimators': 100, 'random_state': 0, 'max_depth': 10}, subset=, subset_params={'subset_size': 40, 'random_state': 0}, cv=, cv_params={})" 35 | p18 36 | aS"Trial(clf=, clf_params={'n_estimators': 100, 'random_state': 0, 'max_depth': 10}, subset=, subset_params={'subset_size': 60, 'random_state': 0}, cv=, cv_params={})" 37 | p19 38 | aS"Trial(clf=, clf_params={'n_estimators': 100, 'random_state': 0, 'max_depth': 10}, subset=, subset_params={'subset_size': 80, 'random_state': 0}, cv=, cv_params={})" 39 | p20 40 | aS"Trial(clf=, clf_params={'n_estimators': 100, 'random_state': 0, 'max_depth': 10}, subset=, subset_params={'subset_size': 100, 'random_state': 0}, cv=, cv_params={})" 41 | p21 42 | a. -------------------------------------------------------------------------------- /tests/data/test_perambulate/slice_on_dimension_subset_params.pkl: -------------------------------------------------------------------------------- 1 | (lp1 2 | . -------------------------------------------------------------------------------- /tests/data/test_perambulate/sliding_windows.csv: -------------------------------------------------------------------------------- 1 | clf,clf_C,clf_algorithm,clf_base_estimator,clf_bootstrap,clf_cache_size,clf_class_weight,clf_coef0,clf_compute_importances,clf_constant,clf_criterion,clf_degree,clf_dual,clf_fit_intercept,clf_gamma,clf_intercept_scaling,clf_kernel,clf_learning_rate,clf_max_depth,clf_max_features,clf_max_iter,clf_max_leaf_nodes,clf_min_density,clf_min_samples_leaf,clf_min_samples_split,clf_n_estimators,clf_n_jobs,clf_oob_score,clf_penalty,clf_probability,clf_random_state,clf_shrinking,clf_splitter,clf_strategy,clf_tol,clf_verbose,subset,subset_cols_to_exclude,subset_max_grades,subset_n_subsets,subset_num_rows,subset_proportions_positive,subset_random_state,subset_subset_size,cv,cv_col_name,cv_col_name,cv_expanding_train,cv_inc_value,cv_indices,cv_n_folds,cv_random_state,cv_shuffle,cv_test_start,cv_test_window_size,cv_train_start,cv_train_window_size,subset_note_excluded_col,subset_note_max_grade,subset_note_prop_positive,subset_note_rows,subset_note_sample_num,cv_note_fold,cv_note_test_end,cv_note_test_start,cv_note_train_end,cv_note_train_start,f1_score,roc_auc,prec@1%,prec@2%,prec@5%,prec@10%,prec@20%,feature_ranked_0,feature_ranked_1,feature_ranked_2,feature_ranked_3,feature_ranked_4,feature_ranked_5,feature_ranked_6,feature_ranked_7,feature_ranked_8,feature_ranked_9,feature_score_0,feature_score_1,feature_score_2,feature_score_3,feature_score_4,feature_score_5,feature_score_6,feature_score_7,feature_score_8,feature_score_9 2 | ,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,0,,,,,,,,,,,,,,,,,,2,,,,,2,2,0,2,,,,,,,3,2,1,0,0.0,0.0,0.98,0.96,0.9,0.8,0.6,id,year,,,,,,,,,0.29999999999999999,0.20000000000000001 3 | ,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,0,,,,,,,,,,,,,,,,,,2,,,,,2,2,0,2,,,,,,,5,4,3,2,0.0,0.0,0.98,0.96,0.9,0.8,0.6,id,year,,,,,,,,,0.29999999999999999,0.20000000000000001 4 | ,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,0,,,,,,,,,,,,,,,,,,2,,,,,2,2,0,2,,,,,,,7,6,5,4,0.0,0.5,0.995,0.99,0.975,0.95,0.9,id,year,,,,,,,,,0.29999999999999999,0.20000000000000001 5 | ,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,0,,,,,,,,,,,,,,,,,,2,,,,,2,2,0,2,,,,,,,9,8,7,6,0.0,1.0,1.0,1.0,1.0,1.0,1.0,id,year,,,,,,,,,0.29999999999999999,0.20000000000000001 6 | ,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,0,,,,,,,,,,,,,,,,year,,2,,,,,1999,2,1997,2,,,,,,,2000,1999,1998,1997,0.66666666666666663,0.5,0.995,0.99,0.975,0.95,0.9,id,year,,,,,,,,,0.45000000000000001,0.14999999999999999 7 | ,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,0,,,,,,,,,,,,,,,,year,,2,,,,,1999,2,1997,2,,,,,,,2002,2001,2000,1999,0.0,0.0,0.98,0.96,0.9,0.8,0.6,id,year,,,,,,,,,0.5,0.0 8 | ,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,0,,,,,,,,,,,,,,,,year,,2,,,,,1999,2,1997,2,,,,,,,2004,2003,2002,2001,0.66666666666666663,0.5,0.995,0.99,0.975,0.95,0.9,id,year,,,,,,,,,0.29999999999999999,0.20000000000000001 9 | ,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,0,,,,,,,,,,,,,,,,year,,2,,,,,1999,2,1997,2,,,,,,,2006,2005,2004,2003,0.0,0,0.99,0.98,0.95,0.9,0.8,id,year,,,,,,,,,0.5,0.0 10 | -------------------------------------------------------------------------------- /tests/data/test_perambulate/test_subsetting.pkl: -------------------------------------------------------------------------------- 1 | (dp1 2 | S"Trial(clf=, clf_params={}, subset=, subset_params={'subset_size': 20, 'random_state': 0}, cv=, cv_params={})" 3 | p2 4 | c__builtin__ 5 | frozenset 6 | p3 7 | ((lp4 8 | S"[Run(clf=RandomForestClassifier(bootstrap=True, compute_importances=None,\n criterion='gini', max_depth=None, max_features='auto',\n max_leaf_nodes=None, min_density=None, min_samples_leaf=1,\n min_samples_split=2, n_estimators=10, n_jobs=1,\n oob_score=False, random_state=None, verbose=0), subset_note={'sample_num': 0}, cv_note={'fold': 0})]" 9 | p5 10 | aS"[Run(clf=RandomForestClassifier(bootstrap=True, compute_importances=None,\n criterion='gini', max_depth=None, max_features='auto',\n max_leaf_nodes=None, min_density=None, min_samples_leaf=1,\n min_samples_split=2, n_estimators=10, n_jobs=1,\n oob_score=False, random_state=None, verbose=0), subset_note={'sample_num': 1}, cv_note={'fold': 0})]" 11 | p6 12 | aS"[Run(clf=RandomForestClassifier(bootstrap=True, compute_importances=None,\n criterion='gini', max_depth=None, max_features='auto',\n max_leaf_nodes=None, min_density=None, min_samples_leaf=1,\n min_samples_split=2, n_estimators=10, n_jobs=1,\n oob_score=False, random_state=None, verbose=0), subset_note={'sample_num': 2}, cv_note={'fold': 0})]" 13 | p7 14 | atRp8 15 | sS"Trial(clf=, clf_params={}, subset=, subset_params={'num_rows': [10, 20, 30], 'random_state': 0}, cv=, cv_params={})" 16 | p9 17 | g3 18 | ((lp10 19 | S"[Run(clf=RandomForestClassifier(bootstrap=True, compute_importances=None,\n criterion='gini', max_depth=None, max_features='auto',\n max_leaf_nodes=None, min_density=None, min_samples_leaf=1,\n min_samples_split=2, n_estimators=10, n_jobs=1,\n oob_score=False, random_state=None, verbose=0), subset_note={'rows': 30}, cv_note={'fold': 0})]" 20 | p11 21 | aS"[Run(clf=RandomForestClassifier(bootstrap=True, compute_importances=None,\n criterion='gini', max_depth=None, max_features='auto',\n max_leaf_nodes=None, min_density=None, min_samples_leaf=1,\n min_samples_split=2, n_estimators=10, n_jobs=1,\n oob_score=False, random_state=None, verbose=0), subset_note={'rows': 10}, cv_note={'fold': 0})]" 22 | p12 23 | aS"[Run(clf=RandomForestClassifier(bootstrap=True, compute_importances=None,\n criterion='gini', max_depth=None, max_features='auto',\n max_leaf_nodes=None, min_density=None, min_samples_leaf=1,\n min_samples_split=2, n_estimators=10, n_jobs=1,\n oob_score=False, random_state=None, verbose=0), subset_note={'rows': 20}, cv_note={'fold': 0})]" 24 | p13 25 | atRp14 26 | sS"Trial(clf=, clf_params={}, subset=, subset_params={'proportions_positive': [0.5, 0.75, 0.9], 'random_state': 0, 'subset_size': 10}, cv=, cv_params={})" 27 | p15 28 | g3 29 | ((lp16 30 | S"[Run(clf=RandomForestClassifier(bootstrap=True, compute_importances=None,\n criterion='gini', max_depth=None, max_features='auto',\n max_leaf_nodes=None, min_density=None, min_samples_leaf=1,\n min_samples_split=2, n_estimators=10, n_jobs=1,\n oob_score=False, random_state=None, verbose=0), subset_note=prop_positive=0.5, cv_note={'fold': 0})]" 31 | p17 32 | aS"[Run(clf=RandomForestClassifier(bootstrap=True, compute_importances=None,\n criterion='gini', max_depth=None, max_features='auto',\n max_leaf_nodes=None, min_density=None, min_samples_leaf=1,\n min_samples_split=2, n_estimators=10, n_jobs=1,\n oob_score=False, random_state=None, verbose=0), subset_note=prop_positive=0.9, cv_note={'fold': 0})]" 33 | p18 34 | aS"[Run(clf=RandomForestClassifier(bootstrap=True, compute_importances=None,\n criterion='gini', max_depth=None, max_features='auto',\n max_leaf_nodes=None, min_density=None, min_samples_leaf=1,\n min_samples_split=2, n_estimators=10, n_jobs=1,\n oob_score=False, random_state=None, verbose=0), subset_note=prop_positive=0.75, cv_note={'fold': 0})]" 35 | p19 36 | atRp20 37 | sS"Trial(clf=, clf_params={}, subset=, subset_params={'subset_size': 20, 'random_state': 0}, cv=, cv_params={})" 38 | p21 39 | g3 40 | ((lp22 41 | S"[Run(clf=RandomForestClassifier(bootstrap=True, compute_importances=None,\n criterion='gini', max_depth=None, max_features='auto',\n max_leaf_nodes=None, min_density=None, min_samples_leaf=1,\n min_samples_split=2, n_estimators=10, n_jobs=1,\n oob_score=False, random_state=None, verbose=0), subset_note={'sample_num': 0}, cv_note={'fold': 0})]" 42 | p23 43 | aS"[Run(clf=RandomForestClassifier(bootstrap=True, compute_importances=None,\n criterion='gini', max_depth=None, max_features='auto',\n max_leaf_nodes=None, min_density=None, min_samples_leaf=1,\n min_samples_split=2, n_estimators=10, n_jobs=1,\n oob_score=False, random_state=None, verbose=0), subset_note={'sample_num': 1}, cv_note={'fold': 0})]" 44 | p24 45 | aS"[Run(clf=RandomForestClassifier(bootstrap=True, compute_importances=None,\n criterion='gini', max_depth=None, max_features='auto',\n max_leaf_nodes=None, min_density=None, min_samples_leaf=1,\n min_samples_split=2, n_estimators=10, n_jobs=1,\n oob_score=False, random_state=None, verbose=0), subset_note={'sample_num': 2}, cv_note={'fold': 0})]" 46 | p25 47 | atRp26 48 | s. -------------------------------------------------------------------------------- /tests/data/test_perambulate_ref.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dssg/eights/9f12f9fb60984b8da2270e0df809fa09027336e5/tests/data/test_perambulate_ref.pdf -------------------------------------------------------------------------------- /tests/test_all.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import unittest 3 | 4 | test_modules = ['test_investigate', 5 | 'test_decontaminate', 6 | 'test_generate', 7 | 'test_utils', 8 | 'test_communicate', 9 | 'test_perambulate', 10 | 'test_truncate', 11 | 'test_operate'] 12 | if __name__ == '__main__': 13 | suite = unittest.defaultTestLoader.loadTestsFromNames(test_modules) 14 | unittest.TextTestRunner().run(suite) 15 | -------------------------------------------------------------------------------- /tests/test_communicate.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from datetime import datetime 3 | from collections import Counter 4 | import eights.communicate as comm 5 | from eights.communicate.communicate import feature_pairs_in_tree 6 | from eights.communicate.communicate import feature_pairs_in_rf 7 | from eights import utils 8 | from sklearn import datasets 9 | from sklearn.ensemble import RandomForestClassifier 10 | from sklearn.cross_validation import train_test_split 11 | from sklearn.metrics import roc_auc_score 12 | import utils_for_tests as uft 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | 16 | REPORT_PATH=uft.path_of_data('test_communicate.pdf') 17 | SUBREPORT_PATH=uft.path_of_data('test_communicate_sub.pdf') 18 | REFERENCE_REPORT_PATH=uft.path_of_data('test_communicate_ref.pdf') 19 | 20 | class TestCommunicate(unittest.TestCase): 21 | @classmethod 22 | def setUpClass(cls): 23 | cls.report = comm.Report(report_path=REPORT_PATH) 24 | 25 | @classmethod 26 | def tearDownClass(cls): 27 | report_path = cls.report.to_pdf(verbose=False) 28 | uft.print_in_box( 29 | 'Test communicate visual regression tests', 30 | ['graphical output available at:', 31 | report_path, 32 | 'Reference available at:', 33 | REFERENCE_REPORT_PATH]) 34 | 35 | def add_fig_to_report(self, fig, heading): 36 | self.report.add_heading(heading) 37 | self.report.add_fig(fig) 38 | 39 | def test_print_matrix_row_col(self): 40 | M = [(1, 2, 3), (4, 5, 6), (7, 8, 'STRING')] 41 | ctrl = """ 42 | f0 f1 f2 43 | 0 1 2 3 44 | 1 4 5 6 45 | 2 7 8 STRING 46 | """.strip() 47 | with uft.rerout_stdout() as get_stdout: 48 | comm.print_matrix_row_col(M) 49 | self.assertEqual(get_stdout().strip(), ctrl) 50 | M = np.array([(1000, 'Bill'), (2000, 'Sam'), (3000, 'James')], 51 | dtype=[('number', float), ('name', 'S5')]) 52 | row_labels = [name[0] for name in M['name']] 53 | ctrl = """ 54 | number name 55 | B 1000.0 Bill 56 | S 2000.0 Sam 57 | J 3000.0 James 58 | """.strip() 59 | with uft.rerout_stdout() as get_stdout: 60 | comm.print_matrix_row_col(M, row_labels=row_labels) 61 | self.assertEqual(get_stdout().strip(), ctrl) 62 | 63 | 64 | 65 | def test_plot_simple_histogram(self): 66 | np.random.seed(0) 67 | data = np.random.normal(size=(1000,)) 68 | fig = comm.plot_simple_histogram(data, verbose=False) 69 | self.add_fig_to_report(fig, 'plot_simple_histogram') 70 | 71 | def test_plot_prec_recall(self): 72 | M, labels = uft.generate_correlated_test_matrix(1000) 73 | M_train, M_test, labels_train, labels_test = train_test_split( 74 | M, 75 | labels) 76 | clf = RandomForestClassifier(random_state=0) 77 | clf.fit(M_train, labels_train) 78 | score = clf.predict_proba(M_test)[:,-1] 79 | fig = comm.plot_prec_recall(labels_test, score, verbose=False) 80 | self.add_fig_to_report(fig, 'plot_prec_recall') 81 | 82 | def test_plot_roc(self): 83 | M, labels = uft.generate_correlated_test_matrix(1000) 84 | M_train, M_test, labels_train, labels_test = train_test_split( 85 | M, 86 | labels) 87 | clf = RandomForestClassifier(random_state=0) 88 | clf.fit(M_train, labels_train) 89 | score = clf.predict_proba(M_test)[:,-1] 90 | fig = comm.plot_roc(labels_test, score, verbose=False) 91 | self.add_fig_to_report(fig, 'plot_roc') 92 | 93 | def test_plot_box_plot(self): 94 | np.random.seed(0) 95 | data = np.random.normal(size=(1000,)) 96 | fig = comm.plot_box_plot(data, col_name='box', verbose=False) 97 | self.add_fig_to_report(fig, 'plot_box_plot') 98 | 99 | def test_get_top_features(self): 100 | M, labels = uft.generate_test_matrix(1000, 15, random_state=0) 101 | M = utils.cast_np_sa_to_nd(M) 102 | M_train, M_test, labels_train, labels_test = train_test_split( 103 | M, 104 | labels) 105 | clf = RandomForestClassifier(random_state=0) 106 | clf.fit(M_train, labels_train) 107 | res = comm.get_top_features(clf, M, verbose=False) 108 | ctrl = utils.convert_to_sa( 109 | [('f5', 0.0773838526068), 110 | ('f13', 0.0769596713039), 111 | ('f8', 0.0751584839431), 112 | ('f6', 0.0730815879102), 113 | ('f11', 0.0684456133071), 114 | ('f9', 0.0666747414603), 115 | ('f10', 0.0659621889608), 116 | ('f7', 0.0657988099065), 117 | ('f2', 0.0634000069218), 118 | ('f0', 0.0632912268319)], 119 | col_names=('feat_name', 'score')) 120 | self.assertTrue(uft.array_equal(ctrl, res)) 121 | 122 | def test_get_roc_auc(self): 123 | M, labels = uft.generate_correlated_test_matrix(1000) 124 | M_train, M_test, labels_train, labels_test = train_test_split( 125 | M, 126 | labels) 127 | clf = RandomForestClassifier(random_state=0) 128 | clf.fit(M_train, labels_train) 129 | score = clf.predict_proba(M_test)[:,-1] 130 | self.assertTrue(np.allclose( 131 | comm.get_roc_auc(labels_test, score, verbose=False), 132 | roc_auc_score(labels_test, score))) 133 | 134 | def test_plot_correlation_matrix(self): 135 | col1 = range(10) 136 | col2 = [cell * 3 + 1 for cell in col1] 137 | col3 = [1, 5, 8, 4, 1, 8, 5, 9, 0, 1] 138 | sa = utils.convert_to_sa( 139 | zip(col1, col2, col3), 140 | col_names=['base', 'linear_trans', 'no_correlation']) 141 | fig = comm.plot_correlation_matrix(sa, verbose=False) 142 | self.add_fig_to_report(fig, 'plot_correlation_matrix') 143 | 144 | def test_plot_correlation_scatter_plot(self): 145 | col1 = range(10) 146 | col2 = [cell * 3 + 1 for cell in col1] 147 | col3 = [1, 5, 8, 4, 1, 8, 5, 9, 0, 1] 148 | sa = utils.convert_to_sa( 149 | zip(col1, col2, col3), 150 | col_names=['base', 'linear_trans', 'no_correlation']) 151 | fig = comm.plot_correlation_scatter_plot(sa, verbose=False) 152 | self.add_fig_to_report(fig, 'plot_correlation_scatter_plot') 153 | 154 | def test_plot_kernel_density(self): 155 | np.random.seed(0) 156 | data = np.random.normal(size=(1000,)) 157 | fig = comm.plot_kernel_density(data, verbose=False) 158 | self.add_fig_to_report(fig, 'plot_kernel_density') 159 | 160 | 161 | def test_plot_on_timeline(self): 162 | dates = [datetime(2015, 1, 1), 163 | datetime(2015, 2, 1), 164 | datetime(2015, 6, 1), 165 | datetime(2015, 6, 15), 166 | datetime(2015, 9, 2), 167 | datetime(2016, 1, 5)] 168 | fig1 = comm.plot_on_timeline(dates, verbose=False) 169 | self.add_fig_to_report(fig1, 'plot_on_timeline_1') 170 | dates = np.array(dates, dtype='M8[us]') 171 | fig2 = comm.plot_on_timeline(dates, verbose=False) 172 | self.add_fig_to_report(fig1, 'plot_on_timeline_2') 173 | 174 | def test_report(self): 175 | subrep = comm.Report(report_path=SUBREPORT_PATH) 176 | self.assertEqual(subrep.get_report_path(), SUBREPORT_PATH) 177 | subrep.add_heading('Subreport', level=3) 178 | subrep.add_text( 179 | (u'Sample text.\n' 180 | u'

HTML tags should render literally

\n')) 181 | subrep.add_heading('Sample table', level=4) 182 | sample_table = np.array( 183 | [(1, datetime(2015, 1, 1), 'New Years Day'), 184 | (2, datetime(2015, 2, 14), 'Valentines Day'), 185 | (3, datetime(2015, 3, 15), 'The Ides of March')], 186 | dtype=[('idx', int), ('day', 'M8[us]'), ('Name', 'S17')]) 187 | subrep.add_table(sample_table) 188 | sample_fig = plt.figure() 189 | plt.plot([1, 2, 3], [1, 2, 3]) 190 | plt.title('Sample fig') 191 | subrep.add_heading('Sample figure', level=4) 192 | subrep.add_fig(sample_fig) 193 | self.report.add_heading('report') 194 | self.report.add_subreport(subrep) 195 | 196 | def test_feature_pairs_in_tree(self): 197 | iris = datasets.load_iris() 198 | rf = RandomForestClassifier(random_state=0) 199 | rf.fit(iris.data, iris.target) 200 | dt = rf.estimators_[0] 201 | result = feature_pairs_in_tree(dt) 202 | ctrl = [[(2, 3)], [(2, 3), (0, 2)], [(0, 2), (1, 3)]] 203 | self.assertEqual(result, ctrl) 204 | 205 | def test_feature_pairs_in_rf(self): 206 | iris = datasets.load_iris() 207 | rf = RandomForestClassifier(random_state=0) 208 | rf.fit(iris.data, iris.target) 209 | results = feature_pairs_in_rf(rf, [1, 0.5], verbose=False) 210 | # TODO make sure these results are actually correct 211 | ctrl_cts_by_pair = Counter( 212 | {(2, 3): 16, (0, 2): 14, (0, 3): 12, (3, 3): 7, (2, 2): 6, 213 | (0, 1): 4, (1, 2): 3, (1, 3): 3, (0, 0): 2, (1, 1): 1}) 214 | ctrl_ct_pairs_by_depth = [ 215 | Counter({(2, 3): 3, (0, 3): 3, (3, 3): 2, (2, 2): 2, (0, 1): 1, 216 | (0, 0): 1}), 217 | Counter({(0, 2): 7, (2, 3): 5, (3, 3): 2, (2, 2): 2, (0, 3): 2, 218 | (1, 1): 1}), 219 | Counter({(2, 3): 5, (0, 2): 5, (2, 2): 2, (0, 3): 2, (1, 2): 1, 220 | (0, 1): 1, (1, 3): 1, (3, 3): 1, (0, 0): 1}), 221 | Counter({(0, 3): 3, (1, 2): 2, (2, 3): 2, (0, 1): 1, (1, 3): 1, 222 | (3, 3): 1}), 223 | Counter({(0, 1): 1, (1, 3): 1, (3, 3): 1, (0, 2): 1}), 224 | Counter({(0, 3): 1, (2, 3): 1, (0, 2): 1}), 225 | Counter({(0, 3): 1})] 226 | ctrl_av_depth_by_pair = { 227 | (0, 1): 2.25, (1, 2): 2.6666666666666665, (0, 0): 1.0, 228 | (3, 3): 1.5714285714285714, (0, 2): 1.8571428571428572, 229 | (1, 3): 3.0, (2, 3): 1.625, (2, 2): 1.0, 230 | (0, 3): 2.1666666666666665, (1, 1): 1.0} 231 | ctrl_weighted= { 232 | (0, 1): 1.0, (1, 2): 0.0, (0, 0): 1.0, (3, 3): 3.0, (0, 2): 3.5, 233 | (1, 3): 0.0, (2, 3): 5.5, (2, 2): 3.0, (0, 3): 4.0, (1, 1): 0.5} 234 | for result, ctrl in zip( 235 | results, 236 | (ctrl_cts_by_pair, ctrl_ct_pairs_by_depth, 237 | ctrl_av_depth_by_pair, ctrl_weighted)): 238 | self.assertEqual(result, ctrl) 239 | 240 | if __name__ == '__main__': 241 | unittest.main() 242 | -------------------------------------------------------------------------------- /tests/test_decontaminate.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | from numpy.random import rand 4 | 5 | from eights.decontaminate import label_encode 6 | from eights.decontaminate import replace_missing_vals 7 | import utils_for_tests 8 | 9 | from collections import Counter 10 | class TestDecontaminate(unittest.TestCase): 11 | 12 | def test_label_encoding(self): 13 | M = np.array( 14 | [('a', 0, 'Martin'), 15 | ('b', 1, 'Tim'), 16 | ('b', 2, 'Martin'), 17 | ('c', 3, 'Martin')], 18 | dtype=[('letter', 'S1'), ('idx', int), ('name', 'S6')]) 19 | ctrl = np.array( 20 | [(0, 0, 0), 21 | (1, 1, 1), 22 | (1, 2, 0), 23 | (2, 3, 0)], 24 | dtype=[('letter', int), ('idx', int), ('name', int)]) 25 | self.assertTrue(np.array_equal(ctrl, label_encode(M))) 26 | 27 | def test_replace_missing_vals(self): 28 | M = np.array([('a', 0, 0.0, 0.1), 29 | ('b', 1, 1.0, np.nan), 30 | ('', -999, np.nan, 0.0), 31 | ('d', 1, np.nan, 0.2), 32 | ('', -999, 2.0, np.nan)], 33 | dtype=[('str', 'S1'), ('int', int), ('float1', float), 34 | ('float2', float)]) 35 | 36 | ctrl = M.copy() 37 | ctrl['float1'] = np.array([0.0, 1.0, -1.0, -1.0, 2.0]) 38 | ctrl['float2'] = np.array([0.1, -1.0, 0.0, 0.2, -1.0]) 39 | res = replace_missing_vals(M, 'constant', constant=-1.0) 40 | self.assertTrue(np.array_equal(ctrl, res)) 41 | 42 | ctrl = M.copy() 43 | ctrl['int'] = np.array([100, 1, -999, 1, -999]) 44 | ctrl['float1'] = np.array([100, 1.0, np.nan, np.nan, 2.0]) 45 | ctrl['float2'] = np.array([0.1, np.nan, 100, 0.2, np.nan]) 46 | res = replace_missing_vals(M, 'constant', missing_val=0, constant=100) 47 | self.assertTrue(utils_for_tests.array_equal(ctrl, res)) 48 | 49 | ctrl = M.copy() 50 | ctrl['int'] = np.array([0, 1, 1, 1, 1]) 51 | res = replace_missing_vals(M, 'most_frequent', missing_val=-999) 52 | self.assertTrue(utils_for_tests.array_equal(ctrl, res)) 53 | 54 | ctrl = M.copy() 55 | ctrl['float1'] = np.array([0.0, 1.0, 1.0, 1.0, 2.0]) 56 | ctrl['float2'] = np.array([0.1, 0.1, 0.0, 0.2, 0.1]) 57 | res = replace_missing_vals(M, 'mean', missing_val=np.nan) 58 | self.assertTrue(utils_for_tests.array_equal(ctrl, res)) 59 | 60 | if __name__ == '__main__': 61 | unittest.main() 62 | 63 | 64 | -------------------------------------------------------------------------------- /tests/test_generate.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | from numpy.random import rand 4 | import eights.investigate 5 | import eights.utils 6 | from eights.generate import generate_bin 7 | from eights.generate import normalize 8 | from eights.generate import distance_from_point 9 | from eights.generate import where_all_are_true, val_eq, val_lt, val_between 10 | from eights.generate import combine_sum, combine_mean, combine_cols 11 | 12 | class TestGenerate(unittest.TestCase): 13 | 14 | def test_generate_bin(self): 15 | M = [1, 1, 1, 3, 3, 3, 5, 5, 5, 5, 2, 6] 16 | ctrl = [0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 0, 3] 17 | self.assertTrue(np.array_equal(ctrl, generate_bin(M, 3))) 18 | M = np.array([0.1, 3.0, 0.0, 1.2, 2.5, 1.7, 2]) 19 | ctrl = [0, 3, 0, 1, 2, 1, 2] 20 | self.assertTrue(np.array_equal(ctrl, generate_bin(M, 3))) 21 | 22 | 23 | def test_where_all_are_true(self): 24 | M = [[1,2,3], [2,3,4], [3,4,5]] 25 | col_names = ['heigh','weight', 'age'] 26 | lables= [0,0,1] 27 | M = eights.utils.cast_list_of_list_to_sa( 28 | M, 29 | col_names=col_names) 30 | 31 | arguments = [{'func': val_eq, 'col_name': 'heigh', 'vals': 1}, 32 | {'func': val_lt, 'col_name': 'weight', 'vals': 3}, 33 | {'func': val_between, 'col_name': 'age', 'vals': 34 | (3, 4)}] 35 | 36 | res = where_all_are_true( 37 | M, 38 | arguments, 39 | 'eq_to_stuff') 40 | ctrl = np.array( 41 | [(1, 2, 3, True), (2, 3, 4, False), (3, 4, 5, False)], 42 | dtype=[('heigh', ' args[0] 22 | # Create a boolean column where f2 is between 4 and 5 and f3 > 1.9 23 | where_col = where( 24 | M, 25 | [between, gt], 26 | ['f2', 'f3'], 27 | [(4, 5), (`1.9`,)]) 28 | 29 | # Append the column we just made to M 30 | M = sa_append(M, where_col) 31 | 32 | # remove columns f2 and f3 33 | M = sa_remove_col(M, ['f2', 'f3']) 34 | 35 | M_train, M_test, y_train, y_test = train_test_split(M, y) 36 | 37 | os.remove('report.pdf') 38 | 39 | # run classifiers over our modified data and generate a report 40 | run_std_classifiers(M_train, M_test, y_train, y_test, 'report.pdf') 41 | 42 | self.assertTrue(os.path.isfile('report.pdf')) 43 | 44 | -------------------------------------------------------------------------------- /tests/test_truncate.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | from collections import Counter 4 | 5 | from eights.utils import remove_cols,cast_list_of_list_to_sa 6 | 7 | from eights.truncate.truncate_helper import (col_has_all_same_val) 8 | 9 | from eights.truncate.truncate import (remove_col_where, 10 | all_equal_to, 11 | all_same_value, 12 | fewer_then_n_nonzero_in_col, 13 | remove_rows_where, 14 | val_eq) 15 | 16 | class TestInvestigate(unittest.TestCase): 17 | def test_are_all_col_equal(self): 18 | M = cast_list_of_list_to_sa( 19 | [[1,2,3], [1,3,4], [1,4,5]], 20 | col_names=['height','weight', 'age']) 21 | 22 | arguments = [{'func': all_equal_to, 'vals': 1}] 23 | M = remove_col_where(M, arguments) 24 | correct = cast_list_of_list_to_sa( 25 | [[2,3], [3,4], [4,5]], 26 | col_names=['weight', 'age']) 27 | self.assertTrue(np.array_equal(M, correct)) 28 | 29 | def test_all_same_value(self): 30 | M = cast_list_of_list_to_sa( 31 | [[1,2,3], [1,3,4], [1,4,5]], 32 | col_names=['height','weight', 'age']) 33 | arguments = [{'func': all_same_value, 'vals': None}] 34 | M = remove_col_where(M, arguments) 35 | correct = cast_list_of_list_to_sa( 36 | [[2,3], [3,4], [4,5]], 37 | col_names=['weight', 'age']) 38 | self.assertTrue(np.array_equal(M, correct)) 39 | 40 | def test_fewer_then_n_nonzero_in_col(self): 41 | M = cast_list_of_list_to_sa( 42 | [[0,2,3], [0,3,4], [1,4,5]], 43 | col_names=['height','weight', 'age']) 44 | arguments = [{'func': fewer_then_n_nonzero_in_col, 'vals': 2}] 45 | M = remove_col_where(M, arguments) 46 | correct = cast_list_of_list_to_sa( 47 | [[2,3], [3,4], [4,5]], 48 | col_names=['weight', 'age']) 49 | self.assertTrue(np.array_equal(M, correct)) 50 | 51 | def test_remove_row(self): 52 | M = cast_list_of_list_to_sa( 53 | [[0,2,3], [0,3,4], [1,4,5]], 54 | col_names=['height','weight', 'age']) 55 | arguments = [{'func': fewer_then_n_nonzero_in_col, 'vals': 2}] 56 | M = remove_rows_where(M, val_eq, 'weight', 3) 57 | correct = cast_list_of_list_to_sa( 58 | [[0, 2, 3], [1, 4, 5]], 59 | col_names=['height','weight', 'age']) 60 | self.assertTrue(np.array_equal(M, correct)) 61 | 62 | 63 | 64 | if __name__ == '__main__': 65 | unittest.main() 66 | 67 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from eights import utils 3 | import utils_for_tests 4 | from datetime import datetime 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | class TestUtils(unittest.TestCase): 10 | def __sa_check(self, sa1, sa2): 11 | # This works even if both rows and columns are in different 12 | # orders in the two arrays 13 | frozenset_sa1_names = frozenset(sa1.dtype.names) 14 | frozenset_sa2_names = frozenset(sa2.dtype.names) 15 | self.assertEqual(frozenset_sa1_names, 16 | frozenset_sa2_names) 17 | sa2_reordered = sa2[list(sa1.dtype.names)] 18 | sa1_set = {tuple(row) for row in sa1} 19 | sa2_set = {tuple(row) for row in sa2_reordered} 20 | self.assertEqual(sa1_set, sa2_set) 21 | 22 | 23 | def test_utf_to_ascii(self): 24 | s = u'\u03BBf.(\u03BBx.f(x x)) (\u05DC.f(x x))' 25 | ctrl = '?f.(?x.f(x x)) (?.f(x x))' 26 | res = utils.utf_to_ascii(s) 27 | self.assertTrue(isinstance(res, str)) 28 | self.assertEqual(ctrl, res) 29 | 30 | def test_validate_time(self): 31 | trials = [('2014-12-12', True), 32 | ('1/2/1999 8:23PM', True), 33 | ('1988-08-15T13:43:01.123', True), 34 | ('2014-14-12', False), # invalid month 35 | ('2012', False), # Just a number 36 | ('a', False), # dateutil interprets this as now 37 | ] 38 | 39 | for (s, ctrl) in trials: 40 | self.assertEqual(utils.validate_time(s), ctrl) 41 | 42 | def test_str_to_time(self): 43 | trials = [('2014-12-12', datetime(2014, 12, 12)), 44 | ('1/2/1999 8:23PM', datetime(1999, 1, 2, 20, 23)), 45 | ('1988-08-15T13:43:01.123', 46 | datetime(1988, 8, 15, 13, 43, 1, 123000)), 47 | ] 48 | 49 | for (s, ctrl) in trials: 50 | self.assertEqual(utils.str_to_time(s), ctrl) 51 | 52 | def test_cast_list_of_list_to_sa(self): 53 | L = [[None, None, None], 54 | ['a', 5, None], 55 | ['ab', 'x', None]] 56 | ctrl = np.array( 57 | [('', '', ''), 58 | ('a', '5', ''), 59 | ('ab', 'x', '')], 60 | dtype=[('f0', 'S2'), 61 | ('f1', 'S1'), 62 | ('f2', 'S1')]) 63 | conv = utils.cast_list_of_list_to_sa(L) 64 | self.assertTrue(np.array_equal(conv, ctrl)) 65 | L = [[None, u'\u05dd\u05d5\u05dc\u05e9', 4.0, 7], 66 | [2, 'hello', np.nan, None], 67 | [4, None, None, 14L]] 68 | ctrl = np.array( 69 | [(-999, u'\u05dd\u05d5\u05dc\u05e9', 4.0, 7), 70 | (2, u'hello', np.nan, -999L), 71 | (4, u'', np.nan, 14L)], 72 | dtype=[('int', int), ('ucode', 'U5'), ('float', float), 73 | ('long', long)]) 74 | conv = utils.cast_list_of_list_to_sa( 75 | L, 76 | col_names=['int', 'ucode', 'float', 'long']) 77 | self.assertTrue(utils_for_tests.array_equal(ctrl, conv)) 78 | 79 | def test_convert_to_sa(self): 80 | # already a structured array 81 | sa = np.array([(1, 1.0, 'a', datetime(2015, 01, 01)), 82 | (2, 2.0, 'b', datetime(2016, 01, 01))], 83 | dtype=[('int', int), ('float', float), ('str', 'S1'), 84 | ('date', 'M8[s]')]) 85 | self.assertTrue(np.array_equal(sa, utils.convert_to_sa(sa))) 86 | 87 | # homogeneous array no col names provided 88 | nd = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 89 | ctrl = np.array([(1, 2, 3), (4, 5, 6), (7, 8, 9)], 90 | dtype=[('f0', int), ('f1', int), ('f2', int)]) 91 | self.assertTrue(np.array_equal(ctrl, utils.convert_to_sa(nd))) 92 | 93 | # homogeneous array with col names provided 94 | nd = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 95 | ctrl = np.array([(1, 2, 3), (4, 5, 6), (7, 8, 9)], 96 | dtype=[('i0', int), ('i1', int), ('i2', int)]) 97 | self.assertTrue(np.array_equal(ctrl, utils.convert_to_sa( 98 | nd, 99 | col_names=['i0', 'i1', 'i2']))) 100 | 101 | # list of lists no col name provided 102 | lol = [[1, 1, None], 103 | ['abc', 2, 3.4]] 104 | ctrl = np.array([('1', 1, np.nan), 105 | ('abc', 2, 3.4)], 106 | dtype=[('f0', 'S3'), ('f1', int), ('f2', float)]) 107 | res = utils.convert_to_sa(lol) 108 | self.assertTrue(utils_for_tests.array_equal(ctrl, res)) 109 | 110 | # list of lists with col name provided 111 | lol = [['hello', 1.2, datetime(2012, 1, 1), None], 112 | [1.3, np.nan, None, '2013-01-01'], 113 | [1.4, 1.5, '2014-01-01', 'NO_SUCH_RECORD']] 114 | ctrl = np.array([('hello', 1.2, datetime(2012, 1, 1), utils.NOT_A_TIME), 115 | ('1.3', np.nan, utils.NOT_A_TIME, datetime(2013, 1, 1)), 116 | ('1.4', 1.5, datetime(2014, 1, 1), utils.NOT_A_TIME)], 117 | dtype=[('i0', 'S5'), ('i1', float), ('i2', 'M8[us]'), 118 | ('i3', 'M8[us]')]) 119 | res = utils.convert_to_sa(lol, col_names = ['i0', 'i1', 'i2', 'i3']) 120 | self.assertTrue(utils_for_tests.array_equal(ctrl, res)) 121 | 122 | def test_np_dtype_is_homogeneous(self): 123 | sa = np.array([(1, 'a', 2)], dtype=[('f0', int), ('f1', 'S1'), 124 | ('f2', int)]) 125 | self.assertFalse(utils.np_dtype_is_homogeneous(sa)) 126 | 127 | sa = np.array([('aa', 'a')], dtype=[('f0', 'S2'), ('f1', 'S1')]) 128 | self.assertFalse(utils.np_dtype_is_homogeneous(sa)) 129 | 130 | sa = np.array([(1, 2, 3)], dtype=[('f0', int), ('f1', int), 131 | ('f2', int)]) 132 | self.assertTrue(utils.np_dtype_is_homogeneous(sa)) 133 | 134 | 135 | def test_nd_to_sa_w_type(self): 136 | nd = np.array([[1, 2, 3], [4, 5, 6]], dtype=int) 137 | dtype = np.dtype({'names': map('f{}'.format, xrange(3)), 138 | 'formats': [int] * 3}) 139 | control = np.array([(1, 2, 3), (4, 5, 6)], dtype=dtype) 140 | result = utils.cast_np_nd_to_sa(nd, dtype) 141 | self.assertEqual(control.dtype, result.dtype) 142 | self.assertTrue(np.array_equal(result, control)) 143 | 144 | def test_nd_to_sa_no_type(self): 145 | nd = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=float) 146 | dtype = np.dtype({'names': map('f{}'.format, xrange(3)), 147 | 'formats': [float] * 3}) 148 | control = np.array([(1.0, 2.0, 3.0), (4.0, 5.0, 6.0)], dtype=dtype) 149 | result = utils.cast_np_nd_to_sa(nd) 150 | self.assertEqual(control.dtype, result.dtype) 151 | self.assertTrue(np.array_equal(result, control)) 152 | 153 | def test_sa_to_nd(self): 154 | dtype = np.dtype({'names': map('f{}'.format, xrange(3)), 155 | 'formats': [float] * 3}) 156 | sa = np.array([(-1.0, 2.0, -1.0), (0.0, -1.0, 2.0)], dtype=dtype) 157 | control = np.array([[-1.0, 2.0, -1.0], [0.0, -1.0, 2.0]], 158 | dtype=float) 159 | result = utils.cast_np_sa_to_nd(sa) 160 | self.assertTrue(np.array_equal(result, control)) 161 | 162 | def test_is_sa(self): 163 | nd = np.array([[1, 2, 3], [4, 5, 6]], dtype=int) 164 | dtype = np.dtype({'names': map('f{}'.format, xrange(3)), 165 | 'formats': [float] * 3}) 166 | sa = np.array([(-1.0, 2.0, -1.0), (0.0, -1.0, 2.0)], dtype=dtype) 167 | self.assertFalse(utils.is_sa(nd)) 168 | self.assertTrue(utils.is_sa(sa)) 169 | 170 | def test_is_nd(self): 171 | nd = np.array([[1, 2, 3], [4, 5, 6]], dtype=int) 172 | dtype = np.dtype({'names': map('f{}'.format, xrange(3)), 173 | 'formats': [float] * 3}) 174 | sa = np.array([(-1.0, 2.0, -1.0), (0.0, -1.0, 2.0)], dtype=dtype) 175 | self.assertTrue(utils.is_nd(nd)) 176 | self.assertTrue(utils.is_nd(sa)) 177 | 178 | def test_distance(self): 179 | # Coords according to https://tools.wmflabs.org/geohack/ 180 | # Minneapolis 181 | lat1 = 44.98 182 | lng1 = -93.27 183 | 184 | # Chicago 185 | lat2 = 41.84 186 | lng2 = -87.68 187 | 188 | # Sao Paulo 189 | lat3 = -23.55 190 | lng3 = -46.63 191 | 192 | # distances from http://www.movable-type.co.uk/scripts/latlong.html 193 | # (Rounds to nearest km) 194 | 195 | self.assertTrue(np.allclose(utils.distance(lat1, lng1, lat2, lng2), 196 | 570.6, atol=1, rtol=0)) 197 | self.assertTrue(np.allclose(utils.distance(lat1, lng1, lat3, lng3), 198 | 8966, atol=1, rtol=0)) 199 | 200 | def test_dist_less_than(self): 201 | # Minneapolis 202 | lat1 = 44.98 203 | lng1 = -93.27 204 | 205 | # Chicago 206 | lat2 = 41.84 207 | lng2 = -87.68 208 | 209 | self.assertTrue(utils.dist_less_than(lat1, lng1, lat2, lng2, 600)) 210 | self.assertFalse(utils.dist_less_than(lat1, lng1, lat2, lng2, 500)) 211 | 212 | def test_stack_rows(self): 213 | dtype = [('id', int), ('name', 'S1')] 214 | M1 = np.array([(1, 'a'), (2, 'b')], dtype=dtype) 215 | M2 = np.array([(3, 'c'), (4, 'd'), (5, 'e')], dtype=dtype) 216 | ctrl = np.array([(1, 'a'), (2, 'b'), (3, 'c'), (4, 'd'), (5, 'e')], 217 | dtype=dtype) 218 | res = utils.stack_rows(M1, M2) 219 | self.assertTrue(np.array_equal(ctrl, res)) 220 | 221 | def test_from_cols(self): 222 | col1 = np.array([1, 2, 3]) 223 | col2 = np.array([4.0, 5.0, 6.0]) 224 | ctrl = np.array( 225 | [(1, 4.0), (2, 5.0), (3, 6.0)], 226 | dtype=[('f0', int), ('f1', float)]) 227 | res = utils.sa_from_cols([col1, col2]) 228 | self.assertTrue(np.array_equal(ctrl, res)) 229 | 230 | def test_append_cols(self): 231 | M = np.array([(1, 'a'), (2, 'b')], dtype=[('int', int), ('str', 'S1')]) 232 | col1 = np.array([1.0, 2.0]) 233 | col2 = np.array([datetime(2015, 12, 12), datetime(2015, 12, 13)], 234 | dtype='M8[us]') 235 | 236 | ctrl = np.array( 237 | [(1, 'a', 1.0), (2, 'b', 2.0)], 238 | dtype=[('int', int), ('str', 'S1'), ('float', float)]) 239 | res = utils.append_cols(M, col1, 'float') 240 | self.assertTrue(np.array_equal(ctrl, res)) 241 | 242 | ctrl = np.array( 243 | [(1, 'a', 1.0, datetime(2015, 12, 12)), 244 | (2, 'b', 2.0, datetime(2015, 12, 13))], 245 | dtype=[('int', int), ('str', 'S1'), ('float', float), 246 | ('dt', 'M8[us]')]) 247 | res = utils.append_cols(M, [col1, col2], ['float', 'dt']) 248 | self.assertTrue(np.array_equal(ctrl, res)) 249 | 250 | def test_remove_cols(self): 251 | M = np.array( 252 | [(1, 'a', 1.0, datetime(2015, 12, 12)), 253 | (2, 'b', 2.0, datetime(2015, 12, 13))], 254 | dtype=[('int', int), ('str', 'S1'), ('float', float), 255 | ('dt', 'M8[us]')]) 256 | 257 | ctrl = np.array( 258 | [(1, 'a', 1.0), (2, 'b', 2.0)], 259 | dtype=[('int', int), ('str', 'S1'), ('float', float)]) 260 | res = utils.remove_cols(M, 'dt') 261 | self.assertTrue(np.array_equal(ctrl, res)) 262 | 263 | ctrl = np.array([(1, 'a'), (2, 'b')], dtype=[('int', int), 264 | ('str', 'S1')]) 265 | res = utils.remove_cols(M, ['dt', 'float']) 266 | self.assertTrue(np.array_equal(ctrl, res)) 267 | 268 | def test_join(self): 269 | # test basic inner join 270 | a1 = np.array([(0, 'Lisa', 2), 271 | (1, 'Bill', 1), 272 | (2, 'Fred', 2), 273 | (3, 'Samantha', 2), 274 | (4, 'Augustine', 1), 275 | (5, 'William', 0)], dtype=[('id', int), 276 | ('name', 'S64'), 277 | ('dept_id', int)]) 278 | a2 = np.array([(0, 'accts receivable'), 279 | (1, 'accts payable'), 280 | (2, 'shipping')], dtype=[('id', int), 281 | ('name', 'S64')]) 282 | ctrl = pd.DataFrame(a1).merge( 283 | pd.DataFrame(a2), 284 | left_on='dept_id', 285 | right_on='id').to_records(index=False) 286 | res = utils.join(a1, a2, 'inner', 'dept_id', 'id') 287 | self.__sa_check(ctrl, res) 288 | 289 | # test column naming rules 290 | a1 = np.array([(0, 'a', 1, 2, 3)], dtype=[('idx0', int), 291 | ('name', 'S1'), 292 | ('a1_idx1', int), 293 | ('idx2', int), 294 | ('idx3', int)]) 295 | a2 = np.array([(0, 'b', 1, 2, 3)], dtype=[('idx0', int), 296 | ('name', 'S1'), 297 | ('a2_idx1', int), 298 | ('idx2', int), 299 | ('idx3', int)]) 300 | pd1 = pd.DataFrame(a1) 301 | pd2 = pd.DataFrame(a2) 302 | ctrl = pd1.merge( 303 | pd2, 304 | left_on=['idx0', 'a1_idx1', 'idx2'], 305 | right_on=['idx0', 'a2_idx1', 'idx2'], 306 | suffixes=['_left', '_right']).to_records(index=False) 307 | res = utils.join( 308 | a1, 309 | a2, 310 | 'inner', 311 | left_on=['idx0', 'a1_idx1', 'idx2'], 312 | right_on=['idx0', 'a2_idx1', 'idx2'], 313 | suffixes=['_left', '_right']) 314 | self.__sa_check(ctrl, res) 315 | 316 | # outer joins 317 | a1 = np.array( 318 | [(0, 'a1_0', 0), 319 | (1, 'a1_1', 1), 320 | (1, 'a1_2', 2), 321 | (2, 'a1_3', 3), 322 | (3, 'a1_4', 4)], 323 | dtype=[('key', int), ('label', 'S64'), ('idx', int)]) 324 | a2 = np.array( 325 | [(0, 'a2_0', 0), 326 | (1, 'a2_1', 1), 327 | (2, 'a2_2', 2), 328 | (2, 'a2_3', 3), 329 | (4, 'a2_4', 4)], 330 | dtype=[('key', int), ('label', 'S64'), ('idx', int)]) 331 | #for how in ('inner', 'left', 'right', 'outer'): 332 | merged_dtype = [('key', int), ('label_x', 'S64'), ('idx_x', int), 333 | ('label_y', 'S64'), ('idx_y', int)] 334 | merge_algos = ('inner', 'left', 'right', 'outer') 335 | merged_data = [[(0, 'a1_0', 0, 'a2_0', 0), 336 | (1, 'a1_1', 1, 'a2_1', 1), 337 | (1, 'a1_2', 2, 'a2_1', 1), 338 | (2, 'a1_3', 3, 'a2_2', 2), 339 | (2, 'a1_3', 3, 'a2_3', 3)], 340 | [(0, 'a1_0', 0, 'a2_0', 0), 341 | (1, 'a1_1', 1, 'a2_1', 1), 342 | (1, 'a1_2', 2, 'a2_1', 1), 343 | (2, 'a1_3', 3, 'a2_2', 2), 344 | (2, 'a1_3', 3, 'a2_3', 3), 345 | (3, 'a1_4', 4, '', -999)], 346 | [(0, 'a1_0', 0, 'a2_0', 0), 347 | (1, 'a1_1', 1, 'a2_1', 1), 348 | (1, 'a1_2', 2, 'a2_1', 1), 349 | (2, 'a1_3', 3, 'a2_2', 2), 350 | (2, 'a1_3', 3, 'a2_3', 3), 351 | (4, '', -999, 'a2_4', 4)], 352 | [(0, 'a1_0', 0, 'a2_0', 0), 353 | (1, 'a1_1', 1, 'a2_1', 1), 354 | (1, 'a1_2', 2, 'a2_1', 1), 355 | (2, 'a1_3', 3, 'a2_2', 2), 356 | (2, 'a1_3', 3, 'a2_3', 3), 357 | (4, '', -999, 'a2_4', 4), 358 | (3, 'a1_4', 4, '', -999)]] 359 | for how, data in zip(merge_algos, merged_data): 360 | res = utils.join( 361 | a1, 362 | a2, 363 | how, 364 | left_on='key', 365 | right_on='key') 366 | ctrl = np.array(data, dtype=merged_dtype) 367 | self.__sa_check(ctrl, res) 368 | 369 | 370 | if __name__ == '__main__': 371 | unittest.main() 372 | -------------------------------------------------------------------------------- /tests/test_utils_for_tests.py: -------------------------------------------------------------------------------- 1 | 2 | import unittest 3 | import utils_for_tests 4 | 5 | class TestUtilsForTests(unittest.TestCase): 6 | def test_generate_matrix(self): 7 | M, y = utils_for_tests.generate_test_matrix(100, 5, 3, [float, str, int]) 8 | print M 9 | print y 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /tests/utils_for_tests.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import itertools as it 4 | import numpy as np 5 | import string 6 | import eights.utils 7 | from numpy.random import rand, seed 8 | from contextlib import contextmanager 9 | from StringIO import StringIO 10 | 11 | TESTS_PATH = os.path.dirname(os.path.realpath(sys.argv[0])) 12 | DATA_PATH = os.path.join(TESTS_PATH, 'data') 13 | EIGHTS_PATH = os.path.join(TESTS_PATH, '..') 14 | 15 | def path_of_data(filename): 16 | return os.path.join(DATA_PATH, filename) 17 | 18 | def generate_test_matrix(n_rows, n_cols=1, n_classes=2, types=[], random_state=None): 19 | full_types = list(it.chain(types, it.repeat(float, n_cols - len(types)))) 20 | np.random.seed(random_state) 21 | cols = [] 22 | for col_type in full_types: 23 | if col_type is int: 24 | col = np.random.randint(100, size=rows) 25 | elif issubclass(col_type, basestring): 26 | col = np.random.choice(list(string.uppercase), size=n_rows) 27 | else: 28 | col = np.random.random(size=n_rows) 29 | cols.append(col) 30 | labels = np.random.randint(n_classes, size=n_rows) 31 | M = eights.utils.sa_from_cols(cols) 32 | return M, labels 33 | 34 | def generate_correlated_test_matrix(n_rows): 35 | seed(0) 36 | M = rand(n_rows, 1) 37 | y = rand(n_rows) < M[:,0] 38 | return M, y 39 | 40 | def array_equal(M1, M2, eps=1e-5): 41 | """ 42 | unlike np.array_equal, works correctly for nan and ignores floating 43 | point errors up to eps 44 | """ 45 | if M1.dtype != M2.dtype: 46 | return False 47 | for col_name, col_type in M1.dtype.descr: 48 | M1_col = M1[col_name] 49 | M2_col = M2[col_name] 50 | if 'f' not in col_type: 51 | if not(np.array_equal(M1_col, M2_col)): 52 | return False 53 | else: 54 | if not (np.all(np.logical_or( 55 | abs(M1_col - M2_col) < eps, 56 | np.logical_and(np.isnan(M1_col), np.isnan(M2_col))))): 57 | return False 58 | return True 59 | 60 | @contextmanager 61 | def rerout_stdout(): 62 | """ 63 | print statements within the context are rerouted to a StringIO, which 64 | can be examined with the method that is yielded here. 65 | 66 | Examples 67 | -------- 68 | >>> print 'This text appears in the console' 69 | This text appears in the console 70 | >>> with rerout_stdout() as get_rerouted_stdout: 71 | ... print 'This text does not appear in the console' 72 | ... # get_rerouted_stdout is a function that gets our rerouted output 73 | ... assert(get_rerouted_stdout().strip() == 'This text does not appear in the console') 74 | >>> print 'This text also appears in the console' 75 | This text also appears in the console 76 | """ 77 | # based on http://stackoverflow.com/questions/4219717/how-to-assert-output-with-nosetest-unittest-in-python 78 | saved_stdout = sys.stdout 79 | try: 80 | out = StringIO() 81 | sys.stdout = out 82 | yield out.getvalue 83 | finally: 84 | sys.stdout = saved_stdout 85 | 86 | def print_in_box(heading, text): 87 | """ Prints text in a nice box. 88 | 89 | Parameters 90 | ---------- 91 | heading : str 92 | text : str or list of str 93 | if a list of str, each item of the list gets its own line 94 | """ 95 | if isinstance(text, basestring): 96 | text = text.split('\n') 97 | str_len = max(len(heading), max([len(line) for line in text])) 98 | meta_fmt = ('{{border}}{{space}}' 99 | '{{{{content:{{fill}}{{align}}{str_len}}}}}' 100 | '{{space}}{{border}}\n').format(str_len=str_len) 101 | boundary = meta_fmt.format( 102 | fill='-', 103 | align='^', 104 | border='+', 105 | space='-').format( 106 | content='') 107 | heading_line = meta_fmt.format( 108 | fill='', 109 | align='^', 110 | border='|', 111 | space=' ').format( 112 | content=heading) 113 | line_fmt = meta_fmt.format( 114 | fill='', 115 | align='<', 116 | border='|', 117 | space=' ') 118 | sys.stdout.write('\n') 119 | sys.stdout.write(boundary) 120 | sys.stdout.write(heading_line) 121 | sys.stdout.write(boundary.replace('-', '=')) 122 | sys.stdout.write(''.join([line_fmt.format(content=line) for line in text])) 123 | sys.stdout.write(boundary) 124 | sys.stdout.write('\n') 125 | --------------------------------------------------------------------------------