├── .gitignore ├── .travis.yml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── pom.xml ├── python ├── Makefile ├── setup.py ├── source │ ├── conf.py │ ├── index.rst │ ├── modules.rst │ └── sparkts.rst └── sparkts │ ├── __init__.py │ ├── datetimeindex.py │ ├── test │ ├── __init__.py │ ├── resources │ │ └── conf │ │ │ └── log4j.properties │ ├── test_datetimeindex.py │ ├── test_timeseriesrdd.py │ └── test_utils.py │ ├── timeseriesrdd.py │ └── utils.py ├── scalastyle-config.xml └── src ├── main └── scala │ └── com │ ├── cloudera │ ├── finance │ │ ├── MultivariateTDistribution.scala │ │ ├── SerializableMultivariateNormalDistribution.scala │ │ ├── Util.scala │ │ ├── YahooParser.scala │ │ ├── examples │ │ │ ├── HistoricalValueAtRiskExample.scala │ │ │ ├── MonteCarloValueAtRiskExample.scala │ │ │ └── SimpleTickDataExample.scala │ │ └── risk │ │ │ └── FilteredHistoricalFactorDistribution.scala │ └── sparkts │ │ ├── ARIMA.scala │ │ ├── Autoregression.scala │ │ ├── DateTimeIndex.scala │ │ ├── EWMA.scala │ │ ├── EasyPlot.scala │ │ ├── Frequency.scala │ │ ├── GARCH.scala │ │ ├── Lag.scala │ │ ├── PythonConnector.scala │ │ ├── TimeSeries.scala │ │ ├── TimeSeriesKryoRegistrator.scala │ │ ├── TimeSeriesModel.scala │ │ ├── TimeSeriesRDD.scala │ │ ├── TimeSeriesStatisticalTests.scala │ │ ├── TimeSeriesUtils.scala │ │ └── UnivariateTimeSeries.scala │ ├── redislabs │ └── provider │ │ ├── RedisConfig.scala │ │ ├── redis │ │ ├── package.scala │ │ ├── partitioner │ │ │ ├── RedisPartition.scala │ │ │ └── RedisPartitioner.scala │ │ ├── rdd │ │ │ └── RedisRDD.scala │ │ └── redisFunctions.scala │ │ ├── sql │ │ ├── DefaultSource.scala │ │ ├── RedisSQLFunctions.scala │ │ └── package.scala │ │ └── util │ │ ├── GenerateWorkdayTestData.scala │ │ └── ImportTimeSeriesData.scala │ ├── run │ └── Main.scala │ └── test │ └── sql │ ├── DefaultSource.scala │ └── rrun.scala ├── site ├── markdown │ ├── docs │ │ └── users.md │ └── index.md └── site.xml └── test ├── resources ├── GOOG.csv ├── R_ARIMA_DataSet1.csv └── R_ARIMA_DataSet2.csv └── scala └── com └── cloudera ├── finance └── YahooParserSuite.scala └── sparkts ├── ARIMASuite.scala ├── AugmentedDickeyFullerSuite.scala ├── AutoregressionSuite.scala ├── BusinessDayFrequencySuite.scala ├── DateTimeIndexSuite.scala ├── EWMASuite.scala ├── FillSuite.scala ├── GARCHSuite.scala ├── LocalSparkContext.scala ├── RebaseSuite.scala ├── TimeSeriesRDDSuite.scala ├── TimeSeriesStatisticalTestsSuite.scala ├── TimeSeriesSuite.scala └── UnivariateTimeSeriesSuite.scala /.gitignore: -------------------------------------------------------------------------------- 1 | .classpath 2 | .project 3 | .settings 4 | .cache 5 | target 6 | *.iml 7 | .idea 8 | scalastyle-output.xml 9 | *.pyc 10 | *.swp 11 | python/build/ 12 | python/dist/ 13 | python/*.egg-info 14 | *.log 15 | 16 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: java 2 | sudo: false 3 | install: mvn ${SPARK} ${JAVA} ${SCALA} -Dmaven.javadoc.skip=true -DskipTests=true -B -V install 4 | script: mvn ${SPARK} ${JAVA} ${SCALA} -Dmaven.javadoc.skip=true -q -B verify 5 | env: 6 | global: 7 | - MAVEN_OPTS=-Xmx2g 8 | matrix: 9 | include: 10 | # Covers Java 7, Open JDK, and Spark 1.3.x 11 | - jdk: openjdk7 12 | # Covers Java 8, Oracle JDK, and Spark 1.4.x 13 | - jdk: oraclejdk8 14 | env: SPARK=-Dspark.version=1.4.1 JAVA=-Djava.version=1.8 15 | # Covers Scala 2.11 16 | - jdk: oraclejdk7 17 | env: SCALA=-Pscala-2.11 SPARK=-Dspark.version=1.4.1 18 | cache: 19 | directories: 20 | - $HOME/.m2 21 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Contributions via GitHub pull requests are gladly accepted from their original author. Along with any pull requests, please state that the contribution is your original work and that you license the work to the project under the project's open source license. Whether or not you state this explicitly, by submitting any copyrighted material via pull request, email, or other means you agree to license the material under the project's open source license and warrant that you have the legal authority to do so. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/RedisLabs/spark-timeseries.svg)](https://travis-ci.org/RedisLabs/spark-timeseries) 2 | 3 | spark-timeseries 4 | ============= 5 | 6 | A Scala / Python library for interacting with time series data on Apache Spark. 7 | 8 | This fork extends [Cloudera's spark-timeseries](https://github.com/cloudera/spark-timeseries) library with: 9 | 10 | * `RedisTimeSeriesRDD` that uses Redis' Sorted Sets as data source 11 | * Utility for generating sample dataset files - `com.redislabs.provider.util.GenerateWorkdayTestData` 12 | * Import tool that loads time series files to Redis - `com.redislabs.provider.util.ImportTimeSeriesData` 13 | * A test suite that compares execution times of common queries using different caching strategies can be found at `src/main/scala/com/run/Main.scala` 14 | 15 | **Note:** The redis extension is only implemented in the Scala library at the moment. 16 | 17 | Time series storage in Redis 18 | --- 19 | 20 | Each time series in the dataset is stored as a Redis Sorted Set. The Sorted Set's key name is the name of the time series. The set's members contain the sampled values and scores are set to the respective timestamp. Filters on keys, column names and date ranges are pushed down to Redis via the `RedisTimeSeriesRDD` for processing the data close to where it is stored and reduce traffic between Spark's workers and the storage. The `RedisTimeSeriesRDD` can be transformed to a `TimeSeriesRDD` to provide its full functionality. 21 | -------------------------------------------------------------------------------- /python/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = build 9 | 10 | # Internal variables. 11 | PAPEROPT_a4 = -D latex_paper_size=a4 12 | PAPEROPT_letter = -D latex_paper_size=letter 13 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source 14 | # the i18n builder cannot share the environment and doctrees with the others 15 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source 16 | 17 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 18 | 19 | help: 20 | @echo "Please use \`make ' where is one of" 21 | @echo " html to make standalone HTML files" 22 | @echo " dirhtml to make HTML files named index.html in directories" 23 | @echo " singlehtml to make a single large HTML file" 24 | @echo " pickle to make pickle files" 25 | @echo " json to make JSON files" 26 | @echo " htmlhelp to make HTML files and a HTML help project" 27 | @echo " qthelp to make HTML files and a qthelp project" 28 | @echo " devhelp to make HTML files and a Devhelp project" 29 | @echo " epub to make an epub" 30 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 31 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 32 | @echo " text to make text files" 33 | @echo " man to make manual pages" 34 | @echo " texinfo to make Texinfo files" 35 | @echo " info to make Texinfo files and run them through makeinfo" 36 | @echo " gettext to make PO message catalogs" 37 | @echo " changes to make an overview of all changed/added/deprecated items" 38 | @echo " linkcheck to check all external links for integrity" 39 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 40 | 41 | clean: 42 | -rm -rf $(BUILDDIR)/* 43 | 44 | html: 45 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 46 | @echo 47 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 48 | 49 | dirhtml: 50 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 51 | @echo 52 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 53 | 54 | singlehtml: 55 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 56 | @echo 57 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 58 | 59 | pickle: 60 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 61 | @echo 62 | @echo "Build finished; now you can process the pickle files." 63 | 64 | json: 65 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 66 | @echo 67 | @echo "Build finished; now you can process the JSON files." 68 | 69 | htmlhelp: 70 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 71 | @echo 72 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 73 | ".hhp project file in $(BUILDDIR)/htmlhelp." 74 | 75 | qthelp: 76 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 77 | @echo 78 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 79 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 80 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/spark-timeseries.qhcp" 81 | @echo "To view the help file:" 82 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/spark-timeseries.qhc" 83 | 84 | devhelp: 85 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 86 | @echo 87 | @echo "Build finished." 88 | @echo "To view the help file:" 89 | @echo "# mkdir -p $$HOME/.local/share/devhelp/spark-timeseries" 90 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/spark-timeseries" 91 | @echo "# devhelp" 92 | 93 | epub: 94 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 95 | @echo 96 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 97 | 98 | latex: 99 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 100 | @echo 101 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 102 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 103 | "(use \`make latexpdf' here to do that automatically)." 104 | 105 | latexpdf: 106 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 107 | @echo "Running LaTeX files through pdflatex..." 108 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 109 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 110 | 111 | text: 112 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 113 | @echo 114 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 115 | 116 | man: 117 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 118 | @echo 119 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 120 | 121 | texinfo: 122 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 123 | @echo 124 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 125 | @echo "Run \`make' in that directory to run these through makeinfo" \ 126 | "(use \`make info' here to do that automatically)." 127 | 128 | info: 129 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 130 | @echo "Running Texinfo files through makeinfo..." 131 | make -C $(BUILDDIR)/texinfo info 132 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 133 | 134 | gettext: 135 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 136 | @echo 137 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 138 | 139 | changes: 140 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 141 | @echo 142 | @echo "The overview file is in $(BUILDDIR)/changes." 143 | 144 | linkcheck: 145 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 146 | @echo 147 | @echo "Link check complete; look for any errors in the above output " \ 148 | "or in $(BUILDDIR)/linkcheck/output.txt." 149 | 150 | doctest: 151 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 152 | @echo "Testing of doctests in the sources finished, look at the " \ 153 | "results in $(BUILDDIR)/doctest/output.txt." 154 | -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import os 3 | 4 | VERSION = '0.0.1' 5 | 6 | setup( 7 | name='sparktimeseries', 8 | version=VERSION, 9 | packages=find_packages() 10 | ) 11 | -------------------------------------------------------------------------------- /python/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # spark-timeseries documentation build configuration file, created by 4 | # sphinx-quickstart on Tue Sep 15 17:34:03 2015. 5 | # 6 | # This file is execfile()d with the current directory set to its containing dir. 7 | # 8 | # Note that not all possible configuration values are present in this 9 | # autogenerated file. 10 | # 11 | # All configuration values have a default; values that are commented out 12 | # serve to show the default. 13 | 14 | import sys, os 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | #sys.path.insert(0, os.path.abspath('.')) 20 | 21 | # -- General configuration ----------------------------------------------------- 22 | 23 | # If your documentation needs a minimal Sphinx version, state it here. 24 | #needs_sphinx = '1.0' 25 | 26 | # Add any Sphinx extension module names here, as strings. They can be extensions 27 | # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 28 | extensions = ['sphinx.ext.autodoc', 'sphinx.ext.napoleon'] 29 | 30 | # Add any paths that contain templates here, relative to this directory. 31 | templates_path = ['_templates'] 32 | 33 | # The suffix of source filenames. 34 | source_suffix = '.rst' 35 | 36 | # The encoding of source files. 37 | #source_encoding = 'utf-8-sig' 38 | 39 | # The master toctree document. 40 | master_doc = 'index' 41 | 42 | # General information about the project. 43 | project = u'spark-timeseries' 44 | copyright = u'2015, Sandy Ryza' 45 | 46 | # The version info for the project you're documenting, acts as replacement for 47 | # |version| and |release|, also used in various other places throughout the 48 | # built documents. 49 | # 50 | # The short X.Y version. 51 | version = '0.0.1' 52 | # The full version, including alpha/beta/rc tags. 53 | release = '0.0.1' 54 | 55 | # The language for content autogenerated by Sphinx. Refer to documentation 56 | # for a list of supported languages. 57 | #language = None 58 | 59 | # There are two options for replacing |today|: either, you set today to some 60 | # non-false value, then it is used: 61 | #today = '' 62 | # Else, today_fmt is used as the format for a strftime call. 63 | #today_fmt = '%B %d, %Y' 64 | 65 | # List of patterns, relative to source directory, that match files and 66 | # directories to ignore when looking for source files. 67 | exclude_patterns = [] 68 | 69 | # The reST default role (used for this markup: `text`) to use for all documents. 70 | #default_role = None 71 | 72 | # If true, '()' will be appended to :func: etc. cross-reference text. 73 | #add_function_parentheses = True 74 | 75 | # If true, the current module name will be prepended to all description 76 | # unit titles (such as .. function::). 77 | #add_module_names = True 78 | 79 | # If true, sectionauthor and moduleauthor directives will be shown in the 80 | # output. They are ignored by default. 81 | #show_authors = False 82 | 83 | # The name of the Pygments (syntax highlighting) style to use. 84 | pygments_style = 'sphinx' 85 | 86 | # A list of ignored prefixes for module index sorting. 87 | #modindex_common_prefix = [] 88 | 89 | 90 | # -- Options for HTML output --------------------------------------------------- 91 | 92 | # The theme to use for HTML and HTML Help pages. See the documentation for 93 | # a list of builtin themes. 94 | html_theme = 'default' 95 | 96 | # Theme options are theme-specific and customize the look and feel of a theme 97 | # further. For a list of options available for each theme, see the 98 | # documentation. 99 | #html_theme_options = {} 100 | 101 | # Add any paths that contain custom themes here, relative to this directory. 102 | #html_theme_path = [] 103 | 104 | # The name for this set of Sphinx documents. If None, it defaults to 105 | # " v documentation". 106 | #html_title = None 107 | 108 | # A shorter title for the navigation bar. Default is the same as html_title. 109 | #html_short_title = None 110 | 111 | # The name of an image file (relative to this directory) to place at the top 112 | # of the sidebar. 113 | #html_logo = None 114 | 115 | # The name of an image file (within the static path) to use as favicon of the 116 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 117 | # pixels large. 118 | #html_favicon = None 119 | 120 | # Add any paths that contain custom static files (such as style sheets) here, 121 | # relative to this directory. They are copied after the builtin static files, 122 | # so a file named "default.css" will overwrite the builtin "default.css". 123 | html_static_path = ['_static'] 124 | 125 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 126 | # using the given strftime format. 127 | #html_last_updated_fmt = '%b %d, %Y' 128 | 129 | # If true, SmartyPants will be used to convert quotes and dashes to 130 | # typographically correct entities. 131 | #html_use_smartypants = True 132 | 133 | # Custom sidebar templates, maps document names to template names. 134 | #html_sidebars = {} 135 | 136 | # Additional templates that should be rendered to pages, maps page names to 137 | # template names. 138 | #html_additional_pages = {} 139 | 140 | # If false, no module index is generated. 141 | #html_domain_indices = True 142 | 143 | # If false, no index is generated. 144 | #html_use_index = True 145 | 146 | # If true, the index is split into individual pages for each letter. 147 | #html_split_index = False 148 | 149 | # If true, links to the reST sources are added to the pages. 150 | #html_show_sourcelink = True 151 | 152 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 153 | #html_show_sphinx = True 154 | 155 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 156 | #html_show_copyright = True 157 | 158 | # If true, an OpenSearch description file will be output, and all pages will 159 | # contain a tag referring to it. The value of this option must be the 160 | # base URL from which the finished HTML is served. 161 | #html_use_opensearch = '' 162 | 163 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 164 | #html_file_suffix = None 165 | 166 | # Output file base name for HTML help builder. 167 | htmlhelp_basename = 'spark-timeseriesdoc' 168 | 169 | 170 | # -- Options for LaTeX output -------------------------------------------------- 171 | 172 | latex_elements = { 173 | # The paper size ('letterpaper' or 'a4paper'). 174 | #'papersize': 'letterpaper', 175 | 176 | # The font size ('10pt', '11pt' or '12pt'). 177 | #'pointsize': '10pt', 178 | 179 | # Additional stuff for the LaTeX preamble. 180 | #'preamble': '', 181 | } 182 | 183 | # Grouping the document tree into LaTeX files. List of tuples 184 | # (source start file, target name, title, author, documentclass [howto/manual]). 185 | latex_documents = [ 186 | ('index', 'spark-timeseries.tex', u'spark-timeseries Documentation', 187 | u'Sandy Ryza', 'manual'), 188 | ] 189 | 190 | # The name of an image file (relative to this directory) to place at the top of 191 | # the title page. 192 | #latex_logo = None 193 | 194 | # For "manual" documents, if this is true, then toplevel headings are parts, 195 | # not chapters. 196 | #latex_use_parts = False 197 | 198 | # If true, show page references after internal links. 199 | #latex_show_pagerefs = False 200 | 201 | # If true, show URL addresses after external links. 202 | #latex_show_urls = False 203 | 204 | # Documents to append as an appendix to all manuals. 205 | #latex_appendices = [] 206 | 207 | # If false, no module index is generated. 208 | #latex_domain_indices = True 209 | 210 | 211 | # -- Options for manual page output -------------------------------------------- 212 | 213 | # One entry per manual page. List of tuples 214 | # (source start file, name, description, authors, manual section). 215 | man_pages = [ 216 | ('index', 'spark-timeseries', u'spark-timeseries Documentation', 217 | [u'Sandy Ryza'], 1) 218 | ] 219 | 220 | # If true, show URL addresses after external links. 221 | #man_show_urls = False 222 | 223 | 224 | # -- Options for Texinfo output ------------------------------------------------ 225 | 226 | # Grouping the document tree into Texinfo files. List of tuples 227 | # (source start file, target name, title, author, 228 | # dir menu entry, description, category) 229 | texinfo_documents = [ 230 | ('index', 'spark-timeseries', u'spark-timeseries Documentation', 231 | u'Sandy Ryza', 'spark-timeseries', 'One line description of project.', 232 | 'Miscellaneous'), 233 | ] 234 | 235 | # Documents to append as an appendix to all manuals. 236 | #texinfo_appendices = [] 237 | 238 | # If false, no module index is generated. 239 | #texinfo_domain_indices = True 240 | 241 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 242 | #texinfo_show_urls = 'footnote' 243 | -------------------------------------------------------------------------------- /python/source/index.rst: -------------------------------------------------------------------------------- 1 | .. spark-timeseries documentation master file, created by 2 | sphinx-quickstart on Tue Sep 15 17:34:03 2015. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to spark-timeseries's documentation! 7 | ============================================ 8 | 9 | Contents: 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | 14 | 15 | 16 | Indices and tables 17 | ================== 18 | 19 | * :ref:`genindex` 20 | * :ref:`modindex` 21 | * :ref:`search` 22 | 23 | -------------------------------------------------------------------------------- /python/source/modules.rst: -------------------------------------------------------------------------------- 1 | sparkts 2 | ======= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | sparkts 8 | -------------------------------------------------------------------------------- /python/source/sparkts.rst: -------------------------------------------------------------------------------- 1 | sparkts package 2 | =============== 3 | 4 | Submodules 5 | ---------- 6 | 7 | sparkts.datetimeindex module 8 | ---------------------------- 9 | 10 | .. automodule:: sparkts.datetimeindex 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | sparkts.timeseriesrdd module 16 | ---------------------------- 17 | 18 | .. automodule:: sparkts.timeseriesrdd 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | sparkts.utils module 24 | -------------------- 25 | 26 | .. automodule:: sparkts.utils 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | 32 | Module contents 33 | --------------- 34 | 35 | .. automodule:: sparkts 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | -------------------------------------------------------------------------------- /python/sparkts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RedisLabs/spark-timeseries/51aea2d514eb4c07b27b4b1bc3f978147bd99009/python/sparkts/__init__.py -------------------------------------------------------------------------------- /python/sparkts/datetimeindex.py: -------------------------------------------------------------------------------- 1 | from py4j.java_gateway import java_import 2 | from utils import datetime_to_millis 3 | import numpy as np 4 | import pandas as pd 5 | 6 | class DateTimeIndex(object): 7 | """ 8 | A DateTimeIndex maintains a bi-directional mapping between integers and an ordered collection of 9 | date-times. Multiple date-times may correspond to the same integer, implying multiple samples 10 | at the same date-time. 11 | 12 | To avoid confusion between the meaning of "index" as it appears in "DateTimeIndex" and "index" 13 | as a location in an array, in the context of this class, we use "location", or "loc", to refer 14 | to the latter. 15 | """ 16 | 17 | def __init__(self, jdt_index): 18 | self._jdt_index = jdt_index 19 | 20 | def __len__(self): 21 | """Returns the number of timestamps included in the index.""" 22 | return self._jdt_index.size() 23 | 24 | def first(self): 25 | """Returns the earliest timestamp in the index, as a Pandas Timestamp.""" 26 | millis = self._jdt_index.first().getMillis() 27 | return pd.Timestamp(millis * 1000000) 28 | 29 | def last(self): 30 | """Returns the latest timestamp in the index, as a Pandas Timestamp.""" 31 | millis = self._jdt_index.last().getMillis() 32 | return pd.Timestamp(millis * 1000000) 33 | 34 | def __getitem__(self, val): 35 | # TODO: throw an error if the step size is defined 36 | if isinstance(val, slice): 37 | start = datetime_to_millis(val.start) 38 | stop = datetime_to_millis(val.stop) 39 | jdt_index = self._jdt_index.slice(start, stop) 40 | return DateTimeIndex(jdt_index) 41 | else: 42 | return self._jdt_index.locAtDateTime(datetime_to_millis(val)) 43 | 44 | def islice(self, start, end): 45 | """ 46 | Returns a new DateTimeIndex, containing a subslice of the timestamps in this index, 47 | as specified by the given integer start and end locations. 48 | 49 | Parameters 50 | ---------- 51 | start : int 52 | The location of the start of the range, inclusive. 53 | end : int 54 | The location of the end of the range, exclusive. 55 | """ 56 | jdt_index = self._jdt_index.islice(start, end) 57 | return DateTimeIndex(jdt_index=jdt_index) 58 | 59 | def datetime_at_loc(self, loc): 60 | """Returns the timestamp at the given integer location as a Pandas Timestamp.""" 61 | millis = self._jdt_index.dateTimeAtLoc(loc).getMillis() 62 | return pd.Timestamp(millis * 1000000) 63 | 64 | def to_pandas_index(self): 65 | """Returns a pandas.DatetimeIndex representing the same date-times""" 66 | # TODO: we can probably speed this up for uniform indices 67 | arr = self._jdt_index.toMillisArray() 68 | arr = [arr[i] * 1000000 for i in xrange(len(self))] 69 | return pd.DatetimeIndex(arr) 70 | 71 | def __eq__(self, other): 72 | return self._jdt_index.equals(other._jdt_index) 73 | 74 | def __ne__(self, other): 75 | return not self.__eq__(other) 76 | 77 | def __repr__(self): 78 | return self._jdt_index.toString() 79 | 80 | class _Frequency(object): 81 | def __eq__(self, other): 82 | return self._jfreq.equals(other._jfreq) 83 | 84 | def __ne__(self, other): 85 | return not self.__eq__(other) 86 | 87 | class DayFrequency(_Frequency): 88 | """ 89 | A frequency that can be used for a uniform DateTimeIndex, where the period is given in days. 90 | """ 91 | 92 | def __init__(self, days, sc): 93 | self._jfreq = sc._jvm.com.cloudera.sparkts.DayFrequency(days) 94 | 95 | def days(self): 96 | return self._jfreq.days() 97 | 98 | class HourFrequency(_Frequency): 99 | """ 100 | A frequency that can be used for a uniform DateTimeIndex, where the period is given in hours. 101 | """ 102 | 103 | def __init__(self, hours, sc): 104 | self._jfreq = sc._jvm.com.cloudera.sparkts.HourFrequency(hours) 105 | 106 | def hours(self): 107 | return self_jfreq.hours() 108 | 109 | class BusinessDayFrequency(object): 110 | """ 111 | A frequency that can be used for a uniform DateTimeIndex, where the period is given in 112 | business days. 113 | """ 114 | 115 | def __init__(self, bdays, sc): 116 | self._jfreq = sc._jvm.com.cloudera.sparkts.BusinessDayFrequency(bdays) 117 | 118 | def __eq__(self, other): 119 | return self._jfreq.equals(other._jfreq) 120 | 121 | def __ne__(self, other): 122 | return not self.__eq__(other) 123 | 124 | def uniform(start, end=None, periods=None, freq=None, sc=None): 125 | """ 126 | Instantiates a uniform DateTimeIndex. 127 | 128 | Either end or periods must be specified. 129 | 130 | Parameters 131 | ---------- 132 | start : string, long (millis from epoch), or Pandas Timestamp 133 | end : string, long (millis from epoch), or Pandas Timestamp 134 | periods : int 135 | freq : a frequency object 136 | sc : SparkContext 137 | """ 138 | dtmodule = sc._jvm.com.cloudera.sparkts.__getattr__('DateTimeIndex$').__getattr__('MODULE$') 139 | if freq is None: 140 | raise ValueError("Missing frequency") 141 | elif end is None and periods == None: 142 | raise ValueError("Need an end date or number of periods") 143 | elif end is not None: 144 | return DateTimeIndex(dtmodule.uniform( \ 145 | datetime_to_millis(start), datetime_to_millis(end), freq._jfreq)) 146 | else: 147 | return DateTimeIndex(dtmodule.uniform( \ 148 | datetime_to_millis(start), periods, freq._jfreq)) 149 | 150 | def irregular(timestamps, sc): 151 | """ 152 | Instantiates an irregular DateTimeIndex. 153 | 154 | Parameters 155 | ---------- 156 | timestamps : a Pandas DateTimeIndex, or an array of strings, longs (nanos from epoch), Pandas 157 | Timestamps 158 | sc : SparkContext 159 | """ 160 | dtmodule = sc._jvm.com.cloudera.sparkts.__getattr__('DateTimeIndex$').__getattr__('MODULE$') 161 | arr = sc._gateway.new_array(sc._jvm.long, len(timestamps)) 162 | for i in xrange(len(timestamps)): 163 | arr[i] = datetime_to_millis(timestamps[i]) 164 | return DateTimeIndex(dtmodule.irregular(arr)) 165 | 166 | -------------------------------------------------------------------------------- /python/sparkts/test/__init__.py: -------------------------------------------------------------------------------- 1 | import subprocess as sub 2 | import os 3 | import logging 4 | 5 | def run_cmd(cmd): 6 | """Execute the command and return the output if successful. If 7 | unsuccessful, print the failed command and its output. 8 | """ 9 | try: 10 | out = sub.check_output(cmd, shell=True, stderr=sub.STDOUT) 11 | return out 12 | except sub.CalledProcessError, err: 13 | logging.error("The failed test setup command was [%s]." % err.cmd) 14 | logging.error("The output of the command was [%s]" % err.output) 15 | raise 16 | 17 | CHECK_SPARK_HOME = """ 18 | if [ -z "$SPARK_HOME" ]; then 19 | echo "Error: SPARK_HOME is not set, can't run tests." 20 | exit -1 21 | fi 22 | """ 23 | os.system(CHECK_SPARK_HOME) 24 | 25 | # Dynamically load project root dir and jars. 26 | project_root = os.getcwd() + "/../" 27 | jars = run_cmd("ls %s/target/sparktimeseries*jar-with-dependencies.jar" % project_root) 28 | 29 | # Set environment variables. 30 | os.environ["PYSPARK_SUBMIT_ARGS"] = \ 31 | ("--jars %s --driver-class-path %s pyspark-shell") % (jars, jars) 32 | 33 | os.environ["SPARK_CONF_DIR"] = "%s/test/resources/conf" % os.getcwd() 34 | 35 | -------------------------------------------------------------------------------- /python/sparkts/test/resources/conf/log4j.properties: -------------------------------------------------------------------------------- 1 | # Set everything to be logged to the console 2 | log4j.rootCategory=ERROR, console 3 | log4j.appender.console=org.apache.log4j.ConsoleAppender 4 | log4j.appender.console.target=System.err 5 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 6 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n 7 | 8 | # Settings to quiet third party logs that are too verbose 9 | log4j.logger.org.spark-project.jetty=WARN 10 | log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR 11 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO 12 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO 13 | -------------------------------------------------------------------------------- /python/sparkts/test/test_datetimeindex.py: -------------------------------------------------------------------------------- 1 | from test_utils import PySparkTestCase 2 | from sparkts.datetimeindex import * 3 | import pandas as pd 4 | 5 | class DateTimeIndexTestCase(PySparkTestCase): 6 | def test_uniform(self): 7 | freq = DayFrequency(3, self.sc) 8 | self.assertEqual(freq.days(), 3) 9 | start = '2015-04-10' 10 | index = uniform(start, periods=5, freq=freq, sc=self.sc) 11 | index2 = uniform(start, end='2015-04-22', freq=freq, sc=self.sc) 12 | self.assertEqual(index, index2) 13 | 14 | self.assertEqual(len(index), 5) 15 | self.assertEqual(index.first(), pd.to_datetime('2015-04-10')) 16 | self.assertEqual(index.last(), pd.to_datetime('2015-04-22')) 17 | subbydate = index[pd.to_datetime('2015-04-13'):pd.to_datetime('2015-04-19')] 18 | subbyloc = index.islice(1, 4) 19 | self.assertEqual(subbydate, subbyloc) 20 | self.assertEqual(subbydate.first(), pd.to_datetime('2015-04-13')) 21 | self.assertEqual(subbydate.last(), pd.to_datetime('2015-04-19')) 22 | self.assertEqual(subbydate.datetime_at_loc(0), pd.to_datetime('2015-04-13')) 23 | self.assertEqual(subbydate[pd.to_datetime('2015-04-13')], 0) 24 | 25 | def test_irregular(self): 26 | pd_index = pd.date_range('2015-04-10', periods=5, freq='3D') 27 | dt_index = irregular(pd_index, self.sc) 28 | 29 | self.assertEqual(len(dt_index), 5) 30 | self.assertEqual(dt_index.first(), pd.to_datetime('2015-04-10')) 31 | self.assertEqual(dt_index.last(), pd.to_datetime('2015-04-22')) 32 | subbydate = dt_index[pd.to_datetime('2015-04-13'):pd.to_datetime('2015-04-19')] 33 | subbyloc = dt_index.islice(1, 4) 34 | self.assertEqual(subbydate, subbyloc) 35 | self.assertEqual(subbydate.first(), pd.to_datetime('2015-04-13')) 36 | self.assertEqual(subbydate.last(), pd.to_datetime('2015-04-19')) 37 | self.assertEqual(subbydate.datetime_at_loc(0), pd.to_datetime('2015-04-13')) 38 | self.assertEqual(subbydate[pd.to_datetime('2015-04-13')], 0) 39 | 40 | pd_index2 = dt_index.to_pandas_index() 41 | self.assertTrue(pd_index.equals(pd_index2), str(pd_index) + "!=" + str(pd_index2)) 42 | 43 | -------------------------------------------------------------------------------- /python/sparkts/test/test_timeseriesrdd.py: -------------------------------------------------------------------------------- 1 | from test_utils import PySparkTestCase 2 | from sparkts.timeseriesrdd import * 3 | from sparkts.timeseriesrdd import _TimeSeriesSerializer 4 | from sparkts.datetimeindex import * 5 | import pandas as pd 6 | import numpy as np 7 | from unittest import TestCase 8 | from io import BytesIO 9 | from pyspark.sql import SQLContext 10 | 11 | class TimeSeriesSerializerTestCase(TestCase): 12 | def test_times_series_serializer(self): 13 | serializer = _TimeSeriesSerializer() 14 | stream = BytesIO() 15 | series = [('abc', np.array([4.0, 4.0, 5.0])), ('123', np.array([1.0, 2.0, 3.0]))] 16 | serializer.dump_stream(iter(series), stream) 17 | stream.seek(0) 18 | reconstituted = list(serializer.load_stream(stream)) 19 | self.assertEquals(reconstituted[0][0], series[0][0]) 20 | self.assertEquals(reconstituted[1][0], series[1][0]) 21 | self.assertTrue((reconstituted[0][1] == series[0][1]).all()) 22 | self.assertTrue((reconstituted[1][1] == series[1][1]).all()) 23 | 24 | class TimeSeriesRDDTestCase(PySparkTestCase): 25 | def test_time_series_rdd(self): 26 | freq = DayFrequency(1, self.sc) 27 | start = '2015-04-09' 28 | dt_index = uniform(start, periods=10, freq=freq, sc=self.sc) 29 | vecs = [np.arange(0, 10), np.arange(10, 20), np.arange(20, 30)] 30 | rdd = self.sc.parallelize(vecs).map(lambda x: (str(x[0]), x)) 31 | tsrdd = TimeSeriesRDD(dt_index, rdd) 32 | self.assertEquals(tsrdd.count(), 3) 33 | 34 | contents = tsrdd.collectAsMap() 35 | self.assertEquals(len(contents), 3) 36 | self.assertTrue((contents["0"] == np.arange(0, 10)).all()) 37 | self.assertTrue((contents["10"] == np.arange(10, 20)).all()) 38 | self.assertTrue((contents["20"] == np.arange(20, 30)).all()) 39 | 40 | subslice = tsrdd['2015-04-10':'2015-04-15'] 41 | self.assertEquals(subslice.index(), uniform('2015-04-10', periods=6, freq=freq, sc=self.sc)) 42 | contents = subslice.collectAsMap() 43 | self.assertEquals(len(contents), 3) 44 | self.assertTrue((contents["0"] == np.arange(1, 7)).all()) 45 | self.assertTrue((contents["10"] == np.arange(11, 17)).all()) 46 | self.assertTrue((contents["20"] == np.arange(21, 27)).all()) 47 | 48 | def test_to_instants(self): 49 | vecs = [np.arange(x, x + 4) for x in np.arange(0, 20, 4)] 50 | labels = ['a', 'b', 'c', 'd', 'e'] 51 | start = '2015-4-9' 52 | dt_index = uniform(start, periods=4, freq=DayFrequency(1, self.sc), sc=self.sc) 53 | rdd = self.sc.parallelize(zip(labels, vecs), 3) 54 | tsrdd = TimeSeriesRDD(dt_index, rdd) 55 | samples = tsrdd.to_instants().collect() 56 | target_dates = ['2015-4-9', '2015-4-10', '2015-4-11', '2015-4-12'] 57 | self.assertEquals([x[0] for x in samples], [pd.Timestamp(x) for x in target_dates]) 58 | self.assertTrue((samples[0][1] == np.arange(0, 20, 4)).all()) 59 | self.assertTrue((samples[1][1] == np.arange(1, 20, 4)).all()) 60 | self.assertTrue((samples[2][1] == np.arange(2, 20, 4)).all()) 61 | self.assertTrue((samples[3][1] == np.arange(3, 20, 4)).all()) 62 | 63 | def test_to_observations(self): 64 | sql_ctx = SQLContext(self.sc) 65 | vecs = [np.arange(x, x + 4) for x in np.arange(0, 20, 4)] 66 | labels = ['a', 'b', 'c', 'd', 'e'] 67 | start = '2015-4-9' 68 | dt_index = uniform(start, periods=4, freq=DayFrequency(1, self.sc), sc=self.sc) 69 | rdd = self.sc.parallelize(zip(labels, vecs), 3) 70 | tsrdd = TimeSeriesRDD(dt_index, rdd) 71 | 72 | obsdf = tsrdd.to_observations_dataframe(sql_ctx) 73 | tsrdd_from_df = time_series_rdd_from_observations( \ 74 | dt_index, obsdf, 'timestamp', 'key', 'value') 75 | 76 | ts1 = tsrdd.collect() 77 | ts1.sort(key = lambda x: x[0]) 78 | ts2 = tsrdd_from_df.collect() 79 | ts2.sort(key = lambda x: x[0]) 80 | self.assertTrue(all([pair[0][0] == pair[1][0] and (pair[0][1] == pair[1][1]).all() \ 81 | for pair in zip(ts1, ts2)])) 82 | 83 | df1 = obsdf.collect() 84 | df1.sort(key = lambda x: x.value) 85 | df2 = tsrdd_from_df.to_observations_dataframe(sql_ctx).collect() 86 | df2.sort(key = lambda x: x.value) 87 | self.assertEquals(df1, df2) 88 | 89 | def test_filter(self): 90 | vecs = [np.arange(x, x + 4) for x in np.arange(0, 20, 4)] 91 | labels = ['a', 'b', 'c', 'd', 'e'] 92 | start = '2015-4-9' 93 | dt_index = uniform(start, periods=4, freq=DayFrequency(1, self.sc), sc=self.sc) 94 | rdd = self.sc.parallelize(zip(labels, vecs), 3) 95 | tsrdd = TimeSeriesRDD(dt_index, rdd) 96 | filtered = tsrdd.filter(lambda x: x[0] == 'a' or x[0] == 'b') 97 | self.assertEquals(filtered.count(), 2) 98 | # assert it has TimeSeriesRDD functionality: 99 | filtered['2015-04-10':'2015-04-15'].count() 100 | 101 | def test_to_pandas_series_rdd(self): 102 | vecs = [np.arange(x, x + 4) for x in np.arange(0, 20, 4)] 103 | labels = ['a', 'b', 'c', 'd', 'e'] 104 | start = '2015-4-9' 105 | dt_index = uniform(start, periods=4, freq=DayFrequency(1, self.sc), sc=self.sc) 106 | rdd = self.sc.parallelize(zip(labels, vecs), 3) 107 | tsrdd = TimeSeriesRDD(dt_index, rdd) 108 | 109 | series_arr = tsrdd.to_pandas_series_rdd().collect() 110 | 111 | pd_index = dt_index.to_pandas_index() 112 | self.assertEquals(len(vecs), len(series_arr)) 113 | for i in xrange(len(vecs)): 114 | self.assertEquals(series_arr[i][0], labels[i]) 115 | self.assertTrue(pd.Series(vecs[i], pd_index).equals(series_arr[i][1])) 116 | 117 | -------------------------------------------------------------------------------- /python/sparkts/test/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from sparkts.utils import add_pyspark_path, quiet_py4j 3 | 4 | add_pyspark_path() 5 | quiet_py4j() 6 | 7 | from pyspark.context import SparkContext 8 | 9 | class PySparkTestCase(unittest.TestCase): 10 | def setUp(self): 11 | class_name = self.__class__.__name__ 12 | self.sc = SparkContext('local', class_name) 13 | self.sc._jvm.System.setProperty("spark.ui.showConsoleProgress", "false") 14 | log4j = self.sc._jvm.org.apache.log4j 15 | log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL) 16 | 17 | def tearDown(self): 18 | self.sc.stop() 19 | # To avoid Akka rebinding to the same port, since it doesn't unbind 20 | # immediately on shutdown 21 | self.sc._jvm.System.clearProperty("spark.driver.port") 22 | 23 | 24 | -------------------------------------------------------------------------------- /python/sparkts/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import logging 4 | import pandas as pd 5 | 6 | from glob import glob 7 | 8 | def add_pyspark_path(): 9 | """Add PySpark to the library path based on the value of SPARK_HOME. """ 10 | 11 | try: 12 | spark_home = os.environ['SPARK_HOME'] 13 | 14 | sys.path.append(os.path.join(spark_home, 'python')) 15 | py4j_src_zip = glob(os.path.join(spark_home, 'python', 16 | 'lib', 'py4j-*-src.zip')) 17 | if len(py4j_src_zip) == 0: 18 | raise ValueError('py4j source archive not found in %s' 19 | % os.path.join(spark_home, 'python', 'lib')) 20 | else: 21 | py4j_src_zip = sorted(py4j_src_zip)[::-1] 22 | sys.path.append(py4j_src_zip[0]) 23 | except KeyError: 24 | logging.error("""SPARK_HOME was not set. please set it. e.g. 25 | SPARK_HOME='/home/...' ./bin/pyspark [program]""") 26 | exit(-1) 27 | except ValueError as e: 28 | logging.error(str(e)) 29 | exit(-1) 30 | 31 | 32 | def quiet_py4j(): 33 | logger = logging.getLogger('py4j') 34 | logger.setLevel(logging.INFO) 35 | 36 | def datetime_to_millis(dt): 37 | """ 38 | Accept a string, Pandas Timestamp, or long, and return millis since the epoch. 39 | """ 40 | if isinstance(dt, pd.Timestamp): 41 | return dt.value / 1000000 42 | elif isinstance(dt, str): 43 | return pd.Timestamp(dt).value / 1000000 44 | elif isinstance(dt, long): 45 | return dt 46 | 47 | raise ValueError 48 | -------------------------------------------------------------------------------- /scalastyle-config.xml: -------------------------------------------------------------------------------- 1 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | 16 | 17 | Scalastyle standard configuration 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /src/main/scala/com/cloudera/finance/MultivariateTDistribution.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.finance 17 | 18 | import org.apache.commons.math3.distribution.{AbstractMultivariateRealDistribution, 19 | ChiSquaredDistribution} 20 | import org.apache.commons.math3.random.RandomGenerator 21 | 22 | /** 23 | * @param v The number of degrees of freedom in the distribution. 24 | * @param mu The mean of the distribution. 25 | * @param sigma The covariance matrix of the distribution. 26 | */ 27 | class MultivariateTDistribution( 28 | v: Int, 29 | mu: Array[Double], 30 | sigma: Array[Array[Double]], 31 | rand: RandomGenerator) 32 | extends AbstractMultivariateRealDistribution(rand, mu.length) { 33 | 34 | val normal = 35 | new SerializableMultivariateNormalDistribution(rand, new Array[Double](mu.length), sigma) 36 | val chiSquared = new ChiSquaredDistribution(v) 37 | reseedRandomGenerator(rand.nextLong()) 38 | 39 | override def sample(): Array[Double] = { 40 | val sample = normal.sample() 41 | val sqrtW = math.sqrt(chiSquared.sample()) 42 | 43 | var i = 0 44 | while (i < sample.length) { 45 | sample(i) = mu(i) + sample(i) * sqrtW 46 | i += 1 47 | } 48 | sample 49 | } 50 | 51 | override def density(p1: Array[Double]): Double = throw new UnsupportedOperationException 52 | 53 | override def reseedRandomGenerator(seed: Long): Unit = { 54 | normal.reseedRandomGenerator(seed) 55 | chiSquared.reseedRandomGenerator(seed + 1) 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/main/scala/com/cloudera/finance/SerializableMultivariateNormalDistribution.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.finance 17 | 18 | import org.apache.commons.math3.distribution.MultivariateNormalDistribution 19 | import org.apache.commons.math3.random.{MersenneTwister, RandomGenerator} 20 | 21 | /** 22 | * Version of MultivariateNormalDistribution that can be serialized in closures. 23 | */ 24 | class SerializableMultivariateNormalDistribution( 25 | rand: RandomGenerator, 26 | means: Array[Double], 27 | covariances: Array[Array[Double]]) 28 | extends MultivariateNormalDistribution(rand, means, covariances) with Serializable { 29 | 30 | def this(means: Array[Double], covariances: Array[Array[Double]]) = { 31 | this(new MersenneTwister(), means, covariances) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/main/scala/com/cloudera/finance/Util.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.finance 17 | 18 | import breeze.linalg._ 19 | 20 | import org.apache.commons.math3.random.RandomGenerator 21 | 22 | object Util { 23 | def sampleWithReplacement(values: Array[Double], rand: RandomGenerator, target: Array[Double]) 24 | : Unit = { 25 | for (i <- target.indices) { 26 | target(i) = values(rand.nextInt(values.length)) 27 | } 28 | } 29 | 30 | def transpose(arr: Array[Array[Double]]): Array[Array[Double]] = { 31 | val mat = new Array[Array[Double]](arr.head.length) 32 | for (i <- arr.head.indices) { 33 | mat(i) = arr.map(_(i)) 34 | } 35 | mat 36 | } 37 | 38 | def matToRowArrs(mat: Matrix[Double]): Array[Array[Double]] = { 39 | val arrs = new Array[Array[Double]](mat.rows) 40 | for (r <- 0 until mat.rows) { 41 | arrs(r) = mat(r to r, 0 to mat.cols - 1).toDenseMatrix.toArray 42 | } 43 | arrs 44 | } 45 | 46 | def arrsToMat(arrs: Iterator[Array[Double]]): DenseMatrix[Double] = { 47 | vecArrsToMats(arrs, arrs.length).next() 48 | } 49 | 50 | def vecArrsToMats(vecArrs: Iterator[Array[Double]], chunkSize: Int) 51 | : Iterator[DenseMatrix[Double]] = { 52 | new Iterator[DenseMatrix[Double]] { 53 | def hasNext: Boolean = vecArrs.hasNext 54 | 55 | def next(): DenseMatrix[Double] = { 56 | val firstVec = vecArrs.next() 57 | val vecLen = firstVec.length 58 | val arr = new Array[Double](chunkSize * vecLen) 59 | System.arraycopy(firstVec, 0, arr, 0, vecLen) 60 | 61 | var i = 1 62 | var offs = 0 63 | while (i < chunkSize && vecArrs.hasNext) { 64 | val vec = vecArrs.next() 65 | System.arraycopy(vec, 0, arr, offs, vecLen) 66 | i += 1 67 | offs += vecLen 68 | } 69 | 70 | new DenseMatrix[Double](i, firstVec.length, arr) 71 | } 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/main/scala/com/cloudera/finance/YahooParser.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.finance 17 | 18 | import com.cloudera.sparkts.TimeSeries 19 | import com.cloudera.sparkts.TimeSeries._ 20 | 21 | import org.apache.spark.SparkContext 22 | import org.apache.spark.rdd.RDD 23 | 24 | import org.joda.time.DateTime 25 | import org.joda.time.DateTimeZone.UTC 26 | 27 | object YahooParser { 28 | def yahooStringToTimeSeries(text: String, keyPrefix: String = ""): TimeSeries = { 29 | val lines = text.split('\n') 30 | val labels = lines(0).split(',').tail.map(keyPrefix + _) 31 | val samples = lines.tail.map { line => 32 | val tokens = line.split(',') 33 | val dt = new DateTime(tokens.head, UTC) 34 | (dt, tokens.tail.map(_.toDouble)) 35 | }.reverse 36 | timeSeriesFromIrregularSamples(samples, labels) 37 | } 38 | 39 | def yahooFiles(dir: String, sc: SparkContext): RDD[TimeSeries] = { 40 | sc.wholeTextFiles(dir, 3).map { case (path, text) => 41 | YahooParser.yahooStringToTimeSeries(text, path.split('/').last) 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/main/scala/com/cloudera/finance/examples/HistoricalValueAtRiskExample.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.finance.examples 17 | 18 | import breeze.linalg._ 19 | 20 | import com.cloudera.finance.YahooParser 21 | import com.cloudera.finance.Util._ 22 | import com.cloudera.sparkts._ 23 | import com.cloudera.sparkts.DateTimeIndex._ 24 | import com.cloudera.sparkts.TimeSeriesRDD._ 25 | 26 | import com.github.nscala_time.time.Imports._ 27 | 28 | import org.apache.commons.math3.random.{MersenneTwister, RandomGenerator} 29 | import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression 30 | 31 | import org.apache.spark.{SparkConf, SparkContext} 32 | 33 | object HistoricalValueAtRiskExample { 34 | def main(args: Array[String]): Unit = { 35 | // Read parameters 36 | val factorsDir = if (args.length > 0) args(0) else "factors/" 37 | val instrumentsDir = if (args.length > 1) args(1) else "instruments/" 38 | val numTrials = if (args.length > 2) args(2).toInt else 10000 39 | val parallelism = if (args.length > 3) args(3).toInt else 10 40 | val horizon = if (args.length > 4) args(4).toInt else 1 41 | 42 | // Initialize Spark 43 | val conf = new SparkConf().setMaster("local").setAppName("Historical VaR") 44 | val sc = new SparkContext(conf) 45 | 46 | // Load the data into a TimeSeriesRDD where each series holds 1-day log returns 47 | def loadTS(inputDir: String, lower: DateTime, upper: DateTime): TimeSeriesRDD = { 48 | val histories = YahooParser.yahooFiles(inputDir, sc) 49 | histories.cache() 50 | val start = histories.map(_.index.first).takeOrdered(1).head 51 | val end = histories.map(_.index.last).top(1).head 52 | val dtIndex = uniform(start, end, 1.businessDays) 53 | val tsRdd = timeSeriesRDD(dtIndex, histories). 54 | filter(_._1.endsWith("csvClose")). 55 | filterStartingBefore(lower).filterEndingAfter(upper) 56 | tsRdd.fill("linear"). 57 | slice(lower, upper). 58 | returnRates(). 59 | mapSeries(_.map(x => math.log(1 + x))) 60 | } 61 | 62 | val year2000 = nextBusinessDay(new DateTime("2008-1-1")) 63 | val year2010 = nextBusinessDay(year2000 + 5.years) 64 | 65 | val instrumentReturns = loadTS(instrumentsDir, year2000, year2010) 66 | val factorReturns = loadTS(factorsDir, year2000, year2010).collectAsTimeSeries() 67 | 68 | // Fit an AR(1) + GARCH(1, 1) model to each factor 69 | val garchModels = factorReturns.mapValues(ARGARCH.fitModel).toMap 70 | val iidFactorReturns = factorReturns.mapSeriesWithKey { case (symbol, series) => 71 | val model = garchModels(symbol) 72 | model.removeTimeDependentEffects(series, DenseVector.zeros[Double](series.length)) 73 | } 74 | 75 | // Generate an RDD of simulations 76 | val seeds = sc.parallelize(0 until numTrials, parallelism) 77 | seeds.map { seed => 78 | val rand = new MersenneTwister(seed) 79 | val factorPaths = simulatedFactorReturns(horizon, rand, iidFactorReturns, garchModels) 80 | } 81 | 82 | // val factorsDist = new FilteredHistoricalFactorDistribution(rand, iidFactorReturns.toArray, 83 | // garchModels.asInstanceOf[Array[TimeSeriesFilter]]) 84 | // val returns = simulationReturns(0L, factorsDist, numTrials, parallelism, sc, 85 | // instrumentReturnsModel) 86 | // returns.cache() 87 | // 88 | // // Calculate VaR and expected shortfall 89 | // val pValues = Array(.01, .03, .05) 90 | // val valueAtRisks = valueAtRisk(returns, pValues) 91 | // println(s"Value at risk at ${pValues.mkString(",")}: ${valueAtRisks.mkString(",")}") 92 | // 93 | // val expectedShortfalls = expectedShortfall(returns, pValues) 94 | // println(s"Expected shortfall at ${pValues.mkString(",")}: ${expectedShortfalls.mkString(",")}") 95 | } 96 | 97 | /** 98 | * Generates paths of factor returns each with the given number of days. 99 | */ 100 | def simulatedFactorReturns( 101 | nDays: Int, 102 | rand: RandomGenerator, 103 | iidSeries: TimeSeries, 104 | models: Map[String, TimeSeriesModel]): Matrix[Double] = { 105 | val mat = DenseMatrix.zeros[Double](nDays, iidSeries.data.cols) 106 | (0 until nDays).foreach { i => 107 | mat(i, ::) := iidSeries.data(rand.nextInt(iidSeries.data.rows), ::) 108 | } 109 | (0 until models.size).foreach { i => 110 | models(iidSeries.keys(i)).addTimeDependentEffects(mat(::, i), mat(::, i)) 111 | } 112 | mat 113 | } 114 | 115 | /** 116 | * Fits a model for each instrument that predicts its returns based on the returns of the factors. 117 | */ 118 | def fitInstrumentReturnsModels(instrumentReturns: TimeSeriesRDD, factorReturns: TimeSeries) { 119 | // Fit factor return -> instrument return predictive models 120 | val linearModels = instrumentReturns.mapValues { series => 121 | val withLag = factorReturns.slice(2 until series.length). 122 | union(series(1 until series.length - 1), "instrlag1"). 123 | union(series(0 until series.length - 2), "instrlag2") 124 | 125 | // Get factors in a form that we can feed into Commons Math linear regression 126 | val factorObservations = matToRowArrs(withLag.data) 127 | // Run the regression 128 | val regression = new OLSMultipleLinearRegression() 129 | regression.newSampleData(series.toArray, factorObservations) 130 | regression.estimateRegressionParameters() 131 | }.map(_._2).collect() 132 | arrsToMat(linearModels.iterator) 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /src/main/scala/com/cloudera/finance/examples/MonteCarloValueAtRiskExample.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.finance.examples 17 | 18 | import com.cloudera.finance.Util._ 19 | 20 | import com.cloudera.finance.{SerializableMultivariateNormalDistribution, YahooParser} 21 | import com.cloudera.sparkts.TimeSeriesRDD 22 | import com.cloudera.sparkts.DateTimeIndex._ 23 | import com.cloudera.sparkts.TimeSeriesRDD._ 24 | 25 | import com.github.nscala_time.time.Imports._ 26 | 27 | import org.apache.commons.math3.stat.correlation.Covariance 28 | import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression 29 | 30 | import org.apache.spark.{SparkConf, SparkContext} 31 | 32 | 33 | object MonteCarloValueAtRiskExample { 34 | def main(args: Array[String]): Unit = { 35 | // Read parameters 36 | val numTrials = if (args.length > 0) args(0).toInt else 10000 37 | val parallelism = if (args.length > 1) args(1).toInt else 10 38 | val factorsDir = if (args.length > 2) args(2) else "factors/" 39 | val instrumentsDir = if (args.length > 3) args(3) else "instruments/" 40 | 41 | // Initialize Spark 42 | val conf = new SparkConf().setMaster("local").setAppName("Monte Carlo VaR") 43 | val sc = new SparkContext(conf) 44 | 45 | 46 | def loadTS(inputDir: String, lower: DateTime, upper: DateTime): TimeSeriesRDD = { 47 | val histories = YahooParser.yahooFiles(instrumentsDir, sc) 48 | histories.cache() 49 | val start = histories.map(_.index.first).takeOrdered(1).head 50 | val end = histories.map(_.index.last).top(1).head 51 | val dtIndex = uniform(start, end, 1.businessDays) 52 | val tsRdd = timeSeriesRDD(dtIndex, histories). 53 | filter(_._1.endsWith("Open")). 54 | filterStartingBefore(lower).filterEndingAfter(upper) 55 | tsRdd.differences(10) 56 | } 57 | 58 | val year2000 = nextBusinessDay(new DateTime("2000-1-1")) 59 | val year2010 = nextBusinessDay(year2000 + 10.years) 60 | 61 | val instrumentReturns = loadTS(instrumentsDir, year2000, year2010) 62 | val factorReturns = loadTS(factorsDir, year2000, year2010).collectAsTimeSeries() 63 | 64 | // Fit factor return -> instrument return predictive models 65 | val factorObservations = matToRowArrs(factorReturns.data) 66 | val linearModels = instrumentReturns.mapValues { series => 67 | val regression = new OLSMultipleLinearRegression() 68 | regression.newSampleData(series.toArray, factorObservations) 69 | regression.estimateRegressionParameters() 70 | }.map(_._2).collect() 71 | // val instrumentReturnsModel = new LinearInstrumentReturnsModel(arrsToMat(linearModels.iterator)) 72 | // 73 | // 74 | // // Generate an RDD of simulations 75 | // val factorCov = new Covariance(factorObservations).getCovarianceMatrix().getData() 76 | // val factorMeans = factorReturns.univariateSeriesIterator. 77 | // map(factor => factor.sum / factor.size).toArray 78 | // val factorsDist = new SerializableMultivariateNormalDistribution(factorMeans, factorCov) 79 | // val returns = simulationReturns(0L, factorsDist, numTrials, parallelism, sc, 80 | // instrumentReturnsModel) 81 | // returns.cache() 82 | // 83 | // // Calculate VaR and expected shortfall 84 | // val pValues = Array(.01, .03, .05) 85 | // val valueAtRisks = valueAtRisk(returns, pValues) 86 | // println(s"Value at risk at ${pValues.mkString(",")}: ${valueAtRisks.mkString(",")}") 87 | // 88 | // val expectedShortfalls = expectedShortfall(returns, pValues) 89 | // println(s"Expected shortfall at ${pValues.mkString(",")}: ${expectedShortfalls.mkString(",")}") 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /src/main/scala/com/cloudera/finance/examples/SimpleTickDataExample.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.finance.examples 17 | 18 | import com.cloudera.finance.YahooParser 19 | import com.cloudera.sparkts.DateTimeIndex._ 20 | import com.cloudera.sparkts.TimeSeries 21 | import com.cloudera.sparkts.UnivariateTimeSeries._ 22 | import com.cloudera.sparkts.TimeSeriesRDD._ 23 | import com.cloudera.sparkts.TimeSeriesStatisticalTests._ 24 | 25 | import com.github.nscala_time.time.Imports._ 26 | 27 | import org.apache.spark.{SparkConf, SparkContext} 28 | import org.apache.spark.rdd.RDD 29 | 30 | object SimpleTickDataExample { 31 | def main(args: Array[String]): Unit = { 32 | val inputDir = args(0) 33 | 34 | val conf = new SparkConf().setMaster("local") 35 | val sc = new SparkContext(conf) 36 | 37 | // Load and parse the data 38 | val seriesByFile: RDD[TimeSeries] = YahooParser.yahooFiles(inputDir, sc) 39 | seriesByFile.cache() 40 | 41 | // Merge the series from individual files into a TimeSeriesRDD 42 | val start = seriesByFile.map(_.index.first).takeOrdered(1).head 43 | val end = seriesByFile.map(_.index.last).top(1).head 44 | val dtIndex = uniform(start, end, 1.businessDays) 45 | val tsRdd = timeSeriesRDD(dtIndex, seriesByFile) 46 | println(s"Num time series: ${tsRdd.count()}") 47 | 48 | // Which symbols do we have the oldest data on? 49 | val oldestData: Array[(Int, String)] = tsRdd.mapValues(trimLeading(_).length).map(_.swap).top(5) 50 | println(oldestData.mkString(",")) 51 | 52 | // Only look at open prices, and filter all symbols that don't span the 2000's 53 | val opensRdd = tsRdd.filter(_._1.endsWith("Open")) 54 | val year2000 = nextBusinessDay(new DateTime("2000-1-1")) 55 | val year2010 = nextBusinessDay(year2000 + 10.years) 56 | val filteredRdd = opensRdd.filterStartingBefore(year2000).filterEndingAfter(year2010) 57 | println(s"Remaining after filtration: ${filteredRdd.count()}") 58 | 59 | // Impute missing data with linear interpolation 60 | val filledRdd = filteredRdd.fill("linear") 61 | 62 | // Slice to the 2000's 63 | val slicedRdd = filledRdd.slice(year2000, year2010) 64 | slicedRdd.cache() 65 | 66 | // Find the series with the largest serial correlations 67 | val durbinWatsonStats: RDD[(String, Double)] = slicedRdd.mapValues(dwtest) 68 | durbinWatsonStats.map(_.swap).top(20) 69 | 70 | // Remove serial correlations 71 | val iidRdd = slicedRdd.mapSeries(series => ar(series, 1).removeTimeDependentEffects(series)) 72 | 73 | // Regress a stock against all the others 74 | val samples = iidRdd.toInstants() 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/main/scala/com/cloudera/finance/risk/FilteredHistoricalFactorDistribution.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.finance.risk 17 | 18 | import breeze.linalg._ 19 | 20 | 21 | import org.apache.commons.math3.distribution.MultivariateRealDistribution 22 | import org.apache.commons.math3.random.RandomGenerator 23 | import com.cloudera.sparkts.TimeSeriesFilter 24 | 25 | class FilteredHistoricalFactorDistribution( 26 | rand: RandomGenerator, 27 | iidHistories: Array[Vector[Double]], 28 | filters: Array[TimeSeriesFilter]) extends MultivariateRealDistribution with Serializable { 29 | 30 | def reseedRandomGenerator(seed: Long): Unit = rand.setSeed(seed) 31 | 32 | def density(p1: Array[Double]): Double = throw new UnsupportedOperationException 33 | 34 | def getDimension: Int = iidHistories.length 35 | 36 | def sample(): Array[Double] = { 37 | val numHistories = iidHistories.length 38 | val numTimePoints = iidHistories.head.length 39 | val resampledHistories = Array.ofDim[Double](numHistories, numTimePoints) 40 | for (i <- 0 until numTimePoints) { 41 | val timePoint = rand.nextInt(numTimePoints) 42 | for (j <- 0 until numHistories) { 43 | resampledHistories(j)(i) = iidHistories(j)(timePoint) 44 | } 45 | } 46 | 47 | for (i <- 0 until numHistories) { 48 | filters(i).filter(resampledHistories(i), resampledHistories(i)) 49 | } 50 | 51 | val timePoint = rand.nextInt(numTimePoints) 52 | resampledHistories.map(_(timePoint)) 53 | } 54 | 55 | def sample(p1: Int): Array[Array[Double]] = throw new UnsupportedOperationException 56 | } 57 | -------------------------------------------------------------------------------- /src/main/scala/com/cloudera/sparkts/Autoregression.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.sparkts 17 | 18 | import breeze.linalg._ 19 | 20 | import com.cloudera.finance.Util.matToRowArrs 21 | 22 | import org.apache.commons.math3.random.RandomGenerator 23 | import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression 24 | 25 | object Autoregression { 26 | /** 27 | * Fits an AR(1) model to the given time series 28 | */ 29 | def fitModel(ts: Vector[Double]): ARModel = fitModel(ts, 1) 30 | 31 | /** 32 | * Fits an AR(n) model to a given time series 33 | * @param ts Data to fit 34 | * @param maxLag The autoregressive factor, terms t - 1 through t - maxLag are included 35 | * @param noIntercept A boolean to indicate if the regression should be run without an intercept, 36 | * the default is set to false, so that the OLS includes an intercept term 37 | * @return AR(n) model 38 | */ 39 | def fitModel(ts: Vector[Double], maxLag: Int, noIntercept: Boolean = false): ARModel = { 40 | // This is loosely based off of the implementation in statsmodels: 41 | // https://github.com/statsmodels/statsmodels/blob/master/statsmodels/tsa/ar_model.py 42 | 43 | // Make left hand side 44 | val Y = ts(maxLag until ts.length) 45 | // Make lagged right hand side 46 | val X = Lag.lagMatTrimBoth(ts, maxLag) 47 | 48 | val regression = new OLSMultipleLinearRegression() 49 | regression.setNoIntercept(noIntercept) // drop intercept in regression 50 | regression.newSampleData(Y.toArray, matToRowArrs(X)) 51 | val params = regression.estimateRegressionParameters() 52 | val (c, coeffs) = if (noIntercept) (0.0, params) else (params.head, params.tail) 53 | new ARModel(c, coeffs) 54 | } 55 | } 56 | 57 | class ARModel(val c: Double, val coefficients: Array[Double]) extends TimeSeriesModel { 58 | 59 | def this(c: Double, coef: Double) = this(c, Array(coef)) 60 | 61 | def removeTimeDependentEffects(ts: Vector[Double], destTs: Vector[Double] = null): Vector[Double] = { 62 | val dest = if (destTs == null) DenseVector.zeros[Double](ts.length) else destTs 63 | var i = 0 64 | while (i < ts.length) { 65 | dest(i) = ts(i) - c 66 | var j = 0 67 | while (j < coefficients.length && i - j - 1 >= 0) { 68 | dest(i) -= ts(i - j - 1) * coefficients(j) 69 | j += 1 70 | } 71 | i += 1 72 | } 73 | dest 74 | } 75 | 76 | def addTimeDependentEffects(ts: Vector[Double], destTs: Vector[Double]): Vector[Double] = { 77 | val dest = if (destTs == null) DenseVector.zeros[Double](ts.length) else destTs 78 | var i = 0 79 | while (i < ts.length) { 80 | dest(i) = c + ts(i) 81 | var j = 0 82 | while (j < coefficients.length && i - j - 1 >= 0) { 83 | dest(i) += dest(i - j - 1) * coefficients(j) 84 | j += 1 85 | } 86 | i += 1 87 | } 88 | dest 89 | } 90 | 91 | def sample(n: Int, rand: RandomGenerator): Vector[Double] = { 92 | val vec = new DenseVector(Array.fill[Double](n)(rand.nextGaussian())) 93 | addTimeDependentEffects(vec, vec) 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /src/main/scala/com/cloudera/sparkts/EWMA.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.sparkts 17 | 18 | import breeze.linalg._ 19 | 20 | import org.apache.commons.math3.analysis.{MultivariateFunction, MultivariateVectorFunction} 21 | import org.apache.commons.math3.optim.{InitialGuess, MaxEval, MaxIter, SimpleValueChecker} 22 | import org.apache.commons.math3.optim.nonlinear.scalar.{GoalType, ObjectiveFunction, 23 | ObjectiveFunctionGradient} 24 | import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer 25 | 26 | /** 27 | * Fits an Exponentially Weight Moving Average model (EWMA) (aka. Simple Exponential Smoothing) to 28 | * a time series. The model is defined as S_t = (1 - a) * X_t + a * S_{t - 1}, where a is the 29 | * smoothing parameter, X is the original series, and S is the smoothed series. For more 30 | * information, please see https://en.wikipedia.org/wiki/Exponential_smoothing. 31 | */ 32 | object EWMA { 33 | /** 34 | * Fits an EWMA model to a time series. Uses the first point in the time series as a starting 35 | * value. Uses sum squared error as an objective function to optimize to find smoothing parameter 36 | * The model for EWMA is recursively defined as S_t = (1 - a) * X_t + a * S_{t-1}, where 37 | * a is the smoothing parameter, X is the original series, and S is the smoothed series 38 | * Note that the optimization is performed as unbounded optimization, although in its formal 39 | * definition the smoothing parameter is <= 1, which corresponds to an inequality bounded 40 | * optimization. Given this, the resulting smoothing parameter should always be sanity checked 41 | * https://en.wikipedia.org/wiki/Exponential_smoothing 42 | * @param ts the time series to which we want to fit an EWMA model 43 | * @return EWMA model 44 | */ 45 | def fitModel(ts: Vector[Double]): EWMAModel = { 46 | val optimizer = new NonLinearConjugateGradientOptimizer( 47 | NonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES, 48 | new SimpleValueChecker(1e-6, 1e-6)) 49 | val gradient = new ObjectiveFunctionGradient(new MultivariateVectorFunction() { 50 | def value(params: Array[Double]): Array[Double] = { 51 | val g = new EWMAModel(params(0)).gradient(ts) 52 | Array(g) 53 | } 54 | }) 55 | val objectiveFunction = new ObjectiveFunction(new MultivariateFunction() { 56 | def value(params: Array[Double]): Double = { 57 | new EWMAModel(params(0)).sse(ts) 58 | } 59 | }) 60 | // optimization parameters 61 | val initGuess = new InitialGuess(Array(.94)) 62 | val maxIter = new MaxIter(10000) 63 | val maxEval = new MaxEval(10000) 64 | val goal = GoalType.MINIMIZE 65 | // optimization step 66 | val optimal = optimizer.optimize(objectiveFunction, goal, gradient, initGuess, maxIter, maxEval) 67 | val params = optimal.getPoint 68 | new EWMAModel(params(0)) 69 | } 70 | } 71 | 72 | class EWMAModel(val smoothing: Double) extends TimeSeriesModel { 73 | 74 | /** 75 | * Calculates the SSE for a given timeseries ts given the smoothing parameter of the current model 76 | * The forecast for the observation at period t + 1 is the smoothed value at time t 77 | * Source: http://people.duke.edu/~rnau/411avg.htm 78 | * @param ts the time series to fit a EWMA model to 79 | * @return Sum Squared Error 80 | */ 81 | private[sparkts] def sse(ts: Vector[Double]): Double = { 82 | val n = ts.length 83 | val smoothed = new DenseVector(Array.fill(n)(0.0)) 84 | addTimeDependentEffects(ts, smoothed) 85 | 86 | var i = 0 87 | var error = 0.0 88 | var sqrErrors = 0.0 89 | while (i < n - 1) { 90 | error = ts(i + 1) - smoothed(i) 91 | sqrErrors += error * error 92 | i += 1 93 | } 94 | 95 | sqrErrors 96 | } 97 | 98 | /** 99 | * Calculates the gradient of the SSE cost function for our EWMA model 100 | * @param ts 101 | * @return gradient 102 | */ 103 | private[sparkts] def gradient(ts: Vector[Double]): Double = { 104 | val n = ts.length 105 | val smoothed = new DenseVector(Array.fill(n)(0.0)) 106 | addTimeDependentEffects(ts, smoothed) 107 | 108 | var error = 0.0 109 | var prevSmoothed = ts(0) 110 | var prevDSda = 0.0 // derivative of the EWMA function at time t - 1: (d S(t - 1)/ d smoothing) 111 | var dSda = 0.0 // derivative of the EWMA function at time t: (d S(t) / d smoothing) 112 | var dJda = 0.0 // derivative of our SSE cost function 113 | var i = 0 114 | 115 | while (i < n - 1) { 116 | error = ts(i + 1) - smoothed(i) 117 | dSda = ts(i) - prevSmoothed + (1 - smoothing) * prevDSda 118 | dJda += error * dSda 119 | prevDSda = dSda 120 | prevSmoothed = smoothed(i) 121 | i += 1 122 | } 123 | 2 * dJda 124 | } 125 | 126 | override def removeTimeDependentEffects(ts: Vector[Double], dest: Vector[Double] = null) 127 | : Vector[Double] = { 128 | dest(0) = ts(0) // by definition in our model S_0 = X_0 129 | 130 | for (i <- 1 until ts.length) { 131 | dest(i) = (ts(i) - (1 - smoothing) * ts(i - 1)) / smoothing 132 | } 133 | dest 134 | } 135 | 136 | override def addTimeDependentEffects(ts: Vector[Double], dest: Vector[Double]): Vector[Double] = { 137 | dest(0) = ts(0) // by definition in our model S_0 = X_0 138 | 139 | for (i <- 1 until ts.length) { 140 | dest(i) = smoothing * ts(i) + (1 - smoothing) * dest(i - 1) 141 | } 142 | dest 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /src/main/scala/com/cloudera/sparkts/EasyPlot.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.sparkts 17 | 18 | import breeze.linalg._ 19 | import breeze.plot._ 20 | 21 | import org.apache.commons.math3.distribution.NormalDistribution 22 | 23 | object EasyPlot { 24 | def ezplot(vec: Vector[Double], style: Char): Figure = { 25 | val f = Figure() 26 | val p = f.subplot(0) 27 | p += plot((0 until vec.length).map(_.toDouble).toArray, vec, style = style) 28 | f 29 | } 30 | 31 | def ezplot(vec: Vector[Double]): Figure = ezplot(vec, '-') 32 | 33 | def ezplot(arr: Array[Double], style: Char): Figure = { 34 | val f = Figure() 35 | val p = f.subplot(0) 36 | p += plot(arr.indices.map(_.toDouble).toArray, arr, style = style) 37 | f 38 | } 39 | 40 | def ezplot(arr: Array[Double]): Figure = ezplot(arr, '-') 41 | 42 | def ezplot(vecs: Seq[Vector[Double]], style: Char): Figure = { 43 | val f = Figure() 44 | val p = f.subplot(0) 45 | val first = vecs.head 46 | vecs.foreach { vec => 47 | p += plot((0 until first.length).map(_.toDouble).toArray, vec, style) 48 | } 49 | f 50 | } 51 | 52 | def ezplot(vecs: Seq[Vector[Double]]): Figure = ezplot(vecs, '-') 53 | 54 | /** 55 | * Autocorrelation function plot 56 | * @param data array of data to analyze 57 | * @param maxLag maximum lag for autocorrelation 58 | * @param conf confidence bounds to display 59 | */ 60 | def acfPlot(data: Array[Double], maxLag: Int, conf: Double = 0.95): Figure = { 61 | // calculate correlations and confidence bound 62 | val autoCorrs = UnivariateTimeSeries.autocorr(data, maxLag) 63 | val confVal = calcConfVal(conf, data.length) 64 | 65 | // Basic plot information 66 | val f = Figure() 67 | val p = f.subplot(0) 68 | p.title = "Autocorrelation function" 69 | p.xlabel = "Lag" 70 | p.ylabel = "Autocorrelation" 71 | drawCorrPlot(autoCorrs, confVal, p) 72 | f 73 | } 74 | 75 | /** 76 | * Partial autocorrelation function plot 77 | * @param data array of data to analyze 78 | * @param maxLag maximum lag for partial autocorrelation function 79 | * @param conf confidence bounds to display 80 | */ 81 | def pacfPlot(data: Array[Double], maxLag: Int, conf: Double = 0.95): Figure = { 82 | // create AR(maxLag) model, retrieve coefficients and calculate confidence bound 83 | val model = Autoregression.fitModel(new DenseVector(data), maxLag) 84 | val pCorrs = model.coefficients // partial autocorrelations are the coefficients in AR(n) model 85 | val confVal = calcConfVal(conf, data.length) 86 | 87 | // Basic plot information 88 | val f = Figure() 89 | val p = f.subplot(0) 90 | p.title = "Partial autocorrelation function" 91 | p.xlabel = "Lag" 92 | p.ylabel = "Partial Autocorrelation" 93 | drawCorrPlot(pCorrs, confVal, p) 94 | f 95 | } 96 | 97 | private[sparkts] def calcConfVal(conf:Double, n: Int): Double = { 98 | val stdNormDist = new NormalDistribution(0, 1) 99 | val pVal = (1 - conf) / 2.0 100 | stdNormDist.inverseCumulativeProbability(1 - pVal) / Math.sqrt(n) 101 | } 102 | 103 | private[sparkts] def drawCorrPlot(corrs: Array[Double], confVal: Double, p: Plot): Unit = { 104 | // make decimal ticks visible 105 | p.setYAxisDecimalTickUnits() 106 | // plot correlations as vertical lines 107 | val verticalLines = corrs.zipWithIndex.map { case (corr, ix) => 108 | (Array(ix.toDouble + 1, ix.toDouble + 1), Array(0, corr)) 109 | } 110 | verticalLines.foreach { case (xs, ys) => p += plot(xs, ys) } 111 | // plot confidence intervals as horizontal lines 112 | val n = corrs.length 113 | Array(confVal, -1 * confVal).foreach { conf => 114 | val xs = (0 to n).toArray.map(_.toDouble) 115 | val ys = Array.fill(n + 1)(conf) 116 | p += plot(xs, ys, '-', colorcode = "red") 117 | } 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /src/main/scala/com/cloudera/sparkts/Frequency.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.sparkts 17 | 18 | import com.github.nscala_time.time.Imports._ 19 | 20 | import org.joda.time.{Days, Hours} 21 | 22 | class BusinessDayRichInt(n: Int) { 23 | def businessDays: BusinessDayFrequency = new BusinessDayFrequency(n) 24 | } 25 | 26 | /** 27 | * A frequency for a uniform index. 28 | */ 29 | trait Frequency extends Serializable { 30 | /** 31 | * Advances the given DateTime by this frequency n times. 32 | */ 33 | def advance(dt: DateTime, n: Int): DateTime 34 | 35 | /** 36 | * The number of times this frequency occurs between the two DateTimes, rounded down. 37 | */ 38 | def difference(dt1: DateTime, dt2: DateTime): Int 39 | } 40 | 41 | 42 | abstract class PeriodFrequency(val period: Period) extends Frequency { 43 | def advance(dt: DateTime, n: Int): DateTime = dt + (n * period) 44 | 45 | override def equals(other: Any): Boolean = { 46 | other match { 47 | case frequency: PeriodFrequency => frequency.period == period 48 | case _ => false 49 | } 50 | } 51 | } 52 | 53 | class DayFrequency(val days: Int) extends PeriodFrequency(days.days) { 54 | 55 | def difference(dt1: DateTime, dt2: DateTime): Int = Days.daysBetween(dt1, dt2).getDays / days 56 | 57 | override def toString: String = s"days $days" 58 | } 59 | 60 | class HourFrequency(val hours: Int) extends PeriodFrequency(hours.hours) { 61 | 62 | def difference(dt1: DateTime, dt2: DateTime): Int = Hours.hoursBetween(dt1, dt2).getHours / hours 63 | 64 | override def toString: String = s"hours $hours" 65 | } 66 | 67 | class BusinessDayFrequency(val days: Int) extends Frequency { 68 | /** 69 | * Advances the given DateTime by n business days. 70 | */ 71 | def advance(dt: DateTime, n: Int): DateTime = { 72 | val dayOfWeek = dt.getDayOfWeek 73 | if (dayOfWeek > 5) { 74 | throw new IllegalArgumentException(s"$dt is not a business day") 75 | } 76 | val totalDays = n * days 77 | val standardWeekendDays = (totalDays / 5) * 2 78 | val remaining = totalDays % 5 79 | val extraWeekendDays = if (dayOfWeek + remaining > 5) 2 else 0 80 | dt + (totalDays + standardWeekendDays + extraWeekendDays).days 81 | } 82 | 83 | def difference(dt1: DateTime, dt2: DateTime): Int = { 84 | if (dt2 < dt1) { 85 | return -difference(dt2, dt1) 86 | } 87 | val daysBetween = Days.daysBetween(dt1, dt2).getDays 88 | val dayOfWeek1 = dt1.getDayOfWeek 89 | if (dayOfWeek1 > 5) { 90 | throw new IllegalArgumentException(s"$dt1 is not a business day") 91 | } 92 | val standardWeekendDays = (daysBetween / 7) * 2 93 | val remaining = daysBetween % 7 94 | val extraWeekendDays = if (dayOfWeek1 + remaining > 5) 2 else 0 95 | (daysBetween - standardWeekendDays - extraWeekendDays) / days 96 | } 97 | 98 | override def equals(other: Any): Boolean = { 99 | other match { 100 | case frequency: BusinessDayFrequency => frequency.days == days 101 | case _ => false 102 | } 103 | } 104 | 105 | override def toString: String = s"businessDays $days" 106 | } 107 | -------------------------------------------------------------------------------- /src/main/scala/com/cloudera/sparkts/Lag.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.sparkts 17 | 18 | import breeze.linalg._ 19 | 20 | object Lag { 21 | /** 22 | * Makes a lag matrix from the given time series with the given lag, trimming both rows and 23 | * columns so that every element in the matrix is full. 24 | */ 25 | private[sparkts] def lagMatTrimBoth(x: Array[Double], maxLag: Int): Array[Array[Double]] = { 26 | lagMatTrimBoth(x, maxLag, false) 27 | } 28 | 29 | /** 30 | * Makes a lag matrix from the given time series with the given lag, trimming both rows and 31 | * columns so that every element in the matrix is full. 32 | */ 33 | private[sparkts] def lagMatTrimBoth(x: Array[Double], maxLag: Int, includeOriginal: Boolean) 34 | : Array[Array[Double]] = { 35 | val numObservations = x.length 36 | val numRows = numObservations - maxLag 37 | val numCols = maxLag + (if (includeOriginal) 1 else 0) 38 | val lagMat = Array.ofDim[Double](numRows, numCols) 39 | 40 | val initialLag = if (includeOriginal) 0 else 1 41 | 42 | for (r <- 0 until numRows) { 43 | for (c <- initialLag to maxLag) { 44 | lagMat(r)(c - initialLag) = x(r + maxLag - c) 45 | } 46 | } 47 | lagMat 48 | } 49 | 50 | /** 51 | * Makes a lag matrix from the given time series with the given lag, trimming both rows and 52 | * columns so that every element in the matrix is full. 53 | */ 54 | private[sparkts] def lagMatTrimBoth(x: Vector[Double], maxLag: Int): Matrix[Double] = { 55 | lagMatTrimBoth(x, maxLag, false) 56 | } 57 | 58 | /** 59 | * Makes a lag matrix from the given time series with the given lag, trimming both rows and 60 | * columns so that every element in the matrix is full. 61 | */ 62 | private[sparkts] def lagMatTrimBoth(x: Vector[Double], maxLag: Int, includeOriginal: Boolean) 63 | : Matrix[Double] = { 64 | val numObservations = x.size 65 | val numRows = numObservations - maxLag 66 | val numCols = maxLag + (if (includeOriginal) 1 else 0) 67 | val lagMat = new DenseMatrix[Double](numRows, numCols) 68 | 69 | val initialLag = if (includeOriginal) 0 else 1 70 | 71 | for (r <- 0 until numRows) { 72 | for (c <- initialLag to maxLag) { 73 | lagMat(r, (c - initialLag)) = x(r + maxLag - c) 74 | } 75 | } 76 | lagMat 77 | } 78 | 79 | private[sparkts] def lagMatTrimBoth( 80 | x: Vector[Double], 81 | outputMat: DenseMatrix[Double], 82 | maxLag: Int, 83 | includeOriginal: Boolean): Unit = { 84 | val numObservations = x.size 85 | val numRows = numObservations - maxLag 86 | 87 | val initialLag = if (includeOriginal) 0 else 1 88 | 89 | for (r <- 0 until numRows) { 90 | for (c <- initialLag to maxLag) { 91 | outputMat(r, (c - initialLag)) = x(r + maxLag - c) 92 | } 93 | } 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /src/main/scala/com/cloudera/sparkts/PythonConnector.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.sparkts 17 | 18 | import java.nio.ByteBuffer 19 | 20 | import breeze.linalg.{DenseVector, Vector} 21 | 22 | import org.apache.spark.api.java.function.Function 23 | 24 | import org.joda.time.DateTime 25 | 26 | /** 27 | * This file contains utilities used by the spark-timeseries Python bindings to communicate with 28 | * the JVM. BytesToKeyAndSeries and KeyAndSeriesToBytes write and read bytes in the format 29 | * read and written by the Python TimeSeriesSerializer class. 30 | */ 31 | 32 | private object PythonConnector { 33 | val INT_SIZE = 4 34 | val DOUBLE_SIZE = 8 35 | val LONG_SIZE = 8 36 | 37 | def putVector(buf: ByteBuffer, vec: Vector[Double]): Unit = { 38 | buf.putInt(vec.length) 39 | var i = 0 40 | while (i < vec.length) { 41 | buf.putDouble(vec(i)) 42 | i += 1 43 | } 44 | } 45 | } 46 | 47 | private class BytesToKeyAndSeries extends Function[Array[Byte], (String, Vector[Double])] { 48 | override def call(arr: Array[Byte]): (String, Vector[Double]) = { 49 | val buf = ByteBuffer.wrap(arr) 50 | val keySize = buf.getInt() 51 | val keyBytes = new Array[Byte](keySize) 52 | buf.get(keyBytes) 53 | 54 | val seriesSize = buf.getInt() 55 | val series = new Array[Double](seriesSize) 56 | var i = 0 57 | while (i < seriesSize) { 58 | series(i) = buf.getDouble() 59 | i += 1 60 | } 61 | (new String(keyBytes, "UTF8"), new DenseVector[Double](series)) 62 | } 63 | } 64 | 65 | private class KeyAndSeriesToBytes extends Function[(String, Vector[Double]), Array[Byte]] { 66 | import PythonConnector._ 67 | 68 | override def call(keyVec: (String, Vector[Double])): Array[Byte] = { 69 | val keyBytes = keyVec._1.getBytes("UTF-8") 70 | val vec = keyVec._2 71 | val arr = new Array[Byte](INT_SIZE + keyBytes.length + INT_SIZE + DOUBLE_SIZE * vec.length) 72 | val buf = ByteBuffer.wrap(arr) 73 | buf.putInt(keyBytes.length) 74 | buf.put(keyBytes) 75 | putVector(buf, vec) 76 | arr 77 | } 78 | } 79 | 80 | private class InstantToBytes extends Function[(DateTime, Vector[Double]), Array[Byte]] { 81 | import PythonConnector._ 82 | 83 | override def call(instant: (DateTime, Vector[Double])): Array[Byte] = { 84 | val arr = new Array[Byte](LONG_SIZE + INT_SIZE + DOUBLE_SIZE * instant._2.length) 85 | val buf = ByteBuffer.wrap(arr) 86 | buf.putLong(instant._1.getMillis) 87 | putVector(buf, instant._2) 88 | arr 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /src/main/scala/com/cloudera/sparkts/TimeSeries.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.sparkts 17 | 18 | import breeze.linalg._ 19 | import breeze.numerics._ 20 | import com.github.nscala_time.time.Imports 21 | import com.github.nscala_time.time.Imports._ 22 | 23 | import scala.collection.immutable.{IndexedSeq, Iterable} 24 | 25 | class TimeSeries(val index: DateTimeIndex, val data: DenseMatrix[Double], 26 | val keys: Array[String]) extends Serializable { 27 | 28 | /** 29 | * IMPORTANT: currently this assumes that the DateTimeIndex is a UniformDateTimeIndex, not an 30 | * Irregular one. This means that this function won't work (yet) on TimeSeries built using 31 | * timeSeriesFromIrregularSamples(). 32 | * 33 | * Lags all individual time series of the TimeSeries instance by up to maxLag amount. 34 | * 35 | * Example input TimeSeries: 36 | * time a b 37 | * 4 pm 1 6 38 | * 5 pm 2 7 39 | * 6 pm 3 8 40 | * 7 pm 4 9 41 | * 8 pm 5 10 42 | * 43 | * With maxLag 2 and includeOriginals = true, we would get: 44 | * time a lag1(a) lag2(a) b lag1(b) lag2(b) 45 | * 6 pm 3 2 1 8 7 6 46 | * 7 pm 4 3 2 9 8 7 47 | * 8 pm 5 4 3 10 9 8 48 | * 49 | */ 50 | def lags(maxLag: Int, includeOriginals: Boolean): TimeSeries = { 51 | val numCols = maxLag * keys.length + (if (includeOriginals) keys.length else 0) 52 | val numRows = data.rows - maxLag 53 | 54 | val laggedData = new DenseMatrix[Double](numRows, numCols) 55 | (0 until data.cols).foreach { colIndex => 56 | val offset = maxLag + (if (includeOriginals) 1 else 0) 57 | val start = colIndex * offset 58 | 59 | Lag.lagMatTrimBoth(data(::, colIndex), laggedData(::, start to (start + offset - 1)), maxLag, 60 | includeOriginals) 61 | } 62 | 63 | val newKeys = keys.indices.map { keyIndex => 64 | val key = keys(keyIndex) 65 | val lagKeys = (1 to maxLag).map(lagOrder => s"lag${lagOrder.toString}($key)").toArray 66 | 67 | if (includeOriginals) Array(key) ++ lagKeys else lagKeys 68 | }.reduce((prev: Array[String], next: Array[String]) => prev ++ next) 69 | 70 | val newDatetimeIndex = index.islice(maxLag, data.rows) 71 | 72 | new TimeSeries(newDatetimeIndex, laggedData, newKeys) 73 | } 74 | 75 | def slice(range: Range): TimeSeries = { 76 | new TimeSeries(index.islice(range), data(range, ::), keys) 77 | } 78 | 79 | def union(vec: Vector[Double], key: String): TimeSeries = { 80 | val mat = DenseMatrix.zeros[Double](data.rows, data.cols + 1) 81 | (0 until data.cols).foreach(c => mat(::, c to c) := data(::, c to c)) 82 | mat(::, -1 to -1) := vec 83 | new TimeSeries(index, mat, keys :+ key) 84 | } 85 | 86 | /** 87 | * Returns a TimeSeries where each time series is differenced with the given order. The new 88 | * TimeSeries will be missing the first n date-times. 89 | */ 90 | def differences(lag: Int): TimeSeries = { 91 | mapSeries(index.islice(lag, index.size), vec => diff(vec.toDenseVector, lag)) 92 | } 93 | 94 | /** 95 | * Returns a TimeSeries where each time series is differenced with order 1. The new TimeSeries 96 | * will be missing the first date-time. 97 | */ 98 | def differences(): TimeSeries = differences(1) 99 | 100 | /** 101 | * Returns a TimeSeries where each time series is quotiented with the given order. The new 102 | * TimeSeries will be missing the first n date-times. 103 | */ 104 | def quotients(lag: Int): TimeSeries = { 105 | mapSeries(index.islice(lag, index.size), vec => UnivariateTimeSeries.quotients(vec, lag)) 106 | } 107 | 108 | /** 109 | * Returns a TimeSeries where each time series is quotiented with order 1. The new TimeSeries will 110 | * be missing the first date-time. 111 | */ 112 | def quotients(): TimeSeries = quotients(1) 113 | 114 | /** 115 | * Returns a return series for each time series. Assumes periodic (as opposed to continuously 116 | * compounded) returns. 117 | */ 118 | def price2ret(): TimeSeries = { 119 | mapSeries(index.islice(1, index.size), vec => UnivariateTimeSeries.price2ret(vec, 1)) 120 | } 121 | 122 | def univariateSeriesIterator(): Iterator[Vector[Double]] = { 123 | new Iterator[Vector[Double]] { 124 | var i = 0 125 | def hasNext: Boolean = i < data.cols 126 | def next(): Vector[Double] = { 127 | i += 1 128 | data(::, i - 1) 129 | } 130 | } 131 | } 132 | 133 | def univariateKeyAndSeriesIterator(): Iterator[(String, Vector[Double])] = { 134 | new Iterator[(String, Vector[Double])] { 135 | var i = 0 136 | def hasNext: Boolean = i < data.cols 137 | def next(): (String, Vector[Double]) = { 138 | i += 1 139 | (keys(i - 1), data(::, i - 1)) 140 | } 141 | } 142 | } 143 | 144 | /** 145 | * Applies a transformation to each series that preserves the time index. 146 | */ 147 | def mapSeries(f: (Vector[Double]) => Vector[Double]): TimeSeries = { 148 | mapSeries(index, f) 149 | } 150 | 151 | /** 152 | * Applies a transformation to each series that preserves the time index. Passes the key along 153 | * with each series. 154 | */ 155 | def mapSeriesWithKey(f: (String, Vector[Double]) => Vector[Double]): TimeSeries = { 156 | val newData = new DenseMatrix[Double](index.size, data.cols) 157 | univariateKeyAndSeriesIterator().zipWithIndex.foreach { case ((key, series), i) => 158 | newData(::, i) := f(key, series) 159 | } 160 | new TimeSeries(index, newData, keys) 161 | } 162 | 163 | /** 164 | * Applies a transformation to each series such that the resulting series align with the given 165 | * time index. 166 | */ 167 | def mapSeries(newIndex: DateTimeIndex, f: (Vector[Double]) => Vector[Double]): TimeSeries = { 168 | val newSize = newIndex.size 169 | val newData = new DenseMatrix[Double](newSize, data.cols) 170 | univariateSeriesIterator().zipWithIndex.foreach { case (vec, i) => 171 | newData(::, i) := f(vec) 172 | } 173 | new TimeSeries(newIndex, newData, keys) 174 | } 175 | 176 | def mapValues[U](f: (Vector[Double]) => U): Seq[(String, U)] = { 177 | univariateKeyAndSeriesIterator().map(ks => (ks._1, f(ks._2))).toSeq 178 | } 179 | 180 | /** 181 | * Gets the first univariate series and its key. 182 | */ 183 | def head(): (String, Vector[Double]) = univariateKeyAndSeriesIterator().next() 184 | } 185 | 186 | object TimeSeries { 187 | def timeSeriesFromIrregularSamples(samples: Seq[(DateTime, Array[Double])], keys: Array[String]) 188 | : TimeSeries = { 189 | val mat = new DenseMatrix[Double](samples.length, samples.head._2.length) 190 | val dts = new Array[Long](samples.length) 191 | for (i <- samples.indices) { 192 | val (dt, values) = samples(i) 193 | dts(i) = dt.getMillis 194 | mat(i to i, ::) := new DenseVector[Double](values) 195 | } 196 | new TimeSeries(new IrregularDateTimeIndex(dts), mat, keys) 197 | } 198 | 199 | /** 200 | * This function should only be called when you can safely make the assumption that the time 201 | * samples are uniform (monotonously increasing) across time. 202 | */ 203 | def timeSeriesFromUniformSamples( 204 | samples: Seq[Array[Double]], 205 | index: UniformDateTimeIndex, 206 | keys: Array[String]): TimeSeries = { 207 | val mat = new DenseMatrix[Double](samples.length, samples.head.length) 208 | 209 | for (i <- samples.indices) { 210 | mat(i to i, ::) := new DenseVector[Double](samples(i)) 211 | } 212 | new TimeSeries(index, mat, keys) 213 | } 214 | } 215 | 216 | trait TimeSeriesFilter extends Serializable { 217 | /** 218 | * Takes a time series of i.i.d. observations and filters it to take on this model's 219 | * characteristics. 220 | * @param ts Time series of i.i.d. observations. 221 | * @param dest Array to put the filtered time series, can be the same as ts. 222 | * @return the dest param. 223 | */ 224 | def filter(ts: Array[Double], dest: Array[Double]): Array[Double] 225 | } 226 | -------------------------------------------------------------------------------- /src/main/scala/com/cloudera/sparkts/TimeSeriesKryoRegistrator.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.sparkts 17 | 18 | import com.esotericsoftware.kryo.{Serializer, Kryo} 19 | import com.esotericsoftware.kryo.io.{Output, Input} 20 | 21 | import org.apache.spark.SparkConf 22 | import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer} 23 | 24 | import org.joda.time.DateTime 25 | 26 | class TimeSeriesKryoRegistrator extends KryoRegistrator { 27 | def registerClasses(kryo: Kryo): Unit = { 28 | kryo.register(classOf[TimeSeries]) 29 | kryo.register(classOf[UniformDateTimeIndex]) 30 | kryo.register(classOf[IrregularDateTimeIndex]) 31 | kryo.register(classOf[BusinessDayFrequency]) 32 | kryo.register(classOf[DayFrequency]) 33 | kryo.register(classOf[DateTime], new DateTimeSerializer) 34 | } 35 | } 36 | 37 | class DateTimeSerializer extends Serializer[DateTime] { 38 | def write(kryo: Kryo, out: Output, dt: DateTime) = { 39 | out.writeLong(dt.getMillis, true) 40 | } 41 | 42 | def read(kryo: Kryo, in: Input, clazz: Class[DateTime]): DateTime = { 43 | new DateTime(in.readLong(true)) 44 | } 45 | } 46 | 47 | object TimeSeriesKryoRegistrator { 48 | def registerKryoClasses(conf: SparkConf): Unit = { 49 | conf.set("spark.serializer", classOf[KryoSerializer].getName) 50 | conf.set("spark.kryo.registrator", classOf[TimeSeriesKryoRegistrator].getName) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /src/main/scala/com/cloudera/sparkts/TimeSeriesModel.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.sparkts 17 | 18 | import breeze.linalg._ 19 | 20 | /** 21 | * Models time dependent effects in a time series. 22 | */ 23 | trait TimeSeriesModel { 24 | /** 25 | * Takes a time series that is assumed to have this model's characteristics and returns a time 26 | * series with time-dependent effects of this model removed. 27 | * 28 | * This is the inverse of [[TimeSeriesModel#addTimeDependentEffects]]. 29 | * 30 | * @param ts Time series of observations with this model's characteristics. 31 | * @param dest Array to put the filtered series, can be the same as ts. 32 | * @return The dest series, for convenience. 33 | */ 34 | def removeTimeDependentEffects(ts: Vector[Double], dest: Vector[Double] = null): Vector[Double] 35 | 36 | /** 37 | * Takes a series of i.i.d. observations and returns a time series based on it with the 38 | * time-dependent effects of this model added. 39 | * 40 | * @param ts Time series of i.i.d. observations. 41 | * @param dest Array to put the filtered series, can be the same as ts. 42 | * @return The dest series, for convenience. 43 | */ 44 | def addTimeDependentEffects(ts: Vector[Double], dest: Vector[Double] = null): Vector[Double] 45 | } 46 | -------------------------------------------------------------------------------- /src/main/scala/com/redislabs/provider/RedisConfig.scala: -------------------------------------------------------------------------------- 1 | package com.redislabs.provider 2 | 3 | import com.redislabs.provider.redis.NodesInfo._ 4 | 5 | class RedisConfig(val ip: String, val port: Int) extends Serializable { 6 | val nodes: java.util.ArrayList[(String, Int)] = new java.util.ArrayList[(String, Int)] 7 | 8 | getNodes((ip, port)).foreach(x => nodes.add((x._1, x._2))) 9 | 10 | /** 11 | * 12 | * @param sPos start position of a slots range 13 | * @param ePos end position of a slots range 14 | * @return list of nodes(addr, port, index, range, startSlot, endSlot), where Inter([startSlot, endSlot], [sPos, ePos]) is not Nil(master only now) 15 | */ 16 | def getNodesBySlots(sPos: Int, ePos: Int) = { 17 | def inter(sPos1: Int, ePos1: Int, sPos2: Int, ePos2: Int):Boolean = { 18 | if (sPos1 <= sPos2) 19 | return ePos1 >= sPos2 20 | else 21 | return ePos2 >= sPos1 22 | } 23 | val node = nodes.get(scala.util.Random.nextInt().abs % nodes.size()) 24 | getSlots((node._1, node._2)).filter(node => inter(sPos, ePos, node._5, node._6)).filter(_._3 == 0) //master only now 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/scala/com/redislabs/provider/redis/package.scala: -------------------------------------------------------------------------------- 1 | package com.redislabs.provider 2 | 3 | package object redis extends RedisFunctions -------------------------------------------------------------------------------- /src/main/scala/com/redislabs/provider/redis/partitioner/RedisPartition.scala: -------------------------------------------------------------------------------- 1 | package com.redislabs.provider.redis.partitioner 2 | 3 | import java.net.InetAddress 4 | import java.util 5 | import org.apache.spark.Partition 6 | import com.redislabs.provider._ 7 | 8 | case class RedisPartition(index: Int, 9 | redisConfig: RedisConfig, 10 | slots: (Int, Int)) extends Partition -------------------------------------------------------------------------------- /src/main/scala/com/redislabs/provider/redis/partitioner/RedisPartitioner.scala: -------------------------------------------------------------------------------- 1 | package com.redislabs.provider.redis.partitioner 2 | -------------------------------------------------------------------------------- /src/main/scala/com/redislabs/provider/redis/redisFunctions.scala: -------------------------------------------------------------------------------- 1 | package com.redislabs.provider.redis 2 | 3 | import org.apache.spark.SparkContext 4 | import org.apache.spark.rdd.RDD 5 | import redis.clients.jedis.Jedis 6 | import redis.clients.util.{SafeEncoder, JedisClusterCRC16} 7 | import scala.collection.JavaConversions._ 8 | import com.redislabs.provider.redis.rdd._ 9 | import com.redislabs.provider.redis.SaveToRedis._ 10 | import com.redislabs.provider.redis.NodesInfo._ 11 | 12 | class RedisContext(val sc: SparkContext) extends Serializable { 13 | /** 14 | * @param initialHost any addr and port of a cluster or a single server 15 | * @param keyPattern 16 | * @param partitionNum number of partitions 17 | * @return RedisKeysRDD of simple Keys stored in redis server 18 | */ 19 | def fromRedisKeyPattern(initialHost: (String, Int), 20 | keyPattern: String = "*", 21 | partitionNum: Int = 3) = { 22 | new RedisKeysRDD(sc, initialHost, keyPattern, partitionNum); 23 | } 24 | 25 | /** 26 | * @param kvs Pair RDD of K/V 27 | * @param zsetName target zset's name which hold all the kvs 28 | * @param initialHost any addr and port of a cluster or a single server 29 | * save all the kvs to zsetName(zset type) in redis-server 30 | */ 31 | def toRedisZSET(kvs: RDD[(String, String)], 32 | zsetName: String, 33 | initialHost: (String, Int)) = { 34 | val host = getHost(zsetName, initialHost) 35 | kvs.foreachPartition(partition => setZset(host, zsetName, partition)) 36 | } 37 | } 38 | 39 | object NodesInfo { 40 | 41 | /** 42 | * @param initialHost any addr and port of a cluster or a single server 43 | * @return true if the target server is in cluster mode 44 | */ 45 | private def clusterEnable(initialHost: (String, Int)) : Boolean = { 46 | val jedis = new Jedis(initialHost._1, initialHost._2) 47 | val res = jedis.info("cluster").contains("1") 48 | jedis.close 49 | res 50 | } 51 | 52 | /** 53 | * @param hosts list of hosts(addr, port, startSlot, endSlot) 54 | * @param key 55 | * @return host whose slots should involve key 56 | */ 57 | def findHost(hosts: Array[(String, Int, Int, Int)], key: String) = { 58 | val slot = JedisClusterCRC16.getSlot(key) 59 | hosts.filter(host => {host._3 <= slot && host._4 >= slot})(0) 60 | } 61 | /** 62 | * @param key 63 | * @param initialHost any addr and port of a cluster or a single server 64 | * @return host whose slots should involve key 65 | */ 66 | def getHost(key: String, initialHost: (String, Int)) = { 67 | val slot = JedisClusterCRC16.getSlot(key); 68 | val hosts = getSlots(initialHost).filter(x => (x._3 == 0 && x._5 <= slot && x._6 >= slot)).map(x => (x._1, x._2)) 69 | hosts(0) 70 | } 71 | /** 72 | * @param initialHost any addr and port of a cluster or a single server 73 | * @return list of hosts(addr, port, startSlot, endSlot) 74 | */ 75 | def getHosts(initialHost: (String, Int)) = { 76 | getSlots(initialHost).filter(_._3 == 0).map(x => (x._1, x._2, x._5, x._6)) 77 | } 78 | 79 | /** 80 | * @param initialHost any addr and port of a single server 81 | * @return list of nodes(addr, port, index, range, startSlot, endSlot) 82 | */ 83 | private def getNonClusterSlots(initialHost: (String, Int)) = { 84 | getNonClusterNodes(initialHost).map(x=> (x._1, x._2, x._3, x._4, 0, 16383)).toArray 85 | } 86 | /** 87 | * @param initialHost any addr and port of a cluster server 88 | * @return list of nodes(addr, port, index, range, startSlot, endSlot) 89 | */ 90 | private def getClusterSlots(initialHost: (String, Int)) = { 91 | val j = new Jedis(initialHost._1, initialHost._2) 92 | val res = j.clusterSlots().asInstanceOf[java.util.List[java.lang.Object]].flatMap { 93 | slotInfoObj => 94 | { 95 | val slotInfo = slotInfoObj.asInstanceOf[java.util.List[java.lang.Object]] 96 | val sPos = slotInfo.get(0).toString.toInt 97 | val ePos = slotInfo.get(1).toString.toInt 98 | (0 until (slotInfo.size - 2)).map(i => { 99 | val node = slotInfo(i + 2).asInstanceOf[java.util.List[java.lang.Object]] 100 | (SafeEncoder.encode(node.get(0).asInstanceOf[Array[scala.Byte]]), 101 | node.get(1).toString.toInt, 102 | i, 103 | slotInfo.size - 2, 104 | sPos, 105 | ePos) 106 | }) 107 | } 108 | }.toArray 109 | j.close 110 | res 111 | } 112 | /** 113 | * @param initialHost any addr and port of a cluster or a single server 114 | * @return list of nodes(addr, port, index, range, startSlot, endSlot) 115 | */ 116 | def getSlots(initialHost: (String, Int)) = { 117 | if (clusterEnable(initialHost)) 118 | getClusterSlots(initialHost) 119 | else 120 | getNonClusterSlots(initialHost) 121 | } 122 | 123 | 124 | /** 125 | * @param initialHost any addr and port of a single server 126 | * @return list of nodes(addr, port, index, range) 127 | */ 128 | private def getNonClusterNodes(initialHost: (String, Int)) = { 129 | var master = initialHost 130 | val j = new Jedis(initialHost._1, initialHost._2) 131 | var replinfo = j.info("Replication").split("\n") 132 | j.close 133 | if (replinfo.filter(_.contains("role:slave")).length != 0){ 134 | val host = replinfo.filter(_.contains("master_host:"))(0).trim.substring(12) 135 | val port = replinfo.filter(_.contains("master_port:"))(0).trim.substring(12).toInt 136 | master = (host, port) 137 | val j = new Jedis(host, port) 138 | replinfo = j.info("Replication").split("\n") 139 | j.close 140 | } 141 | val slaves = replinfo.filter(x => (x.contains("slave") && x.contains("online"))).map(rl => { 142 | val content = rl.substring(rl.indexOf(':') + 1).split(",") 143 | val ip = content(0) 144 | val port = content(1) 145 | (ip.substring(ip.indexOf('=')+1).toString, port.substring(port.indexOf('=')+1).toInt) 146 | }) 147 | val nodes = master +: slaves 148 | val range = nodes.size 149 | (0 until range).map(i => (nodes(i)._1, nodes(i)._2, i, range)).toArray 150 | } 151 | /** 152 | * @param initialHost any addr and port of a cluster server 153 | * @return list of nodes(addr, port, index, range) 154 | */ 155 | private def getClusterNodes(initialHost: (String, Int)) = { 156 | val j = new Jedis(initialHost._1, initialHost._2) 157 | val res = j.clusterSlots().asInstanceOf[java.util.List[java.lang.Object]].flatMap { 158 | slotInfoObj => 159 | { 160 | val slotInfo = slotInfoObj.asInstanceOf[java.util.List[java.lang.Object]].drop(2) 161 | val range = slotInfo.size 162 | (0 until range).map(i => { 163 | var node = slotInfo(i).asInstanceOf[java.util.List[java.lang.Object]] 164 | (SafeEncoder.encode(node.get(0).asInstanceOf[Array[scala.Byte]]), 165 | node.get(1).toString.toInt, 166 | i, 167 | range) 168 | }) 169 | } 170 | }.distinct.toArray 171 | j.close 172 | res 173 | } 174 | 175 | /** 176 | * @param initialHost any addr and port of a cluster or a single server 177 | * @return list of nodes(addr, port, index, range) 178 | */ 179 | def getNodes(initialHost: (String, Int)) = { 180 | if (clusterEnable(initialHost)) 181 | getClusterNodes(initialHost) 182 | else 183 | getNonClusterNodes(initialHost) 184 | } 185 | } 186 | 187 | object SaveToRedis { 188 | /** 189 | * @param host addr and port of a target host 190 | * @param zsetName 191 | * @param arr k/vs which should be saved in the target host 192 | * save all the k/vs to zsetName(zset type) to the target host 193 | */ 194 | def setZset(host: (String, Int), zsetName: String, arr: Iterator[(String, String)]) = { 195 | val jedis = new Jedis(host._1, host._2) 196 | val pipeline = jedis.pipelined 197 | arr.foreach(x => pipeline.zadd(zsetName, x._2.toDouble, x._1)) 198 | pipeline.sync 199 | jedis.close 200 | } 201 | } 202 | 203 | trait RedisFunctions { 204 | implicit def toRedisContext(sc: SparkContext): RedisContext = new RedisContext(sc) 205 | } 206 | 207 | -------------------------------------------------------------------------------- /src/main/scala/com/redislabs/provider/sql/DefaultSource.scala: -------------------------------------------------------------------------------- 1 | package com.redislabs.provider.sql 2 | 3 | import org.apache.spark.rdd.RDD 4 | import org.apache.spark.sql.{Row, SQLContext} 5 | import org.apache.spark.sql.sources._ 6 | import org.apache.spark.sql.types._ 7 | 8 | import org.joda.time.DateTime 9 | import org.joda.time.DateTimeZone.UTC 10 | 11 | import com.cloudera.sparkts.{HourFrequency, BusinessDayFrequency, DayFrequency, Frequency} 12 | import com.cloudera.sparkts.DateTimeIndex._ 13 | 14 | import java.sql.Timestamp 15 | 16 | import com.redislabs.provider.redis._ 17 | 18 | class FilterParser(filters: Array[Filter], fieldName: Map[String, String]) { 19 | private val filtersByAttr: Map[String, Array[Filter]] = filters.map(f => (getAttr(f), f)).groupBy(_._1).mapValues(a => a.map(p => p._2)) 20 | def getStartTime: String = { 21 | var startTime: Timestamp = new Timestamp(90, 10, 14, 0, 0, 0, 0) 22 | filtersByAttr.getOrElse(fieldName("timestamp"), new Array[Filter](0)).foreach({ 23 | case GreaterThan(attr, v) => startTime = v.asInstanceOf[Timestamp] 24 | case GreaterThanOrEqual(attr, v) => startTime = v.asInstanceOf[Timestamp] 25 | case _ => {} 26 | }) 27 | val st = startTime.toString 28 | st.substring(0, st.indexOf(' ')) 29 | } 30 | def getEndTime: String = { 31 | var endTime: Timestamp = new Timestamp(91, 11, 3, 0, 0, 0, 0) 32 | filtersByAttr.getOrElse(fieldName("timestamp"), new Array[Filter](0)).foreach({ 33 | case LessThan(attr, v) => endTime = v.asInstanceOf[Timestamp] 34 | case LessThanOrEqual(attr, v) => endTime = v.asInstanceOf[Timestamp] 35 | case _ => {} 36 | }) 37 | val et = endTime.toString 38 | et.substring(0, et.indexOf(' ')) 39 | } 40 | def getFrequency(frequency: String): Frequency = { 41 | frequency match { 42 | case "businessDay" => new BusinessDayFrequency(1) 43 | case "day" => new DayFrequency(1) 44 | case "hour" => new HourFrequency(1) 45 | case _ => new BusinessDayFrequency(1) 46 | } 47 | } 48 | def getDateTimeIndex(frequency: String) = { 49 | val start = getStartTime 50 | val end = getEndTime 51 | val freq = getFrequency(frequency) 52 | uniform(new DateTime(start, UTC), new DateTime(end, UTC), freq) 53 | } 54 | private def getAttr(f: Filter): String = { 55 | f match { 56 | case EqualTo(attribute, value) => attribute 57 | case GreaterThan(attribute, value) => attribute 58 | case GreaterThanOrEqual(attribute, value) => attribute 59 | case LessThan(attribute, value) => attribute 60 | case LessThanOrEqual(attribute, value) => attribute 61 | case In(attribute, values) => attribute 62 | case IsNull(attribute) => attribute 63 | case IsNotNull(attribute) => attribute 64 | case StringStartsWith(attribute, value) => attribute 65 | case StringEndsWith(attribute, value) => attribute 66 | case StringContains(attribute, value) => attribute 67 | } 68 | } 69 | } 70 | 71 | case class InstantScan(parameters: Map[String, String]) 72 | (@transient val sqlContext: SQLContext) 73 | extends BaseRelation with PrunedFilteredScan { 74 | 75 | val host: String = parameters.getOrElse("host", "127.0.0.1") 76 | val port: Int = parameters.getOrElse("port", "6379").toInt 77 | val prefix: String = parameters.getOrElse("prefix", "TASK") 78 | val frequency: String = parameters.getOrElse("frequency", "businessDay") 79 | 80 | private def getTimeSchema: StructField = { 81 | StructField("instant", TimestampType, nullable = true) 82 | } 83 | 84 | private def getColSchema: Array[StructField] = { 85 | val start = new DateTime("0000-01-03", UTC) 86 | val end = new DateTime("0000-01-03", UTC) 87 | val dtIndex = uniform(start, end, 1.businessDays) 88 | 89 | var rtsRdd = sqlContext.sparkContext.fromRedisKeyPattern((host, port), prefix + "_*").getRedisTimeSeriesRDD(dtIndex) 90 | if (parameters.get("keyPattern") != None) 91 | rtsRdd = rtsRdd.filterKeys(parameters("keyPattern")) 92 | if (parameters.get("startingBefore") != None) 93 | rtsRdd = rtsRdd.filterStartingBefore(new DateTime(parameters("startingBefore"), UTC)) 94 | if (parameters.get("endingAfter") != None) 95 | rtsRdd = rtsRdd.filterEndingAfter(new DateTime(parameters("endingAfter"), UTC)) 96 | return rtsRdd.toTimeSeriesRDD().keys.map(StructField(_, DoubleType, nullable = true)) 97 | } 98 | 99 | override val needConversion: Boolean = false 100 | 101 | val schema: StructType = StructType(getTimeSchema +: getColSchema) 102 | 103 | def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { 104 | val requiredColumnsIndex = requiredColumns.map(schema.fieldIndex(_)) 105 | val fp = new FilterParser(filters, Map("timestamp" -> "instant")) 106 | val dtindex = fp.getDateTimeIndex(frequency) 107 | 108 | var rtsRdd = { 109 | if (parameters.get("partition") != None) 110 | sqlContext.sparkContext.fromRedisKeyPattern((host, port), prefix + "_*", parameters("partition").toInt).getRedisTimeSeriesRDD(dtindex) 111 | else 112 | sqlContext.sparkContext.fromRedisKeyPattern((host, port), prefix + "_*").getRedisTimeSeriesRDD(dtindex) 113 | } 114 | 115 | if (parameters.get("keyPattern") != None) 116 | rtsRdd = rtsRdd.filterKeys(parameters("keyPattern")) 117 | if (parameters.get("startingBefore") != None) 118 | rtsRdd = rtsRdd.filterStartingBefore(new DateTime(parameters("startingBefore"), UTC)) 119 | if (parameters.get("endingAfter") != None) 120 | rtsRdd = rtsRdd.filterEndingAfter(new DateTime(parameters("endingAfter"), UTC)) 121 | if (parameters.get("mapSeries") != None) { 122 | parameters("mapSeries").split(",").foreach(ms => rtsRdd = rtsRdd.mapSeries(sqlContext.getMapSeries(ms.trim))) 123 | } 124 | rtsRdd.toTimeSeriesRDD().toInstants().map(x => new Timestamp(x._1.getMillis()) +: x._2.toArray).map{ 125 | candidates => requiredColumnsIndex.map(candidates(_)) 126 | }.map(x => Row.fromSeq(x.toSeq)) 127 | } 128 | } 129 | 130 | case class ObservationScan(parameters: Map[String, String]) 131 | (@transient val sqlContext: SQLContext) 132 | extends BaseRelation with PrunedFilteredScan { 133 | 134 | val host: String = parameters.getOrElse("host", "127.0.0.1") 135 | val port: Int = parameters.getOrElse("port", "6379").toInt 136 | val prefix: String = parameters.getOrElse("prefix", "TASK") 137 | val frequency: String = parameters.getOrElse("frequency", "businessDay") 138 | val tsCol: String = parameters.getOrElse("tsCol", "timestamp") 139 | val keyCol: String = parameters.getOrElse("keyCol", "key") 140 | val valueCol: String = parameters.getOrElse("valueCol", "value") 141 | 142 | override val needConversion: Boolean = false 143 | 144 | val schema = new StructType(Array( 145 | new StructField(tsCol, TimestampType), 146 | new StructField(keyCol, StringType), 147 | new StructField(valueCol, DoubleType) 148 | )) 149 | 150 | def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { 151 | val requiredColumnsIndex = requiredColumns.map(schema.fieldIndex(_)) 152 | val fp = new FilterParser(filters, Map("timestamp" -> tsCol)) 153 | val dtindex = fp.getDateTimeIndex(frequency) 154 | 155 | var rtsRdd = { 156 | if (parameters.get("partition") != None) 157 | sqlContext.sparkContext.fromRedisKeyPattern((host, port), prefix + "_*", parameters("partition").toInt).getRedisTimeSeriesRDD(dtindex) 158 | else 159 | sqlContext.sparkContext.fromRedisKeyPattern((host, port), prefix + "_*").getRedisTimeSeriesRDD(dtindex) 160 | } 161 | 162 | if (parameters.get("keyPattern") != None) 163 | rtsRdd = rtsRdd.filterKeys(parameters("keyPattern")) 164 | if (parameters.get("startingBefore") != None) 165 | rtsRdd = rtsRdd.filterStartingBefore(new DateTime(parameters("startingBefore"), UTC)) 166 | if (parameters.get("endingAfter") != None) 167 | rtsRdd = rtsRdd.filterEndingAfter(new DateTime(parameters("endingAfter"), UTC)) 168 | if (parameters.get("mapSeries") != None) { 169 | parameters("mapSeries").split(",").foreach(ms => rtsRdd = rtsRdd.mapSeries(sqlContext.getMapSeries(ms.trim))) 170 | } 171 | 172 | rtsRdd.flatMap{ 173 | case (key, series) => { 174 | series.iterator.flatMap { 175 | case (i, value) => 176 | if (value.isNaN) 177 | None 178 | else { 179 | val candidates = (new Timestamp(dtindex.dateTimeAtLoc(i).getMillis)) +: key +: value +: Nil 180 | Some(Row.fromSeq(requiredColumnsIndex.map(candidates(_)).toSeq)) 181 | } 182 | } 183 | } 184 | } 185 | } 186 | } 187 | 188 | class DefaultSource extends RelationProvider { 189 | def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) = { 190 | if (parameters.getOrElse("type", "instant") == "observation") 191 | ObservationScan(parameters)(sqlContext) 192 | else 193 | InstantScan(parameters)(sqlContext) 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /src/main/scala/com/redislabs/provider/sql/RedisSQLFunctions.scala: -------------------------------------------------------------------------------- 1 | package com.redislabs.provider.sql 2 | 3 | import org.apache.spark.sql.SQLContext 4 | import breeze.linalg.Vector 5 | 6 | import scala.collection.mutable.HashMap 7 | 8 | /** 9 | * RedisSQLContext is used to import and make use of mapSeries function for instant and 10 | * observation dataframe. It works by implicitly casting an SQLContext to RedisSQLContext. 11 | */ 12 | class RedisSQLContext(val sc: SQLContext) extends Serializable { 13 | def setMapSeries(funcName: String, mapSeries: (Vector[Double] => Vector[Double])) = { 14 | RedisSQLContext.funcMap += (funcName -> mapSeries) 15 | } 16 | def getMapSeries(funcName: String): (Vector[Double] => Vector[Double]) = { 17 | RedisSQLContext.funcMap.getOrElse(funcName, (x: Vector[Double]) => x) 18 | } 19 | } 20 | 21 | object RedisSQLContext { 22 | private val funcMap = new HashMap[String, (Vector[Double] => Vector[Double])]() 23 | } 24 | 25 | trait RedisSQLFunctions { 26 | implicit def toRedisSQLContext(sc: SQLContext): RedisSQLContext = new RedisSQLContext(sc) 27 | } 28 | -------------------------------------------------------------------------------- /src/main/scala/com/redislabs/provider/sql/package.scala: -------------------------------------------------------------------------------- 1 | package com.redislabs.provider 2 | 3 | package object sql extends RedisSQLFunctions 4 | -------------------------------------------------------------------------------- /src/main/scala/com/redislabs/provider/util/GenerateWorkdayTestData.scala: -------------------------------------------------------------------------------- 1 | package com.redislabs.provider.util 2 | 3 | import com.github.nscala_time.time.Imports._ 4 | import org.joda.time.DateTimeZone.UTC 5 | import com.cloudera.sparkts.DateTimeIndex._ 6 | import java.io._ 7 | import java.util._ 8 | /** 9 | * @author sunheehnus 10 | */ 11 | object GenerateWorkdayTestData { 12 | def Generate(dir: String, testCnt: Int, testFileFloor: Int, testFileCeil: Int, startTime: String, endTime: String) { 13 | val rnd = new Random 14 | (1 to testCnt).foreach{ i=> { 15 | GenerateBetween(dir + "/TEST" + i.toString, startTime, endTime, testFileFloor + rnd.nextInt(testFileCeil - testFileFloor) + 1) 16 | }} 17 | } 18 | def GenerateBetween(dir: String, startTime: String, endTime: String, fileNum: Int) { 19 | val folder = new File(dir) 20 | folder.mkdirs() 21 | 22 | val rnd = new Random 23 | for (i <- 1 to fileNum) { 24 | val stMillis = new DateTime(startTime, UTC).getMillis 25 | val edMillis = new DateTime(endTime, UTC).getMillis 26 | 27 | var floor: String = "" 28 | var ceil: String = "" 29 | do { 30 | val time1 = stMillis + (rnd.nextLong.abs % (edMillis - stMillis)) 31 | val time2 = stMillis + (rnd.nextLong.abs % (edMillis - stMillis)) 32 | floor = nextBusinessDay(new DateTime(if (time1 > time2) time2 else time1, UTC)).toString 33 | floor = floor.substring(0, floor.indexOf("T")) 34 | ceil = nextBusinessDay(new DateTime(if (time1 > time2) time1 else time2, UTC)).toString 35 | ceil = ceil.substring(0, ceil.indexOf("T")) 36 | } while (floor == ceil) 37 | GenerateWorkday(dir + "/test" + i.toString + ".csv", floor, ceil, rnd.nextInt(15) + 1) 38 | } 39 | } 40 | 41 | def GenerateWorkday(targetName: String, startTime: String, endTime: String, colNum: Int) { 42 | val rnd = new Random 43 | val floor = new DateTime(startTime, UTC) 44 | val ceil = new DateTime(endTime, UTC) 45 | val dates: ArrayList[DateTime] = new ArrayList[DateTime] 46 | var dummy_day = floor 47 | while (dummy_day != ceil) { 48 | if (rnd.nextInt(100) < 90) { 49 | dates.add(dummy_day) 50 | } 51 | dummy_day = nextBusinessDay(dummy_day + 1.day) 52 | } 53 | val writer = new PrintWriter(new File(targetName)) 54 | writer.write("Date") 55 | for (j <- 1 to colNum) { 56 | writer.write(",Col" + j.toString) 57 | } 58 | writer.write("\n") 59 | for (i <- 0 to dates.size - 1) { 60 | val data = dates.get(dates.size - i - 1) 61 | writer.write(data.toString.substring(0, data.toString.indexOf("T"))) 62 | for (j <- 1 to colNum) { 63 | val num = rnd.nextInt(10000)/100.0 + 500 * j 64 | writer.write(f",${num}%.2f") 65 | } 66 | writer.write("\n") 67 | } 68 | writer.close 69 | return 70 | } 71 | } -------------------------------------------------------------------------------- /src/main/scala/com/redislabs/provider/util/ImportTimeSeriesData.scala: -------------------------------------------------------------------------------- 1 | package com.redislabs.provider.util 2 | 3 | import org.apache.spark.SparkContext 4 | import org.apache.spark.rdd.RDD 5 | import org.joda.time.DateTime 6 | import org.joda.time.DateTimeZone.UTC 7 | import breeze.linalg._ 8 | import breeze.numerics._ 9 | 10 | import com.redislabs.provider.redis._ 11 | 12 | import redis.clients.jedis._ 13 | import redis.clients.jedis.{ HostAndPort, JedisCluster } 14 | 15 | import com.redislabs.provider.redis.rdd._ 16 | import com.redislabs.provider.redis.SaveToRedis._ 17 | import com.redislabs.provider.redis.NodesInfo._ 18 | 19 | object ImportTimeSeriesData { 20 | def RedisWrite(text: String, keyPrefix: String = "", redisNode: (String, Int)) = { 21 | val lines = text.split('\n') 22 | val label = keyPrefix + lines(0).split(",", 2).tail.head 23 | val samples = lines.tail.map(line => { 24 | val tokens = line.split(",", 2) 25 | val dt = new DateTime(tokens.head, UTC) 26 | (dt, tokens.tail.head) 27 | }) 28 | 29 | val dts = new Array[String](samples.length) 30 | val mat = new Array[String](samples.length) 31 | (0 until samples.length).map(i => { 32 | val (dt, vals) = samples(i) 33 | dts(i) = dt.getMillis.toString 34 | mat(i) = vals 35 | }) 36 | val host = getHost(label, redisNode) 37 | setZset(host, label, (for (j <- 0 to dts.length - 1) yield (j + "_" + mat(j), dts(j).toString)).iterator) 38 | } 39 | def ImportToRedisServer(dir: String, prefix: String, sc: SparkContext, redisNode: (String, Int)) { 40 | sc.wholeTextFiles(dir).foreach { 41 | case (path, text) => RedisWrite(text, prefix + "_RedisTS_" + path.split('/').last + "_RedisTS_", redisNode) 42 | } 43 | } 44 | } -------------------------------------------------------------------------------- /src/main/scala/com/test/sql/DefaultSource.scala: -------------------------------------------------------------------------------- 1 | package com.test.sql 2 | 3 | 4 | import com.cloudera.sparkts.DayFrequency 5 | import com.cloudera.sparkts._ 6 | import org.apache.spark.rdd.RDD 7 | import org.apache.spark.sql.{Row, SQLContext} 8 | import org.apache.spark.sql.sources._ 9 | import org.apache.spark.sql.types._ 10 | 11 | import org.joda.time.DateTime 12 | import org.joda.time.DateTimeZone.UTC 13 | import com.cloudera.sparkts.DateTimeIndex._ 14 | 15 | import java.sql.Timestamp 16 | 17 | import breeze.linalg._ 18 | 19 | case class SCAN(parameters: Map[String, String]) 20 | (@transient val sqlContext: SQLContext) 21 | extends BaseRelation with PrunedFilteredScan { 22 | 23 | private def getTimeSchema: StructField = { 24 | StructField("instant", TimestampType, nullable = true) 25 | } 26 | 27 | private def getColSchema: Seq[StructField] = { 28 | (1 to 2048).map(_.toString).map(StructField(_, DoubleType, nullable = true)) 29 | } 30 | 31 | val schema: StructType = StructType(getTimeSchema +: getColSchema) 32 | 33 | def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { 34 | val rdd1 = sqlContext.sparkContext.parallelize(1 to 2048, 3) 35 | val rdd2 = rdd1.map(x => (x.toString, Vector((x to (x + 2048)).map(_.toDouble).toSeq:_*))) 36 | val start = new DateTime("2010-01-04", UTC) 37 | val end = start.plusDays(2048) 38 | val index = uniform(start, end, new DayFrequency(1)) 39 | val rdd3 = new TimeSeriesRDD(index, rdd2) 40 | 41 | val requiredColumnsIndex = requiredColumns.map(schema.fieldIndex(_)) 42 | 43 | rdd3.toInstants().mapPartitions { case iter => 44 | iter.map(x => new Timestamp(x._1.getMillis) +: x._2.toArray).map { 45 | candidates => requiredColumnsIndex.map(candidates(_)) 46 | }.map(x => Row.fromSeq(x.toSeq)) 47 | } 48 | } 49 | } 50 | 51 | 52 | class DefaultSource extends RelationProvider { 53 | def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) = { 54 | SCAN(parameters)(sqlContext) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/main/scala/com/test/sql/rrun.scala: -------------------------------------------------------------------------------- 1 | package com.test.sql 2 | 3 | import org.apache.spark.sql.SQLContext 4 | import org.apache.spark.{SparkContext, SparkConf} 5 | import com.cloudera.sparkts._ 6 | 7 | import org.joda.time.DateTime 8 | import org.joda.time.DateTimeZone.UTC 9 | import com.cloudera.sparkts.DateTimeIndex._ 10 | 11 | import breeze.linalg._ 12 | 13 | /** 14 | * Created by sunheehnus on 16/1/8. 15 | */ 16 | object rrun extends App { 17 | val conf = new SparkConf().setAppName("test").setMaster("local") 18 | val sc = new SparkContext(conf) 19 | 20 | val sqlContext = new SQLContext(sc) 21 | 22 | val rdd1 = sqlContext.sparkContext.parallelize(1 to 2048, 3) 23 | val rdd2 = rdd1.map(x => (x.toString, Vector((x to (x + 2048)).map(_.toDouble).toSeq:_*))) 24 | val start = new DateTime("2010-01-04", UTC) 25 | val end = start.plusDays(2048) 26 | val index = uniform(start, end, new DayFrequency(1)) 27 | val rdd3 = new TimeSeriesRDD(index, rdd2) 28 | var df = rdd3.toInstantsDataFrame(sqlContext) 29 | 30 | val options = Map("frequency" -> "businessDay", "type" -> "observation", "mapSeries" -> "linear") 31 | val cmpdf = sqlContext.load("com.test.sql", options) 32 | 33 | val t1 = System.currentTimeMillis() 34 | df.collect 35 | val t2 = System.currentTimeMillis() 36 | val t3 = System.currentTimeMillis() 37 | cmpdf.collect 38 | val t4 = System.currentTimeMillis() 39 | println(t2 - t1) 40 | println(t4 - t3) 41 | } 42 | -------------------------------------------------------------------------------- /src/site/markdown/docs/users.md: -------------------------------------------------------------------------------- 1 | title: User Guide 2 | 3 | # User Guide 4 | 5 | ## Terminology 6 | 7 | A variety of terms are used to describe time series data, and many of these apply to conflicting or 8 | overlapping concepts. In the interest of clarity, in spark-timeseries, we stick to the following 9 | set of definitions: 10 | 11 | * **Time Series** - A sequence of floating point values, each linked to a timestamp. 12 | In particular, we try to stick with “time series” as meaning a 13 | univariate time series, although in other contexts it sometimes refers to series with multiple 14 | values at the same timestamp. A notable instance of the latter is the `TimeSeries` class, which 15 | refers to a multivariate time series. In Scala, a time series is usually represented by a Breeze 16 | vector, and in Python, a 1-D numpy array, and has a DateTimeIndex somewhere nearby to link its 17 | values to points in time. 18 | * **Key** - A string label used to identify a time series. A TimeSeriesRDD is a distributed 19 | collection of tuples of (key, time series) 20 | * **Instant** - The set of values in a collection of time series corresponding to a single point in 21 | time. 22 | * **Observation** - A tuple of (timestamp, key, value). 23 | 24 | 25 | ## Abstractions 26 | 27 | ### TimeSeriesRDD 28 | 29 | The central abstraction of the library is the `TimeSeriesRDD`, a lazy distributed collection of 30 | univariate series with a conformed time dimension. It is lazy in the sense that it is an RDD: it 31 | encapsulates all the information needed to generate its elements, but doesn't materialize them upon 32 | instantiation. It is distributed in the sense that different univariate series within the collection 33 | can be stored and processed on different nodes. Within each univariate series, observations are not 34 | distributed. The time dimension is conformed in the sense that a single `DateTimeIndex` applies to 35 | all the univariate series. Each univariate series within the RDD has a key to identify it. 36 | 37 | TimeSeriesRDDs then support efficient series-wise operations like slicing, imputing missing values 38 | based on surrounding elements, and training time-series models. For example, in Scala: 39 | 40 | val tsRdd: TimeSeriesRDD = ... 41 | 42 | // Find a sub-slice between two dates 43 | val subslice = tsRdd.slice(new DateTime("2015-4-10"), new DateTime("2015-4-14")) 44 | 45 | // Fill in missing values based on linear interpolation 46 | val filled = subslice.fill("linear") 47 | 48 | // Use an AR(1) model to remove serial correlations 49 | val residuals = filled.mapSeries(series => ar(series, 1).removeTimeDependentEffects(series)) 50 | 51 | Or in Python: 52 | 53 | tsrdd = ... 54 | 55 | # Find a sub-slice between two dates 56 | subslice = tsrdd['2015-04-10':'2015-04-14'] 57 | 58 | # Fill in missing values based on linear interpolation 59 | filled = subslice.fill('linear') 60 | 61 | 62 | ### DateTimeIndex 63 | 64 | The time spanned by a TimeSeriesRDD is encapsulated in a `DateTimeIndex`, which is essentially an 65 | ordered collection of timestamps. DateTimeIndexes come in two flavors: uniform and irregular. 66 | Uniform DateTimeIndexes have a concise representation including a start date, a frequency (i.e. 67 | the interval between two timestamps), and a number of periods. Irregular indices are simply 68 | represented by an ordered collection of timestamps. 69 | -------------------------------------------------------------------------------- /src/site/markdown/index.md: -------------------------------------------------------------------------------- 1 | title: Overview 2 | 3 | # Overview 4 | 5 | Spark-Timeseries is a Python and Scala library for analyzing large-scale time series data sets. It 6 | is hosted [here](https://github.com/cloudera/spark-timeseries). 7 | 8 | Scaladoc is available [here](scaladocs/index.html). 9 | 10 | Python doc is available [here](pydoc/py-modindex.html). 11 | 12 | Spark-Timeseries offers: 13 | 14 | * A set of abstractions for manipulating time series data, similar to what's provided for smaller 15 | data sets in 16 | [Pandas](http://pandas.pydata.org/pandas-docs/dev/timeseries.html), 17 | [Matlab](http://www.mathworks.com/help/matlab/time-series.html), and R's 18 | [zoo](http://cran.r-project.org/web/packages/zoo/index.html) and 19 | [xts](http://cran.r-project.org/web/packages/xts/index.html) packages. 20 | * Models, tests, and functions that enable dealing with time series from a statistical perspective, 21 | similar to what's provided in [StatsModels](http://statsmodels.sourceforge.net/devel/tsa.html) 22 | and a variety of Matlab and R packages. 23 | 24 | The library is geared towards use cases in finance (munging tick data, building risk models), but 25 | intends to be general enough that other fields with continuous time series data, like meteorology, 26 | can make use of it. 27 | 28 | The library currently expects that individual univariate time series can easily fit in memory on each 29 | machine, but that collections of univariate time series may need to be distributed across many 30 | machines. While time series that violate this expectation pose a bunch of fun distributed 31 | programming problems, they don't tend to come up very often in finance, where an array holding 32 | a value for every minute of every trading day for ten years needs less than a couple million 33 | elements. 34 | 35 | ## Dependencies 36 | 37 | The library sits on a few other excellent Java and Scala libraries. 38 | 39 | * [Breeze](https://github.com/scalanlp/breeze) for NumPy-like, BLAS-able linear algebra. 40 | * [JodaTime](http://www.joda.org/joda-time/) for dates and times. 41 | * [Apache Commons Math](https://commons.apache.org/proper/commons-math/) for general math and 42 | statistics functionality. 43 | * [Apache Spark](https://spark.apache.org/) for distributed computation with in-memory 44 | capabilities. 45 | 46 | 47 | 48 | ## Functionality 49 | 50 | ### Time Series Manipulation 51 | 52 | * Aligning 53 | * Lagging 54 | * Slicing by date-time 55 | * Missing value imputation 56 | * Conversion between different time series data layouts 57 | 58 | ### Time Series Math and Stats 59 | 60 | * Exponentially weighted moving average (EWMA) models 61 | * Autoregressive integrated moving average (ARIMA) models 62 | * Generalized autoregressive conditional heteroskedastic (GARCH) models 63 | * Missing data imputation 64 | * Augmented Dickey-Fuller test 65 | * Durbin-Watson test 66 | * Breusch-Godfrey test 67 | * Breusch-Pagan test 68 | -------------------------------------------------------------------------------- /src/site/site.xml: -------------------------------------------------------------------------------- 1 | 2 | 13 | 17 | 18 | 19 | lt.velykis.maven.skins 20 | reflow-maven-skin 21 | 1.1.1 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | side 30 | 31 | SparkTimeseries 32 | index.html 33 | 34 | bootswatch-united 35 | sidebar 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /src/test/resources/GOOG.csv: -------------------------------------------------------------------------------- 1 | Date,Open,High,Low,Close,Volume,Adj Close 2 | 2014-10-24,544.36,544.88,535.79,539.78,1967700,539.78 3 | 2014-10-23,539.32,547.22,535.85,543.98,2342400,543.98 4 | 2014-10-22,529.89,539.80,528.80,532.71,2911300,532.71 5 | 2014-10-21,525.19,526.79,519.11,526.54,2329900,526.54 6 | 2014-10-20,509.45,521.76,508.10,520.84,2600400,520.84 7 | 2014-10-17,527.25,530.98,508.53,511.17,5524200,511.17 8 | 2014-10-16,519.00,529.43,515.00,524.51,3684300,524.51 9 | 2014-10-15,531.01,532.80,518.30,530.03,3709200,530.03 10 | 2014-10-14,538.90,547.19,533.17,537.94,2216500,537.94 11 | 2014-10-13,544.99,549.50,533.10,533.21,2574600,533.21 12 | 2014-10-10,557.72,565.13,544.05,544.49,3073500,544.49 13 | 2014-10-09,571.18,571.49,559.06,560.88,2517900,560.88 14 | 2014-10-08,565.57,573.88,557.49,572.50,1985400,572.50 15 | 2014-10-07,574.40,575.27,563.74,563.74,1906100,563.74 16 | 2014-10-06,578.80,581.00,574.44,577.35,1211300,577.35 17 | 2014-10-03,573.05,577.22,572.50,575.28,1138600,575.28 18 | 2014-10-02,567.31,571.91,563.32,570.08,1175200,570.08 19 | 2014-10-01,576.01,577.58,567.01,568.27,1441500,568.27 20 | 2014-09-30,576.93,579.85,572.85,577.36,1617300,577.36 21 | 2014-09-29,571.75,578.19,571.17,576.36,1278900,576.36 22 | 2014-09-26,576.06,579.25,574.66,577.10,1439700,577.10 23 | 2014-09-25,587.55,587.98,574.18,575.06,1920700,575.06 24 | 2014-09-24,581.46,589.63,580.52,587.99,1723400,587.99 25 | 2014-09-23,586.85,586.85,581.00,581.13,1467400,581.13 26 | 2014-09-22,593.82,593.95,583.46,587.37,1684900,587.37 27 | 2014-09-19,591.50,596.48,589.50,596.08,3726400,596.08 28 | 2014-09-18,587.00,589.54,585.00,589.27,1440600,589.27 29 | 2014-09-17,580.01,587.52,578.78,584.77,1688200,584.77 30 | 2014-09-16,572.76,581.50,572.66,579.95,1476300,579.95 31 | 2014-09-15,572.94,574.95,568.21,573.10,1593200,573.10 32 | 2014-09-12,581.00,581.64,574.46,575.62,1597300,575.62 33 | 2014-09-11,580.36,581.81,576.26,581.35,1217700,581.35 34 | 2014-09-10,581.50,583.50,576.94,583.10,974700,583.10 35 | 2014-09-09,588.90,589.00,580.00,581.01,1283700,581.01 36 | 2014-09-08,586.60,591.77,586.30,589.72,1427100,589.72 37 | 2014-09-05,583.98,586.55,581.95,586.08,1627900,586.08 38 | 2014-09-04,580.00,586.00,579.22,581.98,1454200,581.98 39 | 2014-09-03,580.00,582.99,575.00,577.94,1211800,577.94 40 | 2014-09-02,571.85,577.83,571.19,577.33,1574100,577.33 41 | 2014-08-29,571.33,572.04,567.07,571.60,1080800,571.60 42 | 2014-08-28,569.56,573.25,567.10,569.20,1289400,569.20 43 | 2014-08-27,577.27,578.49,570.10,571.00,1698700,571.00 44 | 2014-08-26,581.26,581.80,576.58,577.86,1635200,577.86 45 | 2014-08-25,584.72,585.00,579.00,580.20,1357700,580.20 46 | 2014-08-22,583.59,585.24,580.64,582.56,786900,582.56 47 | 2014-08-21,583.82,584.50,581.14,583.37,912300,583.37 48 | 2014-08-20,585.88,586.70,582.57,584.49,1033900,584.49 49 | 2014-08-19,585.00,587.34,584.00,586.86,976000,586.86 50 | 2014-08-18,576.11,584.51,576.00,582.16,1280600,582.16 51 | 2014-08-15,577.86,579.38,570.52,573.48,1515000,573.48 52 | 2014-08-14,576.18,577.90,570.88,574.65,982800,574.65 53 | 2014-08-13,567.31,575.00,565.75,574.78,1435300,574.78 54 | 2014-08-12,564.52,565.90,560.88,562.73,1537800,562.73 55 | 2014-08-11,569.99,570.49,566.00,567.88,1211400,567.88 56 | 2014-08-08,563.56,570.25,560.35,568.77,1490700,568.77 57 | 2014-08-07,568.00,569.89,561.10,563.36,1107900,563.36 58 | 2014-08-06,561.78,570.70,560.00,566.37,1330700,566.37 59 | 2014-08-05,570.05,571.98,562.61,565.07,1547000,565.07 60 | 2014-08-04,569.04,575.35,564.10,573.15,1423400,573.15 61 | 2014-08-01,570.40,575.96,562.85,566.07,1949900,566.07 62 | 2014-07-31,580.60,583.65,570.00,571.60,2097000,571.60 63 | 2014-07-30,586.55,589.50,584.00,587.42,1013700,587.42 64 | 2014-07-29,588.75,589.70,583.52,585.61,1346200,585.61 65 | 2014-07-28,588.07,592.50,584.75,590.60,984100,590.60 66 | 2014-07-25,590.40,591.86,587.03,589.02,929900,589.02 67 | 2014-07-24,596.45,599.50,591.77,593.35,1032300,593.35 68 | 2014-07-23,593.23,597.85,592.50,595.98,1229800,595.98 69 | 2014-07-22,590.72,599.65,590.60,594.74,1694500,594.74 70 | 2014-07-21,591.75,594.40,585.23,589.47,2056500,589.47 71 | 2014-07-18,593.00,596.80,582.00,595.08,4003200,595.08 72 | 2014-07-17,579.53,580.99,568.61,573.73,3008300,573.73 73 | 2014-07-16,588.00,588.40,582.20,582.66,1393300,582.66 74 | 2014-07-15,585.74,585.80,576.56,584.78,1618600,584.78 75 | 2014-07-14,582.60,585.21,578.03,584.87,1849000,584.87 76 | 2014-07-11,571.91,580.85,571.42,579.18,1617300,579.18 77 | 2014-07-10,565.91,576.59,565.01,571.10,1353000,571.10 78 | 2014-07-09,571.58,576.72,569.38,576.08,1113700,576.08 79 | 2014-07-08,577.66,579.53,566.14,571.09,1904300,571.09 80 | 2014-07-07,583.76,586.43,579.59,582.25,1061700,582.25 81 | 2014-07-03,583.35,585.01,580.92,584.73,712200,584.73 82 | 2014-07-02,583.35,585.44,580.39,582.34,1053500,582.34 83 | 2014-07-01,578.32,584.40,576.65,582.67,1444000,582.67 84 | 2014-06-30,578.66,579.57,574.75,575.28,1310200,575.28 85 | 2014-06-27,577.18,579.87,573.80,577.24,2230800,577.24 86 | 2014-06-26,581.00,582.45,571.85,576.00,1737200,576.00 87 | 2014-06-25,565.26,579.96,565.22,578.65,1964000,578.65 88 | 2014-06-24,565.19,572.65,561.01,564.62,2201100,564.62 89 | 2014-06-23,555.15,565.00,554.25,564.95,1532600,564.95 90 | 2014-06-20,556.85,557.58,550.39,556.36,4496000,556.36 91 | 2014-06-19,554.24,555.00,548.51,554.90,2450100,554.90 92 | 2014-06-18,544.86,553.56,544.00,553.37,1737000,553.37 93 | 2014-06-17,544.20,545.32,539.33,543.01,1440600,543.01 94 | 2014-06-16,549.26,549.62,541.52,544.28,1697900,544.28 95 | 2014-06-13,552.26,552.30,545.56,551.76,1217200,551.76 96 | 2014-06-12,557.30,557.99,548.46,551.35,1454500,551.35 97 | 2014-06-11,558.00,559.88,555.02,558.84,1097100,558.84 98 | 2014-06-10,560.51,563.60,557.90,560.55,1348000,560.55 99 | 2014-06-09,557.15,562.90,556.04,562.12,1463500,562.12 100 | 2014-06-06,558.06,558.06,548.93,556.33,1732000,556.33 101 | 2014-06-05,546.40,554.95,544.45,553.90,1684500,553.90 102 | 2014-06-04,541.50,548.61,538.75,544.66,1811500,544.66 103 | 2014-06-03,550.99,552.34,542.55,544.94,1861500,544.94 104 | 2014-06-02,560.70,560.90,545.73,553.93,1431100,553.93 105 | 2014-05-30,560.80,561.35,555.91,559.89,1766300,559.89 106 | 2014-05-29,563.35,564.00,558.71,560.08,1350400,560.08 107 | 2014-05-28,564.57,567.84,561.00,561.68,1647500,561.68 108 | 2014-05-27,556.00,566.00,554.35,565.95,2098400,565.95 109 | 2014-05-23,547.26,553.64,543.70,552.70,1926900,552.70 110 | 2014-05-22,541.13,547.60,540.78,545.06,1611400,545.06 111 | 2014-05-21,532.90,539.18,531.91,538.94,1193000,538.94 112 | 2014-05-20,529.74,536.23,526.30,529.77,1779900,529.77 113 | 2014-05-19,519.70,529.78,517.58,528.86,1274300,528.86 114 | 2014-05-16,521.39,521.80,515.44,520.63,1481200,520.63 115 | 2014-05-15,525.70,525.87,517.42,519.98,1699700,519.98 116 | 2014-05-14,533.00,533.00,525.29,526.65,1188500,526.65 117 | 2014-05-13,530.89,536.07,529.51,533.09,1648900,533.09 118 | 2014-05-12,523.51,530.19,519.01,529.92,1907300,529.92 119 | 2014-05-09,510.75,519.90,504.20,518.73,2432800,518.73 120 | 2014-05-08,508.46,517.23,506.45,511.00,2015800,511.00 121 | 2014-05-07,515.79,516.68,503.30,509.96,3215500,509.96 122 | 2014-05-06,525.23,526.81,515.06,515.14,1684400,515.14 123 | 2014-05-05,524.82,528.90,521.32,527.81,1021300,527.81 124 | 2014-05-02,533.76,534.00,525.61,527.93,1683900,527.93 125 | 2014-05-01,527.11,532.93,523.88,531.35,1900300,531.35 126 | 2014-04-30,527.60,528.00,522.52,526.66,1746400,526.66 127 | 2014-04-29,516.90,529.46,516.32,527.70,2691700,527.70 128 | 2014-04-28,517.18,518.60,502.80,517.15,3326400,517.15 129 | 2014-04-25,522.51,524.70,515.42,516.18,2094600,516.18 130 | 2014-04-24,530.07,531.65,522.12,525.16,1878000,525.16 131 | 2014-04-23,533.79,533.87,526.25,526.94,2046700,526.94 132 | 2014-04-22,528.64,537.23,527.51,534.81,2358900,534.81 133 | 2014-04-21,536.10,536.70,525.60,528.62,2559700,528.62 134 | 2014-04-17,548.81,549.50,531.15,536.10,6790900,536.10 135 | 2014-04-16,543.00,557.00,540.00,556.54,4879900,556.54 136 | 2014-04-15,536.82,538.45,518.46,536.44,3844500,536.44 137 | 2014-04-14,538.25,544.10,529.56,532.52,2568000,532.52 138 | 2014-04-11,532.55,540.00,526.53,530.60,3914100,530.60 139 | 2014-04-10,565.00,565.00,539.90,540.95,4025800,540.95 140 | 2014-04-09,559.62,565.37,552.95,564.14,3321700,564.14 141 | 2014-04-08,542.60,555.00,541.61,554.90,3142600,554.90 142 | 2014-04-07,540.74,548.48,527.15,538.15,4389600,538.15 143 | 2014-04-04,574.65,577.77,543.00,543.14,6351900,543.14 144 | 2014-04-03,569.85,587.28,564.13,569.74,5085200,569.74 145 | 2014-04-02,599.99,604.83,562.19,567.00,146700,567.00 146 | 2014-04-01,558.71,568.45,558.71,567.16,7900,567.16 147 | 2014-03-31,566.89,567.00,556.93,556.97,10800,556.97 148 | 2014-03-28,561.20,566.43,558.67,559.99,41100,559.99 149 | 2014-03-27,568.00,568.00,552.92,558.46,13100,558.46 150 | -------------------------------------------------------------------------------- /src/test/resources/R_ARIMA_DataSet1.csv: -------------------------------------------------------------------------------- 1 | 12.8896942691386 2 | 13.5496440825487 3 | 13.8432744987832 4 | 12.1384361098596 5 | 12.8115609181586 6 | 14.2499628037939 7 | 15.1210259471 8 | 12.7532526308709 9 | 13.2647685601243 10 | 15.5795162488528 11 | 14.7771935240448 12 | 15.484430357205 13 | 16.2792361996444 14 | 13.9856853702286 15 | 10.6461741435826 16 | 9.06490577221341 17 | 10.4289264347516 18 | 11.7386068959506 19 | 13.1307813085612 20 | 12.7703786183829 21 | 11.6787310763953 22 | 13.2582899771967 23 | 12.4272867222566 24 | 10.9368272707903 25 | 10.8170993786496 26 | 9.33055916495176 27 | 10.1960799209371 28 | 11.9367564175161 29 | 14.0161169209224 30 | 13.3151885950953 31 | 11.8918824639671 32 | 11.5454092731078 33 | 11.6281016319998 34 | 11.9323735270844 35 | 12.4524900907391 36 | 12.2612592700595 37 | 12.0870982501761 38 | 14.2634192382215 39 | 14.3550898039082 40 | 13.0092726436564 41 | 13.255540619214 42 | 11.8402977365371 43 | 11.6383461994385 44 | 11.612880018084 45 | 11.261873657545 46 | 12.3777769101735 47 | 13.5853252169882 48 | 13.8217452449092 49 | 13.8108268210019 50 | 12.7420910939141 51 | 13.5912182560401 52 | 14.2095334852662 53 | 13.7191996845487 54 | 10.7650556597763 55 | 9.16915993518643 56 | 10.3627379489681 57 | 12.364353984245 58 | 12.7929124218591 59 | 13.9547002390535 60 | 15.4028250614837 61 | 12.7837787707395 62 | 10.7885745866214 63 | 10.5841029108965 64 | 10.9653654123358 65 | 11.6298561324864 66 | 10.2398832410536 67 | 11.8116930374275 68 | 13.0343676271611 69 | 12.6665320994797 70 | 12.3761882304214 71 | 11.2789149662942 72 | 10.2515119355956 73 | 11.0738614416002 74 | 10.8980951001839 75 | 10.5724172529477 76 | 13.3833982347763 77 | 14.0023203103121 78 | 13.4605129421383 79 | 13.8848354020995 80 | 11.7746100745838 81 | 11.9125748320784 82 | 12.8664721119819 83 | 12.2269039985501 84 | 11.5290275490796 85 | 10.9582368259562 86 | 12.6775534616928 87 | 14.6671527361846 88 | 13.1564146643686 89 | 11.8534308446371 90 | 12.8098397335419 91 | 13.4126772274248 92 | 13.9996370802754 93 | 13.3154271820629 94 | 12.4928836357949 95 | 13.2004736476803 96 | 12.9771377250504 97 | 12.397684761691 98 | 10.5851078264591 99 | 11.6147891665879 100 | 11.4149480586764 101 | 10.4413675982053 102 | 10.8862801065146 103 | 11.8608167289856 104 | 11.4962282909209 105 | 11.2022661589124 106 | 10.587212961078 107 | 10.9690445191254 108 | 12.6727831270978 109 | 13.8550657457501 110 | 14.1954217201325 111 | 14.5615644002647 112 | 12.2798341843468 113 | 10.0793597307141 114 | 10.4761037248296 115 | 10.8075382946206 116 | 12.150181977985 117 | 13.0587103926024 118 | 13.8597656142493 119 | 12.3873256061758 120 | 11.6640658203657 121 | 10.8922583785562 122 | 11.1927015111456 123 | 12.4343987175884 124 | 12.4832990220212 125 | 13.3970206894246 126 | 12.8269280124323 127 | 11.2536518491248 128 | 10.0705707547333 129 | 9.591462488709 130 | 8.58315187165911 131 | 9.19844570720176 132 | 10.1843110665906 133 | 10.8295981795609 134 | 10.7593035245296 135 | 10.9416339614859 136 | 13.0254063189446 137 | 14.7098276851499 138 | 12.7010895144727 139 | 12.2528580708268 140 | 14.2221406386967 141 | 12.3791258868892 142 | 12.4073817913725 143 | 12.160399443029 144 | 12.7119880498806 145 | 12.9884673085225 146 | 11.7865423166256 147 | 10.125448713323 148 | 10.0880682348585 149 | 10.2563982245703 150 | 12.0533496670539 151 | 13.1795830122743 152 | 11.3418157766975 153 | 11.7036115949342 154 | 12.4666392757497 155 | 11.1160313643741 156 | 10.0268521838249 157 | 10.9233159450163 158 | 12.5576902758968 159 | 11.609657813096 160 | 12.9605506906173 161 | 13.5350203181545 162 | 10.1735695289673 163 | 10.078278514057 164 | 10.9380638474963 165 | 12.1280449175116 166 | 13.9123121356149 167 | 13.8184232300527 168 | 11.3130800758004 169 | 10.4795938366831 170 | 10.0007849924256 171 | 9.51189277514488 172 | 8.34299669257088 173 | 10.3240006182461 174 | 10.8660448386922 175 | 10.1763948903775 176 | 11.4432326066117 177 | 11.9205103369064 178 | 12.9864210969285 179 | 13.0476671523163 180 | 13.4138958426065 181 | 14.3091144782517 182 | 13.2158487554604 183 | 13.2498675579933 184 | 12.4752689956411 185 | 12.6674131479284 186 | 14.1715617677481 187 | 14.3452914025276 188 | 12.8515707709797 189 | 13.2747022407431 190 | 13.1570517864898 191 | 11.2470026222789 192 | 11.4001953418265 193 | 11.6301124422196 194 | 11.9476192114175 195 | 13.3424083372957 196 | 13.2329684568961 197 | 12.0608621439489 198 | 11.7541751670433 199 | 10.8971863256532 200 | 10.4507710417537 201 | 11.3672324563934 202 | 12.0952005335828 203 | 12.8530622290286 204 | 14.0058951032783 205 | 14.1960073490033 206 | 13.9208257393169 207 | 12.614503424801 208 | 11.9035490323366 209 | 13.4532804644801 210 | 14.4472404545782 211 | 13.9603617030752 212 | 13.135719115996 213 | 13.3594343948181 214 | 11.3584312241239 215 | 12.0909436982143 216 | 12.6779897029275 217 | 11.074806275098 218 | 10.6550552796512 219 | 11.9353389586751 220 | 13.0322164161175 221 | 11.2920714820789 222 | 12.2015587656555 223 | 13.7800856680362 224 | 14.7607592030631 225 | 13.40874043513 226 | 12.9354407484928 227 | 13.0846562228706 228 | 12.3068301221388 229 | 12.7266539951227 230 | 12.0203694167863 231 | 12.1949523971903 232 | 14.652604310778 233 | 13.1906962360283 234 | 13.5044169886966 235 | 12.9919718058027 236 | 10.7414297257485 237 | 11.756034024645 238 | 11.7154688503683 239 | 9.58092723712417 240 | 11.0950595602134 241 | 12.9530245829574 242 | 11.6559709317329 243 | 11.9684917762582 244 | 12.3225883912406 245 | 12.0931430454623 246 | 12.4279214610473 247 | 11.1744964488009 248 | 10.9334590802183 249 | 11.3752360748597 250 | 10.3722107058145 251 | -------------------------------------------------------------------------------- /src/test/resources/R_ARIMA_DataSet2.csv: -------------------------------------------------------------------------------- 1 | 0 2 | 0 3 | 0 4 | -0.180503307880698 5 | -1.94969098197842 6 | -6.18099684806138 7 | -12.6997093227188 8 | -21.0571250799371 9 | -32.3833614350056 10 | -47.2837096404811 11 | -67.457577581561 12 | -93.4867781887097 13 | -124.320827637665 14 | -158.983588519782 15 | -197.562138089474 16 | -239.116678354532 17 | -282.708330245888 18 | -328.099468471371 19 | -376.227167433833 20 | -427.477566289173 21 | -480.964173852229 22 | -536.018907345746 23 | -593.141481701477 24 | -654.636445884937 25 | -721.615723201632 26 | -796.33334803105 27 | -880.478330633102 28 | -974.67737216725 29 | -1079.69276037491 30 | -1196.53416516889 31 | -1325.47777932074 32 | -1466.79773556181 33 | -1622.39853045261 34 | -1792.72885815011 35 | -1976.835741527 36 | -2174.34054135517 37 | -2386.58621602098 38 | -2615.28426860245 39 | -2860.35971474344 40 | -3123.49922375199 41 | -3405.37915698843 42 | -3706.71598624322 43 | -4026.55347271463 44 | -4365.43685100279 45 | -4724.34729261871 46 | -5102.61605616632 47 | -5501.04389885174 48 | -5919.85316640781 49 | -6358.81709674905 50 | -6818.19039352644 51 | -7298.71091305944 52 | -7799.85895062844 53 | -8321.90409825322 54 | -8865.26104000827 55 | -9428.62873325341 56 | -10009.5958200943 57 | -10607.2289278458 58 | -11220.6405502707 59 | -11850.5756308363 60 | -12496.6817149391 61 | -13159.4981174343 62 | -13838.863029684 63 | -14535.9558486596 64 | -15251.4802695298 65 | -15986.3578502044 66 | -16740.4145395708 67 | -17512.5159381044 68 | -18301.2326446881 69 | -19105.5857434966 70 | -19925.9093050158 71 | -20761.7368262061 72 | -21614.202077827 73 | -22483.1734016081 74 | -23370.0050403598 75 | -24276.0315250375 76 | -25201.6810482927 77 | -26148.4872470723 78 | -27115.5754023423 79 | -28104.1907998681 80 | -29115.0597928702 81 | -30149.3200975536 82 | -31205.6594553029 83 | -32283.1805207946 84 | -33381.9873899269 85 | -34501.4117007251 86 | -35640.5942735005 87 | -36798.9515462218 88 | -37976.2366398949 89 | -39171.7580242572 90 | -40383.1632268179 91 | -41611.1920887188 92 | -42856.1553548089 93 | -44118.4841709302 94 | -45398.4532903614 95 | -46696.0314647898 96 | -48010.0494369411 97 | -49339.6811798047 98 | -50686.2273470197 99 | -52051.578544853 100 | -53435.1631828542 101 | -54836.6545078881 102 | -56254.7958579737 103 | -57690.1036317721 104 | -59142.3108147451 105 | -60612.3734751764 106 | -62099.7880277728 107 | -63605.0453707561 108 | -65127.7089720648 109 | -66667.0052100802 110 | -68224.4599383224 111 | -69799.376598986 112 | -71390.4216048615 113 | -72998.6495198295 114 | -74623.1875065833 115 | -76264.2739260045 116 | -77922.3173652082 117 | -79596.4566830895 118 | -81286.4298762948 119 | -82991.1815922744 120 | -84709.7619601005 121 | -86440.7670576735 122 | -88182.9948570552 123 | -89936.7365400596 124 | -91701.8854980302 125 | -93478.4328943552 126 | -95267.5806572838 127 | -97068.9442216392 128 | -98883.3136556576 129 | -100710.623162228 130 | -102551.885612542 131 | -104407.677513406 132 | -106279.278656122 133 | -108165.50122688 134 | -110065.425947232 135 | -111980.922945618 136 | -113913.073417335 137 | -115862.473784296 138 | -117830.199126792 139 | -119815.950726947 140 | -121818.903700921 141 | -123838.24558648 142 | -125872.888137441 143 | -127924.315728698 144 | -129994.068612078 145 | -132084.351077148 146 | -134194.085022098 147 | -136322.603824125 148 | -138468.76710943 149 | -140631.831052275 150 | -142811.587539243 151 | -145006.678391178 152 | -147216.748816818 153 | -149442.172473417 154 | -151683.277449301 155 | -153938.958613221 156 | -156209.849671467 157 | -158496.990247831 158 | -160802.874695011 159 | -163127.358386244 160 | -165469.169549447 161 | -167829.277781084 162 | -170209.503001443 163 | -172609.508044482 164 | -175029.711685446 165 | -177469.702751507 166 | -179930.076434828 167 | -182411.411525335 168 | -184914.473124188 169 | -187438.437554051 170 | -189984.581160015 171 | -192553.013812854 172 | -195145.13390697 173 | -197762.3658717 174 | -200405.351832348 175 | -203074.030783694 176 | -205767.311914548 177 | -208485.553613946 178 | -211227.592859294 179 | -213992.945924994 180 | -216781.875966249 181 | -219593.484916539 182 | -222427.031987664 183 | -225281.981902055 184 | -228157.606057059 185 | -231055.136058811 186 | -233975.042474648 187 | -236916.745703113 188 | -239878.053674065 189 | -242858.858030339 190 | -245858.938583272 191 | -248877.547766406 192 | -251914.500954835 193 | -254969.67584691 194 | -258042.054504884 195 | -261132.119013744 196 | -264241.282661008 197 | -267371.269555876 198 | -270522.686911965 199 | -273697.301083579 200 | -276896.803317275 201 | -280122.56580966 202 | -283375.908522484 203 | -286655.835850575 204 | -------------------------------------------------------------------------------- /src/test/scala/com/cloudera/finance/YahooParserSuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | package com.cloudera.finance 16 | 17 | import org.scalatest.{FunSuite, ShouldMatchers} 18 | 19 | class YahooParserSuite extends FunSuite with ShouldMatchers { 20 | test("yahoo parser") { 21 | val is = getClass.getClassLoader.getResourceAsStream("GOOG.csv") 22 | val lines = scala.io.Source.fromInputStream(is).getLines().toArray 23 | val text = lines.mkString("\n") 24 | val ts = YahooParser.yahooStringToTimeSeries(text) 25 | ts.data.rows should be (lines.length - 1) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/test/scala/com/cloudera/sparkts/ARIMASuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.sparkts 17 | 18 | import breeze.linalg._ 19 | 20 | import com.cloudera.sparkts.UnivariateTimeSeries.differencesOfOrderD 21 | 22 | import org.apache.commons.math3.random.MersenneTwister 23 | 24 | import org.scalatest.FunSuite 25 | import org.scalatest.Matchers._ 26 | 27 | class ARIMASuite extends FunSuite { 28 | test("compare with R") { 29 | // > R.Version()$version.string 30 | // [1] "R version 3.2.0 (2015-04-16)" 31 | // > set.seed(456) 32 | // y <- arima.sim(n=250,list(ar=0.3,ma=0.7),mean = 5) 33 | // write.table(y, file = "resources/R_ARIMA_DataSet1.csv", row.names = FALSE, col.names = FALSE) 34 | val dataFile = getClass.getClassLoader.getResourceAsStream("R_ARIMA_DataSet1.csv") 35 | val rawData = scala.io.Source.fromInputStream(dataFile).getLines().toArray.map(_.toDouble) 36 | val data = new DenseVector(rawData) 37 | 38 | val model = ARIMA.fitModel(1, 0, 1, data) 39 | val Array(c, ar, ma) = model.coefficients 40 | ar should be (0.3 +- 0.05) 41 | ma should be (0.7 +- 0.05) 42 | } 43 | 44 | test("Data sampled from a given model should result in similar model if fit") { 45 | val rand = new MersenneTwister(10L) 46 | val model = new ARIMAModel(2, 1, 2, Array(8.2, 0.2, 0.5, 0.3, 0.1)) 47 | val sampled = model.sample(1000, rand) 48 | val newModel = ARIMA.fitModel(2, 1, 2, sampled) 49 | val Array(c, ar1, ar2, ma1, ma2) = model.coefficients 50 | val Array(cTest, ar1Test, ar2Test, ma1Test, ma2Test) = newModel.coefficients 51 | // intercept is given more leeway 52 | c should be (cTest +- 1) 53 | ar1Test should be (ar1 +- 0.1) 54 | ma1Test should be (ma1 +- 0.1) 55 | ar2Test should be (ar2 +- 0.1) 56 | ma2Test should be (ma2 +- 0.1) 57 | } 58 | 59 | test("Fitting CSS with BOBYQA and conjugate gradient descent should be fairly similar") { 60 | val rand = new MersenneTwister(10L) 61 | val model = new ARIMAModel(2, 1, 2, Array(8.2, 0.2, 0.5, 0.3, 0.1)) 62 | val sampled = model.sample(1000, rand) 63 | val fitWithBOBYQA = ARIMA.fitModel(2, 1, 2, sampled, method = "css-bobyqa") 64 | val fitWithCGD = ARIMA.fitModel(2, 1, 2, sampled, method = "css-cgd") 65 | 66 | val Array(c, ar1, ar2, ma1, ma2) = fitWithBOBYQA.coefficients 67 | val Array(cCGD, ar1CGD, ar2CGD, ma1CGD, ma2CGD) = fitWithCGD.coefficients 68 | 69 | // give more leeway for intercept 70 | cCGD should be (c +- 1) 71 | ar1CGD should be (ar1 +- 0.1) 72 | ar2CGD should be (ar2 +- 0.1) 73 | ma1CGD should be (ma1 +- 0.1) 74 | ma2CGD should be (ma2 +- 0.1) 75 | } 76 | 77 | test("Fitting ARIMA(p, d, q) should be the same as fitting a d-order differenced ARMA(p, q)") { 78 | val rand = new MersenneTwister(10L) 79 | val model = new ARIMAModel(1, 1, 2, Array(0.3, 0.7, 0.1), hasIntercept = false) 80 | val sampled = model.sample(1000, rand) 81 | val arimaModel = ARIMA.fitModel(1, 1, 2, sampled, includeIntercept = false) 82 | val differencedSample = new DenseVector(differencesOfOrderD(sampled, 1).toArray.drop(1)) 83 | val armaModel = ARIMA.fitModel(1, 0, 2, differencedSample, includeIntercept = false) 84 | 85 | val Array(refAR, refMA1, refMA2) = model.coefficients 86 | val Array(iAR, iMA1, iMA2) = arimaModel.coefficients 87 | val Array(ar, ma1, ma2) = armaModel.coefficients 88 | 89 | // ARIMA model should match parameters used to sample, to some extent 90 | iAR should be (refAR +- 0.05) 91 | iMA1 should be (refMA1 +- 0.05) 92 | iMA2 should be (refMA2 +- 0.05) 93 | 94 | // ARMA model parameters of differenced sample should be equal to ARIMA model parameters 95 | ar should be (iAR) 96 | ma1 should be (iMA1) 97 | ma2 should be (iMA2) 98 | } 99 | 100 | test("Adding ARIMA effects to series, and removing should return the same series") { 101 | val rand = new MersenneTwister(20L) 102 | val model = new ARIMAModel(1, 1, 2, Array(8.3, 0.1, 0.2, 0.3), hasIntercept = true) 103 | val whiteNoise = new DenseVector(Array.fill(100)(rand.nextGaussian)) 104 | val arimaProcess = new DenseVector(Array.fill(100)(0.0)) 105 | model.addTimeDependentEffects(whiteNoise, arimaProcess) 106 | val closeToWhiteNoise = new DenseVector(Array.fill(100)(0.0)) 107 | model.removeTimeDependentEffects(arimaProcess, closeToWhiteNoise) 108 | 109 | for (i <- 0 until whiteNoise.length) { 110 | val diff = whiteNoise(i) - closeToWhiteNoise(i) 111 | math.abs(diff) should be < 1e-4 112 | } 113 | } 114 | 115 | test("Fitting ARIMA(0, 0, 0) with intercept term results in model with average as parameter") { 116 | val rand = new MersenneTwister(10L) 117 | val sampled = new DenseVector(Array.fill(100)(rand.nextGaussian)) 118 | val model = ARIMA.fitModel(0, 0, 0, sampled) 119 | val mean = sum(sampled) / sampled.length 120 | model.coefficients(0) should be (mean +- 1e-4) 121 | } 122 | 123 | test("Fitting an integrated time series of order 3") { 124 | // > set.seed(10) 125 | // > vals <- arima.sim(list(ma = c(0.2), order = c(0, 3, 1)), 200) 126 | // > arima(order = c(0, 3, 1), vals, method = "CSS") 127 | // 128 | // Call: 129 | // arima(x = vals, order = c(0, 3, 1), method = "CSS") 130 | // 131 | // Coefficients: 132 | // ma1 133 | // 0.2523 134 | // s.e. 0.0623 135 | // 136 | // sigma^2 estimated as 0.9218: part log likelihood = -275.65 137 | // > write.table(y, file = "resources/R_ARIMA_DataSet2.csv", row.names = FALSE, col.names = 138 | // FALSE) 139 | val dataFile = getClass.getClassLoader.getResourceAsStream("R_ARIMA_DataSet2.csv") 140 | val rawData = scala.io.Source.fromInputStream(dataFile).getLines().toArray.map(_.toDouble) 141 | val data = new DenseVector(rawData) 142 | val model = ARIMA.fitModel(0, 3, 1, data) 143 | val Array(c, ma) = model.coefficients 144 | ma should be (0.2 +- 0.05) 145 | } 146 | 147 | test("Stationarity and Invertibility checks") { 148 | // Testing violations of stationarity and invertibility 149 | val model1 = new ARIMAModel(1, 0, 0, Array(0.2, 1.5), hasIntercept = true) 150 | model1.isStationary() should be (false) 151 | model1.isInvertible() should be (true) 152 | 153 | val model2 = new ARIMAModel(0, 0, 1, Array(0.13, 1.8), hasIntercept = true) 154 | model2.isStationary() should be (true) 155 | model2.isInvertible() should be (false) 156 | 157 | // http://www.econ.ku.dk/metrics/Econometrics2_05_II/Slides/07_univariatetimeseries_2pp.pdf 158 | // AR(2) model on slide 31 should be stationary 159 | val model3 = new ARIMAModel(2, 0, 0, Array(0.003359, 1.545, -0.5646), hasIntercept = true) 160 | model3.isStationary() should be (true) 161 | model3.isInvertible() should be (true) 162 | 163 | // http://www.econ.ku.dk/metrics/Econometrics2_05_II/Slides/07_univariatetimeseries_2pp.pdf 164 | // ARIMA(1, 0, 1) model from slide 36 should be stationary and invertible 165 | val model4 = new ARIMAModel(1, 0, 1, Array(-0.09341, 0.857361, -0.300821), hasIntercept = true) 166 | model4.isStationary() should be (true) 167 | model4.isInvertible() should be (true) 168 | } 169 | } 170 | -------------------------------------------------------------------------------- /src/test/scala/com/cloudera/sparkts/AugmentedDickeyFullerSuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | package com.cloudera.sparkts 16 | 17 | import breeze.linalg._ 18 | 19 | import org.apache.commons.math3.random.MersenneTwister 20 | 21 | import org.scalatest.FunSuite 22 | 23 | class AugmentedDickeyFullerSuite extends FunSuite { 24 | test("non-stationary AR model") { 25 | val rand = new MersenneTwister(10L) 26 | val arModel = new ARModel(0.0, .95) 27 | val sample = arModel.sample(500, rand) 28 | 29 | val (adfStat, pValue) = TimeSeriesStatisticalTests.adftest(sample, 1) 30 | assert(!java.lang.Double.isNaN(adfStat)) 31 | assert(!java.lang.Double.isNaN(pValue)) 32 | println("adfStat: " + adfStat) 33 | println("pValue: " + pValue) 34 | } 35 | 36 | test("iid samples") { 37 | val rand = new MersenneTwister(11L) 38 | val iidSample = Array.fill(500)(rand.nextDouble()) 39 | val (adfStat, pValue) = TimeSeriesStatisticalTests.adftest(new DenseVector(iidSample), 1) 40 | assert(!java.lang.Double.isNaN(adfStat)) 41 | assert(!java.lang.Double.isNaN(pValue)) 42 | println("adfStat: " + adfStat) 43 | println("pValue: " + pValue) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/test/scala/com/cloudera/sparkts/AutoregressionSuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.sparkts 17 | 18 | import java.util.Random 19 | 20 | import breeze.linalg._ 21 | 22 | import org.apache.commons.math3.random.MersenneTwister 23 | 24 | import org.scalatest.FunSuite 25 | import org.scalatest.Matchers._ 26 | 27 | class AutoregressionSuite extends FunSuite { 28 | test("fit AR(1) model") { 29 | val model = new ARModel(1.5, Array(.2)) 30 | val ts = model.sample(5000, new MersenneTwister(10L)) 31 | val fittedModel = Autoregression.fitModel(ts, 1) 32 | assert(fittedModel.coefficients.length == 1) 33 | assert(math.abs(fittedModel.c - 1.5) < .07) 34 | assert(math.abs(fittedModel.coefficients(0) - .2) < .03) 35 | } 36 | 37 | test("fit AR(2) model") { 38 | val model = new ARModel(1.5, Array(.2, .3)) 39 | val ts = model.sample(5000, new MersenneTwister(10L)) 40 | val fittedModel = Autoregression.fitModel(ts, 2) 41 | assert(fittedModel.coefficients.length == 2) 42 | assert(math.abs(fittedModel.c - 1.5) < .15) 43 | assert(math.abs(fittedModel.coefficients(0) - .2) < .03) 44 | assert(math.abs(fittedModel.coefficients(1) - .3) < .03) 45 | } 46 | 47 | test("add and remove time dependent effects") { 48 | val rand = new Random() 49 | val ts = new DenseVector[Double](Array.fill(1000)(rand.nextDouble())) 50 | val model = new ARModel(1.5, Array(.2, .3)) 51 | val added = model.addTimeDependentEffects(ts, DenseVector.zeros[Double](ts.length)) 52 | val removed = model.removeTimeDependentEffects(added, DenseVector.zeros[Double](ts.length)) 53 | assert((ts - removed).toArray.forall(math.abs(_) < .001)) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/test/scala/com/cloudera/sparkts/BusinessDayFrequencySuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.sparkts 17 | 18 | import com.cloudera.sparkts.DateTimeIndex._ 19 | 20 | import com.github.nscala_time.time.Imports._ 21 | 22 | import org.scalatest.{FunSuite, ShouldMatchers} 23 | 24 | class BusinessDayFrequencySuite extends FunSuite with ShouldMatchers { 25 | test("business days") { 26 | // got the club goin' up, on 27 | val aTuesday = new DateTime("2015-4-7") 28 | // don't cross a weekend 29 | 1.businessDays.advance(aTuesday, 1) should be (aTuesday + 1.days) 30 | 1.businessDays.difference(aTuesday, aTuesday + 1.days) should be (1) 31 | 2.businessDays.advance(aTuesday, 1) should be (aTuesday + 2.days) 32 | 1.businessDays.advance(aTuesday, 2) should be (aTuesday + 2.days) 33 | 2.businessDays.difference(aTuesday, aTuesday + 2.days) should be (1) 34 | 1.businessDays.difference(aTuesday, aTuesday + 2.days) should be (2) 35 | // go exactly a week ahead 36 | 5.businessDays.advance(aTuesday, 1) should be (aTuesday + 7.days) 37 | 1.businessDays.advance(aTuesday, 5) should be (aTuesday + 7.days) 38 | 5.businessDays.difference(aTuesday, aTuesday + 7.days) should be (1) 39 | 1.businessDays.difference(aTuesday, aTuesday + 7.days) should be (5) 40 | // cross a weekend but go less than a week ahead 41 | 4.businessDays.advance(aTuesday, 1) should be (aTuesday + 6.days) 42 | 1.businessDays.advance(aTuesday, 4) should be (aTuesday + 6.days) 43 | 4.businessDays.difference(aTuesday, aTuesday + 6.days) should be (1) 44 | 1.businessDays.difference(aTuesday, aTuesday + 6.days) should be (4) 45 | // go more than a week ahead 46 | 6.businessDays.advance(aTuesday, 1) should be (aTuesday + 8.days) 47 | 1.businessDays.advance(aTuesday, 6) should be (aTuesday + 8.days) 48 | 6.businessDays.difference(aTuesday, aTuesday + 8.days) should be (1) 49 | 1.businessDays.difference(aTuesday, aTuesday + 8.days) should be (6) 50 | // go exactly two weeks ahead 51 | 10.businessDays.advance(aTuesday, 1) should be (aTuesday + 14.days) 52 | 1.businessDays.advance(aTuesday, 10) should be (aTuesday + 14.days) 53 | 10.businessDays.difference(aTuesday, aTuesday + 14.days) should be (1) 54 | 1.businessDays.difference(aTuesday, aTuesday + 14.days) should be (10) 55 | // go more than two weeks ahead 56 | 12.businessDays.advance(aTuesday, 1) should be (aTuesday + 16.days) 57 | 1.businessDays.advance(aTuesday, 12) should be (aTuesday + 16.days) 58 | 12.businessDays.difference(aTuesday, aTuesday + 16.days) should be (1) 59 | 1.businessDays.difference(aTuesday, aTuesday + 16.days) should be (12) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/test/scala/com/cloudera/sparkts/DateTimeIndexSuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.sparkts 17 | 18 | import org.scalatest.{FunSuite, ShouldMatchers} 19 | 20 | import com.github.nscala_time.time.Imports._ 21 | 22 | import com.cloudera.sparkts.DateTimeIndex._ 23 | 24 | import org.joda.time.DateTimeZone.UTC 25 | 26 | class DateTimeIndexSuite extends FunSuite with ShouldMatchers { 27 | 28 | test("to / from string") { 29 | val uniformIndex = uniform(new DateTime("1990-04-10"), 5, 2.businessDays) 30 | val uniformStr = uniformIndex.toString 31 | fromString(uniformStr) should be (uniformIndex) 32 | 33 | val irregularIndex = irregular( 34 | Array(new DateTime("1990-04-10"), new DateTime("1990-04-12"), new DateTime("1990-04-13"))) 35 | val irregularStr = irregularIndex.toString 36 | fromString(irregularStr) should be (irregularIndex) 37 | } 38 | 39 | test("uniform") { 40 | val index: DateTimeIndex = uniform(new DateTime("2015-04-10", UTC), 5, 2.days) 41 | index.size should be (5) 42 | index.first should be (new DateTime("2015-04-10", UTC)) 43 | index.last should be (new DateTime("2015-04-18", UTC)) 44 | 45 | def verifySlice(index: DateTimeIndex) = { 46 | index.size should be (2) 47 | index.first should be (new DateTime("2015-04-14", UTC)) 48 | index.last should be (new DateTime("2015-04-16", UTC)) 49 | } 50 | 51 | verifySlice(index.slice(new DateTime("2015-04-14", UTC), new DateTime("2015-04-16", UTC))) 52 | verifySlice(index.slice(new DateTime("2015-04-14", UTC) to new DateTime("2015-04-16", UTC))) 53 | verifySlice(index.islice(2, 4)) 54 | verifySlice(index.islice(2 until 4)) 55 | verifySlice(index.islice(2 to 3)) 56 | } 57 | 58 | test("irregular") { 59 | val index = irregular(Array( 60 | "2015-04-14", "2015-04-15", "2015-04-17", "2015-04-22", "2015-04-25" 61 | ).map(new DateTime(_, UTC))) 62 | index.size should be (5) 63 | index.first should be (new DateTime("2015-04-14", UTC)) 64 | index.last should be (new DateTime("2015-04-25", UTC)) 65 | 66 | def verifySlice(index: DateTimeIndex) = { 67 | index.size should be (3) 68 | index.first should be (new DateTime("2015-04-15", UTC)) 69 | index.last should be (new DateTime("2015-04-22", UTC)) 70 | } 71 | 72 | verifySlice(index.slice(new DateTime("2015-04-15", UTC), new DateTime("2015-04-22", UTC))) 73 | verifySlice(index.slice(new DateTime("2015-04-15", UTC) to new DateTime("2015-04-22", UTC))) 74 | verifySlice(index.islice(1, 4)) 75 | verifySlice(index.islice(1 until 4)) 76 | verifySlice(index.islice(1 to 3)) 77 | 78 | // TODO: test bounds that aren't members of the index 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /src/test/scala/com/cloudera/sparkts/EWMASuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.sparkts 17 | 18 | import breeze.linalg._ 19 | 20 | import org.scalatest.{FunSuite, ShouldMatchers} 21 | 22 | class EWMASuite extends FunSuite with ShouldMatchers { 23 | test("adding time dependent effects") { 24 | val orig = new DenseVector((1 to 10).toArray.map(_.toDouble)) 25 | 26 | val m1 = new EWMAModel(0.2) 27 | val smoothed1 = new DenseVector(Array.fill(10)(0.0)) 28 | m1.addTimeDependentEffects(orig, smoothed1) 29 | 30 | smoothed1(0) should be (orig(0)) 31 | smoothed1(1) should be (m1.smoothing * orig(1) + (1 - m1.smoothing) * smoothed1(0)) 32 | round2Dec(smoothed1.toArray.last) should be (6.54) 33 | 34 | val m2 = new EWMAModel(0.6) 35 | val smoothed2 = new DenseVector(Array.fill(10)(0.0)) 36 | m2.addTimeDependentEffects(orig, smoothed2) 37 | 38 | smoothed2(0) should be (orig(0)) 39 | smoothed2(1) should be (m2.smoothing * orig(1) + (1 - m2.smoothing) * smoothed2(0)) 40 | round2Dec(smoothed2.toArray.last) should be (9.33) 41 | } 42 | 43 | test("removing time dependent effects") { 44 | val smoothed = new DenseVector(Array(1.0, 1.2, 1.56, 2.05, 2.64, 3.31, 4.05, 4.84, 5.67, 6.54)) 45 | 46 | val m1 = new EWMAModel(0.2) 47 | val orig1 = new DenseVector(Array.fill(10)(0.0)) 48 | m1.removeTimeDependentEffects(smoothed, orig1) 49 | 50 | round2Dec(orig1(0)) should be (1.0) 51 | orig1.toArray.last.toInt should be(10) 52 | } 53 | 54 | test("fitting EWMA model") { 55 | // We reproduce the example in ch 7.1 from 56 | // https://www.otexts.org/fpp/7/1 57 | val oil = Array(446.7, 454.5, 455.7, 423.6, 456.3, 440.6, 425.3, 485.1, 506.0, 526.8, 58 | 514.3, 494.2) 59 | val model = EWMA.fitModel(new DenseVector(oil)) 60 | val truncatedSmoothing = (model.smoothing * 100.0).toInt 61 | truncatedSmoothing should be (89) // approximately 0.89 62 | } 63 | 64 | private def round2Dec(x: Double): Double = { 65 | (x * 100).round / 100.00 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /src/test/scala/com/cloudera/sparkts/FillSuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.sparkts 17 | 18 | import scala.Double.NaN 19 | 20 | import com.cloudera.sparkts.UnivariateTimeSeries._ 21 | 22 | import org.scalatest.{FunSuite, ShouldMatchers} 23 | 24 | class FillSuite extends FunSuite with ShouldMatchers { 25 | ignore("nearest") { 26 | fillNearest(Array(1.0)) should be (Array(1.0)) 27 | fillNearest(Array(1.0, 1.0, 2.0)) should be (Array(1.0, 1.0, 2.0)) 28 | fillNearest(Array(1.0, NaN, NaN, 2.0)) should be (Array(1.0, 1.0, 2.0, 2.0)) 29 | // round down to previous 30 | fillNearest(Array(1.0, NaN, 2.0)) should be (Array(1.0, 1.0, 2.0)) 31 | fillNearest(Array(1.0, NaN, NaN, NaN, 2.0)) should be (Array(1.0, 1.0, 1.0, 2.0, 2.0)) 32 | fillNearest(Array(1.0, NaN, 3.0, NaN, 2.0)) should be (Array(1.0, 1.0, 3.0, 3.0, 2.0)) 33 | } 34 | 35 | test("previous") { 36 | fillPrevious(Array(1.0)) should be (Array(1.0)) 37 | fillPrevious(Array(1.0, 1.0, 2.0)) should be (Array(1.0, 1.0, 2.0)) 38 | fillPrevious(Array(1.0, NaN, 2.0)) should be (Array(1.0, 1.0, 2.0)) 39 | fillPrevious(Array(1.0, NaN, NaN, 2.0)) should be (Array(1.0, 1.0, 1.0, 2.0)) 40 | fillPrevious(Array(1.0, NaN, NaN, NaN, 2.0)) should be (Array(1.0, 1.0, 1.0, 1.0, 2.0)) 41 | fillPrevious(Array(1.0, NaN, 3.0, NaN, 2.0)) should be (Array(1.0, 1.0, 3.0, 3.0, 2.0)) 42 | } 43 | 44 | test("next") { 45 | fillNext(Array(1.0)) should be (Array(1.0)) 46 | fillNext(Array(1.0, 1.0, 2.0)) should be (Array(1.0, 1.0, 2.0)) 47 | fillNext(Array(1.0, NaN, 2.0)) should be (Array(1.0, 2.0, 2.0)) 48 | fillNext(Array(1.0, NaN, NaN, 2.0)) should be (Array(1.0, 2.0, 2.0, 2.0)) 49 | fillNext(Array(1.0, NaN, NaN, NaN, 2.0)) should be (Array(1.0, 2.0, 2.0, 2.0, 2.0)) 50 | fillNext(Array(1.0, NaN, 3.0, NaN, 2.0)) should be (Array(1.0, 3.0, 3.0, 2.0, 2.0)) 51 | } 52 | 53 | test("linear") { 54 | fillLinear(Array(1.0)) should be (Array(1.0)) 55 | fillLinear(Array(1.0, 1.0, 2.0)) should be (Array(1.0, 1.0, 2.0)) 56 | fillLinear(Array(1.0, NaN, 2.0)) should be (Array(1.0, 1.5, 2.0)) 57 | fillLinear(Array(2.0, NaN, 1.0)) should be (Array(2.0, 1.5, 1.0)) 58 | fillLinear(Array(1.0, NaN, NaN, 4.0)) should be (Array(1.0, 2.0, 3.0, 4.0)) 59 | fillLinear(Array(1.0, NaN, NaN, NaN, 5.0)) should be (Array(1.0, 2.0, 3.0, 4.0, 5.0)) 60 | fillLinear(Array(1.0, NaN, 3.0, NaN, 2.0)) should be (Array(1.0, 2.0, 3.0, 2.5, 2.0)) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/test/scala/com/cloudera/sparkts/GARCHSuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.sparkts 17 | 18 | import breeze.linalg._ 19 | 20 | import org.apache.commons.math3.random.MersenneTwister 21 | 22 | import org.scalatest.FunSuite 23 | 24 | class GARCHSuite extends FunSuite { 25 | test("GARCH log likelihood") { 26 | val model = new GARCHModel(.2, .3, .4) 27 | val rand = new MersenneTwister(5L) 28 | val n = 10000 29 | 30 | val ts = new DenseVector(model.sample(n, rand)) 31 | val logLikelihoodWithRightModel = model.logLikelihood(ts) 32 | 33 | val logLikelihoodWithWrongModel1 = new GARCHModel(.3, .4, .5).logLikelihood(ts) 34 | val logLikelihoodWithWrongModel2 = new GARCHModel(.25, .35, .45).logLikelihood(ts) 35 | val logLikelihoodWithWrongModel3 = new GARCHModel(.1, .2, .3).logLikelihood(ts) 36 | 37 | assert(logLikelihoodWithRightModel > logLikelihoodWithWrongModel1) 38 | assert(logLikelihoodWithRightModel > logLikelihoodWithWrongModel2) 39 | assert(logLikelihoodWithRightModel > logLikelihoodWithWrongModel3) 40 | assert(logLikelihoodWithWrongModel2 > logLikelihoodWithWrongModel1) 41 | } 42 | 43 | test("gradient") { 44 | val alpha = 0.3 45 | val beta = 0.4 46 | val omega = 0.2 47 | val genModel = new GARCHModel(omega, alpha, beta) 48 | val rand = new MersenneTwister(5L) 49 | val n = 10000 50 | 51 | val ts = new DenseVector(genModel.sample(n, rand)) 52 | 53 | val gradient1 = new GARCHModel(omega + .1, alpha + .05, beta + .1).gradient(ts) 54 | assert(gradient1.forall(_ < 0.0)) 55 | val gradient2 = new GARCHModel(omega - .1, alpha - .05, beta - .1).gradient(ts) 56 | assert(gradient2.forall(_ > 0.0)) 57 | } 58 | 59 | test("fit model") { 60 | val omega = 0.2 61 | val alpha = 0.3 62 | val beta = 0.5 63 | val genModel = new ARGARCHModel(0.0, 0.0, alpha, beta, omega) 64 | val rand = new MersenneTwister(5L) 65 | val n = 10000 66 | 67 | val ts = new DenseVector(genModel.sample(n, rand)) 68 | 69 | val model = GARCH.fitModel(ts) 70 | assert(model.omega - omega < .1) // TODO: we should be able to be more accurate 71 | assert(model.alpha - alpha < .02) 72 | assert(model.beta - beta < .02) 73 | } 74 | 75 | test("fit model 2") { 76 | val ts = DenseVector(0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1, 77 | 0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00, 78 | -0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2, 79 | -0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0, 80 | -0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1, 81 | 0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1, 82 | 0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01, 83 | 0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1, 84 | -0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1, 85 | 0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00, 86 | -0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2, 87 | -0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0, 88 | -0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1, 89 | 0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1, 90 | 0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00, 91 | -0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2, 92 | -0.1,0.1,0.0,-0.01,0.00,-0.1,0.1,-0.2,-0.1,0.1,0.0,-0.01,0.00,-0.1) 93 | val model = ARGARCH.fitModel(ts) 94 | println(s"alpha: ${model.alpha}") 95 | println(s"beta: ${model.beta}") 96 | println(s"omega: ${model.omega}") 97 | println(s"c: ${model.c}") 98 | println(s"phi: ${model.phi}") 99 | } 100 | 101 | test("standardize and filter") { 102 | val model = new ARGARCHModel(40.0, .4, .2, .3, .4) 103 | val rand = new MersenneTwister(5L) 104 | val n = 10000 105 | 106 | val ts = new DenseVector(model.sample(n, rand)) 107 | 108 | // de-heteroskedasticize 109 | val standardized = model.removeTimeDependentEffects(ts, DenseVector.zeros[Double](n)) 110 | // heteroskedasticize 111 | val filtered = model.addTimeDependentEffects(standardized, DenseVector.zeros[Double](n)) 112 | 113 | assert((filtered - ts).toArray.forall(math.abs(_) < .001)) 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /src/test/scala/com/cloudera/sparkts/LocalSparkContext.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.sparkts 17 | 18 | import org.apache.spark.SparkContext 19 | 20 | import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Suite} 21 | 22 | /** Manages a local `sc` `SparkContext` variable, correctly stopping it after each test. */ 23 | trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self: Suite => 24 | 25 | @transient var sc: SparkContext = _ 26 | 27 | override def beforeAll() { 28 | super.beforeAll() 29 | } 30 | 31 | override def afterEach() { 32 | resetSparkContext() 33 | super.afterEach() 34 | } 35 | 36 | def resetSparkContext() = { 37 | LocalSparkContext.stop(sc) 38 | sc = null 39 | } 40 | 41 | } 42 | 43 | object LocalSparkContext { 44 | def stop(sc: SparkContext) { 45 | if (sc != null) { 46 | sc.stop() 47 | } 48 | // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown 49 | System.clearProperty("spark.driver.port") 50 | } 51 | 52 | /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ 53 | def withSpark[T](sc: SparkContext)(f: SparkContext => T) = { 54 | try { 55 | f(sc) 56 | } finally { 57 | stop(sc) 58 | } 59 | } 60 | 61 | } -------------------------------------------------------------------------------- /src/test/scala/com/cloudera/sparkts/RebaseSuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.sparkts 17 | 18 | import scala.Double.NaN 19 | 20 | import breeze.linalg._ 21 | 22 | import com.cloudera.sparkts.DateTimeIndex._ 23 | import com.cloudera.sparkts.TimeSeriesUtils._ 24 | 25 | import com.github.nscala_time.time.Imports._ 26 | 27 | import org.scalatest.{FunSuite, ShouldMatchers} 28 | 29 | class RebaseSuite extends FunSuite with ShouldMatchers { 30 | test("iterateWithUniformFrequency single value") { 31 | val baseDT = new DateTime("2015-4-8") 32 | val dts = Array(baseDT) 33 | val values = Array(1.0) 34 | val iter = iterateWithUniformFrequency(dts.zip(values).iterator, 1.days, 47.0) 35 | iter.toArray should be (Array((baseDT, 1.0))) 36 | } 37 | 38 | test("iterateWithUniformFrequency no gaps") { 39 | val baseDT = new DateTime("2015-4-8") 40 | val dts = Array(baseDT, baseDT + 1.days, baseDT + 2.days, baseDT + 3.days) 41 | val values = Array(1.0, 2.0, 3.0, 4.0) 42 | val iter = iterateWithUniformFrequency(dts.zip(values).iterator, 1.days) 43 | iter.toArray should be (Array((baseDT, 1.0), (baseDT + 1.days, 2.0), (baseDT + 2.days, 3.0), 44 | (baseDT + 3.days, 4.0))) 45 | } 46 | 47 | test("iterateWithUniformFrequency multiple gaps") { 48 | val baseDT = new DateTime("2015-4-8") 49 | val dts = Array(baseDT, baseDT + 2.days, baseDT + 5.days) 50 | val values = Array(1.0, 2.0, 3.0) 51 | val iter = iterateWithUniformFrequency(dts.zip(values).iterator, 1.days, 47.0) 52 | iter.toArray should be (Array((baseDT, 1.0), (baseDT + 1.days, 47.0), (baseDT + 2.days, 2.0), 53 | (baseDT + 3.days, 47.0), (baseDT + 4.days, 47.0), (baseDT + 5.days, 3.0))) 54 | } 55 | 56 | test("uniform source same range") { 57 | val vec = new DenseVector((0 until 10).map(_.toDouble).toArray) 58 | val source = uniform(new DateTime("2015-4-8"), vec.length, 1.days) 59 | val target = source 60 | val rebased = rebase(source, target, vec, NaN) 61 | rebased.length should be (vec.length) 62 | rebased should be (vec) 63 | } 64 | 65 | test("uniform source, target fits in source") { 66 | val vec = new DenseVector((0 until 10).map(_.toDouble).toArray) 67 | val source = uniform(new DateTime("2015-4-8"), vec.length, 1.days) 68 | val target = uniform(new DateTime("2015-4-9"), 5, 1.days) 69 | val rebased = rebase(source, target, vec, NaN) 70 | rebased should be (new DenseVector(Array(1.0, 2.0, 3.0, 4.0, 5.0))) 71 | } 72 | 73 | test("uniform source, target overlaps source ") { 74 | val vec = new DenseVector((0 until 10).map(_.toDouble).toArray) 75 | val source = uniform(new DateTime("2015-4-8"), vec.length, 1.days) 76 | val targetBefore = uniform(new DateTime("2015-4-4"), 8, 1.days) 77 | val targetAfter = uniform(new DateTime("2015-4-11"), 8, 1.days) 78 | val rebasedBefore = rebase(source, targetBefore, vec, NaN) 79 | val rebasedAfter = rebase(source, targetAfter, vec, NaN) 80 | assertArraysEqualWithNaN( 81 | rebasedBefore.valuesIterator.toArray, 82 | Array(NaN, NaN, NaN, NaN, 0.0, 1.0, 2.0, 3.0)) 83 | assertArraysEqualWithNaN( 84 | rebasedAfter.valuesIterator.toArray, 85 | Array(3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, NaN)) 86 | } 87 | 88 | test("uniform source, source fits in target") { 89 | val vec = new DenseVector((0 until 4).map(_.toDouble).toArray) 90 | val source = uniform(new DateTime("2015-4-8"), vec.length, 1.days) 91 | val target = uniform(new DateTime("2015-4-7"), 8, 1.days) 92 | val rebased = rebase(source, target, vec, NaN) 93 | assertArraysEqualWithNaN( 94 | rebased.valuesIterator.toArray, 95 | Array(NaN, 0.0, 1.0, 2.0, 3.0, NaN, NaN, NaN)) 96 | } 97 | 98 | test("irregular source same range") { 99 | val vec = new DenseVector((4 until 10).map(_.toDouble).toArray) 100 | val source = irregular((4 until 10).map(d => new DateTime(s"2015-4-$d")).toArray) 101 | vec.size should be (source.size) 102 | val target = uniform(new DateTime("2015-4-4"), vec.length, 1.days) 103 | val rebased = rebase(source, target, vec, NaN) 104 | rebased should be (vec) 105 | } 106 | 107 | test("irregular source, hole gets filled default value") { 108 | val dt = new DateTime("2015-4-10") 109 | val source = irregular(Array(dt, dt + 1.days, dt + 3.days)) 110 | val target = uniform(dt, 4, 1.days) 111 | val vec = new DenseVector(Array(1.0, 2.0, 3.0)) 112 | val rebased = rebase(source, target, vec, 47.0) 113 | rebased.toArray should be (Array(1.0, 2.0, 47.0, 3.0)) 114 | } 115 | 116 | test("irregular source, target fits in source") { 117 | val dt = new DateTime("2015-4-10") 118 | val source = irregular(Array(dt, dt + 1.days, dt + 3.days)) 119 | val target = uniform(dt + 1.days, 2, 1.days) 120 | val vec = new DenseVector(Array(1.0, 2.0, 3.0)) 121 | val rebased = rebase(source, target, vec, 47.0) 122 | rebased.toArray should be (Array(2.0, 47.0)) 123 | } 124 | 125 | test("irregular source, target overlaps source ") { 126 | val dt = new DateTime("2015-4-10") 127 | val source = irregular(Array(dt, dt + 1.days, dt + 3.days)) 128 | val targetBefore = uniform(new DateTime("2015-4-8"), 4, 1.days) 129 | val vec = new DenseVector(Array(1.0, 2.0, 3.0)) 130 | val rebasedBefore = rebase(source, targetBefore, vec, 47.0) 131 | rebasedBefore.toArray should be (Array(47.0, 47.0, 1.0, 2.0)) 132 | val targetAfter = uniform(new DateTime("2015-4-11"), 5, 1.days) 133 | val rebasedAfter = rebase(source, targetAfter, vec, 47.0) 134 | rebasedAfter.toArray should be (Array(2.0, 47.0, 3.0, 47.0, 47.0)) 135 | } 136 | 137 | test("irregular source, source fits in target") { 138 | val dt = new DateTime("2015-4-10") 139 | val source = irregular(Array(dt, dt + 1.days, dt + 3.days)) 140 | val target = uniform(dt - 2.days, 7, 1.days) 141 | val vec = new DenseVector(Array(1.0, 2.0, 3.0)) 142 | val rebased = rebase(source, target, vec, 47.0) 143 | rebased.toArray should be (Array(47.0, 47.0, 1.0, 2.0, 47.0, 3.0, 47.0)) 144 | } 145 | 146 | test("irregular source, irregular target") { 147 | // Triples of source index, target index, expected output 148 | // Assumes that at time i, value of source series is i 149 | val cases = Array( 150 | (Array(1, 2, 3), Array(1, 2, 3), Array(1, 2, 3)), 151 | (Array(1, 2, 3), Array(1, 2), Array(1, 2)), 152 | (Array(1, 2), Array(1, 2, 3), Array(1, 2, -1)), 153 | (Array(2, 3), Array(1, 2, 3), Array(-1, 2, 3)), 154 | (Array(1, 2), Array(2, 3), Array(2, -1)), 155 | (Array(1, 2, 3), Array(1, 3), Array(1, 3)), 156 | (Array(1, 2, 3, 4), Array(1, 3), Array(1, 3)), 157 | (Array(1, 2, 3, 4), Array(1, 4), Array(1, 4)), 158 | (Array(1, 2, 3, 4), Array(2, 4), Array(2, 4)), 159 | (Array(1, 2, 3, 4), Array(2, 3), Array(2, 3)), 160 | (Array(1, 2, 3, 4), Array(1, 3, 4), Array(1, 3, 4)) 161 | ) 162 | 163 | cases.foreach { case (source, target, expected) => 164 | val sourceIndex = irregular(source.map(x => new DateTime(s"2015-04-0$x"))) 165 | val targetIndex = irregular(target.map(x => new DateTime(s"2015-04-0$x"))) 166 | val vec = new DenseVector[Double](source.map(_.toDouble)) 167 | val expectedVec = new DenseVector[Double](expected.map(_.toDouble)) 168 | rebase(sourceIndex, targetIndex, vec, -1) should be (expectedVec) 169 | } 170 | } 171 | 172 | private def assertArraysEqualWithNaN(arr1: Array[Double], arr2: Array[Double]): Unit = { 173 | assert(arr1.zip(arr2).forall { case (d1, d2) => 174 | d1 == d2 || (d1.isNaN && d2.isNaN) 175 | }, s"${arr1.mkString(",")} != ${arr2.mkString(",")}") 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /src/test/scala/com/cloudera/sparkts/TimeSeriesStatisticalTestsSuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * Version 2.0 (the "License"). You may not use this file except in 7 | * compliance with the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 12 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 13 | * the specific language governing permissions and limitations under the 14 | * License. 15 | */ 16 | 17 | package com.cloudera.sparkts 18 | 19 | import breeze.linalg._ 20 | 21 | import com.cloudera.sparkts.TimeSeriesStatisticalTests._ 22 | 23 | import org.apache.commons.math.stat.regression.OLSMultipleLinearRegression 24 | import org.apache.commons.math3.random.MersenneTwister 25 | 26 | import org.scalatest.{FunSuite, ShouldMatchers} 27 | 28 | class TimeSeriesStatisticalTestsSuite extends FunSuite with ShouldMatchers { 29 | test("breusch-godfrey") { 30 | // Replicating the example provided by R package lmtest for bgtest 31 | val rand = new MersenneTwister(5L) 32 | val n = 100 33 | val coef = 0.5 // coefficient for lagged series 34 | val x = Array.fill(n / 2)(Array(1.0, -1.0)).flatten 35 | // stationary series 36 | val y1 = x.map(_ + 1 + rand.nextGaussian()) 37 | // AR(1) series, recursive filter with coef coefficient 38 | val y2 = y1.scanLeft(0.0) { case (prior, curr) => prior * coef + curr }.tail 39 | 40 | val pthreshold = 0.05 41 | 42 | val OLS1 = new OLSMultipleLinearRegression() 43 | OLS1.newSampleData(y1, x.map(Array(_))) 44 | val resids1 = OLS1.estimateResiduals() 45 | 46 | val OLS2 = new OLSMultipleLinearRegression() 47 | OLS2.newSampleData(y2, x.map(Array(_))) 48 | val resids2 = OLS2.estimateResiduals() 49 | 50 | // there should be no evidence of serial correlation 51 | bgtest(new DenseVector(resids1), new DenseMatrix(x.length, 1, x), 1)._2 should be > pthreshold 52 | bgtest(new DenseVector(resids1), new DenseMatrix(x.length, 1, x), 4)._2 should be > pthreshold 53 | // there should be evidence of serial correlation 54 | bgtest(new DenseVector(resids2), new DenseMatrix(x.length, 1, x), 1)._2 should be < pthreshold 55 | bgtest(new DenseVector(resids2), new DenseMatrix(x.length, 1, x), 4)._2 should be < pthreshold 56 | } 57 | 58 | test("breusch-pagan") { 59 | // Replicating the example provided by R package lmtest for bptest 60 | val rand = new MersenneTwister(5L) 61 | val n = 100 62 | val x = Array.fill(n / 2)(Array(-1.0, 1.0)).flatten 63 | 64 | // homoskedastic residuals with variance 1 throughout 65 | val err1 = Array.fill(n)(rand.nextGaussian) 66 | // heteroskedastic residuals with alternating variance of 1 and 4 67 | val varFactor = 2 68 | val err2 = err1.zipWithIndex.map { case (xi, i) => if(i % 2 == 0) xi * varFactor else xi } 69 | 70 | // generate dependent variables 71 | val y1 = x.zip(err1).map { case (xi, ei) => xi + ei + 1 } 72 | val y2 = x.zip(err2).map { case (xi, ei) => xi + ei + 1 } 73 | 74 | // create models and calculate residuals 75 | val OLS1 = new OLSMultipleLinearRegression() 76 | OLS1.newSampleData(y1, x.map(Array(_))) 77 | val resids1 = OLS1.estimateResiduals() 78 | 79 | val OLS2 = new OLSMultipleLinearRegression() 80 | OLS2.newSampleData(y2, x.map(Array(_))) 81 | val resids2 = OLS2.estimateResiduals() 82 | 83 | val pthreshold = 0.05 84 | // there should be no evidence of heteroskedasticity 85 | bptest(new DenseVector(resids1), new DenseMatrix(x.length, 1, x))._2 should be > pthreshold 86 | // there should be evidence of heteroskedasticity 87 | bptest(new DenseVector(resids2), new DenseMatrix(x.length, 1, x))._2 should be < pthreshold 88 | } 89 | 90 | test("ljung-box test") { 91 | val rand = new MersenneTwister(5L) 92 | val n = 100 93 | val indep = Array.fill(n)(rand.nextGaussian) 94 | val vecIndep = new DenseVector(indep) 95 | val (stat1, pval1) = lbtest(vecIndep, 1) 96 | pval1 should be > 0.05 97 | 98 | // serially correlated 99 | val coef = 0.3 100 | val dep = indep.scanLeft(0.0) { case (prior, curr) => prior * coef + curr }.tail 101 | val vecDep = new DenseVector(dep) 102 | val (stat2, pval2) = lbtest(vecDep, 2) 103 | pval2 should be < 0.05 104 | } 105 | 106 | test("KPSS test: R equivalence") { 107 | // Note that we only test the statistic, as in contrast with the R tseries implementation 108 | // we do not calculate a p-value, but rather return a map of appropriate critical values 109 | // R's tseries linearly interpolates between critical values. 110 | 111 | // version.string R version 3.2.0 (2015-04-16) 112 | // > set.seed(10) 113 | // > library(tseries) 114 | // > v <- rnorm(20) 115 | // > kpss.test(v, "Level") 116 | // 117 | // KPSS Test for Level Stationarity 118 | // 119 | // data: v 120 | // KPSS Level = 0.27596, Truncation lag parameter = 1, p-value = 0.1 121 | // 122 | // Warning message: 123 | // In kpss.test(v, "Level") : p-value greater than printed p-value 124 | // > kpss.test(v, "Trend") 125 | // 126 | // KPSS Test for Trend Stationarity 127 | // 128 | // data: v 129 | // KPSS Trend = 0.05092, Truncation lag parameter = 1, p-value = 0.1 130 | // 131 | // Warning message: 132 | // In kpss.test(v, "Trend") : p-value greater than printed p-value 133 | 134 | val arr = Array(0.0187461709418264, -0.184252542069064, -1.37133054992251, -0.599167715783718, 135 | 0.294545126567508, 0.389794300700167, -1.20807617542949, -0.363676017470862, 136 | -1.62667268170309, -0.256478394123992, 1.10177950308713, 0.755781508027337, 137 | -0.238233556018718, 0.98744470341339, 0.741390128383824, 0.0893472664958216, 138 | -0.954943856152377, -0.195150384667239, 0.92552126209408, 0.482978524836611 139 | ) 140 | val dv = new DenseVector(arr) 141 | val cTest = kpsstest(dv, "c")._1 142 | val ctTest = kpsstest(dv, "ct")._1 143 | 144 | cTest should be (0.2759 +- 1e-4) 145 | ctTest should be (0.05092 +- 1e-4) 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /src/test/scala/com/cloudera/sparkts/TimeSeriesSuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.sparkts 17 | 18 | import breeze.linalg.{DenseVector, DenseMatrix} 19 | import com.cloudera.sparkts.TimeSeries._ 20 | 21 | import com.github.nscala_time.time.Imports._ 22 | 23 | import org.scalatest.{FunSuite, ShouldMatchers} 24 | 25 | import scala.collection.immutable.IndexedSeq 26 | 27 | class TimeSeriesSuite extends FunSuite with ShouldMatchers { 28 | test("timeSeriesFromIrregularSamples") { 29 | val dt = new DateTime("2015-4-8") 30 | val samples = Array( 31 | ((dt, Array(1.0, 2.0, 3.0))), 32 | ((dt + 1.days, Array(4.0, 5.0, 6.0))), 33 | ((dt + 2.days, Array(7.0, 8.0, 9.0))), 34 | ((dt + 4.days, Array(10.0, 11.0, 12.0))) 35 | ) 36 | 37 | val labels = Array("a", "b", "c", "d") 38 | val ts = timeSeriesFromIrregularSamples(samples, labels) 39 | ts.data.valuesIterator.toArray should be ((1 to 12).map(_.toDouble).toArray) 40 | } 41 | 42 | test("lagsIncludingOriginals") { 43 | val originalIndex = new UniformDateTimeIndex(0, 5, new DayFrequency(1)) 44 | 45 | val data = DenseMatrix((1.0, 6.0), (2.0, 7.0), (3.0, 8.0), (4.0, 9.0), (5.0, 10.0)) 46 | 47 | val originalTimeSeries = new TimeSeries(originalIndex, data, Array("a", "b")) 48 | 49 | val laggedTimeSeries = originalTimeSeries.lags(2, true) 50 | 51 | laggedTimeSeries.keys should be (Array("a", "lag1(a)", "lag2(a)", "b", "lag1(b)", "lag2(b)")) 52 | laggedTimeSeries.index.size should be (3) 53 | laggedTimeSeries.data should be (DenseMatrix((3.0, 2.0, 1.0, 8.0, 7.0, 6.0), 54 | (4.0, 3.0, 2.0, 9.0, 8.0, 7.0), (5.0, 4.0, 3.0, 10.0, 9.0, 8.0))) 55 | } 56 | 57 | test("lagsExcludingOriginals") { 58 | val originalIndex = new UniformDateTimeIndex(0, 5, new DayFrequency(1)) 59 | 60 | val data = DenseMatrix((1.0, 6.0), (2.0, 7.0), (3.0, 8.0), (4.0, 9.0), (5.0, 10.0)) 61 | 62 | val originalTimeSeries = new TimeSeries(originalIndex, data, Array("a", "b")) 63 | 64 | val laggedTimeSeries = originalTimeSeries.lags(2, false) 65 | 66 | laggedTimeSeries.keys should be (Array("lag1(a)", "lag2(a)", "lag1(b)", "lag2(b)")) 67 | laggedTimeSeries.index.size should be (3) 68 | laggedTimeSeries.data should be (DenseMatrix((2.0, 1.0, 7.0, 6.0), (3.0, 2.0, 8.0, 7.0), 69 | (4.0, 3.0, 9.0, 8.0))) 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/test/scala/com/cloudera/sparkts/UnivariateTimeSeriesSuite.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2015, Cloudera, Inc. All Rights Reserved. 3 | * 4 | * Cloudera, Inc. licenses this file to you under the Apache License, 5 | * Version 2.0 (the "License"). You may not use this file except in 6 | * compliance with the License. You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | * CONDITIONS OF ANY KIND, either express or implied. See the License for 12 | * the specific language governing permissions and limitations under the 13 | * License. 14 | */ 15 | 16 | package com.cloudera.sparkts 17 | 18 | import scala.Double.NaN 19 | 20 | import breeze.linalg._ 21 | 22 | import com.cloudera.sparkts.UnivariateTimeSeries._ 23 | 24 | import org.apache.commons.math3.random.MersenneTwister 25 | 26 | import org.scalatest.{FunSuite, ShouldMatchers} 27 | 28 | class UnivariateTimeSeriesSuite extends FunSuite with ShouldMatchers { 29 | test("lagIncludeOriginalsTrue") { 30 | val lagMatrix = UnivariateTimeSeries.lag(Vector(1.0, 2.0, 3.0, 4.0, 5.0), 2, true) 31 | lagMatrix should be (DenseMatrix((3.0, 2.0, 1.0), (4.0, 3.0, 2.0), (5.0, 4.0, 3.0))) 32 | } 33 | 34 | test("lagIncludeOriginalsFalse") { 35 | val lagMatrix = UnivariateTimeSeries.lag(Vector(1.0, 2.0, 3.0, 4.0, 5.0), 2, false) 36 | lagMatrix should be (DenseMatrix((2.0, 1.0), (3.0, 2.0), (4.0, 3.0))) 37 | } 38 | 39 | test("lastNotNaN") { 40 | lastNotNaN(new DenseVector(Array(1.0, 2.0, 3.0, 4.0))) should be (3) 41 | lastNotNaN(new DenseVector(Array(1.0, 2.0, NaN, 4.0))) should be (3) 42 | lastNotNaN(new DenseVector(Array(1.0, 2.0, 3.0, NaN))) should be (2) 43 | } 44 | 45 | test("autocorr") { 46 | val rand = new MersenneTwister(5L) 47 | val iidAutocorr = autocorr(Array.fill(10000)(rand.nextDouble * 5.0), 3) 48 | iidAutocorr.foreach(math.abs(_) should be < .03) 49 | 50 | val arModel = new ARModel(1.5, Array(.2)) 51 | val arSeries = arModel.sample(10000, rand) 52 | val arAutocorr = autocorr(arSeries, 3) 53 | math.abs(.2 - arAutocorr(0)) should be < 0.02 54 | arAutocorr(1) should be < 0.06 55 | arAutocorr(2) should be < 0.06 56 | arAutocorr(1) should be > 0.0 57 | arAutocorr(2) should be > 0.0 58 | } 59 | 60 | test("upsampling") { 61 | // replicating upsampling examples 62 | // from http://www.mathworks.com/help/signal/ref/upsample.html?searchHighlight=upsample 63 | val y = new DenseVector(Array(1.0, 2.0, 3.0, 4.0)) 64 | val yUp1 = upsample(y, 3, useZero = true).toArray 65 | yUp1 should be (Array(1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0)) 66 | 67 | val yUp2 = upsample(y, 3, useZero = true, phase = 2).toArray 68 | yUp2 should be (Array(0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0)) 69 | } 70 | 71 | test("downsampling") { 72 | // replicating downsampling examples 73 | // from http://www.mathworks.com/help/signal/ref/downsample.html?searchHighlight=downsample 74 | val y = new DenseVector((1 to 10).toArray.map(_.toDouble)) 75 | val yDown1 = downsample(y, 3).toArray 76 | yDown1 should be (Array(1.0, 4.0, 7.0, 10.0)) 77 | 78 | val yDown2 = downsample(y, 3, phase = 2).toArray 79 | yDown2 should be (Array(3.0, 6.0, 9.0)) 80 | 81 | } 82 | 83 | test("signal reconstruction with spline") { 84 | // If we have a frequent signal, downsample it (at a rate that doesn't cause aliasing) 85 | // and we upsample, and apply a filter (interpolation), then the result should be fairly 86 | // close to the original signal. In our case, we drop NAs that are not filled by interpolation 87 | // (i.e no extrapolation) 88 | 89 | val y = (1 to 1000).toArray.map(_.toDouble / 100.0).map(Math.sin) 90 | val vy = new DenseVector(y) 91 | val lessFreq = downsample(vy, 100) 92 | val moreFreq = upsample(lessFreq, 100) 93 | 94 | // work on copies 95 | val splineY = fillSpline(new DenseVector(moreFreq.toArray)).toArray 96 | val lineY = fillLinear(new DenseVector(moreFreq.toArray)).toArray 97 | 98 | val MSE = (est: Array[Double], obs: Array[Double]) => { 99 | val errs = est.zip(obs).filter(!_._1.isNaN).map { case (yhat, yi) => (yhat - yi) * (yhat - yi) } 100 | errs.sum / errs.length 101 | } 102 | 103 | val sE = MSE(splineY, y) 104 | val lE = MSE(lineY, y) 105 | 106 | // a cubic spline should be better than linear interpolation 107 | sE should be < lE 108 | } 109 | 110 | test("differencing at lag") { 111 | val rand = new MersenneTwister(10L) 112 | val n = 100 113 | val sampled = new DenseVector(Array.fill(n)(rand.nextGaussian)) 114 | val lag = 5 115 | val diffed = differencesAtLag(sampled, lag) 116 | val invDiffed = inverseDifferencesAtLag(diffed, lag) 117 | 118 | for (i <- 0 until n) { 119 | sampled(i) should be (invDiffed(i) +- 1e-6) 120 | } 121 | 122 | diffed(10) should be (sampled(10) - sampled(5)) 123 | diffed(99) should be (sampled(99) - sampled(94)) 124 | } 125 | 126 | test("differencing of order d") { 127 | val rand = new MersenneTwister(10L) 128 | val n = 100 129 | val sampled = new DenseVector(Array.fill(n)(rand.nextGaussian)) 130 | // differencing at order 1 and lag 1 should be the same 131 | val diffedOfOrder1 = differencesOfOrderD(sampled, 1) 132 | val diffedAtLag1 = differencesAtLag(sampled, 1) 133 | 134 | for (i <- 0 until n) { 135 | diffedAtLag1(i) should be (diffedOfOrder1(i) +- 1e-6) 136 | } 137 | 138 | // differencing at order and inversing should return the original series 139 | val diffedOfOrder5 = differencesOfOrderD(sampled, 5) 140 | val invDiffedOfOrder5 = inverseDifferencesOfOrderD(diffedOfOrder5, 5) 141 | 142 | for (i <- 0 until n) { 143 | invDiffedOfOrder5(i) should be (sampled(i) +- 1e-6) 144 | } 145 | 146 | // Differencing of order n + 1 should be the same as differencing one time a 147 | // vector that has already been differenced to order n 148 | val diffedOfOrder6 = differencesOfOrderD(sampled, 6) 149 | val diffedOneMore = differencesOfOrderD(diffedOfOrder5, 1) 150 | // compare start at index = 6 151 | for (i <- 6 until n) { 152 | diffedOfOrder6(i) should be (diffedOneMore(i) +- 1e-6) 153 | } 154 | } 155 | } 156 | 157 | --------------------------------------------------------------------------------