├── .gitattributes
├── .github
└── workflows
│ └── build.yml
├── .gitignore
├── CHANGELOG.md
├── Gemfile
├── LICENSE-MIT
├── README.md
├── Rakefile
├── UNLICENSE
├── ext
└── stl
│ ├── ext.cpp
│ ├── extconf.rb
│ └── stl.hpp
├── lib
├── stl-rb.rb
├── stl.rb
└── stl
│ └── version.rb
├── stl-rb.gemspec
└── test
├── stl_test.rb
└── test_helper.rb
/.gitattributes:
--------------------------------------------------------------------------------
1 | ext/stl/stl.hpp linguist-vendored
2 |
--------------------------------------------------------------------------------
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: build
2 | on: [push, pull_request]
3 | jobs:
4 | build:
5 | strategy:
6 | fail-fast: false
7 | matrix:
8 | os: [ubuntu-latest, macos-latest, windows-latest]
9 | runs-on: ${{ matrix.os }}
10 | steps:
11 | - uses: actions/checkout@v4
12 | - uses: ruby/setup-ruby@v1
13 | with:
14 | ruby-version: 3.4
15 | bundler-cache: true
16 | - run: bundle exec rake compile
17 | - run: bundle exec rake test
18 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /.bundle/
2 | /.yardoc
3 | /_yardoc/
4 | /coverage/
5 | /doc/
6 | /pkg/
7 | /spec/reports/
8 | /tmp/
9 | *.lock
10 | *.bundle
11 | *.so
12 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | ## 0.3.0 (2024-10-22)
2 |
3 | - Dropped support for Ruby < 3.1
4 |
5 | ## 0.2.3 (2024-01-26)
6 |
7 | - Fixed bug with `inner_loops` and `outer_loops`
8 |
9 | ## 0.2.2 (2023-06-20)
10 |
11 | - Fixed bug when jump > 1
12 |
13 | ## 0.2.1 (2023-05-11)
14 |
15 | - Fixed error on Fedora
16 |
17 | ## 0.2.0 (2023-02-01)
18 |
19 | - Raise error when `period` is less than 2
20 | - Dropped support for Ruby < 2.7
21 |
22 | ## 0.1.3 (2021-12-16)
23 |
24 | - Fixed installation error on macOS 12
25 |
26 | ## 0.1.2 (2021-10-24)
27 |
28 | - Added `seasonal_strength` and `trend_strength` methods
29 | - Improved plot width and height
30 |
31 | ## 0.1.1 (2021-10-20)
32 |
33 | - Added `plot` method
34 |
35 | ## 0.1.0 (2021-10-16)
36 |
37 | - First release
38 |
--------------------------------------------------------------------------------
/Gemfile:
--------------------------------------------------------------------------------
1 | source "https://rubygems.org"
2 |
3 | gemspec
4 |
5 | gem "rake"
6 | gem "rake-compiler"
7 | gem "minitest", ">= 5"
8 | gem "vega"
9 |
--------------------------------------------------------------------------------
/LICENSE-MIT:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2021-2024 Contributors
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in
13 | all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 | THE SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # STL Ruby
2 |
3 | Seasonal-trend decomposition for Ruby
4 |
5 | [](https://github.com/ankane/stl-ruby/actions)
6 |
7 | ## Installation
8 |
9 | Add this line to your application’s Gemfile:
10 |
11 | ```ruby
12 | gem "stl-rb"
13 | ```
14 |
15 | ## Getting Started
16 |
17 | Decompose a time series
18 |
19 | ```ruby
20 | series = {
21 | Date.parse("2025-01-01") => 100,
22 | Date.parse("2025-01-02") => 150,
23 | Date.parse("2025-01-03") => 136,
24 | # ...
25 | }
26 |
27 | Stl.decompose(series, period: 7)
28 | ```
29 |
30 | Works great with [Groupdate](https://github.com/ankane/groupdate)
31 |
32 | ```ruby
33 | series = User.group_by_day(:created_at).count
34 | Stl.decompose(series, period: 7)
35 | ```
36 |
37 | Series can also be an array without times
38 |
39 | ```ruby
40 | series = [100, 150, 136, ...]
41 | Stl.decompose(series, period: 7)
42 | ```
43 |
44 | Use robustness iterations
45 |
46 | ```ruby
47 | Stl.decompose(series, period: 7, robust: true)
48 | ```
49 |
50 | ## Options
51 |
52 | Pass options
53 |
54 | ```ruby
55 | Stl.decompose(
56 | series,
57 | period: 7, # period of the seasonal component
58 | seasonal_length: 7, # length of the seasonal smoother
59 | trend_length: 15, # length of the trend smoother
60 | low_pass_length: 7, # length of the low-pass filter
61 | seasonal_degree: 0, # degree of locally-fitted polynomial in seasonal smoothing
62 | trend_degree: 1, # degree of locally-fitted polynomial in trend smoothing
63 | low_pass_degree: 1, # degree of locally-fitted polynomial in low-pass smoothing
64 | seasonal_jump: 1, # skipping value for seasonal smoothing
65 | trend_jump: 2, # skipping value for trend smoothing
66 | low_pass_jump: 1, # skipping value for low-pass smoothing
67 | inner_loops: 2, # number of loops for updating the seasonal and trend components
68 | outer_loops: 0, # number of iterations of robust fitting
69 | robust: false # if robustness iterations are to be used
70 | )
71 | ```
72 |
73 | ## Plotting
74 |
75 | Add [Vega](https://github.com/ankane/vega) to your application’s Gemfile:
76 |
77 | ```ruby
78 | gem "vega"
79 | ```
80 |
81 | And use:
82 |
83 | ```ruby
84 | Stl.plot(series, decompose_result)
85 | ```
86 |
87 | ## Strength
88 |
89 | Get the seasonal strength
90 |
91 | ```ruby
92 | Stl.seasonal_strength(decompose_result)
93 | ```
94 |
95 | Get the trend strength
96 |
97 | ```ruby
98 | Stl.trend_strength(decompose_result)
99 | ```
100 |
101 | ## Credits
102 |
103 | This library was ported from the [Fortran implementation](https://www.netlib.org/a/stl).
104 |
105 | ## References
106 |
107 | - [STL: A Seasonal-Trend Decomposition Procedure Based on Loess](https://www.scb.se/contentassets/ca21efb41fee47d293bbee5bf7be7fb3/stl-a-seasonal-trend-decomposition-procedure-based-on-loess.pdf)
108 | - [Measuring strength of trend and seasonality](https://otexts.com/fpp2/seasonal-strength.html)
109 |
110 | ## History
111 |
112 | View the [changelog](https://github.com/ankane/stl-ruby/blob/master/CHANGELOG.md)
113 |
114 | ## Contributing
115 |
116 | Everyone is encouraged to help improve this project. Here are a few ways you can help:
117 |
118 | - [Report bugs](https://github.com/ankane/stl-ruby/issues)
119 | - Fix bugs and [submit pull requests](https://github.com/ankane/stl-ruby/pulls)
120 | - Write, clarify, or fix documentation
121 | - Suggest or add new features
122 |
123 | To get started with development:
124 |
125 | ```sh
126 | git clone https://github.com/ankane/stl-ruby.git
127 | cd stl-ruby
128 | bundle install
129 | bundle exec rake compile
130 | bundle exec rake test
131 | ```
132 |
--------------------------------------------------------------------------------
/Rakefile:
--------------------------------------------------------------------------------
1 | require "bundler/gem_tasks"
2 | require "rake/testtask"
3 | require "rake/extensiontask"
4 |
5 | task default: :test
6 | Rake::TestTask.new do |t|
7 | t.libs << "test"
8 | t.pattern = "test/**/*_test.rb"
9 | end
10 |
11 | Rake::ExtensionTask.new("stl") do |ext|
12 | ext.name = "ext"
13 | ext.lib_dir = "lib/stl"
14 | end
15 |
16 | task :remove_ext do
17 | path = "lib/stl/ext.bundle"
18 | File.unlink(path) if File.exist?(path)
19 | end
20 |
21 | Rake::Task["build"].enhance [:remove_ext]
22 |
--------------------------------------------------------------------------------
/UNLICENSE:
--------------------------------------------------------------------------------
1 | This is free and unencumbered software released into the public domain.
2 |
3 | Anyone is free to copy, modify, publish, use, compile, sell, or
4 | distribute this software, either in source code form or as a compiled
5 | binary, for any purpose, commercial or non-commercial, and by any
6 | means.
7 |
8 | In jurisdictions that recognize copyright laws, the author or authors
9 | of this software dedicate any and all copyright interest in the
10 | software to the public domain. We make this dedication for the benefit
11 | of the public at large and to the detriment of our heirs and
12 | successors. We intend this dedication to be an overt act of
13 | relinquishment in perpetuity of all present and future rights to this
14 | software under copyright law.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
22 | OTHER DEALINGS IN THE SOFTWARE.
23 |
24 | For more information, please refer to
25 |
--------------------------------------------------------------------------------
/ext/stl/ext.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "stl.hpp"
5 |
6 | Rice::Array to_a(std::vector& x) {
7 | auto a = Rice::Array();
8 | for (auto v : x) {
9 | a.push(v);
10 | }
11 | return a;
12 | }
13 |
14 | extern "C"
15 | void Init_ext() {
16 | auto rb_mStl = Rice::define_module("Stl");
17 |
18 | Rice::define_class_under(rb_mStl, "StlParams")
19 | .define_constructor(Rice::Constructor())
20 | .define_method("seasonal_length", &stl::StlParams::seasonal_length)
21 | .define_method("trend_length", &stl::StlParams::trend_length)
22 | .define_method("low_pass_length", &stl::StlParams::low_pass_length)
23 | .define_method("seasonal_degree", &stl::StlParams::seasonal_degree)
24 | .define_method("trend_degree", &stl::StlParams::trend_degree)
25 | .define_method("low_pass_degree", &stl::StlParams::low_pass_degree)
26 | .define_method("seasonal_jump", &stl::StlParams::seasonal_jump)
27 | .define_method("trend_jump", &stl::StlParams::trend_jump)
28 | .define_method("low_pass_jump", &stl::StlParams::low_pass_jump)
29 | .define_method("inner_loops", &stl::StlParams::inner_loops)
30 | .define_method("outer_loops", &stl::StlParams::outer_loops)
31 | .define_method("robust", &stl::StlParams::robust)
32 | .define_method(
33 | "fit",
34 | [](stl::StlParams& self, std::vector series, size_t period, bool weights) {
35 | auto result = self.fit(series, period);
36 |
37 | auto ret = Rice::Hash();
38 | ret[Rice::Symbol("seasonal")] = to_a(result.seasonal);
39 | ret[Rice::Symbol("trend")] = to_a(result.trend);
40 | ret[Rice::Symbol("remainder")] = to_a(result.remainder);
41 | if (weights) {
42 | ret[Rice::Symbol("weights")] = to_a(result.weights);
43 | }
44 | return ret;
45 | });
46 | }
47 |
--------------------------------------------------------------------------------
/ext/stl/extconf.rb:
--------------------------------------------------------------------------------
1 | require "mkmf-rice"
2 |
3 | $CXXFLAGS += " -std=c++17 $(optflags)"
4 |
5 | create_makefile("stl/ext")
6 |
--------------------------------------------------------------------------------
/ext/stl/stl.hpp:
--------------------------------------------------------------------------------
1 | /*!
2 | * STL C++ v0.2.0
3 | * https://github.com/ankane/stl-cpp
4 | * Unlicense OR MIT License
5 | *
6 | * Ported from https://www.netlib.org/a/stl
7 | *
8 | * Cleveland, R. B., Cleveland, W. S., McRae, J. E., & Terpenning, I. (1990).
9 | * STL: A Seasonal-Trend Decomposition Procedure Based on Loess.
10 | * Journal of Official Statistics, 6(1), 3-33.
11 | *
12 | * Bandara, K., Hyndman, R. J., & Bergmeir, C. (2021).
13 | * MSTL: A Seasonal-Trend Decomposition Algorithm for Time Series with Multiple Seasonal Patterns.
14 | * arXiv:2107.13462 [stat.AP]. https://doi.org/10.48550/arXiv.2107.13462
15 | */
16 |
17 | #pragma once
18 |
19 | #include
20 | #include
21 | #include
22 | #include
23 | #include
24 | #include
25 | #include
26 |
27 | #if __cplusplus >= 202002L
28 | #include
29 | #endif
30 |
31 | namespace stl {
32 |
33 | namespace {
34 |
35 | template
36 | bool est(const T* y, size_t n, size_t len, int ideg, T xs, T* ys, size_t nleft, size_t nright, T* w, bool userw, const T* rw) {
37 | auto range = ((T) n) - 1.0;
38 | auto h = std::max(xs - ((T) nleft), ((T) nright) - xs);
39 |
40 | if (len > n) {
41 | h += (T) ((len - n) / 2);
42 | }
43 |
44 | auto h9 = 0.999 * h;
45 | auto h1 = 0.001 * h;
46 |
47 | // compute weights
48 | auto a = 0.0;
49 | for (auto j = nleft; j <= nright; j++) {
50 | w[j - 1] = 0.0;
51 | auto r = std::abs(((T) j) - xs);
52 | if (r <= h9) {
53 | if (r <= h1) {
54 | w[j - 1] = 1.0;
55 | } else {
56 | w[j - 1] = (T) std::pow(1.0 - std::pow(r / h, 3), 3);
57 | }
58 | if (userw) {
59 | w[j - 1] *= rw[j - 1];
60 | }
61 | a += w[j - 1];
62 | }
63 | }
64 |
65 | if (a <= 0.0) {
66 | return false;
67 | } else { // weighted least squares
68 | for (auto j = nleft; j <= nright; j++) { // make sum of w(j) == 1
69 | w[j - 1] /= (T) a;
70 | }
71 |
72 | if (h > 0.0 && ideg > 0) { // use linear fit
73 | auto a = 0.0;
74 | for (auto j = nleft; j <= nright; j++) { // weighted center of x values
75 | a += w[j - 1] * ((T) j);
76 | }
77 | auto b = xs - a;
78 | auto c = 0.0;
79 | for (auto j = nleft; j <= nright; j++) {
80 | c += w[j - 1] * std::pow(((T) j) - a, 2);
81 | }
82 | if (std::sqrt(c) > 0.001 * range) {
83 | b /= c;
84 |
85 | // points are spread out enough to compute slope
86 | for (auto j = nleft; j <= nright; j++) {
87 | w[j - 1] *= (T) (b * (((T) j) - a) + 1.0);
88 | }
89 | }
90 | }
91 |
92 | *ys = 0.0;
93 | for (auto j = nleft; j <= nright; j++) {
94 | *ys += w[j - 1] * y[j - 1];
95 | }
96 |
97 | return true;
98 | }
99 | }
100 |
101 | template
102 | void ess(const T* y, size_t n, size_t len, int ideg, size_t njump, bool userw, const T* rw, T* ys, T* res) {
103 | if (n < 2) {
104 | ys[0] = y[0];
105 | return;
106 | }
107 |
108 | size_t nleft = 0;
109 | size_t nright = 0;
110 |
111 | auto newnj = std::min(njump, n - 1);
112 | if (len >= n) {
113 | nleft = 1;
114 | nright = n;
115 | for (size_t i = 1; i <= n; i += newnj) {
116 | auto ok = est(y, n, len, ideg, (T) i, &ys[i - 1], nleft, nright, res, userw, rw);
117 | if (!ok) {
118 | ys[i - 1] = y[i - 1];
119 | }
120 | }
121 | } else if (newnj == 1) { // newnj equal to one, len less than n
122 | auto nsh = (len + 1) / 2;
123 | nleft = 1;
124 | nright = len;
125 | for (size_t i = 1; i <= n; i++) { // fitted value at i
126 | if (i > nsh && nright != n) {
127 | nleft += 1;
128 | nright += 1;
129 | }
130 | auto ok = est(y, n, len, ideg, (T) i, &ys[i - 1], nleft, nright, res, userw, rw);
131 | if (!ok) {
132 | ys[i - 1] = y[i - 1];
133 | }
134 | }
135 | } else { // newnj greater than one, len less than n
136 | auto nsh = (len + 1) / 2;
137 | for (size_t i = 1; i <= n; i += newnj) { // fitted value at i
138 | if (i < nsh) {
139 | nleft = 1;
140 | nright = len;
141 | } else if (i >= n - nsh + 1) {
142 | nleft = n - len + 1;
143 | nright = n;
144 | } else {
145 | nleft = i - nsh + 1;
146 | nright = len + i - nsh;
147 | }
148 | auto ok = est(y, n, len, ideg, (T) i, &ys[i - 1], nleft, nright, res, userw, rw);
149 | if (!ok) {
150 | ys[i - 1] = y[i - 1];
151 | }
152 | }
153 | }
154 |
155 | if (newnj != 1) {
156 | for (size_t i = 1; i <= n - newnj; i += newnj) {
157 | auto delta = (ys[i + newnj - 1] - ys[i - 1]) / ((T) newnj);
158 | for (auto j = i + 1; j <= i + newnj - 1; j++) {
159 | ys[j - 1] = ys[i - 1] + delta * ((T) (j - i));
160 | }
161 | }
162 | auto k = ((n - 1) / newnj) * newnj + 1;
163 | if (k != n) {
164 | auto ok = est(y, n, len, ideg, (T) n, &ys[n - 1], nleft, nright, res, userw, rw);
165 | if (!ok) {
166 | ys[n - 1] = y[n - 1];
167 | }
168 | if (k != n - 1) {
169 | auto delta = (ys[n - 1] - ys[k - 1]) / ((T) (n - k));
170 | for (auto j = k + 1; j <= n - 1; j++) {
171 | ys[j - 1] = ys[k - 1] + delta * ((T) (j - k));
172 | }
173 | }
174 | }
175 | }
176 | }
177 |
178 | template
179 | void ma(const T* x, size_t n, size_t len, T* ave) {
180 | auto newn = n - len + 1;
181 | double flen = (T) len;
182 | double v = 0.0;
183 |
184 | // get the first average
185 | for (size_t i = 0; i < len; i++) {
186 | v += x[i];
187 | }
188 |
189 | ave[0] = (T) (v / flen);
190 | if (newn > 1) {
191 | size_t k = len;
192 | size_t m = 0;
193 | for (size_t j = 1; j < newn; j++) {
194 | // window down the array
195 | v = v - x[m] + x[k];
196 | ave[j] = (T) (v / flen);
197 | k += 1;
198 | m += 1;
199 | }
200 | }
201 | }
202 |
203 | template
204 | void fts(const T* x, size_t n, size_t np, T* trend, T* work) {
205 | ma(x, n, np, trend);
206 | ma(trend, n - np + 1, np, work);
207 | ma(work, n - 2 * np + 2, 3, trend);
208 | }
209 |
210 | template
211 | void rwts(const T* y, size_t n, const T* fit, T* rw) {
212 | for (size_t i = 0; i < n; i++) {
213 | rw[i] = std::abs(y[i] - fit[i]);
214 | }
215 |
216 | auto mid1 = (n - 1) / 2;
217 | auto mid2 = n / 2;
218 |
219 | // sort
220 | std::sort(rw, rw + n);
221 |
222 | auto cmad = 3.0 * (rw[mid1] + rw[mid2]); // 6 * median abs resid
223 | auto c9 = 0.999 * cmad;
224 | auto c1 = 0.001 * cmad;
225 |
226 | for (size_t i = 0; i < n; i++) {
227 | auto r = std::abs(y[i] - fit[i]);
228 | if (r <= c1) {
229 | rw[i] = 1.0;
230 | } else if (r <= c9) {
231 | rw[i] = (T) std::pow(1.0 - std::pow(r / cmad, 2), 2);
232 | } else {
233 | rw[i] = 0.0;
234 | }
235 | }
236 | }
237 |
238 | template
239 | void ss(const T* y, size_t n, size_t np, size_t ns, int isdeg, size_t nsjump, bool userw, T* rw, T* season, T* work1, T* work2, T* work3, T* work4) {
240 | for (size_t j = 1; j <= np; j++) {
241 | size_t k = (n - j) / np + 1;
242 |
243 | for (size_t i = 1; i <= k; i++) {
244 | work1[i - 1] = y[(i - 1) * np + j - 1];
245 | }
246 | if (userw) {
247 | for (size_t i = 1; i <= k; i++) {
248 | work3[i - 1] = rw[(i - 1) * np + j - 1];
249 | }
250 | }
251 | ess(work1, k, ns, isdeg, nsjump, userw, work3, work2 + 1, work4);
252 | T xs = 0.0;
253 | auto nright = std::min(ns, k);
254 | auto ok = est(work1, k, ns, isdeg, xs, &work2[0], 1, nright, work4, userw, work3);
255 | if (!ok) {
256 | work2[0] = work2[1];
257 | }
258 | xs = k + 1;
259 | size_t nleft = (size_t) std::max(1, (int) k - (int) ns + 1);
260 | ok = est(work1, k, ns, isdeg, xs, &work2[k + 1], nleft, k, work4, userw, work3);
261 | if (!ok) {
262 | work2[k + 1] = work2[k];
263 | }
264 | for (size_t m = 1; m <= k + 2; m++) {
265 | season[(m - 1) * np + j - 1] = work2[m - 1];
266 | }
267 | }
268 | }
269 |
270 | template
271 | void onestp(const T* y, size_t n, size_t np, size_t ns, size_t nt, size_t nl, int isdeg, int itdeg, int ildeg, size_t nsjump, size_t ntjump, size_t nljump, size_t ni, bool userw, T* rw, T* season, T* trend, T* work1, T* work2, T* work3, T* work4, T* work5) {
272 | for (size_t j = 0; j < ni; j++) {
273 | for (size_t i = 0; i < n; i++) {
274 | work1[i] = y[i] - trend[i];
275 | }
276 |
277 | ss(work1, n, np, ns, isdeg, nsjump, userw, rw, work2, work3, work4, work5, season);
278 | fts(work2, n + 2 * np, np, work3, work1);
279 | ess(work3, n, nl, ildeg, nljump, false, work4, work1, work5);
280 | for (size_t i = 0; i < n; i++) {
281 | season[i] = work2[np + i] - work1[i];
282 | }
283 | for (size_t i = 0; i < n; i++) {
284 | work1[i] = y[i] - season[i];
285 | }
286 | ess(work1, n, nt, itdeg, ntjump, userw, rw, trend, work3);
287 | }
288 | }
289 |
290 | template
291 | void stl(const T* y, size_t n, size_t np, size_t ns, size_t nt, size_t nl, int isdeg, int itdeg, int ildeg, size_t nsjump, size_t ntjump, size_t nljump, size_t ni, size_t no, T* rw, T* season, T* trend) {
292 | if (ns < 3) {
293 | throw std::invalid_argument("seasonal_length must be at least 3");
294 | }
295 | if (nt < 3) {
296 | throw std::invalid_argument("trend_length must be at least 3");
297 | }
298 | if (nl < 3) {
299 | throw std::invalid_argument("low_pass_length must be at least 3");
300 | }
301 | if (np < 2) {
302 | throw std::invalid_argument("period must be at least 2");
303 | }
304 |
305 | if (isdeg != 0 && isdeg != 1) {
306 | throw std::invalid_argument("seasonal_degree must be 0 or 1");
307 | }
308 | if (itdeg != 0 && itdeg != 1) {
309 | throw std::invalid_argument("trend_degree must be 0 or 1");
310 | }
311 | if (ildeg != 0 && ildeg != 1) {
312 | throw std::invalid_argument("low_pass_degree must be 0 or 1");
313 | }
314 |
315 | if (ns % 2 != 1) {
316 | throw std::invalid_argument("seasonal_length must be odd");
317 | }
318 | if (nt % 2 != 1) {
319 | throw std::invalid_argument("trend_length must be odd");
320 | }
321 | if (nl % 2 != 1) {
322 | throw std::invalid_argument("low_pass_length must be odd");
323 | }
324 |
325 | auto work1 = std::vector(n + 2 * np);
326 | auto work2 = std::vector(n + 2 * np);
327 | auto work3 = std::vector(n + 2 * np);
328 | auto work4 = std::vector(n + 2 * np);
329 | auto work5 = std::vector(n + 2 * np);
330 |
331 | auto userw = false;
332 | size_t k = 0;
333 |
334 | while (true) {
335 | onestp(y, n, np, ns, nt, nl, isdeg, itdeg, ildeg, nsjump, ntjump, nljump, ni, userw, rw, season, trend, work1.data(), work2.data(), work3.data(), work4.data(), work5.data());
336 | k += 1;
337 | if (k > no) {
338 | break;
339 | }
340 | for (size_t i = 0; i < n; i++) {
341 | work1[i] = trend[i] + season[i];
342 | }
343 | rwts(y, n, work1.data(), rw);
344 | userw = true;
345 | }
346 |
347 | if (no <= 0) {
348 | for (size_t i = 0; i < n; i++) {
349 | rw[i] = 1.0;
350 | }
351 | }
352 | }
353 |
354 | template
355 | double var(const std::vector& series) {
356 | auto mean = std::accumulate(series.begin(), series.end(), 0.0) / series.size();
357 | double sum = 0.0;
358 | for (auto v : series) {
359 | double diff = v - mean;
360 | sum += diff * diff;
361 | }
362 | return sum / (series.size() - 1);
363 | }
364 |
365 | template
366 | double strength(const std::vector& component, const std::vector& remainder) {
367 | std::vector sr;
368 | sr.reserve(remainder.size());
369 | for (size_t i = 0; i < remainder.size(); i++) {
370 | sr.push_back(component[i] + remainder[i]);
371 | }
372 | return std::max(0.0, 1.0 - var(remainder) / var(sr));
373 | }
374 |
375 | }
376 |
377 | /// A STL result.
378 | template
379 | class StlResult {
380 | public:
381 | /// Returns the seasonal component.
382 | std::vector seasonal;
383 |
384 | /// Returns the trend component.
385 | std::vector trend;
386 |
387 | /// Returns the remainder.
388 | std::vector remainder;
389 |
390 | /// Returns the weights.
391 | std::vector weights;
392 |
393 | /// Returns the seasonal strength.
394 | inline double seasonal_strength() const {
395 | return strength(seasonal, remainder);
396 | }
397 |
398 | /// Returns the trend strength.
399 | inline double trend_strength() const {
400 | return strength(trend, remainder);
401 | }
402 | };
403 |
404 | /// A set of STL parameters.
405 | class StlParams {
406 | public:
407 | /// @private
408 | std::optional ns_ = std::nullopt;
409 | private:
410 | std::optional nt_ = std::nullopt;
411 | std::optional nl_ = std::nullopt;
412 | int isdeg_ = 0;
413 | int itdeg_ = 1;
414 | std::optional ildeg_ = std::nullopt;
415 | std::optional nsjump_ = std::nullopt;
416 | std::optional ntjump_ = std::nullopt;
417 | std::optional nljump_ = std::nullopt;
418 | std::optional ni_ = std::nullopt;
419 | std::optional no_ = std::nullopt;
420 | bool robust_ = false;
421 |
422 | public:
423 | /// Sets the length of the seasonal smoother.
424 | inline StlParams seasonal_length(size_t length) {
425 | this->ns_ = length;
426 | return *this;
427 | }
428 |
429 | /// Sets the length of the trend smoother.
430 | inline StlParams trend_length(size_t length) {
431 | this->nt_ = length;
432 | return *this;
433 | }
434 |
435 | /// Sets the length of the low-pass filter.
436 | inline StlParams low_pass_length(size_t length) {
437 | this->nl_ = length;
438 | return *this;
439 | }
440 |
441 | /// Sets the degree of locally-fitted polynomial in seasonal smoothing.
442 | inline StlParams seasonal_degree(int degree) {
443 | this->isdeg_ = degree;
444 | return *this;
445 | }
446 |
447 | /// Sets the degree of locally-fitted polynomial in trend smoothing.
448 | inline StlParams trend_degree(int degree) {
449 | this->itdeg_ = degree;
450 | return *this;
451 | }
452 |
453 | /// Sets the degree of locally-fitted polynomial in low-pass smoothing.
454 | inline StlParams low_pass_degree(int degree) {
455 | this->ildeg_ = degree;
456 | return *this;
457 | }
458 |
459 | /// Sets the skipping value for seasonal smoothing.
460 | inline StlParams seasonal_jump(size_t jump) {
461 | this->nsjump_ = jump;
462 | return *this;
463 | }
464 |
465 | /// Sets the skipping value for trend smoothing.
466 | inline StlParams trend_jump(size_t jump) {
467 | this->ntjump_ = jump;
468 | return *this;
469 | }
470 |
471 | /// Sets the skipping value for low-pass smoothing.
472 | inline StlParams low_pass_jump(size_t jump) {
473 | this->nljump_ = jump;
474 | return *this;
475 | }
476 |
477 | /// Sets the number of loops for updating the seasonal and trend components.
478 | inline StlParams inner_loops(size_t loops) {
479 | this->ni_ = loops;
480 | return *this;
481 | }
482 |
483 | /// Sets the number of iterations of robust fitting.
484 | inline StlParams outer_loops(size_t loops) {
485 | this->no_ = loops;
486 | return *this;
487 | }
488 |
489 | /// Sets whether robustness iterations are to be used.
490 | inline StlParams robust(bool robust) {
491 | this->robust_ = robust;
492 | return *this;
493 | }
494 |
495 | /// Decomposes a time series from an array.
496 | template
497 | StlResult fit(const T* series, size_t series_size, size_t period) const;
498 |
499 | /// Decomposes a time series from a vector.
500 | template
501 | StlResult fit(const std::vector& series, size_t period) const;
502 |
503 | #if __cplusplus >= 202002L
504 | /// Decomposes a time series from a span.
505 | template
506 | StlResult fit(std::span series, size_t period) const;
507 | #endif
508 | };
509 |
510 | /// Creates a new set of STL parameters.
511 | inline StlParams params() {
512 | return StlParams();
513 | }
514 |
515 | template
516 | StlResult StlParams::fit(const T* series, size_t series_size, size_t period) const {
517 | auto y = series;
518 | auto np = period;
519 | auto n = series_size;
520 |
521 | if (n < 2 * np) {
522 | throw std::invalid_argument("series has less than two periods");
523 | }
524 |
525 | auto ns = this->ns_.value_or(np);
526 |
527 | auto isdeg = this->isdeg_;
528 | auto itdeg = this->itdeg_;
529 |
530 | auto res = StlResult {
531 | std::vector(n),
532 | std::vector(n),
533 | std::vector(),
534 | std::vector(n)
535 | };
536 |
537 | auto ildeg = this->ildeg_.value_or(itdeg);
538 | auto newns = std::max(ns, (size_t) 3);
539 | if (newns % 2 == 0) {
540 | newns += 1;
541 | }
542 |
543 | auto newnp = std::max(np, (size_t) 2);
544 | auto nt = (size_t) ceil((1.5 * newnp) / (1.0 - 1.5 / (float) newns));
545 | nt = this->nt_.value_or(nt);
546 | nt = std::max(nt, (size_t) 3);
547 | if (nt % 2 == 0) {
548 | nt += 1;
549 | }
550 |
551 | auto nl = this->nl_.value_or(newnp);
552 | if (nl % 2 == 0 && !this->nl_.has_value()) {
553 | nl += 1;
554 | }
555 |
556 | auto ni = this->ni_.value_or(this->robust_ ? 1 : 2);
557 | auto no = this->no_.value_or(this->robust_ ? 15 : 0);
558 |
559 | auto nsjump = this->nsjump_.value_or((size_t) ceil(((float) newns) / 10.0));
560 | auto ntjump = this->ntjump_.value_or((size_t) ceil(((float) nt) / 10.0));
561 | auto nljump = this->nljump_.value_or((size_t) ceil(((float) nl) / 10.0));
562 |
563 | stl(y, n, newnp, newns, nt, nl, isdeg, itdeg, ildeg, nsjump, ntjump, nljump, ni, no, res.weights.data(), res.seasonal.data(), res.trend.data());
564 |
565 | res.remainder.reserve(n);
566 | for (size_t i = 0; i < n; i++) {
567 | res.remainder.push_back(y[i] - res.seasonal[i] - res.trend[i]);
568 | }
569 |
570 | return res;
571 | }
572 |
573 | template
574 | StlResult StlParams::fit(const std::vector& series, size_t period) const {
575 | return StlParams::fit(series.data(), series.size(), period);
576 | }
577 |
578 | #if __cplusplus >= 202002L
579 | template
580 | StlResult StlParams::fit(std::span series, size_t period) const {
581 | return StlParams::fit(series.data(), series.size(), period);
582 | }
583 | #endif
584 |
585 | /// A MSTL result.
586 | template
587 | class MstlResult {
588 | public:
589 | /// Returns the seasonal component.
590 | std::vector> seasonal;
591 |
592 | /// Returns the trend component.
593 | std::vector trend;
594 |
595 | /// Returns the remainder.
596 | std::vector remainder;
597 |
598 | /// Returns the seasonal strength.
599 | inline std::vector seasonal_strength() const {
600 | std::vector res;
601 | for (auto& s : seasonal) {
602 | res.push_back(strength(s, remainder));
603 | }
604 | return res;
605 | }
606 |
607 | /// Returns the trend strength.
608 | inline double trend_strength() const {
609 | return strength(trend, remainder);
610 | }
611 | };
612 |
613 | /// A set of MSTL parameters.
614 | class MstlParams {
615 | size_t iterate_ = 2;
616 | std::optional lambda_ = std::nullopt;
617 | std::optional> swin_ = std::nullopt;
618 | StlParams stl_params_;
619 |
620 | public:
621 | /// Sets the number of iterations.
622 | inline MstlParams iterations(size_t iterations) {
623 | this->iterate_ = iterations;
624 | return *this;
625 | }
626 |
627 | /// Sets lambda for Box-Cox transformation.
628 | inline MstlParams lambda(float lambda) {
629 | this->lambda_ = lambda;
630 | return *this;
631 | }
632 |
633 | /// Sets the lengths of the seasonal smoothers.
634 | inline MstlParams seasonal_lengths(const std::vector& lengths) {
635 | this->swin_ = lengths;
636 | return *this;
637 | }
638 |
639 | /// Sets the STL parameters.
640 | inline MstlParams stl_params(const StlParams& stl_params) {
641 | this->stl_params_ = stl_params;
642 | return *this;
643 | }
644 |
645 | /// Decomposes a time series from an array.
646 | template
647 | MstlResult fit(const T* series, size_t series_size, const size_t* periods, size_t periods_size) const;
648 |
649 | /// Decomposes a time series from a vector.
650 | template
651 | MstlResult fit(const std::vector& series, const std::vector& periods) const;
652 |
653 | #if __cplusplus >= 202002L
654 | /// Decomposes a time series from a span.
655 | template
656 | MstlResult fit(std::span series, std::span periods) const;
657 | #endif
658 | };
659 |
660 | /// Creates a new set of MSTL parameters.
661 | inline MstlParams mstl_params() {
662 | return MstlParams();
663 | }
664 |
665 | namespace {
666 |
667 | template
668 | std::vector box_cox(const T* y, size_t y_size, float lambda) {
669 | std::vector res;
670 | res.reserve(y_size);
671 | if (lambda != 0.0) {
672 | for (size_t i = 0; i < y_size; i++) {
673 | res.push_back((T) (std::pow(y[i], lambda) - 1.0) / lambda);
674 | }
675 | } else {
676 | for (size_t i = 0; i < y_size; i++) {
677 | res.push_back(std::log(y[i]));
678 | }
679 | }
680 | return res;
681 | }
682 |
683 | template
684 | std::tuple, std::vector, std::vector>> mstl(
685 | const T* x,
686 | size_t k,
687 | const size_t* seas_ids,
688 | size_t seas_size,
689 | size_t iterate,
690 | std::optional lambda,
691 | const std::optional>& swin,
692 | const StlParams& stl_params
693 | ) {
694 | // keep track of indices instead of sorting seas_ids
695 | // so order is preserved with seasonality
696 | std::vector indices;
697 | for (size_t i = 0; i < seas_size; i++) {
698 | indices.push_back(i);
699 | }
700 | std::sort(indices.begin(), indices.end(), [&seas_ids](size_t a, size_t b) {
701 | return seas_ids[a] < seas_ids[b];
702 | });
703 |
704 | if (seas_size == 1) {
705 | iterate = 1;
706 | }
707 |
708 | std::vector> seasonality;
709 | seasonality.reserve(seas_size);
710 | std::vector trend;
711 |
712 | auto deseas = lambda.has_value() ? box_cox(x, k, lambda.value()) : std::vector(x, x + k);
713 |
714 | if (seas_size != 0) {
715 | for (size_t i = 0; i < seas_size; i++) {
716 | seasonality.push_back(std::vector());
717 | }
718 |
719 | for (size_t j = 0; j < iterate; j++) {
720 | for (size_t i = 0; i < indices.size(); i++) {
721 | auto idx = indices[i];
722 |
723 | if (j > 0) {
724 | for (size_t ii = 0; ii < deseas.size(); ii++) {
725 | deseas[ii] += seasonality[idx][ii];
726 | }
727 | }
728 |
729 | StlResult fit;
730 | if (swin) {
731 | StlParams clone = stl_params;
732 | fit = clone.seasonal_length((*swin)[idx]).fit(deseas, seas_ids[idx]);
733 | } else if (stl_params.ns_.has_value()) {
734 | fit = stl_params.fit(deseas, seas_ids[idx]);
735 | } else {
736 | StlParams clone = stl_params;
737 | fit = clone.seasonal_length(7 + 4 * (i + 1)).fit(deseas, seas_ids[idx]);
738 | }
739 |
740 | seasonality[idx] = fit.seasonal;
741 | trend = fit.trend;
742 |
743 | for (size_t ii = 0; ii < deseas.size(); ii++) {
744 | deseas[ii] -= seasonality[idx][ii];
745 | }
746 | }
747 | }
748 | } else {
749 | // TODO use Friedman's Super Smoother for trend
750 | throw std::invalid_argument("periods must not be empty");
751 | }
752 |
753 | std::vector remainder;
754 | remainder.reserve(k);
755 | for (size_t i = 0; i < k; i++) {
756 | remainder.push_back(deseas[i] - trend[i]);
757 | }
758 |
759 | return std::make_tuple(trend, remainder, seasonality);
760 | }
761 |
762 | }
763 |
764 | template
765 | MstlResult MstlParams::fit(const T* series, size_t series_size, const size_t* periods, size_t periods_size) const {
766 | // return error to be consistent with stl
767 | // and ensure seasonal is always same length as periods
768 | for (size_t i = 0; i < periods_size; i++) {
769 | if (periods[i] < 2) {
770 | throw std::invalid_argument("periods must be at least 2");
771 | }
772 | }
773 |
774 | // return error to be consistent with stl
775 | // and ensure seasonal is always same length as periods
776 | for (size_t i = 0; i < periods_size; i++) {
777 | if (series_size < periods[i] * 2) {
778 | throw std::invalid_argument("series has less than two periods");
779 | }
780 | }
781 |
782 | if (lambda_.has_value()) {
783 | auto lambda = lambda_.value();
784 | if (lambda < 0 || lambda > 1) {
785 | throw std::invalid_argument("lambda must be between 0 and 1");
786 | }
787 | }
788 |
789 | if (swin_.has_value()) {
790 | auto swin = swin_.value();
791 | if (swin.size() != periods_size) {
792 | throw std::invalid_argument("seasonal_lengths must have the same length as periods");
793 | }
794 | }
795 |
796 | auto [trend, remainder, seasonal] = mstl(
797 | series,
798 | series_size,
799 | periods,
800 | periods_size,
801 | iterate_,
802 | lambda_,
803 | swin_,
804 | stl_params_
805 | );
806 |
807 | return MstlResult {
808 | seasonal,
809 | trend,
810 | remainder
811 | };
812 | }
813 |
814 | template
815 | MstlResult MstlParams::fit(const std::vector& series, const std::vector& periods) const {
816 | return MstlParams::fit(series.data(), series.size(), periods.data(), periods.size());
817 | }
818 |
819 | #if __cplusplus >= 202002L
820 | template
821 | MstlResult MstlParams::fit(std::span series, std::span periods) const {
822 | return MstlParams::fit(series.data(), series.size(), periods.data(), periods.size());
823 | }
824 | #endif
825 |
826 | }
827 |
--------------------------------------------------------------------------------
/lib/stl-rb.rb:
--------------------------------------------------------------------------------
1 | require_relative "stl"
2 |
--------------------------------------------------------------------------------
/lib/stl.rb:
--------------------------------------------------------------------------------
1 | # ext
2 | require "stl/ext"
3 |
4 | # modules
5 | require_relative "stl/version"
6 |
7 | module Stl
8 | class << self
9 | def decompose(
10 | series, period:,
11 | seasonal_length: nil, trend_length: nil, low_pass_length: nil,
12 | seasonal_degree: nil, trend_degree: nil, low_pass_degree: nil,
13 | seasonal_jump: nil, trend_jump: nil, low_pass_jump: nil,
14 | inner_loops: nil, outer_loops: nil, robust: false
15 | )
16 | if period < 2
17 | raise ArgumentError, "period must be greater than 1"
18 | end
19 |
20 | params = StlParams.new
21 |
22 | params.seasonal_length(seasonal_length) unless seasonal_length.nil?
23 | params.trend_length(trend_length) unless trend_length.nil?
24 | params.low_pass_length(low_pass_length) unless low_pass_length.nil?
25 |
26 | params.seasonal_degree(seasonal_degree) unless seasonal_degree.nil?
27 | params.trend_degree(trend_degree) unless trend_degree.nil?
28 | params.low_pass_degree(low_pass_degree) unless low_pass_degree.nil?
29 |
30 | params.seasonal_jump(seasonal_jump) unless seasonal_jump.nil?
31 | params.trend_jump(trend_jump) unless trend_jump.nil?
32 | params.low_pass_jump(low_pass_jump) unless low_pass_jump.nil?
33 |
34 | params.inner_loops(inner_loops) unless inner_loops.nil?
35 | params.outer_loops(outer_loops) unless outer_loops.nil?
36 | params.robust(robust) unless robust.nil?
37 |
38 | if series.is_a?(Hash)
39 | sorted = series.sort_by { |k, _| k }
40 | y = sorted.map(&:last)
41 | else
42 | y = series
43 | end
44 |
45 | params.fit(y, period, outer_loops.nil? ? robust : outer_loops > 0)
46 | end
47 |
48 | def plot(series, result)
49 | require "vega"
50 |
51 | data =
52 | if series.is_a?(Hash)
53 | series.sort_by { |k, _| k }.map.with_index do |s, i|
54 | {
55 | x: iso8601(s[0]),
56 | series: s[1],
57 | seasonal: result[:seasonal][i],
58 | trend: result[:trend][i],
59 | remainder: result[:remainder][i]
60 | }
61 | end
62 | else
63 | series.map.with_index do |v, i|
64 | {
65 | x: i,
66 | series: v,
67 | seasonal: result[:seasonal][i],
68 | trend: result[:trend][i],
69 | remainder: result[:remainder][i]
70 | }
71 | end
72 | end
73 |
74 | if series.is_a?(Hash)
75 | x = {field: "x", type: "temporal"}
76 | x["scale"] = {type: "utc"} if series.keys.first.is_a?(Date)
77 | else
78 | x = {field: "x", type: "quantitative"}
79 | end
80 | x[:axis] = {title: nil, labelFontSize: 12}
81 |
82 | charts =
83 | ["series", "seasonal", "trend", "remainder"].map do |field|
84 | {
85 | mark: {type: "line"},
86 | encoding: {
87 | x: x,
88 | y: {field: field, type: "quantitative", scale: {zero: false}, axis: {labelFontSize: 12}}
89 | },
90 | width: "container",
91 | height: 100
92 | }
93 | end
94 |
95 | Vega.lite
96 | .data(data)
97 | .vconcat(charts)
98 | .config(autosize: {type: "fit-x", contains: "padding"})
99 | .width(nil) # prevents warning
100 | .height(nil) # prevents warning and sets div height to auto
101 | end
102 |
103 | def seasonal_strength(result)
104 | sr = result[:seasonal].zip(result[:remainder]).map { |a, b| a + b }
105 | [0, 1 - var(result[:remainder]) / var(sr)].max
106 | end
107 |
108 | def trend_strength(result)
109 | tr = result[:trend].zip(result[:remainder]).map { |a, b| a + b }
110 | [0, 1 - var(result[:remainder]) / var(tr)].max
111 | end
112 |
113 | private
114 |
115 | def iso8601(v)
116 | if v.is_a?(Date)
117 | v.strftime("%Y-%m-%d")
118 | else
119 | v.strftime("%Y-%m-%dT%H:%M:%S.%L%z")
120 | end
121 | end
122 |
123 | def var(series)
124 | mean = series.sum / series.size.to_f
125 | series.sum { |v| (v - mean) ** 2 } / (series.size.to_f - 1)
126 | end
127 | end
128 | end
129 |
--------------------------------------------------------------------------------
/lib/stl/version.rb:
--------------------------------------------------------------------------------
1 | module Stl
2 | VERSION = "0.3.0"
3 | end
4 |
--------------------------------------------------------------------------------
/stl-rb.gemspec:
--------------------------------------------------------------------------------
1 | require_relative "lib/stl/version"
2 |
3 | Gem::Specification.new do |spec|
4 | spec.name = "stl-rb"
5 | spec.version = Stl::VERSION
6 | spec.summary = "Seasonal-trend decomposition for Ruby"
7 | spec.homepage = "https://github.com/ankane/stl-ruby"
8 | spec.license = "Unlicense OR MIT"
9 |
10 | spec.author = "Andrew Kane"
11 | spec.email = "andrew@ankane.org"
12 |
13 | spec.files = Dir["*.{md,txt}", "{ext,lib}/**/*"]
14 | spec.require_path = "lib"
15 | spec.extensions = ["ext/stl/extconf.rb"]
16 |
17 | spec.required_ruby_version = ">= 3.1"
18 |
19 | spec.add_dependency "rice", ">= 4.3.3"
20 | end
21 |
--------------------------------------------------------------------------------
/test/stl_test.rb:
--------------------------------------------------------------------------------
1 | require_relative "test_helper"
2 |
3 | class StlTest < Minitest::Test
4 | def test_hash
5 | today = Date.today
6 | series = self.series.map.with_index.to_h { |v, i| [today + i, v] }
7 | result = Stl.decompose(series, period: 7)
8 | assert_elements_in_delta [0.36926576, 0.75655484, -1.3324139, 1.9553658, -0.6044802], result[:seasonal].first(5)
9 | assert_elements_in_delta [4.804099, 4.9097075, 5.015316, 5.16045, 5.305584], result[:trend].first(5)
10 | assert_elements_in_delta [-0.17336464, 3.3337379, -1.6829021, 1.8841844, -4.7011037], result[:remainder].first(5)
11 | end
12 |
13 | def test_array
14 | result = Stl.decompose(series, period: 7)
15 | assert_elements_in_delta [0.36926576, 0.75655484, -1.3324139, 1.9553658, -0.6044802], result[:seasonal].first(5)
16 | assert_elements_in_delta [4.804099, 4.9097075, 5.015316, 5.16045, 5.305584], result[:trend].first(5)
17 | assert_elements_in_delta [-0.17336464, 3.3337379, -1.6829021, 1.8841844, -4.7011037], result[:remainder].first(5)
18 | end
19 |
20 | def test_robust
21 | result = Stl.decompose(series, period: 7, robust: true)
22 | assert_elements_in_delta [0.14922355, 0.47939026, -1.833231, 1.7411387, 0.8200711], result[:seasonal].first(5)
23 | assert_elements_in_delta [5.397365, 5.4745436, 5.5517216, 5.6499176, 5.748114], result[:trend].first(5)
24 | assert_elements_in_delta [-0.5465884, 3.0460663, -1.7184906, 1.6089439, -6.5681853], result[:remainder].first(5)
25 | assert_elements_in_delta [0.99374926, 0.8129377, 0.9385952, 0.9458036, 0.29742217], result[:weights].first(5)
26 | end
27 |
28 | def test_repeating
29 | series = 24.times.to_a.shuffle * 8
30 | result = Stl.decompose(series, period: 24)
31 | assert_elements_in_delta [0] * series.size, result[:remainder]
32 | assert_elements_in_delta [11.5] * series.size, result[:trend]
33 | end
34 |
35 | def test_period_one
36 | error = assert_raises(ArgumentError) do
37 | Stl.decompose(series, period: 1)
38 | end
39 | assert_equal "period must be greater than 1", error.message
40 | end
41 |
42 | def test_too_few_periods
43 | error = assert_raises(ArgumentError) do
44 | Stl.decompose(series, period: 16)
45 | end
46 | assert_equal "series has less than two periods", error.message
47 | end
48 |
49 | def test_bad_seasonal_degree
50 | error = assert_raises(ArgumentError) do
51 | Stl.decompose(series, period: 7, seasonal_degree: 2)
52 | end
53 | assert_equal "seasonal_degree must be 0 or 1", error.message
54 | end
55 |
56 | def test_plot_hash
57 | today = Date.today
58 | series = self.series.map.with_index.to_h { |v, i| [today + i, v] }
59 | result = Stl.decompose(series, period: 7)
60 | assert_kind_of Vega::LiteChart, Stl.plot(series, result)
61 | end
62 |
63 | def test_plot_array
64 | result = Stl.decompose(series, period: 7)
65 | assert_kind_of Vega::LiteChart, Stl.plot(series, result)
66 | end
67 |
68 | def test_seasonal_strength
69 | result = Stl.decompose(series, period: 7)
70 | assert_in_delta 0.284111676315015, Stl.seasonal_strength(result)
71 | end
72 |
73 | def test_seasonal_strength_max
74 | series = 30.times.map { |i| i % 7 }
75 | result = Stl.decompose(series, period: 7)
76 | assert_in_delta 1, Stl.seasonal_strength(result)
77 | end
78 |
79 | def test_trend_strength
80 | result = Stl.decompose(series, period: 7)
81 | assert_in_delta 0.16384245231864702, Stl.trend_strength(result)
82 | end
83 |
84 | def test_trend_strength_max
85 | series = 30.times.to_a
86 | result = Stl.decompose(series, period: 7)
87 | assert_in_delta 1, Stl.trend_strength(result)
88 | end
89 |
90 | def series
91 | [
92 | 5.0, 9.0, 2.0, 9.0, 0.0, 6.0, 3.0, 8.0, 5.0, 8.0,
93 | 7.0, 8.0, 8.0, 0.0, 2.0, 5.0, 0.0, 5.0, 6.0, 7.0,
94 | 3.0, 6.0, 1.0, 4.0, 4.0, 4.0, 3.0, 7.0, 5.0, 8.0
95 | ]
96 | end
97 | end
98 |
--------------------------------------------------------------------------------
/test/test_helper.rb:
--------------------------------------------------------------------------------
1 | require "bundler/setup"
2 | Bundler.require(:default)
3 | require "minitest/autorun"
4 | require "minitest/pride"
5 | require "date"
6 |
7 | class Minitest::Test
8 | def assert_elements_in_delta(expected, actual)
9 | assert_equal expected.size, actual.size
10 | expected.zip(actual) do |exp, act|
11 | assert_in_delta exp, act
12 | end
13 | end
14 | end
15 |
--------------------------------------------------------------------------------