├── .gitattributes ├── .github └── workflows │ └── build.yml ├── .gitignore ├── CHANGELOG.md ├── Gemfile ├── LICENSE-MIT ├── README.md ├── Rakefile ├── UNLICENSE ├── ext └── stl │ ├── ext.cpp │ ├── extconf.rb │ └── stl.hpp ├── lib ├── stl-rb.rb ├── stl.rb └── stl │ └── version.rb ├── stl-rb.gemspec └── test ├── stl_test.rb └── test_helper.rb /.gitattributes: -------------------------------------------------------------------------------- 1 | ext/stl/stl.hpp linguist-vendored 2 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | on: [push, pull_request] 3 | jobs: 4 | build: 5 | strategy: 6 | fail-fast: false 7 | matrix: 8 | os: [ubuntu-latest, macos-latest, windows-latest] 9 | runs-on: ${{ matrix.os }} 10 | steps: 11 | - uses: actions/checkout@v4 12 | - uses: ruby/setup-ruby@v1 13 | with: 14 | ruby-version: 3.4 15 | bundler-cache: true 16 | - run: bundle exec rake compile 17 | - run: bundle exec rake test 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.bundle/ 2 | /.yardoc 3 | /_yardoc/ 4 | /coverage/ 5 | /doc/ 6 | /pkg/ 7 | /spec/reports/ 8 | /tmp/ 9 | *.lock 10 | *.bundle 11 | *.so 12 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## 0.3.0 (2024-10-22) 2 | 3 | - Dropped support for Ruby < 3.1 4 | 5 | ## 0.2.3 (2024-01-26) 6 | 7 | - Fixed bug with `inner_loops` and `outer_loops` 8 | 9 | ## 0.2.2 (2023-06-20) 10 | 11 | - Fixed bug when jump > 1 12 | 13 | ## 0.2.1 (2023-05-11) 14 | 15 | - Fixed error on Fedora 16 | 17 | ## 0.2.0 (2023-02-01) 18 | 19 | - Raise error when `period` is less than 2 20 | - Dropped support for Ruby < 2.7 21 | 22 | ## 0.1.3 (2021-12-16) 23 | 24 | - Fixed installation error on macOS 12 25 | 26 | ## 0.1.2 (2021-10-24) 27 | 28 | - Added `seasonal_strength` and `trend_strength` methods 29 | - Improved plot width and height 30 | 31 | ## 0.1.1 (2021-10-20) 32 | 33 | - Added `plot` method 34 | 35 | ## 0.1.0 (2021-10-16) 36 | 37 | - First release 38 | -------------------------------------------------------------------------------- /Gemfile: -------------------------------------------------------------------------------- 1 | source "https://rubygems.org" 2 | 3 | gemspec 4 | 5 | gem "rake" 6 | gem "rake-compiler" 7 | gem "minitest", ">= 5" 8 | gem "vega" 9 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2021-2024 Contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # STL Ruby 2 | 3 | Seasonal-trend decomposition for Ruby 4 | 5 | [![Build Status](https://github.com/ankane/stl-ruby/actions/workflows/build.yml/badge.svg)](https://github.com/ankane/stl-ruby/actions) 6 | 7 | ## Installation 8 | 9 | Add this line to your application’s Gemfile: 10 | 11 | ```ruby 12 | gem "stl-rb" 13 | ``` 14 | 15 | ## Getting Started 16 | 17 | Decompose a time series 18 | 19 | ```ruby 20 | series = { 21 | Date.parse("2025-01-01") => 100, 22 | Date.parse("2025-01-02") => 150, 23 | Date.parse("2025-01-03") => 136, 24 | # ... 25 | } 26 | 27 | Stl.decompose(series, period: 7) 28 | ``` 29 | 30 | Works great with [Groupdate](https://github.com/ankane/groupdate) 31 | 32 | ```ruby 33 | series = User.group_by_day(:created_at).count 34 | Stl.decompose(series, period: 7) 35 | ``` 36 | 37 | Series can also be an array without times 38 | 39 | ```ruby 40 | series = [100, 150, 136, ...] 41 | Stl.decompose(series, period: 7) 42 | ``` 43 | 44 | Use robustness iterations 45 | 46 | ```ruby 47 | Stl.decompose(series, period: 7, robust: true) 48 | ``` 49 | 50 | ## Options 51 | 52 | Pass options 53 | 54 | ```ruby 55 | Stl.decompose( 56 | series, 57 | period: 7, # period of the seasonal component 58 | seasonal_length: 7, # length of the seasonal smoother 59 | trend_length: 15, # length of the trend smoother 60 | low_pass_length: 7, # length of the low-pass filter 61 | seasonal_degree: 0, # degree of locally-fitted polynomial in seasonal smoothing 62 | trend_degree: 1, # degree of locally-fitted polynomial in trend smoothing 63 | low_pass_degree: 1, # degree of locally-fitted polynomial in low-pass smoothing 64 | seasonal_jump: 1, # skipping value for seasonal smoothing 65 | trend_jump: 2, # skipping value for trend smoothing 66 | low_pass_jump: 1, # skipping value for low-pass smoothing 67 | inner_loops: 2, # number of loops for updating the seasonal and trend components 68 | outer_loops: 0, # number of iterations of robust fitting 69 | robust: false # if robustness iterations are to be used 70 | ) 71 | ``` 72 | 73 | ## Plotting 74 | 75 | Add [Vega](https://github.com/ankane/vega) to your application’s Gemfile: 76 | 77 | ```ruby 78 | gem "vega" 79 | ``` 80 | 81 | And use: 82 | 83 | ```ruby 84 | Stl.plot(series, decompose_result) 85 | ``` 86 | 87 | ## Strength 88 | 89 | Get the seasonal strength 90 | 91 | ```ruby 92 | Stl.seasonal_strength(decompose_result) 93 | ``` 94 | 95 | Get the trend strength 96 | 97 | ```ruby 98 | Stl.trend_strength(decompose_result) 99 | ``` 100 | 101 | ## Credits 102 | 103 | This library was ported from the [Fortran implementation](https://www.netlib.org/a/stl). 104 | 105 | ## References 106 | 107 | - [STL: A Seasonal-Trend Decomposition Procedure Based on Loess](https://www.scb.se/contentassets/ca21efb41fee47d293bbee5bf7be7fb3/stl-a-seasonal-trend-decomposition-procedure-based-on-loess.pdf) 108 | - [Measuring strength of trend and seasonality](https://otexts.com/fpp2/seasonal-strength.html) 109 | 110 | ## History 111 | 112 | View the [changelog](https://github.com/ankane/stl-ruby/blob/master/CHANGELOG.md) 113 | 114 | ## Contributing 115 | 116 | Everyone is encouraged to help improve this project. Here are a few ways you can help: 117 | 118 | - [Report bugs](https://github.com/ankane/stl-ruby/issues) 119 | - Fix bugs and [submit pull requests](https://github.com/ankane/stl-ruby/pulls) 120 | - Write, clarify, or fix documentation 121 | - Suggest or add new features 122 | 123 | To get started with development: 124 | 125 | ```sh 126 | git clone https://github.com/ankane/stl-ruby.git 127 | cd stl-ruby 128 | bundle install 129 | bundle exec rake compile 130 | bundle exec rake test 131 | ``` 132 | -------------------------------------------------------------------------------- /Rakefile: -------------------------------------------------------------------------------- 1 | require "bundler/gem_tasks" 2 | require "rake/testtask" 3 | require "rake/extensiontask" 4 | 5 | task default: :test 6 | Rake::TestTask.new do |t| 7 | t.libs << "test" 8 | t.pattern = "test/**/*_test.rb" 9 | end 10 | 11 | Rake::ExtensionTask.new("stl") do |ext| 12 | ext.name = "ext" 13 | ext.lib_dir = "lib/stl" 14 | end 15 | 16 | task :remove_ext do 17 | path = "lib/stl/ext.bundle" 18 | File.unlink(path) if File.exist?(path) 19 | end 20 | 21 | Rake::Task["build"].enhance [:remove_ext] 22 | -------------------------------------------------------------------------------- /UNLICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /ext/stl/ext.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "stl.hpp" 5 | 6 | Rice::Array to_a(std::vector& x) { 7 | auto a = Rice::Array(); 8 | for (auto v : x) { 9 | a.push(v); 10 | } 11 | return a; 12 | } 13 | 14 | extern "C" 15 | void Init_ext() { 16 | auto rb_mStl = Rice::define_module("Stl"); 17 | 18 | Rice::define_class_under(rb_mStl, "StlParams") 19 | .define_constructor(Rice::Constructor()) 20 | .define_method("seasonal_length", &stl::StlParams::seasonal_length) 21 | .define_method("trend_length", &stl::StlParams::trend_length) 22 | .define_method("low_pass_length", &stl::StlParams::low_pass_length) 23 | .define_method("seasonal_degree", &stl::StlParams::seasonal_degree) 24 | .define_method("trend_degree", &stl::StlParams::trend_degree) 25 | .define_method("low_pass_degree", &stl::StlParams::low_pass_degree) 26 | .define_method("seasonal_jump", &stl::StlParams::seasonal_jump) 27 | .define_method("trend_jump", &stl::StlParams::trend_jump) 28 | .define_method("low_pass_jump", &stl::StlParams::low_pass_jump) 29 | .define_method("inner_loops", &stl::StlParams::inner_loops) 30 | .define_method("outer_loops", &stl::StlParams::outer_loops) 31 | .define_method("robust", &stl::StlParams::robust) 32 | .define_method( 33 | "fit", 34 | [](stl::StlParams& self, std::vector series, size_t period, bool weights) { 35 | auto result = self.fit(series, period); 36 | 37 | auto ret = Rice::Hash(); 38 | ret[Rice::Symbol("seasonal")] = to_a(result.seasonal); 39 | ret[Rice::Symbol("trend")] = to_a(result.trend); 40 | ret[Rice::Symbol("remainder")] = to_a(result.remainder); 41 | if (weights) { 42 | ret[Rice::Symbol("weights")] = to_a(result.weights); 43 | } 44 | return ret; 45 | }); 46 | } 47 | -------------------------------------------------------------------------------- /ext/stl/extconf.rb: -------------------------------------------------------------------------------- 1 | require "mkmf-rice" 2 | 3 | $CXXFLAGS += " -std=c++17 $(optflags)" 4 | 5 | create_makefile("stl/ext") 6 | -------------------------------------------------------------------------------- /ext/stl/stl.hpp: -------------------------------------------------------------------------------- 1 | /*! 2 | * STL C++ v0.2.0 3 | * https://github.com/ankane/stl-cpp 4 | * Unlicense OR MIT License 5 | * 6 | * Ported from https://www.netlib.org/a/stl 7 | * 8 | * Cleveland, R. B., Cleveland, W. S., McRae, J. E., & Terpenning, I. (1990). 9 | * STL: A Seasonal-Trend Decomposition Procedure Based on Loess. 10 | * Journal of Official Statistics, 6(1), 3-33. 11 | * 12 | * Bandara, K., Hyndman, R. J., & Bergmeir, C. (2021). 13 | * MSTL: A Seasonal-Trend Decomposition Algorithm for Time Series with Multiple Seasonal Patterns. 14 | * arXiv:2107.13462 [stat.AP]. https://doi.org/10.48550/arXiv.2107.13462 15 | */ 16 | 17 | #pragma once 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | #if __cplusplus >= 202002L 28 | #include 29 | #endif 30 | 31 | namespace stl { 32 | 33 | namespace { 34 | 35 | template 36 | bool est(const T* y, size_t n, size_t len, int ideg, T xs, T* ys, size_t nleft, size_t nright, T* w, bool userw, const T* rw) { 37 | auto range = ((T) n) - 1.0; 38 | auto h = std::max(xs - ((T) nleft), ((T) nright) - xs); 39 | 40 | if (len > n) { 41 | h += (T) ((len - n) / 2); 42 | } 43 | 44 | auto h9 = 0.999 * h; 45 | auto h1 = 0.001 * h; 46 | 47 | // compute weights 48 | auto a = 0.0; 49 | for (auto j = nleft; j <= nright; j++) { 50 | w[j - 1] = 0.0; 51 | auto r = std::abs(((T) j) - xs); 52 | if (r <= h9) { 53 | if (r <= h1) { 54 | w[j - 1] = 1.0; 55 | } else { 56 | w[j - 1] = (T) std::pow(1.0 - std::pow(r / h, 3), 3); 57 | } 58 | if (userw) { 59 | w[j - 1] *= rw[j - 1]; 60 | } 61 | a += w[j - 1]; 62 | } 63 | } 64 | 65 | if (a <= 0.0) { 66 | return false; 67 | } else { // weighted least squares 68 | for (auto j = nleft; j <= nright; j++) { // make sum of w(j) == 1 69 | w[j - 1] /= (T) a; 70 | } 71 | 72 | if (h > 0.0 && ideg > 0) { // use linear fit 73 | auto a = 0.0; 74 | for (auto j = nleft; j <= nright; j++) { // weighted center of x values 75 | a += w[j - 1] * ((T) j); 76 | } 77 | auto b = xs - a; 78 | auto c = 0.0; 79 | for (auto j = nleft; j <= nright; j++) { 80 | c += w[j - 1] * std::pow(((T) j) - a, 2); 81 | } 82 | if (std::sqrt(c) > 0.001 * range) { 83 | b /= c; 84 | 85 | // points are spread out enough to compute slope 86 | for (auto j = nleft; j <= nright; j++) { 87 | w[j - 1] *= (T) (b * (((T) j) - a) + 1.0); 88 | } 89 | } 90 | } 91 | 92 | *ys = 0.0; 93 | for (auto j = nleft; j <= nright; j++) { 94 | *ys += w[j - 1] * y[j - 1]; 95 | } 96 | 97 | return true; 98 | } 99 | } 100 | 101 | template 102 | void ess(const T* y, size_t n, size_t len, int ideg, size_t njump, bool userw, const T* rw, T* ys, T* res) { 103 | if (n < 2) { 104 | ys[0] = y[0]; 105 | return; 106 | } 107 | 108 | size_t nleft = 0; 109 | size_t nright = 0; 110 | 111 | auto newnj = std::min(njump, n - 1); 112 | if (len >= n) { 113 | nleft = 1; 114 | nright = n; 115 | for (size_t i = 1; i <= n; i += newnj) { 116 | auto ok = est(y, n, len, ideg, (T) i, &ys[i - 1], nleft, nright, res, userw, rw); 117 | if (!ok) { 118 | ys[i - 1] = y[i - 1]; 119 | } 120 | } 121 | } else if (newnj == 1) { // newnj equal to one, len less than n 122 | auto nsh = (len + 1) / 2; 123 | nleft = 1; 124 | nright = len; 125 | for (size_t i = 1; i <= n; i++) { // fitted value at i 126 | if (i > nsh && nright != n) { 127 | nleft += 1; 128 | nright += 1; 129 | } 130 | auto ok = est(y, n, len, ideg, (T) i, &ys[i - 1], nleft, nright, res, userw, rw); 131 | if (!ok) { 132 | ys[i - 1] = y[i - 1]; 133 | } 134 | } 135 | } else { // newnj greater than one, len less than n 136 | auto nsh = (len + 1) / 2; 137 | for (size_t i = 1; i <= n; i += newnj) { // fitted value at i 138 | if (i < nsh) { 139 | nleft = 1; 140 | nright = len; 141 | } else if (i >= n - nsh + 1) { 142 | nleft = n - len + 1; 143 | nright = n; 144 | } else { 145 | nleft = i - nsh + 1; 146 | nright = len + i - nsh; 147 | } 148 | auto ok = est(y, n, len, ideg, (T) i, &ys[i - 1], nleft, nright, res, userw, rw); 149 | if (!ok) { 150 | ys[i - 1] = y[i - 1]; 151 | } 152 | } 153 | } 154 | 155 | if (newnj != 1) { 156 | for (size_t i = 1; i <= n - newnj; i += newnj) { 157 | auto delta = (ys[i + newnj - 1] - ys[i - 1]) / ((T) newnj); 158 | for (auto j = i + 1; j <= i + newnj - 1; j++) { 159 | ys[j - 1] = ys[i - 1] + delta * ((T) (j - i)); 160 | } 161 | } 162 | auto k = ((n - 1) / newnj) * newnj + 1; 163 | if (k != n) { 164 | auto ok = est(y, n, len, ideg, (T) n, &ys[n - 1], nleft, nright, res, userw, rw); 165 | if (!ok) { 166 | ys[n - 1] = y[n - 1]; 167 | } 168 | if (k != n - 1) { 169 | auto delta = (ys[n - 1] - ys[k - 1]) / ((T) (n - k)); 170 | for (auto j = k + 1; j <= n - 1; j++) { 171 | ys[j - 1] = ys[k - 1] + delta * ((T) (j - k)); 172 | } 173 | } 174 | } 175 | } 176 | } 177 | 178 | template 179 | void ma(const T* x, size_t n, size_t len, T* ave) { 180 | auto newn = n - len + 1; 181 | double flen = (T) len; 182 | double v = 0.0; 183 | 184 | // get the first average 185 | for (size_t i = 0; i < len; i++) { 186 | v += x[i]; 187 | } 188 | 189 | ave[0] = (T) (v / flen); 190 | if (newn > 1) { 191 | size_t k = len; 192 | size_t m = 0; 193 | for (size_t j = 1; j < newn; j++) { 194 | // window down the array 195 | v = v - x[m] + x[k]; 196 | ave[j] = (T) (v / flen); 197 | k += 1; 198 | m += 1; 199 | } 200 | } 201 | } 202 | 203 | template 204 | void fts(const T* x, size_t n, size_t np, T* trend, T* work) { 205 | ma(x, n, np, trend); 206 | ma(trend, n - np + 1, np, work); 207 | ma(work, n - 2 * np + 2, 3, trend); 208 | } 209 | 210 | template 211 | void rwts(const T* y, size_t n, const T* fit, T* rw) { 212 | for (size_t i = 0; i < n; i++) { 213 | rw[i] = std::abs(y[i] - fit[i]); 214 | } 215 | 216 | auto mid1 = (n - 1) / 2; 217 | auto mid2 = n / 2; 218 | 219 | // sort 220 | std::sort(rw, rw + n); 221 | 222 | auto cmad = 3.0 * (rw[mid1] + rw[mid2]); // 6 * median abs resid 223 | auto c9 = 0.999 * cmad; 224 | auto c1 = 0.001 * cmad; 225 | 226 | for (size_t i = 0; i < n; i++) { 227 | auto r = std::abs(y[i] - fit[i]); 228 | if (r <= c1) { 229 | rw[i] = 1.0; 230 | } else if (r <= c9) { 231 | rw[i] = (T) std::pow(1.0 - std::pow(r / cmad, 2), 2); 232 | } else { 233 | rw[i] = 0.0; 234 | } 235 | } 236 | } 237 | 238 | template 239 | void ss(const T* y, size_t n, size_t np, size_t ns, int isdeg, size_t nsjump, bool userw, T* rw, T* season, T* work1, T* work2, T* work3, T* work4) { 240 | for (size_t j = 1; j <= np; j++) { 241 | size_t k = (n - j) / np + 1; 242 | 243 | for (size_t i = 1; i <= k; i++) { 244 | work1[i - 1] = y[(i - 1) * np + j - 1]; 245 | } 246 | if (userw) { 247 | for (size_t i = 1; i <= k; i++) { 248 | work3[i - 1] = rw[(i - 1) * np + j - 1]; 249 | } 250 | } 251 | ess(work1, k, ns, isdeg, nsjump, userw, work3, work2 + 1, work4); 252 | T xs = 0.0; 253 | auto nright = std::min(ns, k); 254 | auto ok = est(work1, k, ns, isdeg, xs, &work2[0], 1, nright, work4, userw, work3); 255 | if (!ok) { 256 | work2[0] = work2[1]; 257 | } 258 | xs = k + 1; 259 | size_t nleft = (size_t) std::max(1, (int) k - (int) ns + 1); 260 | ok = est(work1, k, ns, isdeg, xs, &work2[k + 1], nleft, k, work4, userw, work3); 261 | if (!ok) { 262 | work2[k + 1] = work2[k]; 263 | } 264 | for (size_t m = 1; m <= k + 2; m++) { 265 | season[(m - 1) * np + j - 1] = work2[m - 1]; 266 | } 267 | } 268 | } 269 | 270 | template 271 | void onestp(const T* y, size_t n, size_t np, size_t ns, size_t nt, size_t nl, int isdeg, int itdeg, int ildeg, size_t nsjump, size_t ntjump, size_t nljump, size_t ni, bool userw, T* rw, T* season, T* trend, T* work1, T* work2, T* work3, T* work4, T* work5) { 272 | for (size_t j = 0; j < ni; j++) { 273 | for (size_t i = 0; i < n; i++) { 274 | work1[i] = y[i] - trend[i]; 275 | } 276 | 277 | ss(work1, n, np, ns, isdeg, nsjump, userw, rw, work2, work3, work4, work5, season); 278 | fts(work2, n + 2 * np, np, work3, work1); 279 | ess(work3, n, nl, ildeg, nljump, false, work4, work1, work5); 280 | for (size_t i = 0; i < n; i++) { 281 | season[i] = work2[np + i] - work1[i]; 282 | } 283 | for (size_t i = 0; i < n; i++) { 284 | work1[i] = y[i] - season[i]; 285 | } 286 | ess(work1, n, nt, itdeg, ntjump, userw, rw, trend, work3); 287 | } 288 | } 289 | 290 | template 291 | void stl(const T* y, size_t n, size_t np, size_t ns, size_t nt, size_t nl, int isdeg, int itdeg, int ildeg, size_t nsjump, size_t ntjump, size_t nljump, size_t ni, size_t no, T* rw, T* season, T* trend) { 292 | if (ns < 3) { 293 | throw std::invalid_argument("seasonal_length must be at least 3"); 294 | } 295 | if (nt < 3) { 296 | throw std::invalid_argument("trend_length must be at least 3"); 297 | } 298 | if (nl < 3) { 299 | throw std::invalid_argument("low_pass_length must be at least 3"); 300 | } 301 | if (np < 2) { 302 | throw std::invalid_argument("period must be at least 2"); 303 | } 304 | 305 | if (isdeg != 0 && isdeg != 1) { 306 | throw std::invalid_argument("seasonal_degree must be 0 or 1"); 307 | } 308 | if (itdeg != 0 && itdeg != 1) { 309 | throw std::invalid_argument("trend_degree must be 0 or 1"); 310 | } 311 | if (ildeg != 0 && ildeg != 1) { 312 | throw std::invalid_argument("low_pass_degree must be 0 or 1"); 313 | } 314 | 315 | if (ns % 2 != 1) { 316 | throw std::invalid_argument("seasonal_length must be odd"); 317 | } 318 | if (nt % 2 != 1) { 319 | throw std::invalid_argument("trend_length must be odd"); 320 | } 321 | if (nl % 2 != 1) { 322 | throw std::invalid_argument("low_pass_length must be odd"); 323 | } 324 | 325 | auto work1 = std::vector(n + 2 * np); 326 | auto work2 = std::vector(n + 2 * np); 327 | auto work3 = std::vector(n + 2 * np); 328 | auto work4 = std::vector(n + 2 * np); 329 | auto work5 = std::vector(n + 2 * np); 330 | 331 | auto userw = false; 332 | size_t k = 0; 333 | 334 | while (true) { 335 | onestp(y, n, np, ns, nt, nl, isdeg, itdeg, ildeg, nsjump, ntjump, nljump, ni, userw, rw, season, trend, work1.data(), work2.data(), work3.data(), work4.data(), work5.data()); 336 | k += 1; 337 | if (k > no) { 338 | break; 339 | } 340 | for (size_t i = 0; i < n; i++) { 341 | work1[i] = trend[i] + season[i]; 342 | } 343 | rwts(y, n, work1.data(), rw); 344 | userw = true; 345 | } 346 | 347 | if (no <= 0) { 348 | for (size_t i = 0; i < n; i++) { 349 | rw[i] = 1.0; 350 | } 351 | } 352 | } 353 | 354 | template 355 | double var(const std::vector& series) { 356 | auto mean = std::accumulate(series.begin(), series.end(), 0.0) / series.size(); 357 | double sum = 0.0; 358 | for (auto v : series) { 359 | double diff = v - mean; 360 | sum += diff * diff; 361 | } 362 | return sum / (series.size() - 1); 363 | } 364 | 365 | template 366 | double strength(const std::vector& component, const std::vector& remainder) { 367 | std::vector sr; 368 | sr.reserve(remainder.size()); 369 | for (size_t i = 0; i < remainder.size(); i++) { 370 | sr.push_back(component[i] + remainder[i]); 371 | } 372 | return std::max(0.0, 1.0 - var(remainder) / var(sr)); 373 | } 374 | 375 | } 376 | 377 | /// A STL result. 378 | template 379 | class StlResult { 380 | public: 381 | /// Returns the seasonal component. 382 | std::vector seasonal; 383 | 384 | /// Returns the trend component. 385 | std::vector trend; 386 | 387 | /// Returns the remainder. 388 | std::vector remainder; 389 | 390 | /// Returns the weights. 391 | std::vector weights; 392 | 393 | /// Returns the seasonal strength. 394 | inline double seasonal_strength() const { 395 | return strength(seasonal, remainder); 396 | } 397 | 398 | /// Returns the trend strength. 399 | inline double trend_strength() const { 400 | return strength(trend, remainder); 401 | } 402 | }; 403 | 404 | /// A set of STL parameters. 405 | class StlParams { 406 | public: 407 | /// @private 408 | std::optional ns_ = std::nullopt; 409 | private: 410 | std::optional nt_ = std::nullopt; 411 | std::optional nl_ = std::nullopt; 412 | int isdeg_ = 0; 413 | int itdeg_ = 1; 414 | std::optional ildeg_ = std::nullopt; 415 | std::optional nsjump_ = std::nullopt; 416 | std::optional ntjump_ = std::nullopt; 417 | std::optional nljump_ = std::nullopt; 418 | std::optional ni_ = std::nullopt; 419 | std::optional no_ = std::nullopt; 420 | bool robust_ = false; 421 | 422 | public: 423 | /// Sets the length of the seasonal smoother. 424 | inline StlParams seasonal_length(size_t length) { 425 | this->ns_ = length; 426 | return *this; 427 | } 428 | 429 | /// Sets the length of the trend smoother. 430 | inline StlParams trend_length(size_t length) { 431 | this->nt_ = length; 432 | return *this; 433 | } 434 | 435 | /// Sets the length of the low-pass filter. 436 | inline StlParams low_pass_length(size_t length) { 437 | this->nl_ = length; 438 | return *this; 439 | } 440 | 441 | /// Sets the degree of locally-fitted polynomial in seasonal smoothing. 442 | inline StlParams seasonal_degree(int degree) { 443 | this->isdeg_ = degree; 444 | return *this; 445 | } 446 | 447 | /// Sets the degree of locally-fitted polynomial in trend smoothing. 448 | inline StlParams trend_degree(int degree) { 449 | this->itdeg_ = degree; 450 | return *this; 451 | } 452 | 453 | /// Sets the degree of locally-fitted polynomial in low-pass smoothing. 454 | inline StlParams low_pass_degree(int degree) { 455 | this->ildeg_ = degree; 456 | return *this; 457 | } 458 | 459 | /// Sets the skipping value for seasonal smoothing. 460 | inline StlParams seasonal_jump(size_t jump) { 461 | this->nsjump_ = jump; 462 | return *this; 463 | } 464 | 465 | /// Sets the skipping value for trend smoothing. 466 | inline StlParams trend_jump(size_t jump) { 467 | this->ntjump_ = jump; 468 | return *this; 469 | } 470 | 471 | /// Sets the skipping value for low-pass smoothing. 472 | inline StlParams low_pass_jump(size_t jump) { 473 | this->nljump_ = jump; 474 | return *this; 475 | } 476 | 477 | /// Sets the number of loops for updating the seasonal and trend components. 478 | inline StlParams inner_loops(size_t loops) { 479 | this->ni_ = loops; 480 | return *this; 481 | } 482 | 483 | /// Sets the number of iterations of robust fitting. 484 | inline StlParams outer_loops(size_t loops) { 485 | this->no_ = loops; 486 | return *this; 487 | } 488 | 489 | /// Sets whether robustness iterations are to be used. 490 | inline StlParams robust(bool robust) { 491 | this->robust_ = robust; 492 | return *this; 493 | } 494 | 495 | /// Decomposes a time series from an array. 496 | template 497 | StlResult fit(const T* series, size_t series_size, size_t period) const; 498 | 499 | /// Decomposes a time series from a vector. 500 | template 501 | StlResult fit(const std::vector& series, size_t period) const; 502 | 503 | #if __cplusplus >= 202002L 504 | /// Decomposes a time series from a span. 505 | template 506 | StlResult fit(std::span series, size_t period) const; 507 | #endif 508 | }; 509 | 510 | /// Creates a new set of STL parameters. 511 | inline StlParams params() { 512 | return StlParams(); 513 | } 514 | 515 | template 516 | StlResult StlParams::fit(const T* series, size_t series_size, size_t period) const { 517 | auto y = series; 518 | auto np = period; 519 | auto n = series_size; 520 | 521 | if (n < 2 * np) { 522 | throw std::invalid_argument("series has less than two periods"); 523 | } 524 | 525 | auto ns = this->ns_.value_or(np); 526 | 527 | auto isdeg = this->isdeg_; 528 | auto itdeg = this->itdeg_; 529 | 530 | auto res = StlResult { 531 | std::vector(n), 532 | std::vector(n), 533 | std::vector(), 534 | std::vector(n) 535 | }; 536 | 537 | auto ildeg = this->ildeg_.value_or(itdeg); 538 | auto newns = std::max(ns, (size_t) 3); 539 | if (newns % 2 == 0) { 540 | newns += 1; 541 | } 542 | 543 | auto newnp = std::max(np, (size_t) 2); 544 | auto nt = (size_t) ceil((1.5 * newnp) / (1.0 - 1.5 / (float) newns)); 545 | nt = this->nt_.value_or(nt); 546 | nt = std::max(nt, (size_t) 3); 547 | if (nt % 2 == 0) { 548 | nt += 1; 549 | } 550 | 551 | auto nl = this->nl_.value_or(newnp); 552 | if (nl % 2 == 0 && !this->nl_.has_value()) { 553 | nl += 1; 554 | } 555 | 556 | auto ni = this->ni_.value_or(this->robust_ ? 1 : 2); 557 | auto no = this->no_.value_or(this->robust_ ? 15 : 0); 558 | 559 | auto nsjump = this->nsjump_.value_or((size_t) ceil(((float) newns) / 10.0)); 560 | auto ntjump = this->ntjump_.value_or((size_t) ceil(((float) nt) / 10.0)); 561 | auto nljump = this->nljump_.value_or((size_t) ceil(((float) nl) / 10.0)); 562 | 563 | stl(y, n, newnp, newns, nt, nl, isdeg, itdeg, ildeg, nsjump, ntjump, nljump, ni, no, res.weights.data(), res.seasonal.data(), res.trend.data()); 564 | 565 | res.remainder.reserve(n); 566 | for (size_t i = 0; i < n; i++) { 567 | res.remainder.push_back(y[i] - res.seasonal[i] - res.trend[i]); 568 | } 569 | 570 | return res; 571 | } 572 | 573 | template 574 | StlResult StlParams::fit(const std::vector& series, size_t period) const { 575 | return StlParams::fit(series.data(), series.size(), period); 576 | } 577 | 578 | #if __cplusplus >= 202002L 579 | template 580 | StlResult StlParams::fit(std::span series, size_t period) const { 581 | return StlParams::fit(series.data(), series.size(), period); 582 | } 583 | #endif 584 | 585 | /// A MSTL result. 586 | template 587 | class MstlResult { 588 | public: 589 | /// Returns the seasonal component. 590 | std::vector> seasonal; 591 | 592 | /// Returns the trend component. 593 | std::vector trend; 594 | 595 | /// Returns the remainder. 596 | std::vector remainder; 597 | 598 | /// Returns the seasonal strength. 599 | inline std::vector seasonal_strength() const { 600 | std::vector res; 601 | for (auto& s : seasonal) { 602 | res.push_back(strength(s, remainder)); 603 | } 604 | return res; 605 | } 606 | 607 | /// Returns the trend strength. 608 | inline double trend_strength() const { 609 | return strength(trend, remainder); 610 | } 611 | }; 612 | 613 | /// A set of MSTL parameters. 614 | class MstlParams { 615 | size_t iterate_ = 2; 616 | std::optional lambda_ = std::nullopt; 617 | std::optional> swin_ = std::nullopt; 618 | StlParams stl_params_; 619 | 620 | public: 621 | /// Sets the number of iterations. 622 | inline MstlParams iterations(size_t iterations) { 623 | this->iterate_ = iterations; 624 | return *this; 625 | } 626 | 627 | /// Sets lambda for Box-Cox transformation. 628 | inline MstlParams lambda(float lambda) { 629 | this->lambda_ = lambda; 630 | return *this; 631 | } 632 | 633 | /// Sets the lengths of the seasonal smoothers. 634 | inline MstlParams seasonal_lengths(const std::vector& lengths) { 635 | this->swin_ = lengths; 636 | return *this; 637 | } 638 | 639 | /// Sets the STL parameters. 640 | inline MstlParams stl_params(const StlParams& stl_params) { 641 | this->stl_params_ = stl_params; 642 | return *this; 643 | } 644 | 645 | /// Decomposes a time series from an array. 646 | template 647 | MstlResult fit(const T* series, size_t series_size, const size_t* periods, size_t periods_size) const; 648 | 649 | /// Decomposes a time series from a vector. 650 | template 651 | MstlResult fit(const std::vector& series, const std::vector& periods) const; 652 | 653 | #if __cplusplus >= 202002L 654 | /// Decomposes a time series from a span. 655 | template 656 | MstlResult fit(std::span series, std::span periods) const; 657 | #endif 658 | }; 659 | 660 | /// Creates a new set of MSTL parameters. 661 | inline MstlParams mstl_params() { 662 | return MstlParams(); 663 | } 664 | 665 | namespace { 666 | 667 | template 668 | std::vector box_cox(const T* y, size_t y_size, float lambda) { 669 | std::vector res; 670 | res.reserve(y_size); 671 | if (lambda != 0.0) { 672 | for (size_t i = 0; i < y_size; i++) { 673 | res.push_back((T) (std::pow(y[i], lambda) - 1.0) / lambda); 674 | } 675 | } else { 676 | for (size_t i = 0; i < y_size; i++) { 677 | res.push_back(std::log(y[i])); 678 | } 679 | } 680 | return res; 681 | } 682 | 683 | template 684 | std::tuple, std::vector, std::vector>> mstl( 685 | const T* x, 686 | size_t k, 687 | const size_t* seas_ids, 688 | size_t seas_size, 689 | size_t iterate, 690 | std::optional lambda, 691 | const std::optional>& swin, 692 | const StlParams& stl_params 693 | ) { 694 | // keep track of indices instead of sorting seas_ids 695 | // so order is preserved with seasonality 696 | std::vector indices; 697 | for (size_t i = 0; i < seas_size; i++) { 698 | indices.push_back(i); 699 | } 700 | std::sort(indices.begin(), indices.end(), [&seas_ids](size_t a, size_t b) { 701 | return seas_ids[a] < seas_ids[b]; 702 | }); 703 | 704 | if (seas_size == 1) { 705 | iterate = 1; 706 | } 707 | 708 | std::vector> seasonality; 709 | seasonality.reserve(seas_size); 710 | std::vector trend; 711 | 712 | auto deseas = lambda.has_value() ? box_cox(x, k, lambda.value()) : std::vector(x, x + k); 713 | 714 | if (seas_size != 0) { 715 | for (size_t i = 0; i < seas_size; i++) { 716 | seasonality.push_back(std::vector()); 717 | } 718 | 719 | for (size_t j = 0; j < iterate; j++) { 720 | for (size_t i = 0; i < indices.size(); i++) { 721 | auto idx = indices[i]; 722 | 723 | if (j > 0) { 724 | for (size_t ii = 0; ii < deseas.size(); ii++) { 725 | deseas[ii] += seasonality[idx][ii]; 726 | } 727 | } 728 | 729 | StlResult fit; 730 | if (swin) { 731 | StlParams clone = stl_params; 732 | fit = clone.seasonal_length((*swin)[idx]).fit(deseas, seas_ids[idx]); 733 | } else if (stl_params.ns_.has_value()) { 734 | fit = stl_params.fit(deseas, seas_ids[idx]); 735 | } else { 736 | StlParams clone = stl_params; 737 | fit = clone.seasonal_length(7 + 4 * (i + 1)).fit(deseas, seas_ids[idx]); 738 | } 739 | 740 | seasonality[idx] = fit.seasonal; 741 | trend = fit.trend; 742 | 743 | for (size_t ii = 0; ii < deseas.size(); ii++) { 744 | deseas[ii] -= seasonality[idx][ii]; 745 | } 746 | } 747 | } 748 | } else { 749 | // TODO use Friedman's Super Smoother for trend 750 | throw std::invalid_argument("periods must not be empty"); 751 | } 752 | 753 | std::vector remainder; 754 | remainder.reserve(k); 755 | for (size_t i = 0; i < k; i++) { 756 | remainder.push_back(deseas[i] - trend[i]); 757 | } 758 | 759 | return std::make_tuple(trend, remainder, seasonality); 760 | } 761 | 762 | } 763 | 764 | template 765 | MstlResult MstlParams::fit(const T* series, size_t series_size, const size_t* periods, size_t periods_size) const { 766 | // return error to be consistent with stl 767 | // and ensure seasonal is always same length as periods 768 | for (size_t i = 0; i < periods_size; i++) { 769 | if (periods[i] < 2) { 770 | throw std::invalid_argument("periods must be at least 2"); 771 | } 772 | } 773 | 774 | // return error to be consistent with stl 775 | // and ensure seasonal is always same length as periods 776 | for (size_t i = 0; i < periods_size; i++) { 777 | if (series_size < periods[i] * 2) { 778 | throw std::invalid_argument("series has less than two periods"); 779 | } 780 | } 781 | 782 | if (lambda_.has_value()) { 783 | auto lambda = lambda_.value(); 784 | if (lambda < 0 || lambda > 1) { 785 | throw std::invalid_argument("lambda must be between 0 and 1"); 786 | } 787 | } 788 | 789 | if (swin_.has_value()) { 790 | auto swin = swin_.value(); 791 | if (swin.size() != periods_size) { 792 | throw std::invalid_argument("seasonal_lengths must have the same length as periods"); 793 | } 794 | } 795 | 796 | auto [trend, remainder, seasonal] = mstl( 797 | series, 798 | series_size, 799 | periods, 800 | periods_size, 801 | iterate_, 802 | lambda_, 803 | swin_, 804 | stl_params_ 805 | ); 806 | 807 | return MstlResult { 808 | seasonal, 809 | trend, 810 | remainder 811 | }; 812 | } 813 | 814 | template 815 | MstlResult MstlParams::fit(const std::vector& series, const std::vector& periods) const { 816 | return MstlParams::fit(series.data(), series.size(), periods.data(), periods.size()); 817 | } 818 | 819 | #if __cplusplus >= 202002L 820 | template 821 | MstlResult MstlParams::fit(std::span series, std::span periods) const { 822 | return MstlParams::fit(series.data(), series.size(), periods.data(), periods.size()); 823 | } 824 | #endif 825 | 826 | } 827 | -------------------------------------------------------------------------------- /lib/stl-rb.rb: -------------------------------------------------------------------------------- 1 | require_relative "stl" 2 | -------------------------------------------------------------------------------- /lib/stl.rb: -------------------------------------------------------------------------------- 1 | # ext 2 | require "stl/ext" 3 | 4 | # modules 5 | require_relative "stl/version" 6 | 7 | module Stl 8 | class << self 9 | def decompose( 10 | series, period:, 11 | seasonal_length: nil, trend_length: nil, low_pass_length: nil, 12 | seasonal_degree: nil, trend_degree: nil, low_pass_degree: nil, 13 | seasonal_jump: nil, trend_jump: nil, low_pass_jump: nil, 14 | inner_loops: nil, outer_loops: nil, robust: false 15 | ) 16 | if period < 2 17 | raise ArgumentError, "period must be greater than 1" 18 | end 19 | 20 | params = StlParams.new 21 | 22 | params.seasonal_length(seasonal_length) unless seasonal_length.nil? 23 | params.trend_length(trend_length) unless trend_length.nil? 24 | params.low_pass_length(low_pass_length) unless low_pass_length.nil? 25 | 26 | params.seasonal_degree(seasonal_degree) unless seasonal_degree.nil? 27 | params.trend_degree(trend_degree) unless trend_degree.nil? 28 | params.low_pass_degree(low_pass_degree) unless low_pass_degree.nil? 29 | 30 | params.seasonal_jump(seasonal_jump) unless seasonal_jump.nil? 31 | params.trend_jump(trend_jump) unless trend_jump.nil? 32 | params.low_pass_jump(low_pass_jump) unless low_pass_jump.nil? 33 | 34 | params.inner_loops(inner_loops) unless inner_loops.nil? 35 | params.outer_loops(outer_loops) unless outer_loops.nil? 36 | params.robust(robust) unless robust.nil? 37 | 38 | if series.is_a?(Hash) 39 | sorted = series.sort_by { |k, _| k } 40 | y = sorted.map(&:last) 41 | else 42 | y = series 43 | end 44 | 45 | params.fit(y, period, outer_loops.nil? ? robust : outer_loops > 0) 46 | end 47 | 48 | def plot(series, result) 49 | require "vega" 50 | 51 | data = 52 | if series.is_a?(Hash) 53 | series.sort_by { |k, _| k }.map.with_index do |s, i| 54 | { 55 | x: iso8601(s[0]), 56 | series: s[1], 57 | seasonal: result[:seasonal][i], 58 | trend: result[:trend][i], 59 | remainder: result[:remainder][i] 60 | } 61 | end 62 | else 63 | series.map.with_index do |v, i| 64 | { 65 | x: i, 66 | series: v, 67 | seasonal: result[:seasonal][i], 68 | trend: result[:trend][i], 69 | remainder: result[:remainder][i] 70 | } 71 | end 72 | end 73 | 74 | if series.is_a?(Hash) 75 | x = {field: "x", type: "temporal"} 76 | x["scale"] = {type: "utc"} if series.keys.first.is_a?(Date) 77 | else 78 | x = {field: "x", type: "quantitative"} 79 | end 80 | x[:axis] = {title: nil, labelFontSize: 12} 81 | 82 | charts = 83 | ["series", "seasonal", "trend", "remainder"].map do |field| 84 | { 85 | mark: {type: "line"}, 86 | encoding: { 87 | x: x, 88 | y: {field: field, type: "quantitative", scale: {zero: false}, axis: {labelFontSize: 12}} 89 | }, 90 | width: "container", 91 | height: 100 92 | } 93 | end 94 | 95 | Vega.lite 96 | .data(data) 97 | .vconcat(charts) 98 | .config(autosize: {type: "fit-x", contains: "padding"}) 99 | .width(nil) # prevents warning 100 | .height(nil) # prevents warning and sets div height to auto 101 | end 102 | 103 | def seasonal_strength(result) 104 | sr = result[:seasonal].zip(result[:remainder]).map { |a, b| a + b } 105 | [0, 1 - var(result[:remainder]) / var(sr)].max 106 | end 107 | 108 | def trend_strength(result) 109 | tr = result[:trend].zip(result[:remainder]).map { |a, b| a + b } 110 | [0, 1 - var(result[:remainder]) / var(tr)].max 111 | end 112 | 113 | private 114 | 115 | def iso8601(v) 116 | if v.is_a?(Date) 117 | v.strftime("%Y-%m-%d") 118 | else 119 | v.strftime("%Y-%m-%dT%H:%M:%S.%L%z") 120 | end 121 | end 122 | 123 | def var(series) 124 | mean = series.sum / series.size.to_f 125 | series.sum { |v| (v - mean) ** 2 } / (series.size.to_f - 1) 126 | end 127 | end 128 | end 129 | -------------------------------------------------------------------------------- /lib/stl/version.rb: -------------------------------------------------------------------------------- 1 | module Stl 2 | VERSION = "0.3.0" 3 | end 4 | -------------------------------------------------------------------------------- /stl-rb.gemspec: -------------------------------------------------------------------------------- 1 | require_relative "lib/stl/version" 2 | 3 | Gem::Specification.new do |spec| 4 | spec.name = "stl-rb" 5 | spec.version = Stl::VERSION 6 | spec.summary = "Seasonal-trend decomposition for Ruby" 7 | spec.homepage = "https://github.com/ankane/stl-ruby" 8 | spec.license = "Unlicense OR MIT" 9 | 10 | spec.author = "Andrew Kane" 11 | spec.email = "andrew@ankane.org" 12 | 13 | spec.files = Dir["*.{md,txt}", "{ext,lib}/**/*"] 14 | spec.require_path = "lib" 15 | spec.extensions = ["ext/stl/extconf.rb"] 16 | 17 | spec.required_ruby_version = ">= 3.1" 18 | 19 | spec.add_dependency "rice", ">= 4.3.3" 20 | end 21 | -------------------------------------------------------------------------------- /test/stl_test.rb: -------------------------------------------------------------------------------- 1 | require_relative "test_helper" 2 | 3 | class StlTest < Minitest::Test 4 | def test_hash 5 | today = Date.today 6 | series = self.series.map.with_index.to_h { |v, i| [today + i, v] } 7 | result = Stl.decompose(series, period: 7) 8 | assert_elements_in_delta [0.36926576, 0.75655484, -1.3324139, 1.9553658, -0.6044802], result[:seasonal].first(5) 9 | assert_elements_in_delta [4.804099, 4.9097075, 5.015316, 5.16045, 5.305584], result[:trend].first(5) 10 | assert_elements_in_delta [-0.17336464, 3.3337379, -1.6829021, 1.8841844, -4.7011037], result[:remainder].first(5) 11 | end 12 | 13 | def test_array 14 | result = Stl.decompose(series, period: 7) 15 | assert_elements_in_delta [0.36926576, 0.75655484, -1.3324139, 1.9553658, -0.6044802], result[:seasonal].first(5) 16 | assert_elements_in_delta [4.804099, 4.9097075, 5.015316, 5.16045, 5.305584], result[:trend].first(5) 17 | assert_elements_in_delta [-0.17336464, 3.3337379, -1.6829021, 1.8841844, -4.7011037], result[:remainder].first(5) 18 | end 19 | 20 | def test_robust 21 | result = Stl.decompose(series, period: 7, robust: true) 22 | assert_elements_in_delta [0.14922355, 0.47939026, -1.833231, 1.7411387, 0.8200711], result[:seasonal].first(5) 23 | assert_elements_in_delta [5.397365, 5.4745436, 5.5517216, 5.6499176, 5.748114], result[:trend].first(5) 24 | assert_elements_in_delta [-0.5465884, 3.0460663, -1.7184906, 1.6089439, -6.5681853], result[:remainder].first(5) 25 | assert_elements_in_delta [0.99374926, 0.8129377, 0.9385952, 0.9458036, 0.29742217], result[:weights].first(5) 26 | end 27 | 28 | def test_repeating 29 | series = 24.times.to_a.shuffle * 8 30 | result = Stl.decompose(series, period: 24) 31 | assert_elements_in_delta [0] * series.size, result[:remainder] 32 | assert_elements_in_delta [11.5] * series.size, result[:trend] 33 | end 34 | 35 | def test_period_one 36 | error = assert_raises(ArgumentError) do 37 | Stl.decompose(series, period: 1) 38 | end 39 | assert_equal "period must be greater than 1", error.message 40 | end 41 | 42 | def test_too_few_periods 43 | error = assert_raises(ArgumentError) do 44 | Stl.decompose(series, period: 16) 45 | end 46 | assert_equal "series has less than two periods", error.message 47 | end 48 | 49 | def test_bad_seasonal_degree 50 | error = assert_raises(ArgumentError) do 51 | Stl.decompose(series, period: 7, seasonal_degree: 2) 52 | end 53 | assert_equal "seasonal_degree must be 0 or 1", error.message 54 | end 55 | 56 | def test_plot_hash 57 | today = Date.today 58 | series = self.series.map.with_index.to_h { |v, i| [today + i, v] } 59 | result = Stl.decompose(series, period: 7) 60 | assert_kind_of Vega::LiteChart, Stl.plot(series, result) 61 | end 62 | 63 | def test_plot_array 64 | result = Stl.decompose(series, period: 7) 65 | assert_kind_of Vega::LiteChart, Stl.plot(series, result) 66 | end 67 | 68 | def test_seasonal_strength 69 | result = Stl.decompose(series, period: 7) 70 | assert_in_delta 0.284111676315015, Stl.seasonal_strength(result) 71 | end 72 | 73 | def test_seasonal_strength_max 74 | series = 30.times.map { |i| i % 7 } 75 | result = Stl.decompose(series, period: 7) 76 | assert_in_delta 1, Stl.seasonal_strength(result) 77 | end 78 | 79 | def test_trend_strength 80 | result = Stl.decompose(series, period: 7) 81 | assert_in_delta 0.16384245231864702, Stl.trend_strength(result) 82 | end 83 | 84 | def test_trend_strength_max 85 | series = 30.times.to_a 86 | result = Stl.decompose(series, period: 7) 87 | assert_in_delta 1, Stl.trend_strength(result) 88 | end 89 | 90 | def series 91 | [ 92 | 5.0, 9.0, 2.0, 9.0, 0.0, 6.0, 3.0, 8.0, 5.0, 8.0, 93 | 7.0, 8.0, 8.0, 0.0, 2.0, 5.0, 0.0, 5.0, 6.0, 7.0, 94 | 3.0, 6.0, 1.0, 4.0, 4.0, 4.0, 3.0, 7.0, 5.0, 8.0 95 | ] 96 | end 97 | end 98 | -------------------------------------------------------------------------------- /test/test_helper.rb: -------------------------------------------------------------------------------- 1 | require "bundler/setup" 2 | Bundler.require(:default) 3 | require "minitest/autorun" 4 | require "minitest/pride" 5 | require "date" 6 | 7 | class Minitest::Test 8 | def assert_elements_in_delta(expected, actual) 9 | assert_equal expected.size, actual.size 10 | expected.zip(actual) do |exp, act| 11 | assert_in_delta exp, act 12 | end 13 | end 14 | end 15 | --------------------------------------------------------------------------------