├── .Rbuildignore
├── .directory
├── .github
├── .gitignore
├── FUNDING.yml
└── workflows
│ ├── R-CMD-check.yaml
│ └── test-coverage.yml
├── .gitignore
├── DESCRIPTION
├── LICENSE
├── LICENSE.md
├── NAMESPACE
├── R
├── as-torch-tensor.R
├── as-ts-dataloder.R
├── as-ts-dataset.R
├── as-vector.R
├── categorical.R
├── checks.R
├── data-tiny-m5.R
├── data-weather-pl.R
├── device.R
├── heuristics.R
├── initialization.R
├── metrics.R
├── mlp-impl.R
├── mlp-module.R
├── mlp-parsnip.R
├── nn-mlp.R
├── nn-multi-embedding.R
├── nn-nonlinear.R
├── palette.R
├── parse-formula.R
├── plot.R
├── predict.R
├── prepare-data.R
├── progress-bar.R
├── rnn-impl.R
├── rnn-module.R
├── rnn-parsnip.R
├── static.R
├── torchts-model.R
├── torchts-package.R
├── training-helpers.R
├── ts-dataset.R
├── utils-internal.R
├── utils.R
└── zzz.R
├── README.Rmd
├── README.md
├── _pkgdown.yml
├── codecov.yml
├── data-raw
├── debug-mlp.Rmd
├── prepare-logo.R
├── prepare-m5.R
└── prepare-weather-pl.R
├── data
├── tiny_m5.rda
└── weather_pl.rda
├── docs
├── 404.html
├── LICENSE-text.html
├── LICENSE.html
├── apple-touch-icon-120x120.png
├── apple-touch-icon-152x152.png
├── apple-touch-icon-180x180.png
├── apple-touch-icon-60x60.png
├── apple-touch-icon-76x76.png
├── apple-touch-icon.png
├── articles
│ ├── data-prepare-rnn.html
│ ├── data-prepare-rnn_files
│ │ └── accessible-code-block-0.0.1
│ │ │ └── empty-anchor.js
│ ├── data_prepare_rnn.html
│ ├── data_prepare_rnn_files
│ │ └── accessible-code-block-0.0.1
│ │ │ └── empty-anchor.js
│ ├── index.html
│ ├── missing-data.html
│ ├── missing-data_files
│ │ └── accessible-code-block-0.0.1
│ │ │ └── empty-anchor.js
│ ├── missing_data.html
│ ├── missing_data_files
│ │ └── accessible-code-block-0.0.1
│ │ │ └── empty-anchor.js
│ ├── multivariate-time-series.html
│ ├── multivariate-time-series_files
│ │ └── accessible-code-block-0.0.1
│ │ │ └── empty-anchor.js
│ ├── naming-convention.html
│ ├── naming-convention_files
│ │ └── accessible-code-block-0.0.1
│ │ │ └── empty-anchor.js
│ ├── prepare-tensor.html
│ ├── prepare-tensor_files
│ │ └── accessible-code-block-0.0.1
│ │ │ └── empty-anchor.js
│ ├── univariate-time-series.html
│ └── univariate-time-series_files
│ │ └── accessible-code-block-0.0.1
│ │ └── empty-anchor.js
├── authors.html
├── bootstrap-toc.css
├── bootstrap-toc.js
├── docsearch.css
├── docsearch.js
├── extra.css
├── favicon-16x16.png
├── favicon-32x32.png
├── favicon.ico
├── index.html
├── link.svg
├── logo.svg
├── notes.html
├── pkgdown.css
├── pkgdown.js
├── pkgdown.yml
├── reference
│ ├── Rplot001.png
│ ├── as.vector.torch_tensor.html
│ ├── as_dataset.html
│ ├── as_tensor.html
│ ├── as_ts_dataloader.html
│ ├── as_ts_dataset.html
│ ├── basic_rnn.html
│ ├── basic_rnn_fit.html
│ ├── call_optim.html
│ ├── cat2idx.html
│ ├── check_is_complete.html
│ ├── check_is_new_data_complete.html
│ ├── check_recursion.html
│ ├── clear_outcome.html
│ ├── col_map_out.html
│ ├── deep_factor.html
│ ├── deep_factor_rnn.html
│ ├── dict_replace.html
│ ├── dict_size.html
│ ├── embedding_size.html
│ ├── equal.html
│ ├── figures
│ │ ├── logo-small.png
│ │ ├── logo.png
│ │ ├── logo.svg
│ │ ├── logo_v1.svg
│ │ └── shampoo.svg
│ ├── fit_network.html
│ ├── geometric_pyramid.html
│ ├── get_x.html
│ ├── idx2cat.html
│ ├── index.html
│ ├── init_gate_bias.html
│ ├── invert_scaling.html
│ ├── is_categorical.html
│ ├── lagged_mlp.html
│ ├── make_lagged_mlp.html
│ ├── make_recurrent_network.html
│ ├── make_rnn.html
│ ├── model_mlp.html
│ ├── model_recurrent.html
│ ├── model_rnn.html
│ ├── nn_mlp.html
│ ├── nn_multi_embedding.html
│ ├── nn_nonlinear.html
│ ├── nnf_mae.html
│ ├── nnf_mape.html
│ ├── nnf_smape.html
│ ├── numeric_date.html
│ ├── plot_forecast.html
│ ├── plug.html
│ ├── predictors_spec.html
│ ├── prepare_dl.html
│ ├── print_and_capture.html
│ ├── recurrent_network.html
│ ├── recurrent_network_fit_formula.html
│ ├── remove_model.html
│ ├── rep_if_one_element.html
│ ├── resolve_data.html
│ ├── rnn.html
│ ├── rnn_fit.html
│ ├── rnn_output_size.html
│ ├── scale_params.html
│ ├── set_device.html
│ ├── span_time.html
│ ├── step_cat2idx.html
│ ├── timesteps.html
│ ├── tiny_m5.html
│ ├── torchts-package.html
│ ├── torchts_mlp.html
│ ├── torchts_model.html
│ ├── torchts_palette.html
│ ├── torchts_parse_formula.html
│ ├── torchts_rnn.html
│ ├── train_batch.html
│ ├── ts_dataset.html
│ ├── valid_batch.html
│ ├── weather_pl.html
│ └── which_static.html
└── roadmap.html
├── man
├── as.vector.torch_tensor.Rd
├── as_ts_dataloader.Rd
├── as_ts_dataset.Rd
├── call_optim.Rd
├── check_is_complete.Rd
├── check_is_new_data_complete.Rd
├── check_recursion.Rd
├── clear_outcome.Rd
├── col_map_out.Rd
├── dict_size.Rd
├── embedding_size.Rd
├── figures
│ ├── README-parsnip.api-1.png
│ ├── logo-small.png
│ ├── logo.png
│ ├── logo.svg
│ ├── logo_v1.svg
│ └── shampoo.svg
├── fit_network.Rd
├── geometric_pyramid.Rd
├── get_x.Rd
├── init_gate_bias.Rd
├── is_categorical.Rd
├── lagged_mlp.Rd
├── make_lagged_mlp.Rd
├── make_rnn.Rd
├── model_mlp.Rd
├── model_rnn.Rd
├── nn_mlp.Rd
├── nn_multi_embedding.Rd
├── nn_nonlinear.Rd
├── nnf_mae.Rd
├── nnf_mape.Rd
├── nnf_smape.Rd
├── plot_forecast.Rd
├── prepare_dl.Rd
├── print_and_capture.Rd
├── remove_model.Rd
├── rep_if_one_element.Rd
├── rnn.Rd
├── rnn_output_size.Rd
├── set_device.Rd
├── tiny_m5.Rd
├── torchts-package.Rd
├── torchts_mlp.Rd
├── torchts_model.Rd
├── torchts_palette.Rd
├── torchts_parse_formula.Rd
├── torchts_rnn.Rd
├── train_batch.Rd
├── ts_dataset.Rd
├── valid_batch.Rd
├── weather_pl.Rd
└── which_static.Rd
├── notes.md
├── pkgdown
├── extra.css
└── favicon
│ ├── apple-touch-icon-120x120.png
│ ├── apple-touch-icon-152x152.png
│ ├── apple-touch-icon-180x180.png
│ ├── apple-touch-icon-60x60.png
│ ├── apple-touch-icon-76x76.png
│ ├── apple-touch-icon.png
│ ├── favicon-16x16.png
│ ├── favicon-32x32.png
│ └── favicon.ico
├── tests
├── testthat.R
└── testthat
│ ├── test-as-tensor.R
│ ├── test-as-ts-dataloader.R
│ ├── test-as-ts-dataset.R
│ ├── test-as-vector.R
│ ├── test-categorical.R
│ ├── test-checks.R
│ ├── test-metrics.R
│ ├── test-module-nn-nonlinear.R
│ ├── test-prepare-dl.R
│ ├── test-rnn-impl.R
│ ├── test-rnn-module.R
│ ├── test-torchts-parse-formula.R
│ ├── test-ts-dataset.R
│ └── test-utils.R
├── torchts.Rproj
└── vignettes
├── .gitignore
├── data-prepare-rnn.Rmd
├── naming-convention.Rmd
├── parsnip-api.Rmd
├── torchts-api.Rmd
└── torchts-formula.Rmd
/.Rbuildignore:
--------------------------------------------------------------------------------
1 | ^.*\.Rproj$
2 | ^\.Rproj\.user$
3 | ^LICENSE\.md$
4 | ^doc$
5 | ^Meta$
6 | ^data-dev$
7 | ^README\.Rmd$
8 | ^appveyor\.yml$
9 | ^\.travis\.yml$
10 | ^codecov\.yml$
11 | ^\.github$
12 |
--------------------------------------------------------------------------------
/.directory:
--------------------------------------------------------------------------------
1 | [Dolphin]
2 | Timestamp=2020,12,16,21,53,0
3 | Version=4
4 |
5 | [Settings]
6 | HiddenFilesShown=true
7 |
--------------------------------------------------------------------------------
/.github/.gitignore:
--------------------------------------------------------------------------------
1 | *.html
2 |
--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | custom: https://www.buymeacoffee.com/kjoachimiak
4 |
--------------------------------------------------------------------------------
/.github/workflows/R-CMD-check.yaml:
--------------------------------------------------------------------------------
1 | # Workflow derived from https://github.com/r-lib/actions/tree/master/examples
2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
3 | #
4 | # NOTE: This workflow is overkill for most R packages and
5 | # check-standard.yaml is likely a better choice.
6 | # usethis::use_github_action("check-standard") will install it.
7 | on:
8 | push:
9 | branches: [main, master]
10 | pull_request:
11 | branches: [main, master]
12 |
13 | name: R-CMD-check
14 |
15 | jobs:
16 | R-CMD-check:
17 | runs-on: ${{ matrix.config.os }}
18 |
19 | name: ${{ matrix.config.os }} (${{ matrix.config.r }})
20 |
21 | strategy:
22 | fail-fast: false
23 | matrix:
24 | config:
25 | - {os: macOS-latest, r: 'release'}
26 |
27 | - {os: windows-latest, r: 'release'}
28 | # Use 3.6 to trigger usage of RTools35
29 | - {os: windows-latest, r: '3.6'}
30 |
31 | # Use older ubuntu to maximise backward compatibility
32 | - {os: ubuntu-18.04, r: 'devel', http-user-agent: 'release'}
33 | - {os: ubuntu-18.04, r: 'release'}
34 | - {os: ubuntu-18.04, r: 'oldrel-1'}
35 | - {os: ubuntu-18.04, r: 'oldrel-2'}
36 | - {os: ubuntu-18.04, r: 'oldrel-3'}
37 | - {os: ubuntu-18.04, r: 'oldrel-4'}
38 |
39 | env:
40 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
41 | R_KEEP_PKG_SOURCE: yes
42 |
43 | steps:
44 | - uses: actions/checkout@v2
45 |
46 | - uses: r-lib/actions/setup-pandoc@v1
47 |
48 | - uses: r-lib/actions/setup-r@v1
49 | with:
50 | r-version: ${{ matrix.config.r }}
51 | http-user-agent: ${{ matrix.config.http-user-agent }}
52 | use-public-rspm: true
53 |
54 | - uses: r-lib/actions/setup-r-dependencies@v1
55 | with:
56 | extra-packages: rcmdcheck
57 |
58 | - uses: r-lib/actions/check-r-package@v1
59 |
60 | - name: Show testthat output
61 | if: always()
62 | run: find check -name 'testthat.Rout*' -exec cat '{}' \; || true
63 | shell: bash
64 |
65 | - name: Upload check results
66 | if: failure()
67 | uses: actions/upload-artifact@main
68 | with:
69 | name: ${{ runner.os }}-r${{ matrix.config.r }}-results
70 | path: check
71 |
--------------------------------------------------------------------------------
/.github/workflows/test-coverage.yml:
--------------------------------------------------------------------------------
1 | # Workflow derived from https://github.com/r-lib/actions/tree/master/examples
2 | # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
3 | on:
4 | push:
5 | branches: [main, master]
6 | pull_request:
7 | branches: [main, master]
8 |
9 | name: test-coverage
10 |
11 | jobs:
12 | test-coverage:
13 | runs-on: ubuntu-latest
14 | env:
15 | GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
16 |
17 | steps:
18 | - uses: actions/checkout@v2
19 |
20 | - uses: r-lib/actions/setup-r@v1
21 | with:
22 | use-public-rspm: true
23 |
24 | - uses: r-lib/actions/setup-r-dependencies@v1
25 | with:
26 | extra-packages: covr
27 |
28 | - name: Test coverage
29 | run: covr::codecov()
30 | shell: Rscript {0}
31 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # History files
2 | .Rhistory
3 | .Rapp.history
4 |
5 | # Session Data files
6 | .RData
7 |
8 | # User-specific files
9 | .Ruserdata
10 |
11 | # Example code in package build process
12 | *-Ex.R
13 |
14 | # Output files from R CMD build
15 | /*.tar.gz
16 |
17 | # Output files from R CMD check
18 | /*.Rcheck/
19 |
20 | # RStudio files
21 | .Rproj.user/
22 |
23 | # produced vignettes
24 | vignettes/*.html
25 | vignettes/*.pdf
26 |
27 | data-dev/*
28 |
29 | # OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3
30 | .httr-oauth
31 |
32 | # knitr and R markdown default cache directories
33 | *_cache/
34 | /cache/
35 |
36 | # Temporary files created by R markdown
37 | *.utf8.md
38 | *.knit.md
39 |
40 | # R Environment Variables
41 | .Renviron
42 | inst/doc
43 | doc
44 | Meta
45 |
--------------------------------------------------------------------------------
/DESCRIPTION:
--------------------------------------------------------------------------------
1 | Package: torchts
2 | Title: Time series Models with torch
3 | Version: 0.1.0
4 | Authors@R:
5 | person(given = "Krzysztof",
6 | family = "Joachimiak",
7 | role = c("aut", "cre"),
8 | email = "joachimiak.krzysztof@gmail.com",
9 | comment = c(ORCID = "0000-0003-4780-7947"))
10 | Description: Deep Learning torch models for time series forecasting.
11 | It includes easy-to-use torch models and data transformation utilities
12 | and provides parsnip API to these models.
13 | License: MIT + file LICENSE
14 | Encoding: UTF-8
15 | LazyData: true
16 | BugReports: https://github.com/krzjoa/torchts/issues
17 | URL: https://github.com/krzjoa/torchts, https://krzjoa.github.io/torchts/
18 | Roxygen: list(markdown = TRUE)
19 | RoxygenNote: 7.1.2
20 | Suggests:
21 | covr,
22 | knitr,
23 | rmarkdown,
24 | testthat,
25 | timetk
26 | VignetteBuilder: knitr
27 | Imports:
28 | rsample,
29 | glue,
30 | torch,
31 | dplyr,
32 | parsnip
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | YEAR: 2020
2 | COPYRIGHT HOLDER: Krzysztof Joachimiak
3 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | Copyright (c) 2020 Krzysztof Joachimiak
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/NAMESPACE:
--------------------------------------------------------------------------------
1 | # Generated by roxygen2: do not edit by hand
2 |
3 | S3method(as.vector,torch_tensor)
4 | S3method(as_torch_tensor,data.frame)
5 | S3method(as_torch_tensor,default)
6 | S3method(as_torch_tensor,torch_tensor)
7 | S3method(as_torch_tensor,ts)
8 | S3method(as_ts_dataloader,data.frame)
9 | S3method(as_ts_dataset,data.frame)
10 | S3method(as_ts_dataset,default)
11 | S3method(predict,torchts_mlp)
12 | S3method(predict,torchts_rnn)
13 | S3method(print,torchts_model)
14 | S3method(set_device,dataloader)
15 | S3method(set_device,default)
16 | S3method(set_device,model_spec)
17 | S3method(set_device,nn_module)
18 | S3method(set_device,torchts_model)
19 | export(as_torch_tensor)
20 | export(as_ts_dataloader)
21 | export(as_ts_dataset)
22 | export(clear_outcome)
23 | export(dict_size)
24 | export(embedding_size_fastai)
25 | export(embedding_size_google)
26 | export(is_categorical)
27 | export(lagged_mlp)
28 | export(model_mlp)
29 | export(model_rnn)
30 | export(nn_mlp)
31 | export(nn_nonlinear)
32 | export(nnf_mae)
33 | export(nnf_mape)
34 | export(nnf_smape)
35 | export(plot_forecast)
36 | export(rnn)
37 | export(rnn_output_size)
38 | export(set_device)
39 | export(torchts_get_default_device)
40 | export(torchts_mlp)
41 | export(torchts_rnn)
42 | export(torchts_set_default_device)
43 | export(torchts_show_devices)
44 | export(ts_dataset)
45 | export(which_static)
46 | import(data.table)
47 | importFrom(crayon,col_nchar)
48 | importFrom(crayon,col_substr)
49 | importFrom(dplyr,group_by)
50 | importFrom(ggplot2,aes)
51 | importFrom(ggplot2,geom_line)
52 | importFrom(ggplot2,ggplot)
53 | importFrom(ggplot2,ggtitle)
54 | importFrom(ggplot2,theme_minimal)
55 | importFrom(glue,glue)
56 | importFrom(hms,as.hms)
57 | importFrom(parsnip,fit)
58 | importFrom(parsnip,fit_xy)
59 | importFrom(parsnip,translate)
60 | importFrom(prettyunits,pretty_bytes)
61 | importFrom(prettyunits,vague_dt)
62 | importFrom(recipes,bake)
63 | importFrom(recipes,prep)
64 | importFrom(recipes,recipe)
65 | importFrom(recipes,step_integer)
66 | importFrom(recipes,step_scale)
67 | importFrom(rsample,testing)
68 | importFrom(rsample,training)
69 | importFrom(torch,dataloader)
70 | importFrom(torch,nn_embedding)
71 | importFrom(torch,nn_gru)
72 | importFrom(torch,nn_linear)
73 | importFrom(torch,nn_module)
74 | importFrom(torch,optim_adam)
75 | importFrom(torch,torch_cat)
76 | importFrom(utils,flush.console)
77 |
--------------------------------------------------------------------------------
/R/as-ts-dataloder.R:
--------------------------------------------------------------------------------
1 | #' Quick shortcut to create a torch dataloader based on the given dataset
2 | #'
3 | #' @inheritParams as_ts_dataset
4 | #' @param batch_size (`numeric`) Batch size.
5 | #' @param shuffle (`logical`) Shuffle examples.
6 | #' @param drop_last (`logical`) Set to TRUE to drop the last incomplete batch,
7 | #' if the dataset size is not divisible by the batch size.
8 | #' If FALSE and the size of dataset is not divisible by the batch size,
9 | #' then the last batch will be smaller. (default: TRUE)
10 | #'
11 | #' @importFrom torch dataloader
12 | #'
13 | #' @examples
14 | #' library(rsample)
15 | #' library(dplyr, warn.conflicts = FALSE)
16 | #'
17 | #' suwalki_temp <-
18 | #' weather_pl %>%
19 | #' filter(station == "SWK") %>%
20 | #' select(date, temp = tmax_daily)
21 | #'
22 | #' # Splitting on training and test
23 | #' data_split <- initial_time_split(suwalki_temp)
24 | #'
25 | #' train_dl <-
26 | #' training(data_split) %>%
27 | #' as_ts_dataloader(temp ~ date, timesteps = 20, horizon = 10, batch_size = 32)
28 | #'
29 | #' train_dl
30 | #'
31 | #' dataloader_next(dataloader_make_iter(train_dl))
32 | #'
33 | #' @export
34 | as_ts_dataloader <- function(data, formula, index = NULL,
35 | key = NULL,
36 | predictors = NULL,
37 | outcomes = NULL,
38 | categorical = NULL,
39 | timesteps, horizon = 1,
40 | sample_frac = 1,
41 | batch_size, shuffle = FALSE,
42 | jump = 1, drop_last = TRUE,
43 | ...){
44 | UseMethod("as_ts_dataloader")
45 | }
46 |
47 |
48 | #' @export
49 | as_ts_dataloader.data.frame <- function(data, formula = NULL, index = NULL,
50 | key = NULL, predictors = NULL,
51 | outcomes = NULL, categorical = NULL,
52 | timesteps, horizon = 1, sample_frac = 1,
53 | batch_size, shuffle = FALSE,
54 | jump = 1, drop_last = TRUE, ...){
55 | dataloader(
56 | as_ts_dataset(
57 | data = data,
58 | formula = formula,
59 | index = index,
60 | key = key,
61 | predictors = predictors,
62 | outcomes = outcomes,
63 | categorical = categorical,
64 | timesteps = timesteps,
65 | horizon = horizon,
66 | sample_frac = sample_frac,
67 | jump = jump,
68 | # Extra args
69 | ...),
70 |
71 | # Dataloader args
72 | batch_size = batch_size,
73 | shuffle = shuffle,
74 | drop_last = drop_last
75 | )
76 | }
77 |
--------------------------------------------------------------------------------
/R/as-ts-dataset.R:
--------------------------------------------------------------------------------
1 | #' Create a torch dataset for time series data from a `data.frame`-like object
2 | #'
3 | #' @param data (`data.frame`) An input data.frame object with.
4 | #' For now only **single** data frames are handled with no categorical features.
5 | #' @param formula (`formula`) A formula describing, how to use the data
6 | #' @param index (`character`) The index column name.
7 | #' @param key (`character`) The key column name(s). Use only if formula was not specified.
8 | #' @param predictors (`character`) Input variable names. Use only if formula was not specified.
9 | #' @param outcomes (`character`) Target variable names. Use only if formula was not specified.
10 | #' @param categorical (`character`) Categorical features.
11 | #' @param timesteps (`integer`) The time series chunk length.
12 | #' @param horizon (`integer`) Forecast horizon.
13 | #' @param sample_frac (`numeric`) Sample a fraction of rows (default: 1, i.e.: all the rows).
14 | #' @param scale (`logical` or `list`) Scale feature columns. Logical value or two-element list.
15 | #' with values (mean, std)
16 | #'
17 | #' @importFrom recipes recipe step_integer step_scale bake prep
18 | #'
19 | #' @note
20 | #' If `scale` is TRUE, only the input variables are scale and not the outcome ones.
21 | #'
22 | #' See: [Is it necessary to scale the target value in addition to scaling features for regression analysis? (Cross Validated)](https://stats.stackexchange.com/questions/111467/is-it-necessary-to-scale-the-target-value-in-addition-to-scaling-features-for-re)
23 | #'
24 | #' @examples
25 | #' library(rsample)
26 | #' library(dplyr, warn.conflicts = FALSE)
27 | #'
28 | #' suwalki_temp <-
29 | #' weather_pl %>%
30 | #' filter(station == "SWK")
31 | #'
32 | #' debugonce(as_ts_dataset.data.frame)
33 | #'
34 | #' # Splitting on training and test
35 | #' data_split <- initial_time_split(suwalki_temp)
36 | #'
37 | #' train_ds <-
38 | #' training(data_split) %>%
39 | #' as_ts_dataset(tmax_daily ~ date + tmax_daily + rr_type,
40 | #' timesteps = 20, horizon = 1)
41 | #'
42 | #' train_ds[1]
43 | #'
44 | #' train_ds <-
45 | #' training(data_split) %>%
46 | #' as_ts_dataset(tmax_daily ~ date + tmax_daily + rr_type + lead(rr_type),
47 | #' timesteps = 20, horizon = 1)
48 | #'
49 | #' train_ds[1]
50 | #'
51 | #' train_ds <-
52 | #' training(data_split) %>%
53 | #' as_ts_dataset(tmax_daily ~ date + tmax_daily + rr_type + lead(tmin_daily),
54 | #' timesteps = 20, horizon = 1)
55 | #'
56 | #' train_ds[1]
57 | #'
58 | #' @export
59 | as_ts_dataset <- function(data, formula,
60 | timesteps, horizon = 1, sample_frac = 1,
61 | jump = 1, ...){
62 | UseMethod("as_ts_dataset")
63 | }
64 |
65 |
66 | #'@export
67 | as_ts_dataset.default <- function(data, formula,
68 | timesteps, horizon = 1, sample_frac = 1,
69 | jump = 1, ...){
70 | stop(sprintf(
71 | "Object of class %s in not handled for now.", class(data)
72 | ))
73 | }
74 |
75 | #' @export
76 | as_ts_dataset.data.frame <- function(data, formula = NULL,
77 | timesteps, horizon = 1, sample_frac = 1,
78 | jump = 1, ...){
79 |
80 | # TODO: remove key, index, outcomes etc.
81 | # (define only with formula or parsed formula)?
82 | extra_args <- list(...)
83 |
84 | if (nrow(data) == 0) {
85 | stop("The data object is empty!")
86 | }
87 |
88 | if (is.null(extra_args$parsed_formula))
89 | parsed_formula <- torchts_parse_formula(formula, data = data)
90 | else
91 | parsed_formula <- extra_args$parsed_formula
92 |
93 | # Parsing formula
94 | # TODO: key is not used for now
95 | .past <- list(
96 |
97 | # Numeric time-varying variables
98 | x_num = get_vars(parsed_formula, "predictor", "numeric"),
99 |
100 | # Categorical time-varying variables
101 | x_cat = get_vars(parsed_formula, "predictor", "categorical")
102 | )
103 |
104 | # Future spec: outcomes + predictors
105 | .future <- list(
106 | y = vars_with_role(parsed_formula, "outcome"),
107 | # Possible predictors from the future (e.g. coming holidays)
108 | x_fut_num = get_vars2(parsed_formula, "predictor", "numeric", "lead"),
109 | x_fut_cat = get_vars2(parsed_formula, "predictor", "categorical", "lead")
110 | )
111 |
112 | .index_columns <-
113 | parsed_formula[parsed_formula$.role == "index", ]$.var
114 |
115 | # Removing NULLs
116 | .past <- remove_nulls(.past)
117 | .future <- remove_nulls(.future)
118 |
119 | categorical <-
120 | parsed_formula %>%
121 | filter(.type == 'categorical') %>%
122 | pull(.var)
123 |
124 | data <-
125 | data %>%
126 | arrange(!!.index_columns)
127 |
128 | ts_recipe <-
129 | recipe(data) %>%
130 | step_integer(all_of(categorical)) %>%
131 | prep()
132 |
133 | data <-
134 | ts_recipe %>%
135 | bake(new_data = data)
136 |
137 | if (is.null(.index_columns) | length(.index_columns) == 0)
138 | stop("No time index column defined! Add at least one time-based variable.")
139 |
140 | ts_dataset(
141 | data = data,
142 | timesteps = timesteps,
143 | horizon = horizon,
144 | past = .past,
145 | future = .future,
146 | categorical = c("x_cat", "x_fut_cat"),
147 | sample_frac = sample_frac,
148 | jump = jump,
149 | extras = list(recipe = ts_recipe)
150 | )
151 | }
152 |
--------------------------------------------------------------------------------
/R/as-vector.R:
--------------------------------------------------------------------------------
1 | #' Convert `torch_tensor` to a vector
2 | #'
3 | #' `as.vector.torch_tensor` attempts to coerce a `torch_tensor` into a vector of
4 | #' mode `mode` (the default is to coerce to whichever vector mode is most convenient):
5 | #' if the result is atomic all attributes are removed.
6 | #'
7 | #' @param x (`torch_tensor`) A `torch` tensor
8 | #' @param mode (`character`) A character string with one of possible vector modes:
9 | #' "any", "list", "expression" or other basic types like "character", "integer" etc.
10 | #'
11 | #' @return
12 | #' A vector of desired type.
13 | #' All attributes are removed from the result if it is of an atomic mode,
14 | #' but not in general for a list result.
15 | #'
16 | #' @seealso
17 | #' [base::as.vector]
18 | #'
19 | #' @examples
20 | #' library(torch)
21 | #' library(torchts)
22 | #'
23 | #' x <- torch_tensor(array(10, dim = c(3, 3, 3)))
24 | #' as.vector(x)
25 | #' as.vector(x, mode = "logical)
26 | #' as.vector(x, mode = "character")
27 | #' as.vector(x, mode = "complex")
28 | #' as.vector(x, mode = "list")
29 | #'
30 | #' @export
31 | as.vector.torch_tensor <- function(x, mode = 'any'){
32 | # TODO: dim order, as_tibble, as.data.frame with dims order
33 | as.vector(as.array(x), mode = mode)
34 | }
35 |
--------------------------------------------------------------------------------
/R/categorical.R:
--------------------------------------------------------------------------------
1 | #' Check, if vector is categorical, i.e.
2 | #' if is logical, factor, character or integer
3 | #'
4 | #' @param x A vector of arbitrary type
5 | #'
6 | #' @return Logical value
7 | #'
8 | #' @examples
9 | #' is_categorical(c(TRUE, FALSE, TRUE, FALSE, FALSE, FALSE, TRUE))
10 | #' is_categorical(1:10)
11 | #' is_categorical((1:10) + 0.1)
12 | #' is_categorical(as.factor(c("Ferrari", "Lamborghini", "Porsche", "McLaren", "Koenigsegg")))
13 | #' is_categorical(c("Ferrari", "Lamborghini", "Porsche", "McLaren", "Koenigsegg"))
14 | #'
15 | #' @export
16 | is_categorical <- function(x){
17 | # TODO: class(x) %in% getOption("torchts_categoricals")
18 | # is.logical(x) |
19 | # is.factor(x) |
20 | # is.character(x) |
21 | # is.integer(x)
22 | any(sapply(
23 | getOption("torchts_categoricals"),
24 | function(cls) inherits(x, cls)
25 | ))
26 | }
27 |
28 | which_categorical <- function(data){
29 | sapply(data, is_categorical)
30 | }
31 |
32 |
33 | #' Return size of categorical variables in the data.frame
34 | #'
35 | #' @param data (`data.frame`) A data.frame containing categorical variables.
36 | #' The function automatically finds categorical variables,
37 | #' calling internally [is_categorical] function.
38 | #'
39 | #' @return Named logical vector
40 | #'
41 | #' @examples
42 | #' glimpse(tiny_m5)
43 | #' dict_size(tiny_m5)
44 | #'
45 | #' # We can choose only the features we want - otherwise it automatically
46 | #' # selects logical, factor, character or integer vectors
47 | #'
48 | #' tiny_m5 %>%
49 | #' select(store_id, event_name_1) %>%
50 | #' dict_size()
51 | #'
52 | #' @export
53 | dict_size <- function(data){
54 | cols <- sapply(data, is_categorical)
55 | sapply(as.data.frame(data)[cols], dplyr::n_distinct)
56 | }
57 |
58 | #' @name embedding_size
59 | #' @title Propose the length of embedding vector for each embedded feature.
60 | #'
61 | #' @param x (`integer`) A vector with dictionary size for each feature
62 | #' @param
63 | #'
64 | #' @description
65 | #' These functions returns proposed embedding sizes for each categorical feature.
66 | #' They are "rule of thumbs", so the are based on empirical rather than theoretical conclusions,
67 | #' and their parameters can look like "magic numbers". Nevertheless, when you don't know what embedding size
68 | #' will be "optimal", it's good to start with such kind of general rules.
69 | #'
70 | #' * **google**
71 | #' Proposed on the [Google Developer](https://developers.googleblog.com/2017/11/introducing-tensorflow-feature-columns.html) site
72 | #' \deqn{x^0.25}
73 | #'
74 | #' * **fastai**
75 | #' \deqn{1.6 * x^0.56}
76 | #'
77 | #'
78 | #' @return Proposed embedding sizes.
79 | #'
80 | #' @examples
81 | #' dict_sizes <- dict_size(tiny_m5)
82 | #' embedding_size_google(dict_sizes)
83 | #' embedding_size_fastai(dict_sizes)
84 | #'
85 | #' @references
86 | #'
87 | #' * [Introducing TensorFlow Feature Columns](https://developers.googleblog.com/2017/11/introducing-tensorflow-feature-columns.html)
88 | #' * [fastai - embedding size rule of thumb](https://github.com/fastai/fastai/blob/master/fastai/tabular/model.py)
89 | #'
90 | #'
91 | NULL
92 |
93 | #' @rdname embedding_size
94 | #' @export
95 | embedding_size_google <- function(x, max_size = 100){
96 | pmin(ceiling(x ** 0.25), max_size)
97 | }
98 |
99 | #' @rdname embedding_size
100 | #' @export
101 | embedding_size_fastai <- function(x, max_size = 100){
102 | pmin(round(1.6 * x ** 0.56), max_size)
103 | }
104 |
105 |
--------------------------------------------------------------------------------
/R/checks.R:
--------------------------------------------------------------------------------
1 | #' Check, if recursion should be used in forecasting
2 | check_recursion <- function(object, new_data){
3 |
4 | # TODO: check, if this procedure is sufficient
5 | recursive_mode <- FALSE
6 |
7 | # Check, if outcome is predictor
8 | if (any(object$outcome %in% colnames(new_data))) {
9 | # Check, there are na values in predictor column
10 | if (any(is.na(new_data[object$outcome]))) {
11 | if (nrow(new_data) > object$horizon)
12 | recursive_mode <- TRUE
13 | }
14 | }
15 |
16 | recursive_mode
17 | }
18 |
19 | #' Check if input data contains no NAs.
20 | #' Otherwise, return error.
21 | check_is_complete <- function(data){
22 |
23 | complete_cases <- complete.cases(data)
24 |
25 | if (!all(complete_cases)) {
26 | sample_rows <-
27 | dplyr::slice_sample(data[!complete_cases,], n = 3)
28 | stop("Passed data contains incomplete rows, for example: \n",
29 | print_and_capture(sample_rows))
30 | }
31 |
32 | }
33 |
34 | #' Check if new data has NAs in columns others than predicted outcome
35 | check_is_new_data_complete <- function(object, new_data){
36 |
37 | only_predictors <- setdiff(
38 | object$predictors, object$outcomes
39 | )
40 |
41 | complete_cases <- complete.cases(new_data[only_predictors])
42 |
43 | if (!all(complete_cases)) {
44 | sample_rows <-
45 | dplyr::slice_sample(new_data[!complete_cases,], n = 3)
46 | stop("Only the outcome variable column is allowed to contains NAs (on its beginning).
47 | NA values in other columns detected.
48 | Passed new data contains incomplete rows, for example: \n",
49 | print_and_capture(sample_rows))
50 | }
51 |
52 | }
53 |
54 | check_length_vs_horizon <- function(object, new_data){
55 | # TODO: adapt to multiple keys
56 |
57 | len <- nrow(new_data)
58 | modulo <- len %% object$horizon
59 |
60 | if (modulo != 0)
61 | message(glue(
62 | "new_data length ({len}) is not a multiple of horizon {object$horizon}.
63 | Forecast output will be shorter by {modulo} timesteps."
64 | ))
65 |
66 | }
67 |
68 |
69 | check_stateful_vs_jump <- function(horizon, jump, stateful){
70 | if ((horizon != jump) & stateful)
71 | message(glue(
72 | "Horizon is not equal to jump, while stateful flag is TRUE.
73 | horizon = {horizon}, jump = {jump}.
74 | It is not recommended, but it will be performed as Your Majesty wishes."
75 | ))
76 | }
77 |
78 |
79 |
80 |
--------------------------------------------------------------------------------
/R/data-tiny-m5.R:
--------------------------------------------------------------------------------
1 | #' A subset from M5 Walmart Challenge Dataset in one data frame
2 | #'
3 | #' A piece of data cut from the training dataset used in the M5 challenges on Kaggle.
4 | #' M5 is a challenge from a series organized by Spyros Makridakis.
5 | #'
6 | #'
7 | #' @format
8 | #' \describe{
9 | #' \item{item_id}{The id of the product}
10 | #' \item{dept_id}{The id of the department the product belongs to}
11 | #' \item{cat_id}{The id of the category the product belongs to}
12 | #' \item{store_id}{The id of the store where the product is sold}
13 | #' \item{state_id}{The State where the store is located}
14 | #' \item{value}{The number of sold units}
15 | #' \item{date}{The date in a “y-m-d” format}
16 | #' \item{wm_yr_wk}{The id of the week the date belongs to}
17 | #' \item{weekday}{The type of the day (Saturday, Sunday, …, Friday)}
18 | #' \item{wday}{The id of the weekday, starting from Saturday}
19 | #' \item{month}{ The month of the date}
20 | #' \item{year}{The year of the date}
21 | #' \item{event_name_1}{If the date includes an event, the name of this event}
22 | #' \item{event_type_1}{If the date includes an event, the type of this event}
23 | #' \item{event_name_2}{If the date includes a second event, the name of this event}
24 | #' \item{event_type_2}{If the date includes a second event, the type of this event}
25 | #' \item{snap}{A binary variable (0 or 1) indicating whether the stores of CA, TX or WI allow SNAP1 purchases on the examined date. 1 indicates that SNAP purchases are allowed}
26 | #' \item{sell_price}{The price of the product for the given week/store.
27 | #' The price is provided per week (average across seven days). If not available, this means that the product was not sold during the examined week.
28 | #' Note that although prices are constant at weekly basis, they may change through time (both training and test set)}
29 | #' }
30 | #'
31 | #' @seealso
32 | #' [M5 Forecasting - Accuracy](https://www.kaggle.com/c/m5-forecasting-accuracy)
33 | #'
34 | #' [M5 Forecasting - Uncertainty](https://www.kaggle.com/c/m5-forecasting-uncertainty)
35 | #'
36 | #' [The M5 competition: Background, organization, and implementation](https://www.sciencedirect.com/science/article/pii/S0169207021001187)
37 | #'
38 | #' [Other Walmart datasets in timetk](https://business-science.github.io/timetk/reference/index.html#section-time-series-datasets)
39 | #'
40 | #' @examples
41 | #' # Head of tiny_m5
42 | #' head(tiny_m5)
43 | "tiny_m5"
44 |
--------------------------------------------------------------------------------
/R/data-weather-pl.R:
--------------------------------------------------------------------------------
1 | #' Weather data from Polish "poles of extreme temperatures" in 2001-2020
2 | #'
3 | #' The data comes from IMGW (Institute of Meteorology and Water Management) and
4 | #' was downloaded using the [climate] package. Two places have been chosen:
5 | #' \itemize{
6 | #' \item{TRN - Tarnów ("pole of warmth")}
7 | #' \item{SWK - Suwałki ("pole of cold")}
8 | #' }
9 | #' A subset of columns has been selected and `date` column was added.
10 | #'
11 | #' @format
12 | #' \describe{
13 | #' \item{station}{A place where weather data were measured}
14 | #' \item{date}{Date}
15 | #' \item{tmax_daily}{Maximum daily air temperatury [C]}
16 | #' \item{tmin_daily}{Minimum daily air temperature [C]}
17 | #' \item{tmin_soil}{Minimum near surface air temperature [C]}
18 | #' \item{rr_daily}{Total daily preciptation [mm]}
19 | #' \item{rr_type}{Precipitation type [S/W]}
20 | #' \item{rr_daytime}{Total precipitation during day [mm]}
21 | #' \item{rr_nightime}{Total precipitation during night [mm]}
22 | #' \item{press_mean_daily}{Daily mean pressure at station level [hPa]}
23 | #' }
24 | #'
25 | #' @seealso
26 | #' [climate](https://github.com/bczernecki/climate)
27 | #' [IMGW public data](https://danepubliczne.imgw.pl/)
28 | #' [IMGW public data (direct access to folders)](https://danepubliczne.imgw.pl/data/dane_pomiarowo_obserwacyjne/)
29 | #'
30 | #' @examples
31 | #' # Head of weather_pl
32 | #' head(weather_pl)
33 | "weather_pl"
34 |
--------------------------------------------------------------------------------
/R/device.R:
--------------------------------------------------------------------------------
1 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2 | # set_device
3 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4 |
5 | #' Set model device.
6 | #'
7 | #' @param object An neural network object.
8 | #' @param device (`character`) Selected device.
9 | #'
10 | #' @return Object of the same class with device set.
11 | #'
12 | #' @examples
13 | #' rnn_net <-
14 | #' model_rnn(
15 | #' input_size = 1,
16 | #' output_size = 1,
17 | #' hidden_size = 10
18 | #' ) %>%
19 | #' set_device("cpu")
20 | #'
21 | #' rnn_net
22 | #'
23 | #' @export
24 | set_device <- function(object, device, ...){
25 | UseMethod("set_device")
26 | }
27 |
28 | #' @export
29 | set_device.default <- function(object, device, ...){
30 |
31 | if (is.null(object))
32 | return(object)
33 |
34 | stop(sprintf(
35 | "Object of class %s has no devices defined!", class(object)
36 | ))
37 | }
38 |
39 | #' @export
40 | set_device.torchts_model <- function(object, device, ...){
41 | set_device(object$net, device)
42 | }
43 |
44 | #' @export
45 | set_device.model_spec <- function(object, device, ...){
46 | object$eng_args$device <- device #rlang::enquo(device)
47 | object
48 | }
49 |
50 | #' @export
51 | set_device.dataloader <- function(object, device, ...){
52 | object$dataset$device <- device
53 | object
54 | }
55 |
56 |
57 | #' @export
58 | set_device.nn_module <- function(object, device, ...){
59 | .set_device(object, device, ...)
60 | }
61 |
62 |
63 | set_device.torch_tensor <- function(object, device, ...){
64 | .set_device(object, device, ...)
65 | }
66 |
67 | .set_device <- function(object, device, ...){
68 | AVAILABLE_DEVICES <- c("cuda", "cpu")
69 |
70 | if (!(device %in% AVAILABLE_DEVICES))
71 | stop(sprintf(
72 | "You cannot select %s device.
73 | Choose 'cpu' or 'cuda' instead.",
74 | device
75 | ))
76 |
77 | if (device == "cpu")
78 | return(object$cpu())
79 |
80 | if (device == "cuda")
81 | return(object$cuda())
82 |
83 | }
84 |
85 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
86 | # show_devices
87 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
88 |
89 | #' Show available devices
90 | #' @examples
91 | #' torchts_show_devices()
92 | #' @export
93 | torchts_show_devices <- function(){
94 | if (cuda_is_available())
95 | return(c("cpu", "cuda"))
96 | else
97 | return("cpu")
98 | }
99 |
100 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
101 | # default device
102 | # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
103 |
104 | #' Set a torch device, which is treated as default for torchts models
105 | #' in the current R session
106 | #' @param device Device name
107 | #' @examples
108 | #' torchts_set_default_device("cuda")
109 | #' @export
110 | torchts_set_default_device <- function(device){
111 | options(torchts_default_device = device)
112 | }
113 |
114 | #' Get a torch device, which is treated as default for torchts models
115 | #' in the current R session
116 | #' @param device Device name
117 | #' @examples
118 | #' torchts_get_default_device()
119 | #' @export
120 | torchts_get_default_device <- function(device){
121 | getOption(torchts_default_device, "cpu")
122 | }
123 |
124 |
125 |
--------------------------------------------------------------------------------
/R/heuristics.R:
--------------------------------------------------------------------------------
1 | #' Geometric pyramid rule
2 | #'
3 | #' @description A simple heuristics to choose hidden layer size
4 | #'
5 | #' @param input_size (`integer`) Input size
6 | #' @param next_layer_size (`integer`) Next layer size
7 | #'
8 | #' @references
9 | #' [Practical Neural Network Recipes in C++](https://books.google.de/books/about/Practical_Neural_Network_Recipes_in_C++.html?id=7Ez_Pq0sp2EC&redir_esc=y)
10 | #'
11 | geometric_pyramid <- function(input_size, next_layer_size){
12 | ceiling(sqrt(input_size * next_layer_size))
13 | }
14 |
--------------------------------------------------------------------------------
/R/initialization.R:
--------------------------------------------------------------------------------
1 | #' Initialize gates to pass full information
2 | #'
3 | #' x <- list(rnn_layer = nn_lstm(2, 20))
4 | #' init_gate_bias(x$rnn_layer)
5 | #' x$rnn_layer$parameters$bias_ih_l1
6 | #'
7 | init_gate_bias <- function(rnn_layer){
8 | # rnn_layer <- nn_lstm(2, 20)
9 |
10 | # https://stackoverflow.com/questions/62198351/why-doesnt-pytorch-allow-inplace-operations-on-leaf-variables
11 | # https://danijar.com/tips-for-training-recurrent-neural-networks/
12 |
13 | # Forget gate bias.
14 | # It can take a while for a recurrent network to learn to remember information form the last time step.
15 | # Initialize biases for LSTM’s forget gate to 1 to remember more by default.
16 | # Similarly, initialize biases for GRU’s reset gate to -1.
17 |
18 | if (inherits(rnn_layer, 'nn_gru')) {
19 |
20 | # Initialize reset gate with -1
21 |
22 | segment_len <- dim(rnn_layer$parameters$bias_hh_l1) / 4
23 | indices <- (segment_len+1):(2*segment_len)
24 |
25 | # First part is reset gate
26 | # ~GRU.bias_ih_l[k] (b_ir|b_iz|b_in), of shape (3*hidden_size)
27 | #
28 | # ~GRU.bias_hh_l[k] (b_hr|b_hz|b_hn), of shape (3*hidden_size)
29 |
30 | # Jeśli jest leaf, to nie można robić inplace
31 | rnn_layer$parameters$bias_hh_l1$requires_grad_(FALSE)
32 | rnn_layer$.__enclos_env__$private$parameters_$bias_hh_l1[indices] <- -1
33 | rnn_layer$parameters$bias_hh_l1$requires_grad_(TRUE)
34 |
35 | rnn_layer$parameters$bias_ih_l1$requires_grad_(FALSE)
36 | rnn_layer$.__enclos_env__$private$parameters_$bias_ih_l1[indices] <- -1
37 | rnn_layer$parameters$bias_ih_l1$requires_grad_(TRUE)
38 | }
39 |
40 | if (inherits(rnn_layer, 'nn_lstm')) {
41 | #' ~LSTM.bias_ih_l[k] – (b_ii|b_if|b_ig|b_io), of shape (4*hidden_size)
42 | #' ~LSTM.bias_hh_l[k] – (b_hi|b_hf|b_hg|b_ho), of shape (4*hidden_size)
43 | #'
44 |
45 | segment_len <- dim(rnn_layer$parameters$bias_ih_l1) / 4
46 | indices <- (segment_len+1):(2*segment_len)
47 |
48 | # Jeśli jest leaf, to nie można robić inplace
49 | rnn_layer$parameters$bias_hh_l1$requires_grad_(FALSE)
50 | rnn_layer$.__enclos_env__$private$parameters_$bias_hh_l1[indices] <- 1
51 | rnn_layer$parameters$bias_hh_l1$requires_grad_(TRUE)
52 |
53 | rnn_layer$parameters$bias_ih_l1$requires_grad_(FALSE)
54 | rnn_layer$.__enclos_env__$private$parameters_$bias_ih_l1[indices] <- 1
55 | rnn_layer$parameters$bias_ih_l1$requires_grad_(TRUE)
56 |
57 | }
58 |
59 | invisible()
60 | }
61 |
62 |
63 |
64 |
--------------------------------------------------------------------------------
/R/metrics.R:
--------------------------------------------------------------------------------
1 | # Metrics
2 |
3 | #' Mean absolute percentage error
4 | #'
5 | #' @param input (`torch_tensor`) A tensor of actual values
6 | #' @param target (`torch_tensor`) A tensor with the same shape as the input
7 | #'
8 | #' @details
9 | #' Computed according to the formula:
10 | #' \deqn{MAPE = \frac{1}{n}\displaystyle\sum_{t=1}^{n} \left\|\frac{target - input}{target}\right\|}
11 | #'
12 | #' @seealso
13 | #' [yardstick::mape]
14 | #'
15 | #' @examples
16 | #' input <- c(92, 6.5, 57.69, 15.9, 88.47, 75.01, 5.06, 45.95, 27.8, 70.96)
17 | #' input <- as_tensor(input)
18 | #'
19 | #' target <- c(91.54, 5.87, 58.85, 10.73, 81.47, 75.39, 2.05, 40.95, 27.34, 66.61)
20 | #' target <- as_tensor(target)
21 | #'
22 | #' nnf_mape(input, target)
23 | #'
24 | #' @export
25 | nnf_mape <- function(input, target){
26 | mean(abs((target - input) / target))
27 | }
28 |
29 |
30 | #' Symmetric mean absolute percentage error
31 | #'
32 | #' @param input (`torch_tensor`) A tensor of actual values
33 | #' @param target (`torch_tensor`) A tensor with the same shape as the input
34 | #'
35 | #' @details
36 | #' Computed according to the formula:
37 | #' \deqn{SMAPE = \frac{1}{n}\displaystyle\sum_{t=1}^{n} \frac{\left\|input - target\right\|}{(\left\|target\right\| + \left\|input\right\|) *0.5}}
38 | #'
39 | #' @seealso
40 | #' [yardstick::smape]
41 | #'
42 | #' @examples
43 | #' input <- c(92, 6.5, 57.69, 15.9, 88.47, 75.01, 5.06, 45.95, 27.8, 70.96)
44 | #' input <- as_tensor(input)
45 | #'
46 | #' target <- c(91.54, 5.87, 58.85, 10.73, 81.47, 75.39, 2.05, 40.95, 27.34, 66.61)
47 | #' target <- as_tensor(target)
48 | #'
49 | #' nnf_smape(input, target)
50 | #'
51 | #' @export
52 | nnf_smape <- function(input, target){
53 | # Verify and change: input vs target
54 | # Target = actual values
55 | numerator <- abs(input - target)
56 | denominator <- ((abs(target) + abs(input)) / 2)
57 | mean(numerator / denominator)
58 | }
59 |
60 |
61 | #' Mean absolute error
62 | #'
63 | #' @param input (`torch_tensor`) A tensor of actual values
64 | #' @param target (`torch_tensor`) A tensor with the same shape as the input
65 | #'
66 | #' @details
67 | #'
68 | #' Computed according to the formula:
69 | #' \deqn{MAE = \frac{1}{n}\displaystyle\sum_{t=1}^{n}\left\|target - input\right\|}
70 | #'
71 | #' @seealso
72 | #' [yardstick::mae]
73 | #'
74 | #' @examples
75 | #' input <- c(92, 6.5, 57.69, 15.9, 88.47, 75.01, 5.06, 45.95, 27.8, 70.96)
76 | #' input <- as_tensor(input)
77 | #'
78 | #' target <- c(91.54, 5.87, 58.85, 10.73, 81.47, 75.39, 2.05, 40.95, 27.34, 66.61)
79 | #' target <- as_tensor(target)
80 | #'
81 | #' nnf_mae(input, target)
82 | #'
83 | #' @export
84 | nnf_mae <- function(input, target){
85 | mean(abs(target - input))
86 | }
87 |
88 |
89 | #' Mean absolute scaled error
90 | #'
91 | #' @param input (`torch_tensor`) A tensor of actual values
92 | #' @param target (`torch_tensor`) A tensor with the same shape as the input
93 | #'
94 | #' @details
95 | #'
96 | #' Computed according to the formula:
97 | #' \deqn{MASE = \displaystyle\frac{MAE_{out-of-sample}}{\frac{1}{n-1}\sum_{i=2}^{n}\left\|a_i - a_{i-1}\right\|} }
98 | #'
99 | #' @seealso
100 | #' [yardstick::mase]
101 | #'
102 | #' @references
103 | #' [Rob J. Hyndman (2006). ANOTHER LOOK AT FORECAST-ACCURACY METRICS
104 | #' FOR INTERMITTENT DEMAND. _Foresight_, 4, 46.](https://robjhyndman.com/papers/foresight.pdf)
105 | #'
106 | #' @examples
107 | #' input <- c(92, 6.5, 57.69, 15.9, 88.47, 75.01, 5.06, 45.95, 27.8, 70.96)
108 | #' input <- as_tensor(input)
109 | #'
110 | #' target <- c(91.54, 5.87, 58.85, 10.73, 81.47, 75.39, 2.05, 40.95, 27.34, 66.61)
111 | #' target <- as_tensor(target)
112 | #'
113 | #' nnf_mae(input, target)
114 | #'
115 | # nnf_mase <- function(y_pred, target, in_sample_actual){
116 | # numerator <- mean(abs(y_pred - target))
117 | # # In-sample naive forecast
118 | #
119 | # denominator <- mean(diffs)
120 | # numerator / denominator
121 | # }
122 |
123 |
124 | #' Weighted Absolute Percentage Error
125 | #'
126 | #' @param input tensor (N,*) where ** means, any number of additional dimensions
127 | #' @param target tensor (N,*) , same shape as the input
128 | #'
129 | #' @details
130 | #' Known also as WMAPE or wMAPE (Weighted Mean Absolute Percentage Error)
131 | #' However, sometimes WAPE and WMAPE metrics are [distinguished](https://www.baeldung.com/cs/mape-vs-wape-vs-wmape).
132 | #'
133 | #' Variant of [nnf_mape()], but weighted with target values.
134 | #'
135 | #' Computed according to the formula:
136 | #' \deqn{MAPE = \frac{1}{n}\displaystyle\sum_{t=1}^{n} \abs{\frac{target - input}{target}}}
137 | #'
138 | # nnf_wape <- function(input, target){
139 | # mean(abs(target - input) / abs(target))
140 | # }
141 |
142 |
143 |
144 |
145 |
--------------------------------------------------------------------------------
/R/mlp-module.R:
--------------------------------------------------------------------------------
1 | #' A configurable feed forward network (Multi-Layer Perceptron)
2 | #' with embedding
3 | #'
4 | #' @importFrom torch torch_cat
5 | #'
6 | #' @examples
7 | #' net <- model_mlp(4, 2, 1)
8 | #' x <- as_tensor(iris[, 1:4])
9 | #' net(x)
10 | #'
11 | #' # With categorical features
12 | #' library(recipes)
13 | #' iris_prep <-
14 | #' recipe(iris) %>%
15 | #' step_integer(Species) %>%
16 | #' prep() %>%
17 | #' juice()
18 | #'
19 | #' iris_prep <- mutate(iris_prep, Species = as.integer(Species))
20 | #'
21 | #' x_num <- as_tensor(iris_prep[, 1:4])
22 | #' x_cat <- as_tensor(dplyr::select(iris_prep, 5))
23 | #'
24 | #' n_unique_values <- dict_size(iris_prep)
25 | #'
26 | #' .init_layer_spec <-
27 | #' init_layer_spec(
28 | #' num_embeddings = n_unique_values,
29 | #' embedding_dim = embedding_size_google(n_unique_values),
30 | #' numeric_in = 4,
31 | #' numeric_out = 2
32 | #' )
33 | #'
34 | #' net <- model_mlp(.init_layer_spec, 2, 1)
35 | #'
36 | #' net(x_num, x_cat)
37 | #'
38 | #' @export
39 | model_mlp <- torch::nn_module(
40 |
41 | "model_mlp",
42 |
43 | initialize = function(..., horizon, output_size, embedding = NULL,
44 | activation = nnf_relu){
45 |
46 | layers <- list(...)
47 |
48 | self$horizon <- horizon
49 | self$output_size <- output_size
50 |
51 | # If first element is a list, it describes embedding + numerical features
52 | if (is.list(layers[[1]])) {
53 |
54 | first_layer <- layers[[1]]
55 |
56 | self$multiembedding <-
57 | nn_multi_embedding(
58 | num_embeddings = first_layer$num_embeddings,
59 | embedding_dim = first_layer$embedding_dim
60 | )
61 |
62 | self$initial_layer <-
63 | nn_nonlinear(
64 | first_layer$numeric_in,
65 | first_layer$numeric_out
66 | )
67 |
68 | first_layer_output <-
69 | first_layer$numeric_out +
70 | sum(first_layer$embedding_dim)
71 |
72 | layers <- c(
73 | list(first_layer_output), layers[-1]
74 | )
75 |
76 | }
77 |
78 | self$mlp <- do.call(
79 | nn_mlp, c(layers, list(activation = activation))
80 | )
81 |
82 | },
83 |
84 | forward = function(x_num = NULL, x_cat = NULL, x_fut_num = NULL, x_fut_cat = NULL){
85 |
86 | if (!is.null(x_cat) & !is.null(x_fut_cat))
87 | x_cat <- torch_cat(list(x_cat, x_fut_cat))
88 |
89 | if (!is.null(x_num) & !is.null(x_fut_num))
90 | x_num <- torch_cat(list(x_num, x_fut_num))
91 |
92 | # Pass trough initial layer
93 | if (!is.null(x_cat)) {
94 |
95 | output <-
96 | torch_cat(list(
97 | self$multiembedding(x_cat),
98 | self$initial_layer(x_num)
99 | ), dim = -1)
100 | } else {
101 | output <- x_num
102 | }
103 |
104 | # Transform batch_size x (timesteps * features)
105 | current_shape <- dim(output)
106 |
107 | # output <- output$reshape(c(
108 | # current_shape[1], current_shape[2] * current_shape[3]
109 | # ))
110 |
111 | output <- self$mlp(output)
112 |
113 | # Reshape output
114 | # output <- output$reshape(c(
115 | # current_shape[1], self$horizon, self$output_size
116 | # ))
117 |
118 | output
119 | }
120 |
121 | )
122 |
123 | init_layer_spec <- function(num_embeddings,
124 | embedding_dim,
125 | numeric_in,
126 | numeric_out){
127 | list(
128 | num_embeddings = num_embeddings,
129 | embedding_dim = embedding_dim,
130 | numeric_in = numeric_in,
131 | numeric_out = numeric_out
132 | )
133 | }
134 |
135 |
136 |
137 |
--------------------------------------------------------------------------------
/R/nn-mlp.R:
--------------------------------------------------------------------------------
1 | #' A shortcut to create a feed-forward block (MLP block)
2 | #'
3 | #' @param ... (`nn_module`, `function` `integer`, `character`)
4 | #' An arbitrary number of arguments, than can be:
5 | #' * `nn_module` - e.g. [`torch::nn_relu()`]
6 | #' * `function` - e.g. [`torch::nnf_relu`]
7 | #' * `character` - e.g. `selu`, which is converted to `nnf_selu`
8 | #' * `integer` -
9 | #'
10 | #' @param activation Used if only integers are specified. By default: `nnf_relu`
11 | #'
12 | #' @examples
13 | #' nn_mlp(10, 1)
14 | #' nn_mlp(30, 10, 1)
15 | #'
16 | #' # Simple forward pass
17 | #' net <- nn_mlp(4, 2, 1)
18 | #' x <- as_torch_tensor(iris[, 1:4])
19 | #' net(x)
20 | #'
21 | #' # Simple forward pass with identity function
22 | #' net <- nn_mlp(4, 2, 1, activation = function (x) x)
23 | #' x <- as_torch_tensor(iris[, 1:4])
24 | #' net(x)
25 | #'
26 | #' @export
27 | nn_mlp <- torch::nn_module(
28 |
29 | "nn_mlp",
30 |
31 | initialize = function(..., activation = nnf_relu){
32 | layers <- list(...)
33 |
34 | if (!at_least_two_integers(layers))
35 | stop("Specified layers must contain at least two integer numerics,
36 | which describes at least one leayer (input and output)")
37 |
38 | # Check, if any activation was specified
39 | if (length(int_elements(layers)) < length(layers))
40 | activation <- NULL
41 |
42 | int_indices <- which(
43 | sapply(layers, is_int)
44 | )
45 |
46 | int_table <-
47 | data.frame(
48 | .curr = int_indices,
49 | .next = dplyr::lead(int_indices)
50 | )
51 |
52 | # Layer is a "candidate layer"
53 | # The last element in the "layers" table has length equal to 0.
54 | # See: .is_last() function
55 | layer_names <- NULL
56 | n_layers <- length(layers)
57 |
58 | for (i in seq_along(layers)) {
59 |
60 | layer <- layers[[i]]
61 |
62 | if (is_int(layer)) {
63 | if (!.is_last(i, int_table)) {
64 | layer <- nn_linear(layer, .next_int(i, layers, int_table))
65 | } else {
66 | next
67 | }
68 | } else if (is.character(layer)) {
69 | layer <- get(glue::glue("nnf_{layer}"),
70 | envir = rlang::pkg_env("torch"))
71 | }
72 |
73 | layer_name <- glue::glue("layer_{i}")
74 |
75 | self[[layer_name]] <- layer
76 |
77 | layer_names <- c(layer_names, layer_name)
78 |
79 | if (i == n_layers - 1)
80 | next
81 |
82 | if (!is.null(activation)) {
83 | activation_layer_name <- glue::glue("layer_{i}_activation")
84 | self[[activation_layer_name]] <- activation #clone_if_module(activation)
85 | layer_names <- c(layer_names, activation_layer_name)
86 | }
87 |
88 | }
89 |
90 | self$layer_names <- layer_names
91 |
92 | },
93 |
94 | forward = function(x){
95 | output <- x
96 | for (ln in self$layer_names) {
97 | # print(output)
98 | output <- self[[ln]](output)
99 | }
100 | output
101 | }
102 |
103 | )
104 |
105 | is_int <- function(x){
106 | if (is.numeric(x))
107 | if (x %% 1 == 0)
108 | return(TRUE)
109 | FALSE
110 | }
111 |
112 | .next_int <- function(i, lst, idx_table){
113 | idx <- idx_table[idx_table$.curr == i, ]$.next
114 | lst[[idx]]
115 | }
116 |
117 | .is_last <- function(i, idx_table){
118 | output <- is.na(idx_table[idx_table$.curr == i, ]$.next)
119 | if (length(output) == 0)
120 | return(FALSE)
121 | else
122 | return(output)
123 | }
124 |
125 | #'
126 | #' @examples
127 | #' at_least_two_integers(list(2, 'char'))
128 | #' at_least_two_integers(list(2, 'char', 3))
129 | at_least_two_integers <- function(l){
130 | length(int_elements(l)) >= 2
131 | }
132 |
133 | int_elements <- function(l){
134 | Filter(is_int, l)
135 | }
136 |
137 | clone_if_module <- function(object){
138 | if (inherits(object, 'nn_module'))
139 | return(object$clone())
140 | else
141 | object
142 | }
143 |
144 |
--------------------------------------------------------------------------------
/R/nn-multi-embedding.R:
--------------------------------------------------------------------------------
1 | #' Create multiple embeddings at once
2 | #'
3 | #' It is especially useful, for dealing with multiple categorical features.
4 | #'
5 | #' @param num_embeddings (`integer`) Size of the dictionary of embeddings.
6 | #' @param embedding_dim (`integer`) The size of each embedding vector.
7 | #' @param padding_idx (`integer`, optional) If given, pads the output with
8 | #' the embedding vector at `padding_idx` (initialized to zeros) whenever it encounters the index.
9 | #' @param max_norm (`numeric`, optional) If given, each embedding vector with norm larger
10 | #' than max_norm is renormalized to have norm max_norm.
11 | #' @param norm_type (`numeric`, optional) The p of the p-norm to compute for the max_norm option. Default 2.
12 | #' @param scale_grad_by_freq (`logical`, optional) If given, this will scale gradients by
13 | #' the inverse of frequency of the words in the mini-batch. Default FALSE.
14 | #' @param sparse (`logical`, optional) If TRUE, gradient w.r.t. weight matrix will be a sparse tensor.
15 | #' @param .weight (`torch_tensor` or `list` of `torch_tensor`) Embeddings weights (in case you want to set it manually).
16 | #'
17 | #' @importFrom torch nn_module nn_embedding
18 | #' @importFrom glue glue
19 | #'
20 | #' @examples
21 | #' library(recipes)
22 | #'
23 | #' data("gss_cat", package = "forcats")
24 | #'
25 | #' gss_cat_transformed <-
26 | #' recipe(gss_cat) %>%
27 | #' step_integer(everything()) %>%
28 | #' prep() %>%
29 | #' juice()
30 | #'
31 | #' gss_cat_transformed <- na.omit(gss_cat_transformed)
32 | #'
33 | #' gss_cat_transformed <-
34 | #' gss_cat_transformed %>%
35 | #' mutate(across(where(is.numeric), as.integer))
36 | #'
37 | #' glimpse(gss_cat_transformed)
38 | #'
39 | #' gss_cat_tensor <- as_tensor(gss_cat_transformed)
40 | #' .dict_size <- dict_size(gss_cat_transformed)
41 | #' .dict_size
42 | #'
43 | #' .embedding_size <- embedding_size_google(.dict_size)
44 | #'
45 | #' embedding_module <-
46 | #' nn_multi_embedding(.dict_size, .embedding_size)
47 | #'
48 | #' # Expected output size
49 | #' sum(.embedding_size)
50 | #'
51 | #' embedding_module(gss_cat_tensor)
52 | #'
53 | #' @export
54 | nn_multi_embedding <- torch::nn_module(
55 |
56 | #' See:
57 | #' "Optimal number of embeddings"
58 | #' See: https://developers.googleblog.com/2017/11/introducing-tensorflow-feature-columns.html
59 |
60 | "nn_multi_embedding",
61 |
62 | initialize = function(num_embeddings, embedding_dim,
63 | padding_idx = NULL, max_norm = NULL, norm_type = 2,
64 | scale_grad_by_freq = FALSE, sparse = FALSE,
65 | .weight = NULL){
66 |
67 | # Check arguments
68 | if (length(num_embeddings) != length(embedding_dim) &
69 | !(length(num_embeddings) == 1 | length(embedding_dim) == 1)) {
70 | torch:::value_error("Values has not equal lengths")
71 | }
72 |
73 | if (length(num_embeddings) > 1 & length(embedding_dim) == 1)
74 | embedding_dim <- rep(embedding_dim, length(num_embeddings))
75 |
76 | if (length(embedding_dim) > 1 & length(num_embeddings) == 1)
77 | num_embeddings <- rep(num_embeddings, length(embedding_dim))
78 |
79 | required_len <- max(length(embedding_dim), length(num_embeddings))
80 |
81 | padding_idx <- rep_if_one_element(padding_idx, required_len)
82 | max_norm <- rep_if_one_element(max_norm, required_len)
83 | norm_type <- rep_if_one_element(norm_type, required_len)
84 | scale_grad_by_freq <- rep_if_one_element(scale_grad_by_freq, required_len)
85 | sparse <- rep_if_one_element(sparse, required_len)
86 |
87 | if (length(.weight) == 1)
88 | .weight <- rep(list(.weight), required_len)
89 |
90 | self$num_embeddings <- num_embeddings
91 |
92 | for (idx in seq_along(self$num_embeddings)){
93 |
94 | self[[glue("embedding_{idx}")]] <-
95 | nn_embedding(
96 | num_embeddings = num_embeddings[[idx]],
97 | embedding_dim = embedding_dim[[idx]],
98 | padding_idx = padding_idx[[idx]],
99 | max_norm = max_norm[[idx]],
100 | norm_type = norm_type[[idx]],
101 | scale_grad_by_freq = scale_grad_by_freq[[idx]],
102 | sparse = sparse[[idx]],
103 | .weight = .weight[[idx]]
104 | )
105 | }
106 |
107 | },
108 |
109 | forward = function(input){
110 | embedded_features <- list()
111 |
112 | for (idx in seq_along(self$num_embeddings)) {
113 | embedded_features[[glue("embedding_{idx}")]] <-
114 | self[[glue("embedding_{idx}")]](input[.., idx])
115 | }
116 |
117 | torch_cat(embedded_features, dim = -1)
118 | }
119 | )
120 |
--------------------------------------------------------------------------------
/R/nn-nonlinear.R:
--------------------------------------------------------------------------------
1 | #' Shortcut to create linear layer with nonlinear activation function
2 | #'
3 | #' @param in_features (`integer`) size of each input sample
4 | #' @param out_features (`integer`) size of each output sample
5 | #' @param bias (`logical`) If set to `FALSE`, the layer will not learn an additive bias.
6 | #' Default: `TRUE`
7 | #' @param activation (`nn_module`) A nonlinear activation function (default: [torch::nn_relu()])
8 | #'
9 | #' @examples
10 | #' net <- nn_nonlinear(10, 1)
11 | #' x <- torch_tensor(matrix(1, nrow = 2, ncol = 10))
12 | #' net(x)
13 | #'
14 | #' @export
15 | nn_nonlinear <- torch::nn_module(
16 |
17 | "nn_nonlinear",
18 |
19 | initialize = function(in_features, out_features, bias = TRUE, activation = nn_relu()) {
20 | self$linear <- nn_linear(in_features, out_features, bias = bias)
21 | self$activation <- activation
22 | },
23 |
24 | forward = function(input){
25 | self$activation(self$linear(input))
26 | }
27 |
28 | )
29 |
30 |
--------------------------------------------------------------------------------
/R/palette.R:
--------------------------------------------------------------------------------
1 | #' Picked from torchts logo
2 | torchts_palette <-
3 | c(
4 | "#ef4c2d", #orange
5 | "#3f0062", #dark violet,
6 | "#78003c", #red-violet
7 | "#ff9400", #yellow
8 | "#7e0027" #red
9 | )
10 |
--------------------------------------------------------------------------------
/R/plot.R:
--------------------------------------------------------------------------------
1 | #' Plot forecast vs ground truth
2 | #'
3 | #' @param data
4 | #' @param forecast
5 | #' @param outcome
6 | #' @param index
7 | #' @param interactive (`logical`)
8 | #'
9 | #' @importFrom ggplot2 ggplot geom_line aes theme_minimal ggtitle
10 | #'
11 | #' @export
12 | plot_forecast <- function(data, forecast, outcome,
13 | index = NULL, interactive = FALSE,
14 | title = "Forecast vs actual values",
15 | ...){
16 |
17 | outcome <- as.character(substitute(outcome))
18 |
19 | if (!is.null(index))
20 | index <- as.character(substitute(index))
21 |
22 | if (ncol(forecast) > 1)
23 | forecast <- forecast[outcome]
24 |
25 | fcast_vs_true <-
26 | bind_cols(
27 | n = 1:nrow(data),
28 | actual = data[[outcome]],
29 | forecast
30 | ) %>%
31 | tidyr::pivot_longer(c(actual, .pred))
32 |
33 | p <-
34 | ggplot(fcast_vs_true) +
35 | geom_line(aes(n, value, col = name)) +
36 | theme_minimal() +
37 | ggtitle(title) +
38 | scale_color_manual(values = torchts_palette)
39 |
40 | if (interactive)
41 | p <- plotly::ggplotly()
42 |
43 | p
44 | }
45 |
--------------------------------------------------------------------------------
/R/predict.R:
--------------------------------------------------------------------------------
1 | torchts_predict <- function(object, new_data, ...){
2 | # WARNING: Cannot be used parallely for now
3 |
4 | # For now we suppose it's continuous
5 | # TODO: Check more conditions
6 | # TODO: keys!!!
7 |
8 |
9 | n_outcomes <- length(object$outcomes)
10 | batch_size <- 1
11 |
12 | # Checks
13 | check_length_vs_horizon(object, new_data)
14 | check_is_new_data_complete(object, new_data)
15 | recursive_mode <- check_recursion(object, new_data)
16 |
17 | # Preparing dataloader
18 | new_data_dl <-
19 | as_ts_dataloader(
20 | new_data,
21 | timesteps = object$timesteps,
22 | horizon = object$horizon,
23 | batch_size = batch_size,
24 | jump = object$horizon,
25 | # Extras
26 | parsed_formula = object$parsed_formula,
27 | cat_recipe = object$extras$cat_recipe,
28 | shuffle = FALSE,
29 | drop_last = FALSE
30 | )
31 |
32 | net <- object$net
33 |
34 | if (!is.null(object$device)) {
35 | net <- set_device(net, object$device)
36 | new_data_dl <- set_device(new_data_dl, object$device)
37 | }
38 |
39 | net$eval()
40 |
41 | output_shape <-
42 | c(length(new_data_dl$dataset), object$horizon, length(object$outcomes))
43 |
44 | preds <- array(0, dim = output_shape)
45 | iter <- 0
46 |
47 | # b <- dataloader_next(dataloader_make_iter(new_data_dl))
48 |
49 | coro::loop(for (b in new_data_dl) {
50 |
51 | output <- do.call(net, get_x(b))
52 | preds[iter+1,,] <- as_array(output$cpu())
53 |
54 | if (recursive_mode) {
55 | start <- object$timesteps + iter * object$horizon + 1
56 | end <- object$timesteps + iter * object$horizon + object$horizon
57 | cols <- unlist(new_data_dl$dataset$outcomes_spec)
58 |
59 | if (length(cols) == 1)
60 | output <- output$reshape(nrow(output))
61 |
62 | # TODO: insert do dataset even after last forecast for consistency?
63 | if (dim(new_data_dl$dataset$data[start:end, mget(object$outcomes)]) == dim(output))
64 | new_data_dl$dataset$data[start:end, mget(object$outcomes)] <- output
65 | }
66 |
67 | iter <- iter + 1
68 |
69 | })
70 |
71 | # Make sure that forecast has right length
72 | preds <-
73 | preds %>%
74 | aperm(c(2, 1, 3)) %>%
75 | array(dim = c(output_shape[1] * output_shape[2], output_shape[3]))
76 |
77 | # Adding colnames if more than one outcome
78 | if (ncol(preds) > 1)
79 | colnames(preds) <- object$outcomes
80 | else
81 | colnames(preds) <- ".pred"
82 |
83 | # browser()
84 |
85 | # Cutting if longer than expected
86 | preds <- as_tibble(preds)
87 | preds <- head(preds, nrow(new_data) - object$timesteps)
88 | preds <- preprend_empty(preds, object$timesteps)
89 |
90 | preds
91 | }
92 |
--------------------------------------------------------------------------------
/R/prepare-data.R:
--------------------------------------------------------------------------------
1 | #' Prepare dataloders
2 | #'
3 | #' @inheritParams as_ts_dataset
4 | #' @inheritParams torchts_rnn
5 | #'
6 | prepare_dl <- function(data, formula, index,
7 | timesteps, horizon,
8 | categorical = NULL,
9 | validation = NULL,
10 | sample_frac = 1,
11 | batch_size, shuffle, jump,
12 | parsed_formula = NULL, flatten = FALSE, ...){
13 |
14 | # TODO: use predictors, outcomes instead of parsing formula second time
15 | valid_dl <- NULL
16 |
17 | if (!is.null(validation)) {
18 |
19 | if(is.numeric(validation)) {
20 |
21 | train_len <- floor(nrow(data) * (1 - validation))
22 | assess_len <- nrow(data) - train_len
23 |
24 | validation <-
25 | data %>%
26 | arrange(!!index) %>%
27 | tail(timesteps + assess_len)
28 |
29 | data <-
30 | data %>%
31 | arrange(!!index) %>%
32 | head(train_len)
33 |
34 | # data_split <-
35 | # timetk::time_series_split(
36 | # data = data,
37 | # date_var = !!index,
38 | # lag = timesteps,
39 | # initial = train_len,
40 | # assess = assess_len
41 | # )
42 |
43 | # data <- rsample::training(data_split)
44 | # validation <- rsample::testing(data_split)
45 | }
46 |
47 | valid_dl <-
48 | as_ts_dataloader(
49 | data = validation,
50 | formula = formula,
51 | timesteps = timesteps,
52 | horizon = horizon,
53 | categorical = categorical,
54 | # sample_frac = sample_frac,
55 | batch_size = batch_size,
56 | parsed_formula = parsed_formula
57 | )
58 |
59 | }
60 |
61 | train_dl <-
62 | as_ts_dataloader(
63 | data = data,
64 | formula = formula,
65 | timesteps = timesteps,
66 | horizon = horizon,
67 | categorical = categorical,
68 | sample_frac = sample_frac,
69 | batch_size = batch_size,
70 | shuffle = shuffle,
71 | jump = jump,
72 | parsed_formula = parsed_formula
73 | )
74 |
75 | list(
76 | train_dl = train_dl,
77 | valid_dl = valid_dl
78 | )
79 | }
80 |
81 |
82 | prepare_categorical <- function(data, categorical){
83 |
84 | if (nrow(categorical) > 0) {
85 |
86 | embedded_vars <- dict_size(data[, mget(categorical$.var)])
87 | embedding_size <- embedding_size_google(embedded_vars)
88 |
89 | embedding<-
90 | embedding_spec(
91 | num_embeddings = embedded_vars,
92 | embedding_dim = embedding_size
93 | )
94 |
95 | } else {
96 | embedding <- NULL
97 | }
98 |
99 | embedding
100 | }
101 |
102 |
103 |
--------------------------------------------------------------------------------
/R/static.R:
--------------------------------------------------------------------------------
1 | #' Check, which variables are static
2 | #'
3 | #' @examples
4 | #' data <- tiny_m5 %>%
5 | #' dplyr::select(store_id, item_id, state_id,
6 | #' weekday, wday, month, year)
7 | #'
8 | #' @export
9 | which_static <- function(data, key, cols = NULL){
10 |
11 | if (is.null(cols))
12 | cols <- colnames(data)
13 |
14 | non_grouping_vars <- setdiff(cols, key)
15 |
16 | data %>%
17 | group_by(across(all_of(key))) %>%
18 | summarise(across(all_of(non_grouping_vars), all_the_same)) %>%
19 | ungroup() %>%
20 | summarise(across(all_of(non_grouping_vars), all))
21 | }
22 |
--------------------------------------------------------------------------------
/R/torchts-model.R:
--------------------------------------------------------------------------------
1 | #' Torchts abstract model
2 | torchts_model <- function(class, net, index, key,
3 | outcomes, predictors,
4 | optim, timesteps,
5 | parsed_formula,
6 | horizon, device,
7 | col_map_out,
8 | extras){
9 | structure(
10 | class = c(class, "torchts_model"),
11 | list(
12 | net = net,
13 | index = index,
14 | key = key,
15 | outcomes = outcomes,
16 | predictors = predictors,
17 | optim = optim,
18 | timesteps = timesteps,
19 | parsed_formula = parsed_formula,
20 | horizon = horizon,
21 | device = device,
22 | col_map_out = col_map_out
23 | )
24 | )
25 | }
26 |
27 |
28 | #' @export
29 | print.torchts_model <- function(x, ...){
30 |
31 | key <- if (length(x$key) == 0) "NULL" else x$key
32 | predictors <- paste0(x$predictors, collapse = ", ")
33 | outcomes <- paste0(x$outcomes, collapse = ", ")
34 |
35 | print(x$net)
36 | cat("\n")
37 | cat("Model specification: \n")
38 | cli::cat_bullet(glue::glue("key: {key}"))
39 | cli::cat_bullet(glue::glue("index: {x$index}"))
40 | cli::cat_bullet(glue::glue("predictors: {predictors}"))
41 | cli::cat_bullet(glue::glue("outcomes: {outcomes}"))
42 | cli::cat_bullet(glue::glue("timesteps: {x$timesteps}"))
43 | cli::cat_bullet(glue::glue("horizon: {x$horizon}"))
44 | cli::cat_bullet(glue::glue("optimizer: {class(x$optim)[1]}"))
45 |
46 | }
47 |
--------------------------------------------------------------------------------
/R/torchts-package.R:
--------------------------------------------------------------------------------
1 | #' torchts: Time Series Models in torch
2 | #'
3 | #' @author Krzysztof Joachimiak
4 | #' @keywords package
5 | #' @name torchts-package
6 | #' @aliases torchts
7 | #' @docType package
8 | NULL
9 |
--------------------------------------------------------------------------------
/R/training-helpers.R:
--------------------------------------------------------------------------------
1 | #' Training helper
2 | train_batch <- function(input, target,
3 | net, optimizer,
4 | loss_fun = nnf_mse_loss) {
5 |
6 | optimizer$zero_grad()
7 | output <- do.call(net, input)
8 |
9 | loss <- loss_fun(output, target$y)
10 | loss$backward()
11 | optimizer$step()
12 |
13 | loss$item()
14 | }
15 |
16 | #' Validation helper function
17 | valid_batch <- function(net, input, target,
18 | loss_fun = nnf_mse_loss) {
19 | output <- do.call(net, input)
20 | loss <- loss_fun(output, target$y)
21 | loss$item()
22 |
23 | }
24 |
25 |
26 | #' Fit a neural network
27 | fit_network <- function(net, train_dl, valid_dl = NULL, epochs,
28 | optimizer, loss_fn){
29 |
30 | message("\nTraining started")
31 |
32 | # Info in Keras
33 | # 938/938 [==============================] - 1s 1ms/step - loss: 0.0563 - acc: 0.9829 - val_loss: 0.1041 - val_acc: 0.9692
34 | # epoch <- 1
35 |
36 | loss_history <- c()
37 |
38 | for (epoch in seq_len(epochs)) {
39 |
40 | net$train()
41 | train_loss <- c()
42 |
43 | # b <- dataloader_next(dataloader_make_iter(train_dl))
44 | train_pb <- progress_bar$new(
45 | "Epoch :epoch/:nepochs [:bar] :current/:total (:percent)",
46 | total = length(train_dl),
47 | clear = FALSE,
48 | width = 50
49 | )
50 |
51 | coro::loop(for (b in train_dl) {
52 | loss <- train_batch(
53 | input = get_x(b),
54 | target = get_y(b),
55 | net = net,
56 | optimizer = optimizer,
57 | loss_fun = loss_fn
58 | )
59 | train_loss <- c(train_loss, loss)
60 | train_pb$tick(tokens = list(epoch = epoch, nepochs = epochs))
61 | })
62 |
63 | valid_loss_info <- ""
64 |
65 | if (!is.null(valid_dl)) {
66 |
67 | net$eval()
68 | valid_loss <- c()
69 |
70 | coro::loop(for (b in valid_dl) {
71 | loss <- valid_batch(b)
72 | valid_loss <- c(valid_loss, loss)
73 | })
74 |
75 | valid_loss_info <- sprintf("validation: %3.5f", mean(valid_loss))
76 | }
77 |
78 | mean_epoch_loss <- mean(train_loss)
79 | loss_history <- c(loss_history, mean_epoch_loss)
80 |
81 | message(sprintf(" | train: %3.5f %s \n",
82 | mean_epoch_loss, valid_loss_info
83 | ), appendLF = FALSE)
84 |
85 | }
86 |
87 | net
88 | }
89 |
90 | #' batch <- list(x_num = "aaa", x_cat = "bbb", y = "c")
91 | #' get_x(batch)
92 | get_x <- function(batch){
93 | batch[startsWith(names(batch), "x")]
94 | }
95 |
96 | get_y <- function(batch){
97 | batch[startsWith(names(batch), "y")]
98 | }
99 |
100 |
101 |
102 |
--------------------------------------------------------------------------------
/R/utils-internal.R:
--------------------------------------------------------------------------------
1 | #' An auxilliary function to call optimizer
2 | call_optim <- function(optim, learn_rate, params){
3 | if (!rlang::is_quosure(optim))
4 | quosure <- rlang::enquo(optim)
5 | else
6 | quosure <- optim
7 | fun <- rlang::call_fn(quosure)
8 | args <- c(
9 | list(lr = learn_rate,
10 | params = params),
11 | rlang::call_args(quosure)
12 | )
13 | do.call({fun}, args)
14 | }
15 |
16 |
17 | update_dl <- function(dl, output){
18 | target_col <- dl$dataset$target_columns
19 | new_data_dl$dataset$data[.., target_col][1:30]
20 |
21 | new_data_dl$.index_sampler$sampler
22 |
23 | }
24 |
25 |
26 | detach_hidden_state <- function(hx){
27 | if (is.list(hx))
28 | return(purrr::map(hx, ~ .x$clone()$detach()))
29 | else
30 | return(hx$clone()$detach())
31 | }
32 |
33 |
34 | #' Repeat element if it length == 1
35 | rep_if_one_element <- function(x, output_length){
36 | if (length(x) == 1)
37 | return(rep(x, output_length))
38 | else
39 | return(x)
40 | }
41 |
42 | #' Remove parsnip model
43 | #' For development purposes only
44 | remove_model <- function(model = "rnn"){
45 | env <- parsnip:::get_model_env()
46 | model_names <- grep(model, names(env), value = TRUE)
47 | rm(list = model_names, envir = env)
48 | }
49 |
50 |
51 | vars_with_role <- function(parsed_formula, role){
52 | parsed_formula$.var[parsed_formula$.role == role]
53 | }
54 |
55 | get_vars <- function(parsed_formula, role, type){
56 | parsed_formula[parsed_formula$.role == role &
57 | parsed_formula$.type == type &
58 | is.na(parsed_formula$.modifier), ]$.var
59 | }
60 |
61 | get_vars2 <- function(parsed_formula, role, type, modifier){
62 | parsed_formula$.modifier <- ifelse(
63 | is.na(parsed_formula$.modifier),
64 | "",
65 | parsed_formula$.modifier
66 | )
67 | parsed_formula[parsed_formula$.role == role &
68 | parsed_formula$.type == type &
69 | parsed_formula$.modifier == modifier, ]$.var
70 | }
71 |
72 | filter_vars <- function(parsed_formula, role = NULL, class = NULL){
73 | parsed_formula$.var[
74 | parsed_formula$.role == role &
75 | parsed_formula$.class == c
76 | ]
77 | }
78 |
79 |
80 | listed <- function(x){
81 | # Add truncate option
82 | paste0(x, collapse = ", ")
83 | }
84 |
85 | all_the_same <- function(x){
86 | all(x == x[1])
87 | }
88 |
89 | #' https://stackoverflow.com/questions/26083625/how-do-you-include-data-frame-output-inside-warnings-and-errors
90 | print_and_capture <- function(x){
91 | paste(capture.output(print(x)), collapse = "\n")
92 | }
93 |
--------------------------------------------------------------------------------
/R/utils.R:
--------------------------------------------------------------------------------
1 | #' RNN output size
2 | #' @param module (nn_module) A torch `nn_module`
3 | #' @examples
4 | #' gru_layer <- nn_gru(15, 3)
5 | #' rnn_output_size(gru_layer)
6 | #' @export
7 | rnn_output_size <- function(module){
8 | tail(dim(module$weight_hh_l1), 1)
9 | }
10 |
11 | #' Partially clear outcome variable
12 | #' in new data by overriding with NA values
13 | #'
14 | #' @param data (`data.frame`) New data
15 | #' @param index Date variable
16 | #' @param outcome Outcome (target) variable
17 | #' @param timesteps (`integer`) Number of timesteps used by RNN model
18 | #' @param key A key (id) to group the data.frame (for panel data)
19 | #'
20 | #' @importFrom dplyr group_by
21 | #'
22 | #' @examples
23 | #' tarnow_temp <-
24 | #' weather_pl %>%
25 | #' filter(station == "TRN") %>%
26 | #' select(date, tmax_daily, tmin_daily, press_mean_daily)
27 | #'
28 | #' TIMESTEPS <- 20
29 | #' HORIZON <- 1
30 | #'
31 | #' data_split <-
32 | #' time_series_split(
33 | #' tarnow_temp, date,
34 | #' initial = "18 years",
35 | #' assess = "2 years",
36 | #' lag = TIMESTEPS
37 | #' )
38 | #'
39 | #' cleared_new_data <-
40 | #' testing(data_split) %>%
41 | #' clear_outcome(date, tmax_daily, TIMESTEPS)
42 | #'
43 | #' head(cleared_new_data, TIMESTEPS + 10)
44 | #'
45 | #' @export
46 | clear_outcome <- function(data, index, outcome, timesteps, key = NULL){
47 |
48 | index <- as.character(substitute(index))
49 | outcome <- as.character(substitute(outcome))
50 |
51 | if (outcome[1] == "c")
52 | outcome <- outcome[-1]
53 |
54 | if (!is.null(key))
55 | key <- as.character(substitute(key))
56 |
57 | data %>%
58 | arrange(!!index) %>%
59 | group_by(!!key) %>%
60 | mutate(across(!!outcome, ~ c(.x[1:timesteps], rep(NA, n() - timesteps))))
61 | }
62 |
63 | inherits_any <- function(col, types){
64 | any(sapply(types, function(type) inherits(col, type)))
65 | }
66 |
67 | inherits_any_char <- function(class, desired_classes){
68 | output <- sapply(class, function(cls) any(cls[[1]] %in% desired_classes))
69 | names(output) <- NULL
70 | output
71 | }
72 |
73 | zeroable <- function(x){
74 | if (is.null(x))
75 | return(0)
76 | else
77 | return(x)
78 | }
79 |
80 | #' Colmap for outcome variable
81 | col_map_out <- function(dataloader){
82 | unlist(dataloader$dataset$outcomes_spec)
83 | }
84 |
85 | # Remove NULLs from a list
86 | remove_nulls <- function(x) {
87 | Filter(function(var) !is.null(var) & length(var) != 0, x)
88 | }
89 |
90 |
91 | preprend_empty <- function(df, n){
92 | empty_rows <- matrix(NA, nrow = n, ncol = ncol(df))
93 | colnames(empty_rows) <- colnames(df)
94 | empty_rows <- as_tibble(empty_rows)
95 | rbind(empty_rows, df)
96 | }
97 |
98 |
99 | # TODO: key_hierarchy
100 |
101 |
102 |
103 |
104 |
--------------------------------------------------------------------------------
/R/zzz.R:
--------------------------------------------------------------------------------
1 | .onLoad <- function(libname, pkgname) {
2 |
3 | # Settings
4 | options(
5 | torchts_categoricals = c("logical", "factor", "character", "integer"),
6 |
7 | # TODO: tochts_time and so on?
8 | torchts_dates = c("Date", "POSIXt", "POSIXlt", "POSIXct"),
9 |
10 | # Default device
11 | torchts_default_device = 'cpu'
12 | )
13 |
14 | # Parsnip models
15 | # remove_model("rnn")
16 | # make_rnn()
17 | # make_lagged_mlp()
18 | }
19 |
--------------------------------------------------------------------------------
/README.Rmd:
--------------------------------------------------------------------------------
1 | ---
2 | output: github_document
3 | always_allow_html: true
4 | ---
5 |
6 |
7 |
8 | ```{r, include = FALSE}
9 | knitr::opts_chunk$set(
10 | collapse = TRUE,
11 | comment = "#>",
12 | fig.path = "man/figures/README-",
13 | out.width = "100%"
14 | )
15 | ```
16 |
17 | # torchts
18 |
19 |
20 | [](https://CRAN.R-project.org/package=torchts)
21 | [](https://github.com/krzjoa/torchts/actions)
22 | [](https://codecov.io/gh/krzjoa/torchts?branch=master)
23 | [](https://www.redbubble.com/i/sticker/torchts-R-package-hex-sticker-by-krzjoa/93537989.EJUG5)
25 |
26 |
27 |
28 |
29 | > Time series models with torch
30 |
31 | [](https://www.buymeacoffee.com/kjoachimiak)
32 |
33 | ## Installation
34 |
35 | You can install the released version of torchts from [CRAN](https://CRAN.R-project.org) with:
36 |
37 | The development version from [GitHub](https://github.com/) with:
38 |
39 | ``` r
40 | # install.packages("devtools")
41 | devtools::install_github("krzjoa/torchts")
42 | ```
43 |
44 | ## parsnip models
45 |
46 | ```{r parsnip.api}
47 | library(torchts)
48 | library(torch)
49 | library(rsample)
50 | library(dplyr, warn.conflicts = FALSE)
51 | library(parsnip)
52 | library(timetk)
53 | library(ggplot2)
54 |
55 | tarnow_temp <-
56 | weather_pl %>%
57 | filter(station == "TRN") %>%
58 | select(date, tmax_daily)
59 |
60 | # Params
61 | EPOCHS <- 3
62 | HORIZON <- 1
63 | TIMESTEPS <- 28
64 |
65 | # Splitting on training and test
66 | data_split <-
67 | time_series_split(
68 | tarnow_temp, date,
69 | initial = "18 years",
70 | assess = "2 years",
71 | lag = TIMESTEPS
72 | )
73 |
74 | # Training
75 | rnn_model <-
76 | rnn(
77 | timesteps = TIMESTEPS,
78 | horizon = HORIZON,
79 | epochs = EPOCHS,
80 | learn_rate = 0.01,
81 | hidden_units = 20,
82 | batch_size = 32,
83 | scale = TRUE
84 | ) %>%
85 | set_device('cpu') %>%
86 | fit(tmax_daily ~ date,
87 | data = training(data_split))
88 |
89 | prediction <-
90 | rnn_model %>%
91 | predict(new_data = testing(data_split))
92 |
93 | plot_forecast(
94 | data = testing(data_split),
95 | forecast = prediction,
96 | outcome = tmax_daily
97 | )
98 | ```
99 |
100 | ## Transforming data.frames to tensors
101 |
102 | In `as_tensor` function we can specify columns, that are used to
103 | create a tensor out of the input `data.frame`. Listed column names
104 | are only used to determine dimension sizes - they are removed after that
105 | and are not present in the final tensor.
106 |
107 | ```{r example}
108 | temperature_pl <-
109 | weather_pl %>%
110 | select(station, date, tmax_daily)
111 |
112 | # Expected shape
113 | c(
114 | n_distinct(temperature_pl$station),
115 | n_distinct(temperature_pl$date),
116 | 1
117 | )
118 |
119 | temperature_tensor <-
120 | temperature_pl %>%
121 | as_tensor(station, date)
122 |
123 | dim(temperature_tensor)
124 | temperature_tensor[1, 1:10]
125 |
126 | temperature_pl %>%
127 | filter(station == "SWK") %>%
128 | arrange(date) %>%
129 | head(10)
130 | ```
131 |
132 | ## Similar projects in Python
133 |
134 | * [PyTorch Forecasting](https://pytorch-forecasting.readthedocs.io/en/stable/)
135 | * [PyTorchTS](https://github.com/zalandoresearch/pytorch-ts)
136 | * [TorchTS](https://rose-stl-lab.github.io/torchTS/)
137 | * [GluonTS ](https://ts.gluon.ai/)
138 | * [sktime-dl](https://github.com/sktime/sktime-dl)
139 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | # torchts
5 |
6 |
7 |
8 | [](https://CRAN.R-project.org/package=torchts)
10 | [](https://github.com/krzjoa/torchts/actions)
12 | [](https://codecov.io/gh/krzjoa/torchts?branch=master)
14 | [](https://www.redbubble.com/i/sticker/torchts-R-package-hex-sticker-by-krzjoa/93537989.EJUG5)
16 |
17 |
18 |
19 | > Time series models with torch
20 |
21 | [](https://www.buymeacoffee.com/kjoachimiak)
23 |
24 | ## Installation
25 |
26 | You can install the released version of torchts from
27 | [CRAN](https://CRAN.R-project.org) with:
28 |
29 | The development version from [GitHub](https://github.com/) with:
30 |
31 | ``` r
32 | # install.packages("devtools")
33 | devtools::install_github("krzjoa/torchts")
34 | ```
35 |
36 | ## parsnip models
37 |
38 | ``` r
39 | library(torchts)
40 | library(torch)
41 | library(rsample)
42 | library(dplyr, warn.conflicts = FALSE)
43 | library(parsnip)
44 | library(timetk)
45 | library(ggplot2)
46 |
47 | tarnow_temp <-
48 | weather_pl %>%
49 | filter(station == "TRN") %>%
50 | select(date, tmax_daily)
51 |
52 | # Params
53 | EPOCHS <- 3
54 | HORIZON <- 1
55 | TIMESTEPS <- 28
56 |
57 | # Splitting on training and test
58 | data_split <-
59 | time_series_split(
60 | tarnow_temp, date,
61 | initial = "18 years",
62 | assess = "2 years",
63 | lag = TIMESTEPS
64 | )
65 |
66 | # Training
67 | rnn_model <-
68 | rnn(
69 | timesteps = TIMESTEPS,
70 | horizon = HORIZON,
71 | epochs = EPOCHS,
72 | learn_rate = 0.01,
73 | hidden_units = 20,
74 | batch_size = 32,
75 | scale = TRUE
76 | ) %>%
77 | set_device('cpu') %>%
78 | fit(tmax_daily ~ date,
79 | data = training(data_split))
80 | #> Warning: Engine set to `torchts`.
81 | #>
82 | #> Training started
83 | #> | train: 0.37756
84 | #> | train: 0.30164
85 | #> | train: 0.28896
86 |
87 | prediction <-
88 | rnn_model %>%
89 | predict(new_data = testing(data_split))
90 |
91 | plot_forecast(
92 | data = testing(data_split),
93 | forecast = prediction,
94 | outcome = tmax_daily
95 | )
96 | #> Warning: Removed 28 row(s) containing missing values (geom_path).
97 | ```
98 |
99 |
100 |
101 | ## Transforming data.frames to tensors
102 |
103 | In `as_tensor` function we can specify columns, that are used to create
104 | a tensor out of the input `data.frame`. Listed column names are only
105 | used to determine dimension sizes - they are removed after that and are
106 | not present in the final tensor.
107 |
108 | ``` r
109 | temperature_pl <-
110 | weather_pl %>%
111 | select(station, date, tmax_daily)
112 |
113 | # Expected shape
114 | c(
115 | n_distinct(temperature_pl$station),
116 | n_distinct(temperature_pl$date),
117 | 1
118 | )
119 | #> [1] 2 7305 1
120 |
121 | temperature_tensor <-
122 | temperature_pl %>%
123 | as_tensor(station, date)
124 |
125 | dim(temperature_tensor)
126 | #> [1] 2 7305 1
127 | temperature_tensor[1, 1:10]
128 | #> torch_tensor
129 | #> -0.2000
130 | #> -1.4000
131 | #> 0.4000
132 | #> 1.0000
133 | #> 0.6000
134 | #> 3.0000
135 | #> 4.0000
136 | #> 1.0000
137 | #> 1.2000
138 | #> 1.4000
139 | #> [ CPUFloatType{10,1} ]
140 |
141 | temperature_pl %>%
142 | filter(station == "SWK") %>%
143 | arrange(date) %>%
144 | head(10)
145 | #> station date tmax_daily
146 | #> 1140 SWK 2001-01-01 -0.2
147 | #> 1230 SWK 2001-01-02 -1.4
148 | #> 2330 SWK 2001-01-03 0.4
149 | #> 2630 SWK 2001-01-04 1.0
150 | #> 2730 SWK 2001-01-05 0.6
151 | #> 2830 SWK 2001-01-06 3.0
152 | #> 2930 SWK 2001-01-07 4.0
153 | #> 3030 SWK 2001-01-08 1.0
154 | #> 3130 SWK 2001-01-09 1.2
155 | #> 2140 SWK 2001-01-10 1.4
156 | ```
157 |
158 | ## Similar projects in Python
159 |
160 | - [PyTorch
161 | Forecasting](https://pytorch-forecasting.readthedocs.io/en/stable/)
162 | - [PyTorchTS](https://github.com/zalandoresearch/pytorch-ts)
163 | - [TorchTS](https://rose-stl-lab.github.io/torchTS/)
164 | - [GluonTS](https://ts.gluon.ai/)
165 |
--------------------------------------------------------------------------------
/_pkgdown.yml:
--------------------------------------------------------------------------------
1 | destination: docs
2 | template:
3 | params:
4 | bootswatch: united
5 | authors:
6 | Krzysztof Joachimiak:
7 | href: https://krzjoa.github.io
8 | reference:
9 | - title: torchts quick API
10 | contents:
11 | - torchts_rnn
12 | - torchts_mlp
13 | - title: parsnip API
14 | contents:
15 | - rnn
16 | - lagged_mlp
17 | - title: Modules
18 | contents:
19 | - model_rnn
20 | - model_mlp
21 | - nn_multi_embedding
22 | - nn_nonlinear
23 | - nn_mlp
24 | - title: Data transformations
25 | contents:
26 | - as_tensor
27 | - ts_dataset
28 | - as_ts_dataset
29 | - as_ts_dataloader
30 | - as.vector.torch_tensor
31 | - title: Metrics
32 | contents:
33 | - nnf_mae
34 | - nnf_mape
35 | - nnf_smape
36 | - title: Utils
37 | contents:
38 | - is_categorical
39 | - dict_size
40 | - embedding_size
41 | - clear_outcome
42 | - set_device
43 | - plot_forecast
44 | - title: Data
45 | contents:
46 | - weather_pl
47 | - tiny_m5
48 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | comment: false
2 |
3 | coverage:
4 | status:
5 | project:
6 | default:
7 | target: auto
8 | threshold: 1%
9 | informational: true
10 | patch:
11 | default:
12 | target: auto
13 | threshold: 1%
14 | informational: true
15 |
--------------------------------------------------------------------------------
/data-raw/debug-mlp.Rmd:
--------------------------------------------------------------------------------
1 | ---
2 | title: "debug"
3 | output: html_document
4 | ---
5 |
6 | ```{r setup, include=FALSE}
7 | knitr::opts_chunk$set(echo = TRUE)
8 | ```
9 |
10 | ## Prepare data
11 |
12 | ```{r libs}
13 | library(torch)
14 | #library(torchts)
15 | library(rsample)
16 | library(dplyr, warn.conflicts = FALSE)
17 | library(ggplot2)
18 | library(parsnip)
19 | library(timetk)
20 |
21 | # write.csv(weather_pl, file = "../weather_pl.csv")
22 |
23 | devtools::load_all()
24 | ```
25 |
26 | ## Data
27 | ```{r data}
28 | tarnow_temp <-
29 | weather_pl %>%
30 | filter(station == 'TRN') %>%
31 | arrange(date)
32 | head(tarnow_temp)
33 |
34 |
35 | train <- tarnow_temp %>%
36 | filter(date < as.Date('2018-01-01'))
37 |
38 | test <- tarnow_temp %>%
39 | filter(date >= as.Date('2018-01-01'))
40 | ```
41 |
42 | ```{r ts_dataset}
43 |
44 | TimeSeriesDataset <- torch::dataset(
45 | "TimeSeriesDataSet",
46 |
47 | initialize = function(ts, lookback, horizon, jump, trim_last = TRUE){
48 | # TS
49 | self$ts <- ts
50 | self$lookback <- lookback
51 | self$horizon <- horizon
52 | self$jump <- jump
53 |
54 | # Non overlapping chunks
55 | # Tu jest błąd
56 | self$chunk_size <- (lookback + horizon)
57 | if (trim_last)
58 | self$length <- (length(ts) - self$chunk_size ) %/% jump
59 | else
60 | self$length <- (length(ts) - self$horizon) %/% jump
61 | },
62 |
63 | .length = function(){
64 | self$length
65 | },
66 |
67 | .getitem = function(idx){
68 | # Input
69 | first <- (idx - 1) * self$jump + 1
70 | last_input <- first + self$lookback - 1
71 | X <- self$ts[first:last_input]
72 |
73 | # Output
74 | y <- self$ts[last_input:(last_input + self$horizon - 1)]
75 |
76 | X_tensor <- torch_tensor(X, dtype = torch_float32())
77 | y_tensor <- torch_tensor(y, dtype = torch_float32())
78 |
79 | return(list(
80 | X_tensor$squeeze()$cuda(),
81 | y_tensor$squeeze()$cuda()
82 | ))
83 | }
84 | )
85 | ```
86 |
87 |
88 | ```{r declare.vars}
89 | TIMESTEPS <- 28
90 | HORIZON <- 7
91 | ```
92 |
93 | ```{r scale.data}
94 | mean_val <- mean(train$tmax_daily)
95 | sd_val <- sd(train$tmax_daily)
96 |
97 | train_scaled <- (train$tmax_daily - mean_val) / sd_val
98 | test_scaled <- (test$tmax_daily - mean_val) / sd_val
99 | ```
100 |
101 | ```{r create.ds}
102 | train_ds <- TimeSeriesDataset(train_scaled, TIMESTEPS, HORIZON, 1)
103 | test_ds <- TimeSeriesDataset(test_scaled, TIMESTEPS, HORIZON, HORIZON, FALSE)
104 | ```
105 |
106 |
107 | ```{r cmp.data}
108 | test_ds[1][1]
109 | ```
110 |
111 | ```{r creat.net}
112 | MLP <-
113 | nn_module(
114 |
115 | "MLP",
116 |
117 | initialize = function(input_size, output_size, layers){
118 | self$linear_1 <- nn_linear(input_size, layers[1])
119 | self$activation_1 <- nn_relu()
120 | self$linear_2 <- nn_linear(layers[1], layers[2])
121 | self$activation_2 <- nn_relu()
122 | self$linear_3 <- nn_linear(layers[2], output_size)
123 | },
124 |
125 |
126 | forward = function(X){
127 | X <- self$activation_1(self$linear_1(X))
128 | X <- self$activation_2(self$linear_2(X))
129 | self$linear_3(X)
130 | }
131 | )
132 | ```
133 |
134 | ```{r init.net}
135 | net <- MLP(TIMESTEPS, HORIZON, c(50, 30))
136 | net <- net$cuda()
137 | epochs <- 10
138 | ```
139 |
140 | ```{r optimizer}
141 | optimizer <- optim_adam(net$parameters)
142 | loss_fun <- nn_mse_loss()
143 |
144 | epochs <- 30
145 |
146 | #X, y = next(iter(train_dl))
147 | #X.shape
148 |
149 | ```
150 |
151 | ```{r creating.dls}
152 | train_dl <- dataloader(train_ds, batch_size = 32)
153 | test_dl <- dataloader(test_ds, batch_size = 1)
154 | ```
155 |
156 | ```{r training.loop}
157 | dataloader_next(
158 | dataloader_make_iter(train_dl)
159 | )[1]
160 |
161 |
162 | net$train()
163 |
164 | for (e in seq_len(epochs)) {
165 | train_loss <- 0.0
166 |
167 | coro::loop(for (b in train_dl) {
168 |
169 | X <- b[[1]]
170 | y <- b[[2]]
171 |
172 | optimizer$zero_grad()
173 | target <- net(X)
174 | loss <- loss_fun(target, y)
175 | # Calculate gradients
176 | loss$backward()
177 | # Update Weights
178 | optimizer$step()
179 | # Calculate Loss
180 | train_loss <- train_loss + loss$item()
181 | })
182 | print(glue::glue(
183 | 'Epoch {e} \t\t Training Loss: {train_loss / length(train_dl)}'
184 | ))
185 | }
186 |
187 | ```
188 | ```{r forecast}
189 | # Forecast
190 | forecast <- function(net, test_dl, timesteps){
191 | targets <- rep(NA, timesteps)
192 | net$eval()
193 | coro::loop(for(b in test_dl){
194 | X <- b[[1]]
195 | # y <- b[[2]]
196 | # print(dim(X))
197 | if (dim(X)[2] == TIMESTEPS) {
198 | out <- net(X)$cpu()$flatten()$detach()
199 | out <- as.vector(out)
200 | targets <- c(out, targets)
201 | }
202 | })
203 | targets
204 | }
205 | ```
206 |
207 | ```{r fcast}
208 | fcast <- forecast(net, test_dl, TIMESTEPS)
209 | ```
210 |
211 | ```{r show.fcast}
212 | plot(ts(fcast))
213 | fcast
214 | ```
215 | # Diagnoza
216 | * przepisać i zdebugować ts_dataset
217 | * zdebugować i być może przepisać torchts_mlp
218 | * przyda się wizualizacja tych sieci
219 | * wejściem powinien być data.frame, a nie tensor -> dzięki temu można np. użyć disk.frame a może nawet połączenia do bazy!
220 | * skalować od razu na wejściu do datasetu -> obecny scenariusz niewiele daje(?)
221 |
222 |
223 |
224 |
225 |
--------------------------------------------------------------------------------
/data-raw/prepare-logo.R:
--------------------------------------------------------------------------------
1 | library(ggplot2)
2 | library(data.table)
3 |
4 | devtools::install_github("coolbutuseless/minisvg") # SVG creation
5 | devtools::install_github("coolbutuseless/devout") # Device interface
6 | devtools::install_github("coolbutuseless/devoutsvg") # This package
7 |
8 | shampoo <- read.csv("https://raw.githubusercontent.com/jbrownlee/Datasets/master/shampoo.csv")
9 | setDT(shampoo)
10 | shampoo[, n := 1:.N]
11 |
12 | plt <-
13 | ggplot(shampoo) +
14 | geom_line(aes(n, Sales, lwd = 0.1)) +
15 | theme_void() +
16 | theme(legend.position="none")
17 |
18 | devoutsvg::svgout(
19 | filename = here::here("shampoo.svg"),
20 | width = 8, height = 4,
21 | )
22 |
23 | plt
24 |
25 | invisible(dev.off())
26 |
27 | ggsave("torch.svg")
28 |
29 |
30 | plt
31 |
--------------------------------------------------------------------------------
/data-raw/prepare-weather-pl.R:
--------------------------------------------------------------------------------
1 | suppressMessages(library(dplyr))
2 |
3 | temp_data_set <-
4 | climate::meteo_imgw_daily(year = 2001:2020)
5 |
6 | # temp_data_set %>%
7 | # mutate(date = lubridate::make_date(yy, mm, day)) %>%
8 | # group_by(station) %>%
9 | # summarise(max_date = max(date),
10 | # min_date = min(date), n = n())
11 |
12 | weather_pl <-
13 | temp_data_set %>%
14 | filter(station %in% c("SUWAŁKI", "TARNÓW")) %>%
15 | mutate(date = lubridate::make_date(yy, mm, day)) %>%
16 | select(-rank, -id, -yy, -mm, -day) %>%
17 | select(station, date, starts_with("tm"), starts_with("rr"), starts_with("press"))
18 |
19 | weather_pl <-
20 | weather_pl %>%
21 | mutate(station = case_when(
22 | station == "TARNÓW" ~ "TRN",
23 | station == "SUWAŁKI" ~ "SWK"
24 | ))
25 |
26 | save(weather_pl, file = here::here("data/weather_pl.rda"))
27 |
--------------------------------------------------------------------------------
/data/tiny_m5.rda:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/krzjoa/torchts/56d3b9177919296823ca8926360d8c25664d0fe5/data/tiny_m5.rda
--------------------------------------------------------------------------------
/data/weather_pl.rda:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/krzjoa/torchts/56d3b9177919296823ca8926360d8c25664d0fe5/data/weather_pl.rda
--------------------------------------------------------------------------------
/docs/apple-touch-icon-120x120.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/krzjoa/torchts/56d3b9177919296823ca8926360d8c25664d0fe5/docs/apple-touch-icon-120x120.png
--------------------------------------------------------------------------------
/docs/apple-touch-icon-152x152.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/krzjoa/torchts/56d3b9177919296823ca8926360d8c25664d0fe5/docs/apple-touch-icon-152x152.png
--------------------------------------------------------------------------------
/docs/apple-touch-icon-180x180.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/krzjoa/torchts/56d3b9177919296823ca8926360d8c25664d0fe5/docs/apple-touch-icon-180x180.png
--------------------------------------------------------------------------------
/docs/apple-touch-icon-60x60.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/krzjoa/torchts/56d3b9177919296823ca8926360d8c25664d0fe5/docs/apple-touch-icon-60x60.png
--------------------------------------------------------------------------------
/docs/apple-touch-icon-76x76.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/krzjoa/torchts/56d3b9177919296823ca8926360d8c25664d0fe5/docs/apple-touch-icon-76x76.png
--------------------------------------------------------------------------------
/docs/apple-touch-icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/krzjoa/torchts/56d3b9177919296823ca8926360d8c25664d0fe5/docs/apple-touch-icon.png
--------------------------------------------------------------------------------
/docs/articles/data-prepare-rnn_files/accessible-code-block-0.0.1/empty-anchor.js:
--------------------------------------------------------------------------------
1 | // Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
2 | // v0.0.1
3 | // Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
4 |
5 | document.addEventListener('DOMContentLoaded', function() {
6 | const codeList = document.getElementsByClassName("sourceCode");
7 | for (var i = 0; i < codeList.length; i++) {
8 | var linkList = codeList[i].getElementsByTagName('a');
9 | for (var j = 0; j < linkList.length; j++) {
10 | if (linkList[j].innerHTML === "") {
11 | linkList[j].setAttribute('aria-hidden', 'true');
12 | }
13 | }
14 | }
15 | });
16 |
--------------------------------------------------------------------------------
/docs/articles/data_prepare_rnn_files/accessible-code-block-0.0.1/empty-anchor.js:
--------------------------------------------------------------------------------
1 | // Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
2 | // v0.0.1
3 | // Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
4 |
5 | document.addEventListener('DOMContentLoaded', function() {
6 | const codeList = document.getElementsByClassName("sourceCode");
7 | for (var i = 0; i < codeList.length; i++) {
8 | var linkList = codeList[i].getElementsByTagName('a');
9 | for (var j = 0; j < linkList.length; j++) {
10 | if (linkList[j].innerHTML === "") {
11 | linkList[j].setAttribute('aria-hidden', 'true');
12 | }
13 | }
14 | }
15 | });
16 |
--------------------------------------------------------------------------------
/docs/articles/missing-data_files/accessible-code-block-0.0.1/empty-anchor.js:
--------------------------------------------------------------------------------
1 | // Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
2 | // v0.0.1
3 | // Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
4 |
5 | document.addEventListener('DOMContentLoaded', function() {
6 | const codeList = document.getElementsByClassName("sourceCode");
7 | for (var i = 0; i < codeList.length; i++) {
8 | var linkList = codeList[i].getElementsByTagName('a');
9 | for (var j = 0; j < linkList.length; j++) {
10 | if (linkList[j].innerHTML === "") {
11 | linkList[j].setAttribute('aria-hidden', 'true');
12 | }
13 | }
14 | }
15 | });
16 |
--------------------------------------------------------------------------------
/docs/articles/missing_data_files/accessible-code-block-0.0.1/empty-anchor.js:
--------------------------------------------------------------------------------
1 | // Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
2 | // v0.0.1
3 | // Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
4 |
5 | document.addEventListener('DOMContentLoaded', function() {
6 | const codeList = document.getElementsByClassName("sourceCode");
7 | for (var i = 0; i < codeList.length; i++) {
8 | var linkList = codeList[i].getElementsByTagName('a');
9 | for (var j = 0; j < linkList.length; j++) {
10 | if (linkList[j].innerHTML === "") {
11 | linkList[j].setAttribute('aria-hidden', 'true');
12 | }
13 | }
14 | }
15 | });
16 |
--------------------------------------------------------------------------------
/docs/articles/multivariate-time-series_files/accessible-code-block-0.0.1/empty-anchor.js:
--------------------------------------------------------------------------------
1 | // Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
2 | // v0.0.1
3 | // Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
4 |
5 | document.addEventListener('DOMContentLoaded', function() {
6 | const codeList = document.getElementsByClassName("sourceCode");
7 | for (var i = 0; i < codeList.length; i++) {
8 | var linkList = codeList[i].getElementsByTagName('a');
9 | for (var j = 0; j < linkList.length; j++) {
10 | if (linkList[j].innerHTML === "") {
11 | linkList[j].setAttribute('aria-hidden', 'true');
12 | }
13 | }
14 | }
15 | });
16 |
--------------------------------------------------------------------------------
/docs/articles/naming-convention_files/accessible-code-block-0.0.1/empty-anchor.js:
--------------------------------------------------------------------------------
1 | // Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
2 | // v0.0.1
3 | // Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
4 |
5 | document.addEventListener('DOMContentLoaded', function() {
6 | const codeList = document.getElementsByClassName("sourceCode");
7 | for (var i = 0; i < codeList.length; i++) {
8 | var linkList = codeList[i].getElementsByTagName('a');
9 | for (var j = 0; j < linkList.length; j++) {
10 | if (linkList[j].innerHTML === "") {
11 | linkList[j].setAttribute('aria-hidden', 'true');
12 | }
13 | }
14 | }
15 | });
16 |
--------------------------------------------------------------------------------
/docs/articles/prepare-tensor_files/accessible-code-block-0.0.1/empty-anchor.js:
--------------------------------------------------------------------------------
1 | // Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
2 | // v0.0.1
3 | // Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
4 |
5 | document.addEventListener('DOMContentLoaded', function() {
6 | const codeList = document.getElementsByClassName("sourceCode");
7 | for (var i = 0; i < codeList.length; i++) {
8 | var linkList = codeList[i].getElementsByTagName('a');
9 | for (var j = 0; j < linkList.length; j++) {
10 | if (linkList[j].innerHTML === "") {
11 | linkList[j].setAttribute('aria-hidden', 'true');
12 | }
13 | }
14 | }
15 | });
16 |
--------------------------------------------------------------------------------
/docs/articles/univariate-time-series_files/accessible-code-block-0.0.1/empty-anchor.js:
--------------------------------------------------------------------------------
1 | // Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) -->
2 | // v0.0.1
3 | // Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020.
4 |
5 | document.addEventListener('DOMContentLoaded', function() {
6 | const codeList = document.getElementsByClassName("sourceCode");
7 | for (var i = 0; i < codeList.length; i++) {
8 | var linkList = codeList[i].getElementsByTagName('a');
9 | for (var j = 0; j < linkList.length; j++) {
10 | if (linkList[j].innerHTML === "") {
11 | linkList[j].setAttribute('aria-hidden', 'true');
12 | }
13 | }
14 | }
15 | });
16 |
--------------------------------------------------------------------------------
/docs/bootstrap-toc.css:
--------------------------------------------------------------------------------
1 | /*!
2 | * Bootstrap Table of Contents v0.4.1 (http://afeld.github.io/bootstrap-toc/)
3 | * Copyright 2015 Aidan Feldman
4 | * Licensed under MIT (https://github.com/afeld/bootstrap-toc/blob/gh-pages/LICENSE.md) */
5 |
6 | /* modified from https://github.com/twbs/bootstrap/blob/94b4076dd2efba9af71f0b18d4ee4b163aa9e0dd/docs/assets/css/src/docs.css#L548-L601 */
7 |
8 | /* All levels of nav */
9 | nav[data-toggle='toc'] .nav > li > a {
10 | display: block;
11 | padding: 4px 20px;
12 | font-size: 13px;
13 | font-weight: 500;
14 | color: #767676;
15 | }
16 | nav[data-toggle='toc'] .nav > li > a:hover,
17 | nav[data-toggle='toc'] .nav > li > a:focus {
18 | padding-left: 19px;
19 | color: #563d7c;
20 | text-decoration: none;
21 | background-color: transparent;
22 | border-left: 1px solid #563d7c;
23 | }
24 | nav[data-toggle='toc'] .nav > .active > a,
25 | nav[data-toggle='toc'] .nav > .active:hover > a,
26 | nav[data-toggle='toc'] .nav > .active:focus > a {
27 | padding-left: 18px;
28 | font-weight: bold;
29 | color: #563d7c;
30 | background-color: transparent;
31 | border-left: 2px solid #563d7c;
32 | }
33 |
34 | /* Nav: second level (shown on .active) */
35 | nav[data-toggle='toc'] .nav .nav {
36 | display: none; /* Hide by default, but at >768px, show it */
37 | padding-bottom: 10px;
38 | }
39 | nav[data-toggle='toc'] .nav .nav > li > a {
40 | padding-top: 1px;
41 | padding-bottom: 1px;
42 | padding-left: 30px;
43 | font-size: 12px;
44 | font-weight: normal;
45 | }
46 | nav[data-toggle='toc'] .nav .nav > li > a:hover,
47 | nav[data-toggle='toc'] .nav .nav > li > a:focus {
48 | padding-left: 29px;
49 | }
50 | nav[data-toggle='toc'] .nav .nav > .active > a,
51 | nav[data-toggle='toc'] .nav .nav > .active:hover > a,
52 | nav[data-toggle='toc'] .nav .nav > .active:focus > a {
53 | padding-left: 28px;
54 | font-weight: 500;
55 | }
56 |
57 | /* from https://github.com/twbs/bootstrap/blob/e38f066d8c203c3e032da0ff23cd2d6098ee2dd6/docs/assets/css/src/docs.css#L631-L634 */
58 | nav[data-toggle='toc'] .nav > .active > ul {
59 | display: block;
60 | }
61 |
--------------------------------------------------------------------------------
/docs/bootstrap-toc.js:
--------------------------------------------------------------------------------
1 | /*!
2 | * Bootstrap Table of Contents v0.4.1 (http://afeld.github.io/bootstrap-toc/)
3 | * Copyright 2015 Aidan Feldman
4 | * Licensed under MIT (https://github.com/afeld/bootstrap-toc/blob/gh-pages/LICENSE.md) */
5 | (function() {
6 | 'use strict';
7 |
8 | window.Toc = {
9 | helpers: {
10 | // return all matching elements in the set, or their descendants
11 | findOrFilter: function($el, selector) {
12 | // http://danielnouri.org/notes/2011/03/14/a-jquery-find-that-also-finds-the-root-element/
13 | // http://stackoverflow.com/a/12731439/358804
14 | var $descendants = $el.find(selector);
15 | return $el.filter(selector).add($descendants).filter(':not([data-toc-skip])');
16 | },
17 |
18 | generateUniqueIdBase: function(el) {
19 | var text = $(el).text();
20 | var anchor = text.trim().toLowerCase().replace(/[^A-Za-z0-9]+/g, '-');
21 | return anchor || el.tagName.toLowerCase();
22 | },
23 |
24 | generateUniqueId: function(el) {
25 | var anchorBase = this.generateUniqueIdBase(el);
26 | for (var i = 0; ; i++) {
27 | var anchor = anchorBase;
28 | if (i > 0) {
29 | // add suffix
30 | anchor += '-' + i;
31 | }
32 | // check if ID already exists
33 | if (!document.getElementById(anchor)) {
34 | return anchor;
35 | }
36 | }
37 | },
38 |
39 | generateAnchor: function(el) {
40 | if (el.id) {
41 | return el.id;
42 | } else {
43 | var anchor = this.generateUniqueId(el);
44 | el.id = anchor;
45 | return anchor;
46 | }
47 | },
48 |
49 | createNavList: function() {
50 | return $(' ');
51 | },
52 |
53 | createChildNavList: function($parent) {
54 | var $childList = this.createNavList();
55 | $parent.append($childList);
56 | return $childList;
57 | },
58 |
59 | generateNavEl: function(anchor, text) {
60 | var $a = $('');
61 | $a.attr('href', '#' + anchor);
62 | $a.text(text);
63 | var $li = $('