├── Chapter7.md ├── Chapter7.qmd ├── Chapter7_files └── figure-commonmark │ ├── unnamed-chunk-10-1.png │ ├── unnamed-chunk-13-1.png │ ├── unnamed-chunk-14-1.png │ ├── unnamed-chunk-15-1.png │ ├── unnamed-chunk-17-1.png │ ├── unnamed-chunk-20-1.png │ ├── unnamed-chunk-21-1.png │ ├── unnamed-chunk-22-1.png │ └── unnamed-chunk-23-1.png ├── Chapter8_modeltime.md ├── Chapter8_modeltime.qmd ├── Chapter8_modeltime_files └── figure-commonmark │ ├── unnamed-chunk-10-1.png │ ├── unnamed-chunk-5-1.png │ ├── unnamed-chunk-6-1.png │ ├── unnamed-chunk-7-1.png │ └── unnamed-chunk-9-1.png ├── Chapter8_nixtla_reticulated.md ├── Chapter8_nixtla_reticulated.qmd ├── Chapter8_nixtla_reticulated_files └── figure-commonmark │ ├── unnamed-chunk-6-1.png │ └── unnamed-chunk-8-1.png ├── README.md ├── README.qmd └── book_cover.jpg /Chapter7.md: -------------------------------------------------------------------------------- 1 | # Chapter 7 \| Conformal Prediction for Regression 2 | frankiethull 3 | 4 | ## Chapter 7 to Practical Guide to Applied Conformal Prediction in **R**: 5 | 6 | The following code is based on the recent book release: *Practical Guide 7 | to Applied Conformal Prediction in Python*. After posting a fuzzy GIF on 8 | X & receiving a lot of requests for a blog or Github repo, below is 9 | Chapter 7 of the practical guide with applications in R, instead of 10 | Python. 11 | 12 | While the book is not free, the Python code is open-source and a located 13 | at the following github repo: 14 | *https://github.com/PacktPublishing/Practical-Guide-to-Applied-Conformal-Prediction/blob/main/Chapter_07.ipynb* 15 | 16 | While this is not copy/paste direct replica of the python notebook or 17 | book, this is a lite, supplemental R guide, & documentation for R users. 18 | 19 | We will follow the example of calculating conformal prediction intervals 20 | manually, then use the probably package. 21 | 22 | ### R setup for tidymodeling: 23 | 24 | ``` r 25 | # using tidymodel framework: 26 | library(tidymodels) # ml modeling api 27 | ``` 28 | 29 | ── Attaching packages ────────────────────────────────────── tidymodels 1.1.0 ── 30 | 31 | ✔ broom 1.0.5 ✔ recipes 1.0.6 32 | ✔ dials 1.2.0 ✔ rsample 1.1.1 33 | ✔ dplyr 1.1.2 ✔ tibble 3.2.1 34 | ✔ ggplot2 3.4.2 ✔ tidyr 1.3.0 35 | ✔ infer 1.0.4 ✔ tune 1.1.1 36 | ✔ modeldata 1.1.0 ✔ workflows 1.1.3 37 | ✔ parsnip 1.1.0 ✔ workflowsets 1.0.1 38 | ✔ purrr 1.0.1 ✔ yardstick 1.2.0 39 | 40 | ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ── 41 | ✖ purrr::discard() masks scales::discard() 42 | ✖ dplyr::filter() masks stats::filter() 43 | ✖ dplyr::lag() masks stats::lag() 44 | ✖ recipes::step() masks stats::step() 45 | • Learn how to get started at https://www.tidymodels.org/start/ 46 | 47 | ``` r 48 | library(probably) # conformal ints 49 | ``` 50 | 51 | Warning: package 'probably' was built under R version 4.3.1 52 | 53 | 54 | Attaching package: 'probably' 55 | 56 | The following objects are masked from 'package:base': 57 | 58 | as.factor, as.ordered 59 | 60 | ``` r 61 | library(dplyr) # pliers keep it tidy 62 | library(ggplot2) # data viz 63 | library(reticulate) # pass the python example dataset :) 64 | ``` 65 | 66 | Warning: package 'reticulate' was built under R version 4.3.1 67 | 68 | ``` r 69 | library(doParallel) # model tuning made fast 70 | ``` 71 | 72 | Loading required package: foreach 73 | 74 | 75 | Attaching package: 'foreach' 76 | 77 | The following objects are masked from 'package:purrr': 78 | 79 | accumulate, when 80 | 81 | Loading required package: iterators 82 | 83 | Loading required package: parallel 84 | 85 | ``` r 86 | # reticulate::py_install("openml", pip = TRUE) 87 | # reticulate::py_install("pandas", pip = TRUE) 88 | ``` 89 | 90 | ### Load Dataset 91 | 92 | get the matching dataset via openml, quick python chunk from the 93 | original ipynb: 94 | 95 | ``` python 96 | import openml 97 | import pandas as pd 98 | 99 | # List of datasets from openml https://docs.openml.org/Python-API/ 100 | datasets_df = openml.datasets.list_datasets(output_format="dataframe") 101 | print(datasets_df.head(n=10)) 102 | ``` 103 | 104 | did name ... NumberOfNumericFeatures NumberOfSymbolicFeatures 105 | 2 2 anneal ... 6.0 33.0 106 | 3 3 kr-vs-kp ... 0.0 37.0 107 | 4 4 labor ... 8.0 9.0 108 | 5 5 arrhythmia ... 206.0 74.0 109 | 6 6 letter ... 16.0 1.0 110 | 7 7 audiology ... 0.0 70.0 111 | 8 8 liver-disorders ... 6.0 0.0 112 | 9 9 autos ... 15.0 11.0 113 | 10 10 lymph ... 3.0 16.0 114 | 11 11 balance-scale ... 4.0 1.0 115 | 116 | [10 rows x 16 columns] 117 | 118 | ``` python 119 | datasets_df.set_index('did', inplace = True) 120 | 121 | # California housing dataset https://www.openml.org/search?type=data&status=active&id=43939 122 | dataset = openml.datasets.get_dataset(43939) 123 | ``` 124 | 125 | C:\Users\Frank\AppData\Local\R-MINI~1\envs\R-RETI~1\lib\site-packages\openml\datasets\functions.py:438: FutureWarning: Starting from Version 0.15 `download_data`, `download_qualities`, and `download_features_meta_data` will all be ``False`` instead of ``True`` by default to enable lazy loading. To disable this message until version 0.15 explicitly set `download_data`, `download_qualities`, and `download_features_meta_data` to a bool while calling `get_dataset`. 126 | warnings.warn( 127 | 128 | ``` python 129 | # Print a summary 130 | print( 131 | f"This is dataset '{dataset.name}', the target feature is " 132 | f"'{dataset.default_target_attribute}'" 133 | ) 134 | ``` 135 | 136 | This is dataset 'california_housing', the target feature is 'median_house_value' 137 | 138 | ``` python 139 | print(f"URL: {dataset.url}") 140 | ``` 141 | 142 | URL: https://api.openml.org/data/v1/download/22102987/california_housing.arff 143 | 144 | ``` python 145 | print(dataset.description[:500]) 146 | ``` 147 | 148 | Median house prices for California districts derived from the 1990 census. 149 | 150 | ``` python 151 | # openml API 152 | X, y, categorical_indicator, attribute_names = dataset.get_data( 153 | dataset_format="array", target=dataset.default_target_attribute 154 | ) 155 | ``` 156 | 157 | :3: FutureWarning: Support for `dataset_format='array'` will be removed in 0.15,start using `dataset_format='dataframe' to ensure your code will continue to work. You can use the dataframe's `to_numpy` function to continue using numpy arrays. 158 | 159 | ``` python 160 | df = pd.DataFrame(X, columns=attribute_names) 161 | df["class"] = y 162 | ``` 163 | 164 | #### pass the python df to R: 165 | 166 | ``` r 167 | df <- py$df 168 | ``` 169 | 170 | data checks: 171 | 172 | ``` r 173 | df |> str() 174 | ``` 175 | 176 | 'data.frame': 20640 obs. of 10 variables: 177 | $ longitude : num -122 -122 -122 -122 -122 ... 178 | $ latitude : num 37.9 37.9 37.8 37.8 37.8 ... 179 | $ housing_median_age: num 41 21 52 52 52 52 52 52 42 52 ... 180 | $ total_rooms : num 880 7099 1467 1274 1627 ... 181 | $ total_bedrooms : num 129 1106 190 235 280 ... 182 | $ population : num 322 2401 496 558 565 ... 183 | $ households : num 126 1138 177 219 259 ... 184 | $ median_income : num 8.33 8.3 7.26 5.64 3.85 ... 185 | $ ocean_proximity : num 3 3 3 3 3 3 3 3 3 3 ... 186 | $ class : num 452600 358500 352100 341300 342200 ... 187 | - attr(*, "pandas.index")=RangeIndex(start=0, stop=20640, step=1) 188 | 189 | na checks: 190 | 191 | ``` r 192 | colSums(is.na(df)) 193 | ``` 194 | 195 | longitude latitude housing_median_age total_rooms 196 | 0 0 0 0 197 | total_bedrooms population households median_income 198 | 207 0 0 0 199 | ocean_proximity class 200 | 0 0 201 | 202 | ``` r 203 | df <- df |> 204 | na.omit() 205 | ``` 206 | 207 | data processing for regression: 208 | 209 | ``` r 210 | # holdout 10% of data for calibration 211 | cal_holdout <- dplyr::slice_sample(df, prop = .1) 212 | 213 | # proceed typical test/train splitting, a tidymodels workflow based on ipynb: 214 | model_df <- df |> anti_join(cal_holdout) 215 | ``` 216 | 217 | Joining with `by = join_by(longitude, latitude, housing_median_age, 218 | total_rooms, total_bedrooms, population, households, median_income, 219 | ocean_proximity, class)` 220 | 221 | ``` r 222 | split <- model_df |> initial_split(prop = 0.99) 223 | training <- training(split) 224 | testing <- testing(split) 225 | ``` 226 | 227 | model building: 228 | 229 | ``` r 230 | # random forest model spec, specifying 'mode' and 'engine' 231 | rf_model_spec <- 232 | rand_forest(trees = 200, min_n = 5) %>% 233 | set_mode("regression") %>% 234 | set_engine("ranger") 235 | 236 | rf_wflow <- workflow(class ~ ., rf_model_spec) 237 | rf_model_fit <- rf_wflow |> fit(data = training) 238 | ``` 239 | 240 | ## ICP Section 241 | 242 | ``` r 243 | # make point predictions 244 | pred_cal <- rf_model_fit |> predict(cal_holdout) 245 | pred_test <- rf_model_fit |> predict(testing) 246 | 247 | 248 | data.frame( 249 | y = cal_holdout$class, 250 | y_hat = pred_cal$.pred 251 | ) |> 252 | ggplot() + 253 | geom_point(aes(x = y, y = y_hat), color = "darkcyan", alpha = .9) + 254 | theme_minimal() + 255 | labs(title = "Prediction Error for RandomForestRegressor") 256 | ``` 257 | 258 | ![](Chapter7_files/figure-commonmark/unnamed-chunk-10-1.png) 259 | 260 | ``` r 261 | alpha <- 0.05 262 | n_cal <- nrow(cal_holdout) 263 | 264 | y_cal <- cal_holdout$class 265 | y_pred_cal <- pred_cal$.pred 266 | 267 | # calculate calibraion errors 268 | y_cal_error <- abs(y_cal - y_pred_cal) 269 | 270 | ceiling((n_cal+1)*(1-alpha))/n_cal 271 | ``` 272 | 273 | [1] 0.9505629 274 | 275 | ``` r 276 | #calculate q_hat on the calibration set 277 | q_yhat_cal = quantile(y_cal_error,ceiling((n_cal+1)*(1-alpha))/n_cal) 278 | q_yhat_cal 279 | ``` 280 | 281 | 95.05629% 282 | 98919.27 283 | 284 | ``` r 285 | ggplot() + 286 | geom_histogram(aes(x = y_cal_error), fill = "lightblue") + 287 | geom_vline(aes(xintercept = q_yhat_cal), color = "red", linetype = 2) + 288 | labs( 289 | title = "Histogram of Calibration Errors", 290 | x = "Calibration Error", 291 | y = "Frequency" 292 | ) + 293 | theme_minimal() 294 | ``` 295 | 296 | `stat_bin()` using `bins = 30`. Pick better value with `binwidth`. 297 | 298 | ![](Chapter7_files/figure-commonmark/unnamed-chunk-13-1.png) 299 | 300 | ``` r 301 | # predicted_df 302 | pred_test |> 303 | mutate( 304 | lower_bound = .pred - q_yhat_cal, 305 | upper_bound = .pred + q_yhat_cal, 306 | actual = testing$class 307 | ) |> 308 | mutate( 309 | index = row_number() 310 | ) |> 311 | ggplot(aes(x = index)) + 312 | geom_ribbon(aes(ymin = lower_bound, 313 | ymax = upper_bound), fill = "grey", 314 | alpha = 0.5) + 315 | geom_line(aes(y = actual, color = "Actual")) + 316 | geom_line(aes(y = .pred, color = "Predicted")) + 317 | theme_minimal() + 318 | labs( 319 | title = "Actual vs Predicted Values with Prediction Interval" 320 | ) + 321 | theme(legend.title = element_blank(), 322 | legend.position = c(.9,.9)) 323 | ``` 324 | 325 | ![](Chapter7_files/figure-commonmark/unnamed-chunk-14-1.png) 326 | 327 | ### using probably 328 | 329 | doing the routine in a ‘tidy’ way, one can use *probably* package for 330 | split conformal inference. probably is a tidymodels extension package 331 | allowing for various interval and post-calibration modeling techniques. 332 | 333 | ``` r 334 | conformal_split <- int_conformal_split(rf_model_fit, 335 | cal_data = cal_holdout) 336 | 337 | conformal_split_test <- predict(conformal_split, testing, level = 0.95) 338 | 339 | conformal_split_test |> 340 | mutate( 341 | actual = testing$class, 342 | index = row_number() 343 | ) |> 344 | ggplot(aes(x = index)) + 345 | geom_ribbon(aes(ymin = .pred_lower, 346 | ymax = .pred_upper), fill = "grey", 347 | alpha = 0.5) + 348 | geom_line(aes(y = actual, color = "Actual")) + 349 | geom_line(aes(y = .pred, color = "Predicted")) + 350 | theme_minimal() + 351 | labs( 352 | title = "Actual vs Predicted Values with Prediction Interval", 353 | subtitle = "Using {probably}" 354 | ) + 355 | theme(legend.title = element_blank(), 356 | legend.position = c(.9,.9)) 357 | ``` 358 | 359 | ![](Chapter7_files/figure-commonmark/unnamed-chunk-15-1.png) 360 | 361 | ## CQR Section 362 | 363 | compute correlation between features and also between features and the 364 | target 365 | 366 | ``` r 367 | df |> 368 | select(-ocean_proximity) |> 369 | select_if(is.numeric) |> 370 | corrr::correlate() |> 371 | #corrr::rearrange() |> 372 | corrr::shave() 373 | ``` 374 | 375 | Correlation computed with 376 | • Method: 'pearson' 377 | • Missing treated using: 'pairwise.complete.obs' 378 | 379 | # A tibble: 9 × 10 380 | term longitude latitude housing_median_age total_rooms total_bedrooms 381 | 382 | 1 longitude NA NA NA NA NA 383 | 2 latitude -0.925 NA NA NA NA 384 | 3 housing_medi… -0.109 0.0119 NA NA NA 385 | 4 total_rooms 0.0455 -0.0367 -0.361 NA NA 386 | 5 total_bedroo… 0.0696 -0.0670 -0.320 0.930 NA 387 | 6 population 0.100 -0.109 -0.296 0.857 0.878 388 | 7 households 0.0565 -0.0718 -0.303 0.919 0.980 389 | 8 median_income -0.0156 -0.0796 -0.118 0.198 -0.00772 390 | 9 class -0.0454 -0.145 0.106 0.133 0.0497 391 | # ℹ 4 more variables: population , households , median_income , 392 | # class 393 | 394 | ``` r 395 | #corrr::rplot() 396 | ``` 397 | 398 | ``` r 399 | df |> 400 | ggplot() + 401 | geom_histogram(aes(class), fill = "lightblue") + 402 | theme_minimal() + 403 | labs(title = "histogram of house prices", 404 | x = "median price of houses") + 405 | scale_x_continuous(labels = scales::dollar_format()) 406 | ``` 407 | 408 | `stat_bin()` using `bins = 30`. Pick better value with `binwidth`. 409 | 410 | ![](Chapter7_files/figure-commonmark/unnamed-chunk-17-1.png) 411 | 412 | ### Optimize underlying tree model 413 | 414 | ``` r 415 | folds <- vfold_cv(training, v = 5) 416 | 417 | params_distributions <- 418 | expand.grid( 419 | trees = c(10, 25), 420 | tree_depth = c(3, 10), 421 | mtry = c(50, 100), 422 | learn_rate = c(.01, .2) 423 | ) 424 | 425 | model_recipe <- recipe(class ~ ., training) 426 | 427 | # refer to api documentation on how-to pass quantile objective to various engines 428 | gbm_spec <- 429 | boost_tree( 430 | trees = tune(), 431 | tree_depth = tune(), 432 | mtry = tune(), 433 | learn_rate = tune() 434 | ) |> 435 | set_mode("regression") |> 436 | set_engine("xgboost", num_threads = 8) 437 | 438 | # pre training settings --- 439 | cluster <- makePSOCKcluster(8) 440 | registerDoParallel(cluster) 441 | 442 | # model creation --- 443 | gbm_results <- 444 | finetune::tune_race_anova( 445 | workflow() %>% 446 | add_recipe(model_recipe) %>% 447 | add_model(gbm_spec), 448 | resamples = folds, 449 | grid = params_distributions, 450 | control = finetune::control_race(), 451 | metrics = metric_set(rmse) 452 | ) 453 | 454 | # post training settings --- 455 | stopCluster(cluster) 456 | registerDoSEQ() 457 | 458 | finalize_gbm <- workflow() %>% 459 | add_recipe(model_recipe) %>% 460 | add_model(gbm_spec) %>% 461 | finalize_workflow(select_best(gbm_results)) 462 | 463 | best_gbm <- finalize_gbm |> fit(training) 464 | ``` 465 | 466 | [15:12:27] WARNING: src/learner.cc:767: 467 | Parameters: { "num_threads" } are not used. 468 | 469 | ``` r 470 | show_best(gbm_results) 471 | ``` 472 | 473 | # A tibble: 2 × 10 474 | mtry trees tree_depth learn_rate .metric .estimator mean n std_err 475 | 476 | 1 50 25 10 0.2 rmse standard 49497. 5 744. 477 | 2 100 25 10 0.2 rmse standard 49497. 5 744. 478 | # ℹ 1 more variable: .config 479 | 480 | boosted model point forecaster with naive & cqr intervals using probably 481 | 482 | ``` r 483 | # naive 484 | xgb_conformal_split <- int_conformal_split(best_gbm, 485 | cal_data = cal_holdout) 486 | 487 | xgb_conformal_split_test <- predict(conformal_split, testing, level = 0.80) 488 | 489 | # cqr 490 | xgb_conformal_cqr <- int_conformal_quantile(best_gbm, 491 | train_data = training, 492 | cal_data = cal_holdout, 493 | level = 0.80) 494 | 495 | 496 | xgb_conformal_cqr_test <- predict(xgb_conformal_cqr, testing) 497 | ``` 498 | 499 | range plot for naive method: 500 | 501 | ``` r 502 | testing |> 503 | select(class) |> 504 | bind_cols(xgb_conformal_split_test) |> 505 | mutate( 506 | coverage = ifelse(class < .pred_upper & class > .pred_lower, "yes", "no") 507 | ) |> 508 | ggplot() + 509 | geom_segment(aes(x = class, xend = class, 510 | y = .pred_lower, yend = .pred_upper, 511 | color = coverage), alpha = .8) + 512 | geom_point(aes(x = class, y = .pred, 513 | color = coverage), size = 2) + 514 | labs(subtitle = "naive interval", 515 | x = "actual") + 516 | theme_minimal() + 517 | theme(legend.position = "bottom") + 518 | coord_equal() + 519 | geom_abline(slope = 1) 520 | ``` 521 | 522 | ![](Chapter7_files/figure-commonmark/unnamed-chunk-20-1.png) 523 | 524 | ``` r 525 | testing |> 526 | select(class) |> 527 | bind_cols(xgb_conformal_cqr_test) |> 528 | mutate( 529 | coverage = ifelse(class < .pred_upper & class > .pred_lower, "yes", "no") 530 | ) |> 531 | ggplot() + 532 | geom_segment(aes(x = class, xend = class, 533 | y = .pred_lower, yend = .pred_upper, 534 | color = coverage), alpha = .8) + 535 | geom_point(aes(x = class, y = .pred, 536 | color = coverage), size = 2) + 537 | labs(subtitle = "cqr interval", 538 | x = "actual") + 539 | theme_minimal() + 540 | theme(legend.position = "bottom") + 541 | coord_equal() + 542 | geom_abline(slope = 1) 543 | ``` 544 | 545 | ![](Chapter7_files/figure-commonmark/unnamed-chunk-21-1.png) 546 | 547 | bin plot 548 | 549 | ``` r 550 | testing |> 551 | select(class) |> 552 | mutate( 553 | bin = ntile(n = 10) 554 | ) |> 555 | bind_cols(xgb_conformal_cqr_test) |> 556 | mutate( 557 | coverage = ifelse(class < .pred_upper & class > .pred_lower, "yes", "no") 558 | ) |> 559 | group_by(bin) |> 560 | count(coverage) |> 561 | ggplot() + 562 | geom_col(aes(x = bin, y = n, fill = coverage)) + 563 | labs(title = "CQR: prediction interval coverage", 564 | subtitle = "by binned housing price") + 565 | theme_minimal() 566 | ``` 567 | 568 | ![](Chapter7_files/figure-commonmark/unnamed-chunk-22-1.png) 569 | 570 | ``` r 571 | testing |> 572 | select(class) |> 573 | mutate( 574 | bin = ntile(n = 10) 575 | ) |> 576 | bind_cols(xgb_conformal_cqr_test) |> 577 | mutate( 578 | coverage_width = .pred_upper - .pred_lower 579 | ) |> 580 | group_by(bin) |> 581 | summarize( 582 | mean_width = mean(coverage_width) 583 | ) |> 584 | ggplot() + 585 | geom_col(aes(x = bin, y = mean_width), fill = "darkcyan") + 586 | labs(title = "CQR: prediction interval width", 587 | subtitle = "by binned housing price") + 588 | theme_minimal() 589 | ``` 590 | 591 | ![](Chapter7_files/figure-commonmark/unnamed-chunk-23-1.png) 592 | -------------------------------------------------------------------------------- /Chapter7.qmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Chapter 7 | Conformal Prediction for Regression" 3 | author: "frankiethull" 4 | format: gfm 5 | --- 6 | 7 | ## Chapter 7 to Practical Guide to Applied Conformal Prediction in **R**: 8 | 9 | The following code is based on the recent book release: *Practical Guide to Applied Conformal Prediction in Python*. After posting a fuzzy GIF on X & receiving a lot of requests for a blog or Github repo, below is Chapter 7 of the practical guide with applications in R, instead of Python. 10 | 11 | While the book is not free, the Python code is open-source and a located at the following github repo: 12 | *https://github.com/PacktPublishing/Practical-Guide-to-Applied-Conformal-Prediction/blob/main/Chapter_07.ipynb* 13 | 14 | While this is not copy/paste direct replica of the python notebook or book, this is a lite, supplemental R guide, & documentation for R users. 15 | 16 | We will follow the example of calculating conformal prediction intervals manually, then use the probably package. 17 | 18 | ### R setup for tidymodeling: 19 | ```{r} 20 | # using tidymodel framework: 21 | library(tidymodels) # ml modeling api 22 | library(probably) # conformal ints 23 | library(dplyr) # pliers keep it tidy 24 | library(ggplot2) # data viz 25 | library(reticulate) # pass the python example dataset :) 26 | library(doParallel) # model tuning made fast 27 | ``` 28 | 29 | ```{r} 30 | # reticulate::py_install("openml", pip = TRUE) 31 | # reticulate::py_install("pandas", pip = TRUE) 32 | ``` 33 | ### Load Dataset 34 | get the matching dataset via openml, quick python chunk from the original ipynb: 35 | ```{python} 36 | import openml 37 | import pandas as pd 38 | 39 | # List of datasets from openml https://docs.openml.org/Python-API/ 40 | datasets_df = openml.datasets.list_datasets(output_format="dataframe") 41 | print(datasets_df.head(n=10)) 42 | 43 | datasets_df.set_index('did', inplace = True) 44 | 45 | # California housing dataset https://www.openml.org/search?type=data&status=active&id=43939 46 | dataset = openml.datasets.get_dataset(43939) 47 | 48 | 49 | # Print a summary 50 | print( 51 | f"This is dataset '{dataset.name}', the target feature is " 52 | f"'{dataset.default_target_attribute}'" 53 | ) 54 | print(f"URL: {dataset.url}") 55 | print(dataset.description[:500]) 56 | 57 | # openml API 58 | X, y, categorical_indicator, attribute_names = dataset.get_data( 59 | dataset_format="array", target=dataset.default_target_attribute 60 | ) 61 | df = pd.DataFrame(X, columns=attribute_names) 62 | df["class"] = y 63 | ``` 64 | #### pass the python df to R: 65 | ```{r} 66 | df <- py$df 67 | ``` 68 | 69 | data checks: 70 | ```{r} 71 | df |> str() 72 | ``` 73 | na checks: 74 | ```{r} 75 | colSums(is.na(df)) 76 | ``` 77 | 78 | ```{r} 79 | df <- df |> 80 | na.omit() 81 | ``` 82 | 83 | data processing for regression: 84 | ```{r} 85 | 86 | # holdout 10% of data for calibration 87 | cal_holdout <- dplyr::slice_sample(df, prop = .1) 88 | 89 | # proceed typical test/train splitting, a tidymodels workflow based on ipynb: 90 | model_df <- df |> anti_join(cal_holdout) 91 | 92 | split <- model_df |> initial_split(prop = 0.99) 93 | training <- training(split) 94 | testing <- testing(split) 95 | ``` 96 | 97 | model building: 98 | ```{r} 99 | # random forest model spec, specifying 'mode' and 'engine' 100 | rf_model_spec <- 101 | rand_forest(trees = 200, min_n = 5) %>% 102 | set_mode("regression") %>% 103 | set_engine("ranger") 104 | 105 | rf_wflow <- workflow(class ~ ., rf_model_spec) 106 | rf_model_fit <- rf_wflow |> fit(data = training) 107 | 108 | 109 | ``` 110 | 111 | ## ICP Section 112 | 113 | ```{r} 114 | # make point predictions 115 | pred_cal <- rf_model_fit |> predict(cal_holdout) 116 | pred_test <- rf_model_fit |> predict(testing) 117 | 118 | 119 | data.frame( 120 | y = cal_holdout$class, 121 | y_hat = pred_cal$.pred 122 | ) |> 123 | ggplot() + 124 | geom_point(aes(x = y, y = y_hat), color = "darkcyan", alpha = .9) + 125 | theme_minimal() + 126 | labs(title = "Prediction Error for RandomForestRegressor") 127 | 128 | ``` 129 | 130 | ```{r} 131 | alpha <- 0.05 132 | n_cal <- nrow(cal_holdout) 133 | 134 | y_cal <- cal_holdout$class 135 | y_pred_cal <- pred_cal$.pred 136 | 137 | # calculate calibraion errors 138 | y_cal_error <- abs(y_cal - y_pred_cal) 139 | 140 | ceiling((n_cal+1)*(1-alpha))/n_cal 141 | ``` 142 | ```{r} 143 | #calculate q_hat on the calibration set 144 | q_yhat_cal = quantile(y_cal_error,ceiling((n_cal+1)*(1-alpha))/n_cal) 145 | q_yhat_cal 146 | ``` 147 | 148 | ```{r} 149 | ggplot() + 150 | geom_histogram(aes(x = y_cal_error), fill = "lightblue") + 151 | geom_vline(aes(xintercept = q_yhat_cal), color = "red", linetype = 2) + 152 | labs( 153 | title = "Histogram of Calibration Errors", 154 | x = "Calibration Error", 155 | y = "Frequency" 156 | ) + 157 | theme_minimal() 158 | ``` 159 | 160 | 161 | ```{r} 162 | # predicted_df 163 | pred_test |> 164 | mutate( 165 | lower_bound = .pred - q_yhat_cal, 166 | upper_bound = .pred + q_yhat_cal, 167 | actual = testing$class 168 | ) |> 169 | mutate( 170 | index = row_number() 171 | ) |> 172 | ggplot(aes(x = index)) + 173 | geom_ribbon(aes(ymin = lower_bound, 174 | ymax = upper_bound), fill = "grey", 175 | alpha = 0.5) + 176 | geom_line(aes(y = actual, color = "Actual")) + 177 | geom_line(aes(y = .pred, color = "Predicted")) + 178 | theme_minimal() + 179 | labs( 180 | title = "Actual vs Predicted Values with Prediction Interval" 181 | ) + 182 | theme(legend.title = element_blank(), 183 | legend.position = c(.9,.9)) 184 | ``` 185 | ### using probably 186 | 187 | doing the routine in a 'tidy' way, one can use *probably* package for split conformal inference. probably is a tidymodels extension package allowing for various interval and post-calibration modeling techniques. 188 | ```{r} 189 | 190 | conformal_split <- int_conformal_split(rf_model_fit, 191 | cal_data = cal_holdout) 192 | 193 | conformal_split_test <- predict(conformal_split, testing, level = 0.95) 194 | 195 | conformal_split_test |> 196 | mutate( 197 | actual = testing$class, 198 | index = row_number() 199 | ) |> 200 | ggplot(aes(x = index)) + 201 | geom_ribbon(aes(ymin = .pred_lower, 202 | ymax = .pred_upper), fill = "grey", 203 | alpha = 0.5) + 204 | geom_line(aes(y = actual, color = "Actual")) + 205 | geom_line(aes(y = .pred, color = "Predicted")) + 206 | theme_minimal() + 207 | labs( 208 | title = "Actual vs Predicted Values with Prediction Interval", 209 | subtitle = "Using {probably}" 210 | ) + 211 | theme(legend.title = element_blank(), 212 | legend.position = c(.9,.9)) 213 | 214 | ``` 215 | 216 | 217 | ## CQR Section 218 | 219 | compute correlation between features and also between features and the target 220 | ```{r} 221 | df |> 222 | select(-ocean_proximity) |> 223 | select_if(is.numeric) |> 224 | corrr::correlate() |> 225 | #corrr::rearrange() |> 226 | corrr::shave() 227 | #corrr::rplot() 228 | ``` 229 | 230 | ```{r} 231 | df |> 232 | ggplot() + 233 | geom_histogram(aes(class), fill = "lightblue") + 234 | theme_minimal() + 235 | labs(title = "histogram of house prices", 236 | x = "median price of houses") + 237 | scale_x_continuous(labels = scales::dollar_format()) 238 | ``` 239 | ### Optimize underlying tree model 240 | ```{r} 241 | 242 | folds <- vfold_cv(training, v = 5) 243 | 244 | params_distributions <- 245 | expand.grid( 246 | trees = c(10, 25), 247 | tree_depth = c(3, 10), 248 | mtry = c(50, 100), 249 | learn_rate = c(.01, .2) 250 | ) 251 | 252 | model_recipe <- recipe(class ~ ., training) 253 | 254 | # refer to api documentation on how-to pass quantile objective to various engines 255 | gbm_spec <- 256 | boost_tree( 257 | trees = tune(), 258 | tree_depth = tune(), 259 | mtry = tune(), 260 | learn_rate = tune() 261 | ) |> 262 | set_mode("regression") |> 263 | set_engine("xgboost", num_threads = 8) 264 | 265 | # pre training settings --- 266 | cluster <- makePSOCKcluster(8) 267 | registerDoParallel(cluster) 268 | 269 | # model creation --- 270 | gbm_results <- 271 | finetune::tune_race_anova( 272 | workflow() %>% 273 | add_recipe(model_recipe) %>% 274 | add_model(gbm_spec), 275 | resamples = folds, 276 | grid = params_distributions, 277 | control = finetune::control_race(), 278 | metrics = metric_set(rmse) 279 | ) 280 | 281 | # post training settings --- 282 | stopCluster(cluster) 283 | registerDoSEQ() 284 | 285 | finalize_gbm <- workflow() %>% 286 | add_recipe(model_recipe) %>% 287 | add_model(gbm_spec) %>% 288 | finalize_workflow(select_best(gbm_results)) 289 | 290 | best_gbm <- finalize_gbm |> fit(training) 291 | 292 | show_best(gbm_results) 293 | ``` 294 | boosted model point forecaster with naive & cqr intervals using probably 295 | ```{r} 296 | # naive 297 | xgb_conformal_split <- int_conformal_split(best_gbm, 298 | cal_data = cal_holdout) 299 | 300 | xgb_conformal_split_test <- predict(conformal_split, testing, level = 0.80) 301 | 302 | # cqr 303 | xgb_conformal_cqr <- int_conformal_quantile(best_gbm, 304 | train_data = training, 305 | cal_data = cal_holdout, 306 | level = 0.80) 307 | 308 | 309 | xgb_conformal_cqr_test <- predict(xgb_conformal_cqr, testing) 310 | 311 | 312 | ``` 313 | 314 | range plot for naive method: 315 | ```{r} 316 | testing |> 317 | select(class) |> 318 | bind_cols(xgb_conformal_split_test) |> 319 | mutate( 320 | coverage = ifelse(class < .pred_upper & class > .pred_lower, "yes", "no") 321 | ) |> 322 | ggplot() + 323 | geom_segment(aes(x = class, xend = class, 324 | y = .pred_lower, yend = .pred_upper, 325 | color = coverage), alpha = .8) + 326 | geom_point(aes(x = class, y = .pred, 327 | color = coverage), size = 2) + 328 | labs(subtitle = "naive interval", 329 | x = "actual") + 330 | theme_minimal() + 331 | theme(legend.position = "bottom") + 332 | coord_equal() + 333 | geom_abline(slope = 1) 334 | ``` 335 | ```{r} 336 | testing |> 337 | select(class) |> 338 | bind_cols(xgb_conformal_cqr_test) |> 339 | mutate( 340 | coverage = ifelse(class < .pred_upper & class > .pred_lower, "yes", "no") 341 | ) |> 342 | ggplot() + 343 | geom_segment(aes(x = class, xend = class, 344 | y = .pred_lower, yend = .pred_upper, 345 | color = coverage), alpha = .8) + 346 | geom_point(aes(x = class, y = .pred, 347 | color = coverage), size = 2) + 348 | labs(subtitle = "cqr interval", 349 | x = "actual") + 350 | theme_minimal() + 351 | theme(legend.position = "bottom") + 352 | coord_equal() + 353 | geom_abline(slope = 1) 354 | ``` 355 | 356 | bin plot 357 | ```{r} 358 | testing |> 359 | select(class) |> 360 | mutate( 361 | bin = ntile(n = 10) 362 | ) |> 363 | bind_cols(xgb_conformal_cqr_test) |> 364 | mutate( 365 | coverage = ifelse(class < .pred_upper & class > .pred_lower, "yes", "no") 366 | ) |> 367 | group_by(bin) |> 368 | count(coverage) |> 369 | ggplot() + 370 | geom_col(aes(x = bin, y = n, fill = coverage)) + 371 | labs(title = "CQR: prediction interval coverage", 372 | subtitle = "by binned housing price") + 373 | theme_minimal() 374 | 375 | ``` 376 | 377 | ```{r} 378 | testing |> 379 | select(class) |> 380 | mutate( 381 | bin = ntile(n = 10) 382 | ) |> 383 | bind_cols(xgb_conformal_cqr_test) |> 384 | mutate( 385 | coverage_width = .pred_upper - .pred_lower 386 | ) |> 387 | group_by(bin) |> 388 | summarize( 389 | mean_width = mean(coverage_width) 390 | ) |> 391 | ggplot() + 392 | geom_col(aes(x = bin, y = mean_width), fill = "darkcyan") + 393 | labs(title = "CQR: prediction interval width", 394 | subtitle = "by binned housing price") + 395 | theme_minimal() 396 | ``` 397 | 398 | -------------------------------------------------------------------------------- /Chapter7_files/figure-commonmark/unnamed-chunk-10-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankiethull/Practical-Guide-to-Applied-Conformal-Prediction-in-R/197da490a6a13b339eba700b5c5f2dda26957a03/Chapter7_files/figure-commonmark/unnamed-chunk-10-1.png -------------------------------------------------------------------------------- /Chapter7_files/figure-commonmark/unnamed-chunk-13-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankiethull/Practical-Guide-to-Applied-Conformal-Prediction-in-R/197da490a6a13b339eba700b5c5f2dda26957a03/Chapter7_files/figure-commonmark/unnamed-chunk-13-1.png -------------------------------------------------------------------------------- /Chapter7_files/figure-commonmark/unnamed-chunk-14-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankiethull/Practical-Guide-to-Applied-Conformal-Prediction-in-R/197da490a6a13b339eba700b5c5f2dda26957a03/Chapter7_files/figure-commonmark/unnamed-chunk-14-1.png -------------------------------------------------------------------------------- /Chapter7_files/figure-commonmark/unnamed-chunk-15-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankiethull/Practical-Guide-to-Applied-Conformal-Prediction-in-R/197da490a6a13b339eba700b5c5f2dda26957a03/Chapter7_files/figure-commonmark/unnamed-chunk-15-1.png -------------------------------------------------------------------------------- /Chapter7_files/figure-commonmark/unnamed-chunk-17-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankiethull/Practical-Guide-to-Applied-Conformal-Prediction-in-R/197da490a6a13b339eba700b5c5f2dda26957a03/Chapter7_files/figure-commonmark/unnamed-chunk-17-1.png -------------------------------------------------------------------------------- /Chapter7_files/figure-commonmark/unnamed-chunk-20-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankiethull/Practical-Guide-to-Applied-Conformal-Prediction-in-R/197da490a6a13b339eba700b5c5f2dda26957a03/Chapter7_files/figure-commonmark/unnamed-chunk-20-1.png -------------------------------------------------------------------------------- /Chapter7_files/figure-commonmark/unnamed-chunk-21-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankiethull/Practical-Guide-to-Applied-Conformal-Prediction-in-R/197da490a6a13b339eba700b5c5f2dda26957a03/Chapter7_files/figure-commonmark/unnamed-chunk-21-1.png -------------------------------------------------------------------------------- /Chapter7_files/figure-commonmark/unnamed-chunk-22-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankiethull/Practical-Guide-to-Applied-Conformal-Prediction-in-R/197da490a6a13b339eba700b5c5f2dda26957a03/Chapter7_files/figure-commonmark/unnamed-chunk-22-1.png -------------------------------------------------------------------------------- /Chapter7_files/figure-commonmark/unnamed-chunk-23-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankiethull/Practical-Guide-to-Applied-Conformal-Prediction-in-R/197da490a6a13b339eba700b5c5f2dda26957a03/Chapter7_files/figure-commonmark/unnamed-chunk-23-1.png -------------------------------------------------------------------------------- /Chapter8_modeltime.md: -------------------------------------------------------------------------------- 1 | # Chapter 8 \| Conformal Prediction for Time Series and Forecasting 2 | frankiethull 3 | 4 | ## Chapter 8 to Practical Guide to Applied Conformal Prediction in **R**: 5 | 6 | The following code is based on the recent book release: *Practical Guide 7 | to Applied Conformal Prediction in Python*. After posting a fuzzy GIF on 8 | X & receiving a lot of requests for a blog or Github repo, below is 9 | Chapter 8 of the practical guide with applications in R, instead of 10 | Python. 11 | 12 | While the book is not free, the Python code is open-source and a located 13 | at the following github repo: 14 | *https://github.com/PacktPublishing/Practical-Guide-to-Applied-Conformal-Prediction/blob/main/Chapter_08_NixtlaStatsforecastipynb* 15 | 16 | While this is not copy/paste direct replica of the python notebook or 17 | book, this is a lite, supplemental R guide, & documentation for R users. 18 | 19 | We will follow the example of time series and forecasting using fable & 20 | conformal prediction intervals using the modeltime package. 21 | 22 | ### R setup for fable & modeltime: 23 | 24 | ``` r 25 | # using tidymodel framework: 26 | library(tidymodels) # ml modeling api 27 | ``` 28 | 29 | ── Attaching packages ────────────────────────────────────── tidymodels 1.1.0 ── 30 | 31 | ✔ broom 1.0.5 ✔ recipes 1.0.6 32 | ✔ dials 1.2.0 ✔ rsample 1.1.1 33 | ✔ dplyr 1.1.2 ✔ tibble 3.2.1 34 | ✔ ggplot2 3.4.2 ✔ tidyr 1.3.0 35 | ✔ infer 1.0.4 ✔ tune 1.1.1 36 | ✔ modeldata 1.1.0 ✔ workflows 1.1.3 37 | ✔ parsnip 1.1.0 ✔ workflowsets 1.0.1 38 | ✔ purrr 1.0.1 ✔ yardstick 1.2.0 39 | 40 | ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ── 41 | ✖ purrr::discard() masks scales::discard() 42 | ✖ dplyr::filter() masks stats::filter() 43 | ✖ dplyr::lag() masks stats::lag() 44 | ✖ recipes::step() masks stats::step() 45 | • Use suppressPackageStartupMessages() to eliminate package startup messages 46 | 47 | ``` r 48 | library(modeltime) # tidy time series 49 | library(fable) # tidy time series 50 | ``` 51 | 52 | Loading required package: fabletools 53 | 54 | 55 | Attaching package: 'fabletools' 56 | 57 | The following object is masked from 'package:yardstick': 58 | 59 | accuracy 60 | 61 | The following object is masked from 'package:parsnip': 62 | 63 | null_model 64 | 65 | The following objects are masked from 'package:infer': 66 | 67 | generate, hypothesize 68 | 69 | ``` r 70 | library(timetk) # temporal kit 71 | library(tsibble) # temporal kit 72 | ``` 73 | 74 | 75 | Attaching package: 'tsibble' 76 | 77 | The following objects are masked from 'package:base': 78 | 79 | intersect, setdiff, union 80 | 81 | ``` r 82 | library(dplyr) # pliers keep it tidy 83 | library(ggplot2) # data viz 84 | library(reticulate) # pass the python example dataset :) 85 | ``` 86 | 87 | Warning: package 'reticulate' was built under R version 4.3.1 88 | 89 | ``` r 90 | library(doParallel) # model tuning made fast 91 | ``` 92 | 93 | Loading required package: foreach 94 | 95 | 96 | Attaching package: 'foreach' 97 | 98 | The following objects are masked from 'package:purrr': 99 | 100 | accumulate, when 101 | 102 | Loading required package: iterators 103 | 104 | Loading required package: parallel 105 | 106 | ### Load the dataset 107 | 108 | ``` r 109 | train = read.csv('https://auto-arima-results.s3.amazonaws.com/M4-Hourly.csv') 110 | test = read.csv('https://auto-arima-results.s3.amazonaws.com/M4-Hourly-test.csv') 111 | ``` 112 | 113 | ``` r 114 | train |> head() 115 | ``` 116 | 117 | unique_id ds y 118 | 1 H1 1 605 119 | 2 H1 2 586 120 | 3 H1 3 586 121 | 4 H1 4 559 122 | 5 H1 5 511 123 | 6 H1 6 443 124 | 125 | ### Train the models 126 | 127 | we will only use the first 4 series of the dataset to reduce the total 128 | computational time. 129 | 130 | ``` r 131 | n_series <- 4 132 | uids <- paste0("H", seq(1:n_series)) 133 | 134 | train <- train |> filter(unique_id %in% uids) |> group_by(unique_id) 135 | test <- test |> filter(unique_id %in% uids) 136 | ``` 137 | 138 | ``` r 139 | train |> 140 | ggplot() + 141 | geom_line(aes(x = ds, y = y, color = "train")) + 142 | geom_line(inherit.aes = FALSE, 143 | data = test, 144 | aes(x = ds, y = y, color = "test")) + 145 | facet_wrap(~unique_id, scales = "free") + 146 | theme_minimal() + 147 | theme( 148 | legend.position = "top" 149 | ) + 150 | labs(subtitle = "data split") 151 | ``` 152 | 153 | ![](Chapter8_modeltime_files/figure-commonmark/unnamed-chunk-5-1.png) 154 | 155 | #### Create a list of models using fable 156 | 157 | for this example we are using fable library 158 | fable is a ‘tidy’ version of the forecast library. 159 | 160 | Both are user-friendly & have accompanying books (fpp2 & fpp3 by rob 161 | hyndman). \##### plot prediction intervals 162 | 163 | ``` r 164 | train_fbl <- train |> tsibble::as_tsibble(index = ds, key = unique_id) 165 | test_fbl <- test |> tsibble::as_tsibble(index = ds, key = unique_id) 166 | 167 | train_fbl |> 168 | model( 169 | ets = ETS(y), 170 | naive = NAIVE(y), 171 | rw = RW(y), 172 | snaive = SNAIVE(y) 173 | ) |> 174 | forecast(new_data = test_fbl) |> 175 | autoplot() + 176 | geom_line(inherit.aes = FALSE, 177 | data = train_fbl, 178 | aes(x = ds, y = y, color = "train")) + 179 | theme_minimal() + 180 | labs(subtitle = "{fable} predictions") 181 | ``` 182 | 183 | Warning: 4 errors (1 unique) encountered for snaive 184 | [4] Non-seasonal model specification provided, use RW() or provide a different lag specification. 185 | 186 | Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning 187 | -Inf 188 | 189 | Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning 190 | -Inf 191 | 192 | Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning 193 | -Inf 194 | 195 | Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning 196 | -Inf 197 | 198 | Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning 199 | -Inf 200 | 201 | Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning 202 | -Inf 203 | 204 | Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning 205 | -Inf 206 | 207 | Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning 208 | -Inf 209 | 210 | Warning: Removed 192 rows containing missing values (`()`). 211 | 212 | ![](Chapter8_modeltime_files/figure-commonmark/unnamed-chunk-6-1.png) 213 | 214 | ``` r 215 | train_fbl |> 216 | model( 217 | auto_arima = ARIMA(y) 218 | ) |> 219 | forecast(new_data = test_fbl) |> 220 | autoplot() + 221 | geom_line(inherit.aes = FALSE, 222 | data = train_fbl, 223 | aes(x = ds, y = y, color = "train")) + 224 | theme_minimal() + 225 | labs(subtitle = "AutoARIMA via {fable}") 226 | ``` 227 | 228 | ![](Chapter8_modeltime_files/figure-commonmark/unnamed-chunk-7-1.png) 229 | 230 | The next section will switch to a modeltime workflow. modeltime is the 231 | tidymodels for time series. 232 | 233 | #### Conformal Prediction with modeltime 234 | 235 | There are two methods for conformal prediction in modeltime, it is the 236 | only tidy timeseries library I know of that supports conformal 237 | prediction options internally and by default. 238 | 239 | The default method is quantile method but there is an option for split 240 | method as well. 241 | 242 | ##### train models 243 | 244 | ``` r 245 | # let's use for one location: 246 | mt_train <- train |> filter(unique_id == uids[[1]]) |> mutate(ds = as.Date(ds)) 247 | mt_test <- test |> filter(unique_id == uids[[1]]) |> mutate(ds = as.Date(ds)) 248 | 249 | # ETS 250 | ets_fit <- exp_smoothing(seasonal_period = 24) |> 251 | set_engine("ets") |> 252 | fit(y ~ ds, data = mt_train) 253 | 254 | # Auto ARIMA 255 | arima_fit <- arima_reg(seasonal_period = 24) |> 256 | set_engine("auto_arima") |> 257 | fit(y ~ ds, data = mt_train) 258 | 259 | # XGB 260 | xgb_fit <- boost_tree("regression") |> 261 | set_engine("xgboost") |> 262 | fit(y ~ ds, data = mt_train) 263 | 264 | # modeltime workflow 265 | modtime_fcst <- 266 | modeltime_calibrate( 267 | modeltime_table( 268 | xgb_fit, 269 | arima_fit, 270 | ets_fit 271 | ), 272 | new_data = mt_test, 273 | quiet = FALSE, 274 | id = "unique_id" 275 | ) |> 276 | modeltime_forecast( 277 | new_data = mt_test, 278 | conf_interval = 0.80, 279 | conf_method = "conformal_default", 280 | conf_by_id = TRUE, 281 | keep_data = TRUE 282 | ) 283 | ``` 284 | 285 | ##### plot prediction intervals 286 | 287 | ``` r 288 | modtime_fcst |> 289 | ggplot() + 290 | geom_ribbon(aes(x = ds, ymin = .conf_lo, ymax = .conf_hi, fill = .model_desc), 291 | alpha = 0.5) + 292 | geom_line(aes(x = ds, y = .value, color = .model_desc)) + 293 | geom_line(inherit.aes = FALSE, 294 | data = mt_train, 295 | aes(x = as.Date(ds), y = y, color = "train")) + 296 | facet_wrap(~unique_id, scales = "free") + 297 | theme_minimal() + 298 | theme(legend.position = "top") + 299 | labs(subtitle = "{modeltime} Default Conformal Prediction Intervals") 300 | ``` 301 | 302 | ![](Chapter8_modeltime_files/figure-commonmark/unnamed-chunk-9-1.png) 303 | 304 | ``` r 305 | modtime_fcst |> 306 | filter(stringr::str_detect(.model_desc, "ARIMA")) |> 307 | ggplot() + 308 | geom_ribbon(aes(x = ds, ymin = .conf_lo, ymax = .conf_hi, fill = "ARIMA"), 309 | alpha = 0.5) + 310 | geom_line(aes(x = ds, y = .value, color = "ARIMA")) + 311 | geom_line(inherit.aes = FALSE, 312 | data = mt_train |> tail(-500), 313 | aes(x = as.Date(ds), y = y, color = "train")) + 314 | facet_wrap(~unique_id, scales = "free") + 315 | theme_minimal() + 316 | theme(legend.position = "top") + 317 | labs(subtitle = "{modeltime} Default Conformal Prediction Intervals with ARIMA") 318 | ``` 319 | 320 | ![](Chapter8_modeltime_files/figure-commonmark/unnamed-chunk-10-1.png) 321 | -------------------------------------------------------------------------------- /Chapter8_modeltime.qmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Chapter 8 | Conformal Prediction for Time Series and Forecasting" 3 | author: "frankiethull" 4 | format: gfm 5 | --- 6 | 7 | ## Chapter 8 to Practical Guide to Applied Conformal Prediction in **R**: 8 | 9 | The following code is based on the recent book release: *Practical Guide to Applied Conformal Prediction in Python*. After posting a fuzzy GIF on X & receiving a lot of requests for a blog or Github repo, below is Chapter 8 of the practical guide with applications in R, instead of Python. 10 | 11 | While the book is not free, the Python code is open-source and a located at the following github repo: 12 | *https://github.com/PacktPublishing/Practical-Guide-to-Applied-Conformal-Prediction/blob/main/Chapter_08_NixtlaStatsforecastipynb* 13 | 14 | While this is not copy/paste direct replica of the python notebook or book, this is a lite, supplemental R guide, & documentation for R users. 15 | 16 | We will follow the example of time series and forecasting using fable & conformal prediction intervals using the modeltime package. 17 | 18 | ### R setup for fable & modeltime: 19 | ```{r} 20 | # using tidymodel framework: 21 | library(tidymodels) # ml modeling api 22 | library(modeltime) # tidy time series 23 | library(fable) # tidy time series 24 | library(timetk) # temporal kit 25 | library(tsibble) # temporal kit 26 | library(dplyr) # pliers keep it tidy 27 | library(ggplot2) # data viz 28 | library(reticulate) # pass the python example dataset :) 29 | library(doParallel) # model tuning made fast 30 | ``` 31 | 32 | ### Load the dataset 33 | ```{r} 34 | train = read.csv('https://auto-arima-results.s3.amazonaws.com/M4-Hourly.csv') 35 | test = read.csv('https://auto-arima-results.s3.amazonaws.com/M4-Hourly-test.csv') 36 | ``` 37 | 38 | 39 | ```{r} 40 | train |> head() 41 | ``` 42 | ### Train the models 43 | 44 | 45 | we will only use the first 4 series of the dataset to reduce the total computational time. 46 | 47 | 48 | ```{r} 49 | n_series <- 4 50 | uids <- paste0("H", seq(1:n_series)) 51 | 52 | train <- train |> filter(unique_id %in% uids) |> group_by(unique_id) 53 | test <- test |> filter(unique_id %in% uids) 54 | 55 | ``` 56 | 57 | 58 | ```{r} 59 | train |> 60 | ggplot() + 61 | geom_line(aes(x = ds, y = y, color = "train")) + 62 | geom_line(inherit.aes = FALSE, 63 | data = test, 64 | aes(x = ds, y = y, color = "test")) + 65 | facet_wrap(~unique_id, scales = "free") + 66 | theme_minimal() + 67 | theme( 68 | legend.position = "top" 69 | ) + 70 | labs(subtitle = "data split") 71 | 72 | 73 | ``` 74 | #### Create a list of models using fable 75 | for this example we are using fable library 76 | fable is a 'tidy' version of the forecast library. 77 | 78 | Both are user-friendly & have accompanying books (fpp2 & fpp3 by rob hyndman). 79 | ##### plot prediction intervals 80 | ```{r} 81 | train_fbl <- train |> tsibble::as_tsibble(index = ds, key = unique_id) 82 | test_fbl <- test |> tsibble::as_tsibble(index = ds, key = unique_id) 83 | 84 | train_fbl |> 85 | model( 86 | ets = ETS(y), 87 | naive = NAIVE(y), 88 | rw = RW(y), 89 | snaive = SNAIVE(y) 90 | ) |> 91 | forecast(new_data = test_fbl) |> 92 | autoplot() + 93 | geom_line(inherit.aes = FALSE, 94 | data = train_fbl, 95 | aes(x = ds, y = y, color = "train")) + 96 | theme_minimal() + 97 | labs(subtitle = "{fable} predictions") 98 | 99 | ``` 100 | 101 | ```{r} 102 | train_fbl |> 103 | model( 104 | auto_arima = ARIMA(y) 105 | ) |> 106 | forecast(new_data = test_fbl) |> 107 | autoplot() + 108 | geom_line(inherit.aes = FALSE, 109 | data = train_fbl, 110 | aes(x = ds, y = y, color = "train")) + 111 | theme_minimal() + 112 | labs(subtitle = "AutoARIMA via {fable}") 113 | 114 | ``` 115 | 116 | 117 | The next section will switch to a modeltime workflow. modeltime is the tidymodels for time series. 118 | 119 | #### Conformal Prediction with modeltime 120 | 121 | There are two methods for conformal prediction in modeltime, it is the only tidy timeseries library I know of that supports conformal prediction options internally and by default. 122 | 123 | The default method is quantile method but there is an option for split method as well. 124 | 125 | ##### train models 126 | ```{r} 127 | # let's use for one location: 128 | mt_train <- train |> filter(unique_id == uids[[1]]) |> mutate(ds = as.Date(ds)) 129 | mt_test <- test |> filter(unique_id == uids[[1]]) |> mutate(ds = as.Date(ds)) 130 | 131 | # ETS 132 | ets_fit <- exp_smoothing(seasonal_period = 24) |> 133 | set_engine("ets") |> 134 | fit(y ~ ds, data = mt_train) 135 | 136 | # Auto ARIMA 137 | arima_fit <- arima_reg(seasonal_period = 24) |> 138 | set_engine("auto_arima") |> 139 | fit(y ~ ds, data = mt_train) 140 | 141 | # XGB 142 | xgb_fit <- boost_tree("regression") |> 143 | set_engine("xgboost") |> 144 | fit(y ~ ds, data = mt_train) 145 | 146 | # modeltime workflow 147 | modtime_fcst <- 148 | modeltime_calibrate( 149 | modeltime_table( 150 | xgb_fit, 151 | arima_fit, 152 | ets_fit 153 | ), 154 | new_data = mt_test, 155 | quiet = FALSE, 156 | id = "unique_id" 157 | ) |> 158 | modeltime_forecast( 159 | new_data = mt_test, 160 | conf_interval = 0.80, 161 | conf_method = "conformal_default", 162 | conf_by_id = TRUE, 163 | keep_data = TRUE 164 | ) 165 | 166 | ``` 167 | 168 | 169 | ##### plot prediction intervals 170 | 171 | ```{r} 172 | 173 | modtime_fcst |> 174 | ggplot() + 175 | geom_ribbon(aes(x = ds, ymin = .conf_lo, ymax = .conf_hi, fill = .model_desc), 176 | alpha = 0.5) + 177 | geom_line(aes(x = ds, y = .value, color = .model_desc)) + 178 | geom_line(inherit.aes = FALSE, 179 | data = mt_train, 180 | aes(x = as.Date(ds), y = y, color = "train")) + 181 | facet_wrap(~unique_id, scales = "free") + 182 | theme_minimal() + 183 | theme(legend.position = "top") + 184 | labs(subtitle = "{modeltime} Default Conformal Prediction Intervals") 185 | 186 | ``` 187 | 188 | ```{r} 189 | modtime_fcst |> 190 | filter(stringr::str_detect(.model_desc, "ARIMA")) |> 191 | ggplot() + 192 | geom_ribbon(aes(x = ds, ymin = .conf_lo, ymax = .conf_hi, fill = "ARIMA"), 193 | alpha = 0.5) + 194 | geom_line(aes(x = ds, y = .value, color = "ARIMA")) + 195 | geom_line(inherit.aes = FALSE, 196 | data = mt_train |> tail(-500), 197 | aes(x = as.Date(ds), y = y, color = "train")) + 198 | facet_wrap(~unique_id, scales = "free") + 199 | theme_minimal() + 200 | theme(legend.position = "top") + 201 | labs(subtitle = "{modeltime} Default Conformal Prediction Intervals with ARIMA") 202 | 203 | ``` 204 | -------------------------------------------------------------------------------- /Chapter8_modeltime_files/figure-commonmark/unnamed-chunk-10-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankiethull/Practical-Guide-to-Applied-Conformal-Prediction-in-R/197da490a6a13b339eba700b5c5f2dda26957a03/Chapter8_modeltime_files/figure-commonmark/unnamed-chunk-10-1.png -------------------------------------------------------------------------------- /Chapter8_modeltime_files/figure-commonmark/unnamed-chunk-5-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankiethull/Practical-Guide-to-Applied-Conformal-Prediction-in-R/197da490a6a13b339eba700b5c5f2dda26957a03/Chapter8_modeltime_files/figure-commonmark/unnamed-chunk-5-1.png -------------------------------------------------------------------------------- /Chapter8_modeltime_files/figure-commonmark/unnamed-chunk-6-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankiethull/Practical-Guide-to-Applied-Conformal-Prediction-in-R/197da490a6a13b339eba700b5c5f2dda26957a03/Chapter8_modeltime_files/figure-commonmark/unnamed-chunk-6-1.png -------------------------------------------------------------------------------- /Chapter8_modeltime_files/figure-commonmark/unnamed-chunk-7-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankiethull/Practical-Guide-to-Applied-Conformal-Prediction-in-R/197da490a6a13b339eba700b5c5f2dda26957a03/Chapter8_modeltime_files/figure-commonmark/unnamed-chunk-7-1.png -------------------------------------------------------------------------------- /Chapter8_modeltime_files/figure-commonmark/unnamed-chunk-9-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankiethull/Practical-Guide-to-Applied-Conformal-Prediction-in-R/197da490a6a13b339eba700b5c5f2dda26957a03/Chapter8_modeltime_files/figure-commonmark/unnamed-chunk-9-1.png -------------------------------------------------------------------------------- /Chapter8_nixtla_reticulated.md: -------------------------------------------------------------------------------- 1 | # Chapter 8 \| Conformal Prediction for Time Series and Forecasting 2 | frankiethull 3 | 4 | ## Chapter 8 to Practical Guide to Applied Conformal Prediction in **R**: 5 | 6 | The following code is based on the recent book release: *Practical Guide 7 | to Applied Conformal Prediction in Python*. After posting a fuzzy GIF on 8 | X & receiving a lot of requests for a blog or Github repo, below is 9 | Chapter 8 of the practical guide with applications in R, instead of 10 | Python. 11 | 12 | While the book is not free, the Python code is open-source and a located 13 | at the following github repo: 14 | *https://github.com/PacktPublishing/Practical-Guide-to-Applied-Conformal-Prediction/blob/main/Chapter_08_NixtlaStatsforecast.ipynb* 15 | 16 | While this is not copy/paste direct replica of the python notebook or 17 | book, this is a lite, supplemental R guide, & documentation for R users. 18 | 19 | We will follow the example of time series and forecasting using fable & 20 | conformal prediction intervals using the **nixtla package via 21 | reticulate**. 22 | 23 | ``` r 24 | # reticulate::py_install("statsforecast", pip = TRUE) 25 | ``` 26 | 27 | ### R setup for nixtla, a python lib accessed via reticulate: 28 | 29 | ``` r 30 | library(dplyr) # pliers keep it tidy 31 | ``` 32 | 33 | 34 | Attaching package: 'dplyr' 35 | 36 | The following objects are masked from 'package:stats': 37 | 38 | filter, lag 39 | 40 | The following objects are masked from 'package:base': 41 | 42 | intersect, setdiff, setequal, union 43 | 44 | ``` r 45 | library(ggplot2) # data viz 46 | library(reticulate) # pass the python example dataset :) 47 | ``` 48 | 49 | Warning: package 'reticulate' was built under R version 4.3.1 50 | 51 | ``` r 52 | # statsforecast r-to-py API obj 53 | sf <- reticulate::import("statsforecast") 54 | 55 | # or like this for submodules: 56 | #ets <- reticulate::py_run_string("from statsforecast.models import ETS") 57 | ``` 58 | 59 | ### Load the dataset 60 | 61 | ``` r 62 | train = read.csv('https://auto-arima-results.s3.amazonaws.com/M4-Hourly.csv') 63 | test = read.csv('https://auto-arima-results.s3.amazonaws.com/M4-Hourly-test.csv') 64 | ``` 65 | 66 | ### Train the models 67 | 68 | we will only use the first series of the dataset to reduce the total 69 | computational time. 70 | 71 | ``` r 72 | n_series <- 1 73 | uids <- paste0("H", seq(1:n_series)) 74 | 75 | train <- train |> filter(unique_id %in% uids) |> group_by(unique_id) 76 | test <- test |> filter(unique_id %in% uids) 77 | 78 | horizon <- test |> filter(unique_id == uids[[1]]) |> nrow() 79 | ``` 80 | 81 | ### nixtla model setup as R interfaces 82 | 83 | one thing R coders need to look out for is dtypes. *integers* instead of 84 | *dbl/numeric* are often needed in Python for parm setting. 85 | 86 | In R, need to wrap with as.integer() 87 | 88 | When importing a python module, explore the various submodules using 89 | **\$**. This allows access to underlying python tools as APIs inside of 90 | R. 91 | 92 | ``` r 93 | # compare these to the initial using str() col types are different and nixtla won't throw errors on int types 94 | 95 | train_nix <- train |> mutate(ds = as.integer(ds)) 96 | test_nix <- test |> mutate(ds = as.integer(ds)) 97 | 98 | models <- c(sf$models$ETS(season_length = as.integer(24)), 99 | sf$models$Naive(), 100 | sf$models$SeasonalNaive(season_length = as.integer(24)) 101 | ) 102 | 103 | nixfit <- sf$StatsForecast( 104 | df=train_nix, 105 | models=models, 106 | freq=as.integer(1) 107 | ) 108 | 109 | levels <- c(80, 90) 110 | 111 | nixcast <- nixfit$forecast(h = as.integer(horizon), 112 | level = as.integer(levels)) 113 | ``` 114 | 115 | Warning in py_to_r.pandas.core.frame.DataFrame(result): index contains 116 | duplicated values: row names not set 117 | 118 | ``` r 119 | nixcast |> head() 120 | ``` 121 | 122 | ds ETS ETS-lo-90 ETS-lo-80 ETS-hi-80 ETS-hi-90 Naive Naive-lo-80 123 | 1 701 631.8896 568.9789 582.8741 680.9051 694.8004 684 631.6456 124 | 2 702 559.7509 496.5244 510.4893 609.0123 622.9773 684 609.9597 125 | 3 703 519.2355 455.6948 469.7292 568.7418 582.7761 684 593.3195 126 | 4 704 486.9734 423.1201 437.2235 536.7233 550.8267 684 579.2911 127 | 5 705 464.6974 400.5330 414.7051 514.6896 528.8618 684 566.9319 128 | 6 706 452.2620 387.7880 402.0285 502.4955 516.7360 684 555.7584 129 | Naive-lo-90 Naive-hi-80 Naive-hi-90 SeasonalNaive SeasonalNaive-lo-80 130 | 1 616.8038 736.3544 751.1962 691 582.8238 131 | 2 588.9702 758.0403 779.0298 618 509.8238 132 | 3 567.6128 774.6805 800.3872 563 454.8238 133 | 4 549.6076 788.7089 818.3924 529 420.8238 134 | 5 533.7448 801.0681 834.2552 504 395.8238 135 | 6 519.4036 812.2416 848.5964 489 380.8238 136 | SeasonalNaive-lo-90 SeasonalNaive-hi-80 SeasonalNaive-hi-90 137 | 1 552.1573 799.1762 829.8427 138 | 2 479.1574 726.1762 756.8427 139 | 3 424.1574 671.1762 701.8427 140 | 4 390.1574 637.1762 667.8427 141 | 5 365.1574 612.1762 642.8427 142 | 6 350.1574 597.1762 627.8427 143 | 144 | ### plotting prediction intervals 145 | 146 | ``` r 147 | #plotly::ggplotly( 148 | nixcast |> 149 | tidyr::pivot_longer(-ds) |> 150 | ggplot() + 151 | geom_line(aes(x = ds, y = value, color = name)) + 152 | geom_line(inherit.aes = FALSE, 153 | data = train_nix |> tail(24*5), 154 | aes(x = ds, y = y, color = "train")) + 155 | theme_minimal() + 156 | labs(title = "Model results for Nixtla with pred intervals") 157 | ``` 158 | 159 | ![](Chapter8_nixtla_reticulated_files/figure-commonmark/unnamed-chunk-6-1.png) 160 | 161 | ``` r 162 | #) 163 | ``` 164 | 165 | ### Conformal Prediction with Nixtla 166 | 167 | once again, we will initiate models but specify conformal intervals in 168 | the model spec 169 | 170 | ``` r 171 | # conformal intervals are under utils: 172 | conf_int <- sf$utils$ConformalIntervals(h = as.integer(horizon), 173 | n_windows = as.integer(2)) 174 | 175 | # arima ints: 176 | arima_order <- sapply(c(24,0, 12), as.integer) 177 | # Create a list of models and instantiation parameters 178 | conf_models = c( 179 | sf$models$ADIDA(prediction_intervals=conf_int), 180 | sf$models$ARIMA(order=arima_order, 181 | season_length=as.integer(24), 182 | prediction_intervals=conf_int) 183 | ) 184 | 185 | conf_nixfit <- sf$StatsForecast( 186 | df=train_nix, 187 | models=conf_models, 188 | freq=as.integer(1) 189 | ) 190 | 191 | levels <- c(80, 90) 192 | 193 | conf_nixcast <- conf_nixfit$forecast(h = as.integer(horizon), 194 | level = as.integer(levels)) 195 | ``` 196 | 197 | Warning in py_to_r.pandas.core.frame.DataFrame(result): index contains 198 | duplicated values: row names not set 199 | 200 | ``` r 201 | conf_nixcast |> head() 202 | ``` 203 | 204 | ds ADIDA ADIDA-lo-90 ADIDA-lo-80 ADIDA-hi-80 ADIDA-hi-90 ARIMA 205 | 1 701 747.2925 628.30 634.6 859.9851 866.2851 631.4179 206 | 2 702 747.2925 551.20 552.4 942.1851 943.3851 570.8654 207 | 3 703 747.2925 517.65 522.3 972.2851 976.9351 533.0674 208 | 4 704 747.2925 480.35 484.7 1009.8851 1014.2351 505.9376 209 | 5 705 747.2925 454.10 459.2 1035.3851 1040.4851 490.0570 210 | 6 706 747.2925 441.80 446.6 1047.9851 1052.7850 492.3616 211 | ARIMA-lo-90 ARIMA-lo-80 ARIMA-hi-80 ARIMA-hi-90 212 | 1 602.3105 605.7851 657.0508 660.5254 213 | 2 551.2000 552.4000 589.3306 590.5306 214 | 3 514.3702 515.7404 550.3943 551.7645 215 | 4 480.3500 484.7000 527.1752 531.5253 216 | 5 454.1000 459.2000 520.9139 526.0140 217 | 6 441.8000 446.6000 538.1231 542.9232 218 | 219 | ``` r 220 | conf_nixcast |> 221 | select(ds, starts_with("ARIMA")) |> 222 | ggplot() + 223 | geom_ribbon(aes(x = ds, ymin = `ARIMA-lo-80`, ymax = `ARIMA-hi-80`, fill = "80th-tile"), 224 | alpha = 0.5) + 225 | geom_line(aes(x = ds, y = ARIMA, color = "arima-expected")) + 226 | geom_line(inherit.aes = FALSE, 227 | data = train_nix |> tail(24*5), 228 | aes(x = ds, y = y, color = "train")) + 229 | geom_line(inherit.aes = FALSE, 230 | data = test_nix, 231 | aes(x = ds, y = y, color = "test")) + 232 | theme_minimal() + 233 | labs(title = "Nixtla with Conformal Prediction Intervals") 234 | ``` 235 | 236 | ![](Chapter8_nixtla_reticulated_files/figure-commonmark/unnamed-chunk-8-1.png) 237 | -------------------------------------------------------------------------------- /Chapter8_nixtla_reticulated.qmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Chapter 8 | Conformal Prediction for Time Series and Forecasting" 3 | author: "frankiethull" 4 | format: gfm 5 | --- 6 | 7 | 8 | ## Chapter 8 to Practical Guide to Applied Conformal Prediction in **R**: 9 | 10 | The following code is based on the recent book release: *Practical Guide to Applied Conformal Prediction in Python*. After posting a fuzzy GIF on X & receiving a lot of requests for a blog or Github repo, below is Chapter 8 of the practical guide with applications in R, instead of Python. 11 | 12 | While the book is not free, the Python code is open-source and a located at the following github repo: 13 | *https://github.com/PacktPublishing/Practical-Guide-to-Applied-Conformal-Prediction/blob/main/Chapter_08_NixtlaStatsforecast.ipynb* 14 | 15 | While this is not copy/paste direct replica of the python notebook or book, this is a lite, supplemental R guide, & documentation for R users. 16 | 17 | We will follow the example of time series and forecasting using fable & conformal prediction intervals using the **nixtla package via reticulate**. 18 | 19 | ```{r} 20 | # reticulate::py_install("statsforecast", pip = TRUE) 21 | ``` 22 | 23 | 24 | ### R setup for nixtla, a python lib accessed via reticulate: 25 | ```{r} 26 | library(dplyr) # pliers keep it tidy 27 | library(ggplot2) # data viz 28 | library(reticulate) # pass the python example dataset :) 29 | 30 | # statsforecast r-to-py API obj 31 | sf <- reticulate::import("statsforecast") 32 | 33 | # or like this for submodules: 34 | #ets <- reticulate::py_run_string("from statsforecast.models import ETS") 35 | ``` 36 | 37 | ### Load the dataset 38 | ```{r} 39 | train = read.csv('https://auto-arima-results.s3.amazonaws.com/M4-Hourly.csv') 40 | test = read.csv('https://auto-arima-results.s3.amazonaws.com/M4-Hourly-test.csv') 41 | ``` 42 | 43 | 44 | ### Train the models 45 | 46 | we will only use the first series of the dataset to reduce the total computational time. 47 | ```{r} 48 | n_series <- 1 49 | uids <- paste0("H", seq(1:n_series)) 50 | 51 | train <- train |> filter(unique_id %in% uids) |> group_by(unique_id) 52 | test <- test |> filter(unique_id %in% uids) 53 | 54 | horizon <- test |> filter(unique_id == uids[[1]]) |> nrow() 55 | 56 | ``` 57 | 58 | 59 | ### nixtla model setup as R interfaces 60 | 61 | one thing R coders need to look out for is dtypes. *integers* instead of *dbl/numeric* are often needed in Python for parm setting. 62 | 63 | In R, need to wrap with as.integer() 64 | 65 | When importing a python module, explore the various submodules using **$**. This allows access to underlying python tools as APIs inside of R. 66 | ```{r} 67 | # compare these to the initial using str() col types are different and nixtla won't throw errors on int types 68 | 69 | train_nix <- train |> mutate(ds = as.integer(ds)) 70 | test_nix <- test |> mutate(ds = as.integer(ds)) 71 | 72 | models <- c(sf$models$ETS(season_length = as.integer(24)), 73 | sf$models$Naive(), 74 | sf$models$SeasonalNaive(season_length = as.integer(24)) 75 | ) 76 | 77 | nixfit <- sf$StatsForecast( 78 | df=train_nix, 79 | models=models, 80 | freq=as.integer(1) 81 | ) 82 | 83 | levels <- c(80, 90) 84 | 85 | nixcast <- nixfit$forecast(h = as.integer(horizon), 86 | level = as.integer(levels)) 87 | 88 | nixcast |> head() 89 | ``` 90 | 91 | ### plotting prediction intervals 92 | 93 | ```{r} 94 | 95 | #plotly::ggplotly( 96 | nixcast |> 97 | tidyr::pivot_longer(-ds) |> 98 | ggplot() + 99 | geom_line(aes(x = ds, y = value, color = name)) + 100 | geom_line(inherit.aes = FALSE, 101 | data = train_nix |> tail(24*5), 102 | aes(x = ds, y = y, color = "train")) + 103 | theme_minimal() + 104 | labs(title = "Model results for Nixtla with pred intervals") 105 | #) 106 | ``` 107 | ### Conformal Prediction with Nixtla 108 | 109 | once again, we will initiate models but specify conformal intervals in the model spec 110 | ```{r} 111 | # conformal intervals are under utils: 112 | conf_int <- sf$utils$ConformalIntervals(h = as.integer(horizon), 113 | n_windows = as.integer(2)) 114 | 115 | # arima ints: 116 | arima_order <- sapply(c(24,0, 12), as.integer) 117 | # Create a list of models and instantiation parameters 118 | conf_models = c( 119 | sf$models$ADIDA(prediction_intervals=conf_int), 120 | sf$models$ARIMA(order=arima_order, 121 | season_length=as.integer(24), 122 | prediction_intervals=conf_int) 123 | ) 124 | 125 | conf_nixfit <- sf$StatsForecast( 126 | df=train_nix, 127 | models=conf_models, 128 | freq=as.integer(1) 129 | ) 130 | 131 | levels <- c(80, 90) 132 | 133 | conf_nixcast <- conf_nixfit$forecast(h = as.integer(horizon), 134 | level = as.integer(levels)) 135 | 136 | conf_nixcast |> head() 137 | ``` 138 | 139 | ```{r} 140 | conf_nixcast |> 141 | select(ds, starts_with("ARIMA")) |> 142 | ggplot() + 143 | geom_ribbon(aes(x = ds, ymin = `ARIMA-lo-80`, ymax = `ARIMA-hi-80`, fill = "80th-tile"), 144 | alpha = 0.5) + 145 | geom_line(aes(x = ds, y = ARIMA, color = "arima-expected")) + 146 | geom_line(inherit.aes = FALSE, 147 | data = train_nix |> tail(24*5), 148 | aes(x = ds, y = y, color = "train")) + 149 | geom_line(inherit.aes = FALSE, 150 | data = test_nix, 151 | aes(x = ds, y = y, color = "test")) + 152 | theme_minimal() + 153 | labs(title = "Nixtla with Conformal Prediction Intervals") 154 | 155 | ``` 156 | 157 | -------------------------------------------------------------------------------- /Chapter8_nixtla_reticulated_files/figure-commonmark/unnamed-chunk-6-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankiethull/Practical-Guide-to-Applied-Conformal-Prediction-in-R/197da490a6a13b339eba700b5c5f2dda26957a03/Chapter8_nixtla_reticulated_files/figure-commonmark/unnamed-chunk-6-1.png -------------------------------------------------------------------------------- /Chapter8_nixtla_reticulated_files/figure-commonmark/unnamed-chunk-8-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankiethull/Practical-Guide-to-Applied-Conformal-Prediction-in-R/197da490a6a13b339eba700b5c5f2dda26957a03/Chapter8_nixtla_reticulated_files/figure-commonmark/unnamed-chunk-8-1.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Practical Guide to Applied Conformal Prediction in **R** 2 | frankiethull 3 | 4 | The following code is based on the recent book release: *Practical Guide 5 | to Applied Conformal Prediction in Python*. After posting a fuzzy GIF on 6 | X & receiving a lot of requests for a blog or Github repo, this repo 7 | includes the practical guide with applications in R, instead of Python. 8 | 9 | This is not copy/paste direct replica of the python notebook or book, 10 | this is a lite, supplemental R guide, & documentation for R users. 11 | 12 | 13 | -------------------------------------------------------------------------------- /README.qmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Practical Guide to Applied Conformal Prediction in **R**" 3 | author: "frankiethull" 4 | format: gfm 5 | --- 6 | 7 | The following code is based on the recent book release: *Practical Guide to Applied Conformal Prediction in Python*. After posting a fuzzy GIF on X & receiving a lot of requests for a blog or Github repo, this repo includes the practical guide with applications in R, instead of Python. 8 | 9 | This is not copy/paste direct replica of the python notebook or book, this is a lite, supplemental R guide, & documentation for R users. 10 | 11 | 12 | ![](book_cover.jpg){fig-align="right" width="40%"} 13 | 14 | -------------------------------------------------------------------------------- /book_cover.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frankiethull/Practical-Guide-to-Applied-Conformal-Prediction-in-R/197da490a6a13b339eba700b5c5f2dda26957a03/book_cover.jpg --------------------------------------------------------------------------------