├── Tree based methods for prediction.pdf ├── data ├── PhysionetChallenge2012-set-a.csv.gz └── PhysionetChallenge2012-set-b-no-outcome.csv.gz ├── README.md ├── LICENSE ├── requirements.txt ├── .gitignore ├── trees.Rmd ├── pipeline_introduction.ipynb └── etc └── presentation-plots.ipynb /Tree based methods for prediction.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alistairewj/tree-prediction-tutorial/HEAD/Tree based methods for prediction.pdf -------------------------------------------------------------------------------- /data/PhysionetChallenge2012-set-a.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alistairewj/tree-prediction-tutorial/HEAD/data/PhysionetChallenge2012-set-a.csv.gz -------------------------------------------------------------------------------- /data/PhysionetChallenge2012-set-b-no-outcome.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alistairewj/tree-prediction-tutorial/HEAD/data/PhysionetChallenge2012-set-b-no-outcome.csv.gz -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tree-prediction-tutorial 2 | 3 | Tutorial on prediction of medical data using tree based models 4 | 5 | ## Setting up the environment 6 | 7 | Most requirements should be met with a simple pip install: 8 | 9 | `pip install pydotplus numpy pandas sklearn matplotlib jupyter` 10 | 11 | If you want the exact version of all the packages, you can use the `requirements.txt` file: 12 | 13 | `pip install -r requirements.txt` 14 | 15 | ## Running the notebook 16 | 17 | The `trees-classification.ipynb` should run out of the box. The first cell is used for running the notebook on [Google's Colaboratory](https://colab.research.google.com/notebooks/welcome.ipynb). 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Alistair Johnson 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | backcall==0.1.0 2 | bleach==3.1.4 3 | cycler==0.10.0 4 | decorator==4.3.0 5 | defusedxml==0.5.0 6 | entrypoints==0.2.3 7 | ipykernel==5.1.0 8 | ipython==7.0.1 9 | ipython-genutils==0.2.0 10 | ipywidgets==7.4.2 11 | jedi==0.13.1 12 | Jinja2==2.10 13 | jsonschema==2.6.0 14 | jupyter==1.0.0 15 | jupyter-client==5.2.3 16 | jupyter-console==6.0.0 17 | jupyter-core==4.4.0 18 | kiwisolver==1.0.1 19 | MarkupSafe==1.0 20 | matplotlib==3.0.0 21 | mistune==0.8.4 22 | nbconvert==5.4.0 23 | nbformat==4.4.0 24 | notebook==6.1.5 25 | numpy==1.15.3 26 | pandas==0.23.4 27 | pandocfilters==1.4.2 28 | parso==0.3.1 29 | pexpect==4.6.0 30 | pickleshare==0.7.5 31 | prometheus-client==0.4.2 32 | prompt-toolkit==2.0.6 33 | ptyprocess==0.6.0 34 | pydotplus==2.0.2 35 | Pygments==2.2.0 36 | pyparsing==2.2.2 37 | python-dateutil==2.7.3 38 | pytz==2018.5 39 | pyzmq==17.1.2 40 | qtconsole==4.4.2 41 | scikit-learn==0.20.0 42 | scipy==1.1.0 43 | Send2Trash==1.5.0 44 | simplegeneric==0.8.1 45 | six==1.11.0 46 | sklearn==0.0 47 | terminado==0.8.1 48 | testpath==0.4.2 49 | tornado==5.1.1 50 | traitlets==4.3.2 51 | wcwidth==0.1.7 52 | webencodings==0.5.1 53 | widgetsnbextension==3.4.2 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /trees.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "trees" 3 | author: "Alistair Johnson" 4 | date: "October 23, 2018" 5 | output: html_document 6 | --- 7 | 8 | ```{r setup, include=FALSE} 9 | install.packages("mlbench") 10 | install.packages("e1071") 11 | install.packages("xgboost") 12 | install.packages("fastAdaboost") 13 | install.packages("caret") 14 | install.packages("rpart") 15 | install.packages("rpart.plot") 16 | install.packages("adabag") 17 | install.packages("randomForest") 18 | install.packages("gbm") 19 | install.packages("ROCR") 20 | install.packages("RCurl") 21 | library(rpart) 22 | library(rpart.plot) 23 | library(adabag) 24 | library(randomForest) 25 | library(MASS) 26 | library(gbm) 27 | library(ROCR) 28 | library(RCurl) 29 | ``` 30 | 31 | Many thanks to Brian Healy for the script which eventually became this R Markdown! 32 | 33 | # Dataset 34 | 35 | The dataset we'll use is a classic: Fisher's iris. This data was collected by Edgar Anderson and was used by Fisher to demonstrate Linear Discriminant Analysis (LDA). We won't talk about LDA in this tutorial but it's an interesting technique worth learning about! 36 | 37 | The iris dataset includes the petal and sepal measurements for three types of flowers. The below code: 38 | 39 | * loads in the dataset 40 | * prints out a brief description 41 | * extracts two columns of the data into `X` 42 | * extracts the class labels into `y` 43 | 44 | Note that we only use two columns of data because we'd like to visualize the classifier. Also, we only use data from index `50` onward because we'd like to focus on two plants: versicolour and virginica. The data for these plants are not linearly separable (i.e. you cannot draw a straight line through the data to split the two plants into groups). 45 | 46 | ```{r iris} 47 | # Creating small dataset with only two species and two predictors 48 | practice <- iris[iris$Species!="setosa",names(iris) %in% c("Petal.Length","Sepal.Length","Species")] 49 | practice$Species<-factor(practice$Species) 50 | head(practice) 51 | ``` 52 | 53 | We use `head` to look at the top 5 rows - and can see we have a variety of sepal length and petal length measurements for versicolor. 54 | 55 | ## Decision tree 56 | 57 | Let's build the simplest tree model we can think of: a classification tree with only one split. Decision trees of this form are commonly referred to under the umbrella term Classification and Regression Trees (CART) [1]. While we will only be looking at classification here, regression isn't too different. After grouping the data (which is essentially what a decision tree does), classification involves assigning all members of the group to the majority class of that group during training. Regression is the same, except you would assign the average value, not the majority. In the case of a decision tree with one split, often called a "stump", the model will partition the data into two groups, and assign classes for those two groups based on majority vote. There are many parameters available for the DecisionTreeClassifier class; by specifying max_depth=1 we will build a decision tree with only one split - i.e. of depth 1. 58 | 59 | [1] L. Breiman, J. Friedman, R. Olshen, and C. Stone. Classification and Regression Trees. Wadsworth, Belmont, CA, 1984. 60 | 61 | ```{r dt, echo=FALSE} 62 | # Fitting decision tree with only one split 63 | iris.dt<-rpart(Species~Sepal.Length+Petal.Length, data=practice, method="class", 64 | control=rpart.control(maxdepth=1)) 65 | ``` 66 | 67 | Since our model is so simple, we can actually look at the full decision tree. 68 | 69 | ```{r dtplot, echo=FALSE} 70 | # Plotting the tree 71 | plot(iris.dt, uniform=TRUE, main="Classification Iris dataset") 72 | text(iris.dt, use.n=TRUE, all=TRUE, cex=.8) 73 | prp(iris.dt, faclen = 0, cex = 0.8, extra = 1) 74 | practice$pointcolor<-ifelse(practice$Species=="versicolor","orange","blue") 75 | 76 | # print the tree 77 | print(iris.dt) 78 | ``` 79 | 80 | Here we see three nodes: one at the top, one in the lower left, and one in the lower right. It's easier to read this from the print out however. 81 | 82 | The top node (#1) is the root of the tree: it contains all the data. Next to "root", the number "100" reminds us how many rows of data are assessed at this node, the next number "50" tells us how many misclassifications there are, and the string tells us the majority class "versicolor". In brackets we can see this node contains a 50:50 class balance (0.5, 0.5). In our iris data, that translates to 50 versicolour and 50 virginica. 83 | 84 | In this tree we've moved observations with petal length < 4.75 cm to the bottom left node (2), and all observations with petal length >= 4.75 cm are moved into the bottom right node (3). Looking in the two nodes, we can also see that the class balance in brackets is much better for both, indicating that these nodes are more homogeneous. Looking at the value line, we can see that the left node has 44 observations in class 1, and 1 observation in class 2. This is much better than the 50/50 split we had earlier! 85 | 86 | Let's take a look at what this decision boundary actually looks like. 87 | 88 | ```{r, echo=FALSE} 89 | ##Plotting the actual data (circles) and model prediction (x) 90 | practice$pred.dt<-predict(iris.dt,practice,type="class") 91 | practice$pointcolor.dt<-ifelse(practice$pred.dt=="versicolor","orange","blue") 92 | plot(x=practice$Sepal.Length,y=practice$Petal.Length,main="Decision tree",xlab="sepal length (cm)", 93 | ylab="petal length (cm)",col=practice$pointcolor) 94 | lines(x=c(4.5, 8.0),y=c(4.75, 4.75)) 95 | points(x=practice$Sepal.Length,y=practice$Petal.Length,col=practice$pointcolor.dt,pch=4) 96 | ``` 97 | 98 | We can see a blue circle with a red X on the far left - the 1 point we misclassified as class 2 which had petal length < 4.75cm. 99 | 100 | Of course we are using a very simple model - let's see what happens when we increase the depth to 5. 101 | 102 | ```{r, echo=FALSE} 103 | ##Fitting decision tree with many splits 104 | iris.dt2<-rpart(Species~Sepal.Length+Petal.Length,data=practice,method="class", 105 | control=rpart.control(maxdepth=6,minsplit=1,cp=0)) 106 | practice$pred.dt2<-predict(iris.dt2,practice,type="class") 107 | practice$pointcolor.dt2<-ifelse(practice$pred.dt2=="versicolor","orange","blue") 108 | plot(x=practice$Sepal.Length,y=practice$Petal.Length,main="Decision tree-v2",xlab="sepal length (cm)", 109 | ylab="petal length (cm)",col=practice$pointcolor) 110 | points(x=practice$Sepal.Length,y=practice$Petal.Length,col=practice$pointcolor.dt2,pch=4) 111 | ``` 112 | Now our tree is more complicated - we can see a few vertical boundaries as well as the horizontal one from before. Some of these we may like - for example the movement of the boundary upward around septal length of ~6.7 cm. However, some appear unnatural; the vertical bar of classification done around a septal length of 6.1 cm, for example. Let's look at the tree itself. 113 | 114 | ```{r, echo=FALSE} 115 | print(iris.dt2) 116 | ``` 117 | 118 | At the bottom, we see nodes (14) and beyond are the culprit. 119 | 120 | (7) 42 obs, 1 misclassification -> Split on Sepal length < 6.05 121 | (14) 5 obs, 1 misclassification -> Split on Sepal Length >= 5.95 122 | (28) 1 obs, 0 misclassification - finished! 123 | 124 | Having an entire rule based upon this one observation seems silly, but it's perfectly logical as at the moment the only objective the algorithm cares about is minimizing the class imbalance (Gini coefficient) - and we can see the class balance is better at node (28) then at nodes (14) and (7)! 125 | 126 | # Boosting 127 | 128 | The premise of boosting is the combination of many weak learners to form a single "strong" learner. In a nutshell, boosting involves building a models iteratively, and at each step we focus on the data we performed poorly on. In our context, we'll use decision trees, so the first step would be to build a tree using the data. Next, we'd look at the data that we misclassified, and re-weight the data so that we really wanted to classify those observations correctly, at a cost of maybe getting some of the other data wrong this time. 129 | 130 | 139 | 140 | Let's take a look at the final decision surface. 141 | 142 | 143 | ```{r, echo=FALSE} 144 | ##Fitting boosting model 145 | set.seed(1) 146 | iris.boost<-boosting(Species~Sepal.Length+Petal.Length,data=practice) 147 | practice$pred.boost<-predict(iris.boost,practice)$class 148 | practice$pointcolor.boost<-ifelse(practice$pred.boost=="versicolor","orange","blue") 149 | plot(x=practice$Sepal.Length,y=practice$Petal.Length,main="Boosting",xlab="sepal length (cm)", 150 | ylab="petal length (cm)",col=practice$pointcolor) 151 | points(x=practice$Sepal.Length,y=practice$Petal.Length,col=practice$pointcolor.boost,pch=4) 152 | ``` 153 | 154 | 157 | 158 | With boosting, we iteratively changed the dataset to have new trees focus on the "difficult" observations. The next approach we discuss is similar as it also involves using changed versions of our dataset to build new trees. 159 | 160 | 161 | ```{r, echo=FALSE} 162 | ##Fitting bagging model 163 | iris.bag<-randomForest(Species~Sepal.Length+Petal.Length,data=practice, mtry=2) 164 | practice$pred.bag<-predict(iris.bag,practice) 165 | practice$pointcolor.bag<-ifelse(practice$pred.bag=="versicolor","orange","blue") 166 | plot(x=practice$Sepal.Length,y=practice$Petal.Length,main="Bagging",xlab="sepal length (cm)", 167 | ylab="petal length (cm)",col=practice$pointcolor) 168 | points(x=practice$Sepal.Length,y=practice$Petal.Length,col=practice$pointcolor.bag,pch=4) 169 | ``` 170 | 171 | # Bagging / Random Forest 172 | 173 | Bootstrap aggregation, or "Bagging", is another form of ensemble learning where we aim to build a single good model by combining many models together. With AdaBoost, we modified the data to focus on hard to classify observations. We can imagine this as a form of resampling the data for each new tree. For example, say we have three observations: A, B, and C, [A, B, C]. If we correctly classify observations [A, B], but incorrectly classify C, then AdaBoost involves building a new tree that focuses on C. Equivalently, we could say AdaBoost builds a new tree using the dataset [A, B, C, C, C], where we have intentionally repeated observation C 3 times so that the algorithm thinks it is 3 times as important as the other observations. Before we move on, convince yourself that this makes sense. 174 | 175 | Bagging involves the exact same approach, except we don't selectively choose which observations to focus on, but rather we randomly select subsets of data each time. As you can see, while this is a similar process to AdaBoost, the concept is quite different. Whereas before we aimed to iteratively improve our overall model with new trees, we now build trees on what we hope are independent datasets. 176 | 177 | Let's take a step back, and think about a practical example. Say we wanted a good model of heart disease. If we saw researchers build a model from a dataset of patients from their hospital, we would be happy. If they then acquired a new dataset from new patients, and built a new model, we'd be inclined to feel that the combination of the two models would be better than any one individually. This exact scenario is what bagging aims to replicate, except instead of actually going out and collecting new datasets, we instead use bootstraping to create new sets of data from our current dataset. If you are unfamiliar with bootstrapping, you can treat it as "magic" for now (and if you are familiar with the bootstrap, you already know it's magic). 178 | 179 | 188 | 189 | The Random Forest takes the previous ideas one step further: instead of just resampling our data, we also select only a fraction of the features to include. It turns out that this subselection tends to improve the performance of our models. The odds of an individual being very good or very bad is higher (i.e. the variance of the trees is increased), and this ends up giving us a final model with better overall performance (lower bias). 190 | 191 | Let's train the model now. 192 | 193 | ```{r, echo=FALSE} 194 | ##Fitting random forest model 195 | iris.rf<-randomForest(Species~Sepal.Length+Petal.Length,data=practice) 196 | practice$pred.rf<-predict(iris.rf,practice) 197 | practice$pointcolor.rf<-ifelse(practice$pred.rf=="versicolor","orange","blue") 198 | plot(x=practice$Sepal.Length,y=practice$Petal.Length,main="Random forest",xlab="sepal length (cm)", 199 | ylab="petal length (cm)",col=practice$pointcolor) 200 | points(x=practice$Sepal.Length,y=practice$Petal.Length,col=practice$pointcolor.rf,pch=4) 201 | ``` 202 | 203 | The visualization doesn't really show us the power of Random Forests, but we'll quantitatively evaluate them soon enough. 204 | 205 | 206 | # Running through a slightly harder dataset 207 | 208 | We've now learned the basics of the various tree methods and have visualized most of them on the Fisher iris data. We now move on to a harder classification problem involving the identification of breast cancer tumours from features describing cell nuclei of breast mass. The goal is to classify whether the mass is cancerous or not. 209 | 210 | ```{r, echo=FALSE} 211 | fn <- getURL('https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/wdbc.data', ssl.verifyhost=FALSE, ssl.verifypeer=FALSE) 212 | 213 | # Wisconsin dataset 214 | wisc<-read.table(textConnection(fn), sep=",",header=F) 215 | names(wisc)<-c("ID","outcome", 216 | "radius_mean","texture_mean","perimeter_mean","area_mean","smoothness_mean","compactness_mean", 217 | "concavity_mean","concave_points_mean","symmetry_mean","fractal_dimension_mean", 218 | "radius_se","texture_se","perimeter_se","area_se","smoothness_se","compactness_se", 219 | "concavity_se","concave_points_se","symmetry_se","fractal_dimension_se", 220 | "radius_worst","texture_worst","perimeter_worst","area_worst","smoothness_worst","compactness_worst", 221 | "concavity_worst","concave_points_worst","symmetry_worst","fractal_dimension_worst") 222 | ``` 223 | 224 | A great package is caret; which allows for easy cross-validation (saves train/val/test split headache) and optimization of hyperparameters. 225 | 226 | ```{r, echo=TRUE} 227 | library(mlbench) 228 | library(caret) 229 | ``` 230 | 231 | The following trains a tree, AdaBoost, RF, and gradient boosting model on the entire dataset. 232 | 233 | ```{r, echo=FALSE} 234 | set.seed(1123) 235 | 236 | # prepare training scheme 237 | control <- trainControl(method="repeatedcv", number=5, repeats=3, classProbs = TRUE, summaryFunction = twoClassSummary) 238 | # train the tree model 239 | modelTree <- train(outcome~., data=wisc, method="rpart", metric="ROC", trControl=control) 240 | # train adaboost 241 | modelAdaboost <- train(outcome~., data=wisc, method="adaboost", metric="ROC", trControl=control) 242 | # train the RF model 243 | modelRF <- train(outcome~., data=wisc, method="rf", metric="ROC", trControl=control) 244 | # train the GBM model (with xgboost) 245 | modelGbm <- train(outcome~., data=wisc, method="xgbTree", metric="ROC", trControl=control, verbose=FALSE) 246 | 247 | # store results in a single object using resamples 248 | results <- resamples(list(tree=modelTree, adaboost=modelAdaboost, RF=modelRF, GBM=modelGbm)) 249 | 250 | # summarize the distributions 251 | summary(results) 252 | # boxplots of results 253 | bwplot(results) 254 | # dot plots of results 255 | dotplot(results) 256 | ``` 257 | 258 | 259 | GBM is working quite well! 260 | 261 | # Exercise 262 | 263 | We'll now practice using these models on a dataset acquired from patients admitted to intensive care units at the Beth Israel Deaconness Medical Center in Boston, MA. All patients in the cohort stayed for at least 48 hours, and the goal of the prediction task is to predict in-hospital mortality. This data is a subset of a publicly accessible ICU database: MIMIC. If you're interested, you can read more about MIMIC here. The particular dataset we are using is described in more detail here: http://physionet.org/challenge/2012/ 264 | 265 | The data is originally provided as a time series of observations for a number of variables, but to simplify the analysis, we've done some preprocessing to get a single row for each patient. The following cell will download the data from online and load it into a dataframe 266 | 267 | ```{r, echo=FALSE} 268 | fn = gzcon(url('https://github.com/alistairewj/tree-prediction-tutorial/raw/master/data/PhysionetChallenge2012-set-a.csv.gz', method="libcurl"), text=TRUE) 269 | seta <- read.table(fn, sep=',', header=TRUE) 270 | 271 | names(seta)[names(seta) == 'In.hospital_death'] <- 'death' 272 | 273 | rownames(seta) <- seta$recordid 274 | head(seta) 275 | ``` 276 | 277 | The first columns are: 278 | 279 | * recordid - random ID for each patient 280 | * SAPS.I - A severity of illness score (higher means sicker) 281 | * SOFA - An organ failure score (higher means sicker) 282 | * Length_of_stay - how long they stayed in the ICU 283 | * Survival - if they survived, this is -1. If they died, it's the number of days until their death 284 | * In.hospital_death - 0/1 if they died in hospital (this is our target) 285 | 286 | If we use length of stay/survival in our models, we are cheating! We don't know them until much later in the patient stay - and much later than our ideal time for prediction. 287 | 288 | We also don't want recordid since it has no physical meaning. 289 | 290 | ```{r} 291 | # drop columns we don't want 292 | drop_columns <- c("Survival", "Length_of_stay", "recordid") 293 | seta <- seta[, !(names(seta) %in% drop_columns)] 294 | head(seta) 295 | ``` 296 | 297 | Much better! Now to try some models. 298 | 299 | ```{r, echo=FALSE} 300 | set.seed(1123) 301 | # prepare training scheme 302 | control <- trainControl(method="repeatedcv", number=5, repeats=3, classProbs = TRUE, summaryFunction = twoClassSummary) 303 | # train the tree model 304 | modelTree <- train(death~., data=seta, method="rpart", metric="ROC", trControl=control) 305 | ``` 306 | 307 | Ah, it errors! Clearly our models cannot handle missing data. We haven't dealt with this before, but it is a challenging issue with medical data. In general there are three types of missing data: 308 | 309 | 1. Missing completely at random (MCAR) 310 | * The data is missing for reasons unrelated to the data 311 | * a power outage results in losing vital sign data 312 | 2. Missing at random (MAR) 313 | * The data is missing for reasons related to the data, but not the missing observation 314 | * we don't collect lactate measurements on admission to a medical ICU, but we collect them for cardiac ICU 315 | 3. Missing not at random (MNAR) 316 | * The data is missing, and the reason it is missing depends on the value 317 | * a doctor does not order the Troponin-I lab test, because they believe it to be normal 318 | 319 | The hardest case to deal with is MNAR, and unfortunately, that is the most common in the medical domain. Still, we have to do something, so we often use approaches which are theoretically invalid under MNAR but in practice work acceptably well. 320 | 321 | Below, we'll replace missing data with the average value for the training population. 322 | 323 | ```{r} 324 | for(i in 1:ncol(seta)){ 325 | seta[is.na(seta[,i]), i] <- mean(seta[,i], na.rm = TRUE) 326 | } 327 | 328 | seta$death <- factor(seta$death) 329 | levels(seta$death) <- c("alive", "dead") 330 | head(seta) 331 | ``` 332 | 333 | Now that the missing data is handled, we can try to build the above tree models using the ICU data! 334 | 335 | ```{r, echo=FALSE} 336 | set.seed(1123) 337 | 338 | # prepare training scheme 339 | control <- trainControl(method="repeatedcv", number=5, repeats=3, classProbs = TRUE, summaryFunction = twoClassSummary) 340 | # train the tree model 341 | modelTree <- train(death~., data=seta, method="rpart", metric="ROC", trControl=control) 342 | # train adaboost 343 | modelAdaboost <- train(death~., data=seta, method="adaboost", metric="ROC", trControl=control) 344 | # train the RF model 345 | modelRF <- train(death~., data=seta, method="rf", metric="ROC", trControl=control) 346 | # train the GBM In.hospital_death (with xgboost) 347 | modelGbm <- train(death~., data=seta, method="xgbTree", metric="ROC", trControl=control, verbose=FALSE) 348 | 349 | # summarize the distributions 350 | summary(results) 351 | # boxplots of results 352 | bwplot(results) 353 | # dot plots of results 354 | dotplot(results) 355 | ``` 356 | 357 | # Challenge 358 | 359 | Now try to build your own model that performs well! Use cross-validation on set-a (4000 patients) to get a good model. Then apply that model on set-b (a distinct set of 4000 patients). 360 | The below code loads in set-b. Note that the outcome isn't available for set-b :) 361 | 362 | Some things to think about: 363 | 364 | * Are there other ways to impute missing data? 365 | * Have we thought about the features in our data, and how we are using them? 366 | * Have we visualized the data? Are there any obvious outliers which may fool our model? (note: a lot were removed by custom preprocessing I did, but some may remain) 367 | * Are there parameters of our model which we could change? 368 | 369 | ```{r, echo=FALSE} 370 | fn = gzcon(url('https://github.com/alistairewj/tree-prediction-tutorial/raw/master/data/PhysionetChallenge2012-set-b-no-outcome.csv.gz', method="libcurl"), text=TRUE) 371 | setb <- read.table(fn, sep=',', header=TRUE) 372 | 373 | rownames(setb) <- setb$recordid 374 | drop_columns <- c("recordid") 375 | setb <- setb[, !(names(setb) %in% drop_columns)] 376 | head(setb) 377 | ``` 378 | -------------------------------------------------------------------------------- /pipeline_introduction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "YI3bzd4hGSFb" 7 | }, 8 | "source": [ 9 | "# Pipeline\n", 10 | "\n", 11 | "The goal of this notebook is to learn about sklearn's pipeline and some useful data science tips.\n", 12 | "\n", 13 | "[View in colab](https://colab.research.google.com/github/alistairewj/tree-prediction-tutorial/blob/master/pipeline_introduction.ipynb)" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": { 20 | "id": "m_s5R2dgGSFk" 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "# Import libraries\n", 25 | "from datetime import timedelta\n", 26 | "import collections\n", 27 | "import os\n", 28 | "import errno\n", 29 | "\n", 30 | "import sklearn\n", 31 | "from sklearn import tree\n", 32 | "from sklearn import ensemble\n", 33 | "from sklearn import metrics\n", 34 | "from sklearn.model_selection import train_test_split\n", 35 | "from sklearn.model_selection import cross_val_score\n", 36 | "from sklearn import datasets\n", 37 | "\n", 38 | "import numpy as np\n", 39 | "import pandas as pd\n", 40 | "import pydotplus\n", 41 | "import matplotlib\n", 42 | "import matplotlib.pyplot as plt\n", 43 | "\n", 44 | "# used to display trees\n", 45 | "from IPython.display import Image\n", 46 | "\n", 47 | "%matplotlib inline\n", 48 | "plt.style.use('ggplot')\n", 49 | "\n", 50 | "plt.rcParams.update({'font.size': 20})" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "source": [ 56 | "def make_colormap(seq):\n", 57 | " \"\"\"Return a LinearSegmentedColormap\n", 58 | " seq: a sequence of floats and RGB-tuples. The floats should be increasing\n", 59 | " and in the interval (0,1).\n", 60 | " \"\"\"\n", 61 | " seq = [(None,) * 3, 0.0] + list(seq) + [1.0, (None,) * 3]\n", 62 | " cdict = {'red': [], 'green': [], 'blue': []}\n", 63 | " for i, item in enumerate(seq):\n", 64 | " if isinstance(item, float):\n", 65 | " r1, g1, b1 = seq[i - 1]\n", 66 | " r2, g2, b2 = seq[i + 1]\n", 67 | " cdict['red'].append([item, r1, r2])\n", 68 | " cdict['green'].append([item, g1, g2])\n", 69 | " cdict['blue'].append([item, b1, b2])\n", 70 | " return matplotlib.colors.LinearSegmentedColormap('CustomMap', cdict)\n", 71 | "\n", 72 | "def plot_model_pred_2d(mdl, X, y, cm=None, cbar=True, xlabel=None, ylabel=None):\n", 73 | " # look at the regions in a 2d plot\n", 74 | " # based on scikit-learn tutorial plot_iris.html\n", 75 | "\n", 76 | " # get minimum and maximum values\n", 77 | " x0_min = X[:, 0].min()\n", 78 | " x0_max = X[:, 0].max()\n", 79 | " x1_min = X[:, 1].min()\n", 80 | " x1_max = X[:, 1].max()\n", 81 | "\n", 82 | " xx, yy = np.meshgrid(np.linspace(x0_min, x0_max, 100),\n", 83 | " np.linspace(x1_min, x1_max, 100))\n", 84 | "\n", 85 | " Z = mdl.predict(np.c_[xx.ravel(), yy.ravel()])\n", 86 | " Z = Z.reshape(xx.shape)\n", 87 | "\n", 88 | " if not cm:\n", 89 | " # custom colormap\n", 90 | " #e58139f9 - orange\n", 91 | " #399de5e0 - to blue\n", 92 | " s = list()\n", 93 | "\n", 94 | " lo = np.array(matplotlib.colors.to_rgb('#e5813900'))\n", 95 | " hi = np.array(matplotlib.colors.to_rgb('#399de5e0'))\n", 96 | "\n", 97 | " for i in range(255):\n", 98 | " s.append( list((hi-lo)*(float(i)/255)+lo) )\n", 99 | " cm = make_colormap(s)\n", 100 | "\n", 101 | " # plot the contour - colouring different regions\n", 102 | " cs = plt.contourf(xx, yy, Z, cmap=cm)\n", 103 | "\n", 104 | " # plot the individual data points - colouring by the *true* outcome\n", 105 | " color = y.ravel()\n", 106 | " plt.scatter(X[:, 0], X[:, 1], c=color, edgecolor='k', linewidth=2,\n", 107 | " marker='o', s=60, cmap=cm)\n", 108 | "\n", 109 | " if xlabel is not None:\n", 110 | " plt.xlabel(xlabel)\n", 111 | " if ylabel is not None:\n", 112 | " plt.ylabel(ylabel)\n", 113 | " plt.axis(\"tight\")\n", 114 | " #plt.clim([-1.5,1.5])\n", 115 | " if cbar:\n", 116 | " plt.colorbar()\n", 117 | "\n", 118 | "def create_graph(mdl, cmap=None, feat=None):\n", 119 | " # cmap is a colormap\n", 120 | " # e.g. cmap = matplotlib.cm.coolwarm( np.linspace(0.0, 1.0, 256, dtype=float) )\n", 121 | " tree_graph = tree.export_graphviz(mdl, out_file=None,\n", 122 | " feature_names=feat,\n", 123 | " filled=True, rounded=True)\n", 124 | " graph = pydotplus.graphviz.graph_from_dot_data(tree_graph)\n", 125 | "\n", 126 | " # get colormap\n", 127 | " if cmap:\n", 128 | " # remove transparency\n", 129 | " if cmap.shape[1]==4:\n", 130 | " cmap = cmap[:,0:2]\n", 131 | "\n", 132 | " nodes = graph.get_node_list()\n", 133 | " for node in nodes:\n", 134 | " if node.get_label():\n", 135 | " # get number of samples in group 1 and group 2\n", 136 | " num_samples = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]\n", 137 | "\n", 138 | " # proportion that is class 2\n", 139 | " cm_value = float(num_samples[1]) / float(sum(num_samples))\n", 140 | " # convert to (R, G, B, alpha) tuple\n", 141 | " cm_value = matplotlib.cm.coolwarm(cm_value)\n", 142 | " cm_value = [int(np.ceil(255*x)) for x in cm_value]\n", 143 | " color = '#{:02x}{:02x}{:02x}'.format(cm_value[0], cm_value[1], cm_value[2])\n", 144 | " node.set_fillcolor(color)\n", 145 | "\n", 146 | " Image(graph.create_png())\n", 147 | " return graph" 148 | ], 149 | "metadata": { 150 | "id": "2jzzHqg_SEmq" 151 | }, 152 | "execution_count": null, 153 | "outputs": [] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "source": [ 158 | "# Exercise\n", 159 | "\n", 160 | "We'll practice using pipeline on a dataset acquired from patients admitted to intensive care units at the Beth Israel Deaconness Medical Center in Boston, MA. All patients in the cohort stayed for at least 48 hours, and the goal of the prediction task is to predict in-hospital mortality. This data is a subset of a publicly accessible ICU database: MIMIC. If you're interested, you can read more about MIMIC [here](https://mimic.physionet.org).\n", 161 | "The particular dataset we are using is described in more detail here: http://physionet.org/challenge/2012/\n", 162 | "\n", 163 | "The data is originally provided as a time series of observations for a number of variables, but to simplify the analysis, we've done some preprocessing to get a single row for each patient.\n", 164 | "The following cell will check if the data is available here. If it's not, it will download it to the subfolder `data` in the same folder as this notebook.\n", 165 | "\n", 166 | "The goal of this challenge was to predict mortality, i.e. whether a patient died in the hospital at the end of their stay." 167 | ], 168 | "metadata": { 169 | "id": "V9Rpf2dbPYF6" 170 | } 171 | }, 172 | { 173 | "cell_type": "code", 174 | "source": [ 175 | "url = 'https://github.com/alistairewj/tree-prediction-tutorial/raw/master/data/PhysionetChallenge2012-set-a.csv.gz'\n", 176 | "seta = pd.read_csv(url, sep=',', header=0, compression='gzip')\n", 177 | "seta.set_index('recordid', inplace=True)\n", 178 | "seta.head()" 179 | ], 180 | "metadata": { 181 | "id": "DtKWMJ9jGKA9" 182 | }, 183 | "execution_count": null, 184 | "outputs": [] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "source": [ 189 | "Above, we can see we have a few variables:\n", 190 | "\n", 191 | "- SAPS-I\n", 192 | "- SOFA\n", 193 | "- Length_of_stay\n", 194 | "- Survival\n", 195 | "- In-hospital_death\n", 196 | "- Age\n", 197 | "- Gender\n", 198 | "- Height\n", 199 | "- Weight\n", 200 | "- CCU\n", 201 | "- ...\n", 202 | "\n", 203 | "Q1: How would I understand what these variables mean?" 204 | ], 205 | "metadata": { 206 | "id": "mJm6IorcQCDM" 207 | } 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "source": [ 212 | "A1: Look at the documentation! http://physionet.org/challenge/2012/" 213 | ], 214 | "metadata": { 215 | "id": "SxVW_4PqQWcV" 216 | } 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "source": [ 221 | "## Avoid data leakage\n", 222 | "\n", 223 | "Our first step will be to really understand our data through EDA. We can look at all the columns. Since there's so many, we print 10 per row." 224 | ], 225 | "metadata": { 226 | "id": "sTj54ZedPhKq" 227 | } 228 | }, 229 | { 230 | "cell_type": "code", 231 | "source": [ 232 | "# print all the columns\n", 233 | "for i in range(0, len(seta.columns), 10):\n", 234 | " print(list(seta.columns[i:i+10]))" 235 | ], 236 | "metadata": { 237 | "id": "QyeJrO98R0QV" 238 | }, 239 | "execution_count": null, 240 | "outputs": [] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "source": [ 245 | "Q2: Are there columns which would cause data leakage?" 246 | ], 247 | "metadata": { 248 | "id": "reUaCZ7iTuBK" 249 | } 250 | }, 251 | { 252 | "cell_type": "markdown", 253 | "source": [ 254 | "A2: Yes! We should not use the survival column for in-hospital mortality prediction." 255 | ], 256 | "metadata": { 257 | "id": "GDPca41iTvqU" 258 | } 259 | }, 260 | { 261 | "cell_type": "code", 262 | "source": [ 263 | "# Alternative A2: We can see this with a decision tree.\n", 264 | "mdl = ensemble.GradientBoostingClassifier(n_estimators=10)\n", 265 | "X, y = seta[['Length_of_stay']].values, seta[['In-hospital_death']].values\n", 266 | "mdl = tree.DecisionTreeClassifier(max_depth=1)\n", 267 | "mdl = mdl.fit(X,y)\n", 268 | "graph = create_graph(mdl, feat=['Length_of_stay'])\n", 269 | "Image(graph.create_png())" 270 | ], 271 | "metadata": { 272 | "id": "6ysWlLFpT_pb" 273 | }, 274 | "execution_count": null, 275 | "outputs": [] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "source": [ 280 | "# remove cheating variables from training data\n", 281 | "X = seta.drop(['In-hospital_death', 'Survival', 'Length_of_stay'], axis=1).values\n", 282 | "features = list(seta.drop(['In-hospital_death', 'Survival', 'Length_of_stay'], axis=1).columns)\n", 283 | "y = seta['In-hospital_death'].values" 284 | ], 285 | "metadata": { 286 | "id": "Pepl_2HMT6yO" 287 | }, 288 | "execution_count": null, 289 | "outputs": [] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "source": [ 294 | "We are trying to build a machine learning model to predict an outcome. We would like this to work in the real world when we go out and use it. To do this, we need to estimate the *generalization* error of our model, which we approximate with the error on a set of data we have.\n", 295 | "\n", 296 | "Q3: Why can't we just use the entire dataset to do this? That is, why can't we just do:\n", 297 | "\n", 298 | "```python\n", 299 | "model = train_model(data, target)\n", 300 | "score = score_model(model, data, target)\n", 301 | "print(f'My amazing score is: {score}!')\n", 302 | "```" 303 | ], 304 | "metadata": { 305 | "id": "gOotvIa5STHB" 306 | } 307 | }, 308 | { 309 | "cell_type": "markdown", 310 | "source": [ 311 | "A3: Evaluating models on the same data they've been trained on is optimistic and misleading.\n", 312 | "\n", 313 | "Models vary in their amount of flexibility (in statistics, we sometimes say models are high variance, which is somewhat confusing terminology). It can be helpful to imagine two extreme scenarios. Imagine our model always predicts the average of the outcome, which in this case is 15%:\n", 314 | "\n", 315 | "```python\n", 316 | "model = lambda x: return 0.15\n", 317 | "```\n", 318 | "\n", 319 | "Every time we ask the model to predict mortality for a patient, it returns 0.15. It doesn't matter what data we provide, this model won't overfit.\n", 320 | "\n", 321 | "Now imagine a model that memorizes the data and returns the result it's already seen:\n", 322 | "\n", 323 | "```python\n", 324 | "model = lambda x: return target[x]\n", 325 | "```\n", 326 | "\n", 327 | "This model will perfectly predict our training set, but will do terribly on any external dataset.\n", 328 | "\n", 329 | "For this reason, we hold out some data for evaluation." 330 | ], 331 | "metadata": { 332 | "id": "PiUuyyWmSwfa" 333 | } 334 | }, 335 | { 336 | "cell_type": "code", 337 | "source": [ 338 | "# here we only use the first 3000 observations as our training set\n", 339 | "y_train = y[0:3000]\n", 340 | "X_train = X[0:3000, :]\n", 341 | "\n", 342 | "y_test = y[3000:]\n", 343 | "X_test = X[3000:, :]\n", 344 | "\n", 345 | "print('Training size: {} - {:6d} missing observations'.format(X_train.shape,\n", 346 | " np.sum(np.sum(np.isnan(X_train)))))\n", 347 | "print('Test size: {} - {:6d} missing observations'.format(X_test.shape,\n", 348 | " np.sum(np.sum(np.isnan(X_test)))))" 349 | ], 350 | "metadata": { 351 | "id": "QOYz3uR8Tpmc" 352 | }, 353 | "execution_count": null, 354 | "outputs": [] 355 | }, 356 | { 357 | "cell_type": "markdown", 358 | "source": [ 359 | "## Missing data\n", 360 | "\n", 361 | "The outcome is the first column `'In-hospital_death'`. Most of the rest of the data are features we can use to predict this binary outcome (while avoiding data leakage).\n", 362 | "\n", 363 | "You'll note that the above has a lot of missing data! It is a challenging issue with medical data. In general there are three types of missing data:\n", 364 | "\n", 365 | "1. Missing completely at random (MCAR)\n", 366 | " * The data is missing for reasons *unrelated* to the data\n", 367 | " * a power outage results in losing vital sign data\n", 368 | "2. Missing at random (MAR)\n", 369 | " * The data is missing for reasons related to the data, but not the missing observation\n", 370 | " * we don't collect lactate measurements on admission to a medical ICU, but we collect them for cardiac ICU\n", 371 | "3. Missing not at random (MNAR)\n", 372 | " * The data is missing, and the reason it is missing *depends* on the value\n", 373 | " * a doctor does not order the Troponin-I lab test, because they believe it to be normal\n", 374 | "\n", 375 | "The hardest case to deal with is MNAR, and unfortunately, that is the most common in the medical domain. Still, we have to do something, so we often use approaches which are theoretically invalid under MNAR but in practice work acceptably well.\n", 376 | "\n", 377 | "Below, we'll replace missing data with the average value for the training population." 378 | ], 379 | "metadata": { 380 | "id": "7ghcSDMkRGFh" 381 | } 382 | }, 383 | { 384 | "cell_type": "code", 385 | "source": [ 386 | "mu = np.nanmean(X, axis=0)\n", 387 | "for i in range(5):\n", 388 | " print(f'{features[i]}: {mu[i]:2.1f}')" 389 | ], 390 | "metadata": { 391 | "id": "qDORwddBRXxB" 392 | }, 393 | "execution_count": null, 394 | "outputs": [] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "source": [ 399 | "Q4: What's wrong with the above approach?" 400 | ], 401 | "metadata": { 402 | "id": "A7yUeyqURYAI" 403 | } 404 | }, 405 | { 406 | "cell_type": "code", 407 | "source": [ 408 | "#A4: Data leakage! We use the test set values for estimating the mean.\n", 409 | "\n", 410 | "# since decision trees do not handle missing data, we impute it here\n", 411 | "mu = np.nanmean(X_train, axis=0)\n", 412 | "\n", 413 | "for i in range(X_train.shape[1]):\n", 414 | " idxMissing = np.isnan(X_train[:, i])\n", 415 | " X_train[idxMissing, i] = mu[i]\n", 416 | "\n", 417 | " idxMissing = np.isnan(X_test[:, i])\n", 418 | " X_test[idxMissing, i] = mu[i]\n", 419 | "\n", 420 | "# now we should find that we have no more missing data!\n", 421 | "\n", 422 | "print('Training size: {} - {:6d} missing observations'.format(X_train.shape,\n", 423 | " np.sum(np.sum(np.isnan(X_train)))))\n", 424 | "print('Test size: {} - {:6d} missing observations'.format(X_test.shape,\n", 425 | " np.sum(np.sum(np.isnan(X_test)))))" 426 | ], 427 | "metadata": { 428 | "id": "AQiPUXLnPuqS" 429 | }, 430 | "execution_count": null, 431 | "outputs": [] 432 | }, 433 | { 434 | "cell_type": "markdown", 435 | "source": [ 436 | "## Idempotency\n", 437 | "\n", 438 | "Q5: What's wrong with keeping the above code as is?" 439 | ], 440 | "metadata": { 441 | "id": "nQPz5CHTj7J-" 442 | } 443 | }, 444 | { 445 | "cell_type": "markdown", 446 | "source": [ 447 | "A5: I can very easily screw it all up by re-running everything out of order.\n", 448 | "\n", 449 | "If working locally, this can be avoided by writing functions in a separate .py file, and using `%load_ext` and `%autoreload 2`." 450 | ], 451 | "metadata": { 452 | "id": "l6kRS7GPkBYI" 453 | } 454 | }, 455 | { 456 | "cell_type": "code", 457 | "source": [ 458 | "def prepare_dataset(seta):\n", 459 | " # remove cheating variables from training data\n", 460 | " X = seta.drop(['In-hospital_death', 'Survival', 'Length_of_stay'], axis=1).values\n", 461 | " y = seta['In-hospital_death'].values\n", 462 | "\n", 463 | " # here we only use the first 2500 observations as our training set\n", 464 | " y_train = y[0:2500]\n", 465 | " X_train = X[0:2500, :]\n", 466 | "\n", 467 | " y_test = y[2500:]\n", 468 | " X_test = X[2500:, :]\n", 469 | "\n", 470 | " features = list(seta.drop(['In-hospital_death', 'Survival', 'Length_of_stay'], axis=1).columns)\n", 471 | "\n", 472 | " return X_train, y_train, X_test, y_test, features" 473 | ], 474 | "metadata": { 475 | "id": "FU5twtbokNa4" 476 | }, 477 | "execution_count": null, 478 | "outputs": [] 479 | }, 480 | { 481 | "cell_type": "markdown", 482 | "source": [ 483 | "## Decision trees have high \"variance\"\n", 484 | "\n", 485 | "It will be useful to demonstrate how decision trees have high \"variance\". In this context, variance refers to a property of some models to have a wide range of performance given random samples of data. Let's take a look at randomly slicing the data we have too see what that means." 486 | ], 487 | "metadata": { 488 | "id": "NBSPRd5eRtqw" 489 | } 490 | }, 491 | { 492 | "cell_type": "code", 493 | "source": [ 494 | "np.random.seed(123)\n", 495 | "\n", 496 | "fig = plt.figure(figsize=[12,3])\n", 497 | "for i in range(3):\n", 498 | " ax = fig.add_subplot(1,3,i+1)\n", 499 | "\n", 500 | " # generate indices in a random order\n", 501 | " idx = np.random.permutation(X_train.shape[0])\n", 502 | "\n", 503 | " # only use the first 50\n", 504 | " idx = idx[:50]\n", 505 | " X_temp = X_train[idx, :2]\n", 506 | " y_temp = y_train[idx]\n", 507 | "\n", 508 | " # initialize the model\n", 509 | " mdl = tree.DecisionTreeClassifier(max_depth=5)\n", 510 | "\n", 511 | " # train the model using the dataset\n", 512 | " mdl = mdl.fit(X_temp, y_temp)\n", 513 | "\n", 514 | " # only specify labels once for clarity\n", 515 | " xlabel = features[0] if i == 1 else None\n", 516 | " ylabel = features[1] if i == 0 else None\n", 517 | "\n", 518 | " plot_model_pred_2d(mdl, X_temp, y_temp, xlabel=xlabel, ylabel=ylabel, cbar=False)\n", 519 | "\n", 520 | "plt.show()" 521 | ], 522 | "metadata": { 523 | "id": "G6Kem-jZRu0l" 524 | }, 525 | "execution_count": null, 526 | "outputs": [] 527 | }, 528 | { 529 | "cell_type": "markdown", 530 | "source": [ 531 | "Above we can see that we are using random subsets of data, and as a result, our decision boundary can change quite a bit.\n", 532 | "\n", 533 | "Let's build the model on the full training set." 534 | ], 535 | "metadata": { 536 | "id": "rNX5-gpNjWpO" 537 | } 538 | }, 539 | { 540 | "cell_type": "code", 541 | "source": [ 542 | "# Instantiate a decision tree classifier\n", 543 | "mdl_dt = tree.DecisionTreeClassifier(criterion='entropy', splitter='best')\n", 544 | "\n", 545 | "# Fit the model to the training data\n", 546 | "mdl_dt = mdl_dt.fit(X_train, y_train)\n", 547 | "\n", 548 | "# evaluate the model on the test set\n", 549 | "yhat = mdl_dt.predict_proba(X_test)[:,1]\n", 550 | "score = metrics.roc_auc_score(y_test, yhat)\n", 551 | "\n", 552 | "print(f'Model AUROC on the test set: {score:1.3f}')" 553 | ], 554 | "metadata": { 555 | "id": "rieypeXtPunN" 556 | }, 557 | "execution_count": null, 558 | "outputs": [] 559 | }, 560 | { 561 | "cell_type": "markdown", 562 | "source": [ 563 | "Q5: This was just one performance measure on a random 1000 cases. Are there techniques we can use to better evaluate how well this model would do on a held-out dataset?" 564 | ], 565 | "metadata": { 566 | "id": "4horL1cujzFj" 567 | } 568 | }, 569 | { 570 | "cell_type": "markdown", 571 | "source": [ 572 | "A5: Of course :)\n", 573 | "\n", 574 | "Cross-validation is a very common one." 575 | ], 576 | "metadata": { 577 | "id": "EanRgcBBj2CM" 578 | } 579 | }, 580 | { 581 | "cell_type": "markdown", 582 | "source": [ 583 | "## Improving the model\n", 584 | "\n", 585 | "Let's say we wanted to improve this model, either by:\n", 586 | "\n", 587 | "- doing cross-validation\n", 588 | "- tuning hyperparameters\n", 589 | "- changing the model type\n", 590 | "\n", 591 | "We would have to be very careful about the data leakage step above, particularly with respect to missing data. We'd have to write for loops and constantly reshuffle the data around. This would be tedious, and very error prone.\n", 592 | "\n", 593 | "Enter scikit-learn pipelines." 594 | ], 595 | "metadata": { 596 | "id": "TJtzYRh1k1PE" 597 | } 598 | }, 599 | { 600 | "cell_type": "markdown", 601 | "source": [ 602 | "## Pipeline\n", 603 | "\n", 604 | "A really great way of building models is to use `pipeline` from scikit-learn. This allows us to define the steps in our preprocessing *with* the ultimate model. It's a great feature that simplifies a lot of the tedium in preprocessing. Here's an example imputing missing data." 605 | ], 606 | "metadata": { 607 | "id": "La_Cac7zlVuv" 608 | } 609 | }, 610 | { 611 | "cell_type": "code", 612 | "source": [], 613 | "metadata": { 614 | "id": "hG0UGdQ1lXR6" 615 | }, 616 | "execution_count": null, 617 | "outputs": [] 618 | }, 619 | { 620 | "cell_type": "code", 621 | "source": [ 622 | "# use pipeline to automatically apply preprocessing\n", 623 | "from sklearn.compose import ColumnTransformer\n", 624 | "from sklearn.pipeline import Pipeline\n", 625 | "from sklearn.impute import SimpleImputer\n", 626 | "from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder\n", 627 | "from sklearn.linear_model import LogisticRegression\n", 628 | "\n", 629 | "# choose the classifier\n", 630 | "base_mdl = tree.DecisionTreeClassifier(criterion='entropy', splitter='best')\n", 631 | "\n", 632 | "# create a pipeline which imputes missing data, then runs the model\n", 633 | "mdl = Pipeline(\n", 634 | " [\n", 635 | " (\"preprocessor\", SimpleImputer(missing_values=np.nan, strategy='mean')),\n", 636 | " ('model', base_mdl)\n", 637 | " ]\n", 638 | ")\n", 639 | "\n", 640 | "# evaluate the model on the data - same as before!\n", 641 | "mdl = mdl.fit(X_train, y_train)\n", 642 | "\n", 643 | "# evaluate the model on the test set\n", 644 | "yhat = mdl.predict_proba(X_test)[:, 1]\n", 645 | "score = metrics.roc_auc_score(y_test, yhat)\n", 646 | "\n", 647 | "print(f'Model AUROC on the test set: {score:1.3f}')" 648 | ], 649 | "metadata": { 650 | "id": "sIjMVC7RliTM" 651 | }, 652 | "execution_count": null, 653 | "outputs": [] 654 | }, 655 | { 656 | "cell_type": "markdown", 657 | "source": [ 658 | "Now that we have a pipeline setup, we can actually use it with the original data, as our pipeline will impute the data for us. We can use another scikit-learn tool, cross_val_score, to evaluate the AUROC across cross-validation folds." 659 | ], 660 | "metadata": { 661 | "id": "O7ZRJNipmBSL" 662 | } 663 | }, 664 | { 665 | "cell_type": "code", 666 | "source": [ 667 | "from sklearn import metrics\n", 668 | "from sklearn.model_selection import cross_val_score\n", 669 | "scores = cross_val_score(mdl, X, y, cv=5, scoring='roc_auc')\n", 670 | "print(\"AUROC: {:0.3f} [{:0.3f}, {:0.3f}]\".format(np.mean(scores), np.min(scores), np.max(scores)))" 671 | ], 672 | "metadata": { 673 | "id": "kpPlThF-mENy" 674 | }, 675 | "execution_count": null, 676 | "outputs": [] 677 | }, 678 | { 679 | "cell_type": "markdown", 680 | "source": [ 681 | "If we want to tune hyper-parameters, there is this helpful guide:\n", 682 | "https://scikit-learn.org/stable/modules/grid_search.html#grid-search\n", 683 | "\n", 684 | "Let's try to tune the max_depth parameter of our decision tree." 685 | ], 686 | "metadata": { 687 | "id": "j7XmPFaMmY1J" 688 | } 689 | }, 690 | { 691 | "cell_type": "code", 692 | "source": [ 693 | "from pprint import pprint\n", 694 | "from sklearn.model_selection import GridSearchCV\n", 695 | "\n", 696 | "param_grid = {'model__max_depth': [None, 3, 5, 10]}\n", 697 | "\n", 698 | "grid_search = GridSearchCV(\n", 699 | " estimator=mdl,\n", 700 | " param_grid=param_grid,\n", 701 | " scoring='roc_auc',\n", 702 | " cv=5\n", 703 | ")\n", 704 | "\n", 705 | "print(\"Performing grid search...\")\n", 706 | "print(\"Hyperparameters to be evaluated:\")\n", 707 | "pprint(param_grid)\n", 708 | "\n", 709 | "grid_search.fit(X, y)\n", 710 | "\n", 711 | "\n", 712 | "print(\"\\n===\\nBest parameters combination found:\")\n", 713 | "best_parameters = grid_search.best_estimator_.get_params()\n", 714 | "for param_name in sorted(best_parameters.keys()):\n", 715 | " print(f\"{param_name}: {best_parameters[param_name]}\")\n", 716 | "print(\"===\\n\")\n", 717 | "test_auroc = grid_search.score(X_test, y_test)\n", 718 | "print(\n", 719 | " \"AUROC of the best parameters using the inner CV of \"\n", 720 | " f\"the random search: {grid_search.best_score_:.3f}\"\n", 721 | ")\n", 722 | "print(f\"AUROC on test set: {test_auroc:.3f}\")" 723 | ], 724 | "metadata": { 725 | "id": "h9-PdkDsmedU" 726 | }, 727 | "execution_count": null, 728 | "outputs": [] 729 | }, 730 | { 731 | "cell_type": "markdown", 732 | "source": [ 733 | "Finally, we can make custom pipelins which preprocess segments of the data separately." 734 | ], 735 | "metadata": { 736 | "id": "QLymMpMsmT0O" 737 | } 738 | }, 739 | { 740 | "cell_type": "code", 741 | "source": [ 742 | "# use pipeline to automatically apply preprocessing\n", 743 | "from sklearn.compose import ColumnTransformer\n", 744 | "from sklearn.pipeline import Pipeline\n", 745 | "from sklearn.impute import SimpleImputer\n", 746 | "from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder\n", 747 | "from sklearn.linear_model import LogisticRegression\n", 748 | "\n", 749 | "# Let's define the features we want to use here\n", 750 | "\n", 751 | "# below says \"all features in the dataset, except these three\"\n", 752 | "# model_features = [x for x in seta.columns if x not in ['Length_of_stay', 'Survival', 'In-hospital_death']]\n", 753 | "\n", 754 | "model_features = [\n", 755 | " 'Age', 'Gender', 'Height', 'Weight',\n", 756 | " 'CCU', 'SysABP_last', 'TroponinI_last', 'MechVentDuration'\n", 757 | "]\n", 758 | "target_feature = 'In-hospital_death'\n", 759 | "\n", 760 | "base_mdl = tree.DecisionTreeClassifier(criterion='entropy', splitter='best', max_depth=3)\n", 761 | "\n", 762 | "# our pipeline will\n", 763 | "# (1) impute 0 if MechVentDuration is missing - since that implies no mech vent\n", 764 | "# (2) impute the mean for continuous variables - Imputer\n", 765 | "mechvent_features = [x for x in ['MechVentStartTime', 'MechVentDuration', 'MechVentLast8Hour'] if x in model_features]\n", 766 | "\n", 767 | "# be sure to exclude the mechvent features from our numeric features\n", 768 | "numeric_features = [x for x in model_features if x not in mechvent_features]\n", 769 | "\n", 770 | "# We create separate preprocessing pipelines for numeric and categorical data.\n", 771 | "numeric_transformer = Pipeline(steps=[\n", 772 | " ('imputer', SimpleImputer(missing_values=np.nan, strategy='mean')),\n", 773 | " ('scaler', StandardScaler())])\n", 774 | "\n", 775 | "mechvent_transformer = Pipeline(steps=[\n", 776 | " ('imputer', SimpleImputer(strategy='constant', fill_value=0))\n", 777 | "])\n", 778 | "\n", 779 | "# You could also consider a transformer which converts categorical data into a bunch of features of 0s/1s\n", 780 | "# so called \"one-hot\" encoding\n", 781 | "#categorical_transformer = Pipeline(steps=[\n", 782 | "# ('imputer', SimpleImputer(strategy='constant', fill_value=0)),\n", 783 | "# ('onehot', OneHotEncoder(handle_unknown='ignore'))\n", 784 | "#])\n", 785 | "\n", 786 | "preprocessor = ColumnTransformer(\n", 787 | " transformers=[\n", 788 | " ('num', numeric_transformer, numeric_features),\n", 789 | " ('mv', mechvent_transformer, mechvent_features)])\n", 790 | "\n", 791 | "mdl = Pipeline([(\"preprocessor\", preprocessor),\n", 792 | " ('model', base_mdl)])" 793 | ], 794 | "metadata": { 795 | "id": "0DdyaFImlZ9_" 796 | }, 797 | "execution_count": null, 798 | "outputs": [] 799 | }, 800 | { 801 | "cell_type": "markdown", 802 | "source": [ 803 | "Fit the above pipeline on the data, and evaluate it in cross-validation." 804 | ], 805 | "metadata": { 806 | "id": "xFqm62Wanu2R" 807 | } 808 | }, 809 | { 810 | "cell_type": "code", 811 | "source": [ 812 | "scores = cross_val_score(mdl, seta[model_features], seta[target_feature], cv=5, scoring='roc_auc')\n", 813 | "print(\"AUROC: {:0.3f} [{:0.3f}, {:0.3f}]\".format(np.mean(scores), np.min(scores), np.max(scores)))" 814 | ], 815 | "metadata": { 816 | "id": "fZMilozmnwiA" 817 | }, 818 | "execution_count": null, 819 | "outputs": [] 820 | }, 821 | { 822 | "cell_type": "markdown", 823 | "source": [ 824 | "## Keep exploring!\n", 825 | "\n", 826 | "* Are there other ways to impute missing data?\n", 827 | "* Have we thought about the features in our data, and how we are using them?\n", 828 | "* Have we visualized the data? Are there any obvious outliers which may fool our model?\n", 829 | " * Note: a lot of outliers were removed by custom preprocessing I did, but some may remain\n", 830 | "* Are there parameters of our model which we could change?\n", 831 | "* Is there a systematic way of choosing the parameters of our model?\n", 832 | "\n", 833 | "The below code downloads a second set of data - `set-b`. This is the same type of data from a distinct 4000 patients, but this time you don't have the answers!" 834 | ], 835 | "metadata": { 836 | "id": "RDfNUfw3lMXz" 837 | } 838 | }, 839 | { 840 | "cell_type": "code", 841 | "source": [ 842 | "url = 'https://github.com/alistairewj/tree-prediction-tutorial/raw/master/data/PhysionetChallenge2012-set-b-no-outcome.csv.gz'\n", 843 | "setb = pd.read_csv(url, sep=',', header=0, compression='gzip')\n", 844 | "setb.set_index('recordid', inplace=True)\n", 845 | "setb.head()" 846 | ], 847 | "metadata": { 848 | "id": "hQC_V0_qlMd_" 849 | }, 850 | "execution_count": null, 851 | "outputs": [] 852 | } 853 | ], 854 | "metadata": { 855 | "kernelspec": { 856 | "display_name": "Python 3", 857 | "name": "python3" 858 | }, 859 | "language_info": { 860 | "name": "python" 861 | }, 862 | "colab": { 863 | "provenance": [] 864 | } 865 | }, 866 | "nbformat": 4, 867 | "nbformat_minor": 0 868 | } -------------------------------------------------------------------------------- /etc/presentation-plots.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Miscellaneous plots\n", 8 | "\n", 9 | "This notebook contains miscellaneous plots I used to present tree based models." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "collapsed": true 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "from __future__ import print_function\n", 21 | "\n", 22 | "import numpy as np\n", 23 | "import pandas as pd\n", 24 | "from sklearn import linear_model\n", 25 | "from sklearn import tree\n", 26 | "from sklearn import ensemble\n", 27 | "from sklearn import metrics\n", 28 | "from sklearn.model_selection import train_test_split\n", 29 | "from sklearn.model_selection import cross_val_score\n", 30 | "\n", 31 | "from sklearn import datasets\n", 32 | "import pydotplus\n", 33 | "\n", 34 | "import matplotlib.pyplot as plt\n", 35 | "import matplotlib\n", 36 | "\n", 37 | "# used to display trees\n", 38 | "from IPython.display import Image\n", 39 | "%matplotlib inline\n", 40 | "plt.style.use('ggplot')" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "## Boosting example" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": { 54 | "collapsed": false 55 | }, 56 | "outputs": [], 57 | "source": [ 58 | "\n", 59 | "\n", 60 | "xi = np.arange(0,51,1)\n", 61 | "\n", 62 | "# exponential that converges to 1\n", 63 | "yi = 1-np.exp(-0.05*xi)\n", 64 | "\n", 65 | "plt.figure(figsize=[12,8])\n", 66 | "plt.plot(xi, 0.3*yi+0.2, lw=4, label='Weak learner')\n", 67 | "plt.plot(xi, -0.15*yi+0.2, lw=4, label='Overall ensemble')\n", 68 | "plt.ylim([0, 0.5])\n", 69 | "plt.yticks([0, 0.5], fontsize=24)\n", 70 | "plt.xticks(np.arange(0,51,10), fontsize=24)\n", 71 | "\n", 72 | "# add text\n", 73 | "plt.arrow(20, 0.3, -10, 0, width=0.005, head_length=1, color='k')\n", 74 | "plt.text(20.5, 0.3, 'Each tree has higher error', fontsize=24, va='center')\n", 75 | "\n", 76 | "plt.arrow(20, 0.15, -10, 0, width=0.005, head_length=1, color='k')\n", 77 | "plt.text(20.5, 0.15, 'Ensemble has lower error', fontsize=24, va='center')\n", 78 | "plt.ylabel('Error', fontsize=24)\n", 79 | "plt.xlabel('Number of trees', fontsize=24)\n", 80 | "plt.show()" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": { 87 | "collapsed": true 88 | }, 89 | "outputs": [], 90 | "source": [ 91 | "def make_colormap(seq):\n", 92 | " \"\"\"Return a LinearSegmentedColormap\n", 93 | " seq: a sequence of floats and RGB-tuples. The floats should be increasing\n", 94 | " and in the interval (0,1).\n", 95 | " \"\"\"\n", 96 | " seq = [(None,) * 3, 0.0] + list(seq) + [1.0, (None,) * 3]\n", 97 | " cdict = {'red': [], 'green': [], 'blue': []}\n", 98 | " for i, item in enumerate(seq):\n", 99 | " if isinstance(item, float):\n", 100 | " r1, g1, b1 = seq[i - 1]\n", 101 | " r2, g2, b2 = seq[i + 1]\n", 102 | " cdict['red'].append([item, r1, r2])\n", 103 | " cdict['green'].append([item, g1, g2])\n", 104 | " cdict['blue'].append([item, b1, b2])\n", 105 | " return matplotlib.colors.LinearSegmentedColormap('CustomMap', cdict)\n", 106 | "\n", 107 | "# OLD PURPLE colormap\n", 108 | "\n", 109 | "# colormap\n", 110 | "#cm = plt.cm.get_cmap(name='Purples',lut=2) # dummy initialization\n", 111 | "#c1 = [x/256.0 for x in [224,236,244]]\n", 112 | "#c2 = [x/256.0 for x in [136,86,167]]\n", 113 | "#cm = cm.from_list('custom', [c1,c2], N=2)\n", 114 | "\n", 115 | "\n", 116 | "# NEW custom colormap\n", 117 | "#e58139f9 - orange\n", 118 | "#399de5e0 - to blue\n", 119 | "s = list()\n", 120 | "\n", 121 | "lo = np.array(matplotlib.colors.to_rgb('#e5813900'))\n", 122 | "hi = np.array(matplotlib.colors.to_rgb('#399de5e0'))\n", 123 | "\n", 124 | "for i in range(255):\n", 125 | " s.append( list((hi-lo)*(float(i)/255)+lo) )\n", 126 | "cm = make_colormap(s)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": { 133 | "collapsed": true 134 | }, 135 | "outputs": [], 136 | "source": [ 137 | "def plot_model_purple(mdl, X, y, feat):\n", 138 | " plt.figure(figsize=[8,5])\n", 139 | "\n", 140 | " # colormap\n", 141 | " cm = plt.cm.get_cmap(name='Purples',lut=2) # dummy initialization\n", 142 | " c1 = [x/256.0 for x in [224,236,244]]\n", 143 | " c2 = [x/256.0 for x in [136,86,167]]\n", 144 | " cm = cm.from_list('custom', [c1,c2], N=2)\n", 145 | "\n", 146 | " # get minimum and maximum values\n", 147 | " x0_min = X[:, 0].min()\n", 148 | " x0_max = X[:, 0].max()\n", 149 | " x1_min = X[:, 1].min()\n", 150 | " x1_max = X[:, 1].max()\n", 151 | "\n", 152 | " vmin = np.min([x0_min,x1_min])\n", 153 | " vmax = np.max([x0_max,x1_max])\n", 154 | " xx, yy = np.meshgrid(np.linspace(x0_min, x0_max, 1000),\n", 155 | " np.linspace(x1_min, x1_max, 1000))\n", 156 | "\n", 157 | " Z = mdl.predict(np.c_[xx.ravel(), yy.ravel()])\n", 158 | " Z = Z.reshape(xx.shape)\n", 159 | "\n", 160 | " # plot the contour - colouring different regions\n", 161 | " cs = plt.contourf(xx, yy, Z, cmap=cm, levels=[0,1,2])\n", 162 | "\n", 163 | " # plot the individual data points - colouring by the *true* outcome\n", 164 | " color = np.asarray(y.ravel(),dtype='float')\n", 165 | " plt.scatter(X[:, 0], X[:, 1], c=color, marker='o',\n", 166 | " s=60, cmap=cm)\n", 167 | "\n", 168 | " plt.xlabel(feat[0],fontsize=24)\n", 169 | " plt.ylabel(feat[1],fontsize=24)\n", 170 | " plt.axis(\"tight\")\n", 171 | "\n", 172 | " plt.colorbar(cs)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": { 179 | "collapsed": true 180 | }, 181 | "outputs": [], 182 | "source": [ 183 | "def plot_model_pred_2d_old(mdl, X, y, feat):\n", 184 | " # look at the regions in a 2d plot\n", 185 | " # based on scikit-learn tutorial plot_iris.html\n", 186 | "\n", 187 | " # get minimum and maximum values\n", 188 | " x0_min = X[:, 0].min()\n", 189 | " x0_max = X[:, 0].max()\n", 190 | " x1_min = X[:, 1].min()\n", 191 | " x1_max = X[:, 1].max()\n", 192 | "\n", 193 | " xx, yy = np.meshgrid(np.linspace(x0_min, x0_max, 1000),\n", 194 | " np.linspace(x1_min, x1_max, 1000))\n", 195 | "\n", 196 | " Z = mdl.predict(np.c_[xx.ravel(), yy.ravel()])\n", 197 | " Z = Z.reshape(xx.shape)\n", 198 | "\n", 199 | " # plot the contour - colouring different regions\n", 200 | " cs = plt.contourf(xx, yy, Z, cmap='hsv')\n", 201 | "\n", 202 | " # plot the individual data points - colouring by the *true* outcome\n", 203 | " color = y.ravel()\n", 204 | " plt.scatter(X[:, 0], X[:, 1], c=color, marker='o', s=40, cmap='Blues')\n", 205 | "\n", 206 | " plt.xlabel(feat[0],fontsize=24)\n", 207 | " plt.ylabel(feat[1],fontsize=24)\n", 208 | " plt.axis(\"tight\")" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": { 215 | "collapsed": true 216 | }, 217 | "outputs": [], 218 | "source": [ 219 | "def plot_model_pred_2d(mdl, X, y, feat, cm=None, plot_colorbar=True):\n", 220 | " # look at the regions in a 2d plot\n", 221 | " # based on scikit-learn tutorial plot_iris.html\n", 222 | " \n", 223 | " # get minimum and maximum values\n", 224 | " x0_min = X[:, 0].min()\n", 225 | " x0_max = X[:, 0].max()\n", 226 | " x1_min = X[:, 1].min()\n", 227 | " x1_max = X[:, 1].max()\n", 228 | "\n", 229 | " xx, yy = np.meshgrid(np.linspace(x0_min, x0_max, 100),\n", 230 | " np.linspace(x1_min, x1_max, 100))\n", 231 | "\n", 232 | " Z = mdl.predict(np.c_[xx.ravel(), yy.ravel()])\n", 233 | " Z = Z.reshape(xx.shape)\n", 234 | " \n", 235 | " if not cm:\n", 236 | " # custom colormap\n", 237 | " #e58139f9 - orange\n", 238 | " #399de5e0 - to blue\n", 239 | " s = list()\n", 240 | "\n", 241 | " lo = np.array(matplotlib.colors.to_rgb('#e5813900'))\n", 242 | " hi = np.array(matplotlib.colors.to_rgb('#399de5e0'))\n", 243 | "\n", 244 | " for i in range(255):\n", 245 | " s.append( list((hi-lo)*(float(i)/255)+lo) )\n", 246 | " cm = make_colormap(s)\n", 247 | " \n", 248 | " # plot the contour - colouring different regions\n", 249 | " cs = plt.contourf(xx, yy, Z, cmap=cm)\n", 250 | "\n", 251 | " # plot the individual data points - colouring by the *true* outcome\n", 252 | " color = y.ravel()\n", 253 | " plt.scatter(X[:, 0], X[:, 1], c=color, edgecolor='k', linewidth=2,\n", 254 | " marker='o', s=60, cmap=cm)\n", 255 | "\n", 256 | " plt.xlabel(feat[0],fontsize=24)\n", 257 | " plt.ylabel(feat[1],fontsize=24)\n", 258 | " plt.axis(\"tight\")\n", 259 | " if plot_colorbar:\n", 260 | " #plt.clim([-1.5,1.5])\n", 261 | " plt.colorbar()" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": null, 267 | "metadata": { 268 | "collapsed": true 269 | }, 270 | "outputs": [], 271 | "source": [ 272 | "# real example\n", 273 | "df = datasets.load_iris()" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "metadata": { 280 | "collapsed": false 281 | }, 282 | "outputs": [], 283 | "source": [ 284 | "# simple plot with just the points\n", 285 | "\n", 286 | "plt.figure(figsize=[10,8])\n", 287 | "plt.scatter(0, 1, s=400)\n", 288 | "plt.scatter(0, 2,color=lo, s=400)\n", 289 | "plt.scatter(0, 3,color=hi, s=400)\n", 290 | "plt.grid()\n", 291 | "plt.show()" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "metadata": { 298 | "collapsed": false 299 | }, 300 | "outputs": [], 301 | "source": [ 302 | "# 16 plots, 3 subplots per row and text\n", 303 | "f, ax = plt.subplots(4, 4, figsize=[16,10], sharex='col', sharey='row')\n", 304 | "\n", 305 | "# for this plot, we re-order the data so sepal length is bottom right\n", 306 | "data = df['data']\n", 307 | "data = data[:, ::-1]\n", 308 | "feat = df['feature_names']\n", 309 | "feat = feat[::-1]\n", 310 | "\n", 311 | "for i in range(df['data'].shape[1]):\n", 312 | " for j in range(df['data'].shape[1]):\n", 313 | " if i==j:\n", 314 | " ax[i, j].grid()\n", 315 | " else:\n", 316 | " ax[i, j].scatter(data[:50,j], data[:50,i])\n", 317 | " ax[i, j].scatter(data[50:100,j], data[50:100,i],color=lo)\n", 318 | " ax[i, j].scatter(data[100:,j], data[100:,i],color=hi)\n", 319 | " \n", 320 | "# add text to middle plots\n", 321 | "for i in range(df['data'].shape[1]):\n", 322 | " xloc = ax[i,i].get_xlim()\n", 323 | " yloc = ax[i,i].get_ylim()\n", 324 | " ax[i, i].text(np.mean(xloc), np.mean(yloc), feat[i],\n", 325 | " horizontalalignment='center', verticalalignment='center', fontsize=16)\n", 326 | " \n", 327 | "# hide x ticks for top plots\n", 328 | "plt.setp([a.get_xticklabels() for a in ax[0, :]], visible=False)\n", 329 | "\n", 330 | "# hide y ticks for right plots\n", 331 | "plt.setp([a.get_yticklabels() for a in ax[:, 1]], visible=False)\n", 332 | "\n", 333 | "plt.show()" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "metadata": { 340 | "collapsed": false 341 | }, 342 | "outputs": [], 343 | "source": [ 344 | "# 16 plots, 3 subplots per row and text\n", 345 | "f, ax = plt.subplots(4, 4, figsize=[16,10], sharex='col', sharey='row')\n", 346 | "\n", 347 | "# for this plot, we re-order the data so sepal length is bottom right\n", 348 | "data = df['data']\n", 349 | "data = data[:, ::-1]\n", 350 | "feat = df['feature_names']\n", 351 | "feat = feat[::-1]\n", 352 | "\n", 353 | "for i in range(df['data'].shape[1]):\n", 354 | " for j in range(df['data'].shape[1]):\n", 355 | " if i==j:\n", 356 | " ax[i, j].grid()\n", 357 | " else:\n", 358 | " #ax[i, j].scatter(data[:50,j], data[:50,i])\n", 359 | " ax[i, j].scatter(data[50:100,j], data[50:100,i],color=lo)\n", 360 | " ax[i, j].scatter(data[100:,j], data[100:,i],color=hi)\n", 361 | " \n", 362 | "# add text to middle plots\n", 363 | "for i in range(df['data'].shape[1]):\n", 364 | " xloc = ax[i,i].get_xlim()\n", 365 | " yloc = ax[i,i].get_ylim()\n", 366 | " ax[i, i].text(np.mean(xloc), np.mean(yloc), feat[i],\n", 367 | " horizontalalignment='center', verticalalignment='center', fontsize=16)\n", 368 | " \n", 369 | "# hide x ticks for top plots\n", 370 | "plt.setp([a.get_xticklabels() for a in ax[0, :]], visible=False)\n", 371 | "\n", 372 | "# hide y ticks for right plots\n", 373 | "plt.setp([a.get_yticklabels() for a in ax[:, 1]], visible=False)\n", 374 | "\n", 375 | "plt.show()" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": null, 381 | "metadata": { 382 | "collapsed": true 383 | }, 384 | "outputs": [], 385 | "source": [ 386 | "idx = [0,2]\n", 387 | "X = df['data'][50:,idx]\n", 388 | "y = df['target'][50:]\n", 389 | "# scale y to be -1, 1\n", 390 | "y[y==1] = -1\n", 391 | "y[y==2] = 1\n", 392 | "feat = [df['feature_names'][x] for x in idx]" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": null, 398 | "metadata": { 399 | "collapsed": true 400 | }, 401 | "outputs": [], 402 | "source": [ 403 | "def plot_cleanup():\n", 404 | " ax = plt.gca()\n", 405 | " ax.spines[\"top\"].set_visible(False)\n", 406 | " ax.spines[\"bottom\"].set_visible(False)\n", 407 | " ax.spines[\"right\"].set_visible(False)\n", 408 | " ax.spines[\"left\"].set_visible(False)\n", 409 | " ax.get_xaxis().tick_bottom()\n", 410 | " ax.get_yaxis().tick_left()\n", 411 | " plt.xticks(fontsize=16)\n", 412 | " plt.yticks(fontsize=16)" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": null, 418 | "metadata": { 419 | "collapsed": false 420 | }, 421 | "outputs": [], 422 | "source": [ 423 | "mdl = linear_model.LogisticRegression()\n", 424 | "mdl = mdl.fit(X,y)\n", 425 | "\n", 426 | "plt.figure(figsize=[12,8])\n", 427 | "\n", 428 | "# plot the individual data points - colouring by the *true* outcome\n", 429 | "color = np.asarray(y.ravel(),dtype='float')\n", 430 | "f = plt.scatter(X[:, 0], X[:, 1], c=color, edgecolor='k',\n", 431 | " marker='o', linewidth=2,\n", 432 | " s=60, cmap=cm)\n", 433 | "\n", 434 | "plt.xlabel(feat[0],fontsize=24)\n", 435 | "plt.ylabel(feat[1],fontsize=24)\n", 436 | "plt.axis(\"tight\")\n", 437 | "\n", 438 | "plt.colorbar(f)\n", 439 | "\n", 440 | "# cleanup plot\n", 441 | "plot_cleanup()\n", 442 | "\n", 443 | "plt.show()" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": null, 449 | "metadata": { 450 | "collapsed": false 451 | }, 452 | "outputs": [], 453 | "source": [ 454 | "mdl = linear_model.LinearRegression()\n", 455 | "x0 = X[:,0].reshape([100,1])\n", 456 | "x1 = X[:,1].reshape([100,1])\n", 457 | "mdl = mdl.fit(x0, x1)\n", 458 | "\n", 459 | "# get minimum and maximum values\n", 460 | "x0_min = X[:, 0].min()+0.2\n", 461 | "x0_max = X[:, 0].max()-0.2\n", 462 | "\n", 463 | "Z = mdl.predict([[x0_min], [x0_max]])\n", 464 | "\n", 465 | "plt.figure(figsize=[12,8])\n", 466 | "\n", 467 | "# plot the line\n", 468 | "plt.plot([x0_min, x0_max], Z, 'k--', linewidth=3)\n", 469 | "# plot the individual data points - colouring by the *true* outcome\n", 470 | "color = np.asarray(y.ravel(),dtype='float')\n", 471 | "f = plt.scatter(X[:, 0], X[:, 1], c=color, edgecolor='k',\n", 472 | " marker='o', linewidth=2,\n", 473 | " s=60, cmap=cm)\n", 474 | "\n", 475 | "plt.xlabel(feat[0],fontsize=24)\n", 476 | "plt.ylabel(feat[1],fontsize=24)\n", 477 | "plt.axis(\"tight\")\n", 478 | "\n", 479 | "plt.colorbar(f)\n", 480 | "\n", 481 | "# cleanup plot\n", 482 | "plot_cleanup()\n", 483 | "plt.show()" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": null, 489 | "metadata": { 490 | "collapsed": false 491 | }, 492 | "outputs": [], 493 | "source": [ 494 | "print('{} = {:3.2f} * {} + {:3.1f}'.format(feat[1], mdl.coef_[0][0], feat[0], mdl.intercept_[0]))" 495 | ] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": null, 500 | "metadata": { 501 | "collapsed": false 502 | }, 503 | "outputs": [], 504 | "source": [ 505 | "mdl = linear_model.LinearRegression()\n", 506 | "mdl = mdl.fit(X,y)\n", 507 | "\n", 508 | "plt.figure(figsize=[12,8])\n", 509 | "\n", 510 | "# get minimum and maximum values\n", 511 | "x0_min = X[:, 0].min()\n", 512 | "x0_max = X[:, 0].max()\n", 513 | "x1_min = X[:, 1].min()\n", 514 | "x1_max = X[:, 1].max()\n", 515 | "\n", 516 | "vmin = np.min([x0_min,x1_min])\n", 517 | "vmax = np.max([x0_max,x1_max])\n", 518 | "xx, yy = np.meshgrid(np.linspace(x0_min, x0_max, 1000),\n", 519 | " np.linspace(x1_min, x1_max, 1000))\n", 520 | "\n", 521 | "Z = mdl.predict(np.c_[xx.ravel(), yy.ravel()])\n", 522 | "Z = Z.reshape(xx.shape)\n", 523 | "\n", 524 | "# plot the contour - colouring different regions\n", 525 | "cs = plt.contourf(xx, yy, Z, cmap=cm)\n", 526 | "\n", 527 | "# plot the individual data points - colouring by the *true* outcome\n", 528 | "color = np.asarray(y.ravel(),dtype='float')\n", 529 | "f = plt.scatter(X[:, 0], X[:, 1], c=color, edgecolor='k',\n", 530 | " marker='o', linewidth=2,\n", 531 | " s=60, cmap=cm)\n", 532 | "\n", 533 | "plt.xlabel(feat[0],fontsize=24)\n", 534 | "plt.ylabel(feat[1],fontsize=24)\n", 535 | "plt.axis(\"tight\")\n", 536 | "\n", 537 | "plt.colorbar(f)\n", 538 | "\n", 539 | "# cleanup plot\n", 540 | "plot_cleanup()\n", 541 | "plt.show()" 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": null, 547 | "metadata": { 548 | "collapsed": false 549 | }, 550 | "outputs": [], 551 | "source": [ 552 | "mdl = linear_model.LinearRegression()\n", 553 | "mdl = mdl.fit(X,y)\n", 554 | "\n", 555 | "plt.figure(figsize=[12,8])\n", 556 | "\n", 557 | "# get minimum and maximum values\n", 558 | "x0_min = X[:, 0].min()\n", 559 | "x0_max = X[:, 0].max()\n", 560 | "x1_min = X[:, 1].min()\n", 561 | "x1_max = X[:, 1].max()\n", 562 | "\n", 563 | "vmin = np.min([x0_min,x1_min])\n", 564 | "vmax = np.max([x0_max,x1_max])\n", 565 | "xx, yy = np.meshgrid(np.linspace(x0_min, x0_max, 1000),\n", 566 | " np.linspace(x1_min, x1_max, 1000))\n", 567 | "\n", 568 | "Z = mdl.predict(np.c_[xx.ravel(), yy.ravel()])\n", 569 | "# round predictions\n", 570 | "Z[Z>=0] = 1\n", 571 | "Z[Z<0] = -1\n", 572 | "Z = Z.reshape(xx.shape)\n", 573 | "\n", 574 | "# plot the contour - colouring different regions\n", 575 | "cs = plt.contourf(xx, yy, Z, cmap=cm)\n", 576 | "\n", 577 | "# plot the individual data points - colouring by the *true* outcome\n", 578 | "color = np.asarray(y.ravel(),dtype='float')\n", 579 | "plt.scatter(X[:, 0], X[:, 1], c=color, edgecolor='k',\n", 580 | " marker='o', linewidth=2,\n", 581 | " s=60, cmap=cm)\n", 582 | "\n", 583 | "plt.xlabel(feat[0],fontsize=24)\n", 584 | "plt.ylabel(feat[1],fontsize=24)\n", 585 | "plt.axis(\"tight\")\n", 586 | "\n", 587 | "plt.colorbar(cs)\n", 588 | "\n", 589 | "# cleanup plot\n", 590 | "plot_cleanup()\n", 591 | "plt.show()" 592 | ] 593 | }, 594 | { 595 | "cell_type": "markdown", 596 | "metadata": {}, 597 | "source": [ 598 | "# decision tree" 599 | ] 600 | }, 601 | { 602 | "cell_type": "code", 603 | "execution_count": null, 604 | "metadata": { 605 | "collapsed": false 606 | }, 607 | "outputs": [], 608 | "source": [ 609 | "# fit a decision tree\n", 610 | "mdl = tree.DecisionTreeClassifier(max_depth=1)\n", 611 | "mdl = mdl.fit(X,y)\n", 612 | "\n", 613 | "plt.figure(figsize=[12,8])\n", 614 | "\n", 615 | "# get minimum and maximum values\n", 616 | "x0_min = X[:, 0].min()\n", 617 | "x0_max = X[:, 0].max()\n", 618 | "x1_min = X[:, 1].min()\n", 619 | "x1_max = X[:, 1].max()\n", 620 | "\n", 621 | "vmin = np.min([x0_min,x1_min])\n", 622 | "vmax = np.max([x0_max,x1_max])\n", 623 | "xx, yy = np.meshgrid(np.linspace(x0_min, x0_max, 1000),\n", 624 | " np.linspace(x1_min, x1_max, 1000))\n", 625 | "\n", 626 | "Z = mdl.predict(np.c_[xx.ravel(), yy.ravel()])\n", 627 | "# round predictions\n", 628 | "Z[Z>=0] = 1\n", 629 | "Z[Z<0] = -1\n", 630 | "Z = Z.reshape(xx.shape)\n", 631 | "\n", 632 | "# plot the contour - colouring different regions\n", 633 | "cs = plt.contourf(xx, yy, Z, cmap=cm)\n", 634 | "\n", 635 | "# plot the individual data points - colouring by the *true* outcome\n", 636 | "color = np.asarray(y.ravel(),dtype='float')\n", 637 | "plt.scatter(X[:, 0], X[:, 1], c=color, edgecolor='k',\n", 638 | " marker='o', linewidth=2,\n", 639 | " s=60, cmap=cm)\n", 640 | "\n", 641 | "plt.xlabel(feat[0],fontsize=24)\n", 642 | "plt.ylabel(feat[1],fontsize=24)\n", 643 | "plt.axis(\"tight\")\n", 644 | "\n", 645 | "plt.colorbar(cs)\n", 646 | "\n", 647 | "# cleanup plot\n", 648 | "plot_cleanup()\n", 649 | "plt.show()" 650 | ] 651 | }, 652 | { 653 | "cell_type": "code", 654 | "execution_count": null, 655 | "metadata": { 656 | "collapsed": false 657 | }, 658 | "outputs": [], 659 | "source": [ 660 | "# examine the tree\n", 661 | "tree_graph = tree.export_graphviz(mdl, out_file=None,\n", 662 | " feature_names=feat, \n", 663 | " filled=True, rounded=True) \n", 664 | "graph = pydotplus.graphviz.graph_from_dot_data(tree_graph) \n", 665 | "Image(graph.create_png())" 666 | ] 667 | }, 668 | { 669 | "cell_type": "markdown", 670 | "metadata": {}, 671 | "source": [ 672 | "# fitting a sinusoid" 673 | ] 674 | }, 675 | { 676 | "cell_type": "code", 677 | "execution_count": null, 678 | "metadata": { 679 | "collapsed": false 680 | }, 681 | "outputs": [], 682 | "source": [ 683 | "# create a sample dataset of sinusoidal data\n", 684 | "rng = np.random.RandomState(777)\n", 685 | "N = 30\n", 686 | "\n", 687 | "# random points along the time axis for two cycles\n", 688 | "x = np.sort(2 * np.pi * rng.rand(N))\n", 689 | "y_true = np.sin(x)\n", 690 | "# generate the same data with random noise\n", 691 | "y_noise = np.sin(x) + (rng.rand(N)-0.5)*0.3\n", 692 | "\n", 693 | "# reshape x to be the only feature\n", 694 | "x = x.reshape(-1,1)\n", 695 | "\n", 696 | "# fit a decision tree\n", 697 | "mdl = tree.DecisionTreeRegressor(max_depth=5).fit(x, y_noise)\n", 698 | "\n", 699 | "# get test points\n", 700 | "x_test = np.linspace(0, 2*np.pi, 100).reshape(-1,1)\n", 701 | "y_test_pred = mdl.predict(x_test)\n", 702 | "\n", 703 | "plt.figure(figsize=[12,8])\n", 704 | "# plot original sinusoid\n", 705 | "plt.plot(x, y_true, 'k--',linewidth=2, label='Truth')\n", 706 | "\n", 707 | "# noisy test points\n", 708 | "plt.scatter(x, y_noise, marker='o', color='b', alpha=0.8, s=75, linewidth=2,label='Data')\n", 709 | "\n", 710 | "# decision tree decisions\n", 711 | "plt.plot(x_test, y_test_pred, 'r-', linewidth=2,label='Decision tree')\n", 712 | "\n", 713 | "plot_cleanup()\n", 714 | "plt.legend(fontsize=20)\n", 715 | "plt.show()" 716 | ] 717 | }, 718 | { 719 | "cell_type": "code", 720 | "execution_count": null, 721 | "metadata": { 722 | "collapsed": false 723 | }, 724 | "outputs": [], 725 | "source": [ 726 | "mdl = tree.DecisionTreeClassifier(criterion='entropy', splitter='best')\n", 727 | "mdl = mdl.fit(X,y)\n", 728 | "\n", 729 | "plt.figure(figsize=[12,8])\n", 730 | "\n", 731 | "# get minimum and maximum values\n", 732 | "x0_min = X[:, 0].min()\n", 733 | "x0_max = X[:, 0].max()\n", 734 | "x1_min = X[:, 1].min()\n", 735 | "x1_max = X[:, 1].max()\n", 736 | "\n", 737 | "vmin = np.min([x0_min,x1_min])\n", 738 | "vmax = np.max([x0_max,x1_max])\n", 739 | "xx, yy = np.meshgrid(np.linspace(x0_min, x0_max, 1000),\n", 740 | " np.linspace(x1_min, x1_max, 1000))\n", 741 | "\n", 742 | "Z = mdl.predict(np.c_[xx.ravel(), yy.ravel()])\n", 743 | "Z = Z.reshape(xx.shape)\n", 744 | "\n", 745 | "# plot the contour - colouring different regions\n", 746 | "cs = plt.contourf(xx, yy, Z, cmap=cm)\n", 747 | "\n", 748 | "# plot the individual data points - colouring by the *true* outcome\n", 749 | "color = np.asarray(y.ravel(),dtype='float')\n", 750 | "plt.scatter(X[:, 0], X[:, 1], c=color, edgecolor='k',\n", 751 | " marker='o', linewidth=2,\n", 752 | " s=60, cmap=cm)\n", 753 | "\n", 754 | "plt.xlabel(feat[0],fontsize=24)\n", 755 | "plt.ylabel(feat[1],fontsize=24)\n", 756 | "plt.axis(\"tight\")\n", 757 | "\n", 758 | "plt.colorbar(f)\n", 759 | "\n", 760 | "# cleanup plot\n", 761 | "plot_cleanup()\n", 762 | "\n", 763 | "plt.show()" 764 | ] 765 | }, 766 | { 767 | "cell_type": "code", 768 | "execution_count": null, 769 | "metadata": { 770 | "collapsed": false 771 | }, 772 | "outputs": [], 773 | "source": [ 774 | "# examine the tree\n", 775 | "tree_graph = tree.export_graphviz(mdl, out_file=None,\n", 776 | " feature_names=feat, \n", 777 | " filled=True, rounded=True) \n", 778 | "graph = pydotplus.graphviz.graph_from_dot_data(tree_graph) \n", 779 | "Image(graph.create_png())" 780 | ] 781 | }, 782 | { 783 | "cell_type": "markdown", 784 | "metadata": {}, 785 | "source": [ 786 | "# best splits" 787 | ] 788 | }, 789 | { 790 | "cell_type": "code", 791 | "execution_count": null, 792 | "metadata": { 793 | "collapsed": false 794 | }, 795 | "outputs": [], 796 | "source": [ 797 | "mdl = tree.DecisionTreeClassifier(criterion='entropy', splitter='best')\n", 798 | "mdl = mdl.fit(X,y)\n", 799 | "\n", 800 | "plt.figure(figsize=[12,8])\n", 801 | "\n", 802 | "# get minimum and maximum values\n", 803 | "x0_min = X[:, 0].min()\n", 804 | "x0_max = X[:, 0].max()\n", 805 | "x1_min = X[:, 1].min()\n", 806 | "x1_max = X[:, 1].max()\n", 807 | "\n", 808 | "vmin = np.min([x0_min,x1_min])\n", 809 | "vmax = np.max([x0_max,x1_max])\n", 810 | "xx, yy = np.meshgrid(np.linspace(x0_min, x0_max, 1000),\n", 811 | " np.linspace(x1_min, x1_max, 1000))\n", 812 | "\n", 813 | "Z = mdl.predict(np.c_[xx.ravel(), yy.ravel()])\n", 814 | "Z = Z.reshape(xx.shape)\n", 815 | "\n", 816 | "# plot the individual data points - colouring by the *true* outcome\n", 817 | "color = np.asarray(y.ravel(),dtype='float')\n", 818 | "plt.scatter(X[:, 0], X[:, 1], c=color, edgecolor='k',\n", 819 | " marker='o', linewidth=2,\n", 820 | " s=60, cmap=cm)\n", 821 | "\n", 822 | "plt.xlabel(feat[0],fontsize=24)\n", 823 | "plt.ylabel(feat[1],fontsize=24)\n", 824 | "plt.axis(\"tight\")\n", 825 | "\n", 826 | "#plt.colorbar(cs)\n", 827 | "\n", 828 | "# cleanup plot\n", 829 | "plot_cleanup()\n", 830 | "\n", 831 | "plt.show()" 832 | ] 833 | }, 834 | { 835 | "cell_type": "code", 836 | "execution_count": null, 837 | "metadata": { 838 | "collapsed": false 839 | }, 840 | "outputs": [], 841 | "source": [ 842 | "mdl = tree.DecisionTreeClassifier(criterion='entropy', splitter='best')\n", 843 | "mdl = mdl.fit(X,y)\n", 844 | "\n", 845 | "plt.figure(figsize=[12,8])\n", 846 | "\n", 847 | "# get minimum and maximum values\n", 848 | "x0_min = X[:, 0].min()\n", 849 | "x0_max = X[:, 0].max()\n", 850 | "x1_min = X[:, 1].min()\n", 851 | "x1_max = X[:, 1].max()\n", 852 | "\n", 853 | "vmin = np.min([x0_min,x1_min])\n", 854 | "vmax = np.max([x0_max,x1_max])\n", 855 | "xx, yy = np.meshgrid(np.linspace(x0_min, x0_max, 1000),\n", 856 | " np.linspace(x1_min, x1_max, 1000))\n", 857 | "\n", 858 | "Z = mdl.predict(np.c_[xx.ravel(), yy.ravel()])\n", 859 | "Z = Z.reshape(xx.shape)\n", 860 | "\n", 861 | "# plot the individual data points - colouring by the *true* outcome\n", 862 | "color = np.asarray(y.ravel(),dtype='float')\n", 863 | "plt.scatter(X[:, 0], X[:, 1], c=color, edgecolor='k',\n", 864 | " marker='o', linewidth=2,\n", 865 | " s=60, cmap=cm)\n", 866 | "\n", 867 | "plt.xlabel(feat[0],fontsize=24)\n", 868 | "plt.ylabel(feat[1],fontsize=24)\n", 869 | "plt.axis(\"tight\")\n", 870 | "\n", 871 | "plt.plot([x0_min, x0_max],[4.75,4.75],'k--',linewidth=3)\n", 872 | "\n", 873 | "# cleanup plot\n", 874 | "plot_cleanup()\n", 875 | "\n", 876 | "plt.show()" 877 | ] 878 | }, 879 | { 880 | "cell_type": "code", 881 | "execution_count": null, 882 | "metadata": { 883 | "collapsed": false 884 | }, 885 | "outputs": [], 886 | "source": [ 887 | "mdl = tree.DecisionTreeClassifier(criterion='entropy', splitter='best',max_depth=1)\n", 888 | "mdl = mdl.fit(X,y)\n", 889 | "\n", 890 | "plt.figure(figsize=[12,8])\n", 891 | "\n", 892 | "# get minimum and maximum values\n", 893 | "x0_min = X[:, 0].min()\n", 894 | "x0_max = X[:, 0].max()\n", 895 | "x1_min = X[:, 1].min()\n", 896 | "x1_max = X[:, 1].max()\n", 897 | "\n", 898 | "vmin = np.min([x0_min,x1_min])\n", 899 | "vmax = np.max([x0_max,x1_max])\n", 900 | "xx, yy = np.meshgrid(np.linspace(x0_min, x0_max, 1000),\n", 901 | " np.linspace(x1_min, x1_max, 1000))\n", 902 | "\n", 903 | "Z = mdl.predict(np.c_[xx.ravel(), yy.ravel()])\n", 904 | "Z = Z.reshape(xx.shape)\n", 905 | "\n", 906 | "# plot the contour - colouring different regions\n", 907 | "cs = plt.contourf(xx, yy, Z, cmap=cm)\n", 908 | "\n", 909 | "# plot the individual data points - colouring by the *true* outcome\n", 910 | "color = np.asarray(y.ravel(),dtype='float')\n", 911 | "plt.scatter(X[:, 0], X[:, 1], c=color, edgecolor='k',\n", 912 | " marker='o', linewidth=2,\n", 913 | " s=60, cmap=cm)\n", 914 | "\n", 915 | "plt.xlabel(feat[0],fontsize=24)\n", 916 | "plt.ylabel(feat[1],fontsize=24)\n", 917 | "plt.axis(\"tight\")\n", 918 | "\n", 919 | "plt.plot([x0_min, x0_max],[4.75,4.75],'k--',linewidth=3)\n", 920 | "\n", 921 | "# cleanup plot\n", 922 | "plot_cleanup()\n", 923 | "\n", 924 | "plt.show()" 925 | ] 926 | }, 927 | { 928 | "cell_type": "code", 929 | "execution_count": null, 930 | "metadata": { 931 | "collapsed": false 932 | }, 933 | "outputs": [], 934 | "source": [ 935 | "# examine the tree\n", 936 | "tree_graph = tree.export_graphviz(mdl, out_file=None,\n", 937 | " feature_names=feat, \n", 938 | " filled=True, rounded=True) \n", 939 | "graph = pydotplus.graphviz.graph_from_dot_data(tree_graph) \n", 940 | "Image(graph.create_png())" 941 | ] 942 | }, 943 | { 944 | "cell_type": "code", 945 | "execution_count": null, 946 | "metadata": { 947 | "collapsed": false 948 | }, 949 | "outputs": [], 950 | "source": [ 951 | "mdl = tree.DecisionTreeClassifier(criterion='entropy', splitter='best', max_depth=1)\n", 952 | "mdl = mdl.fit(X,y)\n", 953 | "\n", 954 | "plt.figure(figsize=[12,8])\n", 955 | "\n", 956 | "# get minimum and maximum values\n", 957 | "x0_min = X[:, 0].min()\n", 958 | "x0_max = X[:, 0].max()\n", 959 | "x1_min = X[:, 1].min()\n", 960 | "x1_max = X[:, 1].max()\n", 961 | "\n", 962 | "vmin = np.min([x0_min,x1_min])\n", 963 | "vmax = np.max([x0_max,x1_max])\n", 964 | "xx, yy = np.meshgrid(np.linspace(x0_min, x0_max, 1000),\n", 965 | " np.linspace(x1_min, x1_max, 1000))\n", 966 | "\n", 967 | "X_grid = np.c_[xx.ravel(), yy.ravel()]\n", 968 | "Z = mdl.predict(X_grid)\n", 969 | "\n", 970 | "# customize the prediction using the left side of above tree\n", 971 | "# apply:\n", 972 | "# petal length <= 3.9\n", 973 | "idxUnk = (X_grid[:,1] <= 4.75) & (X_grid[:,0] <= 4.95)\n", 974 | "Z[idxUnk] = 0\n", 975 | "Z = Z.reshape(xx.shape)\n", 976 | "\n", 977 | "# plot the contour - colouring different regions\n", 978 | "cs = plt.contourf(xx, yy, Z, cmap=cm)\n", 979 | "\n", 980 | "# plot the individual data points - colouring by the *true* outcome\n", 981 | "color = np.asarray(y.ravel(),dtype='float')\n", 982 | "plt.scatter(X[:, 0], X[:, 1], c=color, edgecolor='k',\n", 983 | " marker='o', linewidth=2,\n", 984 | " s=60, cmap=cm)\n", 985 | "\n", 986 | "plt.xlabel(feat[0],fontsize=24)\n", 987 | "plt.ylabel(feat[1],fontsize=24)\n", 988 | "plt.axis(\"tight\")\n", 989 | "\n", 990 | "plt.plot([x0_min, x0_max],[4.75,4.75],'k--',linewidth=3)\n", 991 | "plt.plot([4.95, 4.95],[x1_min, 4.75],'k--',linewidth=3)\n", 992 | "\n", 993 | "# cleanup plot\n", 994 | "plot_cleanup()\n", 995 | "\n", 996 | "plt.show()" 997 | ] 998 | }, 999 | { 1000 | "cell_type": "code", 1001 | "execution_count": null, 1002 | "metadata": { 1003 | "collapsed": false 1004 | }, 1005 | "outputs": [], 1006 | "source": [ 1007 | "mdl = tree.DecisionTreeClassifier(criterion='entropy', splitter='best', max_depth=1)\n", 1008 | "mdl = mdl.fit(X,y)\n", 1009 | "\n", 1010 | "plt.figure(figsize=[12,8])\n", 1011 | "\n", 1012 | "# get minimum and maximum values\n", 1013 | "x0_min = X[:, 0].min()\n", 1014 | "x0_max = X[:, 0].max()\n", 1015 | "x1_min = X[:, 1].min()\n", 1016 | "x1_max = X[:, 1].max()\n", 1017 | "\n", 1018 | "vmin = np.min([x0_min,x1_min])\n", 1019 | "vmax = np.max([x0_max,x1_max])\n", 1020 | "xx, yy = np.meshgrid(np.linspace(x0_min, x0_max, 1000),\n", 1021 | " np.linspace(x1_min, x1_max, 1000))\n", 1022 | "\n", 1023 | "X_grid = np.c_[xx.ravel(), yy.ravel()]\n", 1024 | "Z = mdl.predict(X_grid)\n", 1025 | "\n", 1026 | "# customize the prediction using the left side of above tree\n", 1027 | "idxUnk = (X_grid[:,1] <= 4.75) & (X_grid[:,0] <= 4.95) & (X_grid[:,1] <= 3.9)\n", 1028 | "Z[idxUnk] = -1\n", 1029 | "idxUnk = (X_grid[:,1] <= 4.75) & (X_grid[:,0] <= 4.95) & (X_grid[:,1] > 3.9)\n", 1030 | "Z[idxUnk] = 1\n", 1031 | "Z = Z.reshape(xx.shape)\n", 1032 | "\n", 1033 | "# plot the contour - colouring different regions\n", 1034 | "cs = plt.contourf(xx, yy, Z, cmap=cm)\n", 1035 | "\n", 1036 | "# plot the individual data points - colouring by the *true* outcome\n", 1037 | "color = np.asarray(y.ravel(),dtype='float')\n", 1038 | "plt.scatter(X[:, 0], X[:, 1], c=color, edgecolor='k',\n", 1039 | " marker='o', linewidth=2,\n", 1040 | " s=60, cmap=cm)\n", 1041 | "\n", 1042 | "plt.xlabel(feat[0],fontsize=24)\n", 1043 | "plt.ylabel(feat[1],fontsize=24)\n", 1044 | "plt.axis(\"tight\")\n", 1045 | "\n", 1046 | "plt.plot([x0_min, x0_max],[4.75,4.75],'k--',linewidth=3)\n", 1047 | "plt.plot([4.95, 4.95],[x1_min, 4.75],'k--',linewidth=3)\n", 1048 | "plt.plot([x0_min, 4.95],[3.9, 3.9],'k--',linewidth=3)\n", 1049 | "\n", 1050 | "# cleanup plot\n", 1051 | "plot_cleanup()\n", 1052 | "\n", 1053 | "plt.show()" 1054 | ] 1055 | }, 1056 | { 1057 | "cell_type": "code", 1058 | "execution_count": null, 1059 | "metadata": { 1060 | "collapsed": false 1061 | }, 1062 | "outputs": [], 1063 | "source": [ 1064 | "mdl = tree.DecisionTreeClassifier(criterion='entropy', splitter='best', max_depth=1)\n", 1065 | "mdl = mdl.fit(X,y)\n", 1066 | "\n", 1067 | "plt.figure(figsize=[12,8])\n", 1068 | "\n", 1069 | "# get minimum and maximum values\n", 1070 | "x0_min = X[:, 0].min()\n", 1071 | "x0_max = X[:, 0].max()\n", 1072 | "x1_min = X[:, 1].min()\n", 1073 | "x1_max = X[:, 1].max()\n", 1074 | "\n", 1075 | "vmin = np.min([x0_min,x1_min])\n", 1076 | "vmax = np.max([x0_max,x1_max])\n", 1077 | "xx, yy = np.meshgrid(np.linspace(x0_min, x0_max, 1000),\n", 1078 | " np.linspace(x1_min, x1_max, 1000))\n", 1079 | "\n", 1080 | "X_grid = np.c_[xx.ravel(), yy.ravel()]\n", 1081 | "Z = mdl.predict(X_grid)\n", 1082 | "Z = Z.astype(float)\n", 1083 | "# customize the prediction using the left side of above tree\n", 1084 | "# left side of tree\n", 1085 | "idxUnk = (X_grid[:,1] <= 4.75) & (X_grid[:,0] <= 4.95) & (X_grid[:,1] <= 3.9)\n", 1086 | "Z[idxUnk] = -1\n", 1087 | "idxUnk = (X_grid[:,1] <= 4.75) & (X_grid[:,0] <= 4.95) & (X_grid[:,1] > 3.9)\n", 1088 | "Z[idxUnk] = 1\n", 1089 | "# right side of tree\n", 1090 | "idxUnk = (X_grid[:,1] > 4.75) & (X_grid[:,1] <= 5.15)\n", 1091 | "Z[idxUnk] = 0.5\n", 1092 | "idxUnk = (X_grid[:,1] > 4.75) & (X_grid[:,1] > 5.15)\n", 1093 | "Z[idxUnk] = 1\n", 1094 | "\n", 1095 | "Z = Z.reshape(xx.shape)\n", 1096 | "\n", 1097 | "# plot the contour - colouring different regions\n", 1098 | "cs = plt.contourf(xx, yy, Z, cmap=cm)\n", 1099 | "\n", 1100 | "# plot the individual data points - colouring by the *true* outcome\n", 1101 | "color = np.asarray(y.ravel(),dtype='float')\n", 1102 | "plt.scatter(X[:, 0], X[:, 1], c=color, edgecolor='k',\n", 1103 | " marker='o', linewidth=2,\n", 1104 | " s=60, cmap=cm)\n", 1105 | "\n", 1106 | "plt.xlabel(feat[0],fontsize=24)\n", 1107 | "plt.ylabel(feat[1],fontsize=24)\n", 1108 | "plt.axis(\"tight\")\n", 1109 | "\n", 1110 | "plt.plot([x0_min, x0_max],[4.75,4.75],'k--',linewidth=3)\n", 1111 | "\n", 1112 | "plt.plot([x0_min, x0_max],[4.75,4.75],'k--',linewidth=3)\n", 1113 | "plt.plot([4.95, 4.95],[x1_min, 4.75],'k--',linewidth=3)\n", 1114 | "plt.plot([x0_min, 4.95],[3.9, 3.9],'k--',linewidth=3)\n", 1115 | "plt.plot([x0_min, x0_max],[5.15,5.15],'k--',linewidth=3)\n", 1116 | "\n", 1117 | "# cleanup plot\n", 1118 | "plot_cleanup()\n", 1119 | "\n", 1120 | "plt.show()" 1121 | ] 1122 | }, 1123 | { 1124 | "cell_type": "code", 1125 | "execution_count": null, 1126 | "metadata": { 1127 | "collapsed": false 1128 | }, 1129 | "outputs": [], 1130 | "source": [ 1131 | "# examine a depth-3 tree\n", 1132 | "mdl = tree.DecisionTreeClassifier(criterion='entropy', splitter='best', max_depth=3)\n", 1133 | "mdl = mdl.fit(X,y)\n", 1134 | "\n", 1135 | "tree_graph = tree.export_graphviz(mdl, out_file=None,\n", 1136 | " feature_names=feat, \n", 1137 | " filled=True, rounded=True) \n", 1138 | "graph = pydotplus.graphviz.graph_from_dot_data(tree_graph) \n", 1139 | "Image(graph.create_png())" 1140 | ] 1141 | }, 1142 | { 1143 | "cell_type": "code", 1144 | "execution_count": null, 1145 | "metadata": { 1146 | "collapsed": false 1147 | }, 1148 | "outputs": [], 1149 | "source": [ 1150 | "mdl = tree.DecisionTreeClassifier(criterion='entropy', splitter='best', max_depth=3)\n", 1151 | "mdl = mdl.fit(X,y)\n", 1152 | "\n", 1153 | "plt.figure(figsize=[12,8])\n", 1154 | "\n", 1155 | "# get minimum and maximum values\n", 1156 | "x0_min = X[:, 0].min()\n", 1157 | "x0_max = X[:, 0].max()\n", 1158 | "x1_min = X[:, 1].min()\n", 1159 | "x1_max = X[:, 1].max()\n", 1160 | "\n", 1161 | "vmin = np.min([x0_min,x1_min])\n", 1162 | "vmax = np.max([x0_max,x1_max])\n", 1163 | "xx, yy = np.meshgrid(np.linspace(x0_min, x0_max, 1000),\n", 1164 | " np.linspace(x1_min, x1_max, 1000))\n", 1165 | "\n", 1166 | "X_grid = np.c_[xx.ravel(), yy.ravel()]\n", 1167 | "Z = mdl.predict(X_grid)\n", 1168 | "Z = Z.astype(float)\n", 1169 | "\n", 1170 | "# customize the prediction using the left side of above tree\n", 1171 | "# left side of tree\n", 1172 | "idxUnk = (X_grid[:,1] <= 4.75) & (X_grid[:,0] <= 4.95) & (X_grid[:,1] <= 3.9)\n", 1173 | "Z[idxUnk] = -1\n", 1174 | "idxUnk = (X_grid[:,1] <= 4.75) & (X_grid[:,0] <= 4.95) & (X_grid[:,1] > 3.9)\n", 1175 | "Z[idxUnk] = 1\n", 1176 | "# right side of tree\n", 1177 | "idxUnk = (X_grid[:,1] > 4.75) & (X_grid[:,1] <= 5.15) & (X_grid[:,0] <= 6.6)\n", 1178 | "Z[idxUnk] = 0.85\n", 1179 | "idxUnk = (X_grid[:,1] > 4.75) & (X_grid[:,1] <= 5.15) & (X_grid[:,0] > 6.6)\n", 1180 | "Z[idxUnk] = -0.6\n", 1181 | "\n", 1182 | "Z = Z.reshape(xx.shape)\n", 1183 | "\n", 1184 | "# plot the contour - colouring different regions\n", 1185 | "cs = plt.contourf(xx, yy, Z, cmap=cm)\n", 1186 | "\n", 1187 | "# plot the individual data points - colouring by the *true* outcome\n", 1188 | "color = np.asarray(y.ravel(),dtype='float')\n", 1189 | "plt.scatter(X[:, 0], X[:, 1], c=color, edgecolor='k',\n", 1190 | " marker='o', linewidth=2,\n", 1191 | " s=60, cmap=cm)\n", 1192 | "\n", 1193 | "plt.xlabel(feat[0],fontsize=24)\n", 1194 | "plt.ylabel(feat[1],fontsize=24)\n", 1195 | "plt.axis(\"tight\")\n", 1196 | "\n", 1197 | "plt.plot([x0_min, x0_max],[4.75,4.75],'k--',linewidth=3)\n", 1198 | "\n", 1199 | "plt.plot([x0_min, x0_max],[4.75,4.75],'k--',linewidth=3)\n", 1200 | "plt.plot([4.95, 4.95],[x1_min, 4.75],'k--',linewidth=3)\n", 1201 | "plt.plot([x0_min, 4.95],[3.9, 3.9],'k--',linewidth=3)\n", 1202 | "plt.plot([x0_min, x0_max],[5.15,5.15],'k--',linewidth=3)\n", 1203 | "\n", 1204 | "# cleanup plot\n", 1205 | "plot_cleanup()\n", 1206 | "\n", 1207 | "plt.show()" 1208 | ] 1209 | }, 1210 | { 1211 | "cell_type": "markdown", 1212 | "metadata": {}, 1213 | "source": [ 1214 | "# algorithm at different spots in the tree" 1215 | ] 1216 | }, 1217 | { 1218 | "cell_type": "code", 1219 | "execution_count": null, 1220 | "metadata": { 1221 | "collapsed": false 1222 | }, 1223 | "outputs": [], 1224 | "source": [ 1225 | "mdl = tree.DecisionTreeClassifier(criterion='entropy', splitter='best',max_depth=3).fit(X,y)\n", 1226 | "\n", 1227 | "plt.figure(figsize=[12,8])\n", 1228 | "plot_model_pred_2d(mdl, X, y, feat, plot_colorbar=False)\n", 1229 | "\n", 1230 | "# cleanup plot\n", 1231 | "plot_cleanup()\n", 1232 | "\n", 1233 | "plt.show()" 1234 | ] 1235 | }, 1236 | { 1237 | "cell_type": "code", 1238 | "execution_count": null, 1239 | "metadata": { 1240 | "collapsed": false 1241 | }, 1242 | "outputs": [], 1243 | "source": [ 1244 | "# examine a depth-7 tree\n", 1245 | "mdl = tree.DecisionTreeClassifier(criterion='entropy', splitter='best',max_depth=7).fit(X,y)\n", 1246 | "tree_graph = tree.export_graphviz(mdl, out_file=None,\n", 1247 | " feature_names=feat, \n", 1248 | " filled=True, rounded=True) \n", 1249 | "graph = pydotplus.graphviz.graph_from_dot_data(tree_graph) \n", 1250 | "Image(graph.create_png())" 1251 | ] 1252 | }, 1253 | { 1254 | "cell_type": "code", 1255 | "execution_count": null, 1256 | "metadata": { 1257 | "collapsed": false 1258 | }, 1259 | "outputs": [], 1260 | "source": [ 1261 | "plt.figure(figsize=[12,8])\n", 1262 | "plot_model_pred_2d(mdl, X, y, feat, plot_colorbar=False)\n", 1263 | "\n", 1264 | "# cleanup plot\n", 1265 | "plot_cleanup()\n", 1266 | "\n", 1267 | "plt.show()" 1268 | ] 1269 | }, 1270 | { 1271 | "cell_type": "markdown", 1272 | "metadata": {}, 1273 | "source": [ 1274 | "# bootstrapping CDF vs. PDF" 1275 | ] 1276 | }, 1277 | { 1278 | "cell_type": "code", 1279 | "execution_count": null, 1280 | "metadata": { 1281 | "collapsed": false 1282 | }, 1283 | "outputs": [], 1284 | "source": [ 1285 | "\n", 1286 | "# Create some test data\n", 1287 | "dx = .1\n", 1288 | "X = np.arange(-2,2,dx)\n", 1289 | "Y = np.exp(-X**2)\n", 1290 | "\n", 1291 | "# Normalize the data to a proper PDF\n", 1292 | "Y /= (dx*Y).sum()\n", 1293 | "\n", 1294 | "# Compute the CDF\n", 1295 | "CY = np.cumsum(Y*dx)\n", 1296 | "\n", 1297 | "# shift the axis over, set the histogram widths\n", 1298 | "factor = 10\n", 1299 | "X = (X*factor)+70\n", 1300 | "hist_w = 1\n", 1301 | "\n", 1302 | "colors = plt.cm.Set1([x/7.0 for x in range(7)])\n", 1303 | "plt.figure(figsize=[12,7])\n", 1304 | "# plot pdf\n", 1305 | "plt.plot(X,Y, linewidth=3, color=colors[1])\n", 1306 | "# plot hist\n", 1307 | "plt.bar(X-(dx*factor/2.),Y,width=dx*factor,linewidth=0.1, facecolor=colors[1],alpha=0.5)\n", 1308 | "# plot CDF\n", 1309 | "plt.plot(X,CY,'--', linewidth=3,color=colors[3])\n", 1310 | "\n", 1311 | "plot_cleanup()\n", 1312 | "plt.show()\n", 1313 | "\n", 1314 | "# no histogram\n", 1315 | "\n", 1316 | "\n", 1317 | "# Create some test data\n", 1318 | "dx = .1\n", 1319 | "X = np.arange(-2,2,dx)\n", 1320 | "Y = np.exp(-X**2)\n", 1321 | "\n", 1322 | "# Normalize the data to a proper PDF\n", 1323 | "Y /= (dx*Y).sum()\n", 1324 | "\n", 1325 | "# Compute the CDF\n", 1326 | "CY = np.cumsum(Y*dx)\n", 1327 | "\n", 1328 | "# shift the axis over, set the histogram widths\n", 1329 | "factor = 10\n", 1330 | "X = (X*factor)+70\n", 1331 | "hist_w = 1\n", 1332 | "\n", 1333 | "colors = plt.cm.Set1([x/7.0 for x in range(7)])\n", 1334 | "plt.figure(figsize=[12,7])\n", 1335 | "# plot pdf\n", 1336 | "plt.plot(X,Y, linewidth=3, color=colors[1])\n", 1337 | "# plot CDF\n", 1338 | "plt.plot(X,CY,'--', linewidth=3,color=colors[3])\n", 1339 | "\n", 1340 | "plot_cleanup()\n", 1341 | "plt.show()" 1342 | ] 1343 | }, 1344 | { 1345 | "cell_type": "code", 1346 | "execution_count": null, 1347 | "metadata": { 1348 | "collapsed": false 1349 | }, 1350 | "outputs": [], 1351 | "source": [ 1352 | "colors = plt.cm.Set1([x/7.0 for x in range(7)])\n", 1353 | "\n", 1354 | "# a sample of data using rand + normrnd\n", 1355 | "np.random.seed(123)\n", 1356 | "Y = np.random.normal(loc=0.0, scale=1.0, size=[50,])\n", 1357 | "plt.figure(figsize=[12,7])\n", 1358 | "n, bins, patches = plt.hist(Y,bins=np.linspace(-5,5,50),color=colors[1],normed=True)\n", 1359 | "\n", 1360 | "\n", 1361 | "plt.xlim([-3,3])\n", 1362 | "plt.ylim([0,1])\n", 1363 | "\n", 1364 | "# reset xticks to be in the \"weight\" range\n", 1365 | "locs, labels = plt.xticks()\n", 1366 | "plt.xticks( locs, [x*factor+70 for x in locs] )\n", 1367 | "\n", 1368 | "plot_cleanup()\n", 1369 | "plt.show()\n", 1370 | "\n", 1371 | "# plot the CDF of the above\n", 1372 | "plt.figure(figsize=[12,7])\n", 1373 | "\n", 1374 | "n, bins, patches = plt.hist(Y,bins=np.linspace(-5,5,50),color=colors[1],normed=True)\n", 1375 | "n, bins, patches = plt.hist(Y,bins=np.linspace(-5,5,50),color=colors[2],normed=True,\n", 1376 | " linewidth=4,histtype='step',cumulative=True)\n", 1377 | "\n", 1378 | "\n", 1379 | "plt.xlim([-3,3])\n", 1380 | "plt.ylim([0,1])\n", 1381 | "\n", 1382 | "# reset xticks to be in the \"weight\" range\n", 1383 | "locs, labels = plt.xticks()\n", 1384 | "plt.xticks( locs, [x*factor+70 for x in locs] )\n", 1385 | "\n", 1386 | "plot_cleanup()\n", 1387 | "\n", 1388 | "plt.show()" 1389 | ] 1390 | }, 1391 | { 1392 | "cell_type": "code", 1393 | "execution_count": null, 1394 | "metadata": { 1395 | "collapsed": false 1396 | }, 1397 | "outputs": [], 1398 | "source": [ 1399 | "\n", 1400 | "# Create some test data\n", 1401 | "hist_w = 1\n", 1402 | "\n", 1403 | "dx = .1\n", 1404 | "X = np.arange(-3,3,dx)\n", 1405 | "Y = np.exp(-X**2)\n", 1406 | "\n", 1407 | "# Normalize the data to a proper PDF\n", 1408 | "Y /= (dx*Y).sum()\n", 1409 | "\n", 1410 | "# Compute the CDF\n", 1411 | "CY = np.cumsum(Y*dx)\n", 1412 | "colors = plt.cm.Set1([x/7.0 for x in range(7)])\n", 1413 | "\n", 1414 | "# a sample of data using rand + normrnd\n", 1415 | "np.random.seed(123)\n", 1416 | "Y = np.random.normal(loc=0.0, scale=1.0, size=[50,])\n", 1417 | "\n", 1418 | "\n", 1419 | "\n", 1420 | "\n", 1421 | "colors = plt.cm.Set1([x/7.0 for x in range(7)])\n", 1422 | "plt.figure(figsize=[12,7])\n", 1423 | "\n", 1424 | "\n", 1425 | "# plot CDF\n", 1426 | "n, bins, patches = plt.hist(Y,bins=np.linspace(-5,5,50),color=colors[2],normed=True,\n", 1427 | " linewidth=4,histtype='step',cumulative=True)\n", 1428 | "plt.plot(X,CY,'--', linewidth=3,color=colors[3])\n", 1429 | "\n", 1430 | "plt.xlim([-3,3])\n", 1431 | "plt.ylim([0,1])\n", 1432 | "\n", 1433 | "# reset xticks to be in the \"weight\" range\n", 1434 | "locs, labels = plt.xticks()\n", 1435 | "plt.xticks( locs, [int(x*factor+70) for x in locs] )\n", 1436 | "\n", 1437 | "plot_cleanup()\n", 1438 | "\n", 1439 | "plt.show()" 1440 | ] 1441 | }, 1442 | { 1443 | "cell_type": "code", 1444 | "execution_count": null, 1445 | "metadata": { 1446 | "collapsed": false 1447 | }, 1448 | "outputs": [], 1449 | "source": [ 1450 | "# a sample of data using rand + normrnd\n", 1451 | "for m in range(5):\n", 1452 | " np.random.seed(123)\n", 1453 | " Y = np.random.normal(loc=0.0, scale=1.0, size=[50,])\n", 1454 | "\n", 1455 | " colors = plt.cm.Set1([x/7.0 for x in range(7)])\n", 1456 | " plt.figure(figsize=[12,7])\n", 1457 | "\n", 1458 | " # plot original CDF\n", 1459 | " n, bins, patches = plt.hist(Y,bins=np.linspace(-5,5,50),color=colors[2],normed=True,\n", 1460 | " linewidth=4,histtype='step',cumulative=True)\n", 1461 | "\n", 1462 | " # remove green from our colors\n", 1463 | " colors = [colors[i] for i in range(colors.shape[0]) if i!=2]\n", 1464 | "\n", 1465 | " # plot ~5 repeated samples\n", 1466 | " for i in range(m):\n", 1467 | " # bootstrap sample\n", 1468 | " idx = np.random.randint(0, high=Y.shape[0],size=Y.shape[0])\n", 1469 | " n, bins, patches = plt.hist(Y[idx],bins=np.linspace(-5,5,50),color=colors[i],normed=True,\n", 1470 | " linewidth=0.1,histtype='bar',cumulative=False,alpha=0.5)\n", 1471 | " n, bins, patches = plt.hist(Y[idx],bins=np.linspace(-5,5,50),color=colors[i],normed=True,\n", 1472 | " linewidth=2,histtype='step',cumulative=True)\n", 1473 | "\n", 1474 | " plt.xlim([-3,3])\n", 1475 | " plt.ylim([0,1])\n", 1476 | "\n", 1477 | " # reset xticks to be in the \"weight\" range\n", 1478 | " locs, labels = plt.xticks()\n", 1479 | " plt.xticks( locs, [int(x*factor+70) for x in locs] )\n", 1480 | "\n", 1481 | " plot_cleanup()\n", 1482 | "\n", 1483 | " plt.show()" 1484 | ] 1485 | }, 1486 | { 1487 | "cell_type": "code", 1488 | "execution_count": null, 1489 | "metadata": { 1490 | "collapsed": false, 1491 | "scrolled": false 1492 | }, 1493 | "outputs": [], 1494 | "source": [ 1495 | "# load fisher-iris and build bagging model .. showing each individual tree and cumulative result\n", 1496 | "# real example\n", 1497 | "df = datasets.load_iris()\n", 1498 | "\n", 1499 | "idx = [0,2]\n", 1500 | "X = df['data'][50:,idx]\n", 1501 | "y = df['target'][50:]\n", 1502 | "\n", 1503 | "feat = [df['feature_names'][x] for x in idx]\n", 1504 | "\n", 1505 | "# get minimum and maximum values for dataset\n", 1506 | "# these are used in the plotting\n", 1507 | "x0_min = X[:, 0].min()\n", 1508 | "x0_max = X[:, 0].max()\n", 1509 | "x1_min = X[:, 1].min()\n", 1510 | "x1_max = X[:, 1].max()\n", 1511 | "\n", 1512 | "vmin = np.min([x0_min,x1_min])\n", 1513 | "vmax = np.max([x0_max,x1_max])\n", 1514 | "xx, yy = np.meshgrid(np.linspace(x0_min, x0_max, 1000),\n", 1515 | " np.linspace(x1_min, x1_max, 1000))\n", 1516 | "\n", 1517 | "\n", 1518 | "# plot the original data\n", 1519 | "fig = plt.figure(figsize=[8,5])\n", 1520 | "\n", 1521 | "# plot the individual data points - colouring by the *true* outcome\n", 1522 | "color = np.asarray(y.ravel(),dtype='float')\n", 1523 | "plt.scatter(X[:, 0], X[:, 1], c=color, marker='o',\n", 1524 | " s=60, cmap=cm)\n", 1525 | "plt.xlabel(feat[0],fontsize=24)\n", 1526 | "plt.ylabel(feat[1],fontsize=24)\n", 1527 | "plt.axis(\"tight\")\n", 1528 | "# cleanup plot\n", 1529 | "plot_cleanup()\n", 1530 | "# disable ticks\n", 1531 | "plt.xticks([])\n", 1532 | "plt.yticks([])\n", 1533 | "plt.show()\n", 1534 | "\n", 1535 | "\n", 1536 | "\n", 1537 | "\n", 1538 | "np.random.seed(321)\n", 1539 | "clf = tree.DecisionTreeClassifier(max_depth=5)\n", 1540 | "\n", 1541 | "mdls = list()\n", 1542 | "ypred = np.zeros([y.shape[0]])\n", 1543 | "\n", 1544 | "for i in range(5):\n", 1545 | " fig = plt.figure(figsize=[8,5])\n", 1546 | " \n", 1547 | " # random sample\n", 1548 | " idx = np.random.randint(0, X.shape[0], X.shape[0])\n", 1549 | " idxOOB = [x for x in range(X.shape[0]) if x not in idx]\n", 1550 | " \n", 1551 | " # create the estimator\n", 1552 | " mdl = clf.fit(X[idx,:],y[idx])\n", 1553 | " mdls.append(mdl) \n", 1554 | "\n", 1555 | " Z = mdl.predict(np.c_[xx.ravel(), yy.ravel()])\n", 1556 | " Z = Z.reshape(xx.shape)\n", 1557 | " if i==0:\n", 1558 | " Z_all = Z.astype(float)\n", 1559 | " else:\n", 1560 | " Z_all += Z.astype(float)\n", 1561 | " \n", 1562 | " # plot the contour - colouring different regions according to class\n", 1563 | " cs = plt.contourf(xx, yy, Z, cmap=cm)\n", 1564 | "\n", 1565 | " # plot the individual data points - colouring by the *true* outcome\n", 1566 | " color = np.asarray(y[idx].ravel(),dtype='float')\n", 1567 | " plt.scatter(X[idx, 0], X[idx, 1], c=color, edgecolor='k',\n", 1568 | " marker='o', linewidth=2,\n", 1569 | " s=60, cmap=cm)\n", 1570 | " \n", 1571 | " # plot \"s\" for data points which weren't included\n", 1572 | " color = np.asarray(y[idxOOB].ravel(),dtype='float')\n", 1573 | " plt.scatter(X[idxOOB, 0], X[idxOOB, 1], c=color, edgecolor='gray',\n", 1574 | " marker='s', linewidth=2,\n", 1575 | " s=60, cmap=cm)\n", 1576 | "\n", 1577 | " plt.xlabel(feat[0],fontsize=24)\n", 1578 | " plt.ylabel(feat[1],fontsize=24)\n", 1579 | " plt.axis(\"tight\")\n", 1580 | "\n", 1581 | " #plt.colorbar(cs)\n", 1582 | "\n", 1583 | " # cleanup plot\n", 1584 | " plot_cleanup()\n", 1585 | "\n", 1586 | " # disable ticks\n", 1587 | " plt.xticks([])\n", 1588 | " plt.yticks([])\n", 1589 | " \n", 1590 | " txt = 'Tree {}'.format(i+1)\n", 1591 | " plt.text(7.0, 3.5, txt, fontdict={'fontsize':12})\n", 1592 | " plt.show()\n", 1593 | " \n", 1594 | " \n", 1595 | "print('Final aggregation')\n", 1596 | "Z_all = Z_all / 5.0\n", 1597 | "Z_all = np.round(Z_all)\n", 1598 | "\n", 1599 | "fig = plt.figure(figsize=[8,5])\n", 1600 | "# plot the contour - colouring different regions according to class\n", 1601 | "cs = plt.contourf(xx, yy, Z_all, cmap=cm)\n", 1602 | "\n", 1603 | "# plot the individual data points - colouring by the *true* outcome\n", 1604 | "color = np.asarray(y.ravel(),dtype='float')\n", 1605 | "\n", 1606 | "plt.scatter(X[:, 0], X[:, 1], c=color, edgecolor='k',\n", 1607 | " marker='o', linewidth=2,\n", 1608 | " s=60, cmap=cm)\n", 1609 | "\n", 1610 | "plt.xlabel(feat[0],fontsize=24)\n", 1611 | "plt.ylabel(feat[1],fontsize=24)\n", 1612 | "plt.axis(\"tight\")\n", 1613 | "\n", 1614 | "#plt.colorbar(cs)\n", 1615 | "\n", 1616 | "# cleanup plot\n", 1617 | "plot_cleanup()\n", 1618 | "\n", 1619 | "# disable ticks\n", 1620 | "plt.xticks([])\n", 1621 | "plt.yticks([])\n", 1622 | "\n", 1623 | "txt = 'All trees'\n", 1624 | "plt.text(7.0, 3.5, txt, fontdict={'fontsize':12})\n", 1625 | "plt.show()\n" 1626 | ] 1627 | }, 1628 | { 1629 | "cell_type": "markdown", 1630 | "metadata": {}, 1631 | "source": [ 1632 | "# random forest" 1633 | ] 1634 | }, 1635 | { 1636 | "cell_type": "code", 1637 | "execution_count": null, 1638 | "metadata": { 1639 | "collapsed": false 1640 | }, 1641 | "outputs": [], 1642 | "source": [ 1643 | "# load fisher-iris and build bagging model .. showing each individual tree and cumulative result\n", 1644 | "# real example\n", 1645 | "df = datasets.load_iris()\n", 1646 | "\n", 1647 | "idx = [0,2]\n", 1648 | "X = df['data'][50:,:]\n", 1649 | "y = df['target'][50:]\n", 1650 | "\n", 1651 | "feat = df['feature_names']\n", 1652 | "\n", 1653 | "# get minimum and maximum values for dataset\n", 1654 | "# these are used in the plotting\n", 1655 | "x_min = 0\n", 1656 | "x_max = 8\n", 1657 | "xx, yy = np.meshgrid(np.linspace(x_min, x_max, 1000),\n", 1658 | " np.linspace(x_min, x_max, 1000))\n", 1659 | "\n", 1660 | "\n", 1661 | "np.random.seed(321)\n", 1662 | "clf = tree.DecisionTreeClassifier(max_depth=5)\n", 1663 | "\n", 1664 | "mdls = list()\n", 1665 | "fig = plt.figure(figsize=[14,10])\n", 1666 | "\n", 1667 | "for i in range(6):\n", 1668 | " ax = fig.add_subplot(2,3,i+1)\n", 1669 | " \n", 1670 | " # random sample of data\n", 1671 | " idx = np.random.randint(0, X.shape[0], X.shape[0])\n", 1672 | " idxOOB = [x for x in range(X.shape[0]) if x not in idx]\n", 1673 | " \n", 1674 | " # random subset of features\n", 1675 | " idxFeat = np.random.permutation(4)[:2]\n", 1676 | " \n", 1677 | " # create the estimator\n", 1678 | " mdl = clf.fit(X[idx,:][:,idxFeat],y[idx])\n", 1679 | " mdls.append(mdl) \n", 1680 | "\n", 1681 | " Z = mdl.predict(np.c_[xx.ravel(), yy.ravel()])\n", 1682 | " Z = Z.reshape(xx.shape)\n", 1683 | " \n", 1684 | " # plot the contour - colouring different regions according to class\n", 1685 | " cs = plt.contourf(xx, yy, Z, cmap=cm)\n", 1686 | "\n", 1687 | " # plot the individual data points - colouring by the *true* outcome\n", 1688 | " color = np.asarray(y[idx].ravel(),dtype='float')\n", 1689 | " plt.scatter(X[idx, idxFeat[0]], X[idx, idxFeat[1]], c=color, marker='o', edgecolors='k',\n", 1690 | " s=60, cmap=cm)\n", 1691 | " \n", 1692 | " # plot a gray square around for data points which weren't included\n", 1693 | " color = np.asarray(y[idxOOB].ravel(),dtype='float')\n", 1694 | " plt.scatter(X[idxOOB, idxFeat[0]], X[idxOOB, idxFeat[1]], c=color, marker='s',\n", 1695 | " linewidth=2, edgecolors='gray',\n", 1696 | " s=60, cmap=cm)\n", 1697 | " \n", 1698 | "\n", 1699 | " plt.xlabel(feat[idxFeat[0]],fontsize=24)\n", 1700 | " plt.ylabel(feat[idxFeat[1]],fontsize=24)\n", 1701 | " plt.axis(\"tight\")\n", 1702 | "\n", 1703 | " #plt.colorbar(cs)\n", 1704 | "\n", 1705 | " # cleanup plot\n", 1706 | " plot_cleanup()\n", 1707 | "\n", 1708 | " # disable ticks\n", 1709 | " plt.xticks([])\n", 1710 | " plt.yticks([])\n", 1711 | " \n", 1712 | " txt = 'Tree {}'.format(i+1)\n", 1713 | " plt.text(3.0, 4.0, txt, fontdict={'fontsize':12,'fontweight':'bold'})\n", 1714 | "plt.show()" 1715 | ] 1716 | }, 1717 | { 1718 | "cell_type": "markdown", 1719 | "metadata": {}, 1720 | "source": [ 1721 | "# random forest overview" 1722 | ] 1723 | }, 1724 | { 1725 | "cell_type": "code", 1726 | "execution_count": null, 1727 | "metadata": { 1728 | "collapsed": false 1729 | }, 1730 | "outputs": [], 1731 | "source": [ 1732 | "# load fisher-iris and build bagging model .. showing each individual tree and cumulative result\n", 1733 | "# real example\n", 1734 | "df = datasets.load_iris()\n", 1735 | "\n", 1736 | "idx = [0,2]\n", 1737 | "X = df['data'][50:,:]\n", 1738 | "y = df['target'][50:]\n", 1739 | "\n", 1740 | "feat = df['feature_names']\n", 1741 | "\n", 1742 | "# get minimum and maximum values for dataset\n", 1743 | "# these are used in the plotting\n", 1744 | "x_min = 0\n", 1745 | "x_max = 8\n", 1746 | "xx, yy = np.meshgrid(np.linspace(x_min, x_max, 1000),\n", 1747 | " np.linspace(x_min, x_max, 1000))\n", 1748 | "\n", 1749 | "\n", 1750 | "np.random.seed(172631)\n", 1751 | "i=0\n", 1752 | "# random sample of data\n", 1753 | "idx = np.random.randint(0, X.shape[0], X.shape[0])\n", 1754 | "idxOOB = [x for x in range(X.shape[0]) if x not in idx]\n", 1755 | "\n", 1756 | "# random subset of features\n", 1757 | "idxFeat = np.random.permutation(4)[:2]\n", 1758 | "\n", 1759 | "# create the estimator\n", 1760 | "mdl = clf.fit(X[idx,:][:,idxFeat],y[idx])\n", 1761 | "mdls.append(mdl) \n", 1762 | "\n", 1763 | "Z = mdl.predict(np.c_[xx.ravel(), yy.ravel()])\n", 1764 | "Z = Z.reshape(xx.shape)\n", 1765 | "\n", 1766 | "# plot the bootstrap sample with the OOB observations\n", 1767 | "fig = plt.figure(figsize=[14,10])\n", 1768 | "# plot the individual data points - colouring by the *true* outcome\n", 1769 | "color = np.asarray(y[idx].ravel(),dtype='float')\n", 1770 | "plt.scatter(X[idx, idxFeat[0]], X[idx, idxFeat[1]], c=color, marker='o',\n", 1771 | " edgecolor='k', linewidth=2,\n", 1772 | " s=60, cmap=cm)\n", 1773 | "# plot \"x\" for data points which weren't included\n", 1774 | "color = np.asarray(y[idxOOB].ravel(),dtype='float')\n", 1775 | "plt.scatter(X[idxOOB, idxFeat[0]], X[idxOOB, idxFeat[1]], c=color, marker='s',\n", 1776 | " linewidth=2, edgecolors='gray',\n", 1777 | " s=60, cmap=cm)\n", 1778 | "plt.xlabel(feat[idxFeat[0]],fontsize=24)\n", 1779 | "plt.ylabel(feat[idxFeat[1]],fontsize=24)\n", 1780 | "plt.axis(\"tight\")\n", 1781 | "# cleanup plot\n", 1782 | "plot_cleanup()\n", 1783 | "plt.xlim([0,8])\n", 1784 | "plt.ylim([0,8])\n", 1785 | "plt.show()\n", 1786 | "\n", 1787 | "# plot the bootstrap sample w/o OOB\n", 1788 | "fig = plt.figure(figsize=[14,10])\n", 1789 | "color = np.asarray(y[idx].ravel(),dtype='float')\n", 1790 | "plt.scatter(X[idx, idxFeat[0]], X[idx, idxFeat[1]], c=color, marker='o',\n", 1791 | " linewidth=2, edgecolor='k',\n", 1792 | " s=60, cmap=cm)\n", 1793 | "plt.xlabel(feat[idxFeat[0]],fontsize=24)\n", 1794 | "plt.ylabel(feat[idxFeat[1]],fontsize=24)\n", 1795 | "plt.axis(\"tight\")\n", 1796 | "# cleanup plot\n", 1797 | "plot_cleanup()\n", 1798 | "#plt.text(3.0, 4.0, txt, fontdict={'fontsize':12,'fontweight':'bold'})\n", 1799 | "plt.xlim([0,8])\n", 1800 | "plt.ylim([0,8])\n", 1801 | "plt.show()\n", 1802 | "\n", 1803 | "\n", 1804 | "\n", 1805 | "# plot the bootstrap sample w/o OOB and with decision surface\n", 1806 | "fig = plt.figure(figsize=[14,10])\n", 1807 | "cs = plt.contourf(xx, yy, Z, cmap=cm)\n", 1808 | "color = np.asarray(y[idx].ravel(),dtype='float')\n", 1809 | "plt.scatter(X[idx, idxFeat[0]], X[idx, idxFeat[1]], c=color, marker='o',\n", 1810 | " edgecolor='k', linewidth=2,\n", 1811 | " s=60, cmap=cm)\n", 1812 | "plt.xlabel(feat[idxFeat[0]],fontsize=24)\n", 1813 | "plt.ylabel(feat[idxFeat[1]],fontsize=24)\n", 1814 | "plt.axis(\"tight\")\n", 1815 | "# cleanup plot\n", 1816 | "plot_cleanup()\n", 1817 | "plt.grid()\n", 1818 | "#plt.text(3.0, 4.0, txt, fontdict={'fontsize':12,'fontweight':'bold'})\n", 1819 | "plt.xlim([0,8])\n", 1820 | "plt.ylim([0,8])\n", 1821 | "plt.show()" 1822 | ] 1823 | }, 1824 | { 1825 | "cell_type": "markdown", 1826 | "metadata": {}, 1827 | "source": [ 1828 | "# performance curve of random forest\n", 1829 | "\n", 1830 | "## on training set" 1831 | ] 1832 | }, 1833 | { 1834 | "cell_type": "code", 1835 | "execution_count": null, 1836 | "metadata": { 1837 | "collapsed": false 1838 | }, 1839 | "outputs": [], 1840 | "source": [ 1841 | "df = datasets.load_iris()\n", 1842 | "\n", 1843 | "X = df['data']\n", 1844 | "y = df['target']\n", 1845 | "\n", 1846 | "np.random.seed(321)\n", 1847 | "mdl = ensemble.RandomForestClassifier(n_estimators=50, oob_score=True)\n", 1848 | "mdl = mdl.fit(X,y)\n", 1849 | "\n", 1850 | "from sklearn.ensemble.forest import _generate_unsampled_indices\n", 1851 | "\n", 1852 | "err = list()\n", 1853 | "n_samples = X.shape[0]\n", 1854 | "pred = np.zeros([y.shape[0],50])\n", 1855 | "roll_pred = np.zeros([y.shape[0],50])\n", 1856 | "\n", 1857 | "idx = np.zeros([y.shape[0],50],dtype=bool)\n", 1858 | "for i, estimator in enumerate(mdl.estimators_):\n", 1859 | " # Here at each iteration we obtain out of bag samples for every tree.\n", 1860 | " idxOOB = _generate_unsampled_indices(estimator.random_state, n_samples)\n", 1861 | " \n", 1862 | " # update predictions\n", 1863 | " curr_pred = estimator.predict(X[idxOOB,:])\n", 1864 | " pred[idxOOB,i] = curr_pred\n", 1865 | " idx[idxOOB,i] = True\n", 1866 | "\n", 1867 | " idxFeat = range(i+1)\n", 1868 | " roll_pred[:,i] = np.sum(pred[:,idxFeat]*idx[:,idxFeat],axis=1)\n", 1869 | " idxKeep = np.sum(idx[:, idxFeat],axis=1)\n", 1870 | " \n", 1871 | " roll_pred[idxKeep>0,i] = roll_pred[idxKeep>0,i] / idxKeep[idxKeep>0]\n", 1872 | " \n", 1873 | " # convert from 0/1 to the class labels\n", 1874 | " roll_pred[idxKeep>0,i] = mdl.classes_[np.round(roll_pred[idxKeep>0,i]).astype(int)]\n", 1875 | " \n", 1876 | " \n", 1877 | " # calculate current error\n", 1878 | " err.append( 1.0-np.mean(roll_pred[idxKeep>0,i] == y[idxKeep>0] ) )\n", 1879 | "\n", 1880 | "err = np.asarray(err)\n", 1881 | "plt.figure(figsize=[10,7])\n", 1882 | "plt.plot(range(err.shape[0]),err*100.0,color=colors[1],linewidth=4)\n", 1883 | "plt.ylabel('Number of errors',fontsize=20)\n", 1884 | "plt.xlabel('Number of trees',fontsize=20)\n", 1885 | "plot_cleanup()\n", 1886 | "plt.show()" 1887 | ] 1888 | }, 1889 | { 1890 | "cell_type": "code", 1891 | "execution_count": null, 1892 | "metadata": { 1893 | "collapsed": false 1894 | }, 1895 | "outputs": [], 1896 | "source": [ 1897 | "df = datasets.load_iris()\n", 1898 | "\n", 1899 | "X = df['data']\n", 1900 | "y = df['target']\n", 1901 | "\n", 1902 | "np.random.seed(321)\n", 1903 | "mdl = ensemble.RandomForestClassifier(n_estimators=50, oob_score=True)\n", 1904 | "mdl = mdl.fit(X,y)\n", 1905 | "\n", 1906 | "from sklearn.ensemble.forest import _generate_unsampled_indices\n", 1907 | "\n", 1908 | "err = list()\n", 1909 | "n_samples = X.shape[0]\n", 1910 | "pred = np.zeros([y.shape[0],50])\n", 1911 | "roll_pred = np.zeros([y.shape[0],50])\n", 1912 | "idx = np.zeros([y.shape[0],50],dtype=bool)\n", 1913 | "for i, estimator in enumerate(mdl.estimators_):\n", 1914 | " # Here at each iteration we obtain out of bag samples for every tree.\n", 1915 | " idxOOB = _generate_unsampled_indices(estimator.random_state, n_samples)\n", 1916 | " \n", 1917 | " # update predictions\n", 1918 | " curr_pred = estimator.predict(X[idxOOB,:])\n", 1919 | " pred[idxOOB,i] = curr_pred\n", 1920 | " idx[idxOOB,i] = True\n", 1921 | "\n", 1922 | " idxFeat = range(i+1)\n", 1923 | " roll_pred[:,i] = np.sum(pred[:,idxFeat]*idx[:,idxFeat],axis=1)\n", 1924 | " idxKeep = np.sum(idx[:, idxFeat],axis=1)\n", 1925 | " \n", 1926 | " roll_pred[idxKeep>0,i] = roll_pred[idxKeep>0,i] / idxKeep[idxKeep>0]\n", 1927 | " \n", 1928 | " # convert from 0/1 to the class labels\n", 1929 | " roll_pred[idxKeep>0,i] = mdl.classes_[np.round(roll_pred[idxKeep>0,i]).astype(int)]\n", 1930 | " # calculate current error\n", 1931 | " err.append( 1.0-np.mean(roll_pred[idxKeep>0,i] == y[idxKeep>0] ) )\n", 1932 | "\n", 1933 | "err = np.asarray(err)\n", 1934 | "plt.figure(figsize=[10,7])\n", 1935 | "plt.plot(range(err.shape[0]),err*100.0,color=colors[1],linewidth=4)\n", 1936 | "plt.ylabel('Number of errors',fontsize=20)\n", 1937 | "plt.xlabel('Number of trees',fontsize=20)\n", 1938 | "plot_cleanup()\n", 1939 | "plt.show()" 1940 | ] 1941 | }, 1942 | { 1943 | "cell_type": "code", 1944 | "execution_count": null, 1945 | "metadata": { 1946 | "collapsed": false 1947 | }, 1948 | "outputs": [], 1949 | "source": [ 1950 | "importances = mdl.feature_importances_\n", 1951 | "std = np.std([current_tree.feature_importances_ for current_tree in mdl.estimators_],\n", 1952 | " axis=0)\n", 1953 | "indices = np.argsort(importances)[::-1]\n", 1954 | "\n", 1955 | "# Plot the feature importances of the forest\n", 1956 | "plt.figure(figsize=[10,7])\n", 1957 | "plt.barh(range(X.shape[1]), importances[indices],\n", 1958 | " color=colors[0], xerr=std[indices], align=\"center\")\n", 1959 | "plt.yticks(range(X.shape[1]), [feat[i] for i in indices])\n", 1960 | "plt.ylim([-1, X.shape[1]])\n", 1961 | "plot_cleanup()\n", 1962 | "plt.show()" 1963 | ] 1964 | } 1965 | ], 1966 | "metadata": { 1967 | "kernelspec": { 1968 | "display_name": "Python 2", 1969 | "language": "python", 1970 | "name": "python2" 1971 | }, 1972 | "language_info": { 1973 | "codemirror_mode": { 1974 | "name": "ipython", 1975 | "version": 2 1976 | }, 1977 | "file_extension": ".py", 1978 | "mimetype": "text/x-python", 1979 | "name": "python", 1980 | "nbconvert_exporter": "python", 1981 | "pygments_lexer": "ipython2", 1982 | "version": "2.7.12" 1983 | } 1984 | }, 1985 | "nbformat": 4, 1986 | "nbformat_minor": 1 1987 | } 1988 | --------------------------------------------------------------------------------