├── .gitattributes ├── .gitignore ├── Media └── .gitignore ├── README.md ├── Scripts ├── post1.py ├── post10.py ├── post5-1.py ├── post5.py ├── post7.py ├── post8.py ├── post9.py ├── statsapis.py ├── untitled2.py └── updateBip.py ├── setup.py └── statcast ├── __init__.py ├── better ├── __init__.py ├── base.py ├── declassifier.py ├── kde.py ├── kdr.py ├── mixed.py ├── randomforest.py ├── sm.py ├── spark.py └── utils.py ├── bip.py ├── data ├── bandwidths2015.csv ├── bandwidths2016.csv ├── blackontrans.mplstyle ├── logos │ ├── ARI.png │ ├── ATL.png │ ├── BAL.png │ ├── BOS.png │ ├── CHC.png │ ├── CIN.png │ ├── CLE.png │ ├── COL.png │ ├── CWS.png │ ├── DET.png │ ├── HOU.png │ ├── KC.png │ ├── LAA.png │ ├── LAD.png │ ├── MIA.png │ ├── MIL.png │ ├── MIN.png │ ├── NYM.png │ ├── NYY.png │ ├── OAK.png │ ├── PHI.png │ ├── PIT.png │ ├── SD.png │ ├── SEA.png │ ├── SF.png │ ├── STL.png │ ├── TB.png │ ├── TEX.png │ ├── TOR.png │ └── WSH.png └── personal.mplstyle ├── database ├── __init__.py ├── bbsavant.py ├── database.py ├── gd_game_events.py ├── gd_scoreboards.py ├── gd_weather.py └── gddb.py ├── plot.py └── tools ├── __init__.py ├── convolution.py ├── fixpath.py ├── montecarlo.py └── plot.py /.gitattributes: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/.gitattributes -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | #compiled source # 2 | ################### 3 | *.com 4 | *.class 5 | *.dll 6 | *.exe 7 | *.o 8 | *.so 9 | 10 | # Packages # 11 | ############ 12 | # it's better to unpack these files and commit the raw source 13 | # git has its own built in compression methods 14 | *.7z 15 | *.dmg 16 | *.gz 17 | *.iso 18 | *.jar 19 | *.rar 20 | *.tar 21 | *.zip 22 | 23 | # Logs and databases # 24 | ###################### 25 | *.log 26 | *.sql 27 | *.sqlite 28 | 29 | # OS generated files # 30 | ###################### 31 | .DS_Store 32 | .DS_Store? 33 | ._* 34 | .Spotlight-V100 35 | .Trashes 36 | ehthumbs.db 37 | Thumbs.db 38 | Desktop.ini 39 | 40 | # codekit # 41 | ########### 42 | .sass-cache/ 43 | .codekit-config.json 44 | config.codekit 45 | 46 | # Repo Specific # 47 | ################# 48 | **/__pycache__ -------------------------------------------------------------------------------- /Media/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # statcast 2 | Using machine learning to play around with MLB statcast data 3 | -------------------------------------------------------------------------------- /Scripts/post1.py: -------------------------------------------------------------------------------- 1 | # %% Imports 2 | 3 | import pandas as pd 4 | from matplotlib import pyplot as plt 5 | 6 | from statcast.plot import plotMLBLogos 7 | 8 | # %% Plot 2002 Salary vs Wins 9 | 10 | plt.style.use('blackontrans') 11 | 12 | df = pd.read_csv('/Users/mattfay/Downloads/MLB02.csv', 13 | names=['team', 'winP', 'money'], index_col=0) 14 | 15 | fig = plt.figure() 16 | ax = fig.add_subplot(1, 1, 1) 17 | plotMLBLogos(df.money * 1e-6, df.winP * 1e2, sizes=35, ax=ax) 18 | 19 | ax.set_xlabel('Salary (Million $)') 20 | ax.set_ylabel('Winning Percentage (%)') 21 | ax.set_title('2002 MLB Regular Season') 22 | ax.set_ybound(ax.get_ylim()[0] - 1.2, ax.get_ylim()[1] + 1.2) 23 | 24 | fig.savefig('MLB02.png') 25 | -------------------------------------------------------------------------------- /Scripts/post10.py: -------------------------------------------------------------------------------- 1 | # %% Imports 2 | 3 | import os 4 | import datetime 5 | 6 | import requests 7 | from pyspark import SparkContext 8 | 9 | from sklearn.model_selection import StratifiedKFold 10 | from sklearn.metrics import log_loss 11 | from sklearn.preprocessing import LabelBinarizer 12 | 13 | from statcast.bip import Bip 14 | from statcast.better.spark import cross_val_predict 15 | from statcast.tools.plot import plotPrecRec, plotPrecRecMN, plotResiduals 16 | from statcast.better.declassifier import KDEClassifier 17 | 18 | 19 | # %% Create Spark Context 20 | 21 | sc = SparkContext(appName='post10') 22 | 23 | # %% 24 | 25 | bip = Bip(years=(2016,), n_jobs=-1) 26 | 27 | # %% 28 | 29 | xLabels = ['hit_speed', 'hit_angle', 'sprayAngle'] 30 | fancyLabels = ['Exit Velocity', 'Launch Angle', 'Spray Angle'] 31 | units = ['mph', 'degrees', 'degrees'] 32 | yLabel = 'events' 33 | 34 | subData = bip.data.loc[~bip.data['exclude'], xLabels + [yLabel]] 35 | 36 | outs = ['Bunt Groundout', 'Double Play', 'Fielders Choice', 37 | 'Fielders Choice Out', 'Flyout', 'Forceout', 'Grounded Into DP', 38 | 'Groundout', 'Lineout', 'Pop Out', 'Runner Out', 'Sac Bunt', 39 | 'Sac Fly', 'Sac Fly DP', 'Triple Play', 'Bunt Pop Out', 'Bunt Lineout', 40 | 'Sacrifice Bunt DP'] 41 | 42 | subData['events'] = subData['events'].cat.add_categories(['Out']) 43 | 44 | for out in outs: 45 | subData.loc[subData['events'] == out, 'events'] = 'Out' 46 | 47 | subData['events'] = subData['events'].cat.remove_unused_categories() 48 | 49 | X1 = subData.loc[:, xLabels[:-1]] 50 | X2 = subData.loc[:, xLabels] 51 | y1 = subData[yLabel] == 'Home Run' 52 | y2 = subData[yLabel] 53 | 54 | skf = StratifiedKFold(n_splits=10, shuffle=True) 55 | 56 | kdc = KDEClassifier(kdeParams=dict(kernel='gaussian'), 57 | n_jobs=-1) 58 | 59 | y11p = cross_val_predict(kdc, X1, y1, cv=skf, n_jobs=sc, 60 | method='predict_proba') 61 | y21p = cross_val_predict(kdc, X2, y1, cv=skf, n_jobs=sc, 62 | method='predict_proba') 63 | y12p = cross_val_predict(kdc, X1, y2, cv=skf, n_jobs=sc, 64 | method='predict_proba') 65 | y22p = cross_val_predict(kdc, X2, y2, cv=skf, n_jobs=sc, 66 | method='predict_proba') 67 | 68 | y11p = y11p[:, 1] 69 | y21p = y21p[:, 1] 70 | 71 | # %% Log-loss 72 | 73 | logL11 = log_loss(y1, y11p) 74 | logL21 = log_loss(y1, y21p) 75 | logL12 = log_loss(y2, y12p) 76 | logL22 = log_loss(y2, y22p) 77 | 78 | # %% Plot Precision-Recall Curve 79 | 80 | fig = plotPrecRec(y1, y11p, label='EV + LA: LL={:.2f}'.format(logL11)) 81 | ax = fig.gca() 82 | plotPrecRec(y1, y21p, ax=ax, label='EV + LA + SA: LL={:.2f}'.format(logL21)) 83 | ax.legend() 84 | ax.set_title('KDC Homerun Classifier') 85 | fig.savefig('KDC HR Prec-Rec Curve') 86 | 87 | fig = plotPrecRecMN(y2, y12p) 88 | fig.gca().set_title('EV + LA: LL={:.2f}'.format(logL12)) 89 | fig.savefig('KDC(EV + LA) Hit Prec-Rec Curves') 90 | fig = plotPrecRecMN(y2, y22p) 91 | fig.gca().set_title('EV + LA + SA: LL={:.2f}'.format(logL22)) 92 | fig.savefig('KDC(EV + LA + SA) Hit Prec-Rec Curves') 93 | 94 | # %% Plot Residuals 95 | 96 | figs11 = plotResiduals(X1.values, y1 * 100, y11p * 100, 97 | xLabels=fancyLabels[:-1], xUnits=units[:-1], 98 | yLabels=['Home Run'], yUnits=['%'], pltParams={'ms': 1}) 99 | for label, fig in zip(xLabels[:-1], figs11): 100 | fig.gca().set_title('HR(EV + LA) Classifier') 101 | fig.savefig('KDC(EV + LA) HR Residuals over {}'.format(label)) 102 | figs21 = plotResiduals(X2.values, y1 * 100, y21p * 100, 103 | xLabels=fancyLabels, xUnits=units, 104 | yLabels=['Home Run'], yUnits=['%'], pltParams={'ms': 1}) 105 | for label, fig in zip(xLabels, figs21): 106 | fig.gca().set_title('HR(EV + LA + SA) Classifier') 107 | fig.savefig('KDC(EV + LA + SA) HR Residuals over {}'.format(label)) 108 | 109 | Y2 = LabelBinarizer().fit_transform(y2) 110 | y2Labels = sorted(y2.cat.categories) 111 | 112 | figs12 = plotResiduals(X1.values, Y2 * 100, y12p * 100, 113 | xLabels=fancyLabels[:-1], xUnits=units[:-1], 114 | yLabels=y2Labels, yUnits=['%'] * Y2.shape[1], 115 | pltParams={'ms': 1}) 116 | for label, fig in zip(xLabels[:-1], figs12): 117 | fig.get_axes()[0].set_title('Hit(EV + LA) Classifier') 118 | fig.savefig('KDC(EV + LA) Hit Residuals over {}'.format(label)) 119 | figs22 = plotResiduals(X2.values, Y2 * 100, y22p * 100, 120 | xLabels=fancyLabels, xUnits=units, 121 | yLabels=y2Labels, yUnits=['%'] * Y2.shape[1], 122 | pltParams={'ms': 1}) 123 | for label, fig in zip(xLabels, figs22): 124 | fig.get_axes()[0].set_title('Hit(EV + LA + SA) Classifier') 125 | fig.savefig('KDC(EV + LA + SA) Hit Residuals over {}'.format(label)) 126 | 127 | # %% Transfer results to S3 128 | 129 | instanceID = requests. \ 130 | get('http://169.254.169.254/latest/meta-data/instance-id').text 131 | dtStr = datetime.datetime.utcnow().strftime('%Y-%m-%d--%H-%M-%S') 132 | os.system('aws s3 sync . s3://mf-first-bucket/output/{}/{}'. 133 | format(instanceID, dtStr)) 134 | 135 | # %% Stop Spark Context 136 | 137 | sc.stop() 138 | -------------------------------------------------------------------------------- /Scripts/post5-1.py: -------------------------------------------------------------------------------- 1 | # %% Imports 2 | 3 | from matplotlib import pyplot as plt 4 | 5 | from statcast.bip import Bip 6 | from statcast.tools.plot import correlationPlot 7 | from statcast.better.utils import findTrainSplit 8 | 9 | # %% Plot correlation of imputing model 10 | 11 | years = (2016, 2015) 12 | 13 | labels = ['Exit Velocity', 'Launch Angle', 'Hit Distance'] 14 | units = ['mph', 'degrees', 'feet'] 15 | 16 | for year in years: 17 | 18 | bip = Bip(years=(year,), n_jobs=-1) 19 | 20 | testData = bip.data.loc[~bip.data.exclude & ~bip.data.scImputed, :] 21 | 22 | testY = bip.scImputer.createY(testData) 23 | testYp = bip.scImputer.predictD(testData) 24 | 25 | labelsYr = ['{} {}'.format(label, year) for label in labels] 26 | 27 | figs = correlationPlot(testY, 28 | testYp, 29 | labels=labelsYr, 30 | units=units, 31 | ms=0.7) 32 | 33 | for fig, label in zip(figs, labels): 34 | fig.savefig('{} Correlation {}'.format(label, year)) 35 | 36 | # %% Plot Tree Curve 37 | 38 | fig = plt.figure() 39 | ax = fig.add_subplot(1, 1, 1) 40 | ax.plot(bip.scImputer.treeScores_.index.values, 41 | bip.scImputer.treeScores_.values, 'o-') 42 | ax.set_xlabel('Number of Trees') 43 | ax.set_ylabel('Out of Bag Score (R^2)') 44 | ax.xaxis.set_ticklabels(ax.xaxis.get_majorticklocs().astype(int)) 45 | 46 | fig.savefig('Number of Trees Example') 47 | 48 | # %% Plot RFE Scores 49 | 50 | fig = plt.figure() 51 | ax = fig.add_subplot(1, 1, 1) 52 | ax.plot(bip.scImputer.rfeResults_.columns.shape[0] - 1 - 53 | bip.scImputer.rfeResults_['scores'].index, 54 | [-thing.mean() for thing in bip.scImputer.rfeResults_['scores']]) 55 | ax.set_xlabel('Number of Features') 56 | ax.set_ylabel('Cross-validation Score (RMS Error)') 57 | 58 | fig.savefig('RFE Example') 59 | 60 | # %% Plot Learning Curve 61 | 62 | trainData = bip.data.loc[~bip.data.exclude & ~bip.data.scImputed, :] 63 | 64 | findTrainSplit(bip.scImputer, trainData, cv=10, n_jobs=-1, scoreThresh=0.2) 65 | 66 | fig = plt.figure() 67 | ax = fig.add_subplot(1, 1, 1) 68 | ax.plot(bip.scImputer.trainSplitResults_['size'].values, 69 | [-thing.mean() for thing in bip.scImputer.trainSplitResults_.score], 70 | 'o-') 71 | ax.set_xlabel('Number of Datapoints') 72 | ax.set_ylabel('Cross-validation Score (RMS Error)') 73 | ax.set_xscale('log') 74 | 75 | fig.savefig('Learning Curve Example') 76 | -------------------------------------------------------------------------------- /Scripts/post5.py: -------------------------------------------------------------------------------- 1 | # %% Imports 2 | 3 | import os 4 | import datetime 5 | 6 | import requests 7 | from pyspark import SparkContext 8 | 9 | from statcast.bip import Bip 10 | 11 | 12 | # %% Create Spark Context 13 | 14 | sc = SparkContext(appName="post5") 15 | 16 | # %% Load data, plot histograms of statcast data 17 | 18 | years = (2015, 2016) 19 | 20 | for year in years: 21 | bip = Bip(years=(year,), n_jobs=sc) 22 | bip.plotSCHistograms() 23 | 24 | # %% Transfer results to S3 25 | 26 | instanceID = requests. \ 27 | get('http://169.254.169.254/latest/meta-data/instance-id').text 28 | dtStr = datetime.datetime.utcnow().strftime('%Y-%m-%d--%H-%M-%S') 29 | os.system('aws s3 sync . s3://mf-first-bucket/output/{}/{}'. 30 | format(instanceID, dtStr)) 31 | 32 | # %% Stop Spark Context 33 | 34 | sc.stop() 35 | -------------------------------------------------------------------------------- /Scripts/post7.py: -------------------------------------------------------------------------------- 1 | # %% Imports 2 | 3 | from scipy import stats 4 | from matplotlib import pyplot as plt 5 | 6 | from statcast.bip import Bip 7 | from statcast.plot import plotMLBLogos 8 | from statcast.tools.plot import addText 9 | 10 | 11 | # %% 12 | 13 | bip15 = Bip(years=(2015,), n_jobs=-1) 14 | bip16 = Bip(years=(2016,), n_jobs=-1) 15 | 16 | # %% Plot Correlations 17 | 18 | labels = ['hit_speed', 'hit_angle', 'hit_distance_sc'] 19 | units = ['mph', 'degrees', 'feet'] 20 | fancyLabels = ['Exit Velocity', 'Launch Angle', 'Hit Distance'] 21 | 22 | for i, (label, unit, fancyLabel) in enumerate(zip(labels, units, fancyLabels)): 23 | if '(scImputed||home_team)' in bip15.scFactorMdl.formulas[i]: 24 | x = bip15.scFactorMdl.factors_[label]['home_team']['(Intercept)'] + \ 25 | bip15.scFactorMdl.factors_[label]['home_team']['scImputedFALSE'] 26 | missing15 = False 27 | else: 28 | x = bip15.scFactorMdl.factors_[label]['home_team']['(Intercept)'] 29 | missing15 = True 30 | if '(scImputed||home_team)' in bip16.scFactorMdl.formulas[i]: 31 | y = bip16.scFactorMdl.factors_[label]['home_team']['(Intercept)'] + \ 32 | bip16.scFactorMdl.factors_[label]['home_team']['scImputedFALSE'] 33 | missing16 = False 34 | else: 35 | y = bip16.scFactorMdl.factors_[label]['home_team']['(Intercept)'] 36 | missing16 = True 37 | 38 | fig = plt.figure() 39 | ax = fig.add_subplot(1, 1, 1) 40 | ax.plot(x, y, alpha=0) 41 | 42 | axLims = list(ax.axis()) 43 | axLims[0] = axLims[2] = min(axLims[0::2]) 44 | axLims[1] = axLims[3] = max(axLims[1::2]) 45 | ax.axis(axLims) 46 | 47 | plotMLBLogos(x, y, ax=ax) 48 | 49 | ax.plot(axLims[:2], axLims[2:], 50 | '--', color=plt.rcParams['lines.color'], linewidth=1) 51 | 52 | ax.set_title('{} ({}) Venue Bias'.format(fancyLabel, unit)) 53 | ax.set_xlabel('2015 Season') 54 | ax.set_ylabel('2016 Season') 55 | 56 | r2 = stats.pearsonr(x, y)[0] ** 2 57 | labels = ['R2: {:.2f}'.format(r2)] 58 | addText(ax, labels, loc='lower right') 59 | 60 | fig.savefig('{} 2015-2016 Correlation'.format(fancyLabel)) 61 | 62 | if missing15 or missing16: 63 | continue 64 | 65 | x = bip15.scFactorMdl.factors_[label]['home_team']['(Intercept)'] + \ 66 | bip15.scFactorMdl.factors_[label]['home_team']['scImputedTRUE'] 67 | y = bip16.scFactorMdl.factors_[label]['home_team']['(Intercept)'] + \ 68 | bip16.scFactorMdl.factors_[label]['home_team']['scImputedTRUE'] 69 | 70 | fig = plt.figure() 71 | ax = fig.add_subplot(1, 1, 1) 72 | ax.plot(x, y, alpha=0) 73 | 74 | axLims = list(ax.axis()) 75 | axLims[0] = axLims[2] = min(axLims[0::2]) 76 | axLims[1] = axLims[3] = max(axLims[1::2]) 77 | ax.axis(axLims) 78 | 79 | plotMLBLogos(x, y, ax=ax) 80 | 81 | ax.plot(axLims[:2], axLims[2:], 82 | '--', color=plt.rcParams['lines.color'], linewidth=1) 83 | 84 | ax.set_title('Missing {} ({}) Venue Bias'.format(fancyLabel, unit)) 85 | ax.set_xlabel('2015 Season') 86 | ax.set_ylabel('2016 Season') 87 | 88 | r2 = stats.pearsonr(x, y)[0] ** 2 89 | labels = ['R2: {:.2f}'.format(r2)] 90 | addText(ax, labels, loc='lower right') 91 | 92 | fig.savefig('Missing {} 2015-2016 Correlation'.format(fancyLabel)) 93 | -------------------------------------------------------------------------------- /Scripts/post8.py: -------------------------------------------------------------------------------- 1 | # %% Imports 2 | 3 | import numpy as np 4 | 5 | from statsmodels import api as sm 6 | from sklearn.model_selection import StratifiedKFold 7 | from sklearn.metrics import log_loss 8 | from sklearn.preprocessing import LabelBinarizer 9 | 10 | from statcast.bip import Bip 11 | from statcast.better.sm import BetterGLM, BetterMNLogit 12 | from statcast.better.spark import cross_val_predict 13 | from statcast.tools.plot import plotPrecRec, plotPrecRecMN, plotResiduals 14 | 15 | 16 | # %% 17 | 18 | bip = Bip(years=(2016,), n_jobs=-1) 19 | 20 | # %% 21 | 22 | xLabels = ['hit_speed', 'hit_angle', 'sprayAngle'] 23 | fancyLabels = ['Exit Velocity', 'Launch Angle', 'Spray Angle'] 24 | units = ['mph', 'degrees', 'degrees'] 25 | yLabel = 'events' 26 | 27 | subData = bip.data.loc[~bip.data['exclude'], xLabels + [yLabel]] 28 | 29 | outs = ['Bunt Groundout', 'Double Play', 'Fielders Choice', 30 | 'Fielders Choice Out', 'Flyout', 'Forceout', 'Grounded Into DP', 31 | 'Groundout', 'Lineout', 'Pop Out', 'Runner Out', 'Sac Bunt', 32 | 'Sac Fly', 'Sac Fly DP', 'Triple Play', 'Bunt Pop Out', 'Bunt Lineout', 33 | 'Sacrifice Bunt DP'] 34 | 35 | subData['events'] = subData['events'].cat.add_categories(['Out']) 36 | 37 | for out in outs: 38 | subData.loc[subData['events'] == out, 'events'] = 'Out' 39 | 40 | subData['events'] = subData['events'].cat.remove_unused_categories() 41 | 42 | X1 = subData.loc[:, xLabels[:-1]] 43 | X2 = subData.loc[:, xLabels] 44 | y1 = subData[yLabel] == 'Home Run' 45 | y2 = subData[yLabel] 46 | 47 | skf = StratifiedKFold(n_splits=10, shuffle=True) 48 | 49 | glm = BetterGLM( 50 | SMParams={'family': sm.families.Binomial(sm.families.links.probit)}) 51 | mnLogit = BetterMNLogit() 52 | 53 | y11p = cross_val_predict(glm, X1, y1, cv=skf, n_jobs=-1, 54 | method='predict_proba') 55 | y21p = cross_val_predict(glm, X2, y1, cv=skf, n_jobs=-1, 56 | method='predict_proba') 57 | y12p = cross_val_predict(mnLogit, X1, y2, cv=skf, n_jobs=-1, 58 | method='predict_proba') 59 | y22p = cross_val_predict(mnLogit, X2, y2, cv=skf, n_jobs=-1, 60 | method='predict_proba') 61 | 62 | # %% Log-loss 63 | 64 | L11 = np.exp(-log_loss(y1, y11p)) 65 | L21 = np.exp(-log_loss(y1, y21p)) 66 | L12 = np.exp(-log_loss(y2, y12p)) 67 | L22 = np.exp(-log_loss(y2, y22p)) 68 | 69 | # %% Plot Precision-Recall Curve 70 | 71 | fig = plotPrecRec(y1, y11p, label='EV + LA: L={:.0f}%'.format(L11 * 100)) 72 | ax = fig.gca() 73 | plotPrecRec(y1, y21p, ax=ax, label='EV + LA + SA: L={:.0f}%'.format(L21 * 100)) 74 | ax.legend() 75 | ax.set_title('GLM Homerun Classifier') 76 | fig.savefig('GLM HR Prec-Rec Curve') 77 | 78 | fig = plotPrecRecMN(y2, y12p) 79 | fig.gca().set_title('EV + LA: L={:.0f}%'.format(L12 * 100)) 80 | fig.savefig('GLM(EV + LA) Hit Prec-Rec Curves') 81 | fig = plotPrecRecMN(y2, y22p) 82 | fig.gca().set_title('EV + LA + SA: L={:.0f}%'.format(L22 * 100)) 83 | fig.savefig('GLM(EV + LA + SA) Hit Prec-Rec Curves') 84 | 85 | # %% Plot Residuals 86 | 87 | figs11 = plotResiduals(X1.values, y1 * 100, y11p * 100, 88 | xLabels=fancyLabels[:-1], xUnits=units[:-1], 89 | yLabels=['Home Run'], yUnits=['%'], pltParams={'ms': 1}) 90 | for label, fig in zip(xLabels[:-1], figs11): 91 | fig.gca().set_title('HR(EV + LA) Classifier') 92 | fig.savefig('GLM(EV + LA) HR Residuals over {}'.format(label)) 93 | figs21 = plotResiduals(X2.values, y1 * 100, y21p * 100, 94 | xLabels=fancyLabels, xUnits=units, 95 | yLabels=['Home Run'], yUnits=['%'], pltParams={'ms': 1}) 96 | for label, fig in zip(xLabels, figs21): 97 | fig.gca().set_title('HR(EV + LA + SA) Classifier') 98 | fig.savefig('GLM(EV + LA + SA) HR Residuals over {}'.format(label)) 99 | 100 | Y2 = LabelBinarizer().fit_transform(y2) 101 | 102 | figs12 = plotResiduals(X1.values, Y2 * 100, y12p * 100, 103 | xLabels=fancyLabels[:-1], xUnits=units[:-1], 104 | yLabels=y2.cat.categories, yUnits=['%'] * Y2.shape[1], 105 | pltParams={'ms': 1}) 106 | for label, fig in zip(xLabels[:-1], figs12): 107 | fig.get_axes()[0].set_title('Hit(EV + LA) Classifier') 108 | fig.savefig('GLM(EV + LA) Hit Residuals over {}'.format(label)) 109 | figs22 = plotResiduals(X2.values, Y2 * 100, y22p * 100, 110 | xLabels=fancyLabels, xUnits=units, 111 | yLabels=y2.cat.categories, yUnits=['%'] * Y2.shape[1], 112 | pltParams={'ms': 1}) 113 | for label, fig in zip(xLabels, figs22): 114 | fig.get_axes()[0].set_title('Hit(EV + LA + SA) Classifier') 115 | fig.savefig('GLM(EV + LA + SA) Hit Residuals over {}'.format(label)) 116 | -------------------------------------------------------------------------------- /Scripts/post9.py: -------------------------------------------------------------------------------- 1 | # %% Imports 2 | 3 | from sklearn.model_selection import StratifiedKFold 4 | from sklearn.metrics import log_loss 5 | from sklearn.preprocessing import LabelBinarizer 6 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, \ 7 | QuadraticDiscriminantAnalysis 8 | 9 | from statcast.bip import Bip 10 | from statcast.better.spark import cross_val_predict 11 | from statcast.tools.plot import plotPrecRec, plotPrecRecMN, plotResiduals 12 | 13 | 14 | # %% 15 | 16 | bip = Bip(years=(2016,), n_jobs=-1) 17 | 18 | # %% 19 | 20 | xLabels = ['hit_speed', 'hit_angle', 'sprayAngle'] 21 | fancyLabels = ['Exit Velocity', 'Launch Angle', 'Spray Angle'] 22 | units = ['mph', 'degrees', 'degrees'] 23 | yLabel = 'events' 24 | 25 | subData = bip.data.loc[~bip.data['exclude'], xLabels + [yLabel]] 26 | 27 | outs = ['Bunt Groundout', 'Double Play', 'Fielders Choice', 28 | 'Fielders Choice Out', 'Flyout', 'Forceout', 'Grounded Into DP', 29 | 'Groundout', 'Lineout', 'Pop Out', 'Runner Out', 'Sac Bunt', 30 | 'Sac Fly', 'Sac Fly DP', 'Triple Play', 'Bunt Pop Out', 'Bunt Lineout', 31 | 'Sacrifice Bunt DP'] 32 | 33 | subData['events'] = subData['events'].cat.add_categories(['Out']) 34 | 35 | for out in outs: 36 | subData.loc[subData['events'] == out, 'events'] = 'Out' 37 | 38 | subData['events'] = subData['events'].cat.remove_unused_categories() 39 | 40 | X1 = subData.loc[:, xLabels[:-1]] 41 | X2 = subData.loc[:, xLabels] 42 | y1 = subData[yLabel] == 'Home Run' 43 | y2 = subData[yLabel] 44 | 45 | skf = StratifiedKFold(n_splits=10, shuffle=True) 46 | 47 | lda = LinearDiscriminantAnalysis() 48 | qda = QuadraticDiscriminantAnalysis() 49 | 50 | y11pl = cross_val_predict(lda, X1, y1, cv=skf, n_jobs=-1, 51 | method='predict_proba') 52 | y21pl = cross_val_predict(lda, X2, y1, cv=skf, n_jobs=-1, 53 | method='predict_proba') 54 | y12pl = cross_val_predict(lda, X1, y2, cv=skf, n_jobs=-1, 55 | method='predict_proba') 56 | y22pl = cross_val_predict(lda, X2, y2, cv=skf, n_jobs=-1, 57 | method='predict_proba') 58 | 59 | y11pq = cross_val_predict(qda, X1, y1, cv=skf, n_jobs=-1, 60 | method='predict_proba') 61 | y21pq = cross_val_predict(qda, X2, y1, cv=skf, n_jobs=-1, 62 | method='predict_proba') 63 | y12pq = cross_val_predict(qda, X1, y2, cv=skf, n_jobs=-1, 64 | method='predict_proba') 65 | y22pq = cross_val_predict(qda, X2, y2, cv=skf, n_jobs=-1, 66 | method='predict_proba') 67 | 68 | y11pl = y11pl[:, 1] 69 | y21pl = y21pl[:, 1] 70 | y11pq = y11pq[:, 1] 71 | y21pq = y21pq[:, 1] 72 | 73 | # %% Log-loss 74 | 75 | logL11l = log_loss(y1, y11pl) 76 | logL21l = log_loss(y1, y21pl) 77 | logL12l = log_loss(y2, y12pl) 78 | logL22l = log_loss(y2, y22pl) 79 | 80 | logL11q = log_loss(y1, y11pq) 81 | logL21q = log_loss(y1, y21pq) 82 | logL12q = log_loss(y2, y12pq) 83 | logL22q = log_loss(y2, y22pq) 84 | 85 | # %% Plot Precision-Recall Curve 86 | 87 | fig = plotPrecRec(y1, y11pl, label='EV + LA: LL={:.2f}'.format(logL11l)) 88 | ax = fig.gca() 89 | plotPrecRec(y1, y21pl, ax=ax, label='EV + LA + SA: LL={:.2f}'.format(logL21l)) 90 | ax.legend() 91 | ax.set_title('LDA Homerun Classifier') 92 | fig.savefig('LDA HR Prec-Rec Curve') 93 | 94 | fig = plotPrecRecMN(y2, y12pl) 95 | fig.gca().set_title('LDA(EV + LA): LL={:.2f}'.format(logL12l)) 96 | fig.savefig('LDA(EV + LA) Hit Prec-Rec Curves') 97 | fig = plotPrecRecMN(y2, y22pl) 98 | fig.gca().set_title('LDA(EV + LA + SA): LL={:.2f}'.format(logL22l)) 99 | fig.savefig('LDA(EV + LA + SA) Hit Prec-Rec Curves') 100 | 101 | fig = plotPrecRec(y1, y11pq, label='EV + LA: LL={:.2f}'.format(logL11q)) 102 | ax = fig.gca() 103 | plotPrecRec(y1, y21pq, ax=ax, label='EV + LA + SA: LL={:.2f}'.format(logL21q)) 104 | ax.legend() 105 | ax.set_title('QDA Homerun Classifier') 106 | fig.savefig('QDA HR Prec-Rec Curve') 107 | 108 | fig = plotPrecRecMN(y2, y12pq) 109 | fig.gca().set_title('QDA(EV + LA): LL={:.2f}'.format(logL12q)) 110 | fig.savefig('QDA(EV + LA) Hit Prec-Rec Curves') 111 | fig = plotPrecRecMN(y2, y22pq) 112 | fig.gca().set_title('QDA(EV + LA + SA): LL={:.2f}'.format(logL22q)) 113 | fig.savefig('QDA(EV + LA + SA) Hit Prec-Rec Curves') 114 | 115 | # %% Plot Residuals 116 | 117 | figs11l = plotResiduals(X1.values, y1 * 100, y11pl * 100, 118 | xLabels=fancyLabels[:-1], xUnits=units[:-1], 119 | yLabels=['Home Run'], yUnits=['%'], 120 | pltParams={'ms': 1}) 121 | for label, fig in zip(xLabels[:-1], figs11l): 122 | fig.gca().set_title('LDA(EV + LA) HR Classifier') 123 | fig.savefig('LDA(EV + LA) HR Residuals over {}'.format(label)) 124 | figs21l = plotResiduals(X2.values, y1 * 100, y21pl * 100, 125 | xLabels=fancyLabels, xUnits=units, 126 | yLabels=['Home Run'], yUnits=['%'], 127 | pltParams={'ms': 1}) 128 | for label, fig in zip(xLabels, figs21l): 129 | fig.gca().set_title('LDA(EV + LA + SA) HR Classifier') 130 | fig.savefig('LDA(EV + LA + SA) HR Residuals over {}'.format(label)) 131 | 132 | Y2 = LabelBinarizer().fit_transform(y2) 133 | y2Labels = sorted(y2.cat.categories) 134 | 135 | figs12l = plotResiduals(X1.values, Y2 * 100, y12pl * 100, 136 | xLabels=fancyLabels[:-1], xUnits=units[:-1], 137 | yLabels=y2Labels, yUnits=['%'] * Y2.shape[1], 138 | pltParams={'ms': 1}) 139 | for label, fig in zip(xLabels[:-1], figs12l): 140 | fig.get_axes()[0].set_title('LDA(EV + LA) Hit Classifier') 141 | fig.savefig('LDA(EV + LA) Hit Residuals over {}'.format(label)) 142 | figs22l = plotResiduals(X2.values, Y2 * 100, y22pl * 100, 143 | xLabels=fancyLabels, xUnits=units, 144 | yLabels=y2Labels, yUnits=['%'] * Y2.shape[1], 145 | pltParams={'ms': 1}) 146 | for label, fig in zip(xLabels, figs22l): 147 | fig.get_axes()[0].set_title('LDA(EV + LA + SA) Hit Classifier') 148 | fig.savefig('LDA(EV + LA + SA) Hit Residuals over {}'.format(label)) 149 | 150 | figs11q = plotResiduals(X1.values, y1 * 100, y11pq * 100, 151 | xLabels=fancyLabels[:-1], xUnits=units[:-1], 152 | yLabels=['Home Run'], yUnits=['%'], 153 | pltParams={'ms': 1}) 154 | for label, fig in zip(xLabels[:-1], figs11q): 155 | fig.gca().set_title('QDA(EV + LA) HR Classifier') 156 | fig.savefig('QDA(EV + LA) HR Residuals over {}'.format(label)) 157 | figs21q = plotResiduals(X2.values, y1 * 100, y21pq * 100, 158 | xLabels=fancyLabels, xUnits=units, 159 | yLabels=['Home Run'], yUnits=['%'], 160 | pltParams={'ms': 1}) 161 | for label, fig in zip(xLabels, figs21q): 162 | fig.gca().set_title('QDA(EV + LA + SA) HR Classifier') 163 | fig.savefig('QDA(EV + LA + SA) HR Residuals over {}'.format(label)) 164 | 165 | figs12q = plotResiduals(X1.values, Y2 * 100, y12pq * 100, 166 | xLabels=fancyLabels[:-1], xUnits=units[:-1], 167 | yLabels=y2Labels, yUnits=['%'] * Y2.shape[1], 168 | pltParams={'ms': 1}) 169 | for label, fig in zip(xLabels[:-1], figs12q): 170 | fig.get_axes()[0].set_title('QDA(EV + LA) Hit Classifier') 171 | fig.savefig('QDA(EV + LA) Hit Residuals over {}'.format(label)) 172 | figs22q = plotResiduals(X2.values, Y2 * 100, y22pq * 100, 173 | xLabels=fancyLabels, xUnits=units, 174 | yLabels=y2Labels, yUnits=['%'] * Y2.shape[1], 175 | pltParams={'ms': 1}) 176 | for label, fig in zip(xLabels, figs22q): 177 | fig.get_axes()[0].set_title('QDA(EV + LA + SA) Hit Classifier') 178 | fig.savefig('QDA(EV + LA + SA) Hit Residuals over {}'.format(label)) 179 | -------------------------------------------------------------------------------- /Scripts/statsapis.py: -------------------------------------------------------------------------------- 1 | #%% 2 | 3 | %reset -f 4 | 5 | #%% 6 | 7 | import requests 8 | import io 9 | import json 10 | import pandas as pd 11 | import sqlalchemy as sa 12 | 13 | baseUrl = 'http://statsapi.mlb.com/api/v1/game/{!s}/feed/color' 14 | sqlFlavor = 'sqlite' 15 | tblName = 'raw' 16 | dbName = 'statsapiv1.db' 17 | 18 | engine = sa.create_engine(sqlFlavor + ':///' + dbName) 19 | 20 | game_pk = 487637 21 | 22 | r = requests.get(baseUrl.format(game_pk)) 23 | data = json.load(io.StringIO(r.text)) 24 | 25 | # %% 26 | 27 | import requests 28 | 29 | 30 | baseUrl = 'http://statsapi.mlb.com/api/v1/game/{!s}/feed/color' 31 | game_pk = 487637 32 | 33 | r = requests.get(baseUrl.format(game_pk)) 34 | data = r.json() 35 | items = data['items'] 36 | 37 | 38 | #%% 39 | 40 | try: 41 | items = data['items'] 42 | except KeyError as e: 43 | print('Not a valid game_pk') 44 | raise e 45 | 46 | items.reverse() 47 | game = pd.DataFrame() 48 | 49 | for i in items: 50 | if i['group'] == 'playByPlay': 51 | i.update(i['data']) 52 | del i['data'] 53 | game = game.append(i, ignore_index=True) 54 | 55 | #%% 56 | 57 | %reset -f 58 | 59 | import requests 60 | import json 61 | import io 62 | 63 | #url = 'http://gd2.mlb.com/components/game/mlb/year_2008/month_09/day_03/gid_2008_09_03_anamlb_detmlb_1/game_events.json' 64 | url = 'http://gd2.mlb.com/components/game/mlb/year_2016/month_09/day_03/gid_2016_09_03_anamlb_seamlb_1/game_events.json' 65 | #url = 'http://gd2.mlb.com/components/game/mlb/year_2008/month_09/day_03/gid_2008_09_03_anamlb_detmlb_1/plays.json' 66 | #url = 'http://gd2.mlb.com/components/game/mlb/year_2016/month_09/day_03/gid_2016_09_03_anamlb_seamlb_1/plays.json' 67 | 68 | r = requests.get(url) 69 | data = json.load(io.StringIO(r.text)) 70 | 71 | for (key,val) in data['data']['game'].items(): 72 | print('Key: {}, Value: {}'.format(key,val)) 73 | print('------------------------') 74 | 75 | #%% 76 | 77 | %reset -f 78 | 79 | import requests 80 | import io 81 | import xml.etree.ElementTree as ET 82 | 83 | #url = 'http://gd2.mlb.com/components/game/mlb/year_2007/month_09/day_03/gid_2007_09_03_oakmlb_anamlb_1/plays.xml' 84 | #url = 'http://gd2.mlb.com/components/game/mlb/year_2016/month_03/day_31/gid_2016_03_31_oakmlb_sfnmlb_1/plays.xml' 85 | #url = 'http://gd2.mlb.com/components/game/mlb/year_2008/month_09/day_03/gid_2008_09_03_anamlb_detmlb_1/eventLog.xml' 86 | #url = 'http://gd2.mlb.com/components/game/mlb/year_2008/month_09/day_03/gid_2008_09_03_anamlb_detmlb_1/game_events.xml' 87 | #url = 'http://gd2.mlb.com/components/game/mlb/year_2016/month_09/day_03/gid_2016_09_03_anamlb_seamlb_1/game_events.xml' 88 | #url = 'http://gd2.mlb.com/components/game/mlb/year_2008/month_02/day_26/master_scoreboard.xml' 89 | #url = 'http://gd2.mlb.com/components/game/mlb/year_2008/month_02/day_26/gid_2008_02_29_balmlb_flomlb_1/game_events.xml' 90 | #url = 'http://gd2.mlb.com/components/game/mlb/year_2008/month_02/day_29/master_scoreboard.xml' 91 | #url = 'http://gd2.mlb.com/components/game/mlb/year_2008/month_02/day_28/gid_2008_02_28_bocbbc_bosmlb_1/game_events.xml' 92 | #url = 'http://gd2.mlb.com/components/game/mlb/year_2016/month_09/day_03/master_scoreboard.xml' 93 | url = 'http://gd2.mlb.com/components/game/mlb/year_2008/month_03/day_01/master_scoreboard.xml' 94 | 95 | r = requests.get(url) 96 | tree = ET.parse(io.StringIO(r.text)) 97 | root = tree.getroot() 98 | 99 | def printxml(root, space='', recurse=True, printAttrs=True): 100 | print(space + 'Tag: {}'.format(root.tag)) 101 | if printAttrs: 102 | print(space + 'Attributes: {}'.format(root.attrib)) 103 | if len(root) > 0 and recurse: 104 | print(space + 'Children:') 105 | if recurse is not True: 106 | recurse -= 1 107 | for child in root: 108 | printxml(child, space + ' ', recurse, printAttrs) 109 | print(space + '--------------') 110 | 111 | printxml(root) 112 | 113 | #%% 114 | 115 | %reset -f 116 | 117 | import requests 118 | import io 119 | 120 | url = 'http://gd2.mlb.com/components/game/mlb/year_2008/month_09/day_03/' 121 | 122 | r = requests.get(url) 123 | 124 | import html 125 | 126 | class MyHTMLParser(html.parser.HTMLParser): 127 | 128 | links = [] 129 | 130 | def handle_starttag(self, tag, attrs): 131 | # Only parse the 'anchor' tag. 132 | if tag == "a": 133 | # Check the list of defined attributes. 134 | for name, value in attrs: 135 | # If href is defined, print it. 136 | if name == "href": 137 | print(name + ': ' + value) 138 | 139 | parser = MyHTMLParser() 140 | parser.feed(r.text) -------------------------------------------------------------------------------- /Scripts/untitled2.py: -------------------------------------------------------------------------------- 1 | # %% Imports 2 | 3 | from sklearn.model_selection import StratifiedKFold 4 | from sklearn.metrics import log_loss 5 | from sklearn.base import clone 6 | from sklearn.preprocessing import LabelBinarizer 7 | 8 | from statcast.bip import Bip 9 | from statcast.tools.plot import plotPrecRec, plotPrecRecMN, plotResiduals 10 | from statcast.better.declassifier import KDEClassifier 11 | 12 | 13 | # %% 14 | 15 | bip = Bip(years=(2016,), n_jobs=-1) 16 | 17 | # %% 18 | 19 | xLabels = ['hit_speed', 'hit_angle', 'sprayAngle'] 20 | fancyLabels = ['Exit Velocity', 'Launch Angle', 'Spray Angle'] 21 | units = ['mph', 'degrees', 'degrees'] 22 | yLabel = 'events' 23 | 24 | subData = bip.data.loc[~bip.data['exclude'], xLabels + [yLabel]] 25 | 26 | outs = ['Bunt Groundout', 'Double Play', 'Fielders Choice', 27 | 'Fielders Choice Out', 'Flyout', 'Forceout', 'Grounded Into DP', 28 | 'Groundout', 'Lineout', 'Pop Out', 'Runner Out', 'Sac Bunt', 29 | 'Sac Fly', 'Sac Fly DP', 'Triple Play', 'Bunt Pop Out', 'Bunt Lineout', 30 | 'Sacrifice Bunt DP'] 31 | 32 | subData['events'] = subData['events'].cat.add_categories(['Out']) 33 | 34 | for out in outs: 35 | subData.loc[subData['events'] == out, 'events'] = 'Out' 36 | 37 | subData['events'] = subData['events'].cat.remove_unused_categories() 38 | 39 | X1 = subData.loc[:, xLabels[:-1]].values 40 | X2 = subData.loc[:, xLabels].values 41 | y1 = (subData[yLabel] == 'Home Run').values 42 | y2 = subData[yLabel].values 43 | 44 | skf = StratifiedKFold(n_splits=10, shuffle=True) 45 | test, train = next(skf.split(X1, y1)) 46 | 47 | kdc = KDEClassifier(kdeParams=dict(kernel='gaussian'), 48 | n_jobs=-1) 49 | est11 = clone(kdc).fit(X1[train], y1[train]) 50 | est12 = clone(kdc).fit(X1[train], y2[train]) 51 | est21 = clone(kdc).fit(X2[train], y1[train]) 52 | est22 = clone(kdc).fit(X2[train], y2[train]) 53 | 54 | y11p = est11.predict_proba(X1[test]) 55 | y12p = est12.predict_proba(X1[test]) 56 | y21p = est21.predict_proba(X2[test]) 57 | y22p = est22.predict_proba(X2[test]) 58 | 59 | y11p = y11p[:, 1] 60 | y21p = y21p[:, 1] 61 | 62 | # %% Log-loss 63 | 64 | logL11 = log_loss(y1[test], y11p) 65 | logL21 = log_loss(y1[test], y21p) 66 | logL12 = log_loss(y2[test], y12p) 67 | logL22 = log_loss(y2[test], y22p) 68 | 69 | # %% Plot Precision-Recall Curve 70 | 71 | fig = plotPrecRec(y1[test], y11p, label='EV + LA: LL={:.2f}'.format(logL11)) 72 | ax = fig.gca() 73 | plotPrecRec(y1[test], y21p, ax=ax, label='EV + LA + SA: LL={:.2f}'.format(logL21)) 74 | ax.legend() 75 | ax.set_title('KDC Homerun Classifier') 76 | #fig.savefig('KDC HR Prec-Rec Curve') 77 | 78 | fig = plotPrecRecMN(y2[test], y12p) 79 | fig.gca().set_title('EV + LA: LL={:.2f}'.format(logL12)) 80 | #fig.savefig('KDC(EV + LA) Hit Prec-Rec Curves') 81 | fig = plotPrecRecMN(y2[test], y22p) 82 | fig.gca().set_title('EV + LA + SA: LL={:.2f}'.format(logL22)) 83 | #fig.savefig('KDC(EV + LA + SA) Hit Prec-Rec Curves') 84 | 85 | # %% Plot Residuals 86 | 87 | figs11 = plotResiduals(X1[test], y1[test] * 100, y11p * 100, 88 | xLabels=fancyLabels[:-1], xUnits=units[:-1], 89 | yLabels=['Home Run'], yUnits=['%'], pltParams={'ms': 1}) 90 | for label, fig in zip(xLabels[:-1], figs11): 91 | fig.gca().set_title('HR(EV + LA) Classifier') 92 | # fig.savefig('KDC(EV + LA) HR Residuals over {}'.format(label)) 93 | figs21 = plotResiduals(X2[test], y1[test] * 100, y21p * 100, 94 | xLabels=fancyLabels, xUnits=units, 95 | yLabels=['Home Run'], yUnits=['%'], pltParams={'ms': 1}) 96 | for label, fig in zip(xLabels, figs21): 97 | fig.gca().set_title('HR(EV + LA + SA) Classifier') 98 | # fig.savefig('KDC(EV + LA + SA) HR Residuals over {}'.format(label)) 99 | 100 | Y2 = LabelBinarizer().fit_transform(y2) 101 | 102 | figs12 = plotResiduals(X1[test], Y2[test] * 100, y12p * 100, 103 | xLabels=fancyLabels[:-1], xUnits=units[:-1], 104 | yLabels=est12.classes_, yUnits=['%'] * Y2.shape[1], 105 | pltParams={'ms': 1}) 106 | for label, fig in zip(xLabels[:-1], figs12): 107 | fig.get_axes()[0].set_title('Hit(EV + LA) Classifier') 108 | # fig.savefig('KDC(EV + LA) Hit Residuals over {}'.format(label)) 109 | figs22 = plotResiduals(X2[test], Y2[test] * 100, y22p * 100, 110 | xLabels=fancyLabels, xUnits=units, 111 | yLabels=est22.classes_, yUnits=['%'] * Y2.shape[1], 112 | pltParams={'ms': 1}) 113 | for label, fig in zip(xLabels, figs22): 114 | fig.get_axes()[0].set_title('Hit(EV + LA + SA) Classifier') 115 | # fig.savefig('KDC(EV + LA + SA) Hit Residuals over {}'.format(label)) 116 | -------------------------------------------------------------------------------- /Scripts/updateBip.py: -------------------------------------------------------------------------------- 1 | # %% Imports 2 | 3 | import os 4 | import datetime 5 | 6 | import requests 7 | from pyspark import SparkContext 8 | 9 | from statcast.bip import Bip 10 | 11 | 12 | # %% Create Spark Context 13 | 14 | sc = SparkContext(appName="updateBip") 15 | 16 | # %% Load data, plot histograms of statcast data 17 | 18 | years = (2015, 2016) 19 | 20 | for year in years: 21 | bip = Bip(years=(year,), n_jobs=sc) 22 | 23 | # %% Transfer results to S3 24 | 25 | instanceID = requests. \ 26 | get('http://169.254.169.254/latest/meta-data/instance-id').text 27 | dtStr = datetime.datetime.utcnow().strftime('%Y-%m-%d--%H-%M-%S') 28 | os.system('aws s3 sync . s3://mf-first-bucket/output/{}/{}'. 29 | format(instanceID, dtStr)) 30 | 31 | # %% Stop Spark Context 32 | 33 | sc.stop() 34 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup(name='statcast', 5 | version='1.0.0', 6 | description='A baseball Python project', 7 | long_description='Collecting, storing, manipulating, and visualizing ' 8 | 'baseball data, mostly from statcast.', 9 | url='https://github.com/matteosox/statcast', 10 | author='Matt Fay', 11 | author_email='matt.e.fay@gmail.com', 12 | classifiers=['Programming Language :: Python :: 3.5', 13 | 'Development Status :: 2 - Pre-Alpha', 14 | 'Natural Language :: English'], 15 | keywords='baseball statcast mlb sabermetrics', 16 | packages=find_packages(), 17 | package_data={'statcast': ['data/*.*', 'data/logos/*.png']}) 18 | -------------------------------------------------------------------------------- /statcast/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/__init__.py -------------------------------------------------------------------------------- /statcast/better/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/better/__init__.py -------------------------------------------------------------------------------- /statcast/better/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from inspect import signature, Signature, Parameter 3 | 4 | from sklearn.externals import joblib 5 | 6 | from ..tools.fixpath import findFile 7 | 8 | addParams = [Parameter('name', Parameter.POSITIONAL_OR_KEYWORD, default=''), 9 | Parameter('xLabels', Parameter.POSITIONAL_OR_KEYWORD, default=()), 10 | Parameter('yLabels', Parameter.POSITIONAL_OR_KEYWORD, default=())] 11 | xMethods = ['predict', 'predict_proba', 'transform', 'decision_function', 12 | 'predict_log_proba', 'score_samples'] 13 | xyMethods = ['fit', 'fit_predict', 'fit_transform', 'score', 'partial_fit'] 14 | fitMethods = ['fit', 'fit_predict', 'fit_transform', 'partial_fit'] 15 | 16 | 17 | def addXMethod(clsobj, method): 18 | '''Doc String''' 19 | 20 | if hasattr(clsobj, method) and not hasattr(clsobj, method + 'D'): 21 | methodSig = signature(getattr(clsobj, method)) 22 | methodParams = [p for p in methodSig.parameters.values()] 23 | methodDParams = methodParams.copy() 24 | if methodDParams[0].name != 'self': 25 | methodDParams.insert(0, 26 | Parameter('self', 27 | Parameter.POSITIONAL_ONLY)) 28 | if methodDParams[1].kind is Parameter.VAR_POSITIONAL: 29 | methodDParams.insert(1, Parameter('data', 30 | Parameter. 31 | POSITIONAL_OR_KEYWORD)) 32 | else: 33 | methodDParams[1] = \ 34 | Parameter('data', Parameter.POSITIONAL_OR_KEYWORD) 35 | 36 | methodDSig = Signature(methodDParams) 37 | 38 | def methodD(self, data, *args, **kwargs): 39 | X = self.createX(data) 40 | return getattr(self, method)(X, *args, **kwargs) 41 | methodD.__signature__ = methodDSig 42 | setattr(clsobj, method + 'D', methodD) 43 | 44 | 45 | def addXYMethod(clsobj, method): 46 | '''Doc String''' 47 | 48 | if hasattr(clsobj, method) and not hasattr(clsobj, method + 'D'): 49 | methodSig = signature(getattr(clsobj, method)) 50 | methodParams = [p for p in methodSig.parameters.values()] 51 | methodDParams = methodParams.copy() 52 | if methodDParams[0].name != 'self': 53 | methodDParams.insert(0, 54 | Parameter('self', 55 | Parameter.POSITIONAL_ONLY)) 56 | if methodDParams[1].kind is Parameter.VAR_POSITIONAL: 57 | methodDParams.insert(1, Parameter('data', 58 | Parameter. 59 | POSITIONAL_OR_KEYWORD)) 60 | else: 61 | methodDParams[1] = \ 62 | Parameter('data', Parameter.POSITIONAL_OR_KEYWORD) 63 | if methodDParams[2] is Parameter.VAR_POSITIONAL: 64 | pass 65 | else: 66 | del methodDParams[2] 67 | methodDSig = Signature(methodDParams) 68 | 69 | def methodD(self, data, *args, **kwargs): 70 | X = self.createX(data) 71 | Y = self.createY(data) 72 | return getattr(self, method)(X, Y, *args, **kwargs) 73 | methodD.__signature__ = methodDSig 74 | setattr(clsobj, method + 'D', methodD) 75 | 76 | 77 | def addXMethods(clsobj): 78 | '''Doc String''' 79 | 80 | for method in xMethods: 81 | addXMethod(clsobj, method) 82 | 83 | 84 | def addXYMethods(clsobj): 85 | '''Doc String''' 86 | 87 | for method in xyMethods: 88 | addXYMethod(clsobj, method) 89 | 90 | 91 | class BetterMetaClass(abc.ABCMeta): 92 | ''' Doc String''' 93 | 94 | def __new__(cls, clsname, bases, clsdict): 95 | '''Doc String''' 96 | 97 | clsobj = super().__new__(cls, clsname, bases, clsdict) 98 | oldInit = clsobj.__init__ 99 | oldParams = [param for param in signature(oldInit).parameters.values()] 100 | 101 | firstIn = oldParams.pop(0) 102 | if firstIn.name != 'self': 103 | raise RuntimeError('BetterModels init must have self first ' 104 | 'argument') 105 | if oldInit is object.__init__: 106 | oldParams = [] 107 | if addParams == oldParams[:len(addParams)]: 108 | newParams = [] 109 | else: 110 | newParams = addParams 111 | if any(nP == oP for nP in newParams for oP in oldParams): 112 | raise RuntimeError('BetterModels cannot define {} as inputs to ' 113 | 'init'.format(', '.join([nP.name 114 | for nP in newParams]))) 115 | try: 116 | if any(p.kind != Parameter.POSITIONAL_OR_KEYWORD 117 | for p in clsobj._params): 118 | raise RuntimeError('BetterModels can only add positional or ' 119 | 'keyword inputs to init from _params') 120 | elif any(p.default is Parameter.empty 121 | for p in clsobj._params): 122 | raise RuntimeError('BetterModels init inputs must have ' 123 | 'defaults') 124 | except AttributeError: 125 | raise RuntimeError('BetterModels _params class attribute must be ' 126 | 'a list of Parameters from the inspect module') 127 | 128 | customParams = [param for param in clsobj._params 129 | if param not in newParams] 130 | newSig = Signature([firstIn] + newParams + customParams + oldParams) 131 | 132 | def newInit(self, *args, **kwargs): 133 | '''Doc String''' 134 | 135 | bound = newSig.bind(self, *args, **kwargs) 136 | bound.apply_defaults() 137 | bound.arguments.popitem(False) 138 | for dummy in range(len(newParams)): 139 | name, val = bound.arguments.popitem(False) 140 | setattr(self, name, val) 141 | for dummy in range(len(customParams)): 142 | name, val = bound.arguments.popitem(False) 143 | setattr(self, name, val) 144 | 145 | oldInit(self, *bound.args, **bound.kwargs) 146 | 147 | newInit.__signature__ = newSig 148 | setattr(clsobj, '__init__', newInit) 149 | 150 | addXMethods(clsobj) 151 | addXYMethods(clsobj) 152 | 153 | return clsobj 154 | 155 | 156 | class BetterModel(metaclass=BetterMetaClass): 157 | '''Doc String''' 158 | 159 | _params = [] 160 | 161 | def createX(self, data): 162 | '''Doc String''' 163 | 164 | return data[self.xLabels] 165 | 166 | def createY(self, data): 167 | '''Doc String''' 168 | 169 | return data[self.yLabels] 170 | 171 | def save(self, path=None): 172 | '''Doc String''' 173 | 174 | if path is None: 175 | path = self.name 176 | joblib.dump(self, path + '.pkl') 177 | 178 | def load(self, name=None, filePath=None, searchDirs=None): 179 | '''Doc String''' 180 | 181 | if filePath is None: 182 | if name is None: 183 | name = self.name 184 | 185 | if searchDirs is None: 186 | filePath = findFile(name + '.pkl') 187 | else: 188 | filePath = findFile(name + '.pkl', searchDirs) 189 | 190 | return joblib.load(filePath) 191 | -------------------------------------------------------------------------------- /statcast/better/declassifier.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import warnings 3 | from inspect import Parameter 4 | 5 | import numpy as np 6 | 7 | from sklearn.base import BaseEstimator, ClassifierMixin 8 | from sklearn.utils import check_array, check_X_y 9 | from sklearn.utils.multiclass import check_classification_targets 10 | from sklearn.utils.fixes import bincount 11 | from sklearn.utils.validation import check_is_fitted 12 | 13 | from .base import BetterModel 14 | from .kde import BetterKernelDensity 15 | 16 | 17 | class DensityEstimationClassifier(BaseEstimator, ClassifierMixin, 18 | metaclass=abc.ABCMeta): 19 | 20 | _params = [Parameter('priors', Parameter.POSITIONAL_OR_KEYWORD, 21 | default=None)] 22 | 23 | @abc.abstractmethod 24 | def fit(self, X, y): 25 | '''Doc String''' 26 | 27 | pass 28 | 29 | def _prefit(self, X, y): 30 | '''Doc String''' 31 | 32 | X, y = check_X_y(X, y) 33 | check_classification_targets(y) 34 | self.classes_, y = np.unique(y, return_inverse=True) 35 | n_samples, n_features = X.shape 36 | n_classes = len(self.classes_) 37 | if n_classes < 2: 38 | raise ValueError('y has less than 2 classes') 39 | if self.priors is None: 40 | self.priors_ = bincount(y) / float(n_samples) 41 | else: 42 | self.priors_ = self.priors 43 | 44 | if (self.priors_ < 0).any(): 45 | raise ValueError("priors must be non-negative") 46 | if self.priors_.sum() != 1: 47 | warnings.warn("The priors do not sum to 1. Renormalizing", 48 | UserWarning) 49 | self.priors_ = self.priors_ / self.priors_.sum() 50 | 51 | return X, y 52 | 53 | @abc.abstractmethod 54 | def _estimateDensities(self, X): 55 | '''Doc String''' 56 | 57 | pass 58 | 59 | def predict_proba(self, X): 60 | '''Doc String''' 61 | 62 | X = check_array(X) 63 | check_is_fitted(self, ['priors_', 'classes_']) 64 | F = self._estimateDensities(X) 65 | num = F * self.priors_ 66 | den = num.sum(1, keepdims=True) 67 | den0 = den == 0 68 | den[den0] = 1 69 | probs = num / den 70 | probs[den0.flatten()] = self.priors_ 71 | return probs 72 | 73 | def predict(self, X): 74 | '''Doc String''' 75 | 76 | probs = self.predict_proba(X) 77 | y = self.classes_.take(probs.argmax(1)) 78 | return y 79 | 80 | 81 | class KDEClassifier(DensityEstimationClassifier, BetterModel): 82 | 83 | _params = DensityEstimationClassifier._params 84 | _params.extend([Parameter('kdeParams', Parameter.POSITIONAL_OR_KEYWORD, 85 | default={}), 86 | Parameter('cv', Parameter.POSITIONAL_OR_KEYWORD, 87 | default=None), 88 | Parameter('n_jobs', Parameter.POSITIONAL_OR_KEYWORD, 89 | default=1)]) 90 | 91 | def fit(self, X, y): 92 | '''Doc String''' 93 | 94 | X, y = self._prefit(X, y) 95 | nClasses = len(self.classes_) 96 | self.kdes_ = [] 97 | for i in range(nClasses): 98 | kde = BetterKernelDensity(**self.kdeParams) 99 | kde.fit(X[y == i, :]) 100 | kde.selectBandwidth(n_jobs=self.n_jobs, cv=self.cv) 101 | self.kdes_.append(kde) 102 | 103 | return self 104 | 105 | def _estimateDensities(self, X): 106 | '''Doc String''' 107 | 108 | return np.concatenate([kde.predict(X)[:, None] for kde in self.kdes_], 109 | axis=1) 110 | -------------------------------------------------------------------------------- /statcast/better/kde.py: -------------------------------------------------------------------------------- 1 | from inspect import Parameter 2 | 3 | import numpy as np 4 | 5 | from sklearn.neighbors.kde import KernelDensity 6 | from sklearn.utils.validation import check_array, check_is_fitted 7 | from sklearn.model_selection import check_cv 8 | 9 | from scipy import stats 10 | from scipy.special import gamma 11 | from scipy.spatial.distance import pdist 12 | 13 | from .base import BetterModel 14 | from .spark import gridCVScoresAlt 15 | 16 | 17 | def ballVol(r, n): 18 | return np.pi ** (n / 2) * r ** n / gamma(n / 2 + 1) 19 | 20 | 21 | def epanechnikov(X, h): 22 | d = X.shape[1] 23 | x2 = ((X / h) ** 2).sum(1, keepdims=True) 24 | y = (d + 2) * ((1 - x2) * ((1 - x2) > 0)) / (2 * ballVol(1, d)) 25 | return y * h ** -d 26 | 27 | 28 | def tophat(X, h): 29 | d = X.shape[1] 30 | x = np.sqrt(((X / h) ** 2).sum(1, keepdims=True)) 31 | y = (x < 1) / ballVol(1, d) 32 | return y * h ** -d 33 | 34 | 35 | def gaussian(X, h): 36 | d = X.shape[1] 37 | x2 = ((X / h) ** 2).sum(1, keepdims=True) 38 | y = np.exp(-x2 / 2) / (2 * np.pi) ** (d / 2) 39 | return y * h ** -d 40 | 41 | 42 | def exponential(X, h): 43 | d = X.shape[1] 44 | x = np.sqrt(((X / h) ** 2).sum(1, keepdims=True)) 45 | y = np.exp(-x) / ballVol(1, d) / np.math.factorial(d) 46 | return y * h ** -d 47 | 48 | 49 | def linear(X, h): 50 | d = X.shape[1] 51 | if d > 1: 52 | raise NotImplementedError('Linear kernel not implemented for dims > 1') 53 | x = np.sqrt(((X / h) ** 2).sum(1, keepdims=True)) 54 | y = (1 - x) * (x < 1) 55 | return y * h ** -d 56 | 57 | 58 | def cosine(X, h): 59 | d = X.shape[1] 60 | if d > 1: 61 | raise NotImplementedError('Cosine kernel not implemented for dims > 1') 62 | x = np.sqrt(((X / h) ** 2).sum(1, keepdims=True)) 63 | y = np.cos(x * np.pi / 2) * (x < 1) * np.pi / 4 64 | return y * h ** -d 65 | 66 | 67 | kernelFunctions = {'gaussian': gaussian, 68 | 'tophat': tophat, 69 | 'epanechnikov': epanechnikov, 70 | 'exponential': exponential, 71 | 'linear': linear, 72 | 'cosine': cosine} 73 | 74 | 75 | class BetterKernelDensity(KernelDensity, BetterModel): 76 | '''Doc String''' 77 | 78 | _params = [Parameter('normalize', Parameter.POSITIONAL_OR_KEYWORD, 79 | default=True)] 80 | 81 | def _se(self, X): 82 | '''Doc String''' 83 | 84 | trainX = np.array(self.tree_.data) 85 | n = trainX.shape[0] 86 | 87 | Y = np.array([self._kernelFunction(X - row) 88 | for row in trainX]) 89 | s2 = Y.var(axis=0, ddof=1) 90 | return np.sqrt(s2 / n).flatten() 91 | 92 | def fit(self, X, y=None): 93 | '''Doc String''' 94 | 95 | if self.normalize: 96 | X = check_array(X) 97 | U, s, V = np.linalg.svd(X, full_matrices=False) 98 | self.invH_ = np.diag(np.sqrt(X.shape[0]) / s).dot(V) 99 | self.detH_ = 1 / np.prod(np.sqrt(X.shape[0]) / s) 100 | X = X.dot(self.invH_.T) 101 | 102 | return super().fit(X, y) 103 | 104 | def predict(self, X): 105 | ''' Doc String''' 106 | 107 | return np.exp(self.score_samples(X)) 108 | 109 | def score_samples(self, X): 110 | '''Doc String''' 111 | 112 | if self.normalize: 113 | X = check_array(X) 114 | X = X.dot(self.invH_.T) 115 | return super().score_samples(X) - np.log(self.detH_) 116 | 117 | return super().score_samples(X) 118 | 119 | def confidence(self, X, alpha=0.05): 120 | '''Doc String''' 121 | 122 | check_is_fitted(self, ['tree_']) 123 | trainX = np.array(self.tree_.data) 124 | X = check_array(X) 125 | a, b = trainX.min(0), trainX.max(0) 126 | if self.kernel == 'gaussian': 127 | w = 6 128 | elif self.kernel == 'exponential': 129 | w = 2 * stats.expon.ppf(2 * (stats.norm.cdf(3) - 0.5)) 130 | else: 131 | w = 2 132 | m = np.prod(b - a) / (w * self.bandwidth) ** X.shape[1] 133 | if self.normalize: 134 | m *= self.detH_ 135 | q = stats.norm.ppf((1 + (1 - alpha) ** (1 / m)) / 2) 136 | 137 | f = self.predict(X) 138 | se = self._se(X) 139 | return f - q * se, f + q * se 140 | 141 | def confidenceD(self, data, alpha=0.05): 142 | '''Doc String''' 143 | 144 | X = self.createX(data) 145 | return self.confidence(X, alpha) 146 | 147 | def _kernelFunction(self, X): 148 | '''Doc String''' 149 | 150 | if self.normalize: 151 | X = check_array(X) 152 | X = X.dot(self.invH_.T) 153 | scale = self.detH_ 154 | else: 155 | scale = 1 156 | return kernelFunctions[self.kernel](X, self.bandwidth) / scale 157 | 158 | def selectBandwidth(self, bandwidths=None, n_jobs=1, cv=None): 159 | '''Doc String''' 160 | 161 | check_is_fitted(self, ['tree_']) 162 | trainX = np.array(self.tree_.data) 163 | 164 | nSplits = check_cv(cv).get_n_splits() 165 | 166 | if trainX.shape[0] == 1: 167 | self.bandwidth = 1 168 | self.cv_results_ = None 169 | return self 170 | elif trainX.shape[0] < nSplits: 171 | cv = nSplits = trainX.shape[0] 172 | 173 | scale = ((nSplits - 1) / nSplits) ** (-1 / (4 + trainX.shape[1])) 174 | 175 | if bandwidths is None: 176 | if trainX.shape[0] > 1000: 177 | subs = np.random.randint(0, trainX.shape[0], size=(1000,)) 178 | bandMax = pdist(trainX[subs]).mean() 179 | else: 180 | bandMax = pdist(trainX).mean() 181 | nnDists = self.tree_.query(trainX, k=2)[0][:, 1] 182 | if self.kernel in ['gaussian', 'exponential']: 183 | bandMin = nnDists.mean() 184 | else: 185 | bandMin = nnDists.max() * 1.02 186 | bandwidths = np.logspace(np.log10(bandMin), np.log10(bandMax), 187 | num=5) 188 | 189 | parameters = {'bandwidth': bandwidths * scale} 190 | results = gridCVScoresAlt(self, parameters, trainX, 191 | n_jobs=n_jobs, cv=cv) 192 | totalScores = [scores.sum() for scores in results['scores']] 193 | 194 | if not np.isfinite(totalScores).any(): 195 | self.bandwidth = bandMax 196 | else: 197 | bestInd = np.argmax(totalScores) 198 | bestBand = results.loc[bestInd, 'params']['bandwidth'] 199 | self.bandwidth = bestBand / scale 200 | self.cv_results_ = results 201 | return self 202 | -------------------------------------------------------------------------------- /statcast/better/kdr.py: -------------------------------------------------------------------------------- 1 | from inspect import signature 2 | 3 | import numpy as np 4 | 5 | from sklearn.base import BaseEstimator, RegressorMixin 6 | from sklearn.utils.validation import check_array, check_X_y, check_is_fitted 7 | from sklearn.metrics import mean_squared_error 8 | 9 | from scipy import stats 10 | 11 | from .base import BetterModel 12 | from .kde import BetterKernelDensity 13 | 14 | 15 | class BetterKDR(BaseEstimator, RegressorMixin, BetterModel): 16 | '''Doc String''' 17 | 18 | _params = [param for param 19 | in signature(BetterKernelDensity).parameters.values()] 20 | 21 | def __init__(self): 22 | '''Doc String''' 23 | 24 | self.kde = BetterKernelDensity(**self.get_params()) 25 | 26 | def fit(self, X, Y): 27 | '''Doc String''' 28 | 29 | self._flowParams() 30 | check_X_y(X, Y, multi_output=True, dtype=None) 31 | self.kde.fit(X) 32 | self.trainY_ = Y.copy() 33 | return self 34 | 35 | def _flowParams(self, up=False): 36 | '''Doc String''' 37 | 38 | if up: 39 | self.set_params(**self.kde.get_params()) 40 | else: 41 | self.kde.set_params(**self.get_params()) 42 | 43 | def _weights(self, X): 44 | '''Doc String''' 45 | 46 | self._flowParams() 47 | trainX = np.array(self.kde.tree_.data) 48 | num = np.hstack([self.kde._kernelFunction(np.tile( 49 | row, (trainX.shape[0], 1)) - trainX) for row in X]).T 50 | den = np.tile(self.kde.predict(X)[:, None] * 51 | trainX.shape[0], (1, trainX.shape[0])) 52 | den[den == 0] = 1 53 | W = num / den 54 | W[(W == 0).all(1), :] = 1 / trainX.shape[0] 55 | return W 56 | 57 | def predict(self, X): 58 | '''Doc String''' 59 | 60 | check_is_fitted(self, ['trainY_']) 61 | check_array(X, dtype=None) 62 | W = self._weights(X) 63 | return W.dot(self.trainY_) 64 | 65 | def score(self, X, Y, sample_weight=None): 66 | '''Doc String''' 67 | 68 | X, Y = check_X_y(X, Y, multi_output=True, dtype=None) 69 | Yp = self.predict(X) 70 | return -np.sqrt([mean_squared_error(y, yp, sample_weight) 71 | for y, yp in zip(Y.T, Yp.T)]).mean() 72 | 73 | def risk(self): 74 | '''Doc String''' 75 | 76 | check_is_fitted(self, ['trainY_']) 77 | self._flowParams() 78 | trainX = np.array(self.kde.tree_.data) 79 | k0 = self.kde._kernelFunction(np.zeros((1, trainX.shape[1]))) 80 | den = np.tile(1 - k0 / (self.kde.predict(trainX)[:, None] * 81 | trainX.shape[0]), 82 | (1, self.trainY_.shape[1])) 83 | num = self.trainY_ - self.predict(trainX) 84 | risk = ((num / den) ** 2).mean() 85 | if np.isnan(risk): 86 | risk = np.inf 87 | return risk 88 | 89 | def confidence(self, X, alpha=0.05): 90 | '''Doc String''' 91 | 92 | check_is_fitted(self, ['trainY_']) 93 | trainX = np.array(self.kde.tree_.data) 94 | X = check_array(X) 95 | a, b = trainX.min(0), trainX.max(0) 96 | if self.kernel == 'gaussian': 97 | w = 6 98 | elif self.kernel == 'exponential': 99 | w = 2 * stats.expon.ppf(2 * (stats.norm.cdf(3) - 0.5)) 100 | else: 101 | w = 2 102 | m = np.prod(b - a) / (w * self.bandwidth) ** X.shape[1] 103 | if self.normalize: 104 | m *= self.kde.detH_ 105 | q = stats.norm.ppf((1 + (1 - alpha) ** (1 / m)) / 2) 106 | 107 | f = self.predict(X) 108 | se = np.sqrt((self._weights(X) ** 2).sum(1, keepdims=True) * 109 | self.risk()) 110 | return f - q * se, f + q * se 111 | 112 | def confidenceD(self, data, alpha=0.05): 113 | '''Doc String''' 114 | 115 | X = self.createX(data) 116 | return self.confidence(X, alpha) 117 | 118 | def selectBandwidth(self, bandwidths=None, n_jobs=1, cv=None): 119 | '''Doc String''' 120 | 121 | self._flowParams() 122 | self.kde.selectBandwidth(bandwidths, n_jobs, cv) 123 | self._flowParams(up=True) 124 | return self 125 | -------------------------------------------------------------------------------- /statcast/better/mixed.py: -------------------------------------------------------------------------------- 1 | from inspect import Parameter 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from sklearn.base import BaseEstimator, RegressorMixin 7 | from sklearn.utils.validation import check_X_y, check_array, check_is_fitted 8 | from sklearn.metrics import mean_squared_error 9 | 10 | from rpy2.robjects.packages import importr 11 | from rpy2.robjects import pandas2ri 12 | 13 | from .base import BetterModel 14 | from .spark import GridSearchCV 15 | 16 | pandas2ri.activate() 17 | rLME4 = importr('lme4') 18 | 19 | 20 | class BetterLME4(BaseEstimator, RegressorMixin, BetterModel): 21 | 22 | _params = [Parameter('formulas', Parameter.POSITIONAL_OR_KEYWORD, 23 | default=()), 24 | Parameter('LME4Params', Parameter.POSITIONAL_OR_KEYWORD, 25 | default={})] 26 | 27 | def fit(self, X, Y): 28 | '''Doc String''' 29 | 30 | check_X_y(X, Y, multi_output=True, dtype=None) 31 | data = pd.concat((pd.DataFrame(X, columns=self.xLabels), 32 | pd.DataFrame(Y, columns=self.yLabels)), axis=1) 33 | return self.fitD(data) 34 | 35 | def fitD(self, data): 36 | '''Doc String''' 37 | 38 | if not any((self.xLabels, self.yLabels, self.formulas)): 39 | raise RuntimeError('betterLME must have xLabels, yLabels, and ' 40 | 'formulas defined') 41 | if isinstance(self.formulas, str): 42 | self.formulas = (self.formulas,) 43 | if len(self.formulas) != len(self.yLabels): 44 | if len(self.formulas) == 1: 45 | self.formulas = self.formulas * len(self.yLabels) 46 | else: 47 | raise RuntimeError('formulas must be a single string, or a ' 48 | 'tuple the same length as yLabels') 49 | X, Y = self.createX(data), self.createY(data) 50 | check_X_y(X, Y, multi_output=True, dtype=None) 51 | self.models_ = {} 52 | self.factors_ = {} 53 | for ii, yLabel in enumerate(self.yLabels): 54 | subData = pd.concat((X, Y[yLabel]), axis=1) 55 | formula = yLabel + ' ~ ' + self.formulas[ii] 56 | model = rLME4.lmer(formula=formula, 57 | data=subData, 58 | **self.LME4Params) 59 | self.models_[yLabel] = model 60 | self.factors_[yLabel] = self._factor(model) 61 | 62 | return self 63 | 64 | def createX(self, data): 65 | '''Doc String''' 66 | 67 | X = data[self.xLabels].copy() 68 | for xLabel in self.xLabels: 69 | if X[xLabel].dtype is pd.api.types.CategoricalDtype(): 70 | X[xLabel] = X[xLabel].astype(X[xLabel].cat.categories.dtype) 71 | 72 | return X 73 | 74 | @staticmethod 75 | def _factor(model): 76 | '''Doc String''' 77 | 78 | rEffs = rLME4.random_effects(model) 79 | fEffs = rLME4.fixed_effects(model) 80 | factor = {} 81 | 82 | for elem, name in zip(rEffs, rEffs.names): 83 | factor[name] = pandas2ri.ri2py(elem) 84 | 85 | for elem, name in zip(fEffs, fEffs.names): 86 | factor[name] = elem 87 | 88 | return factor 89 | 90 | def predictD(self, data): 91 | '''Doc String''' 92 | 93 | check_is_fitted(self, ['models_', 'factors_']) 94 | 95 | X = self.createX(data) 96 | check_array(X, dtype=None) 97 | 98 | Y = pd.DataFrame() 99 | for yLabel in self.yLabels: 100 | Y[yLabel] = \ 101 | pandas2ri.ri2py(rLME4.predict_merMod(self.models_[yLabel], 102 | newdata=X, 103 | allow_new_levels=True)) 104 | 105 | return Y 106 | 107 | def predict(self, X): 108 | '''Doc String''' 109 | 110 | check_array(X, dtype=None) 111 | return self.predictD(pd.DataFrame(X, columns=self.xLabels)) 112 | 113 | def score(self, X, Y, sample_weight=None): 114 | '''Doc String''' 115 | 116 | Y = check_array(Y) 117 | Yp = self.predict(X) 118 | return -np.sqrt([mean_squared_error(y, yp, sample_weight) 119 | for y, yp in zip(Y.T, Yp.values.T)]).mean() 120 | 121 | def chooseFormula(self, data, formulas, n_jobs=1, cv=None, refit=True): 122 | '''Doc String''' 123 | 124 | self.cv_results_ = [] 125 | yLabels = self.yLabels 126 | formulaChoices = [] 127 | param_grid = {'formulas': formulas} 128 | for yLabel in yLabels: 129 | self.yLabels = [yLabel] 130 | 131 | result = GridSearchCV(self, param_grid, 132 | n_jobs=n_jobs, cv=cv, refit=False). \ 133 | fit(self.createX(data), self.createY(data)) 134 | formulaChoices.append(result.best_params_['formulas']) 135 | self.cv_results_.append(result.cv_results_) 136 | 137 | self.formulas = formulaChoices 138 | self.yLabels = yLabels 139 | 140 | if refit: 141 | self.fitD(data) 142 | 143 | return self 144 | -------------------------------------------------------------------------------- /statcast/better/randomforest.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from inspect import Parameter 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from sklearn.ensemble import RandomForestRegressor 8 | from sklearn.utils.validation import check_array 9 | from sklearn.metrics import mean_squared_error 10 | 11 | from .base import BetterModel 12 | 13 | 14 | class BetterRandomForestRegressor(RandomForestRegressor, BetterModel): 15 | '''Doc String''' 16 | 17 | @property 18 | def feature_importances_(self): 19 | '''Doc String''' 20 | 21 | if not self.xLabels: 22 | return super().feature_importances_ 23 | ftImps = pd.Series() 24 | ftImpsComplete = pd.Series(super().feature_importances_, 25 | index=self.colNames_). \ 26 | sort_values(ascending=False) 27 | 28 | for xLabel in self.xLabels: 29 | ftImps[xLabel] = \ 30 | ftImpsComplete[ftImpsComplete.index.str.startswith(xLabel)]. \ 31 | sum() 32 | 33 | ftImps.sort_values(ascending=False, inplace=True) 34 | return ftImps 35 | 36 | def score(self, X, Y, sample_weight=None): 37 | '''Doc String''' 38 | 39 | Y = check_array(Y) 40 | Yp = self.predict(X) 41 | return -np.sqrt([mean_squared_error(y, yp, sample_weight) 42 | for y, yp in zip(Y.T, Yp.T)]).mean() 43 | 44 | def createX(self, data): 45 | '''Doc String''' 46 | 47 | X = pd.DataFrame() 48 | for xLabel in self.xLabels: 49 | if data[xLabel].dtype is pd.api.types.CategoricalDtype(): 50 | for cat in data[xLabel].cat.categories: 51 | X[xLabel + '|' + str(cat)] = data[xLabel] == cat 52 | X[xLabel + '|' + 'null'] = data[xLabel].isnull() 53 | else: 54 | X[xLabel] = data[xLabel] 55 | 56 | return X 57 | 58 | def fit(self, X, y, sample_weight=None): 59 | '''Doc String''' 60 | 61 | if self.xLabels: 62 | try: 63 | self.colNames_ = list(X.columns) 64 | except: 65 | warnings.warn('Feature importances may not work when using a ' 66 | 'better model with labels, but calling fit', 67 | UserWarning) 68 | self.colNames_ = self.xLabels 69 | 70 | return super().fit(X, y, sample_weight) 71 | 72 | def _set_oob_score(self, X, Y): 73 | '''Doc String''' 74 | 75 | super()._set_oob_score(X, Y) 76 | Yp = self.oob_prediction_ 77 | return -np.sqrt([mean_squared_error(y, yp) 78 | for y, yp in zip(Y.T, Yp.T)]).mean() 79 | 80 | 81 | class TreeSelectingRFRegressor(BetterRandomForestRegressor, BetterModel): 82 | '''Doc String''' 83 | 84 | _params = [Parameter('treeThreshold', Parameter.POSITIONAL_OR_KEYWORD, 85 | default=0.99)] 86 | 87 | def fit(self, X, Y, sample_weight=None): 88 | '''Doc String''' 89 | 90 | warnStr = "Some inputs do not have OOB scores. " \ 91 | "This probably means too few trees were used " \ 92 | "to compute any reliable oob estimates." 93 | with warnings.catch_warnings(record=True) as caughtWarnings: 94 | warnings.filterwarnings('ignore', message=warnStr) 95 | self.warm_start = False 96 | self.oob_score = True 97 | self.n_estimators = 10 98 | scores = [super().fit(X, Y, sample_weight).oob_score_] 99 | nTrees = [self.n_estimators] 100 | for caughtWarning in caughtWarnings.copy(): 101 | if caughtWarning.message == warnStr: 102 | oobWarns = [True] 103 | caughtWarnings.remove(caughtWarning) 104 | break 105 | else: 106 | oobWarns = [False] 107 | self.warm_start = True 108 | 109 | while True: 110 | self.n_estimators = \ 111 | np.round(self.n_estimators * 1.2).astype(int) 112 | scores.append(super().fit(X, Y, sample_weight).oob_score_) 113 | for caughtWarning in caughtWarnings.copy(): 114 | if caughtWarning.message == warnStr: 115 | oobWarns.append(True) 116 | caughtWarnings.remove(caughtWarning) 117 | break 118 | else: 119 | oobWarns.append(False) 120 | nTrees.append(self.n_estimators) 121 | if (scores[-2] > (self.treeThreshold * scores[-1])) and \ 122 | (not any(oobWarns[-2:])): 123 | break 124 | 125 | for caughtWarning in caughtWarnings: 126 | warnings.showwarning(caughtWarning.message, caughtWarning.category, 127 | caughtWarning.filename, caughtWarning.lineno) 128 | 129 | self.treeScores_ = pd.Series(scores, index=nTrees) 130 | return self 131 | -------------------------------------------------------------------------------- /statcast/better/sm.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from inspect import Parameter 4 | 5 | import numpy as np 6 | 7 | from statsmodels import api as sm 8 | 9 | from sklearn.base import BaseEstimator 10 | from sklearn.utils.validation import check_X_y, check_array, check_is_fitted 11 | from sklearn.utils.multiclass import check_classification_targets 12 | 13 | from .base import BetterModel 14 | 15 | 16 | class BetterSM(BaseEstimator, metaclass=abc.ABCMeta): 17 | 18 | _params = [Parameter('addConstant', Parameter.POSITIONAL_OR_KEYWORD, 19 | default=True), 20 | Parameter('SMParams', Parameter.POSITIONAL_OR_KEYWORD, 21 | default={})] 22 | 23 | @abc.abstractmethod 24 | def mdlClass(): 25 | pass 26 | 27 | def fit(self, X, y): 28 | '''Doc String''' 29 | 30 | X, y = check_X_y(X, y) 31 | check_classification_targets(y) 32 | self.classes_, y = np.unique(y, return_inverse=True) 33 | n_samples, n_features = X.shape 34 | n_classes = len(self.classes_) 35 | if n_classes < 2: 36 | raise ValueError('y has less than 2 classes') 37 | if self.addConstant: 38 | X = sm.tools.tools.add_constant(X) 39 | self.mdl_ = self.mdlClass(y, X, **self.SMParams) 40 | self.results_ = self.mdl_.fit() 41 | return self 42 | 43 | def predict_proba(self, X): 44 | '''Doc String''' 45 | 46 | check_is_fitted(self, ['mdl_', 'results_']) 47 | check_array(X, dtype=None) 48 | 49 | if self.addConstant: 50 | X = sm.tools.tools.add_constant(X) 51 | 52 | return self.results_.predict(X) 53 | 54 | def predict(self, X): 55 | '''Doc String''' 56 | 57 | probs = self.predict_proba(X) 58 | y = self.classes_.take(probs.argmax(1)) 59 | return y 60 | 61 | 62 | class BetterGLM(BetterSM, BetterModel): 63 | 64 | mdlClass = sm.GLM 65 | 66 | 67 | class BetterMNLogit(BetterSM, BetterModel): 68 | 69 | mdlClass = sm.MNLogit 70 | -------------------------------------------------------------------------------- /statcast/better/spark.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | 3 | from collections import Sized, defaultdict 4 | from functools import partial 5 | 6 | import pandas as pd 7 | import numpy as np 8 | import scipy.sparse as sp 9 | from sklearn import model_selection 10 | from sklearn.model_selection import ParameterGrid, check_cv 11 | from sklearn.model_selection._validation import _fit_and_score, \ 12 | _fit_and_predict, _check_is_permutation, _score 13 | from sklearn.base import is_classifier, clone 14 | from sklearn.metrics.scorer import check_scoring 15 | from sklearn.utils.validation import indexable, _num_samples 16 | from sklearn.utils.fixes import rankdata, MaskedArray 17 | from sklearn.utils.metaestimators import _safe_split 18 | from sklearn.externals.joblib import Parallel, delayed 19 | 20 | try: 21 | import pyspark 22 | except ImportError: 23 | sparkRuns = False 24 | else: 25 | sparkRuns = True 26 | 27 | 28 | class GridSearchCV(model_selection.GridSearchCV): 29 | '''Doc String''' 30 | 31 | def fit(self, X, y=None, groups=None): 32 | '''Doc String''' 33 | 34 | if isinstance(self.n_jobs, int): 35 | super().fit(X, y, groups) 36 | self.cv_results_ = pd.DataFrame(self.cv_results_) 37 | return self 38 | elif sparkRuns: 39 | if not isinstance(self.n_jobs, pyspark.SparkContext): 40 | raise RuntimeError('n_jobs parameter was not an int, meaning ' 41 | 'it should have been a SparkContext, but ' 42 | 'it was not.') 43 | return self._scFit(X, y, groups, ParameterGrid(self.param_grid)) 44 | else: 45 | raise RuntimeError('n_jobs parameter was not an int, meaning it ' 46 | 'should have been a SparkContext, but spark ' 47 | 'was unable to be imported.') 48 | 49 | def _scFit(self, X, y, groups, parameter_iterable): 50 | '''Doc String''' 51 | 52 | estimator = self.estimator 53 | cv = check_cv(self.cv, y, classifier=is_classifier(estimator)) 54 | self.scorer_ = check_scoring(self.estimator, scoring=self.scoring) 55 | 56 | X, y, groups = indexable(X, y, groups) 57 | n_splits = cv.get_n_splits(X, y, groups) 58 | if self.verbose > 0 and isinstance(parameter_iterable, Sized): 59 | n_candidates = len(parameter_iterable) 60 | print("Fitting {0} folds for each of {1} candidates, totalling" 61 | " {2} fits".format(n_splits, n_candidates, 62 | n_candidates * n_splits)) 63 | 64 | base_estimator = clone(self.estimator) 65 | 66 | cv_iter = list(cv.split(X, y, groups)) 67 | param_grid = [(parameters, train, test) 68 | for parameters in parameter_iterable 69 | for (train, test) in cv_iter] 70 | # Because the original python code expects a certain order for the 71 | # elements, we need to respect it. 72 | indexed_param_grid = list(zip(range(len(param_grid)), param_grid)) 73 | par_param_grid = self.n_jobs.parallelize(indexed_param_grid, 74 | len(indexed_param_grid)) 75 | X_bc = self.n_jobs.broadcast(X) 76 | y_bc = self.n_jobs.broadcast(y) 77 | 78 | scorer = self.scorer_ 79 | verbose = self.verbose 80 | fit_params = self.fit_params 81 | return_train_score = self.return_train_score 82 | error_score = self.error_score 83 | fas = _fit_and_score 84 | 85 | def fun(tup): 86 | (index, (parameters, train, test)) = tup 87 | local_estimator = clone(base_estimator) 88 | local_X = X_bc.value 89 | local_y = y_bc.value 90 | res = fas(local_estimator, local_X, local_y, scorer, train, test, 91 | verbose, parameters, 92 | fit_params=fit_params, 93 | return_train_score=return_train_score, 94 | return_n_test_samples=True, 95 | return_times=True, 96 | return_parameters=True, 97 | error_score=error_score) 98 | return (index, res) 99 | indexed_out0 = dict(par_param_grid.map(fun).collect()) 100 | out = [indexed_out0[idx] for idx in range(len(param_grid))] 101 | 102 | X_bc.unpersist() 103 | y_bc.unpersist() 104 | 105 | # if one choose to see train score, "out" will contain train score info 106 | if self.return_train_score: 107 | (train_scores, test_scores, test_sample_counts, 108 | fit_time, score_time, parameters) = zip(*out) 109 | else: 110 | (test_scores, test_sample_counts, 111 | fit_time, score_time, parameters) = zip(*out) 112 | 113 | candidate_params = parameters[::n_splits] 114 | n_candidates = len(candidate_params) 115 | 116 | results = dict() 117 | 118 | def _store(key_name, array, weights=None, splits=False, rank=False): 119 | """A small helper to store the scores/times to the cv_results_""" 120 | array = np.array(array, dtype=np.float64).reshape(n_candidates, 121 | n_splits) 122 | if splits: 123 | for split_i in range(n_splits): 124 | results["split%d_%s" 125 | % (split_i, key_name)] = array[:, split_i] 126 | 127 | array_means = np.average(array, axis=1, weights=weights) 128 | results['mean_%s' % key_name] = array_means 129 | # Weighted std is not directly available in numpy 130 | array_stds = np.sqrt(np.average((array - 131 | array_means[:, np.newaxis]) ** 2, 132 | axis=1, weights=weights)) 133 | results['std_%s' % key_name] = array_stds 134 | 135 | if rank: 136 | results["rank_%s" % key_name] = np.asarray( 137 | rankdata(-array_means, method='min'), dtype=np.int32) 138 | 139 | # Computed the (weighted) mean and std for test scores alone 140 | # NOTE test_sample counts (weights) remain the same for all candidates 141 | test_sample_counts = np.array(test_sample_counts[:n_splits], 142 | dtype=np.int) 143 | 144 | _store('test_score', test_scores, splits=True, rank=True, 145 | weights=test_sample_counts if self.iid else None) 146 | if self.return_train_score: 147 | _store('train_score', train_scores, splits=True) 148 | _store('fit_time', fit_time) 149 | _store('score_time', score_time) 150 | 151 | best_index = np.flatnonzero(results["rank_test_score"] == 1)[0] 152 | best_parameters = candidate_params[best_index] 153 | 154 | # Use one MaskedArray and mask all the places where the param is not 155 | # applicable for that candidate. Use defaultdict as each candidate may 156 | # not contain all the params 157 | param_results = defaultdict(partial(MaskedArray, 158 | np.empty(n_candidates,), 159 | mask=True, 160 | dtype=object)) 161 | for cand_i, params in enumerate(candidate_params): 162 | for name, value in params.items(): 163 | # An all masked empty array gets created for the key 164 | # `"param_%s" % name` at the first occurence of `name`. 165 | # Setting the value at an index also unmasks that index 166 | param_results["param_%s" % name][cand_i] = value 167 | 168 | results.update(param_results) 169 | 170 | # Store a list of param dicts at the key 'params' 171 | results['params'] = candidate_params 172 | 173 | self.cv_results_ = pd.DataFrame(results) 174 | self.best_index_ = best_index 175 | self.n_splits_ = n_splits 176 | 177 | if self.refit: 178 | # fit the best estimator using the entire dataset 179 | # clone first to work around broken estimators 180 | best_estimator = clone(base_estimator).set_params( 181 | **best_parameters) 182 | if y is not None: 183 | best_estimator.fit(X, y, **self.fit_params) 184 | else: 185 | best_estimator.fit(X, **self.fit_params) 186 | self.best_estimator_ = best_estimator 187 | return self 188 | 189 | 190 | def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, 191 | n_jobs=1, verbose=0, fit_params=None, 192 | pre_dispatch='2*n_jobs'): 193 | '''Doc String''' 194 | 195 | if isinstance(n_jobs, int): 196 | return model_selection.cross_val_score(estimator, X, y, groups, 197 | scoring, cv, n_jobs, verbose, 198 | fit_params, pre_dispatch) 199 | elif sparkRuns: 200 | sc = n_jobs 201 | if not isinstance(sc, pyspark.SparkContext): 202 | raise RuntimeError('n_jobs parameter was not an int, meaning ' 203 | 'it should have been a SparkContext, but ' 204 | 'it was not.') 205 | gs = GridSearchCV(estimator=estimator, 206 | param_grid={}, 207 | scoring=scoring, 208 | fit_params=fit_params, 209 | n_jobs=n_jobs, 210 | iid=True, 211 | refit=False, 212 | cv=cv, 213 | verbose=verbose, 214 | pre_dispatch=pre_dispatch, 215 | error_score='raise').fit(X, y, groups) 216 | df = pd.DataFrame(gs.cv_results_) 217 | score = df.loc[0, ['split{}_test_score'.format(ii) 218 | for ii in range(gs.n_splits_)]].astype(float).values 219 | return score 220 | else: 221 | raise RuntimeError('n_jobs parameter was not an int, meaning it ' 222 | 'should have been a SparkContext, but spark ' 223 | 'was unable to be imported.') 224 | 225 | 226 | def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1, 227 | verbose=0, fit_params=None, pre_dispatch='2*n_jobs', 228 | method='predict'): 229 | '''Doc String''' 230 | 231 | if isinstance(n_jobs, int): 232 | return model_selection.cross_val_predict(estimator, X, y, groups, 233 | cv, n_jobs, verbose, 234 | fit_params, pre_dispatch, 235 | method) 236 | elif not sparkRuns: 237 | raise RuntimeError('n_jobs parameter was not an int, meaning it ' 238 | 'should have been a SparkContext, but spark ' 239 | 'was unable to be imported.') 240 | elif not isinstance(n_jobs, pyspark.SparkContext): 241 | raise RuntimeError('n_jobs parameter was not an int, meaning ' 242 | 'it should have been a SparkContext, but ' 243 | 'it was not.') 244 | sc = n_jobs 245 | X, y, groups = indexable(X, y, groups) 246 | 247 | cv = check_cv(cv, y, classifier=is_classifier(estimator)) 248 | cv_iter = list(cv.split(X, y, groups)) 249 | 250 | # Ensure the estimator has implemented the passed decision function 251 | if not callable(getattr(estimator, method)): 252 | raise AttributeError('{} not implemented in estimator' 253 | .format(method)) 254 | 255 | inds = [tup for tup in cv_iter] 256 | # Because the original python code expects a certain order for the 257 | # elements, we need to respect it. 258 | numInds = list(zip(range(len(inds)), inds)) 259 | parNumInds = sc.parallelize(numInds, len(numInds)) 260 | X_bc = sc.broadcast(X) 261 | y_bc = sc.broadcast(y) 262 | 263 | fap = _fit_and_predict 264 | 265 | def fun(tup): 266 | (index, (train, test)) = tup 267 | local_estimator = clone(estimator) 268 | local_X = X_bc.value 269 | local_y = y_bc.value 270 | pred = fap(local_estimator, local_X, local_y, train, test, 271 | verbose, fit_params, method) 272 | return (index, pred) 273 | indexed_out0 = dict(parNumInds.map(fun).collect()) 274 | prediction_blocks = [indexed_out0[idx] for idx in range(len(inds))] 275 | 276 | X_bc.unpersist() 277 | y_bc.unpersist() 278 | 279 | # Concatenate the predictions 280 | predictions = [pred_block_i for pred_block_i, _ 281 | in prediction_blocks] 282 | test_indices = np.concatenate([indices_i 283 | for _, indices_i 284 | in prediction_blocks]) 285 | 286 | if not _check_is_permutation(test_indices, _num_samples(X)): 287 | raise ValueError('cross_val_predict only works for partitions') 288 | 289 | inv_test_indices = np.empty(len(test_indices), dtype=int) 290 | inv_test_indices[test_indices] = np.arange(len(test_indices)) 291 | 292 | # Check for sparse predictions 293 | if sp.issparse(predictions[0]): 294 | predictions = sp.vstack(predictions, 295 | format=predictions[0].format) 296 | else: 297 | predictions = np.concatenate(predictions) 298 | return predictions[inv_test_indices] 299 | 300 | 301 | def _fit_and_score_grid(estimator, X, y, scorer, train, test, grid, 302 | fit_params, error_score='raise'): 303 | '''Doc String''' 304 | 305 | X_train, y_train = _safe_split(estimator, X, y, train) 306 | X_test, y_test = _safe_split(estimator, X, y, test, train) 307 | 308 | try: 309 | estimator.fit(X_train, y_train, **fit_params) 310 | except Exception as e: 311 | if error_score == 'raise': 312 | raise 313 | elif isinstance(error_score, numbers.Number): 314 | scores = [error_score] * len(grid) 315 | else: 316 | raise ValueError("error_score must be the string 'raise' or a" 317 | " numeric value. (Hint: if using 'raise', please" 318 | " make sure that it has been spelled correctly.)") 319 | 320 | else: 321 | origParams = estimator.get_params() 322 | scores = [_score(estimator.set_params(**params), 323 | X_test, y_test, scorer) for params in grid] 324 | estimator.set_params(**origParams) 325 | 326 | return scores 327 | 328 | 329 | def gridCVScoresAlt(estimator, param_grid, X, y=None, groups=None, 330 | scoring=None, fit_params={}, n_jobs=1, iid=True, 331 | cv=None, verbose=0, pre_dispatch='2*n_jobs', 332 | error_score='raise'): 333 | '''Function to minimally do GridSearchCV, but parallelization is by cv, not 334 | also by param set. Only works if changing paramaters can be done AFTER 335 | fitting.''' 336 | 337 | fullGrid = ParameterGrid(param_grid) 338 | cv = check_cv(cv, y, classifier=is_classifier(estimator)) 339 | scorer = check_scoring(estimator, scoring=scoring) 340 | 341 | X, y, groups = indexable(X, y, groups) 342 | 343 | base_estimator = clone(estimator) 344 | 345 | cv_iter = list(cv.split(X, y, groups)) 346 | 347 | if isinstance(n_jobs, int): 348 | 349 | out = Parallel( 350 | n_jobs=n_jobs, verbose=verbose, 351 | pre_dispatch=pre_dispatch 352 | )(delayed(_fit_and_score_grid)(clone(base_estimator), X, y, scorer, 353 | train, test, fullGrid, 354 | fit_params=fit_params, 355 | error_score=error_score) 356 | for train, test in cv_iter) 357 | 358 | elif sparkRuns: 359 | if not isinstance(n_jobs, pyspark.SparkContext): 360 | raise RuntimeError('n_jobs parameter was not an int, meaning ' 361 | 'it should have been a SparkContext, but ' 362 | 'it was not.') 363 | 364 | inds = [tup for tup in cv_iter] 365 | numInds = list(zip(range(len(inds)), inds)) 366 | parNumInds = n_jobs.parallelize(numInds, len(numInds)) 367 | X_bc = n_jobs.broadcast(X) 368 | y_bc = n_jobs.broadcast(y) 369 | 370 | fasg = _fit_and_score_grid 371 | 372 | def fun(tup): 373 | (index, (train, test)) = tup 374 | local_estimator = clone(estimator) 375 | local_X = X_bc.value 376 | local_y = y_bc.value 377 | scores = fasg(local_estimator, local_X, local_y, scorer, 378 | train, test, fullGrid, fit_params=fit_params, 379 | error_score=error_score) 380 | return (index, scores) 381 | indexed_out0 = dict(parNumInds.map(fun).collect()) 382 | out = [indexed_out0[idx] for idx in range(len(inds))] 383 | 384 | X_bc.unpersist() 385 | y_bc.unpersist() 386 | else: 387 | raise RuntimeError('n_jobs parameter was not an int, meaning it ' 388 | 'should have been a SparkContext, but spark ' 389 | 'was unable to be imported.') 390 | 391 | return pd.DataFrame({'params': [params for params in fullGrid], 392 | 'scores': [np.array(score) for score in zip(*out)]}) 393 | -------------------------------------------------------------------------------- /statcast/better/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import pandas as pd 4 | 5 | import numpy as np 6 | from sklearn.model_selection import train_test_split 7 | 8 | from .spark import cross_val_score 9 | 10 | 11 | def otherRFE(estimator, data, step=1, cv=None, scoring=None, scoreThresh=2e-2, 12 | n_jobs=1): 13 | '''Doc String''' 14 | 15 | intStep = step 16 | 17 | def score(): 18 | estimator.fitD(data) 19 | return cross_val_score(estimator, estimator.createX(data), 20 | estimator.createY(data), scoring=scoring, 21 | cv=cv, n_jobs=n_jobs) 22 | 23 | scores = [score()] 24 | bestScore = np.mean(scores[-1]) 25 | threshold = bestScore * (1 - np.sign(bestScore) * scoreThresh) 26 | ftImps = estimator.feature_importances_ 27 | results = pd.DataFrame({k: v for k, v in ftImps.iteritems()}, index=(0,)) 28 | while True: 29 | if len(ftImps.index) == 1: 30 | break 31 | if isinstance(step, float): 32 | intStep = int(min(np.floor(len(ftImps.index) * step), 1)) 33 | if intStep > len(ftImps.index): 34 | estimator.xLabels = list(ftImps.index[:-1]) 35 | else: 36 | estimator.xLabels = list(ftImps.index[:-intStep]) 37 | scores.append(score()) 38 | ftImps = estimator.feature_importances_ 39 | results = results.append(pd.DataFrame({k: v 40 | for k, v in ftImps.iteritems()}, 41 | index=(0,)), 42 | ignore_index=True) 43 | if np.mean(scores[-1]) <= threshold: 44 | estimator.xLabels = list(results.columns[~results.iloc[-2, :]. 45 | isnull()]) 46 | estimator.fitD(data) 47 | break 48 | elif np.mean(scores[-1]) > bestScore: 49 | bestScore = np.mean(scores[-1]) 50 | threshold = bestScore * (1 - np.sign(bestScore) * scoreThresh) 51 | 52 | results.sort_values(by=0, axis=1, inplace=True) 53 | results['scores'] = scores 54 | estimator.rfeResults_ = results.iloc[:, ::-1] 55 | return estimator 56 | 57 | 58 | def findTrainSplit(estimator, data, maxTrain=1.0, scoreThresh=2e-2, 59 | groups=None, scoring=None, cv=None, n_jobs=1, 60 | verbose=0, fit_params=None, pre_dispatch='2*n_jobs'): 61 | '''Doc String''' 62 | 63 | if isinstance(maxTrain, float) & (maxTrain == 1): 64 | subData = data 65 | else: 66 | subData, dummy = train_test_split(data, train_size=maxTrain) 67 | 68 | def score(data): 69 | return cross_val_score(estimator, estimator.createX(data), 70 | estimator.createY(data), 71 | groups, scoring, cv, n_jobs, verbose, 72 | fit_params, pre_dispatch) 73 | 74 | scores = [score(subData)] 75 | 76 | def getScore(ind): 77 | return np.mean(scores[ind]) 78 | 79 | def bestScore(): 80 | return max([np.mean(score) for score in scores]) 81 | 82 | def threshold(): 83 | return bestScore() * (1 - np.sign(bestScore()) * scoreThresh) 84 | 85 | trainSizes = [subData.shape[0]] 86 | while trainSizes[-1] > 1: 87 | trainSizes.append(int(np.round(trainSizes[-1] / 2))) 88 | subData, dummy = train_test_split(data, train_size=trainSizes[-1]) 89 | scores.append(score(subData)) 90 | if getScore(-1) <= threshold(): 91 | lowerBound = (trainSizes[-1], getScore(-1)) 92 | upperBound = (trainSizes[-2], getScore(-2)) 93 | break 94 | else: 95 | warnings.warn('Training size of 1 found') 96 | estimator.fitD(subData) 97 | estimator.trainSplitResults_ = pd.DataFrame({'size': trainSizes, 98 | 'score': scores}). \ 99 | sort_values(by='size') 100 | return estimator 101 | 102 | for dummy in range(3): 103 | trainSizes.append(int(np.round(trainSizes[-1] * 2 ** (1/4)))) 104 | subData, dummy = train_test_split(data, train_size=trainSizes[-1]) 105 | scores.append(score(subData)) 106 | if getScore(-1) >= threshold(): 107 | upperBound = (trainSizes[-1], getScore(-1)) 108 | break 109 | else: 110 | lowerBound = (trainSizes[-1], getScore(-1)) 111 | 112 | size = int(np.round(np.interp(threshold(), 113 | (lowerBound[1], upperBound[1]), 114 | (lowerBound[0], upperBound[0])))) 115 | subData, dummy = train_test_split(data, train_size=size) 116 | estimator.fitD(subData) 117 | estimator.trainSplitResults_ = pd.DataFrame({'size': trainSizes, 118 | 'score': scores}). \ 119 | sort_values(by='size') 120 | return estimator, subData 121 | -------------------------------------------------------------------------------- /statcast/bip.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | import numpy as np 5 | from matplotlib.lines import Line2D 6 | from sklearn.base import clone 7 | 8 | from .database.bbsavant import DB as SavantDB 9 | from .database.gd_weather import DB as WeatherDB 10 | 11 | from .better.randomforest import TreeSelectingRFRegressor 12 | from .better.mixed import BetterLME4 13 | from .better.utils import findTrainSplit, otherRFE 14 | from .tools.plot import plotKDHist 15 | 16 | from . import __path__ 17 | 18 | 19 | savantDB = SavantDB('fast') 20 | weatherDB = WeatherDB('fast') 21 | weatherData = pd.read_sql_query( 22 | '''SELECT * 23 | FROM {}'''.format(weatherDB._tblName), weatherDB.engine) 24 | 25 | _storagePath = os.path.join(__path__[0], 'data') 26 | 27 | _scImputer = \ 28 | TreeSelectingRFRegressor(xLabels=['start_speed', 29 | 'x0', 30 | 'z0', 31 | 'events', 32 | 'zone', 33 | 'hit_location', 34 | 'bb_type', 35 | 'balls', 36 | 'strikes', 37 | 'pfx_x', 38 | 'pfx_z', 39 | 'px', 40 | 'pz', 41 | 'hc_x', 42 | 'hc_y', 43 | 'vx0', 44 | 'vy0', 45 | 'vz0', 46 | 'effective_speed', 47 | 'sprayAngle', 48 | 'hitDistanceGD'], 49 | yLabels=['hit_speed', 50 | 'hit_angle', 51 | 'hit_distance_sc'], 52 | oob_score=True, 53 | n_jobs=-1) 54 | _scFactorMdl = \ 55 | BetterLME4( 56 | xLabels=['batter', 'pitcher', 'gdTemp', 'home_team', 'scImputed'], 57 | yLabels=['hit_speed', 'hit_angle', 'hit_distance_sc'], 58 | formulas=('(1|batter) + (1|pitcher) + scImputed + (1|home_team)', 59 | '(1|batter) + (1|pitcher) + gdTemp + scImputed + ' 60 | '(1|home_team)', 61 | '(1|batter) + (1|pitcher) + gdTemp + scImputed + ' 62 | '(scImputed||home_team)', 63 | '(1|batter) + (1|pitcher) + scImputed + ' 64 | '(scImputed||home_team)')) 65 | 66 | 67 | class Bip(): 68 | '''Doc String''' 69 | 70 | def __init__(self, years, scImputerName=None, scFactorMdlName=None, 71 | n_jobs=-1): 72 | '''Doc String''' 73 | 74 | self.n_jobs = n_jobs 75 | self.years = years 76 | 77 | self._initData(years) 78 | 79 | self._initSCImputer(scImputerName=scImputerName) 80 | self._imputeSCData() 81 | 82 | self._initSCFactorMdl(scFactorMdlName=scFactorMdlName) 83 | 84 | def _initData(self, years): 85 | '''Doc String''' 86 | 87 | self.data = pd.DataFrame() 88 | for year in years: 89 | rawD = pd.read_sql_query( 90 | '''SELECT * 91 | FROM {} 92 | WHERE type = 'X' 93 | AND game_year = {} 94 | AND game_type = 'R ' '''.format(savantDB._tblName, year), 95 | savantDB.engine) 96 | self.data = self.data.append(rawD, ignore_index=True) 97 | 98 | self.data['sprayAngle'] = \ 99 | (np.arctan2(208 - self.data.hc_y, self.data.hc_x - 128) / 100 | (2 * np.pi) * 360 + 90) % 360 - 180 101 | self.data['hitDistanceGD'] = np.sqrt((self.data.hc_x - 128) ** 2 + 102 | (208 - self.data.hc_y) ** 2) 103 | 104 | self.data[['on_3b', 'on_2b', 'on_1b']] = \ 105 | self.data[['on_3b', 'on_2b', 'on_1b']]. \ 106 | fillna(value=0).astype('int') 107 | self.data['baseState'] = \ 108 | (self.data[['on_3b', 'on_2b', 'on_1b']] == 0). \ 109 | replace([True, False], ['_', 'X']).sum(axis=1) 110 | 111 | temps = pd.Series(weatherData.temp.values, index=weatherData.game_pk) 112 | temps = temps[~temps.index.duplicated(keep='first')] 113 | self.data['gdTemp'] = temps.loc[self.data.game_pk].values 114 | 115 | excludeEvents = ['Batter Interference', 'Hit By Pitch', 'Strikeout', 116 | 'Walk', 'Fan Intereference', 'Field Error', 117 | 'Catcher Interference', 'Fan interference'] 118 | self.data['exclude'] = self.data.events.isin(excludeEvents) 119 | 120 | categories = ['pitch_type', 'batter', 'pitcher', 'events', 'zone', 121 | 'stand', 'p_throws', 'home_team', 'away_team', 122 | 'hit_location', 'bb_type', 'on_3b', 'on_2b', 'on_1b', 123 | 'inning_topbot', 'catcher', 'umpire', 'game_pk', 124 | 'baseState'] 125 | for category in categories: 126 | self.data[category] = self.data[category].astype('category') 127 | 128 | zeroIsMissingCols = ['hit_speed', 'hit_angle', 'hit_distance_sc'] 129 | for col in zeroIsMissingCols: 130 | self.data.loc[self.data[col] == 0, col] = np.nan 131 | 132 | self.data['missing'] = [', '.join(self.data.columns[row]) 133 | for row in self.data.isnull().values] 134 | 135 | self.data['scImputed'] = self.missing(_scImputer.yLabels) 136 | 137 | self.data.fillna(self.data.median(), inplace=True) 138 | 139 | def _imputeSCData(self): 140 | '''Doc String''' 141 | 142 | imputeData = self.data[~self.data.exclude & self.data.scImputed] 143 | imputeY = pd.DataFrame(self.scImputer.predictD(imputeData), 144 | columns=self.scImputer.yLabels) 145 | 146 | for label in self.scImputer.yLabels: 147 | imputeThisCol = self.data.missing.map(lambda x: label in x) 148 | self.data.loc[~self.data.exclude & imputeThisCol, label] = \ 149 | imputeY.loc[imputeThisCol[~self.data.exclude & 150 | self.data.scImputed].values, 151 | label].values 152 | 153 | def _initSCImputer(self, scImputerName=None): 154 | '''Doc String''' 155 | 156 | if scImputerName == 'new': 157 | self._createSCImputer() 158 | elif scImputerName is not None: 159 | self.scImputer = _scImputer.load(scImputerName) 160 | else: 161 | name = 'scImputer{}'.format('_'.join(str(year) 162 | for year in self.years)) 163 | try: 164 | self.scImputer = \ 165 | _scImputer.load(name=name, searchDirs=(_storagePath,)) 166 | except FileNotFoundError: 167 | self._createSCImputer() 168 | self.scImputer.name = name 169 | try: 170 | self.scImputer.save(os.path.join(_storagePath, 171 | self.scImputer.name)) 172 | except PermissionError: 173 | self.scImputer.save(self.scImputer.name) 174 | 175 | def _createSCImputer(self): 176 | '''Doc String''' 177 | 178 | trainData = self.data[~self.data.exclude & ~self.data.scImputed] 179 | scImputer = clone(_scImputer) 180 | self.scImputer, subTrainData = findTrainSplit(scImputer, trainData, 181 | n_jobs=self.n_jobs) 182 | otherRFE(self.scImputer, subTrainData, cv=10, n_jobs=self.n_jobs) 183 | findTrainSplit(self.scImputer, trainData, cv=10, n_jobs=self.n_jobs) 184 | 185 | def _initSCFactorMdl(self, scFactorMdlName=None): 186 | '''Doc String''' 187 | 188 | if scFactorMdlName == 'new': 189 | self._createSCFactorMdl() 190 | elif scFactorMdlName is not None: 191 | self.scFactorMdl = _scFactorMdl.load(scFactorMdlName) 192 | else: 193 | name = 'scFactorMdl{}'.format('_'.join(str(year) 194 | for year in self.years)) 195 | try: 196 | self.scFactorMdl = \ 197 | _scFactorMdl.load(name=name, searchDirs=(_storagePath,)) 198 | except FileNotFoundError: 199 | self._createSCFactorMdl() 200 | self.scFactorMdl.name = name 201 | try: 202 | self.scFactorMdl.save(os.path.join(_storagePath, 203 | self.scFactorMdl.name)) 204 | except PermissionError: 205 | self.scFactorMdl.save(self.scFactorMdl.name) 206 | 207 | def _createSCFactorMdl(self): 208 | '''Doc String''' 209 | 210 | trainData = self.data[~self.data.exclude] 211 | scFactorMdl = clone(_scFactorMdl) 212 | self.scFactorMdl = scFactorMdl.chooseFormula(trainData, 213 | _scFactorMdl.formulas, 214 | n_jobs=self.n_jobs, 215 | cv=10) 216 | 217 | def missing(self, columns): 218 | '''Doc String''' 219 | 220 | return self.data.missing.map(lambda x: 221 | any(y in x 222 | for y in columns)) 223 | 224 | def plotSCHistograms(self): 225 | '''Doc String''' 226 | 227 | labels = ['Exit Velocity', 'Launch Angle', 'Hit Distance'] 228 | units = ['mph', 'degrees', 'feet'] 229 | 230 | testData = self.data.loc[~self.data.exclude & ~self.data.scImputed, :] 231 | imputeData = self.data.loc[~self.data.exclude & self.data.scImputed, :] 232 | 233 | testY = self.scImputer.createY(testData).values.T 234 | testYp = self.scImputer.predictD(testData).T 235 | imputeY = self.scImputer.predictD(imputeData).T 236 | 237 | del testData, imputeData 238 | 239 | name = 'bandwidths{}.csv'.format('_'.join(str(year) 240 | for year in self.years)) 241 | 242 | try: 243 | bandwidths = pd.read_csv(os.path.join(_storagePath, name), 244 | index_col=0) 245 | saveFlag = False 246 | except FileNotFoundError: 247 | bandwidths = pd.DataFrame({'test': None, 248 | 'testP': None, 249 | 'impute': None}, 250 | index=self.scImputer.yLabels) 251 | saveFlag = True 252 | 253 | for testy, testyp, imputey, label, unit, yLabel in \ 254 | zip(testY, testYp, imputeY, labels, units, 255 | self.scImputer.yLabels): 256 | 257 | fig, kde = plotKDHist(testy, 258 | bandwidth=bandwidths.loc[yLabel, 'test'], 259 | cv=10, n_jobs=self.n_jobs) 260 | if saveFlag: 261 | bandwidths.loc[yLabel, 'test'] = kde.bandwidth 262 | del kde 263 | ax = fig.gca() 264 | 265 | ax, kde = plotKDHist(testyp, ax=ax, 266 | bandwidth=bandwidths.loc[yLabel, 'testP'], 267 | cv=10, n_jobs=self.n_jobs) 268 | if saveFlag: 269 | bandwidths.loc[yLabel, 'testP'] = kde.bandwidth 270 | del kde 271 | 272 | ax, kde = plotKDHist(imputey, ax=ax, 273 | bandwidth=bandwidths.loc[yLabel, 'impute'], 274 | cv=10, n_jobs=self.n_jobs) 275 | if saveFlag: 276 | bandwidths.loc[yLabel, 'impute'] = kde.bandwidth 277 | del kde 278 | 279 | ax.set_xlim(left=min(testy.min(), testyp.min(), imputey.min()), 280 | right=max(testy.max(), testyp.max(), imputey.max())) 281 | ax.set_ylim(bottom=0, auto=True) 282 | 283 | ax.set_xlabel(label + ' ({})'.format(unit)) 284 | lines = [child for child in ax.get_children() 285 | if isinstance(child, Line2D)] 286 | ax.legend(handles=lines, 287 | labels=('Test Data', 288 | 'Test Data Imputed', 289 | 'Missing Data Imputed'), loc='best') 290 | 291 | fig.savefig('{} {} Histogram'. 292 | format(', '.join(str(year) for year in self.years), 293 | label)) 294 | if saveFlag: 295 | try: 296 | bandwidths.to_csv(os.path.join(_storagePath, name)) 297 | except PermissionError: 298 | bandwidths.to_csv(name) 299 | -------------------------------------------------------------------------------- /statcast/data/bandwidths2015.csv: -------------------------------------------------------------------------------- 1 | ,impute,test,testP 2 | hit_speed,0.921317728184,5.98285281818,0.910468710997 3 | hit_angle,1.10429468351,3.64918147798,1.07985007296 4 | hit_distance_sc,16.2964047595,6.39317084182,3.58926521471 5 | -------------------------------------------------------------------------------- /statcast/data/bandwidths2016.csv: -------------------------------------------------------------------------------- 1 | ,impute,test,testP 2 | hit_speed,2.45730107378,2.37418702842,1.38567295995 3 | hit_angle,2.07916337246,3.79503670649,1.24370776311 4 | hit_distance_sc,5.62536683407,18.0768827289,3.47030322327 5 | -------------------------------------------------------------------------------- /statcast/data/blackontrans.mplstyle: -------------------------------------------------------------------------------- 1 | lines.linewidth : 1.5 # line width in points 2 | lines.color : 1 # has no affect on plot(); see axes.prop_cycle 3 | 4 | font.size : 12 5 | font.sans-serif : Helvetica Neue, Helvetica, Bitstream Vera Sans, Lucida Grande, Verdana, Geneva, Lucid, Arial, Avant Garde, sans-serif 6 | text.color : 1 7 | 8 | axes.titlesize : 20 # fontsize of the axes title 9 | axes.labelsize : 16 # fontsize of the x any y labels 10 | axes.labelcolor : 1 11 | axes.facecolor : 0.0667 12 | axes.edgecolor : 1 13 | axes.prop_cycle : cycler('color', ['4b78ca', '467c5a', '7b003d', 'e1883b', 'd9231d', 'eed180', 'd34389', '91b372', 'e7b8cf']) 14 | 15 | xtick.color : 1 # color of the tick labels 16 | xtick.labelsize : 12 # fontsize of the tick labels 17 | 18 | ytick.color : 1 # color of the tick labels 19 | ytick.labelsize : 12 # fontsize of the tick labels 20 | 21 | legend.fontsize : 16 22 | legend.handletextpad : 0.4 23 | legend.handlelength : 1.5 24 | 25 | figure.figsize : 10.21, 5.71 # figure size in inches 26 | figure.titlesize : 20 27 | figure.facecolor : 0.0667 28 | figure.dpi : 125 29 | 30 | savefig.dpi : 125 # figure dots per inch 31 | savefig.transparent : True -------------------------------------------------------------------------------- /statcast/data/logos/ARI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/ARI.png -------------------------------------------------------------------------------- /statcast/data/logos/ATL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/ATL.png -------------------------------------------------------------------------------- /statcast/data/logos/BAL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/BAL.png -------------------------------------------------------------------------------- /statcast/data/logos/BOS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/BOS.png -------------------------------------------------------------------------------- /statcast/data/logos/CHC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/CHC.png -------------------------------------------------------------------------------- /statcast/data/logos/CIN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/CIN.png -------------------------------------------------------------------------------- /statcast/data/logos/CLE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/CLE.png -------------------------------------------------------------------------------- /statcast/data/logos/COL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/COL.png -------------------------------------------------------------------------------- /statcast/data/logos/CWS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/CWS.png -------------------------------------------------------------------------------- /statcast/data/logos/DET.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/DET.png -------------------------------------------------------------------------------- /statcast/data/logos/HOU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/HOU.png -------------------------------------------------------------------------------- /statcast/data/logos/KC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/KC.png -------------------------------------------------------------------------------- /statcast/data/logos/LAA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/LAA.png -------------------------------------------------------------------------------- /statcast/data/logos/LAD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/LAD.png -------------------------------------------------------------------------------- /statcast/data/logos/MIA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/MIA.png -------------------------------------------------------------------------------- /statcast/data/logos/MIL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/MIL.png -------------------------------------------------------------------------------- /statcast/data/logos/MIN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/MIN.png -------------------------------------------------------------------------------- /statcast/data/logos/NYM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/NYM.png -------------------------------------------------------------------------------- /statcast/data/logos/NYY.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/NYY.png -------------------------------------------------------------------------------- /statcast/data/logos/OAK.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/OAK.png -------------------------------------------------------------------------------- /statcast/data/logos/PHI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/PHI.png -------------------------------------------------------------------------------- /statcast/data/logos/PIT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/PIT.png -------------------------------------------------------------------------------- /statcast/data/logos/SD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/SD.png -------------------------------------------------------------------------------- /statcast/data/logos/SEA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/SEA.png -------------------------------------------------------------------------------- /statcast/data/logos/SF.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/SF.png -------------------------------------------------------------------------------- /statcast/data/logos/STL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/STL.png -------------------------------------------------------------------------------- /statcast/data/logos/TB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/TB.png -------------------------------------------------------------------------------- /statcast/data/logos/TEX.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/TEX.png -------------------------------------------------------------------------------- /statcast/data/logos/TOR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/TOR.png -------------------------------------------------------------------------------- /statcast/data/logos/WSH.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/data/logos/WSH.png -------------------------------------------------------------------------------- /statcast/data/personal.mplstyle: -------------------------------------------------------------------------------- 1 | lines.linewidth : 1.5 # line width in points 2 | lines.color : 0.2 # has no affect on plot(); see axes.prop_cycle 3 | 4 | font.size : 12 5 | font.sans-serif : Helvetica Neue, Helvetica, Bitstream Vera Sans, Lucida Grande, Verdana, Geneva, Lucid, Arial, Avant Garde, sans-serif 6 | text.color : 0.2 7 | 8 | axes.titlesize : 20 # fontsize of the axes title 9 | axes.labelsize : 16 # fontsize of the x any y labels 10 | axes.labelcolor : 0.2 11 | axes.edgecolor : 0.2 12 | axes.prop_cycle : cycler('color', ['4b78ca', '467c5a', '7b003d', 'e1883b', 'd9231d', 'eed180', 'd34389', '91b372', 'e7b8cf']) 13 | 14 | xtick.color : 0.2 # color of the tick labels 15 | xtick.labelsize : 12 # fontsize of the tick labels 16 | 17 | ytick.color : 0.2 # color of the tick labels 18 | ytick.labelsize : 12 # fontsize of the tick labels 19 | 20 | legend.fontsize : 16 21 | legend.handletextpad : 0.4 22 | legend.handlelength : 1.5 23 | 24 | figure.figsize : 10.21, 5.71 # figure size in inches 25 | figure.titlesize : 20 26 | 27 | savefig.dpi : 125 # figure dots per inch 28 | -------------------------------------------------------------------------------- /statcast/database/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/database/__init__.py -------------------------------------------------------------------------------- /statcast/database/bbsavant.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime as dt 3 | 4 | import pandas as pd 5 | import sqlalchemy as sa 6 | 7 | from .database import Database 8 | 9 | 10 | _string = sa.types.String 11 | _integer = sa.types.Integer 12 | _float = sa.types.Float 13 | _date = sa.types.Date 14 | 15 | _baseURL = ''' 16 | https://baseballsavant.mlb.com/statcast_search/csv? 17 | all=true& 18 | hfPT=& 19 | hfZ=& 20 | hfGT=R%7CPO%7CS%7C& 21 | hfPR=& 22 | hfAB=& 23 | stadium={venue}& 24 | hfBBT=& 25 | hfBBL=& 26 | hfC=& 27 | season=all& 28 | player_type=batter& 29 | hfOuts=& 30 | pitcher_throws=& 31 | batter_stands=& 32 | start_speed_gt=& 33 | start_speed_lt=& 34 | perceived_speed_gt=& 35 | perceived_speed_lt=& 36 | spin_rate_gt=& 37 | spin_rate_lt=& 38 | exit_velocity_gt=& 39 | exit_velocity_lt=& 40 | launch_angle_gt=& 41 | launch_angle_lt=& 42 | distance_gt=& 43 | distance_lt=& 44 | batted_ball_angle_gt=& 45 | batted_ball_angle_lt=& 46 | game_date_gt={date}& 47 | game_date_lt={date}& 48 | team=& 49 | position=& 50 | hfRO=& 51 | home_road=& 52 | hfInn=& 53 | min_pitches=0& 54 | min_results=0& 55 | group_by=name-event& 56 | sort_col=pitches& 57 | player_event_sort=start_speed& 58 | sort_order=desc& 59 | min_abs=0& 60 | xba_gt=& 61 | xba_lt=& 62 | px1=& 63 | px2=& 64 | pz1=& 65 | pz2=& 66 | ss_gt=& 67 | ss_lt=& 68 | is_barrel=& 69 | type=details& 70 | '''.replace('\n', '') 71 | 72 | _venues = [ 73 | 'LAA', 74 | 'HOU', 75 | 'OAK', 76 | 'TOR', 77 | 'ATL', 78 | 'MIL', 79 | 'STL', 80 | 'CHC', 81 | 'ARI', 82 | 'LAD', 83 | 'SF', 84 | 'CLE', 85 | 'SEA', 86 | 'MIA', 87 | 'NYM', 88 | 'WSH', 89 | 'BAL', 90 | 'SD', 91 | 'PHI', 92 | 'PIT', 93 | 'TEX', 94 | 'TB', 95 | 'BOS', 96 | 'CIN', 97 | 'COL', 98 | 'KC', 99 | 'DET', 100 | 'MIN', 101 | 'CWS', 102 | 'NYY'] 103 | 104 | 105 | class DB(Database): 106 | '''Doc String''' 107 | 108 | dbName = 'bbsavant' 109 | _username = 'matt' 110 | _password = 'gratitude' 111 | _host = 'baseball.cxx9lqfsabek.us-west-2.rds.amazonaws.com' 112 | _port = 5432 113 | _drivername = 'postgresql' 114 | startDate = dt.date(2008, 1, 1) 115 | _itemKeyName = 'game_pk' 116 | _tblDTypes = dict( 117 | pitch_type=_string, 118 | pitch_id=_integer, 119 | game_date=_date, 120 | start_speed=_float, 121 | x0=_float, 122 | z0=_float, 123 | player_name=_string, 124 | batter=_integer, 125 | pitcher=_integer, 126 | events=_string, 127 | description=_string, 128 | spin_dir=_float, 129 | spin_rate=_float, 130 | break_angle=_float, 131 | break_length=_float, 132 | zone=_integer, 133 | des=_string, 134 | game_type=_string, 135 | stand=_string, 136 | p_throws=_string, 137 | home_team=_string, 138 | away_team=_string, 139 | type=_string, 140 | hit_location=_integer, 141 | bb_type=_integer, 142 | balls=_integer, 143 | strikes=_integer, 144 | game_year=_integer, 145 | pfx_x=_float, 146 | pfx_z=_float, 147 | px=_float, 148 | pz=_float, 149 | on_3b=_integer, 150 | on_2b=_integer, 151 | on_1b=_integer, 152 | outs_when_up=_integer, 153 | inning=_integer, 154 | inning_topbot=_string, 155 | hc_x=_float, 156 | hc_y=_float, 157 | tfs=_integer, 158 | tfs_zulu=_string, 159 | catcher=_integer, 160 | umpire=_integer, 161 | sv_id=_string, 162 | vx0=_float, 163 | vy0=_float, 164 | vz0=_float, 165 | ax=_float, 166 | ay=_float, 167 | az=_float, 168 | sz_top=_float, 169 | sz_bot=_float, 170 | hit_distance_sc=_integer, 171 | hit_speed=_float, 172 | hit_angle=_float, 173 | effective_speed=_float, 174 | release_spin_rate=_float, 175 | release_extension=_float, 176 | game_pk=_integer) 177 | 178 | def _getItems(self, d): 179 | '''Doc string''' 180 | 181 | items = [] 182 | itemKeys = [] 183 | for v in _venues: 184 | for dummy in range(100): 185 | try: 186 | data = pd.read_csv(_baseURL.format(date=d, venue=v), 187 | parse_dates=[2], 188 | na_values='null') 189 | except Exception as e: 190 | self.logger.debug( 191 | '{!r} occurred while trying to dowload {} {}.'. 192 | format(e, v, d)) 193 | time.sleep(5) 194 | else: 195 | if not data.empty: 196 | game_pks = data.game_pk.unique() 197 | itemKeys.extend(game_pks) 198 | for game_pk in game_pks: 199 | items.append(data.iloc[ 200 | data.game_pk.values == game_pk, :]) 201 | break 202 | else: 203 | self.logger.error( 204 | 'Unable to download {} {} after {} attempts.'. 205 | format(v, d, dummy + 1)) 206 | 207 | return (items, itemKeys) 208 | -------------------------------------------------------------------------------- /statcast/database/database.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import abc 3 | 4 | from pathlib import Path 5 | import datetime as dt 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import sqlalchemy as sa 10 | 11 | from ..tools.fixpath import findFile 12 | 13 | 14 | _string = sa.types.String 15 | _integer = sa.types.Integer 16 | _float = sa.types.Float 17 | _date = sa.types.Date 18 | _binary = sa.types.Binary 19 | 20 | 21 | class Database(metaclass=abc.ABCMeta): 22 | '''Doc String''' 23 | 24 | _username = None 25 | _password = None 26 | _host = None 27 | _port = None 28 | _tblName = 'raw' 29 | _updtTblName = 'updates' 30 | _updtTblDTypes = {'cmd': _string, 'dateFrom': _date, 'dateTo': _date} 31 | 32 | @abc.abstractmethod 33 | def _drivername(): 34 | pass 35 | 36 | @abc.abstractmethod 37 | def _tblDTypes(): 38 | pass 39 | 40 | @abc.abstractmethod 41 | def dbName(): 42 | pass 43 | 44 | @abc.abstractmethod 45 | def startDate(): 46 | pass 47 | 48 | @abc.abstractmethod 49 | def _itemKeyName(): 50 | pass 51 | 52 | @abc.abstractmethod 53 | def _getItems(self, date): 54 | pass 55 | 56 | def __init__(self, fast=False): 57 | '''Doc string''' 58 | 59 | self.logger = logging.getLogger(__name__) 60 | self.logger.setLevel(logging.DEBUG) 61 | 62 | fmt = logging.Formatter( 63 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 64 | 65 | sH = logging.StreamHandler() 66 | sH.setLevel(logging.WARNING) 67 | sH.setFormatter(fmt) 68 | self.logger.addHandler(sH) 69 | 70 | # Local database 71 | if self._host is None: 72 | dbPath = findFile(self.dbName + '.db') 73 | if not dbPath: 74 | dbPath = self.dbName + '.db' 75 | logPath = str(Path(dbPath).with_name(self.dbName + '.log')) 76 | else: 77 | dbPath = self.dbName 78 | logPath = self.dbName + '.log' 79 | 80 | fH = logging.FileHandler(logPath) 81 | fH.setLevel(logging.DEBUG) 82 | fH.setFormatter(fmt) 83 | self.logger.addHandler(fH) 84 | 85 | url = sa.engine.url.URL(drivername=self._drivername, 86 | username=self._username, 87 | password=self._password, 88 | host=self._host, 89 | port=self._port, 90 | database=dbPath) 91 | self.engine = sa.create_engine(url) 92 | 93 | if self._drivername == 'postgresql': 94 | tempURL = sa.engine.url.URL(drivername=self._drivername, 95 | username=self._username, 96 | password=self._password, 97 | host=self._host, 98 | port=self._port, 99 | database='postgres') 100 | tempEngine = sa.create_engine(tempURL) 101 | dbs = pd.read_sql_query( 102 | '''SELECT datname FROM pg_database 103 | WHERE datistemplate = false''', tempEngine) 104 | if not (self.dbName == dbs).any()[0]: 105 | tempConnection = tempEngine.connect() 106 | tempConnection.execute('commit') 107 | tempConnection.execute( 108 | 'create database "{}"'.format(self.dbName)) 109 | self._init0() 110 | return 111 | 112 | if not self.engine.has_table(self._tblName): 113 | self._init0() 114 | return 115 | 116 | self.lastUpdate = pd.read_sql_query( 117 | 'SELECT "dateTo" FROM "{}" ORDER BY "dateTo" DESC LIMIT 1'. 118 | format(self._updtTblName), 119 | self.engine, parse_dates=['dateTo']).dateTo.iloc[0].date() 120 | 121 | if fast: 122 | self.itemKeys = None 123 | return 124 | 125 | self.itemKeys = list(pd.read_sql_query( 126 | 'SELECT DISTINCT "{}" FROM "{}"'.format(self._itemKeyName, 127 | self._tblName), 128 | self.engine)[self._itemKeyName]) 129 | 130 | if not self.lastUpdate == dt.date.today(): 131 | self.update() 132 | 133 | def _init0(self): 134 | '''Doc string''' 135 | 136 | self.logger.info('Initializing database') 137 | self.itemKeys = [] 138 | 139 | self._update(self.startDate) 140 | 141 | def _addItem(self, item, itemKey, replace=False): 142 | '''Doc String''' 143 | 144 | if replace: 145 | self._rmItem(itemKey) 146 | 147 | try: 148 | item.to_sql(self._tblName, self.engine, if_exists='append', 149 | index=False, dtype=self._tblDTypes) 150 | except Exception as e: 151 | bads = self._checkItem(item) 152 | if not bads: 153 | raise e 154 | else: 155 | self._fixItem(item, bads, itemKey) 156 | 157 | item.to_sql(self._tblName, self.engine, if_exists='append', 158 | index=False, dtype=self._tblDTypes) 159 | 160 | self.itemKeys.append(itemKey) 161 | 162 | def _addDate(self, d, replace=False): 163 | '''Doc string''' 164 | 165 | (items, itemKeys) = self._getItems(d) 166 | for (item, itemKey) in zip(items, itemKeys): 167 | self._addItem(item, itemKey, replace) 168 | 169 | def _addDates(self, dates, replace=False): 170 | '''Doc string''' 171 | 172 | for date in dates: 173 | self._addDate(date, replace) 174 | 175 | def _addDateRng(self, start, end=dt.date.today(), step=1, replace=False): 176 | '''Doc string''' 177 | 178 | dates = [start + dt.timedelta(ii) 179 | for ii in range(0, (end - start).days, step)] 180 | self._addDates(dates, replace) 181 | 182 | def _rmItem(self, itemKey): 183 | '''Doc String''' 184 | 185 | self.engine.execute('DELETE FROM "{}" WHERE "{}" = {}'. 186 | format(self._tblName, self._itemKeyName, itemKey)) 187 | if itemKey in self.itemKeys: 188 | self.itemKeys.remove(itemKey) 189 | 190 | def _rmDate(self, d): 191 | '''Doc String''' 192 | 193 | (items, itemKeys) = self._getItems(d) 194 | for itemKey in itemKeys: 195 | self._rmItem(itemKey) 196 | 197 | def _rmDates(self, dates): 198 | '''Doc String''' 199 | 200 | for date in dates: 201 | self._rmDate(date) 202 | 203 | def _rmDateRng(self, start, end=dt.date.today(), step=1): 204 | '''Doc String''' 205 | 206 | dates = [start + dt.timedelta(ii) 207 | for ii in range(0, (end - start).days, step)] 208 | self._rmDates(dates) 209 | 210 | def _addUpdate(self, cmd, dateFrom, dateTo): 211 | '''Doc String''' 212 | 213 | if self.engine.has_table(self._updtTblName): 214 | count = self.engine.execute( 215 | 'SELECT COUNT(*) FROM "{}"'.format(self._updtTblName)). \ 216 | fetchone() 217 | else: 218 | count = (0,) 219 | update = pd.DataFrame({'cmd': cmd, 220 | 'dateFrom': dateFrom, 221 | 'dateTo': dateTo}, 222 | index=count) 223 | update.to_sql('updates', self.engine, if_exists='append', 224 | dtype=self._updtTblDTypes) 225 | self.lastUpdate = dateTo 226 | 227 | def _update(self, start, end=dt.date.today(), replaceStart=False): 228 | '''Doc String''' 229 | 230 | if replaceStart: 231 | self._rmDate(start) 232 | 233 | self._addDateRng(start, end) 234 | self._addUpdate('update', start, end) 235 | self.logger.info('Updated database') 236 | 237 | def update(self): 238 | '''Doc String''' 239 | 240 | if not self.engine.has_table(self._tblName): 241 | print('Database not yet initialized') 242 | return 243 | 244 | self._update(self.lastUpdate, replaceStart=True) 245 | 246 | def loadItem(self, itemKey): 247 | '''Doc String''' 248 | 249 | if itemKey not in self.itemKeys: 250 | print('Item key {} not found in database'.format(itemKey)) 251 | return pd.DataFrame() 252 | 253 | return pd.read_sql_query( 254 | 'SELECT * FROM "{}" WHERE "{}" = {}'.format( 255 | self._tblName, self._itemKeyName, itemKey), 256 | self.engine, parse_dates=[k for k, v in self._tblDTypes.items() 257 | if v == _date]) 258 | 259 | def _checkItem(self, item): 260 | '''Doc String''' 261 | 262 | bads = [] 263 | for col, sqlType in self._tblDTypes.items(): 264 | ser = item.loc[:, col] 265 | if sqlType is _string: 266 | checkFunc = str 267 | elif sqlType is _integer: 268 | checkFunc = int 269 | elif sqlType is _float: 270 | checkFunc = float 271 | elif sqlType is _date: 272 | checkFunc = pd.to_datetime 273 | elif sqlType is _binary: 274 | checkFunc = bool 275 | else: 276 | raise TypeError('An invalid datatype {} was supplied for ' 277 | 'column {}'.format(sqlType, col)) 278 | for ind, elem in ser.iloc[~(ser.isnull().values)].items(): 279 | try: 280 | checkFunc(elem) 281 | except Exception as e: 282 | bads.append((ind, col, e)) 283 | return bads 284 | 285 | def _fixItem(self, item, bads, itemKey): 286 | '''Doc String''' 287 | 288 | for bad in bads: 289 | elem = item.loc[bad[0], bad[1]] 290 | item.loc[bad[0], bad[1]] = np.nan 291 | self.logger.warning( 292 | 'Bad element {} at row {}, column {} of item {} was replaced ' 293 | 'with np.nan. This exception was raised: {!r}'.format(elem, 294 | bad[0], 295 | bad[1], 296 | itemKey, 297 | bad[2])) 298 | -------------------------------------------------------------------------------- /statcast/database/gd_game_events.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | 3 | import xml.etree.ElementTree as ET 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import sqlalchemy as sa 8 | 9 | from .gddb import GdDatabase 10 | 11 | 12 | _string = sa.types.String 13 | _integer = sa.types.Integer 14 | _float = sa.types.Float 15 | _date = sa.types.Date 16 | _binary = sa.types.Binary 17 | 18 | 19 | class DB(GdDatabase): 20 | '''Doc String''' 21 | 22 | dbName = 'gdGameEvents' 23 | startDate = dt.date(2008, 1, 1) 24 | _fileName = 'game_events.xml' 25 | _tblDTypes = dict( 26 | game_pk=_integer, 27 | inning=_integer, 28 | inning_topbot=_string, 29 | entry=_string, 30 | away_team_runs=_integer, 31 | b=_integer, 32 | b1=_integer, 33 | b2=_integer, 34 | b3=_integer, 35 | batter=_integer, 36 | des=_string, 37 | des_es=_string, 38 | events=_string, 39 | events_es=_string, 40 | event_num=_integer, 41 | home_team_runs=_integer, 42 | num=_integer, 43 | o=_integer, 44 | pitcher=_integer, 45 | play_guid=_string, 46 | pitch=_integer, 47 | player=_integer, 48 | rbi=_integer, 49 | s=_integer, 50 | score=_string, 51 | tfs=_integer, 52 | tfs_zulu=_string, 53 | pitch_des=_string, 54 | pitch_des_es=_string, 55 | pitch_type=_string, 56 | start_speed=_float, 57 | sv_id=_string, 58 | type=_string) 59 | 60 | def _parseFile(self, file, itemKey): 61 | '''Doc string''' 62 | 63 | rowDict1 = dict.fromkeys(self._tblDTypes.keys(), np.nan) 64 | 65 | tree = ET.parse(file) 66 | root = tree.getroot() 67 | innings = root.findall('inning') 68 | rowDict1[self._itemKeyName] = itemKey 69 | df = pd.DataFrame() 70 | for inning in innings: 71 | rowDict1['inning'] = inning.get('num') 72 | for innHalf in inning: 73 | if innHalf.tag == 'top': 74 | rowDict1['inning_topbot'] = 'top' 75 | else: 76 | rowDict1['inning_topbot'] = 'bot' 77 | for entry in innHalf: 78 | rowDict2 = rowDict1.copy() 79 | rowDict2['entry'] = entry.tag 80 | rowDict2['tfs'] = entry.attrib.pop('start_tfs', np.nan) 81 | rowDict2['tfs_zulu'] = \ 82 | entry.attrib.pop('start_tfs_zulu', np.nan) 83 | rowDict2['events'] = \ 84 | '::'.join(entry.attrib.pop(key) 85 | for key in sorted(tuple(entry.attrib.keys())) 86 | if 'event' in key and 87 | not key.endswith(('_es', '_num'))) 88 | rowDict2['events_es'] = \ 89 | '::'.join(entry.attrib.pop(key) 90 | for key in sorted(tuple(entry.attrib.keys())) 91 | if 'event' in key and 92 | key.endswith('_es')) 93 | rowDict2.update(entry.attrib) 94 | pitches = entry.findall('pitch') 95 | if entry.tag == 'atbat' and len(pitches) > 0: 96 | for pitch in pitches: 97 | rowDict3 = rowDict2.copy() 98 | rowDict3['pitch_des'] = \ 99 | pitch.attrib.pop('des', np.nan) 100 | rowDict3['pitch_des_es'] = \ 101 | pitch.attrib.pop('des_es', np.nan) 102 | rowDict3.update(pitch.attrib) 103 | df = df.append(pd.DataFrame(rowDict3, index=(0,)), 104 | ignore_index=True) 105 | else: 106 | df = df.append(pd.DataFrame(rowDict2, index=(0,)), 107 | ignore_index=True) 108 | 109 | df.replace('', np.nan, inplace=True) 110 | return df 111 | 112 | def _fixItem(self, item, bads, itemKey): 113 | '''Doc String''' 114 | 115 | for ind, col, err in bads: 116 | if col in ('b1', 'b2', 'b3'): 117 | badElem = item.loc[ind, col] 118 | goodElem = item.loc[ind, col].split(' ')[-1] 119 | item.loc[ind, col] = goodElem 120 | self.logger.info( 121 | 'Bad element {} at row {}, column {} of item {} was ' 122 | 'replaced with {}.'.format(badElem, ind, col, 123 | itemKey, goodElem)) 124 | else: 125 | super()._fixItem(item, bads, itemKey) 126 | -------------------------------------------------------------------------------- /statcast/database/gd_scoreboards.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | import datetime as dt 4 | import xml.etree.ElementTree as ET 5 | 6 | import requests 7 | import pandas as pd 8 | import sqlalchemy as sa 9 | 10 | from .database import Database 11 | 12 | 13 | _string = sa.types.String 14 | _integer = sa.types.Integer 15 | _float = sa.types.Float 16 | _date = sa.types.Date 17 | _binary = sa.types.Binary 18 | 19 | _baseURL = \ 20 | 'http://gd2.mlb.com/components/game/mlb/year_{yyyy}/month_{mm}/day_{dd}/{}' 21 | 22 | dailyScoreboard = 'master_scoreboard.xml' 23 | 24 | 25 | class DB(Database): 26 | '''Doc String''' 27 | 28 | dbName = 'gdScoreboardGames' 29 | startDate = dt.date(2008, 1, 1) 30 | _itemKeyName = 'game_pk' 31 | _username = 'matt' 32 | _password = 'gratitude' 33 | _host = 'baseball.cxx9lqfsabek.us-west-2.rds.amazonaws.com' 34 | _port = 5432 35 | _drivername = 'postgresql' 36 | _tblDTypes = dict( 37 | ampm=_string, 38 | aw_lg_ampm=_string, 39 | away_ampm=_string, 40 | away_code=_string, 41 | away_division=_string, 42 | away_file_code=_string, 43 | away_games_back=_string, 44 | away_games_back_wildcard=_string, 45 | away_league_id=_integer, 46 | away_league_id_spring=_integer, 47 | away_loss=_integer, 48 | away_name_abbrev=_string, 49 | away_split_squad=_string, 50 | away_sport_code=_string, 51 | away_team_city=_string, 52 | away_team_id=_integer, 53 | away_team_name=_string, 54 | away_time=_string, 55 | away_time_zone=_string, 56 | away_win=_integer, 57 | day=_string, 58 | description=_string, 59 | double_header_sw=_string, 60 | first_pitch_et=_string, 61 | game_data_directory=_string, 62 | game_nbr=_integer, 63 | game_pk=_integer, 64 | game_type=_string, 65 | gameday=_string, 66 | gameday_sw=_string, 67 | hm_lg_ampm=_string, 68 | home_ampm=_string, 69 | home_code=_string, 70 | home_division=_string, 71 | home_file_code=_string, 72 | home_games_back=_string, 73 | home_games_back_wildcard=_string, 74 | home_league_id=_integer, 75 | home_league_id_spring=_integer, 76 | home_loss=_integer, 77 | home_name_abbrev=_string, 78 | home_split_squad=_string, 79 | home_sport_code=_string, 80 | home_team_city=_string, 81 | home_team_id=_integer, 82 | home_team_name=_string, 83 | home_time=_string, 84 | home_time_zone=_string, 85 | home_win=_integer, 86 | id=_string, 87 | if_necessary=_string, 88 | league=_string, 89 | location=_string, 90 | original_date=_string, 91 | resume_ampm=_string, 92 | resume_away_ampm=_string, 93 | resume_away_time=_string, 94 | resume_date=_string, 95 | resume_home_ampm=_string, 96 | resume_home_time=_string, 97 | resume_time=_string, 98 | resume_time_date=_string, 99 | resume_time_date_aw_lg=_string, 100 | resume_time_date_hm_lg=_string, 101 | scheduled_innings=_integer, 102 | ser_games=_integer, 103 | ser_home_nbr=_integer, 104 | series=_string, 105 | series_num=_string, 106 | tbd_flag=_string, 107 | tiebreaker_sw=_string, 108 | time=_string, 109 | time_aw_lg=_string, 110 | time_date=_string, 111 | time_date_aw_lg=_string, 112 | time_date_hm_lg=_string, 113 | time_hm_lg=_string, 114 | time_zone=_string, 115 | time_zone_aw_lg=_integer, 116 | time_zone_hm_lg=_integer, 117 | tz_aw_lg_gen=_string, 118 | tz_hm_lg_gen=_string, 119 | venue=_string, 120 | venue_id=_integer, 121 | venue_w_chan_loc=_string, 122 | b=_string, 123 | ind=_string, 124 | inning=_integer, 125 | inning_state=_string, 126 | is_no_hitter=_string, 127 | is_perfect_game=_string, 128 | note=_string, 129 | o=_integer, 130 | reason=_string, 131 | s=_integer, 132 | status=_string, 133 | top_inning=_string) 134 | 135 | def _getItems(self, d): 136 | '''Doc string''' 137 | 138 | items = [] 139 | itemKeys = [] 140 | 141 | r = requests.get(_baseURL.format(dailyScoreboard, 142 | yyyy=d.strftime('%Y'), 143 | mm=d.strftime('%m'), 144 | dd=d.strftime('%d'))) 145 | if r.status_code != 200: 146 | return (items, itemKeys) 147 | 148 | tree = ET.parse(io.StringIO(r.text)) 149 | root = tree.getroot() 150 | games = root.findall('game') 151 | rowDict1 = dict.fromkeys(self._tblDTypes.keys()) 152 | 153 | for game in games: 154 | itemKey = int(game.attrib['game_pk']) 155 | rowDict2 = rowDict1.copy() 156 | rowDict2.update(game.attrib) 157 | status = game.find('status') 158 | if status: 159 | rowDict2.update(status.attrib) 160 | df = pd.DataFrame(rowDict2, index=(0,)) 161 | 162 | itemKeys.append(itemKey) 163 | items.append(df) 164 | 165 | return (items, itemKeys) 166 | -------------------------------------------------------------------------------- /statcast/database/gd_weather.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import xml.etree.ElementTree as ET 3 | 4 | import pandas as pd 5 | import sqlalchemy as sa 6 | 7 | from .gddb import GdDatabase 8 | 9 | 10 | _string = sa.types.String 11 | _integer = sa.types.Integer 12 | _float = sa.types.Float 13 | _date = sa.types.Date 14 | 15 | 16 | class DB(GdDatabase): 17 | '''Doc String''' 18 | 19 | dbName = 'gdWeather' 20 | startDate = dt.date(2008, 1, 1) 21 | _fileName = 'plays.xml' 22 | _tblDTypes = dict( 23 | condition=_string, 24 | temp=_integer, 25 | wind=_string, 26 | game_pk=_integer) 27 | 28 | def _parseFile(self, file, itemKey): 29 | '''Doc string''' 30 | 31 | tree = ET.parse(file) 32 | root = tree.getroot() 33 | weather = root.find('weather').attrib 34 | 35 | df = pd.DataFrame({'condition': weather['condition'], 36 | 'temp': int(weather['temp']), 37 | 'wind': weather['wind'], 38 | self._itemKeyName: itemKey}, index=(0,)) 39 | 40 | return df 41 | -------------------------------------------------------------------------------- /statcast/database/gddb.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import time 3 | import io 4 | 5 | import xml.etree.ElementTree as ET 6 | 7 | import requests 8 | 9 | from .database import Database 10 | 11 | 12 | _baseurl = \ 13 | 'http://gd2.mlb.com/components/game/mlb/year_{yyyy}/month_{mm}/day_{dd}/{}' 14 | 15 | dailyScoreboard = 'master_scoreboard.xml' 16 | 17 | 18 | class GdDatabase(Database, metaclass=abc.ABCMeta): 19 | '''Doc String''' 20 | 21 | _itemKeyName = 'game_pk' 22 | _username = 'matt' 23 | _password = 'gratitude' 24 | _host = 'baseball.cxx9lqfsabek.us-west-2.rds.amazonaws.com' 25 | _port = 5432 26 | _drivername = 'postgresql' 27 | 28 | @abc.abstractmethod 29 | def _fileName(): 30 | pass 31 | 32 | @abc.abstractmethod 33 | def _parseFile(self, file, itemKey): 34 | pass 35 | 36 | def _getItems(self, d): 37 | '''Doc string''' 38 | 39 | items = [] 40 | itemKeys = [] 41 | url = _baseurl.format(dailyScoreboard, 42 | yyyy=d.strftime('%Y'), 43 | mm=d.strftime('%m'), 44 | dd=d.strftime('%d')) 45 | 46 | for dummy in range(100): 47 | try: 48 | r1 = requests.get(url) 49 | except Exception as e: 50 | self.logger.debug( 51 | '{!r} occurred while trying to dowload {}.'. 52 | format(e, url)) 53 | time.sleep(5) 54 | else: 55 | break 56 | else: 57 | self.logger.error( 58 | 'Unable to download {} after {} attempts.'. 59 | format(url, dummy + 1)) 60 | return (items, itemKeys) 61 | 62 | if r1.status_code != 200: 63 | return (items, itemKeys) 64 | 65 | tree = ET.parse(io.StringIO(r1.text)) 66 | root = tree.getroot() 67 | games = root.findall('game') 68 | 69 | for game in games: 70 | itemKey = int(game.attrib['game_pk']) 71 | gid = game.attrib['gameday'] 72 | url = _baseurl.format('gid_' + gid + '/' + self._fileName, 73 | yyyy=d.strftime('%Y'), 74 | mm=d.strftime('%m'), 75 | dd=d.strftime('%d')) 76 | 77 | for dummy in range(100): 78 | try: 79 | r2 = requests.get(url) 80 | except Exception as e: 81 | self.logger.debug( 82 | '{!r} occurred while trying to dowload {}.'. 83 | format(e, url)) 84 | time.sleep(5) 85 | else: 86 | break 87 | else: 88 | self.logger.error( 89 | 'Unable to download {} after {} attempts.'. 90 | format(url, dummy + 1)) 91 | continue 92 | 93 | if r2.status_code != 200: 94 | try: 95 | status = game.find('status').attrib['status'] 96 | except: 97 | status = None 98 | 99 | gidParts = gid.rsplit('_', 3) 100 | gameDate = gidParts[0] 101 | awayLg = gidParts[1][3:] 102 | awayTm = gidParts[1][:3] 103 | homeLg = gidParts[2][3:] 104 | homeTm = gidParts[2][:3] 105 | 106 | if not awayLg == homeLg == 'mlb': 107 | self.logger.info( 108 | ''' 109 | Received {} status code while trying to retrieve {} for gid = {}, 110 | game_pk = {} at address {}, determined game involved non-MLB team'''. 111 | format(r2.status_code, self._fileName, gid, itemKey, 112 | url)) 113 | elif awayTm == homeTm: 114 | self.logger.info( 115 | ''' 116 | Received {} status code while trying to retrieve {} for gid = {}, 117 | game_pk = {} at address {}, determined game is intra-squad'''. 118 | format(r2.status_code, self._fileName, gid, itemKey, 119 | url)) 120 | elif not d.strftime('%Y_%m_%d') == gameDate: 121 | self.logger.info( 122 | ''' 123 | Received {} status code while trying to retrieve {} for gid = {}, 124 | game_pk = {} at address {}, determined game occurred on different date'''. 125 | format(r2.status_code, self._fileName, gid, itemKey, 126 | url)) 127 | elif status is None: 128 | self.logger.warning( 129 | ''' 130 | Received {} status code while trying to retrieve {} for gid = {}, 131 | game_pk = {} at address {}, could not determine game status'''. 132 | format(r2.status_code, self._fileName, gid, itemKey, 133 | url)) 134 | elif not status == 'Final': 135 | self.logger.info( 136 | ''' 137 | Received {} status code while trying to retrieve {} for gid = {}, 138 | game_pk = {} at address {}, determined game status was {}'''. 139 | format(r2.status_code, self._fileName, gid, itemKey, 140 | url, status)) 141 | else: 142 | self.logger.warning( 143 | ''' 144 | Received {} status code while trying to retrieve {} for gid = {}, 145 | game_pk = {} at address {}, could not determine cause'''. 146 | format(r2.status_code, self._fileName, gid, itemKey, 147 | url)) 148 | continue 149 | 150 | itemKeys.append(itemKey) 151 | items.append(self._parseFile(io.StringIO(r2.text), itemKey)) 152 | 153 | return (items, itemKeys) 154 | -------------------------------------------------------------------------------- /statcast/plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .tools.fixpath import findFile 4 | from .tools.plot import plotImages 5 | 6 | from . import __path__ 7 | 8 | _logoPath = os.path.join(__path__[0], 'data', 'logos') 9 | 10 | 11 | def plotMLBLogos(X, Y, sizes=20, alphas=1, ax=None): 12 | '''Doc String''' 13 | 14 | images = [findFile('{}.png'.format(team.strip()), searchDirs=_logoPath) 15 | for team in X.index] 16 | thing = plotImages(X, Y, images, sizes, alphas, ax) 17 | if ax is None: 18 | ax = thing.gca() 19 | ax.set_xlabel(X.name) 20 | ax.set_ylabel(Y.name) 21 | return thing 22 | -------------------------------------------------------------------------------- /statcast/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matteosox/statcast/4b709fcd26532d9d609e7f6d7086660d8f35de8a/statcast/tools/__init__.py -------------------------------------------------------------------------------- /statcast/tools/convolution.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def conv(f, xf, g, xg, mode='full', dx=None): 5 | '''Doc String''' 6 | 7 | dxf, dxg = np.diff(xf[:2]), np.diff(xg[:2]) 8 | if not dx: 9 | if dxf == dxg: 10 | dx = dxf 11 | else: 12 | dx = dxf * dxg 13 | 14 | xfr = np.arange(xf[0], xf[-1] + dx / 2, dx) 15 | xgr = np.arange(xg[0], xg[-1] + dx / 2, dx) 16 | 17 | fr = np.interp(xfr, xf, f, right=0) 18 | gr = np.interp(xgr, xg, g, right=0) 19 | 20 | h = np.convolve(fr, gr, mode=mode) * dx 21 | xh = np.arange(xfr[0] + xgr[0], xfr[-1] + xgr[-1] + dx / 2, dx) 22 | if mode == 'same': 23 | skip = min(len(xfr), len(xgr)) // 2 24 | l = max(len(xfr), len(xgr)) 25 | xh = xh[skip:(skip + l)] 26 | elif mode == 'valid': 27 | skip = min(len(xfr), len(xgr)) - 1 28 | xh = xh[skip:(-1 - skip)] 29 | return h, xh 30 | -------------------------------------------------------------------------------- /statcast/tools/fixpath.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from pathlib import Path 4 | 5 | 6 | _basePath = sys.path.copy() 7 | 8 | 9 | def resetPath(): 10 | '''Doc String''' 11 | 12 | sys.path = _basePath.copy() 13 | 14 | 15 | def findFile(fileName, searchDirs=sys.path, subs=False, 16 | findAll=False): 17 | '''Doc String''' 18 | 19 | if searchDirs is None: 20 | searchDirs = [Path.cwd().anchor] 21 | subs = True 22 | elif isinstance(searchDirs, str): 23 | searchDirs = [searchDirs] 24 | filePaths = [] 25 | for searchDir in searchDirs: 26 | for root, subDirs, files in os.walk(os.path.abspath(searchDir)): 27 | if fileName in files: 28 | filePath = os.path.join(root, fileName) 29 | if not findAll: 30 | return filePath 31 | filePaths.append(filePath) 32 | if subs: 33 | absSubDirs = [os.path.join(root, subDir) 34 | for subDir in subDirs] 35 | filePaths.extend(findFile(fileName, 36 | searchPaths=absSubDirs, 37 | subs=subs, 38 | findAll=findAll)) 39 | if findAll: 40 | return filePaths 41 | raise FileNotFoundError('Could not find {}'.format(fileName)) 42 | 43 | 44 | def addPath(path=Path.cwd(), subs=True, reset=False): 45 | '''Doc String''' 46 | 47 | if reset: 48 | resetPath() 49 | if isinstance(path, (tuple, list)): 50 | for p in reversed(path): 51 | addPath(path=p, subs=subs) 52 | return 53 | 54 | path = str(path) 55 | if path in sys.path: 56 | sys.path.remove(path) 57 | 58 | if subs: 59 | for child in Path(path).iterdir(): 60 | if child.is_dir(): 61 | addPath(child, subs=subs) 62 | sys.path.insert(0, str(path)) 63 | -------------------------------------------------------------------------------- /statcast/tools/montecarlo.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import numpy as np 3 | 4 | eps = np.sqrt(np.spacing(1)) 5 | 6 | 7 | def rectSampler(lims): 8 | '''Doc String''' 9 | 10 | x0 = np.array([lim[0] for lim in lims]) 11 | d = np.array([lim[1] - lim[0] for lim in lims]) 12 | ndim = x0.size 13 | vol = d.prod() 14 | 15 | return lambda n: ( 16 | np.random.rand(n, ndim) * np.tile(d, (n, 1)) + np.tile(x0, (n, 1)) 17 | ), vol 18 | 19 | 20 | def integrate(f, sampler, vol, relErr=0, absErr=eps, n0=100): 21 | '''Doc String''' 22 | 23 | X = sampler(n0) 24 | y = f(X) 25 | 26 | while True: 27 | mu = y.mean() 28 | std = np.std(y, ddof=1) 29 | i = vol * mu 30 | se = vol * std / np.sqrt(y.size) 31 | if se <= absErr: 32 | break 33 | if i != 0: 34 | if np.abs(se / i) <= relErr: 35 | break 36 | 37 | nA = (vol * std / absErr) ** 2 38 | if relErr > 0: 39 | nR = (std / relErr / mu) ** 2 40 | else: 41 | nR = nA + 1 42 | n = np.ceil(min(nA, nR) - y.size).astype(int) 43 | X = sampler(n) 44 | y = np.concatenate([y, f(X)]) 45 | 46 | return i, se 47 | 48 | 49 | class Region(): 50 | '''Doc String''' 51 | 52 | def __init__(self, f, lims, X=None, y=None): 53 | '''Doc String''' 54 | 55 | self.X = X 56 | self.y = y 57 | 58 | self.lims = lims 59 | self.f = f 60 | self.sampler, self.vol = rectSampler(lims) 61 | 62 | def sample(self, n): 63 | '''Doc String''' 64 | 65 | if n == 0: 66 | return 67 | X = self.sampler(n) 68 | y = self.f(X) 69 | 70 | if self.X is not None: 71 | self.X = np.concatenate([self.X, X]) 72 | else: 73 | self.X = X 74 | 75 | if self.y is not None: 76 | self.y = np.concatenate([self.y, y]) 77 | else: 78 | self.y = y 79 | 80 | @property 81 | def integral(self): 82 | return self.y.mean() * self.vol 83 | 84 | def se(self, n=None): 85 | if n is None: 86 | n = self.y.size 87 | return self.vol * self.y.std(ddof=1) / np.sqrt(n) 88 | 89 | def seImprovement(self, n): 90 | return self.se() - self.se(self.y.size + n) 91 | 92 | def split(self, n): 93 | '''Doc String''' 94 | 95 | variance = np.inf 96 | nOld = self.y.size 97 | nNew = n + nOld 98 | 99 | for i, (xMin, xMax) in enumerate(self.lims): 100 | xMid = (xMin + xMax) / 2 101 | 102 | inds = self.X[:, i] <= xMid 103 | lStd = self.y[inds].std(ddof=1) 104 | rStd = self.y[~inds].std(ddof=1) 105 | lNOld = inds.sum() 106 | rNOld = nOld - lNOld 107 | lNNew = min(max(np.round(lStd / (lStd + rStd) * nNew).astype(int), 108 | lNOld), lNOld + n) 109 | rNNew = nNew - lNNew 110 | newVar = lStd ** 2 / (4 * lNNew) + rStd ** 2 / (4 * rNNew) 111 | if newVar < variance: 112 | variance = newVar 113 | splitInd = i 114 | lNAdd = lNNew - lNOld 115 | rNAdd = rNNew - rNOld 116 | 117 | lLims = deepcopy(self.lims) 118 | rLims = deepcopy(self.lims) 119 | xMid = sum(self.lims[splitInd]) / 2 120 | lLims[splitInd][1] = xMid 121 | rLims[splitInd][0] = xMid 122 | inds = self.X[:, splitInd] <= xMid 123 | lX = self.X[inds, :] 124 | ly = self.y[inds] 125 | rX = self.X[~inds, :] 126 | ry = self.y[~inds] 127 | 128 | lRegion = Region(self.f, lLims, lX, ly) 129 | lRegion.sample(lNAdd) 130 | rRegion = Region(self.f, rLims, rX, ry) 131 | rRegion.sample(rNAdd) 132 | 133 | return lRegion, rRegion 134 | 135 | 136 | def stratifiedIntegrate(f, lims, relErr=0, absErr=eps, n=100): 137 | '''Doc String''' 138 | 139 | region = Region(f, lims) 140 | region.sample(4 * n) 141 | regions = [region] 142 | se2s = [region.se() ** 2 for region in regions] 143 | ints = [region.integral for region in regions] 144 | seImps = [region.seImprovement(2 * n) for region in regions] 145 | 146 | while True: 147 | i = sum(ints) 148 | se = np.sqrt(sum(se2s)) 149 | if se <= absErr: 150 | break 151 | if i != 0: 152 | if np.abs(se / i) <= relErr: 153 | break 154 | ind = np.argmax(seImps) 155 | region = regions.pop(ind) 156 | del se2s[ind] 157 | del ints[ind] 158 | del seImps[ind] 159 | newRegions = region.split(2 * n) 160 | regions.extend(newRegions) 161 | se2s.extend([newRegion.se() ** 2 for newRegion in newRegions]) 162 | ints.extend([newRegion.integral for newRegion in newRegions]) 163 | seImps.extend([newRegion.seImprovement(2 * n) 164 | for newRegion in newRegions]) 165 | 166 | return i, se 167 | -------------------------------------------------------------------------------- /statcast/tools/plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from scipy import stats 6 | from matplotlib import pyplot as plt 7 | from matplotlib import lines as mlines 8 | 9 | try: 10 | from PIL import Image as pilimg 11 | except ImportError: 12 | imaging = False 13 | else: 14 | imaging = True 15 | 16 | from sklearn.metrics import precision_recall_curve 17 | 18 | from ..better.kde import BetterKernelDensity 19 | 20 | from . import __path__ 21 | 22 | plt.style.use(os.path.join(os.path.dirname(__path__[0]), 23 | 'data', 'blackontrans.mplstyle')) 24 | 25 | 26 | def correlationPlot(Y, Yp, labels=None, units=None, **plotParams): 27 | '''Doc String''' 28 | 29 | # Handle Pandas DataFrames 30 | if isinstance(Y, pd.DataFrame): 31 | if labels is None: 32 | labels = list(Y.columns) 33 | Y = Y.values 34 | if isinstance(Yp, pd.DataFrame): 35 | Yp = Yp.values 36 | 37 | # Handle 1D arrays 38 | if Y.ndim == 1: 39 | Y = Y[:, None] 40 | if Yp.ndim == 1: 41 | Yp = Yp[:, None] 42 | 43 | # Handle row vectors 44 | if Y.shape[0] == 1: 45 | Y = Y.T 46 | if Yp.shape[0] == 1: 47 | Yp = Yp.T 48 | 49 | # Handle no label or unit inputs 50 | if labels is None: 51 | labels = [None for dummy in range(Y.shape[1])] 52 | if units is None: 53 | units = [None for dummy in range(Y.shape[1])] 54 | 55 | figs = [] 56 | 57 | for y, yp, label, unit in zip(Y.T, Yp.T, 58 | labels, units): 59 | rmsErr = np.sqrt(np.mean((y - yp) ** 2)) 60 | r2 = stats.pearsonr(y, yp)[0] ** 2 61 | mae = np.mean(np.abs(y - yp)) 62 | 63 | fig = plt.figure() 64 | ax = fig.add_subplot(1, 1, 1) 65 | ax.plot(y, yp, '.', **plotParams) 66 | 67 | if unit is not None: 68 | ax.set_xlabel('Actual ({})'.format(unit)) 69 | ax.set_ylabel('Prediction ({})'.format(unit)) 70 | else: 71 | ax.set_xlabel('Actual') 72 | ax.set_ylabel('Prediction') 73 | 74 | if label is not None: 75 | ax.set_title(label) 76 | 77 | axLims = list(ax.axis()) 78 | axLims[0] = axLims[2] = min(axLims[0::2]) 79 | axLims[1] = axLims[3] = max(axLims[1::2]) 80 | ax.axis(axLims) 81 | ax.plot(axLims[:2], axLims[2:], 82 | '--', color=plt.rcParams['lines.color'], linewidth=1) 83 | 84 | labels = ['{}: {:.2f}'.format(name, stat) 85 | for name, stat in zip(['RMSE', 'R2', 'MAE'], 86 | [rmsErr, r2, mae])] 87 | addText(ax, labels, loc='lower right') 88 | 89 | figs.append(fig) 90 | 91 | return figs 92 | 93 | 94 | def addText(ax, text, loc='best', **kwargs): 95 | '''Doc String''' 96 | 97 | if 'right' in loc: 98 | markerfirst = False 99 | else: 100 | markerfirst = True 101 | 102 | handles = [mlines.Line2D([], [], alpha=0.0)] * len(text) 103 | ax.legend(handles=handles, labels=text, loc=loc, frameon=False, 104 | handlelength=0, handletextpad=0, markerfirst=markerfirst, 105 | **kwargs) 106 | return 107 | 108 | 109 | def plotKDHist(data, kernel='epanechnikov', bandwidth=None, alpha=5e-2, 110 | ax=None, n_jobs=1, cv=None): 111 | '''Doc String''' 112 | 113 | if data.ndim < 2: 114 | data = data[:, None] 115 | xmin, xmax = min(data), max(data) 116 | 117 | if bandwidth is None: 118 | kde = BetterKernelDensity(kernel=kernel, rtol=1e-4, 119 | normalize=False).fit(data) 120 | kde.selectBandwidth(n_jobs=n_jobs, cv=cv) 121 | else: 122 | kde = BetterKernelDensity(kernel=kernel, rtol=1e-4, 123 | normalize=False, 124 | bandwidth=bandwidth).fit(data) 125 | 126 | xFit = np.linspace(xmin - kde.bandwidth, xmax + kde.bandwidth, 1e3) 127 | yFit = kde.predict(xFit[:, None]) 128 | fitL, fitU = kde.confidence(xFit[:, None], alpha=alpha) 129 | 130 | if ax is not None: 131 | ax.plot(xFit, yFit * 100) 132 | ax.fill_between(xFit, fitU * 100, fitL * 100, alpha=0.35, lw=0, 133 | label='{:.0f}% Confidence Interval'. 134 | format(1e2 * (1 - alpha))) 135 | return ax, kde 136 | 137 | fig = plt.figure() 138 | ax = fig.add_subplot(1, 1, 1) 139 | ax.plot(xFit, yFit * 100) 140 | ax.fill_between(xFit, fitU * 100, fitL * 100, alpha=0.35, lw=0, 141 | label='{:.0f}% Confidence Interval'. 142 | format(1e2 * (1 - alpha))) 143 | 144 | try: 145 | ax.set_xlabel(data.name) 146 | except AttributeError: 147 | pass 148 | 149 | ax.set_ylabel('Probability Density (%)') 150 | ax.set_xlim(left=xmin - kde.bandwidth, right=xmax + kde.bandwidth) 151 | ax.set_ylim(bottom=0, auto=True) 152 | return fig, kde 153 | 154 | 155 | def plotPrecRec(y, yp, ax=None, label=None): 156 | '''Doc String''' 157 | 158 | if ax is None: 159 | fig = plt.figure() 160 | ax = fig.add_subplot(1, 1, 1) 161 | 162 | if y.shape[0] != yp.shape[0]: 163 | raise Exception('Number of rows in y & yp must match') 164 | 165 | prec, rec, _ = precision_recall_curve(y, yp) 166 | 167 | ax.plot(rec, prec, label=label) 168 | ax.set_xlabel('Recall') 169 | ax.set_ylabel('Precision') 170 | ax.set_xlim(0, 1) 171 | ax.set_ylim(0, 1) 172 | 173 | try: 174 | return fig 175 | except NameError: 176 | return ax 177 | 178 | 179 | def plotPrecRecMN(y, yp, ax=None, labels=None): 180 | '''Doc String''' 181 | 182 | if ax is None: 183 | fig = plt.figure() 184 | ax = fig.add_subplot(1, 1, 1) 185 | 186 | if y is pd.api.types.CategoricalDtype(): 187 | classes = sorted(y.cat.categories) 188 | else: 189 | classes = np.unique(y) 190 | 191 | if labels is None: 192 | labels = classes 193 | 194 | if not classes.shape[0] == yp.shape[1] == len(labels): 195 | raise Exception('Number of classes in y must match number of ' 196 | 'columns in yp and number labels.') 197 | elif y.shape[0] != yp.shape[0]: 198 | raise Exception('Number of rows in y & yp must match.') 199 | 200 | prec = {} 201 | rec = {} 202 | Y = [] 203 | 204 | for i, klass in enumerate(classes): 205 | yi = y == klass 206 | Y.append(yi) 207 | prec[klass], rec[klass], _ = precision_recall_curve(yi, yp[:, i]) 208 | 209 | Y = np.array(Y) 210 | prec['micro'], rec['micro'], _ = \ 211 | precision_recall_curve(Y.ravel(), yp.ravel()) 212 | 213 | ax.plot(rec['micro'] * 100, prec['micro'] * 100, label='micro-average') 214 | 215 | for klass, label in zip(classes, labels): 216 | ax.plot(rec[klass] * 100, prec[klass] * 100, label=label) 217 | 218 | ax.legend() 219 | ax.set_xlabel('Recall (%)') 220 | ax.set_ylabel('Precision (%)') 221 | ax.set_xlim(0, 100) 222 | ax.set_ylim(0, 100) 223 | 224 | try: 225 | return fig 226 | except NameError: 227 | return ax 228 | 229 | 230 | def plotResiduals(X, Y, Yp, xLabels=None, xUnits=None, yLabels=None, 231 | yUnits=None, pltParams={}): 232 | 233 | if X.ndim == 1: 234 | X = X[:, None] 235 | 236 | if Y.ndim == 1: 237 | Y = Y[:, None] 238 | 239 | if Yp.ndim == 1: 240 | Yp = Yp[:, None] 241 | 242 | if Y.shape != Yp.shape: 243 | raise Exception('Y ({}) and Yp ({}) ' 244 | 'must have same dimensions'.format(Y.shape, Yp.shape)) 245 | 246 | if X.shape[0] != Y.shape[0]: 247 | raise Exception('X ({}) and Y ({}) must have same number of ' 248 | 'rows'.format(X.shape[0], Y.shape[0])) 249 | 250 | if xLabels is None: 251 | try: 252 | xLabels = list(X.columns) 253 | except AttributeError: 254 | xLabels = [''] * X.shape[1] 255 | elif len(xLabels) != X.shape[1]: 256 | raise Exception('xLabels must have same length ({}) as the number of ' 257 | 'columns of X ({})'.format(len(xLabels), X.shape[1])) 258 | else: 259 | xLabels = list(xLabels) 260 | 261 | if xUnits is not None: 262 | if len(xUnits) != len(xLabels): 263 | raise Exception('xUnits must have same length ({}) ' 264 | 'as xLabels ({}) if supplied'.format(len(xUnits), 265 | len(xLabels))) 266 | for i, xUnit in enumerate(xUnits): 267 | xLabels[i] += ' ({})'.format(xUnit) 268 | 269 | if yLabels is None: 270 | try: 271 | yLabels = list(Y.columns) 272 | except AttributeError: 273 | yLabels = [''] * Y.shape[1] 274 | elif len(yLabels) != Y.shape[1]: 275 | raise Exception('yLabels must have same length ({}) as the number of ' 276 | 'columns of Y ({})'.format(len(yLabels), Y.shape[1])) 277 | else: 278 | yLabels = list(yLabels) 279 | 280 | yLabels = [yLabel + ' Error' for yLabel in yLabels] 281 | 282 | if yUnits is not None: 283 | if len(yUnits) != len(yLabels): 284 | raise Exception('yUnits must have same length ({}) ' 285 | 'as yLabels ({}) if supplied'.format(len(yUnits), 286 | len(yLabels))) 287 | for i, yUnit in enumerate(yUnits): 288 | yLabels[i] += ' ({})'.format(yUnit) 289 | 290 | figs = [] 291 | for i, xLabel in enumerate(xLabels): 292 | fig = plt.figure(figsize=(10.21, 3 * len(yLabels))) 293 | for j, yLabel in enumerate(yLabels): 294 | ax = fig.add_subplot(len(yLabels), 1, j + 1) 295 | ax.plot(X[:, i], Yp[:, j] - Y[:, j], '.', **pltParams) 296 | ax.set_ylabel(yLabel) 297 | ax.set_xlabel(xLabel) 298 | figs.append(fig) 299 | 300 | return figs 301 | 302 | 303 | if imaging: 304 | def plotImages(X, Y, images, sizes=20, alphas=1, ax=None): 305 | '''Doc String''' 306 | 307 | if not isinstance(sizes, (list, tuple)): 308 | sizes = (sizes,) * len(X) 309 | if not isinstance(images, (list, tuple)): 310 | images = (images,) * len(X) 311 | if not isinstance(alphas, (list, tuple)): 312 | alphas = (alphas,) * len(X) 313 | 314 | ims = [pilimg.open(image) for image in images] 315 | 316 | if ax is None: 317 | fig = plt.figure() 318 | ax = fig.add_subplot(1, 1, 1) 319 | 320 | # plot and set axis limits 321 | ax.plot(X, Y, 'o', mfc='None', mec='None', markersize=max(sizes)) 322 | ax.axis(ax.axis()) 323 | 324 | for size, im, alpha, x, y in zip(sizes, ims, alphas, X, Y): 325 | offsetsPx = np.array([sz / max(im.size) * size / 72 / 2 * 326 | ax.get_figure().dpi for sz in im.size]) 327 | pxPerUnit = ax.transData.transform((1, 1)) - \ 328 | ax.transData.transform((0, 0)) 329 | offsetsUnit = offsetsPx / pxPerUnit 330 | extent = (x - offsetsUnit[0], x + offsetsUnit[0], 331 | y - offsetsUnit[1], y + offsetsUnit[1]) 332 | ax.imshow(im, alpha=alpha, extent=extent, aspect='auto', 333 | interpolation='bilinear') 334 | 335 | try: 336 | return fig 337 | except NameError: 338 | return ax 339 | --------------------------------------------------------------------------------