├── README.md ├── example.py └── pystan_sklearn └── __init__.py /README.md: -------------------------------------------------------------------------------- 1 | # pystan-sklearn 2 | 3 | This module provides a scikit-learn estimator class based on Stan (http://github.com/stan-dev/pystan). 4 | This allows all of the functionality of scikit-learn to be used in the fitting and checking of Stan models. 5 | 6 | Run example.py from the root pystan-sklearn directory for an example of a grid search over the 'mu' hyperparameter in the 'Eight Schools' example. 7 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import norm 3 | from sklearn.model_selection import ShuffleSplit,GridSearchCV 4 | from pystan_sklearn import StanEstimator 5 | 6 | ############################################################# 7 | # All of this from the eight schools example. 8 | schools_code = """ 9 | data { 10 | int J; // number of schools 11 | real y[J]; // estimated treatment effects 12 | real sigma[J]; // s.e. of effect estimates 13 | } 14 | parameters { 15 | real mu; 16 | real tau; 17 | real eta[J]; 18 | } 19 | transformed parameters { 20 | real theta[J]; 21 | for (j in 1:J) 22 | theta[j] = mu + tau * eta[j]; 23 | } 24 | model { 25 | eta ~ normal(0, 1); 26 | y ~ normal(theta, sigma); 27 | } 28 | """ 29 | 30 | schools_dat = {'J': 8, 31 | 'y': [28, 8, -3, 7, -1, 1, 18, 12], 32 | 'sigma': [15, 10, 16, 11, 9, 11, 10, 18]} 33 | ############################################################# 34 | 35 | # First we have to make an estimator specific to our model. 36 | # For now, I don't have a good way of automatically implementing this 37 | # in a general way based on the model code. 38 | class EightSchoolsEstimator(StanEstimator): 39 | # Implement a make_data method for the estimator. 40 | # This tells the sklearn estimator what things to pass along 41 | # as data to the Stan model. 42 | # This is trivial here but can be more complex for larger models. 43 | def make_data(self,search_data=None): 44 | data = schools_dat 45 | if search_data: 46 | data.update({key:value[0] for key,value in search_data.items()}) 47 | return data 48 | 49 | # Implement a predict_ method for the estimator. 50 | # This tells the sklearn estimator how to make a prediction for one sample. 51 | # This is based on the prediction for the mean theta above. 52 | def predict_(self,X,j): 53 | theta_j = self.mu + self.tau * self.eta[j]; 54 | return (theta_j,self.sigma[j]) 55 | 56 | # Implement a score_ method for the estimator. 57 | # This tells the sklearn estimator how to score one observed sample against 58 | # the prediction from the model. 59 | # It is based on the fitted values of theta and sigma. 60 | def score_(self,prediction,y): 61 | likelihoods = np.zeros(len(y)) 62 | for j,(theta_j,sigma_j) in enumerate(prediction): 63 | likelihoods[j] = norm.pdf(y[j],theta_j,sigma_j) 64 | return np.log(likelihoods).sum() 65 | 66 | # Initialize StanEstimator instance. 67 | estimator = EightSchoolsEstimator() 68 | # Compile the model code. 69 | estimator.set_model(schools_code) 70 | 71 | # Search over these parameter values. 72 | search_data = {'mu':[0.3,1.0,3.0]} 73 | # Create a data dictionary for use with the estimator. 74 | # Note that this 'data' means different things in sklearn and Stan. 75 | data = estimator.make_data(search_data=search_data) 76 | # Set the data (set estimator attributes). 77 | estimator.set_data(data) 78 | 79 | # Set the y data. 80 | # Use the observed effect from the Stan code here (e.g. "y"). 81 | y = data['y'] 82 | # Set the X data, i.e. the covariates. 83 | # In this example there is no X data so we just use an array of ones. 84 | X = np.ones((len(y),1)) 85 | #vstack((data['subject_ids'],data['test_ids'])).transpose() 86 | 87 | # Fraction of data held out for testing. 88 | test_size = 2.0/len(y) 89 | # A cross-validation class from sklearn. 90 | # Use the sample size variable from the Stan code here (e.g. "J"). 91 | cv = ShuffleSplit(n_splits=10, test_size=test_size) 92 | # A grid search class over parameters from sklearn. 93 | grid = GridSearchCV(estimator, search_data, cv=cv) 94 | 95 | # Fit the model over the parameter grid. 96 | grid.fit(X,y) 97 | 98 | # Print the parameter values with the best scores (best predictive accuracy). 99 | print(grid.best_params_) 100 | -------------------------------------------------------------------------------- /pystan_sklearn/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pystan 3 | from sklearn.base import BaseEstimator 4 | 5 | class StanModel_(pystan.StanModel): 6 | def __del__(self): 7 | """ 8 | This method is being used carelessly in sklearn's GridSearchCV class, 9 | creating and destroying copies of the estimator, which is causing the 10 | directory containing the compiled Stan code to be deleted. It is 11 | replaced here with an empty method to avoid this problem. This means 12 | a potential proliferation of temporary directories. 13 | """ 14 | pass 15 | 16 | class StanEstimator(BaseEstimator): 17 | """ 18 | A new sklearn estimator class derived for use with pystan. 19 | """ 20 | def __init__(self, **kwargs): 21 | for key,value in kwargs.items(): 22 | setattr(self,key,value) 23 | 24 | def set_model(self, code): 25 | """ 26 | Sets and compiles a Stan model for this estimator. 27 | """ 28 | self.model = StanModel_(model_code=code) 29 | 30 | def set_data(self, *args, **kwargs): 31 | """ 32 | Sets the data for use with this estimator. 33 | Uses the 'data' keyword argument if provided, else it 34 | uses the 'make_data' method. 35 | """ 36 | if 'data' not in kwargs or kwargs[data] is None: 37 | data = self.make_data() 38 | else: 39 | data = kwargs[data] 40 | self.data = data 41 | for key,value in data.items(): 42 | setattr(self,key,value) 43 | 44 | def make_data(self, *args, **kwargs): 45 | """ 46 | A model-specific method for constructing the data to be used by 47 | the model. May be limited to the data passed to the Stan model's 48 | fitter, or may also include other items as well. Should return 49 | a dictionary.""" 50 | raise NotImplementedError("") 51 | 52 | def optimize(self, X, y): 53 | """ 54 | Optimizes the estimator based on covariates X and observations y. 55 | """ 56 | for key in self.data.keys(): 57 | self.data[key] = getattr(self,key) 58 | self.best = self.model.optimizing(data=self.data) 59 | 60 | def get_params(self, deep=False): 61 | """ 62 | Gets model parameters. These are just attributes of the estimator 63 | as set in __init__ and possibly in other methods. 64 | """ 65 | return self.__dict__ 66 | 67 | def fit(self, X, y): 68 | """ 69 | Fits the estimator based on covariates X and observations y. 70 | """ 71 | print(X.shape,len(y)) 72 | self.optimize(X,y) 73 | for key,value in self.best.items(): 74 | setattr(self,key,value) 75 | 76 | def transform(self, X, y=None, **fit_params): 77 | """ 78 | Performs a transform step on the covariates after fitting. 79 | In the basic form here it just returns the covariates. 80 | """ 81 | return X 82 | 83 | def predict(self, X): 84 | """ 85 | Generates a prediction based on X, the array of covariates. 86 | """ 87 | n_samples = X.shape[0] 88 | prediction = [] 89 | for i in range(n_samples): 90 | prediction.append(self.predict_(X,i)) 91 | return prediction 92 | 93 | def predict_(self, X, i): 94 | """ 95 | Generates a prediction for one sample, based on X, the array of 96 | covariates and i, a point in that array (1D), or row (2D), etc. 97 | This must be implemented for each model. 98 | """ 99 | raise NotImplementedError("") 100 | 101 | def score(self, X, y): 102 | """ 103 | Generates a score for the prediction based on X, the array of 104 | covariates, and y, the observation. 105 | """ 106 | prediction = self.predict(X) 107 | return self.score_(prediction,y) 108 | 109 | def score_(self, prediction, y): 110 | """ 111 | Generates a score based on the prediction (from X), and the 112 | observation y. 113 | """ 114 | raise NotImplementedError("") 115 | 116 | @classmethod 117 | def get_posterior_mean(cls, fit): 118 | """ 119 | Implemented because get_posterior_mean is (was?) broken in pystan: 120 | https://github.com/stan-dev/pystan/issues/107 121 | """ 122 | means = {} 123 | x = fit.extract() 124 | for key,value in x.items()[:-1]: 125 | means[key] = value.mean(axis=0) 126 | return means 127 | --------------------------------------------------------------------------------